mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5762336][fix] support to parse the keyword modules_to_not_convert of the HF model config" (#10527)
Signed-off-by: xxi <xxi@nvidia.com>
This commit is contained in:
parent
48b09e5a25
commit
ba1037ca4a
@ -316,32 +316,50 @@ class ModelConfig(Generic[TConfig]):
|
||||
quant_config = QuantConfig()
|
||||
layer_quant_config = None
|
||||
|
||||
# Read exclude_modules from HF config if present (HF format module names)
|
||||
hf_exclude_modules = hf_quant_config.get('modules_to_not_convert', None)
|
||||
|
||||
# DeepSeek V3 FP8 ckpt
|
||||
if hf_quant_config.get("quant_method") == "fp8" and hf_quant_config.get(
|
||||
"weight_block_size", []):
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
if moe_backend == 'TRTLLM':
|
||||
# TODO: This is a hack. Remove after fp8 bmm is integrated.
|
||||
quant_config.exclude_modules = [
|
||||
"*kv_b_proj*", "*k_b_proj*", "*eh_proj"
|
||||
]
|
||||
else:
|
||||
quant_config.exclude_modules = ["*eh_proj"]
|
||||
|
||||
block_size = hf_quant_config.get("weight_block_size", [])
|
||||
assert tuple(block_size) == (
|
||||
128, 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
|
||||
quant_config.group_size = block_size[0]
|
||||
|
||||
# Set default exclude_modules for FP8_BLOCK_SCALES
|
||||
if moe_backend == 'TRTLLM':
|
||||
default_exclude = ["*kv_b_proj*", "*k_b_proj*", "*eh_proj"]
|
||||
else:
|
||||
default_exclude = ["*eh_proj"]
|
||||
|
||||
# Merge HF config's modules_to_not_convert with default exclude_modules
|
||||
if hf_exclude_modules is not None:
|
||||
quant_config.exclude_modules = list(
|
||||
set(hf_exclude_modules + default_exclude))
|
||||
else:
|
||||
quant_config.exclude_modules = default_exclude
|
||||
# MXFP4 checkpoints.
|
||||
elif hf_quant_config.get("quant_method") == "mxfp4":
|
||||
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
|
||||
moe_backend)
|
||||
quant_config.group_size = 32
|
||||
quant_config.exclude_modules = [
|
||||
|
||||
# Default exclude_modules for MXFP4 (TRTLLM internal format)
|
||||
default_exclude = [
|
||||
'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv',
|
||||
'embedding', 'unembedding'
|
||||
]
|
||||
|
||||
# Merge HF config's modules_to_not_convert with default exclude_modules
|
||||
if hf_exclude_modules is not None:
|
||||
quant_config.exclude_modules = list(
|
||||
set(hf_exclude_modules + default_exclude))
|
||||
else:
|
||||
quant_config.exclude_modules = default_exclude
|
||||
|
||||
return quant_config, layer_quant_config
|
||||
|
||||
@staticmethod
|
||||
|
||||
Loading…
Reference in New Issue
Block a user