diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index b98aa10055..5da1629876 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -554,13 +554,7 @@ class ModelConfig(Generic[TConfig]): update_dict: Dictionary with values to update in the config """ for key, value_new in update_dict.items(): - if not hasattr(config, key): - logger.warning( - f"model_kwargs key '{key}' not found in pretrained_config, ignoring." - ) - continue - - target_value = getattr(config, key) + target_value = getattr(config, key, None) # Handle nested PretrainedConfig objects when value is a dict if isinstance(value_new, dict) and isinstance( diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index 83e26ce93e..0a908ff31e 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -386,82 +386,91 @@ class ModelLoader: return True hf_config_path = f"{self._model_dir}/config.json" + hf_quant_config = None if os.path.exists(hf_config_path): with open(hf_config_path, "r") as f: hf_config = json.load(f) hf_quant_config = hf_config.get("quantization_config", None) + if hf_quant_config is not None: + logger.info( + f"Found quantization_config field in {hf_config_path}, pre-quantized checkpoint is used." + ) + if self.llm_args.model_kwargs is not None and "quantization_config" in self.llm_args.model_kwargs: + logger.info( + f"Update hf_quant_config from model_kwargs: quantization_config={self.llm_args.model_kwargs['quantization_config']} (previous value: {hf_quant_config})" + ) + hf_quant_config = self.llm_args.model_kwargs["quantization_config"] + elif hf_quant_config is not None: + logger.info( + f"Use quantization_config from {hf_config_path}: quantization_config={hf_quant_config}" + ) - if hf_quant_config is not None: - logger.info( - f"Found quantization_config field in {hf_config_path}, pre-quantized checkpoint is used." - ) - # DeepSeek V3 FP8 ckpt - if hf_quant_config.get( - "quant_method") == "fp8" and hf_quant_config.get( - "weight_block_size"): - quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES - quant_config.exclude_modules = ["*eh_proj"] - elif hf_quant_config.get("quant_method") == "mxfp4": - from .._torch.model_config import ModelConfig - quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( - self.llm_args.moe_config.backend) - quant_config.group_size = 32 - quant_config.exclude_modules = [ - 'block.*.attn.out', 'block.*.mlp.gate', - 'block.*.attn.qkv', 'embedding', 'unembedding' - ] - # NOTE: This is for llm-compressor's quantized checkpoints. - elif hf_quant_config.get( - "quant_method") == "compressed-tensors": - config_groups = hf_quant_config.get("config_groups") - if config_groups is None: - raise ValueError( - f"config_groups is not set in {hf_quant_config}.") + if hf_quant_config is not None: + # DeepSeek V3 FP8 ckpt + if hf_quant_config.get( + "quant_method") == "fp8" and hf_quant_config.get( + "weight_block_size"): + quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES + quant_config.exclude_modules = ["*eh_proj"] + elif hf_quant_config.get("quant_method") == "mxfp4": + from .._torch.model_config import ModelConfig + quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( + self.llm_args.moe_config.backend) + quant_config.group_size = 32 + quant_config.exclude_modules = [ + 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', + 'embedding', 'unembedding' + ] + # NOTE: This is for llm-compressor's quantized checkpoints. + elif hf_quant_config.get("quant_method") == "compressed-tensors": + config_groups = hf_quant_config.get("config_groups") + if config_groups is None: + raise ValueError( + f"config_groups is not set in {hf_quant_config}.") - weights_quant_config = config_groups["group_0"]["weights"] - inputs_quant_config = config_groups["group_0"][ - "input_activations"] - weights_quant_strategy = weights_quant_config["strategy"] - inputs_quant_strategy = inputs_quant_config["strategy"] + weights_quant_config = config_groups["group_0"]["weights"] + inputs_quant_config = config_groups["group_0"][ + "input_activations"] + weights_quant_strategy = weights_quant_config["strategy"] + inputs_quant_strategy = inputs_quant_config["strategy"] - if weights_quant_config["num_bits"] == 8: - if weights_quant_strategy == "channel": - if inputs_quant_strategy != "token": - raise ValueError( - f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}." - ) - quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN - elif weights_quant_strategy == "block": - if inputs_quant_strategy != "group": - raise ValueError( - f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}." - ) - quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES - group_size = inputs_quant_config["group_size"] - - # NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES. - if group_size != 128: - raise ValueError( - f"Unsupported group_size: {group_size}. Supported: 128." - ) - quant_config.group_size = group_size - - else: + if weights_quant_config["num_bits"] == 8: + if weights_quant_strategy == "channel": + if inputs_quant_strategy != "token": raise ValueError( - f"Unsupported weights_quant_strategy: {weights_quant_strategy}. " - "Supported strategies: 'channel', 'block'.") + f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}." + ) + quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN + elif weights_quant_strategy == "block": + if inputs_quant_strategy != "group": + raise ValueError( + f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}." + ) + quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES + group_size = inputs_quant_config["group_size"] + + # NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES. + if group_size != 128: + raise ValueError( + f"Unsupported group_size: {group_size}. Supported: 128." + ) + quant_config.group_size = group_size + else: raise ValueError( - f"Unsupported quant_bits: {weights_quant_config['num_bits']}. " - "Supported: 8.") - - quant_config.exclude_modules = hf_quant_config.get( - "ignore", []) + f"Unsupported weights_quant_strategy: {weights_quant_strategy}. " + "Supported strategies: 'channel', 'block'.") else: - raise NotImplementedError( - f"Unsupported quantization_config: {hf_quant_config}.") + raise ValueError( + f"Unsupported quant_bits: {weights_quant_config['num_bits']}. " + "Supported: 8.") - return True + quant_config.exclude_modules = hf_quant_config.get("ignore", []) + else: + raise NotImplementedError( + f"Unsupported quantization_config: {hf_quant_config}.") + + return True return False diff --git a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py index 1955e5c509..073067907a 100644 --- a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py +++ b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py @@ -1,10 +1,6 @@ import base64 -import json -import os import pickle import re -import shutil -from tempfile import TemporaryDirectory from typing import Callable, List, Optional, Tuple import pytest @@ -123,34 +119,6 @@ def run_generate( return llm_logits, ref_logits -def process_and_copy_folder(src_folder, dst_folder, num_hidden_layers: int = 4): - if os.path.exists(dst_folder): - shutil.rmtree(dst_folder) - os.makedirs(dst_folder) - - for root, dirs, files in os.walk(src_folder): - rel_path = os.path.relpath(root, src_folder) - dest_dir = os.path.join(dst_folder, rel_path) - - if not os.path.exists(dest_dir): - os.makedirs(dest_dir) - - for file in files: - src_path = os.path.join(root, file) - dest_path = os.path.join(dest_dir, file) - if "safetensor" in file: - continue - - if file == "config.json": - with open(src_path, "r", encoding="utf-8") as f: - config = json.load(f) - config["num_hidden_layers"] = num_hidden_layers - with open(dest_path, "w", encoding="utf-8") as f: - json.dump(config, f, indent=2, ensure_ascii=False) - else: - shutil.copy2(src_path, dest_path) - - @pytest.mark.parametrize( "model_dir", [ @@ -163,44 +131,40 @@ def process_and_copy_folder(src_folder, dst_folder, num_hidden_layers: int = 4): ], ) def test_llm_update_weights(model_dir): - """Test LLM update_weights with both serialized and direct IPC handle formats.""" model_dir = str(llm_models_root() / model_dir) - with TemporaryDirectory() as tmp_model_dir: - num_hidden_layers = 1 - process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers) - hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers) - tokenizer = AutoTokenizer.from_pretrained(model_dir) - kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) - llm = LLM( - model=tmp_model_dir, - ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", - tensor_parallel_size=1, - load_format="dummy", - pipeline_parallel_size=1, - kv_cache_config=kv_cache_config, - ) + num_hidden_layers = 1 + hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers) + tokenizer = AutoTokenizer.from_pretrained(model_dir) + kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) + llm = LLM( + model=model_dir, + ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + tensor_parallel_size=1, + load_format="dummy", + pipeline_parallel_size=1, + kv_cache_config=kv_cache_config, + model_kwargs={"num_hidden_layers": num_hidden_layers}, + ) - # Generate texts from the prompts. - prompts_texts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - prompts = [tokenizer.encode(prompt) for prompt in prompts_texts] - del tokenizer - sampling_params = SamplingParams( - temperature=0, return_generation_logits=True, max_tokens=1024 - ) + # Generate texts from the prompts. + prompts_texts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + prompts = [tokenizer.encode(prompt) for prompt in prompts_texts] + del tokenizer + sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024) - ipc_handles = hf_model.get_weight_ipc_handles_serialized([0]) + ipc_handles = hf_model.get_weight_ipc_handles_serialized([0]) - llm._collective_rpc("update_weights", (ipc_handles,)) - # Finalize the update weights - llm._collective_rpc("update_weights", (None,)) + llm._collective_rpc("update_weights", (ipc_handles,)) + # Finalize the update weights + llm._collective_rpc("update_weights", (None,)) - llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) - compare_logits(llm_logits, ref_logits) + llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) + compare_logits(llm_logits, ref_logits) @pytest.mark.parametrize( @@ -216,59 +180,107 @@ def test_llm_update_weights(model_dir): ) def test_llm_partial_update_weights(model_dir): model_dir = str(llm_models_root() / model_dir) - with TemporaryDirectory() as tmp_model_dir: - num_hidden_layers = 1 - process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers) - hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers) - tokenizer = AutoTokenizer.from_pretrained(model_dir) - kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) + num_hidden_layers = 1 + hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers) + tokenizer = AutoTokenizer.from_pretrained(model_dir) + kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) - llm = LLM( - model=tmp_model_dir, - ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", - tensor_parallel_size=1, - load_format="dummy", - pipeline_parallel_size=1, - kv_cache_config=kv_cache_config, - ) + llm = LLM( + model=model_dir, + ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + tensor_parallel_size=1, + load_format="dummy", + pipeline_parallel_size=1, + kv_cache_config=kv_cache_config, + model_kwargs={"num_hidden_layers": num_hidden_layers}, + ) - # Generate texts from the prompts. - prompts_texts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - prompts = [tokenizer.encode(prompt) for prompt in prompts_texts] - del tokenizer + # Generate texts from the prompts. + prompts_texts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + prompts = [tokenizer.encode(prompt) for prompt in prompts_texts] + del tokenizer - sampling_params = SamplingParams( - temperature=0, return_generation_logits=True, max_tokens=1024 - ) + sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024) - def common_filter(filter_name: str) -> Callable[[str], bool]: - def filter_fn(name: str) -> bool: - return name.endswith(filter_name) + def common_filter(filter_name: str) -> Callable[[str], bool]: + def filter_fn(name: str) -> bool: + return name.endswith(filter_name) - return filter_fn + return filter_fn - # Generate filter_list from model weight keys by removing layer prefix - # e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight" - layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.") - filter_set = set() - for name, _ in hf_model.all_weights[hf_model.device_id]: - suffix = layer_prefix_pattern.sub("", name) - filter_set.add(suffix) - filter_list = list(filter_set) + # Generate filter_list from model weight keys by removing layer prefix + # e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight" + layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.") + filter_set = set() + for name, _ in hf_model.all_weights[hf_model.device_id]: + suffix = layer_prefix_pattern.sub("", name) + filter_set.add(suffix) + filter_list = list(filter_set) - for filter_name in filter_list: - weight_filter = common_filter(filter_name=filter_name) - ipc_handles = hf_model.get_weight_ipc_handles_serialized( - [0], weight_filter=weight_filter - ) - llm._collective_rpc("update_weights", (ipc_handles,)) - # Finalize the update weights - llm._collective_rpc("update_weights", (None,)) + for filter_name in filter_list: + weight_filter = common_filter(filter_name=filter_name) + ipc_handles = hf_model.get_weight_ipc_handles_serialized([0], weight_filter=weight_filter) + llm._collective_rpc("update_weights", (ipc_handles,)) + # Finalize the update weights + llm._collective_rpc("update_weights", (None,)) - llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) - compare_logits(llm_logits, ref_logits) + llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) + compare_logits(llm_logits, ref_logits) + + +@pytest.mark.parametrize( + "model_dir, fp8_model_dir", + [ + ("Qwen3/Qwen3-8B", "Qwen3/Qwen3-8B-FP8"), + ("Qwen3/Qwen3-30B-A3B", "Qwen3/Qwen3-30B-A3B-FP8"), + ], +) +def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir): + model_dir = str(llm_models_root() / model_dir) + fp8_model_dir = str(llm_models_root() / fp8_model_dir) + num_hidden_layers = 1 + hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers) + tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir) + kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) + llm = LLM( + model=model_dir, + ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + tensor_parallel_size=1, + load_format="dummy", + pipeline_parallel_size=1, + kv_cache_config=kv_cache_config, + model_kwargs={ + "num_hidden_layers": num_hidden_layers, + "quantization_config": { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "weight_block_size": [128, 128], + }, + }, + ) + + # Generate texts from the prompts. + prompts_texts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + prompts = [tokenizer.encode(prompt) for prompt in prompts_texts] + del tokenizer + sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024) + + ipc_handles = hf_model.get_weight_ipc_handles_serialized([0]) + + llm._collective_rpc("update_weights", (ipc_handles,)) + # Finalize the update weights + llm._collective_rpc("update_weights", (None,)) + + llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) + compare_logits(llm_logits, ref_logits)