Skip to content

Commit 62ce55b

Browse files
authored
updating bulk update to use sdkclient (opensearch-project#3546)
* updating bulk update to use sdkclient Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressed comments Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> --------- Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent dd58f18 commit 62ce55b

File tree

3 files changed

+279
-91
lines changed

3 files changed

+279
-91
lines changed

plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java

+45-9
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
import java.util.Map;
1717
import java.util.stream.Collectors;
1818

19+
import org.opensearch.OpenSearchStatusException;
1920
import org.opensearch.action.FailedNodeException;
20-
import org.opensearch.action.bulk.BulkRequest;
21+
import org.opensearch.action.bulk.BulkItemResponse;
2122
import org.opensearch.action.bulk.BulkResponse;
2223
import org.opensearch.action.support.ActionFilters;
2324
import org.opensearch.action.support.WriteRequest;
2425
import org.opensearch.action.support.nodes.TransportNodesAction;
25-
import org.opensearch.action.update.UpdateRequest;
2626
import org.opensearch.cluster.service.ClusterService;
2727
import org.opensearch.common.inject.Inject;
2828
import org.opensearch.common.util.concurrent.ThreadContext;
@@ -43,6 +43,10 @@
4343
import org.opensearch.ml.model.MLModelManager;
4444
import org.opensearch.ml.stats.MLNodeLevelStat;
4545
import org.opensearch.ml.stats.MLStats;
46+
import org.opensearch.remote.metadata.client.BulkDataObjectRequest;
47+
import org.opensearch.remote.metadata.client.SdkClient;
48+
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
49+
import org.opensearch.remote.metadata.common.SdkClientUtils;
4650
import org.opensearch.tasks.Task;
4751
import org.opensearch.threadpool.ThreadPool;
4852
import org.opensearch.transport.TransportService;
@@ -58,6 +62,7 @@ public class TransportUndeployModelAction extends
5862
private final MLModelManager mlModelManager;
5963
private final ClusterService clusterService;
6064
private final Client client;
65+
private final SdkClient sdkClient;
6166
private final DiscoveryNodeHelper nodeFilter;
6267
private final MLStats mlStats;
6368

@@ -69,6 +74,7 @@ public TransportUndeployModelAction(
6974
ClusterService clusterService,
7075
ThreadPool threadPool,
7176
Client client,
77+
SdkClient sdkClient,
7278
DiscoveryNodeHelper nodeFilter,
7379
MLStats mlStats
7480
) {
@@ -87,19 +93,21 @@ public TransportUndeployModelAction(
8793

8894
this.clusterService = clusterService;
8995
this.client = client;
96+
this.sdkClient = sdkClient;
9097
this.nodeFilter = nodeFilter;
9198
this.mlStats = mlStats;
9299
}
93100

94101
@Override
95102
protected void doExecute(Task task, MLUndeployModelNodesRequest request, ActionListener<MLUndeployModelNodesResponse> listener) {
96103
ActionListener<MLUndeployModelNodesResponse> wrappedListener = ActionListener.wrap(undeployModelNodesResponse -> {
97-
processUndeployModelResponseAndUpdate(undeployModelNodesResponse, listener);
104+
processUndeployModelResponseAndUpdate(request.getTenantId(), undeployModelNodesResponse, listener);
98105
}, listener::onFailure);
99106
super.doExecute(task, request, wrappedListener);
100107
}
101108

102109
void processUndeployModelResponseAndUpdate(
110+
String tenantId,
103111
MLUndeployModelNodesResponse undeployModelNodesResponse,
104112
ActionListener<MLUndeployModelNodesResponse> listener
105113
) {
@@ -145,11 +153,10 @@ void processUndeployModelResponseAndUpdate(
145153

146154
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput);
147155
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
148-
if (actualRemovedNodesMap.size() > 0) {
149-
BulkRequest bulkRequest = new BulkRequest();
156+
if (!actualRemovedNodesMap.isEmpty()) {
157+
BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().globalIndex(ML_MODEL_INDEX).build();
150158
Map<String, Boolean> deployToAllNodes = new HashMap<>();
151159
for (String modelId : actualRemovedNodesMap.keySet()) {
152-
UpdateRequest updateRequest = new UpdateRequest();
153160
List<String> removedNodes = actualRemovedNodesMap.get(modelId);
154161
int removedNodeCount = removedNodes.size();
155162
/**
@@ -178,7 +185,13 @@ void processUndeployModelResponseAndUpdate(
178185
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size());
179186
deployToAllNodes.put(modelId, false);
180187
}
181-
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(updateDocument);
188+
189+
UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest
190+
.builder()
191+
.id(modelId)
192+
.tenantId(tenantId)
193+
.dataObject(updateDocument)
194+
.build();
182195
bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
183196
}
184197
syncUpInput.setDeployToAllNodes(deployToAllNodes);
@@ -189,10 +202,33 @@ void processUndeployModelResponseAndUpdate(
189202
Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0]))
190203
);
191204
}, e -> { log.error("Failed to update model state as undeployed", e); });
192-
client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> {
205+
ActionListener<BulkResponse> wrappedListener = ActionListener.runAfter(actionListener, () -> {
193206
syncUpUndeployedModels(syncUpRequest);
194207
listener.onResponse(undeployModelNodesResponse);
195-
}));
208+
});
209+
sdkClient.bulkDataObjectAsync(bulkRequest).whenComplete((r, throwable) -> {
210+
if (throwable != null) {
211+
Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class);
212+
log.error("Failed to execute BulkDataObject request", cause);
213+
wrappedListener.onFailure(cause);
214+
} else {
215+
try {
216+
BulkResponse bulkResponse = BulkResponse.fromXContent(r.parser());
217+
log
218+
.info(
219+
"Executed {} bulk operations with {} failures, Took: {}",
220+
bulkResponse.getItems().length,
221+
bulkResponse.hasFailures()
222+
? Arrays.stream(bulkResponse.getItems()).filter(BulkItemResponse::isFailed).count()
223+
: 0,
224+
bulkResponse.getTook()
225+
);
226+
wrappedListener.onResponse(bulkResponse);
227+
} catch (Exception e) {
228+
wrappedListener.onFailure(e);
229+
}
230+
}
231+
});
196232
} else {
197233
syncUpUndeployedModels(syncUpRequest);
198234
listener.onResponse(undeployModelNodesResponse);

0 commit comments

Comments
 (0)