From 8922ca839f18b72264421cb2d06eaa24f706b586 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Thu, 18 Dec 2025 01:09:14 +0000 Subject: [PATCH] Change from correctness check to functional check and unwaive the test. Signed-off-by: Zheyu Fu --- tests/integration/test_lists/waives.txt | 1 - .../_torch/speculative/test_spec_gate.py | 116 +++++++++++++----- 2 files changed, 86 insertions(+), 31 deletions(-) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 0d5cee3216..d8ceef7084 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -403,7 +403,6 @@ accuracy/test_llm_api_pytorch.py::TestNemotronH_56B_Base::test_auto_dtype[tp8-cu accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8ep4-cuda_graph=True] SKIP (https://nvbugs/5707145) accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-cuda_graph=True] SKIP (https://nvbugs/5707145) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] SKIP (https://nvbugs/5596343) -unittest/_torch/speculative/test_spec_gate.py::test_spec_gate_e2e SKIP (https://nvbugs/5710045) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5569696) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp_trtllm] SKIP (https://nvbugs/5715568) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] SKIP (https://nvbugs/5715568) diff --git a/tests/unittest/_torch/speculative/test_spec_gate.py b/tests/unittest/_torch/speculative/test_spec_gate.py index b1720f5923..bc9e2b95f1 100644 --- a/tests/unittest/_torch/speculative/test_spec_gate.py +++ b/tests/unittest/_torch/speculative/test_spec_gate.py @@ -1,28 +1,34 @@ import os import sys import unittest +from unittest.mock import patch import pytest import torch from utils.llm_data import llm_models_root -from utils.util import similar, skip_blackwell from tensorrt_llm import LLM, SamplingParams from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, KvCacheConfig) +from tensorrt_llm.logger import logger sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -# It tests the end-to-end functionality of the SpeculationGate, -# which will turn off spec decode when the average acceptance length is below the threshold. -# It is set with acceptance window and acceptance threshold in spec_config. -# This test set the max_concurrency to a large value to prevent spec decode turned off due to number of effective requests > max_concurrency, -# So that we can only focus on the turning off effect from the SpeculationGate. -@skip_blackwell # TODO: Remove after fixing TRTLLM-GEN FMHA segfault on Blackwell. NVBugs: https://nvbugspro.nvidia.com/bug/5698292 +@pytest.fixture(scope="function") +def enforce_single_worker(monkeypatch): + """Mock functions don't work with multiple processes, so we enforce single worker.""" + monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1") + yield + + +# Tests that the SpeculationGate correctly disables speculative decoding +# when the average acceptance rate drops below the threshold. +# This test uses a mock to simulate low acceptance rates and verifies +# that the spec gate triggers and disables speculation. @pytest.mark.high_cuda_memory -def test_spec_gate_e2e(): +def test_spec_gate_e2e(enforce_single_worker): total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 35: pytest.skip("Not enough memory to load target + draft model") @@ -32,6 +38,8 @@ def test_spec_gate_e2e(): max_batch_size = 2 max_draft_len = 4 + acceptance_window = 3 + acceptance_threshold = 0.6 kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192) cuda_graph_config = CudaGraphConfig(batch_sizes=[1]) @@ -48,40 +56,88 @@ def test_spec_gate_e2e(): spec_config = EagleDecodingConfig( max_draft_len=max_draft_len, speculative_model_dir=eagle_model_dir, - # Llama 3 does not support one model eagle. eagle3_one_model=False, max_concurrency=10000, - acceptance_window=5, - acceptance_length_threshold=0.6, + acceptance_window=acceptance_window, + acceptance_length_threshold=acceptance_threshold, ) - llm_spec = LLM(**llm_common_config, speculative_config=spec_config) - # Output tests prompts = [ "The capital of France is", "The president of the United States is", "What is the capital of Australia?", - "Explain in one sentence why the sky is blue.", - "Who wrote the book 'Pride and Prejudice'?", - "List three U.S. national holidays in the year 2025.", - "What is the currency of Japan?", - "How many players are on a basketball court for one team?", - "List three primary colors.", ] - sampling_params = SamplingParams(max_tokens=32, temperature=0) + sampling_params = SamplingParams(max_tokens=20, temperature=0) + + # Track calls to record_avg_decoded and the disabled state + gate_state = {"record_calls": [], "gate_disabled": False} + + original_record_avg_decoded = SpeculationGate.record_avg_decoded + + def mock_record_avg_decoded(self, + avg_decoded_tokens_per_iter, + request_id=None): + """ + Mock that simulates low acceptance rate (1.2 tokens/iter = 0.2 accepted). + This is below the threshold of 0.6, so the gate should trigger after the window fills. + """ + # Simulate low acceptance: avg_decoded = 1.2 means accepted_len = 0.2 + # This is below threshold (0.6), so gate should trigger + simulated_low_avg = 1.2 + disabled_now, avg = original_record_avg_decoded(self, simulated_low_avg, + request_id) + + gate_state["record_calls"].append({ + "original_avg": avg_decoded_tokens_per_iter, + "simulated_avg": simulated_low_avg, + "disabled_now": disabled_now, + "avg_accept": avg, + "request_id": request_id, + }) + if disabled_now: + gate_state["gate_disabled"] = True + + return disabled_now, avg + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + with patch.object(SpeculationGate, 'record_avg_decoded', + mock_record_avg_decoded): + llm_spec.generate(prompts, sampling_params) + + # Verify the mock was called (requests completed) + assert len(gate_state["record_calls"] + ) > 0, "record_avg_decoded should have been called" + + # Verify the gate was disabled after enough requests with low acceptance + assert gate_state["gate_disabled"], \ + f"Gate should have been disabled with simulated low acceptance. Calls: {gate_state['record_calls']}" + + # Verify the gate triggered at the right time (after window is filled) + # The gate should trigger on the `acceptance_window`-th call (index = window - 1) + disable_indices = [ + i for i, call in enumerate(gate_state["record_calls"]) + if call["disabled_now"] + ] + assert len(disable_indices) == 1, \ + f"Gate should have triggered exactly once, but triggered at indices: {disable_indices}" + assert disable_indices[0] >= acceptance_window - 1, \ + f"Gate should trigger after window ({acceptance_window}) is filled, but triggered at index {disable_indices[0]}" + + # Verify the average acceptance was below threshold when disabled + disable_call = gate_state["record_calls"][disable_indices[0]] + assert disable_call["avg_accept"] is not None + assert disable_call["avg_accept"] < acceptance_threshold, \ + f"Avg acceptance ({disable_call['avg_accept']}) should be below threshold ({acceptance_threshold})" + + logger.debug( + f"Gate correctly triggered after {disable_indices[0] + 1} requests") + logger.debug( + f"Final avg acceptance: {disable_call['avg_accept']:.3f} < threshold {acceptance_threshold}" + ) - results_spec = llm_spec.generate(prompts, sampling_params) - generated_text_spec = [result.outputs[0].text for result in results_spec] llm_spec.shutdown() - llm_ref = LLM(**llm_common_config) - results_ref = llm_ref.generate(prompts, sampling_params) - generated_text_ref = [result.outputs[0].text for result in results_ref] - llm_ref.shutdown() - - for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): - assert similar(text_spec, text_ref) - def test_returns_none_until_window_and_enabled_when_above_threshold(): gate = SpeculationGate(window=3, threshold=0.5)