From d219a4f2251e98f36dc042f475dab7294f447e72 Mon Sep 17 00:00:00 2001 From: Leslie Fang Date: Wed, 10 Sep 2025 21:14:44 +0800 Subject: [PATCH] [None][chore] remove executor config in kv cache creator (#7526) Signed-off-by: leslie-fang25 --- tensorrt_llm/_torch/pyexecutor/_util.py | 119 ++++++++++-------- .../_torch/pyexecutor/py_executor_creator.py | 24 +++- 2 files changed, 86 insertions(+), 57 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index fd8ffd33b2..684158a806 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -9,7 +9,7 @@ import tensorrt_llm import tensorrt_llm.bindings.executor as trtllm from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str -from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig +from tensorrt_llm.bindings.executor import DecodingMode from tensorrt_llm.llmapi.llm_args import (PeftCacheConfig, SamplerType, SpeculativeConfig) from tensorrt_llm.logger import logger @@ -43,19 +43,38 @@ GB = 1 << 30 class KvCacheCreator: """Groups together logic related to KV cache construction.""" - def __init__(self, *, executor_config: ExecutorConfig, - model_engine: PyTorchModelEngine, - draft_model_engine: Optional[PyTorchModelEngine], - mapping: Mapping, net_max_seq_len: int, - kv_connector_manager: Optional[KvCacheConnectorManager]): - self._executor_config = executor_config + def __init__( + self, + *, + model_engine: PyTorchModelEngine, + draft_model_engine: Optional[PyTorchModelEngine], + mapping: Mapping, + net_max_seq_len: int, + kv_connector_manager: Optional[KvCacheConnectorManager], + max_num_tokens: int, + max_beam_width: int, + tokens_per_block: int, + max_seq_len: int, + max_batch_size: int, + kv_cache_config: trtllm.KvCacheConfig, + pytorch_backend_config: PyTorchConfig, + speculative_config: SpeculativeConfig, + ): self._model_engine = model_engine self._draft_model_engine = draft_model_engine self._mapping = mapping - self._max_kv_tokens_in = self._executor_config.kv_cache_config.max_tokens + self._kv_cache_config = kv_cache_config + self._max_kv_tokens_in = self._kv_cache_config.max_tokens + self._max_num_tokens = max_num_tokens + self._max_beam_width = max_beam_width self._dummy_reqs = self._create_dummy_context_requests(net_max_seq_len - 1) self._kv_connector_manager = kv_connector_manager + self._pytorch_backend_config = pytorch_backend_config + self._speculative_config = speculative_config + self._tokens_per_block = tokens_per_block + self._max_seq_len = max_seq_len + self._max_batch_size = max_batch_size @staticmethod def _get_cache_size_per_token(model_config: ModelConfig, @@ -97,7 +116,7 @@ class KvCacheCreator: return mem_per_token def _get_free_gpu_memory_fraction(self) -> float: - fraction = self._executor_config.kv_cache_config.free_gpu_memory_fraction + fraction = self._kv_cache_config.free_gpu_memory_fraction if fraction is None: fraction = 0.9 return fraction @@ -134,8 +153,8 @@ class KvCacheCreator: def _create_dummy_context_requests( self, input_seq_len: int) -> List[trtllm.Request]: vocab_size = self._model_engine.model.model_config.pretrained_config.vocab_size - max_num_tokens = self._executor_config.max_num_tokens - max_beam_width = self._executor_config.max_beam_width + max_num_tokens = self._max_num_tokens + max_beam_width = self._max_beam_width requests = [] input_seq_len = min(max_num_tokens, input_seq_len) @@ -160,7 +179,6 @@ class KvCacheCreator: def _get_token_num_for_estimation(self) -> int: """Compute KV cache capacity required for estimate_max_kv_cache_tokens to succeed.""" - executor_config = self._executor_config if 'cp_type' in self._mapping.cp_config: raise ValueError( "KV cache size estimation not supported with context parallelism." @@ -168,8 +186,8 @@ class KvCacheCreator: # estimate_max_kv_cache_tokens submits self._dummy_reqs num_cache_blocks = 0 num_extra_tokens_per_seq = 1 # account for generated tokens - pytorch_backend_config = executor_config.pytorch_backend_config - spec_cfg = executor_config.speculative_config + pytorch_backend_config = self._pytorch_backend_config + spec_cfg = self._speculative_config if not pytorch_backend_config.disable_overlap_scheduler: num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1 if spec_cfg is not None: @@ -181,11 +199,10 @@ class KvCacheCreator: for req in self._dummy_reqs: num_req_tokens = len(req.input_token_ids) + num_extra_tokens_per_seq # Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size. - num_cache_blocks += (num_req_tokens + - executor_config.tokens_per_block - - 1) // executor_config.tokens_per_block + num_cache_blocks += (num_req_tokens + self._tokens_per_block - + 1) // self._tokens_per_block # Multiply by beam width, to prevent rescaling of the max_seq_len caused by the influence of beam width during the preparation for kv_cache_estimation - return num_cache_blocks * executor_config.tokens_per_block * self._dummy_reqs[ + return num_cache_blocks * self._tokens_per_block * self._dummy_reqs[ 0].sampling_config.beam_width def try_prepare_estimation(self) -> bool: @@ -197,7 +214,7 @@ class KvCacheCreator: estimating_kv_cache = False if 'cp_type' not in self._mapping.cp_config: estimating_kv_cache = True - self._executor_config.kv_cache_config.max_tokens = self._get_token_num_for_estimation( + self._kv_cache_config.max_tokens = self._get_token_num_for_estimation( ) return estimating_kv_cache @@ -207,7 +224,6 @@ class KvCacheCreator: This updates `kv_cache_config`. """ - executor_config = self._executor_config mapping = self._mapping # TODO: support CP by generating dummy requests for it. @@ -286,7 +302,7 @@ class KvCacheCreator: total_gpu_memory, fraction, allocated_bytes) - max_attention_window = executor_config.kv_cache_config.max_attention_window + max_attention_window = self._kv_cache_config.max_attention_window is_vswa = max_attention_window and len(set(max_attention_window)) > 1 # NOTE: @@ -314,41 +330,39 @@ class KvCacheCreator: ) if is_vswa: # For VSWA KvCacheManager now it can only use max_gpu_total_bytes - executor_config.kv_cache_config.max_tokens = None + self._kv_cache_config.max_tokens = None else: # For non-VSWA KvCacheManager, its logic still relies on max_tokens, need to improve in the future. - executor_config.kv_cache_config.max_tokens = int( + self._kv_cache_config.max_tokens = int( kv_cache_max_memory // self._get_kv_size_per_token()) # ---------------------------handle max_tokens--------------------------------- # ---------------------------handle max_gpu_total_bytes--------------------------------- # if user provided max_gpu_total_bytes, set max memory from max_gpu_total_bytes - if executor_config.kv_cache_config.max_gpu_total_bytes > 0: - kv_cache_max_memory = min( - kv_cache_max_memory, - executor_config.kv_cache_config.max_gpu_total_bytes) + if self._kv_cache_config.max_gpu_total_bytes > 0: + kv_cache_max_memory = min(kv_cache_max_memory, + self._kv_cache_config.max_gpu_total_bytes) logger.info( - f"max_gpu_total_bytes={executor_config.kv_cache_config.max_gpu_total_bytes / (GB):.2f} GiB is provided, max_memory is set to {kv_cache_max_memory / (GB):.2f} GiB" + f"max_gpu_total_bytes={self._kv_cache_config.max_gpu_total_bytes / (GB):.2f} GiB is provided, max_memory is set to {kv_cache_max_memory / (GB):.2f} GiB" ) logger.info( f"Estimated max memory in KV cache : {kv_cache_max_memory / (GB):.2f} GiB" ) # set max_gpu_total_bytes - executor_config.kv_cache_config.max_gpu_total_bytes = kv_cache_max_memory + self._kv_cache_config.max_gpu_total_bytes = kv_cache_max_memory # ---------------------------handle max_gpu_total_bytes--------------------------------- def _create_kv_cache_manager( self, model_engine: PyTorchModelEngine, estimating_kv_cache: bool = False) -> KVCacheManager: - executor_config = self._executor_config mapping = self._mapping assert model_engine.model.model_config.is_generation, "Only construct KV cache for generation models." config = model_engine.model.model_config.pretrained_config quant_config = model_engine.model.model_config.quant_config - spec_config = executor_config.speculative_config + spec_config = self._speculative_config hidden_size = config.hidden_size num_attention_heads = config.num_attention_heads @@ -372,25 +386,25 @@ class KvCacheCreator: if is_mla(config): kv_cache_manager = KVCacheManager( - executor_config.kv_cache_config, + self._kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType. SELFKONLY, num_layers=num_hidden_layers, num_kv_heads=1, head_dim=config.kv_lora_rank + config.qk_rope_head_dim, - tokens_per_block=executor_config.tokens_per_block, - max_seq_len=executor_config.max_seq_len, - max_batch_size=executor_config.max_batch_size, + tokens_per_block=self._tokens_per_block, + max_seq_len=self._max_seq_len, + max_batch_size=self._max_batch_size, mapping=mapping, dtype=kv_cache_dtype, spec_config=spec_config, - max_beam_width=executor_config.max_beam_width, + max_beam_width=self._max_beam_width, is_draft=model_engine.is_draft_model, kv_connector_manager=self._kv_connector_manager if not estimating_kv_cache else None, ) elif is_nemotron_hybrid(config): - if executor_config.max_beam_width > 1: + if self._max_beam_width > 1: raise ValueError( "MambaHybridCacheManager + beam search is not supported yet." ) @@ -423,49 +437,48 @@ class KvCacheCreator: model_engine.model.model_config.quant_config. mamba_ssm_cache_dtype, # kv cache parameters - executor_config.kv_cache_config, + self._kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, num_layers=num_layers, layer_mask=layer_mask, num_kv_heads=num_key_value_heads, head_dim=head_dim, - tokens_per_block=executor_config.tokens_per_block, - max_seq_len=executor_config.max_seq_len, - max_batch_size=executor_config.max_batch_size, + tokens_per_block=self._tokens_per_block, + max_seq_len=self._max_seq_len, + max_batch_size=self._max_batch_size, mapping=mapping, dtype=kv_cache_dtype, spec_config=spec_config, ) else: # NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager - is_vswa = executor_config.kv_cache_config.max_attention_window is not None and len( - set(executor_config.kv_cache_config.max_attention_window)) > 1 + is_vswa = self._kv_cache_config.max_attention_window is not None and len( + set(self._kv_cache_config.max_attention_window)) > 1 binding_model_config = model_engine.model.model_config.get_bindings_model_config( - tokens_per_block=executor_config.tokens_per_block - ) if is_vswa else None + tokens_per_block=self._tokens_per_block) if is_vswa else None kv_cache_manager = KVCacheManager( - executor_config.kv_cache_config, + self._kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, num_layers=num_hidden_layers, num_kv_heads=num_key_value_heads, head_dim=head_dim, - tokens_per_block=executor_config.tokens_per_block, - max_seq_len=executor_config.max_seq_len, - max_batch_size=executor_config.max_batch_size, + tokens_per_block=self._tokens_per_block, + max_seq_len=self._max_seq_len, + max_batch_size=self._max_batch_size, mapping=mapping, dtype=kv_cache_dtype, spec_config=spec_config, - max_num_tokens=executor_config.max_num_tokens, + max_num_tokens=self._max_num_tokens, model_config=binding_model_config, - max_beam_width=executor_config.max_beam_width, + max_beam_width=self._max_beam_width, is_draft=model_engine.is_draft_model, kv_connector_manager=self._kv_connector_manager if not estimating_kv_cache else None, ) - # KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config + # KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER: - executor_config.max_seq_len = kv_cache_manager.max_seq_len + self._max_seq_len = kv_cache_manager.max_seq_len return kv_cache_manager diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 9cc3b3fb5e..cb045f5379 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -493,17 +493,29 @@ def create_py_executor( if model_engine.model.model_config.is_generation: #NOTE: non-generation models do not have kv cache kv_cache_creator = KvCacheCreator( - executor_config=executor_config, model_engine=model_engine, draft_model_engine=draft_model_engine, mapping=mapping, net_max_seq_len=net_max_seq_len, - kv_connector_manager=kv_connector_manager) + kv_connector_manager=kv_connector_manager, + max_num_tokens=executor_config.max_num_tokens, + max_beam_width=executor_config.max_beam_width, + tokens_per_block=executor_config.tokens_per_block, + max_seq_len=executor_config.max_seq_len, + max_batch_size=executor_config.max_batch_size, + kv_cache_config=executor_config.kv_cache_config, + pytorch_backend_config=executor_config.pytorch_backend_config, + speculative_config=executor_config.speculative_config, + ) estimating_kv_cache = kv_cache_creator.try_prepare_estimation() with mem_monitor.observe_creation_stage( _ExecutorCreationStage.INIT_KV_CACHE if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE): kv_cache_creator.build_managers(resources, estimating_kv_cache) + # Originally, executor_config.max_seq_len might be changed inside build_managers and used + # below in create_py_executor_instance. Since now, we are changing + # kv_cache_creator._max_seq_len instead, restore executor_config.max_seq_len. + executor_config.max_seq_len = kv_cache_creator._max_seq_len update_sampler_max_seq_len(executor_config.max_seq_len, sampler) # Resource managers for speculative decoding @@ -564,10 +576,14 @@ def create_py_executor( with mem_monitor.observe_creation_stage( _ExecutorCreationStage.KV_CACHE): # Before estimating KV cache size, a minimal KV cache has been allocated using - # create_kv_cache_manager above, which caps executor_config.max_seq_len. Restoring + # create_kv_cache_manager above, which caps kv_cache_creator.max_seq_len. Restoring # the original value before creating the final KV cache. - executor_config.max_seq_len = max_seq_len + kv_cache_creator._max_seq_len = max_seq_len kv_cache_creator.build_managers(resources, False) + # Originally, executor_config.max_seq_len might be changed again inside build_managers + # Since now, we are changing kv_cache_creator.max_seq_len instead. + # Restore executor_config.max_seq_len which has been used in create_py_executor_instance + executor_config.max_seq_len = kv_cache_creator._max_seq_len update_sampler_max_seq_len(executor_config.max_seq_len, sampler) for eng in [model_engine, draft_model_engine]: