[https://nvbugs/5558117][fix] Allow per-layer quant config from hf_quant_config.json (#8617)

Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
This commit is contained in:
Anthony Chang 2025-10-31 19:41:44 +08:00 committed by GitHub
parent 98453d2bb7
commit 852e5060aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 147 additions and 25 deletions

View File

@ -113,9 +113,9 @@ class ModelConfig(Generic[TConfig]):
pretrained_config: Optional[TConfig] = None
mapping: Mapping = field(default_factory=Mapping)
# quantization configs
# Quantization configs
quant_config: QuantConfig = field(default_factory=QuantConfig)
# TODO(qijun): support per linear layer quantization
# Per linear layer quantization in quant_cfg.json or hf_quant_config.json
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
# Delay weights creation to DecoderModelForCausalLM.__post_init__
# to support mixed quantization.
@ -278,28 +278,41 @@ class ModelConfig(Generic[TConfig]):
'exclude_modules', None)
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
mixed_quant_config_file = transformers.utils.hub.cached_file(
checkpoint_dir, 'quant_cfg.json')
with open(mixed_quant_config_file) as fm:
mixed_quant_configs = json.load(fm)
# kv_cache_quant_algo is global regardless of MIXED_PRECISION
kv_cache_quant_algo = mixed_quant_configs['kv_cache_quant_algo']
mixed_quant_configs = mixed_quant_configs['quantized_layers']
if kv_cache_quant_algo is not None and quant_config.kv_cache_quant_algo is not None:
if kv_cache_quant_algo != quant_config.kv_cache_quant_algo:
raise RuntimeError(
f"The kvcache config in 'quant_cfg.json', {kv_cache_quant_algo},"
f"is different from 'hf_quant_config.json', {quant_config.kv_cache_quant_algo}!"
)
kv_cache_quant_algo = kv_cache_quant_algo or quant_config.kv_cache_quant_algo
for layer in mixed_quant_configs:
config = QuantConfig()
config.kv_cache_quant_algo = kv_cache_quant_algo
config.quant_algo = mixed_quant_configs[layer]['quant_algo']
config.group_size = mixed_quant_configs[layer].get(
'group_size', None)
mixed_quant_configs[layer] = config
json_extended_quant_configs: dict = {}
# See tests/unittest/llmapi/test_llm_quant.py
try:
mixed_quant_config_file = transformers.utils.hub.cached_file(
checkpoint_dir, 'quant_cfg.json')
with open(mixed_quant_config_file) as fm:
json_extended_quant_configs = json.load(fm)
except Exception:
logger.info(
f"No quant_cfg.json found for layer quant info, using hf_quant_config.json."
)
json_quant_configs.update(json_extended_quant_configs)
# kv_cache_quant_algo is global regardless of MIXED_PRECISION
kv_cache_quant_algo = json_quant_configs.get(
'kv_cache_quant_algo', None)
mixed_quant_configs = json_quant_configs.get(
'quantized_layers', None)
if (kv_quant_lhs := json_extended_quant_configs.get(
"kv_cache_quant_algo", None)) is not None and (
kv_quant_rhs :=
quant_config.kv_cache_quant_algo) is not None:
if kv_quant_lhs != kv_quant_rhs:
raise RuntimeError(
f"The kvcache config in 'quant_cfg.json', {kv_quant_lhs},"
f"is different from 'hf_quant_config.json', {kv_quant_rhs}!"
)
quant_config.kv_cache_quant_algo = json_quant_configs[
"kv_cache_quant_algo"]
for layer in mixed_quant_configs:
config = QuantConfig()
config.kv_cache_quant_algo = kv_cache_quant_algo
config.quant_algo = mixed_quant_configs[layer]['quant_algo']
config.group_size = mixed_quant_configs[layer].get(
'group_size', None)
mixed_quant_configs[layer] = config
layer_quant_config = mixed_quant_configs
elif quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
if quant_config.group_size is None:
@ -459,6 +472,9 @@ class ModelConfig(Generic[TConfig]):
except OSError:
return None
# Some checkpoints lack torch_dtype, populate with dtype
pretrained_config.torch_dtype = getattr(pretrained_config, 'dtype',
None)
quant_config = QuantConfig()
layer_quant_config = None
moe_backend = kwargs.get('moe_backend', 'CUTLASS')

View File

@ -390,7 +390,9 @@ class ModelLoader:
)
for key, value in hf_quant_config.items():
logger.info(f"Setting {key}={value} from HF quant config.")
logger.info(
f"Setting {key}={str(value)[:100]}{'...' if len(str(value)) > 100 else ''} from HF quant config."
)
setattr(quant_config, key, value)
# Update the quant_config in llm_args for pytorch

View File

@ -1,6 +1,11 @@
import json
import tempfile
from pathlib import Path
import pytest
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
from tensorrt_llm.llmapi.llm_utils import CalibConfig, QuantAlgo, QuantConfig
@ -71,6 +76,105 @@ def test_llm_fp8_quantization_modelOpt_ckpt():
assert output.outputs[0].text == " D E F G H I"
def test_quant_cfg_from_quant_cfg_json():
"""
Test loading MIXED_PRECISION config from quant_cfg.json with per-layer quantization.
This supports the workflow from examples/quantization/quantize_mixed_precision_moe.py.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model_dir = Path(tmp_dir)
# Create dummy quant_cfg.json
quant_cfg_content = {
"quant_algo": "MIXED_PRECISION",
"kv_cache_quant_algo": "FP8",
"quantized_layers": {
"model.layers.0.self_attn.q_proj": {
"quant_algo": "FP8"
},
"model.layers.0.self_attn.k_proj": {
"quant_algo": "FP8"
},
"model.layers.1.mlp.gate_proj": {
"quant_algo": "W4A8_AWQ",
"group_size": 128
}
}
}
quant_cfg_file = model_dir / "quant_cfg.json"
with open(quant_cfg_file, 'w') as f:
json.dump(quant_cfg_content, f)
# Create dummy hf_quant_config.json
hf_quant_config_content = {
"quantization": {
"quant_algo": "MIXED_PRECISION",
"kv_cache_quant_algo": None,
}
}
hf_quant_config_file = model_dir / "hf_quant_config.json"
with open(hf_quant_config_file, 'w') as f:
json.dump(hf_quant_config_content, f)
quant_config, layer_quant_config = ModelConfig.load_modelopt_quant_config(
hf_quant_config_file, model_dir, None)
# Verify quant_cfg.json was loaded
assert quant_config.quant_algo == QuantAlgo.MIXED_PRECISION
assert quant_config.kv_cache_quant_algo == "FP8"
# Verify layer configs were created correctly
assert layer_quant_config[
"model.layers.0.self_attn.q_proj"].quant_algo == "FP8"
assert layer_quant_config[
"model.layers.0.self_attn.q_proj"].kv_cache_quant_algo == "FP8"
assert layer_quant_config[
"model.layers.1.mlp.gate_proj"].quant_algo == "W4A8_AWQ"
assert layer_quant_config[
"model.layers.1.mlp.gate_proj"].group_size == 128
def test_quant_cfg_from_hf_quant_config():
"""Test fallback to hf_quant_config.json when quant_cfg.json is missing."""
with tempfile.TemporaryDirectory() as tmp_dir:
model_dir = Path(tmp_dir)
# Create dummy hf_quant_config.json
hf_quant_config_content = {
"quantization": {
"quant_algo": "MIXED_PRECISION",
"kv_cache_quant_algo": "FP8",
"quantized_layers": {
"model.layers.0.self_attn.q_proj": {
"quant_algo": "FP8"
},
"model.layers.0.mlp.up_proj": {
"quant_algo": "W4A16_AWQ",
"group_size": 64
}
}
}
}
hf_quant_config_file = model_dir / "hf_quant_config.json"
with open(hf_quant_config_file, 'w') as f:
json.dump(hf_quant_config_content, f)
quant_config, layer_quant_config = ModelConfig.load_modelopt_quant_config(
hf_quant_config_file, model_dir, None)
# Verify layer configs
assert quant_config.quant_algo == QuantAlgo.MIXED_PRECISION
assert quant_config.kv_cache_quant_algo == "FP8"
assert layer_quant_config[
"model.layers.0.self_attn.q_proj"].quant_algo == "FP8"
assert layer_quant_config[
"model.layers.0.mlp.up_proj"].quant_algo == "W4A16_AWQ"
assert layer_quant_config["model.layers.0.mlp.up_proj"].group_size == 64
if __name__ == "__main__":
test_llm_int4_awq_quantization()
test_llm_fp8_quantization_modelOpt_ckpt()
test_quant_cfg_from_quant_cfg_json()
test_quant_cfg_from_hf_quant_config()