[TRTLLM-6650][fix] Enhance CUDA graph + Beam search to correctly handle padding (#6665)

Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
This commit is contained in:
Stefan Niebler 2025-08-08 14:00:33 +02:00 committed by GitHub
parent e251f7c00b
commit b8f036f264
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 15 deletions

View File

@ -34,6 +34,7 @@ class DecodingCUDAGraphRunner:
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata] = None,
use_mrope: bool = False,
max_beam_width: int = 1,
) -> None:
"""
Stores a CUDA graph and its associated input buffers.
@ -49,19 +50,21 @@ class DecodingCUDAGraphRunner:
e.g. FlashInfer cause graph breaks).
"""
self.batch_size = batch_size
self.max_beam_width = max_beam_width
# [CUDA graph spec decode padding]
# We pad input IDs/position IDs to the maximum draft length (token per request).
# We're forced to do this because we cannot reallocate inputs over many graph runs.
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1
# Using ones instead of zeros prevents NaNs in e.g. Deepseek
self.input_ids = torch.ones((batch_size * token_per_request, ),
device=device,
dtype=torch.int32)
self.position_ids = torch.zeros((1, batch_size * token_per_request),
device=device,
dtype=torch.int32)
self.input_ids = torch.ones(
(batch_size * max_beam_width * token_per_request, ),
device=device,
dtype=torch.int32)
self.position_ids = torch.zeros(
(1, batch_size * max_beam_width * token_per_request),
device=device,
dtype=torch.int32)
self.mrope_position_deltas = torch.zeros(
(batch_size,
1), device=device, dtype=torch.int32) if use_mrope else None

View File

@ -846,8 +846,8 @@ class PyTorchModelEngine(ModelEngine):
spec_resource_manager: Optional[BaseResourceManager] = None) -> int:
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
batch_size = scheduled_requests.batch_size
# The number of sequences in the batch is the number of prompts times the beam width.
new_batch_size = batch_size * self.max_beam_width
new_batch_size = batch_size
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
graph_batch_size = self.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
@ -981,8 +981,8 @@ class PyTorchModelEngine(ModelEngine):
self._cuda_graphs[batch_size] = {}
self._cuda_graphs[batch_size][draft_len] = DecodingCUDAGraphRunner(
num_sequences_in_batch, "cuda", attn_metadata, spec_metadata,
self.use_mrope)
batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope,
self.max_beam_width)
return self._cuda_graphs[batch_size][draft_len]
def __del__(self) -> None:
@ -1376,8 +1376,11 @@ class PyTorchModelEngine(ModelEngine):
gather_ids.append(len(position_ids) - 1)
request_ids.append(request.py_request_id)
gen_request_seq_slots.append(request.py_seq_slot)
request.py_batch_idx = request.py_seq_slot
# Do not add a gen_request_seq_slot for CUDA graph dummy requests
# to prevent access errors due to None values
if not request.is_cuda_graph_dummy:
gen_request_seq_slots.append(request.py_seq_slot)
previous_batch_len = len(previous_batch_indices)
@ -1506,7 +1509,7 @@ class PyTorchModelEngine(ModelEngine):
pin_memory=True,
)
num_generation_requests = len(scheduled_requests.generation_requests)
num_generation_requests = len(gen_request_seq_slots)
# Cache indirection is only used for beam search on generation requests
if self.use_beam_search and num_generation_requests > 0:
# CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph

View File

@ -61,7 +61,8 @@ def llm_cuda_graph(fixed_params, input_prompts):
max_seq_len=32,
max_beam_width=fixed_params["max_beam_width"],
disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(),
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8],
enable_padding=True),
)
@ -126,7 +127,7 @@ def test_beam_search_output_shapes(gather_context_logits: bool,
@pytest.mark.parametrize("gather_generation_logits", [True, False])
@pytest.mark.parametrize("gather_context_logits", [True, False])
@pytest.mark.parametrize("num_output_beams", [1, 2])
@pytest.mark.parametrize("num_prompts", [1, 2])
@pytest.mark.parametrize("num_prompts", [1, 2, 3])
@pytest.mark.threadleak(enabled=False)
def test_beam_search_output_shapes_cuda_graph_and_overlap(
gather_context_logits: bool, gather_generation_logits: bool,
@ -145,6 +146,10 @@ def test_beam_search_output_shapes_cuda_graph_and_overlap(
return_generation_logits=gather_generation_logits,
logprobs=return_log_probs,
)
# test padding of cuda graph with 3 prompts
# replicate the prompts to have more than 2 prompts available
if (num_prompts == 3 and len(input_prompts) == 2):
input_prompts = [input_prompts[0]] * 3
outputs = llm_cuda_graph.generate(input_prompts[:num_prompts],
sampling_params=sampling_params)
assert len(outputs) == num_prompts