[https://nvbugs/5772363][fix] fix bug of Mistral-Small-3.1-24B-Instruct-2503 (#10394)

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
bhsueh_NV 2026-01-05 09:04:13 +08:00 committed by GitHub
parent 8e2065b4d9
commit 0517b62789
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 10 deletions

View File

@ -340,5 +340,6 @@ class MistralConfigLoader(BaseConfigLoader):
from tensorrt_llm._torch.models.modeling_mistral_large3 import Mistral3Gate
model_config.pretrained_config.gate_cls = Mistral3Gate
model_config.pretrained_config.input_processor_type = "mistral_large_3"
model_config._frozen = True
return model_config

View File

@ -334,6 +334,9 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
trust_remote_code=trust_remote_code)
self._model_path = model_path
if model_type == "mistral_large_3":
# For mistral large 3, we add chat template in the model forward, and the
# MistralCommonImageProcessor is used to process the input when both text and images are provided.
# When the input only contains text, we use the text processor to process the input.
self._processor = MistralCommonImageProcessor(
tokenizer=self._tokenizer, dtype=self.dtype)
self.text_processor = AutoProcessor.from_pretrained(
@ -341,11 +344,12 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
use_fast=self.use_fast,
trust_remote_code=trust_remote_code)
else:
# For other mistral models, we use the AutoProcessor to process the input.
self._processor = AutoProcessor.from_pretrained(
model_path,
use_fast=self.use_fast,
trust_remote_code=trust_remote_code)
self.text_processor = None
self.text_processor = self._processor
@property
def config(self) -> PretrainedConfig:
@ -457,19 +461,22 @@ class MistralCommonInputProcessor(Mistral3InputProcessor):
trust_remote_code: bool = False,
**kwargs,
):
tokenizer = self.load_tokenizer(model_path, config=config)
tokenizer = self.load_tokenizer(model_path,
config=config,
tokenizer=tokenizer)
super().__init__(model_path=model_path,
config=config,
tokenizer=tokenizer,
trust_remote_code=trust_remote_code,
model_type="mistral_large_3",
model_type=getattr(config, "input_processor_type",
"mistral3"),
**kwargs)
@staticmethod
def load_tokenizer(model_path: str,
config: PretrainedConfig,
checkpoint_format: str = "mistral_large_3"):
if checkpoint_format == "mistral_large_3":
tokenizer: AutoTokenizer | None = None):
if getattr(config, "input_processor_type", None) == "mistral_large_3":
try:
return MistralTokenizer.from_pretrained(model_path)
@ -478,10 +485,8 @@ class MistralCommonInputProcessor(Mistral3InputProcessor):
f"Could not load mistral-common tokenizer from {model_path}, falling back to HuggingFace"
)
tokenizer = AutoTokenizer.from_pretrained(model_path,
config=config,
use_fast=True,
trust_remote_code=True)
tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
model_path, config=config, use_fast=True, trust_remote_code=True)
return tokenizer
@ -498,7 +503,7 @@ class MistralCommonInputProcessor(Mistral3InputProcessor):
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
@register_input_processor(
Mistral3InputProcessor,
MistralCommonInputProcessor,
model_type="mistral3",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={