Skip to content

Commit 570edaf

Browse files
authoredJan 24, 2025
Check before delete (opensearch-project#3209)
* add logic to detect agent before deleting Signed-off-by: xinyual <xinyual@amazon.com> * add logic to detect agent before deleting Signed-off-by: xinyual <xinyual@amazon.com> * add logic to detect pipelines before delete model Signed-off-by: xinyual <xinyual@amazon.com> * check pipeline before deleting Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * remove useless file Signed-off-by: xinyual <xinyual@amazon.com> * rename functions Signed-off-by: xinyual <xinyual@amazon.com> * fix failure test Signed-off-by: xinyual <xinyual@amazon.com> * add UT Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * renam Signed-off-by: xinyual <xinyual@amazon.com> * refactor to parallel check Signed-off-by: xinyual <xinyual@amazon.com> * concate error message Signed-off-by: xinyual <xinyual@amazon.com> * move logic after user access check Signed-off-by: xinyual <xinyual@amazon.com> * change agent model searcher map to set Signed-off-by: xinyual <xinyual@amazon.com> * rename and remove useless method Signed-off-by: xinyual <xinyual@amazon.com> * fix bug to fetch all pipelines Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * remove and add comment Signed-off-by: xinyual <xinyual@amazon.com> * rename and add more UTs Signed-off-by: xinyual <xinyual@amazon.com> * use correct key Signed-off-by: xinyual <xinyual@amazon.com> * simplify function Signed-off-by: xinyual <xinyual@amazon.com> * change to a better class Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * change compareAndSet to set Signed-off-by: xinyual <xinyual@amazon.com> * apply comment Signed-off-by: xinyual <xinyual@amazon.com> * change name and reformat logic Signed-off-by: xinyual <xinyual@amazon.com> * change name Signed-off-by: xinyual <xinyual@amazon.com> * remove useless line Signed-off-by: xinyual <xinyual@amazon.com> * change to a better method Signed-off-by: xinyual <xinyual@amazon.com> * change name Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * add java doc for function Signed-off-by: xinyual <xinyual@amazon.com> * add another interface Signed-off-by: xinyual <xinyual@amazon.com> * apply java spotless Signed-off-by: xinyual <xinyual@amazon.com> * change interface to with model Signed-off-by: xinyual <xinyual@amazon.com> * apply spot less Signed-off-by: xinyual <xinyual@amazon.com> * add settings Signed-off-by: xinyual <xinyual@amazon.com> * apply spot less Signed-off-by: xinyual <xinyual@amazon.com> * add test for cluster setting Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * recover useless change Signed-off-by: xinyual <xinyual@amazon.com> * change default value of cluster setting Signed-off-by: xinyual <xinyual@amazon.com> * rename setting and add comment Signed-off-by: xinyual <xinyual@amazon.com> * apply spot Signed-off-by: xinyual <xinyual@amazon.com> * remove logic for hidden model Signed-off-by: xinyual <xinyual@amazon.com> * reorder code Signed-off-by: xinyual <xinyual@amazon.com> * reorder code Signed-off-by: xinyual <xinyual@amazon.com> * reorder code Signed-off-by: xinyual <xinyual@amazon.com> * apply spot Signed-off-by: xinyual <xinyual@amazon.com> * add UT Signed-off-by: xinyual <xinyual@amazon.com> * add more UT Signed-off-by: xinyual <xinyual@amazon.com> * remove search for hidden agent Signed-off-by: xinyual <xinyual@amazon.com> * fix logic and apply spot Signed-off-by: xinyual <xinyual@amazon.com> * add exist for UT Signed-off-by: xinyual <xinyual@amazon.com> * change dsl to query index Signed-off-by: xinyual <xinyual@amazon.com> * change query logic Signed-off-by: xinyual <xinyual@amazon.com> * remove useless ut Signed-off-by: xinyual <xinyual@amazon.com> * rebert Signed-off-by: xinyual <xinyual@amazon.com> * apply spot Signed-off-by: xinyual <xinyual@amazon.com> * rechange code Signed-off-by: xinyual <xinyual@amazon.com> * apply spot Signed-off-by: xinyual <xinyual@amazon.com> * remove useless should Signed-off-by: xinyual <xinyual@amazon.com> * apply spot Signed-off-by: xinyual <xinyual@amazon.com> * fix final dsl logic and ut Signed-off-by: xinyual <xinyual@amazon.com> --------- Signed-off-by: xinyual <xinyual@amazon.com>
1 parent af96fe0 commit 570edaf

File tree

10 files changed

+825
-47
lines changed

10 files changed

+825
-47
lines changed
 

‎common/src/main/java/org/opensearch/ml/common/CommonValue.java

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public class CommonValue {
4545
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
4646
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
4747
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
48+
public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters.";
4849

4950
// Index mapping paths
5051
public static final String ML_MODEL_GROUP_INDEX_MAPPING_PATH = "index-mappings/ml_model_group.json";

‎ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java

+8-3
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import org.opensearch.ml.common.output.model.ModelTensorOutput;
1818
import org.opensearch.ml.common.output.model.ModelTensors;
1919
import org.opensearch.ml.common.spi.tools.Parser;
20-
import org.opensearch.ml.common.spi.tools.Tool;
2120
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
21+
import org.opensearch.ml.common.spi.tools.WithModelTool;
2222
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
2323
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
2424
import org.opensearch.ml.common.utils.StringUtils;
@@ -33,7 +33,7 @@
3333
*/
3434
@Log4j2
3535
@ToolAnnotation(MLModelTool.TYPE)
36-
public class MLModelTool implements Tool {
36+
public class MLModelTool implements WithModelTool {
3737
public static final String TYPE = "MLModelTool";
3838
public static final String RESPONSE_FIELD = "response_field";
3939
public static final String MODEL_ID_FIELD = "model_id";
@@ -127,7 +127,7 @@ public boolean validate(Map<String, String> parameters) {
127127
return true;
128128
}
129129

130-
public static class Factory implements Tool.Factory<MLModelTool> {
130+
public static class Factory implements WithModelTool.Factory<MLModelTool> {
131131
private Client client;
132132

133133
private static Factory INSTANCE;
@@ -172,5 +172,10 @@ public String getDefaultType() {
172172
public String getDefaultVersion() {
173173
return null;
174174
}
175+
176+
@Override
177+
public List<String> getAllModelKeys() {
178+
return List.of(MODEL_ID_FIELD);
179+
}
175180
}
176181
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package org.opensearch.ml.engine.utils;
2+
3+
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
4+
import static org.opensearch.ml.common.CommonValue.TOOL_PARAMETERS_PREFIX;
5+
6+
import java.util.HashSet;
7+
import java.util.Map;
8+
import java.util.Set;
9+
10+
import org.opensearch.action.search.SearchRequest;
11+
import org.opensearch.index.query.BoolQueryBuilder;
12+
import org.opensearch.index.query.QueryBuilders;
13+
import org.opensearch.ml.common.agent.MLAgent;
14+
import org.opensearch.ml.common.spi.tools.Tool;
15+
import org.opensearch.ml.common.spi.tools.WithModelTool;
16+
import org.opensearch.search.builder.SearchSourceBuilder;
17+
18+
public class AgentModelsSearcher {
19+
private final Set<String> relatedModelIdSet;
20+
21+
public AgentModelsSearcher(Map<String, Tool.Factory> toolFactories) {
22+
relatedModelIdSet = new HashSet<>();
23+
for (Map.Entry<String, Tool.Factory> entry : toolFactories.entrySet()) {
24+
Tool.Factory toolFactory = entry.getValue();
25+
if (toolFactory instanceof WithModelTool.Factory) {
26+
WithModelTool.Factory withModelTool = (WithModelTool.Factory) toolFactory;
27+
relatedModelIdSet.addAll(withModelTool.getAllModelKeys());
28+
}
29+
}
30+
}
31+
32+
/**
33+
* Construct a should query to search all agent which containing candidate model Id
34+
35+
@param candidateModelId the candidate model Id
36+
@return a should search request towards agent index.
37+
*/
38+
public SearchRequest constructQueryRequestToSearchModelIdInsideAgent(String candidateModelId) {
39+
SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX);
40+
// Two conditions here
41+
// 1. {[(exists hidden field) and (hidden field = false)] or (not exist hidden field)} and
42+
// 2. Any model field contains candidate ID
43+
BoolQueryBuilder searchAgentQuery = QueryBuilders.boolQuery();
44+
45+
BoolQueryBuilder hiddenFieldQuery = QueryBuilders.boolQuery();
46+
// not exist hidden
47+
hiddenFieldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));
48+
// exist but equal to false
49+
BoolQueryBuilder existHiddenFieldQuery = QueryBuilders.boolQuery();
50+
existHiddenFieldQuery.must(QueryBuilders.termsQuery(MLAgent.IS_HIDDEN_FIELD, false));
51+
existHiddenFieldQuery.must(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD));
52+
hiddenFieldQuery.should(existHiddenFieldQuery);
53+
54+
//
55+
BoolQueryBuilder modelIdQuery = QueryBuilders.boolQuery();
56+
for (String keyField : relatedModelIdSet) {
57+
modelIdQuery.should(QueryBuilders.termsQuery(TOOL_PARAMETERS_PREFIX + keyField, candidateModelId));
58+
}
59+
60+
searchAgentQuery.must(hiddenFieldQuery);
61+
searchAgentQuery.must(modelIdQuery);
62+
searchRequest.source(new SearchSourceBuilder().query(searchAgentQuery));
63+
return searchRequest;
64+
}
65+
66+
}

‎ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import static org.mockito.Mockito.doAnswer;
1515
import static org.mockito.Mockito.verify;
1616
import static org.opensearch.ml.engine.tools.MLModelTool.DEFAULT_DESCRIPTION;
17+
import static org.opensearch.ml.engine.tools.MLModelTool.MODEL_ID_FIELD;
1718

1819
import java.util.Arrays;
1920
import java.util.Collections;
@@ -218,5 +219,6 @@ public void testTool() {
218219
assertTrue(tool.validate(otherParams));
219220
assertFalse(tool.validate(emptyParams));
220221
assertEquals(DEFAULT_DESCRIPTION, tool.getDescription());
222+
assertEquals(List.of(MODEL_ID_FIELD), MLModelTool.Factory.getInstance().getAllModelKeys());
221223
}
222224
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.utils;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertTrue;
10+
import static org.mockito.Mockito.mock;
11+
import static org.mockito.Mockito.when;
12+
13+
import java.util.Arrays;
14+
import java.util.Collections;
15+
import java.util.HashMap;
16+
import java.util.Map;
17+
18+
import org.junit.Test;
19+
import org.opensearch.action.search.SearchRequest;
20+
import org.opensearch.index.query.BoolQueryBuilder;
21+
import org.opensearch.index.query.ExistsQueryBuilder;
22+
import org.opensearch.index.query.QueryBuilder;
23+
import org.opensearch.index.query.TermsQueryBuilder;
24+
import org.opensearch.ml.common.agent.MLAgent;
25+
import org.opensearch.ml.common.spi.tools.Tool;
26+
import org.opensearch.ml.common.spi.tools.WithModelTool;
27+
28+
public class AgentModelSearcherTests {
29+
30+
@Test
31+
public void testConstructor_CollectsModelIds() {
32+
// Arrange
33+
WithModelTool.Factory withModelToolFactory1 = mock(WithModelTool.Factory.class);
34+
when(withModelToolFactory1.getAllModelKeys()).thenReturn(Arrays.asList("modelKey1", "modelKey2"));
35+
36+
WithModelTool.Factory withModelToolFactory2 = mock(WithModelTool.Factory.class);
37+
when(withModelToolFactory2.getAllModelKeys()).thenReturn(Collections.singletonList("anotherModelKey"));
38+
39+
// This tool factory does not implement WithModelTool.Factory
40+
Tool.Factory regularToolFactory = mock(Tool.Factory.class);
41+
42+
Map<String, Tool.Factory> toolFactories = new HashMap<>();
43+
toolFactories.put("withModelTool1", withModelToolFactory1);
44+
toolFactories.put("withModelTool2", withModelToolFactory2);
45+
toolFactories.put("regularTool", regularToolFactory);
46+
47+
// Act
48+
AgentModelsSearcher searcher = new AgentModelsSearcher(toolFactories);
49+
50+
// (Optional) We can't directly access relatedModelIdSet,
51+
// but we can test the behavior indirectly using the search call:
52+
SearchRequest request = searcher.constructQueryRequestToSearchModelIdInsideAgent("candidateId");
53+
54+
// Assert
55+
// Verify the searchRequest uses all keys from the WithModelTool factories
56+
BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) request.source().query();
57+
// We expect modelKey1, modelKey2, anotherModelKey => total 3 "should" clauses
58+
assertEquals(2, boolQueryBuilder.must().size());
59+
for (QueryBuilder query : boolQueryBuilder.must()) {
60+
BoolQueryBuilder subBoolQueryBuilder = (BoolQueryBuilder) query;
61+
assertTrue(subBoolQueryBuilder.should().size() == 2 || subBoolQueryBuilder.should().size() == 3);
62+
if (subBoolQueryBuilder.should().size() == 3) {
63+
boolQueryBuilder.should().forEach(subQuery -> {
64+
assertTrue(subQuery instanceof TermsQueryBuilder);
65+
TermsQueryBuilder termsQuery = (TermsQueryBuilder) subQuery;
66+
// Each TermsQueryBuilder should contain candidateModelId
67+
assertTrue(termsQuery.values().contains("candidateId"));
68+
});
69+
} else {
70+
boolQueryBuilder.should().forEach(subQuery -> {
71+
assertTrue(subQuery instanceof BoolQueryBuilder);
72+
BoolQueryBuilder boolQuery = (BoolQueryBuilder) subQuery;
73+
assertTrue(boolQuery.must().size() == 2 || boolQuery.mustNot().size() == 1);
74+
if (boolQuery.must().size() == 2) {
75+
boolQuery.must().forEach(existSubQuery -> {
76+
assertTrue(existSubQuery instanceof ExistsQueryBuilder || existSubQuery instanceof TermsQueryBuilder);
77+
if (existSubQuery instanceof TermsQueryBuilder) {
78+
TermsQueryBuilder termsQuery = (TermsQueryBuilder) existSubQuery;
79+
assertTrue(termsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD));
80+
assertTrue(termsQuery.values().contains(false));
81+
} else {
82+
ExistsQueryBuilder existsQuery = (ExistsQueryBuilder) existSubQuery;
83+
assertTrue(existsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD));
84+
}
85+
});
86+
} else {
87+
QueryBuilder mustNotQuery = boolQuery.mustNot().get(0);
88+
assertTrue(mustNotQuery instanceof ExistsQueryBuilder);
89+
assertEquals(MLAgent.IS_HIDDEN_FIELD, ((ExistsQueryBuilder) mustNotQuery).fieldName());
90+
}
91+
});
92+
}
93+
}
94+
95+
}
96+
}

‎plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java

+211-1
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,36 @@
1515
import static org.opensearch.ml.common.MLModel.IS_HIDDEN_FIELD;
1616
import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD;
1717
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
18+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_SAFE_DELETE_WITH_USAGE_CHECK;
1819
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;
1920

21+
import java.util.ArrayDeque;
22+
import java.util.ArrayList;
23+
import java.util.Arrays;
24+
import java.util.Deque;
25+
import java.util.List;
26+
import java.util.Locale;
2027
import java.util.Map;
2128
import java.util.Objects;
29+
import java.util.concurrent.ConcurrentLinkedQueue;
2230
import java.util.concurrent.CountDownLatch;
2331
import java.util.concurrent.atomic.AtomicBoolean;
32+
import java.util.function.Supplier;
2433

34+
import org.apache.commons.lang3.tuple.Pair;
2535
import org.opensearch.ExceptionsHelper;
2636
import org.opensearch.OpenSearchStatusException;
2737
import org.opensearch.ResourceNotFoundException;
2838
import org.opensearch.action.ActionRequest;
39+
import org.opensearch.action.ActionType;
2940
import org.opensearch.action.delete.DeleteRequest;
3041
import org.opensearch.action.delete.DeleteResponse;
3142
import org.opensearch.action.get.GetResponse;
43+
import org.opensearch.action.ingest.GetPipelineAction;
44+
import org.opensearch.action.ingest.GetPipelineRequest;
45+
import org.opensearch.action.search.GetSearchPipelineAction;
46+
import org.opensearch.action.search.GetSearchPipelineRequest;
47+
import org.opensearch.action.search.SearchRequest;
3248
import org.opensearch.action.support.ActionFilters;
3349
import org.opensearch.action.support.HandledTransportAction;
3450
import org.opensearch.client.Client;
@@ -37,11 +53,14 @@
3753
import org.opensearch.common.settings.Settings;
3854
import org.opensearch.common.util.concurrent.ThreadContext;
3955
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
56+
import org.opensearch.common.xcontent.XContentHelper;
57+
import org.opensearch.common.xcontent.json.JsonXContent;
4058
import org.opensearch.commons.authuser.User;
4159
import org.opensearch.core.action.ActionListener;
4260
import org.opensearch.core.rest.RestStatus;
4361
import org.opensearch.core.xcontent.NamedXContentRegistry;
4462
import org.opensearch.core.xcontent.XContentParser;
63+
import org.opensearch.index.IndexNotFoundException;
4564
import org.opensearch.index.query.BoolQueryBuilder;
4665
import org.opensearch.index.query.TermQueryBuilder;
4766
import org.opensearch.index.query.TermsQueryBuilder;
@@ -54,6 +73,7 @@
5473
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
5574
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
5675
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
76+
import org.opensearch.ml.engine.utils.AgentModelsSearcher;
5777
import org.opensearch.ml.helper.ModelAccessControlHelper;
5878
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
5979
import org.opensearch.ml.utils.RestActionUtils;
@@ -62,6 +82,7 @@
6282
import org.opensearch.remote.metadata.client.GetDataObjectRequest;
6383
import org.opensearch.remote.metadata.client.SdkClient;
6484
import org.opensearch.remote.metadata.common.SdkClientUtils;
85+
import org.opensearch.search.SearchHit;
6586
import org.opensearch.search.fetch.subphase.FetchSourceContext;
6687
import org.opensearch.tasks.Task;
6788
import org.opensearch.transport.TransportService;
@@ -80,6 +101,10 @@ public class DeleteModelTransportAction extends HandledTransportAction<ActionReq
80101
static final String BULK_FAILURE_MSG = "Bulk failure while deleting model of ";
81102
static final String SEARCH_FAILURE_MSG = "Search failure while deleting model of ";
82103
static final String OS_STATUS_EXCEPTION_MESSAGE = "Failed to delete all model chunks";
104+
static final String PIPELINE_TARGET_MODEL_KEY = "model_id";
105+
106+
Boolean isSafeDelete;
107+
83108
final Client client;
84109
final SdkClient sdkClient;
85110
final NamedXContentRegistry xContentRegistry;
@@ -90,6 +115,8 @@ public class DeleteModelTransportAction extends HandledTransportAction<ActionReq
90115
final ModelAccessControlHelper modelAccessControlHelper;
91116
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
92117

118+
final AgentModelsSearcher agentModelsSearcher;
119+
93120
@Inject
94121
public DeleteModelTransportAction(
95122
TransportService transportService,
@@ -100,6 +127,7 @@ public DeleteModelTransportAction(
100127
NamedXContentRegistry xContentRegistry,
101128
ClusterService clusterService,
102129
ModelAccessControlHelper modelAccessControlHelper,
130+
AgentModelsSearcher agentModelsSearcher,
103131
MLFeatureEnabledSetting mlFeatureEnabledSetting
104132
) {
105133
super(MLModelDeleteAction.NAME, transportService, actionFilters, MLModelDeleteRequest::new);
@@ -108,6 +136,10 @@ public DeleteModelTransportAction(
108136
this.xContentRegistry = xContentRegistry;
109137
this.clusterService = clusterService;
110138
this.modelAccessControlHelper = modelAccessControlHelper;
139+
this.agentModelsSearcher = agentModelsSearcher;
140+
this.settings = settings;
141+
isSafeDelete = ML_COMMONS_SAFE_DELETE_WITH_USAGE_CHECK.get(settings);
142+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_SAFE_DELETE_WITH_USAGE_CHECK, it -> isSafeDelete = it);
111143
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
112144
}
113145

@@ -193,7 +225,19 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
193225
)
194226
);
195227
} else if (isModelNotDeployed(mlModelState)) {
196-
deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, actionListener);
228+
if (isSafeDelete) {
229+
// We only check downstream task when it's not hidden and cluster setting is true.
230+
checkDownstreamTaskBeforeDeleteModel(
231+
modelId,
232+
tenantId,
233+
mlModel.getAlgorithm().name(),
234+
isHidden,
235+
actionListener
236+
);
237+
} else {
238+
deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, actionListener);
239+
}
240+
// deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, actionListener);
197241
} else {
198242
wrappedListener
199243
.onFailure(
@@ -305,6 +349,107 @@ private void deleteModel(
305349
});
306350
}
307351

352+
private void checkDownstreamTaskBeforeDeleteModel(
353+
String modelId,
354+
String tenantId,
355+
String algorithm,
356+
Boolean isHidden,
357+
ActionListener<DeleteResponse> actionListener
358+
) {
359+
// Now checks 3 resources associated with the model id 1. Agent 2. Search pipeline 3. ingest pipeline
360+
CountDownLatch countDownLatch = new CountDownLatch(3);
361+
AtomicBoolean noneBlocked = new AtomicBoolean(true);
362+
ConcurrentLinkedQueue<String> errorMessages = new ConcurrentLinkedQueue<>();
363+
ActionListener<Boolean> countDownActionListener = ActionListener.wrap(b -> {
364+
countDownLatch.countDown();
365+
noneBlocked.compareAndSet(true, b);
366+
if (countDownLatch.getCount() == 0) {
367+
if (noneBlocked.get()) {
368+
deleteModel(modelId, tenantId, algorithm, isHidden, actionListener);
369+
}
370+
}
371+
}, e -> {
372+
countDownLatch.countDown();
373+
noneBlocked.set(false);
374+
errorMessages.add(e.getMessage());
375+
actionListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.CONFLICT));
376+
377+
});
378+
checkAgentBeforeDeleteModel(modelId, countDownActionListener);
379+
checkIngestPipelineBeforeDeleteModel(modelId, countDownActionListener);
380+
checkSearchPipelineBeforeDeleteModel(modelId, countDownActionListener);
381+
}
382+
383+
private void checkAgentBeforeDeleteModel(String modelId, ActionListener<Boolean> actionListener) {
384+
// check whether agent are using them
385+
SearchRequest searchAgentRequest = agentModelsSearcher.constructQueryRequestToSearchModelIdInsideAgent(modelId);
386+
client.search(searchAgentRequest, ActionListener.wrap(searchResponse -> {
387+
SearchHit[] searchHits = searchResponse.getHits().getHits();
388+
if (searchHits.length == 0) {
389+
actionListener.onResponse(true);
390+
} else {
391+
String errorMessage = formatAgentErrorMessage(searchHits);
392+
actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.CONFLICT));
393+
}
394+
395+
}, e -> {
396+
if (e instanceof IndexNotFoundException) {
397+
actionListener.onResponse(true);
398+
return;
399+
}
400+
log.error("Failed to delete ML Model: " + modelId, e);
401+
actionListener.onFailure(e);
402+
403+
}));
404+
}
405+
406+
private void checkIngestPipelineBeforeDeleteModel(String modelId, ActionListener<Boolean> actionListener) {
407+
checkPipelineBeforeDeleteModel(modelId, actionListener, "ingest", GetPipelineRequest::new, GetPipelineAction.INSTANCE);
408+
409+
}
410+
411+
private void checkSearchPipelineBeforeDeleteModel(String modelId, ActionListener<Boolean> actionListener) {
412+
checkPipelineBeforeDeleteModel(modelId, actionListener, "search", GetSearchPipelineRequest::new, GetSearchPipelineAction.INSTANCE);
413+
414+
}
415+
416+
private void checkPipelineBeforeDeleteModel(
417+
String modelId,
418+
ActionListener<Boolean> actionListener,
419+
String pipelineType,
420+
Supplier<ActionRequest> requestSupplier,
421+
ActionType actionType
422+
) {
423+
ActionRequest request = requestSupplier.get();
424+
client.execute(actionType, request, ActionListener.wrap(pipelineResponse -> {
425+
Map<String, Object> allConfigMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, pipelineResponse.toString(), true);
426+
List<String> allDependentPipelineIds = findDependentPipelinesEasy(allConfigMap, modelId);
427+
if (allDependentPipelineIds.isEmpty()) {
428+
actionListener.onResponse(true);
429+
} else {
430+
actionListener
431+
.onFailure(
432+
new OpenSearchStatusException(
433+
String
434+
.format(
435+
Locale.ROOT,
436+
"%d %s pipelines are still using this model, please delete or update the pipelines first: %s",
437+
allDependentPipelineIds.size(),
438+
pipelineType,
439+
Arrays.toString(allDependentPipelineIds.toArray(new String[0]))
440+
),
441+
RestStatus.CONFLICT
442+
)
443+
);
444+
}
445+
}, e -> {
446+
log.error("Failed to delete ML Model: " + modelId, e);
447+
actionListener.onFailure(e);
448+
449+
}));
450+
451+
}
452+
308453
private void deleteModelChunksAndController(
309454
ActionListener<DeleteResponse> actionListener,
310455
String modelId,
@@ -410,6 +555,71 @@ private Boolean isModelNotDeployed(MLModelState mlModelState) {
410555
&& !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED);
411556
}
412557

558+
private List<String> findDependentPipelinesEasy(Map<String, Object> allConfigMap, String candidateModelId) {
559+
List<String> dependentPipelineConfigurations = new ArrayList<>();
560+
for (Map.Entry<String, Object> entry : allConfigMap.entrySet()) {
561+
String id = entry.getKey();
562+
Map<String, Object> config = (Map<String, Object>) entry.getValue();
563+
if (searchThroughConfig(config, candidateModelId)) {
564+
dependentPipelineConfigurations.add(id);
565+
}
566+
}
567+
return dependentPipelineConfigurations;
568+
}
569+
570+
// This method is to go through the pipeline configs and the configuration is a map of string to objects.
571+
// Objects can be a list or a map. we will search exhaustively through the configuration for any match of the candidateId.
572+
private Boolean searchThroughConfig(Object searchCandidate, String candidateId) {
573+
// Use a stack to store the elements to be processed
574+
Deque<Pair<String, Object>> stack = new ArrayDeque<>();
575+
stack.push(Pair.of("", searchCandidate));
576+
577+
while (!stack.isEmpty()) {
578+
// Pop an item from the stack
579+
Pair<String, Object> current = stack.pop();
580+
String currentKey = current.getLeft();
581+
Object currentCandidate = current.getRight();
582+
583+
if (currentCandidate instanceof String && candidateId.equals(currentCandidate)) {
584+
// Check for a match
585+
if (PIPELINE_TARGET_MODEL_KEY.equals(currentKey)) {
586+
return true;
587+
}
588+
} else if (currentCandidate instanceof List<?>) {
589+
// Push all elements in the list onto the stack
590+
for (Object v : (List<?>) currentCandidate) {
591+
stack.push(Pair.of(currentKey, v));
592+
}
593+
} else if (currentCandidate instanceof Map<?, ?>) {
594+
// Push all values in the map onto the stack
595+
for (Map.Entry<?, ?> entry : ((Map<?, ?>) currentCandidate).entrySet()) {
596+
String key = (String) entry.getKey();
597+
Object value = entry.getValue();
598+
stack.push(Pair.of(key, value));
599+
}
600+
}
601+
}
602+
603+
// If no match is found
604+
return false;
605+
}
606+
607+
private String formatAgentErrorMessage(SearchHit[] hits) {
608+
List<String> agentIds = new ArrayList<>();
609+
for (SearchHit hit : hits) {
610+
Map<String, Object> sourceAsMap = hit.getSourceAsMap();
611+
agentIds.add(hit.getId());
612+
}
613+
return String
614+
.format(
615+
Locale.ROOT,
616+
"%d agents are still using this model, please delete or update the agents first, all visible agents are: %s",
617+
hits.length,
618+
Arrays.toString(agentIds.toArray(new String[0]))
619+
);
620+
621+
}
622+
413623
// this method is only to stub static method.
414624
@VisibleForTesting
415625
boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) {

‎plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+6
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@
197197
import org.opensearch.ml.engine.tools.MLModelTool;
198198
import org.opensearch.ml.engine.tools.SearchIndexTool;
199199
import org.opensearch.ml.engine.tools.VisualizationsTool;
200+
import org.opensearch.ml.engine.utils.AgentModelsSearcher;
200201
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
201202
import org.opensearch.ml.helper.ModelAccessControlHelper;
202203
import org.opensearch.ml.memory.ConversationalMemoryHandler;
@@ -366,6 +367,7 @@ public class MachineLearningPlugin extends Plugin
366367
private IndexUtils indexUtils;
367368
private ModelHelper modelHelper;
368369
private DiscoveryNodeHelper nodeHelper;
370+
private AgentModelsSearcher agentModelsSearcher;
369371

370372
private MLModelChunkUploader mlModelChunkUploader;
371373
private MLEngine mlEngine;
@@ -657,6 +659,8 @@ public Collection<Object> createComponents(
657659
toolFactories.putAll(externalToolFactories);
658660
}
659661

662+
agentModelsSearcher = new AgentModelsSearcher(toolFactories);
663+
660664
MLMemoryManager memoryManager = new MLMemoryManager(client, clusterService, new ConversationMetaIndex(client, clusterService));
661665
Map<String, Memory.Factory> memoryFactoryMap = new HashMap<>();
662666
ConversationIndexMemory.Factory conversationIndexMemoryFactory = new ConversationIndexMemory.Factory();
@@ -723,6 +727,7 @@ public Collection<Object> createComponents(
723727
mlStats,
724728
mlTaskManager,
725729
mlModelManager,
730+
agentModelsSearcher,
726731
mlIndicesHandler,
727732
mlInputDatasetHandler,
728733
mlTrainingTaskRunner,
@@ -1033,6 +1038,7 @@ public List<Setting<?>> getSettings() {
10331038
MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS,
10341039
MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS,
10351040
MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE,
1041+
MLCommonsSettings.ML_COMMONS_SAFE_DELETE_WITH_USAGE_CHECK,
10361042
MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED,
10371043
MLCommonsSettings.REMOTE_METADATA_TYPE,
10381044
MLCommonsSettings.REMOTE_METADATA_ENDPOINT,

‎plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java

+3
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ private MLCommonsSettings() {}
269269
public static final Setting<Boolean> ML_COMMONS_CONTROLLER_ENABLED = Setting
270270
.boolSetting("plugins.ml_commons.controller_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
271271

272+
// This flag is the determine whether we need to check downstream task before deleting a model.
273+
public static final Setting<Boolean> ML_COMMONS_SAFE_DELETE_WITH_USAGE_CHECK = Setting
274+
.boolSetting("plugins.ml_commons.safe_delete_model", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
272275
/**
273276
* Indicates whether multi-tenancy is enabled in ML Commons.
274277
*

‎plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java

+406-43
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.spi.tools;
7+
8+
9+
import java.util.List;
10+
11+
/**
12+
* General tool interface.
13+
*/
14+
public interface WithModelTool extends Tool {
15+
/**
16+
* Tool factory which can create instance of {@link Tool}.
17+
* @param <T> The subclass this factory produces
18+
*/
19+
interface Factory<T extends WithModelTool> extends Tool.Factory<T> {
20+
/**
21+
* Get model id related field names
22+
* @return the list of all model id related field names
23+
*/
24+
List<String> getAllModelKeys();
25+
}
26+
}

0 commit comments

Comments
 (0)
Please sign in to comment.