15
15
import java .io .IOException ;
16
16
import java .util .List ;
17
17
import java .util .Locale ;
18
+ import java .util .Objects ;
18
19
import java .util .Optional ;
19
20
20
21
import org .opensearch .client .node .NodeClient ;
@@ -65,42 +66,45 @@ public String getName() {
65
66
@ Override
66
67
public List <Route > routes () {
67
68
return ImmutableList
68
- .of (
69
- new Route (
70
- RestRequest .Method .POST ,
71
- String .format (Locale .ROOT , "%s/_predict/{%s}/{%s}" , ML_BASE_URI , PARAMETER_ALGORITHM , PARAMETER_MODEL_ID )
72
- ),
73
- new Route (RestRequest .Method .POST , String .format (Locale .ROOT , "%s/models/{%s}/_predict" , ML_BASE_URI , PARAMETER_MODEL_ID ))
74
- );
69
+ .of (
70
+ new Route (
71
+ RestRequest .Method .POST ,
72
+ String .format (Locale .ROOT , "%s/_predict/{%s}/{%s}" , ML_BASE_URI , PARAMETER_ALGORITHM , PARAMETER_MODEL_ID )
73
+ ),
74
+ new Route (RestRequest .Method .POST , String .format (Locale .ROOT , "%s/models/{%s}/_predict" , ML_BASE_URI , PARAMETER_MODEL_ID ))
75
+ );
75
76
}
76
77
77
78
@ Override
78
79
public RestChannelConsumer prepareRequest (RestRequest request , NodeClient client ) throws IOException {
79
- String algorithm = request .param (PARAMETER_ALGORITHM );
80
+ String userAlgorithm = request .param (PARAMETER_ALGORITHM );
80
81
String modelId = getParameterId (request , PARAMETER_MODEL_ID );
81
82
Optional <FunctionName > functionName = modelManager .getOptionalModelFunctionName (modelId );
82
83
83
- if (algorithm == null && functionName .isPresent ()) {
84
- algorithm = functionName .get ().name ();
85
- }
86
-
87
- if (algorithm != null ) {
88
- MLPredictionTaskRequest mlPredictionTaskRequest = getRequest (modelId , algorithm , request );
89
- return channel -> client
90
- .execute (MLPredictionTaskAction .INSTANCE , mlPredictionTaskRequest , new RestToXContentListener <>(channel ));
84
+ // check if the model is in cache
85
+ if (functionName .isPresent ()) {
86
+ MLPredictionTaskRequest predictionRequest = getRequest (
87
+ modelId ,
88
+ functionName .get ().name (),
89
+ Objects .requireNonNullElse (userAlgorithm , functionName .get ().name ()),
90
+ request
91
+ );
92
+ return channel -> client .execute (MLPredictionTaskAction .INSTANCE , predictionRequest , new RestToXContentListener <>(channel ));
91
93
}
92
94
95
+ // If the model isn't in cache
93
96
return channel -> {
94
97
MLModelGetRequest getModelRequest = new MLModelGetRequest (modelId , false );
95
98
ActionListener <MLModelGetResponse > listener = ActionListener .wrap (r -> {
96
99
MLModel mlModel = r .getMlModel ();
97
- String algoName = mlModel .getAlgorithm ().name ();
100
+ String modelType = mlModel .getAlgorithm ().name ();
101
+ String modelAlgorithm = Objects .requireNonNullElse (userAlgorithm , mlModel .getAlgorithm ().name ());
98
102
client
99
- .execute (
100
- MLPredictionTaskAction .INSTANCE ,
101
- getRequest (modelId , algoName , request ),
102
- new RestToXContentListener <>(channel )
103
- );
103
+ .execute (
104
+ MLPredictionTaskAction .INSTANCE ,
105
+ getRequest (modelId , modelType , modelAlgorithm , request ),
106
+ new RestToXContentListener <>(channel )
107
+ );
104
108
}, e -> {
105
109
log .error ("Failed to get ML model" , e );
106
110
try {
@@ -115,20 +119,25 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
115
119
}
116
120
117
121
/**
118
- * Creates a MLPredictionTaskRequest from a RestRequest
122
+ * Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
123
+ * enabled features and model types, and parses the input data for prediction.
119
124
*
120
- * @param request RestRequest
121
- * @return MLPredictionTaskRequest
125
+ * @param modelId The ID of the ML model to use for prediction
126
+ * @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model
127
+ * @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
128
+ * @param request The REST request containing prediction input data
129
+ * @return MLPredictionTaskRequest configured with the model and input parameters
122
130
*/
123
131
@ VisibleForTesting
124
- MLPredictionTaskRequest getRequest (String modelId , String algorithm , RestRequest request ) throws IOException {
125
- if (FunctionName .REMOTE .name ().equals (algorithm ) && !mlFeatureEnabledSetting .isRemoteInferenceEnabled ()) {
132
+ MLPredictionTaskRequest getRequest (String modelId , String modelType , String userAlgorithm , RestRequest request ) throws IOException {
133
+ if (FunctionName .REMOTE .name ().equals (modelType ) && !mlFeatureEnabledSetting .isRemoteInferenceEnabled ()) {
126
134
throw new IllegalStateException (REMOTE_INFERENCE_DISABLED_ERR_MSG );
127
135
}
128
136
XContentParser parser = request .contentParser ();
129
137
ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .nextToken (), parser );
130
- MLInput mlInput = MLInput .parse (parser , algorithm );
138
+ MLInput mlInput = MLInput .parse (parser , userAlgorithm );
131
139
return new MLPredictionTaskRequest (modelId , mlInput , null );
132
140
}
133
141
134
142
}
143
+
0 commit comments