diff --git a/tensorrt_llm/_torch/speculative/drafting_loops.py b/tensorrt_llm/_torch/speculative/drafting_loops.py index a54fb0cbfc..bf0c8e0f6d 100644 --- a/tensorrt_llm/_torch/speculative/drafting_loops.py +++ b/tensorrt_llm/_torch/speculative/drafting_loops.py @@ -20,9 +20,12 @@ from tensorrt_llm._torch.speculative.interface import SpecMetadata @contextmanager def save_metadata_state(attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata) -> None: - attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda", - "kv_lens_cuda") + attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda") batch_size = attn_metadata.num_seqs + # Do not use prepare_for_spec_dec for this special field. + # TRTLLM attention uses views of this tensor internally and prepare_for_spec_dec + # creates a copy. If you write to the copy, TRTLLM attention won't see the updates. + kv_lens = attn_metadata.kv_lens_cuda[:batch_size].clone() if attn_metadata.is_cuda_graph: assert spec_metadata.is_cuda_graph @@ -39,6 +42,8 @@ def save_metadata_state(attn_metadata: AttentionMetadata, yield finally: attn_metadata.restore_from_spec_dec() + attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens) + if attn_metadata.is_cuda_graph: spec_metadata.num_tokens = num_tokens if isinstance(spec_metadata, Eagle3SpecMetadata): diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 8de0ac8642..bf69917ef2 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -16,7 +16,6 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -@pytest.mark.skip(reason="https://nvbugs/5461761") @pytest.mark.parametrize( "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill", [ @@ -27,7 +26,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..')) [False, "TRTLLM", False, True, True, False], [True, "TRTLLM", False, True, True, False], [True, "TRTLLM", True, False, True, True], - [True, "TRTLLM", True, False, False, True], + # TODO: nvbugs/5461761 + # [True, "TRTLLM", True, False, False, True], ]) @pytest.mark.high_cuda_memory def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,