TensorRT-LLMs/tensorrt_llm/_torch/model_config.py
Jinyang Yuan 992d513bc6
feat: Optionally split MoE inputs into chunks to reduce GPU memory usage (#3104)
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Co-authored-by: raccoonliukai <raccoonliu@tencent.com>
2025-04-01 16:07:02 +08:00

140 lines
5.6 KiB
Python

import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Generic, List, Optional, TypeVar
import torch
import transformers
from tensorrt_llm import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)
@dataclass(kw_only=True)
class ModelConfig(Generic[TConfig]):
pretrained_config: Optional[TConfig] = None
mapping: Mapping = field(default_factory=Mapping)
quant_config: QuantConfig = field(default_factory=QuantConfig)
# TODO(qijun): support per linear layer quantization
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
skip_create_weights: bool = False
is_generation: bool = True
max_num_tokens: int = 8192
moe_max_num_tokens: Optional[int] = None
attn_backend: str = 'TRTLLM'
def __post_init__(self):
if self.pretrained_config and hasattr(self.pretrained_config,
"architectures"):
self.is_generation = self.is_generation_model(
self.pretrained_config.architectures)
@property
def fuse_pos_embd(self):
if self.attn_backend == 'TRTLLM':
return True
elif self.attn_backend == 'FLASHINFER':
return False
return False
@property
def enable_flash_mla(self):
if self.attn_backend == 'TRTLLM':
if hasattr(self.pretrained_config, "kv_lora_rank") and hasattr(
self.pretrained_config, "qk_rope_head_dim"):
head_dim = self.pretrained_config.kv_lora_rank + self.pretrained_config.qk_rope_head_dim
if head_dim == 576 and torch.cuda.get_device_capability() == (
9, 0):
return True
return False
def get_quant_config(self, name: Optional[str] = None) -> QuantConfig:
if name is None or self.per_layer_quant_configs is None:
return self.quant_config
if name in self.per_layer_quant_configs:
return self.per_layer_quant_configs[name]
raise ValueError(f'quant config of {name} is not found')
@staticmethod
def is_generation_model(model_architectures: Optional[List[str]]) -> bool:
if model_architectures is None:
logger.warning(
"Model architectures is None, default to is_generation_model=True"
)
return True
return model_architectures[0] not in [
"BertForSequenceClassification", "Qwen2ForProcessRewardModel",
"Qwen2ForRewardModel"
]
# TODO: should be 'not model_type == ModelType.ENCODER_ONLY'
# once ModelType is used in pytorch flow.
@classmethod
def from_pretrained(cls,
checkpoint_dir: str,
trust_remote_code=False,
**kwargs):
pretrained_config = transformers.AutoConfig.from_pretrained(
checkpoint_dir,
trust_remote_code=trust_remote_code,
)
# Find the cache path by looking for the config.json file which should be in all
# huggingface models
model_dir = Path(
transformers.file_utils.get_file_from_repo(checkpoint_dir,
'config.json')).parent
quant_config = QuantConfig()
layer_quant_config = None
# quantized ckpt in modelopt format
quant_config_file = model_dir / 'hf_quant_config.json'
if quant_config_file.exists():
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
json_quant_configs = quant_config_dict['quantization']
def _load_json_quant_config(key: str):
if key in json_quant_configs:
return json_quant_configs[key]
return None
quant_config.quant_algo = _load_json_quant_config('quant_algo')
quant_config.kv_cache_quant_algo = _load_json_quant_config(
'kv_cache_quant_algo')
quant_config.group_size = _load_json_quant_config('group_size')
quant_config.exclude_modules = _load_json_quant_config(
'exclude_modules')
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
mixed_quant_config_file = model_dir / 'quant_cfg.json'
with open(mixed_quant_config_file) as fm:
mixed_quant_config = json.load(fm)
mixed_quant_config = mixed_quant_config['quantized_layers']
for k in mixed_quant_config:
config = QuantConfig()
config.quant_algo = mixed_quant_config[k]['quant_algo']
mixed_quant_config[k] = config
layer_quant_config = mixed_quant_config
# quantized ckpt in other formats
elif hasattr(pretrained_config, "quantization_config"):
hf_quant_config = pretrained_config.quantization_config
# DeepSeek V3 FP8 ckpt
if hf_quant_config.get(
"quant_method") == "fp8" and hf_quant_config.get(
"weight_block_size", []):
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
quant_config.exclude_modules = ["*eh_proj"]
return cls(pretrained_config=pretrained_config,
quant_config=quant_config,
quant_config_dict=layer_quant_config,
**kwargs)