mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] update deepgemm to the DeepGEMM/nv_dev branch (#9898)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
parent
d272f1a9bc
commit
4931c5eb3a
4
3rdparty/CMakeLists.txt
vendored
4
3rdparty/CMakeLists.txt
vendored
@ -38,8 +38,8 @@ FetchContent_Declare(
|
||||
|
||||
FetchContent_Declare(
|
||||
deepgemm
|
||||
GIT_REPOSITORY https://github.com/ruoqianguo/DeepGEMM
|
||||
GIT_TAG 6cb8161516302550785d9af924d2778afef1f3f6 # swapab_sm100 branch
|
||||
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
|
||||
GIT_TAG 4ff3f54d9b7ed3129e4f36f9871232ea7ecab86b # nv_dev branch
|
||||
GIT_SUBMODULES_RECURSE
|
||||
ON
|
||||
SOURCE_SUBDIR
|
||||
|
||||
@ -38,7 +38,13 @@ foreach(SOURCE_FILE ${DEEP_GEMM_ALL_FILES})
|
||||
if(FILE_EXT STREQUAL ".py")
|
||||
# Read file content and replace module imports for Python files
|
||||
file(READ ${SOURCE_FILE} _content)
|
||||
string(REPLACE "deep_gemm_cpp" "tensorrt_llm.deep_gemm_cpp_tllm" _content
|
||||
string(REPLACE "from . import _C" "import tensorrt_llm.deep_gemm_cpp_tllm"
|
||||
_content "${_content}")
|
||||
string(REPLACE ".._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content
|
||||
"${_content}")
|
||||
string(REPLACE "._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content
|
||||
"${_content}")
|
||||
string(REPLACE "_C." "tensorrt_llm.deep_gemm_cpp_tllm." _content
|
||||
"${_content}")
|
||||
|
||||
# Add adaptation header
|
||||
|
||||
@ -90,7 +90,6 @@ To quickly run DeepSeek-V3, [examples/llm-api/quickstart_advanced.py](../llm-api
|
||||
cd examples/llm-api
|
||||
python quickstart_advanced.py --model_dir <YOUR_MODEL_DIR> --tp_size 8
|
||||
```
|
||||
Please include `--tokens_per_block 64` when running DeepSeek-V3.2-Exp, as this model uses the deep_gemm.fp8_paged_mqa_logits kernel, which requires a KV cache block size of 64.
|
||||
|
||||
The model will be run by PyTorch backend and generate outputs like:
|
||||
```
|
||||
@ -108,7 +107,7 @@ cd examples/llm-api
|
||||
python quickstart_advanced.py --model_dir <YOUR_MODEL_DIR> --spec_decode_algo MTP --spec_decode_max_draft_len N
|
||||
```
|
||||
|
||||
`N` is the number of MTP modules. When `N` is equal to `0`, which means that MTP is not used (default). When `N` is greater than `0`, which means that `N` MTP modules are enabled. In the current implementation, the weight of each MTP module is shared. Please include `--tokens_per_block 64` when running DeepSeek-V3.2-Exp.
|
||||
`N` is the number of MTP modules. When `N` is equal to `0`, which means that MTP is not used (default). When `N` is greater than `0`, which means that `N` MTP modules are enabled. In the current implementation, the weight of each MTP module is shared.
|
||||
|
||||
#### Relaxed acceptance
|
||||
**NOTE: This feature can only be used for DeepSeek R1.**
|
||||
|
||||
@ -785,7 +785,6 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
# After changing the kv_lens/kv_lens_cuda, we may need to update other metadatas.
|
||||
# Especially for the changes in the _preprocess_inputs() of model_engine.py.
|
||||
if self.num_generations > 0:
|
||||
tokens_per_block = self.kv_cache_manager.indexer_k_cache_tokens_per_block
|
||||
torch.cumsum(
|
||||
self.kv_lens_cuda[self.num_contexts:self.
|
||||
num_seqs], # num_contexts should be 0
|
||||
@ -800,7 +799,7 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
out=self.gen_cached_token_indptr[1:self.num_generations + 1])
|
||||
scheduler_metadata_buffer = get_paged_mqa_logits_metadata(
|
||||
self.kv_lens_cuda[self.num_contexts:self.num_seqs],
|
||||
tokens_per_block, self.num_sms)
|
||||
self.kv_cache_manager.tokens_per_block, self.num_sms)
|
||||
self.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer,
|
||||
non_blocking=True)
|
||||
if self.use_expanded_buffers_for_mtp:
|
||||
@ -827,7 +826,6 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
|
||||
def update_for_spec_dec(self):
|
||||
super().update_for_spec_dec()
|
||||
self.kv_cache_manager.indexer_k_cache_tokens_per_block
|
||||
# host
|
||||
self.max_ctx_kv_len = 0
|
||||
self.num_ctx_cached_tokens = 0
|
||||
@ -1030,7 +1028,7 @@ class Indexer(nn.Module):
|
||||
request_ids = metadata.request_ids
|
||||
seq_lens = metadata.seq_lens
|
||||
head_dim = metadata.kv_cache_manager.index_head_dim
|
||||
tokens_per_block = metadata.kv_cache_manager.indexer_k_cache_tokens_per_block
|
||||
tokens_per_block = metadata.kv_cache_manager.tokens_per_block
|
||||
quant_block_size = metadata.kv_cache_manager.quant_block_size
|
||||
cached_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
|
||||
total_tokens = seq_lens.sum().item()
|
||||
@ -1750,9 +1748,6 @@ class DSACacheManager(KVCacheManager):
|
||||
) -> None:
|
||||
self.quant_block_size = 128
|
||||
self.index_head_dim = sparse_attn_config.index_head_dim
|
||||
# Use a fixed tokens_per_block for indexer k cache due to DG kernel constraints
|
||||
self.indexer_k_cache_tokens_per_block = 64
|
||||
assert self.indexer_k_cache_tokens_per_block == tokens_per_block, "tokens_per_block must be set to 64 for DeepSeek v3.2"
|
||||
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@ -1778,7 +1773,7 @@ class DSACacheManager(KVCacheManager):
|
||||
self.num_blocks = self.blocks_in_primary_pool
|
||||
|
||||
# Indexer K cache pool for DSA attention
|
||||
# Shape: [num_blocks, self.indexer_k_cache_tokens_per_block * (index_head_dim + scale_size)]
|
||||
# Shape: [num_blocks, self.tokens_per_block * (index_head_dim + scale_size)]
|
||||
# Non-interleaved layout: [fp8_tok0 | fp8_tok1 | ... | scale_tok0 | scale_tok1 | ...]
|
||||
# Store FP8-quantized k values from the indexer
|
||||
self.indexer_k_cache_pool_per_layer = [
|
||||
@ -1805,9 +1800,7 @@ class DSACacheManager(KVCacheManager):
|
||||
config = model_config.pretrained_config
|
||||
sparse_attn_config = model_config.sparse_attention_config
|
||||
index_head_dim = sparse_attn_config.index_head_dim
|
||||
tokens_per_block = kwargs['tokens_per_block']
|
||||
quant_block_size = 128
|
||||
indexer_k_cache_tokens_per_block = 64
|
||||
|
||||
# get kv cache dtype bytes
|
||||
mem_per_token = 2
|
||||
@ -1827,8 +1820,7 @@ class DSACacheManager(KVCacheManager):
|
||||
# 1 for K, others for indexer K cache
|
||||
head_dim_factor = (index_head_dim +
|
||||
index_head_dim // quant_block_size * 4) / head_dim
|
||||
tokens_per_block_factor = indexer_k_cache_tokens_per_block / tokens_per_block
|
||||
kv_factor = 1 + head_dim_factor * tokens_per_block_factor
|
||||
kv_factor = 1 + head_dim_factor
|
||||
mem_per_token *= kv_factor
|
||||
return mem_per_token
|
||||
|
||||
@ -1836,8 +1828,7 @@ class DSACacheManager(KVCacheManager):
|
||||
# self.kv_factor for K, others for indexer K cache
|
||||
head_dim_factor = (self.index_head_dim + self.index_head_dim //
|
||||
self.quant_block_size * 4) / self.head_dim
|
||||
tokens_per_block_factor = self.indexer_k_cache_tokens_per_block / self.tokens_per_block
|
||||
kv_factor = self.kv_factor + head_dim_factor * tokens_per_block_factor
|
||||
kv_factor = self.kv_factor + head_dim_factor
|
||||
cache_size_per_token = math.ceil(
|
||||
kv_factor * sum(self.num_kv_heads_per_layer) * self.head_dim)
|
||||
|
||||
|
||||
@ -2633,14 +2633,12 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
if get_sm_version() == 100 or get_sm_version() == 103:
|
||||
moe_backend = "DEEPGEMM" if moe_backend == "_DEFAULT" else moe_backend
|
||||
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
|
||||
tokens_per_block=64)
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
|
||||
else:
|
||||
if moe_backend != "_DEFAULT":
|
||||
pytest.skip("Not supported MoE backend!")
|
||||
moe_config = MoeConfig()
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
|
||||
tokens_per_block=64)
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
|
||||
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
@ -2711,8 +2709,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
pytest.skip(f"{moe_backend} backend does not support SM 120 or 121")
|
||||
|
||||
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
|
||||
tokens_per_block=64)
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
enable_padding=True,
|
||||
max_batch_size=max_batch_size) if cuda_graph else None
|
||||
@ -2775,8 +2772,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
pytest.skip(f"{moe_backend} backend does not support SM 120 or 121")
|
||||
|
||||
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
|
||||
tokens_per_block=64)
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
enable_padding=True,
|
||||
max_batch_size=max_batch_size) if cuda_graph else None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user