mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][chore] remove executor config in kv cache creator (#7526)
Signed-off-by: leslie-fang25 <leslief@nvidia.com>
This commit is contained in:
parent
a4312ba743
commit
d219a4f225
@ -9,7 +9,7 @@ import tensorrt_llm
|
||||
import tensorrt_llm.bindings.executor as trtllm
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
|
||||
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
|
||||
from tensorrt_llm.bindings.executor import DecodingMode
|
||||
from tensorrt_llm.llmapi.llm_args import (PeftCacheConfig, SamplerType,
|
||||
SpeculativeConfig)
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -43,19 +43,38 @@ GB = 1 << 30
|
||||
class KvCacheCreator:
|
||||
"""Groups together logic related to KV cache construction."""
|
||||
|
||||
def __init__(self, *, executor_config: ExecutorConfig,
|
||||
model_engine: PyTorchModelEngine,
|
||||
draft_model_engine: Optional[PyTorchModelEngine],
|
||||
mapping: Mapping, net_max_seq_len: int,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager]):
|
||||
self._executor_config = executor_config
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_engine: PyTorchModelEngine,
|
||||
draft_model_engine: Optional[PyTorchModelEngine],
|
||||
mapping: Mapping,
|
||||
net_max_seq_len: int,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager],
|
||||
max_num_tokens: int,
|
||||
max_beam_width: int,
|
||||
tokens_per_block: int,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
kv_cache_config: trtllm.KvCacheConfig,
|
||||
pytorch_backend_config: PyTorchConfig,
|
||||
speculative_config: SpeculativeConfig,
|
||||
):
|
||||
self._model_engine = model_engine
|
||||
self._draft_model_engine = draft_model_engine
|
||||
self._mapping = mapping
|
||||
self._max_kv_tokens_in = self._executor_config.kv_cache_config.max_tokens
|
||||
self._kv_cache_config = kv_cache_config
|
||||
self._max_kv_tokens_in = self._kv_cache_config.max_tokens
|
||||
self._max_num_tokens = max_num_tokens
|
||||
self._max_beam_width = max_beam_width
|
||||
self._dummy_reqs = self._create_dummy_context_requests(net_max_seq_len -
|
||||
1)
|
||||
self._kv_connector_manager = kv_connector_manager
|
||||
self._pytorch_backend_config = pytorch_backend_config
|
||||
self._speculative_config = speculative_config
|
||||
self._tokens_per_block = tokens_per_block
|
||||
self._max_seq_len = max_seq_len
|
||||
self._max_batch_size = max_batch_size
|
||||
|
||||
@staticmethod
|
||||
def _get_cache_size_per_token(model_config: ModelConfig,
|
||||
@ -97,7 +116,7 @@ class KvCacheCreator:
|
||||
return mem_per_token
|
||||
|
||||
def _get_free_gpu_memory_fraction(self) -> float:
|
||||
fraction = self._executor_config.kv_cache_config.free_gpu_memory_fraction
|
||||
fraction = self._kv_cache_config.free_gpu_memory_fraction
|
||||
if fraction is None:
|
||||
fraction = 0.9
|
||||
return fraction
|
||||
@ -134,8 +153,8 @@ class KvCacheCreator:
|
||||
def _create_dummy_context_requests(
|
||||
self, input_seq_len: int) -> List[trtllm.Request]:
|
||||
vocab_size = self._model_engine.model.model_config.pretrained_config.vocab_size
|
||||
max_num_tokens = self._executor_config.max_num_tokens
|
||||
max_beam_width = self._executor_config.max_beam_width
|
||||
max_num_tokens = self._max_num_tokens
|
||||
max_beam_width = self._max_beam_width
|
||||
|
||||
requests = []
|
||||
input_seq_len = min(max_num_tokens, input_seq_len)
|
||||
@ -160,7 +179,6 @@ class KvCacheCreator:
|
||||
|
||||
def _get_token_num_for_estimation(self) -> int:
|
||||
"""Compute KV cache capacity required for estimate_max_kv_cache_tokens to succeed."""
|
||||
executor_config = self._executor_config
|
||||
if 'cp_type' in self._mapping.cp_config:
|
||||
raise ValueError(
|
||||
"KV cache size estimation not supported with context parallelism."
|
||||
@ -168,8 +186,8 @@ class KvCacheCreator:
|
||||
# estimate_max_kv_cache_tokens submits self._dummy_reqs
|
||||
num_cache_blocks = 0
|
||||
num_extra_tokens_per_seq = 1 # account for generated tokens
|
||||
pytorch_backend_config = executor_config.pytorch_backend_config
|
||||
spec_cfg = executor_config.speculative_config
|
||||
pytorch_backend_config = self._pytorch_backend_config
|
||||
spec_cfg = self._speculative_config
|
||||
if not pytorch_backend_config.disable_overlap_scheduler:
|
||||
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
|
||||
if spec_cfg is not None:
|
||||
@ -181,11 +199,10 @@ class KvCacheCreator:
|
||||
for req in self._dummy_reqs:
|
||||
num_req_tokens = len(req.input_token_ids) + num_extra_tokens_per_seq
|
||||
# Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size.
|
||||
num_cache_blocks += (num_req_tokens +
|
||||
executor_config.tokens_per_block -
|
||||
1) // executor_config.tokens_per_block
|
||||
num_cache_blocks += (num_req_tokens + self._tokens_per_block -
|
||||
1) // self._tokens_per_block
|
||||
# Multiply by beam width, to prevent rescaling of the max_seq_len caused by the influence of beam width during the preparation for kv_cache_estimation
|
||||
return num_cache_blocks * executor_config.tokens_per_block * self._dummy_reqs[
|
||||
return num_cache_blocks * self._tokens_per_block * self._dummy_reqs[
|
||||
0].sampling_config.beam_width
|
||||
|
||||
def try_prepare_estimation(self) -> bool:
|
||||
@ -197,7 +214,7 @@ class KvCacheCreator:
|
||||
estimating_kv_cache = False
|
||||
if 'cp_type' not in self._mapping.cp_config:
|
||||
estimating_kv_cache = True
|
||||
self._executor_config.kv_cache_config.max_tokens = self._get_token_num_for_estimation(
|
||||
self._kv_cache_config.max_tokens = self._get_token_num_for_estimation(
|
||||
)
|
||||
return estimating_kv_cache
|
||||
|
||||
@ -207,7 +224,6 @@ class KvCacheCreator:
|
||||
|
||||
This updates `kv_cache_config`.
|
||||
"""
|
||||
executor_config = self._executor_config
|
||||
mapping = self._mapping
|
||||
|
||||
# TODO: support CP by generating dummy requests for it.
|
||||
@ -286,7 +302,7 @@ class KvCacheCreator:
|
||||
total_gpu_memory, fraction,
|
||||
allocated_bytes)
|
||||
|
||||
max_attention_window = executor_config.kv_cache_config.max_attention_window
|
||||
max_attention_window = self._kv_cache_config.max_attention_window
|
||||
is_vswa = max_attention_window and len(set(max_attention_window)) > 1
|
||||
|
||||
# NOTE:
|
||||
@ -314,41 +330,39 @@ class KvCacheCreator:
|
||||
)
|
||||
if is_vswa:
|
||||
# For VSWA KvCacheManager now it can only use max_gpu_total_bytes
|
||||
executor_config.kv_cache_config.max_tokens = None
|
||||
self._kv_cache_config.max_tokens = None
|
||||
else:
|
||||
# For non-VSWA KvCacheManager, its logic still relies on max_tokens, need to improve in the future.
|
||||
executor_config.kv_cache_config.max_tokens = int(
|
||||
self._kv_cache_config.max_tokens = int(
|
||||
kv_cache_max_memory // self._get_kv_size_per_token())
|
||||
# ---------------------------handle max_tokens---------------------------------
|
||||
|
||||
# ---------------------------handle max_gpu_total_bytes---------------------------------
|
||||
# if user provided max_gpu_total_bytes, set max memory from max_gpu_total_bytes
|
||||
if executor_config.kv_cache_config.max_gpu_total_bytes > 0:
|
||||
kv_cache_max_memory = min(
|
||||
kv_cache_max_memory,
|
||||
executor_config.kv_cache_config.max_gpu_total_bytes)
|
||||
if self._kv_cache_config.max_gpu_total_bytes > 0:
|
||||
kv_cache_max_memory = min(kv_cache_max_memory,
|
||||
self._kv_cache_config.max_gpu_total_bytes)
|
||||
logger.info(
|
||||
f"max_gpu_total_bytes={executor_config.kv_cache_config.max_gpu_total_bytes / (GB):.2f} GiB is provided, max_memory is set to {kv_cache_max_memory / (GB):.2f} GiB"
|
||||
f"max_gpu_total_bytes={self._kv_cache_config.max_gpu_total_bytes / (GB):.2f} GiB is provided, max_memory is set to {kv_cache_max_memory / (GB):.2f} GiB"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Estimated max memory in KV cache : {kv_cache_max_memory / (GB):.2f} GiB"
|
||||
)
|
||||
# set max_gpu_total_bytes
|
||||
executor_config.kv_cache_config.max_gpu_total_bytes = kv_cache_max_memory
|
||||
self._kv_cache_config.max_gpu_total_bytes = kv_cache_max_memory
|
||||
# ---------------------------handle max_gpu_total_bytes---------------------------------
|
||||
|
||||
def _create_kv_cache_manager(
|
||||
self,
|
||||
model_engine: PyTorchModelEngine,
|
||||
estimating_kv_cache: bool = False) -> KVCacheManager:
|
||||
executor_config = self._executor_config
|
||||
mapping = self._mapping
|
||||
assert model_engine.model.model_config.is_generation, "Only construct KV cache for generation models."
|
||||
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
quant_config = model_engine.model.model_config.quant_config
|
||||
spec_config = executor_config.speculative_config
|
||||
spec_config = self._speculative_config
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
num_attention_heads = config.num_attention_heads
|
||||
@ -372,25 +386,25 @@ class KvCacheCreator:
|
||||
|
||||
if is_mla(config):
|
||||
kv_cache_manager = KVCacheManager(
|
||||
executor_config.kv_cache_config,
|
||||
self._kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.
|
||||
SELFKONLY,
|
||||
num_layers=num_hidden_layers,
|
||||
num_kv_heads=1,
|
||||
head_dim=config.kv_lora_rank + config.qk_rope_head_dim,
|
||||
tokens_per_block=executor_config.tokens_per_block,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
tokens_per_block=self._tokens_per_block,
|
||||
max_seq_len=self._max_seq_len,
|
||||
max_batch_size=self._max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
spec_config=spec_config,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
max_beam_width=self._max_beam_width,
|
||||
is_draft=model_engine.is_draft_model,
|
||||
kv_connector_manager=self._kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
)
|
||||
elif is_nemotron_hybrid(config):
|
||||
if executor_config.max_beam_width > 1:
|
||||
if self._max_beam_width > 1:
|
||||
raise ValueError(
|
||||
"MambaHybridCacheManager + beam search is not supported yet."
|
||||
)
|
||||
@ -423,49 +437,48 @@ class KvCacheCreator:
|
||||
model_engine.model.model_config.quant_config.
|
||||
mamba_ssm_cache_dtype,
|
||||
# kv cache parameters
|
||||
executor_config.kv_cache_config,
|
||||
self._kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
layer_mask=layer_mask,
|
||||
num_kv_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=executor_config.tokens_per_block,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
tokens_per_block=self._tokens_per_block,
|
||||
max_seq_len=self._max_seq_len,
|
||||
max_batch_size=self._max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
spec_config=spec_config,
|
||||
)
|
||||
else:
|
||||
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
|
||||
is_vswa = executor_config.kv_cache_config.max_attention_window is not None and len(
|
||||
set(executor_config.kv_cache_config.max_attention_window)) > 1
|
||||
is_vswa = self._kv_cache_config.max_attention_window is not None and len(
|
||||
set(self._kv_cache_config.max_attention_window)) > 1
|
||||
binding_model_config = model_engine.model.model_config.get_bindings_model_config(
|
||||
tokens_per_block=executor_config.tokens_per_block
|
||||
) if is_vswa else None
|
||||
tokens_per_block=self._tokens_per_block) if is_vswa else None
|
||||
|
||||
kv_cache_manager = KVCacheManager(
|
||||
executor_config.kv_cache_config,
|
||||
self._kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_hidden_layers,
|
||||
num_kv_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=executor_config.tokens_per_block,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
tokens_per_block=self._tokens_per_block,
|
||||
max_seq_len=self._max_seq_len,
|
||||
max_batch_size=self._max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
spec_config=spec_config,
|
||||
max_num_tokens=executor_config.max_num_tokens,
|
||||
max_num_tokens=self._max_num_tokens,
|
||||
model_config=binding_model_config,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
max_beam_width=self._max_beam_width,
|
||||
is_draft=model_engine.is_draft_model,
|
||||
kv_connector_manager=self._kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
)
|
||||
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config
|
||||
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self
|
||||
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:
|
||||
executor_config.max_seq_len = kv_cache_manager.max_seq_len
|
||||
self._max_seq_len = kv_cache_manager.max_seq_len
|
||||
|
||||
return kv_cache_manager
|
||||
|
||||
|
||||
@ -493,17 +493,29 @@ def create_py_executor(
|
||||
if model_engine.model.model_config.is_generation:
|
||||
#NOTE: non-generation models do not have kv cache
|
||||
kv_cache_creator = KvCacheCreator(
|
||||
executor_config=executor_config,
|
||||
model_engine=model_engine,
|
||||
draft_model_engine=draft_model_engine,
|
||||
mapping=mapping,
|
||||
net_max_seq_len=net_max_seq_len,
|
||||
kv_connector_manager=kv_connector_manager)
|
||||
kv_connector_manager=kv_connector_manager,
|
||||
max_num_tokens=executor_config.max_num_tokens,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
tokens_per_block=executor_config.tokens_per_block,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
kv_cache_config=executor_config.kv_cache_config,
|
||||
pytorch_backend_config=executor_config.pytorch_backend_config,
|
||||
speculative_config=executor_config.speculative_config,
|
||||
)
|
||||
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.INIT_KV_CACHE
|
||||
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
|
||||
kv_cache_creator.build_managers(resources, estimating_kv_cache)
|
||||
# Originally, executor_config.max_seq_len might be changed inside build_managers and used
|
||||
# below in create_py_executor_instance. Since now, we are changing
|
||||
# kv_cache_creator._max_seq_len instead, restore executor_config.max_seq_len.
|
||||
executor_config.max_seq_len = kv_cache_creator._max_seq_len
|
||||
update_sampler_max_seq_len(executor_config.max_seq_len, sampler)
|
||||
|
||||
# Resource managers for speculative decoding
|
||||
@ -564,10 +576,14 @@ def create_py_executor(
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.KV_CACHE):
|
||||
# Before estimating KV cache size, a minimal KV cache has been allocated using
|
||||
# create_kv_cache_manager above, which caps executor_config.max_seq_len. Restoring
|
||||
# create_kv_cache_manager above, which caps kv_cache_creator.max_seq_len. Restoring
|
||||
# the original value before creating the final KV cache.
|
||||
executor_config.max_seq_len = max_seq_len
|
||||
kv_cache_creator._max_seq_len = max_seq_len
|
||||
kv_cache_creator.build_managers(resources, False)
|
||||
# Originally, executor_config.max_seq_len might be changed again inside build_managers
|
||||
# Since now, we are changing kv_cache_creator.max_seq_len instead.
|
||||
# Restore executor_config.max_seq_len which has been used in create_py_executor_instance
|
||||
executor_config.max_seq_len = kv_cache_creator._max_seq_len
|
||||
update_sampler_max_seq_len(executor_config.max_seq_len, sampler)
|
||||
|
||||
for eng in [model_engine, draft_model_engine]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user