From de6931bbfd524a20de319bce71aa6dd4ad029f77 Mon Sep 17 00:00:00 2001 From: Gal Hubara-Agam <96368689+galagam@users.noreply.github.com> Date: Wed, 4 Feb 2026 09:01:34 +0200 Subject: [PATCH] [None][fix] Fix selective_state_update perf regression for T=1 decode path (#11194) Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- tensorrt_llm/_torch/modules/mamba/selective_state_update.py | 2 +- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/modules/mamba/selective_state_update.py b/tensorrt_llm/_torch/modules/mamba/selective_state_update.py index a1ba2dfa0a..4f639bbd2d 100644 --- a/tensorrt_llm/_torch/modules/mamba/selective_state_update.py +++ b/tensorrt_llm/_torch/modules/mamba/selective_state_update.py @@ -50,7 +50,7 @@ from .softplus import softplus "HAS_INTERMEDIATE_STATE_INDICES": lambda args: args["intermediate_state_indices_ptr"] is not None }) -@triton.jit(do_not_specialize=["T"]) +@triton.jit() def _selective_scan_update_kernel( # Pointers to matrices state_ptr, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index cef8dbdd27..8be3b35091 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -5767,7 +5767,6 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness): task.evaluate(llm, extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) - @pytest.mark.skip(reason="Skip MTP test due to no model path file in CI") @skip_pre_blackwell @pytest.mark.skip_less_mpi_world_size(8) def test_nvfp4_8gpus_mtp(self): @@ -5777,7 +5776,7 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness): num_nextn_predict_layers=3, mtp_eagle_one_model=True, ) - model_path = f"{llm_models_root()}/nemotron-super-sft-repeated-mtp-iter-0010600-nvfp4-fp8kv" + model_path = f"{llm_models_root()}/NVIDIA-Nemotron-3-Super-120B-NVFP4-FP8KV-011526" with LLM( model_path, kv_cache_config=KvCacheConfig(