[https://nvbugs/5534705][fix] Skip unnecessary CUDA graph capture (#8… (#8344)

Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
Ziyi Xiong 2025-10-16 10:27:19 +08:00 committed by GitHub
parent 838958c631
commit 4ad7ef1497
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -712,19 +712,25 @@ class PyTorchModelEngine(ModelEngine):
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
reverse=True)
# Create CUDA graphs for different draft lengths
draft_lengths = [self.max_draft_len]
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
# so that when we disable spec decode at runtime, we can still run the captured graph.
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
if (not self.is_draft_model and self.max_draft_len > 0
and not self.spec_config.spec_dec_mode.use_one_engine()
# Assume that speculation is always on if the user didn't give us a max_concurrency
# value. This will save on memory.
and self.spec_config.max_concurrency is not None):
draft_lengths.append(0)
if self.is_spec_decode and self.is_draft_model and spec_resource_manager is not None and isinstance(
spec_resource_manager, Eagle3ResourceManager):
draft_lengths.append(self.original_max_draft_len)
draft_lengths = []
if self.is_draft_model:
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
spec_resource_manager, Eagle3ResourceManager):
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
draft_lengths.append(self.original_max_draft_len)
else:
draft_lengths.append(self.max_draft_len)
else:
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
# so that when we disable spec decode at runtime, we can still run the captured graph.
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
if (self.max_draft_len > 0
and not self.spec_config.spec_dec_mode.use_one_engine()
# Assume that speculation is always on if the user didn't give us a max_concurrency
# value. This will save on memory.
and self.spec_config.max_concurrency is not None):
draft_lengths.append(0)
draft_lengths = [self.max_draft_len]
for bs in cuda_graph_batch_sizes:
if bs > self.batch_size:
@ -740,6 +746,7 @@ class PyTorchModelEngine(ModelEngine):
logger.info(
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
)
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
def _update_draft_inference_state(is_first_draft: bool,