From d0e7ba102e2c616f96461a759f5c7846a544e6ae Mon Sep 17 00:00:00 2001 From: Gal Hubara-Agam <96368689+galagam@users.noreply.github.com> Date: Fri, 13 Feb 2026 07:38:37 +0200 Subject: [PATCH] [#11455][fix] Fallback to triton_ssm for nvfp4 quantization (#11456) Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- .../transform/library/ssm_cache.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py index a16a8b73bc..ec71be1abe 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py @@ -1,6 +1,12 @@ """A set of transforms to handle SSM cache transforms.""" -from ..interface import TransformRegistry +from typing import Tuple + +from torch.fx import GraphModule + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ..interface import SharedConfig, TransformInfo, TransformRegistry from .kvcache import _InsertCachedOperator @@ -9,6 +15,23 @@ from .kvcache import _InsertCachedOperator class SSMCacheTransform(_InsertCachedOperator): """A transform to handle SSM cache operations.""" + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + qcfg = factory.get_quant_config() + is_nvfp4 = qcfg.get("quant_algo", "").upper() == "NVFP4" + if is_nvfp4 and self.config.backend == "flashinfer_ssm": + self._log_warning( + f"SSM backend '{self.config.backend}' is not compatible with NVFP4 quantization. " + f"Falling back to triton_ssm." + ) + self.config.backend = "triton_ssm" + return super()._apply(gm, cm, factory, shared_config) + @TransformRegistry.register("insert_cached_causal_conv") class InitializeCausalConvCache(_InsertCachedOperator):