Skip to content

Commit 9b5f5ca

Browse files
authored
SGLang support (#681)
* SGLang support * fix * fix * fix * fix * fix * fix * workaround for s5cmd * unblock by downloading from huggingface directly * black format * fix import * Fix dto + tests * skip coverage * fix tests * Update pytorch version for integration * Revert "Update pytorch version for integration" This reverts commit 01bb400. * revert black formatting
1 parent 17bb585 commit 9b5f5ca

15 files changed

+757
-270
lines changed

.black.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool.black]
22
# Independently enforced in .flake8 and .isort.cfg
33
line-length = 100
4-
target-version = ['py38']
4+
target-version = ['py310']
55
include = '\.pyi?$'
66
exclude = '''
77
(

model-engine/model_engine_server/api/llms_v1.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
UpdateLLMModelEndpointV1UseCase,
8888
)
8989
from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase
90+
from pydantic import RootModel
9091
from sse_starlette.sse import EventSourceResponse
9192

9293

@@ -147,10 +148,11 @@ def handle_streaming_exception(
147148

148149
@llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response)
149150
async def create_model_endpoint(
150-
request: CreateLLMModelEndpointV1Request,
151+
wrapped_request: RootModel[CreateLLMModelEndpointV1Request],
151152
auth: User = Depends(verify_authentication),
152153
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces),
153154
) -> CreateLLMModelEndpointV1Response:
155+
request = wrapped_request.root
154156
"""
155157
Creates an LLM endpoint for the current user.
156158
"""
@@ -261,13 +263,14 @@ async def get_model_endpoint(
261263
)
262264
async def update_model_endpoint(
263265
model_endpoint_name: str,
264-
request: UpdateLLMModelEndpointV1Request,
266+
wrapped_request: RootModel[UpdateLLMModelEndpointV1Request],
265267
auth: User = Depends(verify_authentication),
266268
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces),
267269
) -> UpdateLLMModelEndpointV1Response:
268270
"""
269271
Updates an LLM endpoint for the current user.
270272
"""
273+
request = wrapped_request.root
271274
logger.info(f"PUT /llm/model-endpoints/{model_endpoint_name} with {request} for {auth}")
272275
try:
273276
create_model_bundle_use_case = CreateModelBundleV2UseCase(

model-engine/model_engine_server/common/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class HostedModelInferenceServiceConfig:
7676
cache_redis_aws_secret_name: Optional[str] = (
7777
None # Not an env var because the redis cache info is already here
7878
)
79+
sglang_repository: Optional[str] = None
7980

8081
@classmethod
8182
def from_json(cls, json):

model-engine/model_engine_server/common/dtos/llms/model_endpoints.py

+132-54
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
44
"""
55

6-
from typing import Any, Dict, List, Optional
6+
from typing import Any, Dict, List, Literal, Optional, TypeAlias, Union
77

88
from model_engine_server.common.dtos.core import HttpUrlStr
9+
from model_engine_server.common.dtos.llms.sglang import SGLangEndpointAdditionalArgs
910
from model_engine_server.common.dtos.llms.vllm import VLLMEndpointAdditionalArgs
1011
from model_engine_server.common.dtos.model_endpoints import (
1112
CpuSpecificationType,
@@ -25,21 +26,11 @@
2526
ModelEndpointStatus,
2627
Quantization,
2728
)
29+
from pydantic import Discriminator, Tag
30+
from typing_extensions import Annotated
2831

2932

30-
class CreateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel):
31-
name: str
32-
33-
# LLM specific fields
34-
model_name: str
35-
source: LLMSource = LLMSource.HUGGING_FACE
36-
inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM
37-
inference_framework_image_tag: str = "latest"
38-
num_shards: int = 1
39-
"""
40-
Number of shards to distribute the model onto GPUs.
41-
"""
42-
33+
class LLMModelEndpointCommonArgs(BaseModel):
4334
quantize: Optional[Quantization] = None
4435
"""
4536
Whether to quantize the model.
@@ -51,20 +42,14 @@ class CreateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel):
5142
"""
5243

5344
# General endpoint fields
54-
metadata: Dict[str, Any] # TODO: JSON type
5545
post_inference_hooks: Optional[List[str]] = None
56-
endpoint_type: ModelEndpointType = ModelEndpointType.SYNC
5746
cpus: Optional[CpuSpecificationType] = None
5847
gpus: Optional[int] = None
5948
memory: Optional[StorageSpecificationType] = None
6049
gpu_type: Optional[GpuType] = None
6150
storage: Optional[StorageSpecificationType] = None
6251
nodes_per_worker: Optional[int] = None
6352
optimize_costs: Optional[bool] = None
64-
min_workers: int
65-
max_workers: int
66-
per_worker: int
67-
labels: Dict[str, str]
6853
prewarm: Optional[bool] = None
6954
high_priority: Optional[bool] = None
7055
billing_tags: Optional[Dict[str, Any]] = None
@@ -77,6 +62,83 @@ class CreateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel):
7762
)
7863

7964

65+
class CreateLLMModelEndpointArgs(LLMModelEndpointCommonArgs):
66+
name: str
67+
model_name: str
68+
"""
69+
Number of shards to distribute the model onto GPUs.
70+
"""
71+
metadata: Dict[str, Any] # TODO: JSON type
72+
min_workers: int
73+
max_workers: int
74+
per_worker: int
75+
labels: Dict[str, str]
76+
source: LLMSource = LLMSource.HUGGING_FACE
77+
inference_framework_image_tag: str = "latest"
78+
num_shards: int = 1
79+
endpoint_type: ModelEndpointType = ModelEndpointType.SYNC
80+
81+
82+
class CreateVLLMModelEndpointRequest(
83+
VLLMEndpointAdditionalArgs, CreateLLMModelEndpointArgs, BaseModel
84+
):
85+
inference_framework: Literal[LLMInferenceFramework.VLLM] = LLMInferenceFramework.VLLM
86+
pass
87+
88+
89+
class CreateSGLangModelEndpointRequest(
90+
SGLangEndpointAdditionalArgs, CreateLLMModelEndpointArgs, BaseModel
91+
):
92+
inference_framework: Literal[LLMInferenceFramework.SGLANG] = LLMInferenceFramework.SGLANG
93+
pass
94+
95+
96+
class CreateDeepSpeedModelEndpointRequest(CreateLLMModelEndpointArgs, BaseModel):
97+
inference_framework: Literal[LLMInferenceFramework.DEEPSPEED] = LLMInferenceFramework.DEEPSPEED
98+
pass
99+
100+
101+
class CreateTextGenerationInferenceModelEndpointRequest(CreateLLMModelEndpointArgs, BaseModel):
102+
inference_framework: Literal[LLMInferenceFramework.TEXT_GENERATION_INFERENCE] = (
103+
LLMInferenceFramework.TEXT_GENERATION_INFERENCE
104+
)
105+
pass
106+
107+
108+
class CreateLightLLMModelEndpointRequest(CreateLLMModelEndpointArgs, BaseModel):
109+
inference_framework: Literal[LLMInferenceFramework.LIGHTLLM] = LLMInferenceFramework.LIGHTLLM
110+
pass
111+
112+
113+
class CreateTensorRTLLMModelEndpointRequest(CreateLLMModelEndpointArgs, BaseModel):
114+
inference_framework: Literal[LLMInferenceFramework.TENSORRT_LLM] = (
115+
LLMInferenceFramework.TENSORRT_LLM
116+
)
117+
pass
118+
119+
120+
def get_inference_framework(v: Any) -> str:
121+
if isinstance(v, dict):
122+
return v.get("inference_framework", LLMInferenceFramework.VLLM)
123+
return getattr(v, "inference_framework", LLMInferenceFramework.VLLM)
124+
125+
126+
CreateLLMModelEndpointV1Request: TypeAlias = Annotated[
127+
Union[
128+
Annotated[CreateVLLMModelEndpointRequest, Tag(LLMInferenceFramework.VLLM)],
129+
Annotated[CreateSGLangModelEndpointRequest, Tag(LLMInferenceFramework.SGLANG)],
130+
Annotated[CreateDeepSpeedModelEndpointRequest, Tag(LLMInferenceFramework.DEEPSPEED)],
131+
Annotated[
132+
CreateTextGenerationInferenceModelEndpointRequest,
133+
Tag(LLMInferenceFramework.TEXT_GENERATION_INFERENCE),
134+
],
135+
Annotated[CreateLightLLMModelEndpointRequest, Tag(LLMInferenceFramework.LIGHTLLM)],
136+
Annotated[CreateTensorRTLLMModelEndpointRequest, Tag(LLMInferenceFramework.TENSORRT_LLM)],
137+
],
138+
Discriminator(get_inference_framework),
139+
]
140+
141+
80142
class CreateLLMModelEndpointV1Response(BaseModel):
81143
endpoint_creation_task_id: str
82144

@@ -107,57 +169,73 @@ class ListLLMModelEndpointsV1Response(BaseModel):
107169
model_endpoints: List[GetLLMModelEndpointV1Response]
108170

109171

110-
class UpdateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel):
111-
# LLM specific fields
172+
class UpdateLLMModelEndpointArgs(LLMModelEndpointCommonArgs):
112173
model_name: Optional[str] = None
113174
source: Optional[LLMSource] = None
175+
inference_framework: Optional[LLMInferenceFramework] = None
114176
inference_framework_image_tag: Optional[str] = None
115177
num_shards: Optional[int] = None
116178
"""
117179
Number of shards to distribute the model onto GPUs.
118180
"""
119-
120-
quantize: Optional[Quantization] = None
121-
"""
122-
Whether to quantize the model.
181+
metadata: Optional[Dict[str, Any]] = None
182+
force_bundle_recreation: Optional[bool] = False
123183
"""
184+
Whether to force recreate the underlying bundle.
124185
125-
checkpoint_path: Optional[str] = None
126-
"""
127-
Path to the checkpoint to load the model from.
186+
If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created
187+
that we would like to pick up for existing endpoints
128188
"""
129-
130-
# General endpoint fields
131-
metadata: Optional[Dict[str, Any]] = None
132-
post_inference_hooks: Optional[List[str]] = None
133-
cpus: Optional[CpuSpecificationType] = None
134-
gpus: Optional[int] = None
135-
memory: Optional[StorageSpecificationType] = None
136-
gpu_type: Optional[GpuType] = None
137-
storage: Optional[StorageSpecificationType] = None
138-
optimize_costs: Optional[bool] = None
139189
min_workers: Optional[int] = None
140190
max_workers: Optional[int] = None
141191
per_worker: Optional[int] = None
142192
labels: Optional[Dict[str, str]] = None
143-
prewarm: Optional[bool] = None
144-
high_priority: Optional[bool] = None
145-
billing_tags: Optional[Dict[str, Any]] = None
146-
default_callback_url: Optional[HttpUrlStr] = None
147-
default_callback_auth: Optional[CallbackAuth] = None
148-
public_inference: Optional[bool] = None
149-
chat_template_override: Optional[str] = Field(
150-
default=None,
151-
description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint",
193+
194+
195+
class UpdateVLLMModelEndpointRequest(
196+
VLLMEndpointAdditionalArgs, UpdateLLMModelEndpointArgs, BaseModel
197+
):
198+
inference_framework: Literal[LLMInferenceFramework.VLLM] = LLMInferenceFramework.VLLM
199+
200+
201+
class UpdateSGLangModelEndpointRequest(
202+
SGLangEndpointAdditionalArgs, UpdateLLMModelEndpointArgs, BaseModel
203+
):
204+
inference_framework: Literal[LLMInferenceFramework.SGLANG] = LLMInferenceFramework.SGLANG
205+
206+
207+
class UpdateDeepSpeedModelEndpointRequest(UpdateLLMModelEndpointArgs, BaseModel):
208+
inference_framework: Literal[LLMInferenceFramework.DEEPSPEED] = LLMInferenceFramework.DEEPSPEED
209+
210+
211+
class UpdateTextGenerationInferenceModelEndpointRequest(UpdateLLMModelEndpointArgs, BaseModel):
212+
inference_framework: Literal[LLMInferenceFramework.TEXT_GENERATION_INFERENCE] = (
213+
LLMInferenceFramework.TEXT_GENERATION_INFERENCE
152214
)
153215

154-
force_bundle_recreation: Optional[bool] = False
155-
"""
156-
Whether to force recreate the underlying bundle.
157216

158-
If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created
159-
that we would like to pick up for existing endpoints
160-
"""
217+
class UpdateLightLLMModelEndpointRequest(UpdateLLMModelEndpointArgs, BaseModel):
218+
inference_framework: Literal[LLMInferenceFramework.LIGHTLLM] = LLMInferenceFramework.LIGHTLLM
219+
220+
221+
class UpdateTensorRTLLMModelEndpointRequest(UpdateLLMModelEndpointArgs, BaseModel):
222+
inference_framework: Literal[LLMInferenceFramework.TENSORRT_LLM] = (
223+
LLMInferenceFramework.TENSORRT_LLM
224+
)
225+
226+
227+
UpdateLLMModelEndpointV1Request: TypeAlias = Annotated[
228+
Union[
229+
Annotated[UpdateVLLMModelEndpointRequest, Tag(LLMInferenceFramework.VLLM)],
230+
Annotated[UpdateSGLangModelEndpointRequest, Tag(LLMInferenceFramework.SGLANG)],
231+
Annotated[UpdateDeepSpeedModelEndpointRequest, Tag(LLMInferenceFramework.DEEPSPEED)],
232+
Annotated[
233+
UpdateTextGenerationInferenceModelEndpointRequest,
234+
Tag(LLMInferenceFramework.TEXT_GENERATION_INFERENCE),
235+
],
236+
],
237+
Discriminator(get_inference_framework),
238+
]
161239

162240

163241
class UpdateLLMModelEndpointV1Response(BaseModel):

0 commit comments

Comments
 (0)