mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: DreamGenX <x@dreamgen.com> Co-authored-by: Ace-RR <78812427+Ace-RR@users.noreply.github.com> Co-authored-by: bprus <39293131+bprus@users.noreply.github.com> Co-authored-by: janpetrov <janpetrov@icloud.com>
500 lines
18 KiB
Python
500 lines
18 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import jax
|
|
import torch
|
|
from jax import dlpack as jax_dlpack
|
|
from torch.utils import dlpack as torch_dlpack
|
|
|
|
from ..._utils import pad_vocab_size, release_gc
|
|
from ...layers import MoeConfig
|
|
from ...logger import logger
|
|
from ...quantization import QuantAlgo
|
|
from ..modeling_utils import PretrainedConfig, QuantConfig, optimize_model
|
|
|
|
|
|
def split(v, tp_size, idx, dim=0):
|
|
if tp_size == 1:
|
|
return v
|
|
if len(v.shape) == 1:
|
|
return torch.chunk(v, tp_size)[idx].contiguous()
|
|
else:
|
|
return torch.chunk(v, tp_size, dim=dim)[idx]
|
|
|
|
|
|
def split_matrix_tp(v, tensor_parallel, rank, dim):
|
|
return split(v, tensor_parallel, rank, dim=dim)
|
|
|
|
|
|
def get_weight(config, prefix, dtype, postfix='.weight'):
|
|
if config[prefix + postfix].dtype != dtype:
|
|
config[prefix + postfix].data = config[prefix + postfix].to(dtype)
|
|
return config[prefix + postfix].detach().cpu()
|
|
|
|
|
|
def get_jax_weight(config, prefix, dtype, postfix='.weight', key_name='scale'):
|
|
return torch.as_tensor((config[prefix + postfix][key_name])._value,
|
|
dtype=dtype).T
|
|
|
|
|
|
def get_jax_weight_scale(params, key):
|
|
jax_obj = params[key]['w']
|
|
jax_scales = jax.device_put(jax_obj.scales, device=jax.devices('cpu')[0])
|
|
# jax_scales = jax.device_put(jax_obj.scales, device=jax.devices('gpu')[rank])
|
|
torch_scales = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(jax_scales))
|
|
return torch.as_tensor(jax_obj.weight._value,
|
|
dtype=torch.int8), torch_scales
|
|
|
|
|
|
def get_tllm_linear_weight(
|
|
weight,
|
|
torch_weight_scales,
|
|
prefix,
|
|
plugin_weight_only_quant_type=torch.int8,
|
|
postfix='weight',
|
|
):
|
|
results = {}
|
|
processed_weight = torch.ops.trtllm.preprocess_weights_for_mixed_gemm(
|
|
weight.contiguous(), plugin_weight_only_quant_type, torch.bfloat16)
|
|
results[prefix + postfix] = processed_weight
|
|
|
|
results[prefix + 'per_channel_scale'] = torch_weight_scales.contiguous()
|
|
|
|
return results
|
|
|
|
|
|
def convert_grok(hf_model,
|
|
config,
|
|
mapping,
|
|
vocab_size=32000,
|
|
dtype='float32',
|
|
use_parallel_embedding=False,
|
|
sharding_dim=0,
|
|
use_weight_only=False,
|
|
share_embedding_table=False,
|
|
use_gemm_woq_plugin=False,
|
|
plugin_weight_only_quant_type=torch.int8,
|
|
moe_config=None):
|
|
|
|
weights = {}
|
|
tik = time.time()
|
|
tensor_parallel = mapping.tp_size
|
|
model_params = hf_model
|
|
dtype = getattr(torch, dtype)
|
|
|
|
num_attention_heads = config['num_attention_heads']
|
|
hidden_size = config['hidden_size']
|
|
hidden_size // num_attention_heads
|
|
|
|
layers_range = mapping.pp_layers(config['num_hidden_layers'])
|
|
|
|
# layers_range = mapping.pp_layers(2)
|
|
|
|
def convert_layer(l):
|
|
prefix = f'transformer/decoder_layer_{l}/'
|
|
print(prefix)
|
|
tllm_prex = f'transformer.layers.{l - layers_range[0]}.'
|
|
|
|
q_weight, q_scale = get_jax_weight_scale(
|
|
model_params, prefix + 'multi_head_attention/query')
|
|
k_weight, k_scale = get_jax_weight_scale(
|
|
model_params, prefix + 'multi_head_attention/key')
|
|
v_weight, v_scale = get_jax_weight_scale(
|
|
model_params, prefix + 'multi_head_attention/value')
|
|
|
|
wq = split(q_weight, mapping.tp_size, mapping.tp_rank, dim=1)
|
|
wk = split(k_weight, mapping.tp_size, mapping.tp_rank, dim=1)
|
|
wv = split(v_weight, mapping.tp_size, mapping.tp_rank, dim=1)
|
|
qs = split(q_scale, mapping.tp_size, mapping.tp_rank, dim=1)
|
|
ks = split(k_scale, mapping.tp_size, mapping.tp_rank, dim=1)
|
|
vs = split(v_scale, mapping.tp_size, mapping.tp_rank, dim=1)
|
|
split_v = torch.concat((wq, wk, wv), dim=1)
|
|
scale_v = torch.concat((qs, ks, vs), dim=1)
|
|
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, scale_v.squeeze(),
|
|
tllm_prex + 'attention.qkv.',
|
|
plugin_weight_only_quant_type))
|
|
|
|
attn_dense_weight, attn_dense_scales = get_jax_weight_scale(
|
|
model_params, prefix + 'multi_head_attention/linear')
|
|
|
|
split_v = split_matrix_tp(attn_dense_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
split_scales = split_matrix_tp(attn_dense_scales,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, split_scales.squeeze(),
|
|
tllm_prex + 'attention.dense.',
|
|
plugin_weight_only_quant_type))
|
|
|
|
w3, s3 = get_jax_weight_scale(
|
|
model_params, f'transformer/decoder_layer_{l}/moe/linear_v')
|
|
|
|
w2, s2 = get_jax_weight_scale(
|
|
model_params, f'transformer/decoder_layer_{l}/moe/linear_1')
|
|
|
|
w1, s1 = get_jax_weight_scale(
|
|
model_params, f'transformer/decoder_layer_{l}/moe/linear')
|
|
|
|
# moe expert parallel
|
|
w3_split = split(w3, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0)
|
|
w2_split = split(w2, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0)
|
|
w1_split = split(w1, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0)
|
|
|
|
s3_split = split(s3, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0)
|
|
s2_split = split(s2, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0)
|
|
s1_split = split(s1, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0)
|
|
# moe tensor parallel
|
|
w3_split = split(w3_split,
|
|
mapping.moe_tp_size,
|
|
mapping.moe_tp_rank,
|
|
dim=2)
|
|
w2_split = split(w2_split,
|
|
mapping.moe_tp_size,
|
|
mapping.moe_tp_rank,
|
|
dim=1)
|
|
w1_split = split(w1_split,
|
|
mapping.moe_tp_size,
|
|
mapping.moe_tp_rank,
|
|
dim=2)
|
|
|
|
s3_split = split(s3_split,
|
|
mapping.moe_tp_size,
|
|
mapping.moe_tp_rank,
|
|
dim=2)
|
|
s2_split = split(s2_split,
|
|
mapping.moe_tp_size,
|
|
mapping.moe_tp_rank,
|
|
dim=1)
|
|
s1_split = split(s1_split,
|
|
mapping.moe_tp_size,
|
|
mapping.moe_tp_rank,
|
|
dim=2)
|
|
|
|
weights.update(
|
|
get_tllm_linear_weight(w2_split,
|
|
s2_split.reshape(moe_config.num_experts, -1),
|
|
tllm_prex + 'mlp.proj.',
|
|
plugin_weight_only_quant_type))
|
|
|
|
weights.update(
|
|
get_tllm_linear_weight(
|
|
torch.concat([w3_split, w1_split], dim=-1),
|
|
torch.concat([s3_split, s1_split],
|
|
dim=-1).reshape(moe_config.num_experts, -1),
|
|
tllm_prex + 'mlp.fc.',
|
|
plugin_weight_only_quant_type,
|
|
))
|
|
|
|
moe_experts_gate_weights = get_jax_weight(model_params,
|
|
prefix + 'router',
|
|
torch.float32,
|
|
postfix='',
|
|
key_name='w').contiguous()
|
|
|
|
weights[tllm_prex + 'mlp.router.weight'] = moe_experts_gate_weights
|
|
|
|
# Layer norms do not use tensor parallelism
|
|
input_ln_weight = get_jax_weight(model_params,
|
|
prefix + 'rms_norm',
|
|
dtype,
|
|
postfix='')
|
|
weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight
|
|
|
|
post_attn_weight = get_jax_weight(model_params,
|
|
prefix + 'rms_norm_1',
|
|
dtype,
|
|
postfix='')
|
|
weights[tllm_prex + 'post_attn_layernorm.weight'] = post_attn_weight
|
|
|
|
post_ln_weight = get_jax_weight(model_params,
|
|
prefix + 'rms_norm_2',
|
|
dtype,
|
|
postfix='')
|
|
weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight
|
|
|
|
post_mlp_weight = get_jax_weight(model_params,
|
|
prefix + 'rms_norm_3',
|
|
dtype,
|
|
postfix='')
|
|
weights[tllm_prex + 'post_mlp_layernorm.weight'] = post_mlp_weight
|
|
|
|
for l in layers_range:
|
|
convert_layer(l)
|
|
release_gc()
|
|
|
|
v = get_jax_weight(model_params,
|
|
'language_model/in_out_embed',
|
|
dtype,
|
|
postfix='',
|
|
key_name='embeddings').T
|
|
tie_word_embeddings = config['tie_word_embeddings']
|
|
if tie_word_embeddings:
|
|
# lm_head.weight has the same weights as embedding
|
|
if mapping.is_last_pp_rank():
|
|
if vocab_size % mapping.tp_size != 0:
|
|
# padding
|
|
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
|
|
pad_width = vocab_size_padded - vocab_size
|
|
|
|
v = torch.nn.functional.pad(v, (0, pad_width, 0, 0), 'constant',
|
|
0)
|
|
weights['lm_head.weight'] = split(v, mapping.tp_size,
|
|
mapping.tp_rank)
|
|
|
|
if use_parallel_embedding:
|
|
v = split_matrix_tp(v,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=sharding_dim)
|
|
|
|
if mapping.is_first_pp_rank():
|
|
weights['transformer.vocab_embedding.weight'] = v
|
|
|
|
ln_f_w = get_jax_weight(model_params,
|
|
'language_model/rms_norm',
|
|
dtype,
|
|
postfix='')
|
|
weights['transformer.ln_f.weight'] = ln_f_w
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
print(f'Weights loaded. Total time: {t}')
|
|
return weights
|
|
|
|
|
|
def create_config_from_xai(dtype,
|
|
mapping,
|
|
quantization: QuantConfig = None,
|
|
override_fields: dict = {}):
|
|
config = {}
|
|
hf_config = {
|
|
"architectures": ["Grok1ModelForCausalLM"],
|
|
"vocab_size": 131072,
|
|
"hidden_size": 6144,
|
|
"intermediate_size": 32768,
|
|
"num_hidden_layers": 64,
|
|
"num_attention_heads": 48,
|
|
"num_key_value_heads": 8,
|
|
"attn_output_multiplier": 0.08838834764831845,
|
|
"embedding_multiplier_scale": 78.38367176906169,
|
|
"output_multiplier_scale": 0.5773502691896257,
|
|
"max_attn_value": 30.0,
|
|
"max_position_embeddings": 8192,
|
|
"rms_norm_eps": 1e-5,
|
|
"use_cache": True,
|
|
"pad_token_id": 0,
|
|
"bos_token_id": 1,
|
|
"eos_token_id": 2,
|
|
"tie_word_embeddings": True,
|
|
"num_experts_per_tok": 2,
|
|
"num_experts": 8,
|
|
"output_router_logits": False,
|
|
"router_aux_loss_coef": 0.001,
|
|
"torch_dtype": "bfloat16",
|
|
"transformers_version": "4.35.0"
|
|
}
|
|
# same for from_meta and from_cli_args
|
|
n_head = hf_config['num_attention_heads']
|
|
inter_size = hf_config['intermediate_size']
|
|
n_layer = hf_config['num_hidden_layers']
|
|
# n_layer = 2
|
|
n_embd = hf_config['hidden_size']
|
|
n_kv_head = hf_config['num_key_value_heads']
|
|
rms_norm_eps = hf_config['rms_norm_eps']
|
|
vocab_size = hf_config['vocab_size']
|
|
n_positions = hf_config['max_position_embeddings']
|
|
hidden_act = 'geglu'
|
|
config['rotary_scaling'] = None
|
|
rotary_base = 10000.0
|
|
|
|
config[
|
|
'moe_normalization_mode'] = MoeConfig.ExpertScaleNormalizationMode.NONE
|
|
|
|
moe_num_experts = hf_config['num_experts']
|
|
|
|
moe_top_k = hf_config['num_experts_per_tok']
|
|
|
|
attn_output_multiplier = hf_config['attn_output_multiplier']
|
|
embedding_multiplier_scale = hf_config['embedding_multiplier_scale']
|
|
|
|
output_multiplier_scale = hf_config['output_multiplier_scale']
|
|
max_attn_value = hf_config['max_attn_value']
|
|
|
|
architecture = hf_config['architectures'][0]
|
|
|
|
attn_bias = False
|
|
|
|
config.update({
|
|
'architecture':
|
|
architecture,
|
|
'dtype':
|
|
dtype,
|
|
'logits_dtype':
|
|
'float32',
|
|
'num_hidden_layers':
|
|
n_layer,
|
|
'num_attention_heads':
|
|
n_head,
|
|
'hidden_size':
|
|
n_embd,
|
|
'intermediate_size':
|
|
inter_size,
|
|
'num_key_value_heads':
|
|
n_kv_head,
|
|
'vocab_size':
|
|
vocab_size,
|
|
'position_embedding_type':
|
|
'rope_gpt_neox',
|
|
'max_position_embeddings':
|
|
n_positions,
|
|
'hidden_act':
|
|
hidden_act,
|
|
'rotary_base':
|
|
rotary_base,
|
|
'norm_epsilon':
|
|
rms_norm_eps,
|
|
'moe_num_experts':
|
|
moe_num_experts,
|
|
'moe_top_k':
|
|
moe_top_k,
|
|
'moe_normalization_mode':
|
|
MoeConfig.ExpertScaleNormalizationMode.NONE,
|
|
#TODO: should have directly map from the Mapping object to the TRT-LLM checkpoint fields
|
|
'mapping': {
|
|
'world_size': mapping.tp_size * mapping.pp_size,
|
|
'tp_size': mapping.tp_size,
|
|
'pp_size': mapping.pp_size,
|
|
'moe_tp_size': mapping.moe_tp_size,
|
|
'moe_ep_size': mapping.moe_ep_size,
|
|
},
|
|
'attn_bias':
|
|
attn_bias,
|
|
"attn_output_multiplier":
|
|
attn_output_multiplier,
|
|
"embedding_multiplier_scale":
|
|
embedding_multiplier_scale,
|
|
"output_multiplier_scale":
|
|
output_multiplier_scale,
|
|
"max_attn_value":
|
|
max_attn_value,
|
|
"tie_word_embeddings":
|
|
True,
|
|
})
|
|
|
|
config['quantization'] = quantization.to_dict()
|
|
config.update(override_fields)
|
|
|
|
return config
|
|
|
|
|
|
def from_hugging_face(cls,
|
|
model_dir,
|
|
dtype,
|
|
*,
|
|
mapping,
|
|
quantization: QuantConfig = None,
|
|
override_fields={},
|
|
skip_loading_weights=False,
|
|
preloaded_model=None):
|
|
''' Create a LLaMAForCausalLM object from give parameters
|
|
'''
|
|
assert model_dir is not None
|
|
if isinstance(model_dir, Path): # some code relies on this as string
|
|
model_dir = str(model_dir)
|
|
|
|
if override_fields.get('share_embedding_table', False):
|
|
logger.warning(
|
|
"Llama model does not support share_embedding_table; setting share_embedding_table=False"
|
|
)
|
|
override_fields['share_embedding_table'] = False
|
|
|
|
config = create_config_from_xai(dtype,
|
|
mapping,
|
|
quantization,
|
|
override_fields=override_fields)
|
|
|
|
pretrained_config = PretrainedConfig.from_dict(config)
|
|
pretrained_config.set_rank(mapping.rank) # TODO:remove this hack
|
|
|
|
grok = cls.from_config(pretrained_config)
|
|
grok = optimize_model(
|
|
grok,
|
|
use_parallel_embedding=pretrained_config.use_parallel_embedding,
|
|
share_embedding_table=pretrained_config.share_embedding_table,
|
|
)
|
|
|
|
if skip_loading_weights:
|
|
return grok
|
|
|
|
weights = load_weights_from_xai(config=config,
|
|
mapping=mapping,
|
|
model=preloaded_model)
|
|
|
|
grok.load(weights)
|
|
return grok
|
|
|
|
|
|
def quantize(dtype,
|
|
model_dir,
|
|
output_dir,
|
|
mapping,
|
|
quantization: QuantConfig,
|
|
*,
|
|
override_fields,
|
|
dataset_cache_dir: Optional[str] = None):
|
|
'''
|
|
Quantize the save the model as TRT-LLM checkpoint to output_dir
|
|
'''
|
|
pass #The official grok-1 model is published under int8 wo format, we don't need to quantize again.
|
|
|
|
|
|
def load_weights_from_xai(*, config, mapping, model):
|
|
assert model is not None
|
|
plugin_weight_only_quant_type = None # the value does not matter when use_weight_only is False
|
|
quant_algo = config['quantization']['quant_algo']
|
|
assert quant_algo == QuantAlgo.W8A16
|
|
plugin_weight_only_quant_type = torch.int8
|
|
|
|
moe_config = MoeConfig(config['moe_num_experts'], config['moe_top_k'],
|
|
config['moe_normalization_mode']).validate()
|
|
|
|
use_weight_only = quant_algo in [QuantAlgo.W8A16]
|
|
|
|
weights = convert_grok(
|
|
model,
|
|
config,
|
|
mapping,
|
|
vocab_size=config['vocab_size'],
|
|
dtype=config['dtype'],
|
|
use_weight_only=use_weight_only,
|
|
use_gemm_woq_plugin=not config.get('disable_weight_only_quant_plugin',
|
|
False),
|
|
plugin_weight_only_quant_type=plugin_weight_only_quant_type,
|
|
use_parallel_embedding=config.get('use_parallel_embedding', False),
|
|
sharding_dim=config.get('embedding_sharding_dim', 0),
|
|
share_embedding_table=config.get('share_embedding_table', False),
|
|
moe_config=moe_config)
|
|
return weights
|