[None][fix] Mistral large 3 few code refine (#10405)

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
bhsueh_NV 2026-01-08 19:38:49 +08:00 committed by GitHub
parent dc6b743fb6
commit bea61bb17d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 84 additions and 58 deletions

View File

@ -19,7 +19,8 @@ mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickst
--max_tokens 100 \
--checkpoint_format mistral \
--model_type mistral_large_3 \
--moe_backend TRTLLM
--moe_backend TRTLLM \
--image_format pil
```
## LLM-only run

View File

@ -103,17 +103,14 @@ def _remap_mistral_yarn_args(config: dict) -> dict:
"apply_scale": "apply_yarn_scaling",
}
yarn_config = config.get("yarn") or {}
config["rope_parameters"] = {
config["rope_scaling"] = {
"rope_type": "yarn",
"mscale_all_dim": 1,
}
if rope_theta := config.pop("rope_theta", None):
config["rope_parameters"]["rope_theta"] = rope_theta
for old_name, new_name in yarn_config_map.items():
if old_name in yarn_config:
config["rope_parameters"][new_name] = yarn_config.pop(old_name)
config["rope_scaling"][new_name] = yarn_config.pop(old_name)
assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"

View File

@ -46,6 +46,7 @@ from tensorrt_llm.inputs import (BaseMultimodalDummyInputsBuilder,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from tensorrt_llm.inputs.multimodal import MultimodalParams
from tensorrt_llm.inputs.utils import encode_base64_image
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.logger import logger
@ -58,16 +59,28 @@ class MistralAttention(Attention):
layer_idx: int | None = None,
):
config = model_config.pretrained_config
rope_params = RopeParams.from_config(config)
rope_params_section = getattr(config, "rope_scaling", None) or getattr(
config, "rope_parameters", None)
rope_type = getattr(rope_params_section, "rope_type", None)
if rope_type == "yarn":
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.yarn,
rope=rope_params,
is_neox=False)
else:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=rope_params,
)
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
),
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
@ -266,20 +279,18 @@ class MistralCommonImageProcessor:
}
def get_num_tokens_per_image(self, image_sizes):
# FIXME avoid double loading with custom loader
h, w = image_sizes
ncols, nrows = self.image_processor._image_to_num_tokens(
Image.new("RGB", (w, h)))
return ncols * nrows + nrows
def __call__(self, text, images, media, **kwargs):
assert media is not None
if isinstance(media, str):
media = [media]
mm_items = [{"type": "image_url", "image_url": url} for url in media]
logger.debug(f"text: {text}")
def __call__(self, text, images, **kwargs):
mm_items = []
if images:
mm_items = [{
"type": "image",
"base64": encode_base64_image(image)
} for image in images]
conversation = [{
"role": "user",
@ -292,19 +303,20 @@ class MistralCommonImageProcessor:
encoded = self.tokenizer.transformers_tokenizer.apply_chat_template(
conversation, tokenize=True, return_dict=True, return_tensors='pt')
logger.debug(
f"encoded.pixel_values.shape: {encoded.pixel_values.shape}, encoded.input_ids: {encoded.input_ids[0][-20:]}"
)
logger.debug(
f"encoded.input_ids list: {self.tokenizer.transformers_tokenizer.apply_chat_template(conversation)}"
)
processed = {
"input_ids": encoded.input_ids,
"pixel_values": encoded.pixel_values.to(self.dtype),
"attention_mask": encoded.attention_mask,
"image_sizes": torch.tensor([encoded.pixel_values.shape[2:]])
}
# text-only mode for VLM
if "pixel_values" in encoded:
processed.update({
"pixel_values":
encoded.pixel_values.to(self.dtype),
"attention_mask":
encoded.attention_mask,
"image_sizes":
torch.tensor([encoded.pixel_values.shape[2:]])
})
return processed
@ -376,7 +388,6 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
self, inputs: TextPrompt, sampling_params: SamplingParams
) -> Tuple[List[int], ExtraProcessedInputs | None]:
images = inputs.get("multi_modal_data", {}).get("image")
mm_processor_kwargs = inputs.get("mm_processor_kwargs", {})
do_rescale = getattr(self.processor.image_processor, "do_rescale",
False)
if images is not None and isinstance(images[0], torch.Tensor):
@ -384,18 +395,15 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
# format is "pt" (pytorch tensors), but not for "pil" (PIL images).
do_rescale = False
if mm_processor_kwargs:
# Currently, we only support image modality in MistralCommonImageProcessor.
if images is not None:
processed = self.processor(
text=inputs["prompt"],
images=images,
do_rescale=do_rescale,
**mm_processor_kwargs,
)
else:
processed = self.text_processor(
text=inputs["prompt"],
images=images,
do_rescale=do_rescale,
)
input_ids = processed.pop("input_ids").tolist()[0]

View File

@ -52,7 +52,9 @@ class LmEvalWrapper(TemplateLM):
llm: Union[LLM, PyTorchLLM],
sampling_params: Optional[SamplingParams] = None,
streaming: bool = False,
chat_template_kwargs: Optional[dict[str, Any]] = None):
chat_template_kwargs: Optional[dict[str, Any]] = None,
model_type: str | None = None,
is_force_single_image: bool = False):
super().__init__()
self.llm = llm
self.sampling_params = sampling_params
@ -163,7 +165,9 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
sampling_params: Optional[SamplingParams] = None,
streaming: bool = False,
max_images: int = 999,
chat_template_kwargs: Optional[dict[str, Any]] = None):
chat_template_kwargs: Optional[dict[str, Any]] = None,
model_type: str | None = None,
is_force_single_image: bool = False):
"""
Initialize the multimodal wrapper.
@ -179,7 +183,9 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
self.MULTIMODAL = True
self.max_images = max_images
self.chat_template_kwargs = chat_template_kwargs
self.model_type = self._get_model_type(llm)
self.model_type = model_type if model_type is not None else self._get_model_type(
llm)
self.is_force_single_image = is_force_single_image
# NOTE: In TRT-LLM, currently we do not support interleaved text and image. Instead, we are adding image placeholders at the end of the text or at the beginning of the text.
# So, until we support interleaved text and image, we set this to False.
@ -287,8 +293,13 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
prompt = prompt_inputs(prompt)
# NOTE: Convert RGBA format to RGB format
if self.is_force_single_image:
# NOTE: This is a workaround to force single image for models which only support single image.
images = [convert_image_mode(media_data["visual"][0], "RGB")]
else:
images = [
convert_image_mode(img, "RGB") for img in media_data["visual"]
convert_image_mode(img, "RGB")
for img in media_data["visual"]
]
prompt["multi_modal_data"] = {"image": images}
@ -429,14 +440,18 @@ class LmEvalEvaluator(Evaluator):
llm: Union[LLM, PyTorchLLM],
sampling_params: Optional[SamplingParams] = None,
streaming: bool = False,
scores_filter: str = None) -> float:
scores_filter: str = None,
model_type: str = None,
is_force_single_image: bool = False) -> float:
import lm_eval
lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper
results = lm_eval.evaluate(
lm=lm_cls(llm,
sampling_params=sampling_params,
streaming=streaming,
chat_template_kwargs=self.chat_template_kwargs),
chat_template_kwargs=self.chat_template_kwargs,
model_type=model_type,
is_force_single_image=is_force_single_image),
task_dict=self.task_dict,
limit=self.num_samples,
apply_chat_template=self.apply_chat_template,

View File

@ -774,12 +774,6 @@ def default_multimodal_input_loader(
mm_placeholder_counts=[mm_placeholder_counts])
input = {"prompt": prompt}
# When the tokenizer is a MistralTokenizer, we need to keep the source media to handle in processor later.
from tensorrt_llm._torch.models.checkpoints.mistral.tokenizer import \
MistralTokenizer
if isinstance(tokenizer, MistralTokenizer):
input["mm_processor_kwargs"] = {"media": media}
if mm_placeholder_counts:
if mm_embeddings is not None:
input[

View File

@ -402,6 +402,8 @@ class MMMU(AccuracyTask):
is_multimodal=True,
apply_chat_template=True)
EVALUATE_KWARGS = dict(model_type=None, is_force_single_image=False)
class PassKeyRetrieval64k(AccuracyTask):
DATASET = "passkey_retrieval_64k"

View File

@ -345,9 +345,9 @@ mistralai/Mistral-Nemo-12b-Base:
- quant_algo: FP8
accuracy: 69.66
mistral/Mistral-Large-3-675B:
- accuracy: 87.54
- accuracy: 85.30
- spec_dec_algo: Eagle
accuracy: 87.54
accuracy: 85.30
nvidia/Nemotron-Super-V3:
- accuracy: 81.07
- quant_algo: NVFP4

View File

@ -25,4 +25,5 @@ microsoft/Phi-4-multimodal-instruct:
Qwen/Qwen3-VL-30B-A3B-Instruct:
- accuracy: 55.33
mistral/Mistral-Large-3-675B:
- accuracy: 60.00
# Mistral Large 3 675B only supports single image input, so accuracy is lower.
- accuracy: 47

View File

@ -293,8 +293,19 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
],
)
def test_nvfp4_4gpus(
self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, overlap_scheduler, moe_backend
self,
tp_size,
pp_size,
ep_size,
attention_dp,
cuda_graph,
overlap_scheduler,
moe_backend,
mocker,
):
mocker.patch.dict(
MMMU.EVALUATE_KWARGS, {"model_type": "mistral_large_3", "is_force_single_image": True}
)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
@ -315,4 +326,4 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
kv_cache_config=kv_cache_config,
) as llm:
task = MMMU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=self.sampling_params, model_type="mistral_large_3")
task.evaluate(llm, sampling_params=self.sampling_params)

View File

@ -72,7 +72,7 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
- accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90)
- condition:
ranges:
system_gpu_count:
@ -105,4 +105,4 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus_online_eplb[enable_configurable_moe-fp8]
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90)
- accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)

View File

@ -441,8 +441,6 @@ test_e2e.py::test_ptp_quickstart_advanced_2gpus_sm120[Nemotron-Super-49B-v1-BF16
unittest/_torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-strategy:8-dtype:bfloat16-hidden:8192-seqlen:[15]] SKIP (https://nvbugs/5761364)
triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP (https://nvbugs/5762854)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822)
accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] SKIP (https://nvbugs/5762852)
accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] SKIP (https://nvbugs/5762852)
unittest/_torch/sampler/test_return_logits.py SKIP (https://nvbugs/5764627)
examples/serve/test_serve.py::test_config_file_loading[--config] SKIP (https://nvbugs/5754977)
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5740075)
@ -456,7 +454,6 @@ full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5768068)
test_e2e.py::test_eagle3_output_consistency_4gpus[Qwen3/saved_models_Qwen3-235B-A22B_fp8_hf-Qwen3/qwen3-235B-eagle3] SKIP (https://nvbugs/5685010)
examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5769855)
accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] SKIP (TBD)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8] SKIP (https://nvbugs/5772396)
full:sm100/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-auto] SKIP (https://nvbugs/5772396)
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model_trtllm] SKIP (https://nvbugs/5772360)