mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
298 lines
11 KiB
Python
298 lines
11 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
from tqdm import tqdm
|
|
from transformers import OPTConfig
|
|
from transformers.activations import ACT2FN
|
|
|
|
from tensorrt_llm._torch.attention_backend import AttentionMetadata
|
|
from tensorrt_llm._torch.model_config import ModelConfig
|
|
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
|
|
DecoderModelForCausalLM,
|
|
duplicate_kv_weight,
|
|
register_auto_model)
|
|
from tensorrt_llm._torch.modules.attention import Attention
|
|
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
|
|
from tensorrt_llm._torch.modules.embedding import Embedding
|
|
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
|
|
|
|
|
|
class LayerNorm(nn.LayerNorm):
|
|
|
|
def reset_parameters(self) -> None:
|
|
# Skip the initialization operations that conflict with MetaInitMode
|
|
pass
|
|
|
|
|
|
class OPTAttention(Attention):
|
|
|
|
def __init__(
|
|
self,
|
|
model_config: ModelConfig[OPTConfig],
|
|
layer_idx: Optional[int] = None,
|
|
):
|
|
config = model_config.pretrained_config
|
|
super().__init__(
|
|
hidden_size=config.hidden_size,
|
|
num_attention_heads=config.num_attention_heads,
|
|
num_key_value_heads=config.num_attention_heads,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
bias=config.enable_bias,
|
|
layer_idx=layer_idx,
|
|
dtype=config.torch_dtype,
|
|
config=model_config,
|
|
)
|
|
|
|
|
|
class OPTDecoderLayer(DecoderLayer):
|
|
|
|
def __init__(
|
|
self,
|
|
model_config: ModelConfig[OPTConfig],
|
|
layer_idx: int,
|
|
):
|
|
super().__init__()
|
|
config = model_config.pretrained_config
|
|
|
|
self.self_attn = OPTAttention(model_config, layer_idx=layer_idx)
|
|
|
|
self.do_layer_norm_before = config.do_layer_norm_before
|
|
self.activation_fn = ACT2FN[config.activation_function]
|
|
|
|
self.self_attn_layer_norm = LayerNorm(
|
|
config.hidden_size,
|
|
elementwise_affine=config.layer_norm_elementwise_affine,
|
|
dtype=config.torch_dtype)
|
|
self.fc1 = Linear(config.hidden_size,
|
|
config.ffn_dim,
|
|
bias=config.enable_bias,
|
|
dtype=config.torch_dtype,
|
|
mapping=model_config.mapping,
|
|
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
|
quant_config=model_config.get_quant_config(),
|
|
allreduce_strategy=model_config.allreduce_strategy)
|
|
self.fc2 = Linear(config.ffn_dim,
|
|
config.hidden_size,
|
|
bias=config.enable_bias,
|
|
dtype=config.torch_dtype,
|
|
mapping=model_config.mapping,
|
|
tensor_parallel_mode=TensorParallelMode.ROW,
|
|
quant_config=model_config.get_quant_config(),
|
|
allreduce_strategy=model_config.allreduce_strategy)
|
|
self.final_layer_norm = LayerNorm(
|
|
config.hidden_size,
|
|
elementwise_affine=config.layer_norm_elementwise_affine,
|
|
dtype=config.torch_dtype)
|
|
|
|
def forward(
|
|
self,
|
|
position_ids: torch.IntTensor,
|
|
hidden_states: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
# Self Attention
|
|
residual = hidden_states
|
|
|
|
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
if self.do_layer_norm_before:
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
hidden_states = self.self_attn(
|
|
position_ids=None,
|
|
hidden_states=hidden_states,
|
|
attn_metadata=attn_metadata,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# 350m applies layer norm AFTER attention
|
|
if not self.do_layer_norm_before:
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
|
|
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
if self.do_layer_norm_before:
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# 350m applies layer norm AFTER attention
|
|
if not self.do_layer_norm_before:
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class OPTModel(DecoderModel):
|
|
|
|
def __init__(self, model_config: ModelConfig[OPTConfig]):
|
|
super().__init__(model_config)
|
|
config = model_config.pretrained_config
|
|
|
|
self.padding_idx = config.pad_token_id
|
|
self.max_target_positions = config.max_position_embeddings
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = Embedding(
|
|
config.vocab_size,
|
|
config.word_embed_proj_dim,
|
|
dtype=config.torch_dtype,
|
|
mapping=model_config.mapping,
|
|
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
|
)
|
|
self.embed_positions = Embedding(
|
|
config.max_position_embeddings + 2,
|
|
config.hidden_size,
|
|
dtype=config.torch_dtype,
|
|
)
|
|
|
|
if config.word_embed_proj_dim != config.hidden_size:
|
|
self.project_out = nn.Linear(config.hidden_size,
|
|
config.word_embed_proj_dim,
|
|
bias=False,
|
|
dtype=config.torch_dtype)
|
|
else:
|
|
self.project_out = None
|
|
|
|
if config.word_embed_proj_dim != config.hidden_size:
|
|
self.project_in = nn.Linear(config.word_embed_proj_dim,
|
|
config.hidden_size,
|
|
bias=False,
|
|
dtype=config.torch_dtype)
|
|
else:
|
|
self.project_in = None
|
|
|
|
# Note that the only purpose of `config._remove_final_layer_norm` is to
|
|
# keep backward compatibility with checkpoints that have been fine-tuned
|
|
# before transformers v4.20.1
|
|
# see https://github.com/facebookresearch/metaseq/pull/164
|
|
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
|
self.final_layer_norm = LayerNorm(
|
|
config.hidden_size,
|
|
elementwise_affine=config.layer_norm_elementwise_affine,
|
|
dtype=config.torch_dtype)
|
|
else:
|
|
self.final_layer_norm = None
|
|
|
|
self.layers = nn.ModuleList([
|
|
OPTDecoderLayer(model_config, layer_idx)
|
|
for layer_idx in range(config.num_hidden_layers)
|
|
])
|
|
|
|
def forward(
|
|
self,
|
|
attn_metadata: AttentionMetadata,
|
|
input_ids: Optional[torch.IntTensor] = None,
|
|
position_ids: Optional[torch.IntTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError(
|
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if self.project_in is not None:
|
|
inputs_embeds = self.project_in(inputs_embeds)
|
|
|
|
pos_embeds = self.embed_positions(position_ids.squeeze(0) + 2)
|
|
hidden_states = inputs_embeds + pos_embeds
|
|
|
|
# residual = None
|
|
for decoder_layer in self.layers:
|
|
hidden_states = decoder_layer(
|
|
position_ids=position_ids,
|
|
hidden_states=hidden_states,
|
|
attn_metadata=attn_metadata,
|
|
)
|
|
|
|
if self.final_layer_norm is not None:
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
if self.project_out is not None:
|
|
hidden_states = self.project_out(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
@register_auto_model("OPTForCausalLM")
|
|
class OPTForCausalLM(DecoderModelForCausalLM[OPTModel, OPTConfig]):
|
|
|
|
def __init__(
|
|
self,
|
|
model_config: ModelConfig[OPTConfig],
|
|
):
|
|
super().__init__(OPTModel(model_config),
|
|
config=model_config,
|
|
hidden_size=model_config.pretrained_config.hidden_size,
|
|
vocab_size=model_config.pretrained_config.vocab_size)
|
|
|
|
def load_weights(self, weights: dict):
|
|
tp_size = self.model_config.mapping.tp_size
|
|
num_kv_heads = self.model_config.pretrained_config.num_attention_heads
|
|
|
|
def filter_weights(prefix: str, weights: dict):
|
|
result = {}
|
|
for k, v in weights.items():
|
|
if k.startswith(prefix):
|
|
new_k = k[len(prefix) + 1:]
|
|
result[new_k] = v
|
|
return result
|
|
|
|
params_map = {
|
|
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
|
|
'o_proj': ['out_proj']
|
|
}
|
|
|
|
weight_prefix = 'model.decoder'
|
|
if any(name.startswith('decoder') for name, _ in weights.items()):
|
|
weight_prefix = 'decoder'
|
|
|
|
for name, module in tqdm(list(self.named_modules()),
|
|
desc="Loading weights"):
|
|
if len(module._parameters) > 0:
|
|
# skip load weights if tie word embeddings is enabled and layer is lm_head
|
|
if self.config.tie_word_embeddings and name.startswith(
|
|
'lm_head'):
|
|
continue
|
|
|
|
if name.startswith('model'):
|
|
name = name.replace('model', weight_prefix, 1)
|
|
|
|
names = name.split('.')
|
|
if names[-1] in params_map:
|
|
module_weights = []
|
|
for new_name in params_map[names[-1]]:
|
|
fw = filter_weights('.'.join(names[:-1] + [new_name]),
|
|
weights)
|
|
if new_name in ['k_proj', 'v_proj']:
|
|
fw = {
|
|
k:
|
|
duplicate_kv_weight(
|
|
weight=v[:],
|
|
num_kv_heads=num_kv_heads,
|
|
tensor_parallel_size=tp_size)
|
|
if k in ['weight', 'bias'] else v
|
|
for k, v in fw.items()
|
|
}
|
|
module_weights.append(fw)
|
|
module.load_weights(weights=module_weights)
|
|
else:
|
|
module_weights = filter_weights(name, weights)
|
|
if hasattr(module, 'load_weights'):
|
|
module.load_weights(weights=[module_weights])
|
|
else:
|
|
for n, p in module._parameters.items():
|
|
if p is not None:
|
|
p.data.copy_(module_weights[n][:])
|