TensorRT-LLMs/tensorrt_llm/models/qwen/model.py
Kaiyu Xie bca9a33b02
Update TensorRT-LLM (#2008)
* Update TensorRT-LLM

---------

Co-authored-by: Timur Abishev <abishev.timur@gmail.com>
Co-authored-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
Co-authored-by: Saeyoon Oh <saeyoon.oh@furiosa.ai>
Co-authored-by: hattizai <hattizai@gmail.com>
2024-07-23 23:05:09 +08:00

361 lines
14 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from tensorrt_llm.lora_manager import LoraConfig, use_lora
from ..._utils import pad_vocab_size
from ...functional import Tensor, recv, send, sigmoid
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, GatedMLP, RmsNorm, RowLinear)
from ...mapping import Mapping
from ...module import Module
from ...quantization import W8A8_SQ_PLUGIN_LIST, QuantAlgo
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
QuantConfig, check_share_embedding)
from .config import QWenConfig
from .convert import (load_hf_qwen, load_weights_from_hf_gptq_model,
load_weights_from_hf_model)
class QWenDecoderLayer(Module):
def __init__(self, config: QWenConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.config = config
dtype = config.dtype
tp_group = config.mapping.tp_group
tp_size = config.mapping.tp_size
self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=dtype)
layers_range = config.mapping.pp_layers(config.num_hidden_layers)
local_layer_idx = layer_idx - layers_range[0]
self.attention = Attention(
local_layer_idx=local_layer_idx,
hidden_size=config.hidden_size,
attention_head_size=config.head_size,
num_attention_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
dtype=dtype,
attention_mask_type=AttentionMaskType.causal,
bias=config.attn_bias,
position_embedding_type=config.position_embedding_type,
rotary_embedding_base=config.rotary_base,
rotary_embedding_scaling=config.rotary_scaling,
tp_group=tp_group,
tp_size=tp_size,
quant_mode=config.quant_mode,
dense_bias=False)
ClsMLP = GatedMLP
mlp_kwargs = {}
if config.moe.has_moe():
ClsMLP = MOE
mlp_kwargs = {
"moe_config": config.moe,
"mapping": config.mapping,
}
if config.qwen_type == 'qwen2_moe':
self.shared_expert = MLP(
hidden_size=config.hidden_size,
ffn_hidden_size=config.moe_shared_expert_intermediate_size,
hidden_act=config.hidden_act,
dtype=dtype,
bias=False,
tp_group=tp_group,
tp_size=tp_size,
quant_mode=config.quant_mode)
self.shared_expert_gate = RowLinear(config.hidden_size,
1,
bias=False,
dtype=dtype,
tp_group=None,
tp_size=1)
# Qwen's real inter_size depends on qwen_type
if self.config.qwen_type == 'qwen':
intermediate_size = config.intermediate_size // 2
elif self.config.qwen_type == 'qwen2_moe':
intermediate_size = config.moe_intermediate_size
else:
intermediate_size = config.intermediate_size
self.mlp = ClsMLP(hidden_size=config.hidden_size,
ffn_hidden_size=intermediate_size,
hidden_act=config.hidden_act,
dtype=dtype,
bias=config.mlp_bias,
tp_group=tp_group,
tp_size=tp_size,
quant_mode=config.quant_mode,
**mlp_kwargs)
self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=dtype)
def forward(
self,
hidden_states: Tensor,
attention_mask=None,
use_cache=False,
kv_cache_params=None,
attention_params=None,
lora_layer_params=None,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attention_output = self.attention(
hidden_states,
attention_mask=attention_mask,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
lora_layer_params=lora_layer_params,
)
if use_cache:
attention_output, presents = attention_output
hidden_states = residual + attention_output
residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)
shared_output = None
if self.config.qwen_type == 'qwen2_moe':
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
shared_output = sigmoid(
self.shared_expert_gate(hidden_states)) * shared_output
hidden_states = self.mlp(hidden_states,
lora_layer_params=lora_layer_params)
if shared_output is not None:
hidden_states = hidden_states + shared_output
hidden_states = residual + hidden_states
if use_cache:
return (hidden_states, presents)
return hidden_states
class QWenModel(Module):
def __init__(self, config: QWenConfig) -> None:
super().__init__()
self.mapping = config.mapping
if self.mapping.is_first_pp_rank():
self.vocab_embedding = Embedding(config.vocab_size,
config.hidden_size,
dtype=config.dtype)
self.layers = DecoderLayerList(QWenDecoderLayer, config)
if self.mapping.is_last_pp_rank():
self.ln_f = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=config.dtype)
def forward(self,
input_ids: Tensor,
position_ids=None,
use_cache=False,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None,
prompt_embedding_table: Optional[Tensor] = None,
prompt_tasks: Optional[Tensor] = None,
prompt_vocab_size: Optional[Tensor] = None,
lora_params=None):
ptuning_args = [
prompt_embedding_table, prompt_tasks, prompt_vocab_size
] if prompt_embedding_table is not None else []
if self.mapping.is_first_pp_rank():
hidden_states = self.vocab_embedding(input_ids, *ptuning_args)
else:
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
hidden_states = self.layers.forward(hidden_states,
use_cache=use_cache,
attention_mask=attention_mask,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
lora_params=lora_params)
if use_cache:
hidden_states, presents = hidden_states
if self.mapping.is_last_pp_rank():
hidden_states = self.ln_f(hidden_states)
else:
hidden_states = send(hidden_states, self.mapping.next_pp_rank())
if use_cache:
return (hidden_states, tuple(presents))
return hidden_states
class QWenForCausalLM(DecoderModelForCausalLM):
config_class = QWenConfig
def __init__(self, config: QWenConfig):
transformer = QWenModel(config)
vocab_size_padded = pad_vocab_size(config.vocab_size,
config.mapping.tp_size)
if config.mapping.is_last_pp_rank():
lm_head = ColumnLinear(config.hidden_size,
vocab_size_padded,
bias=False,
dtype=config.dtype,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
gather_output=True)
else:
lm_head = None
self.quant_mode = config.quant_mode
self.mapping = config.mapping
if config.qwen_type == 'qwen':
self.trtllm_modules_to_hf_modules = {
"attn_qkv": "c_attn",
"attn_dense": "attn.c_proj",
"mlp_h_to_4h": "w2",
"mlp_4h_to_h": "mlp.c_proj",
"mlp_gate": "w1",
}
else:
self.trtllm_modules_to_hf_modules = None
super().__init__(config, transformer, lm_head)
@classmethod
def from_hugging_face(
cls,
hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'],
dtype: str = 'auto',
mapping: Optional[Mapping] = None,
quant_config: Optional[QuantConfig] = None,
use_hf_gptq_checkpoint=False,
**kwargs):
''' Create a QWenForCausalLM object from give parameters
'''
import transformers
load_model_on_cpu = kwargs.pop('load_model_on_cpu', False)
assert hf_model_or_dir is not None
use_preloading = isinstance(hf_model_or_dir,
transformers.PreTrainedModel)
if use_preloading:
hf_model = hf_model_or_dir
hf_config_or_dir = hf_model.config
else:
hf_model_dir = hf_model_or_dir
hf_config_or_dir = hf_model_or_dir
config = QWenConfig.from_hugging_face(hf_config_or_dir,
dtype=dtype,
mapping=mapping,
quant_config=quant_config,
**kwargs)
if not use_preloading:
hf_model = load_hf_qwen(hf_model_dir, load_model_on_cpu)
if use_hf_gptq_checkpoint:
weights = load_weights_from_hf_gptq_model(hf_model, config)
else:
weights = load_weights_from_hf_model(hf_model, config)
check_share_embedding(weights, config)
model = QWenForCausalLM(config)
model.load(weights)
return model
def default_plugin_config(self, **kwargs):
plugin_config = super().default_plugin_config(**kwargs)
if self.quant_mode.is_int4_weight_only_per_group():
plugin_config.weight_only_groupwise_quant_matmul_plugin = 'auto'
return plugin_config
@classmethod
def quantize(
cls,
hf_model_dir: str,
output_dir: str,
dtype: str = 'auto',
mapping: Optional[Mapping] = None,
quant_config: Optional[QuantConfig] = None,
*,
calib_dataset='cnn_dailymail',
calib_batches=512,
calib_batch_size=1,
calib_max_seq_length=512,
random_seed=1234,
tokenizer_max_seq_length=2048,
**kwargs,
):
DEFAULT_MODELOPT_FLOW = [
QuantAlgo.W4A16_AWQ, QuantAlgo.FP8, QuantAlgo.W8A8_SQ_PER_CHANNEL,
QuantAlgo.W4A8_AWQ
]
config = QWenConfig.from_hugging_face(hf_model_dir,
dtype=dtype,
mapping=mapping,
quant_config=quant_config,
**kwargs)
if quant_config.quant_algo in DEFAULT_MODELOPT_FLOW:
super().quantize(hf_model_dir,
output_dir,
dtype=config.dtype,
mapping=config.mapping,
quant_config=config.quantization,
calib_dataset=calib_dataset,
calib_batches=calib_batches,
calib_batch_size=calib_batch_size,
calib_max_seq_length=calib_max_seq_length,
random_seed=random_seed,
tokenizer_max_seq_length=tokenizer_max_seq_length)
else:
# non-modelopt, the legacy TRT-LLM native quantization algorithm:
# sq, int4/int8 weights only, int8 kv cache
NATIVE_QUANT_FLOW = [QuantAlgo.W4A16, QuantAlgo.W8A16, None
] + W8A8_SQ_PLUGIN_LIST
is_valid_native_quant = (quant_config.quant_algo in NATIVE_QUANT_FLOW) and \
(quant_config.kv_cache_quant_algo in [QuantAlgo.INT8, None])
assert quant_config.quant_algo is not None or quant_config.kv_cache_quant_algo is not None, \
"There is no point to call the quantize function if both quant_algo and kv_cache_quant_algo is None"
assert is_valid_native_quant, f"Internal error: shall call Modelopt for this quantization {quant_config}"
from . import convert
convert.quantize(hf_model_dir,
output_dir,
config=config,
calib_dataset=calib_dataset)
def use_lora(self, lora_config: LoraConfig):
use_lora(self, lora_config, self.trtllm_modules_to_hf_modules)