12 KiB
TensorRT-LLM Model Weights Loader
Overview
The weights loader is designed for easily converting and loading external weight checkpoints into TensorRT-LLM models.
Workflow
Weight checkpoints can be generated from all sources, and may have different naming and data layouts compared to TRT-LLM's requirements. E.g.:
# HuggingFace LLaMA checkpoints
{
"model.embed_tokens.weight": torch.Tensor([vocab_size, hidden_size])
"model.layers.0.input_layernorm.weight": torch.Tensor([hidden_size]),
"model.layers.0.mlp.down_proj.weight": torch.Tensor([hidden_size, inter_size]),
"model.layers.0.mlp.gate_proj.weight": torch.Tensor([inter_size, hidden_size]),
"model.layers.0.mlp.up_proj.weight": torch.Tensor([inter_size, hidden_size]),
"model.layers.0.post_attention_layernorm.weight": torch.Tensor([hidden_size]),
"model.layers.0.self_attn.q_proj.weight": torch.Tensor([hidden_size, hidden_size]),
"model.layers.0.self_attn.k_proj.weight": torch.Tensor([hidden_size, hidden_size]),
"model.layers.0.self_attn.v_proj.weight": torch.Tensor([hidden_size, hidden_size]),
"model.layers.0.self_attn.o_proj.weight": torch.Tensor([hidden_size, hidden_size]),
...,
}
# TensorRT-LLM expected weights
{
"transformer.vocab_embedding.weight": torch.Tensor([vocab_size, hidden_size])
"transformer.layers.0.input_layernorm.weight": torch.Tensor([hidden_size]),
"transformer.layers.0.mlp.down_proj.weight": torch.Tensor([hidden_size, inter_size]),
"transformer.layers.0.mlp.gate_proj.weight": torch.Tensor([inter_size, hidden_size]),
"transformer.layers.0.mlp.up_proj.weight": torch.Tensor([inter_size, hidden_size]),
"transformer.layers.0.post_layernorm.weight": torch.Tensor([hidden_size]),
"transformer.layers.0.attention.qkv.weight": torch.Tensor([hidden_size * 3, hidden_size]), # Different layout
"transformer.layers.0.attention.dense.weight": torch.Tensor([hidden_size, hidden_size]),
...,
}
Conversion means converting the dictionary of {external_keys:external_weights} into {tllm_keys:tllm_weights}, it includes changing the naming logic and data layouts, and is contains of the following parts:
- Translate a TRT-LLM parameter name into external-format name(s).
- Loading tensor slice(s) according to the translated names.
- Postprocess the tensor(s) into target layout.
Translator
TRT-LLM parameter names are translated in units of sections divided by dots. E.g.:
| TensorRT-LLM key | transformer |
. | layers |
. | 0 |
. | attention |
. | dense |
. | weight |
|---|---|---|---|---|---|---|---|---|---|---|---|
| Translated external key | model |
. | layers |
. | 0 |
. | self_attn |
. | o_proj |
. | weight |
The mapping between TRT-LLM keywords and HF keywords are described in tllm_to_externel_key_dict of ModelWeightsLoader class object.
If any of the mappings has one-to-multiple corresponding, the translated key will get multiplied accordingly. E.g.:
| TensorRT-LLM key and related keyword mapping | Translated external keys |
|---|---|
transformer.layers.0.attention.qkv.weight{"qkv":[q_proj, k_proj, v_proj]} |
model.layers.0.self_attn.q_proj.weightsmodel.layers.0.self_attn.k_proj.weightsmodel.layers.0.self_attn.v_proj.weights |
transformer.layers.0.mlp.fc.weight{"weight":[qweight, qzeros, scales]} |
model.layers.0.mlp.gate_proj.qweightmodel.layers.0.mlp.gate_proj.qzerosmodel.layers.0.mlp.gate_proj.scales |
The default tllm_to_externel_key_dict is based on HF LLaMA as:
class ModelWeightsLoader:
def __init__(self, model_dir, customized_key_dict: dict = {}) -> None:
...
self.tllm_to_externel_key_dict = {
"transformer": "model",
"vocab_embedding": "embed_tokens",
"lm_head": "lm_head",
"ln_f": "norm",
"attention": "self_attn",
"qkv": ["q_proj", "k_proj", "v_proj"],
"dense": "o_proj",
"gate": "up_proj",
"proj": "down_proj",
"fc": "gate_proj",
"input_layernorm": "input_layernorm",
"post_layernorm": "post_attention_layernorm",
}
self.tllm_to_externel_key_dict.update(customized_key_dict)
...
It can be updated through passing customized_key_dict when initializing ModelWeightsLoader.
The dictionary will also get updated according to the layer classes. When iterating over parameters,
if the layer class has attribute tllm_to_externel_key_dict, for keywords exist both in the default one and the layer-specified one,
the weight loader will translate according to the layer attribute with higher priority.
This can enable the support for different quantization precisions automatically.
Loading function
The loading function can load an arbitrary tensor slice according to its key, tp_size, tp_dim and tp_rank.
The template for loading function is as following.
def load_tensor(self, key, tp_size, tp_dim, tp_rank):
# Retrieve file pointer index
if key in self.shard_map:
ptr_idx = self.shard_map[key]
else:
return None
# Load tensor from the corresponding shard
if self.format == ModelWeightsFormat.SAFETENSORS:
tensor = self.shards[ptr_idx].get_slice(key)
tensor_shape = tensor.get_shape()
else:
...
# Shard and return a tensor slice
slice_shape = ...
return tensor[slice_shape]
When initializing the ModelWeightsLoader object, the file format will be derived from model_dir through detect_format. The following formats are supported for now:
- Directory contains or file named
*.safetensors(Recommended, has better performance) - Directory contains or file named
*.bin - Directory contains or file named
*.pth
To support other formats or in-memory loaded models, the format need to be claimed in ModelWeightsFormat, detect_format(), preload() and load_tensor().
Postprocessing functions
After translation and loading, a TRT-LLM key will become a tensor or a list of tensors, which is the input of postprocessing functions.
Operations including QKV concatenating, MoE weight stacking and weight-only quantization can be handled here.
The template of postprocessing function is:
# Example for 1-1 weights mapping
class CustomizedModuleA(Module):
def __init__(...):
super().__init__(...)
...
self.tp_dim = 0 # Need to set or inherit from parent class
def postprocess(self, tllm_key, weights, **kwargs):
weights = proc(weights)
return {tllm_key: weights}
# Example for multiple-multiple weights mapping
class CustomizedModuleB(Module):
def __init__(...):
super().__init__(...)
...
self.tp_dim = 0 # Need to set or inherit from parent class
# The default value of "weights" in tllm_to_externel_key_dict will be override
self.tllm_to_externel_key_dict = {"weight": ["qweight", "qzeros", "scales"]}
def postprocess(self, tllm_key, weights, **kwargs):
# Skipped the postprocess of zeros and weights_scaling_factor
# They are loaded in the postprocess of weight
config = kwargs.get("config", None) # Passed through kwargs by default
if not tllm_key.endswith("weight"):
return {}
# The order in weights is defined in tllm_to_externel_key_dict
qweight, qzeros, scales = weights
proccessed_weight, proccessed_zeros = proc(qweight, qzeros, config.num_heads)
return {
tllm_key: proccessed_weight,
tllm_key.replace("weight", "zeros"): proccessed_zeros,
tllm_key.replace("weight", "weights_scaling_factor"): scales,
}
Examples
The ModelWeightsLoader class can support different models with the following levels:
Natively supported models
For models with native support, users can call the default weight loader without any other operations.
# Using the model weights loader for LLaMA
from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
loader = ModelWeightsLoader(external_checkpoint_dir)
loader.generate_tllm_weights(trtllm_model)
For calibration-free quantization precisions, passing a properly quantized trtllm_model will let the weight loader load at the given precision accordingly. The configurations will be read from trtllm_model.config automatically. For now, LLaMA family models using the default tllm_to_externel_key_dict is supported natively.
Models with customized key names
For models with different naming logic, users can still call the default weight loader with customized_key_dict specified.
# Using the model weights loader for the LLM part of LLaVA
from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
llava_dict = {
"transformer": "language_model.model",
"lm_head": "language_model.lm_head"
}
loader = ModelWeightsLoader(external_checkpoint_dir, llava_dict)
loader.generate_tllm_weights(trtllm_model)
Users need to specify the different part from the default tllm_to_externel_key_dict. The loader still have support across different precisions.
The support for LLaVA and Exaone is in LLaMAForCausalLM.from_hugging_face() of model.py, and can also be taken as examples.
Models with customized weight layout
For models with different weight layout, users can write the conversion loop explicitly and do customized operations.
# Using the model weights loader for BLOOM
from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
bloom_dict = {
"transformer": "",
"layers": "h",
"ln_f": "ln_f",
"lm_head": "word_embeddings",
"ln_embed": "word_embeddings_layernorm",
"vocab_embedding": "word_embeddings",
"attention": "self_attention",
"qkv": "query_key_value",
"dense": "dense",
"fc": "dense_h_to_4h",
"proj": "dense_4h_to_h",
"post_layernorm": "post_attention_layernorm",
}
loader = ModelWeightsLoader(external_checkpoint_dir, bloom_dict)
# See ModelWeightsLoader.generate_tllm_weights()
loader.update_key_mapping(trtllm_model)
tllm_weights = {}
for tllm_key, _ in tqdm(trtllm_model.named_parameters()):
if tllm_key.endswith("qkv"):
# Passing the callable handle
tllm_weights.update(loader.load(tllm_key, preprocess=customized_preprocess))
else:
tllm_weights.update(loader.load(tllm_key))
loader.fill(tllm_weights)
This will apply preprocess after load_tensor() and before postprocess, and demonstrates how to convert the loaded shard into default HF layout. The loader still have support for precisions quantized from FP16/BF16 (e.g. INT8-wo/INT4-wo), the other precisions may require special operations, and can be addressed inside the preprocess function.
The support for Qwen-1 is in QWenForCausalLM.from_hugging_face() of model.py, and can also be taken as example.
Fully customized
If the model weights loader cannot satisfy the requirements, users can write the conversion loop totally on their own.
tllm_weights = {}
for tllm_key, param in tqdm(trtllm_model.named_parameters()):
# Load from external checkpoints
# The load_tensor() function can also be called here
tensor = ...
# Convert tensor and set the values according to the config
if trtllm_model.config.quantization.quant_algo == xxx:
...
else:
...
param.value = tensor
In this mode, every precision require user's own support.
Trouble shooting
The weights loader is enabled for LLaMA family models and Qwen models by default with TensorRT flow only.
If users are encountered with failure caused by ModelWeightsLoader, a workaround is passing environmental variable TRTLLM_DISABLE_UNIFIED_CONVERTER=1 to disable the model weights loader and fallback to the legacy path.
This workaround will be removed in future version after the LLaMA/Qwen weights conversion is stable.