mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Add support for KVCache reuse for DSv32 (#9383)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
parent
dcf5c86720
commit
356a52edf5
@ -183,6 +183,10 @@ private:
|
||||
auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx);
|
||||
mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx));
|
||||
}
|
||||
if (cacheManager.isEnableIndexerKCache())
|
||||
{
|
||||
mIndexerKCachePool = cacheManager.getIndexerKCachePool();
|
||||
}
|
||||
}
|
||||
|
||||
BlockRange(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
|
||||
|
||||
@ -806,7 +806,7 @@ public:
|
||||
|
||||
RequestInfo requestInfo(requestId, mSelfState);
|
||||
|
||||
if (mFormatter->getCacheManager()->getBlockManager().getNumPools() == 1)
|
||||
if (!mFormatter->getCacheManager()->getBlockManager().isVariableWindow())
|
||||
{
|
||||
auto* cacheManager = mFormatter->getCacheManager();
|
||||
auto beam = 0;
|
||||
|
||||
@ -876,14 +876,7 @@ void WindowBlockManager::allocatePools(bool useUvm)
|
||||
}
|
||||
|
||||
nvinfer1::Dims cacheShape;
|
||||
if (pool.containsIndexerKCache)
|
||||
{
|
||||
cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, blockSize});
|
||||
}
|
||||
else
|
||||
{
|
||||
cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, mKVFactor, blockSize});
|
||||
}
|
||||
cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, mKVFactor, blockSize});
|
||||
|
||||
TLLM_LOG_DEBUG("[%s] Allocating primary pool with %d blocks for %d layers with %d kv heads", mLogPrefix.c_str(),
|
||||
mNumPrimaryBlocks, pool.numLayers, pool.numKvHeads);
|
||||
|
||||
@ -881,12 +881,3 @@ python quickstart_advanced.py --model_dir <YOUR_MODEL_DIR> --enable_chunked_pref
|
||||
- **GPU Memory:** Adjust `--max_batch_size` and `--max_num_tokens` if you encounter out-of-memory errors.
|
||||
- **Logs:** Check `/workspace/trt_bench.log` for detailed performance information and troubleshooting messages.
|
||||
- **Configuration Files:** Verify that the configuration files are correctly formatted to avoid runtime issues.
|
||||
|
||||
## Known Issues
|
||||
- Support for KV Cache Reuse and Chunked Prefill in DeepSeek-V3.2-Exp is currently under development. When running `quickstart_advanced.py`, please include `--disable_kv_cache_reuse` to disable KV Cache Reuse. When using `trtllm-eval`/`trtllm-serve`/`trtllm-bench`, please include the following configuration in the extra llm_api options:
|
||||
```
|
||||
kv_cache_config:
|
||||
enable_block_reuse: false
|
||||
tokens_per_block: 64
|
||||
enable_chunked_prefill: false
|
||||
```
|
||||
|
||||
@ -930,7 +930,8 @@ class Indexer(nn.Module):
|
||||
start_idx=0,
|
||||
)
|
||||
|
||||
if len(chunk_groups) > 1:
|
||||
if len(chunk_groups
|
||||
) > 1 or metadata.enable_context_mla_with_cached_kv:
|
||||
metadata.indexer_prefill_chunks = [
|
||||
Indexer.prepare_one_prefill_chunk(
|
||||
metadata,
|
||||
@ -938,7 +939,6 @@ class Indexer(nn.Module):
|
||||
) for chunk_specs in chunk_groups
|
||||
]
|
||||
else:
|
||||
# Single chunk - use non-chunked fallback path
|
||||
metadata.indexer_prefill_chunks = None
|
||||
|
||||
host_cu_seqlen_ks, host_cu_seqlen_ke = compute_cu_seqlen_kv_bounds_with_cache(
|
||||
@ -1018,9 +1018,9 @@ class Indexer(nn.Module):
|
||||
metadata.slot_mapping_scale[:total_tokens].copy_(
|
||||
metadata.host_slot_mapping_scale[:total_tokens], non_blocking=True)
|
||||
|
||||
# Only when MLA chunked prefill is enabled, we need to gather the full KV for indexer's logit computation.
|
||||
# When chunked prefill or KVCache reuse is enabled, we need to gather the full KV for indexer's logit computation.
|
||||
# Indexer's own chunking does not need full KV gathering, instead it gathers only the current chunk with loop-based gathering.
|
||||
_need_full_kv_gathering = num_contexts > 0 and has_mla_chunked_prefill
|
||||
_need_full_kv_gathering = num_contexts > 0 and metadata.enable_context_mla_with_cached_kv
|
||||
if _need_full_kv_gathering:
|
||||
total_kv_len = metadata.host_ctx_kv_indptr[num_contexts].item()
|
||||
total_kv_per_request = seq_lens[:
|
||||
@ -1589,10 +1589,6 @@ class DSACacheManager(KVCacheManager):
|
||||
sparse_attn_config: "SparseAttentionConfig",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
||||
if kv_cache_config.enable_block_reuse:
|
||||
raise NotImplementedError(
|
||||
"DSA indexer K-cache manager does not support block reuse yet")
|
||||
self.quant_block_size = 128
|
||||
self.index_head_dim = sparse_attn_config.index_head_dim
|
||||
# Use a fixed tokens_per_block for indexer k cache due to DG kernel constraints
|
||||
|
||||
@ -1055,7 +1055,6 @@ class TestDeepSeekV32Exp(LlmapiAccuracyTestHarness):
|
||||
ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
|
||||
gen_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
|
||||
ctx_server_config["kv_cache_config"] = {
|
||||
"enable_block_reuse": False,
|
||||
"free_gpu_memory_fraction": 0.7,
|
||||
"tokens_per_block": 64,
|
||||
"dtype": "fp8"
|
||||
@ -1072,7 +1071,6 @@ class TestDeepSeekV32Exp(LlmapiAccuracyTestHarness):
|
||||
ctx_server_config["enable_attention_dp"] = True
|
||||
ctx_server_config["enable_autotuner"] = False
|
||||
gen_server_config["kv_cache_config"] = {
|
||||
"enable_block_reuse": False,
|
||||
"tokens_per_block": 64,
|
||||
"free_gpu_memory_fraction": 0.7,
|
||||
"dtype": "fp8"
|
||||
|
||||
@ -2597,17 +2597,13 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
if get_sm_version() == 100 or get_sm_version() == 103:
|
||||
moe_backend = "DEEPGEMM" if moe_backend == "_DEFAULT" else moe_backend
|
||||
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
|
||||
# TODO: Support block reuse for DeepSeek-V3.2
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||
free_gpu_memory_fraction=0.6,
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
|
||||
tokens_per_block=64)
|
||||
else:
|
||||
if moe_backend != "_DEFAULT":
|
||||
pytest.skip("Not supported MoE backend!")
|
||||
moe_config = MoeConfig()
|
||||
# TODO: Support block reuse for DeepSeek-V3.2
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||
free_gpu_memory_fraction=0.7,
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
|
||||
tokens_per_block=64)
|
||||
|
||||
pytorch_config = dict(
|
||||
@ -2670,8 +2666,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
"MOE TRTLLM backend does not support SM version 120 or 121")
|
||||
|
||||
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||
free_gpu_memory_fraction=0.7,
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
|
||||
tokens_per_block=64)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
enable_padding=True,
|
||||
@ -2730,8 +2725,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
"MOE TRTLLM backend does not support SM version 120 or 121")
|
||||
|
||||
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||
free_gpu_memory_fraction=0.7,
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
|
||||
tokens_per_block=64)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
enable_padding=True,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user