From 4931c5eb3af3bb65ee13dfdf582e3791e738b643 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 5 Jan 2026 16:43:42 +0800 Subject: [PATCH] [None][feat] update deepgemm to the DeepGEMM/nv_dev branch (#9898) Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- 3rdparty/CMakeLists.txt | 4 ++-- cpp/tensorrt_llm/deep_gemm/CMakeLists.txt | 8 +++++++- examples/models/core/deepseek_v3/README.md | 3 +-- .../_torch/attention_backend/sparse/dsa.py | 19 +++++-------------- .../defs/accuracy/test_llm_api_pytorch.py | 12 ++++-------- 5 files changed, 19 insertions(+), 27 deletions(-) diff --git a/3rdparty/CMakeLists.txt b/3rdparty/CMakeLists.txt index 59076e14c9..93565ae099 100644 --- a/3rdparty/CMakeLists.txt +++ b/3rdparty/CMakeLists.txt @@ -38,8 +38,8 @@ FetchContent_Declare( FetchContent_Declare( deepgemm - GIT_REPOSITORY https://github.com/ruoqianguo/DeepGEMM - GIT_TAG 6cb8161516302550785d9af924d2778afef1f3f6 # swapab_sm100 branch + GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM + GIT_TAG 4ff3f54d9b7ed3129e4f36f9871232ea7ecab86b # nv_dev branch GIT_SUBMODULES_RECURSE ON SOURCE_SUBDIR diff --git a/cpp/tensorrt_llm/deep_gemm/CMakeLists.txt b/cpp/tensorrt_llm/deep_gemm/CMakeLists.txt index d2a900bf05..1a884befca 100644 --- a/cpp/tensorrt_llm/deep_gemm/CMakeLists.txt +++ b/cpp/tensorrt_llm/deep_gemm/CMakeLists.txt @@ -38,7 +38,13 @@ foreach(SOURCE_FILE ${DEEP_GEMM_ALL_FILES}) if(FILE_EXT STREQUAL ".py") # Read file content and replace module imports for Python files file(READ ${SOURCE_FILE} _content) - string(REPLACE "deep_gemm_cpp" "tensorrt_llm.deep_gemm_cpp_tllm" _content + string(REPLACE "from . import _C" "import tensorrt_llm.deep_gemm_cpp_tllm" + _content "${_content}") + string(REPLACE ".._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content + "${_content}") + string(REPLACE "._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content + "${_content}") + string(REPLACE "_C." "tensorrt_llm.deep_gemm_cpp_tllm." _content "${_content}") # Add adaptation header diff --git a/examples/models/core/deepseek_v3/README.md b/examples/models/core/deepseek_v3/README.md index b25c58fbea..81a0e95f50 100644 --- a/examples/models/core/deepseek_v3/README.md +++ b/examples/models/core/deepseek_v3/README.md @@ -90,7 +90,6 @@ To quickly run DeepSeek-V3, [examples/llm-api/quickstart_advanced.py](../llm-api cd examples/llm-api python quickstart_advanced.py --model_dir --tp_size 8 ``` -Please include `--tokens_per_block 64` when running DeepSeek-V3.2-Exp, as this model uses the deep_gemm.fp8_paged_mqa_logits kernel, which requires a KV cache block size of 64. The model will be run by PyTorch backend and generate outputs like: ``` @@ -108,7 +107,7 @@ cd examples/llm-api python quickstart_advanced.py --model_dir --spec_decode_algo MTP --spec_decode_max_draft_len N ``` -`N` is the number of MTP modules. When `N` is equal to `0`, which means that MTP is not used (default). When `N` is greater than `0`, which means that `N` MTP modules are enabled. In the current implementation, the weight of each MTP module is shared. Please include `--tokens_per_block 64` when running DeepSeek-V3.2-Exp. +`N` is the number of MTP modules. When `N` is equal to `0`, which means that MTP is not used (default). When `N` is greater than `0`, which means that `N` MTP modules are enabled. In the current implementation, the weight of each MTP module is shared. #### Relaxed acceptance **NOTE: This feature can only be used for DeepSeek R1.** diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index b30edc1aa6..80e37bba7d 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -785,7 +785,6 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata): # After changing the kv_lens/kv_lens_cuda, we may need to update other metadatas. # Especially for the changes in the _preprocess_inputs() of model_engine.py. if self.num_generations > 0: - tokens_per_block = self.kv_cache_manager.indexer_k_cache_tokens_per_block torch.cumsum( self.kv_lens_cuda[self.num_contexts:self. num_seqs], # num_contexts should be 0 @@ -800,7 +799,7 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata): out=self.gen_cached_token_indptr[1:self.num_generations + 1]) scheduler_metadata_buffer = get_paged_mqa_logits_metadata( self.kv_lens_cuda[self.num_contexts:self.num_seqs], - tokens_per_block, self.num_sms) + self.kv_cache_manager.tokens_per_block, self.num_sms) self.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer, non_blocking=True) if self.use_expanded_buffers_for_mtp: @@ -827,7 +826,6 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata): def update_for_spec_dec(self): super().update_for_spec_dec() - self.kv_cache_manager.indexer_k_cache_tokens_per_block # host self.max_ctx_kv_len = 0 self.num_ctx_cached_tokens = 0 @@ -1030,7 +1028,7 @@ class Indexer(nn.Module): request_ids = metadata.request_ids seq_lens = metadata.seq_lens head_dim = metadata.kv_cache_manager.index_head_dim - tokens_per_block = metadata.kv_cache_manager.indexer_k_cache_tokens_per_block + tokens_per_block = metadata.kv_cache_manager.tokens_per_block quant_block_size = metadata.kv_cache_manager.quant_block_size cached_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq total_tokens = seq_lens.sum().item() @@ -1750,9 +1748,6 @@ class DSACacheManager(KVCacheManager): ) -> None: self.quant_block_size = 128 self.index_head_dim = sparse_attn_config.index_head_dim - # Use a fixed tokens_per_block for indexer k cache due to DG kernel constraints - self.indexer_k_cache_tokens_per_block = 64 - assert self.indexer_k_cache_tokens_per_block == tokens_per_block, "tokens_per_block must be set to 64 for DeepSeek v3.2" super().__init__( kv_cache_config, @@ -1778,7 +1773,7 @@ class DSACacheManager(KVCacheManager): self.num_blocks = self.blocks_in_primary_pool # Indexer K cache pool for DSA attention - # Shape: [num_blocks, self.indexer_k_cache_tokens_per_block * (index_head_dim + scale_size)] + # Shape: [num_blocks, self.tokens_per_block * (index_head_dim + scale_size)] # Non-interleaved layout: [fp8_tok0 | fp8_tok1 | ... | scale_tok0 | scale_tok1 | ...] # Store FP8-quantized k values from the indexer self.indexer_k_cache_pool_per_layer = [ @@ -1805,9 +1800,7 @@ class DSACacheManager(KVCacheManager): config = model_config.pretrained_config sparse_attn_config = model_config.sparse_attention_config index_head_dim = sparse_attn_config.index_head_dim - tokens_per_block = kwargs['tokens_per_block'] quant_block_size = 128 - indexer_k_cache_tokens_per_block = 64 # get kv cache dtype bytes mem_per_token = 2 @@ -1827,8 +1820,7 @@ class DSACacheManager(KVCacheManager): # 1 for K, others for indexer K cache head_dim_factor = (index_head_dim + index_head_dim // quant_block_size * 4) / head_dim - tokens_per_block_factor = indexer_k_cache_tokens_per_block / tokens_per_block - kv_factor = 1 + head_dim_factor * tokens_per_block_factor + kv_factor = 1 + head_dim_factor mem_per_token *= kv_factor return mem_per_token @@ -1836,8 +1828,7 @@ class DSACacheManager(KVCacheManager): # self.kv_factor for K, others for indexer K cache head_dim_factor = (self.index_head_dim + self.index_head_dim // self.quant_block_size * 4) / self.head_dim - tokens_per_block_factor = self.indexer_k_cache_tokens_per_block / self.tokens_per_block - kv_factor = self.kv_factor + head_dim_factor * tokens_per_block_factor + kv_factor = self.kv_factor + head_dim_factor cache_size_per_token = math.ceil( kv_factor * sum(self.num_kv_heads_per_layer) * self.head_dim) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index fd274e30b0..8555407c6e 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2633,14 +2633,12 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness): if get_sm_version() == 100 or get_sm_version() == 103: moe_backend = "DEEPGEMM" if moe_backend == "_DEFAULT" else moe_backend moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384) - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6, - tokens_per_block=64) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6) else: if moe_backend != "_DEFAULT": pytest.skip("Not supported MoE backend!") moe_config = MoeConfig() - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7, - tokens_per_block=64) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, @@ -2711,8 +2709,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness): pytest.skip(f"{moe_backend} backend does not support SM 120 or 121") moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384) - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7, - tokens_per_block=64) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) cuda_graph_config = CudaGraphConfig( enable_padding=True, max_batch_size=max_batch_size) if cuda_graph else None @@ -2775,8 +2772,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness): pytest.skip(f"{moe_backend} backend does not support SM 120 or 121") moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384) - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7, - tokens_per_block=64) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) cuda_graph_config = CudaGraphConfig( enable_padding=True, max_batch_size=max_batch_size) if cuda_graph else None