Skip to content

Commit 07357ac

Browse files
Merge pull request #1213 from RafalSkolasinski/batch-bytes
fix mlserver infer with BYTES
2 parents 58c52b9 + 9542138 commit 07357ac

File tree

5 files changed

+97
-6
lines changed

5 files changed

+97
-6
lines changed

mlserver/batch_processing.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import tritonclient.http.aio as httpclient
66

77
import asyncio
8+
import numpy as np
89
import aiofiles
910
import logging
1011
import click
@@ -101,17 +102,21 @@ def from_inference_request(
101102
cls, inference_request: InferenceRequest, binary_data: bool
102103
) -> "TritonRequest":
103104
inputs = []
104-
105105
for request_input in inference_request.inputs or []:
106106
new_input = httpclient.InferInput(
107107
request_input.name, request_input.shape, request_input.datatype
108108
)
109+
request_input_np = NumpyCodec.decode_input(request_input)
110+
111+
# Change datatype if BYTES to satisfy Tritonclient checks
112+
if request_input.datatype == "BYTES":
113+
request_input_np = request_input_np.astype(np.object_)
114+
109115
new_input.set_data_from_numpy(
110-
NumpyCodec.decode_input(request_input),
116+
request_input_np,
111117
binary_data=binary_data,
112118
)
113119
inputs.append(new_input)
114-
115120
outputs = []
116121
for request_output in inference_request.outputs or []:
117122
new_output = httpclient.InferRequestedOutput(
@@ -208,7 +213,6 @@ def preprocess_items(
208213
)
209214
invalid_inputs.append(_serialize_validation_error(item.index, e))
210215
batched = BatchedRequests(inference_requests)
211-
212216
# Set `id` for batched requests - if only single request use its own id
213217
if len(inference_requests) == 1:
214218
batched.merged_request.id = inference_request.id

tests/batch_processing/conftest.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
import pytest
22
import os
33

4+
from mlserver import MLModel, MLServer, ModelSettings
5+
46
from ..conftest import TESTDATA_PATH
7+
from ..fixtures import EchoModel
58

69

710
@pytest.fixture()
811
def single_input():
912
return os.path.join(TESTDATA_PATH, "batch_processing", "single.txt")
1013

1114

15+
@pytest.fixture()
16+
def bytes_input():
17+
return os.path.join(TESTDATA_PATH, "batch_processing", "bytes.txt")
18+
19+
1220
@pytest.fixture()
1321
def invalid_input():
1422
return os.path.join(TESTDATA_PATH, "batch_processing", "invalid.txt")
@@ -27,3 +35,9 @@ def many_input():
2735
@pytest.fixture()
2836
def single_input_with_id():
2937
return os.path.join(TESTDATA_PATH, "batch_processing", "single_with_id.txt")
38+
39+
40+
@pytest.fixture
41+
async def echo_model(mlserver: MLServer) -> MLModel:
42+
model_settings = ModelSettings(name="echo-model", implementation=EchoModel)
43+
return await mlserver._model_registry.load(model_settings)

tests/batch_processing/test_rest.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from mlserver.batch_processing import process_batch
77
from mlserver.settings import Settings
8-
8+
from mlserver import MLModel
99

1010
from ..utils import RESTClient
1111

@@ -53,6 +53,50 @@ async def test_single(
5353
raise RuntimeError(f"Response id is not a valid UUID; got {response['id']}")
5454

5555

56+
async def test_bytes(
57+
tmp_path: str,
58+
echo_model: MLModel,
59+
rest_client: RESTClient,
60+
settings: Settings,
61+
bytes_input: str,
62+
):
63+
await rest_client.wait_until_ready()
64+
model_name = "echo-model"
65+
url = f"{settings.host}:{settings.http_port}"
66+
output_file = os.path.join(tmp_path, "output.txt")
67+
68+
await process_batch(
69+
model_name=model_name,
70+
url=url,
71+
workers=1,
72+
retries=1,
73+
input_data_path=bytes_input,
74+
output_data_path=output_file,
75+
binary_data=False,
76+
batch_size=1,
77+
transport="rest",
78+
request_headers={},
79+
batch_interval=0,
80+
batch_jitter=0,
81+
timeout=60,
82+
use_ssl=False,
83+
insecure=False,
84+
verbose=True,
85+
extra_verbose=True,
86+
)
87+
88+
with open(output_file) as f:
89+
response = json.load(f)
90+
91+
assert response["outputs"][0]["data"] == ["a", "b", "c"]
92+
assert response["id"] is not None and response["id"] != ""
93+
assert response["parameters"]["batch_index"] == 0
94+
try:
95+
_ = UUID(response["id"])
96+
except ValueError:
97+
raise RuntimeError(f"Response id is not a valid UUID; got {response['id']}")
98+
99+
56100
async def test_invalid(
57101
tmp_path: str,
58102
rest_client: RESTClient,

tests/fixtures.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
from typing import Dict, List
1616

1717
from mlserver import MLModel
18-
from mlserver.types import InferenceRequest, InferenceResponse, Parameters
18+
from mlserver.types import (
19+
InferenceRequest,
20+
InferenceResponse,
21+
ResponseOutput,
22+
Parameters,
23+
)
1924
from mlserver.codecs import NumpyCodec, decode_args, StringCodec
2025
from mlserver.handlers.custom import custom_handler
2126
from mlserver.errors import MLServerError
@@ -100,3 +105,26 @@ async def predict(self, inference_request: InferenceRequest) -> InferenceRespons
100105
StringCodec.encode_output("sklearn_version", [self._sklearn_version]),
101106
],
102107
)
108+
109+
110+
class EchoModel(MLModel):
111+
async def load(self) -> bool:
112+
print("Echo Model Initialized")
113+
return await super().load()
114+
115+
async def predict(self, payload: InferenceRequest) -> InferenceResponse:
116+
return InferenceResponse(
117+
id=payload.id,
118+
model_name=self.name,
119+
model_version=self.version,
120+
outputs=[
121+
ResponseOutput(
122+
name=input.name,
123+
shape=input.shape,
124+
datatype=input.datatype,
125+
data=input.data,
126+
parameters=input.parameters,
127+
)
128+
for input in payload.inputs
129+
],
130+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"inputs":[{"name":"input-0","shape":[1,3],"datatype":"BYTES","data":["a","b","c"]}]}

0 commit comments

Comments
 (0)