[Perf] Set IR Op Priority Once at Worker Init (#42631)

Signed-off-by: BadrBasowid <badr.basowid@gmail.com>
This commit is contained in:
BadrBasowid
2026-05-15 23:56:13 +08:00
committed by GitHub
parent ee58665aac
commit fb5bd03f51
6 changed files with 132 additions and 56 deletions
+45
View File
@@ -272,6 +272,31 @@ class TestIrOpImplDispatch:
# Restored to empty
assert _custom_add.get_priority() == []
@pytest.mark.parametrize(
"default,override",
[
(["impl_even", "impl_b"], ["impl_a"]),
(["impl_a"], ["impl_even", "impl_b"]),
],
)
def test_set_default_priority(
self, custom_add_op, default: list[str], override: list[str]
):
_custom_add = custom_add_op
assert _custom_add.get_priority() == []
_custom_add.set_default(default)
assert _custom_add.get_priority() == default
# Priority doesn't change after exiting the set_priority context.
with _custom_add.set_priority(override):
assert _custom_add.get_priority() == override
assert _custom_add.get_priority() == default
# Should override the previous default.
_custom_add.set_default(override)
assert _custom_add.get_priority() == override
@pytest.mark.parametrize("overload", ["default", "maybe_inplace"])
def test_dispatch_priority_order(self, custom_add_op, overload: str):
_custom_add = custom_add_op
@@ -381,6 +406,26 @@ class TestIrOpImplDispatch:
assert "priority not set" in message
@pytest.mark.parametrize("default", [True, False])
def test_set_default_torch_wrap(default: bool):
"""set_default_torch_wrap permanently flips the global flag."""
original = vllm.ir.op._ENABLE_TORCH_WRAP
try:
vllm.ir.set_default_torch_wrap(default)
assert vllm.ir.op._ENABLE_TORCH_WRAP is default
# Flag doesn't change after exiting the enable_torch_wrap context.
with vllm.ir.enable_torch_wrap(not default):
assert vllm.ir.op._ENABLE_TORCH_WRAP is (not default)
assert vllm.ir.op._ENABLE_TORCH_WRAP is default
# Should override the previous default.
vllm.ir.set_default_torch_wrap(not default)
assert vllm.ir.op._ENABLE_TORCH_WRAP is (not default)
finally:
vllm.ir.op._ENABLE_TORCH_WRAP = original
@pytest.fixture
def custom_mm_op(fake_vllm_ir):
"""Fixture that registers ``_custom_mm`` (isolated by ``fake_vllm_ir``)."""
+28 -17
View File
@@ -20,8 +20,8 @@ logger = init_logger(__name__)
class IrOpPriorityConfig:
"""
Configuration for vLLM IR op priority for dispatching/lowering during the
forward pass. Each member is a list of strings, which will be passed to
vllm.ir.ops.<op_name>.set_priority() for the duration of the forward pass.
forward pass. Each member is a list of strings, which will be installed
in worker init via vllm.ir.ops.<op_name>.set_default().
A single comma-separated string is accepted as well,
If specified manually, platform defaults will be appended to the lists.
@@ -67,6 +67,31 @@ class IrOpPriorityConfig:
assert all(isinstance(v, str) for v in value)
return value
def _iter_op_priorities(self):
"""
Yield (IrOp, priority_list) for each field, after importing platform
kernels and validating each entry.
"""
from vllm.ir.op import IrOp
from vllm.platforms import current_platform
current_platform.import_ir_kernels()
for field in fields(self): # type: ignore[arg-type]
op_priority = getattr(self, field.name)
assert op_priority is not None, (
f"IR op priority for {field.name} must be set"
)
logger.debug("Setting IR op priority for %s to %s", field.name, op_priority)
yield IrOp.registry[field.name], op_priority
def set_default(self) -> None:
"""
Permanently set the IR op priority for all op members.
"""
for ir_op, op_priority in self._iter_op_priorities():
ir_op.set_default(op_priority)
@contextlib.contextmanager
def set_priority(self):
"""
@@ -74,23 +99,9 @@ class IrOpPriorityConfig:
It also imports IR kernel implementations for the current platform
to ensure all implementations are made available.
"""
from vllm.ir.op import IrOp
from vllm.platforms import current_platform
current_platform.import_ir_kernels()
with contextlib.ExitStack() as stack:
for field in fields(self): # type: ignore[arg-type]
op_priority = getattr(self, field.name)
assert op_priority is not None, (
f"IR op priority for {field.name} must be set"
)
logger.debug(
"Setting IR op priority for %s to %s", field.name, op_priority
)
ir_op = IrOp.registry[field.name]
for ir_op, op_priority in self._iter_op_priorities():
stack.enter_context(ir_op.set_priority(op_priority))
yield
@classmethod
+1 -8
View File
@@ -10,7 +10,6 @@ from typing import Any
import torch
import vllm.envs as envs
import vllm.ir
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -320,13 +319,7 @@ def set_forward_context(
)
try:
with (
override_forward_context(forward_context),
vllm_config.kernel_config.ir_op_priority.set_priority(),
vllm.ir.enable_torch_wrap(
vllm_config.compilation_config.ir_enable_torch_wrap
),
):
with override_forward_context(forward_context):
yield
finally:
global last_logging_time, batchsize_logging_interval
+2 -2
View File
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from . import ops
from .op import enable_torch_wrap, register_op
from .op import enable_torch_wrap, register_op, set_default_torch_wrap
__all__ = ["enable_torch_wrap", "register_op", "ops"]
__all__ = ["enable_torch_wrap", "register_op", "set_default_torch_wrap", "ops"]
+48 -29
View File
@@ -51,6 +51,14 @@ _ENABLE_TORCH_WRAP: bool = True
"""Global override flag to control torch op layer wrapping."""
def set_default_torch_wrap(enable: bool = True) -> None:
"""
Permanently set the torch wrap flag.
"""
global _ENABLE_TORCH_WRAP
_ENABLE_TORCH_WRAP = enable
@contextlib.contextmanager
def enable_torch_wrap(enable: bool = True):
"""
@@ -372,42 +380,53 @@ class IrOp:
"""Get the current dispatch priority for implementations for this op."""
return [p.provider for p in self._priority_impls]
def _filter_priority_impls(self, priority: list[str]) -> list["IrOpImpl"]:
assert all(p in self.impls for p in priority), (
"All providers in priority must be registered implementations."
)
filtered_impls: list[IrOpImpl] = []
for p in priority:
impl = self.impls[p]
if not impl.supported:
# Skip unsupported implementations
continue
filtered_impls.append(impl)
# If all args are supported, skip other implementations
if impl.supports_all_args:
return filtered_impls
logger.warning_once(
"Op %s: No implementation in priority list supports all args, "
"execution fallback to native is possible. To silence this warning, "
"explicitly add 'native' to the end of the priority list",
self.name,
)
filtered_impls.append(self.impls["native"])
return filtered_impls
def set_default(self, priority: list[str]) -> None:
"""
Permanently set the dispatch priority for this op. Use this for
process-lifetime setup (e.g., worker startup). For scoped overrides,
use ``set_priority`` instead.
"""
self._priority_impls = self._filter_priority_impls(priority)
logger.debug(
"Priority for vllm.ir.%s set to %s",
self.name,
lazy(lambda: [p.provider for p in self._priority_impls]),
)
@contextlib.contextmanager
def set_priority(self, priority: list[str]):
"""
Context manager to set the dispatch priority for implementations for this op.
"""
assert all(p in self.impls for p in priority), (
"All providers in priority must be registered implementations."
)
def filter_priority_impls(p_list: list[str]) -> list[IrOpImpl]:
filtered_impls = []
for p in p_list:
impl = self.impls[p]
if not impl.supported:
# Skip unsupported implementations
continue
filtered_impls.append(impl)
# If all args are supported, skip other implementations
if impl.supports_all_args:
return filtered_impls
logger.warning_once(
"Op %s: No implementation in priority list supports all args, "
"execution fallback to native is possible. To silence this warning, "
"explicitly add 'native' to the end of the priority list",
self.name,
)
filtered_impls.append(self.impls["native"])
return filtered_impls
# Temporarily set priority
old_priority_impls = self._priority_impls
try:
self._priority_impls = filter_priority_impls(priority)
self._priority_impls = self._filter_priority_impls(priority)
logger.debug(
"Priority for vllm.ir.%s set to %s",
self.name,
+8
View File
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
import torch
import torch.nn as nn
import vllm.ir
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -87,6 +88,13 @@ class WorkerBase:
self.device: torch.device | None = None
self.model_runner: nn.Module | None = None
# IR op priority and torch-wrap state are constant for the worker's
# lifetime.
vllm_config.kernel_config.ir_op_priority.set_default()
vllm.ir.set_default_torch_wrap(
vllm_config.compilation_config.ir_enable_torch_wrap
)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""Get specifications for KV cache implementation."""
raise NotImplementedError