mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-18 16:55:08 +08:00
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
This commit is contained in:
parent
361ff36784
commit
fcb7bea07f
@ -11,6 +11,8 @@ import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
|
||||
|
||||
|
||||
class QuantConfigReader(ABC):
|
||||
"""Base class for reading and parsing quantization config."""
|
||||
@ -84,6 +86,8 @@ class QuantConfigReaderRegistry:
|
||||
@QuantConfigReaderRegistry.register("modelopt")
|
||||
class ModelOPTQuantConfigReader(QuantConfigReader):
|
||||
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens", "*.mixer.gate*", "*.mlp.gate")
|
||||
DEFAULT_TORCH_DTYPE = "float16"
|
||||
DEFAULT_KV_CACHE_DTYPE = "fp8"
|
||||
|
||||
def read_config(self, config: Dict) -> Dict:
|
||||
producer = config.get("producer", {}).get("name")
|
||||
@ -97,10 +101,12 @@ class ModelOPTQuantConfigReader(QuantConfigReader):
|
||||
quant_config["exclude_modules"] = excludes + [
|
||||
n for n in self._ALWAYS_EXCLUDE if n not in excludes
|
||||
]
|
||||
# Update dtype
|
||||
if quant_config.get("quant_algo") == "NVFP4":
|
||||
quant_config["torch_dtype"] = "float16"
|
||||
|
||||
if "torch_dtype" not in quant_config:
|
||||
ad_logger.warning(
|
||||
f"torch_dtype not found in quant_config, using default {self.DEFAULT_TORCH_DTYPE}"
|
||||
)
|
||||
quant_config["torch_dtype"] = self.DEFAULT_TORCH_DTYPE
|
||||
# Handle kv cache
|
||||
kv_algo = quant_config.get("kv_cache_quant_algo")
|
||||
if kv_algo:
|
||||
@ -110,11 +116,7 @@ class ModelOPTQuantConfigReader(QuantConfigReader):
|
||||
|
||||
self._quant_config = quant_config
|
||||
|
||||
extra_model_kwargs: Dict[str, Any] = {}
|
||||
if quant_config.get("quant_algo", None) == "NVFP4":
|
||||
extra_model_kwargs["torch_dtype"] = "float16"
|
||||
|
||||
return extra_model_kwargs
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
|
||||
@ -1,12 +1,6 @@
|
||||
"""A set of transforms to handle SSM cache transforms."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ..interface import SharedConfig, TransformInfo, TransformRegistry
|
||||
from ..interface import TransformRegistry
|
||||
from .kvcache import _InsertCachedOperator
|
||||
|
||||
|
||||
@ -15,23 +9,6 @@ from .kvcache import _InsertCachedOperator
|
||||
class SSMCacheTransform(_InsertCachedOperator):
|
||||
"""A transform to handle SSM cache operations."""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
qcfg = factory.get_quant_config()
|
||||
is_nvfp4 = qcfg.get("quant_algo", "").upper() == "NVFP4"
|
||||
if is_nvfp4 and self.config.backend == "flashinfer_ssm":
|
||||
self._log_warning(
|
||||
f"SSM backend '{self.config.backend}' is not compatible with NVFP4 quantization. "
|
||||
f"Falling back to triton_ssm."
|
||||
)
|
||||
self.config.backend = "triton_ssm"
|
||||
return super()._apply(gm, cm, factory, shared_config)
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_causal_conv")
|
||||
class InitializeCausalConvCache(_InsertCachedOperator):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user