19
19
import java .io .IOException ;
20
20
import java .util .List ;
21
21
import java .util .Locale ;
22
+ import java .util .Objects ;
22
23
import java .util .Optional ;
23
24
24
25
import org .opensearch .common .util .concurrent .ThreadContext ;
@@ -83,27 +84,30 @@ public List<Route> routes() {
83
84
84
85
@ Override
85
86
public RestChannelConsumer prepareRequest (RestRequest request , NodeClient client ) throws IOException {
86
- String algorithm = request .param (PARAMETER_ALGORITHM );
87
+ String userAlgorithm = request .param (PARAMETER_ALGORITHM );
87
88
String modelId = getParameterId (request , PARAMETER_MODEL_ID );
88
89
Optional <FunctionName > functionName = modelManager .getOptionalModelFunctionName (modelId );
89
90
90
- if (algorithm == null && functionName .isPresent ()) {
91
- algorithm = functionName .get ().name ();
92
- }
93
-
94
- if (algorithm != null ) {
95
- MLPredictionTaskRequest mlPredictionTaskRequest = getRequest (modelId , algorithm , request );
96
- return channel -> client
97
- .execute (MLPredictionTaskAction .INSTANCE , mlPredictionTaskRequest , new RestToXContentListener <>(channel ));
91
+ // check if the model is in cache
92
+ if (functionName .isPresent ()) {
93
+ MLPredictionTaskRequest predictionRequest = getRequest (
94
+ modelId ,
95
+ functionName .get ().name (),
96
+ Objects .requireNonNullElse (userAlgorithm , functionName .get ().name ()),
97
+ request
98
+ );
99
+ return channel -> client .execute (MLPredictionTaskAction .INSTANCE , predictionRequest , new RestToXContentListener <>(channel ));
98
100
}
99
101
102
+ // If the model isn't in cache
100
103
return channel -> {
101
104
ActionListener <MLModel > listener = ActionListener .wrap (mlModel -> {
102
- String algoName = mlModel .getAlgorithm ().name ();
105
+ String modelType = mlModel .getAlgorithm ().name ();
106
+ String modelAlgorithm = Objects .requireNonNullElse (userAlgorithm , mlModel .getAlgorithm ().name ());
103
107
client
104
108
.execute (
105
109
MLPredictionTaskAction .INSTANCE ,
106
- getRequest (modelId , algoName , request ),
110
+ getRequest (modelId , modelType , modelAlgorithm , request ),
107
111
new RestToXContentListener <>(channel )
108
112
);
109
113
}, e -> {
@@ -126,18 +130,22 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
126
130
}
127
131
128
132
/**
129
- * Creates a MLPredictionTaskRequest from a RestRequest
133
+ * Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
134
+ * enabled features and model types, and parses the input data for prediction.
130
135
*
131
- * @param request RestRequest
132
- * @return MLPredictionTaskRequest
136
+ * @param modelId The ID of the ML model to use for prediction
137
+ * @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model
138
+ * @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
139
+ * @param request The REST request containing prediction input data
140
+ * @return MLPredictionTaskRequest configured with the model and input parameters
133
141
*/
134
142
@ VisibleForTesting
135
- MLPredictionTaskRequest getRequest (String modelId , String algorithm , RestRequest request ) throws IOException {
143
+ MLPredictionTaskRequest getRequest (String modelId , String modelType , String userAlgorithm , RestRequest request ) throws IOException {
136
144
String tenantId = getTenantID (mlFeatureEnabledSetting .isMultiTenancyEnabled (), request );
137
145
ActionType actionType = ActionType .from (getActionTypeFromRestRequest (request ));
138
- if (FunctionName .REMOTE .name ().equals (algorithm ) && !mlFeatureEnabledSetting .isRemoteInferenceEnabled ()) {
146
+ if (FunctionName .REMOTE .name ().equals (modelType ) && !mlFeatureEnabledSetting .isRemoteInferenceEnabled ()) {
139
147
throw new IllegalStateException (REMOTE_INFERENCE_DISABLED_ERR_MSG );
140
- } else if (FunctionName .isDLModel (FunctionName .from (algorithm .toUpperCase (Locale .ROOT )))
148
+ } else if (FunctionName .isDLModel (FunctionName .from (modelType .toUpperCase (Locale .ROOT )))
141
149
&& !mlFeatureEnabledSetting .isLocalModelEnabled ()) {
142
150
throw new IllegalStateException (LOCAL_MODEL_DISABLED_ERR_MSG );
143
151
} else if (ActionType .BATCH_PREDICT == actionType && !mlFeatureEnabledSetting .isOfflineBatchInferenceEnabled ()) {
@@ -148,7 +156,7 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest
148
156
149
157
XContentParser parser = request .contentParser ();
150
158
ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .nextToken (), parser );
151
- MLInput mlInput = MLInput .parse (parser , algorithm , actionType );
159
+ MLInput mlInput = MLInput .parse (parser , userAlgorithm , actionType );
152
160
return new MLPredictionTaskRequest (modelId , mlInput , null , tenantId );
153
161
}
154
162
0 commit comments