[https://nvbugs/5779534][fix] fix buffer reuse for CUDA graph attention metadata (#10393)

Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
Fanrong Li 2026-01-05 09:43:44 +08:00 committed by GitHub
parent da0830670a
commit b5a1e10bc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 25 additions and 26 deletions

View File

@ -130,7 +130,7 @@ class FlashInferAttentionMetadata(AttentionMetadata):
self._post_init_with_buffers(self.cuda_graph_buffers)
def _post_init_with_buffers(self, buffers) -> None:
capture_graph = torch.cuda.is_current_stream_capturing()
capture_graph = self.is_cuda_graph
if self.workspace_buffer is None:
# Note: even though flashinfer only recommends 128 MB, we have to push it

View File

@ -328,7 +328,7 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
self.sparse_mla_topk = self.sparse_attention_config.index_topk
self.enable_indexer_skip = self.sparse_attention_config.skip_indexer_for_short_seqs
capture_graph = torch.cuda.is_current_stream_capturing()
capture_graph = self.is_cuda_graph
self.indexer_k_cache_block_offsets = self.get_empty(
self.cuda_graph_buffers,
@ -550,7 +550,7 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
self.max_draft_tokens = max_draft_len
init_shape = self.kv_lens_expanded_host.shape[0]
if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape:
capture_graph = torch.cuda.is_current_stream_capturing()
capture_graph = self.is_cuda_graph
self.create_expanded_buffers(capture_graph=capture_graph)
def prepare_dense_topk_indices(self,

View File

@ -48,7 +48,7 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
assert self.page_size == next_power_of_2(
self.page_size), "Page size must be a power of 2"
capture_graph = torch.cuda.is_current_stream_capturing()
capture_graph = self.is_cuda_graph
# Cumulative valid sequence lengths for query and key
self.q_cu_seqlens_cuda = self.get_empty(

View File

@ -737,7 +737,7 @@ class TrtllmAttentionMetadata(AttentionMetadata):
if self.max_num_sequences is None:
self.max_num_sequences = self.max_num_requests
capture_graph = torch.cuda.is_current_stream_capturing()
capture_graph = self.is_cuda_graph
self.prompt_lens_cuda = self.get_empty(
buffers,

View File

@ -427,7 +427,7 @@ class ModelConfig(Generic[TConfig]):
index_head_dim = pretrained_config.index_head_dim
index_topk = pretrained_config.index_topk
indexer_max_chunk_size = None
skip_indexer_for_short_seqs = False
skip_indexer_for_short_seqs = True
kwargs[
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
index_n_heads=index_n_heads,

View File

@ -287,9 +287,8 @@ class DeepSeekSparseAttentionConfig(BaseSparseAttentionConfig):
description="The topk for the indexer.")
indexer_max_chunk_size: Optional[int] = Field(
default=None, description="The maximum chunk size for the indexer.")
# TODO: enable this by default once the memory usage in attention metadata is optimized
skip_indexer_for_short_seqs: bool = Field(
default=False,
default=True,
description=
"Whether to skip the MQA and Top-K in the indexer for short sequences.")

View File

@ -2610,7 +2610,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
@skip_pre_hopper
@pytest.mark.skip_less_device_memory(140000)
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend,skip_indexer",
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend,disable_skip_indexer",
[
(8, 1, 8, 0, False, True, True, True, 24, "_DEFAULT", False),
(8, 1, 8, 1, False, True, True, True, 24, "_DEFAULT", False),
@ -2621,11 +2621,11 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
],
ids=[
"baseline", "baseline_mtp1", "baseline_fp8kv", "latency",
"latency_default", "skip_indexer"
"latency_default", "disable_skip_indexer"
])
def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
attention_dp, cuda_graph, overlap_scheduler,
max_batch_size, moe_backend, skip_indexer):
max_batch_size, moe_backend, disable_skip_indexer):
if get_sm_version() == 100 or get_sm_version() == 103:
moe_backend = "DEEPGEMM" if moe_backend == "_DEFAULT" else moe_backend
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
@ -2652,9 +2652,9 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
kv_cache_config.dtype = "fp8"
dsa_config = None
if skip_indexer:
if disable_skip_indexer:
dsa_config = DeepSeekSparseAttentionConfig(
skip_indexer_for_short_seqs=True)
skip_indexer_for_short_seqs=False)
mtp_config = None
if mtp_nextn > 0:
@ -2686,7 +2686,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_mpi_world_size(8)
@skip_pre_blackwell
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend,skip_indexer",
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend,disable_skip_indexer",
[
(8, 1, 8, 0, False, True, True, True, 24, "CUTLASS", False),
(8, 1, 8, 1, False, True, True, True, 24, "CUTLASS", False),
@ -2696,11 +2696,12 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
],
ids=[
"baseline", "baseline_mtp1", "baseline_fp8kv", "latency",
"skip_indexer"
"disable_skip_indexer"
])
def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
attention_dp, cuda_graph, overlap_scheduler,
max_batch_size, moe_backend, skip_indexer):
max_batch_size, moe_backend,
disable_skip_indexer):
sm_version = get_sm_version()
if moe_backend == "TRTLLM" and sm_version in (120, 121):
pytest.skip(f"{moe_backend} backend does not support SM 120 or 121")
@ -2721,9 +2722,9 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
kv_cache_config.dtype = "fp8"
dsa_config = None
if skip_indexer:
if disable_skip_indexer:
dsa_config = DeepSeekSparseAttentionConfig(
skip_indexer_for_short_seqs=True)
skip_indexer_for_short_seqs=False)
mtp_config = None
if mtp_nextn > 0:

View File

@ -511,12 +511,12 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_mtp1]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[skip_indexer]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[skip_indexer]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[disable_skip_indexer]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput]

View File

@ -52,12 +52,12 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_mtp1]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[skip_indexer]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[skip_indexer]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[disable_skip_indexer]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0-moe_backend=WIDEEP]

View File

@ -102,10 +102,10 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[skip_indexer] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[skip_indexer] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[disable_skip_indexer] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_fp8[latency_moe_deepgemm] TIMEOUT (60)
- condition:

View File

@ -19,7 +19,7 @@ l0_dgx_h200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] # 1h
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency_default]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[skip_indexer]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer]
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True]

View File

@ -537,7 +537,6 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=True] SKIP (https://nvbugs/5777044)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5777044)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] SKIP (https://nvbugs/5777044)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[skip_indexer] SKIP (https://nvbugs/5779534)
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8] SKIP (https://nvbugs/5769890)
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/5769890)
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf] SKIP (https://nvbugs/5769890,https://nvbugs/5748683)