mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Altair-Alpha <62340011+Altair-Alpha@users.noreply.github.com>
281 lines
11 KiB
Python
281 lines
11 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from ..._utils import pad_vocab_size
|
|
from ...functional import (Tensor, is_gated_activation, non_gated_version, recv,
|
|
send)
|
|
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
|
|
Embedding, GatedMLP, LayerNorm, MoeConfig,
|
|
PositionEmbeddingType)
|
|
from ...lora_manager import LoraConfig, use_lora
|
|
from ...mapping import Mapping
|
|
from ...module import Module
|
|
from ...quantization import QuantMode
|
|
from ..modeling_utils import DecoderLayerList, DecoderModelForCausalLM
|
|
from .config import GPTConfig
|
|
|
|
|
|
def MLPFactory(hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias=True,
|
|
dtype=None,
|
|
moe_config: MoeConfig = MoeConfig(),
|
|
tp_group=None,
|
|
tp_size=1,
|
|
mapping=Mapping(),
|
|
quant_mode=QuantMode(0),
|
|
inner_layernorm=False,
|
|
eps=1e-05):
|
|
if moe_config.has_moe():
|
|
return MOE(moe_config,
|
|
hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
mapping=mapping,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode)
|
|
MLPClass = GatedMLP if is_gated_activation(hidden_act) else MLP
|
|
hidden_act = non_gated_version(hidden_act)
|
|
return MLPClass(
|
|
hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias,
|
|
dtype,
|
|
tp_group,
|
|
tp_size,
|
|
quant_mode,
|
|
inner_layernorm=inner_layernorm,
|
|
eps=eps,
|
|
)
|
|
|
|
|
|
class GPTDecoderLayer(Module):
|
|
|
|
def __init__(self, config: GPTConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.layer_idx = layer_idx
|
|
self.config = config
|
|
|
|
tp_group = config.mapping.tp_group
|
|
tp_size = config.mapping.tp_size
|
|
tp_rank = config.mapping.tp_rank
|
|
|
|
self.input_layernorm = LayerNorm(normalized_shape=config.hidden_size,
|
|
eps=config.norm_epsilon,
|
|
dtype=config.dtype)
|
|
|
|
layers_range = config.mapping.pp_layers(config.num_hidden_layers)
|
|
local_layer_idx = layer_idx - layers_range[0]
|
|
inner_layernorm = config.inner_layernorm if hasattr(
|
|
config, "inner_layernorm") else False
|
|
attention_head_size = config.head_size if hasattr(config,
|
|
"head_size") else None
|
|
self.attention = Attention(
|
|
local_layer_idx=local_layer_idx,
|
|
hidden_size=config.hidden_size,
|
|
num_attention_heads=config.num_attention_heads,
|
|
num_kv_heads=config.num_key_value_heads,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
num_layers=config.num_hidden_layers,
|
|
q_scaling=config.q_scaling,
|
|
apply_query_key_layer_scaling=config.apply_query_key_layer_scaling,
|
|
dtype=config.dtype,
|
|
attention_mask_type=AttentionMaskType.causal,
|
|
attention_head_size=attention_head_size,
|
|
position_embedding_type=config.position_embedding_type,
|
|
rotary_embedding_percentage=config.rotary_pct,
|
|
rotary_embedding_base=config.rotary_base,
|
|
rotary_embedding_scaling=config.rotary_scaling,
|
|
bias=config.bias,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
tp_rank=tp_rank,
|
|
quant_mode=config.quant_mode,
|
|
qk_layernorm=config.qk_layernorm,
|
|
inner_layernorm=inner_layernorm,
|
|
eps=config.norm_epsilon)
|
|
|
|
mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size
|
|
self.norm_before_bmm1 = config.norm_before_bmm1 if hasattr(
|
|
config, "norm_before_bmm1") else False
|
|
|
|
self.mlp = MLPFactory(hidden_size=config.hidden_size,
|
|
ffn_hidden_size=mlp_hidden_size,
|
|
hidden_act=config.hidden_act,
|
|
dtype=config.dtype,
|
|
bias=config.bias,
|
|
moe_config=config.moe,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
mapping=config.mapping,
|
|
quant_mode=config.quant_mode,
|
|
inner_layernorm=inner_layernorm,
|
|
eps=config.norm_epsilon)
|
|
|
|
self.post_layernorm = LayerNorm(normalized_shape=config.hidden_size,
|
|
eps=config.norm_epsilon,
|
|
dtype=config.dtype)
|
|
|
|
def forward(self,
|
|
hidden_states: Tensor,
|
|
attention_mask=None,
|
|
use_cache=False,
|
|
kv_cache_params=None,
|
|
attention_params=None,
|
|
lora_layer_params=None,
|
|
spec_decoding_params=None):
|
|
|
|
assert isinstance(hidden_states, Tensor)
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
attention_output = self.attention(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
use_cache=use_cache,
|
|
spec_decoding_params=spec_decoding_params,
|
|
kv_cache_params=kv_cache_params,
|
|
attention_params=attention_params,
|
|
lora_layer_params=lora_layer_params,
|
|
norm_before_bmm1=self.norm_before_bmm1)
|
|
|
|
if use_cache:
|
|
attention_output, presents = attention_output
|
|
|
|
hidden_states = residual + attention_output
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.post_layernorm(hidden_states)
|
|
|
|
hidden_states = self.mlp(hidden_states,
|
|
lora_layer_params=lora_layer_params)
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
if use_cache:
|
|
return (hidden_states, presents)
|
|
return hidden_states
|
|
|
|
|
|
class GPTModel(Module):
|
|
|
|
def __init__(self, config: GPTConfig):
|
|
super().__init__()
|
|
self.mapping = config.mapping
|
|
self.position_embedding_type = config.position_embedding_type
|
|
if config.mapping.is_first_pp_rank():
|
|
self.vocab_embedding = Embedding(config.vocab_size,
|
|
config.hidden_size,
|
|
dtype=config.dtype)
|
|
|
|
self.embedding_scale = config.embedding_scale
|
|
|
|
if config.position_embedding_type == PositionEmbeddingType.learned_absolute:
|
|
self.position_embedding = Embedding(
|
|
num_embeddings=config.max_position_embeddings,
|
|
embedding_dim=config.hidden_size,
|
|
dtype=config.dtype)
|
|
|
|
self.layers = DecoderLayerList(GPTDecoderLayer, config)
|
|
|
|
if config.mapping.is_last_pp_rank():
|
|
self.ln_f = LayerNorm(normalized_shape=config.hidden_size,
|
|
eps=config.norm_epsilon,
|
|
dtype=config.dtype)
|
|
|
|
def forward(self,
|
|
input_ids,
|
|
position_ids,
|
|
use_cache=False,
|
|
attention_mask=None,
|
|
kv_cache_params=None,
|
|
attention_params=None,
|
|
hidden_states=None,
|
|
prompt_embedding_table=None,
|
|
prompt_tasks=None,
|
|
prompt_vocab_size=None,
|
|
lora_params=None,
|
|
spec_decoding_params=None):
|
|
if self.mapping.is_first_pp_rank():
|
|
ptuning_args = [
|
|
prompt_embedding_table, prompt_tasks, prompt_vocab_size
|
|
] if prompt_embedding_table is not None else []
|
|
hidden_states = self.vocab_embedding(input_ids, *ptuning_args)
|
|
if self.embedding_scale is not None:
|
|
hidden_states *= self.embedding_scale
|
|
if self.position_embedding_type == PositionEmbeddingType.learned_absolute:
|
|
hidden_states = hidden_states + self.position_embedding(
|
|
position_ids)
|
|
else:
|
|
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
|
|
|
|
hidden_states = self.layers(hidden_states,
|
|
use_cache=use_cache,
|
|
attention_mask=attention_mask,
|
|
kv_cache_params=kv_cache_params,
|
|
attention_params=attention_params,
|
|
lora_params=lora_params,
|
|
spec_decoding_params=spec_decoding_params)
|
|
if use_cache:
|
|
hidden_states, presents = hidden_states
|
|
|
|
if self.mapping.is_last_pp_rank():
|
|
hidden_states = self.ln_f(hidden_states)
|
|
else:
|
|
hidden_states = send(hidden_states, self.mapping.next_pp_rank())
|
|
|
|
if use_cache:
|
|
return (hidden_states, tuple(presents))
|
|
return hidden_states
|
|
|
|
|
|
class GPTForCausalLM(DecoderModelForCausalLM):
|
|
config_class = GPTConfig
|
|
|
|
def __init__(self, config: GPTConfig):
|
|
transformer = GPTModel(config)
|
|
|
|
if config.mapping.is_last_pp_rank():
|
|
vocab_size_padded = pad_vocab_size(config.vocab_size,
|
|
config.mapping.tp_size)
|
|
lm_head = ColumnLinear(config.hidden_size,
|
|
vocab_size_padded,
|
|
bias=False,
|
|
dtype=config.dtype,
|
|
tp_group=config.mapping.tp_group,
|
|
tp_size=config.mapping.tp_size,
|
|
gather_output=True)
|
|
else:
|
|
lm_head = None
|
|
self.trtllm_modules_to_hf_modules = {
|
|
"attn_q": "q_proj",
|
|
"attn_k": "k_proj",
|
|
"attn_v": "v_proj",
|
|
"attn_dense": "o_proj",
|
|
"mlp_h_to_4h": "c_fc",
|
|
"mlp_4h_to_h": "c_proj",
|
|
}
|
|
super().__init__(config, transformer, lm_head)
|
|
|
|
def use_lora(self, lora_config: LoraConfig):
|
|
use_lora(self, lora_config, self.trtllm_modules_to_hf_modules)
|