mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-8690][feat] add more tensors to share buffers (#8691)
Signed-off-by: Hui Gao <huig@nvidia.com>
This commit is contained in:
parent
ed297d7c2e
commit
97674c3114
@ -124,14 +124,22 @@ class FlashInferAttentionMetadata(AttentionMetadata):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
self._post_init_with_buffers(self.cuda_graph_buffers)
|
||||
|
||||
def _post_init_with_buffers(self, buffers) -> None:
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
|
||||
if self.workspace_buffer is None:
|
||||
# Note: even though flashinfer only recommends 128 MB, we have to push it
|
||||
# a bit higher to cover all possible CUDA graph cases. If it's too small,
|
||||
# warmup will crash.
|
||||
self.workspace_buffer = torch.empty(320 * 1024 * 1024,
|
||||
dtype=torch.uint8,
|
||||
device="cuda")
|
||||
self.workspace_buffer = self.get_empty(
|
||||
buffers,
|
||||
(320 * 1024 * 1024, ),
|
||||
dtype=torch.uint8,
|
||||
cache_name="workspace_buffer",
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
|
||||
self.paged_kv_indptr_decode = torch.empty((self.max_num_requests + 1, ),
|
||||
device='cuda',
|
||||
@ -163,9 +171,13 @@ class FlashInferAttentionMetadata(AttentionMetadata):
|
||||
|
||||
if self.kv_cache_manager is not None:
|
||||
max_num_pages = self.kv_cache_manager.blocks_in_primary_pool
|
||||
self._paged_kv_indices = torch.empty((max_num_pages, ),
|
||||
device='cuda',
|
||||
dtype=torch.int)
|
||||
self._paged_kv_indices = self.get_empty(
|
||||
buffers,
|
||||
(max_num_pages, ),
|
||||
dtype=torch.int,
|
||||
cache_name="_paged_kv_indices",
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
|
||||
def create_cuda_graph_metadata(self,
|
||||
max_batch_size: int,
|
||||
|
||||
@ -17,6 +17,7 @@ from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from ..memory_buffer_utils import Buffers
|
||||
from ..metadata import KVCacheParams
|
||||
from ..pyexecutor.resource_manager import KVCacheManager
|
||||
from ..utils import get_model_extra_attrs
|
||||
@ -349,6 +350,49 @@ class AttentionMetadata:
|
||||
Hook to be called during forward when using spec-dec one-model mode.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_empty(buffers: Buffers,
|
||||
tensor_shape: list[int],
|
||||
dtype: torch.dtype,
|
||||
cache_name: str,
|
||||
capture_graph: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Finds a compatible, reusable buffer from a cache or creates a new one.
|
||||
|
||||
This function searches for a pre-allocated tensor (buffer) that can be
|
||||
reused for an operation involving a tensor with the shape of `tensor_shape`.
|
||||
|
||||
The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
|
||||
|
||||
If a compatible buffer is found, it's returned immediately. Otherwise, a new
|
||||
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
|
||||
|
||||
Args:
|
||||
tensor_shape: The required shape.
|
||||
dtype: The required dtype.
|
||||
cache_name: The key for the specific list of buffers to search in.
|
||||
Returns:
|
||||
An existing compatible buffer or a newly created one.
|
||||
"""
|
||||
if buffers is None:
|
||||
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
|
||||
|
||||
return buffers.get_buffer(tensor_shape, dtype, cache_name,
|
||||
capture_graph)
|
||||
|
||||
@staticmethod
|
||||
def get_empty_like(buffers,
|
||||
like_tensor: torch.Tensor,
|
||||
cache_name: str,
|
||||
capture_graph: bool = False) -> torch.Tensor:
|
||||
return AttentionMetadata.get_empty(
|
||||
buffers,
|
||||
like_tensor.shape,
|
||||
dtype=like_tensor.dtype,
|
||||
cache_name=cache_name,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
|
||||
|
||||
class PositionalEmbedder(Protocol):
|
||||
"""
|
||||
|
||||
@ -304,17 +304,12 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
|
||||
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
|
||||
cache_name: str) -> torch.Tensor:
|
||||
if self.cuda_graph_buffers is None:
|
||||
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
|
||||
return self.cuda_graph_buffers.get_buffer(tensor_shape, dtype,
|
||||
cache_name, capture_graph)
|
||||
|
||||
self.indexer_k_cache_block_offsets = get_empty(
|
||||
self.indexer_k_cache_block_offsets = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
[self.max_num_sequences, self.kv_cache_manager.max_blocks_per_seq],
|
||||
cache_name="indexer_k_cache_block_offsets",
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_indexer_k_cache_block_offsets = torch.zeros_like(
|
||||
self.indexer_k_cache_block_offsets,
|
||||
@ -324,20 +319,24 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
|
||||
# For mla_rope_append_paged_kv_assign_q
|
||||
if not self.enable_context_mla_with_cached_kv:
|
||||
self.ctx_cached_token_indptr = get_empty(
|
||||
self.ctx_cached_token_indptr = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_requests + 1, ),
|
||||
cache_name="ctx_cached_token_indptr",
|
||||
dtype=torch.int64,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_ctx_cached_token_indptr = torch.zeros_like(
|
||||
self.ctx_cached_token_indptr,
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
)
|
||||
self.ctx_kv_indptr = get_empty(
|
||||
self.ctx_kv_indptr = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_requests + 1, ),
|
||||
cache_name="ctx_kv_indptr",
|
||||
dtype=torch.int64,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_ctx_kv_indptr = torch.zeros_like(
|
||||
self.ctx_kv_indptr,
|
||||
@ -345,20 +344,24 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
pin_memory=True,
|
||||
)
|
||||
# New generation buffers for dsa
|
||||
self.gen_cached_token_indptr = get_empty(
|
||||
self.gen_cached_token_indptr = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_requests + 1, ),
|
||||
cache_name="gen_cached_token_indptr",
|
||||
dtype=torch.int64,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_gen_cached_token_indptr = torch.zeros_like(
|
||||
self.gen_cached_token_indptr,
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
)
|
||||
self.gen_kv_indptr = get_empty(
|
||||
self.gen_kv_indptr = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_requests + 1, ),
|
||||
cache_name="gen_kv_indptr",
|
||||
dtype=torch.int64,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_gen_kv_indptr = torch.zeros_like(
|
||||
self.gen_kv_indptr,
|
||||
@ -367,20 +370,24 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
)
|
||||
# Indexer metadata
|
||||
# Separate slot mappings for non-interleaved layout (flat byte indices)
|
||||
self.slot_mapping_fp8 = get_empty(
|
||||
self.slot_mapping_fp8 = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_tokens, ),
|
||||
cache_name="slot_mapping_fp8",
|
||||
dtype=torch.int64,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_slot_mapping_fp8 = torch.zeros_like(
|
||||
self.slot_mapping_fp8,
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
)
|
||||
self.slot_mapping_scale = get_empty(
|
||||
self.slot_mapping_scale = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_tokens, ),
|
||||
cache_name="slot_mapping_scale",
|
||||
dtype=torch.int64,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_slot_mapping_scale = torch.zeros_like(
|
||||
self.slot_mapping_scale,
|
||||
@ -388,31 +395,41 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
pin_memory=True,
|
||||
)
|
||||
# Per-token request index buffer for topk_indices conversion
|
||||
self.req_idx_per_token = get_empty(
|
||||
self.req_idx_per_token = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_tokens, ),
|
||||
cache_name="req_idx_per_token",
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
# Block table for topk_indices conversion (shared for context and generation)
|
||||
self.block_table = get_empty(
|
||||
self.block_table = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_requests, self.kv_cache_manager.max_blocks_per_seq),
|
||||
cache_name="block_table",
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.scheduler_metadata_buffer = get_empty(
|
||||
self.scheduler_metadata_buffer = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.num_sms + 1, 2),
|
||||
cache_name="scheduler_metadata_buffer",
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.cu_seqlen_ks = get_empty(
|
||||
self.cu_seqlen_ks = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_tokens, ),
|
||||
cache_name="cu_seqlen_ks",
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.cu_seqlen_ke = get_empty(
|
||||
self.cu_seqlen_ke = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_tokens, ),
|
||||
cache_name="cu_seqlen_ke",
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
|
||||
def prepare(self):
|
||||
|
||||
@ -35,14 +35,19 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
if self.sparse_attention_config is None:
|
||||
raise ValueError("Sparse attention config is not set")
|
||||
self.prompt_budget = self.sparse_attention_config.prompt_budget
|
||||
self.kt_cache_block_offsets = torch.empty(
|
||||
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
self.kt_cache_block_offsets = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
[
|
||||
self.max_num_sequences,
|
||||
self.kv_cache_manager.max_kt_blocks_per_seq
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device='cuda',
|
||||
cache_name="kt_cache_block_offsets",
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
|
||||
self.host_kt_cache_block_offsets = torch.zeros_like(
|
||||
self.kt_cache_block_offsets,
|
||||
device='cpu',
|
||||
|
||||
@ -649,50 +649,24 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
|
||||
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
|
||||
cache_name: str) -> torch.Tensor:
|
||||
"""
|
||||
Finds a compatible, reusable buffer from a cache or creates a new one.
|
||||
|
||||
This function searches for a pre-allocated tensor (buffer) that can be
|
||||
reused for an operation involving a tensor with the shape of `tensor_shape`.
|
||||
|
||||
The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
|
||||
|
||||
If a compatible buffer is found, it's returned immediately. Otherwise, a new
|
||||
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
|
||||
|
||||
Args:
|
||||
tensor_shape: The required shape.
|
||||
dtype: The required dtype.
|
||||
cache_name: The key for the specific list of buffers to search in.
|
||||
Returns:
|
||||
An existing compatible buffer or a newly created one.
|
||||
"""
|
||||
if buffers is None:
|
||||
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
|
||||
|
||||
return buffers.get_buffer(tensor_shape, dtype, cache_name,
|
||||
capture_graph)
|
||||
|
||||
def get_empty_like(like_tensor: torch.Tensor,
|
||||
cache_name: str) -> torch.Tensor:
|
||||
return get_empty(like_tensor.shape,
|
||||
cache_name=cache_name,
|
||||
dtype=like_tensor.dtype)
|
||||
|
||||
self.prompt_lens_cuda = get_empty(
|
||||
self.prompt_lens_cuda = self.get_empty(
|
||||
buffers,
|
||||
(self.max_num_sequences, ),
|
||||
cache_name="prompt_lens_cuda",
|
||||
dtype=torch.int,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.prompt_lens_cpu = torch.empty_like(
|
||||
self.prompt_lens_cuda,
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
)
|
||||
self.kv_lens_cuda = get_empty_like(self.prompt_lens_cuda,
|
||||
cache_name="kv_lens_cuda")
|
||||
self.kv_lens_cuda = self.get_empty_like(
|
||||
buffers,
|
||||
self.prompt_lens_cuda,
|
||||
cache_name="kv_lens_cuda",
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
|
||||
device='cpu',
|
||||
pin_memory=True)
|
||||
@ -707,13 +681,15 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
dtype=torch.int8,
|
||||
)
|
||||
if self.kv_cache_manager is not None:
|
||||
self.kv_cache_block_offsets = get_empty(
|
||||
self.kv_cache_block_offsets = self.get_empty(
|
||||
buffers,
|
||||
[
|
||||
self.kv_cache_manager.num_pools, self.max_num_sequences, 2,
|
||||
self.kv_cache_manager.max_blocks_per_seq
|
||||
],
|
||||
cache_name="kv_cache_block_offsets",
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_kv_cache_block_offsets = torch.empty_like(
|
||||
self.kv_cache_block_offsets,
|
||||
@ -723,38 +699,46 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
self.block_ids_per_seq = None
|
||||
self.kv_block_ids_per_seq = None
|
||||
if self.enable_flash_mla:
|
||||
self.block_ids_per_seq = get_empty(
|
||||
self.block_ids_per_seq = self.get_empty(
|
||||
buffers,
|
||||
[
|
||||
self.kv_cache_manager.max_batch_size,
|
||||
self.kv_cache_manager.max_blocks_per_seq
|
||||
],
|
||||
cache_name="block_ids_per_seq",
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.kv_block_ids_per_seq = get_empty(
|
||||
self.kv_block_ids_per_seq = self.get_empty(
|
||||
buffers,
|
||||
[
|
||||
self.kv_cache_manager.max_batch_size,
|
||||
self.kv_cache_manager.max_blocks_per_seq
|
||||
],
|
||||
cache_name="kv_block_ids_per_seq",
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
if self.enable_context_mla_with_cached_kv:
|
||||
# for kv cache reuse/chunked context in MLA
|
||||
self.ctx_cached_token_indptr = get_empty(
|
||||
self.ctx_cached_token_indptr = self.get_empty(
|
||||
buffers,
|
||||
(self.max_num_requests + 1, ),
|
||||
cache_name="ctx_cached_token_indptr",
|
||||
dtype=torch.int64,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_ctx_cached_token_indptr = torch.zeros_like(
|
||||
self.ctx_cached_token_indptr,
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
)
|
||||
self.ctx_uncached_token_indptr = get_empty(
|
||||
self.ctx_uncached_token_indptr = self.get_empty(
|
||||
buffers,
|
||||
(self.max_num_requests + 1, ),
|
||||
cache_name="ctx_uncached_token_indptr",
|
||||
dtype=torch.int64,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_ctx_uncached_token_indptr = torch.zeros_like(
|
||||
self.ctx_uncached_token_indptr,
|
||||
@ -762,10 +746,12 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
pin_memory=True,
|
||||
)
|
||||
# context full seqlens include cached tokens and uncached tokens
|
||||
self.ctx_kv_indptr = get_empty(
|
||||
self.ctx_kv_indptr = self.get_empty(
|
||||
buffers,
|
||||
(self.max_num_requests + 1, ),
|
||||
cache_name="ctx_kv_indptr",
|
||||
dtype=torch.int64,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_ctx_kv_indptr = torch.zeros_like(
|
||||
self.ctx_kv_indptr,
|
||||
|
||||
@ -8,6 +8,10 @@ import torch
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
def get_size_in_byte(target_shape: list[int], target_dtype: torch.dtype):
|
||||
return math.prod(target_shape) * target_dtype.itemsize
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferBlock:
|
||||
"""A container for a buffer tensor and its state."""
|
||||
@ -36,13 +40,13 @@ class Buffers:
|
||||
target_dtype: torch.dtype) -> torch.Tensor:
|
||||
"""Safely creates a view of a raw byte buffer with the desired shape and dtype."""
|
||||
# The buffer is stored as uint8, so its numel is its size in bytes.
|
||||
required_size_in_bytes = math.prod(target_shape) * target_dtype.itemsize
|
||||
if buffer.numel() < required_size_in_bytes:
|
||||
required_memory_size = get_size_in_byte(target_shape, target_dtype)
|
||||
if buffer.numel() < required_memory_size:
|
||||
raise ValueError(
|
||||
"Buffer is too small for the requested shape and dtype.")
|
||||
|
||||
# Slice the buffer to the exact required size, then view it with the correct type and shape.
|
||||
return buffer[:required_size_in_bytes].view(target_dtype).view(
|
||||
return buffer[:required_memory_size].view(target_dtype).view(
|
||||
target_shape)
|
||||
|
||||
def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
|
||||
@ -50,6 +54,7 @@ class Buffers:
|
||||
|
||||
# all buffers are allocated with 1 byte element size
|
||||
required_memory_size = math.prod(tensor_shape) * dtype.itemsize
|
||||
|
||||
candidate_blocks = self.buffers.get(buffer_name, [])
|
||||
|
||||
# Find the best-fit available buffer.
|
||||
|
||||
@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ....memory_buffer_utils import get_memory_buffers
|
||||
from .moe_op import MoEOp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -24,6 +25,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class DeepGemmMoEOp(MoEOp):
|
||||
"""DeepGemm-based MoE op for GB200 block FP8."""
|
||||
buffers = get_memory_buffers()
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize DeepGemm op."""
|
||||
@ -85,23 +87,27 @@ class DeepGemmMoEOp(MoEOp):
|
||||
workspace = {}
|
||||
|
||||
# Workspace for FP8 activations
|
||||
workspace["workspace_0"] = torch.empty(
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
workspace["workspace_0"] = DeepGemmMoEOp.buffers.get_buffer(
|
||||
(expert_size_per_partition * m_max * fp8_dim),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device='cuda')
|
||||
buffer_name='workspace_0',
|
||||
reserve_buffer=capture_graph)
|
||||
|
||||
# Workspace for intermediate results
|
||||
workspace["workspace_1"] = torch.empty(
|
||||
workspace["workspace_1"] = DeepGemmMoEOp.buffers.get_buffer(
|
||||
(expert_size_per_partition * m_max *
|
||||
max(intermediate_size * 2, hidden_size)),
|
||||
dtype=torch.bfloat16,
|
||||
device='cuda')
|
||||
buffer_name='workspace_1',
|
||||
reserve_buffer=capture_graph)
|
||||
|
||||
# Workspace for scaling factors
|
||||
workspace["workspace_sf"] = torch.empty(
|
||||
workspace["workspace_sf"] = DeepGemmMoEOp.buffers.get_buffer(
|
||||
expert_size_per_partition * (scale_k_padded // 4) * m_padded,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
buffer_name='workspace_sf',
|
||||
reserve_buffer=capture_graph)
|
||||
|
||||
return workspace
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user