From b4b2755a178c2cfefc90fa95d6d673237ef480fe Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Mon, 10 Jun 2024 15:17:40 +0200 Subject: [PATCH 1/9] [Hardware][Intel] OpenVINO vLLM backend --- .buildkite/run-openvino-test.sh | 14 + .buildkite/test-template.j2 | 4 + Dockerfile.openvino | 26 ++ benchmarks/benchmark_latency.py | 7 +- benchmarks/benchmark_throughput.py | 7 +- .../getting_started/openvino-installation.rst | 95 +++++ docs/source/index.rst | 1 + requirements-openvino.txt | 9 + setup.py | 13 +- vllm/attention/backends/openvino.py | 69 ++++ vllm/attention/selector.py | 10 +- vllm/config.py | 6 +- vllm/engine/arg_utils.py | 11 +- vllm/engine/async_llm_engine.py | 5 + vllm/engine/llm_engine.py | 3 + vllm/envs.py | 3 +- vllm/executor/openvino_executor.py | 162 ++++++++ vllm/model_executor/layers/sampler.py | 4 +- vllm/model_executor/model_loader/openvino.py | 221 +++++++++++ vllm/utils.py | 13 +- vllm/worker/openvino_model_runner.py | 340 +++++++++++++++++ vllm/worker/openvino_worker.py | 349 ++++++++++++++++++ 22 files changed, 1352 insertions(+), 20 deletions(-) create mode 100755 .buildkite/run-openvino-test.sh create mode 100644 Dockerfile.openvino create mode 100644 docs/source/getting_started/openvino-installation.rst create mode 100644 requirements-openvino.txt create mode 100644 vllm/attention/backends/openvino.py create mode 100644 vllm/executor/openvino_executor.py create mode 100644 vllm/model_executor/model_loader/openvino.py create mode 100644 vllm/worker/openvino_model_runner.py create mode 100644 vllm/worker/openvino_worker.py diff --git a/.buildkite/run-openvino-test.sh b/.buildkite/run-openvino-test.sh new file mode 100755 index 0000000000000..70e56596c4a86 --- /dev/null +++ b/.buildkite/run-openvino-test.sh @@ -0,0 +1,14 @@ +# This script build the OpenVINO docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t openvino-test -f Dockerfile.openvino . + +# Setup cleanup +remove_docker_container() { docker rm -f openvino-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and launch offline inference +docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 4a20a462b98ec..3e3348a384400 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -45,6 +45,10 @@ steps: queue: intel command: bash .buildkite/run-cpu-test.sh + - label: "OpenVINO Test" + depends_on: ~ + command: bash .buildkite/run-openvino-test.sh + {% for step in steps %} - label: "{{ step.label }}" agents: diff --git a/Dockerfile.openvino b/Dockerfile.openvino new file mode 100644 index 0000000000000..9861997b451a9 --- /dev/null +++ b/Dockerfile.openvino @@ -0,0 +1,26 @@ +# The vLLM Dockerfile is used to construct vLLM image that can be directly used +# to run the OpenAI compatible server. + +FROM ubuntu:22.04 AS dev + +RUN apt-get update -y && \ + apt-get install -y python3-pip git +WORKDIR /workspace + +# copy requirements +COPY requirements-build.txt /workspace/vllm/ +COPY requirements-common.txt /workspace/vllm/ +COPY requirements-openvino.txt /workspace/vllm/ + +COPY vllm/ /workspace/vllm/vllm +COPY setup.py /workspace/vllm/ + +# install build requirements +RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt +# build vLLM with OpenVINO backend +RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ + +COPY examples/ /workspace/vllm/examples +COPY benchmarks/ /workspace/vllm/benchmarks + +CMD ["/bin/bash"] diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 1a41b66b38824..bf8d360959dd9 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -188,9 +188,10 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument( "--device", type=str, - default="cuda", - choices=["cuda", "cpu"], - help='device type for vLLM execution, supporting CUDA and CPU.') + default="auto", + choices=["auto", "cuda", "cpu", "openvino"], + help='device type for vLLM execution, supporting CUDA, OpenVINO and ' + 'CPU.') parser.add_argument('--block-size', type=int, default=16, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 90f7433e0ae28..04abb5e14c1d2 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -345,9 +345,10 @@ def main(args: argparse.Namespace): parser.add_argument( "--device", type=str, - default="cuda", - choices=["cuda", "cpu"], - help='device type for vLLM execution, supporting CUDA and CPU.') + default="auto", + choices=["auto", "cuda", "cpu", "openvino"], + help='device type for vLLM execution, supporting CUDA, OpenVINO and ' + 'CPU.') parser.add_argument( "--enable-prefix-caching", action='store_true', diff --git a/docs/source/getting_started/openvino-installation.rst b/docs/source/getting_started/openvino-installation.rst new file mode 100644 index 0000000000000..71b7807d241cb --- /dev/null +++ b/docs/source/getting_started/openvino-installation.rst @@ -0,0 +1,95 @@ +.. _installation_openvino: + +Installation with OpenVINO +======================== + +vLLM powered by OpenVINO supports all LLM models from [vLLM supported models list](../dev/models/supported_models.rst) and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features: + +- Prefix caching (``--enable-prefix-caching``) +- Chunked prefill (``--enable-chunked-prefill``) + +Table of contents: + +#. :ref:`Requirements ` +#. :ref:`Quick start using Dockerfile ` +#. :ref:`Build from source ` +#. :ref:`Performance tips ` +#. :ref:`Limitations ` + +.. _openvino_backend_requirements: + +Requirements +------------ + +* OS: Linux +* Instruction set architecture (ISA) requirement: at least AVX2. + +.. _openvino_backend_quick_start_dockerfile: + +Quick start using Dockerfile +---------------------------- + +.. code-block:: console + + $ docker build -f Dockerfile.openvino -t vllm-openvino-env . + $ docker run -it --rm vllm-openvino-env + +.. _install_openvino_backend_from_source: + +Install from source +----------------- + +- First, install Python. For example, on Ubuntu 22.04, you can run: + +.. code-block:: console + + $ sudo apt-get update -y + $ sudo apt-get install python3 + +- Second, install prerequisites vLLM OpenVINO backend installation: + +.. code-block:: console + + $ pip install --upgrade pip + $ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu + +- Finally, install vLLM with OpenVINO backend: + +.. code-block:: console + + $ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python install -v . + +.. _openvino_backend_performance_tips: + +Performance tips +----------------- + +vLLM OpenVINO backend uses the following environment variables to control behavior: + +- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. + +- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform. + +- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. + +To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``) + +OpenVINO best known configuration is: + +.. code-block:: console + + $ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \ + python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256 + +.. _openvino_backend_limitations: + +Limitations +----------------- + +- LoRA serving is not supported. + +- Only LLM models are currently supported. LLaVa and encoder-decoder models are not currently enabled in vLLM OpenVINO integration. + +- Tensor and pipeline parallelism are not currently enabled in vLLM integration. + +- Speculative sampling is not tested within vLLM integration. diff --git a/docs/source/index.rst b/docs/source/index.rst index fad3c3b05b0c0..32eb4ea22a117 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -64,6 +64,7 @@ Documentation getting_started/installation getting_started/amd-installation getting_started/neuron-installation + getting_started/openvino-installation getting_started/cpu-installation getting_started/quickstart getting_started/examples/examples_index diff --git a/requirements-openvino.txt b/requirements-openvino.txt new file mode 100644 index 0000000000000..e555d52572541 --- /dev/null +++ b/requirements-openvino.txt @@ -0,0 +1,9 @@ +# Common dependencies +-r requirements-common.txt + +# OpenVINO dependencies +torch >= 2.1.2 +openvino ~= 2024.3.0.dev +optimum-intel[openvino] >= 1.17.2 + +triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. diff --git a/setup.py b/setup.py index 339b0ad6de2d1..9fde3552a93e5 100644 --- a/setup.py +++ b/setup.py @@ -229,6 +229,10 @@ def _is_cpu() -> bool: return VLLM_TARGET_DEVICE == "cpu" +def _is_openvino() -> bool: + return VLLM_TARGET_DEVICE == "openvino" + + def _install_punica() -> bool: return envs.VLLM_INSTALL_PUNICA_KERNELS @@ -325,6 +329,8 @@ def get_vllm_version() -> str: if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"+neuron{neuron_version_str}" + elif _is_openvino(): + version += "+openvino" elif _is_cpu(): version += "+cpu" else: @@ -372,11 +378,14 @@ def _read_requirements(filename: str) -> List[str]: requirements = _read_requirements("requirements-rocm.txt") elif _is_neuron(): requirements = _read_requirements("requirements-neuron.txt") + elif _is_openvino(): + requirements = _read_requirements("requirements-openvino.txt") elif _is_cpu(): requirements = _read_requirements("requirements-cpu.txt") else: raise ValueError( - "Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.") + "Unsupported platform, please use CUDA, ROCm, Neuron, " + "OpenVINO, or CPU.") return requirements @@ -385,7 +394,7 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) -if not _is_neuron(): +if not (_is_neuron() or _is_openvino()): ext_modules.append(CMakeExtension(name="vllm._C")) if _install_punica(): diff --git a/vllm/attention/backends/openvino.py b/vllm/attention/backends/openvino.py new file mode 100644 index 0000000000000..d75cd0ad0daaa --- /dev/null +++ b/vllm/attention/backends/openvino.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import List, Tuple + +import openvino as ov +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) + + +class OpenVINOAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "openvino" + + @staticmethod + def get_impl_cls(): + # OpenVINO implements PagedAttention as part of the Optimum + # exported model + raise NotImplementedError + + @staticmethod + def make_metadata(*args, **kwargs) -> "AttentionMetadata": + raise NotImplementedError + + @staticmethod + def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata": + return OpenVINOAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, num_kv_heads, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: ov.Tensor, + dst_kv_cache: ov.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + # OpenVINO currently supports only CPU, which does not require + # swap of KV cache blocks + raise NotImplementedError + + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[ov.Tensor, ov.Tensor]], + src_to_dists: List[Tuple[int, int]], + ) -> None: + for src, dst in src_to_dists: + for key_cache, value_cache in kv_caches: + key_cache.data[dst, :] = key_cache.data[src, :] + value_cache.data[dst, :] = value_cache.data[src, :] + + +@dataclass +class OpenVINOAttentionMetadata: + """Metadata for OpenVINOAttentionBackend. + """ + past_lens: torch.Tensor + subsequence_begins: torch.Tensor + block_indices: torch.Tensor + block_indices_begins: torch.Tensor + max_context_len: torch.Tensor diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7253483f9a0b8..12141e821c084 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -7,7 +7,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger -from vllm.utils import is_cpu, is_hip +from vllm.utils import is_cpu, is_hip, is_openvino logger = init_logger(__name__) @@ -17,6 +17,7 @@ class _Backend(enum.Enum): XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() + OPENVINO = enum.auto() FLASHINFER = enum.auto() @@ -60,6 +61,10 @@ def get_attn_backend( logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend return TorchSDPABackend + elif backend == _Backend.OPENVINO: + logger.info("Using OpenVINO Attention backend.") + from vllm.attention.backends.openvino import OpenVINOAttentionBackend + return OpenVINOAttentionBackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") logger.warning("Eager mode is required for the Flashinfer backend. " @@ -100,6 +105,9 @@ def which_attn_to_use( logger.info("Cannot use %s backend on CPU.", selected_backend) return _Backend.TORCH_SDPA + if is_openvino(): + return _Backend.OPENVINO + if is_hip(): # AMD GPUs. selected_backend = (_Backend.ROCM_FLASH if selected_backend diff --git a/vllm/config.py b/vllm/config.py index fa296cd626f17..3f739ca95e63c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron +from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_openvino if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -730,6 +730,8 @@ def __init__(self, device: str = "auto") -> None: # Automated device type detection if is_neuron(): self.device_type = "neuron" + elif is_openvino(): + self.device_type = "openvino" elif is_cpu(): self.device_type = "cpu" else: @@ -741,7 +743,7 @@ def __init__(self, device: str = "auto") -> None: self.device_type = device # Some device types require processing inputs on CPU - if self.device_type in ["neuron"]: + if self.device_type in ["neuron", "openvino"]: self.device = torch.device("cpu") else: # Set device with device type diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b7e815db12eb4..e1dc93d90671f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -494,11 +494,12 @@ def add_cli_args( 'Enabling this will use the fully sharded layers. ' 'At high sequence length, max rank or ' 'tensor parallel size, this is likely faster.')) - parser.add_argument("--device", - type=str, - default=EngineArgs.device, - choices=["auto", "cuda", "neuron", "cpu"], - help='Device type for vLLM execution.') + parser.add_argument( + "--device", + type=str, + default=EngineArgs.device, + choices=["auto", "cuda", "neuron", "openvino", "cpu"], + help='Device type for vLLM execution.') # Related to Vision-language models such as llava parser = EngineArgs.add_cli_args_for_vlm(parser) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index aa1f07b5bdc24..625845ffcbdb3 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -380,6 +380,11 @@ def from_engine_args( "Distributed execution is not supported with the CPU backend.") from vllm.executor.cpu_executor import CPUExecutorAsync executor_class = CPUExecutorAsync + elif engine_config.device_config.device_type == "openvino": + assert not engine_config.parallel_config.worker_use_ray, ( + "Ray is not supported with the OpenVINO backend.") + from vllm.executor.openvino_executor import OpenVINOExecutorAsync + executor_class = OpenVINOExecutorAsync elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cb5893e707c8b..8f68672ae11a7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -343,6 +343,9 @@ def from_engine_args( elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor + elif engine_config.device_config.device_type == "openvino": + from vllm.executor.openvino_executor import OpenVINOExecutor + executor_class = OpenVINOExecutor elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor diff --git a/vllm/envs.py b/vllm/envs.py index b140aa6d658e6..29113162597e9 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -48,7 +48,8 @@ # ================== Installation Time Env Vars ================== - # Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu] + # Target device of vLLM, supporting [cuda (by default), + # rocm, neuron, cpu, openvino] "VLLM_TARGET_DEVICE": lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"), diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py new file mode 100644 index 0000000000000..028933be247b6 --- /dev/null +++ b/vllm/executor/openvino_executor.py @@ -0,0 +1,162 @@ +import os +from typing import List, Set, Tuple + +import openvino as ov +import openvino.properties.hint as hints +import torch + +from vllm.config import CacheConfig, ModelConfig +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) + +logger = init_logger(__name__) + + +class OpenVINOExecutor(ExecutorBase): + + def _init_executor(self) -> None: + assert self.device_config.device_type == "openvino" + assert self.lora_config is None, "OpenVINO backend doesn't support LoRA" + self.model_config = _verify_and_get_model_config(self.model_config) + self.cache_config = _verify_and_get_cache_config(self.cache_config) + + # Instantiate the worker and load the model to CPU. + self._init_worker() + + def _init_worker(self): + from vllm.worker.openvino_worker import OpenVINOWorker + + assert ( + self.parallel_config.world_size == 1 + ), "OpenVINOExecutor only supports single CPU socket currently." + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = OpenVINOWorker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=True, + ) + self.driver_worker.init_device() + self.driver_worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker.""" + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + # NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is + # referred as `gpu block`. Because we want to reuse the existing block + # management procedure. + logger.info("# CPU blocks: %d", num_gpu_blocks) + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.driver_worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.driver_worker.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.driver_worker.list_loras() + + def check_health(self) -> None: + # OpenVINOExecutor will always be healthy as long as + # it's running. + return + + +class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) + return output + + async def check_health_async(self) -> None: + # OpenVINOExecutor will always be healthy as long as + # it's running. + return + + +def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: + if config.dtype != torch.float32: + logger.warning( + f"Only float32 dtype is supported on OpenVINO, casting from {config.dtype}." # noqa: G004, E501 + ) + config.dtype = torch.float32 + if not config.enforce_eager: + logger.warning( + "CUDA graph is not supported on OpenVINO backend, fallback to the " + "eager mode.") + config.enforce_eager = True + return config + + +def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: + if os.environ.get("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION", "") == "u8": + logger.warning("KV cache type is overried to u8 via " + "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") + config.cache_dtype = ov.Type.u8 + else: + core = ov.Core() + inference_precision = core.get_property("CPU", + hints.inference_precision) + if inference_precision == ov.Type.bf16: + config.cache_dtype = ov.Type.bf16 + else: + config.cache_dtype = ov.Type.f16 + + kv_cache_space_str = os.getenv("VLLM_OPENVINO_KVCACHE_SPACE", "0") + kv_cache_space = int(kv_cache_space_str) + + if config.block_size != 32: + logger.warning( + f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501 + ) + config.block_size = 32 + + if kv_cache_space >= 0: + _GB = 1 << 30 + if kv_cache_space == 0: + config.openvino_kvcache_space_bytes = 4 * _GB # type: ignore + logger.warning( + "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " + "for OpenVINO backend is not set, using 4 by default.") + else: + config.openvino_kvcache_space_bytes = kv_cache_space * _GB # type: ignore + else: + raise RuntimeError( + "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") + + return config diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index a84f562909d50..d2f09eeb9f2c8 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -676,7 +676,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. - Each element in the returned tensor represents the rank + Each element in the returned tensor represents the rank of the chosen token in the input logprob tensor. """ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), @@ -962,7 +962,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, distribution. - Greedy sampling performs `argmax` to obtain the token with the highest likelihood. - + Ignoring greedy sampling for a moment, we find that the computed probability distribution has the following property: we can sample from it independently and find that the token sampled by the Sampler has a frequency corresponding diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py new file mode 100644 index 0000000000000..cf3893ddc19f3 --- /dev/null +++ b/vllm/model_executor/model_loader/openvino.py @@ -0,0 +1,221 @@ +# ruff: noqa: SIM117 +import os +from pathlib import Path +from typing import List, Optional, Tuple + +import openvino as ov +import torch +from huggingface_hub import HfApi +from openvino._offline_transformations import paged_attention_transformation +from optimum.intel import OVModelForCausalLM +from torch import nn + +from vllm.attention.backends.openvino import OpenVINOAttentionMetadata +from vllm.config import DeviceConfig, ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import (LogitsProcessor, + _prune_hidden_states) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput + +logger = init_logger(__name__) + + +def _flattenize_inputs(inputs): + """ + Helper function for making nested inputs flattens + """ + flatten_inputs = [] + for input_data in inputs: + if input_data is None: + continue + if isinstance(input_data, (list, tuple)): + flatten_inputs.extend(_flattenize_inputs(input_data)) + elif isinstance(input_data, dict): + flatten_inputs.extend(_flattenize_inputs(list( + input_data.values()))) + else: + flatten_inputs.append(input_data) + return flatten_inputs + + +def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type, + is_cpu: bool): + # Apply hardware dependent modifications to KV tensors + for parameter in model.get_parameters(): + input = parameter.get_output_tensor(0) + input_names = input.get_names() + if len(input_names) != 1: + continue + input_name = next(iter(input_names)) + shape = parameter.get_partial_shape() + # use real block size if available, just a placeholder + # to provide the expected rank + x_size = 1 + num_blocks = ov.Dimension() + block_size = ov.Dimension() + head_size = ov.Dimension() + # TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD), + # pass more parameters to this function to set more static dimensions + if input_name.startswith("key_cache."): + cpu_shape = [num_blocks, shape[1], block_size, head_size] + gpu_shape = [ + num_blocks, + shape[1], + shape[2].get_length() // + x_size if shape[2].is_static else ov.Dimension(), + block_size, + x_size, + ] + elif input_name.startswith("value_cache."): + cpu_shape = [num_blocks, shape[1], block_size, head_size] + gpu_shape = [num_blocks, shape[1], shape[2], block_size] + else: + continue + parameter.set_partial_shape( + ov.PartialShape(cpu_shape if is_cpu else gpu_shape)) + parameter.set_element_type(kv_cache_dtype) + model.validate_nodes_and_infer_types() + + +def _require_model_export(model_id, revision=None, subfolder=None): + # Stored IR may not be suitable for vLLM purposes (too old, + # not stateful, not compressed etc.). This is an option to override + # IR usage logic and always do model conversion. + if os.environ.get("VLLM_OPENVINO_OPTIMUM_FORCE_CONVERSION", "0") == "1": + return True + model_dir = Path(model_id) + if subfolder is not None: + model_dir = model_dir / subfolder + if model_dir.is_dir(): + return (not (model_dir / "openvino_model.xml").exists() + or not (model_dir / "openvino_model.bin").exists()) + + hf_api = HfApi() + try: + model_info = hf_api.model_info(model_id, revision=revision or "main") + normalized_subfolder = (None if subfolder is None else + Path(subfolder).as_posix()) + model_files = [ + file.rfilename for file in model_info.siblings + if normalized_subfolder is None + or file.rfilename.startswith(normalized_subfolder) + ] + ov_model_path = ("openvino_model.xml" if normalized_subfolder is None + else f"{normalized_subfolder}/openvino_model.xml") + return (ov_model_path not in model_files + or ov_model_path.replace(".xml", ".bin") not in model_files) + except Exception: + return True + + +class OpenVINOCasualLM(nn.Module): + + def __init__( + self, + model_config: ModelConfig, + device_config: DeviceConfig, + kv_cache_dtype: ov.Type, + ) -> None: + super().__init__() + self.logits_processor = LogitsProcessor( + model_config.hf_config.vocab_size, logits_as_input=True) + self.sampler = Sampler() + + export = _require_model_export(model_config.model) + if export: + logger.warning( + f"[ INFO ] Provided model id {model_config.model} does not " # noqa: G004 + "contain OpenVINO IR, the model will be converted to IR with " + "default options. If you need to use specific options for " + "model conversion, use optimum-cli export openvino with " + "desired options.") + else: + logger.warning( + "[ INFO ] OpenVINO IR is available for provided model id " # noqa: G004 + f"{model_config.model}. This IR will be used for inference " + "as-is, all possible options that may affect model conversion " + "are ignored.") + + load_in_8bit = os.environ.get("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", + "0") == "1" # noqa: E501 + pt_model = OVModelForCausalLM.from_pretrained( + model_config.model, + export=export, + compile=False, + load_in_8bit=load_in_8bit, + trust_remote_code=model_config.trust_remote_code, + ) + + paged_attention_transformation(pt_model.model) + _modify_cache_parameters(pt_model.model, kv_cache_dtype, + device_config.device.type == "cpu") + + # For deployment outside vLLM + model_file_name = os.environ.get("VLLM_OPENVINO_EXPORTED_IR_NAME", "") + if model_file_name: + ov.save_model(pt_model.model, model_file_name) + + core = ov.Core() + ov_compiled = core.compile_model(pt_model.model, "CPU") + self.ov_request = ov_compiled.create_infer_request() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[Tuple[ov.Tensor, ov.Tensor]], + attn_metadata: OpenVINOAttentionMetadata, + ) -> torch.Tensor: + flatten_kv_cache = _flattenize_inputs(kv_caches) + + inputs = [ + input_ids, + positions, + *flatten_kv_cache, + attn_metadata.past_lens, + attn_metadata.subsequence_begins, + attn_metadata.block_indices, + attn_metadata.block_indices_begins, + attn_metadata.max_context_len, + ] + + self.ov_request.start_async(inputs, share_inputs=True) + self.ov_request.wait() + + logits = torch.from_numpy(self.ov_request.get_tensor("logits").data) + + # TODO: remove 'view' once OpenVINO PA will drop 'seq_len' dimension + return logits.view(-1, logits.shape[-1]) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + +def get_model( + model_config: ModelConfig, + device_config: DeviceConfig, + kv_cache_dtype: ov.Type, + **kwargs, +) -> torch.nn.Module: + lora_config = kwargs.get("lora_config", None) + if lora_config: + raise ValueError( + "OpenVINO modeling does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + + return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype) diff --git a/vllm/utils.py b/vllm/utils.py index 54d446b23350a..bbd3395bf4b31 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -31,10 +31,12 @@ STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.half, "bfloat16": torch.bfloat16, + "float16": torch.float16, "float": torch.float, "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, + "u8": torch.uint8 } @@ -137,6 +139,15 @@ def is_cpu() -> bool: return False +@lru_cache(maxsize=None) +def is_openvino() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "openvino" in version("vllm") + except PackageNotFoundError: + return False + + @lru_cache(maxsize=None) def is_neuron() -> bool: try: @@ -468,7 +479,7 @@ def is_pin_memory_available() -> bool: elif is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False - elif is_cpu(): + elif is_cpu() or is_openvino(): return False return True diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py new file mode 100644 index 0000000000000..d4399cf1d9ae6 --- /dev/null +++ b/vllm/worker/openvino_model_runner.py @@ -0,0 +1,340 @@ +from typing import List, NamedTuple, Optional, Tuple + +import openvino as ov +import torch +from torch import nn + +from vllm.attention import get_attn_backend +from vllm.attention.backends.openvino import OpenVINOAttentionMetadata +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader.openvino import get_model +from vllm.sequence import SamplerOutput, SequenceGroupMetadata + +logger = init_logger(__name__) + + +class ModelInput(NamedTuple): + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: Optional[OpenVINOAttentionMetadata] + seq_lens: List[int] + query_lens: List[int] + multi_modal_input: Optional[torch.Tensor] + + @classmethod + def empty(cls, device): + return ModelInput(input_tokens=torch.empty(0, device=device), + input_positions=torch.empty(0, device=device), + attn_metadata=None, + seq_lens=[], + query_lens=[], + multi_modal_input=None) + + +class OpenVINOModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.vision_language_config = vision_language_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + def load_model(self) -> None: + self.model = get_model( + model_config=self.model_config, + device_config=self.device_config, + kv_cache_dtype=self.kv_cache_dtype, + ) + + def _prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> ModelInput: + """Prepare the model input based on a given sequence group. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + """ + input_tokens: List[int] = [] + input_positions: List[int] = [] + + seq_lens: List[int] = [] + past_lens: List[int] = [] + query_lens: List[int] = [] + subsequence_begins: List[int] = [] + block_indices: List[int] = [] + block_indices_begins: List[int] = [] + multi_modal_input_list: List[torch.Tensor] = [] + + # initialize beginning of prefix sums + subsequence_begins.append(0) + block_indices_begins.append(0) + + if len(seq_group_metadata_list) == 0: + return ModelInput.empty(self.device) + + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + is_prompt = seq_group_metadata.is_prompt + + for seq_id in seq_ids: + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + seq_data = seq_group_metadata.seq_data[seq_id] + if is_prompt: + computed_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + computed_len = seq_data.get_len() - 1 + + seq_len = min( + seq_data.get_len(), + computed_len + seq_group_metadata.token_chunk_size, + ) + if is_prompt: + tokens = seq_data.get_token_ids()[computed_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and is_prompt) + + block_table = seq_group_metadata.block_tables[seq_id] + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + computed_len = len(computed_block_nums) * self.block_size + tokens = tokens[computed_len:] + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + block_table = seq_group_metadata.block_tables[seq_id] + if self.sliding_window is not None: + # chunked prefill doesn't support sliding window. + assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501 + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # prompt phase w/o prefix_caching, chunked_prefill + pass + + block_indices.extend(block_table) + block_indices_begins.append(block_indices_begins[-1] + + len(block_table)) + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if self.sliding_window is not None and not is_prompt: + seq_len = min(seq_len, self.sliding_window) + computed_len = seq_len - 1 + + seq_lens.append(seq_len) + + query_len = seq_len - computed_len + query_lens.append(query_len) + + input_tokens.extend(tokens) + input_positions.extend(list(range(computed_len, seq_len))) + + past_lens.append(computed_len) + subsequence_begins.append(subsequence_begins[-1] + query_len) + + if is_prompt: + assert len(seq_ids) == 1 + else: + assert ( + query_len == 1 + ), "seq_len: {}, computed_len: {}, query_len: {}".format( + seq_len, computed_len, query_len) + + max_query_len = max(query_lens) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + + past_lens_tensor = torch.tensor(past_lens, + dtype=torch.int32, + device=self.device) # type: ignore + subsequence_begins_tensor = torch.tensor( + subsequence_begins, dtype=torch.int32, + device=self.device) # type: ignore + block_indices_tensor = torch.tensor(block_indices, + dtype=torch.int32, + device=self.device) # type: ignore + block_indices_begins_tensor = torch.tensor( + block_indices_begins, dtype=torch.int32, + device=self.device) # type: ignore + + max_context_len = max(seq_lens) + max_context_len_tensor = torch.tensor( + max_context_len, dtype=torch.int32, + device=self.device) # type: ignore + + attn_metadata = self.attn_backend.make_openvino_metadata( + past_lens=past_lens_tensor, + subsequence_begins=subsequence_begins_tensor, + block_indices=block_indices_tensor, + block_indices_begins=block_indices_begins_tensor, + max_context_len=max_context_len_tensor, + ) + return ModelInput( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + query_lens, + multi_modal_input, + ) + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata, + SamplingMetadata, Optional[torch.Tensor], ]: + multi_modal_input = None + + # Prepare input tensors. + ( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + query_lens, + multi_modal_input, + ) = self._prepare_model_input(seq_group_metadata_list) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens, + self.device, + pin_memory=False, + ) + + return ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + multi_modal_input, + ) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]], + ) -> Optional[SamplerOutput]: + ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + multi_modal_input, + ) = self.prepare_input_tensors(seq_group_metadata_list) + + model_executable = self.model + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return output diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py new file mode 100644 index 0000000000000..ceb7db4a49b15 --- /dev/null +++ b/vllm/worker/openvino_worker.py @@ -0,0 +1,349 @@ +"""An OpenVINO worker class.""" +from typing import Any, Dict, List, Optional, Tuple + +import openvino as ov +import torch +import torch.distributed + +from vllm.attention import get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import (broadcast_tensor_dict, + ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.worker.openvino_model_runner import OpenVINOModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + +logger = init_logger(__name__) + + +class OpenVINOCacheEngine: + """Manages the KV cache for OpenVINO backend. + + This class is responsible for initializing and managing CPU KV + caches. It also provides methods for performing KV cache operations, such + as copying. + """ + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig, + ) -> None: + assert device_config.device_type == "openvino" + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + if device_config.device.type == "cpu" and \ + cache_config.cache_dtype == ov.Type.u8: + # Scale, zero point and quantized data will be stored together. + # The layout for per token per head: + # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501 + self.head_size += 8 + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks + # for OpenVINO backend, because we want to reuse KV cache management + # in the scheduler. + self.num_cpu_blocks = cache_config.num_gpu_blocks + + # Get attention backend. + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + ) + + # Initialize the cache. + self.kv_cache: List[Tuple[ov.Tensor, + ov.Tensor]] = self._allocate_kv_cache( + self.num_cpu_blocks) + + def _allocate_kv_cache( + self, + num_blocks: int, + ) -> List[Tuple[ov.Tensor, ov.Tensor]]: + """Allocates KV cache.""" + k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:] + kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = [] + for _ in range(self.num_layers): + key_blocks = ov.Tensor(self.cache_config.cache_dtype, + k_block_shape) + value_blocks = ov.Tensor(self.cache_config.cache_dtype, + v_block_shape) + kv_cache.append((key_blocks, value_blocks)) + return kv_cache + + def swap_in(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError( + "Swap is not supported in OpenVINOCacheEngine.") + + def swap_out(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError( + "Swap is not supported in OpenVINOCacheEngine.") + + def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + block_size: int, + cache_dtype: ov.Type, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_kv_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + if cache_dtype == ov.Type.u8: + # Scale, zero point and quantized data will be stored together. + # The layout for per token per head: + # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501 + head_size += 8 + + key_cache_block = block_size * num_kv_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + dtype_size = cache_dtype.size + return dtype_size * total + + +class OpenVINOWorker(LoraNotSupportedWorkerBase): + """A worker class that executes the model on OpenVINO backend. + + Each worker is associated with a single OpenVINO device. The worker is + responsible for maintaining the KV cache and executing the model on the + OpenVINO backend. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + vision_language_config: Optional[VisionLanguageConfig] = None, + kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined, + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.vision_language_config = vision_language_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + + init_cached_hf_modules() + self.model_runner = OpenVINOModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=self.load_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + ) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: OpenVINOCacheEngine + self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] + + def init_device(self) -> None: + self.init_distributed_environment() + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of blocks available for the KV cache. + + This determines how many KV blocks can fit into the configured + KV cache space. + + Note that since vLLM assumes a block resides on GPU if it can be + modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. + This allows us to reuse the scheduler of vLLM without generalizing it + to different devices. + """ + # For OpenVINO backend, the block number will be calculated based on the + # openvino_kvcache_space_bytes. + cache_block_size = self.get_cache_block_size_bytes() + num_cpu_blocks = int(self.cache_config.openvino_kvcache_space_bytes // + cache_block_size) + num_cpu_blocks = max(num_cpu_blocks, 0) + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_gpu_blocks = num_cpu_blocks + num_cpu_blocks = 0 + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. Currently, swappable CPU memory is not + supported. + + Since this worker does not support GPUs, we use the num_gpu_blocks to + determine how many non-swappable CPU blocks to allocate. + """ + assert (num_cpu_blocks == 0 + ), f"{type(self)} does not support swappable cache" + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_cpu_blocks = num_gpu_blocks + + self._validate_num_cpu_blocks(num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_cpu_blocks + self.cache_config.num_cpu_blocks = 0 + + # Initialize the cache. + self._init_cache_engine() + + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: + """Raise errors if the num_cpu_blocks is invalid.""" + if num_cpu_blocks <= 0: + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_cpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_OPENVINO_KVCACHE_SPACE` or decreasing `max_model_len` " + "when initializing the engine.") + + def _init_cache_engine(self) -> None: + self.cache_engine = OpenVINOCacheEngine( + self.cache_config, + self.model_config, + self.parallel_config, + self.device_config, + ) + self.kv_cache = self.cache_engine.kv_cache + self.model_runner.block_size = self.cache_engine.block_size + + assert self.kv_cache is not None + + # Populate the cache to warmup the memory + for key_cache, value_cache in self.kv_cache: + key_cache.data[:] = 0 + value_cache.data[:] = 0 + + def cache_copy( + self, + blocks_to_copy: List[Tuple[int, int]], + ) -> None: + self.cache_engine.copy(blocks_to_copy) # type: ignore + + @torch.inference_mode() + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> List[SamplerOutput]: + if execute_model_req is None: + seq_group_metadata_list = None + else: + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups: int = len(seq_group_metadata_list) + assert execute_model_req is not None + blocks_to_copy = execute_model_req.blocks_to_copy + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 + data: Dict[str, Any] = { + "num_seq_groups": num_seq_groups, + "blocks_to_copy": execute_model_req.blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) + else: + data = broadcast_tensor_dict(src=0) + num_seq_groups = data["num_seq_groups"] + blocks_to_copy = data["blocks_to_copy"] + + self.cache_copy(blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return [] + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.kv_cache) + + # OpenVINO worker only supports single-step execution. + return [output] + + def init_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cpu()) + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + ) + + def get_cache_block_size_bytes(self) -> int: + """Return the size in bytes of a single KV cache block.""" + return OpenVINOCacheEngine.get_cache_block_size( + self.cache_config.block_size, + self.cache_config.cache_dtype, + self.model_config, + self.parallel_config, + ) From a29ed93151de613f94db0b1bd1a34d6aa33b84ef Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Tue, 25 Jun 2024 21:50:50 +0200 Subject: [PATCH 2/9] Review comments --- .../getting_started/openvino-installation.rst | 2 +- tests/kernels/test_attention_selector.py | 9 ++++-- vllm/attention/backends/openvino.py | 32 +++++++++++++++++++ vllm/envs.py | 18 +++++++++++ vllm/executor/openvino_executor.py | 16 ++++++---- vllm/model_executor/model_loader/openvino.py | 13 +++----- vllm/utils.py | 2 -- vllm/worker/openvino_model_runner.py | 12 +------ vllm/worker/openvino_worker.py | 4 ++- 9 files changed, 75 insertions(+), 33 deletions(-) diff --git a/docs/source/getting_started/openvino-installation.rst b/docs/source/getting_started/openvino-installation.rst index 71b7807d241cb..8a745b83e3b22 100644 --- a/docs/source/getting_started/openvino-installation.rst +++ b/docs/source/getting_started/openvino-installation.rst @@ -57,7 +57,7 @@ Install from source .. code-block:: console - $ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python install -v . + $ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python -m pip install -v . .. _openvino_backend_performance_tips: diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 79e03c7478de0..8e6c50666e70c 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -9,8 +9,8 @@ @pytest.mark.parametrize( - "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) -@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) + "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"]) +@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"]) def test_env(name: str, device: str, monkeypatch): """Test that the attention selector can be set via environment variable. Note that we do not test FlashAttn because it is the default backend. @@ -28,6 +28,11 @@ def test_env(name: str, device: str, monkeypatch): backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) assert backend.name == "ROCM_FLASH" + elif device == "openvino": + with patch("vllm.attention.selector.is_openvino", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "OPENVINO" else: backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) diff --git a/vllm/attention/backends/openvino.py b/vllm/attention/backends/openvino.py index d75cd0ad0daaa..0f21b50ad4dc7 100644 --- a/vllm/attention/backends/openvino.py +++ b/vllm/attention/backends/openvino.py @@ -61,9 +61,41 @@ def copy_blocks( @dataclass class OpenVINOAttentionMetadata: """Metadata for OpenVINOAttentionBackend. + + Basic terms used below: + - batch_size_in_sequences - total number of sequences to execute​ + - prompt_lens – per sequence size number of scheduled tokens​ + - batch_size_in_tokens = sum(prompt_lens)​ + - max_context_len = max(context_lens)​ + - max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​ + - num_blocks – total number of blocks in block_indices​ """ + + # Describes past KV cache size for each sequence within a batch + # Shape: [batch_size_in_sequences] + # Type: i32​ past_lens: torch.Tensor + + # Describes start indices of input / speculative tokens from + # current sequences within a batch sequence​ + # Shape: [batch_size_in_sequences + 1]​ + # Type: i32 subsequence_begins: torch.Tensor + + # Describes block tables for each sequence within a batch​ - + # indices along 0th dimension in key_cache and value_cache inputs​ + # Shape: [num_blocks] + # Type: i32​ block_indices: torch.Tensor + + # Describes block tables for each sequence within a batch​ - + # for i-th element, it is an index in block_indices with the + # first block belonging to i-th sequence​ + # Shape: [batch_size_in_sequences + 1] + # Type: i32​ block_indices_begins: torch.Tensor + + # Describes max context length + # Shape: scalar + # Type: i32 max_context_len: torch.Tensor diff --git a/vllm/envs.py b/vllm/envs.py index b88c4b5934aad..29d37ae478e46 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -28,6 +28,9 @@ VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_OPENVINO_KVCACHE_SPACE: int = 0 + VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None + VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/" VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "fork" @@ -209,6 +212,21 @@ "VLLM_CPU_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), + # OpenVINO key-value cache space + # default is 4GB + "VLLM_OPENVINO_KVCACHE_SPACE": + lambda: int(os.getenv("VLLM_OPENVINO_KVCACHE_SPACE", "0")), + + # OpenVINO KV cache precision + # default is bf16 if natively supported by platform, otherwise f16 + "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION": + lambda: os.getenv("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION", None), + + # OpenVINO key-value cache space + # default is False + "VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS": + lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)), + # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index 028933be247b6..1b011208eadb9 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -5,6 +5,7 @@ import openvino.properties.hint as hints import torch +import vllm.envs as envs from vllm.config import CacheConfig, ModelConfig from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -84,6 +85,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) + def pin_lora(self, lora_id: int) -> bool: + return self.driver_worker.pin_lora(lora_id) + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() @@ -123,9 +127,9 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: - if os.environ.get("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION", "") == "u8": - logger.warning("KV cache type is overried to u8 via " - "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") + if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": + logger.info("KV cache type is overried to u8 via " + "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") config.cache_dtype = ov.Type.u8 else: core = ov.Core() @@ -136,15 +140,13 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: else: config.cache_dtype = ov.Type.f16 - kv_cache_space_str = os.getenv("VLLM_OPENVINO_KVCACHE_SPACE", "0") - kv_cache_space = int(kv_cache_space_str) - if config.block_size != 32: - logger.warning( + logger.info( f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501 ) config.block_size = 32 + kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE if kv_cache_space >= 0: _GB = 1 << 30 if kv_cache_space == 0: diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py index cf3893ddc19f3..b65f853d426c5 100644 --- a/vllm/model_executor/model_loader/openvino.py +++ b/vllm/model_executor/model_loader/openvino.py @@ -10,6 +10,7 @@ from optimum.intel import OVModelForCausalLM from torch import nn +import vllm.envs as envs from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.config import DeviceConfig, ModelConfig from vllm.logger import init_logger @@ -126,20 +127,19 @@ def __init__( export = _require_model_export(model_config.model) if export: logger.warning( - f"[ INFO ] Provided model id {model_config.model} does not " # noqa: G004 + f"Provided model id {model_config.model} does not " # noqa: G004 "contain OpenVINO IR, the model will be converted to IR with " "default options. If you need to use specific options for " "model conversion, use optimum-cli export openvino with " "desired options.") else: logger.warning( - "[ INFO ] OpenVINO IR is available for provided model id " # noqa: G004 + "OpenVINO IR is available for provided model id " # noqa: G004 f"{model_config.model}. This IR will be used for inference " "as-is, all possible options that may affect model conversion " "are ignored.") - load_in_8bit = os.environ.get("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", - "0") == "1" # noqa: E501 + load_in_8bit = envs.VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS pt_model = OVModelForCausalLM.from_pretrained( model_config.model, export=export, @@ -152,11 +152,6 @@ def __init__( _modify_cache_parameters(pt_model.model, kv_cache_dtype, device_config.device.type == "cpu") - # For deployment outside vLLM - model_file_name = os.environ.get("VLLM_OPENVINO_EXPORTED_IR_NAME", "") - if model_file_name: - ov.save_model(pt_model.model, model_file_name) - core = ov.Core() ov_compiled = core.compile_model(pt_model.model, "CPU") self.ov_request = ov_compiled.create_infer_request() diff --git a/vllm/utils.py b/vllm/utils.py index a858d62d5705c..03c16d4d6bba9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -33,12 +33,10 @@ STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.half, "bfloat16": torch.bfloat16, - "float16": torch.float16, "float": torch.float, "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, - "u8": torch.uint8 } P = ParamSpec('P') diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index d4399cf1d9ae6..336eaf814fb3f 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -111,7 +111,6 @@ def _prepare_model_input( subsequence_begins: List[int] = [] block_indices: List[int] = [] block_indices_begins: List[int] = [] - multi_modal_input_list: List[torch.Tensor] = [] # initialize beginning of prefix sums subsequence_begins.append(0) @@ -220,15 +219,6 @@ def _prepare_model_input( max_query_len = max(query_lens) assert max_query_len > 0, "query_lens: {}".format(query_lens) - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None - input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore @@ -267,7 +257,7 @@ def _prepare_model_input( attn_metadata, seq_lens, query_lens, - multi_modal_input, + None, ) def prepare_input_tensors( diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index ceb7db4a49b15..a5e7adb5fac72 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -47,6 +47,7 @@ def __init__( # Scale, zero point and quantized data will be stored together. # The layout for per token per head: # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501 + # so, we have to extend head_size by 8, which is sizeof(float) for scale and sizeof(float) for zeropoint self.head_size += 8 self.num_layers = model_config.get_num_layers(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) @@ -60,7 +61,7 @@ def __init__( # Get attention backend. self.attn_backend = get_attn_backend( self.model_config.get_num_attention_heads(self.parallel_config), - self.model_config.get_head_size(), + self.head_size, self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, @@ -115,6 +116,7 @@ def get_cache_block_size( # Scale, zero point and quantized data will be stored together. # The layout for per token per head: # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501 + # so, we have to extend head_size by 8, which is sizeof(float) for scale and sizeof(float) for zeropoint head_size += 8 key_cache_block = block_size * num_kv_heads * head_size From 9e6ed8d7381ac48911869b8a47bf083145c962c9 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Tue, 25 Jun 2024 23:02:06 +0200 Subject: [PATCH 3/9] Dropped VLLM_OPENVINO_OPTIMUM_FORCE_CONVERSION env var --- vllm/model_executor/model_loader/openvino.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py index b65f853d426c5..dcae1cf997092 100644 --- a/vllm/model_executor/model_loader/openvino.py +++ b/vllm/model_executor/model_loader/openvino.py @@ -81,11 +81,6 @@ def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type, def _require_model_export(model_id, revision=None, subfolder=None): - # Stored IR may not be suitable for vLLM purposes (too old, - # not stateful, not compressed etc.). This is an option to override - # IR usage logic and always do model conversion. - if os.environ.get("VLLM_OPENVINO_OPTIMUM_FORCE_CONVERSION", "0") == "1": - return True model_dir = Path(model_id) if subfolder is not None: model_dir = model_dir / subfolder From d902872516ce92d1941e7ea1bdf42de6b7a64704 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Tue, 25 Jun 2024 23:12:56 +0200 Subject: [PATCH 4/9] Fixed code style --- vllm/engine/async_llm_engine.py | 3 ++- vllm/executor/openvino_executor.py | 1 - vllm/model_executor/model_loader/openvino.py | 1 - vllm/worker/openvino_worker.py | 6 ++++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5d9b44aebb30d..1575c2739f08d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -393,7 +393,8 @@ def from_engine_args( executor_class = CPUExecutorAsync elif engine_config.device_config.device_type == "openvino": assert distributed_executor_backend is None, ( - "Distributed execution is not supported with the OpenVINO backend.") + "Distributed execution is not supported with " + "the OpenVINO backend.") from vllm.executor.openvino_executor import OpenVINOExecutorAsync executor_class = OpenVINOExecutorAsync elif engine_config.device_config.device_type == "xpu": diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index 1b011208eadb9..8af375371f2f0 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -1,4 +1,3 @@ -import os from typing import List, Set, Tuple import openvino as ov diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py index dcae1cf997092..5c522a61732a4 100644 --- a/vllm/model_executor/model_loader/openvino.py +++ b/vllm/model_executor/model_loader/openvino.py @@ -1,5 +1,4 @@ # ruff: noqa: SIM117 -import os from pathlib import Path from typing import List, Optional, Tuple diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index a5e7adb5fac72..7a462ce5d0b66 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -47,7 +47,8 @@ def __init__( # Scale, zero point and quantized data will be stored together. # The layout for per token per head: # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501 - # so, we have to extend head_size by 8, which is sizeof(float) for scale and sizeof(float) for zeropoint + # so, we have to extend head_size by 8, which is sizeof(float) + # for scale and sizeof(float) for zeropoint self.head_size += 8 self.num_layers = model_config.get_num_layers(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) @@ -116,7 +117,8 @@ def get_cache_block_size( # Scale, zero point and quantized data will be stored together. # The layout for per token per head: # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501 - # so, we have to extend head_size by 8, which is sizeof(float) for scale and sizeof(float) for zeropoint + # so, we have to extend head_size by 8, which is sizeof(float) + # for scale and sizeof(float) for zeropoint head_size += 8 key_cache_block = block_size * num_kv_heads * head_size From 4bee0666e8822f1e728310bcc0775c61071dc0c9 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Tue, 25 Jun 2024 23:14:28 +0200 Subject: [PATCH 5/9] Fixed isort code style --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index b9a7e0ff7dbcb..7b906dc151b12 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -13,7 +13,7 @@ from vllm.tracing import is_otel_installed from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, - is_hip, is_openvino, is_neuron, is_tpu, is_xpu) + is_hip, is_neuron, is_openvino, is_tpu, is_xpu) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup From a9c85eb5e083e580565a78217c1ad9dfc7f3c8c4 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Tue, 25 Jun 2024 23:24:53 +0200 Subject: [PATCH 6/9] Fixed yapf code style --- vllm/engine/arg_utils.py | 2 ++ vllm/envs.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9e129d28bf7df..6257c98fc558c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -504,12 +504,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'Enabling this will use the fully sharded layers. ' 'At high sequence length, max rank or ' 'tensor parallel size, this is likely faster.')) + # yapf: disable parser.add_argument( "--device", type=str, default=EngineArgs.device, choices=["auto", "cuda", "neuron", "cpu", "openvino", "tpu", "xpu"], help='Device type for vLLM execution.') + # yapf: enable # Related to Vision-language models such as llava parser = EngineArgs.add_cli_args_for_vlm(parser) diff --git a/vllm/envs.py b/vllm/envs.py index 29d37ae478e46..e5e79d2c0f482 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -216,7 +216,7 @@ # default is 4GB "VLLM_OPENVINO_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_OPENVINO_KVCACHE_SPACE", "0")), - + # OpenVINO KV cache precision # default is bf16 if natively supported by platform, otherwise f16 "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION": From 3de627ceb932fbe9b985a1a2dfa3b9e7f653e1c7 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Thu, 27 Jun 2024 16:35:10 +0200 Subject: [PATCH 7/9] Fixed next portion of comments --- docs/source/getting_started/openvino-installation.rst | 2 +- vllm/envs.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/openvino-installation.rst b/docs/source/getting_started/openvino-installation.rst index 8a745b83e3b22..2942d6deb7ae2 100644 --- a/docs/source/getting_started/openvino-installation.rst +++ b/docs/source/getting_started/openvino-installation.rst @@ -3,7 +3,7 @@ Installation with OpenVINO ======================== -vLLM powered by OpenVINO supports all LLM models from [vLLM supported models list](../dev/models/supported_models.rst) and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features: +vLLM powered by OpenVINO supports all LLM models from :ref:`vLLM supported models list <_supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features: - Prefix caching (``--enable-prefix-caching``) - Chunked prefill (``--enable-chunked-prefill``) diff --git a/vllm/envs.py b/vllm/envs.py index e5e79d2c0f482..e8257535f1bf5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -219,10 +219,11 @@ # OpenVINO KV cache precision # default is bf16 if natively supported by platform, otherwise f16 + # To enable KV cache compression, please, explicitly specify u8 "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION": lambda: os.getenv("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION", None), - # OpenVINO key-value cache space + # Enables weights compression during model export via HF Optimum # default is False "VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS": lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)), From 295e494656b0f778f04f8899ff0a7c90c53b2aae Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Thu, 27 Jun 2024 21:35:21 +0200 Subject: [PATCH 8/9] Fixed docs compilation --- docs/source/getting_started/openvino-installation.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/getting_started/openvino-installation.rst b/docs/source/getting_started/openvino-installation.rst index 2942d6deb7ae2..78cb84853bb8b 100644 --- a/docs/source/getting_started/openvino-installation.rst +++ b/docs/source/getting_started/openvino-installation.rst @@ -1,7 +1,7 @@ .. _installation_openvino: Installation with OpenVINO -======================== +========================== vLLM powered by OpenVINO supports all LLM models from :ref:`vLLM supported models list <_supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features: @@ -37,7 +37,7 @@ Quick start using Dockerfile .. _install_openvino_backend_from_source: Install from source ------------------ +------------------- - First, install Python. For example, on Ubuntu 22.04, you can run: @@ -62,7 +62,7 @@ Install from source .. _openvino_backend_performance_tips: Performance tips ------------------ +---------------- vLLM OpenVINO backend uses the following environment variables to control behavior: @@ -84,7 +84,7 @@ OpenVINO best known configuration is: .. _openvino_backend_limitations: Limitations ------------------ +----------- - LoRA serving is not supported. From 4f0be9641cd26d20ee31cc5559706d0edf9988ac Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Thu, 27 Jun 2024 22:00:57 +0200 Subject: [PATCH 9/9] Fixed docs: attempt 2 --- .../getting_started/openvino-installation.rst | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/source/getting_started/openvino-installation.rst b/docs/source/getting_started/openvino-installation.rst index 78cb84853bb8b..0d8e0b680ff0d 100644 --- a/docs/source/getting_started/openvino-installation.rst +++ b/docs/source/getting_started/openvino-installation.rst @@ -3,18 +3,18 @@ Installation with OpenVINO ========================== -vLLM powered by OpenVINO supports all LLM models from :ref:`vLLM supported models list <_supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features: +vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features: - Prefix caching (``--enable-prefix-caching``) - Chunked prefill (``--enable-chunked-prefill``) -Table of contents: +**Table of contents**: -#. :ref:`Requirements ` -#. :ref:`Quick start using Dockerfile ` -#. :ref:`Build from source ` -#. :ref:`Performance tips ` -#. :ref:`Limitations ` +- :ref:`Requirements ` +- :ref:`Quick start using Dockerfile ` +- :ref:`Build from source ` +- :ref:`Performance tips ` +- :ref:`Limitations ` .. _openvino_backend_requirements: @@ -41,23 +41,23 @@ Install from source - First, install Python. For example, on Ubuntu 22.04, you can run: -.. code-block:: console + .. code-block:: console - $ sudo apt-get update -y - $ sudo apt-get install python3 + $ sudo apt-get update -y + $ sudo apt-get install python3 - Second, install prerequisites vLLM OpenVINO backend installation: -.. code-block:: console + .. code-block:: console - $ pip install --upgrade pip - $ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu + $ pip install --upgrade pip + $ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu - Finally, install vLLM with OpenVINO backend: -.. code-block:: console + .. code-block:: console - $ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python -m pip install -v . + $ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python -m pip install -v . .. _openvino_backend_performance_tips: