TensorRT-LLMs/tensorrt_llm/models/chatglm2_6b/model.py
2023-10-10 23:22:17 -07:00

678 lines
26 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import tensorrt as trt
import torch
from ..._common import default_net
from ..._utils import pad_vocab_size, str_dtype_to_trt
from ...functional import (PositionEmbeddingType, Tensor, concat, constant,
expand, expand_dims, gather_last_token_logits,
gpt_attention, index_select, select, shape, slice,
split)
from ...layers import (MLP, AttentionMaskType, AttentionParams, ColumnLinear,
Embedding, KeyValueCacheParams, RmsNorm, RowLinear)
from ...mapping import Mapping
from ...module import Module, ModuleList
from ...parameter import Parameter
from ...quantization import QuantMode
from ..generation_mixin import GenerationMixin
def apply_rotary_pos_emb_trt(x: Tensor, rope_cache: Tensor) -> Tensor:
# x-> [seq, batch, num_heads, 2]
x = x.permute((1, 0, 2, 3))
# sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
sq = shape(x, 0)
b = shape(x, 1)
nh = shape(x, 2)
shape(x, 3)
# rope_cache shape: seq,batch,heads,2 rot_dim = 2* numheads
#rope_cache: seq,batch,num_states/4,2
rot_dim = shape(rope_cache, 2) * constant(np.array(2, dtype=np.int32))
starts = concat([0, 0, 0, 0])
sizes = concat([sq, b, nh, rot_dim])
# first half
x_rot = slice(x, starts, sizes)
starts = concat([0, 0, 0, rot_dim])
# second half
x_pass = slice(x, starts, sizes)
# truncate to support variable sizes
rope_cache = slice(rope_cache, (0, 0, 0, 0), (concat(
[sq,
shape(rope_cache, 1),
shape(rope_cache, 2),
shape(rope_cache, 3)])))
xshaped = x_rot.view(concat([sq, b, nh, rot_dim / 2, 2]))
rope_cache = rope_cache.view(concat([sq, b, 1, shape(xshaped, 3), 2]))
# first half
xshape0 = select(xshaped, 4, 0)
# second half
xshape1 = select(xshaped, 4, 1)
# first half
rope_cache0 = select(rope_cache, 4, 0)
# second half
rope_cache1 = select(rope_cache, 4, 1)
out0 = xshape0 * rope_cache0 - xshape1 * rope_cache1
out1 = xshape1 * rope_cache0 + xshape0 * rope_cache1
out0 = expand_dims(out0, 4)
out1 = expand_dims(out1, 4)
x_out2_v1 = concat([out0, out1], 4)
x_out2 = x_out2_v1.view(
concat([sq, b, nh, shape(x_out2_v1, 3) * shape(x_out2_v1, 4)]))
output = concat([x_out2, x_pass], dim=3)
# to batch,seq,num_group,head_states
output = output.permute((1, 0, 2, 3))
return output
class RotaryEmbedding(Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, seq_len: int):
theta = 1.0 / (10000**(torch.arange(0, self.dim, 2) / self.dim))
seq_idx = torch.arange(seq_len)
idx_theta = torch.outer(seq_idx, theta).float()
cache = torch.stack(
[torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
cache = cache.half()
# create rope embeddings and make it constant
cache = constant(cache.numpy())
return cache
class ChatGLM2Attention(Module):
def __init__(
self,
hidden_size,
num_attention_heads,
layer_number,
kv_channels=128,
multi_query_group_num=2,
apply_query_key_layer_scaling=False,
attention_mask_type=AttentionMaskType.causal,
qkv_bias=True,
linear_bias=False,
dtype='float16',
use_int8_kv_cache=False,
tp_group=None,
tp_size=1,
):
super().__init__()
self.attention_mask_type = attention_mask_type
self.attention_head_size = hidden_size // num_attention_heads
self.num_attention_heads = num_attention_heads // tp_size
self.num_multi_query_groups_per_partition = multi_query_group_num
self.num_attention_kv_heads = self.num_attention_heads
self.hidden_size = hidden_size // tp_size
self.projection_size = num_attention_heads * kv_channels
self.hidden_size_per_attention_head = kv_channels
self.layer_number = layer_number
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.q_scaling = 1
if apply_query_key_layer_scaling:
self.q_scaling *= self.layer_number
self.position_embedding_type = PositionEmbeddingType.learned_absolute
self.multi_block_mode = False
self.multi_query_mode = False
self.rotary_embedding_dim = 0
self.dtype = dtype
self.use_int8_kv_cache = use_int8_kv_cache
if self.use_int8_kv_cache:
self.kv_orig_quant_scale = Parameter(shape=(1, ), dtype='float32')
self.kv_quant_orig_scale = Parameter(shape=(1, ), dtype='float32')
else:
self.register_parameter('kv_orig_quant_scale', None)
self.register_parameter('kv_quant_orig_scale', None)
# Note: in multi_query_mode, only query heads are split between multiple GPUs,
# while key/value head are not split as there is only one head per key/value.
# The output feature size is therefore (h/tp + 2) * d, where h is num_heads,
# d is head_size, and tp is tensor_parallel_size.
# In ColumnLinear op, the output dim is calculated by (h + 2*tp) * d / tp,
# which matches the desired output size (h/tp + 2) * d after splitting
self.qkv_hidden_size = (self.projection_size +
2 * self.hidden_size_per_attention_head * 2)
self.qkv = ColumnLinear(hidden_size,
self.qkv_hidden_size,
bias=qkv_bias,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size,
gather_output=False)
self.dense = RowLinear(hidden_size,
hidden_size,
bias=linear_bias,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size)
def forward(self,
hidden_states: Tensor,
rotary_pos_emb,
use_cache=True,
kv_cache_params=None,
attention_params=None):
if not default_net().plugin_config.gpt_attention_plugin:
raise ValueError(
'ChatGLM2 is only supported with GPTAttention plugin,pleas build it with --use_gpt_attention_plugin argument.'
)
assert isinstance(hidden_states, Tensor)
qkv = self.qkv(hidden_states)
query, key, value = split(qkv, [
self.num_attention_heads * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
],
dim=-1)
query = query.view(
concat([
shape(qkv, 0),
shape(qkv, 1), self.num_attention_heads,
self.attention_head_size
]))
key = key.view(
concat([
shape(qkv, 0),
shape(qkv, 1), self.num_multi_query_groups_per_partition,
self.attention_head_size
]))
value = value.view(
concat([
shape(qkv, 0),
shape(qkv, 1), self.num_multi_query_groups_per_partition,
self.attention_head_size
]))
if rotary_pos_emb is not None:
query = apply_rotary_pos_emb_trt(query, rotary_pos_emb)
key = apply_rotary_pos_emb_trt(key, rotary_pos_emb)
# batch,seq,num_group,1,head_states
key = expand_dims(key, 3)
#expand 16x
expand_rate = self.num_attention_heads // self.num_multi_query_groups_per_partition
key = expand(
key,
concat([
shape(key, 0),
shape(key, 1),
shape(key, 2), expand_rate,
shape(key, 4)
]))
# batch,seq,num_heads,head_states
key = key.view(
concat([
shape(key, 0),
shape(key, 1),
shape(key, 2) * shape(key, 3),
shape(key, 4)
]))
value = expand_dims(value, 3)
value = expand(
value,
concat([
shape(value, 0),
shape(value, 1),
shape(value, 2), expand_rate,
shape(value, 4)
]))
value = value.view(
concat([
shape(value, 0),
shape(value, 1),
shape(value, 2) * shape(value, 3),
shape(value, 4)
]))
qkv = concat([query, key, value], dim=2)
qkv = qkv.view(
concat([shape(qkv, 0),
shape(qkv, 1), self.hidden_size * 3]))
assert attention_params.is_valid(
default_net().plugin_config.gpt_attention_plugin,
default_net().plugin_config.remove_input_padding)
assert kv_cache_params.is_valid(
default_net().plugin_config.gpt_attention_plugin)
kv_orig_quant_scale = self.kv_orig_quant_scale.value if self.use_int8_kv_cache else None
kv_quant_orig_scale = self.kv_quant_orig_scale.value if self.use_int8_kv_cache else None
context, past_key_value = gpt_attention(
tensor=qkv,
past_key_value=kv_cache_params.get_first_past_key_value(),
sequence_length=attention_params.sequence_length,
host_past_key_value_lengths=kv_cache_params.
host_past_key_value_lengths,
context_lengths=attention_params.context_lengths,
cache_indirection=kv_cache_params.cache_indirection,
host_request_types=attention_params.host_request_types,
num_heads=self.num_attention_heads,
num_kv_heads=self.
num_attention_heads, # since self.multi_query_mode is set to False
hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling,
rotary_embedding_dim=self.rotary_embedding_dim,
position_embedding_type=self.position_embedding_type,
multi_block_mode=self.multi_block_mode,
kv_orig_quant_scale=kv_orig_quant_scale,
kv_quant_orig_scale=kv_quant_orig_scale,
kv_cache_quant_mode=QuantMode.INT8_KV_CACHE
if self.use_int8_kv_cache else QuantMode(0),
kv_cache_block_pointers=kv_cache_params.
get_first_kv_cache_block_pointers(),
max_context_length=attention_params.max_context_length,
host_context_lengths=attention_params.host_context_lengths)
# dense layer after self-attention
context = self.dense(context)
if use_cache:
return (context, past_key_value)
else:
return context
class ChatGLM2Block(Module):
def __init__(self,
hidden_size,
num_attention_heads,
kv_channels=128,
multi_query_group_num=2,
apply_query_key_layer_scaling=False,
attention_mask_type=AttentionMaskType.causal,
qkv_bias=True,
linear_bias=False,
use_int8_kv_cache=False,
tp_group=None,
tp_size=1,
ffn_hiden_size=13696,
layer_number=1,
eps=1e-5,
act_func='swiglu',
dtype=trt.float16,
quant_mode=QuantMode(0)):
super(ChatGLM2Block, self).__init__()
self.layer_number = layer_number
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.dtype = dtype
self.ffn_hiden_size = ffn_hiden_size
self.apply_residual_connection_post_layernorm = False
self.fp32_residual_connection = False
LayerNormFunc = RmsNorm
# Layernorm on the input data.
self.input_layernorm = LayerNormFunc(self.hidden_size,
eps=eps,
dtype=dtype)
# Self attention.
self.self_attention = ChatGLM2Attention(
hidden_size, num_attention_heads, layer_number, kv_channels,
multi_query_group_num, apply_query_key_layer_scaling,
attention_mask_type, qkv_bias, linear_bias, dtype,
use_int8_kv_cache, tp_group, tp_size)
self.hidden_dropout = 0.0
# Layernorm on the attention output
self.post_attention_layernorm = LayerNormFunc(self.hidden_size,
eps=eps,
dtype=dtype)
self.mlp = MLP(self.hidden_size, ffn_hiden_size, act_func, linear_bias,
dtype)
def forward(self,
hidden_states,
rotary_pos_emb,
use_cache=True,
kv_cache_params=None,
attention_params=None):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, kv_cache = self.self_attention(
layernorm_output,
rotary_pos_emb,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = residual + mlp_output
return output, kv_cache
class ChatGLM2Transformer(Module):
"""Transformer class."""
def __init__(self,
hidden_size,
num_attention_heads,
kv_channels=128,
multi_query_group_num=2,
apply_query_key_layer_scaling=False,
attention_mask_type=AttentionMaskType.causal,
qkv_bias=True,
linear_bias=False,
use_int8_kv_cache=False,
tp_group=None,
tp_size=1,
ffn_hiden_size=13696,
num_layers=28,
eps=1e-5,
act_func='swiglu',
dtype=trt.float16,
quant_mode=QuantMode(0)):
super(ChatGLM2Transformer, self).__init__()
self.fp32_residual_connection = False
self.post_layer_norm = True
# Number of layers.
self.num_layers = num_layers
# Transformer layers.
def build_layer(layer_number):
return ChatGLM2Block(hidden_size, num_attention_heads, kv_channels,
multi_query_group_num,
apply_query_key_layer_scaling,
attention_mask_type, qkv_bias, linear_bias,
use_int8_kv_cache, tp_group, tp_size,
ffn_hiden_size, layer_number, eps, act_func,
dtype, quant_mode)
self.layers = ModuleList(
build_layer(i + 1) for i in range(self.num_layers))
if self.post_layer_norm:
self.final_layernorm = RmsNorm(hidden_size, eps=eps, dtype=dtype)
self.gradient_checkpointing = False
def _get_layer(self, layer_number):
return self.layers[layer_number]
def forward(self,
hidden_states,
rotary_pos_emb,
use_cache=True,
kv_cache_params=None,
attention_params=None):
presents = []
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states, kv_cache = layer(
hidden_states,
rotary_pos_emb,
use_cache=use_cache,
kv_cache_params=KeyValueCacheParams(
past_key_value=[kv_cache_params.past_key_value[index]],
kv_cache_block_pointers=[
kv_cache_params.kv_cache_block_pointers[index]
],
host_past_key_value_lengths=kv_cache_params.
host_past_key_value_lengths,
cache_indirection=kv_cache_params.cache_indirection),
attention_params=attention_params)
presents.append(kv_cache)
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states, presents
class ChatGLM2Model(Module):
def __init__(self,
hidden_size,
num_attention_heads,
kv_channels=128,
multi_query_group_num=2,
apply_query_key_layer_scaling=False,
attention_mask_type=AttentionMaskType.causal,
qkv_bias=True,
linear_bias=False,
use_int8_kv_cache=False,
mapping=Mapping(),
ffn_hiden_size=13696,
num_layers=28,
eps=1e-5,
act_func='swiglu',
dtype=trt.float16,
quant_mode=QuantMode(0),
max_seq_length=32768,
vocab_size=65024):
super(ChatGLM2Model, self).__init__()
self.dtype = dtype
self.embedding = Embedding(vocab_size, hidden_size, dtype=dtype)
self.num_layers = num_layers
self.multi_query_group_num = multi_query_group_num
self.kv_channels = kv_channels
# Rotary positional embeddings
self.max_seq_length = max_seq_length
rotary_dim = kv_channels
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, )
self.encoder = ChatGLM2Transformer(
hidden_size, num_attention_heads, kv_channels,
multi_query_group_num, apply_query_key_layer_scaling,
attention_mask_type, qkv_bias, linear_bias, use_int8_kv_cache,
mapping.tp_group, mapping.tp_size, ffn_hiden_size, num_layers, eps,
act_func, dtype, quant_mode)
def forward(
self,
input_ids: Tensor,
position_ids,
use_cache=True,
kv_cache_params=None,
attention_params=None,
):
inputs_embeds = self.embedding(input_ids)
# Rotary positional embeddings
# generate 32768 pos embeddings
# max_seq_length,head_dim/4,2
rotary_pos_emb = self.rotary_pos_emb(self.max_seq_length)
flat_position = position_ids.view(
concat([shape(position_ids, 0) * shape(position_ids, 1)]))
selected_pos_emb = index_select(rotary_pos_emb, 0, flat_position)
# selected batch,seq from rotary_pos_emb
selected_pos_emb = selected_pos_emb.view(
concat([
shape(position_ids, 0),
shape(position_ids, 1),
shape(rotary_pos_emb, 1),
shape(rotary_pos_emb, 2)
]))
# seq,batch
selected_pos_emb = selected_pos_emb.permute((1, 0, 2, 3))
# return inputs_embeds,selected_pos_emb
# Run encoder.
hidden_states, presents = self.encoder(
inputs_embeds,
selected_pos_emb,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
)
return hidden_states, presents
class ChatGLM2HeadModel(ChatGLM2Model, GenerationMixin):
def __init__(self,
hidden_size,
num_attention_heads,
kv_channels=128,
multi_query_group_num=2,
apply_query_key_layer_scaling=False,
attention_mask_type=AttentionMaskType.causal,
qkv_bias=True,
linear_bias=False,
use_int8_kv_cache=False,
mapping=Mapping(),
ffn_hiden_size=13696,
num_layers=28,
eps=1e-5,
act_func='swiglu',
dtype=trt.float16,
quant_mode=QuantMode(0),
max_seq_length=32768,
vocab_size=65024,
use_cache=True,
kv_cache_block_pointers=None):
if isinstance(dtype, str):
self._kv_dtype = str_dtype_to_trt(dtype)
else:
assert isinstance(dtype, trt.DataType)
self._kv_dtype = dtype
self._dtype = self._kv_dtype
if quant_mode.has_int8_kv_cache():
self._kv_dtype = str_dtype_to_trt('int8')
elif quant_mode.has_fp8_kv_cache():
self._kv_dtype = str_dtype_to_trt('fp8')
self.use_cache = use_cache
self.kv_cache_block_pointers = kv_cache_block_pointers
self.quant_mode = quant_mode
self._num_layers = num_layers
self._num_heads = num_attention_heads
self._hidden_size = hidden_size
self._vocab_size = vocab_size
self._tp_size = mapping.tp_size
super().__init__(hidden_size, num_attention_heads, kv_channels,
multi_query_group_num, apply_query_key_layer_scaling,
attention_mask_type, qkv_bias, linear_bias,
use_int8_kv_cache, mapping, ffn_hiden_size, num_layers,
eps, act_func, dtype, quant_mode, max_seq_length,
vocab_size)
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
self.lm_head = ColumnLinear(hidden_size,
vocab_size_padded,
bias=False,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True)
def forward(self,
input_ids=None,
position_ids=None,
last_token_ids=None,
kv_cache_params=None,
attention_params=None):
hidden_states = super().forward(input_ids, position_ids, self.use_cache,
kv_cache_params, attention_params)
if self.use_cache:
hidden_states, presents = hidden_states
hidden_states = gather_last_token_logits(
hidden_states, last_token_ids,
default_net().plugin_config.remove_input_padding)
lm_logits = self.lm_head(hidden_states)
lm_logits.mark_output('logits', self._dtype)
if default_net().plugin_config.paged_kv_cache == False:
for i, present in enumerate(presents):
present.mark_output(f'present_key_value_{i}', self._kv_dtype)
return (lm_logits, presents)
return lm_logits
def prepare_inputs(self,
max_batch_size,
max_input_len,
max_new_tokens,
use_cache,
max_beam_width: int = 1):
'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
ranges of the dimensions of when using TRT dynamic shapes.
@return: a list contains values which can be fed into the self.forward()
'''
# Prepare inputs
head_size = self._hidden_size // self._num_heads
num_heads = self._num_heads // self._tp_size
remove_input_padding = default_net().plugin_config.remove_input_padding
use_gpt_attention_plugin = default_net(
).plugin_config.gpt_attention_plugin
use_gemm_plugin = default_net().plugin_config.gemm_plugin
model_inputs = self.prepare_basic_inputs(
max_batch_size,
max_beam_width,
max_input_len,
max_new_tokens,
num_heads,
head_size,
self.num_layers,
self._kv_dtype,
remove_input_padding,
use_gpt_attention_plugin,
use_gemm_plugin=use_gemm_plugin)
return (model_inputs['input_ids'], model_inputs['position_ids'],
model_inputs['last_token_ids'],
KeyValueCacheParams(
past_key_value=model_inputs['past_key_value'],
host_past_key_value_lengths=model_inputs[
'host_past_key_value_lengths'],
kv_cache_block_pointers=model_inputs[
'kv_cache_block_pointers_list'],
cache_indirection=model_inputs['cache_indirection'],
),
AttentionParams(
sequence_length=model_inputs['sequence_length'],
context_lengths=model_inputs['context_lengths'],
host_context_lengths=model_inputs['host_context_lengths'],
max_context_length=max_input_len,
host_request_types=model_inputs['host_request_types']))