mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][fix] enable NvFP4/FP8 quantization for Nemotron-H architecture (#7589)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
parent
9cb5410067
commit
6e712dd1cc
@ -34,7 +34,7 @@ class NemotronHHfWeightMapper(HfWeightMapper):
|
||||
if "A_log" in key:
|
||||
key = key.replace("A_log", "A")
|
||||
|
||||
if "_scale" in key and weights[name].dim() == 0:
|
||||
if "_scale" in key:
|
||||
new_weights[key] = weights[name]
|
||||
elif "A" in key:
|
||||
w = split(weights[name], tp_size, tp_rank)
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -255,7 +256,7 @@ class NemotronHForCausalLM(DecoderModelForCausalLM[NemotronHModel,
|
||||
|
||||
if model_config.quant_config.exclude_modules is not None:
|
||||
model_config.quant_config.exclude_modules = [
|
||||
k.replace('model.layers.backbone', 'model')
|
||||
re.sub(r'(model\.layers\.)?backbone', 'model', k)
|
||||
for k in model_config.quant_config.exclude_modules
|
||||
]
|
||||
|
||||
|
||||
@ -482,8 +482,27 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
if quant_config is not None:
|
||||
if quant_config.exclude_modules is not None:
|
||||
for name, module in self.named_modules():
|
||||
is_excluded = quant_config.is_module_excluded_from_quantization(
|
||||
name)
|
||||
candidates = [name]
|
||||
if isinstance(module, Linear):
|
||||
weight_mode = module.weights_loading_config.weight_mode
|
||||
if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
|
||||
# sometimes gate and up proj are not packed in the checkpoint,
|
||||
# but they still share the same exclusion rule
|
||||
candidates += [
|
||||
name.replace('gate_up_proj', 'gate_proj'),
|
||||
name.replace('gate_up_proj', 'up_proj')
|
||||
]
|
||||
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
|
||||
# sometimes q_proj, k_proj and v_proj are not packed in the checkpoint,
|
||||
# but they still share the same exclusion rule
|
||||
candidates += [
|
||||
name.replace('qkv_proj', 'q_proj'),
|
||||
name.replace('qkv_proj', 'k_proj'),
|
||||
name.replace('qkv_proj', 'v_proj')
|
||||
]
|
||||
is_excluded = any(
|
||||
quant_config.is_module_excluded_from_quantization(n)
|
||||
for n in candidates)
|
||||
if is_excluded and getattr(module, "quant_config",
|
||||
None) is not None:
|
||||
module.quant_config = new_config
|
||||
|
||||
@ -89,14 +89,16 @@ class Mamba2Mixer(nn.Module):
|
||||
self.is_paged_state = False
|
||||
|
||||
# in_proj
|
||||
self.in_proj = Linear(d_model,
|
||||
d_in_proj,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
mapping=self.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=config.get_quant_config(),
|
||||
allreduce_strategy=config.allreduce_strategy)
|
||||
self.in_proj = Linear(
|
||||
d_model,
|
||||
d_in_proj,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
mapping=self.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
allreduce_strategy=config.allreduce_strategy)
|
||||
|
||||
# conv1d, reuse Linear to store weights since it has support for TP > 1 already
|
||||
self.conv1d = Linear(
|
||||
@ -138,14 +140,16 @@ class Mamba2Mixer(nn.Module):
|
||||
)
|
||||
|
||||
# out_proj
|
||||
self.out_proj = Linear(d_inner,
|
||||
d_model,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
mapping=self.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=config.get_quant_config(),
|
||||
allreduce_strategy=config.allreduce_strategy)
|
||||
self.out_proj = Linear(
|
||||
d_inner,
|
||||
d_model,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
mapping=self.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
allreduce_strategy=config.allreduce_strategy)
|
||||
|
||||
self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user