From 6e712dd1cc2e0959a524cc72d7441f450708c73b Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:42:22 +0300 Subject: [PATCH] [None][fix] enable NvFP4/FP8 quantization for Nemotron-H architecture (#7589) Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- .../hf/nemotron_h_weight_mapper.py | 2 +- .../_torch/models/modeling_nemotron_h.py | 3 +- tensorrt_llm/_torch/models/modeling_utils.py | 23 ++++++++++-- .../_torch/modules/mamba/mamba2_mixer.py | 36 ++++++++++--------- 4 files changed, 44 insertions(+), 20 deletions(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py index e5a5245ee8..170f57d42c 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py @@ -34,7 +34,7 @@ class NemotronHHfWeightMapper(HfWeightMapper): if "A_log" in key: key = key.replace("A_log", "A") - if "_scale" in key and weights[name].dim() == 0: + if "_scale" in key: new_weights[key] = weights[name] elif "A" in key: w = split(weights[name], tp_size, tp_rank) diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index e548d09a08..d271b30b8b 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Optional import torch @@ -255,7 +256,7 @@ class NemotronHForCausalLM(DecoderModelForCausalLM[NemotronHModel, if model_config.quant_config.exclude_modules is not None: model_config.quant_config.exclude_modules = [ - k.replace('model.layers.backbone', 'model') + re.sub(r'(model\.layers\.)?backbone', 'model', k) for k in model_config.quant_config.exclude_modules ] diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index ad200f1a3e..284a31c26a 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -482,8 +482,27 @@ class DecoderModelForCausalLM(nn.Module, if quant_config is not None: if quant_config.exclude_modules is not None: for name, module in self.named_modules(): - is_excluded = quant_config.is_module_excluded_from_quantization( - name) + candidates = [name] + if isinstance(module, Linear): + weight_mode = module.weights_loading_config.weight_mode + if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR: + # sometimes gate and up proj are not packed in the checkpoint, + # but they still share the same exclusion rule + candidates += [ + name.replace('gate_up_proj', 'gate_proj'), + name.replace('gate_up_proj', 'up_proj') + ] + elif weight_mode == WeightMode.FUSED_QKV_LINEAR: + # sometimes q_proj, k_proj and v_proj are not packed in the checkpoint, + # but they still share the same exclusion rule + candidates += [ + name.replace('qkv_proj', 'q_proj'), + name.replace('qkv_proj', 'k_proj'), + name.replace('qkv_proj', 'v_proj') + ] + is_excluded = any( + quant_config.is_module_excluded_from_quantization(n) + for n in candidates) if is_excluded and getattr(module, "quant_config", None) is not None: module.quant_config = new_config diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index d5a3e3996a..41872af46f 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -89,14 +89,16 @@ class Mamba2Mixer(nn.Module): self.is_paged_state = False # in_proj - self.in_proj = Linear(d_model, - d_in_proj, - bias=bias, - dtype=dtype, - mapping=self.mapping, - tensor_parallel_mode=TensorParallelMode.COLUMN, - quant_config=config.get_quant_config(), - allreduce_strategy=config.allreduce_strategy) + self.in_proj = Linear( + d_model, + d_in_proj, + bias=bias, + dtype=dtype, + mapping=self.mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + quant_config=config.get_quant_config(), + skip_create_weights_in_init=config.skip_create_weights_in_init, + allreduce_strategy=config.allreduce_strategy) # conv1d, reuse Linear to store weights since it has support for TP > 1 already self.conv1d = Linear( @@ -138,14 +140,16 @@ class Mamba2Mixer(nn.Module): ) # out_proj - self.out_proj = Linear(d_inner, - d_model, - bias=bias, - dtype=dtype, - mapping=self.mapping, - tensor_parallel_mode=TensorParallelMode.ROW, - quant_config=config.get_quant_config(), - allreduce_strategy=config.allreduce_strategy) + self.out_proj = Linear( + d_inner, + d_model, + bias=bias, + dtype=dtype, + mapping=self.mapping, + tensor_parallel_mode=TensorParallelMode.ROW, + quant_config=config.get_quant_config(), + skip_create_weights_in_init=config.skip_create_weights_in_init, + allreduce_strategy=config.allreduce_strategy) self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype