mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5779534][fix] fix buffer reuse for CUDA graph attention metadata (#10393)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
parent
da0830670a
commit
b5a1e10bc0
@ -130,7 +130,7 @@ class FlashInferAttentionMetadata(AttentionMetadata):
|
||||
self._post_init_with_buffers(self.cuda_graph_buffers)
|
||||
|
||||
def _post_init_with_buffers(self, buffers) -> None:
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
capture_graph = self.is_cuda_graph
|
||||
|
||||
if self.workspace_buffer is None:
|
||||
# Note: even though flashinfer only recommends 128 MB, we have to push it
|
||||
|
||||
@ -328,7 +328,7 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
|
||||
self.sparse_mla_topk = self.sparse_attention_config.index_topk
|
||||
self.enable_indexer_skip = self.sparse_attention_config.skip_indexer_for_short_seqs
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
capture_graph = self.is_cuda_graph
|
||||
|
||||
self.indexer_k_cache_block_offsets = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
@ -550,7 +550,7 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
self.max_draft_tokens = max_draft_len
|
||||
init_shape = self.kv_lens_expanded_host.shape[0]
|
||||
if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape:
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
capture_graph = self.is_cuda_graph
|
||||
self.create_expanded_buffers(capture_graph=capture_graph)
|
||||
|
||||
def prepare_dense_topk_indices(self,
|
||||
|
||||
@ -48,7 +48,7 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
assert self.page_size == next_power_of_2(
|
||||
self.page_size), "Page size must be a power of 2"
|
||||
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
capture_graph = self.is_cuda_graph
|
||||
|
||||
# Cumulative valid sequence lengths for query and key
|
||||
self.q_cu_seqlens_cuda = self.get_empty(
|
||||
|
||||
@ -737,7 +737,7 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
if self.max_num_sequences is None:
|
||||
self.max_num_sequences = self.max_num_requests
|
||||
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
capture_graph = self.is_cuda_graph
|
||||
|
||||
self.prompt_lens_cuda = self.get_empty(
|
||||
buffers,
|
||||
|
||||
@ -427,7 +427,7 @@ class ModelConfig(Generic[TConfig]):
|
||||
index_head_dim = pretrained_config.index_head_dim
|
||||
index_topk = pretrained_config.index_topk
|
||||
indexer_max_chunk_size = None
|
||||
skip_indexer_for_short_seqs = False
|
||||
skip_indexer_for_short_seqs = True
|
||||
kwargs[
|
||||
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
|
||||
index_n_heads=index_n_heads,
|
||||
|
||||
@ -287,9 +287,8 @@ class DeepSeekSparseAttentionConfig(BaseSparseAttentionConfig):
|
||||
description="The topk for the indexer.")
|
||||
indexer_max_chunk_size: Optional[int] = Field(
|
||||
default=None, description="The maximum chunk size for the indexer.")
|
||||
# TODO: enable this by default once the memory usage in attention metadata is optimized
|
||||
skip_indexer_for_short_seqs: bool = Field(
|
||||
default=False,
|
||||
default=True,
|
||||
description=
|
||||
"Whether to skip the MQA and Top-K in the indexer for short sequences.")
|
||||
|
||||
|
||||
@ -2610,7 +2610,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_device_memory(140000)
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend,skip_indexer",
|
||||
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend,disable_skip_indexer",
|
||||
[
|
||||
(8, 1, 8, 0, False, True, True, True, 24, "_DEFAULT", False),
|
||||
(8, 1, 8, 1, False, True, True, True, 24, "_DEFAULT", False),
|
||||
@ -2621,11 +2621,11 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
],
|
||||
ids=[
|
||||
"baseline", "baseline_mtp1", "baseline_fp8kv", "latency",
|
||||
"latency_default", "skip_indexer"
|
||||
"latency_default", "disable_skip_indexer"
|
||||
])
|
||||
def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
|
||||
attention_dp, cuda_graph, overlap_scheduler,
|
||||
max_batch_size, moe_backend, skip_indexer):
|
||||
max_batch_size, moe_backend, disable_skip_indexer):
|
||||
if get_sm_version() == 100 or get_sm_version() == 103:
|
||||
moe_backend = "DEEPGEMM" if moe_backend == "_DEFAULT" else moe_backend
|
||||
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
|
||||
@ -2652,9 +2652,9 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
dsa_config = None
|
||||
if skip_indexer:
|
||||
if disable_skip_indexer:
|
||||
dsa_config = DeepSeekSparseAttentionConfig(
|
||||
skip_indexer_for_short_seqs=True)
|
||||
skip_indexer_for_short_seqs=False)
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
@ -2686,7 +2686,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.skip_less_mpi_world_size(8)
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend,skip_indexer",
|
||||
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend,disable_skip_indexer",
|
||||
[
|
||||
(8, 1, 8, 0, False, True, True, True, 24, "CUTLASS", False),
|
||||
(8, 1, 8, 1, False, True, True, True, 24, "CUTLASS", False),
|
||||
@ -2696,11 +2696,12 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
],
|
||||
ids=[
|
||||
"baseline", "baseline_mtp1", "baseline_fp8kv", "latency",
|
||||
"skip_indexer"
|
||||
"disable_skip_indexer"
|
||||
])
|
||||
def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
|
||||
attention_dp, cuda_graph, overlap_scheduler,
|
||||
max_batch_size, moe_backend, skip_indexer):
|
||||
max_batch_size, moe_backend,
|
||||
disable_skip_indexer):
|
||||
sm_version = get_sm_version()
|
||||
if moe_backend == "TRTLLM" and sm_version in (120, 121):
|
||||
pytest.skip(f"{moe_backend} backend does not support SM 120 or 121")
|
||||
@ -2721,9 +2722,9 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
dsa_config = None
|
||||
if skip_indexer:
|
||||
if disable_skip_indexer:
|
||||
dsa_config = DeepSeekSparseAttentionConfig(
|
||||
skip_indexer_for_short_seqs=True)
|
||||
skip_indexer_for_short_seqs=False)
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
|
||||
@ -511,12 +511,12 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_mtp1]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[skip_indexer]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[skip_indexer]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[disable_skip_indexer]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput]
|
||||
|
||||
@ -52,12 +52,12 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_mtp1]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[skip_indexer]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[skip_indexer]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[disable_skip_indexer]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0-moe_backend=WIDEEP]
|
||||
|
||||
@ -102,10 +102,10 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[skip_indexer] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[skip_indexer] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[disable_skip_indexer] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_fp8[latency_moe_deepgemm] TIMEOUT (60)
|
||||
- condition:
|
||||
|
||||
@ -19,7 +19,7 @@ l0_dgx_h200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] # 1h
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency_default]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[skip_indexer]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True]
|
||||
|
||||
@ -537,7 +537,6 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=True] SKIP (https://nvbugs/5777044)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5777044)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] SKIP (https://nvbugs/5777044)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[skip_indexer] SKIP (https://nvbugs/5779534)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8] SKIP (https://nvbugs/5769890)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/5769890)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf] SKIP (https://nvbugs/5769890,https://nvbugs/5748683)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user