[None][fix] fix config loading for DeepSeek-V3.2 in trtllm-bench (#8729)

Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
Fanrong Li 2025-10-29 20:17:16 +08:00 committed by GitHub
parent e2c5a38879
commit a21697ead9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 76 additions and 86 deletions

View File

@ -12,7 +12,8 @@ import transformers
from transformers.utils import HF_MODULES_CACHE
from tensorrt_llm import logger
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
from tensorrt_llm._torch.pyexecutor.config_utils import (is_nemotron_hybrid,
load_pretrained_config)
from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
from tensorrt_llm.functional import AllReduceStrategy
@ -25,18 +26,6 @@ from tensorrt_llm.quantization.mode import QuantAlgo
TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)
class LazyConfigDict(dict):
def __getitem__(self, key):
import tensorrt_llm._torch.configs as configs
return getattr(configs, super().__getitem__(key))
_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict(
deepseek_v32="DeepseekV3Config",
) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class
@dataclass
class MoeLoadBalancerConfig:
num_slots: Optional[int] = None
@ -432,51 +421,31 @@ class ModelConfig(Generic[TConfig]):
# When handling the case where model_format is TLLM_ENGINE
# send cyclic requests to the NONE URL.
if checkpoint_dir is not None:
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
pretrained_config = load_pretrained_config(
checkpoint_dir,
trust_remote_code=trust_remote_code,
**kwargs,
)
model_type = config_dict.get("model_type")
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
pretrained_config = config_class.from_pretrained(
checkpoint_dir,
**kwargs,
)
if model_type == "deepseek_v32":
sparse_attention_config = kwargs.get(
'sparse_attention_config')
kwargs[
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
index_n_heads=(
sparse_attention_config.index_n_heads
if sparse_attention_config
and sparse_attention_config.index_n_heads
is not None else
pretrained_config.index_n_heads),
index_head_dim=(
sparse_attention_config.index_head_dim
if sparse_attention_config
and sparse_attention_config.index_head_dim
is not None else
pretrained_config.index_head_dim),
index_topk=(sparse_attention_config.index_topk
if sparse_attention_config and
sparse_attention_config.index_topk
is not None else
pretrained_config.index_topk),
indexer_max_chunk_size=(
sparse_attention_config.
indexer_max_chunk_size
if sparse_attention_config
and sparse_attention_config.
indexer_max_chunk_size is not None else
None))
else:
pretrained_config = transformers.AutoConfig.from_pretrained(
checkpoint_dir,
trust_remote_code=trust_remote_code,
)
if pretrained_config.architectures[
0] == "DeepseekV32ForCausalLM":
sparse_attention_config = kwargs.get(
'sparse_attention_config')
if sparse_attention_config:
index_n_heads = sparse_attention_config.index_n_heads or pretrained_config.index_n_heads
index_head_dim = sparse_attention_config.index_head_dim or pretrained_config.index_head_dim
index_topk = sparse_attention_config.index_topk or pretrained_config.index_topk
indexer_max_chunk_size = sparse_attention_config.indexer_max_chunk_size
else:
index_n_heads = pretrained_config.index_n_heads
index_head_dim = pretrained_config.index_head_dim
index_topk = pretrained_config.index_topk
indexer_max_chunk_size = None
kwargs[
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
index_n_heads=index_n_heads,
index_head_dim=index_head_dim,
index_topk=index_topk,
indexer_max_chunk_size=indexer_max_chunk_size)
else:
raise ValueError(
"checkpoint_dir is None. Cannot load model config without a valid checkpoint directory."

View File

@ -1,3 +1,6 @@
import transformers
def is_nemotron_hybrid(config):
if hasattr(config, "hybrid_override_pattern"
) and config.hybrid_override_pattern is not None and len(
@ -18,3 +21,32 @@ def is_qwen3_next(config):
config, 'architectures'
) and config.architectures is not None and config.architectures[
0] == 'Qwen3NextForCausalLM'
# TODO: remove this once the transformers can support all of those models in _CONFIG_REGISTRY
class LazyConfigDict(dict):
def __getitem__(self, key):
import tensorrt_llm._torch.configs as configs
return getattr(configs, super().__getitem__(key))
_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict(
deepseek_v32="DeepseekV3Config",
) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class
def load_pretrained_config(model_name_or_path: str,
trust_remote_code: bool = False,
**kwargs) -> transformers.PretrainedConfig:
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
model_name_or_path, **kwargs)
model_type = config_dict.get("model_type")
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
model_config = config_class.from_pretrained(model_name_or_path,
**kwargs)
else:
model_config = transformers.AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code)
return model_config

View File

@ -1,12 +1,11 @@
from __future__ import annotations
from transformers import AutoConfig
from pathlib import Path
from typing import Tuple, get_args
import click
from click_option_group import AllOptionGroup, optgroup
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid, load_pretrained_config
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.bench.utils.data import create_dataset_from_stream, initialize_tokenizer
from tensorrt_llm.bench.utils import VALID_QUANT_ALGOS
@ -86,9 +85,9 @@ def get_model_config(model_name: str, model_path: Path = None) -> ModelConfig:
Raises:
ValueError: When model is not supported.
"""
if is_nemotron_hybrid(
AutoConfig.from_pretrained(model_path or model_name,
trust_remote_code=True)):
pretrained_config = load_pretrained_config(model_path or model_name,
trust_remote_code=True)
if is_nemotron_hybrid(pretrained_config):
return NemotronHybridConfig.from_hf(model_name, model_path)
return ModelConfig.from_hf(model_name, model_path)

View File

@ -1,4 +1,3 @@
from transformers import AutoConfig
from typing import Optional, Literal
from pydantic import AliasPath, BaseModel, Field, AliasChoices, model_validator
import huggingface_hub
@ -14,6 +13,8 @@ import os
import json
import struct
from tensorrt_llm._torch.pyexecutor.config_utils import load_pretrained_config
def parse_safetensors_file_metadata(model_path, filename):
@ -192,9 +193,10 @@ class ModelConfig(BaseModel):
@classmethod
def from_hf(cls, model_hf_name, hf_model_path):
model_name_or_path = hf_model_path or model_hf_name
hf_config = AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=True).to_dict()
pretrained_config = load_pretrained_config(hf_model_path
or model_hf_name,
trust_remote_code=True)
hf_config = pretrained_config.to_dict()
param_count = cls.get_param_count(model_hf_name, hf_model_path)
return cls(name=model_hf_name, param_count=param_count, **hf_config)

View File

@ -17,7 +17,7 @@ from fastapi import Body, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Mount
from transformers import AutoConfig, AutoProcessor
from transformers import AutoProcessor
from tensorrt_llm._tensorrt_engine import LLM
# yapf: disable
@ -101,27 +101,15 @@ class OpenAIServer:
except Exception:
logger.debug("Failed to load AutoProcessor or AutoConfig for %s", hf_tokenizer_path)
self.processor = None
# Temporary workaround for DSv3.2 config.
import transformers
from tensorrt_llm._torch.model_config import _CONFIG_REGISTRY
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
hf_tokenizer_path,
trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type")
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
self.model_config = config_class.from_pretrained(
hf_tokenizer_path,
trust_remote_code=trust_remote_code
)
else:
try:
self.model_config = AutoConfig.from_pretrained(hf_tokenizer_path, trust_remote_code=trust_remote_code)
except Exception:
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
self.model_config = None
# load model config
try:
from tensorrt_llm._torch.pyexecutor.config_utils import \
load_pretrained_config
self.model_config = load_pretrained_config(hf_tokenizer_path,
trust_remote_code=trust_remote_code)
except Exception:
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
self.model_config = None
# Enable response storage for Responses API
self.enable_store = True