Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use model type to check local or remote model #3597

Merged
merged 6 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;

import org.opensearch.common.util.concurrent.ThreadContext;
Expand Down Expand Up @@ -83,27 +84,30 @@ public List<Route> routes() {

@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
String algorithm = request.param(PARAMETER_ALGORITHM);
String userAlgorithm = request.param(PARAMETER_ALGORITHM);
String modelId = getParameterId(request, PARAMETER_MODEL_ID);
Optional<FunctionName> functionName = modelManager.getOptionalModelFunctionName(modelId);

if (algorithm == null && functionName.isPresent()) {
algorithm = functionName.get().name();
}

if (algorithm != null) {
MLPredictionTaskRequest mlPredictionTaskRequest = getRequest(modelId, algorithm, request);
return channel -> client
.execute(MLPredictionTaskAction.INSTANCE, mlPredictionTaskRequest, new RestToXContentListener<>(channel));
// check if the model is in cache
if (functionName.isPresent()) {
MLPredictionTaskRequest predictionRequest = getRequest(
modelId,
functionName.get().name(),
Objects.requireNonNullElse(userAlgorithm, functionName.get().name()),
request
);
return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel));
}

// If the model isn't in cache
return channel -> {
ActionListener<MLModel> listener = ActionListener.wrap(mlModel -> {
String algoName = mlModel.getAlgorithm().name();
String modelType = mlModel.getAlgorithm().name();
String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name());
client
.execute(
MLPredictionTaskAction.INSTANCE,
getRequest(modelId, algoName, request),
getRequest(modelId, modelType, modelAlgorithm, request),
new RestToXContentListener<>(channel)
);
}, e -> {
Expand All @@ -126,18 +130,22 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
}

/**
* Creates a MLPredictionTaskRequest from a RestRequest
* Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
* enabled features and model types, and parses the input data for prediction.
*
* @param request RestRequest
* @return MLPredictionTaskRequest
* @param modelId The ID of the ML model to use for prediction
* @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model
* @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
* @param request The REST request containing prediction input data
* @return MLPredictionTaskRequest configured with the model and input parameters
*/
@VisibleForTesting
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add java doc for this method. I'm afraid this could be confusing later if we don't have enough documentation about this in the code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request);
ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request));
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase(Locale.ROOT)))
} else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT)))
&& !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
} else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
Expand All @@ -148,7 +156,7 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest

XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLInput mlInput = MLInput.parse(parser, algorithm, actionType);
MLInput mlInput = MLInput.parse(parser, userAlgorithm, actionType);
return new MLPredictionTaskRequest(modelId, mlInput, null, tenantId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public class RestMLPredictionActionTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.empty());
when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.of(FunctionName.REMOTE));
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);
restMLPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting);
Expand Down Expand Up @@ -127,7 +127,8 @@ public void testRoutes_Batch() {
@Test
public void testGetRequest() throws IOException {
RestRequest request = getRestRequest_PredictModel();
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request);
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
.getRequest("modelId", FunctionName.KMEANS.name(), FunctionName.KMEANS.name(), request);

MLInput mlInput = mlPredictionTaskRequest.getMlInput();
verifyParsedKMeansMLInput(mlInput);
Expand All @@ -140,7 +141,8 @@ public void testGetRequest_RemoteInferenceDisabled() throws IOException {

when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false);
RestRequest request = getRestRequest_PredictModel();
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.REMOTE.name(), request);
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
.getRequest("modelId", FunctionName.REMOTE.name(), "text_embedding", request);
}

@Test
Expand All @@ -151,7 +153,7 @@ public void testGetRequest_LocalModelInferenceDisabled() throws IOException {
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);
RestRequest request = getRestRequest_PredictModel();
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
.getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), request);
.getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), "text_embedding", request);
}

@Test
Expand Down Expand Up @@ -194,7 +196,7 @@ public void testPrepareBatchRequest_WrongActionType() throws Exception {
thrown.expectMessage("Wrong Action Type");

RestRequest request = getBatchRestRequest_WrongActionType();
restMLPredictionAction.getRequest("model id", "remote", request);
restMLPredictionAction.getRequest("model id", "remote", "text_embedding", request);
}

@Ignore
Expand Down Expand Up @@ -232,7 +234,7 @@ public void testGetRequest_InvalidActionType() throws IOException {
thrown.expectMessage("Wrong Action Type of models");

RestRequest request = getBatchRestRequest_WrongActionType();
restMLPredictionAction.getRequest("model_id", FunctionName.REMOTE.name(), request);
restMLPredictionAction.getRequest("model_id", FunctionName.REMOTE.name(), "text_embedding", request);
}

@Test
Expand All @@ -242,7 +244,7 @@ public void testGetRequest_UnsupportedAlgorithm() throws IOException {

// Create a RestRequest with an unsupported algorithm
RestRequest request = getRestRequest_PredictModel();
restMLPredictionAction.getRequest("model_id", "INVALID_ALGO", request);
restMLPredictionAction.getRequest("model_id", "INVALID_ALGO", "text_embedding", request);
}

private RestRequest getRestRequest_PredictModel() {
Expand Down
Loading