[TRTLLM-6650][feat] Enhance beam search support with CUDA graph integration (#6217)

Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
This commit is contained in:
Stefan Niebler 2025-07-24 18:04:41 +02:00 committed by GitHub
parent ff72ca90de
commit 0df758ec9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 63 additions and 36 deletions

View File

@ -135,6 +135,9 @@ class AttentionMetadata:
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
_num_tokens: int = field(init=False, default=0, repr=False)
# This buffer is currently only used for TrtllmAttentionMetadata.
cache_indirection: Optional[torch.Tensor] = None
def __post_init__(self) -> None:
if self.is_cross:
assert self.cross is None or self.cross is self, "Cross attention metadata should not have sub metadata"

View File

@ -517,10 +517,9 @@ class TrtllmAttentionWrapper:
class TrtllmAttentionMetadata(AttentionMetadata):
workspace: Optional[torch.Tensor] = None
# TrtllmAttention needs to know the beam width and access to the cache indirection buffer,
# TrtllmAttention needs to know the beam width to access to the cache indirection buffer,
# when beam search is enabled.
beam_width: int = 1
cache_indirection: Optional[torch.Tensor] = None
# TrtllmAttention needs to know the max sequence length.
# Implemented as a property to support no cache mode.

View File

@ -392,9 +392,6 @@ class PyTorchModelEngine(ModelEngine):
self._cuda_graphs = {}
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
self._run_cuda_graphs = pytorch_backend_config.use_cuda_graph
if self._run_cuda_graphs and self.max_beam_width > 1:
raise NotImplementedError(
"CUDA Graph + beam search is not implemented yet.")
self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled
@ -425,6 +422,17 @@ class PyTorchModelEngine(ModelEngine):
self.lora_model_config: Optional[LoraModelConfig] = None
self.cuda_graph_dummy_request = None
# Setup the local cache indirection buffer only once and reuse it.
# This way it can also be used for CUDA graphs.
if self.use_beam_search:
self.cache_indirection_attention = torch.zeros(
(self.batch_size, self.max_beam_width, self.max_seq_len +
(0 if self._disable_overlap_scheduler else 1)),
device="cuda",
dtype=torch.int32)
else:
self.cache_indirection_attention = None
def set_lora_model_config(self, lora_target_modules: list[str],
trtllm_modules_to_hf_modules: dict[str, str]):
self.lora_model_config = LoraModelConfig(
@ -444,6 +452,10 @@ class PyTorchModelEngine(ModelEngine):
logger.info(f"Detected use_mrope: {use_mrope}")
return use_mrope
@property
def use_beam_search(self):
return self.max_beam_width > 1
@contextmanager
def set_warmup_flag(self):
self.in_warmup = True
@ -487,7 +499,9 @@ class PyTorchModelEngine(ModelEngine):
self.cuda_graph_dummy_request = None
def get_cuda_graph_warmup_request(batch_size):
available_blocks = kv_cache_manager.get_num_free_blocks()
# Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel.
available_blocks = kv_cache_manager.get_num_free_blocks(
) // self.max_beam_width
if available_blocks >= batch_size:
result = ScheduledRequests()
result.context_requests = []
@ -498,9 +512,10 @@ class PyTorchModelEngine(ModelEngine):
is_gen=True,
max_num_draft_tokens=self.max_draft_len,
use_mrope=use_mrope,
)
max_beam_width=self.max_beam_width)
# Divide by max_beam_width to get an approximation of the number of tokens that can be added to the final request.
available_tokens = kv_cache_manager.get_num_available_tokens(
self.max_draft_len)
self.max_draft_len) // self.max_beam_width
# Add one dummy request with the maximum possible sequence length.
# The sequence length is limited by both the max_seq_len and the number of available blocks.
@ -511,7 +526,7 @@ class PyTorchModelEngine(ModelEngine):
is_gen=True,
max_num_draft_tokens=self.max_draft_len,
use_mrope=use_mrope,
)[0]
max_beam_width=self.max_beam_width)[0]
# Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case.
# This batch contains both the longest request and the shortest requests,
# it also contains the maximum number of requests and the maximum token number,
@ -739,6 +754,7 @@ class PyTorchModelEngine(ModelEngine):
self.model.model_config.pretrained_config) and (
self.attn_runtime_features.cache_reuse
or self.attn_runtime_features.chunked_prefill)
cache_indirection = self.cache_indirection_attention if self.attn_backend.Metadata is TrtllmAttentionMetadata else None
if kv_cache_manager is None:
return self.attn_backend.Metadata(
max_num_requests=self.batch_size,
@ -748,7 +764,8 @@ class PyTorchModelEngine(ModelEngine):
mapping=self.mapping,
runtime_features=self.attn_runtime_features,
enable_flash_mla=self.model.model_config.enable_flash_mla,
enable_paged_context_mla=enable_paged_context_mla)
enable_paged_context_mla=enable_paged_context_mla,
cache_indirection=cache_indirection)
if self.attn_metadata is not None:
# This assertion can be relaxed if needed: just create a new metadata
@ -764,7 +781,9 @@ class PyTorchModelEngine(ModelEngine):
mapping=self.mapping,
runtime_features=self.attn_runtime_features,
enable_flash_mla=self.model.model_config.enable_flash_mla,
enable_paged_context_mla=enable_paged_context_mla)
enable_paged_context_mla=enable_paged_context_mla,
cache_indirection=cache_indirection)
return self.attn_metadata
def _set_up_spec_metadata(
@ -795,7 +814,8 @@ class PyTorchModelEngine(ModelEngine):
kv_cache_manager) -> int:
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
batch_size = scheduled_requests.batch_size
new_batch_size = batch_size
# The number of sequences in the batch is the number of prompts times the beam width.
new_batch_size = batch_size * self.max_beam_width
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
graph_batch_size = self.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
@ -831,7 +851,8 @@ class PyTorchModelEngine(ModelEngine):
[MAX_UINT64 - 1],
is_gen=True,
max_num_draft_tokens=self.max_draft_len,
use_mrope=self.use_mrope)[0]
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width)[0]
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
scheduled_requests.generation_requests.extend(
@ -903,19 +924,21 @@ class PyTorchModelEngine(ModelEngine):
if batch_size not in self._cuda_graph_batch_sizes:
return None
num_sequences_in_batch = batch_size * self.max_beam_width
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
batch_size, False, spec_max_draft_tokens)
num_sequences_in_batch, False, spec_max_draft_tokens)
assert attn_metadata.is_cuda_graph
if self.is_spec_decode:
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
batch_size)
num_sequences_in_batch)
spec_metadata.draft_tokens = self.draft_tokens_cuda
else:
spec_metadata = None
self._cuda_graphs[batch_size] = DecodingCUDAGraphRunner(
batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope)
num_sequences_in_batch, "cuda", attn_metadata, spec_metadata,
self.use_mrope)
return self._cuda_graphs[batch_size]
def __del__(self) -> None:
@ -1439,16 +1462,16 @@ class PyTorchModelEngine(ModelEngine):
num_generation_requests = len(scheduled_requests.generation_requests)
# Cache indirection is only used for beam search on generation requests
if self.max_beam_width > 1 and num_generation_requests > 0 and cache_indirection_buffer is not None:
cache_indirection_attention = torch.zeros_like(
cache_indirection_buffer)
#Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
cache_indirection_attention[:num_generation_requests].copy_(
cache_indirection_buffer[gen_request_seq_slots])
attn_metadata.cache_indirection = cache_indirection_attention
attn_metadata.beam_width = self.max_beam_width
if self.use_beam_search and num_generation_requests > 0:
# CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph
is_cuda_graph_during_warmup = self.in_warmup and attn_metadata.is_cuda_graph
if cache_indirection_buffer is not None:
#Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
self.cache_indirection_attention[:num_generation_requests].copy_(
cache_indirection_buffer[gen_request_seq_slots])
if cache_indirection_buffer is not None or is_cuda_graph_during_warmup:
attn_metadata.beam_width = self.max_beam_width
else:
attn_metadata.cache_indirection = None
attn_metadata.beam_width = 1
attn_metadata.request_ids = request_ids

View File

@ -375,11 +375,15 @@ class KVCacheManager(BaseResourceManager):
prepare_resource: bool = True,
max_num_draft_tokens: int = 0,
use_mrope: bool = False,
max_beam_width: int = 1,
):
beam_width = 1 # TODO: more than 1 beam?
beam_width = max_beam_width
requests = []
for i, req_id in enumerate(request_ids):
sampling_params = SamplingParams()
# exact choice of n can be ignored for dummy requests
sampling_params = SamplingParams(n=beam_width,
best_of=beam_width,
use_beam_search=beam_width > 1)
# Here 1+max_num_draft_tokens is used to extend the prompt length to
# a non-zero number to skip illegal memory access issue in MLA kernel
# during warmup.

View File

@ -5,7 +5,7 @@ from utils.llm_data import llm_models_root
from utils.util import force_ampere, similar
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi.llm_utils import KvCacheConfig
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig
@pytest.fixture(scope="module")
@ -46,13 +46,12 @@ def llm(fixed_params, input_prompts):
enable_trtllm_sampler=True,
max_beam_width=fixed_params["max_beam_width"],
disable_overlap_scheduler=True,
#TODO: remove this once we have a proper fix for CUDA graph in beam search
cuda_graph_config=None,
)
@pytest.fixture(scope="module")
def llm_overlap(fixed_params, input_prompts):
def llm_cuda_graph(fixed_params, input_prompts):
return LLM(
model=os.path.join(llm_models_root(), "llama-models-v2",
"TinyLlama-1.1B-Chat-v1.0"),
@ -64,8 +63,7 @@ def llm_overlap(fixed_params, input_prompts):
enable_trtllm_sampler=True,
max_beam_width=fixed_params["max_beam_width"],
disable_overlap_scheduler=False,
#TODO: remove this once we have a proper fix for CUDA graph in beam search
cuda_graph_config=None,
cuda_graph_config=CudaGraphConfig(enabled=True),
)
@ -132,10 +130,10 @@ def test_beam_search_output_shapes(gather_context_logits: bool,
@pytest.mark.parametrize("num_output_beams", [1, 2])
@pytest.mark.parametrize("num_prompts", [1, 2])
@pytest.mark.threadleak(enabled=False)
def test_beam_search_output_shapes_overlap(
def test_beam_search_output_shapes_cuda_graph_and_overlap(
gather_context_logits: bool, gather_generation_logits: bool,
return_log_probs: bool, num_output_beams: int, num_prompts: int,
llm_overlap, fixed_params, input_prompts, expected_outputs):
llm_cuda_graph, fixed_params, input_prompts, expected_outputs):
if return_log_probs and num_prompts > 1:
pytest.skip(
"Beam search currently does not support return_log_probs with multiple prompts"
@ -149,8 +147,8 @@ def test_beam_search_output_shapes_overlap(
return_generation_logits=gather_generation_logits,
logprobs=return_log_probs,
)
outputs = llm_overlap.generate(input_prompts[:num_prompts],
sampling_params=sampling_params)
outputs = llm_cuda_graph.generate(input_prompts[:num_prompts],
sampling_params=sampling_params)
assert len(outputs) == num_prompts
for output_idx, output in enumerate(outputs):
if gather_context_logits: