[None][feat] Support VLM part for Mistral Large 3 (#10188)

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
bhsueh_NV 2025-12-26 00:20:58 +08:00 committed by GitHub
parent 7e4cef9def
commit db3430f589
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 558 additions and 41 deletions

View File

@ -7,6 +7,21 @@ export mistral_large_3_model_path=<mistral_large_3_model_path>
export mistral_large_3_eagle_model_path=<mistral_large_3_eagle_model_path>
```
## Multimodal run
* Run the Mistral Large V3 by `quickstart_multimodal.py`
```bash
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_multimodal.py \
--model_dir ${mistral_large_3_model_path} \
--tp_size 4 \
--moe_ep_size 4 \
--max_tokens 100 \
--checkpoint_format mistral \
--model_type mistral_large_3 \
--moe_backend TRTLLM
```
## LLM-only run
* Run the Mistral Large V3 by `quickstart_advanced.py`
@ -44,9 +59,6 @@ echo "
backend: pytorch
tensor_parallel_size: 4
moe_expert_parallel_size: 4
enable_attention_dp: false
kv_cache_config:
enable_block_reuse: true
checkpoint_format: mistral
" > serve.yml
mpirun -n 1 --allow-run-as-root --oversubscribe python3 -m tensorrt_llm.commands.serve serve \

View File

@ -0,0 +1,304 @@
from pathlib import Path
from typing import Union
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens
from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as TransformersMistralTokenizer,
)
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
from tensorrt_llm.logger import logger
# Adapted from:
# https://github.com/vllm-project/vllm/blob/8e67b2557aae7204c697d7a5c61e00754da465be/vllm/transformers_utils/tokenizers/mistral.py#L166
class MistralTokenizer(TransformersTokenizer):
def __init__(self, tokenizer: "TransformersMistralTokenizer"):
self.transformers_tokenizer = tokenizer
self.mistral = tokenizer.tokenizer
self.instruct = self.mistral.instruct_tokenizer
self.tokenizer = self.instruct.tokenizer
_mistral_version_str = str(self.tokenizer.version.value)
self.version: int = int(_mistral_version_str.split("v")[-1])
self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
if not (self.is_tekken or self.is_spm):
raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")
# Reverse order to ensure that the lowest token id is kept.
self._vocab_dict = {
self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
for i in range(self.transformers_tokenizer.vocab_size - 1, -1, -1)
}
# Sort the dict for convenience
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
# Cache special tokens for faster access.
self._special_token_ids = self._get_special_token_ids()
self._special_token_ids_set = set(self._special_token_ids)
self._special_tokens = self._get_special_tokens(self._special_token_ids)
self._special_tokens_set = set(self._special_tokens)
# Vocab sorted by token id.
self._vocab = self.tokenizer._vocab
self._max_token_id = self.transformers_tokenizer.vocab_size - 1
self._all_special_tokens_set = set(self.all_special_tokens)
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in all_special_ids
]
# the following attributes are set to fit vLLM's design and are used
# by the structured output backends.
@property
def all_special_tokens_extended(self) -> list[str]:
return self.all_special_tokens
@property
def all_special_tokens(self) -> list[str]:
return self._special_tokens
@property
def all_special_ids(self) -> list[int]:
return self._special_token_ids
@classmethod
def from_pretrained(cls, pretrained_model_dir: str, **kwargs):
if Path(pretrained_model_dir).is_file():
tokenizer = TransformersMistralTokenizer(tokenizer_path=pretrained_model_dir)
else:
tokenizer = TransformersMistralTokenizer.from_pretrained(pretrained_model_dir)
return cls(tokenizer)
def _get_special_token_ids(self) -> list[int]:
from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
elif self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(self.tokenizer)
special_ids = self.tokenizer._control_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return sorted(special_ids)
@property
def bos_token_id(self) -> int:
return self.tokenizer.bos_id
@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_id
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
return self.transformers_tokenizer.pad_token
@property
def pad_token_id(self) -> int:
return self.transformers_tokenizer.pad_token_id
def __call__(self, text: str, *args, **kwargs) -> any:
return self.transformers_tokenizer(text=text, *args, **kwargs)
@property
def name_or_path(self) -> str:
return self.transformers_tokenizer.name_or_path
def batch_encode_plus(self, texts: list[str], *args, **kwargs) -> dict:
raise NotImplementedError
def get_chat_template(
self, chat_template: str | None = None, tools: list[dict] | None = None
) -> str:
raise NotImplementedError
def clean_up_tokenization(self, out_string: str) -> str:
raise NotImplementedError
@property
def is_fast(self) -> bool:
return True
def get_added_vocab(self) -> dict[str, int]:
# Mistral tokenizers have no added vocabulary
return {}
def _tekken_token_to_id(self, tokenizer: "Tekkenizer", t: str | bytes) -> int:
assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
shift = tokenizer.num_special_tokens
try:
return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
except KeyError:
t_str = t_bytes.decode("utf-8")
if t_str in tokenizer._special_tokens_reverse_vocab:
return tokenizer._special_tokens_reverse_vocab[t_str]
logger.warning("Failed to convert token %s to id, replacing with <unk>", t_bytes)
return tokenizer.unk_id
def _is_special_token_id(self, token_id: int) -> bool:
return token_id in self._special_token_ids_set
def convert_tokens_to_string(
self,
tokens: list[str],
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> str:
to_decode_special_tokens = {SpecialTokens.tool_calls}
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
tokens = [
t
for t in tokens
if (t in to_decode_special_tokens or t not in self._special_tokens_set)
]
if any(isinstance(t, bytes) for t in tokens):
# we need to encode and decode all tokens again
ids = [self._tekken_token_to_id(self.tokenizer, t) for t in tokens]
# We filtered unwanted special tokens before
# so we can decode the rest.
decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
else:
decoded = "".join(tokens)
else:
# make sure certain special tokens like Tool calls are
# not decoded
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(self.tokenizer)
regular_tokens: list[str] = []
decoded_list: list[str] = []
decoded = ""
for token in tokens:
if token in to_decode_special_tokens:
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
)
regular_tokens = []
decoded_list.append(token)
else:
regular_tokens.append(token)
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
)
decoded = "".join(decoded_list)
return decoded
def encode(
self,
text: str,
truncation: bool | None = None,
max_length: int | None = None,
add_special_tokens: bool | None = None,
) -> list[int]:
if add_special_tokens is not None:
return self.transformers_tokenizer.encode(
text,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
else:
encoded = self.tokenizer.encode(text, bos=True, eos=False)
if truncation is not False and max_length is not None:
return encoded[:max_length]
else:
return encoded
def decode(
self, token_ids: list[int] | int, skip_special_tokens: bool = True, *args, **kwargs
) -> str:
return self.transformers_tokenizer.decode(
token_ids, skip_special_tokens=skip_special_tokens
)
def convert_ids_to_tokens(
self,
ids: list[int],
skip_special_tokens: bool = True,
) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
if not skip_special_tokens:
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
non_skip_special_tokens_ids = {
self.tokenizer.get_control_token(SpecialTokens.tool_calls),
}
if isinstance(self.instruct, InstructTokenizerV13):
if self.instruct.BEGIN_THINK:
non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
if self.instruct.END_THINK:
non_skip_special_tokens_ids.add(self.instruct.END_THINK)
ids_kept = [
i for i in ids if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
]
# We filtered unwanted special tokens so we can decode the rest.
tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
if any("<EFBFBD>" in t for t in tokens) and self.is_tekken:
# if a decoded token contains the replacement character, then the
# token has an incomplete UTF-8 character so we must use bytes
# See: https://github.com/vllm-project/vllm/pull/8640
# https://github.com/vllm-project/vllm/pull/9625
# if underlying tokenizer is sentencepiece, we just add "<22>".
# We filtered unwanted special tokens so we can decode the rest.
tokens = [
self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
if token_id not in self._special_token_ids_set
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
for token_id in ids_kept
]
return tokens
@property
def vocab_size(self) -> int:
return len(self._vocab_dict)
@property
def clean_up_tokenization_spaces(self):
return False
def hf_decode_incrementally(
self,
token_ids: list[int],
prev_text: str | None = None,
states: dict | None = None,
*,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool | None = None,
) -> tuple[str, dict]:
raise NotImplementedError
def apply_chat_template(
self, conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], *args, **kwargs
) -> Union[str, list[int], list[str], list[list[int]]]:
raise NotImplementedError

View File

@ -4,6 +4,8 @@ from typing import Any, Dict, List, Tuple
import torch
import torchvision
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
from torch import nn
from transformers import (AutoProcessor, AutoTokenizer, Mistral3Config,
MistralConfig, PretrainedConfig, PreTrainedModel)
@ -14,6 +16,8 @@ from tensorrt_llm._torch.attention_backend.interface import (
PositionalEmbeddingParams, RopeParams)
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models import modeling_pixtral
from tensorrt_llm._torch.models.checkpoints.mistral.tokenizer import \
MistralTokenizer
from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import \
MistralWeightMapper
from tensorrt_llm._torch.models.modeling_mistral_large3 import (
@ -214,6 +218,96 @@ class MistralForCausalLM(DecoderModelForCausalLM[MistralModel, MistralConfig]):
)
class MistralCommonImageProcessor:
def __init__(self, tokenizer: MistralTokenizer, dtype) -> None:
super().__init__()
self.tokenizer = tokenizer
self.dtype = dtype
@property
def image_processor(self) -> ImageEncoder:
image_encoder = self.tokenizer.instruct.mm_encoder
assert isinstance(image_encoder, ImageEncoder)
return image_encoder
@property
def image_break_id(self) -> int:
return self.image_processor.special_ids.img_break
@property
def image_break_token_id(self) -> int:
return self.image_break_id
@property
def image_token_id(self) -> int:
return self.image_processor.special_ids.img
@property
def image_end_id(self) -> int:
return self.image_processor.special_ids.img_end
@property
def image_end_token_id(self):
return self.image_end_id
@property
def image_size(self) -> int:
return self.image_processor.mm_config.max_image_size
@property
def patch_size(self) -> int:
return self.image_processor.mm_config.image_patch_size
def _get_num_multimodal_tokens(self, image_sizes):
return {
"num_image_tokens":
[self.get_num_tokens_per_image(size) for size in image_sizes]
}
def get_num_tokens_per_image(self, image_sizes):
# FIXME avoid double loading with custom loader
h, w = image_sizes
ncols, nrows = self.image_processor._image_to_num_tokens(
Image.new("RGB", (w, h)))
return ncols * nrows + nrows
def __call__(self, text, images, media, **kwargs):
assert media is not None
if isinstance(media, str):
media = [media]
mm_items = [{"type": "image_url", "image_url": url} for url in media]
logger.debug(f"text: {text}")
conversation = [{
"role": "user",
"content": [{
"type": "text",
"text": text
}, *mm_items]
}]
encoded = self.tokenizer.transformers_tokenizer.apply_chat_template(
conversation, tokenize=True, return_dict=True, return_tensors='pt')
logger.debug(
f"encoded.pixel_values.shape: {encoded.pixel_values.shape}, encoded.input_ids: {encoded.input_ids[0][-20:]}"
)
logger.debug(
f"encoded.input_ids list: {self.tokenizer.transformers_tokenizer.apply_chat_template(conversation)}"
)
processed = {
"input_ids": encoded.input_ids,
"pixel_values": encoded.pixel_values.to(self.dtype),
"attention_mask": encoded.attention_mask,
"image_sizes": torch.tensor([encoded.pixel_values.shape[2:]])
}
return processed
class Mistral3InputProcessor(BaseMultimodalInputProcessor,
BaseMultimodalDummyInputsBuilder):
@ -223,6 +317,7 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
config: PretrainedConfig,
tokenizer: AutoTokenizer | None,
trust_remote_code: bool = False,
model_type: str = "mistral3",
**kwargs,
):
super().__init__(model_path=model_path,
@ -233,12 +328,24 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
self._config = config
self._dtype = self._config.torch_dtype
self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
model_path)
self._model_path = model_path
self._processor = AutoProcessor.from_pretrained(
model_path,
config=config,
use_fast=self.use_fast,
trust_remote_code=trust_remote_code)
self._model_path = model_path
if model_type == "mistral_large_3":
self._processor = MistralCommonImageProcessor(
tokenizer=self._tokenizer, dtype=self.dtype)
self.text_processor = AutoProcessor.from_pretrained(
model_path,
use_fast=self.use_fast,
trust_remote_code=trust_remote_code)
else:
self._processor = AutoProcessor.from_pretrained(
model_path,
use_fast=self.use_fast,
trust_remote_code=trust_remote_code)
self.text_processor = None
@property
def config(self) -> PretrainedConfig:
@ -273,12 +380,20 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
# format is "pt" (pytorch tensors), but not for "pil" (PIL images).
do_rescale = False
processed = self.processor(
text=inputs["prompt"],
images=images,
do_rescale=do_rescale,
**mm_processor_kwargs,
)
if mm_processor_kwargs:
# Currently, we only support image modality in MistralCommonImageProcessor.
processed = self.processor(
text=inputs["prompt"],
images=images,
do_rescale=do_rescale,
**mm_processor_kwargs,
)
else:
processed = self.text_processor(
text=inputs["prompt"],
images=images,
do_rescale=do_rescale,
)
input_ids = processed.pop("input_ids").tolist()[0]
# Remaining in `processed`:
# * "attention_mask": [B, num_input_tokens]
@ -332,8 +447,56 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
])
class MistralCommonInputProcessor(Mistral3InputProcessor):
def __init__(
self,
model_path: str,
config: PretrainedConfig,
tokenizer: AutoTokenizer,
trust_remote_code: bool = False,
**kwargs,
):
tokenizer = self.load_tokenizer(model_path, config=config)
super().__init__(model_path=model_path,
config=config,
tokenizer=tokenizer,
trust_remote_code=trust_remote_code,
model_type="mistral_large_3",
**kwargs)
@staticmethod
def load_tokenizer(model_path: str,
config: PretrainedConfig,
checkpoint_format: str = "mistral_large_3"):
if checkpoint_format == "mistral_large_3":
try:
return MistralTokenizer.from_pretrained(model_path)
except ValueError:
logger.info(
f"Could not load mistral-common tokenizer from {model_path}, falling back to HuggingFace"
)
tokenizer = AutoTokenizer.from_pretrained(model_path,
config=config,
use_fast=True,
trust_remote_code=True)
return tokenizer
@register_auto_model("Mistral3ForConditionalGeneration")
@register_auto_model("PixtralForConditionalGeneration")
@register_input_processor(
MistralCommonInputProcessor,
model_type="mistral_large_3",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
# NOTE: mistral-common uses the tokenizer to set placeholders, this will be ignored
"image": "[IMG]",
},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
@register_input_processor(
Mistral3InputProcessor,
model_type="mistral3",
@ -410,7 +573,6 @@ class Mistral3VLM(PreTrainedModel):
self._multi_modal_projector = Mistral3MultiModalProjector(
model_config).eval().to(self._device)
self._post_config()
self.is_loaded = True
# This is necessary because the executor looks at
# `model.model_config.pretrained_config.vocab_size`.
@ -766,11 +928,3 @@ class Mistral3MultiModalProjector(torch.nn.Module):
def load_weights(self, weights):
_load_weights_impl(self, weights)
def _filter_weights(weights: Dict[str, torch.Tensor],
prefix: str) -> Dict[str, torch.Tensor]:
return {
name[len(prefix):]: weight
for name, weight in weights.items() if name.startswith(prefix)
}

View File

@ -4,8 +4,6 @@ import torch
from torch import nn
from transformers import LlamaConfig, PretrainedConfig
from tensorrt_llm.logger import logger
from ...functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
@ -754,19 +752,6 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
skip_create_weights_in_init=True,
)
self.draft_config.extra_attrs = model_config.extra_attrs
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.quantization.mode import QuantAlgo
if get_sm_version(
) == 100 and self.draft_config.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
# FIXME There is a known issue on TRTLLM moe backend + FP8 blockwise
logger.warning(
"Switching moe_backend of draft model to DEEPGEMM for FP8_BLOCK_SCALES quantization on SM100"
"This is a workaround for the known issue on TRTLLM moe backend + FP8 blockwise"
)
self.draft_config._frozen = False
self.draft_config.moe_backend = "DEEPGEMM"
self.draft_config._frozen = True
elif spec_config.eagle3_model_arch == "llama3":
self.draft_config = ModelConfig.from_pretrained(
model_config.spec_config.speculative_model_dir,

View File

@ -374,8 +374,10 @@ NOTE:
placeholder for the model needs to be added in retrieve_multimodal_placeholder().
"""
HF_CHAT_TEMPLATE_EXCEPTIONS = ["llava_llama"]
PLACEHOLDER_EXCEPTIONS = ["llava_next", "NemotronH_Nano_VL_V2"]
HF_CHAT_TEMPLATE_EXCEPTIONS = ["llava_llama", "mistral_large_3"]
PLACEHOLDER_EXCEPTIONS = [
"llava_next", "NemotronH_Nano_VL_V2", "mistral_large_3"
]
# Helpers to always get the latest supported multimodal model types from the registry
@ -771,6 +773,13 @@ def default_multimodal_input_loader(
add_generation_prompt=True,
mm_placeholder_counts=[mm_placeholder_counts])
input = {"prompt": prompt}
# When the tokenizer is a MistralTokenizer, we need to keep the source media to handle in processor later.
from tensorrt_llm._torch.models.checkpoints.mistral.tokenizer import \
MistralTokenizer
if isinstance(tokenizer, MistralTokenizer):
input["mm_processor_kwargs"] = {"media": media}
if mm_placeholder_counts:
if mm_embeddings is not None:
input[

View File

@ -21,3 +21,5 @@ microsoft/Phi-4-multimodal-instruct:
- accuracy: 53.67
Qwen/Qwen3-VL-30B-A3B-Instruct:
- accuracy: 55.33
mistral/Mistral-Large-3-675B:
- accuracy: 60.00

View File

@ -1,9 +1,9 @@
import pytest
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, MoeConfig, SamplingParams
from ..conftest import llm_models_root
from ..conftest import llm_models_root, skip_pre_blackwell
from .accuracy_core import MMMU, LlmapiAccuracyTestHarness
@ -263,3 +263,52 @@ class TestQwen3VL_MOE(LlmapiAccuracyTestHarness):
) as llm:
task = MMMU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=self.sampling_params)
class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
MODEL_NAME = "mistral/Mistral-Large-3-675B"
MODEL_PATH = (
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-NVFP4/"
)
MAX_NUM_TOKENS = 16384
sampling_params = SamplingParams(
max_tokens=MAX_NUM_TOKENS, truncate_prompt_tokens=MMMU.MAX_INPUT_LEN, stop="<|endoftext|>"
)
@skip_pre_blackwell
@pytest.mark.skip_less_mpi_world_size(4)
@pytest.mark.skip_less_device_memory(183000)
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend",
[
(4, 1, 4, False, True, True, "TRTLLM"),
],
ids=[
"latency_moe_trtllm",
],
)
def test_nvfp4_4gpus(
self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, overlap_scheduler, moe_backend
):
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
moe_config=MoeConfig(backend=moe_backend),
)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
with LLM(
self.MODEL_PATH,
max_num_tokens=self.MAX_NUM_TOKENS,
checkpoint_format="mistral",
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
moe_expert_parallel_size=ep_size,
**pytorch_config,
enable_attention_dp=attention_dp,
kv_cache_config=kv_cache_config,
) as llm:
task = MMMU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=self.sampling_params, model_type="mistral_large_3")

View File

@ -71,6 +71,7 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
- condition:
ranges:
system_gpu_count:
@ -107,4 +108,4 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus_online_eplb[enable_configurable_moe-fp8]
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
- accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90)

View File

@ -517,3 +517,4 @@ test_e2e.py::test_eagle3_output_consistency_4gpus[Qwen3/saved_models_Qwen3-235B-
examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5769855)
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8] SKIP (https://nvbugs/5769890)
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/5769890)
accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm]