mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-9771][feat] Allow overriding quantization configs (#11062)
Signed-off-by: shuyixiong <219646547+shuyixiong@users.noreply.github.com>
This commit is contained in:
parent
d1e4527c06
commit
278ced972b
@ -554,13 +554,7 @@ class ModelConfig(Generic[TConfig]):
|
||||
update_dict: Dictionary with values to update in the config
|
||||
"""
|
||||
for key, value_new in update_dict.items():
|
||||
if not hasattr(config, key):
|
||||
logger.warning(
|
||||
f"model_kwargs key '{key}' not found in pretrained_config, ignoring."
|
||||
)
|
||||
continue
|
||||
|
||||
target_value = getattr(config, key)
|
||||
target_value = getattr(config, key, None)
|
||||
|
||||
# Handle nested PretrainedConfig objects when value is a dict
|
||||
if isinstance(value_new, dict) and isinstance(
|
||||
|
||||
@ -386,82 +386,91 @@ class ModelLoader:
|
||||
return True
|
||||
|
||||
hf_config_path = f"{self._model_dir}/config.json"
|
||||
hf_quant_config = None
|
||||
if os.path.exists(hf_config_path):
|
||||
with open(hf_config_path, "r") as f:
|
||||
hf_config = json.load(f)
|
||||
hf_quant_config = hf_config.get("quantization_config", None)
|
||||
if hf_quant_config is not None:
|
||||
logger.info(
|
||||
f"Found quantization_config field in {hf_config_path}, pre-quantized checkpoint is used."
|
||||
)
|
||||
if self.llm_args.model_kwargs is not None and "quantization_config" in self.llm_args.model_kwargs:
|
||||
logger.info(
|
||||
f"Update hf_quant_config from model_kwargs: quantization_config={self.llm_args.model_kwargs['quantization_config']} (previous value: {hf_quant_config})"
|
||||
)
|
||||
hf_quant_config = self.llm_args.model_kwargs["quantization_config"]
|
||||
elif hf_quant_config is not None:
|
||||
logger.info(
|
||||
f"Use quantization_config from {hf_config_path}: quantization_config={hf_quant_config}"
|
||||
)
|
||||
|
||||
if hf_quant_config is not None:
|
||||
logger.info(
|
||||
f"Found quantization_config field in {hf_config_path}, pre-quantized checkpoint is used."
|
||||
)
|
||||
# DeepSeek V3 FP8 ckpt
|
||||
if hf_quant_config.get(
|
||||
"quant_method") == "fp8" and hf_quant_config.get(
|
||||
"weight_block_size"):
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
quant_config.exclude_modules = ["*eh_proj"]
|
||||
elif hf_quant_config.get("quant_method") == "mxfp4":
|
||||
from .._torch.model_config import ModelConfig
|
||||
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
|
||||
self.llm_args.moe_config.backend)
|
||||
quant_config.group_size = 32
|
||||
quant_config.exclude_modules = [
|
||||
'block.*.attn.out', 'block.*.mlp.gate',
|
||||
'block.*.attn.qkv', 'embedding', 'unembedding'
|
||||
]
|
||||
# NOTE: This is for llm-compressor's quantized checkpoints.
|
||||
elif hf_quant_config.get(
|
||||
"quant_method") == "compressed-tensors":
|
||||
config_groups = hf_quant_config.get("config_groups")
|
||||
if config_groups is None:
|
||||
raise ValueError(
|
||||
f"config_groups is not set in {hf_quant_config}.")
|
||||
if hf_quant_config is not None:
|
||||
# DeepSeek V3 FP8 ckpt
|
||||
if hf_quant_config.get(
|
||||
"quant_method") == "fp8" and hf_quant_config.get(
|
||||
"weight_block_size"):
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
quant_config.exclude_modules = ["*eh_proj"]
|
||||
elif hf_quant_config.get("quant_method") == "mxfp4":
|
||||
from .._torch.model_config import ModelConfig
|
||||
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
|
||||
self.llm_args.moe_config.backend)
|
||||
quant_config.group_size = 32
|
||||
quant_config.exclude_modules = [
|
||||
'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv',
|
||||
'embedding', 'unembedding'
|
||||
]
|
||||
# NOTE: This is for llm-compressor's quantized checkpoints.
|
||||
elif hf_quant_config.get("quant_method") == "compressed-tensors":
|
||||
config_groups = hf_quant_config.get("config_groups")
|
||||
if config_groups is None:
|
||||
raise ValueError(
|
||||
f"config_groups is not set in {hf_quant_config}.")
|
||||
|
||||
weights_quant_config = config_groups["group_0"]["weights"]
|
||||
inputs_quant_config = config_groups["group_0"][
|
||||
"input_activations"]
|
||||
weights_quant_strategy = weights_quant_config["strategy"]
|
||||
inputs_quant_strategy = inputs_quant_config["strategy"]
|
||||
weights_quant_config = config_groups["group_0"]["weights"]
|
||||
inputs_quant_config = config_groups["group_0"][
|
||||
"input_activations"]
|
||||
weights_quant_strategy = weights_quant_config["strategy"]
|
||||
inputs_quant_strategy = inputs_quant_config["strategy"]
|
||||
|
||||
if weights_quant_config["num_bits"] == 8:
|
||||
if weights_quant_strategy == "channel":
|
||||
if inputs_quant_strategy != "token":
|
||||
raise ValueError(
|
||||
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
|
||||
)
|
||||
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
|
||||
elif weights_quant_strategy == "block":
|
||||
if inputs_quant_strategy != "group":
|
||||
raise ValueError(
|
||||
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
|
||||
)
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
group_size = inputs_quant_config["group_size"]
|
||||
|
||||
# NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES.
|
||||
if group_size != 128:
|
||||
raise ValueError(
|
||||
f"Unsupported group_size: {group_size}. Supported: 128."
|
||||
)
|
||||
quant_config.group_size = group_size
|
||||
|
||||
else:
|
||||
if weights_quant_config["num_bits"] == 8:
|
||||
if weights_quant_strategy == "channel":
|
||||
if inputs_quant_strategy != "token":
|
||||
raise ValueError(
|
||||
f"Unsupported weights_quant_strategy: {weights_quant_strategy}. "
|
||||
"Supported strategies: 'channel', 'block'.")
|
||||
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
|
||||
)
|
||||
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
|
||||
elif weights_quant_strategy == "block":
|
||||
if inputs_quant_strategy != "group":
|
||||
raise ValueError(
|
||||
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
|
||||
)
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
group_size = inputs_quant_config["group_size"]
|
||||
|
||||
# NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES.
|
||||
if group_size != 128:
|
||||
raise ValueError(
|
||||
f"Unsupported group_size: {group_size}. Supported: 128."
|
||||
)
|
||||
quant_config.group_size = group_size
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_bits: {weights_quant_config['num_bits']}. "
|
||||
"Supported: 8.")
|
||||
|
||||
quant_config.exclude_modules = hf_quant_config.get(
|
||||
"ignore", [])
|
||||
f"Unsupported weights_quant_strategy: {weights_quant_strategy}. "
|
||||
"Supported strategies: 'channel', 'block'.")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported quantization_config: {hf_quant_config}.")
|
||||
raise ValueError(
|
||||
f"Unsupported quant_bits: {weights_quant_config['num_bits']}. "
|
||||
"Supported: 8.")
|
||||
|
||||
return True
|
||||
quant_config.exclude_modules = hf_quant_config.get("ignore", [])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported quantization_config: {hf_quant_config}.")
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
@ -123,34 +119,6 @@ def run_generate(
|
||||
return llm_logits, ref_logits
|
||||
|
||||
|
||||
def process_and_copy_folder(src_folder, dst_folder, num_hidden_layers: int = 4):
|
||||
if os.path.exists(dst_folder):
|
||||
shutil.rmtree(dst_folder)
|
||||
os.makedirs(dst_folder)
|
||||
|
||||
for root, dirs, files in os.walk(src_folder):
|
||||
rel_path = os.path.relpath(root, src_folder)
|
||||
dest_dir = os.path.join(dst_folder, rel_path)
|
||||
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
|
||||
for file in files:
|
||||
src_path = os.path.join(root, file)
|
||||
dest_path = os.path.join(dest_dir, file)
|
||||
if "safetensor" in file:
|
||||
continue
|
||||
|
||||
if file == "config.json":
|
||||
with open(src_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
config["num_hidden_layers"] = num_hidden_layers
|
||||
with open(dest_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
shutil.copy2(src_path, dest_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir",
|
||||
[
|
||||
@ -163,44 +131,40 @@ def process_and_copy_folder(src_folder, dst_folder, num_hidden_layers: int = 4):
|
||||
],
|
||||
)
|
||||
def test_llm_update_weights(model_dir):
|
||||
"""Test LLM update_weights with both serialized and direct IPC handle formats."""
|
||||
model_dir = str(llm_models_root() / model_dir)
|
||||
with TemporaryDirectory() as tmp_model_dir:
|
||||
num_hidden_layers = 1
|
||||
process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers)
|
||||
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
llm = LLM(
|
||||
model=tmp_model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=1,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
num_hidden_layers = 1
|
||||
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=1,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
model_kwargs={"num_hidden_layers": num_hidden_layers},
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
prompts_texts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
|
||||
del tokenizer
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, return_generation_logits=True, max_tokens=1024
|
||||
)
|
||||
# Generate texts from the prompts.
|
||||
prompts_texts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
|
||||
del tokenizer
|
||||
sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024)
|
||||
|
||||
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0])
|
||||
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0])
|
||||
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -216,59 +180,107 @@ def test_llm_update_weights(model_dir):
|
||||
)
|
||||
def test_llm_partial_update_weights(model_dir):
|
||||
model_dir = str(llm_models_root() / model_dir)
|
||||
with TemporaryDirectory() as tmp_model_dir:
|
||||
num_hidden_layers = 1
|
||||
process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers)
|
||||
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
num_hidden_layers = 1
|
||||
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
|
||||
llm = LLM(
|
||||
model=tmp_model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=1,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=1,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
model_kwargs={"num_hidden_layers": num_hidden_layers},
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
prompts_texts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
|
||||
del tokenizer
|
||||
# Generate texts from the prompts.
|
||||
prompts_texts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
|
||||
del tokenizer
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, return_generation_logits=True, max_tokens=1024
|
||||
)
|
||||
sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024)
|
||||
|
||||
def common_filter(filter_name: str) -> Callable[[str], bool]:
|
||||
def filter_fn(name: str) -> bool:
|
||||
return name.endswith(filter_name)
|
||||
def common_filter(filter_name: str) -> Callable[[str], bool]:
|
||||
def filter_fn(name: str) -> bool:
|
||||
return name.endswith(filter_name)
|
||||
|
||||
return filter_fn
|
||||
return filter_fn
|
||||
|
||||
# Generate filter_list from model weight keys by removing layer prefix
|
||||
# e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight"
|
||||
layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.")
|
||||
filter_set = set()
|
||||
for name, _ in hf_model.all_weights[hf_model.device_id]:
|
||||
suffix = layer_prefix_pattern.sub("", name)
|
||||
filter_set.add(suffix)
|
||||
filter_list = list(filter_set)
|
||||
# Generate filter_list from model weight keys by removing layer prefix
|
||||
# e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight"
|
||||
layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.")
|
||||
filter_set = set()
|
||||
for name, _ in hf_model.all_weights[hf_model.device_id]:
|
||||
suffix = layer_prefix_pattern.sub("", name)
|
||||
filter_set.add(suffix)
|
||||
filter_list = list(filter_set)
|
||||
|
||||
for filter_name in filter_list:
|
||||
weight_filter = common_filter(filter_name=filter_name)
|
||||
ipc_handles = hf_model.get_weight_ipc_handles_serialized(
|
||||
[0], weight_filter=weight_filter
|
||||
)
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
for filter_name in filter_list:
|
||||
weight_filter = common_filter(filter_name=filter_name)
|
||||
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0], weight_filter=weight_filter)
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir, fp8_model_dir",
|
||||
[
|
||||
("Qwen3/Qwen3-8B", "Qwen3/Qwen3-8B-FP8"),
|
||||
("Qwen3/Qwen3-30B-A3B", "Qwen3/Qwen3-30B-A3B-FP8"),
|
||||
],
|
||||
)
|
||||
def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir):
|
||||
model_dir = str(llm_models_root() / model_dir)
|
||||
fp8_model_dir = str(llm_models_root() / fp8_model_dir)
|
||||
num_hidden_layers = 1
|
||||
hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=1,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
model_kwargs={
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"quantization_config": {
|
||||
"activation_scheme": "dynamic",
|
||||
"fmt": "e4m3",
|
||||
"quant_method": "fp8",
|
||||
"weight_block_size": [128, 128],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
prompts_texts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
|
||||
del tokenizer
|
||||
sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024)
|
||||
|
||||
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0])
|
||||
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user