mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
523 lines
22 KiB
Python
523 lines
22 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import re
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm.logger as logger
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales
|
|
from tensorrt_llm.quantization import QuantMode
|
|
|
|
|
|
def split(weight: np.ndarray, tp_size: int, rank: int = 0, dim: int = 0):
|
|
if tp_size == 1:
|
|
return weight
|
|
elif weight.ndim == 1:
|
|
return np.ascontiguousarray(np.split(weight, tp_size)[rank].copy())
|
|
return np.ascontiguousarray(
|
|
np.split(weight, tp_size, axis=dim)[rank].copy())
|
|
|
|
|
|
def reorder_qkv_weight_or_bias(weight: np.ndarray,
|
|
head_dim: int,
|
|
num_heads: int,
|
|
num_kv_heads: Optional[int] = None,
|
|
tp_size: int = 1,
|
|
is_bias: bool = False):
|
|
""" Reorder the qkv weight for TRT-LLM use.
|
|
|
|
The shape of the fused QKV weights in HF is different from the shape that
|
|
TRT-LLM requires. In particular, the weight of HF consists of interleaved
|
|
q, k, v head weights, while that of TRT-LLM is contigous.
|
|
HF : [q1, k1, v1, ..., qh, kh, vh]
|
|
TRT-LLM: [q1, ..., qh, k1, ..., kh, v1, vh]
|
|
where qi, vi, ki are weight vectors corresponding to attention head i.
|
|
It's similar to multi/grouped query attention cases.
|
|
|
|
We reorder and split the weight of an attention layer to fit into TRT-LLM.
|
|
The reordered weight and bias will be
|
|
weight: (T, Qh * D + 2 * KVh * D, H)
|
|
bias : (T, Qh * D + 2 * KVh * D)
|
|
where T=tp_size, Qh=local_num_q_heads, KVh=local_num_kv_heads, D=head_dim,
|
|
H=hidden_dim. In the multi/grouped query attention, the number of K/V
|
|
attention heads are less than that of Q attention, so that K/V attention
|
|
heads may be shared across different ranks if necessary.
|
|
|
|
For tensor parallelism, we use the first dimension to select the
|
|
corresponding weights.
|
|
"""
|
|
|
|
# Query types and expected kv heads.
|
|
# - Conventional MHA: num_heads = num_kv_heads
|
|
# - Multi-Query Attention: num_kv_heads = 1
|
|
# - Grouped-Query Attention: num_heads % num_kv_heads = 0
|
|
num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
|
|
assert num_heads % num_kv_heads == 0, \
|
|
f'num_heads({num_heads}) must be divisible by '\
|
|
f'num_kv_heads({num_kv_heads})).'
|
|
|
|
# The number of attention heads per group: N q head + 1 k head + 1 v head.
|
|
num_group_heads = num_heads // num_kv_heads + 2
|
|
assert weight.shape[0] == num_kv_heads * num_group_heads * head_dim, \
|
|
f'{weight.shape[0]} != {num_kv_heads} * {num_group_heads} * {head_dim}'
|
|
|
|
qkv_in = num_heads * head_dim if not is_bias else 1
|
|
|
|
# Split Q/K/V weights
|
|
weight = weight.reshape(num_kv_heads, num_heads // num_kv_heads + 2,
|
|
head_dim, qkv_in)
|
|
q_w = weight[:, :-2, ...] # (nKV, num_heads // nKV, head_dim, qkv_in)
|
|
k_w = weight[:, -2:-1, ...] # (nKV, 1, head_dim, qkv_in)
|
|
v_w = weight[:, -1:, ...] # (nKV, 1, head_dim, qkv_in)
|
|
|
|
if num_kv_heads < num_heads and num_kv_heads < tp_size:
|
|
# Duplicate K/V heads to make sure that each rank has at least one
|
|
# K/V heads. For instance, num_heads=8, num_kv_heads=2, tp_size=4,
|
|
# we will make the qkv weight as below.
|
|
# Orig: [q0 q1 q2 q3 k0 v0 q4 q5 q6 q7 k1 v0 v1]
|
|
# >>>> [[q0 q1 k0 v0], [q2 q3 k0 v0], [q4 q5 k1 v1], [q6 q7 k1 v1]]
|
|
assert tp_size % num_kv_heads == 0
|
|
num_dups = tp_size // num_kv_heads
|
|
|
|
# k_w and v_w have the same shape.
|
|
new_shape = (num_kv_heads, num_dups) + k_w.shape[2:]
|
|
k_w = np.broadcast_to(k_w, shape=new_shape)
|
|
v_w = np.broadcast_to(v_w, shape=new_shape)
|
|
|
|
# Update the number of kv heads.
|
|
num_kv_heads = tp_size
|
|
|
|
reordered = np.concatenate(
|
|
[
|
|
q_w.reshape(tp_size, num_heads // tp_size, head_dim, qkv_in),
|
|
k_w.reshape(tp_size, num_kv_heads // tp_size, head_dim, qkv_in),
|
|
v_w.reshape(tp_size, num_kv_heads // tp_size, head_dim, qkv_in),
|
|
],
|
|
axis=1,
|
|
)
|
|
|
|
qkv_out = (num_heads + 2 * num_kv_heads) // tp_size * head_dim
|
|
return reordered.reshape((tp_size, qkv_out, -1))
|
|
|
|
|
|
def split_qkv_weight(trtllm_falcon: tensorrt_llm.models.FalconModel,
|
|
weight: np.ndarray,
|
|
tp_size: int,
|
|
rank: int,
|
|
is_bias: bool,
|
|
num_kv_heads: Optional[int] = None):
|
|
""" Splits the QKV matrix according to tensor parallelism """
|
|
n_heads = trtllm_falcon.num_heads
|
|
hidden_size = trtllm_falcon.hidden_size
|
|
head_dim = hidden_size // n_heads
|
|
weight = reorder_qkv_weight_or_bias(weight,
|
|
head_dim=head_dim,
|
|
num_heads=n_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
tp_size=tp_size,
|
|
is_bias=is_bias)
|
|
|
|
# Copy a sliced tensor to prevent memory leak. A sliced tensor shares the
|
|
# memory buffer of the original tensor. So, returning without copying makes
|
|
# the buffer of a loaded "qkv" be referenced, resulting GC can't release
|
|
# those weights until the whole process ends.
|
|
if not is_bias:
|
|
return np.ascontiguousarray(weight[rank, ...].copy())
|
|
else:
|
|
return weight[rank, ...].ravel().copy()
|
|
|
|
|
|
def split_matrix(weight: np.ndarray, tp_size: int, rank: int, dim: int):
|
|
return np.ascontiguousarray(split(weight, tp_size, rank, dim=dim))
|
|
|
|
|
|
def get_weight(params: Dict, prefix: str, dtype: torch.dtype):
|
|
if f'{prefix}.weight' not in params:
|
|
return None
|
|
param = params[f'{prefix}.weight'].to(dtype).detach().cpu()
|
|
return tensorrt_llm._utils.torch_to_numpy(param)
|
|
|
|
|
|
def get_bias(params: Dict, prefix: str, dtype: torch.dtype):
|
|
if f'{prefix}.bias' not in params:
|
|
return None
|
|
param = params[f'{prefix}.bias'].to(dtype).detach().cpu()
|
|
return tensorrt_llm._utils.torch_to_numpy(param)
|
|
|
|
|
|
def get_weight_and_bias(params: Dict, prefix: str, dtype: torch.dtype):
|
|
return get_weight(params, prefix, dtype), get_bias(params, prefix, dtype)
|
|
|
|
|
|
def load_from_hf_falcon(trtllm_falcon: tensorrt_llm.models.FalconForCausalLM,
|
|
hf_falcon,
|
|
mapping=Mapping(),
|
|
dtype: Union[str, torch.dtype] = torch.float32):
|
|
logger.info('Loading weights from HF Falcon...')
|
|
tik = time.time()
|
|
|
|
model_params = dict(hf_falcon.named_parameters())
|
|
if isinstance(dtype, str):
|
|
dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype)
|
|
num_kv_heads = trtllm_falcon.num_kv_heads
|
|
|
|
layers_range = trtllm_falcon.get_transformer_layers(
|
|
trtllm_falcon.mapping, trtllm_falcon.num_layers)
|
|
for i in layers_range:
|
|
prefix = f'transformer.h.{i}'
|
|
layer = trtllm_falcon.layers[i - layers_range[0]]
|
|
qkv_weight, qkv_bias = get_weight_and_bias(
|
|
model_params, f'{prefix}.self_attention.query_key_value', dtype)
|
|
qkv_w = split_qkv_weight(trtllm_falcon,
|
|
qkv_weight,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
is_bias=False,
|
|
num_kv_heads=num_kv_heads)
|
|
layer.attention.qkv.weight.value = qkv_w
|
|
if qkv_bias is not None:
|
|
layer.attention.qkv.bias.value = split_qkv_weight(
|
|
trtllm_falcon,
|
|
qkv_bias,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
is_bias=True,
|
|
num_kv_heads=num_kv_heads)
|
|
|
|
logger.debug(f'Layer {i}: Loading attention Dense weights...')
|
|
attn_dense_weight, attn_dense_bias = get_weight_and_bias(
|
|
model_params, f'{prefix}.self_attention.dense', dtype)
|
|
layer.attention.dense.weight.value = split_matrix(attn_dense_weight,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
if attn_dense_bias is not None:
|
|
layer.attention.dense.bias.value = attn_dense_bias
|
|
|
|
logger.debug(f'Layer {i}: Loading MLP FC weights...')
|
|
mlp_fc_weight, mlp_fc_bias = get_weight_and_bias(
|
|
model_params, f'{prefix}.mlp.dense_h_to_4h', dtype)
|
|
layer.mlp.fc.weight.value = split_matrix(mlp_fc_weight,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
if mlp_fc_bias is not None:
|
|
layer.mlp.fc.bias.value = split_matrix(mlp_fc_bias,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
|
|
logger.debug(f'Layer {i}: Loading MLP Proj weights...')
|
|
mlp_proj_weight, mlp_proj_bias = get_weight_and_bias(
|
|
model_params, f'{prefix}.mlp.dense_4h_to_h', dtype)
|
|
layer.mlp.proj.weight.value = split_matrix(mlp_proj_weight,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
if mlp_proj_bias is not None:
|
|
layer.mlp.proj.bias.value = mlp_proj_bias
|
|
|
|
if trtllm_falcon.new_decoder_architecture:
|
|
input_ln_weight, input_ln_bias = get_weight_and_bias(
|
|
model_params, f'{prefix}.ln_attn', dtype)
|
|
layer.input_layernorm.weight.value = input_ln_weight
|
|
if input_ln_bias is not None:
|
|
layer.input_layernorm.bias.value = input_ln_bias
|
|
|
|
mlp_ln_weight, mlp_ln_bias = get_weight_and_bias(
|
|
model_params, f'{prefix}.ln_mlp', dtype)
|
|
layer.mlp_layernorm.weight.value = mlp_ln_weight
|
|
if mlp_ln_bias is not None:
|
|
layer.mlp_layernorm.bias.value = mlp_ln_bias
|
|
else:
|
|
# Layer norms do not use tensor parallelism
|
|
logger.debug(f'Layer {i}: Loading normalization weights...')
|
|
input_ln_weight, input_ln_bias = get_weight_and_bias(
|
|
model_params, f'{prefix}.input_layernorm', dtype)
|
|
layer.input_layernorm.weight.value = input_ln_weight
|
|
if input_ln_bias is not None:
|
|
layer.input_layernorm.bias.value = input_ln_bias
|
|
|
|
if not trtllm_falcon.parallel_attention:
|
|
post_ln_weight, post_ln_bias = get_weight_and_bias(
|
|
model_params, f'{prefix}.post_attention_layernorm', dtype)
|
|
if post_ln_weight is not None:
|
|
layer.post_layernorm.weight.value = post_ln_weight
|
|
if post_ln_bias is not None:
|
|
layer.post_layernorm.bias.value = post_ln_bias
|
|
|
|
embed_w = get_weight(model_params, 'transformer.word_embeddings', dtype)
|
|
if mapping.is_first_pp_rank():
|
|
trtllm_falcon.embedding.weight.value = embed_w.copy()
|
|
if mapping.is_last_pp_rank():
|
|
trtllm_falcon.lm_head.weight.value = split_matrix(embed_w,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
|
|
ln_f_w, ln_f_b = get_weight_and_bias(model_params, 'transformer.ln_f',
|
|
dtype)
|
|
trtllm_falcon.ln_f.weight.value = ln_f_w
|
|
if ln_f_b is not None:
|
|
trtllm_falcon.ln_f.bias.value = ln_f_b
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
logger.info(f'Weights loaded. Total time: {t}')
|
|
|
|
|
|
def load_state_dict(file_path: Path,
|
|
dtype: torch.dtype) -> Dict[str, np.ndarray]:
|
|
""" Load weights from model file
|
|
|
|
`safetensors` or `pytorch binary` is supported
|
|
|
|
# Args.
|
|
file_path: model file path, ends with .bin or .safetensors.
|
|
dtype: torch.dtype, data type.
|
|
# Returns.
|
|
Dict[str, torch.Tensor]
|
|
"""
|
|
|
|
state_dict = {}
|
|
if file_path.suffix == '.safetensors':
|
|
# load from safetensors file
|
|
from safetensors import safe_open
|
|
with safe_open(file_path, framework='pt', device='cpu') as f:
|
|
for name in f.keys():
|
|
param = f.get_tensor(name).to(dtype)
|
|
state_dict[name] = tensorrt_llm._utils.torch_to_numpy(param)
|
|
elif file_path.suffix == '.bin':
|
|
# load from pytorch bin file
|
|
state_dict = torch.load(file_path, map_location='cpu')
|
|
for name in state_dict:
|
|
param = state_dict[name].to(dtype)
|
|
state_dict[name] = tensorrt_llm._utils.torch_to_numpy(param)
|
|
else:
|
|
raise NotImplementedError(
|
|
f'Support .safetensors or .bin files, but got {str(file_path)}')
|
|
return state_dict
|
|
|
|
|
|
def retrieved_layer_index_from_name(name: str) -> Optional[int]:
|
|
res = re.search(r'\d+', name)
|
|
return int(res.group()) if res is not None else res
|
|
|
|
|
|
def iterate_shard_files(model_dir: Path, rank: int):
|
|
import tqdm
|
|
|
|
shard_files = list(model_dir.glob('*.bin')) + list(
|
|
model_dir.glob('*.safetensors'))
|
|
desc = f'Rank [{rank}] Loading weights'
|
|
for shard_file in tqdm.tqdm(shard_files, desc=desc, position=rank):
|
|
yield shard_file
|
|
|
|
|
|
def load_from_hf_checkpoint(
|
|
trtllm_falcon: tensorrt_llm.models.FalconForCausalLM,
|
|
model_dir: Union[str, Path],
|
|
mapping=Mapping(),
|
|
dtype: Union[str, torch.dtype] = torch.float32,
|
|
):
|
|
logger.info('Loading weights from HF Falcon...')
|
|
tik = time.time()
|
|
|
|
model_dir = Path(model_dir)
|
|
if isinstance(dtype, str):
|
|
dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype)
|
|
|
|
def is_bias(_name):
|
|
return 'bias' in _name
|
|
|
|
layers_range = trtllm_falcon.get_transformer_layers(
|
|
trtllm_falcon.mapping, trtllm_falcon.num_layers)
|
|
for model_file in iterate_shard_files(model_dir, mapping.tp_rank):
|
|
logger.debug(f'Loading file {str(model_file)}...')
|
|
state_dict = load_state_dict(model_file, dtype)
|
|
for name, param in state_dict.items():
|
|
logger.debug(f'Converting weight {name}...')
|
|
i = retrieved_layer_index_from_name(name)
|
|
if i is None:
|
|
layer = None
|
|
else:
|
|
if i not in layers_range:
|
|
continue
|
|
layer = trtllm_falcon.layers[i - layers_range[0]]
|
|
|
|
if 'self_attention.query_key_value' in name:
|
|
if not is_bias(name):
|
|
layer.attention.qkv.weight.value = split_qkv_weight(
|
|
trtllm_falcon,
|
|
param,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
is_bias=False,
|
|
num_kv_heads=trtllm_falcon.num_kv_heads)
|
|
else:
|
|
layer.attention.qkv.bias.value = split_qkv_weight(
|
|
trtllm_falcon,
|
|
param,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
is_bias=True,
|
|
num_kv_heads=trtllm_falcon.num_kv_heads)
|
|
elif 'self_attention.dense' in name:
|
|
if not is_bias(name):
|
|
layer.attention.dense.weight.value = split_matrix(
|
|
param, mapping.tp_size, mapping.tp_rank, dim=1)
|
|
else:
|
|
layer.attention.dense.bias.value = param
|
|
elif 'mlp.dense_h_to_4h' in name:
|
|
if not is_bias(name):
|
|
layer.mlp.fc.weight.value = split_matrix(param,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
else:
|
|
layer.mlp.fc.bias.value = split_matrix(param,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
elif 'mlp.dense_4h_to_h' in name:
|
|
if not is_bias(name):
|
|
layer.mlp.proj.weight.value = split_matrix(param,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
else:
|
|
layer.mlp.proj.bias.value = param
|
|
elif 'ln_attn' in name or 'input_layernorm' in name:
|
|
if not is_bias(name):
|
|
layer.input_layernorm.weight.value = param
|
|
else:
|
|
layer.input_layernorm.bias.value = param
|
|
elif 'ln_mlp' in name:
|
|
assert layer.mlp_layernorm is not None
|
|
if not is_bias(name):
|
|
layer.mlp_layernorm.weight.value = param
|
|
else:
|
|
layer.mlp_layernorm.bias.value = param
|
|
elif 'post_attention_layernorm' in name:
|
|
assert layer.post_layernorm is not None
|
|
if not is_bias(name):
|
|
layer.post_layernorm.weight.value = param
|
|
else:
|
|
layer.post_layernorm.bias.value = param
|
|
elif 'word_embeddings' in name:
|
|
if mapping.is_first_pp_rank():
|
|
trtllm_falcon.embedding.weight.value = param.copy()
|
|
if mapping.is_last_pp_rank():
|
|
trtllm_falcon.lm_head.weight.value = split_matrix(
|
|
param, mapping.tp_size, mapping.tp_rank, dim=0)
|
|
elif 'ln_f' in name:
|
|
if mapping.is_last_pp_rank():
|
|
if not is_bias(name):
|
|
trtllm_falcon.ln_f.weight.value = param
|
|
else:
|
|
trtllm_falcon.ln_f.bias.value = param
|
|
del state_dict
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
logger.info(f'Weights loaded. Total time: {t}')
|
|
|
|
|
|
def get_scaling_factors(
|
|
model_path: Union[str, Path],
|
|
num_layers: int,
|
|
quant_mode: Optional[QuantMode] = None,
|
|
) -> Optional[Dict[str, List[int]]]:
|
|
""" Get the scaling factors for Falcon model
|
|
|
|
Returns a dictionary of scaling factors for the selected layers of the
|
|
Falcon model.
|
|
|
|
Args:
|
|
model_path (str): Path to the quantized Falcon model
|
|
layers (list): List of layers to get the scaling factors for. If None,
|
|
all layers are selected.
|
|
|
|
Returns:
|
|
dict: Dictionary of scaling factors for the selected layers of the
|
|
Falcon model.
|
|
|
|
example:
|
|
|
|
{
|
|
'qkv_act': qkv_act_scale,
|
|
'qkv_weights': qkv_weights_scale,
|
|
'qkv_out' : qkv_outputs_scale,
|
|
'dense_act': dense_act_scale,
|
|
'dense_weights': dense_weights_scale,
|
|
'fc_act': fc_act_scale,
|
|
'fc_weights': fc_weights_scale,
|
|
'proj_act': proj_act_scale,
|
|
'proj_weights': proj_weights_scale,
|
|
}
|
|
"""
|
|
|
|
if model_path is None:
|
|
logger.warning(f"--quantized_fp8_model_path not specified. "
|
|
f"Initialize quantization scales automatically.")
|
|
return get_dummy_quant_scales(num_layers)
|
|
weight_dict = np.load(model_path)
|
|
|
|
# yapf: disable
|
|
scaling_factor = {
|
|
'qkv_act': [],
|
|
'qkv_weights': [],
|
|
'qkv_output': [],
|
|
'dense_act': [],
|
|
'dense_weights': [],
|
|
'fc_act': [],
|
|
'fc_weights': [],
|
|
'proj_act': [],
|
|
'proj_weights': [],
|
|
}
|
|
|
|
for layer in range(num_layers):
|
|
scaling_factor['qkv_act'].append(max(
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
|
|
))
|
|
scaling_factor['qkv_weights'].append(max(
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
|
|
))
|
|
if quant_mode is not None and quant_mode.has_fp8_kv_cache():
|
|
# Not calibrarting KV cache.
|
|
scaling_factor['qkv_output'].append(1.0)
|
|
scaling_factor['dense_act'].append(weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
|
|
scaling_factor['dense_weights'].append(weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
|
|
scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
|
|
scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
|
|
scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
|
|
scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
|
|
# yapf: enable
|
|
for k, v in scaling_factor.items():
|
|
assert len(v) == num_layers, \
|
|
f'Expect scaling factor {k} of length {num_layers}, got {len(v)}'
|
|
|
|
return scaling_factor
|