mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5404046][fix] Fix Nemotron-H flaky CUDA graph / overlap scheduler test (#6485)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
parent
0c42f54a39
commit
6d5da9f7c2
@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
import torch
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import skip_gpu_memory_less_than
|
||||
@ -238,15 +237,15 @@ def test_nemotron_h_correctness():
|
||||
nemotron_h.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5404046")
|
||||
def test_nemotron_h_cuda_graph_overlap_scheduler():
|
||||
prompts = [
|
||||
"Tell me something I don't know about the future of AI",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"Hello, this is a beautiful day and I'm eager to start my day and",
|
||||
"The sky is blue because",
|
||||
"The sum of two and two is",
|
||||
"The largest mammal is the",
|
||||
"The chemical symbol for water is",
|
||||
]
|
||||
sampling_config = SamplingParams(max_tokens=12,
|
||||
|
||||
sampling_config = SamplingParams(max_tokens=10,
|
||||
temperature=0.0,
|
||||
return_generation_logits=True)
|
||||
|
||||
@ -273,32 +272,46 @@ def test_nemotron_h_cuda_graph_overlap_scheduler():
|
||||
prompts, sampling_params=sampling_config, use_tqdm=True)
|
||||
|
||||
# Verify outputs are consistent
|
||||
for (no_cg_no_overlap, with_cg_no_overlap,
|
||||
with_cg_with_overlap) in zip(outputs_no_cg_no_overlap,
|
||||
outputs_with_cg_no_overlap,
|
||||
outputs_with_cg_with_overlap):
|
||||
for i, (no_cg_no_overlap, with_cg_no_overlap,
|
||||
with_cg_with_overlap) in enumerate(
|
||||
zip(outputs_no_cg_no_overlap, outputs_with_cg_no_overlap,
|
||||
outputs_with_cg_with_overlap)):
|
||||
|
||||
assert (no_cg_no_overlap.outputs[0].text ==
|
||||
with_cg_no_overlap.outputs[0].text)
|
||||
assert (with_cg_no_overlap.outputs[0].text ==
|
||||
with_cg_with_overlap.outputs[0].text)
|
||||
assert (
|
||||
no_cg_no_overlap.outputs[0].text ==
|
||||
with_cg_no_overlap.outputs[0].text
|
||||
), f"Prompt {i}: no CG no overlap generated text != with CG no overlap generated text"
|
||||
assert (
|
||||
with_cg_no_overlap.outputs[0].text ==
|
||||
with_cg_with_overlap.outputs[0].text
|
||||
), f"Prompt {i}: with CG no overlap generated text != with CG with overlap generated text"
|
||||
|
||||
# similar to other unittests comparing with / without CG, compare logits of first generation step (2nd generated token)
|
||||
torch.testing.assert_close(
|
||||
no_cg_no_overlap.outputs[0].generation_logits[1, :],
|
||||
with_cg_no_overlap.outputs[0].generation_logits[1, :],
|
||||
atol=0.2,
|
||||
rtol=0.2)
|
||||
rtol=0.2,
|
||||
msg=lambda x:
|
||||
f"Prompt {i}: with/without CG (no overlap) logits for first generated step {x}"
|
||||
)
|
||||
|
||||
# compare logprobs of all generated tokens
|
||||
torch.testing.assert_close(extract_decode_logprobs(no_cg_no_overlap),
|
||||
extract_decode_logprobs(with_cg_no_overlap),
|
||||
atol=0.2,
|
||||
rtol=0.2)
|
||||
torch.testing.assert_close(
|
||||
extract_decode_logprobs(no_cg_no_overlap),
|
||||
extract_decode_logprobs(with_cg_no_overlap),
|
||||
atol=0.2,
|
||||
rtol=0.2,
|
||||
msg=lambda x:
|
||||
f"Prompt {i}: with/without CG (no overlap) logprobs for all selected tokens {x}"
|
||||
)
|
||||
|
||||
# overlap scheduler should have no effect on all logits - low tolerance
|
||||
torch.testing.assert_close(
|
||||
with_cg_no_overlap.outputs[0].generation_logits,
|
||||
with_cg_with_overlap.outputs[0].generation_logits,
|
||||
atol=0.05,
|
||||
rtol=0.05)
|
||||
rtol=0.05,
|
||||
msg=lambda x:
|
||||
f"Prompt {i}: with/without overlap (no CG) all generation logits {x}"
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user