mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5772363][fix] fix bug of Mistral-Small-3.1-24B-Instruct-2503 (#10394)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
parent
8e2065b4d9
commit
0517b62789
@ -340,5 +340,6 @@ class MistralConfigLoader(BaseConfigLoader):
|
||||
from tensorrt_llm._torch.models.modeling_mistral_large3 import Mistral3Gate
|
||||
|
||||
model_config.pretrained_config.gate_cls = Mistral3Gate
|
||||
model_config.pretrained_config.input_processor_type = "mistral_large_3"
|
||||
model_config._frozen = True
|
||||
return model_config
|
||||
|
||||
@ -334,6 +334,9 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
trust_remote_code=trust_remote_code)
|
||||
self._model_path = model_path
|
||||
if model_type == "mistral_large_3":
|
||||
# For mistral large 3, we add chat template in the model forward, and the
|
||||
# MistralCommonImageProcessor is used to process the input when both text and images are provided.
|
||||
# When the input only contains text, we use the text processor to process the input.
|
||||
self._processor = MistralCommonImageProcessor(
|
||||
tokenizer=self._tokenizer, dtype=self.dtype)
|
||||
self.text_processor = AutoProcessor.from_pretrained(
|
||||
@ -341,11 +344,12 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
use_fast=self.use_fast,
|
||||
trust_remote_code=trust_remote_code)
|
||||
else:
|
||||
# For other mistral models, we use the AutoProcessor to process the input.
|
||||
self._processor = AutoProcessor.from_pretrained(
|
||||
model_path,
|
||||
use_fast=self.use_fast,
|
||||
trust_remote_code=trust_remote_code)
|
||||
self.text_processor = None
|
||||
self.text_processor = self._processor
|
||||
|
||||
@property
|
||||
def config(self) -> PretrainedConfig:
|
||||
@ -457,19 +461,22 @@ class MistralCommonInputProcessor(Mistral3InputProcessor):
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
tokenizer = self.load_tokenizer(model_path, config=config)
|
||||
tokenizer = self.load_tokenizer(model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer)
|
||||
super().__init__(model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_type="mistral_large_3",
|
||||
model_type=getattr(config, "input_processor_type",
|
||||
"mistral3"),
|
||||
**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load_tokenizer(model_path: str,
|
||||
config: PretrainedConfig,
|
||||
checkpoint_format: str = "mistral_large_3"):
|
||||
if checkpoint_format == "mistral_large_3":
|
||||
tokenizer: AutoTokenizer | None = None):
|
||||
if getattr(config, "input_processor_type", None) == "mistral_large_3":
|
||||
try:
|
||||
return MistralTokenizer.from_pretrained(model_path)
|
||||
|
||||
@ -478,10 +485,8 @@ class MistralCommonInputProcessor(Mistral3InputProcessor):
|
||||
f"Could not load mistral-common tokenizer from {model_path}, falling back to HuggingFace"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path,
|
||||
config=config,
|
||||
use_fast=True,
|
||||
trust_remote_code=True)
|
||||
tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
|
||||
model_path, config=config, use_fast=True, trust_remote_code=True)
|
||||
return tokenizer
|
||||
|
||||
|
||||
@ -498,7 +503,7 @@ class MistralCommonInputProcessor(Mistral3InputProcessor):
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
))
|
||||
@register_input_processor(
|
||||
Mistral3InputProcessor,
|
||||
MistralCommonInputProcessor,
|
||||
model_type="mistral3",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
|
||||
Loading…
Reference in New Issue
Block a user