TensorRT-LLMs/tensorrt_llm/models/medusa/weight.py
Dan Blanaru 16d2467ea8 Update TensorRT-LLM (#2755)
* Update TensorRT-LLM

---------

Co-authored-by: Denis Kayshev <topenkoff@gmail.com>
Co-authored-by: akhoroshev <arthoroshev@gmail.com>
Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com>

Update
2025-02-11 03:01:00 +00:00

649 lines
26 KiB
Python

import copy
import functools
import time
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.pytorch_utils import Conv1D
from tensorrt_llm import logger
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import (dup_kv_weight, generate_int8,
smooth_gemm,
smooth_gemm_fc1_gate, split,
split_matrix_tp, split_qkv_tp)
def get_tllm_linear_weight(weight,
prefix,
bias=None,
use_weight_only=False,
plugin_weight_only_quant_type=torch.int8,
postfix='weight'):
results = {}
if use_weight_only:
v = weight.t().contiguous().cpu()
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
v, plugin_weight_only_quant_type)
results[prefix + postfix] = processed_torch_weights
results[prefix + 'per_channel_scale'] = torch_weight_scales
else:
results[prefix + postfix] = weight.contiguous()
if bias is not None:
results[prefix + 'bias'] = bias
return results
def load_medusa_hf(medusa_path: str,
num_medusa_heads: int,
num_medusa_layers: int,
mapping=Mapping(),
dtype='float32',
use_weight_only=False,
plugin_weight_only_quant_type=None,
is_modelopt_ckpt=False):
# logger.info("Loading Medusa heads' weights ...")
if is_modelopt_ckpt:
from safetensors.torch import load_file
state_dict = {}
for filename in sorted(Path(medusa_path).glob("*.safetensors")):
print(f"Loading the weights of Medusa heads from {filename}")
state_dict.update(load_file(filename))
else:
is_ckpt_safetensors = False
ckpt_file = Path(medusa_path) / "medusa_lm_head.pt"
if not ckpt_file.exists():
ckpt_file = Path(medusa_path) / "medusa_lm_head.safetensors"
is_ckpt_safetensors = True
if is_ckpt_safetensors:
logger.info("Safetensors Found ...")
from safetensors.torch import load_file
state_dict = load_file(ckpt_file)
else:
state_dict = torch.load(ckpt_file, map_location="cpu")
torch_dtype = str_dtype_to_torch(dtype)
weights = {}
prefix = "medusa_heads." if is_modelopt_ckpt else ""
for h in range(num_medusa_heads):
for l in range(num_medusa_layers):
w = state_dict[f"{prefix}{h}.{l}.linear.weight"].clone().to(
torch_dtype)
split_v = split(w, mapping.tp_size, mapping.tp_rank)
weights.update(
get_tllm_linear_weight(
split_v, f'medusa_heads.{h}.medusa_layers.{l}.linear.',
None, use_weight_only, plugin_weight_only_quant_type))
b = state_dict[f"{prefix}{h}.{l}.linear.bias"].clone().to(
torch_dtype)
weights['medusa_heads.{}.medusa_layers.{}.linear.bias'.format(
h, l)] = split(b, mapping.tp_size, mapping.tp_rank)
lm = state_dict[f"{prefix}{h}.{num_medusa_layers}.weight"].clone().to(
torch_dtype) # LM Head
weights['medusa_heads.{}.lm_head.weight'.format(h)] = split(
lm, mapping.tp_size, mapping.tp_rank)
# scaling factors
if is_modelopt_ckpt:
scaling_dtype = str_dtype_to_torch("float32")
weights[f'medusa_heads.{h}.medusa_layers.{l}.linear.activation_scaling_factor'] = \
state_dict[f"{prefix}{h}.{l}.linear.input_scale"].clone().to(scaling_dtype)
weights[f'medusa_heads.{h}.medusa_layers.{l}.linear.weights_scaling_factor'] = \
state_dict[f"{prefix}{h}.{l}.linear.weight_scale"].clone().to(scaling_dtype)
weights['medusa_heads.{}.lm_head.activation_scaling_factor'.format(h)] = \
state_dict[f"{prefix}{h}.{num_medusa_layers}.input_scale"].clone().to(scaling_dtype)
weights['medusa_heads.{}.lm_head.weights_scaling_factor'.format(h)] = \
state_dict[f"{prefix}{h}.{num_medusa_layers}.weight_scale"].clone().to(scaling_dtype)
return weights
@torch.no_grad()
def smooth_llama_model(model, scales, alpha, llama_qkv_para, llama_smoother):
# Smooth the activation and weights with smoother = $\diag{s}$
for name, module in model.named_modules():
if not isinstance(module, LlamaDecoderLayer):
continue
# qkv_proj
layer_name_q = name + ".self_attn.q_proj"
layer_name_k = name + ".self_attn.k_proj"
layer_name_v = name + ".self_attn.v_proj"
layer_name_qkv = name + ".self_attn.qkv_proj"
weight = torch.cat([
module.self_attn.q_proj.weight, module.self_attn.k_proj.weight,
module.self_attn.v_proj.weight
],
dim=0)
smoother = smooth_gemm(weight, scales[layer_name_q]["x"],
module.input_layernorm.weight, None, alpha)
scales[layer_name_qkv]["x"] = scales[layer_name_q]["x"] / smoother
scales[layer_name_qkv]["w"] = weight.abs().max(dim=1)[0]
scales[layer_name_qkv]["y"] = torch.cat([
scales[layer_name_q]["y"], scales[layer_name_k]["y"],
scales[layer_name_v]["y"]
],
dim=0)
# see transpose_weights function
llama_qkv_para[layer_name_qkv] = weight.transpose(0, 1)
# =================================================================
layer_name = name + ".self_attn.o_proj"
smoother = smooth_gemm(module.self_attn.o_proj.weight,
scales[layer_name]["x"], None, None, alpha)
llama_smoother[layer_name] = smoother.float()
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.self_attn.o_proj.weight.abs().max(
dim=1)[0]
# ==================================================================
fc1_layer_name = name + ".mlp.gate_proj"
gate_layer_name = name + ".mlp.up_proj"
smoother = smooth_gemm_fc1_gate(module.mlp.gate_proj.weight,
module.mlp.up_proj.weight,
scales[fc1_layer_name]["x"],
module.post_attention_layernorm.weight,
None, alpha)
scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother
scales[fc1_layer_name]["w"] = module.mlp.gate_proj.weight.abs().max(
dim=1)[0]
scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother
scales[gate_layer_name]["w"] = module.mlp.up_proj.weight.abs().max(
dim=1)[0]
# ==================================================================
layer_name = name + ".mlp.down_proj"
smoother = smooth_gemm(module.mlp.down_proj.weight,
scales[layer_name]["x"], None, None, alpha)
llama_smoother[layer_name] = smoother.float()
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.mlp.down_proj.weight.abs().max(
dim=1)[0]
@torch.no_grad()
def capture_activation_range(model,
tokenizer,
dataset,
num_samples=512,
seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
tokenizer.pad_token = tokenizer.eos_token
def stat_tensor(name, tensor, act_scales, key):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float()
if act_scales[name][key] is None:
act_scales[name][key] = comming_max
else:
act_scales[name][key] = torch.max(act_scales[name][key],
comming_max)
def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x, act_scales, "x")
stat_tensor(name, y, act_scales, "y")
if act_scales[name]["w"] is None:
act_scales[name]["w"] = m.weight.abs().clip(1e-8,
None).max(dim=1)[0]
hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
hooks.append(
m.register_forward_hook(
functools.partial(stat_input_hook, name=name)))
for i in tqdm(range(num_samples), desc="calibrating model"):
datapoint = dataset[i:i + 1]
line = copy.copy(datapoint)
line[0] = line[0] + ' TL;DR: '
line[0] = line[0].strip()
line[0] = line[0].replace(" n't", "n't")
input_ids = tokenizer(line,
return_tensors="pt",
max_length=seq_len,
padding=True,
truncation=True).input_ids.to(device)
model(input_ids)
for h in hooks:
h.remove()
return act_scales
def get_weight(config, prefix, dtype):
if config[prefix + '.weight'].dtype != dtype:
config[prefix + '.weight'].data = config[prefix + '.weight'].to(dtype)
return config[prefix + '.weight']
def get_bias(config, prefix, dtype):
if config[prefix + '.bias'].dtype != dtype:
config[prefix + '.bias'].data = config[prefix + '.bias'].to(dtype)
return config[prefix + '.bias']
def get_weight_and_bias(config, prefix, dtype):
return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype)
def get_tllm_linear_sq_weight(vals,
prefix,
shape,
tensor_parallel,
is_qkv=False,
per_token=False,
per_channel=False,
last_prefix=None,
bias=None,
smoother_value=None,
smoother_shape=None,
rank=0,
cat_dim=0,
multi_query_mode=False):
results = {}
def multi_query_split(data, local_dim, head_size, tp_size, cur_rank):
q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1)
q_split = np.split(q, tp_size, axis=-1)
k_split = np.split(k, tp_size, axis=-1)
v_split = np.split(v, tp_size, axis=-1)
return [
np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1)
for ii in range(tp_size)
][cur_rank]
col_shape = shape if (is_qkv or per_channel) else [1, 1]
if per_token:
original_weights = vals["weight.int8.col"]
local_dim = original_weights.shape[0]
head_size = (original_weights.shape[1] - local_dim) // 2
if multi_query_mode:
cur_weights = multi_query_split(original_weights, local_dim,
head_size, tensor_parallel, rank)
else:
cur_weights = np.split(original_weights,
tensor_parallel,
axis=cat_dim)[rank]
if is_qkv:
hidden_dim = cur_weights.shape[0]
cur_weights = cur_weights.reshape(hidden_dim, -1)
results[prefix +
'weight'] = torch.from_numpy(cur_weights).t().contiguous()
if smoother_value is None:
results[last_prefix] = torch.from_numpy(
np.array([1.0], dtype=np.float32))
if smoother_value is None:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_w_quant_orig.col"], local_dim, head_size,
tensor_parallel, rank)
else:
cur_per_channel_value = np.split(vals["scale_w_quant_orig.col"],
tensor_parallel,
axis=cat_dim)[rank]
else:
cur_per_channel_value = vals["scale_w_quant_orig.col"]
results[prefix + 'per_channel_scale'] = torch.from_numpy(
np.array(cur_per_channel_value,
dtype=np.float32).reshape(col_shape)).contiguous()
else:
original_weights = np.array(vals["weight.int8"])
cur_weights = np.split(original_weights, tensor_parallel,
axis=cat_dim)[rank]
if is_qkv:
hidden_dim = cur_weights.shape[0]
cur_weights = cur_weights.reshape(hidden_dim, -1)
results[prefix +
'weight'] = torch.from_numpy(cur_weights).t().contiguous()
cur_per_channel_value = vals["scale_y_accum_quant"]
results[prefix + 'per_channel_scale'] = torch.from_numpy(
np.array([cur_per_channel_value],
dtype=np.float32).reshape(col_shape)).contiguous()
results[last_prefix] = torch.from_numpy(
np.array([vals['scale_x_orig_quant']],
dtype=np.float32)).contiguous()
results[prefix + 'act_scale'] = torch.from_numpy(
np.array([[vals["scale_y_quant_orig"]]],
dtype=np.float32)).contiguous()
if smoother_value is not None:
cur_smoother_value = np.split(smoother_value,
tensor_parallel,
axis=cat_dim)[rank]
results[prefix + 'smoother'] = cur_smoother_value.reshape(
smoother_shape).contiguous().to(torch.float32)
if bias is not None:
results[prefix + 'bias'] = bias
return results
def convert_hf_llama(hf_model,
mapping,
rank=0,
dtype='float32',
use_parallel_embedding=False,
sharding_dim=0,
use_weight_only=False,
plugin_weight_only_quant_type=torch.int8,
use_smooth_quant=False,
per_channel=False,
per_token=False,
int8_kv_cache=False,
act_range=[],
qkv_para=[],
smoother=[],
lora_config=None):
weights = {}
tik = time.time()
tensor_parallel = mapping.tp_size
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, dtype)
num_attention_heads = hf_model.config.num_attention_heads
hidden_size = hf_model.config.hidden_size
intermediate_size = hf_model.config.intermediate_size
num_key_value_heads = hf_model.config.num_key_value_heads
mha_mode = (num_key_value_heads == num_attention_heads)
num_hidden_layers = hf_model.config.num_hidden_layers
layers_range = mapping.pp_layers(num_hidden_layers)
for l in layers_range:
layer_idx = l - layers_range[0]
prefix = f'model.layers.{l}.'
tllm_prex = f'transformer.layers.{layer_idx}.'
q_weight = get_weight(model_params, prefix + 'self_attn.q_proj', dtype)
k_weight = get_weight(model_params, prefix + 'self_attn.k_proj', dtype)
v_weight = get_weight(model_params, prefix + 'self_attn.v_proj', dtype)
if not mha_mode:
head_size = hidden_size // num_attention_heads
if num_key_value_heads < tensor_parallel:
# duplicate the KV heads up to tensor_parallel
k_weight = dup_kv_weight(k_weight, num_key_value_heads,
tensor_parallel)
v_weight = dup_kv_weight(v_weight, num_key_value_heads,
tensor_parallel)
assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0
assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0
wq = split(q_weight, mapping.tp_size, mapping.tp_rank)
wk = split(k_weight, mapping.tp_size, mapping.tp_rank)
wv = split(v_weight, mapping.tp_size, mapping.tp_rank)
split_v = torch.concat((wq, wk, wv))
else:
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
split_v = split_qkv_tp(qkv_weight, num_attention_heads, hidden_size,
tensor_parallel, mapping.tp_rank)
if use_smooth_quant:
qkv_weight = qkv_para[prefix + 'self_attn.qkv_proj']
if not mha_mode:
hidden_size = qkv_weight.shape[0]
local_dim = hidden_size
head_size = (qkv_weight.shape[-1] - local_dim) // 2
qkv_weight = qkv_weight.reshape(hidden_size,
local_dim + 2 * head_size)
else:
qkv_weight = qkv_weight.reshape(hidden_size, 3, hidden_size)
int8_weights = generate_int8(qkv_weight,
act_range.get(prefix +
'self_attn.qkv_proj'),
is_qkv=True,
multi_query_mode=bool(not mha_mode))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'attention.qkv.', [
1, 3 * hidden_size // tensor_parallel
if mha_mode else hidden_size // tensor_parallel +
(hidden_size // num_key_value_heads) //
tensor_parallel * 2
],
tensor_parallel,
is_qkv=True,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + 'input_layernorm.scale_to_int',
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
multi_query_mode=bool(not mha_mode)))
else:
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'attention.qkv.',
None, use_weight_only,
plugin_weight_only_quant_type))
if int8_kv_cache:
qkv_y = torch.cat([
act_range.get(prefix + 'self_attn.q_proj')["y"],
act_range.get(prefix + 'self_attn.k_proj')["y"],
act_range.get(prefix + 'self_attn.v_proj')["y"]
],
dim=0)
int8_kv_scales = qkv_y.max() / 127.
kv_cache_weights = {}
kv_cache_weights[
tllm_prex +
'attention.kv_cache_scaling_factor'] = int8_kv_scales.reshape(
[1])
attn_dense_weight = get_weight(model_params,
prefix + 'self_attn.o_proj', dtype)
split_v = split_matrix_tp(attn_dense_weight,
tensor_parallel,
mapping.tp_rank,
dim=1)
if use_smooth_quant:
attn_dense_weight = attn_dense_weight.t()
int8_weights = generate_int8(
attn_dense_weight, act_range.get(prefix + 'self_attn.o_proj'))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'attention.dense.', [1, hidden_size],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex +
'attention.quantization_scaling_factor',
smoother_value=smoother[(prefix + 'self_attn.o_proj')],
smoother_shape=[1, hidden_size // tensor_parallel],
rank=mapping.tp_rank,
cat_dim=0))
else:
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'attention.dense.',
None, use_weight_only,
plugin_weight_only_quant_type))
mlp_gate_weight = get_weight(model_params, prefix + 'mlp.up_proj',
dtype)
split_v = split_matrix_tp(mlp_gate_weight,
tensor_parallel,
mapping.tp_rank,
dim=0)
if use_smooth_quant:
mlp_gate_weight = mlp_gate_weight.t()
int8_weights = generate_int8(mlp_gate_weight,
act_range.get(prefix + 'mlp.up_proj'))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'mlp.gate.',
[1, intermediate_size // tensor_parallel],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + 'post_layernorm.scale_to_int',
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1))
else:
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.gate.', None,
use_weight_only,
plugin_weight_only_quant_type))
mlp_fc_weight = get_weight(model_params, prefix + 'mlp.gate_proj',
dtype)
split_v = split_matrix_tp(mlp_fc_weight,
tensor_parallel,
mapping.tp_rank,
dim=0)
if use_smooth_quant:
mlp_fc_weight = mlp_fc_weight.t() #verified
int8_weights = generate_int8(
mlp_fc_weight, act_range.get(prefix + 'mlp.gate_proj'))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'mlp.fc.',
[1, intermediate_size // tensor_parallel],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + 'post_layernorm.scale_to_int',
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1))
else:
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.fc.', None,
use_weight_only,
plugin_weight_only_quant_type))
mlp_proj_weight = get_weight(model_params, prefix + 'mlp.down_proj',
dtype)
split_v = split_matrix_tp(mlp_proj_weight,
tensor_parallel,
mapping.tp_rank,
dim=1)
if use_smooth_quant:
mlp_proj_weight = mlp_proj_weight.t()
int8_weights = generate_int8(
mlp_proj_weight, act_range.get(prefix + 'mlp.down_proj'))
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + 'mlp.proj.', [1, hidden_size],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + 'mlp.quantization_scaling_factor',
smoother_value=smoother[prefix + 'mlp.down_proj'],
smoother_shape=[1, intermediate_size // tensor_parallel],
rank=mapping.tp_rank,
cat_dim=0))
else:
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.', None,
use_weight_only,
plugin_weight_only_quant_type))
# Layer norms do not use tensor parallelism
input_ln_weight = get_weight(model_params, prefix + 'input_layernorm',
dtype)
weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight
post_ln_weight = get_weight(model_params,
prefix + 'post_attention_layernorm', dtype)
weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight
v = get_weight(model_params, 'model.embed_tokens', dtype)
if hf_model.config.tie_word_embeddings:
# lm_head.weight has the same weights as embedding
if mapping.is_last_pp_rank():
weights['lm_head.weight'] = split(v, mapping.tp_size,
mapping.tp_rank)
if use_parallel_embedding:
v = split_matrix_tp(v,
mapping.tp_size,
mapping.tp_rank,
dim=sharding_dim)
if mapping.is_first_pp_rank():
weights['transformer.vocab_embedding.weight'] = v
lm_head_weights = get_weight(model_params, 'lm_head', dtype)
if mapping.is_last_pp_rank():
weights['lm_head.weight'] = split_matrix_tp(lm_head_weights,
tensor_parallel,
mapping.tp_rank,
dim=0)
ln_f_w = get_weight(model_params, 'model.norm', dtype)
weights['transformer.ln_f.weight'] = ln_f_w
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights