Change from correctness check to functional check and unwaive the test.

Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
This commit is contained in:
Zheyu Fu 2025-12-18 01:09:14 +00:00
parent 7175d89b48
commit 8922ca839f
2 changed files with 86 additions and 31 deletions

View File

@ -403,7 +403,6 @@ accuracy/test_llm_api_pytorch.py::TestNemotronH_56B_Base::test_auto_dtype[tp8-cu
accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8ep4-cuda_graph=True] SKIP (https://nvbugs/5707145)
accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-cuda_graph=True] SKIP (https://nvbugs/5707145)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] SKIP (https://nvbugs/5596343)
unittest/_torch/speculative/test_spec_gate.py::test_spec_gate_e2e SKIP (https://nvbugs/5710045)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5569696)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp_trtllm] SKIP (https://nvbugs/5715568)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] SKIP (https://nvbugs/5715568)

View File

@ -1,28 +1,34 @@
import os
import sys
import unittest
from unittest.mock import patch
import pytest
import torch
from utils.llm_data import llm_models_root
from utils.util import similar, skip_blackwell
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
KvCacheConfig)
from tensorrt_llm.logger import logger
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
# It tests the end-to-end functionality of the SpeculationGate,
# which will turn off spec decode when the average acceptance length is below the threshold.
# It is set with acceptance window and acceptance threshold in spec_config.
# This test set the max_concurrency to a large value to prevent spec decode turned off due to number of effective requests > max_concurrency,
# So that we can only focus on the turning off effect from the SpeculationGate.
@skip_blackwell # TODO: Remove after fixing TRTLLM-GEN FMHA segfault on Blackwell. NVBugs: https://nvbugspro.nvidia.com/bug/5698292
@pytest.fixture(scope="function")
def enforce_single_worker(monkeypatch):
"""Mock functions don't work with multiple processes, so we enforce single worker."""
monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1")
yield
# Tests that the SpeculationGate correctly disables speculative decoding
# when the average acceptance rate drops below the threshold.
# This test uses a mock to simulate low acceptance rates and verifies
# that the spec gate triggers and disables speculation.
@pytest.mark.high_cuda_memory
def test_spec_gate_e2e():
def test_spec_gate_e2e(enforce_single_worker):
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
pytest.skip("Not enough memory to load target + draft model")
@ -32,6 +38,8 @@ def test_spec_gate_e2e():
max_batch_size = 2
max_draft_len = 4
acceptance_window = 3
acceptance_threshold = 0.6
kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192)
cuda_graph_config = CudaGraphConfig(batch_sizes=[1])
@ -48,40 +56,88 @@ def test_spec_gate_e2e():
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=False,
max_concurrency=10000,
acceptance_window=5,
acceptance_length_threshold=0.6,
acceptance_window=acceptance_window,
acceptance_length_threshold=acceptance_threshold,
)
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
# Output tests
prompts = [
"The capital of France is",
"The president of the United States is",
"What is the capital of Australia?",
"Explain in one sentence why the sky is blue.",
"Who wrote the book 'Pride and Prejudice'?",
"List three U.S. national holidays in the year 2025.",
"What is the currency of Japan?",
"How many players are on a basketball court for one team?",
"List three primary colors.",
]
sampling_params = SamplingParams(max_tokens=32, temperature=0)
sampling_params = SamplingParams(max_tokens=20, temperature=0)
# Track calls to record_avg_decoded and the disabled state
gate_state = {"record_calls": [], "gate_disabled": False}
original_record_avg_decoded = SpeculationGate.record_avg_decoded
def mock_record_avg_decoded(self,
avg_decoded_tokens_per_iter,
request_id=None):
"""
Mock that simulates low acceptance rate (1.2 tokens/iter = 0.2 accepted).
This is below the threshold of 0.6, so the gate should trigger after the window fills.
"""
# Simulate low acceptance: avg_decoded = 1.2 means accepted_len = 0.2
# This is below threshold (0.6), so gate should trigger
simulated_low_avg = 1.2
disabled_now, avg = original_record_avg_decoded(self, simulated_low_avg,
request_id)
gate_state["record_calls"].append({
"original_avg": avg_decoded_tokens_per_iter,
"simulated_avg": simulated_low_avg,
"disabled_now": disabled_now,
"avg_accept": avg,
"request_id": request_id,
})
if disabled_now:
gate_state["gate_disabled"] = True
return disabled_now, avg
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
with patch.object(SpeculationGate, 'record_avg_decoded',
mock_record_avg_decoded):
llm_spec.generate(prompts, sampling_params)
# Verify the mock was called (requests completed)
assert len(gate_state["record_calls"]
) > 0, "record_avg_decoded should have been called"
# Verify the gate was disabled after enough requests with low acceptance
assert gate_state["gate_disabled"], \
f"Gate should have been disabled with simulated low acceptance. Calls: {gate_state['record_calls']}"
# Verify the gate triggered at the right time (after window is filled)
# The gate should trigger on the `acceptance_window`-th call (index = window - 1)
disable_indices = [
i for i, call in enumerate(gate_state["record_calls"])
if call["disabled_now"]
]
assert len(disable_indices) == 1, \
f"Gate should have triggered exactly once, but triggered at indices: {disable_indices}"
assert disable_indices[0] >= acceptance_window - 1, \
f"Gate should trigger after window ({acceptance_window}) is filled, but triggered at index {disable_indices[0]}"
# Verify the average acceptance was below threshold when disabled
disable_call = gate_state["record_calls"][disable_indices[0]]
assert disable_call["avg_accept"] is not None
assert disable_call["avg_accept"] < acceptance_threshold, \
f"Avg acceptance ({disable_call['avg_accept']}) should be below threshold ({acceptance_threshold})"
logger.debug(
f"Gate correctly triggered after {disable_indices[0] + 1} requests")
logger.debug(
f"Final avg acceptance: {disable_call['avg_accept']:.3f} < threshold {acceptance_threshold}"
)
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()
llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
assert similar(text_spec, text_ref)
def test_returns_none_until_window_and_enabled_when_above_threshold():
gate = SpeculationGate(window=3, threshold=0.5)