mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][refactor] refactor the CUDA graph runner to manage all CUDA graphs (#6846)
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
parent
b32e00e9fd
commit
bea5e07fb7
@ -1,122 +1,309 @@
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
import bisect
|
||||
import contextlib
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..attention_backend.interface import AttentionMetadata
|
||||
from ..expert_statistic import ExpertStatistic
|
||||
from ..modules.multi_stream_utils import with_multi_stream
|
||||
from ..speculative.interface import SpecMetadata
|
||||
from ..utils import make_weak_ref, piecewise_cuda_graph
|
||||
from .resource_manager import ResourceManager, ResourceManagerType
|
||||
from .scheduler import ScheduledRequests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model_engine import PyTorchModelEngine
|
||||
|
||||
# A large prime number used for dummy request IDs to avoid collisions
|
||||
CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1
|
||||
|
||||
|
||||
class DecodingCUDAGraphRunner:
|
||||
class CUDAGraphRunner:
|
||||
"""
|
||||
Manages the lifecycle and execution of CUDA graphs for the model engine.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
device: str,
|
||||
attn_metadata: AttentionMetadata,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
use_mrope: bool = False,
|
||||
max_beam_width: int = 1,
|
||||
) -> None:
|
||||
This unified class handles high-level orchestration (padding, eligibility)
|
||||
and low-level execution (capturing, resource management, replaying) for
|
||||
multiple graphs, keyed by (batch size, draft_len).
|
||||
"""
|
||||
WARMUP_STEPS = 2
|
||||
|
||||
def __init__(self, engine: "PyTorchModelEngine"):
|
||||
self.engine_ref = weakref.ref(engine)
|
||||
|
||||
# High-level configuration
|
||||
config = engine.pytorch_backend_config
|
||||
self.enabled = config.use_cuda_graph
|
||||
self.padding_enabled = config.cuda_graph_padding_enabled
|
||||
self.supported_batch_sizes = engine._cuda_graph_batch_sizes
|
||||
self.max_supported_batch_size = engine._max_cuda_graph_batch_size
|
||||
self.max_beam_width = engine.max_beam_width
|
||||
self.spec_config = engine.spec_config
|
||||
|
||||
self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
|
||||
self.static_inputs: Dict[Tuple[int, int], Dict[str, torch.Tensor]] = {}
|
||||
self.graph_outputs: Dict[Tuple[int, int],
|
||||
Callable[[], Optional[torch.Tensor]]] = {}
|
||||
self.graph_metadata: Dict[Tuple[int, int], Dict[str, Any]] = {}
|
||||
self.memory_pool = engine._cuda_graph_mem_pool
|
||||
self.padding_dummy_request: Optional["Request"] = None
|
||||
|
||||
@property
|
||||
def enable_spec_decode(self):
|
||||
return self._get_engine().is_spec_decode
|
||||
|
||||
@property
|
||||
def draft_len(self):
|
||||
return self.spec_config.max_draft_len if self.enable_spec_decode else 0
|
||||
|
||||
@property
|
||||
def spec_metadata(self):
|
||||
return self._get_engine().spec_metadata
|
||||
|
||||
@property
|
||||
def draft_tokens_cuda(self):
|
||||
return self._get_engine().draft_tokens_cuda
|
||||
|
||||
@property
|
||||
def attn_metadata(self):
|
||||
return self._get_engine().attn_metadata
|
||||
|
||||
def __del__(self):
|
||||
self.clear()
|
||||
|
||||
def _get_engine(self) -> "PyTorchModelEngine":
|
||||
"""Safely dereferences the weak reference to the engine."""
|
||||
engine = self.engine_ref()
|
||||
if engine is None:
|
||||
raise RuntimeError(
|
||||
"The parent PyTorchModelEngine has been garbage collected.")
|
||||
return engine
|
||||
|
||||
def maybe_get_cuda_graph(self, batch: ScheduledRequests):
|
||||
"""
|
||||
Stores a CUDA graph and its associated input buffers.
|
||||
Determines if the current batch can be run with a CUDA graph.
|
||||
|
||||
Each CUDA graph runner is associated with an AttentionMetadata object
|
||||
if flashinfer is being used. Make sure to call attn_metadata.prepare()
|
||||
before run()!
|
||||
|
||||
Note that torch.compile w/ mode reduce-overhead supports CUDA graphs
|
||||
with memory pool sharing. However, we have our own manager here because,
|
||||
at the time of writing this, torch.compile takes way too long to warmup
|
||||
graphs compared to doing it manually (not to mention, custom ops from
|
||||
e.g. FlashInfer cause graph breaks).
|
||||
Returns a tuple containing:
|
||||
- A boolean indicating if a graph can be used.
|
||||
- The attn_metadata for the graph, if applicable.
|
||||
- The spec_metadata for the graph, if applicable.
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
self.max_beam_width = max_beam_width
|
||||
engine = self._get_engine()
|
||||
|
||||
# disable when doing statistic
|
||||
if hasattr(engine, 'iter_counter') and ExpertStatistic.set_iter(
|
||||
engine.iter_counter):
|
||||
return False, None, None
|
||||
|
||||
can_run_cuda_graph = batch.can_run_cuda_graph
|
||||
batch_size = batch.batch_size
|
||||
if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1:
|
||||
all_can_graph_batch = engine.dist.tp_allgather(
|
||||
[can_run_cuda_graph, batch_size])
|
||||
is_all_gen_only = all(all_can_graph[0]
|
||||
for all_can_graph in all_can_graph_batch)
|
||||
all_batch_size_equal = all(
|
||||
all_gen_only[1] == all_can_graph_batch[0][1]
|
||||
for all_gen_only in all_can_graph_batch)
|
||||
|
||||
if not is_all_gen_only or not all_batch_size_equal:
|
||||
return False, None, None
|
||||
|
||||
if not self.enabled or not can_run_cuda_graph:
|
||||
return False, None, None
|
||||
|
||||
key = (batch_size, self.draft_len)
|
||||
if key in self.graphs:
|
||||
return True, self.graph_metadata[key][
|
||||
"attn_metadata"], self.graph_metadata[key]["spec_metadata"]
|
||||
|
||||
if batch_size not in self.supported_batch_sizes:
|
||||
return False, None, None
|
||||
|
||||
num_sequences_in_batch = batch_size * self.max_beam_width
|
||||
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
|
||||
num_sequences_in_batch, False, self.draft_len)
|
||||
assert attn_metadata.is_cuda_graph
|
||||
|
||||
if self.enable_spec_decode:
|
||||
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
|
||||
num_sequences_in_batch)
|
||||
spec_metadata.draft_tokens = self.draft_tokens_cuda
|
||||
else:
|
||||
spec_metadata = None
|
||||
return True, attn_metadata, spec_metadata
|
||||
|
||||
def needs_capture(self, batch_size: int):
|
||||
return (batch_size, self.draft_len) not in self.graph_outputs
|
||||
|
||||
def capture(self, batch_size: int, forward_fn: Callable,
|
||||
initial_inputs: Dict[str, Any]):
|
||||
"""Captures the forward pass for a given batch size."""
|
||||
engine = self._get_engine()
|
||||
key = (batch_size, self.draft_len)
|
||||
spec_metadata = initial_inputs.get("spec_metadata", None)
|
||||
# [CUDA graph spec decode padding]
|
||||
# We pad input IDs/position IDs to the maximum draft length (token per request).
|
||||
# We're forced to do this because we cannot reallocate inputs over many graph runs.
|
||||
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1
|
||||
|
||||
# Using ones instead of zeros prevents NaNs in e.g. Deepseek
|
||||
self.input_ids = torch.ones(
|
||||
(batch_size * max_beam_width * token_per_request, ),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.position_ids = torch.zeros(
|
||||
(1, batch_size * max_beam_width * token_per_request),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.mrope_position_deltas = torch.zeros(
|
||||
(batch_size,
|
||||
1), device=device, dtype=torch.int32) if use_mrope else None
|
||||
static_tensors = {
|
||||
"input_ids":
|
||||
torch.ones((batch_size * self.max_beam_width * token_per_request, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32),
|
||||
"position_ids":
|
||||
torch.zeros((
|
||||
1,
|
||||
batch_size * self.max_beam_width * token_per_request,
|
||||
),
|
||||
device="cuda",
|
||||
dtype=torch.int32),
|
||||
}
|
||||
if engine.use_mrope:
|
||||
static_tensors["mrope_position_deltas"] = torch.zeros(
|
||||
(batch_size, 1), device="cuda", dtype=torch.int32)
|
||||
self.static_inputs[key] = static_tensors
|
||||
|
||||
self.attn_metadata = attn_metadata
|
||||
self.spec_metadata = spec_metadata
|
||||
self._output = None
|
||||
self._graph = None
|
||||
self.optional_extra_model_inputs = ["mrope_position_deltas"]
|
||||
capture_inputs = initial_inputs.copy()
|
||||
capture_inputs.update(static_tensors)
|
||||
|
||||
def __del__(self):
|
||||
self._graph.reset()
|
||||
|
||||
def capture(
|
||||
self,
|
||||
forward_fn: Callable[[Dict[str, Any]], torch.Tensor],
|
||||
pool: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
self._graph = torch.cuda.CUDAGraph()
|
||||
inputs = {
|
||||
"attn_metadata": self.attn_metadata,
|
||||
"input_ids": self.input_ids,
|
||||
"position_ids": self.position_ids,
|
||||
"inputs_embeds": None,
|
||||
"spec_metadata": self.spec_metadata,
|
||||
"mrope_position_deltas": self.mrope_position_deltas,
|
||||
self.graph_metadata[key] = {
|
||||
"attn_metadata": initial_inputs["attn_metadata"],
|
||||
"spec_metadata": spec_metadata,
|
||||
}
|
||||
|
||||
# We have to do warm up runs to initialize PyTorch's
|
||||
# internal states according to the docs:
|
||||
# https://pytorch.org/docs/stable/notes/cuda.html#cuda-graph-semantics
|
||||
# This also lets us initialize states in the attn_metadata.
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with with_multi_stream(True), piecewise_cuda_graph(False):
|
||||
for _ in range(2):
|
||||
forward_fn(inputs)
|
||||
with torch.cuda.graph(self._graph, pool=pool):
|
||||
output = forward_fn(inputs)
|
||||
# Mark weak ref here. The output tensor should be freed properly.
|
||||
self._output = make_weak_ref(output)
|
||||
return self._graph.pool()
|
||||
for _ in range(self.WARMUP_STEPS):
|
||||
forward_fn(capture_inputs)
|
||||
with torch.cuda.graph(graph, pool=self.memory_pool):
|
||||
output = forward_fn(capture_inputs)
|
||||
|
||||
def needs_capture(self) -> bool:
|
||||
return self._output is None
|
||||
self.graphs[key] = graph
|
||||
self.graph_outputs[key] = make_weak_ref(output)
|
||||
self.memory_pool = graph.pool()
|
||||
|
||||
def run(self, inputs: Dict[str, Any]) -> torch.Tensor:
|
||||
assert "input_ids" in inputs
|
||||
assert "position_ids" in inputs
|
||||
assert "attn_metadata" in inputs
|
||||
def replay(self, batch_size: int,
|
||||
current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
|
||||
"""Replays a previously captured graph."""
|
||||
key = (batch_size, self.draft_len)
|
||||
stored_meta = self.graph_metadata[key]
|
||||
assert current_inputs["attn_metadata"] is stored_meta["attn_metadata"]
|
||||
if stored_meta["spec_metadata"] is not None:
|
||||
assert current_inputs.get(
|
||||
"spec_metadata") is stored_meta["spec_metadata"]
|
||||
|
||||
attn_metadata = inputs["attn_metadata"]
|
||||
assert attn_metadata is self.attn_metadata, (
|
||||
"attn_metadata does not match the attn_metadata instance that was used to "
|
||||
"capture this graph.")
|
||||
static_tensors = self.static_inputs[key]
|
||||
|
||||
if "spec_metadata" in inputs:
|
||||
spec_metadata = inputs["spec_metadata"]
|
||||
assert spec_metadata is self.spec_metadata, (
|
||||
"spec_metadata does not match the spec_metadata instance that was used to "
|
||||
"capture this graph.")
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
position_ids = inputs["position_ids"]
|
||||
input_ids = current_inputs["input_ids"]
|
||||
seqlen = input_ids.shape[0]
|
||||
self.input_ids[:seqlen].copy_(input_ids)
|
||||
self.position_ids[:, :seqlen].copy_(position_ids)
|
||||
if "mrope_position_deltas" in inputs:
|
||||
self.mrope_position_deltas[:self.batch_size].copy_(
|
||||
inputs["mrope_position_deltas"])
|
||||
static_tensors["input_ids"][:seqlen].copy_(input_ids)
|
||||
|
||||
assert self._output is not None and self._graph is not None
|
||||
self._graph.replay()
|
||||
return self._output
|
||||
position_ids = current_inputs["position_ids"]
|
||||
static_tensors["position_ids"][:, :seqlen].copy_(position_ids)
|
||||
|
||||
if "mrope_position_deltas" in current_inputs:
|
||||
assert "mrope_position_deltas" in static_tensors
|
||||
static_tensors["mrope_position_deltas"][:batch_size].copy_(
|
||||
current_inputs["mrope_position_deltas"])
|
||||
|
||||
self.graphs[key].replay()
|
||||
output_ref = self.graph_outputs[key]
|
||||
|
||||
return output_ref
|
||||
|
||||
def _get_padded_batch(self, batch: ScheduledRequests,
|
||||
resource_manager: ResourceManager) -> int:
|
||||
engine = self._get_engine()
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
engine.kv_cache_manager_key)
|
||||
can_run_cuda_graph = batch.can_run_cuda_graph
|
||||
batch_size = batch.batch_size
|
||||
new_batch_size = batch_size
|
||||
|
||||
if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1:
|
||||
graph_batch_size = engine.dist.tp_allgather(
|
||||
[can_run_cuda_graph, batch_size])
|
||||
all_can_graph = all(graph_batch[0]
|
||||
for graph_batch in graph_batch_size)
|
||||
if all_can_graph:
|
||||
new_batch_size = max(gen_only_batch[1]
|
||||
for gen_only_batch in graph_batch_size)
|
||||
|
||||
if (not self.enabled or not self.padding_enabled
|
||||
or not can_run_cuda_graph
|
||||
or new_batch_size > self.max_supported_batch_size):
|
||||
return 0
|
||||
|
||||
padded_batch_size = self._round_up_batch_size(new_batch_size)
|
||||
if batch_size == padded_batch_size:
|
||||
return 0
|
||||
|
||||
padding_size = padded_batch_size - batch_size
|
||||
if padding_size + batch.batch_size > engine.batch_size:
|
||||
return 0
|
||||
|
||||
# No padding if it would create too many concurrent requests.
|
||||
# This is not strictly required, but we should probably
|
||||
# respect the requirement just in case that changes in the future.
|
||||
if self.padding_dummy_request is None:
|
||||
available_blocks = kv_cache_manager.get_num_free_blocks()
|
||||
# No padding if not enough KV cache space
|
||||
if available_blocks < 1:
|
||||
return 0
|
||||
|
||||
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
|
||||
[CUDA_GRAPH_DUMMY_REQUEST_ID],
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=engine.max_draft_len,
|
||||
use_mrope=engine.use_mrope,
|
||||
max_beam_width=engine.max_beam_width)[0]
|
||||
self.padding_dummy_request.is_cuda_graph_dummy = True
|
||||
spec_res_mgr = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
if spec_res_mgr:
|
||||
spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID])
|
||||
|
||||
batch.generation_requests.extend([self.padding_dummy_request] *
|
||||
padding_size)
|
||||
return padding_size
|
||||
|
||||
def _round_up_batch_size(self, batch_size: int) -> int:
|
||||
"""Finds the smallest supported graph batch size >= the given size."""
|
||||
if not self.supported_batch_sizes:
|
||||
return 0
|
||||
idx = bisect.bisect_left(self.supported_batch_sizes, batch_size)
|
||||
if idx == len(self.supported_batch_sizes):
|
||||
return 0
|
||||
return self.supported_batch_sizes[idx]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def pad_batch(self, scheduled_requests: ScheduledRequests,
|
||||
resource_manager: ResourceManager):
|
||||
"""Context manager to pad a batch to a graph-compatible size."""
|
||||
|
||||
padding_size = self._get_padded_batch(scheduled_requests,
|
||||
resource_manager)
|
||||
try:
|
||||
yield scheduled_requests
|
||||
finally:
|
||||
if padding_size > 0:
|
||||
scheduled_requests.generation_requests = scheduled_requests.generation_requests[:
|
||||
-padding_size]
|
||||
|
||||
def clear(self):
|
||||
"""Releases all captured graphs and the associated memory pool."""
|
||||
for graph in self.graphs.values():
|
||||
graph.reset()
|
||||
self.graphs.clear()
|
||||
self.static_inputs.clear()
|
||||
self.graph_outputs.clear()
|
||||
self.graph_metadata.clear()
|
||||
del self.memory_pool
|
||||
self.memory_pool = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import bisect
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
@ -57,7 +56,7 @@ from ..utils import (get_model_extra_attrs, set_torch_compiling,
|
||||
with_model_extra_attrs)
|
||||
from .config import LoadFormat, PyTorchConfig
|
||||
from .config_utils import is_mla
|
||||
from .cuda_graph_runner import DecodingCUDAGraphRunner
|
||||
from .cuda_graph_runner import CUDAGraphRunner
|
||||
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
|
||||
from .llm_request import get_draft_token_length
|
||||
from .resource_manager import (BaseResourceManager, KVCacheManager,
|
||||
@ -422,7 +421,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.iter_states = {}
|
||||
self._cuda_graphs = {}
|
||||
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
|
||||
self._run_cuda_graphs = pytorch_backend_config.use_cuda_graph
|
||||
|
||||
self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled
|
||||
|
||||
@ -451,7 +449,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# with different KV cache managers.
|
||||
self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER
|
||||
self.lora_model_config: Optional[LoraModelConfig] = None
|
||||
self.cuda_graph_dummy_request = None
|
||||
self.cuda_graph_runner = CUDAGraphRunner(self)
|
||||
|
||||
# Setup the local cache indirection buffer only once and reuse it.
|
||||
# This way it can also be used for CUDA graphs.
|
||||
@ -541,12 +539,12 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_cuda_graph(self):
|
||||
_run_cuda_graphs = self._run_cuda_graphs
|
||||
self._run_cuda_graphs = False
|
||||
_run_cuda_graphs = self.cuda_graph_runner.enabled
|
||||
self.cuda_graph_runner.enabled = False
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._run_cuda_graphs = _run_cuda_graphs
|
||||
self.cuda_graph_runner.enabled = _run_cuda_graphs
|
||||
|
||||
@with_warmup_flag
|
||||
def warmup(self, resource_manager: ResourceManager) -> None:
|
||||
@ -561,7 +559,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
# The lifetime of model engine and kv cache manager can be different.
|
||||
# Reset the global cuda graph dummy request to None in warmup.
|
||||
self.cuda_graph_dummy_request = None
|
||||
self.cuda_graph_runner.padding_dummy_request = None
|
||||
|
||||
def get_cuda_graph_warmup_request(batch_size, draft_len):
|
||||
# Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel.
|
||||
@ -756,7 +754,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
AutoTuner.get().print_profiling_cache()
|
||||
|
||||
if not (self._run_cuda_graphs
|
||||
if not (self.cuda_graph_runner.enabled
|
||||
or self._torch_compile_piecewise_cuda_graph):
|
||||
return
|
||||
|
||||
@ -889,152 +887,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
is_draft_model=self.is_draft_model)
|
||||
return self.spec_metadata
|
||||
|
||||
def _get_padded_batch(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
kv_cache_manager,
|
||||
spec_resource_manager: Optional[BaseResourceManager] = None) -> int:
|
||||
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
|
||||
batch_size = scheduled_requests.batch_size
|
||||
new_batch_size = batch_size
|
||||
|
||||
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
graph_batch_size = self.dist.tp_allgather(
|
||||
[can_run_cuda_graph, batch_size])
|
||||
all_can_graph = all(graph_batch[0]
|
||||
for graph_batch in graph_batch_size)
|
||||
if all_can_graph:
|
||||
new_batch_size = max(gen_only_batch[1]
|
||||
for gen_only_batch in graph_batch_size)
|
||||
|
||||
if (not self._run_cuda_graphs or not self._cuda_graph_padding_enabled
|
||||
or not can_run_cuda_graph
|
||||
or new_batch_size > self._max_cuda_graph_batch_size):
|
||||
return 0
|
||||
|
||||
padded_batch_size = self._round_up_batch_size(new_batch_size)
|
||||
if batch_size == padded_batch_size:
|
||||
return 0
|
||||
|
||||
padding_size = padded_batch_size - batch_size
|
||||
if padding_size + scheduled_requests.batch_size > self.batch_size:
|
||||
return 0
|
||||
|
||||
# No padding if it would create too many concurrent requests.
|
||||
# This is not strictly required, but we should probably
|
||||
# respect the requirement just in case that changes in the future.
|
||||
if self.cuda_graph_dummy_request is None:
|
||||
available_blocks = kv_cache_manager.get_num_free_blocks()
|
||||
# No padding if not enough KV cache space
|
||||
if available_blocks < 1:
|
||||
return 0
|
||||
|
||||
cuda_graph_dummy_request_ids = [MAX_UINT64 - 1]
|
||||
self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests(
|
||||
cuda_graph_dummy_request_ids,
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=self.runtime_draft_len,
|
||||
use_mrope=self.use_mrope,
|
||||
max_beam_width=self.max_beam_width)[0]
|
||||
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.add_dummy_requests(
|
||||
request_ids=cuda_graph_dummy_request_ids)
|
||||
|
||||
scheduled_requests.generation_requests.extend(
|
||||
[self.cuda_graph_dummy_request] * padding_size)
|
||||
|
||||
return padding_size
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _maybe_pad_batch(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
kv_cache_manager,
|
||||
spec_resource_manager: Optional[BaseResourceManager] = None):
|
||||
"""
|
||||
CUDA graphs can only be used for specific batch sizes.
|
||||
|
||||
If using CUDA graphs, this method will add dummy requests to the given
|
||||
batch so we can always use a CUDA graph. It is a context manager
|
||||
because the padded requests will be removed from scheduled requests.
|
||||
"""
|
||||
padding_size = self._get_padded_batch(scheduled_requests,
|
||||
kv_cache_manager,
|
||||
spec_resource_manager)
|
||||
try:
|
||||
yield scheduled_requests
|
||||
finally:
|
||||
if padding_size > 0:
|
||||
scheduled_requests.generation_requests = scheduled_requests.generation_requests[:
|
||||
-padding_size]
|
||||
|
||||
def _round_up_batch_size(self, batch_size: int) -> int:
|
||||
"""
|
||||
Round up the given batch size to the nearest batch size that is
|
||||
associated with a CUDA graph.
|
||||
"""
|
||||
idx = bisect.bisect_left(self._cuda_graph_batch_sizes, batch_size)
|
||||
return self._cuda_graph_batch_sizes[idx]
|
||||
|
||||
def _maybe_get_cuda_graph(
|
||||
self,
|
||||
batch: ScheduledRequests,
|
||||
) -> Optional[DecodingCUDAGraphRunner]:
|
||||
"""
|
||||
Get a CUDA graph runner or return None (e.g. if CUDA graphs are disabled
|
||||
or if the batch size is too big).
|
||||
"""
|
||||
# disable when doing statistic
|
||||
if ExpertStatistic.set_iter(self.iter_counter):
|
||||
return None
|
||||
|
||||
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
|
||||
can_run_cuda_graph = batch.can_run_cuda_graph
|
||||
batch_size = len(batch.generation_requests)
|
||||
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
all_can_graph_batch = self.dist.tp_allgather(
|
||||
[can_run_cuda_graph, batch_size])
|
||||
is_all_gen_only = all(all_can_graph[0]
|
||||
for all_can_graph in all_can_graph_batch)
|
||||
all_batch_size_equal = all(
|
||||
all_gen_only[1] == all_can_graph_batch[0][1]
|
||||
for all_gen_only in all_can_graph_batch)
|
||||
|
||||
if not is_all_gen_only or not all_batch_size_equal:
|
||||
return None
|
||||
|
||||
if not self._run_cuda_graphs or not can_run_cuda_graph:
|
||||
return None
|
||||
|
||||
if batch_size in self._cuda_graphs and draft_len in self._cuda_graphs[
|
||||
batch_size]:
|
||||
return self._cuda_graphs[batch_size][draft_len]
|
||||
|
||||
if batch_size not in self._cuda_graph_batch_sizes:
|
||||
return None
|
||||
|
||||
num_sequences_in_batch = batch_size * self.max_beam_width
|
||||
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
|
||||
num_sequences_in_batch, False, draft_len)
|
||||
assert attn_metadata.is_cuda_graph
|
||||
|
||||
if self.enable_spec_decode:
|
||||
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
|
||||
num_sequences_in_batch)
|
||||
spec_metadata.draft_tokens = self.draft_tokens_cuda
|
||||
else:
|
||||
spec_metadata = None
|
||||
|
||||
# Initialize nested dictionary if needed
|
||||
if batch_size not in self._cuda_graphs:
|
||||
self._cuda_graphs[batch_size] = {}
|
||||
|
||||
self._cuda_graphs[batch_size][draft_len] = DecodingCUDAGraphRunner(
|
||||
batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope,
|
||||
self.max_beam_width)
|
||||
return self._cuda_graphs[batch_size][draft_len]
|
||||
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, 'ub_buffers', None):
|
||||
for u in self.ub_buffers:
|
||||
@ -1244,13 +1096,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self._init_max_num_tokens()
|
||||
|
||||
def _release_cuda_graphs(self):
|
||||
for batch_size, draft_graphs in self._cuda_graphs.items():
|
||||
for draft_len, graph in draft_graphs.items():
|
||||
del graph
|
||||
self._cuda_graphs.clear()
|
||||
torch.cuda.empty_cache()
|
||||
del self._cuda_graph_mem_pool
|
||||
self._cuda_graph_mem_pool = None
|
||||
self.cuda_graph_runner.clear()
|
||||
|
||||
def get_max_num_sequences(self) -> int:
|
||||
"""
|
||||
@ -2237,12 +2083,14 @@ class PyTorchModelEngine(ModelEngine):
|
||||
else:
|
||||
return self._forward_step(inputs, gather_ids,
|
||||
gather_context_logits)
|
||||
with self._maybe_pad_batch(scheduled_requests, kv_cache_manager,
|
||||
spec_resource_manager) as scheduled_requests:
|
||||
maybe_graph = self._maybe_get_cuda_graph(scheduled_requests)
|
||||
if maybe_graph is not None:
|
||||
attn_metadata = maybe_graph.attn_metadata
|
||||
spec_metadata = maybe_graph.spec_metadata
|
||||
with self.cuda_graph_runner.pad_batch(
|
||||
scheduled_requests, resource_manager) as padded_requests:
|
||||
|
||||
maybe_graph, maybe_attn_metadata, maybe_spec_metadata = self.cuda_graph_runner.maybe_get_cuda_graph(
|
||||
padded_requests)
|
||||
if maybe_graph:
|
||||
attn_metadata = maybe_attn_metadata
|
||||
spec_metadata = maybe_spec_metadata
|
||||
else:
|
||||
attn_metadata = self.attn_metadata
|
||||
if self.enable_spec_decode:
|
||||
@ -2251,17 +2099,19 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_metadata = None
|
||||
|
||||
inputs, gather_ids = self._prepare_inputs(
|
||||
scheduled_requests, kv_cache_manager, attn_metadata,
|
||||
spec_metadata, new_tensors_device, cache_indirection_buffer)
|
||||
padded_requests, kv_cache_manager, attn_metadata, spec_metadata,
|
||||
new_tensors_device, cache_indirection_buffer)
|
||||
|
||||
self.iter_counter += 1
|
||||
|
||||
if maybe_graph is None:
|
||||
if not maybe_graph:
|
||||
# Fallback to eager execution if graph was not used
|
||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||
outputs = self._forward_step(inputs, gather_ids,
|
||||
gather_context_logits)
|
||||
else:
|
||||
if maybe_graph.needs_capture():
|
||||
batch_size = len(padded_requests.generation_requests)
|
||||
if self.cuda_graph_runner.needs_capture(batch_size):
|
||||
|
||||
def capture_forward_fn(inputs: Dict[str, Any]):
|
||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||
@ -2270,18 +2120,16 @@ class PyTorchModelEngine(ModelEngine):
|
||||
gather_ids=gather_ids,
|
||||
gather_context_logits=gather_context_logits)
|
||||
|
||||
pool = maybe_graph.capture(
|
||||
capture_forward_fn,
|
||||
self._cuda_graph_mem_pool,
|
||||
)
|
||||
self._cuda_graph_mem_pool = pool
|
||||
self.cuda_graph_runner.capture(batch_size,
|
||||
capture_forward_fn, inputs)
|
||||
|
||||
# here we don't need to use context since cuda graph capture didn't run kernel.
|
||||
# maybe we need a cleaner way to do this.
|
||||
outputs = maybe_graph.run(inputs)
|
||||
outputs = self.cuda_graph_runner.replay(batch_size, inputs)
|
||||
else:
|
||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||
outputs = maybe_graph.run(inputs)
|
||||
outputs = self.cuda_graph_runner.replay(
|
||||
batch_size, inputs)
|
||||
|
||||
self._execute_logit_post_processors(scheduled_requests, outputs)
|
||||
|
||||
|
||||
@ -140,6 +140,8 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||
|
||||
def test_pad_generation_requests(self) -> None:
|
||||
model_engine, kv_cache_manager = create_model_engine_and_kvcache()
|
||||
resource_manager = ResourceManager(
|
||||
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
|
||||
|
||||
seqlens_and_batch_sizes = [
|
||||
(5, 1),
|
||||
@ -155,8 +157,8 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||
batch.generation_requests = []
|
||||
|
||||
pages_before = kv_cache_manager.get_num_free_blocks()
|
||||
with model_engine._maybe_pad_batch(
|
||||
batch, kv_cache_manager) as padded_batch:
|
||||
with model_engine.cuda_graph_runner.pad_batch(
|
||||
batch, resource_manager) as padded_batch:
|
||||
# No padding for prefill
|
||||
self.assertIs(batch, padded_batch)
|
||||
self.assertEqual(kv_cache_manager.get_num_free_blocks(),
|
||||
@ -166,9 +168,9 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||
batch.context_requests = []
|
||||
batch.generation_requests = requests
|
||||
pages_before = kv_cache_manager.get_num_free_blocks()
|
||||
new_dummy_block = 1 if model_engine.cuda_graph_dummy_request is None else 0
|
||||
with model_engine._maybe_pad_batch(
|
||||
batch, kv_cache_manager) as padded_batch:
|
||||
new_dummy_block = 1 if model_engine.cuda_graph_runner.padding_dummy_request is None else 0
|
||||
with model_engine.cuda_graph_runner.pad_batch(
|
||||
batch, resource_manager) as padded_batch:
|
||||
if batch_size < 8 and max_seq_len < 25:
|
||||
self.assertEqual(
|
||||
len(padded_batch.generation_requests) % 8, 0)
|
||||
|
||||
@ -162,3 +162,32 @@ def reference_block_scale_moe_torch(
|
||||
results[batch_idx] += final_scales[batch_idx, nth_expert, None] * output
|
||||
|
||||
return results.view_as(x)
|
||||
|
||||
|
||||
class MockPytorchBackendConfig:
|
||||
|
||||
def __init__(self, use_cuda_graph, cuda_graph_padding_enabled):
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.cuda_graph_padding_enabled = cuda_graph_padding_enabled
|
||||
|
||||
|
||||
class MockEngine:
|
||||
"""A replacement for SimpleNamespace that supports weak references."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
|
||||
def create_mock_engine(batch_size: int):
|
||||
|
||||
return MockEngine(
|
||||
pytorch_backend_config=MockPytorchBackendConfig(
|
||||
use_cuda_graph=True, cuda_graph_padding_enabled=False),
|
||||
_cuda_graph_batch_sizes=[batch_size],
|
||||
_max_cuda_graph_batch_size=batch_size,
|
||||
max_beam_width=1,
|
||||
is_spec_decode=False,
|
||||
spec_config=None,
|
||||
_cuda_graph_mem_pool=None,
|
||||
use_mrope=False,
|
||||
)
|
||||
|
||||
@ -22,6 +22,7 @@ except ImportError:
|
||||
# TODO: Remove this once we have a proper config for Exaone4
|
||||
SKIP_EXAONE4_HF_ACCURACY_TEST = True
|
||||
|
||||
from _torch.helpers import create_mock_engine
|
||||
from transformers.cache_utils import HybridCache
|
||||
from utils.util import getSMVersion
|
||||
|
||||
@ -30,8 +31,7 @@ from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_exaone4 import Exaone4ForCausalLM
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -338,6 +338,11 @@ class TestEXAONE4(unittest.TestCase):
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
graph_runner = None
|
||||
if scenario.use_cuda_graph:
|
||||
mock_engine = create_mock_engine(1)
|
||||
graph_runner = CUDAGraphRunner(mock_engine)
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not scenario.use_cuda_graph:
|
||||
@ -345,19 +350,20 @@ class TestEXAONE4(unittest.TestCase):
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
else:
|
||||
graph_runner = DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata)
|
||||
graph_runner.capture(lambda inputs: exaone4.forward(**inputs))
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
graph_runner.capture(1,
|
||||
lambda inputs: exaone4.forward(**inputs),
|
||||
inputs)
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run({
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
})
|
||||
logits = graph_runner.replay(1, inputs)
|
||||
return logits
|
||||
|
||||
if scenario.use_cuda_graph:
|
||||
@ -380,5 +386,6 @@ class TestEXAONE4(unittest.TestCase):
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4)
|
||||
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from _torch.helpers import create_mock_engine
|
||||
from parameterized import parameterized
|
||||
from transformers import LlamaConfig
|
||||
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
|
||||
@ -14,8 +15,7 @@ from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_llama import LlamaForCausalLM
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -326,6 +326,11 @@ class TestLlama(unittest.TestCase):
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
graph_runner = None
|
||||
if scenario.use_cuda_graph:
|
||||
mock_engine = create_mock_engine(1)
|
||||
graph_runner = CUDAGraphRunner(mock_engine)
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not scenario.use_cuda_graph:
|
||||
@ -333,19 +338,18 @@ class TestLlama(unittest.TestCase):
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
else:
|
||||
graph_runner = DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata)
|
||||
graph_runner.capture(lambda inputs: llama.forward(**inputs))
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
graph_runner.capture(1, lambda inputs: llama.forward(**inputs),
|
||||
inputs)
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run({
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
})
|
||||
logits = graph_runner.replay(1, inputs)
|
||||
return logits
|
||||
|
||||
if scenario.use_cuda_graph:
|
||||
@ -364,5 +368,6 @@ class TestLlama(unittest.TestCase):
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4)
|
||||
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from _torch.helpers import create_mock_engine
|
||||
from parameterized import parameterized
|
||||
from transformers import Llama4Config
|
||||
from transformers import \
|
||||
@ -20,8 +21,7 @@ from tensorrt_llm._torch.models.checkpoints.hf.llama4_weight_mapper import \
|
||||
from tensorrt_llm._torch.models.modeling_llama import \
|
||||
Llama4ForConditionalGeneration
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -403,6 +403,10 @@ class TestLlama4MinLatency(unittest.TestCase):
|
||||
input_ids.size(-1) + gen_input_ids.size(-1))
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
graph_runner = None
|
||||
if scenario.use_cuda_graph:
|
||||
mock_engine = create_mock_engine(1)
|
||||
graph_runner = CUDAGraphRunner(mock_engine)
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
@ -411,19 +415,19 @@ class TestLlama4MinLatency(unittest.TestCase):
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
else:
|
||||
graph_runner = DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata)
|
||||
graph_runner.capture(lambda inputs: llama.forward(**inputs))
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
graph_runner.capture(1, lambda inputs: llama.forward(**inputs),
|
||||
inputs)
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run({
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
})
|
||||
logits = graph_runner.replay(1, inputs)
|
||||
return logits
|
||||
|
||||
if scenario.use_cuda_graph:
|
||||
|
||||
@ -7,6 +7,7 @@ from unittest import mock
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from _torch.helpers import create_mock_engine
|
||||
from utils.util import getSMVersion
|
||||
|
||||
import tensorrt_llm
|
||||
@ -15,7 +16,8 @@ from tensorrt_llm._torch import metadata as metadata_lib
|
||||
from tensorrt_llm._torch import model_config as model_config_lib
|
||||
from tensorrt_llm._torch.attention_backend import utils as attention_utils
|
||||
from tensorrt_llm._torch.models import modeling_mistral
|
||||
from tensorrt_llm._torch.pyexecutor import cuda_graph_runner, resource_manager
|
||||
from tensorrt_llm._torch.pyexecutor import resource_manager
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
|
||||
from tensorrt_llm.bindings import executor as executor_lib
|
||||
from tensorrt_llm.models import modeling_utils
|
||||
|
||||
@ -398,6 +400,11 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
graph_runner = None
|
||||
if use_cuda_graph:
|
||||
mock_engine = create_mock_engine(1)
|
||||
graph_runner = CUDAGraphRunner(mock_engine)
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not use_cuda_graph:
|
||||
@ -405,22 +412,18 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use
|
||||
input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata
|
||||
)
|
||||
else:
|
||||
graph_runner = cuda_graph_runner.DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata
|
||||
)
|
||||
graph_runner.capture(lambda inputs: mistral.forward(**inputs))
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
graph_runner.capture(1, lambda inputs: mistral.forward(**inputs), inputs)
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run(
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
)
|
||||
logits = graph_runner.replay(1, inputs)
|
||||
return logits
|
||||
|
||||
if use_cuda_graph:
|
||||
@ -438,3 +441,5 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use
|
||||
)
|
||||
|
||||
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.4, rtol=0.4)
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
|
||||
@ -3,6 +3,7 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from _torch.helpers import create_mock_engine
|
||||
from parameterized import parameterized
|
||||
from transformers import MixtralConfig
|
||||
from transformers import MixtralForCausalLM as HFMixtralForCausalLM
|
||||
@ -15,8 +16,7 @@ from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.checkpoints.hf.mixtral_weight_mapper import \
|
||||
MixtralHfWeightMapper
|
||||
from tensorrt_llm._torch.models.modeling_mixtral import MixtralForCausalLM
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -310,6 +310,11 @@ class TestMixtral(unittest.TestCase):
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
graph_runner = None
|
||||
if scenario.use_cuda_graph:
|
||||
mock_engine = create_mock_engine(1)
|
||||
graph_runner = CUDAGraphRunner(mock_engine)
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not scenario.use_cuda_graph:
|
||||
@ -317,19 +322,20 @@ class TestMixtral(unittest.TestCase):
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
else:
|
||||
graph_runner = DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata)
|
||||
graph_runner.capture(lambda inputs: mixtral.forward(**inputs))
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
graph_runner.capture(1,
|
||||
lambda inputs: mixtral.forward(**inputs),
|
||||
inputs)
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run({
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
})
|
||||
logits = graph_runner.replay(1, inputs)
|
||||
return logits
|
||||
|
||||
if scenario.use_cuda_graph:
|
||||
@ -348,5 +354,6 @@ class TestMixtral(unittest.TestCase):
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.1,
|
||||
rtol=0.1)
|
||||
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@ -3,6 +3,7 @@ import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from _torch.helpers import create_mock_engine
|
||||
from parameterized import parameterized
|
||||
from test_modeling_llama import Scenario, reduce_llama_config
|
||||
from transformers import MllamaConfig
|
||||
@ -15,8 +16,7 @@ from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_mllama import \
|
||||
MllamaForConditionalGeneration
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -417,6 +417,11 @@ class TestMLlama(unittest.TestCase):
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
graph_runner = None
|
||||
if scenario.use_cuda_graph:
|
||||
mock_engine = create_mock_engine(1)
|
||||
graph_runner = CUDAGraphRunner(mock_engine)
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not scenario.use_cuda_graph:
|
||||
@ -424,19 +429,19 @@ class TestMLlama(unittest.TestCase):
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
else:
|
||||
graph_runner = DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata)
|
||||
graph_runner.capture(lambda inputs: mllama.forward(**inputs))
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
graph_runner.capture(1, lambda inputs: mllama.forward(**inputs),
|
||||
inputs)
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run({
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
})
|
||||
logits = graph_runner.replay(1, inputs)
|
||||
return logits
|
||||
|
||||
if scenario.use_cuda_graph:
|
||||
@ -455,3 +460,6 @@ class TestMLlama(unittest.TestCase):
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.3,
|
||||
rtol=0.3)
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from _torch.helpers import create_mock_engine
|
||||
from parameterized import parameterized
|
||||
from transformers import NemotronConfig
|
||||
from transformers import NemotronForCausalLM as HFNemotronForCausalLM
|
||||
@ -14,8 +15,7 @@ from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_nemotron import NemotronForCausalLM
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -318,6 +318,11 @@ class TestNemotron(unittest.TestCase):
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
graph_runner = None
|
||||
if scenario.use_cuda_graph:
|
||||
mock_engine = create_mock_engine(1)
|
||||
graph_runner = CUDAGraphRunner(mock_engine)
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not scenario.use_cuda_graph:
|
||||
@ -325,19 +330,20 @@ class TestNemotron(unittest.TestCase):
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
else:
|
||||
graph_runner = DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata)
|
||||
graph_runner.capture(lambda inputs: nemotron.forward(**inputs))
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
graph_runner.capture(1,
|
||||
lambda inputs: nemotron.forward(**inputs),
|
||||
inputs)
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run({
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
})
|
||||
logits = graph_runner.replay(1, inputs)
|
||||
return logits
|
||||
|
||||
if scenario.use_cuda_graph:
|
||||
@ -357,4 +363,6 @@ class TestNemotron(unittest.TestCase):
|
||||
atol=0.4,
|
||||
rtol=0.4)
|
||||
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from _torch.helpers import create_mock_engine
|
||||
from transformers import Phi3Config
|
||||
from transformers import Phi3ForCausalLM as HFPhi3ForCausalLM
|
||||
from utils.util import default_dtype
|
||||
@ -13,8 +14,7 @@ from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_phi3 import Phi3ForCausalLM
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -310,6 +310,11 @@ class TestPhi3(unittest.TestCase):
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
graph_runner = None
|
||||
if scenario.use_cuda_graph:
|
||||
mock_engine = create_mock_engine(1)
|
||||
graph_runner = CUDAGraphRunner(mock_engine)
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not scenario.use_cuda_graph:
|
||||
@ -317,19 +322,19 @@ class TestPhi3(unittest.TestCase):
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
else:
|
||||
graph_runner = DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata)
|
||||
graph_runner.capture(lambda inputs: phi3.forward(**inputs))
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
graph_runner.capture(1, lambda inputs: phi3.forward(**inputs),
|
||||
inputs)
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run({
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
})
|
||||
logits = graph_runner.replay(1, inputs)
|
||||
return logits
|
||||
|
||||
if scenario.use_cuda_graph:
|
||||
@ -348,5 +353,6 @@ class TestPhi3(unittest.TestCase):
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4)
|
||||
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@ -17,12 +17,12 @@ from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_qwen import (
|
||||
Qwen2ForCausalLM, Qwen2ForProcessRewardModel)
|
||||
# yapf: enable
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from _torch.helpers import create_mock_engine
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
|
||||
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import getSMVersion
|
||||
@ -265,6 +265,11 @@ class TestQwen(unittest.TestCase):
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
graph_runner = None
|
||||
if scenario.use_cuda_graph:
|
||||
mock_engine = create_mock_engine(1)
|
||||
graph_runner = CUDAGraphRunner(mock_engine)
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not scenario.use_cuda_graph:
|
||||
@ -272,19 +277,19 @@ class TestQwen(unittest.TestCase):
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
else:
|
||||
graph_runner = DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata)
|
||||
graph_runner.capture(lambda inputs: qwen.forward(**inputs))
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
graph_runner.capture(1, lambda inputs: qwen.forward(**inputs),
|
||||
inputs)
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run({
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
})
|
||||
logits = graph_runner.replay(1, inputs)
|
||||
return logits
|
||||
|
||||
if scenario.use_cuda_graph:
|
||||
@ -303,7 +308,8 @@ class TestQwen(unittest.TestCase):
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4)
|
||||
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@parameterized.expand(
|
||||
|
||||
@ -3,6 +3,7 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from _torch.helpers import create_mock_engine
|
||||
from parameterized import parameterized
|
||||
from transformers import Qwen2MoeConfig
|
||||
from transformers import Qwen2MoeForCausalLM as HFQwen2MoeForCausalLM
|
||||
@ -15,8 +16,7 @@ from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \
|
||||
Qwen2MoeHfWeightMapper
|
||||
from tensorrt_llm._torch.models.modeling_qwen_moe import Qwen2MoeForCausalLM
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -315,6 +315,11 @@ class TestQwenMoe(unittest.TestCase):
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
graph_runner = None
|
||||
if scenario.use_cuda_graph:
|
||||
mock_engine = create_mock_engine(1)
|
||||
graph_runner = CUDAGraphRunner(mock_engine)
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not scenario.use_cuda_graph:
|
||||
@ -322,19 +327,20 @@ class TestQwenMoe(unittest.TestCase):
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
else:
|
||||
graph_runner = DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata)
|
||||
graph_runner.capture(lambda inputs: qwen_moe.forward(**inputs))
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
graph_runner.capture(1,
|
||||
lambda inputs: qwen_moe.forward(**inputs),
|
||||
inputs)
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run({
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
})
|
||||
logits = graph_runner.replay(1, inputs)
|
||||
return logits
|
||||
|
||||
if scenario.use_cuda_graph:
|
||||
@ -353,5 +359,6 @@ class TestQwenMoe(unittest.TestCase):
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.1,
|
||||
rtol=0.1)
|
||||
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user