mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5534705][fix] Skip unnecessary CUDA graph capture (#8… (#8344)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
parent
838958c631
commit
4ad7ef1497
@ -712,19 +712,25 @@ class PyTorchModelEngine(ModelEngine):
|
||||
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
|
||||
reverse=True)
|
||||
# Create CUDA graphs for different draft lengths
|
||||
draft_lengths = [self.max_draft_len]
|
||||
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
|
||||
# so that when we disable spec decode at runtime, we can still run the captured graph.
|
||||
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
|
||||
if (not self.is_draft_model and self.max_draft_len > 0
|
||||
and not self.spec_config.spec_dec_mode.use_one_engine()
|
||||
# Assume that speculation is always on if the user didn't give us a max_concurrency
|
||||
# value. This will save on memory.
|
||||
and self.spec_config.max_concurrency is not None):
|
||||
draft_lengths.append(0)
|
||||
if self.is_spec_decode and self.is_draft_model and spec_resource_manager is not None and isinstance(
|
||||
spec_resource_manager, Eagle3ResourceManager):
|
||||
draft_lengths.append(self.original_max_draft_len)
|
||||
draft_lengths = []
|
||||
if self.is_draft_model:
|
||||
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
|
||||
spec_resource_manager, Eagle3ResourceManager):
|
||||
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
|
||||
draft_lengths.append(self.original_max_draft_len)
|
||||
else:
|
||||
draft_lengths.append(self.max_draft_len)
|
||||
else:
|
||||
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
|
||||
# so that when we disable spec decode at runtime, we can still run the captured graph.
|
||||
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
|
||||
if (self.max_draft_len > 0
|
||||
and not self.spec_config.spec_dec_mode.use_one_engine()
|
||||
# Assume that speculation is always on if the user didn't give us a max_concurrency
|
||||
# value. This will save on memory.
|
||||
and self.spec_config.max_concurrency is not None):
|
||||
draft_lengths.append(0)
|
||||
draft_lengths = [self.max_draft_len]
|
||||
|
||||
for bs in cuda_graph_batch_sizes:
|
||||
if bs > self.batch_size:
|
||||
@ -740,6 +746,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
logger.info(
|
||||
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
|
||||
)
|
||||
|
||||
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
|
||||
|
||||
def _update_draft_inference_state(is_first_draft: bool,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user