TensorRT-LLMs/examples/hf_lora_convert.py
Dan Blanaru 16d2467ea8 Update TensorRT-LLM (#2755)
* Update TensorRT-LLM

---------

Co-authored-by: Denis Kayshev <topenkoff@gmail.com>
Co-authored-by: akhoroshev <arthoroshev@gmail.com>
Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com>

Update
2025-02-11 03:01:00 +00:00

267 lines
9.7 KiB
Python
Executable File

#! /usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import datetime
import json
import logging
import re
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.models.convert_utils import get_model_path, load_state_dict
log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s"
logging.basicConfig(format=log_format)
LOGGER = logging.getLogger(__name__)
def save_val(val, dir, key, tp_num=None, write_npy=False):
ext = "npy" if write_npy else "bin"
suffix = ext if tp_num is None else f"{tp_num}.{ext}"
if write_npy:
np.save(dir / f"model.{key}.{suffix}", val)
else:
val.tofile(dir / f"model.{key}.{suffix}")
def get_all_lora_weights(lora_weights):
all_weights = defaultdict(lambda: defaultdict(dict))
pattern = re.compile(
r'(.*\.layers\.([0-9]+)\.(self_attn|mlp)\.([a-z_]+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight).*'
)
moe_pattern = re.compile(
r'(.*\.layers\.([0-9]+)\.(block_sparse_moe)\.((experts)\.([0-9]+)\.|)([a-zA-Z0-9_]+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight).*'
)
for key, weights in lora_weights.items():
m = pattern.match(key)
m_moe = moe_pattern.match(key)
if m:
layer_idx = int(m.group(2))
hf_module = m.group(4)
inout = m.group(5)
dora_magnitude = m.group(6) or m.group(7)
if inout:
inout = "in" if inout == "A" else "out"
all_weights[layer_idx][hf_module][inout] = weights
elif dora_magnitude:
LOGGER.warning(
"Detected DoRA magnitude vector, make sure it was preprocessed and normalized using the proper base model weights"
)
all_weights[layer_idx][hf_module]["magnitude"] = weights.view(
-1)
elif m_moe:
layer_idx = int(m_moe.group(2))
hf_module = m_moe.group(7)
inout = m_moe.group(8)
dora_magnitude = m_moe.group(9) or m.group(10)
if inout:
inout = "in" if inout == "A" else "out"
all_weights[layer_idx][hf_module][inout] = weights
elif dora_magnitude:
LOGGER.warning(
"Detected DoRA magnitude vector, make sure it was preprocessed and normalized using the proper base model weights"
)
all_weights[layer_idx][hf_module]["magnitude"] = weights.view(
-1)
else:
print(f"no match {key}")
continue
return all_weights
def preprocess_lora_weights(lora_model):
# Swap weights of gate_up_proj
for key, value in lora_model.items():
if "gate_up_proj.lora_B.weight" in key:
print("Swap {}".format(key))
original_weights = value.contiguous().clone()
half_split = original_weights.shape[0] // 2
first_half = original_weights[:half_split, :]
second_half = original_weights[half_split:, :]
value = torch.cat((second_half, first_half), dim=0)
lora_model[key] = value
return lora_model
hf_modules_to_trtllm_modules = {
"q_proj": "attn_q",
"v_proj": "attn_v",
"k_proj": "attn_k",
"qkv_proj": "attn_qkv",
"query_key_value": "attn_qkv",
"o_proj": "attn_dense",
"dense": "attn_dense",
"gate_proj": "mlp_h_to_4h",
"down_proj": "mlp_4h_to_h",
"up_proj": "mlp_gate",
"gate_up_proj": "mlp_h_to_4h",
"c_fc": "mlp_h_to_4h",
"c_proj": "mlp_4h_to_h",
"w1": "moe_h_to_4h",
"w2": "moe_4h_to_h",
"w3": "moe_gate",
"gate": "moe_router",
} # lora modules on llama
hf_modules_to_module_id = {
k: LoraManager.LORA_MODULE_IDS[v]
for k, v in hf_modules_to_trtllm_modules.items()
}
def convert_hf_model(model_dir, dtype, out_dir):
saved_dir = Path(out_dir)
saved_dir.mkdir(parents=True, exist_ok=True)
with open(f"{model_dir}/adapter_config.json", "r") as f:
config = json.load(f)
alpha = config.get("lora_alpha")
use_rslora = config.get("use_rslora", False)
lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
lora_model = preprocess_lora_weights(lora_model)
all_weights = get_all_lora_weights(lora_model)
converted_weights = []
converted_config = []
def derive_adapter_size(inout_weight: torch.Tensor) -> int:
assert len(inout_weight.shape) == 2
dim0, dim1 = inout_weight.shape
# assume the hidden dim is the larger of the 2
adapter_size = min(dim0, dim1)
return adapter_size
def derive_weights_scale(adapter_size: int, alpha: float,
use_rslora: bool) -> float:
if use_rslora:
return alpha / np.sqrt(adapter_size)
return alpha / adapter_size
for layer_idx, layer_weights in all_weights.items():
for hf_module, module_weights in layer_weights.items():
in_weights = module_weights['in']
out_weights = module_weights['out']
magnitude = module_weights.get("magnitude", None)
is_dora = magnitude is not None
processed_weights = []
assert len(in_weights.shape) == 2
assert len(out_weights.shape) == 2
assert not is_dora or len(magnitude.shape) == 1
adapter_size = derive_adapter_size(in_weights)
assert adapter_size == derive_adapter_size(
out_weights), "adapter size of A mismatches adapter size of B"
scale = derive_weights_scale(adapter_size, alpha, use_rslora)
for w, inout in ((in_weights, "in"), (out_weights, "out")):
dim0 = w.shape[0]
dim1 = w.shape[1]
# in_weights should have shape [adaper_size, hidden]
if dim1 < dim0 and inout == "in":
w = w.transpose(1, 0)
# out_weights should have shape [hidden, adapter_size]
elif dim0 < dim1 and inout == "out":
w = w.transpose(1, 0)
if inout == "out":
w = w * scale
w = w.contiguous().flatten().to(dtype=str_dtype_to_torch(dtype))
processed_weights.append(w)
if is_dora:
processed_weights.append(magnitude.contiguous().flatten().to(
dtype=str_dtype_to_torch(dtype)))
processed_weights = torch.concatenate(processed_weights).flatten()
converted_weights.append(processed_weights)
converted_config.append([
hf_modules_to_module_id[hf_module], layer_idx, adapter_size,
1 if is_dora else 0
])
max_row_size = 0
for t in converted_weights:
max_row_size = max(max_row_size, t.shape[0])
for i in range(len(converted_weights)):
converted_weights[i] = torch.nn.functional.pad(
converted_weights[i],
(0, max_row_size - converted_weights[i].shape[0])).unsqueeze(0)
converted_weights = torch_to_numpy(
torch.concatenate(
converted_weights,
dim=0).unsqueeze(0).to(dtype=str_dtype_to_torch(dtype)).cpu())
converted_config = torch.tensor(converted_config,
dtype=torch.int32,
device='cpu').unsqueeze(0).numpy()
save_val(converted_weights,
saved_dir,
"lora_weights",
tp_num=None,
write_npy=True)
save_val(converted_config,
saved_dir,
"lora_config",
tp_num=None,
write_npy=True)
def main(args):
start_time = datetime.datetime.now()
convert_hf_model(args.in_file, args.storage_type, args.out_dir)
LOGGER.info("Spent %s (h:m:s) to convert the prompt model",
datetime.datetime.now() - start_time)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--out-dir',
'-o',
type=Path,
help='path to output embedding table file in the .npy format',
required=True)
parser.add_argument('--in-file',
'-i',
type=Path,
help='path to input lora checkpoint file',
required=True)
parser.add_argument("--verbose",
action="store_true",
help="Provide verbose messages")
parser.add_argument("--storage-type",
type=str,
default="float16",
choices=["float32", "float16", "bfloat16"])
args = parser.parse_args()
LOGGER.setLevel(logging.DEBUG if args.verbose else logging.INFO)
print("\n=============== Argument ===============")
for key in vars(args):
print(f"{key}: {vars(args)[key]}")
print("========================================")
main(args)