mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5772363][fix] fix bug of Mistral-Small-3.1-24B-Instruct-2503 (#10394)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
parent
8e2065b4d9
commit
0517b62789
@ -340,5 +340,6 @@ class MistralConfigLoader(BaseConfigLoader):
|
|||||||
from tensorrt_llm._torch.models.modeling_mistral_large3 import Mistral3Gate
|
from tensorrt_llm._torch.models.modeling_mistral_large3 import Mistral3Gate
|
||||||
|
|
||||||
model_config.pretrained_config.gate_cls = Mistral3Gate
|
model_config.pretrained_config.gate_cls = Mistral3Gate
|
||||||
|
model_config.pretrained_config.input_processor_type = "mistral_large_3"
|
||||||
model_config._frozen = True
|
model_config._frozen = True
|
||||||
return model_config
|
return model_config
|
||||||
|
|||||||
@ -334,6 +334,9 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
|||||||
trust_remote_code=trust_remote_code)
|
trust_remote_code=trust_remote_code)
|
||||||
self._model_path = model_path
|
self._model_path = model_path
|
||||||
if model_type == "mistral_large_3":
|
if model_type == "mistral_large_3":
|
||||||
|
# For mistral large 3, we add chat template in the model forward, and the
|
||||||
|
# MistralCommonImageProcessor is used to process the input when both text and images are provided.
|
||||||
|
# When the input only contains text, we use the text processor to process the input.
|
||||||
self._processor = MistralCommonImageProcessor(
|
self._processor = MistralCommonImageProcessor(
|
||||||
tokenizer=self._tokenizer, dtype=self.dtype)
|
tokenizer=self._tokenizer, dtype=self.dtype)
|
||||||
self.text_processor = AutoProcessor.from_pretrained(
|
self.text_processor = AutoProcessor.from_pretrained(
|
||||||
@ -341,11 +344,12 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
|||||||
use_fast=self.use_fast,
|
use_fast=self.use_fast,
|
||||||
trust_remote_code=trust_remote_code)
|
trust_remote_code=trust_remote_code)
|
||||||
else:
|
else:
|
||||||
|
# For other mistral models, we use the AutoProcessor to process the input.
|
||||||
self._processor = AutoProcessor.from_pretrained(
|
self._processor = AutoProcessor.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
use_fast=self.use_fast,
|
use_fast=self.use_fast,
|
||||||
trust_remote_code=trust_remote_code)
|
trust_remote_code=trust_remote_code)
|
||||||
self.text_processor = None
|
self.text_processor = self._processor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self) -> PretrainedConfig:
|
def config(self) -> PretrainedConfig:
|
||||||
@ -457,19 +461,22 @@ class MistralCommonInputProcessor(Mistral3InputProcessor):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
tokenizer = self.load_tokenizer(model_path, config=config)
|
tokenizer = self.load_tokenizer(model_path,
|
||||||
|
config=config,
|
||||||
|
tokenizer=tokenizer)
|
||||||
super().__init__(model_path=model_path,
|
super().__init__(model_path=model_path,
|
||||||
config=config,
|
config=config,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
model_type="mistral_large_3",
|
model_type=getattr(config, "input_processor_type",
|
||||||
|
"mistral3"),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_tokenizer(model_path: str,
|
def load_tokenizer(model_path: str,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
checkpoint_format: str = "mistral_large_3"):
|
tokenizer: AutoTokenizer | None = None):
|
||||||
if checkpoint_format == "mistral_large_3":
|
if getattr(config, "input_processor_type", None) == "mistral_large_3":
|
||||||
try:
|
try:
|
||||||
return MistralTokenizer.from_pretrained(model_path)
|
return MistralTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
@ -478,10 +485,8 @@ class MistralCommonInputProcessor(Mistral3InputProcessor):
|
|||||||
f"Could not load mistral-common tokenizer from {model_path}, falling back to HuggingFace"
|
f"Could not load mistral-common tokenizer from {model_path}, falling back to HuggingFace"
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path,
|
tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
|
||||||
config=config,
|
model_path, config=config, use_fast=True, trust_remote_code=True)
|
||||||
use_fast=True,
|
|
||||||
trust_remote_code=True)
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -498,7 +503,7 @@ class MistralCommonInputProcessor(Mistral3InputProcessor):
|
|||||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||||
))
|
))
|
||||||
@register_input_processor(
|
@register_input_processor(
|
||||||
Mistral3InputProcessor,
|
MistralCommonInputProcessor,
|
||||||
model_type="mistral3",
|
model_type="mistral3",
|
||||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||||
placeholder_map={
|
placeholder_map={
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user