diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py index a16a8b73bc..ec71be1abe 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py @@ -1,6 +1,12 @@ """A set of transforms to handle SSM cache transforms.""" -from ..interface import TransformRegistry +from typing import Tuple + +from torch.fx import GraphModule + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ..interface import SharedConfig, TransformInfo, TransformRegistry from .kvcache import _InsertCachedOperator @@ -9,6 +15,23 @@ from .kvcache import _InsertCachedOperator class SSMCacheTransform(_InsertCachedOperator): """A transform to handle SSM cache operations.""" + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + qcfg = factory.get_quant_config() + is_nvfp4 = qcfg.get("quant_algo", "").upper() == "NVFP4" + if is_nvfp4 and self.config.backend == "flashinfer_ssm": + self._log_warning( + f"SSM backend '{self.config.backend}' is not compatible with NVFP4 quantization. " + f"Falling back to triton_ssm." + ) + self.config.backend = "triton_ssm" + return super()._apply(gm, cm, factory, shared_config) + @TransformRegistry.register("insert_cached_causal_conv") class InitializeCausalConvCache(_InsertCachedOperator):