mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
183 lines
7.8 KiB
Python
183 lines
7.8 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm._utils import str_dtype_to_torch
|
|
from tensorrt_llm.models import ChatGLM2HeadModel
|
|
from tensorrt_llm.quantization import QuantMode
|
|
|
|
|
|
def extract_layer_idx(name):
|
|
ss = name.split('.')
|
|
for s in ss:
|
|
if s.isdigit():
|
|
return s
|
|
return None
|
|
|
|
|
|
def split(v, tp_size, idx, dim=0):
|
|
if tp_size == 1:
|
|
return v
|
|
if len(v.shape) == 1:
|
|
return np.ascontiguousarray(np.split(v, tp_size)[idx])
|
|
else:
|
|
return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx])
|
|
|
|
|
|
def load_from_hf_chatglm2_6B(tensorrt_llm_model,
|
|
hf_model,
|
|
rank=0,
|
|
tensor_parallel=1,
|
|
dtype="float32",
|
|
multi_query_mode=False):
|
|
tensorrt_llm.logger.info('Loading weights from HF ChatGLM2...')
|
|
time.time()
|
|
|
|
quant_mode = getattr(tensorrt_llm_model, 'quant_mode', QuantMode(0))
|
|
if quant_mode.is_int8_weight_only():
|
|
plugin_weight_only_quant_type = torch.int8
|
|
elif quant_mode.is_int4_weight_only():
|
|
plugin_weight_only_quant_type = torch.quint4x2
|
|
use_weight_only = quant_mode.is_weight_only()
|
|
|
|
str_dtype_to_torch(dtype)
|
|
tensorrt_llm_model.encoder.final_layernorm.weight.value = hf_model.transformer.encoder.final_layernorm.weight.detach(
|
|
).cpu().numpy()
|
|
tensorrt_llm_model.embedding.weight.value = hf_model.transformer.embedding.word_embeddings.weight.detach(
|
|
).cpu().numpy()
|
|
tensorrt_llm_model.lm_head.weight.value = hf_model.transformer.output_layer.weight.detach(
|
|
).cpu().numpy()
|
|
|
|
def load_quant_weight(src, value_dst, scale_dst,
|
|
plugin_weight_only_quant_type):
|
|
v = np.ascontiguousarray(src.transpose())
|
|
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
|
torch.tensor(v), plugin_weight_only_quant_type)
|
|
# workaround for trt not supporting int8 inputs in plugins currently
|
|
value_dst.value = processed_torch_weights.view(
|
|
dtype=torch.float32).numpy()
|
|
scale_dst.value = torch_weight_scales.numpy()
|
|
|
|
for i in range(28):
|
|
tensorrt_llm_model.encoder.layers[
|
|
i].input_layernorm.weight.value = hf_model.transformer.encoder.layers[
|
|
i].input_layernorm.weight.detach().cpu().numpy()
|
|
tensorrt_llm_model.encoder.layers[
|
|
i].post_attention_layernorm.weight.value = hf_model.transformer.encoder.layers[
|
|
i].post_attention_layernorm.weight.detach().cpu().numpy()
|
|
tensorrt_llm_model.encoder.layers[
|
|
i].self_attention.qkv.bias.value = hf_model.transformer.encoder.layers[
|
|
i].self_attention.query_key_value.bias.detach().cpu().numpy()
|
|
# swap first and secont half weight columns to adapt trt_llm Swiglu
|
|
h_to_4h_weight = hf_model.transformer.encoder.layers[
|
|
i].mlp.dense_h_to_4h.weight.detach().cpu()
|
|
h_to_4h_weight = torch.split(h_to_4h_weight,
|
|
h_to_4h_weight.shape[0] // 2, 0)
|
|
h_to_4h_weight = torch.concat(h_to_4h_weight[::-1], 0).numpy()
|
|
if use_weight_only:
|
|
|
|
load_quant_weight(
|
|
src=h_to_4h_weight,
|
|
value_dst=tensorrt_llm_model.encoder.layers[i].mlp.fc.weight,
|
|
scale_dst=tensorrt_llm_model.encoder.layers[i].mlp.fc.
|
|
per_channel_scale,
|
|
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
|
|
load_quant_weight(
|
|
src=hf_model.transformer.encoder.layers[i].mlp.dense_4h_to_h.
|
|
weight.detach().cpu().numpy(),
|
|
value_dst=tensorrt_llm_model.encoder.layers[i].mlp.proj.weight,
|
|
scale_dst=tensorrt_llm_model.encoder.layers[i].mlp.proj.
|
|
per_channel_scale,
|
|
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
|
|
load_quant_weight(
|
|
src=hf_model.transformer.encoder.layers[i].self_attention.
|
|
query_key_value.weight.detach().cpu().numpy(),
|
|
value_dst=tensorrt_llm_model.encoder.layers[i].self_attention.
|
|
qkv.weight,
|
|
scale_dst=tensorrt_llm_model.encoder.layers[i].self_attention.
|
|
qkv.per_channel_scale,
|
|
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
|
|
load_quant_weight(
|
|
src=hf_model.transformer.encoder.layers[i].self_attention.dense.
|
|
weight.detach().cpu().numpy(),
|
|
value_dst=tensorrt_llm_model.encoder.layers[i].self_attention.
|
|
dense.weight,
|
|
scale_dst=tensorrt_llm_model.encoder.layers[i].self_attention.
|
|
dense.per_channel_scale,
|
|
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
|
|
|
|
else:
|
|
tensorrt_llm_model.encoder.layers[
|
|
i].self_attention.qkv.weight.value = hf_model.transformer.encoder.layers[
|
|
i].self_attention.query_key_value.weight.detach().cpu(
|
|
).numpy()
|
|
tensorrt_llm_model.encoder.layers[
|
|
i].self_attention.dense.weight.value = hf_model.transformer.encoder.layers[
|
|
i].self_attention.dense.weight.detach().cpu().numpy()
|
|
tensorrt_llm_model.encoder.layers[
|
|
i].mlp.fc.weight.value = h_to_4h_weight
|
|
tensorrt_llm_model.encoder.layers[
|
|
i].mlp.proj.weight.value = hf_model.transformer.encoder.layers[
|
|
i].mlp.dense_4h_to_h.weight.detach().cpu().numpy()
|
|
return tensorrt_llm_model
|
|
|
|
|
|
if __name__ == '__main__':
|
|
from tensorrt_llm.layers.attention import PositionEmbeddingType
|
|
from tensorrt_llm.models import weight_only_quantize
|
|
from tensorrt_llm.quantization import QuantMode
|
|
|
|
kv_dtype = 'float16'
|
|
quant_mode = QuantMode.use_weight_only(False)
|
|
tensorrt_llm_ChatGLM2_6BModel = ChatGLM2HeadModel(
|
|
num_layers=28,
|
|
num_heads=32,
|
|
hidden_size=4096,
|
|
inter_size=None,
|
|
vocab_size=65024,
|
|
hidden_act='swiglu',
|
|
max_position_embeddings=4096,
|
|
position_embedding_type=PositionEmbeddingType.learned_absolute,
|
|
rotary_embedding_percentage=1.0,
|
|
dtype=kv_dtype,
|
|
tensor_parallel=1, # TP only
|
|
tensor_parallel_group=list(range(1)), # TP only
|
|
apply_query_key_layer_scaling=False,
|
|
quant_mode=quant_mode,
|
|
bias=False,
|
|
multi_query_mode=False)
|
|
tensorrt_llm_ChatGLM2_6BModel = weight_only_quantize(
|
|
tensorrt_llm_ChatGLM2_6BModel, quant_mode)
|
|
|
|
model_dir = './pyTorchModel'
|
|
|
|
print(f'Loading HF Chat_GLM2 ... from {model_dir}')
|
|
|
|
import transformers
|
|
hf_model = transformers.AutoModel.from_pretrained(
|
|
model_dir, trust_remote_code=True).cpu()
|
|
|
|
load_from_hf_chatglm2_6B(tensorrt_llm_ChatGLM2_6BModel,
|
|
hf_model,
|
|
0,
|
|
1,
|
|
dtype='float16',
|
|
multi_query_mode=False)
|
|
del hf_model
|