[#11455][fix] Fallback to triton_ssm for nvfp4 quantization (#11456)

Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
This commit is contained in:
Gal Hubara-Agam 2026-02-13 07:38:37 +02:00 committed by GitHub
parent db35119c7c
commit d0e7ba102e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,12 @@
"""A set of transforms to handle SSM cache transforms."""
from ..interface import TransformRegistry
from typing import Tuple
from torch.fx import GraphModule
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ..interface import SharedConfig, TransformInfo, TransformRegistry
from .kvcache import _InsertCachedOperator
@ -9,6 +15,23 @@ from .kvcache import _InsertCachedOperator
class SSMCacheTransform(_InsertCachedOperator):
"""A transform to handle SSM cache operations."""
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
qcfg = factory.get_quant_config()
is_nvfp4 = qcfg.get("quant_algo", "").upper() == "NVFP4"
if is_nvfp4 and self.config.backend == "flashinfer_ssm":
self._log_warning(
f"SSM backend '{self.config.backend}' is not compatible with NVFP4 quantization. "
f"Falling back to triton_ssm."
)
self.config.backend = "triton_ssm"
return super()._apply(gm, cm, factory, shared_config)
@TransformRegistry.register("insert_cached_causal_conv")
class InitializeCausalConvCache(_InsertCachedOperator):