Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spawn worker in custom environment #1739

Merged
merged 15 commits into from
May 17, 2024
22 changes: 18 additions & 4 deletions docs/user-guide/custom.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,26 @@ To load a custom environment, [parallel inference](./parallel-inference)
```

```{warning}
When loading custom environments, MLServer will always use the same Python
interpreter that is used to run the main process.
In other words, all custom environments will use the same version of Python
than the main MLServer process.
The main MLServer process communicates with workers in custom environments via
[`multiprocessing.Queue`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue)
using pickled objects. Custom environments therefore **must** use the same
version of MLServer and a compatible version of Python with the same [default
pickle protocol](https://docs.python.org/3/library/pickle.html#pickle.DEFAULT_PROTOCOL)
as the main process. Consult the tables below for environment compatibility.
```

| Status | Description |
| ------ | ------------ |
| 🔴 | Unsupported |
| 🟢 | Supported |
| 🔵 | Untested |

| Worker Python \ Server Python | 3.9 | 3.10 | 3.11 |
| ----------------------------- | --- | ---- | ---- |
| 3.9 | 🟢 | 🟢 | 🔵 |
| 3.10 | 🟢 | 🟢 | 🔵 |
| 3.11 | 🔵 | 🔵 | 🔵 |

If we take the [previous example](#loading-a-custom-mlserver-runtime) above as
a reference, we could extend it to include our custom environment as:

Expand Down
10 changes: 10 additions & 0 deletions mlserver/env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import multiprocessing
import os
import sys
import tarfile
Expand Down Expand Up @@ -99,6 +100,13 @@ def _bin_path(self) -> str:
"""
return os.path.join(self._env_path, "bin")

@cached_property
def _exec_path(self) -> str:
"""
Path to python executable in our custom environment.
"""
return os.path.join(self._bin_path, "python")

@cached_property
def _lib_path(self) -> str:
"""
Expand All @@ -118,11 +126,13 @@ def __enter__(self) -> "Environment":
self._prev_sys_path = sys.path
self._prev_bin_path = os.environ["PATH"]

multiprocessing.set_executable(self._exec_path)
sys.path = [*self._sys_path, *self._prev_sys_path]
os.environ["PATH"] = os.pathsep.join([self._bin_path, self._prev_bin_path])

return self

def __exit__(self, *exc_details) -> None:
multiprocessing.set_executable(sys.executable)
sys.path = self._prev_sys_path
os.environ["PATH"] = self._prev_bin_path
17 changes: 12 additions & 5 deletions mlserver/parallel/pool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio

from contextlib import nullcontext
from multiprocessing import Queue
from typing import Awaitable, Callable, Dict, Optional, List, Iterable

Expand All @@ -24,6 +25,14 @@
InferencePoolHook = Callable[[Worker], Awaitable[None]]


def _spawn_worker(settings: Settings, responses: Queue, env: Optional[Environment]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add return type hint

with env or nullcontext():
worker = Worker(settings, responses, env)
worker.start()

return worker


class WorkerRegistry:
"""
Simple registry to keep track of which models have been loaded.
Expand Down Expand Up @@ -78,9 +87,8 @@ def __init__(
self._worker_registry = WorkerRegistry()
self._settings = settings
self._responses: Queue[ModelResponseMessage] = Queue()
for idx in range(self._settings.parallel_workers):
worker = Worker(self._settings, self._responses, self._env)
worker.start()
for _ in range(self._settings.parallel_workers):
worker = _spawn_worker(self._settings, self._responses, self._env)
self._workers[worker.pid] = worker # type: ignore

self._dispatcher = Dispatcher(self._workers, self._responses)
Expand Down Expand Up @@ -123,8 +131,7 @@ async def on_worker_stop(self, pid: int, exit_code: int):
await self._start_worker()

async def _start_worker(self) -> Worker:
worker = Worker(self._settings, self._responses, self._env)
worker.start()
worker = _spawn_worker(self._settings, self._responses, self._env)
logger.info(f"Starting new worker with PID {worker.pid} on {self.name}...")

# Add to dispatcher so that it can receive load requests and reload all
Expand Down
2 changes: 1 addition & 1 deletion mlserver/parallel/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def _handle_worker_stop(self, signum, frame):

await self._default_pool.on_worker_stop(pid, exit_code)
await asyncio.gather(
*[pool.on_worker_stop(pid) for pool in self._pools.values()]
*[pool.on_worker_stop(pid, exit_code) for pool in self._pools.values()]
)

async def _get_or_create(self, model: MLModel) -> InferencePool:
Expand Down
26 changes: 22 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json

from filelock import FileLock
from typing import Dict, Any
from typing import Dict, Any, Tuple
from starlette_exporter import PrometheusMiddleware
from prometheus_client.registry import REGISTRY, CollectorRegistry
from unittest.mock import Mock
Expand All @@ -29,6 +29,13 @@
from .fixtures import SumModel, ErrorModel, SimpleModel
from .utils import RESTClient, get_available_ports, _pack, _get_tarball_name

MIN_PYTHON_VERSION = (3, 9)
MAX_PYTHON_VERSION = (3, 10)
PYTHON_VERSIONS = [
(major, minor)
for major in range(MIN_PYTHON_VERSION[0], MAX_PYTHON_VERSION[0] + 1)
for minor in range(MIN_PYTHON_VERSION[1], MAX_PYTHON_VERSION[1] + 1)
]
TESTS_PATH = os.path.dirname(__file__)
TESTDATA_PATH = os.path.join(TESTS_PATH, "testdata")
TESTDATA_CACHE_PATH = os.path.join(TESTDATA_PATH, ".cache")
Expand Down Expand Up @@ -59,17 +66,28 @@ def testdata_cache_path() -> str:
return TESTDATA_CACHE_PATH


@pytest.fixture(
params=PYTHON_VERSIONS,
ids=[f"py{major}{minor}" for (major, minor) in PYTHON_VERSIONS],
)
def env_python_version(request: pytest.FixtureRequest) -> Tuple[int, int]:
return request.param


@pytest.fixture
async def env_tarball(testdata_cache_path: str) -> str:
tarball_name = _get_tarball_name()
async def env_tarball(
env_python_version: Tuple[int, int],
testdata_cache_path: str,
) -> str:
tarball_name = _get_tarball_name(env_python_version)
tarball_path = os.path.join(testdata_cache_path, tarball_name)

with FileLock(f"{tarball_path}.lock"):
if os.path.isfile(tarball_path):
return tarball_path

env_yml = os.path.join(TESTDATA_PATH, "environment.yml")
await _pack(env_yml, tarball_path)
await _pack(env_python_version, env_yml, tarball_path)

return tarball_path

Expand Down
6 changes: 2 additions & 4 deletions tests/parallel/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mlserver.env import Environment
from mlserver.parallel.dispatcher import Dispatcher
from mlserver.parallel.model import ModelMethods
from mlserver.parallel.pool import InferencePool
from mlserver.parallel.pool import InferencePool, _spawn_worker
from mlserver.parallel.worker import Worker
from mlserver.parallel.utils import configure_inference_pool, cancel_task
from mlserver.parallel.messages import (
Expand Down Expand Up @@ -170,9 +170,7 @@ async def worker_with_env(
):
# NOTE: This fixture will start an actual worker running on a separate
# process.
worker = Worker(settings, responses, env)

worker.start()
worker = _spawn_worker(settings, responses, env)

load_message = ModelUpdateMessage(
update_type=ModelUpdateType.Load, model_settings=env_model_settings
Expand Down
15 changes: 12 additions & 3 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import os
import shutil

from typing import Tuple

from mlserver.env import Environment, compute_hash


@pytest.fixture
def expected_python_folder() -> str:
v = sys.version_info
return f"python{v.major}.{v.minor}"
def expected_python_folder(env_python_version: Tuple[int, int]) -> str:
major, minor = env_python_version
return f"python{major}.{minor}"


async def test_compute_hash(env_tarball: str):
Expand Down Expand Up @@ -68,6 +70,10 @@ def test_bin_path(env: Environment):
assert env._bin_path == f"{env._env_path}/bin"


def test_exec_path(env: Environment):
assert env._exec_path == f"{env._env_path}/bin/python"


def test_activate_env(env: Environment):
assert env._env_path not in ",".join(sys.path)
assert env._env_path not in os.environ["PATH"]
Expand All @@ -82,5 +88,8 @@ def test_activate_env(env: Environment):
bin_paths = os.environ["PATH"].split(os.pathsep)
assert bin_paths[0] == env._bin_path

exec_path = os.path.join(bin_paths[0], "python")
assert exec_path == env._exec_path

assert env._env_path not in ",".join(sys.path)
assert env._env_path not in os.environ["PATH"]
4 changes: 2 additions & 2 deletions tests/testdata/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: custom-runtime-environment
channels:
- conda-forge
dependencies:
- python == 3.8
- python == 3.9
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good spot

- scikit-learn == 1.0.2
- pip:
- mlserver == 1.3.0.dev2
- git+${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git@${GITHUB_REF}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched to installing mlserver from git here since worker environment has to match main process for this PR. GitHub Actions should set these environment variables to the fork/branch so the worker environment installs the same mlserver, but default to SeldonIO/MLServer's master branch in tox.ini otherwise.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lhnwrk is it possible to install from the local mlserver directory so it is easier logic and also locally we might want to be testing from changes done locally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did go with this option first, but as it turns out this breaks docker build tests since the local mlserver directory is not available inside the container and the template Dockerfile only copies environment file. It has been a hassle to test locally though, what do you think if we use a separate yaml with a pinned version of mlserver for the CLI build test and install local directory for others?

32 changes: 16 additions & 16 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import socket
import yaml

import sys
import os

from itertools import filterfalse

from asyncio import subprocess
from typing import List
from typing import List, Tuple

from aiohttp.client_exceptions import (
ClientConnectorError,
Expand Down Expand Up @@ -65,29 +64,30 @@ def _is_python(dep: str) -> bool:
return "python" in dep


def _inject_python_version(env_yml: str, tarball_path: str) -> str:
def _inject_python_version(
version: Tuple[int, int],
env_yml: str,
tarball_path: str,
) -> str:
"""
To ensure the same environment.yml fixture we've got works across
environments, we inject dynamically the current Python version.
That way, we ensure tests using the tarball to load a dynamic custom
environment are using the same Python version used to run the tests.
To test the same environment.yml fixture we've got with different Python
versions across environments, we inject dynamically the requested version.
"""
env = _read_env(env_yml)

v = sys.version_info
major, minor = version
without_python = list(filterfalse(_is_python, env["dependencies"]))
with_env_python = [f"python == {v.major}.{v.minor}", *without_python]
with_env_python = [f"python == {major}.{minor}", *without_python]
env["dependencies"] = with_env_python

dst_folder = os.path.dirname(tarball_path)
new_env_yml = os.path.join(dst_folder, f"environment-py{v.major}{v.minor}.yml")
new_env_yml = os.path.join(dst_folder, f"environment-py{major}{minor}.yml")
_write_env(env, new_env_yml)
return new_env_yml


async def _pack(env_yml: str, tarball_path: str):
async def _pack(version: Tuple[int, int], env_yml: str, tarball_path: str):
uuid = generate_uuid()
fixed_env_yml = _inject_python_version(env_yml, tarball_path)
fixed_env_yml = _inject_python_version(version, env_yml, tarball_path)
env_name = f"mlserver-{uuid}"
try:
await _run(f"conda env create -n {env_name} -f {fixed_env_yml}")
Expand All @@ -104,9 +104,9 @@ async def _pack(env_yml: str, tarball_path: str):
await _run(f"conda env remove -n {env_name}")


def _get_tarball_name() -> str:
v = sys.version_info
return f"environment-py{v.major}{v.minor}.tar.gz"
def _get_tarball_name(version: Tuple[int, int]) -> str:
major, minor = version
return f"environment-py{major}{minor}.tar.gz"


class RESTClient:
Expand Down
4 changes: 4 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ commands_pre =
commands =
python -m pytest {posargs} -n auto \
{toxinidir}/tests
set_env =
GITHUB_SERVER_URL = {env:GITHUB_SERVER_URL:https\://github.com}
GITHUB_REPOSITORY = {env:GITHUB_REPOSITORY:SeldonIO/MLServer}
GITHUB_REF = {env:GITHUB_REF:/refs/heads/master}

[testenv:all-runtimes]
commands_pre =
Expand Down
Loading