mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
256 lines
10 KiB
Python
256 lines
10 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import Optional, Union
|
|
|
|
from ..._utils import pad_vocab_size
|
|
from ...functional import Tensor
|
|
from ...layers import (Attention, AttentionMaskType, ColumnLinear, Embedding,
|
|
GatedMLP, RmsNorm)
|
|
from ...mapping import Mapping
|
|
from ...module import Module
|
|
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
|
PretrainedConfig, QuantConfig,
|
|
check_share_embedding)
|
|
from .config import BaichuanConfig
|
|
from .convert import load_weights_from_hf_model
|
|
|
|
|
|
class BaichuanDecoderLayer(Module):
|
|
|
|
def __init__(self, config: PretrainedConfig, layer_idx):
|
|
super().__init__()
|
|
self.layer_idx = layer_idx
|
|
self.config = config
|
|
hidden_size = config.hidden_size
|
|
dtype = config.dtype
|
|
position_embedding_type = config.position_embedding_type
|
|
tp_group = config.mapping.tp_group
|
|
tp_size = config.mapping.tp_size
|
|
tp_rank = config.mapping.tp_rank
|
|
quant_mode = config.quant_mode
|
|
|
|
self.input_layernorm = RmsNorm(normalized_shape=hidden_size,
|
|
eps=config.norm_epsilon,
|
|
dtype=dtype)
|
|
layers_range = config.mapping.pp_layers(config.num_hidden_layers)
|
|
local_layer_idx = layer_idx - layers_range[0]
|
|
self.attention = Attention(
|
|
local_layer_idx=local_layer_idx,
|
|
hidden_size=hidden_size,
|
|
num_attention_heads=config.num_attention_heads,
|
|
num_kv_heads=config.num_key_value_heads,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
dtype=dtype,
|
|
attention_mask_type=AttentionMaskType.causal,
|
|
bias=False,
|
|
position_embedding_type=position_embedding_type,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
tp_rank=tp_rank,
|
|
quant_mode=quant_mode)
|
|
|
|
self.mlp = GatedMLP(hidden_size=hidden_size,
|
|
ffn_hidden_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
dtype=dtype,
|
|
bias=False,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode)
|
|
self.post_layernorm = RmsNorm(normalized_shape=hidden_size,
|
|
eps=config.norm_epsilon,
|
|
dtype=dtype)
|
|
|
|
def forward(self,
|
|
hidden_states: Tensor,
|
|
attention_mask=None,
|
|
use_cache=False,
|
|
kv_cache_params=None,
|
|
attention_params=None):
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
attention_output = self.attention(hidden_states,
|
|
attention_mask=attention_mask,
|
|
use_cache=use_cache,
|
|
kv_cache_params=kv_cache_params,
|
|
attention_params=attention_params)
|
|
|
|
if use_cache:
|
|
attention_output, presents = attention_output
|
|
|
|
hidden_states = residual + attention_output
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.post_layernorm(hidden_states)
|
|
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
hidden_states = residual + hidden_states
|
|
if use_cache:
|
|
return (hidden_states, presents)
|
|
return hidden_states
|
|
|
|
|
|
class BaichuanModel(Module):
|
|
|
|
def __init__(self, config: PretrainedConfig):
|
|
super().__init__()
|
|
hidden_size = config.hidden_size
|
|
|
|
self.vocab_embedding = Embedding(config.vocab_size,
|
|
config.hidden_size,
|
|
dtype=config.dtype)
|
|
|
|
self.layers = DecoderLayerList(BaichuanDecoderLayer, config)
|
|
self.ln_f = RmsNorm(normalized_shape=hidden_size,
|
|
eps=config.norm_epsilon,
|
|
dtype=config.dtype)
|
|
|
|
def forward(self,
|
|
input_ids: Tensor,
|
|
position_ids=None,
|
|
use_cache=False,
|
|
attention_mask=None,
|
|
kv_cache_params=None,
|
|
attention_params=None,
|
|
prompt_embedding_table=None,
|
|
prompt_tasks=None,
|
|
prompt_vocab_size=None):
|
|
args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size
|
|
] if prompt_embedding_table is not None else []
|
|
hidden_states = self.vocab_embedding(input_ids, *args)
|
|
|
|
hidden_states = self.layers(hidden_states,
|
|
use_cache=use_cache,
|
|
attention_mask=attention_mask,
|
|
kv_cache_params=kv_cache_params,
|
|
attention_params=attention_params)
|
|
|
|
if use_cache:
|
|
hidden_states, presents = hidden_states
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
if use_cache:
|
|
return (hidden_states, tuple(presents))
|
|
return hidden_states
|
|
|
|
|
|
class BaichuanForCausalLM(DecoderModelForCausalLM):
|
|
config_class = BaichuanConfig
|
|
|
|
def __init__(self, config: PretrainedConfig):
|
|
transformer = BaichuanModel(config)
|
|
vocab_size_padded = pad_vocab_size(config.vocab_size,
|
|
config.mapping.tp_size)
|
|
lm_head = ColumnLinear(config.hidden_size,
|
|
vocab_size_padded,
|
|
bias=False,
|
|
dtype=config.dtype,
|
|
tp_group=config.mapping.tp_group,
|
|
tp_size=config.mapping.tp_size,
|
|
gather_output=True)
|
|
super().__init__(config, transformer, lm_head)
|
|
|
|
@classmethod
|
|
def from_hugging_face(
|
|
cls,
|
|
hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'],
|
|
dtype: str = 'auto',
|
|
mapping: Optional[Mapping] = None,
|
|
quant_config: Optional[QuantConfig] = None,
|
|
**kwargs):
|
|
''' Create a BaichuanForCausalLM object from give parameters
|
|
'''
|
|
import transformers
|
|
|
|
assert hf_model_or_dir is not None
|
|
if isinstance(hf_model_or_dir, transformers.PreTrainedModel):
|
|
hf_model = hf_model_or_dir
|
|
hf_config_or_dir = hf_model.config
|
|
else:
|
|
trust_remote_code = kwargs.pop('trust_remote_code', True)
|
|
|
|
hf_model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
hf_model_or_dir,
|
|
trust_remote_code=trust_remote_code,
|
|
torch_dtype='auto')
|
|
hf_config_or_dir = hf_model_or_dir
|
|
|
|
config = BaichuanConfig.from_hugging_face(hf_config_or_dir,
|
|
dtype=dtype,
|
|
mapping=mapping,
|
|
quant_config=quant_config,
|
|
**kwargs)
|
|
|
|
weights = load_weights_from_hf_model(hf_model, config)
|
|
|
|
check_share_embedding(weights, config)
|
|
model = cls(config)
|
|
model.load(weights)
|
|
return model
|
|
|
|
@classmethod
|
|
def quantize(
|
|
cls,
|
|
hf_model_dir: str,
|
|
output_dir: str,
|
|
dtype: str = 'auto',
|
|
mapping: Optional[Mapping] = None,
|
|
quant_config: Optional[QuantConfig] = None,
|
|
*,
|
|
device: str = 'cuda',
|
|
calib_dataset: str = 'cnn_dailymail',
|
|
calib_batches: int = 512,
|
|
calib_batch_size: int = 1,
|
|
calib_max_seq_length: int = 512,
|
|
random_seed: int = 1234,
|
|
tokenizer_max_seq_length: int = 2048,
|
|
**kwargs,
|
|
):
|
|
if quant_config.requires_modelopt_quantization:
|
|
# modelopt quantization flow
|
|
super().quantize(hf_model_dir,
|
|
output_dir,
|
|
dtype=dtype,
|
|
mapping=mapping,
|
|
quant_config=quant_config,
|
|
device=device,
|
|
calib_dataset=calib_dataset,
|
|
calib_batches=calib_batches,
|
|
calib_batch_size=calib_batch_size,
|
|
calib_max_seq_length=calib_max_seq_length,
|
|
random_seed=random_seed,
|
|
tokenizer_max_seq_length=tokenizer_max_seq_length)
|
|
elif quant_config.requires_calibration:
|
|
# non-modelopt quantization flow
|
|
from .convert import quantize
|
|
|
|
config = BaichuanConfig.from_hugging_face(hf_model_dir,
|
|
dtype=dtype,
|
|
mapping=mapping,
|
|
quant_config=quant_config,
|
|
**kwargs)
|
|
quantize(hf_model_dir,
|
|
output_dir,
|
|
config=config,
|
|
device=device,
|
|
calib_dataset=calib_dataset)
|
|
else:
|
|
raise ValueError(
|
|
f"The quant_config ({quant_config}) does not require calibration, try {cls.__name__}.from_hugging_face instead."
|
|
)
|