mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
This commit is contained in:
parent
db35119c7c
commit
d0e7ba102e
@ -1,6 +1,12 @@
|
||||
"""A set of transforms to handle SSM cache transforms."""
|
||||
|
||||
from ..interface import TransformRegistry
|
||||
from typing import Tuple
|
||||
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ..interface import SharedConfig, TransformInfo, TransformRegistry
|
||||
from .kvcache import _InsertCachedOperator
|
||||
|
||||
|
||||
@ -9,6 +15,23 @@ from .kvcache import _InsertCachedOperator
|
||||
class SSMCacheTransform(_InsertCachedOperator):
|
||||
"""A transform to handle SSM cache operations."""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
qcfg = factory.get_quant_config()
|
||||
is_nvfp4 = qcfg.get("quant_algo", "").upper() == "NVFP4"
|
||||
if is_nvfp4 and self.config.backend == "flashinfer_ssm":
|
||||
self._log_warning(
|
||||
f"SSM backend '{self.config.backend}' is not compatible with NVFP4 quantization. "
|
||||
f"Falling back to triton_ssm."
|
||||
)
|
||||
self.config.backend = "triton_ssm"
|
||||
return super()._apply(gm, cm, factory, shared_config)
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_causal_conv")
|
||||
class InitializeCausalConvCache(_InsertCachedOperator):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user