Skip to content

Commit ad5b1e5

Browse files
b4sjoogithub-actions[bot]
authored andcommitted
Use model type to check local or remote model (#3597)
* use model type to check local or remote model Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * spotless Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Ignore test resource Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Add java doc Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Handle when model not in cache Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Handle when model not in cache Signed-off-by: Sicheng Song <sicheng.song@outlook.com> --------- Signed-off-by: Sicheng Song <sicheng.song@outlook.com> (cherry picked from commit 696b1e1)
1 parent 6697c20 commit ad5b1e5

File tree

2 files changed

+35
-25
lines changed

2 files changed

+35
-25
lines changed

plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java

+26-18
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.io.IOException;
2020
import java.util.List;
2121
import java.util.Locale;
22+
import java.util.Objects;
2223
import java.util.Optional;
2324

2425
import org.opensearch.client.node.NodeClient;
@@ -83,27 +84,30 @@ public List<Route> routes() {
8384

8485
@Override
8586
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
86-
String algorithm = request.param(PARAMETER_ALGORITHM);
87+
String userAlgorithm = request.param(PARAMETER_ALGORITHM);
8788
String modelId = getParameterId(request, PARAMETER_MODEL_ID);
8889
Optional<FunctionName> functionName = modelManager.getOptionalModelFunctionName(modelId);
8990

90-
if (algorithm == null && functionName.isPresent()) {
91-
algorithm = functionName.get().name();
92-
}
93-
94-
if (algorithm != null) {
95-
MLPredictionTaskRequest mlPredictionTaskRequest = getRequest(modelId, algorithm, request);
96-
return channel -> client
97-
.execute(MLPredictionTaskAction.INSTANCE, mlPredictionTaskRequest, new RestToXContentListener<>(channel));
91+
// check if the model is in cache
92+
if (functionName.isPresent()) {
93+
MLPredictionTaskRequest predictionRequest = getRequest(
94+
modelId,
95+
functionName.get().name(),
96+
Objects.requireNonNullElse(userAlgorithm, functionName.get().name()),
97+
request
98+
);
99+
return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel));
98100
}
99101

102+
// If the model isn't in cache
100103
return channel -> {
101104
ActionListener<MLModel> listener = ActionListener.wrap(mlModel -> {
102-
String algoName = mlModel.getAlgorithm().name();
105+
String modelType = mlModel.getAlgorithm().name();
106+
String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name());
103107
client
104108
.execute(
105109
MLPredictionTaskAction.INSTANCE,
106-
getRequest(modelId, algoName, request),
110+
getRequest(modelId, modelType, modelAlgorithm, request),
107111
new RestToXContentListener<>(channel)
108112
);
109113
}, e -> {
@@ -126,18 +130,22 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
126130
}
127131

128132
/**
129-
* Creates a MLPredictionTaskRequest from a RestRequest
133+
* Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
134+
* enabled features and model types, and parses the input data for prediction.
130135
*
131-
* @param request RestRequest
132-
* @return MLPredictionTaskRequest
136+
* @param modelId The ID of the ML model to use for prediction
137+
* @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model
138+
* @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
139+
* @param request The REST request containing prediction input data
140+
* @return MLPredictionTaskRequest configured with the model and input parameters
133141
*/
134142
@VisibleForTesting
135-
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
143+
MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException {
136144
String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request);
137145
ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request));
138-
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
146+
if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
139147
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
140-
} else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase(Locale.ROOT)))
148+
} else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT)))
141149
&& !mlFeatureEnabledSetting.isLocalModelEnabled()) {
142150
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
143151
} else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
@@ -148,7 +156,7 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest
148156

149157
XContentParser parser = request.contentParser();
150158
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
151-
MLInput mlInput = MLInput.parse(parser, algorithm, actionType);
159+
MLInput mlInput = MLInput.parse(parser, userAlgorithm, actionType);
152160
return new MLPredictionTaskRequest(modelId, mlInput, null, tenantId);
153161
}
154162

plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java

+9-7
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public class RestMLPredictionActionTests extends OpenSearchTestCase {
7070
@Before
7171
public void setup() {
7272
MockitoAnnotations.openMocks(this);
73-
when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.empty());
73+
when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.of(FunctionName.REMOTE));
7474
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);
7575
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);
7676
restMLPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting);
@@ -127,7 +127,8 @@ public void testRoutes_Batch() {
127127
@Test
128128
public void testGetRequest() throws IOException {
129129
RestRequest request = getRestRequest_PredictModel();
130-
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request);
130+
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
131+
.getRequest("modelId", FunctionName.KMEANS.name(), FunctionName.KMEANS.name(), request);
131132

132133
MLInput mlInput = mlPredictionTaskRequest.getMlInput();
133134
verifyParsedKMeansMLInput(mlInput);
@@ -140,7 +141,8 @@ public void testGetRequest_RemoteInferenceDisabled() throws IOException {
140141

141142
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false);
142143
RestRequest request = getRestRequest_PredictModel();
143-
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.REMOTE.name(), request);
144+
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
145+
.getRequest("modelId", FunctionName.REMOTE.name(), "text_embedding", request);
144146
}
145147

146148
@Test
@@ -151,7 +153,7 @@ public void testGetRequest_LocalModelInferenceDisabled() throws IOException {
151153
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);
152154
RestRequest request = getRestRequest_PredictModel();
153155
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
154-
.getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), request);
156+
.getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), "text_embedding", request);
155157
}
156158

157159
@Test
@@ -194,7 +196,7 @@ public void testPrepareBatchRequest_WrongActionType() throws Exception {
194196
thrown.expectMessage("Wrong Action Type");
195197

196198
RestRequest request = getBatchRestRequest_WrongActionType();
197-
restMLPredictionAction.getRequest("model id", "remote", request);
199+
restMLPredictionAction.getRequest("model id", "remote", "text_embedding", request);
198200
}
199201

200202
@Ignore
@@ -232,7 +234,7 @@ public void testGetRequest_InvalidActionType() throws IOException {
232234
thrown.expectMessage("Wrong Action Type of models");
233235

234236
RestRequest request = getBatchRestRequest_WrongActionType();
235-
restMLPredictionAction.getRequest("model_id", FunctionName.REMOTE.name(), request);
237+
restMLPredictionAction.getRequest("model_id", FunctionName.REMOTE.name(), "text_embedding", request);
236238
}
237239

238240
@Test
@@ -242,7 +244,7 @@ public void testGetRequest_UnsupportedAlgorithm() throws IOException {
242244

243245
// Create a RestRequest with an unsupported algorithm
244246
RestRequest request = getRestRequest_PredictModel();
245-
restMLPredictionAction.getRequest("model_id", "INVALID_ALGO", request);
247+
restMLPredictionAction.getRequest("model_id", "INVALID_ALGO", "text_embedding", request);
246248
}
247249

248250
private RestRequest getRestRequest_PredictModel() {

0 commit comments

Comments
 (0)