[TRTLLM-9771][feat] Allow overriding quantization configs (#11062)

Signed-off-by: shuyixiong <219646547+shuyixiong@users.noreply.github.com>
This commit is contained in:
shuyixiong 2026-01-31 23:48:51 +08:00 committed by GitHub
parent d1e4527c06
commit 278ced972b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 199 additions and 184 deletions

View File

@ -554,13 +554,7 @@ class ModelConfig(Generic[TConfig]):
update_dict: Dictionary with values to update in the config
"""
for key, value_new in update_dict.items():
if not hasattr(config, key):
logger.warning(
f"model_kwargs key '{key}' not found in pretrained_config, ignoring."
)
continue
target_value = getattr(config, key)
target_value = getattr(config, key, None)
# Handle nested PretrainedConfig objects when value is a dict
if isinstance(value_new, dict) and isinstance(

View File

@ -386,82 +386,91 @@ class ModelLoader:
return True
hf_config_path = f"{self._model_dir}/config.json"
hf_quant_config = None
if os.path.exists(hf_config_path):
with open(hf_config_path, "r") as f:
hf_config = json.load(f)
hf_quant_config = hf_config.get("quantization_config", None)
if hf_quant_config is not None:
logger.info(
f"Found quantization_config field in {hf_config_path}, pre-quantized checkpoint is used."
)
if self.llm_args.model_kwargs is not None and "quantization_config" in self.llm_args.model_kwargs:
logger.info(
f"Update hf_quant_config from model_kwargs: quantization_config={self.llm_args.model_kwargs['quantization_config']} (previous value: {hf_quant_config})"
)
hf_quant_config = self.llm_args.model_kwargs["quantization_config"]
elif hf_quant_config is not None:
logger.info(
f"Use quantization_config from {hf_config_path}: quantization_config={hf_quant_config}"
)
if hf_quant_config is not None:
logger.info(
f"Found quantization_config field in {hf_config_path}, pre-quantized checkpoint is used."
)
# DeepSeek V3 FP8 ckpt
if hf_quant_config.get(
"quant_method") == "fp8" and hf_quant_config.get(
"weight_block_size"):
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
quant_config.exclude_modules = ["*eh_proj"]
elif hf_quant_config.get("quant_method") == "mxfp4":
from .._torch.model_config import ModelConfig
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
self.llm_args.moe_config.backend)
quant_config.group_size = 32
quant_config.exclude_modules = [
'block.*.attn.out', 'block.*.mlp.gate',
'block.*.attn.qkv', 'embedding', 'unembedding'
]
# NOTE: This is for llm-compressor's quantized checkpoints.
elif hf_quant_config.get(
"quant_method") == "compressed-tensors":
config_groups = hf_quant_config.get("config_groups")
if config_groups is None:
raise ValueError(
f"config_groups is not set in {hf_quant_config}.")
if hf_quant_config is not None:
# DeepSeek V3 FP8 ckpt
if hf_quant_config.get(
"quant_method") == "fp8" and hf_quant_config.get(
"weight_block_size"):
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
quant_config.exclude_modules = ["*eh_proj"]
elif hf_quant_config.get("quant_method") == "mxfp4":
from .._torch.model_config import ModelConfig
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
self.llm_args.moe_config.backend)
quant_config.group_size = 32
quant_config.exclude_modules = [
'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv',
'embedding', 'unembedding'
]
# NOTE: This is for llm-compressor's quantized checkpoints.
elif hf_quant_config.get("quant_method") == "compressed-tensors":
config_groups = hf_quant_config.get("config_groups")
if config_groups is None:
raise ValueError(
f"config_groups is not set in {hf_quant_config}.")
weights_quant_config = config_groups["group_0"]["weights"]
inputs_quant_config = config_groups["group_0"][
"input_activations"]
weights_quant_strategy = weights_quant_config["strategy"]
inputs_quant_strategy = inputs_quant_config["strategy"]
weights_quant_config = config_groups["group_0"]["weights"]
inputs_quant_config = config_groups["group_0"][
"input_activations"]
weights_quant_strategy = weights_quant_config["strategy"]
inputs_quant_strategy = inputs_quant_config["strategy"]
if weights_quant_config["num_bits"] == 8:
if weights_quant_strategy == "channel":
if inputs_quant_strategy != "token":
raise ValueError(
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
)
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
elif weights_quant_strategy == "block":
if inputs_quant_strategy != "group":
raise ValueError(
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
)
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
group_size = inputs_quant_config["group_size"]
# NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES.
if group_size != 128:
raise ValueError(
f"Unsupported group_size: {group_size}. Supported: 128."
)
quant_config.group_size = group_size
else:
if weights_quant_config["num_bits"] == 8:
if weights_quant_strategy == "channel":
if inputs_quant_strategy != "token":
raise ValueError(
f"Unsupported weights_quant_strategy: {weights_quant_strategy}. "
"Supported strategies: 'channel', 'block'.")
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
)
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
elif weights_quant_strategy == "block":
if inputs_quant_strategy != "group":
raise ValueError(
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
)
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
group_size = inputs_quant_config["group_size"]
# NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES.
if group_size != 128:
raise ValueError(
f"Unsupported group_size: {group_size}. Supported: 128."
)
quant_config.group_size = group_size
else:
raise ValueError(
f"Unsupported quant_bits: {weights_quant_config['num_bits']}. "
"Supported: 8.")
quant_config.exclude_modules = hf_quant_config.get(
"ignore", [])
f"Unsupported weights_quant_strategy: {weights_quant_strategy}. "
"Supported strategies: 'channel', 'block'.")
else:
raise NotImplementedError(
f"Unsupported quantization_config: {hf_quant_config}.")
raise ValueError(
f"Unsupported quant_bits: {weights_quant_config['num_bits']}. "
"Supported: 8.")
return True
quant_config.exclude_modules = hf_quant_config.get("ignore", [])
else:
raise NotImplementedError(
f"Unsupported quantization_config: {hf_quant_config}.")
return True
return False

View File

@ -1,10 +1,6 @@
import base64
import json
import os
import pickle
import re
import shutil
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Tuple
import pytest
@ -123,34 +119,6 @@ def run_generate(
return llm_logits, ref_logits
def process_and_copy_folder(src_folder, dst_folder, num_hidden_layers: int = 4):
if os.path.exists(dst_folder):
shutil.rmtree(dst_folder)
os.makedirs(dst_folder)
for root, dirs, files in os.walk(src_folder):
rel_path = os.path.relpath(root, src_folder)
dest_dir = os.path.join(dst_folder, rel_path)
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
for file in files:
src_path = os.path.join(root, file)
dest_path = os.path.join(dest_dir, file)
if "safetensor" in file:
continue
if file == "config.json":
with open(src_path, "r", encoding="utf-8") as f:
config = json.load(f)
config["num_hidden_layers"] = num_hidden_layers
with open(dest_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
else:
shutil.copy2(src_path, dest_path)
@pytest.mark.parametrize(
"model_dir",
[
@ -163,44 +131,40 @@ def process_and_copy_folder(src_folder, dst_folder, num_hidden_layers: int = 4):
],
)
def test_llm_update_weights(model_dir):
"""Test LLM update_weights with both serialized and direct IPC handle formats."""
model_dir = str(llm_models_root() / model_dir)
with TemporaryDirectory() as tmp_model_dir:
num_hidden_layers = 1
process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers)
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
llm = LLM(
model=tmp_model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=1,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
)
num_hidden_layers = 1
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
llm = LLM(
model=model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=1,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
model_kwargs={"num_hidden_layers": num_hidden_layers},
)
# Generate texts from the prompts.
prompts_texts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
del tokenizer
sampling_params = SamplingParams(
temperature=0, return_generation_logits=True, max_tokens=1024
)
# Generate texts from the prompts.
prompts_texts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
del tokenizer
sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024)
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0])
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0])
llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)
@pytest.mark.parametrize(
@ -216,59 +180,107 @@ def test_llm_update_weights(model_dir):
)
def test_llm_partial_update_weights(model_dir):
model_dir = str(llm_models_root() / model_dir)
with TemporaryDirectory() as tmp_model_dir:
num_hidden_layers = 1
process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers)
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
num_hidden_layers = 1
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
llm = LLM(
model=tmp_model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=1,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
)
llm = LLM(
model=model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=1,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
model_kwargs={"num_hidden_layers": num_hidden_layers},
)
# Generate texts from the prompts.
prompts_texts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
del tokenizer
# Generate texts from the prompts.
prompts_texts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
del tokenizer
sampling_params = SamplingParams(
temperature=0, return_generation_logits=True, max_tokens=1024
)
sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024)
def common_filter(filter_name: str) -> Callable[[str], bool]:
def filter_fn(name: str) -> bool:
return name.endswith(filter_name)
def common_filter(filter_name: str) -> Callable[[str], bool]:
def filter_fn(name: str) -> bool:
return name.endswith(filter_name)
return filter_fn
return filter_fn
# Generate filter_list from model weight keys by removing layer prefix
# e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight"
layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.")
filter_set = set()
for name, _ in hf_model.all_weights[hf_model.device_id]:
suffix = layer_prefix_pattern.sub("", name)
filter_set.add(suffix)
filter_list = list(filter_set)
# Generate filter_list from model weight keys by removing layer prefix
# e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight"
layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.")
filter_set = set()
for name, _ in hf_model.all_weights[hf_model.device_id]:
suffix = layer_prefix_pattern.sub("", name)
filter_set.add(suffix)
filter_list = list(filter_set)
for filter_name in filter_list:
weight_filter = common_filter(filter_name=filter_name)
ipc_handles = hf_model.get_weight_ipc_handles_serialized(
[0], weight_filter=weight_filter
)
llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
for filter_name in filter_list:
weight_filter = common_filter(filter_name=filter_name)
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0], weight_filter=weight_filter)
llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)
@pytest.mark.parametrize(
"model_dir, fp8_model_dir",
[
("Qwen3/Qwen3-8B", "Qwen3/Qwen3-8B-FP8"),
("Qwen3/Qwen3-30B-A3B", "Qwen3/Qwen3-30B-A3B-FP8"),
],
)
def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir):
model_dir = str(llm_models_root() / model_dir)
fp8_model_dir = str(llm_models_root() / fp8_model_dir)
num_hidden_layers = 1
hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
llm = LLM(
model=model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=1,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
model_kwargs={
"num_hidden_layers": num_hidden_layers,
"quantization_config": {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [128, 128],
},
},
)
# Generate texts from the prompts.
prompts_texts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
del tokenizer
sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024)
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0])
llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)