[https://nvbugs/5502352][fix] Fix 2-model CDL path (#7543)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
Mike Iovine 2025-09-06 23:53:27 -04:00 committed by GitHub
parent 99b98f1374
commit 45390402fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 4 deletions

View File

@ -20,9 +20,12 @@ from tensorrt_llm._torch.speculative.interface import SpecMetadata
@contextmanager
def save_metadata_state(attn_metadata: AttentionMetadata,
spec_metadata: SpecMetadata) -> None:
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda",
"kv_lens_cuda")
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
batch_size = attn_metadata.num_seqs
# Do not use prepare_for_spec_dec for this special field.
# TRTLLM attention uses views of this tensor internally and prepare_for_spec_dec
# creates a copy. If you write to the copy, TRTLLM attention won't see the updates.
kv_lens = attn_metadata.kv_lens_cuda[:batch_size].clone()
if attn_metadata.is_cuda_graph:
assert spec_metadata.is_cuda_graph
@ -39,6 +42,8 @@ def save_metadata_state(attn_metadata: AttentionMetadata,
yield
finally:
attn_metadata.restore_from_spec_dec()
attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens)
if attn_metadata.is_cuda_graph:
spec_metadata.num_tokens = num_tokens
if isinstance(spec_metadata, Eagle3SpecMetadata):

View File

@ -16,7 +16,6 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@pytest.mark.skip(reason="https://nvbugs/5461761")
@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
[
@ -27,7 +26,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
[False, "TRTLLM", False, True, True, False],
[True, "TRTLLM", False, True, True, False],
[True, "TRTLLM", True, False, True, True],
[True, "TRTLLM", True, False, False, True],
# TODO: nvbugs/5461761
# [True, "TRTLLM", True, False, False, True],
])
@pytest.mark.high_cuda_memory
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,