[https://nvbugs/5762336][fix] support to parse the keyword modules_to_not_convert of the HF model config" (#10527)

Signed-off-by: xxi <xxi@nvidia.com>
This commit is contained in:
xxi 2026-01-13 09:21:01 +08:00 committed by GitHub
parent 48b09e5a25
commit ba1037ca4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -316,32 +316,50 @@ class ModelConfig(Generic[TConfig]):
quant_config = QuantConfig()
layer_quant_config = None
# Read exclude_modules from HF config if present (HF format module names)
hf_exclude_modules = hf_quant_config.get('modules_to_not_convert', None)
# DeepSeek V3 FP8 ckpt
if hf_quant_config.get("quant_method") == "fp8" and hf_quant_config.get(
"weight_block_size", []):
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
if moe_backend == 'TRTLLM':
# TODO: This is a hack. Remove after fp8 bmm is integrated.
quant_config.exclude_modules = [
"*kv_b_proj*", "*k_b_proj*", "*eh_proj"
]
else:
quant_config.exclude_modules = ["*eh_proj"]
block_size = hf_quant_config.get("weight_block_size", [])
assert tuple(block_size) == (
128, 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
quant_config.group_size = block_size[0]
# Set default exclude_modules for FP8_BLOCK_SCALES
if moe_backend == 'TRTLLM':
default_exclude = ["*kv_b_proj*", "*k_b_proj*", "*eh_proj"]
else:
default_exclude = ["*eh_proj"]
# Merge HF config's modules_to_not_convert with default exclude_modules
if hf_exclude_modules is not None:
quant_config.exclude_modules = list(
set(hf_exclude_modules + default_exclude))
else:
quant_config.exclude_modules = default_exclude
# MXFP4 checkpoints.
elif hf_quant_config.get("quant_method") == "mxfp4":
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
moe_backend)
quant_config.group_size = 32
quant_config.exclude_modules = [
# Default exclude_modules for MXFP4 (TRTLLM internal format)
default_exclude = [
'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv',
'embedding', 'unembedding'
]
# Merge HF config's modules_to_not_convert with default exclude_modules
if hf_exclude_modules is not None:
quant_config.exclude_modules = list(
set(hf_exclude_modules + default_exclude))
else:
quant_config.exclude_modules = default_exclude
return quant_config, layer_quant_config
@staticmethod