From d31482686cc8e137e9a2692c6babc1f83acbb437 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Thu, 22 Jan 2026 22:47:51 -0800 Subject: [PATCH] [https://nvbugs/5680911][fix] Remove @cache decorator to enhance CI stability for unit tests using single process mode (#10730) Signed-off-by: Zheyu Fu Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/llmapi/utils.py | 3 +- tests/integration/test_lists/waives.txt | 3 +- .../speculative/test_draft_len_schedule.py | 20 ++- .../speculative/test_dynamic_spec_decode.py | 1 + .../_torch/speculative/test_spec_gate.py | 115 +++++++++++++----- 5 files changed, 97 insertions(+), 45 deletions(-) diff --git a/tensorrt_llm/llmapi/utils.py b/tensorrt_llm/llmapi/utils.py index 88e22fc639..f79823d844 100644 --- a/tensorrt_llm/llmapi/utils.py +++ b/tensorrt_llm/llmapi/utils.py @@ -15,7 +15,7 @@ import time import traceback import warnings import weakref -from functools import cache, wraps +from functools import wraps from pathlib import Path from queue import Queue from typing import (Any, Callable, Iterable, List, Optional, Tuple, Type, @@ -355,7 +355,6 @@ def enable_llmapi_debug() -> bool: return _enable_llmapi_debug_ -@cache def enable_worker_single_process_for_tp1() -> bool: ''' Tell whether to make worker use single process for TP1. This is helpful for return-logits performance and debugging. ''' diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 191cb49be8..26db2ea7d9 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -197,7 +197,6 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[t unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency] SKIP (https://nvbugs/5808500) unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py::test_build_ad[meta-llama/Meta-Llama-3.1-8B-Instruct-llm_extra_args0-2] SKIP (https://nvbugs/5680755) full:H100_PCIe/unittest/llmapi/test_llm_pytorch.py::test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache SKIP (https://nvbugs/5682551) -unittest/_torch/speculative/test_draft_len_schedule.py::test_correctness_across_batch_sizes[model_drafter-schedule1] SKIP (https://nvbugs/5680911) test_e2e.py::test_openai_completions_example[trt] SKIP (https://nvbugs/5701450) triton_server/test_triton_llm.py::test_llmapi_backend[4-0-disableDecoupleMode-tensorrt_llm] SKIP (https://nvbugs/5701480) unittest/_torch/modules/tests_lora_modules/test_lora_attention_pytorch_flow_vs_trt.py::TestLoraAttentionPytorchFlowVsTRT::test_lora_attention SKIP (https://nvbugs/5701421) @@ -207,7 +206,6 @@ accuracy/test_cli_flow.py::TestGpt2::test_int8_kv_cache SKIP (https://nvbugs/570 accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only_int8_kv_cache[int8] SKIP (https://nvbugs/5666826) disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2pp2_gentp2pp2[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5705199) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] SKIP (https://nvbugs/5596343) -unittest/_torch/speculative/test_spec_gate.py::test_spec_gate_e2e SKIP (https://nvbugs/5710045) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5722629) full:RTXPro6000D/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0] SKIP (https://nvbugs/5748600) full:RTXPro6000D/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] SKIP (https://nvbugs/5748600) @@ -225,6 +223,7 @@ triton_server/test_triton.py::test_opt[opt] SKIP (https://nvbugs/5739981) cpp/test_multi_gpu.py::TestDisagg::test_symmetric_executor[gpt-2proc-mpi_kvcache-90] SKIP (https://nvbugs/5755941) examples/test_granite.py::test_llm_granite[granite-3.0-1b-a400m-instruct-bfloat16] SKIP (https://nvbugs/5608979) examples/test_granite.py::test_llm_granite[granite-3.0-2b-instruct-bfloat16] SKIP (https://nvbugs/5608979) +unittest/_torch/speculative/test_dynamic_spec_decode.py::test_dynamic_spec_decode SKIP (https://nvbugs/5758449) triton_server/test_triton.py::test_gpt_disaggregated_serving_bls[gpt-disaggregated-serving-bls] SKIP (https://nvbugs/5582118) triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP (https://nvbugs/5762854) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822) diff --git a/tests/unittest/_torch/speculative/test_draft_len_schedule.py b/tests/unittest/_torch/speculative/test_draft_len_schedule.py index e64ca7fa53..32c491460f 100644 --- a/tests/unittest/_torch/speculative/test_draft_len_schedule.py +++ b/tests/unittest/_torch/speculative/test_draft_len_schedule.py @@ -13,23 +13,18 @@ from utils.util import similar # # ============================================================================ -# # Fixture: Force single-worker mode for all tests in this module +# # Fixture: Force single-worker mode (only for tests that use mocking) # # ============================================================================ -@pytest.fixture(scope="module", autouse=True) -def enforce_single_worker(): - """Force single-worker mode for all tests in this module.""" - import os - - os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = "1" +@pytest.fixture(scope="function") +def enforce_single_worker(monkeypatch): + """Mock functions don't work with multiple processes, so we enforce single worker.""" + monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1") yield - if "TLLM_WORKER_USE_SINGLE_PROCESS" in os.environ: - del os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] # # ============================================================================ # # test 1: Generation correctness check # # ============================================================================ -@pytest.mark.skip("https://nvbugspro.nvidia.com/bug/5680911") @pytest.mark.parametrize( "drafter_type,schedule", [ @@ -151,8 +146,9 @@ def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict): ], ) @pytest.mark.high_cuda_memory -@pytest.mark.skip("https://nvbugspro.nvidia.com/bug/5680911") -def test_draft_len_schedule_functionality(drafter_type: str, draft_schedule: dict): +def test_draft_len_schedule_functionality( + enforce_single_worker, drafter_type: str, draft_schedule: dict +): if not torch.cuda.is_available(): pytest.skip("CUDA not available") diff --git a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py index 1140f53c62..9ec061f8fa 100644 --- a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py +++ b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py @@ -21,6 +21,7 @@ def enforce_single_worker(monkeypatch): yield +@pytest.mark.skip("https://nvbugs/5758449") @pytest.mark.parametrize("disable_overlap_scheduler", [True, False]) @pytest.mark.high_cuda_memory def test_dynamic_spec_decode(enforce_single_worker, diff --git a/tests/unittest/_torch/speculative/test_spec_gate.py b/tests/unittest/_torch/speculative/test_spec_gate.py index 2a3dd9b99e..49ae058c41 100644 --- a/tests/unittest/_torch/speculative/test_spec_gate.py +++ b/tests/unittest/_torch/speculative/test_spec_gate.py @@ -1,28 +1,34 @@ import os import sys import unittest +from unittest.mock import patch import pytest import torch from utils.llm_data import llm_models_root -from utils.util import similar, skip_blackwell from tensorrt_llm import LLM, SamplingParams from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig, KvCacheConfig) +from tensorrt_llm.logger import logger sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -# It tests the end-to-end functionality of the SpeculationGate, -# which will turn off spec decode when the average acceptance length is below the threshold. -# It is set with acceptance window and acceptance threshold in spec_config. -# This test set the max_concurrency to a large value to prevent spec decode turned off due to number of effective requests > max_concurrency, -# So that we can only focus on the turning off effect from the SpeculationGate. -@skip_blackwell # TODO: Remove after fixing TRTLLM-GEN FMHA segfault on Blackwell. NVBugs: https://nvbugspro.nvidia.com/bug/5698292 +@pytest.fixture(scope="function") +def enforce_single_worker(monkeypatch): + """Mock functions don't work with multiple processes, so we enforce single worker.""" + monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1") + yield + + +# Tests that the SpeculationGate correctly disables speculative decoding +# when the average acceptance rate drops below the threshold. +# This test uses a mock to simulate low acceptance rates and verifies +# that the spec gate triggers and disables speculation. @pytest.mark.high_cuda_memory -def test_spec_gate_e2e(): +def test_spec_gate_e2e(enforce_single_worker): total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 35: pytest.skip("Not enough memory to load target + draft model") @@ -32,6 +38,8 @@ def test_spec_gate_e2e(): max_batch_size = 2 max_draft_len = 4 + acceptance_window = 3 + acceptance_threshold = 0.6 kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192) cuda_graph_config = CudaGraphConfig(batch_sizes=[1]) @@ -48,39 +56,88 @@ def test_spec_gate_e2e(): spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, speculative_model=eagle_model_dir, - # Llama 3 does not support one model eagle. eagle3_one_model=False, max_concurrency=10000, - acceptance_window=5, - acceptance_length_threshold=0.6, + acceptance_window=acceptance_window, + acceptance_length_threshold=acceptance_threshold, ) - llm_spec = LLM(**llm_common_config, speculative_config=spec_config) - # Output tests prompts = [ "The capital of France is", "The president of the United States is", "What is the capital of Australia?", - "Explain in one sentence why the sky is blue.", - "Who wrote the book 'Pride and Prejudice'?", - "List three U.S. national holidays in the year 2025.", - "What is the currency of Japan?", - "How many players are on a basketball court for one team?", - "List three primary colors.", ] - sampling_params = SamplingParams(max_tokens=32, temperature=0) + sampling_params = SamplingParams(max_tokens=20, temperature=0) - results_spec = llm_spec.generate(prompts, sampling_params) - generated_text_spec = [result.outputs[0].text for result in results_spec] - llm_spec.shutdown() + # Track calls to record_avg_decoded and the disabled state + gate_state = {"record_calls": [], "gate_disabled": False} - llm_ref = LLM(**llm_common_config) - results_ref = llm_ref.generate(prompts, sampling_params) - generated_text_ref = [result.outputs[0].text for result in results_ref] - llm_ref.shutdown() + original_record_avg_decoded = SpeculationGate.record_avg_decoded - for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): - assert similar(text_spec, text_ref) + def mock_record_avg_decoded(self, + avg_decoded_tokens_per_iter, + request_id=None): + """ + Mock that simulates low acceptance rate (1.2 tokens/iter = 0.2 accepted). + This is below the threshold of 0.6, so the gate should trigger after the window fills. + """ + # Simulate low acceptance: avg_decoded = 1.2 means accepted_len = 0.2 + # This is below threshold (0.6), so gate should trigger + simulated_low_avg = 1.2 + disabled_now, avg = original_record_avg_decoded(self, simulated_low_avg, + request_id) + + gate_state["record_calls"].append({ + "original_avg": avg_decoded_tokens_per_iter, + "simulated_avg": simulated_low_avg, + "disabled_now": disabled_now, + "avg_accept": avg, + "request_id": request_id, + }) + if disabled_now: + gate_state["gate_disabled"] = True + + return disabled_now, avg + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + try: + with patch.object(SpeculationGate, 'record_avg_decoded', + mock_record_avg_decoded): + llm_spec.generate(prompts, sampling_params) + + # Verify the mock was called (requests completed) + assert len(gate_state["record_calls"] + ) > 0, "record_avg_decoded should have been called" + + # Verify the gate was disabled after enough requests with low acceptance + assert gate_state["gate_disabled"], \ + f"Gate should have been disabled with simulated low acceptance. Calls: {gate_state['record_calls']}" + + # Verify the gate triggered at the right time (after window is filled) + # The gate should trigger on the `acceptance_window`-th call (index = window - 1) + disable_indices = [ + i for i, call in enumerate(gate_state["record_calls"]) + if call["disabled_now"] + ] + assert len(disable_indices) == 1, \ + f"Gate should have triggered exactly once, but triggered at indices: {disable_indices}" + assert disable_indices[0] >= acceptance_window - 1, \ + f"Gate should trigger after window ({acceptance_window}) is filled, but triggered at index {disable_indices[0]}" + + # Verify the average acceptance was below threshold when disabled + disable_call = gate_state["record_calls"][disable_indices[0]] + assert disable_call["avg_accept"] is not None + assert disable_call["avg_accept"] < acceptance_threshold, \ + f"Avg acceptance ({disable_call['avg_accept']}) should be below threshold ({acceptance_threshold})" + + logger.debug( + f"Gate correctly triggered after {disable_indices[0] + 1} requests") + logger.debug( + f"Final avg acceptance: {disable_call['avg_accept']:.3f} < threshold {acceptance_threshold}" + ) + finally: + llm_spec.shutdown() def test_returns_none_until_window_and_enabled_when_above_threshold():