mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Support VLM part for Mistral Large 3 (#10188)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
parent
7e4cef9def
commit
db3430f589
@ -7,6 +7,21 @@ export mistral_large_3_model_path=<mistral_large_3_model_path>
|
||||
export mistral_large_3_eagle_model_path=<mistral_large_3_eagle_model_path>
|
||||
```
|
||||
|
||||
## Multimodal run
|
||||
|
||||
* Run the Mistral Large V3 by `quickstart_multimodal.py`
|
||||
|
||||
```bash
|
||||
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_multimodal.py \
|
||||
--model_dir ${mistral_large_3_model_path} \
|
||||
--tp_size 4 \
|
||||
--moe_ep_size 4 \
|
||||
--max_tokens 100 \
|
||||
--checkpoint_format mistral \
|
||||
--model_type mistral_large_3 \
|
||||
--moe_backend TRTLLM
|
||||
```
|
||||
|
||||
## LLM-only run
|
||||
|
||||
* Run the Mistral Large V3 by `quickstart_advanced.py`
|
||||
@ -44,9 +59,6 @@ echo "
|
||||
backend: pytorch
|
||||
tensor_parallel_size: 4
|
||||
moe_expert_parallel_size: 4
|
||||
enable_attention_dp: false
|
||||
kv_cache_config:
|
||||
enable_block_reuse: true
|
||||
checkpoint_format: mistral
|
||||
" > serve.yml
|
||||
mpirun -n 1 --allow-run-as-root --oversubscribe python3 -m tensorrt_llm.commands.serve serve \
|
||||
|
||||
304
tensorrt_llm/_torch/models/checkpoints/mistral/tokenizer.py
Normal file
304
tensorrt_llm/_torch/models/checkpoints/mistral/tokenizer.py
Normal file
@ -0,0 +1,304 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from transformers.tokenization_mistral_common import (
|
||||
MistralCommonTokenizer as TransformersMistralTokenizer,
|
||||
)
|
||||
|
||||
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/8e67b2557aae7204c697d7a5c61e00754da465be/vllm/transformers_utils/tokenizers/mistral.py#L166
|
||||
class MistralTokenizer(TransformersTokenizer):
|
||||
def __init__(self, tokenizer: "TransformersMistralTokenizer"):
|
||||
self.transformers_tokenizer = tokenizer
|
||||
self.mistral = tokenizer.tokenizer
|
||||
self.instruct = self.mistral.instruct_tokenizer
|
||||
self.tokenizer = self.instruct.tokenizer
|
||||
|
||||
_mistral_version_str = str(self.tokenizer.version.value)
|
||||
self.version: int = int(_mistral_version_str.split("v")[-1])
|
||||
|
||||
self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
|
||||
self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
|
||||
if not (self.is_tekken or self.is_spm):
|
||||
raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")
|
||||
|
||||
# Reverse order to ensure that the lowest token id is kept.
|
||||
self._vocab_dict = {
|
||||
self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
|
||||
for i in range(self.transformers_tokenizer.vocab_size - 1, -1, -1)
|
||||
}
|
||||
# Sort the dict for convenience
|
||||
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
|
||||
|
||||
# Cache special tokens for faster access.
|
||||
self._special_token_ids = self._get_special_token_ids()
|
||||
self._special_token_ids_set = set(self._special_token_ids)
|
||||
self._special_tokens = self._get_special_tokens(self._special_token_ids)
|
||||
self._special_tokens_set = set(self._special_tokens)
|
||||
|
||||
# Vocab sorted by token id.
|
||||
self._vocab = self.tokenizer._vocab
|
||||
self._max_token_id = self.transformers_tokenizer.vocab_size - 1
|
||||
|
||||
self._all_special_tokens_set = set(self.all_special_tokens)
|
||||
|
||||
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
return [
|
||||
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
|
||||
for i in all_special_ids
|
||||
]
|
||||
|
||||
# the following attributes are set to fit vLLM's design and are used
|
||||
# by the structured output backends.
|
||||
@property
|
||||
def all_special_tokens_extended(self) -> list[str]:
|
||||
return self.all_special_tokens
|
||||
|
||||
@property
|
||||
def all_special_tokens(self) -> list[str]:
|
||||
return self._special_tokens
|
||||
|
||||
@property
|
||||
def all_special_ids(self) -> list[int]:
|
||||
return self._special_token_ids
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_dir: str, **kwargs):
|
||||
if Path(pretrained_model_dir).is_file():
|
||||
tokenizer = TransformersMistralTokenizer(tokenizer_path=pretrained_model_dir)
|
||||
else:
|
||||
tokenizer = TransformersMistralTokenizer.from_pretrained(pretrained_model_dir)
|
||||
return cls(tokenizer)
|
||||
|
||||
def _get_special_token_ids(self) -> list[int]:
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
if self.is_tekken:
|
||||
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
||||
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
|
||||
elif self.is_spm:
|
||||
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(self.tokenizer)
|
||||
special_ids = self.tokenizer._control_tokens
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
|
||||
return sorted(special_ids)
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return self.tokenizer.bos_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self.tokenizer.eos_id
|
||||
|
||||
@property
|
||||
def sep_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return self.transformers_tokenizer.pad_token
|
||||
|
||||
@property
|
||||
def pad_token_id(self) -> int:
|
||||
return self.transformers_tokenizer.pad_token_id
|
||||
|
||||
def __call__(self, text: str, *args, **kwargs) -> any:
|
||||
return self.transformers_tokenizer(text=text, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return self.transformers_tokenizer.name_or_path
|
||||
|
||||
def batch_encode_plus(self, texts: list[str], *args, **kwargs) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_chat_template(
|
||||
self, chat_template: str | None = None, tools: list[dict] | None = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def clean_up_tokenization(self, out_string: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_added_vocab(self) -> dict[str, int]:
|
||||
# Mistral tokenizers have no added vocabulary
|
||||
return {}
|
||||
|
||||
def _tekken_token_to_id(self, tokenizer: "Tekkenizer", t: str | bytes) -> int:
|
||||
assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
|
||||
|
||||
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
|
||||
shift = tokenizer.num_special_tokens
|
||||
try:
|
||||
return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
|
||||
except KeyError:
|
||||
t_str = t_bytes.decode("utf-8")
|
||||
if t_str in tokenizer._special_tokens_reverse_vocab:
|
||||
return tokenizer._special_tokens_reverse_vocab[t_str]
|
||||
logger.warning("Failed to convert token %s to id, replacing with <unk>", t_bytes)
|
||||
return tokenizer.unk_id
|
||||
|
||||
def _is_special_token_id(self, token_id: int) -> bool:
|
||||
return token_id in self._special_token_ids_set
|
||||
|
||||
def convert_tokens_to_string(
|
||||
self,
|
||||
tokens: list[str],
|
||||
skip_special_tokens: bool = False,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
) -> str:
|
||||
to_decode_special_tokens = {SpecialTokens.tool_calls}
|
||||
if self.is_tekken:
|
||||
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
||||
tokens = [
|
||||
t
|
||||
for t in tokens
|
||||
if (t in to_decode_special_tokens or t not in self._special_tokens_set)
|
||||
]
|
||||
|
||||
if any(isinstance(t, bytes) for t in tokens):
|
||||
# we need to encode and decode all tokens again
|
||||
ids = [self._tekken_token_to_id(self.tokenizer, t) for t in tokens]
|
||||
# We filtered unwanted special tokens before
|
||||
# so we can decode the rest.
|
||||
decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
|
||||
else:
|
||||
decoded = "".join(tokens)
|
||||
else:
|
||||
# make sure certain special tokens like Tool calls are
|
||||
# not decoded
|
||||
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(self.tokenizer)
|
||||
|
||||
regular_tokens: list[str] = []
|
||||
decoded_list: list[str] = []
|
||||
decoded = ""
|
||||
|
||||
for token in tokens:
|
||||
if token in to_decode_special_tokens:
|
||||
if regular_tokens:
|
||||
decoded_list.append(
|
||||
self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
|
||||
)
|
||||
regular_tokens = []
|
||||
decoded_list.append(token)
|
||||
else:
|
||||
regular_tokens.append(token)
|
||||
|
||||
if regular_tokens:
|
||||
decoded_list.append(
|
||||
self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
|
||||
)
|
||||
decoded = "".join(decoded_list)
|
||||
|
||||
return decoded
|
||||
|
||||
def encode(
|
||||
self,
|
||||
text: str,
|
||||
truncation: bool | None = None,
|
||||
max_length: int | None = None,
|
||||
add_special_tokens: bool | None = None,
|
||||
) -> list[int]:
|
||||
if add_special_tokens is not None:
|
||||
return self.transformers_tokenizer.encode(
|
||||
text,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
encoded = self.tokenizer.encode(text, bos=True, eos=False)
|
||||
|
||||
if truncation is not False and max_length is not None:
|
||||
return encoded[:max_length]
|
||||
else:
|
||||
return encoded
|
||||
|
||||
def decode(
|
||||
self, token_ids: list[int] | int, skip_special_tokens: bool = True, *args, **kwargs
|
||||
) -> str:
|
||||
return self.transformers_tokenizer.decode(
|
||||
token_ids, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: list[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens
|
||||
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
|
||||
|
||||
if not skip_special_tokens:
|
||||
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
|
||||
|
||||
non_skip_special_tokens_ids = {
|
||||
self.tokenizer.get_control_token(SpecialTokens.tool_calls),
|
||||
}
|
||||
if isinstance(self.instruct, InstructTokenizerV13):
|
||||
if self.instruct.BEGIN_THINK:
|
||||
non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
|
||||
if self.instruct.END_THINK:
|
||||
non_skip_special_tokens_ids.add(self.instruct.END_THINK)
|
||||
|
||||
ids_kept = [
|
||||
i for i in ids if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
|
||||
]
|
||||
|
||||
# We filtered unwanted special tokens so we can decode the rest.
|
||||
tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
|
||||
|
||||
if any("<EFBFBD>" in t for t in tokens) and self.is_tekken:
|
||||
# if a decoded token contains the replacement character, then the
|
||||
# token has an incomplete UTF-8 character so we must use bytes
|
||||
# See: https://github.com/vllm-project/vllm/pull/8640
|
||||
# https://github.com/vllm-project/vllm/pull/9625
|
||||
# if underlying tokenizer is sentencepiece, we just add "<22>".
|
||||
# We filtered unwanted special tokens so we can decode the rest.
|
||||
tokens = [
|
||||
self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
|
||||
if token_id not in self._special_token_ids_set
|
||||
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
|
||||
for token_id in ids_kept
|
||||
]
|
||||
|
||||
return tokens
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self._vocab_dict)
|
||||
|
||||
@property
|
||||
def clean_up_tokenization_spaces(self):
|
||||
return False
|
||||
|
||||
def hf_decode_incrementally(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
prev_text: str | None = None,
|
||||
states: dict | None = None,
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_chat_template(
|
||||
self, conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], *args, **kwargs
|
||||
) -> Union[str, list[int], list[str], list[list[int]]]:
|
||||
raise NotImplementedError
|
||||
@ -4,6 +4,8 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import (AutoProcessor, AutoTokenizer, Mistral3Config,
|
||||
MistralConfig, PretrainedConfig, PreTrainedModel)
|
||||
@ -14,6 +16,8 @@ from tensorrt_llm._torch.attention_backend.interface import (
|
||||
PositionalEmbeddingParams, RopeParams)
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models import modeling_pixtral
|
||||
from tensorrt_llm._torch.models.checkpoints.mistral.tokenizer import \
|
||||
MistralTokenizer
|
||||
from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import \
|
||||
MistralWeightMapper
|
||||
from tensorrt_llm._torch.models.modeling_mistral_large3 import (
|
||||
@ -214,6 +218,96 @@ class MistralForCausalLM(DecoderModelForCausalLM[MistralModel, MistralConfig]):
|
||||
)
|
||||
|
||||
|
||||
class MistralCommonImageProcessor:
|
||||
|
||||
def __init__(self, tokenizer: MistralTokenizer, dtype) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.dtype = dtype
|
||||
|
||||
@property
|
||||
def image_processor(self) -> ImageEncoder:
|
||||
image_encoder = self.tokenizer.instruct.mm_encoder
|
||||
assert isinstance(image_encoder, ImageEncoder)
|
||||
return image_encoder
|
||||
|
||||
@property
|
||||
def image_break_id(self) -> int:
|
||||
return self.image_processor.special_ids.img_break
|
||||
|
||||
@property
|
||||
def image_break_token_id(self) -> int:
|
||||
return self.image_break_id
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.image_processor.special_ids.img
|
||||
|
||||
@property
|
||||
def image_end_id(self) -> int:
|
||||
return self.image_processor.special_ids.img_end
|
||||
|
||||
@property
|
||||
def image_end_token_id(self):
|
||||
return self.image_end_id
|
||||
|
||||
@property
|
||||
def image_size(self) -> int:
|
||||
return self.image_processor.mm_config.max_image_size
|
||||
|
||||
@property
|
||||
def patch_size(self) -> int:
|
||||
return self.image_processor.mm_config.image_patch_size
|
||||
|
||||
def _get_num_multimodal_tokens(self, image_sizes):
|
||||
return {
|
||||
"num_image_tokens":
|
||||
[self.get_num_tokens_per_image(size) for size in image_sizes]
|
||||
}
|
||||
|
||||
def get_num_tokens_per_image(self, image_sizes):
|
||||
# FIXME avoid double loading with custom loader
|
||||
h, w = image_sizes
|
||||
ncols, nrows = self.image_processor._image_to_num_tokens(
|
||||
Image.new("RGB", (w, h)))
|
||||
return ncols * nrows + nrows
|
||||
|
||||
def __call__(self, text, images, media, **kwargs):
|
||||
assert media is not None
|
||||
if isinstance(media, str):
|
||||
media = [media]
|
||||
|
||||
mm_items = [{"type": "image_url", "image_url": url} for url in media]
|
||||
|
||||
logger.debug(f"text: {text}")
|
||||
|
||||
conversation = [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": text
|
||||
}, *mm_items]
|
||||
}]
|
||||
|
||||
encoded = self.tokenizer.transformers_tokenizer.apply_chat_template(
|
||||
conversation, tokenize=True, return_dict=True, return_tensors='pt')
|
||||
|
||||
logger.debug(
|
||||
f"encoded.pixel_values.shape: {encoded.pixel_values.shape}, encoded.input_ids: {encoded.input_ids[0][-20:]}"
|
||||
)
|
||||
logger.debug(
|
||||
f"encoded.input_ids list: {self.tokenizer.transformers_tokenizer.apply_chat_template(conversation)}"
|
||||
)
|
||||
|
||||
processed = {
|
||||
"input_ids": encoded.input_ids,
|
||||
"pixel_values": encoded.pixel_values.to(self.dtype),
|
||||
"attention_mask": encoded.attention_mask,
|
||||
"image_sizes": torch.tensor([encoded.pixel_values.shape[2:]])
|
||||
}
|
||||
return processed
|
||||
|
||||
|
||||
class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
BaseMultimodalDummyInputsBuilder):
|
||||
|
||||
@ -223,6 +317,7 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AutoTokenizer | None,
|
||||
trust_remote_code: bool = False,
|
||||
model_type: str = "mistral3",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model_path=model_path,
|
||||
@ -233,12 +328,24 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
self._config = config
|
||||
self._dtype = self._config.torch_dtype
|
||||
self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
|
||||
model_path)
|
||||
self._model_path = model_path
|
||||
self._processor = AutoProcessor.from_pretrained(
|
||||
model_path,
|
||||
config=config,
|
||||
use_fast=self.use_fast,
|
||||
trust_remote_code=trust_remote_code)
|
||||
self._model_path = model_path
|
||||
if model_type == "mistral_large_3":
|
||||
self._processor = MistralCommonImageProcessor(
|
||||
tokenizer=self._tokenizer, dtype=self.dtype)
|
||||
self.text_processor = AutoProcessor.from_pretrained(
|
||||
model_path,
|
||||
use_fast=self.use_fast,
|
||||
trust_remote_code=trust_remote_code)
|
||||
else:
|
||||
self._processor = AutoProcessor.from_pretrained(
|
||||
model_path,
|
||||
use_fast=self.use_fast,
|
||||
trust_remote_code=trust_remote_code)
|
||||
self.text_processor = None
|
||||
|
||||
@property
|
||||
def config(self) -> PretrainedConfig:
|
||||
@ -273,12 +380,20 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
# format is "pt" (pytorch tensors), but not for "pil" (PIL images).
|
||||
do_rescale = False
|
||||
|
||||
processed = self.processor(
|
||||
text=inputs["prompt"],
|
||||
images=images,
|
||||
do_rescale=do_rescale,
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
if mm_processor_kwargs:
|
||||
# Currently, we only support image modality in MistralCommonImageProcessor.
|
||||
processed = self.processor(
|
||||
text=inputs["prompt"],
|
||||
images=images,
|
||||
do_rescale=do_rescale,
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
else:
|
||||
processed = self.text_processor(
|
||||
text=inputs["prompt"],
|
||||
images=images,
|
||||
do_rescale=do_rescale,
|
||||
)
|
||||
input_ids = processed.pop("input_ids").tolist()[0]
|
||||
# Remaining in `processed`:
|
||||
# * "attention_mask": [B, num_input_tokens]
|
||||
@ -332,8 +447,56 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
])
|
||||
|
||||
|
||||
class MistralCommonInputProcessor(Mistral3InputProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AutoTokenizer,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
tokenizer = self.load_tokenizer(model_path, config=config)
|
||||
super().__init__(model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_type="mistral_large_3",
|
||||
**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load_tokenizer(model_path: str,
|
||||
config: PretrainedConfig,
|
||||
checkpoint_format: str = "mistral_large_3"):
|
||||
if checkpoint_format == "mistral_large_3":
|
||||
try:
|
||||
return MistralTokenizer.from_pretrained(model_path)
|
||||
|
||||
except ValueError:
|
||||
logger.info(
|
||||
f"Could not load mistral-common tokenizer from {model_path}, falling back to HuggingFace"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path,
|
||||
config=config,
|
||||
use_fast=True,
|
||||
trust_remote_code=True)
|
||||
return tokenizer
|
||||
|
||||
|
||||
@register_auto_model("Mistral3ForConditionalGeneration")
|
||||
@register_auto_model("PixtralForConditionalGeneration")
|
||||
@register_input_processor(
|
||||
MistralCommonInputProcessor,
|
||||
model_type="mistral_large_3",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
# NOTE: mistral-common uses the tokenizer to set placeholders, this will be ignored
|
||||
"image": "[IMG]",
|
||||
},
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
))
|
||||
@register_input_processor(
|
||||
Mistral3InputProcessor,
|
||||
model_type="mistral3",
|
||||
@ -410,7 +573,6 @@ class Mistral3VLM(PreTrainedModel):
|
||||
self._multi_modal_projector = Mistral3MultiModalProjector(
|
||||
model_config).eval().to(self._device)
|
||||
self._post_config()
|
||||
self.is_loaded = True
|
||||
|
||||
# This is necessary because the executor looks at
|
||||
# `model.model_config.pretrained_config.vocab_size`.
|
||||
@ -766,11 +928,3 @@ class Mistral3MultiModalProjector(torch.nn.Module):
|
||||
|
||||
def load_weights(self, weights):
|
||||
_load_weights_impl(self, weights)
|
||||
|
||||
|
||||
def _filter_weights(weights: Dict[str, torch.Tensor],
|
||||
prefix: str) -> Dict[str, torch.Tensor]:
|
||||
return {
|
||||
name[len(prefix):]: weight
|
||||
for name, weight in weights.items() if name.startswith(prefix)
|
||||
}
|
||||
|
||||
@ -4,8 +4,6 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig, PretrainedConfig
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ...functional import PositionEmbeddingType
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
@ -754,19 +752,6 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
skip_create_weights_in_init=True,
|
||||
)
|
||||
self.draft_config.extra_attrs = model_config.extra_attrs
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
if get_sm_version(
|
||||
) == 100 and self.draft_config.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
|
||||
# FIXME There is a known issue on TRTLLM moe backend + FP8 blockwise
|
||||
logger.warning(
|
||||
"Switching moe_backend of draft model to DEEPGEMM for FP8_BLOCK_SCALES quantization on SM100"
|
||||
"This is a workaround for the known issue on TRTLLM moe backend + FP8 blockwise"
|
||||
)
|
||||
self.draft_config._frozen = False
|
||||
self.draft_config.moe_backend = "DEEPGEMM"
|
||||
self.draft_config._frozen = True
|
||||
|
||||
elif spec_config.eagle3_model_arch == "llama3":
|
||||
self.draft_config = ModelConfig.from_pretrained(
|
||||
model_config.spec_config.speculative_model_dir,
|
||||
|
||||
@ -374,8 +374,10 @@ NOTE:
|
||||
placeholder for the model needs to be added in retrieve_multimodal_placeholder().
|
||||
"""
|
||||
|
||||
HF_CHAT_TEMPLATE_EXCEPTIONS = ["llava_llama"]
|
||||
PLACEHOLDER_EXCEPTIONS = ["llava_next", "NemotronH_Nano_VL_V2"]
|
||||
HF_CHAT_TEMPLATE_EXCEPTIONS = ["llava_llama", "mistral_large_3"]
|
||||
PLACEHOLDER_EXCEPTIONS = [
|
||||
"llava_next", "NemotronH_Nano_VL_V2", "mistral_large_3"
|
||||
]
|
||||
|
||||
|
||||
# Helpers to always get the latest supported multimodal model types from the registry
|
||||
@ -771,6 +773,13 @@ def default_multimodal_input_loader(
|
||||
add_generation_prompt=True,
|
||||
mm_placeholder_counts=[mm_placeholder_counts])
|
||||
input = {"prompt": prompt}
|
||||
|
||||
# When the tokenizer is a MistralTokenizer, we need to keep the source media to handle in processor later.
|
||||
from tensorrt_llm._torch.models.checkpoints.mistral.tokenizer import \
|
||||
MistralTokenizer
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
input["mm_processor_kwargs"] = {"media": media}
|
||||
|
||||
if mm_placeholder_counts:
|
||||
if mm_embeddings is not None:
|
||||
input[
|
||||
|
||||
@ -21,3 +21,5 @@ microsoft/Phi-4-multimodal-instruct:
|
||||
- accuracy: 53.67
|
||||
Qwen/Qwen3-VL-30B-A3B-Instruct:
|
||||
- accuracy: 55.33
|
||||
mistral/Mistral-Large-3-675B:
|
||||
- accuracy: 60.00
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
|
||||
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, MoeConfig, SamplingParams
|
||||
|
||||
from ..conftest import llm_models_root
|
||||
from ..conftest import llm_models_root, skip_pre_blackwell
|
||||
from .accuracy_core import MMMU, LlmapiAccuracyTestHarness
|
||||
|
||||
|
||||
@ -263,3 +263,52 @@ class TestQwen3VL_MOE(LlmapiAccuracyTestHarness):
|
||||
) as llm:
|
||||
task = MMMU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=self.sampling_params)
|
||||
|
||||
|
||||
class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "mistral/Mistral-Large-3-675B"
|
||||
MODEL_PATH = (
|
||||
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-NVFP4/"
|
||||
)
|
||||
MAX_NUM_TOKENS = 16384
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=MAX_NUM_TOKENS, truncate_prompt_tokens=MMMU.MAX_INPUT_LEN, stop="<|endoftext|>"
|
||||
)
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_mpi_world_size(4)
|
||||
@pytest.mark.skip_less_device_memory(183000)
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend",
|
||||
[
|
||||
(4, 1, 4, False, True, True, "TRTLLM"),
|
||||
],
|
||||
ids=[
|
||||
"latency_moe_trtllm",
|
||||
],
|
||||
)
|
||||
def test_nvfp4_4gpus(
|
||||
self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, overlap_scheduler, moe_backend
|
||||
):
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
moe_config=MoeConfig(backend=moe_backend),
|
||||
)
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
|
||||
|
||||
with LLM(
|
||||
self.MODEL_PATH,
|
||||
max_num_tokens=self.MAX_NUM_TOKENS,
|
||||
checkpoint_format="mistral",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
kv_cache_config=kv_cache_config,
|
||||
) as llm:
|
||||
task = MMMU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=self.sampling_params, model_type="mistral_large_3")
|
||||
|
||||
@ -71,6 +71,7 @@ l0_gb200_multi_gpus:
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
@ -107,4 +108,4 @@ l0_gb200_multi_gpus:
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus_online_eplb[enable_configurable_moe-fp8]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90)
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90)
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
|
||||
- accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90)
|
||||
|
||||
@ -517,3 +517,4 @@ test_e2e.py::test_eagle3_output_consistency_4gpus[Qwen3/saved_models_Qwen3-235B-
|
||||
examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5769855)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8] SKIP (https://nvbugs/5769890)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/5769890)
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user