[TRTLLM-8477][chore] Replace KvCacheConfigCpp with KvCacheConfig inside PyExecutor (#8259)

Signed-off-by: leslie-fang25 <leslief@nvidia.com>
This commit is contained in:
Leslie Fang 2025-10-13 14:55:36 +08:00 committed by GitHub
parent 1a9044949f
commit 8d1b068b1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 23 additions and 44 deletions

View File

@ -10,7 +10,7 @@ import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import DecodingMode
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig,
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig,
MTPDecodingConfig, PeftCacheConfig,
SamplerType, SpeculativeConfig,
TorchLlmArgs)
@ -58,7 +58,7 @@ class KvCacheCreator:
tokens_per_block: int,
max_seq_len: int,
max_batch_size: int,
kv_cache_config: trtllm.KvCacheConfig,
kv_cache_config: KvCacheConfig,
pytorch_backend_config: PyTorchConfig,
speculative_config: SpeculativeConfig,
):
@ -790,7 +790,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
max_seq_len: int, mm_encoder_only: bool,
speculative_config: SpeculativeConfig,
decoding_config: trtllm.DecodingConfig,
kv_cache_config: trtllm.KvCacheConfig):
kv_cache_config: KvCacheConfig):
sampler_args = create_torch_sampler_args(
mapping,
max_seq_len=engine.max_seq_len,

View File

@ -19,9 +19,9 @@ import torch
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._torch.pyexecutor.resource_manager import (
BaseResourceManager, CacheTypeCpp, DataType, KvCacheConfigCpp,
KVCacheManager, get_pp_layers)
BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers)
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
from tensorrt_llm.mapping import Mapping
@ -180,7 +180,7 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
mamba_ssm_cache_dtype: torch.dtype,
# kv cache parameters
kv_cache_config: KvCacheConfigCpp,
kv_cache_config: KvCacheConfig,
kv_cache_type: CacheTypeCpp,
*,
num_layers: int,

View File

@ -223,7 +223,7 @@ def create_py_executor(
llm_args.peft_cache_config)
assert llm_args.kv_cache_config, "Expect llm_args.kv_cache_config is not None"
kv_cache_config = PybindMirror.maybe_to_pybind(llm_args.kv_cache_config)
kv_cache_config = llm_args.kv_cache_config
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
# Disable KV cache reuse for deterministic mode
kv_cache_config.enable_block_reuse = False
@ -251,7 +251,7 @@ def create_py_executor(
if max_num_tokens is None:
max_num_tokens = 8192
tokens_per_block = llm_args.kv_cache_config.tokens_per_block
tokens_per_block = kv_cache_config.tokens_per_block
if pytorch_backend_config.attn_backend in [
"FLASHINFER", "FLASHINFER_STAR_ATTENTION"

View File

@ -11,6 +11,7 @@ import tensorrt_llm
import tensorrt_llm.bindings
from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PybindMirror
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
from tensorrt_llm.runtime import ModelConfig as ModelConfigPython
@ -31,7 +32,6 @@ if ENABLE_MULTI_DEVICE:
BufferManagerCpp = tensorrt_llm.bindings.internal.runtime.BufferManager
KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager
KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig
CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType
ModelConfigCpp = tensorrt_llm.bindings.ModelConfig
DataType = tensorrt_llm.bindings.DataType
@ -145,7 +145,7 @@ class KVCacheManager(BaseResourceManager):
def __init__(
self,
kv_cache_config: KvCacheConfigCpp,
kv_cache_config: KvCacheConfig,
kv_cache_type: CacheTypeCpp,
*,
num_layers: int,
@ -268,8 +268,8 @@ class KVCacheManager(BaseResourceManager):
)
# kv cache config check
assert isinstance(
kv_cache_config, KvCacheConfigCpp
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfigCpp"
kv_cache_config, KvCacheConfig
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig"
blocks_per_window = self.calculate_max_num_blocks_from_cpp(
kv_cache_config=kv_cache_config,
model_config=model_config,
@ -370,28 +370,6 @@ class KVCacheManager(BaseResourceManager):
def shutdown(self):
self.impl.release_pools()
@classmethod
def from_model_config(cls,
model_config: ModelConfigCpp,
kv_cache_config: KvCacheConfigCpp,
mapping: Mapping,
kv_cache_type: CacheTypeCpp = CacheTypeCpp.SELF,
dtype: DataType = DataType.HALF) -> "KVCacheManager":
return cls(
kv_cache_config,
kv_cache_type,
num_layers=model_config.num_attention_layers(mapping.pp_size),
# NOTE: this preserves existing behavior in KV cache manager.
# But we should change this to pass a list at some point.
# We're assuming the KV cache is homogeneous here.
num_kv_heads=model_config.num_kv_heads(0),
head_dim=model_config.size_per_head,
tokens_per_block=model_config.tokens_per_block,
max_seq_len=model_config.max_seq_len,
max_batch_size=model_config.max_batch_size,
mapping=mapping,
dtype=dtype)
def get_max_resource_count(self) -> int:
return self.impl.max_num_blocks
@ -566,7 +544,7 @@ class KVCacheManager(BaseResourceManager):
scaling_factor_dtype)
def calculate_max_num_blocks(self,
kv_cache_config: KvCacheConfigCpp,
kv_cache_config: KvCacheConfig,
head_dim: int,
tokens_per_block: int,
mapping: Mapping,
@ -772,7 +750,7 @@ class KVCacheManager(BaseResourceManager):
def adjust_window_sizes_for_vswa(
window_size_to_layers: Dict[int, List[int]],
max_attention_window_vec: List[int],
kv_cache_config: KvCacheConfigCpp,
kv_cache_config: KvCacheConfig,
model_config: ModelConfigCpp,
pool_memory_bytes: int,
kv_factor: int,
@ -887,7 +865,7 @@ class KVCacheManager(BaseResourceManager):
def calculate_max_num_blocks_from_cpp(
self,
kv_cache_config: KvCacheConfigCpp,
kv_cache_config: KvCacheConfig,
model_config: ModelConfigCpp,
extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]:
"""
@ -945,7 +923,7 @@ class KVCacheManager(BaseResourceManager):
self.max_attention_window_vec = max_attention_window_vec
blocks_per_window = KVCacheManagerCpp.calculate_max_num_blocks(
config=kv_cache_config,
config=PybindMirror.maybe_to_pybind(kv_cache_config),
# TODO: support cross attention
is_cross_attention=is_cross_attention,
dtype=self.dtype,

View File

@ -17,7 +17,7 @@ from tensorrt_llm._utils import mpi_disabled, nvtx_range, torch_dtype_to_binding
from tensorrt_llm.bindings import (CudaStream, DataType, ModelConfig,
WorldConfig, make_sampling_config)
from tensorrt_llm.bindings.executor import (DecodingConfig, DecodingMode,
FinishReason, KvCacheConfig)
FinishReason)
from tensorrt_llm.bindings.internal.algorithms import CreateNewDecoderRequests
from tensorrt_llm.bindings.internal.batch_manager import (
DecoderInputBuffers, add_new_tokens_to_requests, make_decoding_batch_input)
@ -25,6 +25,7 @@ from tensorrt_llm.bindings.internal.runtime import (BufferManager, CudaEvent,
DecoderState,
GptDecoderBatched)
from tensorrt_llm.executor.result import Logprob
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.sampling_params import SamplingParams

View File

@ -20,6 +20,7 @@ from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
from tensorrt_llm.bindings import executor as tllm
from tensorrt_llm.bindings.internal.batch_manager import \
PeftTaskNotCachedException
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping
@ -574,11 +575,11 @@ class TestResourceManager(unittest.TestCase):
@staticmethod
def _create_kv_cache_config_for_kv_cache_manager(
params: dict) -> tllm.KvCacheConfig:
params: dict) -> KvCacheConfig:
"""
Create a KV cache config for KVCacheManager test.
"""
return tllm.KvCacheConfig(**params)
return KvCacheConfig(**params)
def test_calculate_max_num_blocks_from_cpp(self):
# Construct a minimal mapping (single-rank, no TP/PP)
@ -633,9 +634,8 @@ class TestResourceManager(unittest.TestCase):
"free_gpu_memory_fraction": free_gpu_memory_fraction,
"enable_block_reuse": enable_block_reuse,
},
# NOTE: use np.float32 to avoid float precision issue between python(double in most cases) and cpp binding(float)
expected_memory_bytes=(int(
fixed_free_mem * np.float32(free_gpu_memory_fraction)), 0),
expected_memory_bytes=(int(fixed_free_mem *
free_gpu_memory_fraction), 0),
),
]