mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6650][feat] Enhance beam search support with CUDA graph integration (#6217)
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
This commit is contained in:
parent
ff72ca90de
commit
0df758ec9f
@ -135,6 +135,9 @@ class AttentionMetadata:
|
||||
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
|
||||
_num_tokens: int = field(init=False, default=0, repr=False)
|
||||
|
||||
# This buffer is currently only used for TrtllmAttentionMetadata.
|
||||
cache_indirection: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.is_cross:
|
||||
assert self.cross is None or self.cross is self, "Cross attention metadata should not have sub metadata"
|
||||
|
||||
@ -517,10 +517,9 @@ class TrtllmAttentionWrapper:
|
||||
class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
workspace: Optional[torch.Tensor] = None
|
||||
|
||||
# TrtllmAttention needs to know the beam width and access to the cache indirection buffer,
|
||||
# TrtllmAttention needs to know the beam width to access to the cache indirection buffer,
|
||||
# when beam search is enabled.
|
||||
beam_width: int = 1
|
||||
cache_indirection: Optional[torch.Tensor] = None
|
||||
|
||||
# TrtllmAttention needs to know the max sequence length.
|
||||
# Implemented as a property to support no cache mode.
|
||||
|
||||
@ -392,9 +392,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self._cuda_graphs = {}
|
||||
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
|
||||
self._run_cuda_graphs = pytorch_backend_config.use_cuda_graph
|
||||
if self._run_cuda_graphs and self.max_beam_width > 1:
|
||||
raise NotImplementedError(
|
||||
"CUDA Graph + beam search is not implemented yet.")
|
||||
|
||||
self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled
|
||||
|
||||
@ -425,6 +422,17 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.lora_model_config: Optional[LoraModelConfig] = None
|
||||
self.cuda_graph_dummy_request = None
|
||||
|
||||
# Setup the local cache indirection buffer only once and reuse it.
|
||||
# This way it can also be used for CUDA graphs.
|
||||
if self.use_beam_search:
|
||||
self.cache_indirection_attention = torch.zeros(
|
||||
(self.batch_size, self.max_beam_width, self.max_seq_len +
|
||||
(0 if self._disable_overlap_scheduler else 1)),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
self.cache_indirection_attention = None
|
||||
|
||||
def set_lora_model_config(self, lora_target_modules: list[str],
|
||||
trtllm_modules_to_hf_modules: dict[str, str]):
|
||||
self.lora_model_config = LoraModelConfig(
|
||||
@ -444,6 +452,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
logger.info(f"Detected use_mrope: {use_mrope}")
|
||||
return use_mrope
|
||||
|
||||
@property
|
||||
def use_beam_search(self):
|
||||
return self.max_beam_width > 1
|
||||
|
||||
@contextmanager
|
||||
def set_warmup_flag(self):
|
||||
self.in_warmup = True
|
||||
@ -487,7 +499,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.cuda_graph_dummy_request = None
|
||||
|
||||
def get_cuda_graph_warmup_request(batch_size):
|
||||
available_blocks = kv_cache_manager.get_num_free_blocks()
|
||||
# Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel.
|
||||
available_blocks = kv_cache_manager.get_num_free_blocks(
|
||||
) // self.max_beam_width
|
||||
if available_blocks >= batch_size:
|
||||
result = ScheduledRequests()
|
||||
result.context_requests = []
|
||||
@ -498,9 +512,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=self.max_draft_len,
|
||||
use_mrope=use_mrope,
|
||||
)
|
||||
max_beam_width=self.max_beam_width)
|
||||
# Divide by max_beam_width to get an approximation of the number of tokens that can be added to the final request.
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
self.max_draft_len)
|
||||
self.max_draft_len) // self.max_beam_width
|
||||
|
||||
# Add one dummy request with the maximum possible sequence length.
|
||||
# The sequence length is limited by both the max_seq_len and the number of available blocks.
|
||||
@ -511,7 +526,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=self.max_draft_len,
|
||||
use_mrope=use_mrope,
|
||||
)[0]
|
||||
max_beam_width=self.max_beam_width)[0]
|
||||
# Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case.
|
||||
# This batch contains both the longest request and the shortest requests,
|
||||
# it also contains the maximum number of requests and the maximum token number,
|
||||
@ -739,6 +754,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.model.model_config.pretrained_config) and (
|
||||
self.attn_runtime_features.cache_reuse
|
||||
or self.attn_runtime_features.chunked_prefill)
|
||||
cache_indirection = self.cache_indirection_attention if self.attn_backend.Metadata is TrtllmAttentionMetadata else None
|
||||
if kv_cache_manager is None:
|
||||
return self.attn_backend.Metadata(
|
||||
max_num_requests=self.batch_size,
|
||||
@ -748,7 +764,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
mapping=self.mapping,
|
||||
runtime_features=self.attn_runtime_features,
|
||||
enable_flash_mla=self.model.model_config.enable_flash_mla,
|
||||
enable_paged_context_mla=enable_paged_context_mla)
|
||||
enable_paged_context_mla=enable_paged_context_mla,
|
||||
cache_indirection=cache_indirection)
|
||||
|
||||
if self.attn_metadata is not None:
|
||||
# This assertion can be relaxed if needed: just create a new metadata
|
||||
@ -764,7 +781,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
mapping=self.mapping,
|
||||
runtime_features=self.attn_runtime_features,
|
||||
enable_flash_mla=self.model.model_config.enable_flash_mla,
|
||||
enable_paged_context_mla=enable_paged_context_mla)
|
||||
enable_paged_context_mla=enable_paged_context_mla,
|
||||
cache_indirection=cache_indirection)
|
||||
|
||||
return self.attn_metadata
|
||||
|
||||
def _set_up_spec_metadata(
|
||||
@ -795,7 +814,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
kv_cache_manager) -> int:
|
||||
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
|
||||
batch_size = scheduled_requests.batch_size
|
||||
new_batch_size = batch_size
|
||||
# The number of sequences in the batch is the number of prompts times the beam width.
|
||||
new_batch_size = batch_size * self.max_beam_width
|
||||
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
graph_batch_size = self.dist.tp_allgather(
|
||||
[can_run_cuda_graph, batch_size])
|
||||
@ -831,7 +851,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
[MAX_UINT64 - 1],
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=self.max_draft_len,
|
||||
use_mrope=self.use_mrope)[0]
|
||||
use_mrope=self.use_mrope,
|
||||
max_beam_width=self.max_beam_width)[0]
|
||||
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
|
||||
|
||||
scheduled_requests.generation_requests.extend(
|
||||
@ -903,19 +924,21 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if batch_size not in self._cuda_graph_batch_sizes:
|
||||
return None
|
||||
|
||||
num_sequences_in_batch = batch_size * self.max_beam_width
|
||||
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
|
||||
batch_size, False, spec_max_draft_tokens)
|
||||
num_sequences_in_batch, False, spec_max_draft_tokens)
|
||||
assert attn_metadata.is_cuda_graph
|
||||
|
||||
if self.is_spec_decode:
|
||||
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
|
||||
batch_size)
|
||||
num_sequences_in_batch)
|
||||
spec_metadata.draft_tokens = self.draft_tokens_cuda
|
||||
else:
|
||||
spec_metadata = None
|
||||
|
||||
self._cuda_graphs[batch_size] = DecodingCUDAGraphRunner(
|
||||
batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope)
|
||||
num_sequences_in_batch, "cuda", attn_metadata, spec_metadata,
|
||||
self.use_mrope)
|
||||
return self._cuda_graphs[batch_size]
|
||||
|
||||
def __del__(self) -> None:
|
||||
@ -1439,16 +1462,16 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
num_generation_requests = len(scheduled_requests.generation_requests)
|
||||
# Cache indirection is only used for beam search on generation requests
|
||||
if self.max_beam_width > 1 and num_generation_requests > 0 and cache_indirection_buffer is not None:
|
||||
cache_indirection_attention = torch.zeros_like(
|
||||
cache_indirection_buffer)
|
||||
#Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
|
||||
cache_indirection_attention[:num_generation_requests].copy_(
|
||||
cache_indirection_buffer[gen_request_seq_slots])
|
||||
attn_metadata.cache_indirection = cache_indirection_attention
|
||||
attn_metadata.beam_width = self.max_beam_width
|
||||
if self.use_beam_search and num_generation_requests > 0:
|
||||
# CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph
|
||||
is_cuda_graph_during_warmup = self.in_warmup and attn_metadata.is_cuda_graph
|
||||
if cache_indirection_buffer is not None:
|
||||
#Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
|
||||
self.cache_indirection_attention[:num_generation_requests].copy_(
|
||||
cache_indirection_buffer[gen_request_seq_slots])
|
||||
if cache_indirection_buffer is not None or is_cuda_graph_during_warmup:
|
||||
attn_metadata.beam_width = self.max_beam_width
|
||||
else:
|
||||
attn_metadata.cache_indirection = None
|
||||
attn_metadata.beam_width = 1
|
||||
|
||||
attn_metadata.request_ids = request_ids
|
||||
|
||||
@ -375,11 +375,15 @@ class KVCacheManager(BaseResourceManager):
|
||||
prepare_resource: bool = True,
|
||||
max_num_draft_tokens: int = 0,
|
||||
use_mrope: bool = False,
|
||||
max_beam_width: int = 1,
|
||||
):
|
||||
beam_width = 1 # TODO: more than 1 beam?
|
||||
beam_width = max_beam_width
|
||||
requests = []
|
||||
for i, req_id in enumerate(request_ids):
|
||||
sampling_params = SamplingParams()
|
||||
# exact choice of n can be ignored for dummy requests
|
||||
sampling_params = SamplingParams(n=beam_width,
|
||||
best_of=beam_width,
|
||||
use_beam_search=beam_width > 1)
|
||||
# Here 1+max_num_draft_tokens is used to extend the prompt length to
|
||||
# a non-zero number to skip illegal memory access issue in MLA kernel
|
||||
# during warmup.
|
||||
|
||||
@ -5,7 +5,7 @@ from utils.llm_data import llm_models_root
|
||||
from utils.util import force_ampere, similar
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi.llm_utils import KvCacheConfig
|
||||
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -46,13 +46,12 @@ def llm(fixed_params, input_prompts):
|
||||
enable_trtllm_sampler=True,
|
||||
max_beam_width=fixed_params["max_beam_width"],
|
||||
disable_overlap_scheduler=True,
|
||||
#TODO: remove this once we have a proper fix for CUDA graph in beam search
|
||||
cuda_graph_config=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm_overlap(fixed_params, input_prompts):
|
||||
def llm_cuda_graph(fixed_params, input_prompts):
|
||||
return LLM(
|
||||
model=os.path.join(llm_models_root(), "llama-models-v2",
|
||||
"TinyLlama-1.1B-Chat-v1.0"),
|
||||
@ -64,8 +63,7 @@ def llm_overlap(fixed_params, input_prompts):
|
||||
enable_trtllm_sampler=True,
|
||||
max_beam_width=fixed_params["max_beam_width"],
|
||||
disable_overlap_scheduler=False,
|
||||
#TODO: remove this once we have a proper fix for CUDA graph in beam search
|
||||
cuda_graph_config=None,
|
||||
cuda_graph_config=CudaGraphConfig(enabled=True),
|
||||
)
|
||||
|
||||
|
||||
@ -132,10 +130,10 @@ def test_beam_search_output_shapes(gather_context_logits: bool,
|
||||
@pytest.mark.parametrize("num_output_beams", [1, 2])
|
||||
@pytest.mark.parametrize("num_prompts", [1, 2])
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_beam_search_output_shapes_overlap(
|
||||
def test_beam_search_output_shapes_cuda_graph_and_overlap(
|
||||
gather_context_logits: bool, gather_generation_logits: bool,
|
||||
return_log_probs: bool, num_output_beams: int, num_prompts: int,
|
||||
llm_overlap, fixed_params, input_prompts, expected_outputs):
|
||||
llm_cuda_graph, fixed_params, input_prompts, expected_outputs):
|
||||
if return_log_probs and num_prompts > 1:
|
||||
pytest.skip(
|
||||
"Beam search currently does not support return_log_probs with multiple prompts"
|
||||
@ -149,8 +147,8 @@ def test_beam_search_output_shapes_overlap(
|
||||
return_generation_logits=gather_generation_logits,
|
||||
logprobs=return_log_probs,
|
||||
)
|
||||
outputs = llm_overlap.generate(input_prompts[:num_prompts],
|
||||
sampling_params=sampling_params)
|
||||
outputs = llm_cuda_graph.generate(input_prompts[:num_prompts],
|
||||
sampling_params=sampling_params)
|
||||
assert len(outputs) == num_prompts
|
||||
for output_idx, output in enumerate(outputs):
|
||||
if gather_context_logits:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user