diff --git a/examples/models/core/mistral_large_3/README.md b/examples/models/core/mistral_large_3/README.md index 5ea730c9f1..da219bf7b0 100644 --- a/examples/models/core/mistral_large_3/README.md +++ b/examples/models/core/mistral_large_3/README.md @@ -19,7 +19,8 @@ mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickst --max_tokens 100 \ --checkpoint_format mistral \ --model_type mistral_large_3 \ - --moe_backend TRTLLM + --moe_backend TRTLLM \ + --image_format pil ``` ## LLM-only run diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py index b72cb6da38..c679734fcf 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py @@ -103,17 +103,14 @@ def _remap_mistral_yarn_args(config: dict) -> dict: "apply_scale": "apply_yarn_scaling", } yarn_config = config.get("yarn") or {} - config["rope_parameters"] = { + config["rope_scaling"] = { "rope_type": "yarn", "mscale_all_dim": 1, } - if rope_theta := config.pop("rope_theta", None): - config["rope_parameters"]["rope_theta"] = rope_theta - for old_name, new_name in yarn_config_map.items(): if old_name in yarn_config: - config["rope_parameters"][new_name] = yarn_config.pop(old_name) + config["rope_scaling"][new_name] = yarn_config.pop(old_name) assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}" diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index ea06b5e100..99ff8169c1 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -46,6 +46,7 @@ from tensorrt_llm.inputs import (BaseMultimodalDummyInputsBuilder, MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) from tensorrt_llm.inputs.multimodal import MultimodalParams +from tensorrt_llm.inputs.utils import encode_base64_image from tensorrt_llm.llmapi import SamplingParams from tensorrt_llm.logger import logger @@ -58,16 +59,28 @@ class MistralAttention(Attention): layer_idx: int | None = None, ): config = model_config.pretrained_config + rope_params = RopeParams.from_config(config) + rope_params_section = getattr(config, "rope_scaling", None) or getattr( + config, "rope_parameters", None) + rope_type = getattr(rope_params_section, "rope_type", None) + if rope_type == "yarn": + pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.yarn, + rope=rope_params, + is_neox=False) + else: + pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.rope_gpt_neox, + rope=rope_params, + ) + super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, bias=False, - pos_embd_params=PositionalEmbeddingParams( - type=PositionEmbeddingType.rope_gpt_neox, - rope=RopeParams.from_config(config), - ), + pos_embd_params=pos_embd_params, layer_idx=layer_idx, dtype=config.torch_dtype, config=model_config, @@ -266,20 +279,18 @@ class MistralCommonImageProcessor: } def get_num_tokens_per_image(self, image_sizes): - # FIXME avoid double loading with custom loader h, w = image_sizes ncols, nrows = self.image_processor._image_to_num_tokens( Image.new("RGB", (w, h))) return ncols * nrows + nrows - def __call__(self, text, images, media, **kwargs): - assert media is not None - if isinstance(media, str): - media = [media] - - mm_items = [{"type": "image_url", "image_url": url} for url in media] - - logger.debug(f"text: {text}") + def __call__(self, text, images, **kwargs): + mm_items = [] + if images: + mm_items = [{ + "type": "image", + "base64": encode_base64_image(image) + } for image in images] conversation = [{ "role": "user", @@ -292,19 +303,20 @@ class MistralCommonImageProcessor: encoded = self.tokenizer.transformers_tokenizer.apply_chat_template( conversation, tokenize=True, return_dict=True, return_tensors='pt') - logger.debug( - f"encoded.pixel_values.shape: {encoded.pixel_values.shape}, encoded.input_ids: {encoded.input_ids[0][-20:]}" - ) - logger.debug( - f"encoded.input_ids list: {self.tokenizer.transformers_tokenizer.apply_chat_template(conversation)}" - ) - processed = { "input_ids": encoded.input_ids, - "pixel_values": encoded.pixel_values.to(self.dtype), - "attention_mask": encoded.attention_mask, - "image_sizes": torch.tensor([encoded.pixel_values.shape[2:]]) } + + # text-only mode for VLM + if "pixel_values" in encoded: + processed.update({ + "pixel_values": + encoded.pixel_values.to(self.dtype), + "attention_mask": + encoded.attention_mask, + "image_sizes": + torch.tensor([encoded.pixel_values.shape[2:]]) + }) return processed @@ -376,7 +388,6 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor, self, inputs: TextPrompt, sampling_params: SamplingParams ) -> Tuple[List[int], ExtraProcessedInputs | None]: images = inputs.get("multi_modal_data", {}).get("image") - mm_processor_kwargs = inputs.get("mm_processor_kwargs", {}) do_rescale = getattr(self.processor.image_processor, "do_rescale", False) if images is not None and isinstance(images[0], torch.Tensor): @@ -384,18 +395,15 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor, # format is "pt" (pytorch tensors), but not for "pil" (PIL images). do_rescale = False - if mm_processor_kwargs: - # Currently, we only support image modality in MistralCommonImageProcessor. + if images is not None: processed = self.processor( text=inputs["prompt"], images=images, do_rescale=do_rescale, - **mm_processor_kwargs, ) else: processed = self.text_processor( text=inputs["prompt"], - images=images, do_rescale=do_rescale, ) input_ids = processed.pop("input_ids").tolist()[0] diff --git a/tensorrt_llm/evaluate/lm_eval.py b/tensorrt_llm/evaluate/lm_eval.py index a3a59c3f5f..4a877d75f4 100644 --- a/tensorrt_llm/evaluate/lm_eval.py +++ b/tensorrt_llm/evaluate/lm_eval.py @@ -52,7 +52,9 @@ class LmEvalWrapper(TemplateLM): llm: Union[LLM, PyTorchLLM], sampling_params: Optional[SamplingParams] = None, streaming: bool = False, - chat_template_kwargs: Optional[dict[str, Any]] = None): + chat_template_kwargs: Optional[dict[str, Any]] = None, + model_type: str | None = None, + is_force_single_image: bool = False): super().__init__() self.llm = llm self.sampling_params = sampling_params @@ -163,7 +165,9 @@ class MultimodalLmEvalWrapper(LmEvalWrapper): sampling_params: Optional[SamplingParams] = None, streaming: bool = False, max_images: int = 999, - chat_template_kwargs: Optional[dict[str, Any]] = None): + chat_template_kwargs: Optional[dict[str, Any]] = None, + model_type: str | None = None, + is_force_single_image: bool = False): """ Initialize the multimodal wrapper. @@ -179,7 +183,9 @@ class MultimodalLmEvalWrapper(LmEvalWrapper): self.MULTIMODAL = True self.max_images = max_images self.chat_template_kwargs = chat_template_kwargs - self.model_type = self._get_model_type(llm) + self.model_type = model_type if model_type is not None else self._get_model_type( + llm) + self.is_force_single_image = is_force_single_image # NOTE: In TRT-LLM, currently we do not support interleaved text and image. Instead, we are adding image placeholders at the end of the text or at the beginning of the text. # So, until we support interleaved text and image, we set this to False. @@ -287,9 +293,14 @@ class MultimodalLmEvalWrapper(LmEvalWrapper): prompt = prompt_inputs(prompt) # NOTE: Convert RGBA format to RGB format - images = [ - convert_image_mode(img, "RGB") for img in media_data["visual"] - ] + if self.is_force_single_image: + # NOTE: This is a workaround to force single image for models which only support single image. + images = [convert_image_mode(media_data["visual"][0], "RGB")] + else: + images = [ + convert_image_mode(img, "RGB") + for img in media_data["visual"] + ] prompt["multi_modal_data"] = {"image": images} sampling_params = self._get_sampling_params(gen_kwargs) @@ -429,14 +440,18 @@ class LmEvalEvaluator(Evaluator): llm: Union[LLM, PyTorchLLM], sampling_params: Optional[SamplingParams] = None, streaming: bool = False, - scores_filter: str = None) -> float: + scores_filter: str = None, + model_type: str = None, + is_force_single_image: bool = False) -> float: import lm_eval lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper results = lm_eval.evaluate( lm=lm_cls(llm, sampling_params=sampling_params, streaming=streaming, - chat_template_kwargs=self.chat_template_kwargs), + chat_template_kwargs=self.chat_template_kwargs, + model_type=model_type, + is_force_single_image=is_force_single_image), task_dict=self.task_dict, limit=self.num_samples, apply_chat_template=self.apply_chat_template, diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index a6f7e49fa8..bbbd5f4f8f 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -774,12 +774,6 @@ def default_multimodal_input_loader( mm_placeholder_counts=[mm_placeholder_counts]) input = {"prompt": prompt} - # When the tokenizer is a MistralTokenizer, we need to keep the source media to handle in processor later. - from tensorrt_llm._torch.models.checkpoints.mistral.tokenizer import \ - MistralTokenizer - if isinstance(tokenizer, MistralTokenizer): - input["mm_processor_kwargs"] = {"media": media} - if mm_placeholder_counts: if mm_embeddings is not None: input[ diff --git a/tests/integration/defs/accuracy/accuracy_core.py b/tests/integration/defs/accuracy/accuracy_core.py index f96ac7d618..e30c6e2c2c 100644 --- a/tests/integration/defs/accuracy/accuracy_core.py +++ b/tests/integration/defs/accuracy/accuracy_core.py @@ -402,6 +402,8 @@ class MMMU(AccuracyTask): is_multimodal=True, apply_chat_template=True) + EVALUATE_KWARGS = dict(model_type=None, is_force_single_image=False) + class PassKeyRetrieval64k(AccuracyTask): DATASET = "passkey_retrieval_64k" diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index a0e38d67c1..9cbd7a9f73 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -345,9 +345,9 @@ mistralai/Mistral-Nemo-12b-Base: - quant_algo: FP8 accuracy: 69.66 mistral/Mistral-Large-3-675B: - - accuracy: 87.54 + - accuracy: 85.30 - spec_dec_algo: Eagle - accuracy: 87.54 + accuracy: 85.30 nvidia/Nemotron-Super-V3: - accuracy: 81.07 - quant_algo: NVFP4 diff --git a/tests/integration/defs/accuracy/references/mmmu.yaml b/tests/integration/defs/accuracy/references/mmmu.yaml index a2fb8f4a77..37819c3f14 100644 --- a/tests/integration/defs/accuracy/references/mmmu.yaml +++ b/tests/integration/defs/accuracy/references/mmmu.yaml @@ -25,4 +25,5 @@ microsoft/Phi-4-multimodal-instruct: Qwen/Qwen3-VL-30B-A3B-Instruct: - accuracy: 55.33 mistral/Mistral-Large-3-675B: - - accuracy: 60.00 +# Mistral Large 3 675B only supports single image input, so accuracy is lower. + - accuracy: 47 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py b/tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py index 78e0f3e401..c3a812b195 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py @@ -293,8 +293,19 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness): ], ) def test_nvfp4_4gpus( - self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, overlap_scheduler, moe_backend + self, + tp_size, + pp_size, + ep_size, + attention_dp, + cuda_graph, + overlap_scheduler, + moe_backend, + mocker, ): + mocker.patch.dict( + MMMU.EVALUATE_KWARGS, {"model_type": "mistral_large_3", "is_force_single_image": True} + ) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, @@ -315,4 +326,4 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness): kv_cache_config=kv_cache_config, ) as llm: task = MMMU(self.MODEL_NAME) - task.evaluate(llm, sampling_params=self.sampling_params, model_type="mistral_large_3") + task.evaluate(llm, sampling_params=self.sampling_params) diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index 2241aea415..62c0af24f8 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -72,7 +72,7 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm] - - accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90) + - accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90) - condition: ranges: system_gpu_count: @@ -105,4 +105,4 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus_online_eplb[enable_configurable_moe-fp8] - accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90) - accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90) - - accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90) + - accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index ec161196b8..563a38a76e 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -441,8 +441,6 @@ test_e2e.py::test_ptp_quickstart_advanced_2gpus_sm120[Nemotron-Super-49B-v1-BF16 unittest/_torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-strategy:8-dtype:bfloat16-hidden:8192-seqlen:[15]] SKIP (https://nvbugs/5761364) triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP (https://nvbugs/5762854) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822) -accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] SKIP (https://nvbugs/5762852) -accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] SKIP (https://nvbugs/5762852) unittest/_torch/sampler/test_return_logits.py SKIP (https://nvbugs/5764627) examples/serve/test_serve.py::test_config_file_loading[--config] SKIP (https://nvbugs/5754977) full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5740075) @@ -456,7 +454,6 @@ full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5768068) test_e2e.py::test_eagle3_output_consistency_4gpus[Qwen3/saved_models_Qwen3-235B-A22B_fp8_hf-Qwen3/qwen3-235B-eagle3] SKIP (https://nvbugs/5685010) examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5769855) -accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] SKIP (TBD) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8] SKIP (https://nvbugs/5772396) full:sm100/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-auto] SKIP (https://nvbugs/5772396) accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model_trtllm] SKIP (https://nvbugs/5772360)