Skip to content

Commit 6ef2f1e

Browse files
authored
Allow MLflow URI with scheme (#839)
1 parent 8e1daa0 commit 6ef2f1e

File tree

3 files changed

+30
-15
lines changed

3 files changed

+30
-15
lines changed

mlserver/utils.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import uuid
33
import asyncio
4+
import urllib.parse
45

56
from asyncio import Task
67
from typing import Callable, Dict, Optional, List, Type
@@ -21,22 +22,26 @@ async def get_model_uri(
2122
if not model_uri:
2223
raise InvalidModelURI(settings.name)
2324

24-
full_model_uri = _to_absolute_path(settings._source, model_uri)
25-
if os.path.isfile(full_model_uri):
26-
return full_model_uri
25+
model_uri_components = urllib.parse.urlparse(model_uri, scheme="file")
26+
if model_uri_components.scheme != "file":
27+
return model_uri
28+
29+
full_model_path = _to_absolute_path(settings._source, model_uri_components.path)
30+
if os.path.isfile(full_model_path):
31+
return full_model_path
2732

28-
if os.path.isdir(full_model_uri):
29-
# If full_model_uri is a folder, search for a well-known model filename
33+
if os.path.isdir(full_model_path):
34+
# If full_model_path is a folder, search for a well-known model filename
3035
for fname in wellknown_filenames:
31-
model_path = os.path.join(full_model_uri, fname)
36+
model_path = os.path.join(full_model_path, fname)
3237
if os.path.isfile(model_path):
3338
return model_path
3439

3540
# If none, return the folder
36-
return full_model_uri
41+
return full_model_path
3742

3843
# Otherwise, the uri is neither a file nor a folder
39-
raise InvalidModelURI(settings.name, full_model_uri)
44+
raise InvalidModelURI(settings.name, full_model_path)
4045

4146

4247
def _to_absolute_path(source: Optional[str], model_uri: str) -> str:

runtimes/mlflow/tests/conftest.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ def pytorch_model_uri() -> str:
8787
return model_path
8888

8989

90-
@pytest.fixture
91-
def model_settings(model_uri: str) -> ModelSettings:
90+
@pytest.fixture(params=["", "file:"])
91+
def model_settings(model_uri: str, request: pytest.FixtureRequest) -> ModelSettings:
92+
scheme = request.param
93+
model_uri = scheme + model_uri
9294
return ModelSettings(
9395
name="mlflow-model",
9496
implementation=MLflowRuntime,

tests/test_utils.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
from mlserver.types import InferenceRequest, InferenceResponse, Parameters
1717
from mlserver.settings import ModelSettings, ModelParameters
1818

19-
20-
@pytest.mark.parametrize(
21-
"uri, source, expected",
22-
[
19+
test_get_model_uri_paramaters = [
20+
("s3://bucket/key", None, "s3://bucket/key"),
21+
("s3://bucket/key", "/mnt/models/model-settings.json", "s3://bucket/key"),
22+
]
23+
for scheme in ["", "file:"]:
24+
for uri, source, expected in [
2325
("my-model.bin", None, "my-model.bin"),
2426
(
2527
"my-model.bin",
@@ -36,7 +38,13 @@
3638
"/mnt/models/model-settings.json",
3739
"/an/absolute/path/my-model.bin",
3840
),
39-
],
41+
]:
42+
test_get_model_uri_paramaters.append((scheme + uri, source, expected))
43+
44+
45+
@pytest.mark.parametrize(
46+
"uri, source, expected",
47+
test_get_model_uri_paramaters,
4048
)
4149
async def test_get_model_uri(uri: str, source: Optional[str], expected: str):
4250
model_settings = ModelSettings(

0 commit comments

Comments
 (0)