TensorRT-LLMs/tensorrt_llm/_torch/model_config.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

65 lines
2.1 KiB
Python

import json
import os
from dataclasses import dataclass, field
from typing import Dict, Generic, Optional, TypeVar
import transformers
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)
@dataclass(kw_only=True)
class ModelConfig(Generic[TConfig]):
pretrained_config: Optional[TConfig] = None
mapping: Mapping = field(default_factory=Mapping)
quant_config: QuantConfig = field(default_factory=QuantConfig)
# TODO(qijun): support per linear layer quantization
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
attn_backend: str = 'TRTLLM'
@property
def fuse_pos_embd(self):
if self.attn_backend == 'TRTLLM':
return True
elif self.attn_backend == 'FLASHINFER':
return False
return False
def get_quant_config(self, name: Optional[str] = None) -> QuantConfig:
if name is None or self.per_layer_quant_configs is None:
return self.quant_config
if name in self.per_layer_quant_configs:
return self.per_layer_quant_configs[name]
raise ValueError(f'quant config of {name} is not found')
@classmethod
def from_pretrained(cls,
checkpoint_dir: str,
trust_remote_code=False,
**kwargs):
pretrained_config = transformers.AutoConfig.from_pretrained(
checkpoint_dir,
trust_remote_code=trust_remote_code,
)
quant_config = QuantConfig()
quant_config_file = os.path.join(checkpoint_dir, 'hf_quant_config.json')
if os.path.exists(quant_config_file):
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
quant_config.quant_algo = quant_config_dict['quantization'][
'quant_algo']
quant_config.kv_cache_quant_algo = quant_config_dict[
'quantization']['kv_cache_quant_algo']
return cls(pretrained_config=pretrained_config,
quant_config=quant_config,
**kwargs)