[None][fix] Fix selective_state_update perf regression for T=1 decode path (#11194)

Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
This commit is contained in:
Gal Hubara-Agam 2026-02-04 09:01:34 +02:00 committed by GitHub
parent 04b7db3ab5
commit de6931bbfd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 3 deletions

View File

@ -50,7 +50,7 @@ from .softplus import softplus
"HAS_INTERMEDIATE_STATE_INDICES":
lambda args: args["intermediate_state_indices_ptr"] is not None
})
@triton.jit(do_not_specialize=["T"])
@triton.jit()
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr,

View File

@ -5767,7 +5767,6 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
task.evaluate(llm,
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
@pytest.mark.skip(reason="Skip MTP test due to no model path file in CI")
@skip_pre_blackwell
@pytest.mark.skip_less_mpi_world_size(8)
def test_nvfp4_8gpus_mtp(self):
@ -5777,7 +5776,7 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
num_nextn_predict_layers=3,
mtp_eagle_one_model=True,
)
model_path = f"{llm_models_root()}/nemotron-super-sft-repeated-mtp-iter-0010600-nvfp4-fp8kv"
model_path = f"{llm_models_root()}/NVIDIA-Nemotron-3-Super-120B-NVFP4-FP8KV-011526"
with LLM(
model_path,
kv_cache_config=KvCacheConfig(