TensorRT-LLMs/tensorrt_llm/models/automodel.py
Yan Chunwei 7568deb2f1
[nvbug/5387226] chore: add propogation for trust_remote_code to AutoConfig (#6001)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2025-07-16 16:05:38 +08:00

107 lines
4.2 KiB
Python

from pathlib import Path
from typing import Optional, Union
from ..bindings.executor import DecodingMode
from ..mapping import Mapping
from . import MODEL_MAP
from .modeling_utils import QuantConfig
class AutoConfig:
@staticmethod
def from_hugging_face(hf_model_or_dir,
dtype: str = 'auto',
mapping: Optional[Mapping] = None,
quant_config: Optional[QuantConfig] = None,
**kwargs):
import transformers
trust_remote_code = kwargs.get('trust_remote_code', False)
hf_config = transformers.AutoConfig.from_pretrained(
hf_model_or_dir, trust_remote_code=trust_remote_code)
if hasattr(hf_config,
'architectures') and hf_config.architectures is not None:
hf_arch = hf_config.architectures[0]
elif hasattr(hf_config,
'model_type') and hf_config.model_type.find('mamba') != -1:
hf_arch = 'MambaForCausalLM'
trtllm_model_cls = MODEL_MAP.get(hf_arch, None)
if trtllm_model_cls is None:
raise NotImplementedError(
f"The given huggingface model architecture {hf_arch} is not supported in TRT-LLM yet"
)
if not hasattr(trtllm_model_cls, 'config_class'):
raise NotImplementedError(
f"The given TRT-LLM model class {trtllm_model_cls} does not support AutoConfig"
)
trtllm_cfg_cls = getattr(trtllm_model_cls, 'config_class')
if not hasattr(trtllm_cfg_cls, 'from_hugging_face'):
raise NotImplementedError(
f"The given TRT-LLM model class {trtllm_cfg_cls} does not support from_hugging_face"
)
return trtllm_cfg_cls.from_hugging_face(hf_model_or_dir, dtype, mapping,
quant_config, **kwargs)
class AutoModelForCausalLM:
@staticmethod
def get_trtllm_model_class(hf_model_or_dir: Union[str, Path],
trust_remote_code: bool = False,
decoding_mode: DecodingMode = None):
import transformers
hf_model_or_dir = Path(hf_model_or_dir) if not isinstance(
hf_model_or_dir, Path) else hf_model_or_dir
assert (hf_model_or_dir / "config.json").exists(
), "Please provide a Hugging Face model as the input to the LLM API."
hf_config = transformers.AutoConfig.from_pretrained(
hf_model_or_dir, trust_remote_code=trust_remote_code)
if decoding_mode is not None:
if decoding_mode.isMedusa():
hf_arch = 'MedusaForCausalLM'
elif decoding_mode.isEagle():
hf_arch = 'EagleForCausalLM'
else:
raise NotImplementedError(f"Unknown speculative decoding mode.")
elif hasattr(hf_config,
'architectures') and hf_config.architectures is not None:
hf_arch = hf_config.architectures[0]
elif hasattr(hf_config,
'model_type') and hf_config.model_type.find('mamba') != -1:
hf_arch = 'MambaForCausalLM'
trtllm_model_cls = MODEL_MAP.get(hf_arch, None)
if trtllm_model_cls is None:
raise NotImplementedError(
f"The given huggingface model architecture {hf_arch} is not supported in TRT-LLM yet"
)
return trtllm_model_cls
@staticmethod
def from_hugging_face(hf_model_or_dir,
dtype: str = 'auto',
mapping: Optional[Mapping] = None,
quant_config: Optional[QuantConfig] = None,
**kwargs):
trtllm_model_cls = AutoModelForCausalLM.get_trtllm_model_class(
hf_model_or_dir)
if not hasattr(trtllm_model_cls, 'from_hugging_face'):
raise NotImplementedError(
f"The given {trtllm_model_cls} does not support from_hugging_face yet"
)
return trtllm_model_cls.from_hugging_face(hf_model_or_dir, dtype,
mapping, quant_config,
**kwargs)