[None][feat] update deepgemm to the DeepGEMM/nv_dev branch (#9898)

Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
Fanrong Li 2026-01-05 16:43:42 +08:00 committed by GitHub
parent d272f1a9bc
commit 4931c5eb3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 19 additions and 27 deletions

View File

@ -38,8 +38,8 @@ FetchContent_Declare(
FetchContent_Declare(
deepgemm
GIT_REPOSITORY https://github.com/ruoqianguo/DeepGEMM
GIT_TAG 6cb8161516302550785d9af924d2778afef1f3f6 # swapab_sm100 branch
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
GIT_TAG 4ff3f54d9b7ed3129e4f36f9871232ea7ecab86b # nv_dev branch
GIT_SUBMODULES_RECURSE
ON
SOURCE_SUBDIR

View File

@ -38,7 +38,13 @@ foreach(SOURCE_FILE ${DEEP_GEMM_ALL_FILES})
if(FILE_EXT STREQUAL ".py")
# Read file content and replace module imports for Python files
file(READ ${SOURCE_FILE} _content)
string(REPLACE "deep_gemm_cpp" "tensorrt_llm.deep_gemm_cpp_tllm" _content
string(REPLACE "from . import _C" "import tensorrt_llm.deep_gemm_cpp_tllm"
_content "${_content}")
string(REPLACE ".._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content
"${_content}")
string(REPLACE "._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content
"${_content}")
string(REPLACE "_C." "tensorrt_llm.deep_gemm_cpp_tllm." _content
"${_content}")
# Add adaptation header

View File

@ -90,7 +90,6 @@ To quickly run DeepSeek-V3, [examples/llm-api/quickstart_advanced.py](../llm-api
cd examples/llm-api
python quickstart_advanced.py --model_dir <YOUR_MODEL_DIR> --tp_size 8
```
Please include `--tokens_per_block 64` when running DeepSeek-V3.2-Exp, as this model uses the deep_gemm.fp8_paged_mqa_logits kernel, which requires a KV cache block size of 64.
The model will be run by PyTorch backend and generate outputs like:
```
@ -108,7 +107,7 @@ cd examples/llm-api
python quickstart_advanced.py --model_dir <YOUR_MODEL_DIR> --spec_decode_algo MTP --spec_decode_max_draft_len N
```
`N` is the number of MTP modules. When `N` is equal to `0`, which means that MTP is not used (default). When `N` is greater than `0`, which means that `N` MTP modules are enabled. In the current implementation, the weight of each MTP module is shared. Please include `--tokens_per_block 64` when running DeepSeek-V3.2-Exp.
`N` is the number of MTP modules. When `N` is equal to `0`, which means that MTP is not used (default). When `N` is greater than `0`, which means that `N` MTP modules are enabled. In the current implementation, the weight of each MTP module is shared.
#### Relaxed acceptance
**NOTE: This feature can only be used for DeepSeek R1.**

View File

@ -785,7 +785,6 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
# After changing the kv_lens/kv_lens_cuda, we may need to update other metadatas.
# Especially for the changes in the _preprocess_inputs() of model_engine.py.
if self.num_generations > 0:
tokens_per_block = self.kv_cache_manager.indexer_k_cache_tokens_per_block
torch.cumsum(
self.kv_lens_cuda[self.num_contexts:self.
num_seqs], # num_contexts should be 0
@ -800,7 +799,7 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
out=self.gen_cached_token_indptr[1:self.num_generations + 1])
scheduler_metadata_buffer = get_paged_mqa_logits_metadata(
self.kv_lens_cuda[self.num_contexts:self.num_seqs],
tokens_per_block, self.num_sms)
self.kv_cache_manager.tokens_per_block, self.num_sms)
self.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer,
non_blocking=True)
if self.use_expanded_buffers_for_mtp:
@ -827,7 +826,6 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
def update_for_spec_dec(self):
super().update_for_spec_dec()
self.kv_cache_manager.indexer_k_cache_tokens_per_block
# host
self.max_ctx_kv_len = 0
self.num_ctx_cached_tokens = 0
@ -1030,7 +1028,7 @@ class Indexer(nn.Module):
request_ids = metadata.request_ids
seq_lens = metadata.seq_lens
head_dim = metadata.kv_cache_manager.index_head_dim
tokens_per_block = metadata.kv_cache_manager.indexer_k_cache_tokens_per_block
tokens_per_block = metadata.kv_cache_manager.tokens_per_block
quant_block_size = metadata.kv_cache_manager.quant_block_size
cached_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
total_tokens = seq_lens.sum().item()
@ -1750,9 +1748,6 @@ class DSACacheManager(KVCacheManager):
) -> None:
self.quant_block_size = 128
self.index_head_dim = sparse_attn_config.index_head_dim
# Use a fixed tokens_per_block for indexer k cache due to DG kernel constraints
self.indexer_k_cache_tokens_per_block = 64
assert self.indexer_k_cache_tokens_per_block == tokens_per_block, "tokens_per_block must be set to 64 for DeepSeek v3.2"
super().__init__(
kv_cache_config,
@ -1778,7 +1773,7 @@ class DSACacheManager(KVCacheManager):
self.num_blocks = self.blocks_in_primary_pool
# Indexer K cache pool for DSA attention
# Shape: [num_blocks, self.indexer_k_cache_tokens_per_block * (index_head_dim + scale_size)]
# Shape: [num_blocks, self.tokens_per_block * (index_head_dim + scale_size)]
# Non-interleaved layout: [fp8_tok0 | fp8_tok1 | ... | scale_tok0 | scale_tok1 | ...]
# Store FP8-quantized k values from the indexer
self.indexer_k_cache_pool_per_layer = [
@ -1805,9 +1800,7 @@ class DSACacheManager(KVCacheManager):
config = model_config.pretrained_config
sparse_attn_config = model_config.sparse_attention_config
index_head_dim = sparse_attn_config.index_head_dim
tokens_per_block = kwargs['tokens_per_block']
quant_block_size = 128
indexer_k_cache_tokens_per_block = 64
# get kv cache dtype bytes
mem_per_token = 2
@ -1827,8 +1820,7 @@ class DSACacheManager(KVCacheManager):
# 1 for K, others for indexer K cache
head_dim_factor = (index_head_dim +
index_head_dim // quant_block_size * 4) / head_dim
tokens_per_block_factor = indexer_k_cache_tokens_per_block / tokens_per_block
kv_factor = 1 + head_dim_factor * tokens_per_block_factor
kv_factor = 1 + head_dim_factor
mem_per_token *= kv_factor
return mem_per_token
@ -1836,8 +1828,7 @@ class DSACacheManager(KVCacheManager):
# self.kv_factor for K, others for indexer K cache
head_dim_factor = (self.index_head_dim + self.index_head_dim //
self.quant_block_size * 4) / self.head_dim
tokens_per_block_factor = self.indexer_k_cache_tokens_per_block / self.tokens_per_block
kv_factor = self.kv_factor + head_dim_factor * tokens_per_block_factor
kv_factor = self.kv_factor + head_dim_factor
cache_size_per_token = math.ceil(
kv_factor * sum(self.num_kv_heads_per_layer) * self.head_dim)

View File

@ -2633,14 +2633,12 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
if get_sm_version() == 100 or get_sm_version() == 103:
moe_backend = "DEEPGEMM" if moe_backend == "_DEFAULT" else moe_backend
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
tokens_per_block=64)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
else:
if moe_backend != "_DEFAULT":
pytest.skip("Not supported MoE backend!")
moe_config = MoeConfig()
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
tokens_per_block=64)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
@ -2711,8 +2709,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
pytest.skip(f"{moe_backend} backend does not support SM 120 or 121")
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
tokens_per_block=64)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
cuda_graph_config = CudaGraphConfig(
enable_padding=True,
max_batch_size=max_batch_size) if cuda_graph else None
@ -2775,8 +2772,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
pytest.skip(f"{moe_backend} backend does not support SM 120 or 121")
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
tokens_per_block=64)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
cuda_graph_config = CudaGraphConfig(
enable_padding=True,
max_batch_size=max_batch_size) if cuda_graph else None