mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Denis Kayshev <topenkoff@gmail.com> Co-authored-by: akhoroshev <arthoroshev@gmail.com> Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com> Update
267 lines
9.7 KiB
Python
Executable File
267 lines
9.7 KiB
Python
Executable File
#! /usr/bin/env python3
|
|
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import re
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy
|
|
from tensorrt_llm.lora_manager import LoraManager
|
|
from tensorrt_llm.models.convert_utils import get_model_path, load_state_dict
|
|
|
|
log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s"
|
|
logging.basicConfig(format=log_format)
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
def save_val(val, dir, key, tp_num=None, write_npy=False):
|
|
ext = "npy" if write_npy else "bin"
|
|
suffix = ext if tp_num is None else f"{tp_num}.{ext}"
|
|
if write_npy:
|
|
np.save(dir / f"model.{key}.{suffix}", val)
|
|
else:
|
|
val.tofile(dir / f"model.{key}.{suffix}")
|
|
|
|
|
|
def get_all_lora_weights(lora_weights):
|
|
all_weights = defaultdict(lambda: defaultdict(dict))
|
|
pattern = re.compile(
|
|
r'(.*\.layers\.([0-9]+)\.(self_attn|mlp)\.([a-z_]+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight).*'
|
|
)
|
|
moe_pattern = re.compile(
|
|
r'(.*\.layers\.([0-9]+)\.(block_sparse_moe)\.((experts)\.([0-9]+)\.|)([a-zA-Z0-9_]+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight).*'
|
|
)
|
|
for key, weights in lora_weights.items():
|
|
m = pattern.match(key)
|
|
m_moe = moe_pattern.match(key)
|
|
if m:
|
|
layer_idx = int(m.group(2))
|
|
hf_module = m.group(4)
|
|
inout = m.group(5)
|
|
dora_magnitude = m.group(6) or m.group(7)
|
|
|
|
if inout:
|
|
inout = "in" if inout == "A" else "out"
|
|
all_weights[layer_idx][hf_module][inout] = weights
|
|
elif dora_magnitude:
|
|
LOGGER.warning(
|
|
"Detected DoRA magnitude vector, make sure it was preprocessed and normalized using the proper base model weights"
|
|
)
|
|
all_weights[layer_idx][hf_module]["magnitude"] = weights.view(
|
|
-1)
|
|
|
|
elif m_moe:
|
|
layer_idx = int(m_moe.group(2))
|
|
hf_module = m_moe.group(7)
|
|
inout = m_moe.group(8)
|
|
dora_magnitude = m_moe.group(9) or m.group(10)
|
|
|
|
if inout:
|
|
inout = "in" if inout == "A" else "out"
|
|
all_weights[layer_idx][hf_module][inout] = weights
|
|
elif dora_magnitude:
|
|
LOGGER.warning(
|
|
"Detected DoRA magnitude vector, make sure it was preprocessed and normalized using the proper base model weights"
|
|
)
|
|
all_weights[layer_idx][hf_module]["magnitude"] = weights.view(
|
|
-1)
|
|
else:
|
|
print(f"no match {key}")
|
|
continue
|
|
return all_weights
|
|
|
|
|
|
def preprocess_lora_weights(lora_model):
|
|
# Swap weights of gate_up_proj
|
|
for key, value in lora_model.items():
|
|
if "gate_up_proj.lora_B.weight" in key:
|
|
print("Swap {}".format(key))
|
|
original_weights = value.contiguous().clone()
|
|
half_split = original_weights.shape[0] // 2
|
|
first_half = original_weights[:half_split, :]
|
|
second_half = original_weights[half_split:, :]
|
|
value = torch.cat((second_half, first_half), dim=0)
|
|
lora_model[key] = value
|
|
return lora_model
|
|
|
|
|
|
hf_modules_to_trtllm_modules = {
|
|
"q_proj": "attn_q",
|
|
"v_proj": "attn_v",
|
|
"k_proj": "attn_k",
|
|
"qkv_proj": "attn_qkv",
|
|
"query_key_value": "attn_qkv",
|
|
"o_proj": "attn_dense",
|
|
"dense": "attn_dense",
|
|
"gate_proj": "mlp_h_to_4h",
|
|
"down_proj": "mlp_4h_to_h",
|
|
"up_proj": "mlp_gate",
|
|
"gate_up_proj": "mlp_h_to_4h",
|
|
"c_fc": "mlp_h_to_4h",
|
|
"c_proj": "mlp_4h_to_h",
|
|
"w1": "moe_h_to_4h",
|
|
"w2": "moe_4h_to_h",
|
|
"w3": "moe_gate",
|
|
"gate": "moe_router",
|
|
} # lora modules on llama
|
|
hf_modules_to_module_id = {
|
|
k: LoraManager.LORA_MODULE_IDS[v]
|
|
for k, v in hf_modules_to_trtllm_modules.items()
|
|
}
|
|
|
|
|
|
def convert_hf_model(model_dir, dtype, out_dir):
|
|
saved_dir = Path(out_dir)
|
|
saved_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(f"{model_dir}/adapter_config.json", "r") as f:
|
|
config = json.load(f)
|
|
|
|
alpha = config.get("lora_alpha")
|
|
use_rslora = config.get("use_rslora", False)
|
|
|
|
lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
|
|
lora_model = preprocess_lora_weights(lora_model)
|
|
all_weights = get_all_lora_weights(lora_model)
|
|
converted_weights = []
|
|
converted_config = []
|
|
|
|
def derive_adapter_size(inout_weight: torch.Tensor) -> int:
|
|
assert len(inout_weight.shape) == 2
|
|
dim0, dim1 = inout_weight.shape
|
|
# assume the hidden dim is the larger of the 2
|
|
adapter_size = min(dim0, dim1)
|
|
return adapter_size
|
|
|
|
def derive_weights_scale(adapter_size: int, alpha: float,
|
|
use_rslora: bool) -> float:
|
|
if use_rslora:
|
|
return alpha / np.sqrt(adapter_size)
|
|
return alpha / adapter_size
|
|
|
|
for layer_idx, layer_weights in all_weights.items():
|
|
for hf_module, module_weights in layer_weights.items():
|
|
in_weights = module_weights['in']
|
|
out_weights = module_weights['out']
|
|
magnitude = module_weights.get("magnitude", None)
|
|
is_dora = magnitude is not None
|
|
|
|
processed_weights = []
|
|
|
|
assert len(in_weights.shape) == 2
|
|
assert len(out_weights.shape) == 2
|
|
assert not is_dora or len(magnitude.shape) == 1
|
|
|
|
adapter_size = derive_adapter_size(in_weights)
|
|
assert adapter_size == derive_adapter_size(
|
|
out_weights), "adapter size of A mismatches adapter size of B"
|
|
scale = derive_weights_scale(adapter_size, alpha, use_rslora)
|
|
|
|
for w, inout in ((in_weights, "in"), (out_weights, "out")):
|
|
dim0 = w.shape[0]
|
|
dim1 = w.shape[1]
|
|
# in_weights should have shape [adaper_size, hidden]
|
|
if dim1 < dim0 and inout == "in":
|
|
w = w.transpose(1, 0)
|
|
# out_weights should have shape [hidden, adapter_size]
|
|
elif dim0 < dim1 and inout == "out":
|
|
w = w.transpose(1, 0)
|
|
if inout == "out":
|
|
w = w * scale
|
|
w = w.contiguous().flatten().to(dtype=str_dtype_to_torch(dtype))
|
|
processed_weights.append(w)
|
|
|
|
if is_dora:
|
|
processed_weights.append(magnitude.contiguous().flatten().to(
|
|
dtype=str_dtype_to_torch(dtype)))
|
|
|
|
processed_weights = torch.concatenate(processed_weights).flatten()
|
|
converted_weights.append(processed_weights)
|
|
converted_config.append([
|
|
hf_modules_to_module_id[hf_module], layer_idx, adapter_size,
|
|
1 if is_dora else 0
|
|
])
|
|
max_row_size = 0
|
|
for t in converted_weights:
|
|
max_row_size = max(max_row_size, t.shape[0])
|
|
for i in range(len(converted_weights)):
|
|
converted_weights[i] = torch.nn.functional.pad(
|
|
converted_weights[i],
|
|
(0, max_row_size - converted_weights[i].shape[0])).unsqueeze(0)
|
|
converted_weights = torch_to_numpy(
|
|
torch.concatenate(
|
|
converted_weights,
|
|
dim=0).unsqueeze(0).to(dtype=str_dtype_to_torch(dtype)).cpu())
|
|
converted_config = torch.tensor(converted_config,
|
|
dtype=torch.int32,
|
|
device='cpu').unsqueeze(0).numpy()
|
|
|
|
save_val(converted_weights,
|
|
saved_dir,
|
|
"lora_weights",
|
|
tp_num=None,
|
|
write_npy=True)
|
|
save_val(converted_config,
|
|
saved_dir,
|
|
"lora_config",
|
|
tp_num=None,
|
|
write_npy=True)
|
|
|
|
|
|
def main(args):
|
|
start_time = datetime.datetime.now()
|
|
convert_hf_model(args.in_file, args.storage_type, args.out_dir)
|
|
|
|
LOGGER.info("Spent %s (h:m:s) to convert the prompt model",
|
|
datetime.datetime.now() - start_time)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--out-dir',
|
|
'-o',
|
|
type=Path,
|
|
help='path to output embedding table file in the .npy format',
|
|
required=True)
|
|
parser.add_argument('--in-file',
|
|
'-i',
|
|
type=Path,
|
|
help='path to input lora checkpoint file',
|
|
required=True)
|
|
parser.add_argument("--verbose",
|
|
action="store_true",
|
|
help="Provide verbose messages")
|
|
parser.add_argument("--storage-type",
|
|
type=str,
|
|
default="float16",
|
|
choices=["float32", "float16", "bfloat16"])
|
|
args = parser.parse_args()
|
|
|
|
LOGGER.setLevel(logging.DEBUG if args.verbose else logging.INFO)
|
|
|
|
print("\n=============== Argument ===============")
|
|
for key in vars(args):
|
|
print(f"{key}: {vars(args)[key]}")
|
|
print("========================================")
|
|
|
|
main(args)
|