TensorRT-LLMs/examples/llm-api/out_of_tree_example/modeling_opt.py
Erin e277766f0d
chores: merge examples for v1.0 doc (#5736)
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
2025-07-08 21:00:42 -07:00

298 lines
11 KiB
Python

from typing import Optional
import torch
from torch import nn
from tqdm import tqdm
from transformers import OPTConfig
from transformers.activations import ACT2FN
from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
DecoderModelForCausalLM,
duplicate_kv_weight,
register_auto_model)
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
from tensorrt_llm._torch.modules.embedding import Embedding
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
class LayerNorm(nn.LayerNorm):
def reset_parameters(self) -> None:
# Skip the initialization operations that conflict with MetaInitMode
pass
class OPTAttention(Attention):
def __init__(
self,
model_config: ModelConfig[OPTConfig],
layer_idx: Optional[int] = None,
):
config = model_config.pretrained_config
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.enable_bias,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
)
class OPTDecoderLayer(DecoderLayer):
def __init__(
self,
model_config: ModelConfig[OPTConfig],
layer_idx: int,
):
super().__init__()
config = model_config.pretrained_config
self.self_attn = OPTAttention(model_config, layer_idx=layer_idx)
self.do_layer_norm_before = config.do_layer_norm_before
self.activation_fn = ACT2FN[config.activation_function]
self.self_attn_layer_norm = LayerNorm(
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine,
dtype=config.torch_dtype)
self.fc1 = Linear(config.hidden_size,
config.ffn_dim,
bias=config.enable_bias,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=model_config.get_quant_config(),
allreduce_strategy=model_config.allreduce_strategy)
self.fc2 = Linear(config.ffn_dim,
config.hidden_size,
bias=config.enable_bias,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=model_config.get_quant_config(),
allreduce_strategy=model_config.allreduce_strategy)
self.final_layer_norm = LayerNorm(
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine,
dtype=config.torch_dtype)
def forward(
self,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
position_ids=None,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
**kwargs,
)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class OPTModel(DecoderModel):
def __init__(self, model_config: ModelConfig[OPTConfig]):
super().__init__(model_config)
config = model_config.pretrained_config
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = Embedding(
config.vocab_size,
config.word_embed_proj_dim,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
)
self.embed_positions = Embedding(
config.max_position_embeddings + 2,
config.hidden_size,
dtype=config.torch_dtype,
)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size,
config.word_embed_proj_dim,
bias=False,
dtype=config.torch_dtype)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(config.word_embed_proj_dim,
config.hidden_size,
bias=False,
dtype=config.torch_dtype)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to
# keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = LayerNorm(
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine,
dtype=config.torch_dtype)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList([
OPTDecoderLayer(model_config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
pos_embeds = self.embed_positions(position_ids.squeeze(0) + 2)
hidden_states = inputs_embeds + pos_embeds
# residual = None
for decoder_layer in self.layers:
hidden_states = decoder_layer(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
return hidden_states
@register_auto_model("OPTForCausalLM")
class OPTForCausalLM(DecoderModelForCausalLM[OPTModel, OPTConfig]):
def __init__(
self,
model_config: ModelConfig[OPTConfig],
):
super().__init__(OPTModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
def load_weights(self, weights: dict):
tp_size = self.model_config.mapping.tp_size
num_kv_heads = self.model_config.pretrained_config.num_attention_heads
def filter_weights(prefix: str, weights: dict):
result = {}
for k, v in weights.items():
if k.startswith(prefix):
new_k = k[len(prefix) + 1:]
result[new_k] = v
return result
params_map = {
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
'o_proj': ['out_proj']
}
weight_prefix = 'model.decoder'
if any(name.startswith('decoder') for name, _ in weights.items()):
weight_prefix = 'decoder'
for name, module in tqdm(list(self.named_modules()),
desc="Loading weights"):
if len(module._parameters) > 0:
# skip load weights if tie word embeddings is enabled and layer is lm_head
if self.config.tie_word_embeddings and name.startswith(
'lm_head'):
continue
if name.startswith('model'):
name = name.replace('model', weight_prefix, 1)
names = name.split('.')
if names[-1] in params_map:
module_weights = []
for new_name in params_map[names[-1]]:
fw = filter_weights('.'.join(names[:-1] + [new_name]),
weights)
if new_name in ['k_proj', 'v_proj']:
fw = {
k:
duplicate_kv_weight(
weight=v[:],
num_kv_heads=num_kv_heads,
tensor_parallel_size=tp_size)
if k in ['weight', 'bias'] else v
for k, v in fw.items()
}
module_weights.append(fw)
module.load_weights(weights=module_weights)
else:
module_weights = filter_weights(name, weights)
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])