[https://nvbugs/5756028][fix] Fix VSWA initialization with spec-dec and boundary condition in context input preparation (#10798)

Signed-off-by: eopXD <yuehtingc@nvidia.com>
This commit is contained in:
Yueh-Ting (eop) Chen 2026-02-06 14:28:47 +08:00 committed by GitHub
parent 09807918c7
commit 383c5921c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 166 additions and 44 deletions

View File

@ -691,7 +691,7 @@ def _create_kv_cache_manager(
execution_stream=execution_stream,
)
else:
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_for_vswa in KVCahceManager
is_vswa = kv_cache_config.max_attention_window is not None and len(
set(kv_cache_config.max_attention_window)) > 1
binding_model_config = model_engine.model.model_config.get_bindings_model_config(

View File

@ -1,6 +1,7 @@
import copy
import enum
import math
import os
from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict, deque
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
@ -8,17 +9,18 @@ from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
import numpy as np
import torch
from mpi4py import MPI
import tensorrt_llm
import tensorrt_llm.bindings
from tensorrt_llm._torch.distributed.communicator import Distributed, ReduceOp
from tensorrt_llm._utils import (TensorWrapper, convert_to_torch_tensor,
get_size_in_bytes)
get_size_in_bytes, mpi_comm, mpi_disabled,
torch_comm)
from tensorrt_llm.bindings.internal.batch_manager.kv_cache_manager_v2_utils import (
IndexMapper, copy_batch_block_offsets_to_device)
from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig
from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig,
PybindMirror)
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
from tensorrt_llm.math_utils import ceil_div
@ -331,12 +333,44 @@ class KVCacheManager(BaseResourceManager):
)
assert isinstance(
kv_cache_config, KvCacheConfig
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig"
blocks_per_window = self.calculate_max_num_blocks_from_cpp(
), "calculate_max_num_blocks_for_vswa only accepts KvCacheConfig"
blocks_per_window = self.calculate_max_num_blocks_for_vswa(
kv_cache_config=kv_cache_config,
model_config=model_config,
extra_cost_memory=0,
)
if mapping.world_size > 1:
# make sure all ranks use the same number of primary/secondary blocks
if mpi_disabled():
for window_size, (
primary_blocks,
secondary_blocks) in blocks_per_window.items():
reduced_primary_blocks = torch_comm().allreduce(
primary_blocks,
op=torch.distributed.ReduceOp.MIN)
reduced_secondary_blocks = torch_comm().allreduce(
secondary_blocks,
op=torch.distributed.ReduceOp.MIN)
blocks_per_window[window_size] = (
reduced_primary_blocks,
reduced_secondary_blocks)
else:
for window_size, (
primary_blocks,
secondary_blocks) in blocks_per_window.items():
reduced_primary_blocks = mpi_comm().allreduce(
primary_blocks, op=MPI.MIN)
reduced_secondary_blocks = mpi_comm().allreduce(
secondary_blocks, op=MPI.MIN)
blocks_per_window[window_size] = (
reduced_primary_blocks,
reduced_secondary_blocks)
logger.info(
f"[MPI rank={mapping.rank}] Original blocks_per_window: {blocks_per_window}"
)
logger.info(
f"[MPI rank={mapping.rank}] Reduced blocks_per_window: {blocks_per_window}"
)
else:
# Standard case: use original Python implementation
self.blocks_in_primary_pool, self.blocks_in_secondary_pool = self.calculate_max_num_blocks(
@ -1089,8 +1123,8 @@ class KVCacheManager(BaseResourceManager):
window_size_to_layers_map[window_size].append(local_layer_idx)
return window_size_to_layers_map
@staticmethod
def adjust_window_sizes_for_vswa(
self,
window_size_to_layers: Dict[int, List[int]],
max_attention_window_vec: List[int],
kv_cache_config: KvCacheConfig,
@ -1107,8 +1141,7 @@ class KVCacheManager(BaseResourceManager):
def calculate_cache_size_per_token(layers: Set[int]) -> int:
# Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize
total_kv_heads = sum(model_config.num_kv_heads_per_layer[i]
for i in layers)
total_kv_heads = sum(self.num_kv_heads_per_layer[i] for i in layers)
return total_kv_heads * kv_factor * model_config.head_size
# Calculate the required memory bytes per sequence.
@ -1124,15 +1157,17 @@ class KVCacheManager(BaseResourceManager):
quant_vector_size=16,
scaling_factor_dtype=DataType.FP8)
required_mem_bytes_per_seq += window_size * cache_size_bytes_per_token
logger.debug(
logger.info(
f'Required memory per sequence: {required_mem_bytes_per_seq} bytes')
logger.info(f"Memory bytes in pool: {pool_memory_bytes}")
if required_mem_bytes_per_seq < pool_memory_bytes:
# No need to adjust the window sizes.
logger.info("No need to adjust the window sizes, returning")
return (copy.deepcopy(window_size_to_layers),
max_attention_window_vec)
logger.debug(
logger.info(
f'Adjusting the window sizes {list(window_size_to_layers)} to fit '
f'the memory {pool_memory_bytes} bytes.')
adjusted_window_size_to_layers = {}
@ -1206,14 +1241,12 @@ class KVCacheManager(BaseResourceManager):
return (adjusted_window_size_to_layers,
adjusted_max_attention_window_vec)
def calculate_max_num_blocks_from_cpp(
def calculate_max_num_blocks_for_vswa(
self,
kv_cache_config: KvCacheConfig,
model_config: ModelConfigCpp,
extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]:
"""
This function is a wrapper of KVCacheManagerCpp.calculate_max_num_blocks.
The final goal is to switch to the C++ implementation of calculate_max_num_blocks.
Currently, this function is added to support *ONLY* VSWA.
Args:
@ -1223,6 +1256,15 @@ class KVCacheManager(BaseResourceManager):
Returns:
A dict of (max_attention_window, (blocks_in_primary_pool, blocks_in_secondary_pool)).
Environment variable TRTLLM_WINDOW_SIZE_SHARES is used to adjust the memory
share of each window size. By default, we allocate equal proportion shares of
memory for all window sizes (see the else case). With TRTLLM_WINDOW_SIZE_SHARES,
we can override this behavior to adjust the memory share of each window size.
For example, if we have window size of [512, 32768], then setting
TRTLLM_WINDOW_SIZE_SHARES=0.4,0.6 will be allocating 40% of the memory to
window size 512 and 60% of the memory to window size 32768.
"""
# VSWA on Torch backend has not supported the cross attention.
@ -1266,19 +1308,70 @@ class KVCacheManager(BaseResourceManager):
)
self.max_attention_window_vec = max_attention_window_vec
blocks_per_window = KVCacheManagerCpp.calculate_max_num_blocks(
config=PybindMirror.maybe_to_pybind(kv_cache_config),
# TODO: support cross attention
is_cross_attention=is_cross_attention,
dtype=self.dtype,
model_config=model_config,
world_config=world_config_cpp,
window_size_to_layers=window_size_to_layers,
allotted_primary_mem_bytes=self._primary_pool_memory_bytes,
allotted_secondary_mem_bytes=self._secondary_pool_memory_bytes,
extra_cost_memory=extra_cost_memory,
kv_factor=self.kv_factor,
)
def calculate_cache_size_per_token(layers: Set[int]) -> int:
# Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize
total_kv_heads = sum(self.num_kv_heads_per_layer[i] for i in layers)
return total_kv_heads * self.kv_factor * model_config.head_size
logger.info(
f"Primary pool memory bytes: {self._primary_pool_memory_bytes}")
logger.info(
f"Secondary pool memory bytes: {self._secondary_pool_memory_bytes}")
if os.getenv("TRTLLM_WINDOW_SIZE_SHARES") is not None:
logger.info("Environment variable TRTLLM_WINDOW_SIZE_SHARES is set")
window_size_shares = os.getenv("TRTLLM_WINDOW_SIZE_SHARES").split(
",")
window_size_shares = [float(share) for share in window_size_shares]
assert len(window_size_shares) == len(
window_size_to_layers
), "Number of shares in TRTLLM_WINDOW_SIZE_SHARES must match number of window sizes"
assert sum(
window_size_shares
) == 1.0, "Sum of shares in TRTLLM_WINDOW_SIZE_SHARES must be 1.0"
else:
logger.info(
"Using default allocation of equal proportion of memory to each window size"
)
window_size_shares = [
1.0 / len(window_size_to_layers) for _ in window_size_to_layers
]
logger.info(f"Derived window_size_shares: {window_size_shares}")
blocks_per_window = {}
for window_idx, (window_size, layers) in enumerate(
sorted(window_size_to_layers.items())):
cache_size_per_token = calculate_cache_size_per_token(layers)
cache_size_bytes_per_token = get_size_in_bytes(
cache_size_per_token, self.dtype)
primary_tokens = self._primary_pool_memory_bytes * window_size_shares[
window_idx] / cache_size_bytes_per_token
secondary_tokens = self._secondary_pool_memory_bytes * window_size_shares[
window_idx] / cache_size_bytes_per_token
if kv_cache_config.max_tokens is not None:
if self.is_vswa:
logger.info(
f"kv_cache_config.max_tokens is not None ({kv_cache_config.max_tokens}) but we are operating on VSWA scheme. Ignoring the configuration."
)
if not self.is_vswa:
logger.info(
f"kv_cache_config.max_tokens is {kv_cache_config.max_tokens}"
)
if kv_cache_config.max_tokens < primary_tokens:
logger.info(
f"kv_cache_config.max_tokens {kv_cache_config.max_tokens} is less than primary_tokens {primary_tokens}. Reducing primary_tokens to {kv_cache_config.max_tokens}"
)
primary_tokens = kv_cache_config.max_tokens
primary_blocks = int(primary_tokens // self.tokens_per_block)
secondary_blocks = int(secondary_tokens // self.tokens_per_block)
logger.info(
f"Window size = {window_size}, primary_blocks: {primary_blocks}, secondary_blocks: {secondary_blocks}"
)
blocks_per_window[window_size] = (primary_blocks, secondary_blocks)
return blocks_per_window
def _validate_and_adjust_attention_windows(

View File

@ -588,14 +588,17 @@ class SpecWorkerBase(nn.Module, ABC):
Returns:
input_ids_ctx: Prepared context input IDs
"""
input_prompt_ids = input_ids[:num_ctx_tokens]
input_ids_ctx = torch.empty_like(input_prompt_ids,
dtype=torch.int32,
device="cuda")
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
input_ids_ctx[
gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0]
return input_ids_ctx
if num_ctx_tokens > 0:
input_prompt_ids = input_ids[:num_ctx_tokens]
input_ids_ctx = torch.empty_like(input_prompt_ids,
dtype=torch.int32,
device="cuda")
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
input_ids_ctx[
gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0]
return input_ids_ctx
else:
return torch.empty(0, dtype=torch.int32, device="cuda")
def _sample_tokens_for_batch(
self,

View File

@ -456,9 +456,11 @@ class TestResourceManager(unittest.TestCase):
model_config.num_hidden_layers = len(total_layers)
model_config.num_attention_layers = len(total_layers)
model_config.layer_types = [LayerType.ATTENTION
] * model_config.num_attention_layers
kv_factor = 2
cache_bytes_per_token_per_layer = 8
cache_bytes_per_token_per_layer = 32
# Define test cases:
# (memory_bytes, expected_window_sizes, max_tokens, description)
@ -466,7 +468,7 @@ class TestResourceManager(unittest.TestCase):
test_cases = [
(
# Case 1: Limited memory - windows get clamped
cache_bytes_per_token_per_layer * (100 * 9 + 30 * 5) + 4,
cache_bytes_per_token_per_layer * (100 * 4 + 130 * 5) + 4,
{
100: [0, 1, 2, 3],
130: [4, 5, 6, 7, 8],
@ -477,7 +479,7 @@ class TestResourceManager(unittest.TestCase):
(
# Case 2: Less limited memory - the largest window get clamped
cache_bytes_per_token_per_layer *
(100 * 9 + 100 * 5 + 817 * 2) + 4,
(100 * 4 + 200 * 3 + 1017 * 2) + 4,
{
100: [0, 1, 2, 3],
200: [4, 5, 6],
@ -510,8 +512,7 @@ class TestResourceManager(unittest.TestCase):
(
# Case 5: Less limited memory but max_tokens is given.
# memory is enough for 1017 tokens, it will be clamped by max_tokens=134.
cache_bytes_per_token_per_layer *
(100 * 9 + 100 * 5 + 817 * 2) + 4,
cache_bytes_per_token_per_layer * (100 * 4 + 134 * 5) + 4,
{
100: [0, 1, 2, 3],
134: [4, 5, 6, 7, 8],
@ -523,8 +524,33 @@ class TestResourceManager(unittest.TestCase):
for memory_bytes, expected_window_sizes, expected_max_attention_window_vec, max_tokens, description in test_cases:
with self.subTest(case=description, memory_bytes=memory_bytes):
kv_cache_config = tllm.KvCacheConfig(max_tokens=max_tokens)
adjusted, adjusted_max_attention_window_vec = KVCacheManager.adjust_window_sizes_for_vswa(
kv_cache_config_params = {
"max_attention_window": max_attention_window_vec,
"free_gpu_memory_fraction": 1.0,
"host_cache_size": 0,
"max_gpu_total_bytes": memory_bytes,
}
kv_cache_config = TestResourceManager._create_kv_cache_config_for_kv_cache_manager(
kv_cache_config_params)
mapping = Mapping(world_size=1, tp_size=1, pp_size=1)
manager = KVCacheManager(
kv_cache_config=kv_cache_config,
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
CacheType.SELF,
num_layers=model_config.num_attention_layers,
num_kv_heads=8,
head_dim=model_config.head_size,
tokens_per_block=32,
max_seq_len=max(max_attention_window_vec),
max_batch_size=1,
mapping=mapping,
dtype=model_config.data_type,
model_config=model_config,
max_beam_width=1,
)
adjusted, adjusted_max_attention_window_vec = manager.adjust_window_sizes_for_vswa(
window_size_to_layers=window_size_to_layers,
max_attention_window_vec=max_attention_window_vec,
model_config=model_config,
@ -581,7 +607,7 @@ class TestResourceManager(unittest.TestCase):
"""
return KvCacheConfig(**params)
def test_calculate_max_num_blocks_from_cpp(self):
def test_calculate_max_num_blocks_for_vswa(self):
# Construct a minimal mapping (single-rank, no TP/PP)
mapping = Mapping(world_size=1, tp_size=1, pp_size=1)
@ -651,7 +677,7 @@ class TestResourceManager(unittest.TestCase):
kv_cache_config_params)
with patch('torch.cuda.mem_get_info',
return_value=(fixed_free_mem, fixed_total_mem)):
# Create a real KVCacheManager, it will run calculate_max_num_blocks_from_cpp in __init__
# Create a real KVCacheManager, it will run calculate_max_num_blocks_for_vswa in __init__
manager = KVCacheManager(
kv_cache_config=kv_cache_config,
kv_cache_type=tensorrt_llm.bindings.internal.