[https://nvbugs/5517404][fix] Use the correct cuda graph for dynamic spec dec (#7728)

Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
Ziyi Xiong 2025-09-21 08:20:48 +08:00 committed by GitHub
parent 4509d97780
commit 897c4dd23b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 170 additions and 125 deletions

View File

@ -40,9 +40,6 @@ class CUDAGraphRunner:
self.max_beam_width = engine.max_beam_width
self.spec_config = engine.spec_config
self.max_possible_draft_len = (self.spec_config.max_draft_len
if self.enable_spec_decode else 0)
self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
self.graph_outputs: Dict[Tuple[int, int],
Callable[[], Optional[torch.Tensor]]] = {}
@ -58,7 +55,7 @@ class CUDAGraphRunner:
"""Allocates static tensors sized for the largest possible batch."""
engine = self._get_engine()
token_per_request = self.max_possible_draft_len + 1
token_per_request = self.draft_len + 1
max_total_tokens = (self.max_supported_batch_size *
self.max_beam_width * token_per_request)
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)
@ -78,7 +75,7 @@ class CUDAGraphRunner:
@property
def enable_spec_decode(self):
return self._get_engine().is_spec_decode
return self._get_engine().enable_spec_decode
@property
def draft_len(self):
@ -174,7 +171,7 @@ class CUDAGraphRunner:
# [CUDA graph spec decode padding]
# We pad input IDs/position IDs to the maximum draft length (token per request).
# We're forced to do this because we cannot reallocate inputs over many graph runs.
token_per_request = self.max_possible_draft_len + 1
token_per_request = self.draft_len + 1
num_tokens_for_capture = (batch_size * self.max_beam_width *
token_per_request)

View File

@ -1511,7 +1511,6 @@ class PyTorchModelEngine(ModelEngine):
prompt_lengths.append(1 + self.runtime_draft_len)
else:
prompt_lengths.append(request.py_prompt_len)
for request in generation_requests:
request_ids.append(request.py_request_id)
beam_width = request.sampling_config.beam_width
@ -1534,7 +1533,6 @@ class PyTorchModelEngine(ModelEngine):
if beam == first_beam:
previous_batch_indices.append(request.py_batch_idx)
past_seen_token_num = request.max_beam_num_tokens
position_ids.append(past_seen_token_num)
num_cached_tokens_per_seq.append(past_seen_token_num)
prompt_lengths.append(request.py_prompt_len)

View File

@ -1198,13 +1198,18 @@ class PyExecutor:
previous_tensors = self.previous_batch and self.previous_batch.sample_state
target_inputs = None
draft_outputs = None
if self.drafter is not None and self.use_spec_decode:
# If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
# When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
use_previous_draft_tokens = self.has_previous_draft_tokens
if self.drafter is not None and (self.use_spec_decode or
use_previous_draft_tokens):
target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding(
scheduled_batch, previous_tensors)
# Use the draft_model's outputs if we've launched the draft model.
# Otherwise, use the previous batch's outputs.
if target_inputs is not None:
if target_inputs is not None or use_previous_draft_tokens:
previous_tensors_device = target_inputs
else:
previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device
@ -1215,7 +1220,7 @@ class PyExecutor:
if target_inputs is not None:
self._process_draft_results(scheduled_batch,
draft_outputs, draft_batch)
elif self.previous_batch is not None:
elif self.previous_batch is not None and not use_previous_draft_tokens:
self._update_requests(self.previous_batch.sample_state)
if self.guided_decoder is not None:
@ -1968,19 +1973,21 @@ class PyExecutor:
self.inflight_req_ids.erase(req.request_id)
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
with request_context(is_draft=True, scheduled_requests=scheduled_batch):
with request_context(is_draft=self.draft_model_engine is not None,
scheduled_requests=scheduled_batch):
# Do an early checking to see if we need to forward the draft model.
# If needed, the overlap should happen between the target requests and the draft requests.
# Otherwise, we can still do overlap between the previous target requests and the current target requests.
has_draft_batch = (
self.previous_batch is not None
self.previous_batch is not None and self.use_spec_decode
and self.drafter.should_forward_draft_model(scheduled_batch))
if has_draft_batch:
if has_draft_batch or self.has_previous_draft_tokens:
self._update_requests(self.previous_batch.sample_state)
if self.has_previous_draft_tokens:
self._prepare_draft_requests()
if has_draft_batch:
target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap(
scheduled_batch, self.resource_manager,
previous_tensors.device if previous_tensors else None)
@ -1997,26 +2004,27 @@ class PyExecutor:
"""
Append the draft tokens to the target requests, and clean up the draft resources.
"""
req_id_to_old_request = {
req.py_request_id: req
for req in scheduled_batch.all_requests()
}
with request_context(is_draft=self.draft_model_engine is not None,
scheduled_requests=scheduled_batch):
req_id_to_old_request = {
req.py_request_id: req
for req in scheduled_batch.all_requests()
}
if self.drafter.use_static_draft_loop:
self.drafter.process_static_draft_outputs(draft_outputs,
draft_batch,
req_id_to_old_request)
elif draft_outputs is not None:
self.drafter.process_dynamic_draft_outputs(draft_outputs,
req_id_to_old_request)
if self.drafter.use_static_draft_loop:
self.drafter.process_static_draft_outputs(
draft_outputs, draft_batch, req_id_to_old_request)
elif draft_outputs is not None:
self.drafter.process_dynamic_draft_outputs(
draft_outputs, req_id_to_old_request)
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
# add_batch must be called again to restore to target requests with updated draft tokens.
if self.guided_decoder is not None:
self.guided_decoder.add_batch(scheduled_batch)
if hasattr(self.drafter, "guided_decoder"):
self.guided_decoder.rollback_draft_tokens()
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
# add_batch must be called again to restore to target requests with updated draft tokens.
if self.guided_decoder is not None:
self.guided_decoder.add_batch(scheduled_batch)
if hasattr(self.drafter, "guided_decoder"):
self.guided_decoder.rollback_draft_tokens()
class DisaggPPTerminationHandler:

View File

@ -41,11 +41,10 @@ from .model_engine import PyTorchModelEngine
from .py_executor import PyExecutor
# Development flag to control chain drafter feature
# Development function to control chain drafter feature.
# It's here so that unit tests can mock it and turn it off.
def _get_allow_chain_drafter() -> bool:
"""Get the chain drafter flag from environment variable."""
# Use environment variable for cross-process compatibility
return os.getenv("TRTLLM_ALLOW_CHAIN_DRAFTER", "0") == "1"
return True
class _ExecutorCreationStage(enum.Enum):

View File

@ -563,6 +563,9 @@ class TorchSampler(Sampler):
if get_draft_token_length(req) > 0:
req.py_num_accepted_draft_tokens = num_accepted
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
else:
req.py_num_accepted_draft_tokens = 0
req.py_rewind_len = 0
processed += num_accepted
self.handle_logprobs(req, state, beam=self.BEAM, count=processed)
req.py_decoding_iter += 1

View File

@ -396,6 +396,7 @@ class ModelDrafter(Drafter):
new_tokens_lens = None
next_draft_tokens = None
has_draft_tokens = False
batch_size = new_tokens.shape[1]
# Iterate through generation requests and copy tokens based on accepted draft tokens
for request in scheduled_batch.all_requests():
idx = request.py_seq_slot
@ -411,9 +412,8 @@ class ModelDrafter(Drafter):
if has_draft_tokens:
# We already updated the target state, so the new_tokens_lens should be all ones.
new_tokens_lens = torch.ones(scheduled_batch.batch_size,
device=device)
next_draft_tokens = torch.zeros(scheduled_batch.batch_size,
new_tokens_lens = torch.ones(batch_size, device=device)
next_draft_tokens = torch.zeros(batch_size,
self.max_draft_tokens,
device=device)
@ -438,15 +438,15 @@ class ModelDrafter(Drafter):
Update target inputs with new draft tokens from sample state.
"""
if draft_tensors is not None:
for request in draft_batch.all_requests():
for req_idx, request in enumerate(draft_batch.all_requests()):
# Skip prefill requests
if target_inputs.next_draft_tokens is None:
continue
# Get the index of the draft/target tokens in the device tensor
draft_idx = request.py_seq_slot
draft_idx = req_idx if self.use_static_draft_loop else request.py_batch_idx
target_idx = req_id_to_old_request[
request.py_request_id].py_seq_slot
request.py_request_id].py_batch_idx
target_inputs.new_tokens[draft_position + 1:draft_position +
draft_length + 1, target_idx,
0] = draft_tensors[0:draft_length,

View File

@ -188,6 +188,7 @@ def create_mock_engine(batch_size: int):
max_beam_width=1,
max_num_tokens=8192,
is_spec_decode=False,
enable_spec_decode=False,
spec_config=None,
_cuda_graph_mem_pool=None,
use_mrope=False,

View File

@ -1,22 +1,32 @@
import os
import sys
import unittest
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
import torch
from utils.llm_data import llm_models_root
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
KvCacheConfig)
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@pytest.fixture(scope="function")
def enforce_single_worker(monkeypatch):
monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1")
yield
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
@pytest.mark.high_cuda_memory
def test_dynamic_spec_decode(disable_overlap_scheduler: bool):
def test_dynamic_spec_decode(enforce_single_worker,
disable_overlap_scheduler: bool):
# mock_should_use_spec_decode doesn't work with multiple processes,
# so we use the enforce_single_worker fixture to set the environment variable.
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
pytest.skip("Not enough memory to load target + draft model")
@ -51,32 +61,42 @@ def test_dynamic_spec_decode(disable_overlap_scheduler: bool):
eagle3_one_model=False,
)
# Mock should_use_spec_decode to return True for first two calls, then False
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
sampling_params = SamplingParams(max_tokens=128, temperature=0)
# Output tests
prompts = [
"The president of the United States is",
]
sampling_params = SamplingParams(max_tokens=20, temperature=0)
# Mock should_use_spec_decode to turn on/off spec decode dynamically.
def mock_should_use_spec_decode(requests, max_batch_size, max_num_tokens,
max_draft_len):
if not hasattr(mock_should_use_spec_decode, 'call_count'):
mock_should_use_spec_decode.call_count = 0
mock_should_use_spec_decode.call_count += 1
return mock_should_use_spec_decode.call_count <= 2
for req in requests:
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
continue
mock_should_use_spec_decode.call_count += 1
# Turn off spec decode when we've called it 5 times.
# In the current case, at the 5th call, there are 2 accepted draft tokens,
# so we can have better coverage for the switching between spec decode on and off.
if mock_should_use_spec_decode.call_count > 5:
return False
return True
# Create a Mock object with the mock function as side_effect
mock_should_use_spec_decode = Mock(side_effect=mock_should_use_spec_decode)
# Reset mock state before using it
mock_should_use_spec_decode.reset_mock()
mock_should_use_spec_decode.call_count = 0
with patch(
'tensorrt_llm._torch.speculative.model_drafter.ModelDrafter.should_use_spec_decode',
side_effect=mock_should_use_spec_decode):
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
sampling_params = SamplingParams(max_tokens=128, temperature=0)
# Output tests
prompts = [
"The capital of France is",
"The president of the United States is",
]
sampling_params = SamplingParams(max_tokens=10, temperature=0)
mock_should_use_spec_decode):
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [
result.outputs[0].text for result in results_spec
]
llm_spec.shutdown()
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()
llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)

View File

@ -4,6 +4,7 @@ import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
import pytest
import torch
@ -16,35 +17,50 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@pytest.fixture(scope="function")
def enforce_single_worker(monkeypatch):
monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1")
yield
@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter",
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch",
[
[True, "TRTLLM", True, False, False, False, True],
[True, "TRTLLM", True, False, False, False, False],
[False, "TRTLLM", True, False, False, False, True],
[False, "TRTLLM", True, False, False, False, False],
[True, "FLASHINFER", True, False, False, False, True],
[False, "FLASHINFER", True, False, False, False, True],
[False, "TRTLLM", False, True, True, False, True],
[True, "TRTLLM", False, True, True, False, True],
[True, "TRTLLM", True, False, True, True, True],
[True, "TRTLLM", True, False, True, False, True],
[True, "TRTLLM", True, False, False, False, True, False],
[True, "TRTLLM", True, False, False, False, False, False],
[False, "TRTLLM", True, False, False, False, True, False],
[False, "TRTLLM", True, False, False, False, False, False],
[True, "FLASHINFER", True, False, False, False, True, False],
[False, "FLASHINFER", True, False, False, False, True, False],
[False, "TRTLLM", False, True, True, False, True, False],
[True, "TRTLLM", False, True, True, False, True, False],
[True, "TRTLLM", True, False, True, True, True, False],
[True, "TRTLLM", True, False, True, False, True, False],
# TODO: nvbugs/5461761
# [True, "TRTLLM", True, False, False, True, True],
[True, "TRTLLM", False, False, False, False, True],
[False, "TRTLLM", False, False, False, False, True],
[True, "TRTLLM", False, False, False, False, False],
[False, "TRTLLM", False, False, False, False, False],
[True, "TRTLLM", False, False, False, True, True],
[True, "TRTLLM", False, False, False, True, False],
[True, "FLASHINFER", False, False, False, False, True],
[False, "FLASHINFER", False, False, False, False, True],
# [True, "TRTLLM", True, False, False, True, True, False],
[True, "TRTLLM", False, False, False, False, True, False],
[False, "TRTLLM", False, False, False, False, True, False],
[True, "TRTLLM", False, False, False, False, False, True],
[False, "TRTLLM", False, False, False, False, False, True],
[True, "TRTLLM", False, False, False, False, True, True],
[False, "TRTLLM", False, False, False, False, True, True],
[True, "TRTLLM", False, False, False, False, False, False],
[False, "TRTLLM", False, False, False, False, False, False],
[True, "TRTLLM", False, False, False, True, True, False],
[True, "TRTLLM", False, False, False, True, False, False],
[True, "FLASHINFER", False, False, False, False, True, False],
[False, "FLASHINFER", False, False, False, False, True, False],
])
@pytest.mark.high_cuda_memory
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
disable_overlap_scheduler: bool, enable_block_reuse: bool,
use_one_model: bool, enable_chunked_prefill: bool,
use_chain_drafter: bool):
use_chain_drafter: bool, multi_batch: bool, request):
# Use enforce_single_worker fixture only when use_chain_drafter is False.
# Otherwise, we can't modify the returned value of _get_allow_chain_drafter in multiprocessing.
if not use_chain_drafter:
request.getfixturevalue('enforce_single_worker')
# Eagle3 one model works with overlap scheduler and block reuse.
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
@ -54,46 +70,52 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
# that ref and spec does not match 100%
max_batch_size = 1
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
max_tokens=8192)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1]) if use_cuda_graph else None
# Mock _get_allow_chain_drafter to return False when use_chain_drafter is False
if not use_chain_drafter:
patch_context = patch(
'tensorrt_llm._torch.pyexecutor.py_executor_creator._get_allow_chain_drafter',
return_value=False)
else:
patch_context = patch(
'tensorrt_llm._torch.pyexecutor.py_executor_creator._get_allow_chain_drafter',
return_value=True)
llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
# This max_seq_len is larger than the one specified
# in the llama 3 8B eagle's config. We want to make sure
# that the draft model won't go above its max in warmup
# in this test.
max_seq_len=8192,
enable_chunked_prefill=enable_chunked_prefill,
)
if enable_chunked_prefill:
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
llm_common_config['max_num_tokens'] = 64
with patch_context:
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
# that ref and spec does not match 100%
max_batch_size = 4 if multi_batch else 1
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
max_tokens=8192)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1]) if use_cuda_graph else None
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=use_one_model,
)
llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
# This max_seq_len is larger than the one specified
# in the llama 3 8B eagle's config. We want to make sure
# that the draft model won't go above its max in warmup
# in this test.
max_seq_len=8192,
enable_chunked_prefill=enable_chunked_prefill,
)
if enable_chunked_prefill:
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
llm_common_config['max_num_tokens'] = 64
# Set the development flag to control use_chain_drafter behavior
original_env_value = os.environ.get("TRTLLM_ALLOW_CHAIN_DRAFTER", "0")
try:
os.environ[
"TRTLLM_ALLOW_CHAIN_DRAFTER"] = "1" if use_chain_drafter else "0"
# Create the LLM instance with the mocked flag controlling use_chain_drafter
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=use_one_model,
)
# Create the LLM instance
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
# Acceptance rate tests
@ -142,9 +164,6 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref
finally:
# Restore the original environment variable value
os.environ["TRTLLM_ALLOW_CHAIN_DRAFTER"] = original_env_value
def test_deepseek_eagle3():