mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Perf] Set IR Op Priority Once at Worker Init (#42631)
Signed-off-by: BadrBasowid <badr.basowid@gmail.com>
This commit is contained in:
@@ -272,6 +272,31 @@ class TestIrOpImplDispatch:
|
||||
# Restored to empty
|
||||
assert _custom_add.get_priority() == []
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"default,override",
|
||||
[
|
||||
(["impl_even", "impl_b"], ["impl_a"]),
|
||||
(["impl_a"], ["impl_even", "impl_b"]),
|
||||
],
|
||||
)
|
||||
def test_set_default_priority(
|
||||
self, custom_add_op, default: list[str], override: list[str]
|
||||
):
|
||||
_custom_add = custom_add_op
|
||||
assert _custom_add.get_priority() == []
|
||||
|
||||
_custom_add.set_default(default)
|
||||
assert _custom_add.get_priority() == default
|
||||
|
||||
# Priority doesn't change after exiting the set_priority context.
|
||||
with _custom_add.set_priority(override):
|
||||
assert _custom_add.get_priority() == override
|
||||
assert _custom_add.get_priority() == default
|
||||
|
||||
# Should override the previous default.
|
||||
_custom_add.set_default(override)
|
||||
assert _custom_add.get_priority() == override
|
||||
|
||||
@pytest.mark.parametrize("overload", ["default", "maybe_inplace"])
|
||||
def test_dispatch_priority_order(self, custom_add_op, overload: str):
|
||||
_custom_add = custom_add_op
|
||||
@@ -381,6 +406,26 @@ class TestIrOpImplDispatch:
|
||||
assert "priority not set" in message
|
||||
|
||||
|
||||
@pytest.mark.parametrize("default", [True, False])
|
||||
def test_set_default_torch_wrap(default: bool):
|
||||
"""set_default_torch_wrap permanently flips the global flag."""
|
||||
original = vllm.ir.op._ENABLE_TORCH_WRAP
|
||||
try:
|
||||
vllm.ir.set_default_torch_wrap(default)
|
||||
assert vllm.ir.op._ENABLE_TORCH_WRAP is default
|
||||
|
||||
# Flag doesn't change after exiting the enable_torch_wrap context.
|
||||
with vllm.ir.enable_torch_wrap(not default):
|
||||
assert vllm.ir.op._ENABLE_TORCH_WRAP is (not default)
|
||||
assert vllm.ir.op._ENABLE_TORCH_WRAP is default
|
||||
|
||||
# Should override the previous default.
|
||||
vllm.ir.set_default_torch_wrap(not default)
|
||||
assert vllm.ir.op._ENABLE_TORCH_WRAP is (not default)
|
||||
finally:
|
||||
vllm.ir.op._ENABLE_TORCH_WRAP = original
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_mm_op(fake_vllm_ir):
|
||||
"""Fixture that registers ``_custom_mm`` (isolated by ``fake_vllm_ir``)."""
|
||||
|
||||
+28
-17
@@ -20,8 +20,8 @@ logger = init_logger(__name__)
|
||||
class IrOpPriorityConfig:
|
||||
"""
|
||||
Configuration for vLLM IR op priority for dispatching/lowering during the
|
||||
forward pass. Each member is a list of strings, which will be passed to
|
||||
vllm.ir.ops.<op_name>.set_priority() for the duration of the forward pass.
|
||||
forward pass. Each member is a list of strings, which will be installed
|
||||
in worker init via vllm.ir.ops.<op_name>.set_default().
|
||||
A single comma-separated string is accepted as well,
|
||||
|
||||
If specified manually, platform defaults will be appended to the lists.
|
||||
@@ -67,6 +67,31 @@ class IrOpPriorityConfig:
|
||||
assert all(isinstance(v, str) for v in value)
|
||||
return value
|
||||
|
||||
def _iter_op_priorities(self):
|
||||
"""
|
||||
Yield (IrOp, priority_list) for each field, after importing platform
|
||||
kernels and validating each entry.
|
||||
"""
|
||||
from vllm.ir.op import IrOp
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.import_ir_kernels()
|
||||
|
||||
for field in fields(self): # type: ignore[arg-type]
|
||||
op_priority = getattr(self, field.name)
|
||||
assert op_priority is not None, (
|
||||
f"IR op priority for {field.name} must be set"
|
||||
)
|
||||
logger.debug("Setting IR op priority for %s to %s", field.name, op_priority)
|
||||
yield IrOp.registry[field.name], op_priority
|
||||
|
||||
def set_default(self) -> None:
|
||||
"""
|
||||
Permanently set the IR op priority for all op members.
|
||||
"""
|
||||
for ir_op, op_priority in self._iter_op_priorities():
|
||||
ir_op.set_default(op_priority)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_priority(self):
|
||||
"""
|
||||
@@ -74,23 +99,9 @@ class IrOpPriorityConfig:
|
||||
It also imports IR kernel implementations for the current platform
|
||||
to ensure all implementations are made available.
|
||||
"""
|
||||
from vllm.ir.op import IrOp
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.import_ir_kernels()
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
for field in fields(self): # type: ignore[arg-type]
|
||||
op_priority = getattr(self, field.name)
|
||||
assert op_priority is not None, (
|
||||
f"IR op priority for {field.name} must be set"
|
||||
)
|
||||
logger.debug(
|
||||
"Setting IR op priority for %s to %s", field.name, op_priority
|
||||
)
|
||||
ir_op = IrOp.registry[field.name]
|
||||
for ir_op, op_priority in self._iter_op_priorities():
|
||||
stack.enter_context(ir_op.set_priority(op_priority))
|
||||
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.ir
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -320,13 +319,7 @@ def set_forward_context(
|
||||
)
|
||||
|
||||
try:
|
||||
with (
|
||||
override_forward_context(forward_context),
|
||||
vllm_config.kernel_config.ir_op_priority.set_priority(),
|
||||
vllm.ir.enable_torch_wrap(
|
||||
vllm_config.compilation_config.ir_enable_torch_wrap
|
||||
),
|
||||
):
|
||||
with override_forward_context(forward_context):
|
||||
yield
|
||||
finally:
|
||||
global last_logging_time, batchsize_logging_interval
|
||||
|
||||
+2
-2
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from . import ops
|
||||
from .op import enable_torch_wrap, register_op
|
||||
from .op import enable_torch_wrap, register_op, set_default_torch_wrap
|
||||
|
||||
__all__ = ["enable_torch_wrap", "register_op", "ops"]
|
||||
__all__ = ["enable_torch_wrap", "register_op", "set_default_torch_wrap", "ops"]
|
||||
|
||||
+48
-29
@@ -51,6 +51,14 @@ _ENABLE_TORCH_WRAP: bool = True
|
||||
"""Global override flag to control torch op layer wrapping."""
|
||||
|
||||
|
||||
def set_default_torch_wrap(enable: bool = True) -> None:
|
||||
"""
|
||||
Permanently set the torch wrap flag.
|
||||
"""
|
||||
global _ENABLE_TORCH_WRAP
|
||||
_ENABLE_TORCH_WRAP = enable
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_torch_wrap(enable: bool = True):
|
||||
"""
|
||||
@@ -372,42 +380,53 @@ class IrOp:
|
||||
"""Get the current dispatch priority for implementations for this op."""
|
||||
return [p.provider for p in self._priority_impls]
|
||||
|
||||
def _filter_priority_impls(self, priority: list[str]) -> list["IrOpImpl"]:
|
||||
assert all(p in self.impls for p in priority), (
|
||||
"All providers in priority must be registered implementations."
|
||||
)
|
||||
filtered_impls: list[IrOpImpl] = []
|
||||
for p in priority:
|
||||
impl = self.impls[p]
|
||||
if not impl.supported:
|
||||
# Skip unsupported implementations
|
||||
continue
|
||||
|
||||
filtered_impls.append(impl)
|
||||
|
||||
# If all args are supported, skip other implementations
|
||||
if impl.supports_all_args:
|
||||
return filtered_impls
|
||||
|
||||
logger.warning_once(
|
||||
"Op %s: No implementation in priority list supports all args, "
|
||||
"execution fallback to native is possible. To silence this warning, "
|
||||
"explicitly add 'native' to the end of the priority list",
|
||||
self.name,
|
||||
)
|
||||
filtered_impls.append(self.impls["native"])
|
||||
return filtered_impls
|
||||
|
||||
def set_default(self, priority: list[str]) -> None:
|
||||
"""
|
||||
Permanently set the dispatch priority for this op. Use this for
|
||||
process-lifetime setup (e.g., worker startup). For scoped overrides,
|
||||
use ``set_priority`` instead.
|
||||
"""
|
||||
self._priority_impls = self._filter_priority_impls(priority)
|
||||
logger.debug(
|
||||
"Priority for vllm.ir.%s set to %s",
|
||||
self.name,
|
||||
lazy(lambda: [p.provider for p in self._priority_impls]),
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_priority(self, priority: list[str]):
|
||||
"""
|
||||
Context manager to set the dispatch priority for implementations for this op.
|
||||
"""
|
||||
assert all(p in self.impls for p in priority), (
|
||||
"All providers in priority must be registered implementations."
|
||||
)
|
||||
|
||||
def filter_priority_impls(p_list: list[str]) -> list[IrOpImpl]:
|
||||
filtered_impls = []
|
||||
for p in p_list:
|
||||
impl = self.impls[p]
|
||||
if not impl.supported:
|
||||
# Skip unsupported implementations
|
||||
continue
|
||||
|
||||
filtered_impls.append(impl)
|
||||
|
||||
# If all args are supported, skip other implementations
|
||||
if impl.supports_all_args:
|
||||
return filtered_impls
|
||||
|
||||
logger.warning_once(
|
||||
"Op %s: No implementation in priority list supports all args, "
|
||||
"execution fallback to native is possible. To silence this warning, "
|
||||
"explicitly add 'native' to the end of the priority list",
|
||||
self.name,
|
||||
)
|
||||
filtered_impls.append(self.impls["native"])
|
||||
return filtered_impls
|
||||
|
||||
# Temporarily set priority
|
||||
old_priority_impls = self._priority_impls
|
||||
try:
|
||||
self._priority_impls = filter_priority_impls(priority)
|
||||
self._priority_impls = self._filter_priority_impls(priority)
|
||||
logger.debug(
|
||||
"Priority for vllm.ir.%s set to %s",
|
||||
self.name,
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.ir
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -87,6 +88,13 @@ class WorkerBase:
|
||||
self.device: torch.device | None = None
|
||||
self.model_runner: nn.Module | None = None
|
||||
|
||||
# IR op priority and torch-wrap state are constant for the worker's
|
||||
# lifetime.
|
||||
vllm_config.kernel_config.ir_op_priority.set_default()
|
||||
vllm.ir.set_default_torch_wrap(
|
||||
vllm_config.compilation_config.ir_enable_torch_wrap
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""Get specifications for KV cache implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user