[None][refactor] refactor the CUDA graph runner to manage all CUDA graphs (#6846)

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
QI JUN 2025-08-25 20:52:05 +08:00 committed by GitHub
parent b32e00e9fd
commit bea5e07fb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 513 additions and 384 deletions

View File

@ -1,122 +1,309 @@
from typing import Any, Callable, Dict, Optional, Tuple
import bisect
import contextlib
import weakref
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
import torch
from ..attention_backend.interface import AttentionMetadata
from ..expert_statistic import ExpertStatistic
from ..modules.multi_stream_utils import with_multi_stream
from ..speculative.interface import SpecMetadata
from ..utils import make_weak_ref, piecewise_cuda_graph
from .resource_manager import ResourceManager, ResourceManagerType
from .scheduler import ScheduledRequests
if TYPE_CHECKING:
from .model_engine import PyTorchModelEngine
# A large prime number used for dummy request IDs to avoid collisions
CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1
class DecodingCUDAGraphRunner:
class CUDAGraphRunner:
"""
Manages the lifecycle and execution of CUDA graphs for the model engine.
def __init__(
self,
batch_size: int,
device: str,
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata] = None,
use_mrope: bool = False,
max_beam_width: int = 1,
) -> None:
This unified class handles high-level orchestration (padding, eligibility)
and low-level execution (capturing, resource management, replaying) for
multiple graphs, keyed by (batch size, draft_len).
"""
WARMUP_STEPS = 2
def __init__(self, engine: "PyTorchModelEngine"):
self.engine_ref = weakref.ref(engine)
# High-level configuration
config = engine.pytorch_backend_config
self.enabled = config.use_cuda_graph
self.padding_enabled = config.cuda_graph_padding_enabled
self.supported_batch_sizes = engine._cuda_graph_batch_sizes
self.max_supported_batch_size = engine._max_cuda_graph_batch_size
self.max_beam_width = engine.max_beam_width
self.spec_config = engine.spec_config
self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
self.static_inputs: Dict[Tuple[int, int], Dict[str, torch.Tensor]] = {}
self.graph_outputs: Dict[Tuple[int, int],
Callable[[], Optional[torch.Tensor]]] = {}
self.graph_metadata: Dict[Tuple[int, int], Dict[str, Any]] = {}
self.memory_pool = engine._cuda_graph_mem_pool
self.padding_dummy_request: Optional["Request"] = None
@property
def enable_spec_decode(self):
return self._get_engine().is_spec_decode
@property
def draft_len(self):
return self.spec_config.max_draft_len if self.enable_spec_decode else 0
@property
def spec_metadata(self):
return self._get_engine().spec_metadata
@property
def draft_tokens_cuda(self):
return self._get_engine().draft_tokens_cuda
@property
def attn_metadata(self):
return self._get_engine().attn_metadata
def __del__(self):
self.clear()
def _get_engine(self) -> "PyTorchModelEngine":
"""Safely dereferences the weak reference to the engine."""
engine = self.engine_ref()
if engine is None:
raise RuntimeError(
"The parent PyTorchModelEngine has been garbage collected.")
return engine
def maybe_get_cuda_graph(self, batch: ScheduledRequests):
"""
Stores a CUDA graph and its associated input buffers.
Determines if the current batch can be run with a CUDA graph.
Each CUDA graph runner is associated with an AttentionMetadata object
if flashinfer is being used. Make sure to call attn_metadata.prepare()
before run()!
Note that torch.compile w/ mode reduce-overhead supports CUDA graphs
with memory pool sharing. However, we have our own manager here because,
at the time of writing this, torch.compile takes way too long to warmup
graphs compared to doing it manually (not to mention, custom ops from
e.g. FlashInfer cause graph breaks).
Returns a tuple containing:
- A boolean indicating if a graph can be used.
- The attn_metadata for the graph, if applicable.
- The spec_metadata for the graph, if applicable.
"""
self.batch_size = batch_size
self.max_beam_width = max_beam_width
engine = self._get_engine()
# disable when doing statistic
if hasattr(engine, 'iter_counter') and ExpertStatistic.set_iter(
engine.iter_counter):
return False, None, None
can_run_cuda_graph = batch.can_run_cuda_graph
batch_size = batch.batch_size
if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1:
all_can_graph_batch = engine.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
is_all_gen_only = all(all_can_graph[0]
for all_can_graph in all_can_graph_batch)
all_batch_size_equal = all(
all_gen_only[1] == all_can_graph_batch[0][1]
for all_gen_only in all_can_graph_batch)
if not is_all_gen_only or not all_batch_size_equal:
return False, None, None
if not self.enabled or not can_run_cuda_graph:
return False, None, None
key = (batch_size, self.draft_len)
if key in self.graphs:
return True, self.graph_metadata[key][
"attn_metadata"], self.graph_metadata[key]["spec_metadata"]
if batch_size not in self.supported_batch_sizes:
return False, None, None
num_sequences_in_batch = batch_size * self.max_beam_width
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
num_sequences_in_batch, False, self.draft_len)
assert attn_metadata.is_cuda_graph
if self.enable_spec_decode:
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
num_sequences_in_batch)
spec_metadata.draft_tokens = self.draft_tokens_cuda
else:
spec_metadata = None
return True, attn_metadata, spec_metadata
def needs_capture(self, batch_size: int):
return (batch_size, self.draft_len) not in self.graph_outputs
def capture(self, batch_size: int, forward_fn: Callable,
initial_inputs: Dict[str, Any]):
"""Captures the forward pass for a given batch size."""
engine = self._get_engine()
key = (batch_size, self.draft_len)
spec_metadata = initial_inputs.get("spec_metadata", None)
# [CUDA graph spec decode padding]
# We pad input IDs/position IDs to the maximum draft length (token per request).
# We're forced to do this because we cannot reallocate inputs over many graph runs.
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1
# Using ones instead of zeros prevents NaNs in e.g. Deepseek
self.input_ids = torch.ones(
(batch_size * max_beam_width * token_per_request, ),
device=device,
dtype=torch.int32)
self.position_ids = torch.zeros(
(1, batch_size * max_beam_width * token_per_request),
device=device,
dtype=torch.int32)
self.mrope_position_deltas = torch.zeros(
(batch_size,
1), device=device, dtype=torch.int32) if use_mrope else None
static_tensors = {
"input_ids":
torch.ones((batch_size * self.max_beam_width * token_per_request, ),
device="cuda",
dtype=torch.int32),
"position_ids":
torch.zeros((
1,
batch_size * self.max_beam_width * token_per_request,
),
device="cuda",
dtype=torch.int32),
}
if engine.use_mrope:
static_tensors["mrope_position_deltas"] = torch.zeros(
(batch_size, 1), device="cuda", dtype=torch.int32)
self.static_inputs[key] = static_tensors
self.attn_metadata = attn_metadata
self.spec_metadata = spec_metadata
self._output = None
self._graph = None
self.optional_extra_model_inputs = ["mrope_position_deltas"]
capture_inputs = initial_inputs.copy()
capture_inputs.update(static_tensors)
def __del__(self):
self._graph.reset()
def capture(
self,
forward_fn: Callable[[Dict[str, Any]], torch.Tensor],
pool: Optional[Tuple[int, int]] = None,
) -> Tuple[int, int]:
self._graph = torch.cuda.CUDAGraph()
inputs = {
"attn_metadata": self.attn_metadata,
"input_ids": self.input_ids,
"position_ids": self.position_ids,
"inputs_embeds": None,
"spec_metadata": self.spec_metadata,
"mrope_position_deltas": self.mrope_position_deltas,
self.graph_metadata[key] = {
"attn_metadata": initial_inputs["attn_metadata"],
"spec_metadata": spec_metadata,
}
# We have to do warm up runs to initialize PyTorch's
# internal states according to the docs:
# https://pytorch.org/docs/stable/notes/cuda.html#cuda-graph-semantics
# This also lets us initialize states in the attn_metadata.
graph = torch.cuda.CUDAGraph()
with with_multi_stream(True), piecewise_cuda_graph(False):
for _ in range(2):
forward_fn(inputs)
with torch.cuda.graph(self._graph, pool=pool):
output = forward_fn(inputs)
# Mark weak ref here. The output tensor should be freed properly.
self._output = make_weak_ref(output)
return self._graph.pool()
for _ in range(self.WARMUP_STEPS):
forward_fn(capture_inputs)
with torch.cuda.graph(graph, pool=self.memory_pool):
output = forward_fn(capture_inputs)
def needs_capture(self) -> bool:
return self._output is None
self.graphs[key] = graph
self.graph_outputs[key] = make_weak_ref(output)
self.memory_pool = graph.pool()
def run(self, inputs: Dict[str, Any]) -> torch.Tensor:
assert "input_ids" in inputs
assert "position_ids" in inputs
assert "attn_metadata" in inputs
def replay(self, batch_size: int,
current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
"""Replays a previously captured graph."""
key = (batch_size, self.draft_len)
stored_meta = self.graph_metadata[key]
assert current_inputs["attn_metadata"] is stored_meta["attn_metadata"]
if stored_meta["spec_metadata"] is not None:
assert current_inputs.get(
"spec_metadata") is stored_meta["spec_metadata"]
attn_metadata = inputs["attn_metadata"]
assert attn_metadata is self.attn_metadata, (
"attn_metadata does not match the attn_metadata instance that was used to "
"capture this graph.")
static_tensors = self.static_inputs[key]
if "spec_metadata" in inputs:
spec_metadata = inputs["spec_metadata"]
assert spec_metadata is self.spec_metadata, (
"spec_metadata does not match the spec_metadata instance that was used to "
"capture this graph.")
input_ids = inputs["input_ids"]
position_ids = inputs["position_ids"]
input_ids = current_inputs["input_ids"]
seqlen = input_ids.shape[0]
self.input_ids[:seqlen].copy_(input_ids)
self.position_ids[:, :seqlen].copy_(position_ids)
if "mrope_position_deltas" in inputs:
self.mrope_position_deltas[:self.batch_size].copy_(
inputs["mrope_position_deltas"])
static_tensors["input_ids"][:seqlen].copy_(input_ids)
assert self._output is not None and self._graph is not None
self._graph.replay()
return self._output
position_ids = current_inputs["position_ids"]
static_tensors["position_ids"][:, :seqlen].copy_(position_ids)
if "mrope_position_deltas" in current_inputs:
assert "mrope_position_deltas" in static_tensors
static_tensors["mrope_position_deltas"][:batch_size].copy_(
current_inputs["mrope_position_deltas"])
self.graphs[key].replay()
output_ref = self.graph_outputs[key]
return output_ref
def _get_padded_batch(self, batch: ScheduledRequests,
resource_manager: ResourceManager) -> int:
engine = self._get_engine()
kv_cache_manager = resource_manager.get_resource_manager(
engine.kv_cache_manager_key)
can_run_cuda_graph = batch.can_run_cuda_graph
batch_size = batch.batch_size
new_batch_size = batch_size
if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1:
graph_batch_size = engine.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
all_can_graph = all(graph_batch[0]
for graph_batch in graph_batch_size)
if all_can_graph:
new_batch_size = max(gen_only_batch[1]
for gen_only_batch in graph_batch_size)
if (not self.enabled or not self.padding_enabled
or not can_run_cuda_graph
or new_batch_size > self.max_supported_batch_size):
return 0
padded_batch_size = self._round_up_batch_size(new_batch_size)
if batch_size == padded_batch_size:
return 0
padding_size = padded_batch_size - batch_size
if padding_size + batch.batch_size > engine.batch_size:
return 0
# No padding if it would create too many concurrent requests.
# This is not strictly required, but we should probably
# respect the requirement just in case that changes in the future.
if self.padding_dummy_request is None:
available_blocks = kv_cache_manager.get_num_free_blocks()
# No padding if not enough KV cache space
if available_blocks < 1:
return 0
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
[CUDA_GRAPH_DUMMY_REQUEST_ID],
is_gen=True,
max_num_draft_tokens=engine.max_draft_len,
use_mrope=engine.use_mrope,
max_beam_width=engine.max_beam_width)[0]
self.padding_dummy_request.is_cuda_graph_dummy = True
spec_res_mgr = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
if spec_res_mgr:
spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID])
batch.generation_requests.extend([self.padding_dummy_request] *
padding_size)
return padding_size
def _round_up_batch_size(self, batch_size: int) -> int:
"""Finds the smallest supported graph batch size >= the given size."""
if not self.supported_batch_sizes:
return 0
idx = bisect.bisect_left(self.supported_batch_sizes, batch_size)
if idx == len(self.supported_batch_sizes):
return 0
return self.supported_batch_sizes[idx]
@contextlib.contextmanager
def pad_batch(self, scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager):
"""Context manager to pad a batch to a graph-compatible size."""
padding_size = self._get_padded_batch(scheduled_requests,
resource_manager)
try:
yield scheduled_requests
finally:
if padding_size > 0:
scheduled_requests.generation_requests = scheduled_requests.generation_requests[:
-padding_size]
def clear(self):
"""Releases all captured graphs and the associated memory pool."""
for graph in self.graphs.values():
graph.reset()
self.graphs.clear()
self.static_inputs.clear()
self.graph_outputs.clear()
self.graph_metadata.clear()
del self.memory_pool
self.memory_pool = None
torch.cuda.empty_cache()

View File

@ -1,4 +1,3 @@
import bisect
import contextlib
import copy
import functools
@ -57,7 +56,7 @@ from ..utils import (get_model_extra_attrs, set_torch_compiling,
with_model_extra_attrs)
from .config import LoadFormat, PyTorchConfig
from .config_utils import is_mla
from .cuda_graph_runner import DecodingCUDAGraphRunner
from .cuda_graph_runner import CUDAGraphRunner
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
from .llm_request import get_draft_token_length
from .resource_manager import (BaseResourceManager, KVCacheManager,
@ -422,7 +421,6 @@ class PyTorchModelEngine(ModelEngine):
self.iter_states = {}
self._cuda_graphs = {}
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
self._run_cuda_graphs = pytorch_backend_config.use_cuda_graph
self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled
@ -451,7 +449,7 @@ class PyTorchModelEngine(ModelEngine):
# with different KV cache managers.
self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER
self.lora_model_config: Optional[LoraModelConfig] = None
self.cuda_graph_dummy_request = None
self.cuda_graph_runner = CUDAGraphRunner(self)
# Setup the local cache indirection buffer only once and reuse it.
# This way it can also be used for CUDA graphs.
@ -541,12 +539,12 @@ class PyTorchModelEngine(ModelEngine):
@contextlib.contextmanager
def no_cuda_graph(self):
_run_cuda_graphs = self._run_cuda_graphs
self._run_cuda_graphs = False
_run_cuda_graphs = self.cuda_graph_runner.enabled
self.cuda_graph_runner.enabled = False
try:
yield
finally:
self._run_cuda_graphs = _run_cuda_graphs
self.cuda_graph_runner.enabled = _run_cuda_graphs
@with_warmup_flag
def warmup(self, resource_manager: ResourceManager) -> None:
@ -561,7 +559,7 @@ class PyTorchModelEngine(ModelEngine):
# The lifetime of model engine and kv cache manager can be different.
# Reset the global cuda graph dummy request to None in warmup.
self.cuda_graph_dummy_request = None
self.cuda_graph_runner.padding_dummy_request = None
def get_cuda_graph_warmup_request(batch_size, draft_len):
# Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel.
@ -756,7 +754,7 @@ class PyTorchModelEngine(ModelEngine):
AutoTuner.get().print_profiling_cache()
if not (self._run_cuda_graphs
if not (self.cuda_graph_runner.enabled
or self._torch_compile_piecewise_cuda_graph):
return
@ -889,152 +887,6 @@ class PyTorchModelEngine(ModelEngine):
is_draft_model=self.is_draft_model)
return self.spec_metadata
def _get_padded_batch(
self,
scheduled_requests: ScheduledRequests,
kv_cache_manager,
spec_resource_manager: Optional[BaseResourceManager] = None) -> int:
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
batch_size = scheduled_requests.batch_size
new_batch_size = batch_size
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
graph_batch_size = self.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
all_can_graph = all(graph_batch[0]
for graph_batch in graph_batch_size)
if all_can_graph:
new_batch_size = max(gen_only_batch[1]
for gen_only_batch in graph_batch_size)
if (not self._run_cuda_graphs or not self._cuda_graph_padding_enabled
or not can_run_cuda_graph
or new_batch_size > self._max_cuda_graph_batch_size):
return 0
padded_batch_size = self._round_up_batch_size(new_batch_size)
if batch_size == padded_batch_size:
return 0
padding_size = padded_batch_size - batch_size
if padding_size + scheduled_requests.batch_size > self.batch_size:
return 0
# No padding if it would create too many concurrent requests.
# This is not strictly required, but we should probably
# respect the requirement just in case that changes in the future.
if self.cuda_graph_dummy_request is None:
available_blocks = kv_cache_manager.get_num_free_blocks()
# No padding if not enough KV cache space
if available_blocks < 1:
return 0
cuda_graph_dummy_request_ids = [MAX_UINT64 - 1]
self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests(
cuda_graph_dummy_request_ids,
is_gen=True,
max_num_draft_tokens=self.runtime_draft_len,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width)[0]
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests(
request_ids=cuda_graph_dummy_request_ids)
scheduled_requests.generation_requests.extend(
[self.cuda_graph_dummy_request] * padding_size)
return padding_size
@contextlib.contextmanager
def _maybe_pad_batch(
self,
scheduled_requests: ScheduledRequests,
kv_cache_manager,
spec_resource_manager: Optional[BaseResourceManager] = None):
"""
CUDA graphs can only be used for specific batch sizes.
If using CUDA graphs, this method will add dummy requests to the given
batch so we can always use a CUDA graph. It is a context manager
because the padded requests will be removed from scheduled requests.
"""
padding_size = self._get_padded_batch(scheduled_requests,
kv_cache_manager,
spec_resource_manager)
try:
yield scheduled_requests
finally:
if padding_size > 0:
scheduled_requests.generation_requests = scheduled_requests.generation_requests[:
-padding_size]
def _round_up_batch_size(self, batch_size: int) -> int:
"""
Round up the given batch size to the nearest batch size that is
associated with a CUDA graph.
"""
idx = bisect.bisect_left(self._cuda_graph_batch_sizes, batch_size)
return self._cuda_graph_batch_sizes[idx]
def _maybe_get_cuda_graph(
self,
batch: ScheduledRequests,
) -> Optional[DecodingCUDAGraphRunner]:
"""
Get a CUDA graph runner or return None (e.g. if CUDA graphs are disabled
or if the batch size is too big).
"""
# disable when doing statistic
if ExpertStatistic.set_iter(self.iter_counter):
return None
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
can_run_cuda_graph = batch.can_run_cuda_graph
batch_size = len(batch.generation_requests)
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
all_can_graph_batch = self.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
is_all_gen_only = all(all_can_graph[0]
for all_can_graph in all_can_graph_batch)
all_batch_size_equal = all(
all_gen_only[1] == all_can_graph_batch[0][1]
for all_gen_only in all_can_graph_batch)
if not is_all_gen_only or not all_batch_size_equal:
return None
if not self._run_cuda_graphs or not can_run_cuda_graph:
return None
if batch_size in self._cuda_graphs and draft_len in self._cuda_graphs[
batch_size]:
return self._cuda_graphs[batch_size][draft_len]
if batch_size not in self._cuda_graph_batch_sizes:
return None
num_sequences_in_batch = batch_size * self.max_beam_width
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
num_sequences_in_batch, False, draft_len)
assert attn_metadata.is_cuda_graph
if self.enable_spec_decode:
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
num_sequences_in_batch)
spec_metadata.draft_tokens = self.draft_tokens_cuda
else:
spec_metadata = None
# Initialize nested dictionary if needed
if batch_size not in self._cuda_graphs:
self._cuda_graphs[batch_size] = {}
self._cuda_graphs[batch_size][draft_len] = DecodingCUDAGraphRunner(
batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope,
self.max_beam_width)
return self._cuda_graphs[batch_size][draft_len]
def __del__(self) -> None:
if getattr(self, 'ub_buffers', None):
for u in self.ub_buffers:
@ -1244,13 +1096,7 @@ class PyTorchModelEngine(ModelEngine):
self._init_max_num_tokens()
def _release_cuda_graphs(self):
for batch_size, draft_graphs in self._cuda_graphs.items():
for draft_len, graph in draft_graphs.items():
del graph
self._cuda_graphs.clear()
torch.cuda.empty_cache()
del self._cuda_graph_mem_pool
self._cuda_graph_mem_pool = None
self.cuda_graph_runner.clear()
def get_max_num_sequences(self) -> int:
"""
@ -2237,12 +2083,14 @@ class PyTorchModelEngine(ModelEngine):
else:
return self._forward_step(inputs, gather_ids,
gather_context_logits)
with self._maybe_pad_batch(scheduled_requests, kv_cache_manager,
spec_resource_manager) as scheduled_requests:
maybe_graph = self._maybe_get_cuda_graph(scheduled_requests)
if maybe_graph is not None:
attn_metadata = maybe_graph.attn_metadata
spec_metadata = maybe_graph.spec_metadata
with self.cuda_graph_runner.pad_batch(
scheduled_requests, resource_manager) as padded_requests:
maybe_graph, maybe_attn_metadata, maybe_spec_metadata = self.cuda_graph_runner.maybe_get_cuda_graph(
padded_requests)
if maybe_graph:
attn_metadata = maybe_attn_metadata
spec_metadata = maybe_spec_metadata
else:
attn_metadata = self.attn_metadata
if self.enable_spec_decode:
@ -2251,17 +2099,19 @@ class PyTorchModelEngine(ModelEngine):
spec_metadata = None
inputs, gather_ids = self._prepare_inputs(
scheduled_requests, kv_cache_manager, attn_metadata,
spec_metadata, new_tensors_device, cache_indirection_buffer)
padded_requests, kv_cache_manager, attn_metadata, spec_metadata,
new_tensors_device, cache_indirection_buffer)
self.iter_counter += 1
if maybe_graph is None:
if not maybe_graph:
# Fallback to eager execution if graph was not used
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = self._forward_step(inputs, gather_ids,
gather_context_logits)
else:
if maybe_graph.needs_capture():
batch_size = len(padded_requests.generation_requests)
if self.cuda_graph_runner.needs_capture(batch_size):
def capture_forward_fn(inputs: Dict[str, Any]):
with MoeLoadBalancerIterContext(moe_load_balancer):
@ -2270,18 +2120,16 @@ class PyTorchModelEngine(ModelEngine):
gather_ids=gather_ids,
gather_context_logits=gather_context_logits)
pool = maybe_graph.capture(
capture_forward_fn,
self._cuda_graph_mem_pool,
)
self._cuda_graph_mem_pool = pool
self.cuda_graph_runner.capture(batch_size,
capture_forward_fn, inputs)
# here we don't need to use context since cuda graph capture didn't run kernel.
# maybe we need a cleaner way to do this.
outputs = maybe_graph.run(inputs)
outputs = self.cuda_graph_runner.replay(batch_size, inputs)
else:
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = maybe_graph.run(inputs)
outputs = self.cuda_graph_runner.replay(
batch_size, inputs)
self._execute_logit_post_processors(scheduled_requests, outputs)

View File

@ -140,6 +140,8 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
def test_pad_generation_requests(self) -> None:
model_engine, kv_cache_manager = create_model_engine_and_kvcache()
resource_manager = ResourceManager(
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
seqlens_and_batch_sizes = [
(5, 1),
@ -155,8 +157,8 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
batch.generation_requests = []
pages_before = kv_cache_manager.get_num_free_blocks()
with model_engine._maybe_pad_batch(
batch, kv_cache_manager) as padded_batch:
with model_engine.cuda_graph_runner.pad_batch(
batch, resource_manager) as padded_batch:
# No padding for prefill
self.assertIs(batch, padded_batch)
self.assertEqual(kv_cache_manager.get_num_free_blocks(),
@ -166,9 +168,9 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
batch.context_requests = []
batch.generation_requests = requests
pages_before = kv_cache_manager.get_num_free_blocks()
new_dummy_block = 1 if model_engine.cuda_graph_dummy_request is None else 0
with model_engine._maybe_pad_batch(
batch, kv_cache_manager) as padded_batch:
new_dummy_block = 1 if model_engine.cuda_graph_runner.padding_dummy_request is None else 0
with model_engine.cuda_graph_runner.pad_batch(
batch, resource_manager) as padded_batch:
if batch_size < 8 and max_seq_len < 25:
self.assertEqual(
len(padded_batch.generation_requests) % 8, 0)

View File

@ -162,3 +162,32 @@ def reference_block_scale_moe_torch(
results[batch_idx] += final_scales[batch_idx, nth_expert, None] * output
return results.view_as(x)
class MockPytorchBackendConfig:
def __init__(self, use_cuda_graph, cuda_graph_padding_enabled):
self.use_cuda_graph = use_cuda_graph
self.cuda_graph_padding_enabled = cuda_graph_padding_enabled
class MockEngine:
"""A replacement for SimpleNamespace that supports weak references."""
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def create_mock_engine(batch_size: int):
return MockEngine(
pytorch_backend_config=MockPytorchBackendConfig(
use_cuda_graph=True, cuda_graph_padding_enabled=False),
_cuda_graph_batch_sizes=[batch_size],
_max_cuda_graph_batch_size=batch_size,
max_beam_width=1,
is_spec_decode=False,
spec_config=None,
_cuda_graph_mem_pool=None,
use_mrope=False,
)

View File

@ -22,6 +22,7 @@ except ImportError:
# TODO: Remove this once we have a proper config for Exaone4
SKIP_EXAONE4_HF_ACCURACY_TEST = True
from _torch.helpers import create_mock_engine
from transformers.cache_utils import HybridCache
from utils.util import getSMVersion
@ -30,8 +31,7 @@ from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_exaone4 import Exaone4ForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
@ -338,6 +338,11 @@ class TestEXAONE4(unittest.TestCase):
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
@ -345,19 +350,20 @@ class TestEXAONE4(unittest.TestCase):
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
graph_runner = DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata)
graph_runner.capture(lambda inputs: exaone4.forward(**inputs))
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
graph_runner.capture(1,
lambda inputs: exaone4.forward(**inputs),
inputs)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run({
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
})
logits = graph_runner.replay(1, inputs)
return logits
if scenario.use_cuda_graph:
@ -380,5 +386,6 @@ class TestEXAONE4(unittest.TestCase):
ref.logits[:, -1].float(),
atol=0.4,
rtol=0.4)
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Any
import torch
from _torch.helpers import create_mock_engine
from parameterized import parameterized
from transformers import LlamaConfig
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
@ -14,8 +15,7 @@ from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_llama import LlamaForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
@ -326,6 +326,11 @@ class TestLlama(unittest.TestCase):
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
@ -333,19 +338,18 @@ class TestLlama(unittest.TestCase):
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
graph_runner = DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata)
graph_runner.capture(lambda inputs: llama.forward(**inputs))
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
graph_runner.capture(1, lambda inputs: llama.forward(**inputs),
inputs)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run({
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
})
logits = graph_runner.replay(1, inputs)
return logits
if scenario.use_cuda_graph:
@ -364,5 +368,6 @@ class TestLlama(unittest.TestCase):
ref.logits[:, -1].float(),
atol=0.4,
rtol=0.4)
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass
import torch
import transformers
from _torch.helpers import create_mock_engine
from parameterized import parameterized
from transformers import Llama4Config
from transformers import \
@ -20,8 +21,7 @@ from tensorrt_llm._torch.models.checkpoints.hf.llama4_weight_mapper import \
from tensorrt_llm._torch.models.modeling_llama import \
Llama4ForConditionalGeneration
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
@ -403,6 +403,10 @@ class TestLlama4MinLatency(unittest.TestCase):
input_ids.size(-1) + gen_input_ids.size(-1))
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
@ -411,19 +415,19 @@ class TestLlama4MinLatency(unittest.TestCase):
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
graph_runner = DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata)
graph_runner.capture(lambda inputs: llama.forward(**inputs))
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
graph_runner.capture(1, lambda inputs: llama.forward(**inputs),
inputs)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run({
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
})
logits = graph_runner.replay(1, inputs)
return logits
if scenario.use_cuda_graph:

View File

@ -7,6 +7,7 @@ from unittest import mock
import pytest
import torch
import transformers
from _torch.helpers import create_mock_engine
from utils.util import getSMVersion
import tensorrt_llm
@ -15,7 +16,8 @@ from tensorrt_llm._torch import metadata as metadata_lib
from tensorrt_llm._torch import model_config as model_config_lib
from tensorrt_llm._torch.attention_backend import utils as attention_utils
from tensorrt_llm._torch.models import modeling_mistral
from tensorrt_llm._torch.pyexecutor import cuda_graph_runner, resource_manager
from tensorrt_llm._torch.pyexecutor import resource_manager
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm.bindings import executor as executor_lib
from tensorrt_llm.models import modeling_utils
@ -398,6 +400,11 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not use_cuda_graph:
@ -405,22 +412,18 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use
input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata
)
else:
graph_runner = cuda_graph_runner.DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata
)
graph_runner.capture(lambda inputs: mistral.forward(**inputs))
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
graph_runner.capture(1, lambda inputs: mistral.forward(**inputs), inputs)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run(
{
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
)
logits = graph_runner.replay(1, inputs)
return logits
if use_cuda_graph:
@ -438,3 +441,5 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use
)
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.4, rtol=0.4)
if graph_runner is not None:
graph_runner.clear()

View File

@ -3,6 +3,7 @@ from copy import deepcopy
from dataclasses import dataclass
import torch
from _torch.helpers import create_mock_engine
from parameterized import parameterized
from transformers import MixtralConfig
from transformers import MixtralForCausalLM as HFMixtralForCausalLM
@ -15,8 +16,7 @@ from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.checkpoints.hf.mixtral_weight_mapper import \
MixtralHfWeightMapper
from tensorrt_llm._torch.models.modeling_mixtral import MixtralForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
@ -310,6 +310,11 @@ class TestMixtral(unittest.TestCase):
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
@ -317,19 +322,20 @@ class TestMixtral(unittest.TestCase):
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
graph_runner = DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata)
graph_runner.capture(lambda inputs: mixtral.forward(**inputs))
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
graph_runner.capture(1,
lambda inputs: mixtral.forward(**inputs),
inputs)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run({
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
})
logits = graph_runner.replay(1, inputs)
return logits
if scenario.use_cuda_graph:
@ -348,5 +354,6 @@ class TestMixtral(unittest.TestCase):
ref.logits[:, -1].float(),
atol=0.1,
rtol=0.1)
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()

View File

@ -3,6 +3,7 @@ import unittest
from copy import deepcopy
import torch
from _torch.helpers import create_mock_engine
from parameterized import parameterized
from test_modeling_llama import Scenario, reduce_llama_config
from transformers import MllamaConfig
@ -15,8 +16,7 @@ from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_mllama import \
MllamaForConditionalGeneration
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
@ -417,6 +417,11 @@ class TestMLlama(unittest.TestCase):
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
@ -424,19 +429,19 @@ class TestMLlama(unittest.TestCase):
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
graph_runner = DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata)
graph_runner.capture(lambda inputs: mllama.forward(**inputs))
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
graph_runner.capture(1, lambda inputs: mllama.forward(**inputs),
inputs)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run({
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
})
logits = graph_runner.replay(1, inputs)
return logits
if scenario.use_cuda_graph:
@ -455,3 +460,6 @@ class TestMLlama(unittest.TestCase):
ref.logits[:, -1].float(),
atol=0.3,
rtol=0.3)
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Any
import torch
from _torch.helpers import create_mock_engine
from parameterized import parameterized
from transformers import NemotronConfig
from transformers import NemotronForCausalLM as HFNemotronForCausalLM
@ -14,8 +15,7 @@ from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_nemotron import NemotronForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
@ -318,6 +318,11 @@ class TestNemotron(unittest.TestCase):
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
@ -325,19 +330,20 @@ class TestNemotron(unittest.TestCase):
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
graph_runner = DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata)
graph_runner.capture(lambda inputs: nemotron.forward(**inputs))
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
graph_runner.capture(1,
lambda inputs: nemotron.forward(**inputs),
inputs)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run({
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
})
logits = graph_runner.replay(1, inputs)
return logits
if scenario.use_cuda_graph:
@ -357,4 +363,6 @@ class TestNemotron(unittest.TestCase):
atol=0.4,
rtol=0.4)
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Any
import torch
from _torch.helpers import create_mock_engine
from transformers import Phi3Config
from transformers import Phi3ForCausalLM as HFPhi3ForCausalLM
from utils.util import default_dtype
@ -13,8 +14,7 @@ from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_phi3 import Phi3ForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
@ -310,6 +310,11 @@ class TestPhi3(unittest.TestCase):
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
@ -317,19 +322,19 @@ class TestPhi3(unittest.TestCase):
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
graph_runner = DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata)
graph_runner.capture(lambda inputs: phi3.forward(**inputs))
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
graph_runner.capture(1, lambda inputs: phi3.forward(**inputs),
inputs)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run({
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
})
logits = graph_runner.replay(1, inputs)
return logits
if scenario.use_cuda_graph:
@ -348,5 +353,6 @@ class TestPhi3(unittest.TestCase):
ref.logits[:, -1].float(),
atol=0.4,
rtol=0.4)
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()

View File

@ -17,12 +17,12 @@ from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_qwen import (
Qwen2ForCausalLM, Qwen2ForProcessRewardModel)
# yapf: enable
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from _torch.helpers import create_mock_engine
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from utils.llm_data import llm_models_root
from utils.util import getSMVersion
@ -265,6 +265,11 @@ class TestQwen(unittest.TestCase):
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
@ -272,19 +277,19 @@ class TestQwen(unittest.TestCase):
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
graph_runner = DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata)
graph_runner.capture(lambda inputs: qwen.forward(**inputs))
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
graph_runner.capture(1, lambda inputs: qwen.forward(**inputs),
inputs)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run({
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
})
logits = graph_runner.replay(1, inputs)
return logits
if scenario.use_cuda_graph:
@ -303,7 +308,8 @@ class TestQwen(unittest.TestCase):
ref.logits[:, -1].float(),
atol=0.4,
rtol=0.4)
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()
@parameterized.expand(

View File

@ -3,6 +3,7 @@ from copy import deepcopy
from dataclasses import dataclass
import torch
from _torch.helpers import create_mock_engine
from parameterized import parameterized
from transformers import Qwen2MoeConfig
from transformers import Qwen2MoeForCausalLM as HFQwen2MoeForCausalLM
@ -15,8 +16,7 @@ from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \
Qwen2MoeHfWeightMapper
from tensorrt_llm._torch.models.modeling_qwen_moe import Qwen2MoeForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
@ -315,6 +315,11 @@ class TestQwenMoe(unittest.TestCase):
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
@ -322,19 +327,20 @@ class TestQwenMoe(unittest.TestCase):
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
graph_runner = DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata)
graph_runner.capture(lambda inputs: qwen_moe.forward(**inputs))
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
graph_runner.capture(1,
lambda inputs: qwen_moe.forward(**inputs),
inputs)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run({
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
})
logits = graph_runner.replay(1, inputs)
return logits
if scenario.use_cuda_graph:
@ -353,5 +359,6 @@ class TestQwenMoe(unittest.TestCase):
ref.logits[:, -1].float(),
atol=0.1,
rtol=0.1)
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()