mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5502352][fix] Fix 2-model CDL path (#7543)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
parent
99b98f1374
commit
45390402fc
@ -20,9 +20,12 @@ from tensorrt_llm._torch.speculative.interface import SpecMetadata
|
||||
@contextmanager
|
||||
def save_metadata_state(attn_metadata: AttentionMetadata,
|
||||
spec_metadata: SpecMetadata) -> None:
|
||||
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda",
|
||||
"kv_lens_cuda")
|
||||
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
|
||||
batch_size = attn_metadata.num_seqs
|
||||
# Do not use prepare_for_spec_dec for this special field.
|
||||
# TRTLLM attention uses views of this tensor internally and prepare_for_spec_dec
|
||||
# creates a copy. If you write to the copy, TRTLLM attention won't see the updates.
|
||||
kv_lens = attn_metadata.kv_lens_cuda[:batch_size].clone()
|
||||
|
||||
if attn_metadata.is_cuda_graph:
|
||||
assert spec_metadata.is_cuda_graph
|
||||
@ -39,6 +42,8 @@ def save_metadata_state(attn_metadata: AttentionMetadata,
|
||||
yield
|
||||
finally:
|
||||
attn_metadata.restore_from_spec_dec()
|
||||
attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens)
|
||||
|
||||
if attn_metadata.is_cuda_graph:
|
||||
spec_metadata.num_tokens = num_tokens
|
||||
if isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
|
||||
@ -16,7 +16,6 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5461761")
|
||||
@pytest.mark.parametrize(
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
|
||||
[
|
||||
@ -27,7 +26,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
[False, "TRTLLM", False, True, True, False],
|
||||
[True, "TRTLLM", False, True, True, False],
|
||||
[True, "TRTLLM", True, False, True, True],
|
||||
[True, "TRTLLM", True, False, False, True],
|
||||
# TODO: nvbugs/5461761
|
||||
# [True, "TRTLLM", True, False, False, True],
|
||||
])
|
||||
@pytest.mark.high_cuda_memory
|
||||
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user