mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6650][fix] Enhance CUDA graph + Beam search to correctly handle padding (#6665)
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
This commit is contained in:
parent
e251f7c00b
commit
b8f036f264
@ -34,6 +34,7 @@ class DecodingCUDAGraphRunner:
|
||||
attn_metadata: AttentionMetadata,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
use_mrope: bool = False,
|
||||
max_beam_width: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Stores a CUDA graph and its associated input buffers.
|
||||
@ -49,19 +50,21 @@ class DecodingCUDAGraphRunner:
|
||||
e.g. FlashInfer cause graph breaks).
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.max_beam_width = max_beam_width
|
||||
# [CUDA graph spec decode padding]
|
||||
# We pad input IDs/position IDs to the maximum draft length (token per request).
|
||||
# We're forced to do this because we cannot reallocate inputs over many graph runs.
|
||||
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1
|
||||
|
||||
# Using ones instead of zeros prevents NaNs in e.g. Deepseek
|
||||
self.input_ids = torch.ones((batch_size * token_per_request, ),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.position_ids = torch.zeros((1, batch_size * token_per_request),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.input_ids = torch.ones(
|
||||
(batch_size * max_beam_width * token_per_request, ),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.position_ids = torch.zeros(
|
||||
(1, batch_size * max_beam_width * token_per_request),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.mrope_position_deltas = torch.zeros(
|
||||
(batch_size,
|
||||
1), device=device, dtype=torch.int32) if use_mrope else None
|
||||
|
||||
@ -846,8 +846,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_resource_manager: Optional[BaseResourceManager] = None) -> int:
|
||||
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
|
||||
batch_size = scheduled_requests.batch_size
|
||||
# The number of sequences in the batch is the number of prompts times the beam width.
|
||||
new_batch_size = batch_size * self.max_beam_width
|
||||
new_batch_size = batch_size
|
||||
|
||||
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
graph_batch_size = self.dist.tp_allgather(
|
||||
[can_run_cuda_graph, batch_size])
|
||||
@ -981,8 +981,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self._cuda_graphs[batch_size] = {}
|
||||
|
||||
self._cuda_graphs[batch_size][draft_len] = DecodingCUDAGraphRunner(
|
||||
num_sequences_in_batch, "cuda", attn_metadata, spec_metadata,
|
||||
self.use_mrope)
|
||||
batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope,
|
||||
self.max_beam_width)
|
||||
return self._cuda_graphs[batch_size][draft_len]
|
||||
|
||||
def __del__(self) -> None:
|
||||
@ -1376,8 +1376,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
gather_ids.append(len(position_ids) - 1)
|
||||
|
||||
request_ids.append(request.py_request_id)
|
||||
gen_request_seq_slots.append(request.py_seq_slot)
|
||||
request.py_batch_idx = request.py_seq_slot
|
||||
# Do not add a gen_request_seq_slot for CUDA graph dummy requests
|
||||
# to prevent access errors due to None values
|
||||
if not request.is_cuda_graph_dummy:
|
||||
gen_request_seq_slots.append(request.py_seq_slot)
|
||||
|
||||
previous_batch_len = len(previous_batch_indices)
|
||||
|
||||
@ -1506,7 +1509,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
num_generation_requests = len(scheduled_requests.generation_requests)
|
||||
num_generation_requests = len(gen_request_seq_slots)
|
||||
# Cache indirection is only used for beam search on generation requests
|
||||
if self.use_beam_search and num_generation_requests > 0:
|
||||
# CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph
|
||||
|
||||
@ -61,7 +61,8 @@ def llm_cuda_graph(fixed_params, input_prompts):
|
||||
max_seq_len=32,
|
||||
max_beam_width=fixed_params["max_beam_width"],
|
||||
disable_overlap_scheduler=False,
|
||||
cuda_graph_config=CudaGraphConfig(),
|
||||
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8],
|
||||
enable_padding=True),
|
||||
)
|
||||
|
||||
|
||||
@ -126,7 +127,7 @@ def test_beam_search_output_shapes(gather_context_logits: bool,
|
||||
@pytest.mark.parametrize("gather_generation_logits", [True, False])
|
||||
@pytest.mark.parametrize("gather_context_logits", [True, False])
|
||||
@pytest.mark.parametrize("num_output_beams", [1, 2])
|
||||
@pytest.mark.parametrize("num_prompts", [1, 2])
|
||||
@pytest.mark.parametrize("num_prompts", [1, 2, 3])
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_beam_search_output_shapes_cuda_graph_and_overlap(
|
||||
gather_context_logits: bool, gather_generation_logits: bool,
|
||||
@ -145,6 +146,10 @@ def test_beam_search_output_shapes_cuda_graph_and_overlap(
|
||||
return_generation_logits=gather_generation_logits,
|
||||
logprobs=return_log_probs,
|
||||
)
|
||||
# test padding of cuda graph with 3 prompts
|
||||
# replicate the prompts to have more than 2 prompts available
|
||||
if (num_prompts == 3 and len(input_prompts) == 2):
|
||||
input_prompts = [input_prompts[0]] * 3
|
||||
outputs = llm_cuda_graph.generate(input_prompts[:num_prompts],
|
||||
sampling_params=sampling_params)
|
||||
assert len(outputs) == num_prompts
|
||||
|
||||
Loading…
Reference in New Issue
Block a user