mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[https://nvbugs/5756028][fix] Fix VSWA initialization with spec-dec and boundary condition in context input preparation (#10798)
Signed-off-by: eopXD <yuehtingc@nvidia.com>
This commit is contained in:
parent
09807918c7
commit
383c5921c2
@ -691,7 +691,7 @@ def _create_kv_cache_manager(
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
else:
|
||||
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
|
||||
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_for_vswa in KVCahceManager
|
||||
is_vswa = kv_cache_config.max_attention_window is not None and len(
|
||||
set(kv_cache_config.max_attention_window)) > 1
|
||||
binding_model_config = model_engine.model.model_config.get_bindings_model_config(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import enum
|
||||
import math
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
|
||||
@ -8,17 +9,18 @@ from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mpi4py import MPI
|
||||
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.bindings
|
||||
from tensorrt_llm._torch.distributed.communicator import Distributed, ReduceOp
|
||||
from tensorrt_llm._utils import (TensorWrapper, convert_to_torch_tensor,
|
||||
get_size_in_bytes)
|
||||
get_size_in_bytes, mpi_comm, mpi_disabled,
|
||||
torch_comm)
|
||||
from tensorrt_llm.bindings.internal.batch_manager.kv_cache_manager_v2_utils import (
|
||||
IndexMapper, copy_batch_block_offsets_to_device)
|
||||
from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig
|
||||
from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig,
|
||||
PybindMirror)
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
|
||||
from tensorrt_llm.math_utils import ceil_div
|
||||
@ -331,12 +333,44 @@ class KVCacheManager(BaseResourceManager):
|
||||
)
|
||||
assert isinstance(
|
||||
kv_cache_config, KvCacheConfig
|
||||
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig"
|
||||
blocks_per_window = self.calculate_max_num_blocks_from_cpp(
|
||||
), "calculate_max_num_blocks_for_vswa only accepts KvCacheConfig"
|
||||
blocks_per_window = self.calculate_max_num_blocks_for_vswa(
|
||||
kv_cache_config=kv_cache_config,
|
||||
model_config=model_config,
|
||||
extra_cost_memory=0,
|
||||
)
|
||||
if mapping.world_size > 1:
|
||||
# make sure all ranks use the same number of primary/secondary blocks
|
||||
if mpi_disabled():
|
||||
for window_size, (
|
||||
primary_blocks,
|
||||
secondary_blocks) in blocks_per_window.items():
|
||||
reduced_primary_blocks = torch_comm().allreduce(
|
||||
primary_blocks,
|
||||
op=torch.distributed.ReduceOp.MIN)
|
||||
reduced_secondary_blocks = torch_comm().allreduce(
|
||||
secondary_blocks,
|
||||
op=torch.distributed.ReduceOp.MIN)
|
||||
blocks_per_window[window_size] = (
|
||||
reduced_primary_blocks,
|
||||
reduced_secondary_blocks)
|
||||
else:
|
||||
for window_size, (
|
||||
primary_blocks,
|
||||
secondary_blocks) in blocks_per_window.items():
|
||||
reduced_primary_blocks = mpi_comm().allreduce(
|
||||
primary_blocks, op=MPI.MIN)
|
||||
reduced_secondary_blocks = mpi_comm().allreduce(
|
||||
secondary_blocks, op=MPI.MIN)
|
||||
blocks_per_window[window_size] = (
|
||||
reduced_primary_blocks,
|
||||
reduced_secondary_blocks)
|
||||
logger.info(
|
||||
f"[MPI rank={mapping.rank}] Original blocks_per_window: {blocks_per_window}"
|
||||
)
|
||||
logger.info(
|
||||
f"[MPI rank={mapping.rank}] Reduced blocks_per_window: {blocks_per_window}"
|
||||
)
|
||||
else:
|
||||
# Standard case: use original Python implementation
|
||||
self.blocks_in_primary_pool, self.blocks_in_secondary_pool = self.calculate_max_num_blocks(
|
||||
@ -1089,8 +1123,8 @@ class KVCacheManager(BaseResourceManager):
|
||||
window_size_to_layers_map[window_size].append(local_layer_idx)
|
||||
return window_size_to_layers_map
|
||||
|
||||
@staticmethod
|
||||
def adjust_window_sizes_for_vswa(
|
||||
self,
|
||||
window_size_to_layers: Dict[int, List[int]],
|
||||
max_attention_window_vec: List[int],
|
||||
kv_cache_config: KvCacheConfig,
|
||||
@ -1107,8 +1141,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
|
||||
def calculate_cache_size_per_token(layers: Set[int]) -> int:
|
||||
# Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize
|
||||
total_kv_heads = sum(model_config.num_kv_heads_per_layer[i]
|
||||
for i in layers)
|
||||
total_kv_heads = sum(self.num_kv_heads_per_layer[i] for i in layers)
|
||||
return total_kv_heads * kv_factor * model_config.head_size
|
||||
|
||||
# Calculate the required memory bytes per sequence.
|
||||
@ -1124,15 +1157,17 @@ class KVCacheManager(BaseResourceManager):
|
||||
quant_vector_size=16,
|
||||
scaling_factor_dtype=DataType.FP8)
|
||||
required_mem_bytes_per_seq += window_size * cache_size_bytes_per_token
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f'Required memory per sequence: {required_mem_bytes_per_seq} bytes')
|
||||
logger.info(f"Memory bytes in pool: {pool_memory_bytes}")
|
||||
|
||||
if required_mem_bytes_per_seq < pool_memory_bytes:
|
||||
# No need to adjust the window sizes.
|
||||
logger.info("No need to adjust the window sizes, returning")
|
||||
return (copy.deepcopy(window_size_to_layers),
|
||||
max_attention_window_vec)
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f'Adjusting the window sizes {list(window_size_to_layers)} to fit '
|
||||
f'the memory {pool_memory_bytes} bytes.')
|
||||
adjusted_window_size_to_layers = {}
|
||||
@ -1206,14 +1241,12 @@ class KVCacheManager(BaseResourceManager):
|
||||
return (adjusted_window_size_to_layers,
|
||||
adjusted_max_attention_window_vec)
|
||||
|
||||
def calculate_max_num_blocks_from_cpp(
|
||||
def calculate_max_num_blocks_for_vswa(
|
||||
self,
|
||||
kv_cache_config: KvCacheConfig,
|
||||
model_config: ModelConfigCpp,
|
||||
extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]:
|
||||
"""
|
||||
This function is a wrapper of KVCacheManagerCpp.calculate_max_num_blocks.
|
||||
The final goal is to switch to the C++ implementation of calculate_max_num_blocks.
|
||||
Currently, this function is added to support *ONLY* VSWA.
|
||||
|
||||
Args:
|
||||
@ -1223,6 +1256,15 @@ class KVCacheManager(BaseResourceManager):
|
||||
|
||||
Returns:
|
||||
A dict of (max_attention_window, (blocks_in_primary_pool, blocks_in_secondary_pool)).
|
||||
|
||||
Environment variable TRTLLM_WINDOW_SIZE_SHARES is used to adjust the memory
|
||||
share of each window size. By default, we allocate equal proportion shares of
|
||||
memory for all window sizes (see the else case). With TRTLLM_WINDOW_SIZE_SHARES,
|
||||
we can override this behavior to adjust the memory share of each window size.
|
||||
|
||||
For example, if we have window size of [512, 32768], then setting
|
||||
TRTLLM_WINDOW_SIZE_SHARES=0.4,0.6 will be allocating 40% of the memory to
|
||||
window size 512 and 60% of the memory to window size 32768.
|
||||
"""
|
||||
|
||||
# VSWA on Torch backend has not supported the cross attention.
|
||||
@ -1266,19 +1308,70 @@ class KVCacheManager(BaseResourceManager):
|
||||
)
|
||||
self.max_attention_window_vec = max_attention_window_vec
|
||||
|
||||
blocks_per_window = KVCacheManagerCpp.calculate_max_num_blocks(
|
||||
config=PybindMirror.maybe_to_pybind(kv_cache_config),
|
||||
# TODO: support cross attention
|
||||
is_cross_attention=is_cross_attention,
|
||||
dtype=self.dtype,
|
||||
model_config=model_config,
|
||||
world_config=world_config_cpp,
|
||||
window_size_to_layers=window_size_to_layers,
|
||||
allotted_primary_mem_bytes=self._primary_pool_memory_bytes,
|
||||
allotted_secondary_mem_bytes=self._secondary_pool_memory_bytes,
|
||||
extra_cost_memory=extra_cost_memory,
|
||||
kv_factor=self.kv_factor,
|
||||
)
|
||||
def calculate_cache_size_per_token(layers: Set[int]) -> int:
|
||||
# Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize
|
||||
total_kv_heads = sum(self.num_kv_heads_per_layer[i] for i in layers)
|
||||
return total_kv_heads * self.kv_factor * model_config.head_size
|
||||
|
||||
logger.info(
|
||||
f"Primary pool memory bytes: {self._primary_pool_memory_bytes}")
|
||||
logger.info(
|
||||
f"Secondary pool memory bytes: {self._secondary_pool_memory_bytes}")
|
||||
|
||||
if os.getenv("TRTLLM_WINDOW_SIZE_SHARES") is not None:
|
||||
logger.info("Environment variable TRTLLM_WINDOW_SIZE_SHARES is set")
|
||||
window_size_shares = os.getenv("TRTLLM_WINDOW_SIZE_SHARES").split(
|
||||
",")
|
||||
window_size_shares = [float(share) for share in window_size_shares]
|
||||
assert len(window_size_shares) == len(
|
||||
window_size_to_layers
|
||||
), "Number of shares in TRTLLM_WINDOW_SIZE_SHARES must match number of window sizes"
|
||||
assert sum(
|
||||
window_size_shares
|
||||
) == 1.0, "Sum of shares in TRTLLM_WINDOW_SIZE_SHARES must be 1.0"
|
||||
else:
|
||||
logger.info(
|
||||
"Using default allocation of equal proportion of memory to each window size"
|
||||
)
|
||||
window_size_shares = [
|
||||
1.0 / len(window_size_to_layers) for _ in window_size_to_layers
|
||||
]
|
||||
|
||||
logger.info(f"Derived window_size_shares: {window_size_shares}")
|
||||
|
||||
blocks_per_window = {}
|
||||
for window_idx, (window_size, layers) in enumerate(
|
||||
sorted(window_size_to_layers.items())):
|
||||
cache_size_per_token = calculate_cache_size_per_token(layers)
|
||||
cache_size_bytes_per_token = get_size_in_bytes(
|
||||
cache_size_per_token, self.dtype)
|
||||
|
||||
primary_tokens = self._primary_pool_memory_bytes * window_size_shares[
|
||||
window_idx] / cache_size_bytes_per_token
|
||||
secondary_tokens = self._secondary_pool_memory_bytes * window_size_shares[
|
||||
window_idx] / cache_size_bytes_per_token
|
||||
|
||||
if kv_cache_config.max_tokens is not None:
|
||||
if self.is_vswa:
|
||||
logger.info(
|
||||
f"kv_cache_config.max_tokens is not None ({kv_cache_config.max_tokens}) but we are operating on VSWA scheme. Ignoring the configuration."
|
||||
)
|
||||
if not self.is_vswa:
|
||||
logger.info(
|
||||
f"kv_cache_config.max_tokens is {kv_cache_config.max_tokens}"
|
||||
)
|
||||
if kv_cache_config.max_tokens < primary_tokens:
|
||||
logger.info(
|
||||
f"kv_cache_config.max_tokens {kv_cache_config.max_tokens} is less than primary_tokens {primary_tokens}. Reducing primary_tokens to {kv_cache_config.max_tokens}"
|
||||
)
|
||||
primary_tokens = kv_cache_config.max_tokens
|
||||
|
||||
primary_blocks = int(primary_tokens // self.tokens_per_block)
|
||||
secondary_blocks = int(secondary_tokens // self.tokens_per_block)
|
||||
logger.info(
|
||||
f"Window size = {window_size}, primary_blocks: {primary_blocks}, secondary_blocks: {secondary_blocks}"
|
||||
)
|
||||
blocks_per_window[window_size] = (primary_blocks, secondary_blocks)
|
||||
return blocks_per_window
|
||||
|
||||
def _validate_and_adjust_attention_windows(
|
||||
|
||||
@ -588,14 +588,17 @@ class SpecWorkerBase(nn.Module, ABC):
|
||||
Returns:
|
||||
input_ids_ctx: Prepared context input IDs
|
||||
"""
|
||||
input_prompt_ids = input_ids[:num_ctx_tokens]
|
||||
input_ids_ctx = torch.empty_like(input_prompt_ids,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
|
||||
input_ids_ctx[
|
||||
gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0]
|
||||
return input_ids_ctx
|
||||
if num_ctx_tokens > 0:
|
||||
input_prompt_ids = input_ids[:num_ctx_tokens]
|
||||
input_ids_ctx = torch.empty_like(input_prompt_ids,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
|
||||
input_ids_ctx[
|
||||
gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0]
|
||||
return input_ids_ctx
|
||||
else:
|
||||
return torch.empty(0, dtype=torch.int32, device="cuda")
|
||||
|
||||
def _sample_tokens_for_batch(
|
||||
self,
|
||||
|
||||
@ -456,9 +456,11 @@ class TestResourceManager(unittest.TestCase):
|
||||
|
||||
model_config.num_hidden_layers = len(total_layers)
|
||||
model_config.num_attention_layers = len(total_layers)
|
||||
model_config.layer_types = [LayerType.ATTENTION
|
||||
] * model_config.num_attention_layers
|
||||
|
||||
kv_factor = 2
|
||||
cache_bytes_per_token_per_layer = 8
|
||||
cache_bytes_per_token_per_layer = 32
|
||||
|
||||
# Define test cases:
|
||||
# (memory_bytes, expected_window_sizes, max_tokens, description)
|
||||
@ -466,7 +468,7 @@ class TestResourceManager(unittest.TestCase):
|
||||
test_cases = [
|
||||
(
|
||||
# Case 1: Limited memory - windows get clamped
|
||||
cache_bytes_per_token_per_layer * (100 * 9 + 30 * 5) + 4,
|
||||
cache_bytes_per_token_per_layer * (100 * 4 + 130 * 5) + 4,
|
||||
{
|
||||
100: [0, 1, 2, 3],
|
||||
130: [4, 5, 6, 7, 8],
|
||||
@ -477,7 +479,7 @@ class TestResourceManager(unittest.TestCase):
|
||||
(
|
||||
# Case 2: Less limited memory - the largest window get clamped
|
||||
cache_bytes_per_token_per_layer *
|
||||
(100 * 9 + 100 * 5 + 817 * 2) + 4,
|
||||
(100 * 4 + 200 * 3 + 1017 * 2) + 4,
|
||||
{
|
||||
100: [0, 1, 2, 3],
|
||||
200: [4, 5, 6],
|
||||
@ -510,8 +512,7 @@ class TestResourceManager(unittest.TestCase):
|
||||
(
|
||||
# Case 5: Less limited memory but max_tokens is given.
|
||||
# memory is enough for 1017 tokens, it will be clamped by max_tokens=134.
|
||||
cache_bytes_per_token_per_layer *
|
||||
(100 * 9 + 100 * 5 + 817 * 2) + 4,
|
||||
cache_bytes_per_token_per_layer * (100 * 4 + 134 * 5) + 4,
|
||||
{
|
||||
100: [0, 1, 2, 3],
|
||||
134: [4, 5, 6, 7, 8],
|
||||
@ -523,8 +524,33 @@ class TestResourceManager(unittest.TestCase):
|
||||
|
||||
for memory_bytes, expected_window_sizes, expected_max_attention_window_vec, max_tokens, description in test_cases:
|
||||
with self.subTest(case=description, memory_bytes=memory_bytes):
|
||||
kv_cache_config = tllm.KvCacheConfig(max_tokens=max_tokens)
|
||||
adjusted, adjusted_max_attention_window_vec = KVCacheManager.adjust_window_sizes_for_vswa(
|
||||
kv_cache_config_params = {
|
||||
"max_attention_window": max_attention_window_vec,
|
||||
"free_gpu_memory_fraction": 1.0,
|
||||
"host_cache_size": 0,
|
||||
"max_gpu_total_bytes": memory_bytes,
|
||||
}
|
||||
kv_cache_config = TestResourceManager._create_kv_cache_config_for_kv_cache_manager(
|
||||
kv_cache_config_params)
|
||||
mapping = Mapping(world_size=1, tp_size=1, pp_size=1)
|
||||
|
||||
manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
|
||||
CacheType.SELF,
|
||||
num_layers=model_config.num_attention_layers,
|
||||
num_kv_heads=8,
|
||||
head_dim=model_config.head_size,
|
||||
tokens_per_block=32,
|
||||
max_seq_len=max(max_attention_window_vec),
|
||||
max_batch_size=1,
|
||||
mapping=mapping,
|
||||
dtype=model_config.data_type,
|
||||
model_config=model_config,
|
||||
max_beam_width=1,
|
||||
)
|
||||
|
||||
adjusted, adjusted_max_attention_window_vec = manager.adjust_window_sizes_for_vswa(
|
||||
window_size_to_layers=window_size_to_layers,
|
||||
max_attention_window_vec=max_attention_window_vec,
|
||||
model_config=model_config,
|
||||
@ -581,7 +607,7 @@ class TestResourceManager(unittest.TestCase):
|
||||
"""
|
||||
return KvCacheConfig(**params)
|
||||
|
||||
def test_calculate_max_num_blocks_from_cpp(self):
|
||||
def test_calculate_max_num_blocks_for_vswa(self):
|
||||
# Construct a minimal mapping (single-rank, no TP/PP)
|
||||
mapping = Mapping(world_size=1, tp_size=1, pp_size=1)
|
||||
|
||||
@ -651,7 +677,7 @@ class TestResourceManager(unittest.TestCase):
|
||||
kv_cache_config_params)
|
||||
with patch('torch.cuda.mem_get_info',
|
||||
return_value=(fixed_free_mem, fixed_total_mem)):
|
||||
# Create a real KVCacheManager, it will run calculate_max_num_blocks_from_cpp in __init__
|
||||
# Create a real KVCacheManager, it will run calculate_max_num_blocks_for_vswa in __init__
|
||||
manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
kv_cache_type=tensorrt_llm.bindings.internal.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user