Skip to content

Commit fc2bfc9

Browse files
b4sjoorithin-pullela-aws
authored andcommitted
Use model type to check local or remote model (opensearch-project#3597)
* use model type to check local or remote model Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * spotless Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Ignore test resource Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Add java doc Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Handle when model not in cache Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Handle when model not in cache Signed-off-by: Sicheng Song <sicheng.song@outlook.com> --------- Signed-off-by: Sicheng Song <sicheng.song@outlook.com> (cherry picked from commit 696b1e1)
1 parent 78868e6 commit fc2bfc9

File tree

2 files changed

+43
-31
lines changed

2 files changed

+43
-31
lines changed

plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java

+37-28
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import java.io.IOException;
1616
import java.util.List;
1717
import java.util.Locale;
18+
import java.util.Objects;
1819
import java.util.Optional;
1920

2021
import org.opensearch.client.node.NodeClient;
@@ -65,42 +66,45 @@ public String getName() {
6566
@Override
6667
public List<Route> routes() {
6768
return ImmutableList
68-
.of(
69-
new Route(
70-
RestRequest.Method.POST,
71-
String.format(Locale.ROOT, "%s/_predict/{%s}/{%s}", ML_BASE_URI, PARAMETER_ALGORITHM, PARAMETER_MODEL_ID)
72-
),
73-
new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/models/{%s}/_predict", ML_BASE_URI, PARAMETER_MODEL_ID))
74-
);
69+
.of(
70+
new Route(
71+
RestRequest.Method.POST,
72+
String.format(Locale.ROOT, "%s/_predict/{%s}/{%s}", ML_BASE_URI, PARAMETER_ALGORITHM, PARAMETER_MODEL_ID)
73+
),
74+
new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/models/{%s}/_predict", ML_BASE_URI, PARAMETER_MODEL_ID))
75+
);
7576
}
7677

7778
@Override
7879
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
79-
String algorithm = request.param(PARAMETER_ALGORITHM);
80+
String userAlgorithm = request.param(PARAMETER_ALGORITHM);
8081
String modelId = getParameterId(request, PARAMETER_MODEL_ID);
8182
Optional<FunctionName> functionName = modelManager.getOptionalModelFunctionName(modelId);
8283

83-
if (algorithm == null && functionName.isPresent()) {
84-
algorithm = functionName.get().name();
85-
}
86-
87-
if (algorithm != null) {
88-
MLPredictionTaskRequest mlPredictionTaskRequest = getRequest(modelId, algorithm, request);
89-
return channel -> client
90-
.execute(MLPredictionTaskAction.INSTANCE, mlPredictionTaskRequest, new RestToXContentListener<>(channel));
84+
// check if the model is in cache
85+
if (functionName.isPresent()) {
86+
MLPredictionTaskRequest predictionRequest = getRequest(
87+
modelId,
88+
functionName.get().name(),
89+
Objects.requireNonNullElse(userAlgorithm, functionName.get().name()),
90+
request
91+
);
92+
return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel));
9193
}
9294

95+
// If the model isn't in cache
9396
return channel -> {
9497
MLModelGetRequest getModelRequest = new MLModelGetRequest(modelId, false);
9598
ActionListener<MLModelGetResponse> listener = ActionListener.wrap(r -> {
9699
MLModel mlModel = r.getMlModel();
97-
String algoName = mlModel.getAlgorithm().name();
100+
String modelType = mlModel.getAlgorithm().name();
101+
String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name());
98102
client
99-
.execute(
100-
MLPredictionTaskAction.INSTANCE,
101-
getRequest(modelId, algoName, request),
102-
new RestToXContentListener<>(channel)
103-
);
103+
.execute(
104+
MLPredictionTaskAction.INSTANCE,
105+
getRequest(modelId, modelType, modelAlgorithm, request),
106+
new RestToXContentListener<>(channel)
107+
);
104108
}, e -> {
105109
log.error("Failed to get ML model", e);
106110
try {
@@ -115,20 +119,25 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
115119
}
116120

117121
/**
118-
* Creates a MLPredictionTaskRequest from a RestRequest
122+
* Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
123+
* enabled features and model types, and parses the input data for prediction.
119124
*
120-
* @param request RestRequest
121-
* @return MLPredictionTaskRequest
125+
* @param modelId The ID of the ML model to use for prediction
126+
* @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model
127+
* @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
128+
* @param request The REST request containing prediction input data
129+
* @return MLPredictionTaskRequest configured with the model and input parameters
122130
*/
123131
@VisibleForTesting
124-
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
125-
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
132+
MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException {
133+
if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
126134
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
127135
}
128136
XContentParser parser = request.contentParser();
129137
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
130-
MLInput mlInput = MLInput.parse(parser, algorithm);
138+
MLInput mlInput = MLInput.parse(parser, userAlgorithm);
131139
return new MLPredictionTaskRequest(modelId, mlInput, null);
132140
}
133141

134142
}
143+

plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public class RestMLPredictionActionTests extends OpenSearchTestCase {
6565
@Before
6666
public void setup() {
6767
MockitoAnnotations.openMocks(this);
68-
when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.empty());
68+
when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.of(FunctionName.REMOTE));
6969
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);
7070
restMLPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting);
7171

@@ -107,7 +107,8 @@ public void testRoutes() {
107107

108108
public void testGetRequest() throws IOException {
109109
RestRequest request = getRestRequest_PredictModel();
110-
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request);
110+
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
111+
.getRequest("modelId", FunctionName.KMEANS.name(), FunctionName.KMEANS.name(), request);
111112

112113
MLInput mlInput = mlPredictionTaskRequest.getMlInput();
113114
verifyParsedKMeansMLInput(mlInput);
@@ -119,7 +120,8 @@ public void testGetRequest_RemoteInferenceDisabled() throws IOException {
119120

120121
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false);
121122
RestRequest request = getRestRequest_PredictModel();
122-
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.REMOTE.name(), request);
123+
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
124+
.getRequest("modelId", FunctionName.REMOTE.name(), "text_embedding", request);
123125
}
124126

125127
public void testPrepareRequest() throws Exception {
@@ -165,3 +167,4 @@ private RestRequest getRestRequest_PredictModel() {
165167
return request;
166168
}
167169
}
170+

0 commit comments

Comments
 (0)