import io import json import re import tarfile from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np import torch import yaml from ._utils import (DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy) from .layers.linear import ColumnLinear from .mapping import Mapping from .models.convert_utils import (get_model_path, load_state_dict, split_matrix_tp) if TYPE_CHECKING: from .runtime import ModelConfig def get_all_nemo_lora_weights(lora_weights): layer_weights = defaultdict(dict) adapter_key = "self_attention.adapter_layer.lora_kqv_adapter" layer_pattern = re.compile(r'.*\.layers\.(\d+)\..*') for key, weights in lora_weights.items(): if adapter_key in key: if key.endswith('linear_in.weight'): inout = 'in' elif key.endswith('linear_out.weight'): inout = 'out' else: continue m = layer_pattern.match(key) if m is None: raise KeyError( f"Failed to extract layer index from key {key} using pattern {layer_pattern.pattern}" ) layer_idx = int(m.group(1)) layer_weights[layer_idx][inout] = weights else: raise KeyError(f"unsupported key {key} from Nemo LoRA weights") return layer_weights # The pattern is {layer_prefix:1}.{layer_idx:2}.{module_prefix:3}.{module_name or {expert_name:5}.{expert_idx:6}.{module_name:7} :4}.lora_{A|B:8}.weight HF_LORA_PATTERN = re.compile( r'(.*)\.(\d+)\.(\w+)\.(\w+|\w+\.\w+|(\w+)\.(\d+)\.(\w+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight)' ) def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None): all_weights = defaultdict(lambda: defaultdict(dict)) pattern = HF_LORA_PATTERN for key, weights in lora_weights.items(): m = pattern.match(key) if not m: if "lm_head" not in key and "embed_tokens" not in key: raise KeyError(f"unsupported key {key} from HF LoRA weights") continue if component is not None and component not in m.group(1): continue layer_idx = int(m.group(2)) expert_idx = m.group(6) if expert_idx is not None: expert_idx = int(expert_idx) is_moe = expert_idx is not None if is_moe: expert_name = m.group(5) module_name = m.group(7) hf_module = m.group(3) + "." + expert_name + "." + module_name else: module_name = m.group(4) hf_module = m.group(3) + "." + module_name if hf_module not in hf_modules: hf_module = module_name assert hf_module in hf_modules, f"hf_module {hf_module} is not in supported list {hf_modules}" is_lora_a_or_b = m.group(8) is not None if is_lora_a_or_b: inout_or_mag = "in" if m.group(8) == "A" else "out" else: inout_or_mag = "magnitude" iter_fn(layer_idx, hf_module, expert_idx, inout_or_mag, weights) if not is_moe: all_weights[layer_idx][hf_module][inout_or_mag] = weights else: all_weights[layer_idx][hf_module].setdefault(expert_idx, {}) all_weights[layer_idx][hf_module][expert_idx][ inout_or_mag] = weights return all_weights def get_all_hf_lora_weights(lora_weights, hf_modules, component=None): def iter_fn(layer_idx, hf_module, expert_idx, inout, weights): if expert_idx is None: all_weights[layer_idx][hf_module][inout] = weights else: all_weights[layer_idx][hf_module].setdefault(expert_idx, {}) all_weights[layer_idx][hf_module][expert_idx][inout] = weights all_weights = defaultdict(lambda: defaultdict(dict)) iterate_hf_lora(iter_fn, lora_weights, hf_modules, component) return all_weights def get_hf_target_modules(lora_weights, hf_modules): def iter_fn(layer_idx, hf_module, expert_idx, inout, weights): hf_target_modules.add(hf_module) hf_target_modules = set() iterate_hf_lora(iter_fn, lora_weights, hf_modules) return hf_target_modules def invert_module_mapping(trtllm_modules_to_hf_modules): hf_modules_to_trtllm_modules = {} for k, hf_modules in trtllm_modules_to_hf_modules.items(): if isinstance(hf_modules, list): for hf_module in hf_modules: hf_modules_to_trtllm_modules[hf_module] = k else: hf_modules_to_trtllm_modules[hf_modules] = k return hf_modules_to_trtllm_modules def norm_dora_magnitude(W0: torch.Tensor, A: torch.Tensor, B: torch.Tensor, m: torch.Tensor, scaling: float = 1.0): new_weight_v = W0 + (B @ A) * scaling norm_m = m.view(-1) / (torch.linalg.norm(new_weight_v, dim=1)).detach() return norm_m @dataclass class LoraConfig(DictConversion): lora_dir: List[str] = field(default_factory=list) lora_ckpt_source: str = 'hf' max_lora_rank: int = 64 lora_target_modules: List[str] = field(default_factory=list) trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict) max_loras: int = 4 max_cpu_loras: int = 4 def __post_init__(self): assert self.lora_ckpt_source in [ 'hf', 'nemo' ], f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}" @property def missing_qkv_modules(self) -> List[str]: return LoraManager.get_missing_qkv_modules(self.lora_target_modules) @dataclass class LoraModelConfig: lora_target_modules: list[str] trtllm_modules_to_hf_modules: dict[str, str] hidden_size: int dtype: str class HfLoraLoader: def __init__(self, lora_dirs: List[str]): self.lora_target_modules = [] self.is_valid = False self.lm_head = None self.embed_tokens = None self.vocab_size = 0 if len(lora_dirs) == 0: return for lora_dir in lora_dirs: model_path = get_model_path(lora_dir, "adapter_model") if model_path is None: raise ValueError( f"adapter_model file does not exist in {lora_dir}") config_file = Path(f"{lora_dir}/adapter_config.json") if not config_file.exists(): raise ValueError(f"{config_file} does not exist") if not config_file.is_file(): raise ValueError(f"{config_file} is not a file") self.is_valid = True lora_dir = lora_dirs[0] with open(f"{lora_dir}/adapter_config.json") as f: adapter_config = json.load(f) model_path = get_model_path(lora_dir, "adapter_model") if model_path is None: raise ValueError(f"adapter_model file does not exist in {lora_dir}") lora_weight = load_state_dict(model_path) self.lora_weight = lora_weight if adapter_config["modules_to_save"] is not None: if "lm_head" in adapter_config["modules_to_save"]: self.lm_head = lora_weight["base_model.model.lm_head.weight"] self.vocab_size = self.lm_head.shape[0] if "embed_tokens" in adapter_config["modules_to_save"]: self.embed_tokens = lora_weight[ "base_model.model.model.embed_tokens.weight"] def get_target_modules(self, trtllm_modules_to_hf_modules): hf_modules_to_trtllm_modules = invert_module_mapping( trtllm_modules_to_hf_modules) lora_target_modules = set() if self.is_valid: hf_target_modules = get_hf_target_modules( self.lora_weight, hf_modules=set(hf_modules_to_trtllm_modules.keys()), ) for m in hf_target_modules: trtllm_module = hf_modules_to_trtllm_modules[m] lora_target_modules.add(trtllm_module) return list(lora_target_modules) class NemoLoraLoader: def __init__(self, lora_dirs: List[str]): self.lora_target_modules = [] self.is_valid = False if len(lora_dirs) == 0: return for lora_file in lora_dirs: path = Path(lora_file) if not path.exists(): raise ValueError(f"{path} does not exist") if not path.is_file(): raise ValueError(f"{path} is not a file") self.is_valid = True # Hardcoded since LoraManager only supports this case now self.lora_target_modules = ["attn_qkv"] def load_nemo_lora(model, lora_config: LoraConfig): lora_loader = NemoLoraLoader(lora_config.lora_dir) if len(lora_config.lora_target_modules) == 0: lora_config.lora_target_modules = lora_loader.lora_target_modules def get_default_trtllm_modules_to_hf_modules(): return { "attn_q": "q_proj", "attn_k": "k_proj", "attn_v": "v_proj", "attn_dense": "o_proj", "mlp_h_to_4h": "gate_proj", "mlp_4h_to_h": "down_proj", "mlp_gate": "up_proj", "mlp_gate_up": "gate_up_proj", "moe_h_to_4h": "w1", "moe_4h_to_h": "w2", "moe_gate": "w3", "moe_router": "gate", } def load_torch_hf_lora(lora_config: LoraConfig): """ This is a shortned version of load_hf_lora that is used for torch models. Main problem is model.config in legacy code is custom (defined in the legacy code) whereas pivot model config is the transformer's one. """ # TODO smor- need to comibe with load_hf_lora lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules( ) assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir" lora_loader = HfLoraLoader(lora_config.lora_dir) if len(lora_config.lora_target_modules) == 0: lora_config.lora_target_modules = lora_loader.get_target_modules( lora_config.trtllm_modules_to_hf_modules) if len(lora_config.lora_target_modules) == 0: raise ValueError( "lora_target_modules is empty. " "Please specify lora_target_modules or provide lora_dir to infer lora_target_modules." ) missing_qkv_modules = LoraManager.get_missing_qkv_modules( lora_config.lora_target_modules) lora_config.lora_target_modules.extend(missing_qkv_modules) def load_hf_lora( model, lora_config: LoraConfig, trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None, ): trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules or get_default_trtllm_modules_to_hf_modules( ) lora_config.trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules lora_loader = HfLoraLoader(lora_config.lora_dir) if len(lora_config.lora_target_modules) == 0: lora_config.lora_target_modules = lora_loader.get_target_modules( trtllm_modules_to_hf_modules) if len(lora_config.lora_target_modules) == 0: raise ValueError( "lora_target_modules is empty. " "Please specify lora_target_modules or provide lora_dir to infer lora_target_modules." ) missing_qkv_modules = LoraManager.get_missing_qkv_modules( lora_config.lora_target_modules) lora_config.lora_target_modules.extend(missing_qkv_modules) if lora_loader.is_valid: config = model.config torch_dtype = str_dtype_to_torch(config.dtype) # the lora checkpoint might finetune the embedding if lora_loader.vocab_size != 0: config.vocab_size = lora_loader.vocab_size mapping = config.mapping if mapping.is_first_pp_rank() and lora_loader.embed_tokens is not None: weight = lora_loader.embed_tokens if config.use_parallel_embedding: weight = split_matrix_tp( weight, mapping.tp_size, mapping.tp_rank, dim=config.embedding_sharding_dim, ) if model.transformer.vocab_embedding.weight.raw_value.shape != weight.shape: model.transformer.vocab_embedding = model.transformer.vocab_embedding.__class__( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, dtype=config.dtype, tp_size=mapping.tp_size if config.use_parallel_embedding else 1, tp_group=mapping.tp_group if config.use_parallel_embedding else None, sharding_dim=config.embedding_sharding_dim, tp_rank=mapping.tp_rank, ) model.transformer.vocab_embedding.weight.value = weight.to( torch_dtype) if mapping.is_last_pp_rank() and lora_loader.lm_head is not None: weight = lora_loader.lm_head vocab_size = lora_loader.vocab_size if vocab_size % mapping.tp_size != 0: # padding vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) pad_width = vocab_size_padded - vocab_size weight = torch.from_numpy( np.pad(torch_to_numpy(weight), ((0, pad_width), (0, 0)), 'constant', constant_values=0)) else: vocab_size_padded = vocab_size if model.lm_head.weight.raw_value.shape != weight.shape: model.lm_head = ColumnLinear( config.hidden_size, vocab_size_padded, bias=False, dtype=config.dtype, tp_group=mapping.tp_group, tp_size=mapping.tp_size, gather_output=True, ) model.lm_head.weight.value = split_matrix_tp( weight, mapping.tp_size, mapping.tp_rank, dim=0, ).to(torch_dtype) def use_lora( model, lora_config: LoraConfig, trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None, ): if lora_config.lora_ckpt_source == "nemo": load_nemo_lora(model, lora_config) elif lora_config.lora_ckpt_source == "hf": load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules) else: raise ValueError( f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}") def unpack_nemo_weights(nemo_archive_path): with tarfile.open(nemo_archive_path) as tar: try: model_weights_file = tar.extractfile("model_weights.ckpt") model_config_file = tar.extractfile("model_config.yaml") except KeyError: try: model_weights_file = tar.extractfile("./model_weights.ckpt") model_config_file = tar.extractfile("./model_config.yaml") except KeyError: err_str = "Both model_weights paths not found in the tar archive." raise Exception(err_str) if model_weights_file is None or model_config_file is None: raise Exception("Could not extract model weights or config files") model_config_content = model_config_file.read() model_config_dict = yaml.safe_load(model_config_content) model_weights_bytes = model_weights_file.read() model_weights_dict = torch.load(io.BytesIO(model_weights_bytes), map_location=torch.device("cpu")) return model_config_dict, model_weights_dict class LoraManager(object): LORA_MODULE_IDS = { "attn_qkv": 0, "attn_q": 1, "attn_k": 2, "attn_v": 3, "attn_dense": 4, "mlp_h_to_4h": 5, "mlp_4h_to_h": 6, "mlp_gate": 7, "cross_attn_qkv": 8, "cross_attn_q": 9, "cross_attn_k": 10, "cross_attn_v": 11, "cross_attn_dense": 12, "moe_h_to_4h": 13, "moe_4h_to_h": 14, "moe_gate": 15, "moe_router": 16, "mlp_router": 17, "mlp_gate_up": 18, } def __init__(self): ''' _lora_uid_to_low_ranks: dict[str -> dict[int -> dict[str -> int]]] { uid: { 0: { lora_module: int }, # layer_0_rank, 1: { lora_module: int }, # layer_1_rank, ... } } _lora_weights_pointers_list: dict[str -> dict[int -> dict[str -> [Tensor, Tensor]]]] { uid: { 0: { lora_module: [t_in, t_out] }, # layer_0, 1: { lora_module: [t_in, t_out] }, # layer_1, ... } } ''' self._lora_uid_counter = 0 self._lora_uid_to_low_ranks: Dict[str, Dict[int, Dict[str, int]]] = {} # hold the torch tensors and prevent them from being freed # TODO(enweiz): free device tensors if it's used for c++ runtime only self._lora_weights: List[torch.Tensor] = [] self._lora_weights_pointers_list: Dict[str, Dict[int, Dict[str, List[int]]]] = {} self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu self.lora_target_modules: List[str] = [] @staticmethod def get_missing_qkv_modules(lora_target_modules): # In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time. # However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor to fill the missing ones. missing_qkv_modules = [] if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]): for lora_module in ["attn_q", "attn_k", "attn_v"]: if lora_module not in lora_target_modules: missing_qkv_modules.append(lora_module) if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]): for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]: if lora_module not in lora_target_modules: missing_qkv_modules.append(lora_module) return missing_qkv_modules @property def missing_qkv_modules(self) -> List[str]: return LoraManager.get_missing_qkv_modules(self.lora_target_modules) def load_from_ckpt(self, model_dirs_or_files: List[str], model_config: Union['ModelConfig', LoraModelConfig], runtime_mapping: Optional[Mapping] = None, uids: Optional[List[str]] = None, ckpt_source: str = 'hf'): if ckpt_source == 'hf': self.load_from_hf(model_dirs=model_dirs_or_files, model_config=model_config, runtime_mapping=runtime_mapping, uids=uids) elif ckpt_source == 'nemo': self.load_from_nemo(model_files=model_dirs_or_files, model_config=model_config, runtime_mapping=runtime_mapping, uids=uids) else: assert False, f"{self.__class__.__name__} does not support source {ckpt_source}" def load_from_nemo(self, model_files: List[str], model_config: Union['ModelConfig', LoraModelConfig], runtime_mapping: Optional[Mapping] = None, uids: Optional[List[str]] = None): if runtime_mapping is None: runtime_mapping = Mapping() tp_size = runtime_mapping.tp_size tp_rank = runtime_mapping.tp_rank if uids is None: uids = [self._generate_uid() for _ in range(len(model_files))] assert len(uids) == len(model_files) new_uids, new_model_files = [], [] for uid, model_file in zip(uids, model_files): if uid in self._lora_uid_to_low_ranks: continue new_uids.append(uid) new_model_files.append(model_file) if len(new_uids) == 0: return self.lora_target_modules = model_config.lora_target_modules def load_from_model_file(uid, model_file): if uid not in self._cpp_lora_weights: self._cpp_lora_weights[uid] = [ ] # Will be converted to tensor later if uid not in self._cpp_lora_config: self._cpp_lora_config[uid] = [ ] # Will be converted to tensor later _, nemo_weights = unpack_nemo_weights(model_file) all_lora_weights = get_all_nemo_lora_weights(nemo_weights) self._lora_uid_to_low_ranks[uid] = {} self._lora_weights_pointers_list[uid] = {} for layer_idx in sorted(all_lora_weights.keys()): self._lora_uid_to_low_ranks[uid][layer_idx] = {} self._lora_weights_pointers_list[uid][layer_idx] = {} for lora_module in self.lora_target_modules: if lora_module != "attn_qkv": self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = 0 continue if lora_module == "attn_qkv": t_in = all_lora_weights[layer_idx]["in"] t_out = all_lora_weights[layer_idx]["out"] assert t_out.shape[0] % tp_size == 0 t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[tp_rank].contiguous() else: t_in = None t_out = None if t_in is not None and t_out is not None: t_in = t_in.cuda().to( str_dtype_to_torch( model_config.dtype)).contiguous() t_out = t_out.cuda().to( str_dtype_to_torch( model_config.dtype)).contiguous() rank = t_in.shape[0] self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = int(rank) self._lora_weights_pointers_list[uid][layer_idx][ lora_module] = [ t_in.data_ptr(), t_out.data_ptr(), 0 ] # prevent torch free this buffer self._lora_weights.append(t_in) self._lora_weights.append(t_out) self._cpp_lora_weights[uid].append( torch.concatenate( [t_in.flatten().cpu(), t_out.flatten().cpu()])) self._cpp_lora_config[uid].append( torch.tensor([ self.LORA_MODULE_IDS[lora_module], layer_idx, int(rank) ], dtype=torch.int32)) max_weight_size = max( w.size(0) for w in self._cpp_lora_weights[uid]) self._cpp_lora_weights[uid] = torch.stack([ torch.nn.functional.pad(w, (0, max_weight_size - w.size(0))) for w in self._cpp_lora_weights[uid] ]) self._cpp_lora_config[uid] = torch.stack( [c for c in self._cpp_lora_config[uid]]) for uid, model_file in zip(new_uids, new_model_files): load_from_model_file(uid, model_file) release_gc() def load_from_hf(self, model_dirs: List[str], model_config: Union['ModelConfig', LoraModelConfig], runtime_mapping: Optional[Mapping] = None, uids: Optional[List[str]] = None, component: Optional[str] = None): ''' lora config of https://huggingface.co/hfl/chinese-alpaca-2-lora-7b { "base_model_name_or_path": "/Llama-2-7b-hf", "bias": "none", "enable_lora": null, "fan_in_fan_out": false, "inference_mode": true, "lora_alpha": 128.0, "lora_dropout": 0.05, "merge_weights": false, "modules_to_save": [ "embed_tokens", "lm_head" ], "peft_type": "LORA", "r": 64, "target_modules": [ "q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj" ], "task_type": "CAUSAL_LM" } keys in adapter_model.bin: base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.self_attn.k_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.k_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.self_attn.o_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.o_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight torch.Size([11008, 64]) base_model.model.model.layers.0.mlp.up_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.mlp.up_proj.lora_B.weight torch.Size([11008, 64]) base_model.model.model.layers.0.mlp.down_proj.lora_A.weight torch.Size([64, 11008]) base_model.model.model.layers.0.mlp.down_proj.lora_B.weight torch.Size([4096, 64]) ... ''' if runtime_mapping is None: runtime_mapping = Mapping() tp_size = runtime_mapping.tp_size tp_rank = runtime_mapping.tp_rank if uids is None: uids = [self._generate_uid() for _ in range(len(model_dirs))] assert len(uids) == len(model_dirs) new_uids, new_model_dirs = [], [] for uid, model_dir in zip(uids, model_dirs): if uid in self._lora_uid_to_low_ranks: continue new_uids.append(uid) new_model_dirs.append(model_dir) if len(new_uids) == 0: return lora_hf_configs = [] for model_dir in new_model_dirs: with open(f"{model_dir}/adapter_config.json", 'r') as f: config = json.load(f) lora_hf_configs.append(config) self.lora_target_modules = model_config.lora_target_modules hf_modules_to_trtllm_modules = invert_module_mapping( model_config.trtllm_modules_to_hf_modules) hf_modules = set(hf_modules_to_trtllm_modules.keys()) def preprocess_lora_weights(lora_model): # Swap weights of gate_up_proj for key, value in lora_model.items(): if "gate_up_proj.lora_B.weight" in key: original_weights = value.contiguous().clone() half_split = original_weights.shape[0] // 2 first_half = original_weights[:half_split, :] second_half = original_weights[half_split:, :] value = torch.cat((second_half, first_half), dim=0) lora_model[key] = value return lora_model def load_from_model_dir(uid, model_dir, hf_config): if uid not in self._cpp_lora_weights: self._cpp_lora_weights[uid] = [ ] # Will be converted to tensor later if uid not in self._cpp_lora_config: self._cpp_lora_config[uid] = [ ] # Will be converted to tensor later lora_model = load_state_dict( get_model_path(model_dir, "adapter_model")) if lora_model is None: raise ValueError( f"Failed to load adapter_model from {model_dir}") lora_model = preprocess_lora_weights(lora_model) all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component) rank = int(hf_config["r"]) rs_lora = bool(hf_config.get("use_rslora", False)) self._lora_uid_to_low_ranks[uid] = {} self._lora_weights_pointers_list[uid] = {} for layer_idx in sorted(all_weights.keys()): layer_weights = all_weights[layer_idx] self._lora_uid_to_low_ranks[uid][layer_idx] = {} self._lora_weights_pointers_list[uid][layer_idx] = {} for lora_module in self.missing_qkv_modules: hf_module = model_config.trtllm_modules_to_hf_modules[ lora_module] if isinstance(hf_module, list): hf_module = hf_module[0] layer_weights[hf_module] = { "in": torch.zeros(rank, model_config.hidden_size), "out": torch.zeros(model_config.hidden_size, rank), } for hf_module, module_weights in layer_weights.items(): lora_module = hf_modules_to_trtllm_modules[hf_module] if lora_module not in self.lora_target_modules: self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = 0 continue if "in" not in module_weights: is_moe = True t_in = torch.stack([ module_weights[expert_idx]["in"] for expert_idx in sorted(module_weights.keys()) ]) t_out = torch.stack([ module_weights[expert_idx]["out"] for expert_idx in sorted(module_weights.keys()) ]) for weights in module_weights.values(): if "mag" in weights: # TODO(oargov): this might work, but I had no MoE DoRA models to test raise ValueError( "DoRA with MoE is not supported") t_mag = None else: is_moe = False t_in = module_weights["in"] t_out = module_weights["out"] t_mag = module_weights.get("magnitude", None) is_dora = t_mag is not None if lora_module in ["moe_router", "mlp_router"]: pass elif "moe" in lora_module and runtime_mapping.has_moe_ep(): pass elif lora_module in [ "attn_dense", "cross_attn_dense", "mlp_4h_to_h", "moe_4h_to_h", ]: # split by row dim = 2 if is_moe else 1 assert t_in.shape[dim] % tp_size == 0 t_in = torch.split(t_in, t_in.shape[dim] // tp_size, dim=dim)[tp_rank].contiguous() else: # split by column dim = 1 if is_moe else 0 assert t_out.shape[dim] % tp_size == 0 t_out = torch.split(t_out, t_out.shape[dim] // tp_size, dim=dim)[tp_rank].contiguous() if dim == 0 and is_dora and t_mag is not None: t_mag = torch.split(t_mag, t_mag.shape[0] // tp_size, dim=0)[tp_rank].contiguous() rank_dim = 1 if is_moe else 0 effective_rank = t_in.shape[rank_dim] t_in = t_in.cuda().contiguous() t_out = t_out.cuda().contiguous() if is_dora and t_mag is not None: t_mag = t_mag.cuda().contiguous() if rs_lora: scale = float( hf_config["lora_alpha"]) / np.sqrt(effective_rank) else: scale = float(hf_config["lora_alpha"]) / effective_rank t_out = t_out * scale t_in = t_in.to(str_dtype_to_torch(model_config.dtype)) t_out = t_out.to(str_dtype_to_torch(model_config.dtype)) if is_dora and t_mag is not None: t_mag = t_mag.to(str_dtype_to_torch(model_config.dtype)) self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = effective_rank self._lora_weights_pointers_list[uid][layer_idx][ lora_module] = [ t_in.data_ptr(), t_out.data_ptr(), t_mag.data_ptr() if (is_dora and t_mag is not None) else 0 ] # prevent torch free this buffer self._lora_weights.append(t_in) self._lora_weights.append(t_out) if is_dora and t_mag is not None: self._lora_weights.append(t_mag) t_in_cpu = t_in.flatten().cpu() t_out_cpu = t_out.flatten().cpu() weights_to_concat = [t_in_cpu, t_out_cpu] if is_dora and t_mag is not None: t_mag_cpu = t_mag.flatten().cpu() weights_to_concat.append(t_mag_cpu) self._cpp_lora_weights[uid].append( torch.cat(weights_to_concat)) self._cpp_lora_config[uid].append( torch.tensor([ self.LORA_MODULE_IDS[lora_module], layer_idx, effective_rank, is_dora ], dtype=torch.int32)) max_weight_size = max( w.size(0) for w in self._cpp_lora_weights[uid]) self._cpp_lora_weights[uid] = torch.stack([ torch.nn.functional.pad(w, (0, max_weight_size - w.size(0))) for w in self._cpp_lora_weights[uid] ]) self._cpp_lora_config[uid] = torch.stack( [c for c in self._cpp_lora_config[uid]]) for uid, model_dir, hf_config in zip(new_uids, new_model_dirs, lora_hf_configs): load_from_model_dir(uid, model_dir, hf_config) release_gc() @property def lora_weights(self): return self._lora_weights @property def lora_weights_pointers_list(self): return self._lora_weights_pointers_list @property def cpp_lora_weights(self): return self._cpp_lora_weights @property def cpp_lora_config(self): return self._cpp_lora_config def uid_to_low_ranks(self, uid: str): assert isinstance(uid, str) return self._lora_uid_to_low_ranks[uid] def _generate_uid(self): while str(self._lora_uid_counter) in self._lora_uid_to_low_ranks: self._lora_uid_counter += 1 uid = str(self._lora_uid_counter) self._lora_uid_counter += 1 return uid @property def num_lora_adapters(self): return len([uid for uid in self._lora_uid_to_low_ranks if uid != '-1']) def save_lora_weights_to_bin(self, out_dir): def save_val(val, dir, key, tp_num=None, write_npy=False): ext = "npy" if write_npy else "bin" suffix = ext if tp_num is None else f"{tp_num}.{ext}" if write_npy: np.save(dir / f"model.{key}.{suffix}", val) else: val.tofile(dir / f"model.{key}.{suffix}") if isinstance(out_dir, str): out_dir_path = Path(out_dir) elif isinstance(out_dir, Path): out_dir_path = out_dir else: assert False for uid in self.cpp_lora_weights: if uid == '-1': continue all_weights = np.expand_dims( torch_to_numpy(self.cpp_lora_weights[uid]), 0) all_configs = np.expand_dims( torch_to_numpy(self.cpp_lora_config[uid]), 0) uid_path = out_dir_path / f"{uid}" uid_path.mkdir(parents=True, exist_ok=True) save_val(all_weights, uid_path, "lora_weights", tp_num=None, write_npy=True) save_val(all_configs, uid_path, "lora_config", tp_num=None, write_npy=True) def input_buffers(self, lora_uids, mapping: Mapping, num_layers: int): inputs = {} for layer_idx in mapping.pp_layers(num_layers): for lora_module in (self.lora_target_modules + self.missing_qkv_modules): lora_ranks_ = [] lora_ptrs_ = [] for lora_uid in lora_uids: lora_rank = 0 lora_ptrs = [0, 0, 0] if lora_uid != "-1": low_ranks = self.uid_to_low_ranks(lora_uid) if (layer_idx in low_ranks and lora_module in low_ranks[layer_idx].keys() and low_ranks[layer_idx][lora_module] != 0): lora_rank = low_ranks[layer_idx][lora_module] lora_ptrs = self.lora_weights_pointers_list[ lora_uid][layer_idx][lora_module] lora_ranks_.append(lora_rank) lora_ptrs_.append(lora_ptrs) inputs[ f'{lora_module}_lora_ranks_{layer_idx}'] = torch.IntTensor( lora_ranks_) inputs[ f'{lora_module}_lora_weights_pointers_{layer_idx}'] = torch.LongTensor( lora_ptrs_) return inputs