mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5558117][fix] Allow per-layer quant config from hf_quant_config.json (#8617)
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
This commit is contained in:
parent
98453d2bb7
commit
852e5060aa
@ -113,9 +113,9 @@ class ModelConfig(Generic[TConfig]):
|
||||
pretrained_config: Optional[TConfig] = None
|
||||
mapping: Mapping = field(default_factory=Mapping)
|
||||
|
||||
# quantization configs
|
||||
# Quantization configs
|
||||
quant_config: QuantConfig = field(default_factory=QuantConfig)
|
||||
# TODO(qijun): support per linear layer quantization
|
||||
# Per linear layer quantization in quant_cfg.json or hf_quant_config.json
|
||||
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
|
||||
# Delay weights creation to DecoderModelForCausalLM.__post_init__
|
||||
# to support mixed quantization.
|
||||
@ -278,28 +278,41 @@ class ModelConfig(Generic[TConfig]):
|
||||
'exclude_modules', None)
|
||||
|
||||
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
|
||||
mixed_quant_config_file = transformers.utils.hub.cached_file(
|
||||
checkpoint_dir, 'quant_cfg.json')
|
||||
with open(mixed_quant_config_file) as fm:
|
||||
mixed_quant_configs = json.load(fm)
|
||||
# kv_cache_quant_algo is global regardless of MIXED_PRECISION
|
||||
kv_cache_quant_algo = mixed_quant_configs['kv_cache_quant_algo']
|
||||
mixed_quant_configs = mixed_quant_configs['quantized_layers']
|
||||
if kv_cache_quant_algo is not None and quant_config.kv_cache_quant_algo is not None:
|
||||
if kv_cache_quant_algo != quant_config.kv_cache_quant_algo:
|
||||
raise RuntimeError(
|
||||
f"The kvcache config in 'quant_cfg.json', {kv_cache_quant_algo},"
|
||||
f"is different from 'hf_quant_config.json', {quant_config.kv_cache_quant_algo}!"
|
||||
)
|
||||
kv_cache_quant_algo = kv_cache_quant_algo or quant_config.kv_cache_quant_algo
|
||||
|
||||
for layer in mixed_quant_configs:
|
||||
config = QuantConfig()
|
||||
config.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
config.quant_algo = mixed_quant_configs[layer]['quant_algo']
|
||||
config.group_size = mixed_quant_configs[layer].get(
|
||||
'group_size', None)
|
||||
mixed_quant_configs[layer] = config
|
||||
json_extended_quant_configs: dict = {}
|
||||
# See tests/unittest/llmapi/test_llm_quant.py
|
||||
try:
|
||||
mixed_quant_config_file = transformers.utils.hub.cached_file(
|
||||
checkpoint_dir, 'quant_cfg.json')
|
||||
with open(mixed_quant_config_file) as fm:
|
||||
json_extended_quant_configs = json.load(fm)
|
||||
except Exception:
|
||||
logger.info(
|
||||
f"No quant_cfg.json found for layer quant info, using hf_quant_config.json."
|
||||
)
|
||||
json_quant_configs.update(json_extended_quant_configs)
|
||||
# kv_cache_quant_algo is global regardless of MIXED_PRECISION
|
||||
kv_cache_quant_algo = json_quant_configs.get(
|
||||
'kv_cache_quant_algo', None)
|
||||
mixed_quant_configs = json_quant_configs.get(
|
||||
'quantized_layers', None)
|
||||
if (kv_quant_lhs := json_extended_quant_configs.get(
|
||||
"kv_cache_quant_algo", None)) is not None and (
|
||||
kv_quant_rhs :=
|
||||
quant_config.kv_cache_quant_algo) is not None:
|
||||
if kv_quant_lhs != kv_quant_rhs:
|
||||
raise RuntimeError(
|
||||
f"The kvcache config in 'quant_cfg.json', {kv_quant_lhs},"
|
||||
f"is different from 'hf_quant_config.json', {kv_quant_rhs}!"
|
||||
)
|
||||
quant_config.kv_cache_quant_algo = json_quant_configs[
|
||||
"kv_cache_quant_algo"]
|
||||
for layer in mixed_quant_configs:
|
||||
config = QuantConfig()
|
||||
config.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
config.quant_algo = mixed_quant_configs[layer]['quant_algo']
|
||||
config.group_size = mixed_quant_configs[layer].get(
|
||||
'group_size', None)
|
||||
mixed_quant_configs[layer] = config
|
||||
layer_quant_config = mixed_quant_configs
|
||||
elif quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
|
||||
if quant_config.group_size is None:
|
||||
@ -459,6 +472,9 @@ class ModelConfig(Generic[TConfig]):
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
# Some checkpoints lack torch_dtype, populate with dtype
|
||||
pretrained_config.torch_dtype = getattr(pretrained_config, 'dtype',
|
||||
None)
|
||||
quant_config = QuantConfig()
|
||||
layer_quant_config = None
|
||||
moe_backend = kwargs.get('moe_backend', 'CUTLASS')
|
||||
|
||||
@ -390,7 +390,9 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
for key, value in hf_quant_config.items():
|
||||
logger.info(f"Setting {key}={value} from HF quant config.")
|
||||
logger.info(
|
||||
f"Setting {key}={str(value)[:100]}{'...' if len(str(value)) > 100 else ''} from HF quant config."
|
||||
)
|
||||
setattr(quant_config, key, value)
|
||||
|
||||
# Update the quant_config in llm_args for pytorch
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._tensorrt_engine import LLM
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
|
||||
from tensorrt_llm.llmapi.llm_utils import CalibConfig, QuantAlgo, QuantConfig
|
||||
|
||||
@ -71,6 +76,105 @@ def test_llm_fp8_quantization_modelOpt_ckpt():
|
||||
assert output.outputs[0].text == " D E F G H I"
|
||||
|
||||
|
||||
def test_quant_cfg_from_quant_cfg_json():
|
||||
"""
|
||||
Test loading MIXED_PRECISION config from quant_cfg.json with per-layer quantization.
|
||||
This supports the workflow from examples/quantization/quantize_mixed_precision_moe.py.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model_dir = Path(tmp_dir)
|
||||
|
||||
# Create dummy quant_cfg.json
|
||||
quant_cfg_content = {
|
||||
"quant_algo": "MIXED_PRECISION",
|
||||
"kv_cache_quant_algo": "FP8",
|
||||
"quantized_layers": {
|
||||
"model.layers.0.self_attn.q_proj": {
|
||||
"quant_algo": "FP8"
|
||||
},
|
||||
"model.layers.0.self_attn.k_proj": {
|
||||
"quant_algo": "FP8"
|
||||
},
|
||||
"model.layers.1.mlp.gate_proj": {
|
||||
"quant_algo": "W4A8_AWQ",
|
||||
"group_size": 128
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
quant_cfg_file = model_dir / "quant_cfg.json"
|
||||
with open(quant_cfg_file, 'w') as f:
|
||||
json.dump(quant_cfg_content, f)
|
||||
|
||||
# Create dummy hf_quant_config.json
|
||||
hf_quant_config_content = {
|
||||
"quantization": {
|
||||
"quant_algo": "MIXED_PRECISION",
|
||||
"kv_cache_quant_algo": None,
|
||||
}
|
||||
}
|
||||
|
||||
hf_quant_config_file = model_dir / "hf_quant_config.json"
|
||||
with open(hf_quant_config_file, 'w') as f:
|
||||
json.dump(hf_quant_config_content, f)
|
||||
|
||||
quant_config, layer_quant_config = ModelConfig.load_modelopt_quant_config(
|
||||
hf_quant_config_file, model_dir, None)
|
||||
|
||||
# Verify quant_cfg.json was loaded
|
||||
assert quant_config.quant_algo == QuantAlgo.MIXED_PRECISION
|
||||
assert quant_config.kv_cache_quant_algo == "FP8"
|
||||
|
||||
# Verify layer configs were created correctly
|
||||
assert layer_quant_config[
|
||||
"model.layers.0.self_attn.q_proj"].quant_algo == "FP8"
|
||||
assert layer_quant_config[
|
||||
"model.layers.0.self_attn.q_proj"].kv_cache_quant_algo == "FP8"
|
||||
assert layer_quant_config[
|
||||
"model.layers.1.mlp.gate_proj"].quant_algo == "W4A8_AWQ"
|
||||
assert layer_quant_config[
|
||||
"model.layers.1.mlp.gate_proj"].group_size == 128
|
||||
|
||||
|
||||
def test_quant_cfg_from_hf_quant_config():
|
||||
"""Test fallback to hf_quant_config.json when quant_cfg.json is missing."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model_dir = Path(tmp_dir)
|
||||
|
||||
# Create dummy hf_quant_config.json
|
||||
hf_quant_config_content = {
|
||||
"quantization": {
|
||||
"quant_algo": "MIXED_PRECISION",
|
||||
"kv_cache_quant_algo": "FP8",
|
||||
"quantized_layers": {
|
||||
"model.layers.0.self_attn.q_proj": {
|
||||
"quant_algo": "FP8"
|
||||
},
|
||||
"model.layers.0.mlp.up_proj": {
|
||||
"quant_algo": "W4A16_AWQ",
|
||||
"group_size": 64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
hf_quant_config_file = model_dir / "hf_quant_config.json"
|
||||
with open(hf_quant_config_file, 'w') as f:
|
||||
json.dump(hf_quant_config_content, f)
|
||||
quant_config, layer_quant_config = ModelConfig.load_modelopt_quant_config(
|
||||
hf_quant_config_file, model_dir, None)
|
||||
|
||||
# Verify layer configs
|
||||
assert quant_config.quant_algo == QuantAlgo.MIXED_PRECISION
|
||||
assert quant_config.kv_cache_quant_algo == "FP8"
|
||||
assert layer_quant_config[
|
||||
"model.layers.0.self_attn.q_proj"].quant_algo == "FP8"
|
||||
assert layer_quant_config[
|
||||
"model.layers.0.mlp.up_proj"].quant_algo == "W4A16_AWQ"
|
||||
assert layer_quant_config["model.layers.0.mlp.up_proj"].group_size == 64
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llm_int4_awq_quantization()
|
||||
test_llm_fp8_quantization_modelOpt_ckpt()
|
||||
test_quant_cfg_from_quant_cfg_json()
|
||||
test_quant_cfg_from_hf_quant_config()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user