mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5474169][fix]Adjust max seq len for kvcache for memory estimation (#7391)
Signed-off-by: Hui Gao <huig@nvidia.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
293d9fb612
commit
123f5cbbf0
@ -67,14 +67,14 @@ class KvCacheCreator:
|
||||
self._max_kv_tokens_in = self._kv_cache_config.max_tokens
|
||||
self._max_num_tokens = max_num_tokens
|
||||
self._max_beam_width = max_beam_width
|
||||
self._dummy_reqs = self._create_dummy_context_requests(net_max_seq_len -
|
||||
1)
|
||||
self._kv_connector_manager = kv_connector_manager
|
||||
self._pytorch_backend_config = pytorch_backend_config
|
||||
self._speculative_config = speculative_config
|
||||
self._tokens_per_block = tokens_per_block
|
||||
self._max_seq_len = max_seq_len
|
||||
self._max_batch_size = max_batch_size
|
||||
self._dummy_reqs = self._create_dummy_context_requests(net_max_seq_len -
|
||||
1)
|
||||
|
||||
@staticmethod
|
||||
def _get_cache_size_per_token(model_config: ModelConfig,
|
||||
@ -196,6 +196,10 @@ class KvCacheCreator:
|
||||
if spec_cfg is not None:
|
||||
num_extra_tokens_per_seq += spec_cfg.max_draft_len
|
||||
num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg)
|
||||
|
||||
if self._dummy_reqs is None:
|
||||
self._dummy_reqs = self._create_dummy_context_requests(
|
||||
max(1, self.net_max_seq_len - 1))
|
||||
for req in self._dummy_reqs:
|
||||
num_req_tokens = len(req.input_token_ids) + num_extra_tokens_per_seq
|
||||
# Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size.
|
||||
@ -480,6 +484,10 @@ class KvCacheCreator:
|
||||
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:
|
||||
self._max_seq_len = kv_cache_manager.max_seq_len
|
||||
|
||||
# When SWA is enabled, max_seq_len is updated inside kv_cache_manager.
|
||||
if kv_cache_manager is not None:
|
||||
self._max_seq_len = kv_cache_manager.max_seq_len
|
||||
|
||||
return kv_cache_manager
|
||||
|
||||
def build_managers(self,
|
||||
|
||||
@ -582,7 +582,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
if kv_cache_config.free_gpu_memory_fraction is not None:
|
||||
max_tokens = min(kv_cache_config.max_tokens, max_tokens)
|
||||
logger.warning(
|
||||
f'Both free_gpu_memory_fraction and max_tokens are set (to {free_mem_fraction} and {kv_cache_config.max_tokens}, respectively). The smaller value will be used.'
|
||||
f'Both free_gpu_memory_fraction and max_tokens are set (to {free_mem_fraction} and {max_tokens} with free memory {free_mem / (1 << 32)} of total memory {total_mem / (1<<32)}, respectively). The smaller value will be used.'
|
||||
)
|
||||
else:
|
||||
max_tokens = kv_cache_config.max_tokens
|
||||
|
||||
Loading…
Reference in New Issue
Block a user