mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
1270 lines
51 KiB
Python
1270 lines
51 KiB
Python
import io
|
|
import itertools
|
|
import json
|
|
import logging
|
|
import re
|
|
import tarfile
|
|
import warnings
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import yaml
|
|
|
|
from tensorrt_llm.bindings import internal as tb_internal
|
|
|
|
from ._utils import pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
|
|
from .layers.linear import ColumnLinear
|
|
from .lora_helper import (
|
|
LoraConfig,
|
|
get_default_trtllm_modules_to_hf_modules,
|
|
get_missing_qkv_modules_from_lora_modules,
|
|
)
|
|
from .mapping import Mapping
|
|
from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp
|
|
|
|
if TYPE_CHECKING:
|
|
from .runtime import ModelConfig
|
|
|
|
NEMO_SUPPORTED_LORA_MODULES = {"attn_qkv"}
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _check_lora_in_out(
|
|
layer_idx: int, lora_module: str, available_matrices: Dict, source_identifier: str
|
|
) -> None:
|
|
"""Check that 'in' and 'out' matrices are present."""
|
|
missing = []
|
|
if "in" not in available_matrices:
|
|
missing.append("'in' matrix (lora_A equivalent)")
|
|
if "out" not in available_matrices:
|
|
missing.append("'out' matrix (lora_B equivalent)")
|
|
|
|
if missing:
|
|
raise ValueError(
|
|
f"Layer {layer_idx} is missing required {' and '.join(missing)} for {lora_module} "
|
|
f"in LoRA weights from {source_identifier}. "
|
|
f"LoRA adapters must contain both 'in' and 'out' matrices for all layers. "
|
|
f"Please check if the LoRA checkpoint is complete or was corrupted during loading."
|
|
)
|
|
|
|
|
|
def _is_moe_module_weights(module_weights: Dict) -> bool:
|
|
"""Check if module weights represent MoE (integer expert indices with nested dicts)."""
|
|
if not module_weights:
|
|
return False
|
|
|
|
# All keys should be integers (expert indices) and values should be dicts
|
|
return all(isinstance(k, int) for k in module_weights.keys()) and all(
|
|
isinstance(v, dict) for v in module_weights.values()
|
|
)
|
|
|
|
|
|
def get_all_nemo_lora_weights(
|
|
lora_weights: Dict[str, torch.Tensor],
|
|
) -> Dict[int, Dict[str, torch.Tensor]]:
|
|
"""Extract and organize NeMo LoRA weights by layer and direction.
|
|
|
|
Args:
|
|
lora_weights: Dictionary mapping weight keys to tensors from NeMo checkpoint
|
|
|
|
Returns:
|
|
Dictionary mapping layer_idx -> {direction -> tensor} where direction is 'in' or 'out'
|
|
|
|
Raises:
|
|
KeyError: If unsupported keys are found or layer extraction fails
|
|
"""
|
|
layer_weights: Dict[int, Dict[str, torch.Tensor]] = defaultdict(dict)
|
|
adapter_key = "self_attention.adapter_layer.lora_kqv_adapter"
|
|
layer_pattern = re.compile(r".*\.layers\.(\d+)\..*")
|
|
for key, weights in lora_weights.items():
|
|
if adapter_key in key:
|
|
if key.endswith("linear_in.weight"):
|
|
inout = "in"
|
|
elif key.endswith("linear_out.weight"):
|
|
inout = "out"
|
|
else:
|
|
continue
|
|
m = layer_pattern.match(key)
|
|
if m is None:
|
|
raise KeyError(
|
|
f"Failed to extract layer index from key {key} using pattern {layer_pattern.pattern}"
|
|
)
|
|
layer_idx = int(m.group(1))
|
|
layer_weights[layer_idx][inout] = weights
|
|
else:
|
|
raise KeyError(f"unsupported key {key} from Nemo LoRA weights")
|
|
return layer_weights
|
|
|
|
|
|
# The pattern is {layer_prefix:1}.{layer_idx:2}.{module_prefix:3}.{module_name or {expert_name:5}.{expert_idx:6}.{module_name:7} :4}.lora_{A|B:8}.weight # noqa: E501
|
|
HF_LORA_PATTERN = re.compile(
|
|
r"(.*)\.(\d+)\.(\w+)\.(\w+|\w+\.\w+|(\w+)\.(\d+)\.(\w+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight)"
|
|
)
|
|
|
|
|
|
def iterate_hf_lora(
|
|
iter_fn,
|
|
lora_weights: Dict[str, torch.Tensor],
|
|
hf_modules: Set[str],
|
|
component: Optional[str] = None,
|
|
):
|
|
"""Iterate over HuggingFace LoRA weights and call iterator function for each weight.
|
|
|
|
Args:
|
|
iter_fn: Function to call for each weight with signature
|
|
(layer_idx, hf_module, expert_idx, inout_or_mag, weights)
|
|
lora_weights: Dictionary mapping weight keys to tensors from HF checkpoint
|
|
hf_modules: Set of supported HF module names
|
|
component: Optional component name to filter by (e.g., 'decoder')
|
|
|
|
Returns:
|
|
Nested dictionary structure organizing the weights
|
|
|
|
Raises:
|
|
KeyError: If unsupported keys are found
|
|
AssertionError: If HF module is not in supported list
|
|
"""
|
|
all_weights = defaultdict(lambda: defaultdict(dict))
|
|
pattern = HF_LORA_PATTERN
|
|
for key, weights in lora_weights.items():
|
|
m = pattern.match(key)
|
|
if not m:
|
|
if "lm_head" not in key and "embed_tokens" not in key:
|
|
raise KeyError(f"unsupported key {key} from HF LoRA weights")
|
|
continue
|
|
if component is not None and component not in m.group(1):
|
|
continue
|
|
layer_idx = int(m.group(2))
|
|
expert_idx = m.group(6)
|
|
if expert_idx is not None:
|
|
expert_idx = int(expert_idx)
|
|
is_moe = expert_idx is not None
|
|
if is_moe:
|
|
expert_name = m.group(5)
|
|
module_name = m.group(7)
|
|
hf_module = m.group(3) + "." + expert_name + "." + module_name
|
|
else:
|
|
module_name = m.group(4)
|
|
hf_module = m.group(3) + "." + module_name
|
|
if hf_module not in hf_modules:
|
|
hf_module = module_name
|
|
assert hf_module in hf_modules, (
|
|
f"hf_module {hf_module} is not in supported list {hf_modules}"
|
|
)
|
|
|
|
is_lora_a_or_b = m.group(8) is not None
|
|
if is_lora_a_or_b:
|
|
inout_or_mag = "in" if m.group(8) == "A" else "out"
|
|
else:
|
|
inout_or_mag = "magnitude"
|
|
|
|
iter_fn(layer_idx, hf_module, expert_idx, inout_or_mag, weights)
|
|
if not is_moe:
|
|
all_weights[layer_idx][hf_module][inout_or_mag] = weights
|
|
else:
|
|
all_weights[layer_idx][hf_module].setdefault(expert_idx, {})
|
|
all_weights[layer_idx][hf_module][expert_idx][inout_or_mag] = weights
|
|
return all_weights
|
|
|
|
|
|
def get_all_hf_lora_weights(
|
|
lora_weights: Dict[str, torch.Tensor], hf_modules: Set[str], component: Optional[str] = None
|
|
):
|
|
"""Extract and organize all HuggingFace LoRA weights by layer and module.
|
|
|
|
Args:
|
|
lora_weights: Dictionary mapping weight keys to tensors from HF checkpoint
|
|
hf_modules: Set of supported HF module names
|
|
component: Optional component name to filter by (e.g., 'decoder')
|
|
|
|
Returns:
|
|
Nested dictionary organizing weights by layer, module, and potentially expert
|
|
"""
|
|
|
|
def iter_fn(layer_idx, hf_module, expert_idx, inout, weights):
|
|
if expert_idx is None:
|
|
all_weights[layer_idx][hf_module][inout] = weights
|
|
else:
|
|
all_weights[layer_idx][hf_module].setdefault(expert_idx, {})
|
|
all_weights[layer_idx][hf_module][expert_idx][inout] = weights
|
|
|
|
all_weights = defaultdict(lambda: defaultdict(dict))
|
|
iterate_hf_lora(iter_fn, lora_weights, hf_modules, component)
|
|
return all_weights
|
|
|
|
|
|
def get_hf_target_modules(lora_weights, hf_modules):
|
|
def iter_fn(layer_idx, hf_module, expert_idx, inout, weights):
|
|
hf_target_modules.add(hf_module)
|
|
|
|
hf_target_modules = set()
|
|
iterate_hf_lora(iter_fn, lora_weights, hf_modules)
|
|
return hf_target_modules
|
|
|
|
|
|
def invert_module_mapping(
|
|
trtllm_modules_to_hf_modules: Dict[str, Union[str, List[str]]],
|
|
) -> Dict[str, str]:
|
|
"""Invert module mapping from TensorRT LLM -> HF to HF -> TensorRT-LLM.
|
|
|
|
Args:
|
|
trtllm_modules_to_hf_modules: Mapping from TensorRT LLM module names to HF module names
|
|
(values can be strings or lists of strings)
|
|
|
|
Returns:
|
|
Dictionary mapping HF module names to TensorRT LLM module names
|
|
"""
|
|
hf_modules_to_trtllm_modules: Dict[str, str] = {}
|
|
for k, hf_modules in trtllm_modules_to_hf_modules.items():
|
|
if isinstance(hf_modules, list):
|
|
for hf_module in hf_modules:
|
|
hf_modules_to_trtllm_modules[hf_module] = k
|
|
else:
|
|
hf_modules_to_trtllm_modules[hf_modules] = k
|
|
return hf_modules_to_trtllm_modules
|
|
|
|
|
|
def norm_dora_magnitude(
|
|
W0: torch.Tensor, A: torch.Tensor, B: torch.Tensor, m: torch.Tensor, scaling: float = 1.0
|
|
):
|
|
new_weight_v = W0 + (B @ A) * scaling
|
|
norm_m = m.view(-1) / (torch.linalg.norm(new_weight_v, dim=1)).detach()
|
|
return norm_m
|
|
|
|
|
|
@dataclass
|
|
class LoraModelConfig:
|
|
lora_target_modules: list[str]
|
|
trtllm_modules_to_hf_modules: dict[str, str]
|
|
hidden_size: int
|
|
dtype: str
|
|
swap_gate_up_proj_lora_b_weight: bool = True
|
|
|
|
|
|
class HfLoraLoader:
|
|
def __init__(self, lora_dirs: List[str]):
|
|
self.lora_target_modules = []
|
|
self.is_valid = False
|
|
self.lm_head = None
|
|
self.embed_tokens = None
|
|
self.vocab_size = 0
|
|
|
|
if len(lora_dirs) == 0:
|
|
return
|
|
|
|
for lora_dir in lora_dirs:
|
|
model_path = get_model_path(lora_dir, "adapter_model")
|
|
if model_path is None:
|
|
raise ValueError(f"adapter_model file does not exist in {lora_dir}")
|
|
config_file = Path(f"{lora_dir}/adapter_config.json")
|
|
if not config_file.exists():
|
|
raise ValueError(f"{config_file} does not exist")
|
|
if not config_file.is_file():
|
|
raise ValueError(f"{config_file} is not a file")
|
|
self.is_valid = True
|
|
|
|
lora_dir = lora_dirs[0]
|
|
with open(f"{lora_dir}/adapter_config.json") as f:
|
|
adapter_config = json.load(f)
|
|
|
|
model_path = get_model_path(lora_dir, "adapter_model")
|
|
if model_path is None:
|
|
raise ValueError(f"adapter_model file does not exist in {lora_dir}")
|
|
lora_weight = load_state_dict(model_path)
|
|
self.lora_weight = lora_weight
|
|
if adapter_config.get("modules_to_save") is not None:
|
|
if "lm_head" in adapter_config["modules_to_save"]:
|
|
self.lm_head = lora_weight["base_model.model.lm_head.weight"]
|
|
self.vocab_size = self.lm_head.shape[0]
|
|
|
|
if "embed_tokens" in adapter_config["modules_to_save"]:
|
|
self.embed_tokens = lora_weight["base_model.model.model.embed_tokens.weight"]
|
|
|
|
def get_target_modules(self, trtllm_modules_to_hf_modules):
|
|
hf_modules_to_trtllm_modules = invert_module_mapping(trtllm_modules_to_hf_modules)
|
|
lora_target_modules = set()
|
|
if self.is_valid:
|
|
hf_target_modules = get_hf_target_modules(
|
|
self.lora_weight,
|
|
hf_modules=set(hf_modules_to_trtllm_modules.keys()),
|
|
)
|
|
for m in hf_target_modules:
|
|
trtllm_module = hf_modules_to_trtllm_modules[m]
|
|
lora_target_modules.add(trtllm_module)
|
|
return list(lora_target_modules)
|
|
|
|
|
|
@lru_cache(maxsize=128)
|
|
def _find_nemo_files_single_path(lora_path: str) -> List[str]:
|
|
"""Find .nemo files from a single path (file or directory).
|
|
|
|
This function is cached per individual path to maximize cache efficiency
|
|
when the same paths appear in different collections.
|
|
|
|
Args:
|
|
lora_path: A single path that can be either:
|
|
- Direct path to a .nemo file
|
|
- Directory containing .nemo files (will auto-detect *.nemo)
|
|
|
|
Returns:
|
|
List[str]: List of paths to .nemo files found in this single path
|
|
|
|
Raises:
|
|
ValueError: If path doesn't exist, no .nemo files found, or invalid file type
|
|
"""
|
|
path = Path(lora_path)
|
|
if not path.exists():
|
|
raise ValueError(f"{path} does not exist")
|
|
|
|
if path.is_file():
|
|
if path.suffix == ".nemo":
|
|
return [str(path)]
|
|
else:
|
|
raise ValueError(f"{path} is not a .nemo file")
|
|
elif path.is_dir():
|
|
nemo_files_in_dir = list(path.glob("*.nemo"))
|
|
if not nemo_files_in_dir:
|
|
raise ValueError(f"No .nemo files found in directory {path}")
|
|
return [str(f) for f in nemo_files_in_dir]
|
|
else:
|
|
raise ValueError(f"{path} is neither a file nor a directory")
|
|
|
|
|
|
def find_nemo_files(lora_dirs: List[str]) -> List[str]:
|
|
"""Find all .nemo files from a list of directories or file paths.
|
|
|
|
This function is optimized for repeated calls at generation time by using an internal LRU cache
|
|
on individual paths, which maximizes cache efficiency when the same paths
|
|
appear in different collections.
|
|
|
|
Args:
|
|
lora_dirs: List of paths that can be either:
|
|
- Direct paths to .nemo files
|
|
- Directories containing .nemo files (will auto-detect *.nemo)
|
|
|
|
Returns:
|
|
List[str]: List of paths to .nemo files
|
|
|
|
Raises:
|
|
ValueError: If a path doesn't exist, no .nemo files are found in a directory
|
|
path, or a file path is of invalid file type
|
|
"""
|
|
if len(lora_dirs) == 0:
|
|
return []
|
|
|
|
all_nemo_files: List[str] = []
|
|
for lora_path in lora_dirs:
|
|
nemo_files_for_path = _find_nemo_files_single_path(lora_path)
|
|
all_nemo_files.extend(nemo_files_for_path)
|
|
|
|
if not all_nemo_files:
|
|
raise ValueError("No .nemo files found in the provided paths")
|
|
|
|
return all_nemo_files
|
|
|
|
|
|
class NemoLoraLoader:
|
|
def __init__(self, lora_dirs: List[str]):
|
|
"""Initialize NemoLoraLoader with paths to .nemo files or directories.
|
|
|
|
Args:
|
|
lora_dirs: List of paths that can be either:
|
|
- Direct paths to .nemo files
|
|
- Directories containing .nemo files (will auto-detect *.nemo)
|
|
|
|
Note: The parameter name 'lora_dirs' is misleading - it can accept both
|
|
directories and files. This is a design flaw that should be fixed
|
|
in a future version (e.g., rename to 'lora_paths').
|
|
"""
|
|
self.lora_target_modules = []
|
|
self.is_valid = False
|
|
|
|
if len(lora_dirs) == 0:
|
|
return
|
|
|
|
for lora_file in lora_dirs:
|
|
path = Path(lora_file)
|
|
if not path.exists():
|
|
raise ValueError(f"{path} does not exist")
|
|
self.is_valid = True
|
|
self.lora_target_modules = list(NEMO_SUPPORTED_LORA_MODULES)
|
|
|
|
def get_target_modules(self):
|
|
"""Get target modules for NeMo LoRA.
|
|
|
|
Unlike the HF loader, this method does not accept trtllm_modules_to_hf_modules
|
|
as an argument since the module mapping is hardcoded for NeMo LoRA support.
|
|
|
|
Returns:
|
|
List[str]: List of target module names supported by NeMo LoRA
|
|
"""
|
|
return self.lora_target_modules
|
|
|
|
|
|
def load_nemo_lora(model, lora_config: LoraConfig):
|
|
lora_loader = NemoLoraLoader(lora_config.lora_dir)
|
|
|
|
if not lora_loader.is_valid:
|
|
raise ValueError(f"Failed to load NeMo LoRA from {lora_config.lora_dir}")
|
|
|
|
if len(lora_config.lora_target_modules) == 0:
|
|
lora_config.lora_target_modules = lora_loader.lora_target_modules
|
|
|
|
|
|
def load_torch_hf_lora(lora_config: LoraConfig):
|
|
"""This is a shortned version of load_hf_lora that is used for torch models.
|
|
|
|
Main problem is model.config in legacy code is custom (defined in the legacy code) whereas
|
|
pivot model config is the transformer's one.
|
|
"""
|
|
# TODO smor- need to comibe with load_hf_lora
|
|
if not lora_config.trtllm_modules_to_hf_modules:
|
|
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules()
|
|
|
|
assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir"
|
|
lora_loader = HfLoraLoader(lora_config.lora_dir)
|
|
|
|
if len(lora_config.lora_target_modules) == 0:
|
|
lora_config.lora_target_modules = lora_loader.get_target_modules(
|
|
lora_config.trtllm_modules_to_hf_modules
|
|
)
|
|
|
|
if len(lora_config.lora_target_modules) == 0:
|
|
raise ValueError(
|
|
"lora_target_modules is empty. "
|
|
"Please specify lora_target_modules or provide lora_dir to infer lora_target_modules."
|
|
)
|
|
|
|
missing_qkv_modules = LoraManager.get_missing_qkv_modules(lora_config.lora_target_modules)
|
|
lora_config.lora_target_modules.extend(missing_qkv_modules)
|
|
|
|
|
|
def load_torch_nemo_lora(lora_config: LoraConfig):
|
|
"""Load NeMo LoRA checkpoint for PyTorch workflow.
|
|
|
|
This is a PyTorch-specific loader for NeMo LoRA checkpoints, similar to
|
|
load_torch_hf_lora but handling NeMo checkpoint format. NeMo uses a combined
|
|
"attn_qkv" module rather than separate Q, K, V modules, so no missing QKV
|
|
module handling is needed.
|
|
|
|
Note: This function only sets up the configuration. For PyTorch workflow,
|
|
the actual weight loading happens later via LoraManager when requests are
|
|
made with LoRA UIDs.
|
|
|
|
Args:
|
|
lora_config: LoRA configuration with lora_ckpt_source="nemo"
|
|
|
|
Raises:
|
|
ValueError: If NeMo LoRA directory is invalid or unsupported modules are specified
|
|
"""
|
|
lora_config.trtllm_modules_to_hf_modules = {"attn_qkv": "attn_qkv"}
|
|
|
|
assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir"
|
|
lora_loader = NemoLoraLoader(lora_config.lora_dir)
|
|
|
|
if not lora_loader.is_valid:
|
|
raise ValueError(f"Failed to load NeMo LoRA from {lora_config.lora_dir}")
|
|
|
|
if len(lora_config.lora_target_modules) == 0:
|
|
lora_config.lora_target_modules = lora_loader.get_target_modules()
|
|
|
|
if len(lora_config.lora_target_modules) == 0:
|
|
raise ValueError(
|
|
"lora_target_modules is empty. "
|
|
"Please specify lora_target_modules or provide lora_dir to infer lora_target_modules."
|
|
)
|
|
|
|
unsupported_modules = set(lora_config.lora_target_modules) - NEMO_SUPPORTED_LORA_MODULES
|
|
if unsupported_modules:
|
|
raise ValueError(
|
|
f"NeMo LoRA only supports {NEMO_SUPPORTED_LORA_MODULES} modules, "
|
|
f"but got unsupported modules: {unsupported_modules}. "
|
|
f"NeMo LoRA does not support embedding, lm_head, or MLP adapters."
|
|
)
|
|
|
|
|
|
def load_torch_lora(lora_config: LoraConfig):
|
|
"""Load LoRA checkpoint for PyTorch workflow.
|
|
|
|
This function routes to the appropriate loader based on lora_ckpt_source.
|
|
|
|
Args:
|
|
lora_config: LoRA configuration with lora_ckpt_source set to "hf" or "nemo"
|
|
|
|
Raises:
|
|
ValueError: If lora_ckpt_source is not supported
|
|
"""
|
|
if lora_config.lora_ckpt_source == "nemo":
|
|
load_torch_nemo_lora(lora_config)
|
|
elif lora_config.lora_ckpt_source == "hf":
|
|
load_torch_hf_lora(lora_config)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}. "
|
|
f"Supported sources: 'hf', 'nemo'"
|
|
)
|
|
|
|
|
|
def load_hf_lora(
|
|
model,
|
|
lora_config: LoraConfig,
|
|
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
|
|
):
|
|
trtllm_modules_to_hf_modules = (
|
|
trtllm_modules_to_hf_modules or get_default_trtllm_modules_to_hf_modules()
|
|
)
|
|
lora_config.trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules
|
|
|
|
lora_loader = HfLoraLoader(lora_config.lora_dir)
|
|
|
|
if len(lora_config.lora_target_modules) == 0:
|
|
lora_config.lora_target_modules = lora_loader.get_target_modules(
|
|
trtllm_modules_to_hf_modules
|
|
)
|
|
if len(lora_config.lora_target_modules) == 0:
|
|
raise ValueError(
|
|
"lora_target_modules is empty. "
|
|
"Please specify lora_target_modules or provide lora_dir to infer lora_target_modules."
|
|
)
|
|
|
|
missing_qkv_modules = LoraManager.get_missing_qkv_modules(lora_config.lora_target_modules)
|
|
lora_config.lora_target_modules.extend(missing_qkv_modules)
|
|
|
|
if lora_loader.is_valid:
|
|
config = model.config
|
|
torch_dtype = str_dtype_to_torch(config.dtype)
|
|
# the lora checkpoint might finetune the embedding
|
|
if lora_loader.vocab_size != 0:
|
|
config.vocab_size = lora_loader.vocab_size
|
|
mapping = config.mapping
|
|
if mapping.is_first_pp_rank() and lora_loader.embed_tokens is not None:
|
|
weight = lora_loader.embed_tokens
|
|
if config.use_parallel_embedding:
|
|
weight = split_matrix_tp(
|
|
weight,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=config.embedding_sharding_dim,
|
|
)
|
|
if model.transformer.vocab_embedding.weight.raw_value.shape != weight.shape:
|
|
model.transformer.vocab_embedding = model.transformer.vocab_embedding.__class__(
|
|
num_embeddings=config.vocab_size,
|
|
embedding_dim=config.hidden_size,
|
|
dtype=config.dtype,
|
|
tp_size=mapping.tp_size if config.use_parallel_embedding else 1,
|
|
tp_group=mapping.tp_group if config.use_parallel_embedding else None,
|
|
sharding_dim=config.embedding_sharding_dim,
|
|
tp_rank=mapping.tp_rank,
|
|
)
|
|
model.transformer.vocab_embedding.weight.value = weight.to(torch_dtype)
|
|
if mapping.is_last_pp_rank() and lora_loader.lm_head is not None:
|
|
weight = lora_loader.lm_head
|
|
vocab_size = lora_loader.vocab_size
|
|
if vocab_size % mapping.tp_size != 0:
|
|
# padding
|
|
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
|
|
pad_width = vocab_size_padded - vocab_size
|
|
|
|
weight = torch.from_numpy(
|
|
np.pad(
|
|
torch_to_numpy(weight),
|
|
((0, pad_width), (0, 0)),
|
|
"constant",
|
|
constant_values=0,
|
|
)
|
|
)
|
|
else:
|
|
vocab_size_padded = vocab_size
|
|
if model.lm_head.weight.raw_value.shape != weight.shape:
|
|
model.lm_head = ColumnLinear(
|
|
config.hidden_size,
|
|
vocab_size_padded,
|
|
bias=False,
|
|
dtype=config.dtype,
|
|
tp_group=mapping.tp_group,
|
|
tp_size=mapping.tp_size,
|
|
gather_output=True,
|
|
)
|
|
model.lm_head.weight.value = split_matrix_tp(
|
|
weight,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0,
|
|
).to(torch_dtype)
|
|
|
|
|
|
def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]:
|
|
"""Unpack model config and weights from a NeMo .nemo archive file.
|
|
|
|
Args:
|
|
nemo_archive_path: Path to the .nemo archive file
|
|
|
|
Returns:
|
|
Tuple of (model_config_dict, model_weights_dict)
|
|
|
|
Raises:
|
|
Exception: If required files cannot be extracted from the archive
|
|
"""
|
|
with tarfile.open(nemo_archive_path) as tar:
|
|
try:
|
|
model_weights_file = tar.extractfile("model_weights.ckpt")
|
|
model_config_file = tar.extractfile("model_config.yaml")
|
|
except KeyError:
|
|
try:
|
|
model_weights_file = tar.extractfile("./model_weights.ckpt")
|
|
model_config_file = tar.extractfile("./model_config.yaml")
|
|
except KeyError:
|
|
err_str = "Both model_weights paths not found in the tar archive."
|
|
raise Exception(err_str)
|
|
|
|
if model_weights_file is None or model_config_file is None:
|
|
raise Exception("Could not extract model weights or config files")
|
|
|
|
model_config_content = model_config_file.read()
|
|
model_config_dict = yaml.safe_load(model_config_content)
|
|
|
|
model_weights_bytes = model_weights_file.read()
|
|
model_weights_dict = torch.load(
|
|
io.BytesIO(model_weights_bytes), map_location=torch.device("cpu")
|
|
)
|
|
|
|
return model_config_dict, model_weights_dict
|
|
|
|
|
|
class LoraManager(object):
|
|
LORA_MODULE_IDS = {
|
|
"attn_qkv": 0,
|
|
"attn_q": 1,
|
|
"attn_k": 2,
|
|
"attn_v": 3,
|
|
"attn_dense": 4,
|
|
"mlp_h_to_4h": 5,
|
|
"mlp_4h_to_h": 6,
|
|
"mlp_gate": 7,
|
|
"cross_attn_qkv": 8,
|
|
"cross_attn_q": 9,
|
|
"cross_attn_k": 10,
|
|
"cross_attn_v": 11,
|
|
"cross_attn_dense": 12,
|
|
"moe_h_to_4h": 13,
|
|
"moe_4h_to_h": 14,
|
|
"moe_gate": 15,
|
|
"moe_router": 16,
|
|
"mlp_router": 17,
|
|
"mlp_gate_up": 18,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
mapping: Mapping,
|
|
model_config: "ModelConfig",
|
|
cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None,
|
|
):
|
|
"""Constructor.
|
|
|
|
Args:
|
|
mapping (Mapping): Parallelism related information.
|
|
model_config (ModelConfig): model configuration python class instance.
|
|
cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for
|
|
a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when
|
|
the adapter is already loaded in the LoRA CPU cache.
|
|
"""
|
|
# _lora_uid_to_low_ranks: dict[str -> dict[int -> dict[str -> int]]]
|
|
# {
|
|
# uid: {
|
|
# 0: {
|
|
# lora_module: int
|
|
# }, # layer_0_rank,
|
|
# 1: {
|
|
# lora_module: int
|
|
# }, # layer_1_rank,
|
|
# ...
|
|
# }
|
|
# }
|
|
|
|
# _lora_weights_pointers_list: dict[str -> dict[int -> dict[str -> [Tensor, Tensor]]]]
|
|
# {
|
|
# uid: {
|
|
# 0: {
|
|
# lora_module: [t_in, t_out]
|
|
# }, # layer_0,
|
|
# 1: {
|
|
# lora_module: [t_in, t_out]
|
|
# }, # layer_1,
|
|
# ...
|
|
# }
|
|
# }
|
|
|
|
self._lora_uid_counter = 0
|
|
self._lora_uid_to_low_ranks: Dict[str, Dict[int, Dict[str, int]]] = {}
|
|
# hold the torch tensors and prevent them from being freed
|
|
# TODO(enweiz): free device tensors if it's used for c++ runtime only
|
|
self._lora_weights: List[torch.Tensor] = []
|
|
self._lora_weights_pointers_list: Dict[str, Dict[int, Dict[str, List[int]]]] = {}
|
|
self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu
|
|
self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu
|
|
self.lora_target_modules: List[str] = []
|
|
self._mapping = mapping
|
|
self._model_config = model_config
|
|
self._cpp_peft_cache_manager = cpp_peft_cache_manager
|
|
|
|
def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:
|
|
"""Best effort to check if a LoRA adapter is in the LoRA CPU cache.
|
|
|
|
If no cpp_peft_cache_manager instance was given at the construction of this LoraManager instance, then False is
|
|
returned.
|
|
"""
|
|
return (
|
|
self._cpp_peft_cache_manager.is_task_cached(adapter_uid)
|
|
if self._cpp_peft_cache_manager
|
|
else False
|
|
)
|
|
|
|
@staticmethod
|
|
def get_missing_qkv_modules(lora_target_modules: List[str]) -> List[str]:
|
|
return get_missing_qkv_modules_from_lora_modules(lora_target_modules)
|
|
|
|
@property
|
|
def missing_qkv_modules(self) -> List[str]:
|
|
return LoraManager.get_missing_qkv_modules(self.lora_target_modules)
|
|
|
|
def load_from_ckpt(
|
|
self,
|
|
model_dirs_or_files: List[str],
|
|
model_config: Union["ModelConfig", LoraModelConfig],
|
|
uids: Optional[List[str]] = None,
|
|
ckpt_source: str = "hf",
|
|
) -> List[str]:
|
|
"""Returns the adapter UIDs that were loaded by this call.
|
|
|
|
Note that when an adapter was already loaded before this call, it would not be
|
|
included in the returned list of UIDs.
|
|
"""
|
|
if ckpt_source == "hf":
|
|
return self.load_from_hf(
|
|
model_dirs=model_dirs_or_files,
|
|
model_config=model_config,
|
|
uids=uids,
|
|
)
|
|
elif ckpt_source == "nemo":
|
|
# Find all .nemo files from directories or files
|
|
nemo_files = find_nemo_files(model_dirs_or_files)
|
|
|
|
# Pass the actual .nemo files to the loader
|
|
return self.load_from_nemo(
|
|
model_files=nemo_files,
|
|
model_config=model_config,
|
|
uids=uids,
|
|
)
|
|
else:
|
|
assert False, f"{self.__class__.__name__} does not support source {ckpt_source}"
|
|
|
|
def load_from_nemo(
|
|
self,
|
|
model_files: List[str],
|
|
model_config: Union["ModelConfig", LoraModelConfig],
|
|
uids: Optional[List[str]] = None,
|
|
) -> List[str]:
|
|
"""Returns the adapter UIDs that were loaded by this call.
|
|
|
|
Note that when an adapter was already loaded before this call, it would not be
|
|
included in the returned list of UIDs.
|
|
"""
|
|
if uids is None:
|
|
uids = [self._generate_uid() for _ in range(len(model_files))]
|
|
assert len(uids) == len(model_files)
|
|
|
|
new_uids, new_model_files = [], []
|
|
for uid, model_file in zip(uids, model_files):
|
|
if uid in self._lora_uid_to_low_ranks:
|
|
continue
|
|
new_uids.append(uid)
|
|
new_model_files.append(model_file)
|
|
|
|
if len(new_uids) == 0:
|
|
return new_uids
|
|
|
|
self.lora_target_modules = model_config.lora_target_modules
|
|
|
|
def load_from_model_file(uid, model_file):
|
|
if uid not in self._cpp_lora_weights:
|
|
self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
|
|
if uid not in self._cpp_lora_config:
|
|
self._cpp_lora_config[uid] = [] # Will be converted to tensor later
|
|
|
|
_, nemo_weights = unpack_nemo_weights(model_file)
|
|
all_lora_weights = get_all_nemo_lora_weights(nemo_weights)
|
|
|
|
self._lora_uid_to_low_ranks[uid] = {}
|
|
self._lora_weights_pointers_list[uid] = {}
|
|
for layer_idx in sorted(all_lora_weights.keys()):
|
|
self._lora_uid_to_low_ranks[uid][layer_idx] = {}
|
|
self._lora_weights_pointers_list[uid][layer_idx] = {}
|
|
|
|
for lora_module in self.lora_target_modules:
|
|
if lora_module not in NEMO_SUPPORTED_LORA_MODULES:
|
|
warnings.warn(
|
|
f"LoRA module '{lora_module}' not supported in NeMo loading for "
|
|
f"layer {layer_idx}, skipping. NeMo LoRA currently only supports "
|
|
f"{NEMO_SUPPORTED_LORA_MODULES} modules."
|
|
)
|
|
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = 0
|
|
continue
|
|
|
|
if lora_module == "attn_qkv":
|
|
# Validate required matrices are present
|
|
_check_lora_in_out(
|
|
layer_idx=layer_idx,
|
|
lora_module=lora_module,
|
|
available_matrices=all_lora_weights[layer_idx],
|
|
source_identifier=f"file {model_file}",
|
|
)
|
|
|
|
t_in = all_lora_weights[layer_idx]["in"]
|
|
t_out = all_lora_weights[layer_idx]["out"]
|
|
else:
|
|
t_in = None
|
|
t_out = None
|
|
|
|
if t_in is not None and t_out is not None:
|
|
t_in = t_in.cuda().to(str_dtype_to_torch(model_config.dtype)).contiguous()
|
|
t_out = t_out.cuda().to(str_dtype_to_torch(model_config.dtype)).contiguous()
|
|
rank = t_in.shape[0]
|
|
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = int(rank)
|
|
self._lora_weights_pointers_list[uid][layer_idx][lora_module] = [
|
|
t_in.data_ptr(),
|
|
t_out.data_ptr(),
|
|
0,
|
|
]
|
|
|
|
# prevent torch free this buffer
|
|
self._lora_weights.append(t_in)
|
|
self._lora_weights.append(t_out)
|
|
self._cpp_lora_weights[uid].append(
|
|
torch.concatenate([t_in.flatten().cpu(), t_out.flatten().cpu()])
|
|
)
|
|
self._cpp_lora_config[uid].append(
|
|
torch.tensor(
|
|
[self.LORA_MODULE_IDS[lora_module], layer_idx, int(rank)],
|
|
dtype=torch.int32,
|
|
)
|
|
)
|
|
|
|
max_weight_size = max(w.size(0) for w in self._cpp_lora_weights[uid])
|
|
self._cpp_lora_weights[uid] = torch.stack(
|
|
[
|
|
torch.nn.functional.pad(w, (0, max_weight_size - w.size(0)))
|
|
for w in self._cpp_lora_weights[uid]
|
|
]
|
|
)
|
|
self._cpp_lora_config[uid] = torch.stack([c for c in self._cpp_lora_config[uid]])
|
|
|
|
for uid, model_file in zip(new_uids, new_model_files):
|
|
load_from_model_file(uid, model_file)
|
|
release_gc()
|
|
|
|
if new_uids:
|
|
logger.info(f"Successfully loaded NeMo LoRA adapters with UIDs: {new_uids}")
|
|
return new_uids
|
|
|
|
def load_from_hf(
|
|
self,
|
|
model_dirs: List[str],
|
|
model_config: Union["ModelConfig", LoraModelConfig],
|
|
uids: Optional[List[str]] = None,
|
|
component: Optional[str] = None,
|
|
) -> List[str]:
|
|
"""Returns the adapter UIDs that were loaded by this call.
|
|
|
|
Note that when an adapter was already loaded before this call, it would not be
|
|
included in the returned list of UIDs.
|
|
|
|
Lora config of https://huggingface.co/hfl/chinese-alpaca-2-lora-7b.
|
|
|
|
{
|
|
"base_model_name_or_path": "/Llama-2-7b-hf",
|
|
"bias": "none",
|
|
"enable_lora": null,
|
|
"fan_in_fan_out": false,
|
|
"inference_mode": true,
|
|
"lora_alpha": 128.0,
|
|
"lora_dropout": 0.05,
|
|
"merge_weights": false,
|
|
"modules_to_save": [
|
|
"embed_tokens",
|
|
"lm_head"
|
|
],
|
|
"peft_type": "LORA",
|
|
"r": 64,
|
|
"target_modules": [
|
|
"q_proj",
|
|
"v_proj",
|
|
"k_proj",
|
|
"o_proj",
|
|
"gate_proj",
|
|
"down_proj",
|
|
"up_proj"
|
|
],
|
|
"task_type": "CAUSAL_LM"
|
|
|
|
}
|
|
|
|
keys in adapter_model.bin:
|
|
base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight torch.Size([64, 4096])
|
|
base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight torch.Size([4096, 64])
|
|
base_model.model.model.layers.0.self_attn.k_proj.lora_A.weight torch.Size([64, 4096])
|
|
base_model.model.model.layers.0.self_attn.k_proj.lora_B.weight torch.Size([4096, 64])
|
|
base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight torch.Size([64, 4096])
|
|
base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight torch.Size([4096, 64])
|
|
base_model.model.model.layers.0.self_attn.o_proj.lora_A.weight torch.Size([64, 4096])
|
|
base_model.model.model.layers.0.self_attn.o_proj.lora_B.weight torch.Size([4096, 64])
|
|
base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight torch.Size([64, 4096])
|
|
base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight torch.Size([11008, 64])
|
|
base_model.model.model.layers.0.mlp.up_proj.lora_A.weight torch.Size([64, 4096])
|
|
base_model.model.model.layers.0.mlp.up_proj.lora_B.weight torch.Size([11008, 64])
|
|
base_model.model.model.layers.0.mlp.down_proj.lora_A.weight torch.Size([64, 11008])
|
|
base_model.model.model.layers.0.mlp.down_proj.lora_B.weight torch.Size([4096, 64])
|
|
...
|
|
|
|
"""
|
|
if uids is None:
|
|
uids = [self._generate_uid() for _ in range(len(model_dirs))]
|
|
assert len(uids) == len(model_dirs)
|
|
|
|
new_uids, new_model_dirs = [], []
|
|
for uid, model_dir in zip(uids, model_dirs):
|
|
if uid in self._lora_uid_to_low_ranks:
|
|
continue
|
|
new_uids.append(uid)
|
|
new_model_dirs.append(model_dir)
|
|
|
|
if len(new_uids) == 0:
|
|
return new_uids
|
|
|
|
lora_hf_configs = []
|
|
for model_dir in new_model_dirs:
|
|
with open(f"{model_dir}/adapter_config.json", "r") as f:
|
|
config = json.load(f)
|
|
lora_hf_configs.append(config)
|
|
|
|
self.lora_target_modules = model_config.lora_target_modules
|
|
hf_modules_to_trtllm_modules = invert_module_mapping(
|
|
model_config.trtllm_modules_to_hf_modules
|
|
)
|
|
hf_modules = set(hf_modules_to_trtllm_modules.keys())
|
|
|
|
def preprocess_lora_weights(lora_model, model_config):
|
|
# Swap weights of gate_up_proj
|
|
if getattr(model_config, "swap_gate_up_proj_lora_b_weight", True):
|
|
for key, value in lora_model.items():
|
|
if "gate_up_proj.lora_B.weight" in key:
|
|
original_weights = value.contiguous().clone()
|
|
half_split = original_weights.shape[0] // 2
|
|
first_half = original_weights[:half_split, :]
|
|
second_half = original_weights[half_split:, :]
|
|
value = torch.cat((second_half, first_half), dim=0)
|
|
lora_model[key] = value
|
|
return lora_model
|
|
|
|
def interleave_fused_lora_weights_for_tp(
|
|
weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: List[int]
|
|
) -> List[torch.Tensor]:
|
|
"""Interleaves fused LoRA modules weights for TP.
|
|
e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to
|
|
torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN])
|
|
where N=TP size.
|
|
""" # noqa: D205
|
|
assert weight.shape[rank_dim] == sum(part_sizes)
|
|
|
|
# Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv.
|
|
weight_parts = [
|
|
weight.narrow(rank_dim, sum(part_sizes[:i]), part_sizes[i])
|
|
for i in range(len(part_sizes))
|
|
]
|
|
for i in range(len(part_sizes)):
|
|
assert weight_parts[i].shape[rank_dim] % tp_size == 0
|
|
|
|
# Split each part into tp_size chunks.
|
|
# e.g. [Wq, Wk, Wv] -> [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
|
|
# where N is TP size, for attn_qkv.
|
|
weight_parts_tp_weights = [
|
|
torch.split(
|
|
weight_parts[i], weight_parts[i].shape[rank_dim] // tp_size, dim=rank_dim
|
|
)
|
|
for i in range(len(part_sizes))
|
|
]
|
|
|
|
# Interleave the parts across TP ranks and flatten the list of lists into a single list.
|
|
# e.g. [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
|
|
# -> [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] where N is TP size, for attn_qkv.
|
|
return list(itertools.chain.from_iterable(zip(*weight_parts_tp_weights)))
|
|
|
|
def prepare_fused_lora_modules_for_tp(
|
|
lora_module: str, t_out: torch.Tensor, rank_dim: int
|
|
) -> torch.Tensor:
|
|
"""Reorders fused LoRA modules weights for TP. This is required since HF stores the parts weights
|
|
sequentially, whereas with TP>1 we need them to be interleaved so they would be sharded correctly.
|
|
|
|
See interleave_fused_lora_weights_for_tp for more details.
|
|
""" # noqa: D205
|
|
tp_size = self._mapping.tp_size
|
|
if tp_size == 1:
|
|
return t_out
|
|
part_sizes = []
|
|
if lora_module == "mlp_gate_up":
|
|
assert t_out.shape[rank_dim] % 2 == 0
|
|
half_size = t_out.shape[rank_dim] // 2
|
|
part_sizes = [half_size, half_size]
|
|
elif lora_module == "attn_qkv":
|
|
# The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already
|
|
# divided by tp_size in tensorrt_llm/_torch/model_config.py::ModelConfig.get_bindings_model_config
|
|
q_size = self._model_config.head_size * self._model_config.num_heads * tp_size
|
|
kv_size = self._model_config.head_size * self._model_config.num_kv_heads * tp_size
|
|
part_sizes = [q_size, kv_size, kv_size]
|
|
|
|
if part_sizes:
|
|
interleaved_parts = interleave_fused_lora_weights_for_tp(
|
|
t_out, rank_dim, tp_size, part_sizes
|
|
)
|
|
# Concatenate them all after interleaving, as the CPP implementation expects the full non-split weights.
|
|
t_out = torch.cat(interleaved_parts, dim=rank_dim)
|
|
return t_out
|
|
|
|
def load_from_model_dir(uid, model_dir, hf_config):
|
|
if uid not in self._cpp_lora_weights:
|
|
self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
|
|
if uid not in self._cpp_lora_config:
|
|
self._cpp_lora_config[uid] = [] # Will be converted to tensor later
|
|
|
|
lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
|
|
if lora_model is None:
|
|
raise ValueError(f"Failed to load adapter_model from {model_dir}")
|
|
lora_model = preprocess_lora_weights(lora_model, model_config)
|
|
all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component)
|
|
rank = int(hf_config["r"])
|
|
rs_lora = bool(hf_config.get("use_rslora", False))
|
|
|
|
self._lora_uid_to_low_ranks[uid] = {}
|
|
self._lora_weights_pointers_list[uid] = {}
|
|
for layer_idx in sorted(all_weights.keys()):
|
|
layer_weights = all_weights[layer_idx]
|
|
self._lora_uid_to_low_ranks[uid][layer_idx] = {}
|
|
self._lora_weights_pointers_list[uid][layer_idx] = {}
|
|
|
|
for lora_module in self.missing_qkv_modules:
|
|
hf_module = model_config.trtllm_modules_to_hf_modules[lora_module]
|
|
if isinstance(hf_module, list):
|
|
hf_module = hf_module[0]
|
|
layer_weights[hf_module] = {
|
|
"in": torch.zeros(rank, model_config.hidden_size),
|
|
"out": torch.zeros(model_config.hidden_size, rank),
|
|
}
|
|
|
|
for hf_module, module_weights in layer_weights.items():
|
|
lora_module = hf_modules_to_trtllm_modules[hf_module]
|
|
if lora_module not in self.lora_target_modules:
|
|
warnings.warn(
|
|
f"LoRA module '{lora_module}' not in target modules {self.lora_target_modules}, skipping."
|
|
)
|
|
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = 0
|
|
continue
|
|
|
|
has_expert_indices = _is_moe_module_weights(module_weights)
|
|
|
|
if has_expert_indices: # MoE
|
|
# Validate and extract matrices in one pass
|
|
expert_indices = sorted(module_weights.keys())
|
|
t_in_list, t_out_list = [], []
|
|
for expert_idx in expert_indices:
|
|
expert_weights = module_weights[expert_idx]
|
|
_check_lora_in_out(
|
|
layer_idx=layer_idx,
|
|
lora_module=f"{lora_module}_expert_{expert_idx}",
|
|
available_matrices=expert_weights,
|
|
source_identifier=f"directory {model_dir}",
|
|
)
|
|
t_in_list.append(expert_weights["in"])
|
|
t_out_list.append(expert_weights["out"])
|
|
|
|
t_in = torch.stack(t_in_list)
|
|
t_out = torch.stack(t_out_list)
|
|
for weights in module_weights.values():
|
|
if "mag" in weights:
|
|
# TODO(oargov): this might work, but I had no MoE DoRA models to test
|
|
raise ValueError("DoRA with MoE is not supported")
|
|
t_mag = None
|
|
else:
|
|
# Not MoE - validate required matrices are present
|
|
_check_lora_in_out(
|
|
layer_idx=layer_idx,
|
|
lora_module=lora_module,
|
|
available_matrices=module_weights,
|
|
source_identifier=f"directory {model_dir}",
|
|
)
|
|
|
|
t_in = module_weights["in"]
|
|
t_out = module_weights["out"]
|
|
t_mag = module_weights.get("magnitude", None)
|
|
|
|
is_dora = t_mag is not None
|
|
rank_dim = 1 if has_expert_indices else 0
|
|
t_out = prepare_fused_lora_modules_for_tp(lora_module, t_out, rank_dim)
|
|
|
|
effective_rank = t_in.shape[rank_dim]
|
|
|
|
t_in = t_in.cuda().contiguous()
|
|
t_out = t_out.cuda().contiguous()
|
|
if is_dora and t_mag is not None:
|
|
t_mag = t_mag.cuda().contiguous()
|
|
|
|
if rs_lora:
|
|
scale = float(hf_config["lora_alpha"]) / np.sqrt(effective_rank)
|
|
else:
|
|
scale = float(hf_config["lora_alpha"]) / effective_rank
|
|
t_out = t_out * scale
|
|
t_in = t_in.to(str_dtype_to_torch(model_config.dtype))
|
|
t_out = t_out.to(str_dtype_to_torch(model_config.dtype))
|
|
if is_dora and t_mag is not None:
|
|
t_mag = t_mag.to(str_dtype_to_torch(model_config.dtype))
|
|
|
|
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = effective_rank
|
|
self._lora_weights_pointers_list[uid][layer_idx][lora_module] = [
|
|
t_in.data_ptr(),
|
|
t_out.data_ptr(),
|
|
t_mag.data_ptr() if (is_dora and t_mag is not None) else 0,
|
|
]
|
|
|
|
# prevent torch free this buffer
|
|
self._lora_weights.append(t_in)
|
|
self._lora_weights.append(t_out)
|
|
if is_dora and t_mag is not None:
|
|
self._lora_weights.append(t_mag)
|
|
|
|
t_in_cpu = t_in.flatten().cpu()
|
|
t_out_cpu = t_out.flatten().cpu()
|
|
weights_to_concat = [t_in_cpu, t_out_cpu]
|
|
|
|
if is_dora and t_mag is not None:
|
|
t_mag_cpu = t_mag.flatten().cpu()
|
|
weights_to_concat.append(t_mag_cpu)
|
|
|
|
self._cpp_lora_weights[uid].append(torch.cat(weights_to_concat))
|
|
self._cpp_lora_config[uid].append(
|
|
torch.tensor(
|
|
[self.LORA_MODULE_IDS[lora_module], layer_idx, effective_rank, is_dora],
|
|
dtype=torch.int32,
|
|
)
|
|
)
|
|
|
|
max_weight_size = max(w.size(0) for w in self._cpp_lora_weights[uid])
|
|
self._cpp_lora_weights[uid] = torch.stack(
|
|
[
|
|
torch.nn.functional.pad(w, (0, max_weight_size - w.size(0)))
|
|
for w in self._cpp_lora_weights[uid]
|
|
]
|
|
)
|
|
self._cpp_lora_config[uid] = torch.stack([c for c in self._cpp_lora_config[uid]])
|
|
|
|
for uid, model_dir, hf_config in zip(new_uids, new_model_dirs, lora_hf_configs):
|
|
load_from_model_dir(uid, model_dir, hf_config)
|
|
release_gc()
|
|
|
|
return new_uids
|
|
|
|
@property
|
|
def lora_weights(self):
|
|
return self._lora_weights
|
|
|
|
@property
|
|
def lora_weights_pointers_list(self):
|
|
return self._lora_weights_pointers_list
|
|
|
|
@property
|
|
def cpp_lora_weights(self):
|
|
return self._cpp_lora_weights
|
|
|
|
@property
|
|
def cpp_lora_config(self):
|
|
return self._cpp_lora_config
|
|
|
|
def uid_to_low_ranks(self, uid: str):
|
|
assert isinstance(uid, str)
|
|
return self._lora_uid_to_low_ranks[uid]
|
|
|
|
def _generate_uid(self):
|
|
while str(self._lora_uid_counter) in self._lora_uid_to_low_ranks:
|
|
self._lora_uid_counter += 1
|
|
uid = str(self._lora_uid_counter)
|
|
self._lora_uid_counter += 1
|
|
return uid
|
|
|
|
@property
|
|
def num_lora_adapters(self):
|
|
return len([uid for uid in self._lora_uid_to_low_ranks if uid != "-1"])
|
|
|
|
def save_lora_weights_to_bin(self, out_dir):
|
|
def save_val(val, dir, key, tp_num=None, write_npy=False):
|
|
ext = "npy" if write_npy else "bin"
|
|
suffix = ext if tp_num is None else f"{tp_num}.{ext}"
|
|
if write_npy:
|
|
np.save(dir / f"model.{key}.{suffix}", val)
|
|
else:
|
|
val.tofile(dir / f"model.{key}.{suffix}")
|
|
|
|
if isinstance(out_dir, str):
|
|
out_dir_path = Path(out_dir)
|
|
elif isinstance(out_dir, Path):
|
|
out_dir_path = out_dir
|
|
else:
|
|
assert False
|
|
for uid in self.cpp_lora_weights:
|
|
if uid == "-1":
|
|
continue
|
|
|
|
all_weights = np.expand_dims(torch_to_numpy(self.cpp_lora_weights[uid]), 0)
|
|
all_configs = np.expand_dims(torch_to_numpy(self.cpp_lora_config[uid]), 0)
|
|
|
|
uid_path = out_dir_path / f"{uid}"
|
|
uid_path.mkdir(parents=True, exist_ok=True)
|
|
save_val(all_weights, uid_path, "lora_weights", tp_num=None, write_npy=True)
|
|
save_val(all_configs, uid_path, "lora_config", tp_num=None, write_npy=True)
|
|
|
|
def input_buffers(self, lora_uids, mapping: Mapping, num_layers: int):
|
|
inputs = {}
|
|
for layer_idx in mapping.pp_layers(num_layers):
|
|
for lora_module in self.lora_target_modules + self.missing_qkv_modules:
|
|
lora_ranks_ = []
|
|
lora_ptrs_ = []
|
|
for lora_uid in lora_uids:
|
|
lora_rank = 0
|
|
lora_ptrs = [0, 0, 0]
|
|
|
|
if lora_uid != "-1":
|
|
low_ranks = self.uid_to_low_ranks(lora_uid)
|
|
|
|
if (
|
|
layer_idx in low_ranks
|
|
and lora_module in low_ranks[layer_idx].keys()
|
|
and low_ranks[layer_idx][lora_module] != 0
|
|
):
|
|
lora_rank = low_ranks[layer_idx][lora_module]
|
|
lora_ptrs = self.lora_weights_pointers_list[lora_uid][layer_idx][
|
|
lora_module
|
|
]
|
|
|
|
lora_ranks_.append(lora_rank)
|
|
lora_ptrs_.append(lora_ptrs)
|
|
|
|
inputs[f"{lora_module}_lora_ranks_{layer_idx}"] = torch.IntTensor(lora_ranks_)
|
|
inputs[f"{lora_module}_lora_weights_pointers_{layer_idx}"] = torch.LongTensor(
|
|
lora_ptrs_
|
|
)
|
|
return inputs
|