[TRTLLM-8690][feat] add more tensors to share buffers (#8691)

Signed-off-by: Hui Gao <huig@nvidia.com>
This commit is contained in:
HuiGao-NV 2025-11-04 13:08:01 +08:00 committed by GitHub
parent ed297d7c2e
commit 97674c3114
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 152 additions and 77 deletions

View File

@ -124,14 +124,22 @@ class FlashInferAttentionMetadata(AttentionMetadata):
def __post_init__(self) -> None:
super().__post_init__()
self._post_init_with_buffers(self.cuda_graph_buffers)
def _post_init_with_buffers(self, buffers) -> None:
capture_graph = torch.cuda.is_current_stream_capturing()
if self.workspace_buffer is None:
# Note: even though flashinfer only recommends 128 MB, we have to push it
# a bit higher to cover all possible CUDA graph cases. If it's too small,
# warmup will crash.
self.workspace_buffer = torch.empty(320 * 1024 * 1024,
dtype=torch.uint8,
device="cuda")
self.workspace_buffer = self.get_empty(
buffers,
(320 * 1024 * 1024, ),
dtype=torch.uint8,
cache_name="workspace_buffer",
capture_graph=capture_graph,
)
self.paged_kv_indptr_decode = torch.empty((self.max_num_requests + 1, ),
device='cuda',
@ -163,9 +171,13 @@ class FlashInferAttentionMetadata(AttentionMetadata):
if self.kv_cache_manager is not None:
max_num_pages = self.kv_cache_manager.blocks_in_primary_pool
self._paged_kv_indices = torch.empty((max_num_pages, ),
device='cuda',
dtype=torch.int)
self._paged_kv_indices = self.get_empty(
buffers,
(max_num_pages, ),
dtype=torch.int,
cache_name="_paged_kv_indices",
capture_graph=capture_graph,
)
def create_cuda_graph_metadata(self,
max_batch_size: int,

View File

@ -17,6 +17,7 @@ from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from ..memory_buffer_utils import Buffers
from ..metadata import KVCacheParams
from ..pyexecutor.resource_manager import KVCacheManager
from ..utils import get_model_extra_attrs
@ -349,6 +350,49 @@ class AttentionMetadata:
Hook to be called during forward when using spec-dec one-model mode.
"""
@staticmethod
def get_empty(buffers: Buffers,
tensor_shape: list[int],
dtype: torch.dtype,
cache_name: str,
capture_graph: bool = False) -> torch.Tensor:
"""
Finds a compatible, reusable buffer from a cache or creates a new one.
This function searches for a pre-allocated tensor (buffer) that can be
reused for an operation involving a tensor with the shape of `tensor_shape`.
The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
If a compatible buffer is found, it's returned immediately. Otherwise, a new
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
Args:
tensor_shape: The required shape.
dtype: The required dtype.
cache_name: The key for the specific list of buffers to search in.
Returns:
An existing compatible buffer or a newly created one.
"""
if buffers is None:
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
return buffers.get_buffer(tensor_shape, dtype, cache_name,
capture_graph)
@staticmethod
def get_empty_like(buffers,
like_tensor: torch.Tensor,
cache_name: str,
capture_graph: bool = False) -> torch.Tensor:
return AttentionMetadata.get_empty(
buffers,
like_tensor.shape,
dtype=like_tensor.dtype,
cache_name=cache_name,
capture_graph=capture_graph,
)
class PositionalEmbedder(Protocol):
"""

View File

@ -304,17 +304,12 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
capture_graph = torch.cuda.is_current_stream_capturing()
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
cache_name: str) -> torch.Tensor:
if self.cuda_graph_buffers is None:
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
return self.cuda_graph_buffers.get_buffer(tensor_shape, dtype,
cache_name, capture_graph)
self.indexer_k_cache_block_offsets = get_empty(
self.indexer_k_cache_block_offsets = self.get_empty(
self.cuda_graph_buffers,
[self.max_num_sequences, self.kv_cache_manager.max_blocks_per_seq],
cache_name="indexer_k_cache_block_offsets",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.host_indexer_k_cache_block_offsets = torch.zeros_like(
self.indexer_k_cache_block_offsets,
@ -324,20 +319,24 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
# For mla_rope_append_paged_kv_assign_q
if not self.enable_context_mla_with_cached_kv:
self.ctx_cached_token_indptr = get_empty(
self.ctx_cached_token_indptr = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_requests + 1, ),
cache_name="ctx_cached_token_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_ctx_cached_token_indptr = torch.zeros_like(
self.ctx_cached_token_indptr,
device='cpu',
pin_memory=True,
)
self.ctx_kv_indptr = get_empty(
self.ctx_kv_indptr = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_requests + 1, ),
cache_name="ctx_kv_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_ctx_kv_indptr = torch.zeros_like(
self.ctx_kv_indptr,
@ -345,20 +344,24 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
pin_memory=True,
)
# New generation buffers for dsa
self.gen_cached_token_indptr = get_empty(
self.gen_cached_token_indptr = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_requests + 1, ),
cache_name="gen_cached_token_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_gen_cached_token_indptr = torch.zeros_like(
self.gen_cached_token_indptr,
device='cpu',
pin_memory=True,
)
self.gen_kv_indptr = get_empty(
self.gen_kv_indptr = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_requests + 1, ),
cache_name="gen_kv_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_gen_kv_indptr = torch.zeros_like(
self.gen_kv_indptr,
@ -367,20 +370,24 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
)
# Indexer metadata
# Separate slot mappings for non-interleaved layout (flat byte indices)
self.slot_mapping_fp8 = get_empty(
self.slot_mapping_fp8 = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_tokens, ),
cache_name="slot_mapping_fp8",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_slot_mapping_fp8 = torch.zeros_like(
self.slot_mapping_fp8,
device='cpu',
pin_memory=True,
)
self.slot_mapping_scale = get_empty(
self.slot_mapping_scale = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_tokens, ),
cache_name="slot_mapping_scale",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_slot_mapping_scale = torch.zeros_like(
self.slot_mapping_scale,
@ -388,31 +395,41 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
pin_memory=True,
)
# Per-token request index buffer for topk_indices conversion
self.req_idx_per_token = get_empty(
self.req_idx_per_token = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_tokens, ),
cache_name="req_idx_per_token",
dtype=torch.int32,
capture_graph=capture_graph,
)
# Block table for topk_indices conversion (shared for context and generation)
self.block_table = get_empty(
self.block_table = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_requests, self.kv_cache_manager.max_blocks_per_seq),
cache_name="block_table",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.scheduler_metadata_buffer = get_empty(
self.scheduler_metadata_buffer = self.get_empty(
self.cuda_graph_buffers,
(self.num_sms + 1, 2),
cache_name="scheduler_metadata_buffer",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.cu_seqlen_ks = get_empty(
self.cu_seqlen_ks = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_tokens, ),
cache_name="cu_seqlen_ks",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.cu_seqlen_ke = get_empty(
self.cu_seqlen_ke = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_tokens, ),
cache_name="cu_seqlen_ke",
dtype=torch.int32,
capture_graph=capture_graph,
)
def prepare(self):

View File

@ -35,14 +35,19 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
if self.sparse_attention_config is None:
raise ValueError("Sparse attention config is not set")
self.prompt_budget = self.sparse_attention_config.prompt_budget
self.kt_cache_block_offsets = torch.empty(
capture_graph = torch.cuda.is_current_stream_capturing()
self.kt_cache_block_offsets = self.get_empty(
self.cuda_graph_buffers,
[
self.max_num_sequences,
self.kv_cache_manager.max_kt_blocks_per_seq
],
dtype=torch.int32,
device='cuda',
cache_name="kt_cache_block_offsets",
capture_graph=capture_graph,
)
self.host_kt_cache_block_offsets = torch.zeros_like(
self.kt_cache_block_offsets,
device='cpu',

View File

@ -649,50 +649,24 @@ class TrtllmAttentionMetadata(AttentionMetadata):
capture_graph = torch.cuda.is_current_stream_capturing()
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
cache_name: str) -> torch.Tensor:
"""
Finds a compatible, reusable buffer from a cache or creates a new one.
This function searches for a pre-allocated tensor (buffer) that can be
reused for an operation involving a tensor with the shape of `tensor_shape`.
The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
If a compatible buffer is found, it's returned immediately. Otherwise, a new
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
Args:
tensor_shape: The required shape.
dtype: The required dtype.
cache_name: The key for the specific list of buffers to search in.
Returns:
An existing compatible buffer or a newly created one.
"""
if buffers is None:
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
return buffers.get_buffer(tensor_shape, dtype, cache_name,
capture_graph)
def get_empty_like(like_tensor: torch.Tensor,
cache_name: str) -> torch.Tensor:
return get_empty(like_tensor.shape,
cache_name=cache_name,
dtype=like_tensor.dtype)
self.prompt_lens_cuda = get_empty(
self.prompt_lens_cuda = self.get_empty(
buffers,
(self.max_num_sequences, ),
cache_name="prompt_lens_cuda",
dtype=torch.int,
capture_graph=capture_graph,
)
self.prompt_lens_cpu = torch.empty_like(
self.prompt_lens_cuda,
device='cpu',
pin_memory=True,
)
self.kv_lens_cuda = get_empty_like(self.prompt_lens_cuda,
cache_name="kv_lens_cuda")
self.kv_lens_cuda = self.get_empty_like(
buffers,
self.prompt_lens_cuda,
cache_name="kv_lens_cuda",
capture_graph=capture_graph,
)
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
device='cpu',
pin_memory=True)
@ -707,13 +681,15 @@ class TrtllmAttentionMetadata(AttentionMetadata):
dtype=torch.int8,
)
if self.kv_cache_manager is not None:
self.kv_cache_block_offsets = get_empty(
self.kv_cache_block_offsets = self.get_empty(
buffers,
[
self.kv_cache_manager.num_pools, self.max_num_sequences, 2,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="kv_cache_block_offsets",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.host_kv_cache_block_offsets = torch.empty_like(
self.kv_cache_block_offsets,
@ -723,38 +699,46 @@ class TrtllmAttentionMetadata(AttentionMetadata):
self.block_ids_per_seq = None
self.kv_block_ids_per_seq = None
if self.enable_flash_mla:
self.block_ids_per_seq = get_empty(
self.block_ids_per_seq = self.get_empty(
buffers,
[
self.kv_cache_manager.max_batch_size,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="block_ids_per_seq",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.kv_block_ids_per_seq = get_empty(
self.kv_block_ids_per_seq = self.get_empty(
buffers,
[
self.kv_cache_manager.max_batch_size,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="kv_block_ids_per_seq",
dtype=torch.int32,
capture_graph=capture_graph,
)
if self.enable_context_mla_with_cached_kv:
# for kv cache reuse/chunked context in MLA
self.ctx_cached_token_indptr = get_empty(
self.ctx_cached_token_indptr = self.get_empty(
buffers,
(self.max_num_requests + 1, ),
cache_name="ctx_cached_token_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_ctx_cached_token_indptr = torch.zeros_like(
self.ctx_cached_token_indptr,
device='cpu',
pin_memory=True,
)
self.ctx_uncached_token_indptr = get_empty(
self.ctx_uncached_token_indptr = self.get_empty(
buffers,
(self.max_num_requests + 1, ),
cache_name="ctx_uncached_token_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_ctx_uncached_token_indptr = torch.zeros_like(
self.ctx_uncached_token_indptr,
@ -762,10 +746,12 @@ class TrtllmAttentionMetadata(AttentionMetadata):
pin_memory=True,
)
# context full seqlens include cached tokens and uncached tokens
self.ctx_kv_indptr = get_empty(
self.ctx_kv_indptr = self.get_empty(
buffers,
(self.max_num_requests + 1, ),
cache_name="ctx_kv_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_ctx_kv_indptr = torch.zeros_like(
self.ctx_kv_indptr,

View File

@ -8,6 +8,10 @@ import torch
from tensorrt_llm.logger import logger
def get_size_in_byte(target_shape: list[int], target_dtype: torch.dtype):
return math.prod(target_shape) * target_dtype.itemsize
@dataclass
class BufferBlock:
"""A container for a buffer tensor and its state."""
@ -36,13 +40,13 @@ class Buffers:
target_dtype: torch.dtype) -> torch.Tensor:
"""Safely creates a view of a raw byte buffer with the desired shape and dtype."""
# The buffer is stored as uint8, so its numel is its size in bytes.
required_size_in_bytes = math.prod(target_shape) * target_dtype.itemsize
if buffer.numel() < required_size_in_bytes:
required_memory_size = get_size_in_byte(target_shape, target_dtype)
if buffer.numel() < required_memory_size:
raise ValueError(
"Buffer is too small for the requested shape and dtype.")
# Slice the buffer to the exact required size, then view it with the correct type and shape.
return buffer[:required_size_in_bytes].view(target_dtype).view(
return buffer[:required_memory_size].view(target_dtype).view(
target_shape)
def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
@ -50,6 +54,7 @@ class Buffers:
# all buffers are allocated with 1 byte element size
required_memory_size = math.prod(tensor_shape) * dtype.itemsize
candidate_blocks = self.buffers.get(buffer_name, [])
# Find the best-fit available buffer.

View File

@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional
import torch
from ....memory_buffer_utils import get_memory_buffers
from .moe_op import MoEOp
if TYPE_CHECKING:
@ -24,6 +25,7 @@ if TYPE_CHECKING:
class DeepGemmMoEOp(MoEOp):
"""DeepGemm-based MoE op for GB200 block FP8."""
buffers = get_memory_buffers()
def __init__(self):
"""Initialize DeepGemm op."""
@ -85,23 +87,27 @@ class DeepGemmMoEOp(MoEOp):
workspace = {}
# Workspace for FP8 activations
workspace["workspace_0"] = torch.empty(
capture_graph = torch.cuda.is_current_stream_capturing()
workspace["workspace_0"] = DeepGemmMoEOp.buffers.get_buffer(
(expert_size_per_partition * m_max * fp8_dim),
dtype=torch.float8_e4m3fn,
device='cuda')
buffer_name='workspace_0',
reserve_buffer=capture_graph)
# Workspace for intermediate results
workspace["workspace_1"] = torch.empty(
workspace["workspace_1"] = DeepGemmMoEOp.buffers.get_buffer(
(expert_size_per_partition * m_max *
max(intermediate_size * 2, hidden_size)),
dtype=torch.bfloat16,
device='cuda')
buffer_name='workspace_1',
reserve_buffer=capture_graph)
# Workspace for scaling factors
workspace["workspace_sf"] = torch.empty(
workspace["workspace_sf"] = DeepGemmMoEOp.buffers.get_buffer(
expert_size_per_partition * (scale_k_padded // 4) * m_padded,
dtype=torch.int32,
device='cuda')
buffer_name='workspace_sf',
reserve_buffer=capture_graph)
return workspace