mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-29 15:15:08 +08:00
Co-authored-by: Rong Zhou <130957722+ReginaZh@users.noreply.github.com> Co-authored-by: Onur Galoglu <33498883+ogaloglu@users.noreply.github.com> Co-authored-by: Fabian Joswig <fjosw@users.noreply.github.com>
208 lines
9.1 KiB
Python
208 lines
9.1 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from dataclasses import asdict
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import torch
|
|
|
|
from tensorrt_llm._utils import torch_dtype_to_str
|
|
from tensorrt_llm.functional import PositionEmbeddingType
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.models.deci.convert import hf_block_config_to_layer_config
|
|
from tensorrt_llm.models.deci.layer_config import (AttentionConfig,
|
|
AttentionImplementation,
|
|
DeciLayerConfig, FFNConfig)
|
|
from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig
|
|
|
|
|
|
class DeciConfig(PretrainedConfig):
|
|
|
|
def __init__(self,
|
|
*,
|
|
architecture: str = 'DeciLMForCausalLM',
|
|
dtype: str,
|
|
hidden_size: int,
|
|
num_hidden_layers: int,
|
|
num_attention_heads: int,
|
|
vocab_size: int,
|
|
hidden_act: str = 'gelu',
|
|
logits_dtype: str = 'float32',
|
|
norm_epsilon: float = 0.00001,
|
|
position_embedding_type: Union[
|
|
PositionEmbeddingType,
|
|
str] = PositionEmbeddingType.rope_gpt_neox,
|
|
rotary_base: float = 10000.0,
|
|
rotary_scaling: Optional[dict] = None,
|
|
max_position_embeddings: int,
|
|
num_key_value_heads: Optional[int] = None,
|
|
intermediate_size: Optional[int] = None,
|
|
mapping: Optional[Union[Mapping, dict]] = None,
|
|
quantization: Optional[Union[QuantConfig, dict]] = None,
|
|
use_parallel_embedding: bool = False,
|
|
embedding_sharding_dim: int = 0,
|
|
share_embedding_table: bool = False,
|
|
head_size: Optional[int] = None,
|
|
qk_layernorm: bool = False,
|
|
layer_configs: Optional[List[Union[DeciLayerConfig,
|
|
Dict[str,
|
|
Dict[str,
|
|
Any]]]]] = None,
|
|
**kwargs):
|
|
super().__init__(architecture=architecture,
|
|
dtype=dtype,
|
|
hidden_size=hidden_size,
|
|
num_hidden_layers=num_hidden_layers,
|
|
num_attention_heads=num_attention_heads,
|
|
vocab_size=vocab_size,
|
|
hidden_act=hidden_act,
|
|
logits_dtype=logits_dtype,
|
|
norm_epsilon=norm_epsilon,
|
|
position_embedding_type=position_embedding_type,
|
|
max_position_embeddings=max_position_embeddings,
|
|
num_key_value_heads=num_key_value_heads,
|
|
intermediate_size=intermediate_size,
|
|
mapping=mapping,
|
|
quantization=quantization,
|
|
use_parallel_embedding=use_parallel_embedding,
|
|
embedding_sharding_dim=embedding_sharding_dim,
|
|
share_embedding_table=share_embedding_table,
|
|
head_size=head_size,
|
|
qk_layernorm=qk_layernorm,
|
|
**kwargs)
|
|
|
|
self.rotary_base = rotary_base
|
|
self.rotary_scaling = rotary_scaling
|
|
|
|
if layer_configs is not None:
|
|
assert len(
|
|
layer_configs
|
|
) == num_hidden_layers, f"num_hidden_layers ({num_hidden_layers}) must match len(layer_configs) ({len(layer_configs)})"
|
|
|
|
self.layer_configs = self._ensure_layer_configs(layer_configs)
|
|
else:
|
|
self.layer_configs = None
|
|
|
|
# HACK: this is needed for many parts of the code
|
|
self.layer_types = [
|
|
AttentionImplementation(
|
|
self.get_layer_config(layer_idx).attention.impl).value
|
|
for layer_idx in range(self.num_hidden_layers)
|
|
]
|
|
|
|
def _ensure_layer_configs(
|
|
self, layer_configs: List[Union[DeciLayerConfig, Dict[str, Any]]]
|
|
) -> List[DeciLayerConfig]:
|
|
return [
|
|
DeciLayerConfig.from_dict(c) if isinstance(c, dict) else c
|
|
for c in layer_configs
|
|
]
|
|
|
|
def to_dict(self):
|
|
output = super().to_dict()
|
|
if self.layer_configs is not None:
|
|
output["layer_configs"] = [asdict(c) for c in self.layer_configs]
|
|
return output
|
|
|
|
def get_layer_config(self, layer_idx: int) -> DeciLayerConfig:
|
|
if self.layer_configs is not None:
|
|
conf = self.layer_configs[layer_idx]
|
|
else:
|
|
conf = DeciLayerConfig()
|
|
|
|
attention_impl = conf.attention.impl
|
|
num_key_value_heads = conf.attention.num_key_value_heads or self.num_key_value_heads
|
|
ffn_impl = conf.ffn.impl
|
|
intermediate_size = conf.ffn.intermediate_size or self.intermediate_size
|
|
|
|
return DeciLayerConfig(
|
|
attention=AttentionConfig(impl=attention_impl,
|
|
num_key_value_heads=num_key_value_heads),
|
|
ffn=FFNConfig(impl=ffn_impl, intermediate_size=intermediate_size))
|
|
|
|
def get_layer_num_kv_heads(self, layer_idx) -> int:
|
|
layer_config = self.get_layer_config(layer_idx)
|
|
assert layer_config.is_attention_layer, f"Layer {layer_idx} is not an attention layer"
|
|
return layer_config.attention.num_key_value_heads or self.num_key_value_heads
|
|
|
|
@classmethod
|
|
def from_hugging_face(
|
|
cls,
|
|
hf_config_or_dir: Union[str, 'transformers.PretrainedConfig'],
|
|
dtype: str = 'auto',
|
|
mapping: Optional[Mapping] = None,
|
|
quant_config: Optional[QuantConfig] = None,
|
|
trust_remote_code: bool = False,
|
|
**kwargs):
|
|
import transformers
|
|
|
|
if isinstance(hf_config_or_dir, transformers.PretrainedConfig):
|
|
hf_config = hf_config_or_dir
|
|
else:
|
|
hf_config = transformers.AutoConfig.from_pretrained(
|
|
hf_config_or_dir, trust_remote_code=trust_remote_code)
|
|
|
|
assert hf_config.model_type == "deci", f"Unsupported model type: {hf_config.model_type}"
|
|
|
|
block_configs = getattr(hf_config, "block_configs", None)
|
|
if block_configs is not None:
|
|
layer_configs = [
|
|
hf_block_config_to_layer_config(block_config,
|
|
hf_config.num_attention_heads,
|
|
hf_config.hidden_size)
|
|
for block_config in block_configs
|
|
]
|
|
else:
|
|
# older deci arch
|
|
num_key_value_heads_per_layer = getattr(
|
|
hf_config, "num_key_value_heads_per_layer", None)
|
|
if num_key_value_heads_per_layer is not None:
|
|
layer_configs = [
|
|
DeciLayerConfig(attention=AttentionConfig(
|
|
num_key_value_heads=num_key_value_heads))
|
|
for num_key_value_heads in num_key_value_heads_per_layer
|
|
]
|
|
else:
|
|
layer_configs = None
|
|
|
|
if dtype == 'auto':
|
|
dtype = getattr(hf_config, 'torch_dtype', "float16")
|
|
if isinstance(dtype, torch.dtype):
|
|
dtype = torch_dtype_to_str(dtype)
|
|
if dtype == 'float32':
|
|
dtype = 'float16'
|
|
if dtype == 'bfloat16' and torch.cuda.get_device_properties(
|
|
0).major < 8:
|
|
logger.warning(
|
|
"Pre SM 80 GPUs do not support bfloat16, fallback to float16")
|
|
dtype = 'float16'
|
|
|
|
return cls(dtype=dtype,
|
|
hidden_size=hf_config.hidden_size,
|
|
hidden_act=hf_config.hidden_act,
|
|
intermediate_size=hf_config.intermediate_size,
|
|
num_attention_heads=hf_config.num_attention_heads,
|
|
num_hidden_layers=hf_config.num_hidden_layers,
|
|
num_key_value_heads=hf_config.num_key_value_heads,
|
|
norm_epsilon=hf_config.rms_norm_eps,
|
|
rotary_scaling=hf_config.rope_scaling,
|
|
rotary_base=hf_config.rope_theta,
|
|
vocab_size=hf_config.vocab_size,
|
|
max_position_embeddings=hf_config.max_position_embeddings,
|
|
mapping=mapping,
|
|
quantization=quant_config,
|
|
layer_configs=layer_configs,
|
|
**kwargs)
|