diff --git a/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py b/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py index 898b46c769..baa6833983 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py +++ b/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py @@ -11,6 +11,8 @@ import os from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Optional, Tuple, Type +from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger + class QuantConfigReader(ABC): """Base class for reading and parsing quantization config.""" @@ -84,6 +86,8 @@ class QuantConfigReaderRegistry: @QuantConfigReaderRegistry.register("modelopt") class ModelOPTQuantConfigReader(QuantConfigReader): _ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens", "*.mixer.gate*", "*.mlp.gate") + DEFAULT_TORCH_DTYPE = "float16" + DEFAULT_KV_CACHE_DTYPE = "fp8" def read_config(self, config: Dict) -> Dict: producer = config.get("producer", {}).get("name") @@ -97,10 +101,12 @@ class ModelOPTQuantConfigReader(QuantConfigReader): quant_config["exclude_modules"] = excludes + [ n for n in self._ALWAYS_EXCLUDE if n not in excludes ] - # Update dtype - if quant_config.get("quant_algo") == "NVFP4": - quant_config["torch_dtype"] = "float16" + if "torch_dtype" not in quant_config: + ad_logger.warning( + f"torch_dtype not found in quant_config, using default {self.DEFAULT_TORCH_DTYPE}" + ) + quant_config["torch_dtype"] = self.DEFAULT_TORCH_DTYPE # Handle kv cache kv_algo = quant_config.get("kv_cache_quant_algo") if kv_algo: @@ -110,11 +116,7 @@ class ModelOPTQuantConfigReader(QuantConfigReader): self._quant_config = quant_config - extra_model_kwargs: Dict[str, Any] = {} - if quant_config.get("quant_algo", None) == "NVFP4": - extra_model_kwargs["torch_dtype"] = "float16" - - return extra_model_kwargs + return {} @classmethod def from_file( diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py index ec71be1abe..a16a8b73bc 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py @@ -1,12 +1,6 @@ """A set of transforms to handle SSM cache transforms.""" -from typing import Tuple - -from torch.fx import GraphModule - -from ...models.factory import ModelFactory -from ...shim.interface import CachedSequenceInterface -from ..interface import SharedConfig, TransformInfo, TransformRegistry +from ..interface import TransformRegistry from .kvcache import _InsertCachedOperator @@ -15,23 +9,6 @@ from .kvcache import _InsertCachedOperator class SSMCacheTransform(_InsertCachedOperator): """A transform to handle SSM cache operations.""" - def _apply( - self, - gm: GraphModule, - cm: CachedSequenceInterface, - factory: ModelFactory, - shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: - qcfg = factory.get_quant_config() - is_nvfp4 = qcfg.get("quant_algo", "").upper() == "NVFP4" - if is_nvfp4 and self.config.backend == "flashinfer_ssm": - self._log_warning( - f"SSM backend '{self.config.backend}' is not compatible with NVFP4 quantization. " - f"Falling back to triton_ssm." - ) - self.config.backend = "triton_ssm" - return super()._apply(gm, cm, factory, shared_config) - @TransformRegistry.register("insert_cached_causal_conv") class InitializeCausalConvCache(_InsertCachedOperator):