Support NemotronH FP8 Quantization

(1) match quant exclude modules names to TRTLLM names 
(2) No need for any special weight loading for quantization scales weights (#3891)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91 2025-04-29 18:51:43 +03:00 committed by GitHub
parent 68a19a33d4
commit 35010e8073
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -232,6 +232,13 @@ class NemotronHForCausalLM(DecoderModelForCausalLM[NemotronHModel,
):
if not model_config.mapping.tp_size in [1, 2, 4, 8]:
raise ValueError("TP has to be either 1, 2, 4 or 8")
if model_config.quant_config.exclude_modules is not None:
model_config.quant_config.exclude_modules = [
k.replace('model.layers.backbone', 'model')
for k in model_config.quant_config.exclude_modules
]
super().__init__(
NemotronHModel(model_config),
config=model_config,
@ -263,7 +270,9 @@ class NemotronHForCausalLM(DecoderModelForCausalLM[NemotronHModel,
if "A_log" in key:
key = key.replace("A_log", "A")
if "A" in key:
if "_scale" in key and weights[name].dim() == 0:
new_weights[key] = weights[name]
elif "A" in key:
w = split(weights[name], tp_size, tp_rank)
w = w.to(torch.float32)
w = -torch.exp(w)