TensorRT-LLMs/tensorrt_llm/runtime/lora_manager.py
Kaiyu Xie c89653021e
Update TensorRT-LLM (20240116) (#891)
* Update TensorRT-LLM

---------

Co-authored-by: Eddie-Wang1120 <81598289+Eddie-Wang1120@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-01-16 20:03:11 +08:00

462 lines
19 KiB
Python

import json
import os
from pathlib import Path
import numpy as np
import torch
from .._utils import (fromfile, numpy_to_torch, str_dtype_to_np,
str_dtype_to_torch, torch_to_numpy)
class LoraConfig(object):
LORA_MODULE_IDS = {
"attn_qkv": 0,
"attn_q": 1,
"attn_k": 2,
"attn_v": 3,
"attn_dense": 4,
"mlp_h_to_4h": 5,
"mlp_4h_to_h": 6,
"mlp_gate": 7,
}
def __init__(self,
hf_lora_dir: str = None,
adapter_config: dict = {},
tokenizer_config: dict = {},
lora_target_modules: list = [],
is_valid: bool = False,
has_tokenizer: bool = False,
lm_head_weight=None,
embedding_weight=None,
hf_modules_to_trtllm_modules: dict = {},
trtllm_modules_to_hf_modules: dict = {}):
self.hf_lora_dir = hf_lora_dir
self.adapter_config = adapter_config
self.tokenizer_config = tokenizer_config
self.hf_lora_target_modules = lora_target_modules
self.lora_target_modules = [
hf_modules_to_trtllm_modules[m] for m in lora_target_modules
]
self.is_valid = is_valid
self.has_tokenizer = has_tokenizer
self.lm_head_weight = lm_head_weight
self.embedding_weight = embedding_weight
self.vocab_size, self.hidden_size = self.lm_head_weight.shape if self.lm_head_weight is not None else (
0, 0)
self.hf_modules_to_trtllm_modules = hf_modules_to_trtllm_modules
self.trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules
@classmethod
def from_hf(cls, hf_lora_dir, hf_modules_to_trtllm_modules,
trtllm_modules_to_hf_modules):
lora_target_modules = {}
adapter_config = None
tokenizer_config = None
hf_lora_dir = hf_lora_dir
is_valid = True
has_tokenizer = True
if os.path.exists(f"{hf_lora_dir}/adapter_config.json"):
with open(f"{hf_lora_dir}/adapter_config.json") as f:
adapter_config = json.load(f)
lora_target_modules = adapter_config["target_modules"]
else:
is_valid = False
if os.path.exists(f"{hf_lora_dir}/tokenizer_config.json"):
with open(f"{hf_lora_dir}/tokenizer_config.json") as f:
tokenizer_config = json.load(f)
else:
has_tokenizer = False
lm_head_weight = None
embedding_weight = None
if os.path.exists(f"{hf_lora_dir}/adapter_model.bin"):
lora_weight = torch.load(f"{hf_lora_dir}/adapter_model.bin")
if adapter_config["modules_to_save"] is not None:
if "lm_head" in adapter_config["modules_to_save"]:
lm_head_weight = lora_weight[
"base_model.model.lm_head.weight"]
if "embed_tokens" in adapter_config["modules_to_save"]:
embedding_weight = lora_weight[
"base_model.model.model.embed_tokens.weight"]
return cls(hf_lora_dir, adapter_config, tokenizer_config,
lora_target_modules, is_valid, has_tokenizer, lm_head_weight,
embedding_weight, hf_modules_to_trtllm_modules,
trtllm_modules_to_hf_modules)
class LoraManager(object):
def __init__(self):
self._lora_uid_to_key = {}
'''
_lora_uid_to_low_ranks: dict[str -> List[dict[str -> int]]]
{
uid:
[
{
lora_module: int
}, # layer_0_rank,
{
lora_module: int
}, # layer_1_rank,
...
]
}
_lora_weights_pointers_list:
[
{
uid:
{
lora_module_1: [t_in, t_out]
lora_module_2: [t_in, t_out]
}
}, # layer_0
{
}, # layer_1
...
]
'''
self._lora_uid_to_low_ranks = {}
self._lora_weights = []
self._lora_weights_pointers_list = []
self._lora_cpp_weights = {}
self._lora_weight_config = {}
def load_from_ckpt(self, model_dir, model_config, runtime_mapping,
ckpt_source):
if ckpt_source == "hf":
self.load_from_hf(model_dir, model_config, runtime_mapping)
elif ckpt_source == "nemo":
self.load_from_nemo(model_dir, model_config, runtime_mapping)
else:
assert False, f"LoraManager does not support source {ckpt_source}"
def load_from_nemo(self, model_dirs, model_config, runtime_mapping):
'''
Load lora modules, could be move to client side
'''
self._model_config = model_config
model_dir = Path(model_dirs[0])
with open(model_dir / "lora_weights.json", 'r') as f:
config = json.load(f)
lora_config = config['lora_config']
precision = config.get('precision', 'float16')
for key in lora_config['lora_kqv_adapter']:
self._lora_uid_to_key[lora_config['lora_kqv_adapter'][key]
['key']] = key
lora_target_modules = model_config.lora_target_modules
dtype = model_config.dtype
for layer_idx in range(model_config.num_layers):
self._lora_weights_pointers_list.append({})
for uid, key in self._lora_uid_to_key.items():
self._lora_weights_pointers_list[layer_idx].update({uid: {}})
low_rank = int(lora_config['lora_kqv_adapter'][key]['low_rank'])
if uid not in self._lora_cpp_weights:
self._lora_cpp_weights[uid] = []
if uid not in self._lora_weight_config:
self._lora_weight_config[uid] = []
for lora_module in lora_target_modules:
if uid not in self._lora_uid_to_low_ranks:
self._lora_uid_to_low_ranks.update(
{uid: [{} for _ in range(model_config.num_layers)]})
self._lora_uid_to_low_ranks[uid][layer_idx][
lora_module] = low_rank
prefix = f"model.model.language_model.encoder.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter.{key}"
t_in = numpy_to_torch(
np.ascontiguousarray(
fromfile(
model_dir, f'{prefix}.linear_in.weight.bin',
[model_config.hidden_size, low_rank],
str_dtype_to_np(precision)).transpose(
1,
0))).cuda() # t_in: [low_rank, hidden_size]
t_out = numpy_to_torch(
np.ascontiguousarray(
fromfile(model_dir,
f'{prefix}.linear_out.weight.bin',
[low_rank, model_config.hidden_size * 3],
str_dtype_to_np(precision)).transpose(
1, 0))).cuda(
) # t_in: [hidden_size * 3, low_rank]
t_in = t_in.float().to(str_dtype_to_torch(dtype))
t_out = t_out.float().to(str_dtype_to_torch(dtype))
self._lora_weights_pointers_list[layer_idx][uid].update({
lora_module: [
t_in.contiguous().data_ptr(),
t_out.contiguous().data_ptr()
]
})
self._lora_weights.append(t_in)
self._lora_weights.append(t_out)
self._lora_cpp_weights[uid].append(
torch.concatenate([t_in.flatten(),
t_out.flatten()]))
self._lora_weight_config[uid].append(
np.array([
LoraConfig.LORA_MODULE_IDS[lora_module], layer_idx,
int(low_rank)
],
dtype=np.int32))
if "-1" not in self._lora_uid_to_low_ranks:
self._lora_uid_to_low_ranks.update(
{"-1": [{} for _ in range(model_config.num_layers)]})
self._lora_uid_to_low_ranks["-1"][layer_idx][lora_module] = 0
def load_from_hf(self, model_dirs, model_config, runtime_mapping):
'''
lora config of https://huggingface.co/hfl/chinese-alpaca-2-lora-7b
{
"base_model_name_or_path": "/Llama-2-7b-hf",
"bias": "none",
"enable_lora": null,
"fan_in_fan_out": false,
"inference_mode": true,
"lora_alpha": 128.0,
"lora_dropout": 0.05,
"merge_weights": false,
"modules_to_save": [
"embed_tokens",
"lm_head"
],
"peft_type": "LORA",
"r": 64,
"target_modules": [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"down_proj",
"up_proj"
],
"task_type": "CAUSAL_LM"
}
keys in adapter_model.bin:
base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight torch.Size([64, 4096])
base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight torch.Size([4096, 64])
base_model.model.model.layers.0.self_attn.k_proj.lora_A.weight torch.Size([64, 4096])
base_model.model.model.layers.0.self_attn.k_proj.lora_B.weight torch.Size([4096, 64])
base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight torch.Size([64, 4096])
base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight torch.Size([4096, 64])
base_model.model.model.layers.0.self_attn.o_proj.lora_A.weight torch.Size([64, 4096])
base_model.model.model.layers.0.self_attn.o_proj.lora_B.weight torch.Size([4096, 64])
base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight torch.Size([64, 4096])
base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight torch.Size([11008, 64])
base_model.model.model.layers.0.mlp.up_proj.lora_A.weight torch.Size([64, 4096])
base_model.model.model.layers.0.mlp.up_proj.lora_B.weight torch.Size([11008, 64])
base_model.model.model.layers.0.mlp.down_proj.lora_A.weight torch.Size([64, 11008])
base_model.model.model.layers.0.mlp.down_proj.lora_B.weight torch.Size([4096, 64])
...
'''
tp_size = runtime_mapping.tp_size
tp_rank = runtime_mapping.tp_rank
lora_hf_configs = [{}]
ranks = [0]
uids = ["-1"]
for i, model_dir in enumerate(model_dirs):
with open(f"{model_dir}/adapter_config.json", 'r') as f:
config = json.load(f)
lora_hf_configs.append(config)
ranks.append(config["r"])
uids.append(str(i))
new_model_dirs = [""] + model_dirs
lora_target_modules = model_config.lora_target_modules
dtype = model_config.dtype
for uid, rank, model_dir, hf_config in zip(uids, ranks, new_model_dirs,
lora_hf_configs):
if uid not in self._lora_cpp_weights:
self._lora_cpp_weights[uid] = []
if uid not in self._lora_weight_config:
self._lora_weight_config[uid] = []
if model_dir != "":
lora_model = torch.load(f"{model_dir}/adapter_model.bin")
else:
lora_model = None
self._lora_uid_to_low_ranks[uid] = []
for layer_idx in range(model_config.num_layers):
self._lora_weights_pointers_list.append({})
self._lora_weights_pointers_list[layer_idx].update({uid: {}})
self._lora_uid_to_low_ranks[uid].append({})
prefix = "base_model.model.model.layers"
for lora_module in lora_target_modules:
if uid == "-1" or model_config.trtllm_modules_to_hf_modules[
lora_module] not in hf_config["target_modules"]:
self._lora_uid_to_low_ranks[uid][layer_idx][
lora_module] = 0
continue
if lora_module == "attn_q" or lora_module == "attn_k" or lora_module == "attn_v":
name = f"{prefix}.{layer_idx}.{lora_module.replace('attn_', 'self_attn.')}_proj"
# not split
t_in = lora_model[f"{name}.lora_A.weight"]
# split by column
t_out = lora_model[f"{name}.lora_B.weight"]
assert t_out.shape[0] % tp_size == 0
t_out = torch.split(t_out,
t_out.shape[0] // tp_size,
dim=0)[tp_rank].contiguous()
elif lora_module == "attn_dense":
# split by row
t_in = lora_model[
f"{prefix}.{layer_idx}.self_attn.o_proj.lora_A.weight"]
assert t_in.shape[1] % tp_size == 0
t_in = torch.split(t_in,
t_in.shape[1] // tp_size,
dim=1)[tp_rank].contiguous()
# not split
t_out = lora_model[
f"{prefix}.{layer_idx}.self_attn.o_proj.lora_B.weight"]
elif lora_module == "mlp_h_to_4h":
# not split
t_in = lora_model[
f"{prefix}.{layer_idx}.mlp.gate_proj.lora_A.weight"]
# split by column
t_out = lora_model[
f"{prefix}.{layer_idx}.mlp.gate_proj.lora_B.weight"]
assert t_out.shape[0] % tp_size == 0
t_out = torch.split(t_out,
t_out.shape[0] // tp_size,
dim=0)[tp_rank].contiguous()
elif lora_module == "mlp_gate":
# not split
t_in = lora_model[
f"{prefix}.{layer_idx}.mlp.up_proj.lora_A.weight"]
# split by column
t_out = lora_model[
f"{prefix}.{layer_idx}.mlp.up_proj.lora_B.weight"]
assert t_out.shape[0] % tp_size == 0
t_out = torch.split(t_out,
t_out.shape[0] // tp_size,
dim=0)[tp_rank].contiguous()
elif lora_module == "mlp_4h_to_h":
# split by row
t_in = lora_model[
f"{prefix}.{layer_idx}.mlp.down_proj.lora_A.weight"]
assert t_in.shape[0] % tp_size == 0
t_in = torch.split(t_in,
t_in.shape[1] // tp_size,
dim=1)[tp_rank].contiguous()
# not split
t_out = lora_model[
f"{prefix}.{layer_idx}.mlp.down_proj.lora_B.weight"]
t_in = t_in.cuda().contiguous()
t_out = t_out.cuda().contiguous()
scale = float(hf_config["lora_alpha"] / hf_config["r"])
t_out = t_out * scale
t_in = t_in.float().to(str_dtype_to_torch(dtype))
t_out = t_out.float().to(str_dtype_to_torch(dtype))
self._lora_weights_pointers_list[layer_idx][uid].update(
{lora_module: [t_in.data_ptr(),
t_out.data_ptr()]})
assert t_in.shape[0] == int(hf_config["r"])
self._lora_uid_to_low_ranks[uid][layer_idx][
lora_module] = int(hf_config["r"])
# prevent torch free this buffer
self._lora_weights.append(t_in)
self._lora_weights.append(t_out)
self._lora_cpp_weights[uid].append(
torch.concatenate([t_in.flatten(),
t_out.flatten()]))
self._lora_weight_config[uid].append(
np.array([
LoraConfig.LORA_MODULE_IDS[lora_module], layer_idx,
int(hf_config['r'])
],
dtype=np.int32))
del lora_model
def save_lora_weights_to_bin(self, out_dir):
def save_val(val, dir, key, tp_num=None, write_npy=False):
ext = "npy" if write_npy else "bin"
suffix = ext if tp_num is None else f"{tp_num}.{ext}"
if write_npy:
np.save(dir / f"model.{key}.{suffix}", val)
else:
val.tofile(dir / f"model.{key}.{suffix}")
if isinstance(out_dir, str):
out_dir_path = Path(out_dir)
elif isinstance(out_dir, Path):
out_dir_path = out_dir
else:
assert False
for uid in self._lora_cpp_weights:
if uid == '-1':
continue
all_weights = np.expand_dims(
np.stack([
torch_to_numpy(w.flatten().contiguous())
for w in self._lora_cpp_weights[uid]
]), 0)
all_configs = np.expand_dims(
np.stack(self._lora_weight_config[uid]), 0)
uid_path = out_dir_path / f"{uid}"
uid_path.mkdir(parents=True, exist_ok=True)
save_val(all_weights,
uid_path,
"lora_weights",
tp_num=None,
write_npy=True)
save_val(all_configs,
uid_path,
"lora_config",
tp_num=None,
write_npy=True)
def uid_to_key(self, uid: str):
assert isinstance(uid, str)
return self._lora_uid_to_key[uid]
def uid_to_low_ranks(self, uid: str):
assert isinstance(uid, str)
return self._lora_uid_to_low_ranks[uid]
@property
def lora_weights(self):
return self._lora_weights
@property
def lora_weights_pointers_list(self):
return self._lora_weights_pointers_list