From 8a56da3845270837424ef4b7ee83ca97a7883025 Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Sat, 16 May 2026 22:04:12 +0800 Subject: [PATCH] [Experimental] Breakable CUDA graph (#42304) Signed-off-by: zjy0516 --- .buildkite/test_areas/cuda.yaml | 3 +- .../v1/cudagraph/test_breakable_cudagraph.py | 367 +++++++++++++++ vllm/compilation/breakable_cudagraph.py | 424 ++++++++++++++++++ vllm/config/vllm.py | 13 +- vllm/envs.py | 5 + .../layers/attention/mla_attention.py | 2 + .../layers/deepseek_v4_attention.py | 2 + .../layers/sparse_attn_indexer.py | 2 + .../v1/attention/ops/rocm_aiter_mla_sparse.py | 2 + vllm/v1/cudagraph_dispatcher.py | 5 + vllm/v1/worker/gpu_model_runner.py | 25 +- 11 files changed, 845 insertions(+), 5 deletions(-) create mode 100644 tests/v1/cudagraph/test_breakable_cudagraph.py create mode 100644 vllm/compilation/breakable_cudagraph.py diff --git a/.buildkite/test_areas/cuda.yaml b/.buildkite/test_areas/cuda.yaml index 6254a6ba3dd..b56e635bea6 100644 --- a/.buildkite/test_areas/cuda.yaml +++ b/.buildkite/test_areas/cuda.yaml @@ -27,4 +27,5 @@ steps: - vllm/compilation commands: - pytest -v -s v1/cudagraph/test_cudagraph_dispatch.py - - pytest -v -s v1/cudagraph/test_cudagraph_mode.py \ No newline at end of file + - pytest -v -s v1/cudagraph/test_cudagraph_mode.py + - pytest -v -s v1/cudagraph/test_breakable_cudagraph.py \ No newline at end of file diff --git a/tests/v1/cudagraph/test_breakable_cudagraph.py b/tests/v1/cudagraph/test_breakable_cudagraph.py new file mode 100644 index 00000000000..f856d91b639 --- /dev/null +++ b/tests/v1/cudagraph/test_breakable_cudagraph.py @@ -0,0 +1,367 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for the breakable cudagraph primitives. +""" + +from __future__ import annotations + +import os +import threading + +import pytest +import torch + +os.environ["VLLM_USE_BREAKABLE_CUDAGRAPH"] = "1" + + +@pytest.fixture(autouse=True) +def _reset_breakable_tls(): + """Defensively clear thread-local capture state between tests so a + failure in one test can't leak "nested capture" errors into the next.""" + from vllm.compilation.breakable_cudagraph import BreakableCUDAGraphCapture + + BreakableCUDAGraphCapture._tls.active = None + yield + BreakableCUDAGraphCapture._tls.active = None + + +@pytest.fixture +def cuda_capture_stream(): + """A non-default CUDA stream suitable for cudagraph capture. + + ``CUDAGraph.capture_begin`` refuses to capture from the default + stream, so all capture-using tests need to run under + ``torch.cuda.stream(...)`` for a separate stream. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + yield stream + torch.cuda.current_stream().wait_stream(stream) + + +# --------------------------------------------------------------------------- +# eager_break_during_capture: outside capture +# --------------------------------------------------------------------------- + + +def test_decorator_passthrough_outside_capture(): + from vllm.compilation.breakable_cudagraph import eager_break_during_capture + + calls = [] + + @eager_break_during_capture + def f(x): + calls.append(x) + return x * 2 + + assert f(3) == 6 + assert calls == [3] + + +# --------------------------------------------------------------------------- +# BreakableCUDAGraphCapture: thread-local + nested rejection +# --------------------------------------------------------------------------- + + +def test_current_is_none_when_inactive(): + from vllm.compilation.breakable_cudagraph import BreakableCUDAGraphCapture + + assert BreakableCUDAGraphCapture.current() is None + assert BreakableCUDAGraphCapture.is_active() is False + + +def test_thread_local_active_during_context(cuda_capture_stream): + from vllm.compilation.breakable_cudagraph import BreakableCUDAGraphCapture + + cap = BreakableCUDAGraphCapture() + with cap: + assert BreakableCUDAGraphCapture.current() is cap + assert BreakableCUDAGraphCapture.is_active() is True + assert BreakableCUDAGraphCapture.current() is None + + +def test_nested_capture_raises(cuda_capture_stream): + from vllm.compilation.breakable_cudagraph import BreakableCUDAGraphCapture + + outer = BreakableCUDAGraphCapture() + inner = BreakableCUDAGraphCapture() + with outer, pytest.raises(RuntimeError, match="Nested.*not supported"), inner: + pass + + +def test_active_state_isolated_across_threads(cuda_capture_stream): + """Verify the thread-local 'active capture' slot is per-thread. + + We don't run concurrent captures here -- CUDA only supports one + in-flight capture per stream and we keep tests cheap. We just check + that the worker thread sees its own slot as None while the main + thread has a capture active. + """ + from vllm.compilation.breakable_cudagraph import BreakableCUDAGraphCapture + + worker_view: dict[str, BreakableCUDAGraphCapture | None] = {} + + def worker(): + worker_view["state"] = BreakableCUDAGraphCapture.current() + + main_cap = BreakableCUDAGraphCapture() + with main_cap: + # Main thread has a live capture. + assert BreakableCUDAGraphCapture.current() is main_cap + t = threading.Thread(target=worker) + t.start() + t.join() + + # Worker thread saw None -- thread-local separation works. + assert worker_view["state"] is None + # Main thread's slot is cleared on exit. + assert BreakableCUDAGraphCapture.current() is None + + +# --------------------------------------------------------------------------- +# Segment list construction +# --------------------------------------------------------------------------- + + +def test_capture_with_no_eager_break_records_one_graph(cuda_capture_stream): + from vllm.compilation.breakable_cudagraph import BreakableCUDAGraphCapture + + x = torch.zeros(4, device="cuda") + cap = BreakableCUDAGraphCapture() + with cap: + x.add_(1.0) + assert len(cap.segments) == 1 + assert cap.num_graphs == 1 + assert cap.num_eager_breaks == 0 + + +def test_add_eager_creates_alternating_graph_eager_graph(cuda_capture_stream): + from vllm.compilation.breakable_cudagraph import BreakableCUDAGraphCapture + + x = torch.zeros(4, device="cuda") + counter = {"eager_calls": 0} + + def eager_step(): + counter["eager_calls"] += 1 + x.add_(10.0) + + cap = BreakableCUDAGraphCapture() + with cap: + x.add_(1.0) + cap.add_eager(eager_step) + x.add_(1.0) + cap.add_eager(eager_step) + x.add_(1.0) + # 3 graph segments + 2 eager segments, interleaved as G E G E G. + assert len(cap.segments) == 5 + assert cap.num_graphs == 3 + assert cap.num_eager_breaks == 2 + # Eager fn is stored as-is in the segment list, so we can confirm + # the alternation pattern by identity check. + assert cap.segments[1] is eager_step + assert cap.segments[3] is eager_step + assert counter["eager_calls"] == 2 # only the in-capture invocation + + +# --------------------------------------------------------------------------- +# Capture vs eager numerical equivalence +# --------------------------------------------------------------------------- + + +def test_capture_replay_matches_eager_simple(cuda_capture_stream): + """Verify that replay reproduces the same end-state as a single eager + forward, with an eager break in the middle. + + Note: during capture, the *captured* kernels are recorded but NOT + executed (that's CUDA-graph semantics). Only the eager segments + actually mutate state at capture time. So we check correctness after + ``replay()``, not after ``with cap:`` exits. + """ + from vllm.compilation.breakable_cudagraph import BreakableCUDAGraphCapture + + x = torch.zeros(8, device="cuda") + log: list[str] = [] + + def eager_break_op(): + x.mul_(2.0) + log.append("eager") + + cap = BreakableCUDAGraphCapture() + with cap: + x.add_(1.0) # recorded into graph[0] + cap.add_eager(eager_break_op) # runs eagerly: x *= 2 + x.add_(5.0) # recorded into graph[1] + + # Capture-time: graph kernels were recorded only; eager segment ran + # once on x == 0, leaving x == 0. + torch.accelerator.synchronize() + assert torch.equal(x, torch.zeros(8, device="cuda")) + assert log == ["eager"] + + # Replay with a fresh input: 10 -> 11 -> 22 -> 27. + x.fill_(10.0) + cap.replay() + torch.accelerator.synchronize() + assert torch.equal(x, torch.full((8,), 27.0, device="cuda")) + assert log == ["eager", "eager"] + + # Replay again with another input: 100 -> 101 -> 202 -> 207. + x.fill_(100.0) + cap.replay() + torch.accelerator.synchronize() + assert torch.equal(x, torch.full((8,), 207.0, device="cuda")) + assert log == ["eager", "eager", "eager"] + + +def test_decorator_breaks_when_invoked_inside_capture(cuda_capture_stream): + """Verify @eager_break_during_capture correctly routes through + add_eager when inside a capture context, and runs straight through + when there's no active capture.""" + from vllm.compilation.breakable_cudagraph import ( + BreakableCUDAGraphCapture, + eager_break_during_capture, + ) + + @eager_break_during_capture + def attention_like(t: torch.Tensor) -> None: + # In-place double; stands in for "real" attention work. + t.mul_(2.0) + + x = torch.zeros(4, device="cuda") + + # Outside capture: decorator should just call through. + x.fill_(3.0) + attention_like(x) + torch.accelerator.synchronize() + assert torch.equal(x, torch.full((4,), 6.0, device="cuda")) + + # Inside capture: decorator should split the graph. Only the eager + # segment actually mutates state during capture. + x.fill_(0.0) + cap = BreakableCUDAGraphCapture() + with cap: + x.add_(5.0) # recorded + attention_like(x) # eager: x *= 2 (on x == 0, no-op) + x.add_(1.0) # recorded + torch.accelerator.synchronize() + assert torch.equal(x, torch.zeros(4, device="cuda")) + # 2 graph segments + 1 eager segment, ordered G E G; the arithmetic + # equivalence check below verifies the ordering. + assert len(cap.segments) == 3 + assert cap.num_graphs == 2 + assert cap.num_eager_breaks == 1 + + # Replay: 2 -> 7 -> 14 -> 15. + x.fill_(2.0) + cap.replay() + torch.accelerator.synchronize() + assert torch.equal(x, torch.full((4,), 15.0, device="cuda")) + + +# --------------------------------------------------------------------------- +# Replay ordering +# --------------------------------------------------------------------------- + + +def test_replay_invokes_eager_segments_in_order(cuda_capture_stream): + from vllm.compilation.breakable_cudagraph import BreakableCUDAGraphCapture + + log: list[str] = [] + x = torch.zeros(1, device="cuda") + + def make_eager(name): + def step(): + log.append(name) + x.add_(1.0) + + return step + + cap = BreakableCUDAGraphCapture() + with cap: + x.add_(1.0) + cap.add_eager(make_eager("A")) + x.add_(1.0) + cap.add_eager(make_eager("B")) + x.add_(1.0) + cap.add_eager(make_eager("C")) + x.add_(1.0) + + # Capture-time invocation order + assert log == ["A", "B", "C"] + + log.clear() + cap.replay() + torch.accelerator.synchronize() + assert log == ["A", "B", "C"] + + +# --------------------------------------------------------------------------- +# Capture cleanup releases thread-local even if body raises +# --------------------------------------------------------------------------- + + +def test_exception_in_body_clears_active(cuda_capture_stream): + from vllm.compilation.breakable_cudagraph import BreakableCUDAGraphCapture + + cap = BreakableCUDAGraphCapture() + with pytest.raises(RuntimeError, match="boom"), cap: + raise RuntimeError("boom") + + # active must be reset even after an exception inside the body + assert BreakableCUDAGraphCapture.current() is None + + +# --------------------------------------------------------------------------- +# Nested decorated ops: inner op must not trigger a recursive eager break +# --------------------------------------------------------------------------- + + +def test_nested_decorated_op_runs_inline(cuda_capture_stream): + """A decorated op invoked from inside another decorated op's eager + body must execute inline -- starting a second eager break mid-flight + corrupts the segment state and explodes ``_begin_segment``'s assert. + + This mirrors the deepseek_v4_attention case where the outer attention + op's impl internally dispatches sparse_attn_indexer (also decorated). + """ + from vllm.compilation.breakable_cudagraph import ( + BreakableCUDAGraphCapture, + eager_break_during_capture, + ) + + x = torch.zeros(4, device="cuda") + inner_calls = 0 + + @eager_break_during_capture + def inner_op(t: torch.Tensor) -> None: + nonlocal inner_calls + inner_calls += 1 + t.add_(1.0) + + @eager_break_during_capture + def outer_op(t: torch.Tensor) -> None: + # outer body calls another decorated op -- this is the case that + # used to assert in _begin_segment. + inner_op(t) + t.add_(10.0) + + cap = BreakableCUDAGraphCapture() + with cap: + x.add_(2.0) # recorded in graph[0] + outer_op(x) # one eager break, inner runs inline + x.add_(100.0) # recorded in graph[1] + + # Exactly one eager break (the outer); inner must NOT add a second. + assert cap.num_graphs == 2 + assert cap.num_eager_breaks == 1 + assert inner_calls == 1 # only the capture-time invocation + + x.fill_(0.0) + cap.replay() + torch.accelerator.synchronize() + # 0 -> +2 -> +1 (inner) -> +10 (outer) -> +100 = 113 + assert torch.equal(x, torch.full((4,), 113.0, device="cuda")) + assert inner_calls == 2 # replay invokes the outer's lambda again diff --git a/vllm/compilation/breakable_cudagraph.py b/vllm/compilation/breakable_cudagraph.py new file mode 100644 index 00000000000..6da3ec71786 --- /dev/null +++ b/vllm/compilation/breakable_cudagraph.py @@ -0,0 +1,424 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Breakable CUDA graph capture/replay. + +This is an alternative to :class:`CUDAGraphWrapper` that replaces vLLM's +torch.compile-based FX graph splitting with runtime stream-capture +breaks. + +The idea (inspired by sgl-project/sglang#19102): instead of pre-splitting +the model into many pieces at attention boundaries, a +single capture context drives the whole forward and intercepts +attention / kv-cache custom ops at the dispatcher to end the current +stream capture, run the op eagerly, and resume capture. + +The captured artifact is a list of zero-arg callables -- the bound +``CUDAGraph.replay`` for graph segments, or the user fn for eager +segments -- replayed in order at inference time. + +Eager segments must operate on the same static buffers used during +capture so subsequent graph segments read the same memory addresses. +""" + +from __future__ import annotations + +import dataclasses +import functools +import gc +import threading +import weakref +from collections.abc import Callable +from typing import Any, ClassVar, TypeVar + +import torch + +import vllm.envs as envs +from vllm.compilation.monitor import validate_cudagraph_capturing_enabled +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id +from vllm.forward_context import ( + BatchDescriptor, + get_forward_context, + is_forward_context_available, +) +from vllm.logger import init_logger +from vllm.model_executor.offloader.base import get_offloader +from vllm.platforms import current_platform +from vllm.utils.torch_utils import weak_ref_tensor, weak_ref_tensors + +logger = init_logger(__name__) + + +def is_breakable_cudagraph_enabled() -> bool: + return bool(envs.VLLM_USE_BREAKABLE_CUDAGRAPH) + + +F = TypeVar("F", bound=Callable[..., Any]) + + +def eager_break_during_capture(fn: F) -> F: + """Decorator that turns a custom-op Python kernel into a "break point" + for the breakable cudagraph capture. + + When the decorated function is invoked outside of a + :class:`BreakableCUDAGraphCapture` context, it executes normally. + + When invoked inside a capture context, it ends the current cudagraph + segment, runs the function eagerly on the capture stream, records the + callable for replay, and starts a fresh segment. + + **In-place output buffer required.** Decorated ops must write into a + caller-provided output tensor; a fresh tensor returned by ``fn`` would + change address each replay and break downstream graph segments. + + **Decorator order matters.** Apply as the *outermost* decorator if + there are other decorators that introduce host-side side effects + around the call -- the canonical example is + ``@maybe_transfer_kv_layer`` for PD-disaggregation, whose + ``wait_for_layer_load`` and ``save_kv_layer`` calls must run in the + eager segment, not inside the captured cudagraph. Putting + ``@eager_break_during_capture`` *inside* such a decorator would + record those side effects into the graph and hang on replay. + + The correct order is:: + + @eager_break_during_capture # outermost + @maybe_transfer_kv_layer + def unified_attention_with_output(...): + ... + """ + if not is_breakable_cudagraph_enabled(): + return fn + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + capture = BreakableCUDAGraphCapture.current() + if capture is None: + return fn(*args, **kwargs) + if not capture._capturing: + return fn(*args, **kwargs) + if is_forward_context_available(): + mode = get_forward_context().cudagraph_runtime_mode + if mode == CUDAGraphMode.FULL: + return fn(*args, **kwargs) + + # Weak-ref args: strong refs in the replay lambda pin cudagraph-pool + # slots across batch descriptors. cudagraph owns the slot, so the + # weak_ref is safe to deref on replay. + weak_args = tuple( + weak_ref_tensor(a) if isinstance(a, torch.Tensor) else a for a in args + ) + weak_kwargs = { + k: weak_ref_tensor(v) if isinstance(v, torch.Tensor) else v + for k, v in kwargs.items() + } + return capture.add_eager(lambda: fn(*weak_args, **weak_kwargs)) + + return wrapper # type: ignore[return-value] + + +# --------------------------------------------------------------------------- +# Capture context +# --------------------------------------------------------------------------- + + +class BreakableCUDAGraphCapture: + """Stream-capture context that supports eager breaks via :meth:`add_eager`. + + Usage:: + + cap = BreakableCUDAGraphCapture(pool=...) + with cap: + output = model(*static_inputs) + # Later, after copying new inputs into the static buffers: + cap.replay() + # Output tensors live at the same addresses as during capture. + + Thread-local: only one capture may be active per thread. + """ + + _tls = threading.local() + + @classmethod + def current(cls) -> BreakableCUDAGraphCapture | None: + return getattr(cls._tls, "active", None) + + @classmethod + def is_active(cls) -> bool: + return cls.current() is not None + + def __init__(self, pool: Any | None = None) -> None: + self.pool = pool + self.segments: list[Callable[[], Any]] = [] + self._num_graphs: int = 0 + self._num_eager_breaks: int = 0 + self._current_graph: torch.cuda.CUDAGraph | None = None + self._capturing: bool = False + + # --- context manager protocol ---------------------------------------- + + def __enter__(self) -> BreakableCUDAGraphCapture: + if getattr(BreakableCUDAGraphCapture._tls, "active", None) is not None: + raise RuntimeError("Nested BreakableCUDAGraphCapture is not supported.") + BreakableCUDAGraphCapture._tls.active = self + self._begin_segment() + return self + + def __exit__(self, exc_type, exc, tb) -> None: + try: + self._end_segment() + finally: + BreakableCUDAGraphCapture._tls.active = None + + # --- segment management ---------------------------------------------- + + def _begin_segment(self) -> None: + assert not self._capturing + g = torch.cuda.CUDAGraph() + if self.pool is not None: + g.capture_begin(pool=self.pool) + else: + g.capture_begin() + self._current_graph = g + self._capturing = True + + def _end_segment(self) -> None: + if not self._capturing: + return + assert self._current_graph is not None + self._current_graph.capture_end() + self.segments.append(self._current_graph.replay) + self._num_graphs += 1 + self._current_graph = None + self._capturing = False + + def add_eager(self, fn: Callable[[], Any]) -> Any: + """End the current capture segment, run ``fn`` eagerly on the + capture stream, record ``fn`` for replay, and start a new segment. + + Returns whatever ``fn`` returned during this (capture-time) call. + Replay does not return values; callers should propagate any + downstream dependencies via static output buffers. + """ + self._end_segment() + result = fn() + self.segments.append(fn) + self._num_eager_breaks += 1 + self._begin_segment() + return result + + # --- replay ---------------------------------------------------------- + + def replay(self) -> None: + for r in self.segments: + r() + + # --- introspection --------------------------------------------------- + + @property + def num_graphs(self) -> int: + return self._num_graphs + + @property + def num_eager_breaks(self) -> int: + return self._num_eager_breaks + + def __repr__(self) -> str: + return ( + f"BreakableCUDAGraphCapture(graphs={self.num_graphs}, " + f"eager_breaks={self.num_eager_breaks})" + ) + + +# --------------------------------------------------------------------------- +# Wrapper that mirrors CUDAGraphWrapper's interface +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class _BreakableEntry: + batch_descriptor: BatchDescriptor + capture: BreakableCUDAGraphCapture | None = None + output: Any = None + input_addresses: list[int] | None = None + + +class BreakableCUDAGraphWrapper: + """Drop-in replacement for :class:`CUDAGraphWrapper` that uses + :class:`BreakableCUDAGraphCapture` instead of a single monolithic + ``torch.cuda.graph()`` capture. + + Same dispatch contract as ``CUDAGraphWrapper``: + * If no ``forward_context`` is available, run the underlying + callable eagerly. + * If runtime mode mismatch / NONE, run eagerly. + * Otherwise, lazily capture per ``batch_descriptor`` and replay + on subsequent invocations with the same descriptor. + """ + + _all_instances: ClassVar[weakref.WeakSet[BreakableCUDAGraphWrapper]] = ( + weakref.WeakSet() + ) + + @classmethod + def clear_all_graphs(cls) -> None: + for instance in list(cls._all_instances): + instance.clear_graphs() + + def __init__( + self, + runnable: Callable[..., Any], + vllm_config: VllmConfig, + ) -> None: + # Unlike the original CUDAGraphWrapper which strictly matches a + # single runtime_mode, this wrapper captures whatever the + # dispatcher emits (any non-NONE runtime_mode) -- breakable's + # capture is identical for prefill and decode, so there's nothing + # to dispatch on at the runtime_mode level. Entries are keyed by + # BatchDescriptor which already encodes batch shape / uniformity. + self.runnable = runnable + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.graph_pool = current_platform.get_global_graph_pool() + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + self.entries: dict[BatchDescriptor, _BreakableEntry] = {} + BreakableCUDAGraphWrapper._all_instances.add(self) + + logger.info_once("Breakable CUDA graph enabled") + + # --- vllm-style attribute forwarding --------------------------------- + + def __getattr__(self, key: str) -> Any: + runnable = self.__dict__.get("runnable") + if runnable is not None and hasattr(runnable, key): + return getattr(runnable, key) + raise AttributeError(key) + + def unwrap(self) -> Callable[..., Any]: + return self.runnable + + @property + def cudagraph_wrapper(self) -> BreakableCUDAGraphWrapper: + return self + + def clear_graphs(self) -> None: + self.entries.clear() + + # --- dispatch -------------------------------------------------------- + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if not is_forward_context_available(): + return self.runnable(*args, **kwargs) + + forward_context = get_forward_context() + batch_descriptor = forward_context.batch_descriptor + cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode + + # Capture whenever the dispatcher says "some cudagraph mode" -- + # breakable produces the same artifact regardless of PIECEWISE + # vs FULL, so we match either. Entries are keyed by batch + # descriptor, which already encodes prefill/decode distinctions. + if cudagraph_runtime_mode == CUDAGraphMode.NONE: + return self.runnable(*args, **kwargs) + + assert batch_descriptor is not None + entry = self.entries.get(batch_descriptor) + if entry is None: + entry = _BreakableEntry(batch_descriptor=batch_descriptor) + self.entries[batch_descriptor] = entry + + if entry.capture is None: + return self._capture(entry, args, kwargs) + return self._replay(entry, args, kwargs) + + # --- capture / replay paths ----------------------------------------- + + @staticmethod + def _collect_tensor_addresses( + args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> list[int]: + """Flatten tensor data_ptrs from positional and keyword args in a + stable order (positionals first, then kwargs in insertion order). + + Used for the DEBUG-mode address-stability check; covers both call + styles since vLLM models are typically invoked with kwargs. + """ + addrs = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)] + addrs.extend( + v.data_ptr() for v in kwargs.values() if isinstance(v, torch.Tensor) + ) + return addrs + + def _capture( + self, + entry: _BreakableEntry, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + validate_cudagraph_capturing_enabled() + + entry.input_addresses = self._collect_tensor_addresses(args, kwargs) + + if self.graph_pool is not None: + set_graph_pool_id(self.graph_pool) + else: + set_graph_pool_id(current_platform.graph_pool_handle()) + + # Match torch.cuda.graph()'s pre-capture cleanup once per descriptor. + # We drive capture_begin/end directly and bypass torch.cuda.graph(), + # so its built-in gc + empty_cache never fire. Run them here once + # per _capture call -- NOT inside _begin_segment, since this capture + # session may issue many begin/end pairs (one per layer's break), + # and repeated gc would tank capture time the way it did for the + # pre-`gc_disable` piecewise path. + gc.collect() + torch.accelerator.empty_cache() + # Sync the offloader's copy stream before capture so any in-flight + # pre-capture prefetches are complete and don't leak into the graph. + get_offloader().sync_prev_onload() + + capture = BreakableCUDAGraphCapture(pool=self.graph_pool) + with capture: + output = self.runnable(*args, **kwargs) + # Join the offloader's copy stream while we still hold the last + # segment open, so the join is captured into the graph (otherwise + # we get an "unjoined stream" error on subsequent forwards). + get_offloader().join_after_forward() + # Convert output to a weak ref *inside* the capture context so the + # strong ref is dropped before the last segment closes, letting + # the cudagraph pool reclaim/reuse that memory immediately for + # the next batch descriptor's capture. + output = weak_ref_tensors(output) + + entry.capture = capture + entry.output = weak_ref_tensors(output) + + logger.debug( + "Captured breakable cudagraph for %s: %r", + entry.batch_descriptor, + capture, + ) + # Return the (already-weak) output from the captured run so the + # caller of model(...) gets a tensor pointing at the cudagraph pool's memory + return output + + def _replay( + self, + entry: _BreakableEntry, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + if self.is_debugging_mode and entry.input_addresses is not None: + new_addresses = self._collect_tensor_addresses(args, kwargs) + assert new_addresses == entry.input_addresses, ( + "Input tensor addresses changed between capture and replay " + f"for {entry.batch_descriptor}. Expected " + f"{entry.input_addresses}, got {new_addresses}." + ) + # Sync the offloader's copy stream before replay so any external + # dependencies from pre-capture prefetches are satisfied. + get_offloader().sync_prev_onload() + assert entry.capture is not None + entry.capture.replay() + return entry.output diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index f9090ff588b..5d45de6ff33 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1035,6 +1035,13 @@ class VllmConfig: ) self.compilation_config.mode = CompilationMode.NONE + if envs.VLLM_USE_BREAKABLE_CUDAGRAPH: + logger.warning_once( + "VLLM_USE_BREAKABLE_CUDAGRAPH is set, disabling vLLM's " + "torch.compile pipeline. Equivalent to -cc.mode=none." + ) + self.compilation_config.mode = CompilationMode.NONE + if self.compilation_config.backend == "eager" or ( self.compilation_config.mode is not None and self.compilation_config.mode != CompilationMode.VLLM_COMPILE @@ -1102,6 +1109,7 @@ class VllmConfig: if ( self.compilation_config.cudagraph_mode.requires_piecewise_compilation() and self.compilation_config.mode != CompilationMode.VLLM_COMPILE + and not envs.VLLM_USE_BREAKABLE_CUDAGRAPH ): logger.info( "Cudagraph mode %s is not compatible with compilation mode %s." @@ -1335,7 +1343,10 @@ class VllmConfig: ) if self.compilation_config.cudagraph_mode.requires_piecewise_compilation(): - assert self.compilation_config.mode == CompilationMode.VLLM_COMPILE, ( + assert ( + self.compilation_config.mode == CompilationMode.VLLM_COMPILE + or envs.VLLM_USE_BREAKABLE_CUDAGRAPH + ), ( "Compilation mode should be CompilationMode.VLLM_COMPILE " "when cudagraph_mode piecewise cudagraphs is used, " f"cudagraph_mode={self.compilation_config.cudagraph_mode}" diff --git a/vllm/envs.py b/vllm/envs.py index 7c9ee0ea9df..9d3542c1bee 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -145,6 +145,7 @@ if TYPE_CHECKING: VLLM_DP_SIZE: int = 1 VLLM_USE_STANDALONE_COMPILE: bool = True VLLM_ENABLE_PREGRAD_PASSES: bool = True + VLLM_USE_BREAKABLE_CUDAGRAPH: bool = False VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False @@ -637,6 +638,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ENABLE_PREGRAD_PASSES": lambda: ( os.environ.get("VLLM_ENABLE_PREGRAD_PASSES", "1") == "1" ), + # Experimental: breakable cudagraph does not rely on torch.compile + "VLLM_USE_BREAKABLE_CUDAGRAPH": lambda: ( + os.environ.get("VLLM_USE_BREAKABLE_CUDAGRAPH", "0") == "1" + ), # Debug pattern matching inside custom passes. # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). "VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get( diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 3d6efc38563..71fd297a7ed 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -200,6 +200,7 @@ from tqdm import tqdm import vllm.envs as envs from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops +from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.config import ( CacheConfig, ModelConfig, @@ -1036,6 +1037,7 @@ direct_register_custom_op( ) +@eager_break_during_capture @maybe_transfer_kv_layer def unified_mla_attention_with_output( q: torch.Tensor, diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index e9a6ec9a587..996534b05f1 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from transformers import DeepseekV2Config, DeepseekV3Config import vllm.envs as envs +from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.model_executor.layers.linear import ( ReplicatedLinear, ) @@ -553,6 +554,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): ) +@eager_break_during_capture def deepseek_v4_attention( hidden_states: torch.Tensor, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 4bf52a49c43..9f202435a73 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -6,6 +6,7 @@ import torch import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops +from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -81,6 +82,7 @@ def kv_cache_as_quant_view( return kv_cache.unsqueeze(-2) +@eager_break_during_capture def sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index c653f051ebc..c3f1b29465e 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -8,6 +8,7 @@ from importlib.util import find_spec import torch import torch.nn.functional as F +from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.forward_context import get_forward_context from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -836,6 +837,7 @@ def rocm_aiter_sparse_attn_indexer_native( return topk_indices_buffer +@eager_break_during_capture def rocm_aiter_sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index e27b5ee3883..cf0c1d41772 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -46,9 +46,14 @@ class CudagraphDispatcher: CUDAGraphMode.FULL: set(), } + from vllm.compilation.breakable_cudagraph import ( + is_breakable_cudagraph_enabled, + ) + assert ( not self.compilation_config.cudagraph_mode.requires_piecewise_compilation() or self.compilation_config.is_attention_compiled_piecewise() + or is_breakable_cudagraph_enabled() ), ( "Compilation mode should be CompilationMode.VLLM_COMPILE when " "cudagraph_mode piecewise cudagraphs is used, " diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a30e9275b18..2c010040bc2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -21,6 +21,10 @@ import torch.nn as nn from tqdm import tqdm import vllm.envs as envs +from vllm.compilation.breakable_cudagraph import ( + BreakableCUDAGraphWrapper, + is_breakable_cudagraph_enabled, +) from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled @@ -3127,7 +3131,9 @@ class GPUModelRunner( def get_model(self) -> nn.Module: if not hasattr(self, "model"): raise ValueError("Cannot get model before model has been initialized") - if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)): + if isinstance( + self.model, (CUDAGraphWrapper, UBatchWrapper, BreakableCUDAGraphWrapper) + ): # get raw model out of the cudagraph wrapper. return self.model.unwrap() return self.model @@ -5130,6 +5136,12 @@ class GPUModelRunner( cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None if ( + is_breakable_cudagraph_enabled() + and cudagraph_mode != CUDAGraphMode.NONE + and not self.parallel_config.use_ubatching + ): + self.model = BreakableCUDAGraphWrapper(self.model, self.vllm_config) + elif ( cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.use_ubatching ): @@ -6206,7 +6218,10 @@ class GPUModelRunner( # Use a temporary pool for profiling to avoid fragmentation in the main pool. profiling_pool = current_platform.graph_pool_handle() original_pools: dict[int, Any] = {} - for instance in list(CUDAGraphWrapper._all_instances): + all_wrappers = list(CUDAGraphWrapper._all_instances) + list( + BreakableCUDAGraphWrapper._all_instances + ) + for instance in all_wrappers: original_pools[id(instance)] = instance.graph_pool instance.graph_pool = profiling_pool @@ -6257,7 +6272,11 @@ class GPUModelRunner( set_cudagraph_capturing_enabled(False) CUDAGraphWrapper.clear_all_graphs() - for instance in list(CUDAGraphWrapper._all_instances): + BreakableCUDAGraphWrapper.clear_all_graphs() + all_wrappers = list(CUDAGraphWrapper._all_instances) + list( + BreakableCUDAGraphWrapper._all_instances + ) + for instance in all_wrappers: if id(instance) in original_pools: instance.graph_pool = original_pools[id(instance)] for key_set in self.cudagraph_dispatcher.cudagraph_keys.values():