mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] fix config loading for DeepSeek-V3.2 in trtllm-bench (#8729)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
parent
e2c5a38879
commit
a21697ead9
@ -12,7 +12,8 @@ import transformers
|
||||
from transformers.utils import HF_MODULES_CACHE
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import (is_nemotron_hybrid,
|
||||
load_pretrained_config)
|
||||
from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding
|
||||
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
@ -25,18 +26,6 @@ from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)
|
||||
|
||||
|
||||
class LazyConfigDict(dict):
|
||||
|
||||
def __getitem__(self, key):
|
||||
import tensorrt_llm._torch.configs as configs
|
||||
return getattr(configs, super().__getitem__(key))
|
||||
|
||||
|
||||
_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict(
|
||||
deepseek_v32="DeepseekV3Config",
|
||||
) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoeLoadBalancerConfig:
|
||||
num_slots: Optional[int] = None
|
||||
@ -432,51 +421,31 @@ class ModelConfig(Generic[TConfig]):
|
||||
# When handling the case where model_format is TLLM_ENGINE
|
||||
# send cyclic requests to the NONE URL.
|
||||
if checkpoint_dir is not None:
|
||||
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
|
||||
pretrained_config = load_pretrained_config(
|
||||
checkpoint_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
model_type = config_dict.get("model_type")
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
pretrained_config = config_class.from_pretrained(
|
||||
checkpoint_dir,
|
||||
**kwargs,
|
||||
)
|
||||
if model_type == "deepseek_v32":
|
||||
sparse_attention_config = kwargs.get(
|
||||
'sparse_attention_config')
|
||||
kwargs[
|
||||
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
|
||||
index_n_heads=(
|
||||
sparse_attention_config.index_n_heads
|
||||
if sparse_attention_config
|
||||
and sparse_attention_config.index_n_heads
|
||||
is not None else
|
||||
pretrained_config.index_n_heads),
|
||||
index_head_dim=(
|
||||
sparse_attention_config.index_head_dim
|
||||
if sparse_attention_config
|
||||
and sparse_attention_config.index_head_dim
|
||||
is not None else
|
||||
pretrained_config.index_head_dim),
|
||||
index_topk=(sparse_attention_config.index_topk
|
||||
if sparse_attention_config and
|
||||
sparse_attention_config.index_topk
|
||||
is not None else
|
||||
pretrained_config.index_topk),
|
||||
indexer_max_chunk_size=(
|
||||
sparse_attention_config.
|
||||
indexer_max_chunk_size
|
||||
if sparse_attention_config
|
||||
and sparse_attention_config.
|
||||
indexer_max_chunk_size is not None else
|
||||
None))
|
||||
else:
|
||||
pretrained_config = transformers.AutoConfig.from_pretrained(
|
||||
checkpoint_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if pretrained_config.architectures[
|
||||
0] == "DeepseekV32ForCausalLM":
|
||||
sparse_attention_config = kwargs.get(
|
||||
'sparse_attention_config')
|
||||
if sparse_attention_config:
|
||||
index_n_heads = sparse_attention_config.index_n_heads or pretrained_config.index_n_heads
|
||||
index_head_dim = sparse_attention_config.index_head_dim or pretrained_config.index_head_dim
|
||||
index_topk = sparse_attention_config.index_topk or pretrained_config.index_topk
|
||||
indexer_max_chunk_size = sparse_attention_config.indexer_max_chunk_size
|
||||
else:
|
||||
index_n_heads = pretrained_config.index_n_heads
|
||||
index_head_dim = pretrained_config.index_head_dim
|
||||
index_topk = pretrained_config.index_topk
|
||||
indexer_max_chunk_size = None
|
||||
kwargs[
|
||||
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
|
||||
index_n_heads=index_n_heads,
|
||||
index_head_dim=index_head_dim,
|
||||
index_topk=index_topk,
|
||||
indexer_max_chunk_size=indexer_max_chunk_size)
|
||||
else:
|
||||
raise ValueError(
|
||||
"checkpoint_dir is None. Cannot load model config without a valid checkpoint directory."
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import transformers
|
||||
|
||||
|
||||
def is_nemotron_hybrid(config):
|
||||
if hasattr(config, "hybrid_override_pattern"
|
||||
) and config.hybrid_override_pattern is not None and len(
|
||||
@ -18,3 +21,32 @@ def is_qwen3_next(config):
|
||||
config, 'architectures'
|
||||
) and config.architectures is not None and config.architectures[
|
||||
0] == 'Qwen3NextForCausalLM'
|
||||
|
||||
|
||||
# TODO: remove this once the transformers can support all of those models in _CONFIG_REGISTRY
|
||||
class LazyConfigDict(dict):
|
||||
|
||||
def __getitem__(self, key):
|
||||
import tensorrt_llm._torch.configs as configs
|
||||
return getattr(configs, super().__getitem__(key))
|
||||
|
||||
|
||||
_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict(
|
||||
deepseek_v32="DeepseekV3Config",
|
||||
) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class
|
||||
|
||||
|
||||
def load_pretrained_config(model_name_or_path: str,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs) -> transformers.PretrainedConfig:
|
||||
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
|
||||
model_name_or_path, **kwargs)
|
||||
model_type = config_dict.get("model_type")
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
model_config = config_class.from_pretrained(model_name_or_path,
|
||||
**kwargs)
|
||||
else:
|
||||
model_config = transformers.AutoConfig.from_pretrained(
|
||||
model_name_or_path, trust_remote_code=trust_remote_code)
|
||||
return model_config
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from transformers import AutoConfig
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Tuple, get_args
|
||||
import click
|
||||
from click_option_group import AllOptionGroup, optgroup
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid, load_pretrained_config
|
||||
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
|
||||
from tensorrt_llm.bench.utils.data import create_dataset_from_stream, initialize_tokenizer
|
||||
from tensorrt_llm.bench.utils import VALID_QUANT_ALGOS
|
||||
@ -86,9 +85,9 @@ def get_model_config(model_name: str, model_path: Path = None) -> ModelConfig:
|
||||
Raises:
|
||||
ValueError: When model is not supported.
|
||||
"""
|
||||
if is_nemotron_hybrid(
|
||||
AutoConfig.from_pretrained(model_path or model_name,
|
||||
trust_remote_code=True)):
|
||||
pretrained_config = load_pretrained_config(model_path or model_name,
|
||||
trust_remote_code=True)
|
||||
if is_nemotron_hybrid(pretrained_config):
|
||||
return NemotronHybridConfig.from_hf(model_name, model_path)
|
||||
return ModelConfig.from_hf(model_name, model_path)
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from transformers import AutoConfig
|
||||
from typing import Optional, Literal
|
||||
from pydantic import AliasPath, BaseModel, Field, AliasChoices, model_validator
|
||||
import huggingface_hub
|
||||
@ -14,6 +13,8 @@ import os
|
||||
import json
|
||||
import struct
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import load_pretrained_config
|
||||
|
||||
|
||||
def parse_safetensors_file_metadata(model_path, filename):
|
||||
|
||||
@ -192,9 +193,10 @@ class ModelConfig(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_hf(cls, model_hf_name, hf_model_path):
|
||||
model_name_or_path = hf_model_path or model_hf_name
|
||||
hf_config = AutoConfig.from_pretrained(
|
||||
model_name_or_path, trust_remote_code=True).to_dict()
|
||||
pretrained_config = load_pretrained_config(hf_model_path
|
||||
or model_hf_name,
|
||||
trust_remote_code=True)
|
||||
hf_config = pretrained_config.to_dict()
|
||||
param_count = cls.get_param_count(model_hf_name, hf_model_path)
|
||||
|
||||
return cls(name=model_hf_name, param_count=param_count, **hf_config)
|
||||
|
||||
@ -17,7 +17,7 @@ from fastapi import Body, FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.routing import Mount
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from tensorrt_llm._tensorrt_engine import LLM
|
||||
# yapf: disable
|
||||
@ -101,27 +101,15 @@ class OpenAIServer:
|
||||
except Exception:
|
||||
logger.debug("Failed to load AutoProcessor or AutoConfig for %s", hf_tokenizer_path)
|
||||
self.processor = None
|
||||
# Temporary workaround for DSv3.2 config.
|
||||
import transformers
|
||||
|
||||
from tensorrt_llm._torch.model_config import _CONFIG_REGISTRY
|
||||
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
|
||||
hf_tokenizer_path,
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
||||
model_type = config_dict.get("model_type")
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
self.model_config = config_class.from_pretrained(
|
||||
hf_tokenizer_path,
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.model_config = AutoConfig.from_pretrained(hf_tokenizer_path, trust_remote_code=trust_remote_code)
|
||||
except Exception:
|
||||
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
|
||||
self.model_config = None
|
||||
# load model config
|
||||
try:
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import \
|
||||
load_pretrained_config
|
||||
self.model_config = load_pretrained_config(hf_tokenizer_path,
|
||||
trust_remote_code=trust_remote_code)
|
||||
except Exception:
|
||||
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
|
||||
self.model_config = None
|
||||
|
||||
# Enable response storage for Responses API
|
||||
self.enable_store = True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user