import json import os import re from pathlib import Path import numpy as np import torch from .._utils import str_dtype_to_torch, torch_to_numpy, unpack_nemo_weights def get_all_nemo_lora_weights(num_layers, lora_weights): layer_weights = [{} for _ in range(2 * num_layers)] adapter_key = "self_attention.adapter_layer.lora_kqv_adapter" layer_pattern = re.compile(r'.*\.layers\.([0-9]+)\..*') for key, weights in lora_weights.items(): if adapter_key in key: if key.endswith('linear_in.weight'): inout = 'in' elif key.endswith('linear_out.weight'): inout = 'out' else: continue m = layer_pattern.match(key) layer_idx = int(m.group(1)) layer_weights[layer_idx][inout] = weights return layer_weights class LoraConfig(object): LORA_MODULE_IDS = { "attn_qkv": 0, "attn_q": 1, "attn_k": 2, "attn_v": 3, "attn_dense": 4, "mlp_h_to_4h": 5, "mlp_4h_to_h": 6, "mlp_gate": 7, } def __init__(self, hf_lora_dir: str = None, adapter_config: dict = {}, tokenizer_config: dict = {}, lora_target_modules: list = [], is_valid: bool = False, has_tokenizer: bool = False, lm_head_weight=None, embedding_weight=None, hf_modules_to_trtllm_modules: dict = {}, trtllm_modules_to_hf_modules: dict = {}): self.hf_lora_dir = hf_lora_dir self.adapter_config = adapter_config self.tokenizer_config = tokenizer_config self.hf_lora_target_modules = lora_target_modules self.lora_target_modules = [ hf_modules_to_trtllm_modules[m] for m in lora_target_modules ] self.is_valid = is_valid self.has_tokenizer = has_tokenizer self.lm_head_weight = lm_head_weight self.embedding_weight = embedding_weight self.vocab_size, self.hidden_size = self.lm_head_weight.shape if self.lm_head_weight is not None else ( 0, 0) self.hf_modules_to_trtllm_modules = hf_modules_to_trtllm_modules self.trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules @classmethod def from_hf(cls, hf_lora_dir, hf_modules_to_trtllm_modules, trtllm_modules_to_hf_modules): lora_target_modules = {} adapter_config = None tokenizer_config = None hf_lora_dir = hf_lora_dir is_valid = True has_tokenizer = True if os.path.exists(f"{hf_lora_dir}/adapter_config.json"): with open(f"{hf_lora_dir}/adapter_config.json") as f: adapter_config = json.load(f) lora_target_modules = adapter_config["target_modules"] else: is_valid = False if os.path.exists(f"{hf_lora_dir}/tokenizer_config.json"): with open(f"{hf_lora_dir}/tokenizer_config.json") as f: tokenizer_config = json.load(f) else: has_tokenizer = False lm_head_weight = None embedding_weight = None if os.path.exists(f"{hf_lora_dir}/adapter_model.bin"): lora_weight = torch.load(f"{hf_lora_dir}/adapter_model.bin") if adapter_config["modules_to_save"] is not None: if "lm_head" in adapter_config["modules_to_save"]: lm_head_weight = lora_weight[ "base_model.model.lm_head.weight"] if "embed_tokens" in adapter_config["modules_to_save"]: embedding_weight = lora_weight[ "base_model.model.model.embed_tokens.weight"] return cls(hf_lora_dir, adapter_config, tokenizer_config, lora_target_modules, is_valid, has_tokenizer, lm_head_weight, embedding_weight, hf_modules_to_trtllm_modules, trtllm_modules_to_hf_modules) class LoraManager(object): def __init__(self): self._lora_uid_to_key = {} ''' _lora_uid_to_low_ranks: dict[str -> List[dict[str -> int]]] { uid: [ { lora_module: int }, # layer_0_rank, { lora_module: int }, # layer_1_rank, ... ] } _lora_weights_pointers_list: [ { uid: { lora_module_1: [t_in, t_out] lora_module_2: [t_in, t_out] } }, # layer_0 { }, # layer_1 ... ] ''' self._lora_uid_to_low_ranks = {} self._lora_weights = [] self._lora_weights_pointers_list = [] self._lora_cpp_weights = {} self._lora_weight_config = {} def load_from_ckpt(self, model_dir, model_config, runtime_mapping, ckpt_source): if ckpt_source == "hf": self.load_from_hf(model_dir, model_config, runtime_mapping) elif ckpt_source == "nemo": self.load_from_nemo(model_dir, model_config, runtime_mapping) else: assert False, f"LoraManager does not support source {ckpt_source}" def load_from_nemo(self, model_files, model_config, runtime_mapping): tp_size = runtime_mapping.tp_size tp_rank = runtime_mapping.tp_rank lora_target_modules = model_config.lora_target_modules dtype = model_config.dtype uids = ["-1"] for i in range(len(model_files)): uids.append(str(i)) model_files = [""] + model_files for uid, model_file in zip(uids, model_files): if uid not in self._lora_cpp_weights: self._lora_cpp_weights[uid] = [] if uid not in self._lora_weight_config: self._lora_weight_config[uid] = [] if model_file != "": _, nemo_weights = unpack_nemo_weights(model_file) all_lora_weights = get_all_nemo_lora_weights( model_config.num_layers, nemo_weights) else: all_lora_weights = None nemo_weights = None self._lora_uid_to_low_ranks[uid] = [] for layer_idx in range(model_config.num_layers): self._lora_weights_pointers_list.append({}) self._lora_weights_pointers_list[layer_idx].update({uid: {}}) self._lora_uid_to_low_ranks[uid].append({}) for lora_module in lora_target_modules: if uid == "-1" or lora_module != "attn_qkv" or all_lora_weights is None: self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = 0 continue if lora_module == "attn_qkv": t_in = all_lora_weights[layer_idx]["in"] t_out = all_lora_weights[layer_idx]["out"] assert t_out.shape[0] % tp_size == 0 t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[tp_rank].contiguous() else: t_in = None t_out = None if t_in is not None and t_out is not None: t_in = t_in.cuda().to( str_dtype_to_torch(dtype)).contiguous() t_out = t_out.cuda().to( str_dtype_to_torch(dtype)).contiguous() rank = t_in.shape[0] self._lora_weights_pointers_list[layer_idx][uid].update( {lora_module: [t_in.data_ptr(), t_out.data_ptr()]}) self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = int(rank) # prevent torch free this buffer self._lora_weights.append(t_in) self._lora_weights.append(t_out) self._lora_cpp_weights[uid].append( torch.concatenate([t_in.flatten(), t_out.flatten()])) self._lora_weight_config[uid].append( np.array([ LoraConfig.LORA_MODULE_IDS[lora_module], layer_idx, int(rank) ], dtype=np.int32)) del nemo_weights def load_from_hf(self, model_dirs, model_config, runtime_mapping): ''' lora config of https://huggingface.co/hfl/chinese-alpaca-2-lora-7b { "base_model_name_or_path": "/Llama-2-7b-hf", "bias": "none", "enable_lora": null, "fan_in_fan_out": false, "inference_mode": true, "lora_alpha": 128.0, "lora_dropout": 0.05, "merge_weights": false, "modules_to_save": [ "embed_tokens", "lm_head" ], "peft_type": "LORA", "r": 64, "target_modules": [ "q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj" ], "task_type": "CAUSAL_LM" } keys in adapter_model.bin: base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.self_attn.k_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.k_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.self_attn.o_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.o_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight torch.Size([11008, 64]) base_model.model.model.layers.0.mlp.up_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.mlp.up_proj.lora_B.weight torch.Size([11008, 64]) base_model.model.model.layers.0.mlp.down_proj.lora_A.weight torch.Size([64, 11008]) base_model.model.model.layers.0.mlp.down_proj.lora_B.weight torch.Size([4096, 64]) ... ''' tp_size = runtime_mapping.tp_size tp_rank = runtime_mapping.tp_rank lora_hf_configs = [{}] ranks = [0] uids = ["-1"] for i, model_dir in enumerate(model_dirs): with open(f"{model_dir}/adapter_config.json", 'r') as f: config = json.load(f) lora_hf_configs.append(config) ranks.append(config["r"]) uids.append(str(i)) new_model_dirs = [""] + model_dirs lora_target_modules = model_config.lora_target_modules dtype = model_config.dtype for uid, rank, model_dir, hf_config in zip(uids, ranks, new_model_dirs, lora_hf_configs): if uid not in self._lora_cpp_weights: self._lora_cpp_weights[uid] = [] if uid not in self._lora_weight_config: self._lora_weight_config[uid] = [] if model_dir != "": lora_model = torch.load(f"{model_dir}/adapter_model.bin") else: lora_model = None self._lora_uid_to_low_ranks[uid] = [] for layer_idx in range(model_config.num_layers): self._lora_weights_pointers_list.append({}) self._lora_weights_pointers_list[layer_idx].update({uid: {}}) self._lora_uid_to_low_ranks[uid].append({}) prefix = "base_model.model.model.layers" for lora_module in lora_target_modules: if uid == "-1" or model_config.trtllm_modules_to_hf_modules[ lora_module] not in hf_config["target_modules"]: self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = 0 continue if lora_module == "attn_q" or lora_module == "attn_k" or lora_module == "attn_v": name = f"{prefix}.{layer_idx}.{lora_module.replace('attn_', 'self_attn.')}_proj" # not split t_in = lora_model[f"{name}.lora_A.weight"] # split by column t_out = lora_model[f"{name}.lora_B.weight"] assert t_out.shape[0] % tp_size == 0 t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[tp_rank].contiguous() elif lora_module == "attn_dense": # split by row t_in = lora_model[ f"{prefix}.{layer_idx}.self_attn.o_proj.lora_A.weight"] assert t_in.shape[1] % tp_size == 0 t_in = torch.split(t_in, t_in.shape[1] // tp_size, dim=1)[tp_rank].contiguous() # not split t_out = lora_model[ f"{prefix}.{layer_idx}.self_attn.o_proj.lora_B.weight"] elif lora_module == "mlp_h_to_4h": # not split t_in = lora_model[ f"{prefix}.{layer_idx}.mlp.gate_proj.lora_A.weight"] # split by column t_out = lora_model[ f"{prefix}.{layer_idx}.mlp.gate_proj.lora_B.weight"] assert t_out.shape[0] % tp_size == 0 t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[tp_rank].contiguous() elif lora_module == "mlp_gate": # not split t_in = lora_model[ f"{prefix}.{layer_idx}.mlp.up_proj.lora_A.weight"] # split by column t_out = lora_model[ f"{prefix}.{layer_idx}.mlp.up_proj.lora_B.weight"] assert t_out.shape[0] % tp_size == 0 t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[tp_rank].contiguous() elif lora_module == "mlp_4h_to_h": # split by row t_in = lora_model[ f"{prefix}.{layer_idx}.mlp.down_proj.lora_A.weight"] assert t_in.shape[0] % tp_size == 0 t_in = torch.split(t_in, t_in.shape[1] // tp_size, dim=1)[tp_rank].contiguous() # not split t_out = lora_model[ f"{prefix}.{layer_idx}.mlp.down_proj.lora_B.weight"] t_in = t_in.cuda().contiguous() t_out = t_out.cuda().contiguous() scale = float(hf_config["lora_alpha"] / hf_config["r"]) t_out = t_out * scale t_in = t_in.float().to(str_dtype_to_torch(dtype)) t_out = t_out.float().to(str_dtype_to_torch(dtype)) self._lora_weights_pointers_list[layer_idx][uid].update( {lora_module: [t_in.data_ptr(), t_out.data_ptr()]}) assert t_in.shape[0] == int(hf_config["r"]) self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = int(hf_config["r"]) # prevent torch free this buffer self._lora_weights.append(t_in) self._lora_weights.append(t_out) self._lora_cpp_weights[uid].append( torch.concatenate([t_in.flatten(), t_out.flatten()])) self._lora_weight_config[uid].append( np.array([ LoraConfig.LORA_MODULE_IDS[lora_module], layer_idx, int(hf_config['r']) ], dtype=np.int32)) del lora_model def save_lora_weights_to_bin(self, out_dir): def save_val(val, dir, key, tp_num=None, write_npy=False): ext = "npy" if write_npy else "bin" suffix = ext if tp_num is None else f"{tp_num}.{ext}" if write_npy: np.save(dir / f"model.{key}.{suffix}", val) else: val.tofile(dir / f"model.{key}.{suffix}") if isinstance(out_dir, str): out_dir_path = Path(out_dir) elif isinstance(out_dir, Path): out_dir_path = out_dir else: assert False for uid in self._lora_cpp_weights: if uid == '-1': continue all_weights = np.expand_dims( np.stack([ torch_to_numpy(w.flatten().contiguous()) for w in self._lora_cpp_weights[uid] ]), 0) all_configs = np.expand_dims( np.stack(self._lora_weight_config[uid]), 0) uid_path = out_dir_path / f"{uid}" uid_path.mkdir(parents=True, exist_ok=True) save_val(all_weights, uid_path, "lora_weights", tp_num=None, write_npy=True) save_val(all_configs, uid_path, "lora_config", tp_num=None, write_npy=True) def uid_to_key(self, uid: str): assert isinstance(uid, str) return self._lora_uid_to_key[uid] def uid_to_low_ranks(self, uid: str): assert isinstance(uid, str) return self._lora_uid_to_low_ranks[uid] @property def lora_weights(self): return self._lora_weights @property def lora_weights_pointers_list(self): return self._lora_weights_pointers_list