[#11455][bug] Use the torch_dtype set by ModelOpt (#11525)

Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
This commit is contained in:
tcherckez-nvidia 2026-02-15 19:37:59 +02:00 committed by GitHub
parent 361ff36784
commit fcb7bea07f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 32 deletions

View File

@ -11,6 +11,8 @@ import os
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple, Type
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
class QuantConfigReader(ABC):
"""Base class for reading and parsing quantization config."""
@ -84,6 +86,8 @@ class QuantConfigReaderRegistry:
@QuantConfigReaderRegistry.register("modelopt")
class ModelOPTQuantConfigReader(QuantConfigReader):
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens", "*.mixer.gate*", "*.mlp.gate")
DEFAULT_TORCH_DTYPE = "float16"
DEFAULT_KV_CACHE_DTYPE = "fp8"
def read_config(self, config: Dict) -> Dict:
producer = config.get("producer", {}).get("name")
@ -97,10 +101,12 @@ class ModelOPTQuantConfigReader(QuantConfigReader):
quant_config["exclude_modules"] = excludes + [
n for n in self._ALWAYS_EXCLUDE if n not in excludes
]
# Update dtype
if quant_config.get("quant_algo") == "NVFP4":
quant_config["torch_dtype"] = "float16"
if "torch_dtype" not in quant_config:
ad_logger.warning(
f"torch_dtype not found in quant_config, using default {self.DEFAULT_TORCH_DTYPE}"
)
quant_config["torch_dtype"] = self.DEFAULT_TORCH_DTYPE
# Handle kv cache
kv_algo = quant_config.get("kv_cache_quant_algo")
if kv_algo:
@ -110,11 +116,7 @@ class ModelOPTQuantConfigReader(QuantConfigReader):
self._quant_config = quant_config
extra_model_kwargs: Dict[str, Any] = {}
if quant_config.get("quant_algo", None) == "NVFP4":
extra_model_kwargs["torch_dtype"] = "float16"
return extra_model_kwargs
return {}
@classmethod
def from_file(

View File

@ -1,12 +1,6 @@
"""A set of transforms to handle SSM cache transforms."""
from typing import Tuple
from torch.fx import GraphModule
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ..interface import SharedConfig, TransformInfo, TransformRegistry
from ..interface import TransformRegistry
from .kvcache import _InsertCachedOperator
@ -15,23 +9,6 @@ from .kvcache import _InsertCachedOperator
class SSMCacheTransform(_InsertCachedOperator):
"""A transform to handle SSM cache operations."""
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
qcfg = factory.get_quant_config()
is_nvfp4 = qcfg.get("quant_algo", "").upper() == "NVFP4"
if is_nvfp4 and self.config.backend == "flashinfer_ssm":
self._log_warning(
f"SSM backend '{self.config.backend}' is not compatible with NVFP4 quantization. "
f"Falling back to triton_ssm."
)
self.config.backend = "triton_ssm"
return super()._apply(gm, cm, factory, shared_config)
@TransformRegistry.register("insert_cached_causal_conv")
class InitializeCausalConvCache(_InsertCachedOperator):