3
3
4
4
"""
5
5
6
- from typing import Any , Dict , List , Optional
6
+ from typing import Any , Dict , List , Literal , Optional , TypeAlias , Union
7
7
8
8
from model_engine_server .common .dtos .core import HttpUrlStr
9
+ from model_engine_server .common .dtos .llms .sglang import SGLangEndpointAdditionalArgs
9
10
from model_engine_server .common .dtos .llms .vllm import VLLMEndpointAdditionalArgs
10
11
from model_engine_server .common .dtos .model_endpoints import (
11
12
CpuSpecificationType ,
25
26
ModelEndpointStatus ,
26
27
Quantization ,
27
28
)
29
+ from pydantic import Discriminator , Tag
30
+ from typing_extensions import Annotated
28
31
29
32
30
- class CreateLLMModelEndpointV1Request (VLLMEndpointAdditionalArgs , BaseModel ):
31
- name : str
32
-
33
- # LLM specific fields
34
- model_name : str
35
- source : LLMSource = LLMSource .HUGGING_FACE
36
- inference_framework : LLMInferenceFramework = LLMInferenceFramework .VLLM
37
- inference_framework_image_tag : str = "latest"
38
- num_shards : int = 1
39
- """
40
- Number of shards to distribute the model onto GPUs.
41
- """
42
-
33
+ class LLMModelEndpointCommonArgs (BaseModel ):
43
34
quantize : Optional [Quantization ] = None
44
35
"""
45
36
Whether to quantize the model.
@@ -51,20 +42,14 @@ class CreateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel):
51
42
"""
52
43
53
44
# General endpoint fields
54
- metadata : Dict [str , Any ] # TODO: JSON type
55
45
post_inference_hooks : Optional [List [str ]] = None
56
- endpoint_type : ModelEndpointType = ModelEndpointType .SYNC
57
46
cpus : Optional [CpuSpecificationType ] = None
58
47
gpus : Optional [int ] = None
59
48
memory : Optional [StorageSpecificationType ] = None
60
49
gpu_type : Optional [GpuType ] = None
61
50
storage : Optional [StorageSpecificationType ] = None
62
51
nodes_per_worker : Optional [int ] = None
63
52
optimize_costs : Optional [bool ] = None
64
- min_workers : int
65
- max_workers : int
66
- per_worker : int
67
- labels : Dict [str , str ]
68
53
prewarm : Optional [bool ] = None
69
54
high_priority : Optional [bool ] = None
70
55
billing_tags : Optional [Dict [str , Any ]] = None
@@ -77,6 +62,83 @@ class CreateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel):
77
62
)
78
63
79
64
65
+ class CreateLLMModelEndpointArgs (LLMModelEndpointCommonArgs ):
66
+ name : str
67
+ model_name : str
68
+ """
69
+ Number of shards to distribute the model onto GPUs.
70
+ """
71
+ metadata : Dict [str , Any ] # TODO: JSON type
72
+ min_workers : int
73
+ max_workers : int
74
+ per_worker : int
75
+ labels : Dict [str , str ]
76
+ source : LLMSource = LLMSource .HUGGING_FACE
77
+ inference_framework_image_tag : str = "latest"
78
+ num_shards : int = 1
79
+ endpoint_type : ModelEndpointType = ModelEndpointType .SYNC
80
+
81
+
82
+ class CreateVLLMModelEndpointRequest (
83
+ VLLMEndpointAdditionalArgs , CreateLLMModelEndpointArgs , BaseModel
84
+ ):
85
+ inference_framework : Literal [LLMInferenceFramework .VLLM ] = LLMInferenceFramework .VLLM
86
+ pass
87
+
88
+
89
+ class CreateSGLangModelEndpointRequest (
90
+ SGLangEndpointAdditionalArgs , CreateLLMModelEndpointArgs , BaseModel
91
+ ):
92
+ inference_framework : Literal [LLMInferenceFramework .SGLANG ] = LLMInferenceFramework .SGLANG
93
+ pass
94
+
95
+
96
+ class CreateDeepSpeedModelEndpointRequest (CreateLLMModelEndpointArgs , BaseModel ):
97
+ inference_framework : Literal [LLMInferenceFramework .DEEPSPEED ] = LLMInferenceFramework .DEEPSPEED
98
+ pass
99
+
100
+
101
+ class CreateTextGenerationInferenceModelEndpointRequest (CreateLLMModelEndpointArgs , BaseModel ):
102
+ inference_framework : Literal [LLMInferenceFramework .TEXT_GENERATION_INFERENCE ] = (
103
+ LLMInferenceFramework .TEXT_GENERATION_INFERENCE
104
+ )
105
+ pass
106
+
107
+
108
+ class CreateLightLLMModelEndpointRequest (CreateLLMModelEndpointArgs , BaseModel ):
109
+ inference_framework : Literal [LLMInferenceFramework .LIGHTLLM ] = LLMInferenceFramework .LIGHTLLM
110
+ pass
111
+
112
+
113
+ class CreateTensorRTLLMModelEndpointRequest (CreateLLMModelEndpointArgs , BaseModel ):
114
+ inference_framework : Literal [LLMInferenceFramework .TENSORRT_LLM ] = (
115
+ LLMInferenceFramework .TENSORRT_LLM
116
+ )
117
+ pass
118
+
119
+
120
+ def get_inference_framework (v : Any ) -> str :
121
+ if isinstance (v , dict ):
122
+ return v .get ("inference_framework" , LLMInferenceFramework .VLLM )
123
+ return getattr (v , "inference_framework" , LLMInferenceFramework .VLLM )
124
+
125
+
126
+ CreateLLMModelEndpointV1Request : TypeAlias = Annotated [
127
+ Union [
128
+ Annotated [CreateVLLMModelEndpointRequest , Tag (LLMInferenceFramework .VLLM )],
129
+ Annotated [CreateSGLangModelEndpointRequest , Tag (LLMInferenceFramework .SGLANG )],
130
+ Annotated [CreateDeepSpeedModelEndpointRequest , Tag (LLMInferenceFramework .DEEPSPEED )],
131
+ Annotated [
132
+ CreateTextGenerationInferenceModelEndpointRequest ,
133
+ Tag (LLMInferenceFramework .TEXT_GENERATION_INFERENCE ),
134
+ ],
135
+ Annotated [CreateLightLLMModelEndpointRequest , Tag (LLMInferenceFramework .LIGHTLLM )],
136
+ Annotated [CreateTensorRTLLMModelEndpointRequest , Tag (LLMInferenceFramework .TENSORRT_LLM )],
137
+ ],
138
+ Discriminator (get_inference_framework ),
139
+ ]
140
+
141
+
80
142
class CreateLLMModelEndpointV1Response (BaseModel ):
81
143
endpoint_creation_task_id : str
82
144
@@ -107,57 +169,73 @@ class ListLLMModelEndpointsV1Response(BaseModel):
107
169
model_endpoints : List [GetLLMModelEndpointV1Response ]
108
170
109
171
110
- class UpdateLLMModelEndpointV1Request (VLLMEndpointAdditionalArgs , BaseModel ):
111
- # LLM specific fields
172
+ class UpdateLLMModelEndpointArgs (LLMModelEndpointCommonArgs ):
112
173
model_name : Optional [str ] = None
113
174
source : Optional [LLMSource ] = None
175
+ inference_framework : Optional [LLMInferenceFramework ] = None
114
176
inference_framework_image_tag : Optional [str ] = None
115
177
num_shards : Optional [int ] = None
116
178
"""
117
179
Number of shards to distribute the model onto GPUs.
118
180
"""
119
-
120
- quantize : Optional [Quantization ] = None
121
- """
122
- Whether to quantize the model.
181
+ metadata : Optional [Dict [str , Any ]] = None
182
+ force_bundle_recreation : Optional [bool ] = False
123
183
"""
184
+ Whether to force recreate the underlying bundle.
124
185
125
- checkpoint_path : Optional [str ] = None
126
- """
127
- Path to the checkpoint to load the model from.
186
+ If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created
187
+ that we would like to pick up for existing endpoints
128
188
"""
129
-
130
- # General endpoint fields
131
- metadata : Optional [Dict [str , Any ]] = None
132
- post_inference_hooks : Optional [List [str ]] = None
133
- cpus : Optional [CpuSpecificationType ] = None
134
- gpus : Optional [int ] = None
135
- memory : Optional [StorageSpecificationType ] = None
136
- gpu_type : Optional [GpuType ] = None
137
- storage : Optional [StorageSpecificationType ] = None
138
- optimize_costs : Optional [bool ] = None
139
189
min_workers : Optional [int ] = None
140
190
max_workers : Optional [int ] = None
141
191
per_worker : Optional [int ] = None
142
192
labels : Optional [Dict [str , str ]] = None
143
- prewarm : Optional [bool ] = None
144
- high_priority : Optional [bool ] = None
145
- billing_tags : Optional [Dict [str , Any ]] = None
146
- default_callback_url : Optional [HttpUrlStr ] = None
147
- default_callback_auth : Optional [CallbackAuth ] = None
148
- public_inference : Optional [bool ] = None
149
- chat_template_override : Optional [str ] = Field (
150
- default = None ,
151
- description = "A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint" ,
193
+
194
+
195
+ class UpdateVLLMModelEndpointRequest (
196
+ VLLMEndpointAdditionalArgs , UpdateLLMModelEndpointArgs , BaseModel
197
+ ):
198
+ inference_framework : Literal [LLMInferenceFramework .VLLM ] = LLMInferenceFramework .VLLM
199
+
200
+
201
+ class UpdateSGLangModelEndpointRequest (
202
+ SGLangEndpointAdditionalArgs , UpdateLLMModelEndpointArgs , BaseModel
203
+ ):
204
+ inference_framework : Literal [LLMInferenceFramework .SGLANG ] = LLMInferenceFramework .SGLANG
205
+
206
+
207
+ class UpdateDeepSpeedModelEndpointRequest (UpdateLLMModelEndpointArgs , BaseModel ):
208
+ inference_framework : Literal [LLMInferenceFramework .DEEPSPEED ] = LLMInferenceFramework .DEEPSPEED
209
+
210
+
211
+ class UpdateTextGenerationInferenceModelEndpointRequest (UpdateLLMModelEndpointArgs , BaseModel ):
212
+ inference_framework : Literal [LLMInferenceFramework .TEXT_GENERATION_INFERENCE ] = (
213
+ LLMInferenceFramework .TEXT_GENERATION_INFERENCE
152
214
)
153
215
154
- force_bundle_recreation : Optional [bool ] = False
155
- """
156
- Whether to force recreate the underlying bundle.
157
216
158
- If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created
159
- that we would like to pick up for existing endpoints
160
- """
217
+ class UpdateLightLLMModelEndpointRequest (UpdateLLMModelEndpointArgs , BaseModel ):
218
+ inference_framework : Literal [LLMInferenceFramework .LIGHTLLM ] = LLMInferenceFramework .LIGHTLLM
219
+
220
+
221
+ class UpdateTensorRTLLMModelEndpointRequest (UpdateLLMModelEndpointArgs , BaseModel ):
222
+ inference_framework : Literal [LLMInferenceFramework .TENSORRT_LLM ] = (
223
+ LLMInferenceFramework .TENSORRT_LLM
224
+ )
225
+
226
+
227
+ UpdateLLMModelEndpointV1Request : TypeAlias = Annotated [
228
+ Union [
229
+ Annotated [UpdateVLLMModelEndpointRequest , Tag (LLMInferenceFramework .VLLM )],
230
+ Annotated [UpdateSGLangModelEndpointRequest , Tag (LLMInferenceFramework .SGLANG )],
231
+ Annotated [UpdateDeepSpeedModelEndpointRequest , Tag (LLMInferenceFramework .DEEPSPEED )],
232
+ Annotated [
233
+ UpdateTextGenerationInferenceModelEndpointRequest ,
234
+ Tag (LLMInferenceFramework .TEXT_GENERATION_INFERENCE ),
235
+ ],
236
+ ],
237
+ Discriminator (get_inference_framework ),
238
+ ]
161
239
162
240
163
241
class UpdateLLMModelEndpointV1Response (BaseModel ):
0 commit comments