[None][fix] enable NvFP4/FP8 quantization for Nemotron-H architecture (#7589)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91 2025-09-09 11:42:22 +03:00 committed by GitHub
parent 9cb5410067
commit 6e712dd1cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 44 additions and 20 deletions

View File

@ -34,7 +34,7 @@ class NemotronHHfWeightMapper(HfWeightMapper):
if "A_log" in key:
key = key.replace("A_log", "A")
if "_scale" in key and weights[name].dim() == 0:
if "_scale" in key:
new_weights[key] = weights[name]
elif "A" in key:
w = split(weights[name], tp_size, tp_rank)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Optional
import torch
@ -255,7 +256,7 @@ class NemotronHForCausalLM(DecoderModelForCausalLM[NemotronHModel,
if model_config.quant_config.exclude_modules is not None:
model_config.quant_config.exclude_modules = [
k.replace('model.layers.backbone', 'model')
re.sub(r'(model\.layers\.)?backbone', 'model', k)
for k in model_config.quant_config.exclude_modules
]

View File

@ -482,8 +482,27 @@ class DecoderModelForCausalLM(nn.Module,
if quant_config is not None:
if quant_config.exclude_modules is not None:
for name, module in self.named_modules():
is_excluded = quant_config.is_module_excluded_from_quantization(
name)
candidates = [name]
if isinstance(module, Linear):
weight_mode = module.weights_loading_config.weight_mode
if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
# sometimes gate and up proj are not packed in the checkpoint,
# but they still share the same exclusion rule
candidates += [
name.replace('gate_up_proj', 'gate_proj'),
name.replace('gate_up_proj', 'up_proj')
]
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
# sometimes q_proj, k_proj and v_proj are not packed in the checkpoint,
# but they still share the same exclusion rule
candidates += [
name.replace('qkv_proj', 'q_proj'),
name.replace('qkv_proj', 'k_proj'),
name.replace('qkv_proj', 'v_proj')
]
is_excluded = any(
quant_config.is_module_excluded_from_quantization(n)
for n in candidates)
if is_excluded and getattr(module, "quant_config",
None) is not None:
module.quant_config = new_config

View File

@ -89,14 +89,16 @@ class Mamba2Mixer(nn.Module):
self.is_paged_state = False
# in_proj
self.in_proj = Linear(d_model,
d_in_proj,
bias=bias,
dtype=dtype,
mapping=self.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=config.get_quant_config(),
allreduce_strategy=config.allreduce_strategy)
self.in_proj = Linear(
d_model,
d_in_proj,
bias=bias,
dtype=dtype,
mapping=self.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy)
# conv1d, reuse Linear to store weights since it has support for TP > 1 already
self.conv1d = Linear(
@ -138,14 +140,16 @@ class Mamba2Mixer(nn.Module):
)
# out_proj
self.out_proj = Linear(d_inner,
d_model,
bias=bias,
dtype=dtype,
mapping=self.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
allreduce_strategy=config.allreduce_strategy)
self.out_proj = Linear(
d_inner,
d_model,
bias=bias,
dtype=dtype,
mapping=self.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy)
self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype