mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix:https://nvbugs/5234033 enable startcoder trt-flow with transformer 4.51.3. Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
1453 lines
61 KiB
Python
1453 lines
61 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import functools
|
|
import os
|
|
import shutil
|
|
import tarfile
|
|
import time
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Union
|
|
|
|
import numpy as np
|
|
import safetensors
|
|
import torch
|
|
import torch.nn as nn
|
|
import yaml
|
|
from tqdm import tqdm
|
|
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
|
AutoTokenizer)
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
|
from transformers.pytorch_utils import Conv1D
|
|
|
|
from ..._utils import pad_vocab_size, str_dtype_to_torch
|
|
from ...logger import logger
|
|
from ...quantization import QuantAlgo
|
|
from ..convert_utils import (generate_int8, get_weight, get_weight_and_bias,
|
|
load_calib_dataset,
|
|
retrieved_layer_index_from_name, smooth_gemm)
|
|
from .config import GPTConfig
|
|
|
|
|
|
def rename_keys(model_state, layer_rename_config: Dict[str, str]):
|
|
if not layer_rename_config:
|
|
return model_state
|
|
|
|
new_state_dict = {}
|
|
for key, value in model_state.items():
|
|
for old, new in layer_rename_config.items():
|
|
key = key.replace(old, new)
|
|
assert key not in new_state_dict, f"Key already exists: {key}"
|
|
new_state_dict[key] = value
|
|
|
|
return new_state_dict
|
|
|
|
|
|
def get_needed_padding(value: int, multiple: int) -> int:
|
|
return (multiple - value % multiple) % multiple
|
|
|
|
|
|
def pad_array_up_to(v: torch.Tensor, axis: int, multiple: int) -> torch.Tensor:
|
|
a = [0 for i in range(len(v.shape) * 2)]
|
|
a[axis * 2 - 1] = get_needed_padding(v.shape[axis], multiple)
|
|
return torch.nn.functional.pad(v, a)
|
|
|
|
|
|
def split(param: torch.Tensor,
|
|
tp_rank: int,
|
|
tp_size: int,
|
|
is_column: bool = True) -> torch.Tensor:
|
|
"""Split linear layer's weight, bias or scaling factors for tensor parallelism."""
|
|
if param is None:
|
|
return None
|
|
assert param.ndim in [1, 2]
|
|
if tp_size == 1:
|
|
return param
|
|
if param.numel() == 1:
|
|
return param
|
|
if param.ndim == 1 and not is_column:
|
|
return param
|
|
split_dim = 0 if (param.ndim == 1 or is_column) else 1
|
|
return torch.chunk(param, tp_size, dim=split_dim)[tp_rank].contiguous()
|
|
|
|
|
|
def split_qkv(
|
|
param: torch.Tensor,
|
|
tp_rank: int,
|
|
tp_size: int,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
num_kv_heads: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
"""Split qkv layer's weight, bias or scaling factors for tensor parallelism.
|
|
|
|
param: (num_heads*head_dim + 2*num_kv_heads*head_dim, in_dim)
|
|
"""
|
|
if param is None:
|
|
return None
|
|
assert hidden_size % num_heads == 0
|
|
head_dim = hidden_size // num_heads
|
|
num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
|
|
assert num_heads % num_kv_heads == 0
|
|
assert num_heads % tp_size == 0
|
|
|
|
q_param, k_param, v_param = torch.split(
|
|
param, [hidden_size, num_kv_heads * head_dim, num_kv_heads * head_dim],
|
|
dim=0)
|
|
|
|
if num_kv_heads < tp_size:
|
|
assert tp_size % num_kv_heads == 0
|
|
num_dups = tp_size // num_kv_heads
|
|
remain_shape = k_param.shape[1:]
|
|
k_param = k_param.view(
|
|
num_kv_heads, head_dim,
|
|
*remain_shape).repeat_interleave(num_dups, dim=0).view(
|
|
num_kv_heads * head_dim * num_dups, *remain_shape)
|
|
v_param = v_param.view(
|
|
num_kv_heads, head_dim,
|
|
*remain_shape).repeat_interleave(num_dups, dim=0).view(
|
|
num_kv_heads * head_dim * num_dups, *remain_shape)
|
|
else:
|
|
assert num_kv_heads % tp_size == 0
|
|
|
|
q_param = split(q_param, tp_rank, tp_size, is_column=True)
|
|
k_param = split(k_param, tp_rank, tp_size, is_column=True)
|
|
v_param = split(v_param, tp_rank, tp_size, is_column=True)
|
|
return torch.cat([q_param, k_param, v_param], dim=0)
|
|
|
|
|
|
def split_embedding(
|
|
param: torch.Tensor,
|
|
tp_rank: int,
|
|
tp_size: int,
|
|
use_parallel_embedding: bool = False,
|
|
sharding_dim: int = 0,
|
|
) -> torch.Tensor:
|
|
if param is None:
|
|
return None
|
|
if not use_parallel_embedding:
|
|
return param
|
|
|
|
vocab_size, hidden_size = param.size()
|
|
if sharding_dim == 0:
|
|
if vocab_size % tp_size != 0:
|
|
vocab_size_padded = pad_vocab_size(vocab_size, tp_size)
|
|
pad_width = vocab_size_padded - vocab_size
|
|
param = torch.nn.functional.pad(param, (0, 0, 0, pad_width),
|
|
value=0)
|
|
else:
|
|
assert hidden_size % tp_size == 0
|
|
return split(param, tp_rank, tp_size, is_column=(sharding_dim == 0))
|
|
|
|
|
|
def get_tllm_linear_weight(
|
|
weight: torch.Tensor,
|
|
prefix: str,
|
|
bias: Optional[torch.Tensor] = None,
|
|
use_weight_only: bool = False,
|
|
plugin_weight_only_quant_type: torch.dtype = torch.int8
|
|
) -> Dict[str, torch.Tensor]:
|
|
results = {}
|
|
if use_weight_only:
|
|
v = weight.t().contiguous()
|
|
processed_torch_weights, torch_weight_scales = \
|
|
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
|
|
v, plugin_weight_only_quant_type)
|
|
results[f'{prefix}.weight'] = processed_torch_weights
|
|
results[f'{prefix}.per_channel_scale'] = torch_weight_scales
|
|
else:
|
|
results[f'{prefix}.weight'] = weight
|
|
|
|
if bias is not None:
|
|
results[f'{prefix}.bias'] = bias
|
|
|
|
return results
|
|
|
|
|
|
@torch.no_grad()
|
|
def capture_activation_range(model,
|
|
tokenizer,
|
|
dataset,
|
|
num_samples=512,
|
|
seq_len=512):
|
|
model.eval()
|
|
device = next(model.parameters()).device
|
|
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
|
|
|
|
def stat_tensor(name, tensor, act_scales, key):
|
|
hidden_dim = tensor.shape[-1]
|
|
tensor = tensor.view(-1, hidden_dim).abs().detach()
|
|
comming_max = torch.max(tensor, dim=0)[0].float()
|
|
|
|
if act_scales[name][key] is None:
|
|
act_scales[name][key] = comming_max
|
|
else:
|
|
act_scales[name][key] = torch.max(act_scales[name][key],
|
|
comming_max)
|
|
|
|
def stat_input_hook(m, x, y, name):
|
|
if isinstance(x, tuple):
|
|
x = x[0]
|
|
stat_tensor(name, x, act_scales, "x")
|
|
stat_tensor(name, y, act_scales, "y")
|
|
|
|
if act_scales[name]["w"] is None:
|
|
act_scales[name]["w"] = m.weight.abs().clip(1e-8,
|
|
None).max(dim=0)[0]
|
|
|
|
hooks = []
|
|
for name, m in model.named_modules():
|
|
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
|
|
hooks.append(
|
|
m.register_forward_hook(
|
|
functools.partial(stat_input_hook, name=name)))
|
|
|
|
for i in tqdm(range(num_samples), desc="calibrating model"):
|
|
input_ids = tokenizer(dataset[i],
|
|
return_tensors="pt",
|
|
max_length=seq_len,
|
|
truncation=True).input_ids.to(device)
|
|
model(input_ids)
|
|
|
|
for h in hooks:
|
|
h.remove()
|
|
|
|
return act_scales
|
|
|
|
|
|
@torch.no_grad()
|
|
def smooth_gpt_model(model, scales, alpha):
|
|
# Smooth the activation and weights with smoother = $\diag{s}$
|
|
for name, module in model.named_modules():
|
|
if not isinstance(module, GPT2Block):
|
|
continue
|
|
|
|
# qkv_proj
|
|
layer_name = name + ".attn.c_attn"
|
|
smoother = smooth_gemm(module.attn.c_attn.weight.T,
|
|
scales[layer_name]["x"], module.ln_1.weight,
|
|
module.ln_1.bias, alpha)
|
|
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
|
|
scales[layer_name]["w"] = module.attn.c_attn.weight.abs().max(dim=0)[0]
|
|
|
|
# fc1
|
|
layer_name = name + ".mlp.c_fc"
|
|
smoother = smooth_gemm(module.mlp.c_fc.weight.T,
|
|
scales[layer_name]["x"], module.ln_2.weight,
|
|
module.ln_2.bias, alpha)
|
|
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
|
|
scales[layer_name]["w"] = module.mlp.c_fc.weight.abs().max(dim=0)[0]
|
|
|
|
|
|
def get_tllm_linear_sq_weight(vals,
|
|
prefix,
|
|
shape,
|
|
tensor_parallel,
|
|
is_qkv=False,
|
|
per_token=False,
|
|
per_channel=False,
|
|
last_prefix=None,
|
|
bias=None,
|
|
smoother_value=None,
|
|
smoother_shape=None,
|
|
rank=0,
|
|
cat_dim=0,
|
|
multi_query_mode=False):
|
|
results = {}
|
|
|
|
def multi_query_split(data, local_dim, head_size, tp_size, cur_rank):
|
|
q, k, v = torch.split(data, [local_dim, head_size, head_size], dim=-1)
|
|
q_split = torch.split(q, q.shape[-1] // tp_size, dim=-1)
|
|
k_split = torch.split(k, q.shape[-1] // tp_size, dim=-1)
|
|
v_split = torch.split(v, q.shape[-1] // tp_size, dim=-1)
|
|
return [
|
|
torch.concat((q_split[ii], k_split[ii], v_split[ii]), dim=-1)
|
|
for ii in range(tp_size)
|
|
][cur_rank]
|
|
|
|
col_shape = shape if (is_qkv or per_channel) else [1, 1]
|
|
|
|
if per_token:
|
|
if per_channel:
|
|
original_weights = vals["weight.int8.col"]
|
|
else:
|
|
original_weights = vals["weight.int8"]
|
|
local_dim = original_weights.shape[0]
|
|
head_size = (original_weights.shape[1] - local_dim) // 2
|
|
|
|
if multi_query_mode:
|
|
cur_weights = multi_query_split(original_weights, local_dim,
|
|
head_size, tensor_parallel, rank)
|
|
else:
|
|
cur_weights = np.split(original_weights,
|
|
tensor_parallel,
|
|
axis=cat_dim)[rank]
|
|
if is_qkv:
|
|
hidden_dim = cur_weights.shape[0]
|
|
cur_weights = cur_weights.reshape(hidden_dim, -1)
|
|
results[prefix + 'weight'] = cur_weights.t().contiguous()
|
|
if smoother_value is None:
|
|
results[last_prefix] = torch.from_numpy(
|
|
np.array([1.0], dtype=np.float32))
|
|
|
|
if per_channel:
|
|
cur_per_channel_value = vals["scale_w_quant_orig.col"]
|
|
if smoother_value is None:
|
|
if multi_query_mode:
|
|
cur_per_channel_value = multi_query_split(
|
|
vals["scale_w_quant_orig.col"], local_dim, head_size,
|
|
tensor_parallel, rank)
|
|
else:
|
|
cur_per_channel_value = np.split(
|
|
vals["scale_w_quant_orig.col"],
|
|
tensor_parallel,
|
|
axis=cat_dim)[rank]
|
|
else:
|
|
cur_per_channel_value = vals["scale_w_quant_orig"]
|
|
if is_qkv:
|
|
if multi_query_mode:
|
|
cur_per_channel_value = multi_query_split(
|
|
vals["scale_w_quant_orig"], local_dim, head_size,
|
|
tensor_parallel, rank)
|
|
else:
|
|
cur_per_channel_value = np.split(vals["scale_w_quant_orig"],
|
|
tensor_parallel,
|
|
axis=cat_dim)[rank]
|
|
|
|
results[prefix + 'per_channel_scale'] = cur_per_channel_value.reshape(
|
|
col_shape).contiguous()
|
|
else:
|
|
if per_channel:
|
|
original_weights = np.array(vals["weight.int8.col"])
|
|
else:
|
|
original_weights = np.array(vals["weight.int8"])
|
|
local_dim = original_weights.shape[0]
|
|
head_size = (original_weights.shape[1] - local_dim) // 2
|
|
|
|
if multi_query_mode:
|
|
cur_weights = multi_query_split(original_weights, local_dim,
|
|
head_size, tensor_parallel, rank)
|
|
else:
|
|
cur_weights = np.split(original_weights,
|
|
tensor_parallel,
|
|
axis=cat_dim)[rank]
|
|
if is_qkv:
|
|
hidden_dim = cur_weights.shape[0]
|
|
cur_weights = cur_weights.reshape(hidden_dim, -1)
|
|
results[prefix +
|
|
'weight'] = torch.from_numpy(cur_weights).t().contiguous()
|
|
|
|
if per_channel:
|
|
cur_per_channel_value = vals["scale_y_accum_quant.col"]
|
|
if smoother_value is None:
|
|
if multi_query_mode:
|
|
cur_per_channel_value = multi_query_split(
|
|
vals["scale_y_accum_quant.col"], local_dim, head_size,
|
|
tensor_parallel, rank)
|
|
else:
|
|
cur_per_channel_value = np.split(
|
|
vals["scale_y_accum_quant.col"],
|
|
tensor_parallel,
|
|
axis=cat_dim)[rank]
|
|
else:
|
|
cur_per_channel_value = vals["scale_y_accum_quant"]
|
|
# QKV is always per_channel
|
|
if is_qkv:
|
|
if multi_query_mode:
|
|
cur_per_channel_value = multi_query_split(
|
|
vals["scale_y_accum_quant"], local_dim, head_size,
|
|
tensor_parallel, rank)
|
|
else:
|
|
cur_per_channel_value = np.split(
|
|
vals["scale_y_accum_quant"],
|
|
tensor_parallel,
|
|
axis=cat_dim)[rank]
|
|
|
|
results[prefix + 'per_channel_scale'] = cur_per_channel_value.reshape(
|
|
col_shape).contiguous()
|
|
|
|
results[last_prefix] = vals['scale_x_orig_quant'].contiguous()
|
|
|
|
results[prefix + 'act_scale'] = vals["scale_y_quant_orig"].contiguous()
|
|
|
|
if smoother_value is not None:
|
|
cur_smoother_value = np.split(smoother_value,
|
|
tensor_parallel,
|
|
axis=cat_dim)[rank]
|
|
results[prefix + 'smoother'] = cur_smoother_value.reshape(
|
|
smoother_shape).contiguous().to(torch.float32)
|
|
|
|
if bias is not None:
|
|
results[prefix + 'bias'] = bias
|
|
|
|
return results
|
|
|
|
|
|
def load_weights_from_hf_model(hf_model,
|
|
config: GPTConfig,
|
|
act_range: Optional[dict] = None):
|
|
quant_algo = config.quantization.quant_algo
|
|
use_weight_only = quant_algo in [QuantAlgo.W8A16, QuantAlgo.W4A16]
|
|
if quant_algo == QuantAlgo.W8A16:
|
|
plugin_weight_only_quant_type = torch.int8
|
|
elif quant_algo == QuantAlgo.W4A16:
|
|
plugin_weight_only_quant_type = torch.quint4x2
|
|
else:
|
|
plugin_weight_only_quant_type = None
|
|
|
|
use_smooth_quant = config.quantization._use_plugin_sq
|
|
per_channel = use_smooth_quant and 'PER_CHANNEL' in quant_algo
|
|
per_token = use_smooth_quant and 'PER_TOKEN' in quant_algo
|
|
int8_kv_cache = config.quantization.kv_cache_quant_algo == QuantAlgo.INT8
|
|
if use_smooth_quant or int8_kv_cache:
|
|
assert act_range is not None
|
|
|
|
weights = {}
|
|
tik = time.time()
|
|
|
|
hf_config = hf_model.config
|
|
model_params = dict(hf_model.named_parameters())
|
|
dtype = getattr(torch, config.dtype)
|
|
gpt_variant = config.gpt_variant
|
|
num_attention_heads = config.num_attention_heads
|
|
hidden_size = config.hidden_size
|
|
vocab_size = config.vocab_size
|
|
num_kv_heads = config.num_key_value_heads
|
|
num_hidden_layers = config.num_hidden_layers
|
|
multi_query_mode = (num_kv_heads != num_attention_heads)
|
|
|
|
mapping = config.mapping
|
|
layers_range = mapping.pp_layers(num_hidden_layers)
|
|
for l in layers_range:
|
|
if gpt_variant in ['starcoder2', 'nemotron']:
|
|
prefix = f'model.layers.{l}'
|
|
elif gpt_variant == 'persimmon':
|
|
is_fuyu = f'language_model.model.embed_tokens.weight' in model_params
|
|
prefix = f'language_model.model.layers.{l}' if is_fuyu else f'model.layers.{l}'
|
|
elif gpt_variant == 'kosmos-2':
|
|
prefix = f'text_model.model.layers.{l}'
|
|
else:
|
|
prefix = f'transformer.h.{l}'
|
|
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
|
|
|
|
# (1) Attention QKV Linear
|
|
if gpt_variant == 'santacoder':
|
|
q_w, q_b = get_weight_and_bias(model_params,
|
|
f'{prefix}.attn.q_attn', dtype)
|
|
kv_w, kv_b = get_weight_and_bias(model_params,
|
|
f'{prefix}.attn.kv_attn', dtype)
|
|
qkv_w = torch.cat([q_w, kv_w], dim=-1)
|
|
qkv_b = torch.cat([q_b, kv_b], dim=-1)
|
|
elif gpt_variant in ['starcoder2', 'nemotron', 'kosmos-2']:
|
|
q_w, q_b = get_weight_and_bias(model_params,
|
|
f'{prefix}.self_attn.q_proj', dtype)
|
|
k_w, k_b = get_weight_and_bias(model_params,
|
|
f'{prefix}.self_attn.k_proj', dtype)
|
|
v_w, v_b = get_weight_and_bias(model_params,
|
|
f'{prefix}.self_attn.v_proj', dtype)
|
|
qkv_w = torch.cat([q_w.cuda(), k_w.cuda(), v_w.cuda()], dim=0)
|
|
qkv_b = torch.cat([q_b.cuda(), k_b.cuda(),
|
|
v_b.cuda()], dim=0) if q_b is not None else None
|
|
elif gpt_variant == 'persimmon':
|
|
qkv_w, qkv_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.self_attn.query_key_value', dtype)
|
|
else:
|
|
qkv_w, qkv_b = get_weight_and_bias(model_params,
|
|
f'{prefix}.attn.c_attn', dtype)
|
|
if gpt_variant in ['gpt2', 'santacoder', 'jais']:
|
|
qkv_w = qkv_w.t().contiguous() # transpose for Conv1D
|
|
|
|
if use_smooth_quant:
|
|
qkv_out_dim = qkv_w.shape[0]
|
|
qkv_w_t = qkv_w.t()
|
|
if not multi_query_mode:
|
|
qkv_w_t = qkv_w_t.reshape(hidden_size, 3, hidden_size)
|
|
int8_weights = generate_int8(qkv_w_t,
|
|
act_range.get(f'{prefix}.attn.c_attn'),
|
|
is_qkv=True,
|
|
multi_query_mode=multi_query_mode)
|
|
qkv_b = split_qkv(qkv_b, mapping.tp_rank, mapping.tp_size,
|
|
hidden_size, num_attention_heads, num_kv_heads)
|
|
weights.update(
|
|
get_tllm_linear_sq_weight(
|
|
int8_weights,
|
|
f'{tllm_prex}.attention.qkv.',
|
|
[1, qkv_out_dim // mapping.tp_size],
|
|
mapping.tp_size,
|
|
is_qkv=True,
|
|
per_token=per_token,
|
|
per_channel=per_channel,
|
|
last_prefix=f'{tllm_prex}.input_layernorm.scale_to_int',
|
|
bias=qkv_b,
|
|
smoother_value=None,
|
|
smoother_shape=None,
|
|
rank=mapping.tp_rank,
|
|
cat_dim=-1,
|
|
multi_query_mode=multi_query_mode))
|
|
else:
|
|
if gpt_variant == 'persimmon':
|
|
qkv_w = split(qkv_w,
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
is_column=True)
|
|
|
|
qkv_b = split(qkv_b,
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
is_column=True)
|
|
else:
|
|
qkv_w = split_qkv(qkv_w, mapping.tp_rank, mapping.tp_size,
|
|
hidden_size, num_attention_heads,
|
|
num_kv_heads)
|
|
qkv_b = split_qkv(qkv_b, mapping.tp_rank, mapping.tp_size,
|
|
hidden_size, num_attention_heads,
|
|
num_kv_heads)
|
|
weights.update(
|
|
get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv',
|
|
qkv_b, use_weight_only,
|
|
plugin_weight_only_quant_type))
|
|
|
|
if int8_kv_cache:
|
|
qkv_w_t = qkv_w.t()
|
|
if not multi_query_mode:
|
|
qkv_w_t = qkv_w_t.reshape(hidden_size, 3, hidden_size)
|
|
int8_weights = generate_int8(qkv_w_t,
|
|
act_range.get(f'{prefix}.attn.c_attn'),
|
|
is_qkv=True,
|
|
multi_query_mode=multi_query_mode)
|
|
weights[
|
|
f'{tllm_prex}.attention.kv_cache_scaling_factor'] = int8_weights[
|
|
'scale_y_quant_orig'].contiguous()
|
|
|
|
# (2) Attention Dense Linear
|
|
if gpt_variant in ['starcoder2', 'nemotron']:
|
|
attn_dense_w, attn_dense_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.self_attn.o_proj', dtype)
|
|
elif gpt_variant == 'persimmon':
|
|
attn_dense_w, attn_dense_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.self_attn.dense', dtype)
|
|
elif gpt_variant == 'kosmos-2':
|
|
attn_dense_w, attn_dense_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.self_attn.out_proj', dtype)
|
|
else:
|
|
attn_dense_w, attn_dense_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.attn.c_proj', dtype)
|
|
if gpt_variant in ['gpt2', 'santacoder', 'jais']:
|
|
attn_dense_w = attn_dense_w.t().contiguous() # transpose for Conv1D
|
|
|
|
if use_smooth_quant:
|
|
attn_dense_w_t = attn_dense_w.t()
|
|
int8_weights = generate_int8(attn_dense_w_t,
|
|
act_range.get(f'{prefix}.attn.c_proj'))
|
|
# change it to the real smoother if dense layer is applied smooth quant
|
|
fake_smoother_value = torch.ones([1, hidden_size],
|
|
dtype=torch.float32)
|
|
weights.update(
|
|
get_tllm_linear_sq_weight(
|
|
int8_weights,
|
|
f'{tllm_prex}.attention.dense.', [1, hidden_size],
|
|
mapping.tp_size,
|
|
is_qkv=False,
|
|
per_token=per_token,
|
|
per_channel=per_channel,
|
|
last_prefix=
|
|
f'{tllm_prex}.attention.quantization_scaling_factor',
|
|
bias=attn_dense_b,
|
|
smoother_value=fake_smoother_value,
|
|
smoother_shape=[1, hidden_size // mapping.tp_size],
|
|
rank=mapping.tp_rank,
|
|
cat_dim=0))
|
|
else:
|
|
attn_dense_w = split(attn_dense_w,
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
is_column=False)
|
|
weights.update(
|
|
get_tllm_linear_weight(attn_dense_w,
|
|
f'{tllm_prex}.attention.dense',
|
|
attn_dense_b, use_weight_only,
|
|
plugin_weight_only_quant_type))
|
|
|
|
# (3) MLP FC Linear
|
|
if gpt_variant == 'persimmon':
|
|
suffix = "mlp.dense_h_to_4h"
|
|
elif gpt_variant == 'kosmos-2':
|
|
suffix = "ffn.fc1"
|
|
elif gpt_variant == 'nemotron':
|
|
suffix = "mlp.up_proj"
|
|
else:
|
|
suffix = "mlp.c_fc"
|
|
mlp_fc_w, mlp_fc_b = get_weight_and_bias(model_params,
|
|
f'{prefix}.{suffix}', dtype)
|
|
if gpt_variant in ['gpt2', 'santacoder', 'jais']:
|
|
mlp_fc_w = mlp_fc_w.t().contiguous() # transpose for Conv1D
|
|
if gpt_variant in ['jais']:
|
|
mlp_fc_w = pad_array_up_to(mlp_fc_w, 0, mapping.tp_size)
|
|
mlp_fc_b = pad_array_up_to(mlp_fc_b, 0, mapping.tp_size)
|
|
if use_smooth_quant:
|
|
mlp_fc_w_t = mlp_fc_w.t()
|
|
int8_weights = generate_int8(mlp_fc_w_t,
|
|
act_range.get(f'{prefix}.mlp.c_fc'))
|
|
mlp_fc_b = split(mlp_fc_b,
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
is_column=True)
|
|
weights.update(
|
|
get_tllm_linear_sq_weight(
|
|
int8_weights,
|
|
f'{tllm_prex}.mlp.fc.',
|
|
[1, 4 * hidden_size // mapping.tp_size],
|
|
mapping.tp_size,
|
|
is_qkv=False,
|
|
per_token=per_token,
|
|
per_channel=per_channel,
|
|
last_prefix=f'{tllm_prex}.post_layernorm.scale_to_int',
|
|
bias=mlp_fc_b,
|
|
smoother_value=None,
|
|
smoother_shape=None,
|
|
rank=mapping.tp_rank,
|
|
cat_dim=-1))
|
|
else:
|
|
mlp_fc_w = split(mlp_fc_w,
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
is_column=True)
|
|
mlp_fc_b = split(mlp_fc_b,
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
is_column=True)
|
|
if gpt_variant in ['jais']:
|
|
weights.update(
|
|
get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.gate',
|
|
mlp_fc_b, use_weight_only,
|
|
plugin_weight_only_quant_type))
|
|
else:
|
|
weights.update(
|
|
get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc',
|
|
mlp_fc_b, use_weight_only,
|
|
plugin_weight_only_quant_type))
|
|
if gpt_variant in ['jais']:
|
|
mlp_fc2_w, mlp_fc2_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.mlp.c_fc2', dtype)
|
|
mlp_fc2_w = mlp_fc2_w.t().contiguous()
|
|
mlp_fc2_w = pad_array_up_to(mlp_fc2_w, 0, mapping.tp_size)
|
|
mlp_fc2_b = pad_array_up_to(mlp_fc2_b, 0, mapping.tp_size)
|
|
mlp_fc2_w = split(mlp_fc2_w,
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
is_column=True)
|
|
mlp_fc2_b = split(mlp_fc2_b,
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
is_column=True)
|
|
weights.update(
|
|
get_tllm_linear_weight(mlp_fc2_w, f'{tllm_prex}.mlp.fc',
|
|
mlp_fc2_b, use_weight_only,
|
|
plugin_weight_only_quant_type))
|
|
|
|
# (4) MLP Proj Layer
|
|
if gpt_variant == 'persimmon':
|
|
suffix = "mlp.dense_4h_to_h"
|
|
elif gpt_variant == 'kosmos-2':
|
|
suffix = "ffn.fc2"
|
|
elif gpt_variant == 'nemotron':
|
|
suffix = "mlp.down_proj"
|
|
else:
|
|
suffix = "mlp.c_proj"
|
|
mlp_proj_w, mlp_proj_b = get_weight_and_bias(model_params,
|
|
f'{prefix}.{suffix}',
|
|
dtype)
|
|
if gpt_variant in ['gpt2', 'santacoder', 'jais']:
|
|
mlp_proj_w = mlp_proj_w.t().contiguous() # transpose for Conv1D
|
|
if gpt_variant in ['jais']:
|
|
mlp_proj_w = pad_array_up_to(mlp_proj_w, 1, mapping.tp_size)
|
|
if use_smooth_quant:
|
|
mlp_proj_w_t = mlp_proj_w.t()
|
|
int8_weights = generate_int8(mlp_proj_w_t,
|
|
act_range.get(f'{prefix}.mlp.c_proj'))
|
|
# change it to the real smoother if proj layer is applied smooth quant
|
|
fake_smoother_value = torch.ones([1, 4 * hidden_size],
|
|
dtype=torch.float32)
|
|
weights.update(
|
|
get_tllm_linear_sq_weight(
|
|
int8_weights,
|
|
f'{tllm_prex}.mlp.proj.', [1, hidden_size],
|
|
mapping.tp_size,
|
|
is_qkv=False,
|
|
per_token=per_token,
|
|
per_channel=per_channel,
|
|
last_prefix=f'{tllm_prex}.mlp.quantization_scaling_factor',
|
|
bias=mlp_proj_b,
|
|
smoother_value=fake_smoother_value,
|
|
smoother_shape=[1, 4 * hidden_size // mapping.tp_size],
|
|
rank=mapping.tp_rank,
|
|
cat_dim=0))
|
|
else:
|
|
mlp_proj_w = split(mlp_proj_w,
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
is_column=False)
|
|
weights.update(
|
|
get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj',
|
|
mlp_proj_b, use_weight_only,
|
|
plugin_weight_only_quant_type))
|
|
|
|
# (5) Input layernorm
|
|
apply_layernorm_1p = gpt_variant == 'nemotron'
|
|
if gpt_variant in ['starcoder2', 'nemotron', 'persimmon']:
|
|
input_ln_w, input_ln_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.input_layernorm', dtype)
|
|
elif gpt_variant == 'kosmos-2':
|
|
input_ln_w, input_ln_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.self_attn_layer_norm', dtype)
|
|
else:
|
|
input_ln_w, input_ln_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.ln_1', dtype)
|
|
if apply_layernorm_1p:
|
|
input_ln_w += 1.0
|
|
weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_w
|
|
if input_ln_b is not None:
|
|
weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_b
|
|
|
|
# (6) Post layernorm
|
|
if gpt_variant in ['starcoder2', 'nemotron', 'persimmon']:
|
|
post_ln_w, post_ln_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.post_attention_layernorm', dtype)
|
|
elif gpt_variant == 'kosmos-2':
|
|
post_ln_w, post_ln_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.final_layer_norm', dtype)
|
|
else:
|
|
post_ln_w, post_ln_b = get_weight_and_bias(model_params,
|
|
f'{prefix}.ln_2', dtype)
|
|
if apply_layernorm_1p:
|
|
post_ln_w += 1.0
|
|
weights[f'{tllm_prex}.post_layernorm.weight'] = post_ln_w
|
|
if post_ln_b is not None:
|
|
weights[f'{tllm_prex}.post_layernorm.bias'] = post_ln_b
|
|
|
|
if gpt_variant == 'persimmon':
|
|
q_layernorm_w, q_layernorm_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.self_attn.q_layernorm', dtype)
|
|
|
|
weights[f'{tllm_prex}.attention.q_layernorm.weight'] = q_layernorm_w
|
|
weights[f'{tllm_prex}.attention.q_layernorm.bias'] = q_layernorm_b
|
|
|
|
k_layernorm_w, k_layernorm_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.self_attn.k_layernorm', dtype)
|
|
|
|
weights[f'{tllm_prex}.attention.k_layernorm.weight'] = k_layernorm_w
|
|
weights[f'{tllm_prex}.attention.k_layernorm.bias'] = k_layernorm_b
|
|
|
|
if gpt_variant == 'kosmos-2':
|
|
q_layernorm_w, q_layernorm_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.self_attn.inner_attn_ln', dtype)
|
|
|
|
weights[
|
|
f'{tllm_prex}.attention.inner_layernorm.weight'] = q_layernorm_w
|
|
weights[
|
|
f'{tllm_prex}.attention.inner_layernorm.bias'] = q_layernorm_b
|
|
|
|
k_layernorm_w, k_layernorm_b = get_weight_and_bias(
|
|
model_params, f'{prefix}.ffn.ffn_layernorm', dtype)
|
|
|
|
weights[f'{tllm_prex}.mlp.inner_layernorm.weight'] = k_layernorm_w
|
|
weights[f'{tllm_prex}.mlp.inner_layernorm.bias'] = k_layernorm_b
|
|
|
|
if mapping.is_first_pp_rank():
|
|
if gpt_variant in ['starcoder2', 'nemotron']:
|
|
embed_w = get_weight(model_params, 'model.embed_tokens', dtype)
|
|
elif gpt_variant == 'kosmos-2':
|
|
embed_w = get_weight(model_params, 'text_model.model.embed_tokens',
|
|
dtype)
|
|
elif gpt_variant == 'persimmon':
|
|
embed_w = get_weight(model_params,
|
|
('language_model.' if is_fuyu else '') +
|
|
'model.embed_tokens', dtype)
|
|
else:
|
|
embed_w = get_weight(model_params, 'transformer.wte', dtype)
|
|
weights['transformer.vocab_embedding.weight'] = split_embedding(
|
|
embed_w,
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
use_parallel_embedding=config.use_parallel_embedding,
|
|
sharding_dim=config.embedding_sharding_dim)
|
|
|
|
if gpt_variant == 'kosmos-2':
|
|
padding_idx = hf_config.text_config.pad_token_id
|
|
sin_pos_embedding = hf_model.text_model.model.embed_positions.get_embedding(
|
|
padding_idx + 1 + hf_config.text_config.max_position_embeddings,
|
|
hf_config.text_config.embed_dim,
|
|
padding_idx=padding_idx) # [2 + num_embeddings, embed_dim]
|
|
pos_embed_w = sin_pos_embedding[2:].to(dtype).detach().cpu()
|
|
else:
|
|
pos_embed_w = get_weight(model_params, 'transformer.wpe', dtype)
|
|
if pos_embed_w is not None:
|
|
weights['transformer.position_embedding.weight'] = split_embedding(
|
|
pos_embed_w,
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
use_parallel_embedding=config.use_parallel_embedding,
|
|
sharding_dim=config.embedding_sharding_dim)
|
|
|
|
if mapping.is_last_pp_rank():
|
|
if gpt_variant in ['starcoder2', 'nemotron']:
|
|
embed_w = get_weight(model_params, 'lm_head', dtype)
|
|
if embed_w is None:
|
|
embed_w = get_weight(model_params, 'model.embed_tokens', dtype)
|
|
elif gpt_variant == 'persimmon':
|
|
embed_w = get_weight(model_params,
|
|
('language_model.' if is_fuyu else '') +
|
|
'lm_head', dtype)
|
|
elif gpt_variant == 'kosmos-2':
|
|
embed_w = get_weight(model_params, 'text_model.model.embed_tokens',
|
|
dtype)
|
|
else:
|
|
embed_w = get_weight(model_params, 'transformer.wte', dtype)
|
|
|
|
if vocab_size % mapping.tp_size != 0:
|
|
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
|
|
pad_width = vocab_size_padded - vocab_size
|
|
embed_w = torch.nn.functional.pad(embed_w, (0, 0, 0, pad_width),
|
|
value=0)
|
|
if hasattr(hf_config, 'logits_scale'):
|
|
embed_w *= hf_config.logits_scale
|
|
weights['lm_head.weight'] = split(embed_w.clone(),
|
|
mapping.tp_rank,
|
|
mapping.tp_size,
|
|
is_column=True)
|
|
|
|
if gpt_variant in ['starcoder2', 'nemotron']:
|
|
ln_f_w, ln_f_b = get_weight_and_bias(model_params, 'model.norm',
|
|
dtype)
|
|
elif gpt_variant == 'persimmon':
|
|
ln_f_w, ln_f_b = get_weight_and_bias(
|
|
model_params, ('language_model.' if is_fuyu else '') +
|
|
'model.final_layernorm', dtype)
|
|
elif gpt_variant == 'kosmos-2':
|
|
ln_f_w, ln_f_b = get_weight_and_bias(model_params,
|
|
'text_model.model.layer_norm',
|
|
dtype)
|
|
else:
|
|
ln_f_w, ln_f_b = get_weight_and_bias(model_params,
|
|
'transformer.ln_f', dtype)
|
|
if apply_layernorm_1p:
|
|
ln_f_w += 1.0
|
|
weights['transformer.ln_f.weight'] = ln_f_w
|
|
if ln_f_b is not None:
|
|
weights['transformer.ln_f.bias'] = ln_f_b
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
print(f'Weights loaded. Total time: {t}')
|
|
return weights
|
|
|
|
|
|
def quantize(hf_model_dir: str,
|
|
output_dir: str,
|
|
config: GPTConfig,
|
|
device: str = 'cuda',
|
|
calib_dataset: str = 'cnn_dailymail',
|
|
trust_remote_code: bool = True):
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
config.to_json_file(os.path.join(output_dir, 'config.json'))
|
|
|
|
mapping = config.mapping
|
|
assert mapping.rank == 0, "quantize should be called at rank 0 only"
|
|
|
|
quant_config = config.quantization
|
|
use_smooth_quant = quant_config._use_plugin_sq
|
|
int8_kv_cache = quant_config.kv_cache_quant_algo == QuantAlgo.INT8
|
|
|
|
assert use_smooth_quant or int8_kv_cache, "Call from_hugging_face when there is no quantization"
|
|
assert hf_model_dir is not None
|
|
## only load and call smooth quant routine once for all ranks
|
|
hf_model = AutoModelForCausalLM.from_pretrained(
|
|
hf_model_dir,
|
|
device_map='auto' if device != 'cpu' else 'cpu',
|
|
torch_dtype='auto' if not use_smooth_quant else torch.float16,
|
|
trust_remote_code=trust_remote_code)
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
|
|
"TOKENIZERS_PARALLELISM", "false")
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
hf_model_dir,
|
|
trust_remote_code=trust_remote_code,
|
|
use_fast=False,
|
|
padding_side='left')
|
|
|
|
dataset = load_calib_dataset(calib_dataset)
|
|
act_range = capture_activation_range(hf_model, tokenizer, dataset)
|
|
if use_smooth_quant:
|
|
smooth_gpt_model(hf_model, act_range, quant_config.smoothquant_val)
|
|
|
|
for rank in range(mapping.world_size):
|
|
# To avoid changing the mapping arg in-place, also the given mapping from caller is rank agnostic, since quantize is called from only one rank
|
|
config = copy.deepcopy(config)
|
|
config.set_rank(rank)
|
|
weights = load_weights_from_hf_model(
|
|
hf_model,
|
|
config=config,
|
|
act_range=act_range,
|
|
)
|
|
safetensors.torch.save_file(
|
|
weights, os.path.join(output_dir, f'rank{rank}.safetensors'))
|
|
del weights
|
|
|
|
|
|
def load_hf_gpt(model_dir: str, load_model_on_cpu: bool = False):
|
|
if 'kosmos-2' in model_dir:
|
|
hf_model = AutoModelForVision2Seq.from_pretrained(
|
|
model_dir, trust_remote_code=True)
|
|
else:
|
|
hf_model = AutoModelForCausalLM.from_pretrained(
|
|
model_dir,
|
|
device_map='auto' if not load_model_on_cpu else 'cpu',
|
|
torch_dtype='auto',
|
|
trust_remote_code=True,
|
|
)
|
|
return hf_model
|
|
|
|
|
|
def cpu_map_location(storage, loc):
|
|
return storage.cpu()
|
|
|
|
|
|
def gpu_map_location(storage, loc):
|
|
if loc.startswith("cuda"):
|
|
training_gpu_idx = int(loc.split(":")[1])
|
|
inference_gpu_idx = training_gpu_idx % torch.cuda.device_count()
|
|
return storage.cuda(inference_gpu_idx)
|
|
elif loc.startswith("cpu"):
|
|
return storage.cpu()
|
|
else:
|
|
raise ValueError(f"Not handled {loc}")
|
|
|
|
|
|
def copy_tokenizer_files(config, out_dir):
|
|
basenames = {
|
|
"model": "tokenizer",
|
|
"vocab_file": "vocab",
|
|
"merge_file": "merges",
|
|
}
|
|
|
|
for key in basenames.keys():
|
|
if config[key] is None:
|
|
continue
|
|
path = Path(config[key])
|
|
if not path.exists():
|
|
logger.debug(f"Tokenizer {key}: {path} file not found")
|
|
continue
|
|
|
|
dst_path = out_dir / f"{basenames[key]}{path.suffix}"
|
|
logger.debug(f"Copy tokenizer {key}: {path}->{dst_path}")
|
|
shutil.copy(path.as_posix(), dst_path.as_posix())
|
|
|
|
|
|
def update_tokenizer_paths(tokenizer_config: Dict,
|
|
tokenizer_file_paths: Dict[str, Optional[str]]):
|
|
for key, new_path in tokenizer_file_paths.items():
|
|
old_path = tokenizer_config[key]
|
|
if old_path is None:
|
|
continue
|
|
old_path = Path(old_path)
|
|
if new_path:
|
|
logger.debug(f"Update tokenizer {key} {old_path} -> {new_path}")
|
|
tokenizer_config[key] = new_path.as_posix()
|
|
elif not old_path.exists():
|
|
logger.warning(
|
|
f"Tokenizer {key}'s path {old_path} does not exists: set it to None"
|
|
)
|
|
tokenizer_config[key] = None
|
|
return tokenizer_config
|
|
|
|
|
|
def unpack_nemo_ckpt(nemo_archive_path: Union[str, Path],
|
|
out_dir_path: Union[str, Path]):
|
|
nemo_archive_path = Path(nemo_archive_path)
|
|
if not nemo_archive_path.exists():
|
|
raise FileNotFoundError(f"{nemo_archive_path} does not exist")
|
|
|
|
for tar_mode in ["r:", "r:gz"]:
|
|
try:
|
|
with tarfile.open(nemo_archive_path, mode=tar_mode) as tar_file:
|
|
|
|
def is_within_directory(directory, target):
|
|
|
|
abs_directory = os.path.abspath(directory)
|
|
abs_target = os.path.abspath(target)
|
|
|
|
prefix = os.path.commonprefix([abs_directory, abs_target])
|
|
|
|
return prefix == abs_directory
|
|
|
|
def safe_members(tar_file):
|
|
members = []
|
|
for member in tar_file.getmembers():
|
|
member_path = os.path.join(out_dir_path, member.name)
|
|
if not is_within_directory(out_dir_path, member_path):
|
|
raise Exception(
|
|
"Attempted Path Traversal in Tar File")
|
|
members.append(member)
|
|
return members
|
|
|
|
for member in safe_members(tar_file):
|
|
tar_file.extract(member,
|
|
path=out_dir_path,
|
|
numeric_owner=False,
|
|
filter=tarfile.data_filter)
|
|
|
|
return out_dir_path
|
|
except tarfile.ReadError:
|
|
pass
|
|
|
|
raise RuntimeError(f"Could not unpack {nemo_archive_path}")
|
|
|
|
|
|
def extract_layers_with_prefix(model_, prefix):
|
|
length_to_trim = len(prefix)
|
|
model_state = model_.get("state_dict", model_)
|
|
return {
|
|
key[length_to_trim:]: model_state[key]
|
|
for key in model_state.keys() if prefix in key
|
|
}
|
|
|
|
|
|
class UnpackedNemoCheckpointDir:
|
|
|
|
def __init__(self,
|
|
checkpoints_dir: Union[str, Path],
|
|
load_checkpoints_to_cpu: bool = False):
|
|
self._checkpoints_dir = Path(checkpoints_dir)
|
|
self._load_checkpoints_to_cpu = load_checkpoints_to_cpu
|
|
|
|
@property
|
|
@functools.lru_cache
|
|
def model_config(self):
|
|
model_config = None
|
|
|
|
model_config_filename = "model_config.yaml"
|
|
model_configs_paths = list(
|
|
self._checkpoints_dir.rglob(model_config_filename))
|
|
if model_configs_paths:
|
|
if len(model_configs_paths) > 1:
|
|
raise RuntimeError(
|
|
f"There are more than single {model_config_filename} "
|
|
f"in {self._checkpoints_dir}: {', '.join(map(lambda p: p.as_posix(), model_configs_paths))}"
|
|
)
|
|
model_config_path = model_configs_paths[0]
|
|
logger.debug(f"Loading model config from {model_config_path}")
|
|
with model_config_path.open("r") as model_config_file:
|
|
model_config = yaml.load(model_config_file,
|
|
Loader=yaml.SafeLoader)
|
|
else:
|
|
logger.debug("Searching model config in checkpoints")
|
|
# try to obtain from checkpoint
|
|
checkpoint_name = self.checkpoint_name
|
|
checkpoints_paths = sorted(
|
|
self._checkpoints_dir.rglob(checkpoint_name))
|
|
if checkpoints_paths:
|
|
# assume that parallel ranks 0 checkpoint should have model config embedded
|
|
checkpoint_path = checkpoints_paths[0]
|
|
|
|
map_location_fn = cpu_map_location if self._load_checkpoints_to_cpu else gpu_map_location
|
|
|
|
model_00 = torch.load(checkpoint_path,
|
|
map_location=map_location_fn)
|
|
if "hyper_parameters" in model_00 and "cfg" in model_00[
|
|
"hyper_parameters"]:
|
|
model_config = model_00["hyper_parameters"]["cfg"]
|
|
logger.debug(
|
|
f"Loaded model config from checkpoint {checkpoint_path}"
|
|
)
|
|
else:
|
|
logger.debug(
|
|
f"Could not find model config in checkpoint {checkpoint_path}"
|
|
)
|
|
del model_00
|
|
|
|
if model_config is None:
|
|
logger.warning(
|
|
f"Could not find checkpoint with NeMo model config in {self._checkpoints_dir}"
|
|
)
|
|
|
|
logger.debug(f"Loaded model config {model_config}")
|
|
|
|
return model_config
|
|
|
|
@property
|
|
def checkpoints_dir(self):
|
|
return self._checkpoints_dir
|
|
|
|
def get_checkpoints_paths(self,
|
|
tensor_model_parallel_size=1,
|
|
pipeline_model_parallel_size=1):
|
|
"""
|
|
Injects tensor/pipeline model parallel ranks into the filepath.
|
|
Does nothing if not using model parallelism.
|
|
"""
|
|
|
|
checkpoint_path_without_rank = self.checkpoints_dir / self.checkpoint_name
|
|
|
|
def _inject_parallel_ranks(tp_rank, pp_rank):
|
|
if tensor_model_parallel_size > 1 or pipeline_model_parallel_size > 1:
|
|
if pipeline_model_parallel_size is None or pipeline_model_parallel_size == 1:
|
|
checkpoint_path = (checkpoint_path_without_rank.parent /
|
|
f"mp_rank_{tp_rank:02d}" /
|
|
checkpoint_path_without_rank.name)
|
|
else:
|
|
checkpoint_path = (
|
|
checkpoint_path_without_rank.parent /
|
|
f"tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}" /
|
|
checkpoint_path_without_rank.name)
|
|
return checkpoint_path
|
|
else:
|
|
return checkpoint_path_without_rank
|
|
|
|
return [[
|
|
_inject_parallel_ranks(tp_rank=tp_rank, pp_rank=pp_rank)
|
|
for pp_rank in range(pipeline_model_parallel_size)
|
|
] for tp_rank in range(tensor_model_parallel_size)]
|
|
|
|
@property
|
|
@functools.lru_cache
|
|
def checkpoint_name(self):
|
|
patterns = [
|
|
"model_weights.ckpt", # older megatron checkpoints
|
|
"*last.ckpt", # newer format of checkpoints
|
|
]
|
|
for pattern in patterns:
|
|
model_files = sorted(list(self._checkpoints_dir.rglob(pattern)))
|
|
if model_files:
|
|
return model_files[0].name
|
|
|
|
raise ValueError(
|
|
f"Could not find checkpoint files in {self._checkpoints_dir}")
|
|
|
|
@functools.lru_cache
|
|
def get_tokenizer_file_path(self, tokenizer_key, file_key,
|
|
default_filename_pattern):
|
|
model_config = self.model_config
|
|
file_property = None
|
|
if tokenizer_key in model_config and file_key in model_config[
|
|
tokenizer_key]:
|
|
file_property = model_config[tokenizer_key][file_key]
|
|
elif file_key in model_config:
|
|
file_property = model_config[file_key]
|
|
|
|
logger.debug(
|
|
f"model_config[{tokenizer_key}][{file_key}]={file_property}")
|
|
|
|
if file_property and file_property.startswith("nemo:"):
|
|
filename = file_property.split("nemo:")[1]
|
|
filename_pattern = f"*{filename}"
|
|
elif file_property and file_property.startswith("/artifacts/"):
|
|
filename = Path(file_property).name
|
|
filename_pattern = f"*{filename}"
|
|
elif file_property is None or file_property == "None":
|
|
filename_pattern = None
|
|
else:
|
|
filename_pattern = default_filename_pattern
|
|
logger.warning(
|
|
f"Tokenizer file from config: {tokenizer_key}.{file_key}={file_property} "
|
|
f"looks like unsupported path. Pattern {filename_pattern} will be used."
|
|
)
|
|
|
|
file_path = None
|
|
if filename_pattern is not None:
|
|
files_paths = list(self._checkpoints_dir.glob(filename_pattern))
|
|
if files_paths:
|
|
assert len(files_paths) == 1
|
|
file_path = files_paths[0]
|
|
|
|
return file_path
|
|
|
|
@functools.lru_cache
|
|
def get_all_tokenizer_file_paths(self):
|
|
return {
|
|
"model":
|
|
self.get_tokenizer_file_path("tokenizer", "model", "*.model"),
|
|
"vocab_file":
|
|
self.get_tokenizer_file_path("tokenizer", "vocab_file", "*vocab*"),
|
|
"merge_file":
|
|
self.get_tokenizer_file_path("tokenizer", "merge_file",
|
|
"*merge*.txt"),
|
|
}
|
|
|
|
|
|
@torch.no_grad()
|
|
def load_torch_checkpoints(checkpoints_paths,
|
|
merge_factor,
|
|
tp_rank,
|
|
pp_rank,
|
|
map_location_fn,
|
|
handle_model_level_weights,
|
|
layer_rename_config: Dict[str, str] = {}):
|
|
models = []
|
|
for k in range(merge_factor):
|
|
rank_weights = checkpoints_paths[tp_rank * merge_factor + k][pp_rank]
|
|
model = torch.load(rank_weights, map_location=map_location_fn)
|
|
model = rename_keys(model, layer_rename_config)
|
|
handle_model_level_weights(model, tp_rank * merge_factor + k, pp_rank)
|
|
layers = extract_layers_with_prefix(model,
|
|
"model.language_model.encoder.")
|
|
models.append(layers)
|
|
return models
|
|
|
|
|
|
@torch.no_grad()
|
|
def load_weights_from_nemo(nemo_ckpt_dir: str, config: GPTConfig, **kwargs):
|
|
assert config.mapping.pp_size == 1, \
|
|
"Pipeline parallelism is not supported."
|
|
assert not config.quantization.quant_mode.has_any_quant(), \
|
|
"Quantization is not supported."
|
|
|
|
load_model_on_cpu = kwargs.pop('load_model_on_cpu', False)
|
|
nemo_rename_key = kwargs.pop('nemo_rename_key', [])
|
|
layer_rename_config = {
|
|
pattern.split(':')[0]: pattern.split(':')[1]
|
|
for pattern in nemo_rename_key
|
|
}
|
|
|
|
unpacked_checkpoints_dir = UnpackedNemoCheckpointDir(
|
|
nemo_ckpt_dir, load_checkpoints_to_cpu=load_model_on_cpu)
|
|
nemo_model_config = unpacked_checkpoints_dir.model_config
|
|
|
|
checkpoints_paths = unpacked_checkpoints_dir.get_checkpoints_paths(
|
|
nemo_model_config.get("tensor_model_parallel_size", 1),
|
|
nemo_model_config.get("pipeline_model_parallel_size", 1),
|
|
)
|
|
|
|
if unpacked_checkpoints_dir._load_checkpoints_to_cpu:
|
|
map_location_fn = cpu_map_location
|
|
else:
|
|
map_location_fn = gpu_map_location
|
|
dtype = str_dtype_to_torch(config.dtype)
|
|
|
|
# load position_embedding from rank 0
|
|
model_00 = torch.load(checkpoints_paths[0][0], map_location=map_location_fn)
|
|
model_00 = model_00.get("state_dict", model_00)
|
|
model_00 = rename_keys(model_00, layer_rename_config)
|
|
has_position_embedding = "model.language_model.embedding.position_embeddings.weight" in model_00
|
|
has_lm_head = "model.language_model.output_layer.weight" in model_00
|
|
del model_00
|
|
|
|
num_layers = nemo_model_config["num_layers"]
|
|
training_tp_size = nemo_model_config.get("tensor_model_parallel_size", 1)
|
|
training_pp_size = nemo_model_config.get("pipeline_model_parallel_size", 1)
|
|
inference_tp_size = config.mapping.tp_size
|
|
inference_tp_rank = config.mapping.tp_rank
|
|
|
|
apply_layernorm_1p = (nemo_model_config.get('normalization',
|
|
'') == "layernorm1p")
|
|
split_gated_activation = ("swiglu"
|
|
in nemo_model_config.get('activation', "gelu"))
|
|
num_attention_heads = nemo_model_config["num_attention_heads"]
|
|
# use_attention_nemo_shape = True
|
|
transpose_weights = True
|
|
# multi_query_mode = False
|
|
local_dim = None
|
|
|
|
# merge_factor: how many TP training nodes are merged into an inference TP node
|
|
# split_factor: in how many parts a TP training node is split
|
|
gcd = np.gcd(training_tp_size, inference_tp_size)
|
|
merge_factor = training_tp_size // gcd
|
|
split_factor = inference_tp_size // gcd
|
|
|
|
model_level_weights = defaultdict(list)
|
|
|
|
def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
|
|
if tp_idx == 0 and pp_idx == 0:
|
|
if has_position_embedding:
|
|
val = model[
|
|
"model.language_model.embedding.position_embeddings.weight"].detach(
|
|
).cpu()
|
|
model_level_weights[
|
|
"transformer.position_embedding.weight"].append(val)
|
|
if pp_idx == 0:
|
|
val = model.get(
|
|
"state_dict", model
|
|
)["model.language_model.embedding.word_embeddings.weight"].detach(
|
|
).cpu()
|
|
model_level_weights["transformer.vocab_embedding.weight"].append(
|
|
val)
|
|
if has_lm_head and pp_idx == training_pp_size - 1:
|
|
val = model.get(
|
|
"state_dict",
|
|
model)["model.language_model.output_layer.weight"].detach().cpu(
|
|
)
|
|
model_level_weights["lm_head.weight"].append(val)
|
|
|
|
weights = {}
|
|
tik = time.time()
|
|
tp_rank = inference_tp_rank // split_factor
|
|
# for tp_rank in range(training_tp_size // merge_factor):
|
|
for pp_rank in range(training_pp_size):
|
|
models = load_torch_checkpoints(checkpoints_paths, merge_factor,
|
|
tp_rank, pp_rank, map_location_fn,
|
|
handle_model_level_weights,
|
|
layer_rename_config)
|
|
for name in list(models[0].keys()):
|
|
params = [model[name].detach().cpu() for model in models]
|
|
if transpose_weights and params[0].ndim == 2:
|
|
params = [p.T for p in params]
|
|
if "layernorm.weight" in name and apply_layernorm_1p:
|
|
params = [p + 1.0 for p in params]
|
|
|
|
l = retrieved_layer_index_from_name(name)
|
|
if l is not None:
|
|
new_l = l + pp_rank * num_layers // training_pp_size
|
|
prefix = f'transformer.layers.{new_l}'
|
|
|
|
if 'attention.query_key_value' in name:
|
|
if name.endswith('weight'):
|
|
hidden_dim = params[0].shape[0]
|
|
if local_dim is None:
|
|
local_dim = params[0].shape[-1] // 3
|
|
|
|
# multi_query_mode = False; use_attention_nemo_shape = True
|
|
head_num = num_attention_heads // training_tp_size
|
|
size_per_head = hidden_dim // num_attention_heads
|
|
params = [
|
|
param.reshape(hidden_dim, head_num, 3,
|
|
size_per_head) for param in params
|
|
]
|
|
params = [param.permute(0, 2, 1, 3) for param in params]
|
|
params = [
|
|
param.reshape(hidden_dim, 3, local_dim)
|
|
for param in params
|
|
]
|
|
cat_dim = -1
|
|
param = torch.concat(params, dim=cat_dim)
|
|
param = torch.chunk(param, split_factor,
|
|
dim=cat_dim)[inference_tp_rank %
|
|
split_factor]
|
|
weights[
|
|
f'{prefix}.attention.qkv.weight'] = param.reshape(
|
|
hidden_dim, -1).t()
|
|
else:
|
|
if local_dim is None:
|
|
local_dim = params[0].shape[-1] // 3
|
|
|
|
# multi_query_mode = False; use_attention_nemo_shape = True
|
|
head_num = num_attention_heads // training_tp_size
|
|
size_per_head = local_dim // head_num
|
|
params = [
|
|
param.reshape(head_num, 3, size_per_head)
|
|
for param in params
|
|
]
|
|
params = [param.permute(1, 0, 2) for param in params]
|
|
params = [
|
|
param.reshape(3, local_dim) for param in params
|
|
]
|
|
cat_dim = -1
|
|
param = torch.concat(params, dim=cat_dim)
|
|
param = torch.chunk(param, split_factor,
|
|
dim=cat_dim)[inference_tp_rank %
|
|
split_factor]
|
|
weights[f'{prefix}.attention.qkv.bias'] = param.reshape(
|
|
-1)
|
|
|
|
elif 'attention.dense' in name:
|
|
if name.endswith('weight'):
|
|
cat_dim = 0
|
|
param = torch.concat(params, dim=cat_dim)
|
|
param = torch.chunk(param, split_factor,
|
|
dim=cat_dim)[inference_tp_rank %
|
|
split_factor]
|
|
weights[f'{prefix}.attention.dense.weight'] = param.t()
|
|
else:
|
|
weights[f'{prefix}.attention.dense.bias'] = params[0]
|
|
|
|
elif 'mlp.dense_h_to_4h' in name:
|
|
if name.endswith('weight'):
|
|
if split_gated_activation:
|
|
params = [torch.chunk(p, 2, dim=-1) for p in params]
|
|
params, gate_params = list(zip(*params))
|
|
cat_dim = -1
|
|
param = torch.concat(params, dim=cat_dim)
|
|
param = torch.chunk(param, split_factor,
|
|
dim=cat_dim)[inference_tp_rank %
|
|
split_factor]
|
|
weights[f'{prefix}.mlp.fc.weight'] = param.t()
|
|
if split_gated_activation:
|
|
gate_param = torch.concat(gate_params, dim=cat_dim)
|
|
gate_param = torch.chunk(
|
|
gate_param, split_factor,
|
|
dim=cat_dim)[inference_tp_rank % split_factor]
|
|
weights[f'{prefix}.mlp.gate.weight'] = gate_param.t(
|
|
)
|
|
else:
|
|
if split_gated_activation:
|
|
params = [torch.chunk(p, 2, dim=-1) for p in params]
|
|
params, gate_params = list(zip(*params))
|
|
cat_dim = -1
|
|
param = torch.concat(params, dim=cat_dim)
|
|
param = torch.chunk(param, split_factor,
|
|
dim=cat_dim)[inference_tp_rank %
|
|
split_factor]
|
|
weights[f'{prefix}.mlp.fc.bias'] = param
|
|
if split_gated_activation:
|
|
gate_param = torch.concat(gate_params, dim=cat_dim)
|
|
gate_param = torch.chunk(
|
|
gate_param, split_factor,
|
|
dim=cat_dim)[inference_tp_rank % split_factor]
|
|
weights[f'{prefix}.mlp.gate.bias'] = gate_param
|
|
|
|
elif 'mlp.dense_4h_to_h' in name:
|
|
if name.endswith('weight'):
|
|
cat_dim = 0
|
|
param = torch.concat(params, dim=cat_dim)
|
|
param = torch.chunk(param, split_factor,
|
|
dim=cat_dim)[inference_tp_rank %
|
|
split_factor]
|
|
weights[f'{prefix}.mlp.proj.weight'] = param.t()
|
|
else:
|
|
weights[f'{prefix}.mlp.proj.bias'] = params[0]
|
|
|
|
elif 'input_layernorm' in name:
|
|
if name.endswith('weight'):
|
|
weights[f'{prefix}.input_layernorm.weight'] = params[0]
|
|
else:
|
|
weights[f'{prefix}.input_layernorm.bias'] = params[0]
|
|
elif 'post_attention_layernorm' in name:
|
|
if name.endswith('weight'):
|
|
weights[f'{prefix}.post_layernorm.weight'] = params[0]
|
|
else:
|
|
weights[f'{prefix}.post_layernorm.bias'] = params[0]
|
|
|
|
elif 'final_layernorm' in name:
|
|
if name.endswith('weight'):
|
|
weights['transformer.ln_f.weight'] = params[0]
|
|
else:
|
|
weights['transformer.ln_f.bias'] = params[0]
|
|
for model in models:
|
|
del model[name]
|
|
del models
|
|
|
|
for key in list(model_level_weights.keys()):
|
|
weights[key] = torch.concat(model_level_weights[key], dim=0)
|
|
weights[key] = torch.chunk(weights[key], split_factor,
|
|
dim=0)[inference_tp_rank % split_factor]
|
|
del model_level_weights[key]
|
|
for key, param in weights.items():
|
|
weights[key] = weights[key].to(dtype).contiguous()
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
print(f'Weights loaded. Total time: {t}')
|
|
return weights
|