mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Denis Kayshev <topenkoff@gmail.com> Co-authored-by: akhoroshev <arthoroshev@gmail.com> Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com> Update
649 lines
26 KiB
Python
649 lines
26 KiB
Python
import copy
|
|
import functools
|
|
import time
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from tqdm import tqdm
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
from transformers.pytorch_utils import Conv1D
|
|
|
|
from tensorrt_llm import logger
|
|
from tensorrt_llm._utils import str_dtype_to_torch
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.models.convert_utils import (dup_kv_weight, generate_int8,
|
|
smooth_gemm,
|
|
smooth_gemm_fc1_gate, split,
|
|
split_matrix_tp, split_qkv_tp)
|
|
|
|
|
|
def get_tllm_linear_weight(weight,
|
|
prefix,
|
|
bias=None,
|
|
use_weight_only=False,
|
|
plugin_weight_only_quant_type=torch.int8,
|
|
postfix='weight'):
|
|
results = {}
|
|
if use_weight_only:
|
|
v = weight.t().contiguous().cpu()
|
|
processed_torch_weights, torch_weight_scales = \
|
|
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
|
|
v, plugin_weight_only_quant_type)
|
|
results[prefix + postfix] = processed_torch_weights
|
|
results[prefix + 'per_channel_scale'] = torch_weight_scales
|
|
else:
|
|
results[prefix + postfix] = weight.contiguous()
|
|
|
|
if bias is not None:
|
|
results[prefix + 'bias'] = bias
|
|
|
|
return results
|
|
|
|
|
|
def load_medusa_hf(medusa_path: str,
|
|
num_medusa_heads: int,
|
|
num_medusa_layers: int,
|
|
mapping=Mapping(),
|
|
dtype='float32',
|
|
use_weight_only=False,
|
|
plugin_weight_only_quant_type=None,
|
|
is_modelopt_ckpt=False):
|
|
# logger.info("Loading Medusa heads' weights ...")
|
|
|
|
if is_modelopt_ckpt:
|
|
from safetensors.torch import load_file
|
|
state_dict = {}
|
|
for filename in sorted(Path(medusa_path).glob("*.safetensors")):
|
|
print(f"Loading the weights of Medusa heads from {filename}")
|
|
state_dict.update(load_file(filename))
|
|
else:
|
|
is_ckpt_safetensors = False
|
|
|
|
ckpt_file = Path(medusa_path) / "medusa_lm_head.pt"
|
|
if not ckpt_file.exists():
|
|
ckpt_file = Path(medusa_path) / "medusa_lm_head.safetensors"
|
|
is_ckpt_safetensors = True
|
|
|
|
if is_ckpt_safetensors:
|
|
logger.info("Safetensors Found ...")
|
|
from safetensors.torch import load_file
|
|
state_dict = load_file(ckpt_file)
|
|
else:
|
|
state_dict = torch.load(ckpt_file, map_location="cpu")
|
|
|
|
torch_dtype = str_dtype_to_torch(dtype)
|
|
weights = {}
|
|
|
|
prefix = "medusa_heads." if is_modelopt_ckpt else ""
|
|
for h in range(num_medusa_heads):
|
|
for l in range(num_medusa_layers):
|
|
w = state_dict[f"{prefix}{h}.{l}.linear.weight"].clone().to(
|
|
torch_dtype)
|
|
|
|
split_v = split(w, mapping.tp_size, mapping.tp_rank)
|
|
weights.update(
|
|
get_tllm_linear_weight(
|
|
split_v, f'medusa_heads.{h}.medusa_layers.{l}.linear.',
|
|
None, use_weight_only, plugin_weight_only_quant_type))
|
|
|
|
b = state_dict[f"{prefix}{h}.{l}.linear.bias"].clone().to(
|
|
torch_dtype)
|
|
|
|
weights['medusa_heads.{}.medusa_layers.{}.linear.bias'.format(
|
|
h, l)] = split(b, mapping.tp_size, mapping.tp_rank)
|
|
|
|
lm = state_dict[f"{prefix}{h}.{num_medusa_layers}.weight"].clone().to(
|
|
torch_dtype) # LM Head
|
|
|
|
weights['medusa_heads.{}.lm_head.weight'.format(h)] = split(
|
|
lm, mapping.tp_size, mapping.tp_rank)
|
|
|
|
# scaling factors
|
|
if is_modelopt_ckpt:
|
|
scaling_dtype = str_dtype_to_torch("float32")
|
|
weights[f'medusa_heads.{h}.medusa_layers.{l}.linear.activation_scaling_factor'] = \
|
|
state_dict[f"{prefix}{h}.{l}.linear.input_scale"].clone().to(scaling_dtype)
|
|
|
|
weights[f'medusa_heads.{h}.medusa_layers.{l}.linear.weights_scaling_factor'] = \
|
|
state_dict[f"{prefix}{h}.{l}.linear.weight_scale"].clone().to(scaling_dtype)
|
|
|
|
weights['medusa_heads.{}.lm_head.activation_scaling_factor'.format(h)] = \
|
|
state_dict[f"{prefix}{h}.{num_medusa_layers}.input_scale"].clone().to(scaling_dtype)
|
|
|
|
weights['medusa_heads.{}.lm_head.weights_scaling_factor'.format(h)] = \
|
|
state_dict[f"{prefix}{h}.{num_medusa_layers}.weight_scale"].clone().to(scaling_dtype)
|
|
|
|
return weights
|
|
|
|
|
|
@torch.no_grad()
|
|
def smooth_llama_model(model, scales, alpha, llama_qkv_para, llama_smoother):
|
|
# Smooth the activation and weights with smoother = $\diag{s}$
|
|
for name, module in model.named_modules():
|
|
if not isinstance(module, LlamaDecoderLayer):
|
|
continue
|
|
# qkv_proj
|
|
layer_name_q = name + ".self_attn.q_proj"
|
|
layer_name_k = name + ".self_attn.k_proj"
|
|
layer_name_v = name + ".self_attn.v_proj"
|
|
layer_name_qkv = name + ".self_attn.qkv_proj"
|
|
|
|
weight = torch.cat([
|
|
module.self_attn.q_proj.weight, module.self_attn.k_proj.weight,
|
|
module.self_attn.v_proj.weight
|
|
],
|
|
dim=0)
|
|
|
|
smoother = smooth_gemm(weight, scales[layer_name_q]["x"],
|
|
module.input_layernorm.weight, None, alpha)
|
|
|
|
scales[layer_name_qkv]["x"] = scales[layer_name_q]["x"] / smoother
|
|
scales[layer_name_qkv]["w"] = weight.abs().max(dim=1)[0]
|
|
scales[layer_name_qkv]["y"] = torch.cat([
|
|
scales[layer_name_q]["y"], scales[layer_name_k]["y"],
|
|
scales[layer_name_v]["y"]
|
|
],
|
|
dim=0)
|
|
|
|
# see transpose_weights function
|
|
llama_qkv_para[layer_name_qkv] = weight.transpose(0, 1)
|
|
|
|
# =================================================================
|
|
layer_name = name + ".self_attn.o_proj"
|
|
smoother = smooth_gemm(module.self_attn.o_proj.weight,
|
|
scales[layer_name]["x"], None, None, alpha)
|
|
llama_smoother[layer_name] = smoother.float()
|
|
|
|
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
|
|
scales[layer_name]["w"] = module.self_attn.o_proj.weight.abs().max(
|
|
dim=1)[0]
|
|
|
|
# ==================================================================
|
|
fc1_layer_name = name + ".mlp.gate_proj"
|
|
gate_layer_name = name + ".mlp.up_proj"
|
|
|
|
smoother = smooth_gemm_fc1_gate(module.mlp.gate_proj.weight,
|
|
module.mlp.up_proj.weight,
|
|
scales[fc1_layer_name]["x"],
|
|
module.post_attention_layernorm.weight,
|
|
None, alpha)
|
|
|
|
scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother
|
|
scales[fc1_layer_name]["w"] = module.mlp.gate_proj.weight.abs().max(
|
|
dim=1)[0]
|
|
|
|
scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother
|
|
scales[gate_layer_name]["w"] = module.mlp.up_proj.weight.abs().max(
|
|
dim=1)[0]
|
|
|
|
# ==================================================================
|
|
layer_name = name + ".mlp.down_proj"
|
|
smoother = smooth_gemm(module.mlp.down_proj.weight,
|
|
scales[layer_name]["x"], None, None, alpha)
|
|
llama_smoother[layer_name] = smoother.float()
|
|
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
|
|
scales[layer_name]["w"] = module.mlp.down_proj.weight.abs().max(
|
|
dim=1)[0]
|
|
|
|
|
|
@torch.no_grad()
|
|
def capture_activation_range(model,
|
|
tokenizer,
|
|
dataset,
|
|
num_samples=512,
|
|
seq_len=512):
|
|
model.eval()
|
|
device = next(model.parameters()).device
|
|
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
def stat_tensor(name, tensor, act_scales, key):
|
|
hidden_dim = tensor.shape[-1]
|
|
tensor = tensor.view(-1, hidden_dim).abs().detach()
|
|
comming_max = torch.max(tensor, dim=0)[0].float()
|
|
|
|
if act_scales[name][key] is None:
|
|
act_scales[name][key] = comming_max
|
|
else:
|
|
act_scales[name][key] = torch.max(act_scales[name][key],
|
|
comming_max)
|
|
|
|
def stat_input_hook(m, x, y, name):
|
|
if isinstance(x, tuple):
|
|
x = x[0]
|
|
stat_tensor(name, x, act_scales, "x")
|
|
stat_tensor(name, y, act_scales, "y")
|
|
|
|
if act_scales[name]["w"] is None:
|
|
act_scales[name]["w"] = m.weight.abs().clip(1e-8,
|
|
None).max(dim=1)[0]
|
|
|
|
hooks = []
|
|
for name, m in model.named_modules():
|
|
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
|
|
hooks.append(
|
|
m.register_forward_hook(
|
|
functools.partial(stat_input_hook, name=name)))
|
|
|
|
for i in tqdm(range(num_samples), desc="calibrating model"):
|
|
datapoint = dataset[i:i + 1]
|
|
line = copy.copy(datapoint)
|
|
line[0] = line[0] + ' TL;DR: '
|
|
line[0] = line[0].strip()
|
|
line[0] = line[0].replace(" n't", "n't")
|
|
input_ids = tokenizer(line,
|
|
return_tensors="pt",
|
|
max_length=seq_len,
|
|
padding=True,
|
|
truncation=True).input_ids.to(device)
|
|
model(input_ids)
|
|
for h in hooks:
|
|
h.remove()
|
|
return act_scales
|
|
|
|
|
|
def get_weight(config, prefix, dtype):
|
|
if config[prefix + '.weight'].dtype != dtype:
|
|
config[prefix + '.weight'].data = config[prefix + '.weight'].to(dtype)
|
|
return config[prefix + '.weight']
|
|
|
|
|
|
def get_bias(config, prefix, dtype):
|
|
if config[prefix + '.bias'].dtype != dtype:
|
|
config[prefix + '.bias'].data = config[prefix + '.bias'].to(dtype)
|
|
return config[prefix + '.bias']
|
|
|
|
|
|
def get_weight_and_bias(config, prefix, dtype):
|
|
return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype)
|
|
|
|
|
|
def get_tllm_linear_sq_weight(vals,
|
|
prefix,
|
|
shape,
|
|
tensor_parallel,
|
|
is_qkv=False,
|
|
per_token=False,
|
|
per_channel=False,
|
|
last_prefix=None,
|
|
bias=None,
|
|
smoother_value=None,
|
|
smoother_shape=None,
|
|
rank=0,
|
|
cat_dim=0,
|
|
multi_query_mode=False):
|
|
results = {}
|
|
|
|
def multi_query_split(data, local_dim, head_size, tp_size, cur_rank):
|
|
q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1)
|
|
q_split = np.split(q, tp_size, axis=-1)
|
|
k_split = np.split(k, tp_size, axis=-1)
|
|
v_split = np.split(v, tp_size, axis=-1)
|
|
return [
|
|
np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1)
|
|
for ii in range(tp_size)
|
|
][cur_rank]
|
|
|
|
col_shape = shape if (is_qkv or per_channel) else [1, 1]
|
|
|
|
if per_token:
|
|
original_weights = vals["weight.int8.col"]
|
|
|
|
local_dim = original_weights.shape[0]
|
|
head_size = (original_weights.shape[1] - local_dim) // 2
|
|
if multi_query_mode:
|
|
cur_weights = multi_query_split(original_weights, local_dim,
|
|
head_size, tensor_parallel, rank)
|
|
else:
|
|
cur_weights = np.split(original_weights,
|
|
tensor_parallel,
|
|
axis=cat_dim)[rank]
|
|
if is_qkv:
|
|
hidden_dim = cur_weights.shape[0]
|
|
cur_weights = cur_weights.reshape(hidden_dim, -1)
|
|
results[prefix +
|
|
'weight'] = torch.from_numpy(cur_weights).t().contiguous()
|
|
if smoother_value is None:
|
|
results[last_prefix] = torch.from_numpy(
|
|
np.array([1.0], dtype=np.float32))
|
|
|
|
if smoother_value is None:
|
|
if multi_query_mode:
|
|
cur_per_channel_value = multi_query_split(
|
|
vals["scale_w_quant_orig.col"], local_dim, head_size,
|
|
tensor_parallel, rank)
|
|
else:
|
|
cur_per_channel_value = np.split(vals["scale_w_quant_orig.col"],
|
|
tensor_parallel,
|
|
axis=cat_dim)[rank]
|
|
else:
|
|
cur_per_channel_value = vals["scale_w_quant_orig.col"]
|
|
results[prefix + 'per_channel_scale'] = torch.from_numpy(
|
|
np.array(cur_per_channel_value,
|
|
dtype=np.float32).reshape(col_shape)).contiguous()
|
|
else:
|
|
original_weights = np.array(vals["weight.int8"])
|
|
cur_weights = np.split(original_weights, tensor_parallel,
|
|
axis=cat_dim)[rank]
|
|
|
|
if is_qkv:
|
|
hidden_dim = cur_weights.shape[0]
|
|
cur_weights = cur_weights.reshape(hidden_dim, -1)
|
|
results[prefix +
|
|
'weight'] = torch.from_numpy(cur_weights).t().contiguous()
|
|
|
|
cur_per_channel_value = vals["scale_y_accum_quant"]
|
|
|
|
results[prefix + 'per_channel_scale'] = torch.from_numpy(
|
|
np.array([cur_per_channel_value],
|
|
dtype=np.float32).reshape(col_shape)).contiguous()
|
|
|
|
results[last_prefix] = torch.from_numpy(
|
|
np.array([vals['scale_x_orig_quant']],
|
|
dtype=np.float32)).contiguous()
|
|
|
|
results[prefix + 'act_scale'] = torch.from_numpy(
|
|
np.array([[vals["scale_y_quant_orig"]]],
|
|
dtype=np.float32)).contiguous()
|
|
|
|
if smoother_value is not None:
|
|
cur_smoother_value = np.split(smoother_value,
|
|
tensor_parallel,
|
|
axis=cat_dim)[rank]
|
|
results[prefix + 'smoother'] = cur_smoother_value.reshape(
|
|
smoother_shape).contiguous().to(torch.float32)
|
|
|
|
if bias is not None:
|
|
results[prefix + 'bias'] = bias
|
|
|
|
return results
|
|
|
|
|
|
def convert_hf_llama(hf_model,
|
|
mapping,
|
|
rank=0,
|
|
dtype='float32',
|
|
use_parallel_embedding=False,
|
|
sharding_dim=0,
|
|
use_weight_only=False,
|
|
plugin_weight_only_quant_type=torch.int8,
|
|
use_smooth_quant=False,
|
|
per_channel=False,
|
|
per_token=False,
|
|
int8_kv_cache=False,
|
|
act_range=[],
|
|
qkv_para=[],
|
|
smoother=[],
|
|
lora_config=None):
|
|
|
|
weights = {}
|
|
tik = time.time()
|
|
tensor_parallel = mapping.tp_size
|
|
model_params = dict(hf_model.named_parameters())
|
|
dtype = getattr(torch, dtype)
|
|
num_attention_heads = hf_model.config.num_attention_heads
|
|
hidden_size = hf_model.config.hidden_size
|
|
intermediate_size = hf_model.config.intermediate_size
|
|
num_key_value_heads = hf_model.config.num_key_value_heads
|
|
mha_mode = (num_key_value_heads == num_attention_heads)
|
|
|
|
num_hidden_layers = hf_model.config.num_hidden_layers
|
|
layers_range = mapping.pp_layers(num_hidden_layers)
|
|
for l in layers_range:
|
|
layer_idx = l - layers_range[0]
|
|
prefix = f'model.layers.{l}.'
|
|
tllm_prex = f'transformer.layers.{layer_idx}.'
|
|
|
|
q_weight = get_weight(model_params, prefix + 'self_attn.q_proj', dtype)
|
|
k_weight = get_weight(model_params, prefix + 'self_attn.k_proj', dtype)
|
|
v_weight = get_weight(model_params, prefix + 'self_attn.v_proj', dtype)
|
|
|
|
if not mha_mode:
|
|
head_size = hidden_size // num_attention_heads
|
|
if num_key_value_heads < tensor_parallel:
|
|
# duplicate the KV heads up to tensor_parallel
|
|
k_weight = dup_kv_weight(k_weight, num_key_value_heads,
|
|
tensor_parallel)
|
|
v_weight = dup_kv_weight(v_weight, num_key_value_heads,
|
|
tensor_parallel)
|
|
assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0
|
|
assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0
|
|
|
|
wq = split(q_weight, mapping.tp_size, mapping.tp_rank)
|
|
wk = split(k_weight, mapping.tp_size, mapping.tp_rank)
|
|
wv = split(v_weight, mapping.tp_size, mapping.tp_rank)
|
|
|
|
split_v = torch.concat((wq, wk, wv))
|
|
|
|
else:
|
|
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
|
|
|
|
split_v = split_qkv_tp(qkv_weight, num_attention_heads, hidden_size,
|
|
tensor_parallel, mapping.tp_rank)
|
|
if use_smooth_quant:
|
|
qkv_weight = qkv_para[prefix + 'self_attn.qkv_proj']
|
|
|
|
if not mha_mode:
|
|
hidden_size = qkv_weight.shape[0]
|
|
local_dim = hidden_size
|
|
head_size = (qkv_weight.shape[-1] - local_dim) // 2
|
|
qkv_weight = qkv_weight.reshape(hidden_size,
|
|
local_dim + 2 * head_size)
|
|
else:
|
|
qkv_weight = qkv_weight.reshape(hidden_size, 3, hidden_size)
|
|
|
|
int8_weights = generate_int8(qkv_weight,
|
|
act_range.get(prefix +
|
|
'self_attn.qkv_proj'),
|
|
is_qkv=True,
|
|
multi_query_mode=bool(not mha_mode))
|
|
|
|
weights.update(
|
|
get_tllm_linear_sq_weight(
|
|
int8_weights,
|
|
tllm_prex + 'attention.qkv.', [
|
|
1, 3 * hidden_size // tensor_parallel
|
|
if mha_mode else hidden_size // tensor_parallel +
|
|
(hidden_size // num_key_value_heads) //
|
|
tensor_parallel * 2
|
|
],
|
|
tensor_parallel,
|
|
is_qkv=True,
|
|
per_token=per_token,
|
|
per_channel=per_channel,
|
|
last_prefix=tllm_prex + 'input_layernorm.scale_to_int',
|
|
smoother_value=None,
|
|
smoother_shape=None,
|
|
rank=mapping.tp_rank,
|
|
cat_dim=-1,
|
|
multi_query_mode=bool(not mha_mode)))
|
|
else:
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'attention.qkv.',
|
|
None, use_weight_only,
|
|
plugin_weight_only_quant_type))
|
|
|
|
if int8_kv_cache:
|
|
qkv_y = torch.cat([
|
|
act_range.get(prefix + 'self_attn.q_proj')["y"],
|
|
act_range.get(prefix + 'self_attn.k_proj')["y"],
|
|
act_range.get(prefix + 'self_attn.v_proj')["y"]
|
|
],
|
|
dim=0)
|
|
|
|
int8_kv_scales = qkv_y.max() / 127.
|
|
|
|
kv_cache_weights = {}
|
|
|
|
kv_cache_weights[
|
|
tllm_prex +
|
|
'attention.kv_cache_scaling_factor'] = int8_kv_scales.reshape(
|
|
[1])
|
|
|
|
attn_dense_weight = get_weight(model_params,
|
|
prefix + 'self_attn.o_proj', dtype)
|
|
split_v = split_matrix_tp(attn_dense_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
if use_smooth_quant:
|
|
attn_dense_weight = attn_dense_weight.t()
|
|
int8_weights = generate_int8(
|
|
attn_dense_weight, act_range.get(prefix + 'self_attn.o_proj'))
|
|
weights.update(
|
|
get_tllm_linear_sq_weight(
|
|
int8_weights,
|
|
tllm_prex + 'attention.dense.', [1, hidden_size],
|
|
tensor_parallel,
|
|
is_qkv=False,
|
|
per_token=per_token,
|
|
per_channel=per_channel,
|
|
last_prefix=tllm_prex +
|
|
'attention.quantization_scaling_factor',
|
|
smoother_value=smoother[(prefix + 'self_attn.o_proj')],
|
|
smoother_shape=[1, hidden_size // tensor_parallel],
|
|
rank=mapping.tp_rank,
|
|
cat_dim=0))
|
|
else:
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'attention.dense.',
|
|
None, use_weight_only,
|
|
plugin_weight_only_quant_type))
|
|
|
|
mlp_gate_weight = get_weight(model_params, prefix + 'mlp.up_proj',
|
|
dtype)
|
|
split_v = split_matrix_tp(mlp_gate_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
if use_smooth_quant:
|
|
mlp_gate_weight = mlp_gate_weight.t()
|
|
int8_weights = generate_int8(mlp_gate_weight,
|
|
act_range.get(prefix + 'mlp.up_proj'))
|
|
|
|
weights.update(
|
|
get_tllm_linear_sq_weight(
|
|
int8_weights,
|
|
tllm_prex + 'mlp.gate.',
|
|
[1, intermediate_size // tensor_parallel],
|
|
tensor_parallel,
|
|
is_qkv=False,
|
|
per_token=per_token,
|
|
per_channel=per_channel,
|
|
last_prefix=tllm_prex + 'post_layernorm.scale_to_int',
|
|
smoother_value=None,
|
|
smoother_shape=None,
|
|
rank=mapping.tp_rank,
|
|
cat_dim=-1))
|
|
else:
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.gate.', None,
|
|
use_weight_only,
|
|
plugin_weight_only_quant_type))
|
|
|
|
mlp_fc_weight = get_weight(model_params, prefix + 'mlp.gate_proj',
|
|
dtype)
|
|
split_v = split_matrix_tp(mlp_fc_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
|
|
if use_smooth_quant:
|
|
mlp_fc_weight = mlp_fc_weight.t() #verified
|
|
int8_weights = generate_int8(
|
|
mlp_fc_weight, act_range.get(prefix + 'mlp.gate_proj'))
|
|
weights.update(
|
|
get_tllm_linear_sq_weight(
|
|
int8_weights,
|
|
tllm_prex + 'mlp.fc.',
|
|
[1, intermediate_size // tensor_parallel],
|
|
tensor_parallel,
|
|
is_qkv=False,
|
|
per_token=per_token,
|
|
per_channel=per_channel,
|
|
last_prefix=tllm_prex + 'post_layernorm.scale_to_int',
|
|
smoother_value=None,
|
|
smoother_shape=None,
|
|
rank=mapping.tp_rank,
|
|
cat_dim=-1))
|
|
else:
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.fc.', None,
|
|
use_weight_only,
|
|
plugin_weight_only_quant_type))
|
|
|
|
mlp_proj_weight = get_weight(model_params, prefix + 'mlp.down_proj',
|
|
dtype)
|
|
split_v = split_matrix_tp(mlp_proj_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
|
|
if use_smooth_quant:
|
|
mlp_proj_weight = mlp_proj_weight.t()
|
|
int8_weights = generate_int8(
|
|
mlp_proj_weight, act_range.get(prefix + 'mlp.down_proj'))
|
|
weights.update(
|
|
get_tllm_linear_sq_weight(
|
|
int8_weights,
|
|
tllm_prex + 'mlp.proj.', [1, hidden_size],
|
|
tensor_parallel,
|
|
is_qkv=False,
|
|
per_token=per_token,
|
|
per_channel=per_channel,
|
|
last_prefix=tllm_prex + 'mlp.quantization_scaling_factor',
|
|
smoother_value=smoother[prefix + 'mlp.down_proj'],
|
|
smoother_shape=[1, intermediate_size // tensor_parallel],
|
|
rank=mapping.tp_rank,
|
|
cat_dim=0))
|
|
else:
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.', None,
|
|
use_weight_only,
|
|
plugin_weight_only_quant_type))
|
|
# Layer norms do not use tensor parallelism
|
|
input_ln_weight = get_weight(model_params, prefix + 'input_layernorm',
|
|
dtype)
|
|
weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight
|
|
|
|
post_ln_weight = get_weight(model_params,
|
|
prefix + 'post_attention_layernorm', dtype)
|
|
weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight
|
|
|
|
v = get_weight(model_params, 'model.embed_tokens', dtype)
|
|
|
|
if hf_model.config.tie_word_embeddings:
|
|
# lm_head.weight has the same weights as embedding
|
|
if mapping.is_last_pp_rank():
|
|
weights['lm_head.weight'] = split(v, mapping.tp_size,
|
|
mapping.tp_rank)
|
|
|
|
if use_parallel_embedding:
|
|
v = split_matrix_tp(v,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=sharding_dim)
|
|
|
|
if mapping.is_first_pp_rank():
|
|
weights['transformer.vocab_embedding.weight'] = v
|
|
|
|
lm_head_weights = get_weight(model_params, 'lm_head', dtype)
|
|
|
|
if mapping.is_last_pp_rank():
|
|
weights['lm_head.weight'] = split_matrix_tp(lm_head_weights,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
|
|
ln_f_w = get_weight(model_params, 'model.norm', dtype)
|
|
weights['transformer.ln_f.weight'] = ln_f_w
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
print(f'Weights loaded. Total time: {t}')
|
|
return weights
|