diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 4fea1e0b4e..497e92d861 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -691,7 +691,7 @@ def _create_kv_cache_manager( execution_stream=execution_stream, ) else: - # NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager + # NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_for_vswa in KVCahceManager is_vswa = kv_cache_config.max_attention_window is not None and len( set(kv_cache_config.max_attention_window)) > 1 binding_model_config = model_engine.model.model_config.get_bindings_model_config( diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index a5548946d8..f3cb048140 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1,6 +1,7 @@ import copy import enum import math +import os from abc import ABC, abstractmethod from collections import OrderedDict, defaultdict, deque from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, @@ -8,17 +9,18 @@ from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, import numpy as np import torch +from mpi4py import MPI import tensorrt_llm import tensorrt_llm.bindings from tensorrt_llm._torch.distributed.communicator import Distributed, ReduceOp from tensorrt_llm._utils import (TensorWrapper, convert_to_torch_tensor, - get_size_in_bytes) + get_size_in_bytes, mpi_comm, mpi_disabled, + torch_comm) from tensorrt_llm.bindings.internal.batch_manager.kv_cache_manager_v2_utils import ( IndexMapper, copy_batch_block_offsets_to_device) from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig -from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig, - PybindMirror) +from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig from tensorrt_llm.math_utils import ceil_div @@ -331,12 +333,44 @@ class KVCacheManager(BaseResourceManager): ) assert isinstance( kv_cache_config, KvCacheConfig - ), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig" - blocks_per_window = self.calculate_max_num_blocks_from_cpp( + ), "calculate_max_num_blocks_for_vswa only accepts KvCacheConfig" + blocks_per_window = self.calculate_max_num_blocks_for_vswa( kv_cache_config=kv_cache_config, model_config=model_config, extra_cost_memory=0, ) + if mapping.world_size > 1: + # make sure all ranks use the same number of primary/secondary blocks + if mpi_disabled(): + for window_size, ( + primary_blocks, + secondary_blocks) in blocks_per_window.items(): + reduced_primary_blocks = torch_comm().allreduce( + primary_blocks, + op=torch.distributed.ReduceOp.MIN) + reduced_secondary_blocks = torch_comm().allreduce( + secondary_blocks, + op=torch.distributed.ReduceOp.MIN) + blocks_per_window[window_size] = ( + reduced_primary_blocks, + reduced_secondary_blocks) + else: + for window_size, ( + primary_blocks, + secondary_blocks) in blocks_per_window.items(): + reduced_primary_blocks = mpi_comm().allreduce( + primary_blocks, op=MPI.MIN) + reduced_secondary_blocks = mpi_comm().allreduce( + secondary_blocks, op=MPI.MIN) + blocks_per_window[window_size] = ( + reduced_primary_blocks, + reduced_secondary_blocks) + logger.info( + f"[MPI rank={mapping.rank}] Original blocks_per_window: {blocks_per_window}" + ) + logger.info( + f"[MPI rank={mapping.rank}] Reduced blocks_per_window: {blocks_per_window}" + ) else: # Standard case: use original Python implementation self.blocks_in_primary_pool, self.blocks_in_secondary_pool = self.calculate_max_num_blocks( @@ -1089,8 +1123,8 @@ class KVCacheManager(BaseResourceManager): window_size_to_layers_map[window_size].append(local_layer_idx) return window_size_to_layers_map - @staticmethod def adjust_window_sizes_for_vswa( + self, window_size_to_layers: Dict[int, List[int]], max_attention_window_vec: List[int], kv_cache_config: KvCacheConfig, @@ -1107,8 +1141,7 @@ class KVCacheManager(BaseResourceManager): def calculate_cache_size_per_token(layers: Set[int]) -> int: # Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize - total_kv_heads = sum(model_config.num_kv_heads_per_layer[i] - for i in layers) + total_kv_heads = sum(self.num_kv_heads_per_layer[i] for i in layers) return total_kv_heads * kv_factor * model_config.head_size # Calculate the required memory bytes per sequence. @@ -1124,15 +1157,17 @@ class KVCacheManager(BaseResourceManager): quant_vector_size=16, scaling_factor_dtype=DataType.FP8) required_mem_bytes_per_seq += window_size * cache_size_bytes_per_token - logger.debug( + logger.info( f'Required memory per sequence: {required_mem_bytes_per_seq} bytes') + logger.info(f"Memory bytes in pool: {pool_memory_bytes}") if required_mem_bytes_per_seq < pool_memory_bytes: # No need to adjust the window sizes. + logger.info("No need to adjust the window sizes, returning") return (copy.deepcopy(window_size_to_layers), max_attention_window_vec) - logger.debug( + logger.info( f'Adjusting the window sizes {list(window_size_to_layers)} to fit ' f'the memory {pool_memory_bytes} bytes.') adjusted_window_size_to_layers = {} @@ -1206,14 +1241,12 @@ class KVCacheManager(BaseResourceManager): return (adjusted_window_size_to_layers, adjusted_max_attention_window_vec) - def calculate_max_num_blocks_from_cpp( + def calculate_max_num_blocks_for_vswa( self, kv_cache_config: KvCacheConfig, model_config: ModelConfigCpp, extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]: """ - This function is a wrapper of KVCacheManagerCpp.calculate_max_num_blocks. - The final goal is to switch to the C++ implementation of calculate_max_num_blocks. Currently, this function is added to support *ONLY* VSWA. Args: @@ -1223,6 +1256,15 @@ class KVCacheManager(BaseResourceManager): Returns: A dict of (max_attention_window, (blocks_in_primary_pool, blocks_in_secondary_pool)). + + Environment variable TRTLLM_WINDOW_SIZE_SHARES is used to adjust the memory + share of each window size. By default, we allocate equal proportion shares of + memory for all window sizes (see the else case). With TRTLLM_WINDOW_SIZE_SHARES, + we can override this behavior to adjust the memory share of each window size. + + For example, if we have window size of [512, 32768], then setting + TRTLLM_WINDOW_SIZE_SHARES=0.4,0.6 will be allocating 40% of the memory to + window size 512 and 60% of the memory to window size 32768. """ # VSWA on Torch backend has not supported the cross attention. @@ -1266,19 +1308,70 @@ class KVCacheManager(BaseResourceManager): ) self.max_attention_window_vec = max_attention_window_vec - blocks_per_window = KVCacheManagerCpp.calculate_max_num_blocks( - config=PybindMirror.maybe_to_pybind(kv_cache_config), - # TODO: support cross attention - is_cross_attention=is_cross_attention, - dtype=self.dtype, - model_config=model_config, - world_config=world_config_cpp, - window_size_to_layers=window_size_to_layers, - allotted_primary_mem_bytes=self._primary_pool_memory_bytes, - allotted_secondary_mem_bytes=self._secondary_pool_memory_bytes, - extra_cost_memory=extra_cost_memory, - kv_factor=self.kv_factor, - ) + def calculate_cache_size_per_token(layers: Set[int]) -> int: + # Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize + total_kv_heads = sum(self.num_kv_heads_per_layer[i] for i in layers) + return total_kv_heads * self.kv_factor * model_config.head_size + + logger.info( + f"Primary pool memory bytes: {self._primary_pool_memory_bytes}") + logger.info( + f"Secondary pool memory bytes: {self._secondary_pool_memory_bytes}") + + if os.getenv("TRTLLM_WINDOW_SIZE_SHARES") is not None: + logger.info("Environment variable TRTLLM_WINDOW_SIZE_SHARES is set") + window_size_shares = os.getenv("TRTLLM_WINDOW_SIZE_SHARES").split( + ",") + window_size_shares = [float(share) for share in window_size_shares] + assert len(window_size_shares) == len( + window_size_to_layers + ), "Number of shares in TRTLLM_WINDOW_SIZE_SHARES must match number of window sizes" + assert sum( + window_size_shares + ) == 1.0, "Sum of shares in TRTLLM_WINDOW_SIZE_SHARES must be 1.0" + else: + logger.info( + "Using default allocation of equal proportion of memory to each window size" + ) + window_size_shares = [ + 1.0 / len(window_size_to_layers) for _ in window_size_to_layers + ] + + logger.info(f"Derived window_size_shares: {window_size_shares}") + + blocks_per_window = {} + for window_idx, (window_size, layers) in enumerate( + sorted(window_size_to_layers.items())): + cache_size_per_token = calculate_cache_size_per_token(layers) + cache_size_bytes_per_token = get_size_in_bytes( + cache_size_per_token, self.dtype) + + primary_tokens = self._primary_pool_memory_bytes * window_size_shares[ + window_idx] / cache_size_bytes_per_token + secondary_tokens = self._secondary_pool_memory_bytes * window_size_shares[ + window_idx] / cache_size_bytes_per_token + + if kv_cache_config.max_tokens is not None: + if self.is_vswa: + logger.info( + f"kv_cache_config.max_tokens is not None ({kv_cache_config.max_tokens}) but we are operating on VSWA scheme. Ignoring the configuration." + ) + if not self.is_vswa: + logger.info( + f"kv_cache_config.max_tokens is {kv_cache_config.max_tokens}" + ) + if kv_cache_config.max_tokens < primary_tokens: + logger.info( + f"kv_cache_config.max_tokens {kv_cache_config.max_tokens} is less than primary_tokens {primary_tokens}. Reducing primary_tokens to {kv_cache_config.max_tokens}" + ) + primary_tokens = kv_cache_config.max_tokens + + primary_blocks = int(primary_tokens // self.tokens_per_block) + secondary_blocks = int(secondary_tokens // self.tokens_per_block) + logger.info( + f"Window size = {window_size}, primary_blocks: {primary_blocks}, secondary_blocks: {secondary_blocks}" + ) + blocks_per_window[window_size] = (primary_blocks, secondary_blocks) return blocks_per_window def _validate_and_adjust_attention_windows( diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index a155bac5a0..eeee063a7f 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -588,14 +588,17 @@ class SpecWorkerBase(nn.Module, ABC): Returns: input_ids_ctx: Prepared context input IDs """ - input_prompt_ids = input_ids[:num_ctx_tokens] - input_ids_ctx = torch.empty_like(input_prompt_ids, - dtype=torch.int32, - device="cuda") - input_ids_ctx[:-1].copy_(input_prompt_ids[1:]) - input_ids_ctx[ - gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0] - return input_ids_ctx + if num_ctx_tokens > 0: + input_prompt_ids = input_ids[:num_ctx_tokens] + input_ids_ctx = torch.empty_like(input_prompt_ids, + dtype=torch.int32, + device="cuda") + input_ids_ctx[:-1].copy_(input_prompt_ids[1:]) + input_ids_ctx[ + gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0] + return input_ids_ctx + else: + return torch.empty(0, dtype=torch.int32, device="cuda") def _sample_tokens_for_batch( self, diff --git a/tests/unittest/_torch/executor/test_resource_manager.py b/tests/unittest/_torch/executor/test_resource_manager.py index c31c08fa22..273abca39e 100644 --- a/tests/unittest/_torch/executor/test_resource_manager.py +++ b/tests/unittest/_torch/executor/test_resource_manager.py @@ -456,9 +456,11 @@ class TestResourceManager(unittest.TestCase): model_config.num_hidden_layers = len(total_layers) model_config.num_attention_layers = len(total_layers) + model_config.layer_types = [LayerType.ATTENTION + ] * model_config.num_attention_layers kv_factor = 2 - cache_bytes_per_token_per_layer = 8 + cache_bytes_per_token_per_layer = 32 # Define test cases: # (memory_bytes, expected_window_sizes, max_tokens, description) @@ -466,7 +468,7 @@ class TestResourceManager(unittest.TestCase): test_cases = [ ( # Case 1: Limited memory - windows get clamped - cache_bytes_per_token_per_layer * (100 * 9 + 30 * 5) + 4, + cache_bytes_per_token_per_layer * (100 * 4 + 130 * 5) + 4, { 100: [0, 1, 2, 3], 130: [4, 5, 6, 7, 8], @@ -477,7 +479,7 @@ class TestResourceManager(unittest.TestCase): ( # Case 2: Less limited memory - the largest window get clamped cache_bytes_per_token_per_layer * - (100 * 9 + 100 * 5 + 817 * 2) + 4, + (100 * 4 + 200 * 3 + 1017 * 2) + 4, { 100: [0, 1, 2, 3], 200: [4, 5, 6], @@ -510,8 +512,7 @@ class TestResourceManager(unittest.TestCase): ( # Case 5: Less limited memory but max_tokens is given. # memory is enough for 1017 tokens, it will be clamped by max_tokens=134. - cache_bytes_per_token_per_layer * - (100 * 9 + 100 * 5 + 817 * 2) + 4, + cache_bytes_per_token_per_layer * (100 * 4 + 134 * 5) + 4, { 100: [0, 1, 2, 3], 134: [4, 5, 6, 7, 8], @@ -523,8 +524,33 @@ class TestResourceManager(unittest.TestCase): for memory_bytes, expected_window_sizes, expected_max_attention_window_vec, max_tokens, description in test_cases: with self.subTest(case=description, memory_bytes=memory_bytes): - kv_cache_config = tllm.KvCacheConfig(max_tokens=max_tokens) - adjusted, adjusted_max_attention_window_vec = KVCacheManager.adjust_window_sizes_for_vswa( + kv_cache_config_params = { + "max_attention_window": max_attention_window_vec, + "free_gpu_memory_fraction": 1.0, + "host_cache_size": 0, + "max_gpu_total_bytes": memory_bytes, + } + kv_cache_config = TestResourceManager._create_kv_cache_config_for_kv_cache_manager( + kv_cache_config_params) + mapping = Mapping(world_size=1, tp_size=1, pp_size=1) + + manager = KVCacheManager( + kv_cache_config=kv_cache_config, + kv_cache_type=tensorrt_llm.bindings.internal.batch_manager. + CacheType.SELF, + num_layers=model_config.num_attention_layers, + num_kv_heads=8, + head_dim=model_config.head_size, + tokens_per_block=32, + max_seq_len=max(max_attention_window_vec), + max_batch_size=1, + mapping=mapping, + dtype=model_config.data_type, + model_config=model_config, + max_beam_width=1, + ) + + adjusted, adjusted_max_attention_window_vec = manager.adjust_window_sizes_for_vswa( window_size_to_layers=window_size_to_layers, max_attention_window_vec=max_attention_window_vec, model_config=model_config, @@ -581,7 +607,7 @@ class TestResourceManager(unittest.TestCase): """ return KvCacheConfig(**params) - def test_calculate_max_num_blocks_from_cpp(self): + def test_calculate_max_num_blocks_for_vswa(self): # Construct a minimal mapping (single-rank, no TP/PP) mapping = Mapping(world_size=1, tp_size=1, pp_size=1) @@ -651,7 +677,7 @@ class TestResourceManager(unittest.TestCase): kv_cache_config_params) with patch('torch.cuda.mem_get_info', return_value=(fixed_free_mem, fixed_total_mem)): - # Create a real KVCacheManager, it will run calculate_max_num_blocks_from_cpp in __init__ + # Create a real KVCacheManager, it will run calculate_max_num_blocks_for_vswa in __init__ manager = KVCacheManager( kv_cache_config=kv_cache_config, kv_cache_type=tensorrt_llm.bindings.internal.