feat(models): Mistral3.1 VLM pytorch backend support (#5529)

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
2ez4bz 2025-07-09 13:17:40 -07:00 committed by GitHub
parent b61a717275
commit 87fe44fd29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1224 additions and 4 deletions

View File

@ -18,6 +18,7 @@ extend_skip_glob = [
"tensorrt_llm/_dlpack_utils.py",
"tensorrt_llm/_ipc_utils.py",
"tensorrt_llm/_mnnvl_utils.py",
"tensorrt_llm/_torch/models/modeling_pixtral.py",
"tensorrt_llm/disaggregated_params.py",
"tensorrt_llm/engine.py",
"tensorrt_llm/graph_rewriting.py",
@ -30,6 +31,8 @@ extend_skip_glob = [
"tensorrt_llm/python_plugin.py",
"tensorrt_llm/sampling_params.py",
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
]
[tool.yapf]
@ -45,6 +48,7 @@ ignore_patterns = [
"tensorrt_llm/_dlpack_utils.py",
"tensorrt_llm/_ipc_utils.py",
"tensorrt_llm/_mnnvl_utils.py",
"tensorrt_llm/_torch/models/modeling_pixtral.py",
"tensorrt_llm/disaggregated_params.py",
"tensorrt_llm/engine.py",
"tensorrt_llm/graph_rewriting.py",
@ -57,6 +61,8 @@ ignore_patterns = [
"tensorrt_llm/python_plugin.py",
"tensorrt_llm/sampling_params.py",
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
]
[tool.codespell]
@ -76,6 +82,7 @@ exclude = [
"tensorrt_llm/_dlpack_utils.py",
"tensorrt_llm/_ipc_utils.py",
"tensorrt_llm/_mnnvl_utils.py",
"tensorrt_llm/_torch/models/modeling_pixtral.py",
"tensorrt_llm/disaggregated_params.py",
"tensorrt_llm/engine.py",
"tensorrt_llm/graph_rewriting.py",
@ -88,6 +95,8 @@ exclude = [
"tensorrt_llm/python_plugin.py",
"tensorrt_llm/sampling_params.py",
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
]
@ -116,6 +125,7 @@ include = [
"tensorrt_llm/_dlpack_utils.py",
"tensorrt_llm/_ipc_utils.py",
"tensorrt_llm/_mnnvl_utils.py",
"tensorrt_llm/_torch/models/modeling_pixtral.py",
"tensorrt_llm/disaggregated_params.py",
"tensorrt_llm/engine.py",
"tensorrt_llm/graph_rewriting.py",
@ -128,6 +138,8 @@ include = [
"tensorrt_llm/python_plugin.py",
"tensorrt_llm/sampling_params.py",
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
]
exclude = [
"**3rdparty/**",

View File

@ -1,24 +1,43 @@
from typing import Any, Optional, Tuple
import dataclasses
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import MistralConfig
from transformers import (AutoProcessor, AutoTokenizer, Mistral3Config,
MistralConfig, PretrainedConfig, PreTrainedModel)
from transformers.activations import ACT2FN
from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.attention_backend.interface import (
PositionalEmbeddingParams, RopeParams)
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models import modeling_pixtral
from tensorrt_llm._torch.models.modeling_multimodal_utils import \
fuse_input_embeds
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
DecoderModelForCausalLM,
_load_weights_impl,
register_auto_model)
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
from tensorrt_llm._torch.modules.embedding import Embedding
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
from tensorrt_llm._torch.modules.linear import TensorParallelMode
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._torch.speculative import SpecMetadata
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.inputs import (ExtraProcessedInputs, InputProcessor,
TextPrompt, register_input_processor)
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.logger import logger
_MULTIMODAL_ENV_NAME = "TLLM_MULTIMODAL_DISAGGREGATED"
# Make this a runtime lookup rather than a module-wide constant for easier unit testing.
def _is_disagg() -> bool:
return os.getenv(_MULTIMODAL_ENV_NAME, "0") == "1"
class MistralAttention(Attention):
@ -187,3 +206,337 @@ class MistralForCausalLM(DecoderModelForCausalLM[MistralModel, MistralConfig]):
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
)
class Mistral3InputProcessor(InputProcessor):
def __init__(
self,
model_path: str,
model_config: PretrainedConfig,
tokenizer: Optional[AutoTokenizer],
trust_remote_code: bool = False,
):
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model_path,
use_fast=False)
# To abide by the `InputProcessor` interface.
self.model_path = model_path
self.model_config = model_config
self.tokenizer = tokenizer
self._device = "cuda"
self._processor = AutoProcessor.from_pretrained(model_path,
use_fast=False)
@torch.inference_mode()
def __call__(
self, inputs: TextPrompt, sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
images = inputs.get("multi_modal_data", {}).get("image")
do_rescale = self._processor.image_processor.do_rescale
if images is not None and isinstance(images[0], torch.Tensor):
# The default multimodal input loader will normalize images to [0, 1] when the requested
# format is "pt" (pytorch tensors), but not for "pil" (PIL images).
do_rescale = False
processed = self._processor(
text=inputs["prompt"],
images=images,
do_rescale=do_rescale,
)
input_ids = processed.pop("input_ids").tolist()[0]
# Remaining in `processed`:
# * "attention_mask": [B, num_input_tokens]
# * "pixel_values": [B, C, H, W]
# * "image_sizes": [B, 2]
extra_processed_inputs = None
pixel_values = processed.get("pixel_values")
if pixel_values is not None:
# We have no use for the `attention_mask`.
processed.pop("attention_mask")
processed = processed.to(self._device)
# NOTE: `processed` is a dict-like object, but not actually a dict.
extra_processed_inputs = {
"multimodal_data": {
"image": {
**processed
}
},
}
return input_ids, extra_processed_inputs
@register_auto_model("Mistral3ForConditionalGeneration")
# The below informs the registry which input registry to create for this in `tensorrt_llm/llmapi/llm.py`.
@register_input_processor(Mistral3InputProcessor, model_type="mistral3")
class Mistral3VLM(PreTrainedModel):
"""Mistral3VLM implementation for TRTLLM.
NOTE: for the time being, image tokens are only placed after the text (see
`tensorrt_llm/inputs/utils.py`).
"""
def __init__(
self,
model_config: ModelConfig[Mistral3Config],
):
if _is_disagg():
raise NotImplementedError(
"Mistral3VLM does not support disaggregated inference yet. Please unset "
f"the {_MULTIMODAL_ENV_NAME} environment variable, or set it to '0'."
)
config = model_config.pretrained_config
super().__init__(config)
self.model_config = model_config
llm_model_config = self._get_sub_model_config(model_config,
"text_config")
self.llm = MistralForCausalLM(llm_model_config)
self._device = "cuda"
vision_model_config = self._get_sub_model_config(
model_config, "vision_config")
self._vision_tower = modeling_pixtral.PixtralVisionModel(
vision_model_config)
self._multi_modal_projector = Mistral3MultiModalProjector(model_config)
vision_feature_layer = config.vision_feature_layer
if vision_feature_layer != -1:
raise ValueError(
f"Using intermediate layers ({vision_feature_layer}) in the `PixtralVisionModel` "
f"is not supported. Please use `vision_feature_layer=-1`.")
self.model_dtype = getattr(config, "torch_dtype", torch.bfloat16)
self._image_token_ids = torch.tensor([config.image_token_index],
dtype=torch.int32,
device=self._device)
self._post_config()
# This is necessary because the executor looks at
# `model.model_config.pretrained_config.vocab_size`.
def _post_config(self):
self.config = self.llm.config
self.model_config.pretrained_config = self.llm.config
def load_weights(self, weights: Dict, *args, **kwargs):
llm_weights = _filter_weights(weights, "language_model.")
self.llm.load_weights(llm_weights, *args, **kwargs)
vit_weights = _filter_weights(weights, "vision_tower.")
self._vision_tower.load_weights(vit_weights, *args, **kwargs)
mm_projector_weights = _filter_weights(weights,
"multi_modal_projector.")
# `_load_weights_impl` assumes `config.hidden_size` exists, which is not the case for the
# top-level `Mistral3Config`.
self._multi_modal_projector.load_state_dict(mm_projector_weights)
def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()
@torch.inference_mode()
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
return_context_logits: bool = False,
**kwargs,
) -> torch.Tensor:
"""Forward method."""
num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
logger.debug(f"{num_context_requests=}, {num_generation_requests=}")
multimodal_params = kwargs.get("multimodal_params", [])
image_features = []
multimodal_params_len = len(multimodal_params)
if multimodal_params_len > 0:
if multimodal_params_len != num_context_requests:
raise RuntimeError(
f"Number of multimodal tensors ({multimodal_params_len}) should be equal to number of "
f"context requests ({num_context_requests}) in the batch.")
# NOTES:
# 1. the pixel values in `multimodal_data["image"]` might vary in (height, width) between
# images, making them unsafe to batch in general. The input processor also cannot produce
# them in a batch, since it is always called with a single input - otherwise, we would
# have been able to naturally leverage the padding / resizing capabilities of the underlying
# `PixtralProcessor`.
# 2. After each `pixel_values` tensor has gone through the vision tower's `patch_conv` layer,
# they are divided into patches that are then concatenated in order to treat them as a
# single "sequence" in the vision tower's attention layers, so some form of batching still
# happens in the vision tower.
image_features = [
self._get_image_features(**x.multimodal_data["image"])
for x in multimodal_params
]
input_ids, inputs_embeds = fuse_input_embeds(
embedding_layer=self.llm.model.embed_tokens,
input_ids=input_ids,
mm_embeds=image_features,
mm_token_ids=self._image_token_ids,
)
return self.llm.forward(
attn_metadata=attn_metadata,
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
return_context_logits=return_context_logits,
)
@staticmethod
def _get_sub_model_config(
model_config: ModelConfig[MistralConfig],
name: str,
) -> ModelConfig:
# Extract the subconfig from the `transformers` config and shove it into our own
# `ModelConfig` class.
sub_model_config: ModelConfig[MistralConfig] = dataclasses.replace(
model_config,
pretrained_config=getattr(model_config.pretrained_config, name),
)
# Make sure some fields that are not explicitly included in the sub config, but present
# in the top-level config, are replicated.
if (hasattr(sub_model_config.pretrained_config, "torch_dtype")
and sub_model_config.pretrained_config.torch_dtype is None):
sub_model_config.pretrained_config.torch_dtype = model_config.pretrained_config.torch_dtype
return sub_model_config
# Original implementation:
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/
# modeling_mistral3.py#L341
def _get_image_features(
self,
pixel_values: torch.Tensor,
image_sizes: torch.Tensor,
):
image_outputs = self._vision_tower(
pixel_values=pixel_values,
image_sizes=image_sizes,
)
image_features = self._multi_modal_projector(image_outputs.squeeze(0),
image_sizes)
return image_features
# Original implementation:
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/modeling_mistral3.py#L66
# NOTE: the main difference is the usage of TRTLLM's own `Linear` layer over pytorch's built-in layer.
class Mistral3PatchMerger(torch.nn.Module):
def __init__(self, model_config: ModelConfig[Mistral3Config]):
super().__init__()
config = model_config.pretrained_config
# Both the below are needed in order to use `_load_weights_impl`.
self.model_config = model_config
self.config = config
hidden_size = config.vision_config.hidden_size
self._spatial_merge_size = config.spatial_merge_size
self._patch_size = config.vision_config.patch_size
self.merging_layer = Linear(
in_features=hidden_size * self._spatial_merge_size**2,
out_features=hidden_size,
bias=False,
dtype=config.torch_dtype,
)
@torch.inference_mode()
def forward(self, image_features: torch.Tensor,
image_sizes: torch.Tensor) -> torch.Tensor:
image_sizes = [(image_size[0] // self._patch_size,
image_size[1] // self._patch_size)
for image_size in image_sizes]
tokens_per_image = [h * w for h, w in image_sizes]
d = image_features.shape[-1]
permuted_tensor = []
for image_index, image_tokens in enumerate(
image_features.split(tokens_per_image)):
# Reshape image_tokens into a 2D grid
h, w = image_sizes[image_index]
image_grid = image_tokens.view(h, w, d).permute(2, 0,
1).unsqueeze(0)
grid = torch.nn.functional.unfold(
image_grid,
kernel_size=self._spatial_merge_size,
stride=self._spatial_merge_size)
grid = grid.view(d * self._spatial_merge_size**2, -1).t()
permuted_tensor.append(grid)
image_features = torch.cat(permuted_tensor, dim=0)
image_features = self.merging_layer(image_features)
return image_features
def load_weights(self, weights):
_load_weights_impl(self, weights)
# Original implementation:
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/
# modeling_mistral3.py#L104C1-L127C29
class Mistral3MultiModalProjector(torch.nn.Module):
def __init__(self, model_config: ModelConfig[Mistral3Config]):
super().__init__()
config = model_config.pretrained_config
# Both the below are needed in order to use `_load_weights_impl`.
self.model_config = model_config
self.config = config
dtype = config.torch_dtype
self.norm = RMSNorm(
hidden_size=config.vision_config.hidden_size,
# NOTE: the original implementation actually does not look at the config for this value.
# We therefore hardcode the default value `1e-6` from `Mistral3RMSNorm`.
eps=1e-6,
dtype=dtype,
)
self.patch_merger = Mistral3PatchMerger(model_config)
# We have hidden_size * the number of vision feature layers
num_feature_layers = 1 if isinstance(config.vision_feature_layer,
int) else len(
config.vision_feature_layer)
self.linear_1 = Linear(
in_features=config.vision_config.hidden_size * num_feature_layers,
out_features=config.text_config.hidden_size,
bias=config.multimodal_projector_bias,
dtype=dtype,
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = Linear(
in_features=config.text_config.hidden_size,
out_features=config.text_config.hidden_size,
bias=config.multimodal_projector_bias,
dtype=dtype,
)
@torch.inference_mode()
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
image_features = self.norm(image_features)
image_features = self.patch_merger(image_features, image_sizes)
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def load_weights(self, weights):
_load_weights_impl(self, weights)
def _filter_weights(weights: Dict[str, torch.Tensor],
prefix: str) -> Dict[str, torch.Tensor]:
return {
name[len(prefix):]: weight
for name, weight in weights.items() if name.startswith(prefix)
}

View File

@ -0,0 +1,311 @@
from typing import List
import torch
import transformers
from tensorrt_llm._torch import model_config as model_config_lib
from tensorrt_llm._torch.attention_backend import interface as attention_interface
from tensorrt_llm._torch.attention_backend import utils as attention_utils
from tensorrt_llm._torch.models import modeling_utils
from tensorrt_llm._torch.modules import attention as trtllm_attention
from tensorrt_llm._torch.modules import gated_mlp as trtllm_gated_mlp
from tensorrt_llm._torch.modules import rms_norm as trtllm_rmsnorm
class PixtralAttention(trtllm_attention.Attention):
def __init__(
self,
model_config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig],
layer_idx: int,
):
config = model_config.pretrained_config
pos_embd_params = None
max_position_embeddings = None
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_attention_heads,
max_position_embeddings=max_position_embeddings,
bias=False,
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=getattr(config, "torch_dtype", torch.float32),
config=model_config,
# Pixtral first needs to compute positional embeddings using its own
# `PixtralRotaryEmbedding`.
rope_fusion=False,
)
class PixtralAttentionLayer(torch.nn.Module):
def __init__(
self,
config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig],
layer_idx: int,
):
super().__init__()
hidden_size = config.pretrained_config.hidden_size
dtype = config.pretrained_config.torch_dtype
self.attention_norm = trtllm_rmsnorm.RMSNorm(
hidden_size=hidden_size,
eps=1e-5,
dtype=dtype,
)
pretrained_config = config.pretrained_config
if pretrained_config.hidden_act != "silu":
raise ValueError(
"Only 'silu' is accepted as the activation function for the MLP in "
f"{self.__class__.__name__}. Got: {pretrained_config.hidden_act}."
)
self.feed_forward = trtllm_gated_mlp.GatedMLP(
hidden_size=pretrained_config.hidden_size,
intermediate_size=pretrained_config.intermediate_size,
bias=False,
activation=torch.nn.functional.silu,
dtype=pretrained_config.torch_dtype,
config=config,
)
self.attention = PixtralAttention(config, layer_idx)
self.ffn_norm = trtllm_rmsnorm.RMSNorm(
hidden_size=hidden_size,
eps=1e-5,
dtype=dtype,
)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: attention_interface.AttentionMetadata,
position_ids: torch.Tensor,
):
residual = hidden_states
hidden_states = self.attention_norm(hidden_states)
hidden_states = self.attention(
# NOTE: although we do not need the `position_ids` to compute ROPE (since it has already
# been pre-computed), internally, the cos / sin vectors will not be applied to the
# query / key tensors if `position_ids=None` in `RotaryEmbedding`.
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask=attention_interface.PredefinedAttentionMask.FULL,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.ffn_norm(hidden_states)
hidden_states = self.feed_forward(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# Original implementation:
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/pixtral/modeling_pixtral.py#L279
class PixtralTransformer(torch.nn.Module):
def __init__(self, config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig]):
super().__init__()
self.layers = torch.nn.ModuleList()
for i in range(config.pretrained_config.num_hidden_layers):
self.layers.append(PixtralAttentionLayer(config=config, layer_idx=i))
self._head_dim = config.pretrained_config.head_dim
self._num_heads = config.pretrained_config.num_attention_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_embeddings: torch.Tensor,
position_ids: torch.Tensor,
attn_metadata: attention_interface.AttentionMetadata,
):
if inputs_embeds.ndim == 3:
batch_size, patches, _ = inputs_embeds.shape
elif inputs_embeds.ndim == 2:
batch_size = 1
patches = inputs_embeds.size(0)
rope_function = _RopeFunction(
batch_size=batch_size,
patches=patches,
num_heads=self._num_heads,
head_dim=self._head_dim,
position_embeddings=position_embeddings,
)
hidden_states = inputs_embeds
for encoder_layer in self.layers:
# The way pixtral applies rope is by:
# 1. Computing the `position_ids` using the `patch_embeds` (which are essentially a
# sliced + concat'ed output of the conv2d layer), using their positions in the
# a meshgrid.
# 2. Computing `position_embeddings` once for the entire transformer portion.
# See: https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/
# models/pixtral/modeling_pixtral.py#L494
# 3. These `position_embeddings` are then the ones used to apply rope in each attention
# layer.
# By substituting the `encoder_layer.attention.rotary_emb` to use `_RopeFunction`, which
# has these `position_embeddings` as an attribute, we can reuse the embeddings + application
# logic for each encoder layer.
encoder_layer.attention.rotary_emb = rope_function
layer_outputs = encoder_layer(
hidden_states=hidden_states,
attn_metadata=attn_metadata,
position_ids=position_ids,
)
hidden_states = layer_outputs
return hidden_states
# Original implementation:
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/pixtral/modeling_pixtral.py#L440
@modeling_utils.register_auto_model("PixtralVisionModel")
class PixtralVisionModel(torch.nn.Module):
def __init__(
self, model_config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig]
):
super().__init__()
tp_size = model_config.mapping.tp_size
# TODO: implement support for `tp_size > 1`.
if tp_size > 1:
raise NotImplementedError(
f"Mistral3VLM does not support `mapping.tp_size > 1` yet (got {tp_size})."
)
# Both the below are needed in order to use `_load_weights_impl`.
self.model_config = model_config
self.config: transformers.PixtralVisionConfig = model_config.pretrained_config
self.patch_conv = torch.nn.Conv2d(
in_channels=self.config.num_channels,
out_channels=self.config.hidden_size,
kernel_size=self.config.patch_size,
stride=self.config.patch_size,
bias=False,
)
self._patch_size = self.config.patch_size
self.ln_pre = trtllm_rmsnorm.RMSNorm(
hidden_size=self.config.hidden_size,
eps=1e-5,
dtype=self.config.torch_dtype,
)
self.transformer = PixtralTransformer(model_config)
self._patch_positional_embedding = (
transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(self.config)
)
self._metadata_cls = attention_utils.get_attention_backend(
model_config.attn_backend
).Metadata
@torch.inference_mode()
def forward(
self,
pixel_values: torch.Tensor,
image_sizes: torch.Tensor,
):
with torch.autocast(device_type="cuda", dtype=self.config.torch_dtype):
patch_embeds = self.patch_conv(pixel_values)
patch_embeds_list = [
embed[..., : (size[0] // self._patch_size), : (size[1] // self._patch_size)]
for embed, size in zip(patch_embeds, image_sizes)
]
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0)
patch_embeds = self.ln_pre(patch_embeds)
position_ids = transformers.models.pixtral.modeling_pixtral.position_ids_in_meshgrid(
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
)
position_embeddings = self._patch_positional_embedding(patch_embeds, position_ids)
attn_metadata = self._prepare_attn_metadata(
# The `torch.cat` that creates the `patch_embeds` flattens the conv features from multiple
# images into a single sequence - hence why we hardcode the batch size to 1 here.
batch_size=1,
seq_len=position_ids.size(0),
)
out = self.transformer(
patch_embeds,
position_embeddings=position_embeddings,
position_ids=position_ids,
attn_metadata=attn_metadata,
)
return out
def load_weights(self, weights):
modeling_utils._load_weights_impl(self, weights)
def _prepare_attn_metadata(self, batch_size: int, seq_len: int):
request_ids = list(range(1, batch_size + 1))
prompt_lens = [seq_len] * batch_size
attn_metadata = self._metadata_cls(
seq_lens=torch.tensor([seq_len] * batch_size, dtype=torch.int),
num_contexts=batch_size,
max_num_requests=batch_size,
max_num_tokens=seq_len * batch_size,
kv_cache_manager=None,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
attn_metadata.max_seq_len = seq_len * batch_size
attn_metadata.prepare()
return attn_metadata
class _RopeFunction:
def __init__(
self,
batch_size: int,
patches: int,
num_heads: int,
head_dim: int,
position_embeddings: torch.Tensor,
):
self._batch_size = batch_size
self._patches = patches
self._num_heads = num_heads
self._head_dim = head_dim
self._cos, self._sin = position_embeddings
# This signature matches that of
# `tensorrt_llm/_torch/modules/rotary_embedding.py::RotaryEmbedding.forward` so that we are
# able to override the `PixtralAttentionLayer.rotary_embed` attribute.
@torch.no_grad()
def __call__(
self,
# Unused.
position_ids: torch.Tensor,
# Assumed to be in the order `[q, k]`.
targets: List[torch.Tensor],
) -> List[torch.Tensor]:
if len(targets) != 2:
raise ValueError("Expected exactly two targets [q, k].")
# TODO: see if we can reuse `RotaryEmbedding.apply_rotary_pos_emb`.
orig_shape = targets[0].shape
q_embed, k_embed = transformers.models.pixtral.modeling_pixtral.apply_rotary_pos_emb(
q=targets[0]
.view(
self._batch_size,
self._patches,
self._num_heads,
self._head_dim,
)
.transpose(1, 2),
k=targets[1]
.view(
self._batch_size,
self._patches,
self._num_heads,
self._head_dim,
)
.transpose(1, 2),
cos=self._cos,
sin=self._sin,
unsqueeze_dim=0,
)
q_embed = q_embed.transpose(2, 1).reshape(orig_shape)
k_embed = k_embed.transpose(2, 1).reshape(orig_shape)
return [q_embed, k_embed]

View File

@ -184,13 +184,15 @@ SUPPORTED_GEMMA_MODEL_GROUP = ["gemma3"]
SUPPORTED_LLAMA_MODEL_GROUP = ["mllama", "llama4"]
SUPPORTED_LLAVA_IMAGE_MODEL_GROUP = ["llava_llama", "llava_next"]
SUPPORTED_LLAVA_VIDEO_MODEL_GROUP = ["llava_llama"]
SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP = ["mistral3"]
SUPPORTED_HYPERCLOVAX_MODEL_GROUP = ["hyperclovax_vlm"]
ALL_SUPPORTED_IMAGE_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
+ SUPPORTED_LLAMA_MODEL_GROUP \
+ SUPPORTED_LLAVA_IMAGE_MODEL_GROUP \
+ SUPPORTED_HYPERCLOVAX_MODEL_GROUP \
+ SUPPORTED_GEMMA_MODEL_GROUP
+ SUPPORTED_GEMMA_MODEL_GROUP \
+ SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP
ALL_SUPPORTED_VIDEO_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
+ SUPPORTED_LLAVA_VIDEO_MODEL_GROUP
@ -217,6 +219,10 @@ PLACEHOLDER_PLACEMENT_MAP = {
"mllama": MultimodalPlaceholderPlacement.BEFORE_TEXT,
"hyperclovax_vlm": MultimodalPlaceholderPlacement.AFTER_TEXT,
"gemma3": MultimodalPlaceholderPlacement.BEFORE_TEXT,
# NOTE: for mistral3 multimodal models, it does not strictly have to be after the text.
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
# src/mistral_common/tokens/tokenizers/base.py#L326
"mistral3": MultimodalPlaceholderPlacement.AFTER_TEXT,
}
assert len(PLACEHOLDER_PLACEMENT_MAP) == len(ALL_SUPPORTED_MULTIMODAL_MODELS)
@ -247,6 +253,10 @@ def retrieve_multimodal_placeholder(model_type: str, modality: str,
'<|im_start|>user (vector)\n<|dummy3|><|im_end|>\n' + \
'<|im_start|>image/aux\n다음 중 ocr은 사진에서 검출된 글자이고, lens_keyword는 사진에서 추출된 keyword와 bbox 위치입니다.' + \
'bbox는 0~1 사이로 정규화된 [x1, y1, x2, y2]의 형태입니다. 참고하여 답변하세요. {"ocr": "", "lens_keywords": "", "lens_local_keywords": ""}'
elif model_type in SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP:
# Ref: https://github.com/mistralai/mistral-common/blob/26a6bb3a07ee0b78a3808f2797f23e1d28514b93/
# src/mistral_common/tokens/tokenizers/base.py#L60
return "[IMG]"
raise TypeError(
f"For image modality, only {ALL_SUPPORTED_IMAGE_MODELS} are supported but got {model_type}"
)

View File

@ -1963,9 +1963,12 @@ def test_ptp_quickstart_advanced_mixed_precision(llm_root, llm_venv):
("llava-v1.6-mistral-7b", "llava-v1.6-mistral-7b-hf"),
("qwen2-vl-7b-instruct", "Qwen2-VL-7B-Instruct"),
("qwen2.5-vl-7b-instruct", "Qwen2.5-VL-7B-Instruct"),
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
])
def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
modality, use_cuda_graph):
# NOTE: individual tests need to be enabled in
# tests/integration/test_lists/qa/examples_test_list.txt
llm_venv.run_cmd(
['-m', 'pip', 'install', 'flash-attn==2.7.3', '--no-build-isolation'])
@ -2051,6 +2054,16 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
["earth", "rotating", "night", "lights", "cities"],
],
},
"mistral-small-3.1-24b-instruct": {
"image": [
[
"dramatic", "seascape", "stormy", "turbulent", "waves",
"rough"
],
["scenic", "rock", "landscape", "snow", "formation"],
["highway", "traffic", "directions", "lanes", "Jurong"],
],
},
}
cmd = [

View File

@ -523,6 +523,8 @@ test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B
test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-image-True]
test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-False]
test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-True]
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False]
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]

View File

@ -0,0 +1,415 @@
import contextlib
import os
from typing import Any, Dict
from unittest import mock
import pytest
import torch
import transformers
from utils.util import getSMVersion
import tensorrt_llm
from tensorrt_llm import mapping as mapping_lib
from tensorrt_llm._torch import metadata as metadata_lib
from tensorrt_llm._torch import model_config as model_config_lib
from tensorrt_llm._torch.attention_backend import utils as attention_utils
from tensorrt_llm._torch.models import modeling_mistral
from tensorrt_llm._torch.pyexecutor import cuda_graph_runner, resource_manager
from tensorrt_llm.bindings import executor as executor_lib
from tensorrt_llm.models import modeling_utils
@pytest.fixture
def mistral_small_3_1_24b_config():
return {
"architectures": ["Mistral3ForConditionalGeneration"],
"image_token_index": 10,
"model_type": "mistral3",
"multimodal_projector_bias": False,
"projector_hidden_act": "gelu",
"spatial_merge_size": 2,
"text_config": {
"attention_dropout": 0.0,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 32768,
"max_position_embeddings": 131072,
"model_type": "mistral",
"num_attention_heads": 32,
"num_hidden_layers": 40,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000000.0,
"sliding_window": None,
"use_cache": True,
"vocab_size": 131072,
},
"torch_dtype": "bfloat16",
"transformers_version": "4.50.0.dev0",
"vision_config": {
"attention_dropout": 0.0,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 1024,
"image_size": 1540,
"initializer_range": 0.02,
"intermediate_size": 4096,
"model_type": "pixtral",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"rope_theta": 10000.0,
},
"vision_feature_layer": -1,
}
def reduce_mistral_config(
mem_for_full_model: int, config_dict: Dict[str, Any], default_num_layers: int = 32
):
_, total_mem = torch.cuda.mem_get_info()
if "text_config" in config_dict:
config_dict = config_dict["text_config"]
# scale model down if gpu memory is low
if total_mem < mem_for_full_model:
model_fraction = total_mem / mem_for_full_model
num_layers = int(config_dict["num_hidden_layers"] * model_fraction)
num_layers = min(num_layers, default_num_layers)
config_dict["num_hidden_layers"] = num_layers
def init_hf_model(cls, config, dtype, device):
"""Helper function for initializing a model from `transformers`.
The reason this function exists is: by default, instantiating a `transformers` model also
eagerly initializes the model's weights on the CPU, which takes an absurdly long time to
complete.
Instead, we lazily instantiate the model, and initialize the weights only after moving it to
the requested `device`.
"""
from transformers import modeling_utils as t_modeling_utils
with t_modeling_utils.no_init_weights():
model = cls(config).eval()
model.to(device=device)
model.init_weights()
model.to(dtype=dtype)
return model
@contextlib.contextmanager
def kv_cache_manager_context(kv_cache_manager):
try:
yield
finally:
kv_cache_manager.shutdown()
def test_mistral_3_vlm_rejects_disagg(mistral_small_3_1_24b_config):
with (
mock.patch.dict(os.environ, {"TLLM_MULTIMODAL_DISAGGREGATED": "1"}),
pytest.raises(NotImplementedError, match="disaggregated inference"),
):
modeling_mistral.Mistral3VLM(
model_config=model_config_lib.ModelConfig(
pretrained_config=transformers.Mistral3Config.from_dict(
mistral_small_3_1_24b_config
)
),
)
@pytest.mark.parametrize("quant_algo", [None, "FP8"])
def test_mistral_3_vlm_sanity(mistral_small_3_1_24b_config, quant_algo):
if quant_algo == "FP8" and getSMVersion() < 89:
pytest.skip("This test is not supported in pre-Ada architecture")
config_dict = mistral_small_3_1_24b_config
# 24B * sizeof(float16) plus some extra for activations
mem_for_full_model = int(2.1 * 24 * 2 ** (30))
reduce_mistral_config(mem_for_full_model, config_dict)
if config_dict["text_config"]["num_hidden_layers"] <= 0:
pytest.skip("Insufficient memory for a single Mistral layer")
mistral_3_config = transformers.Mistral3Config.from_dict(config_dict)
if quant_algo:
quant_config = modeling_utils.QuantConfig(quant_algo=quant_algo)
else:
quant_config = None
dtype = mistral_3_config.torch_dtype
device = torch.device("cuda")
model_config = model_config_lib.ModelConfig(
pretrained_config=mistral_3_config,
quant_config=quant_config,
)
mistral = modeling_mistral.Mistral3VLM(model_config).to(device)
input_ids = torch.tensor(
[100, 200, 300, 100, 200, 100, 400, 500], dtype=torch.int, device=device
)
context_sequence_lengths = [3, 2, 1]
sequence_lengths = context_sequence_lengths + [1, 1]
past_seen_tokens = [0, 0, 0, 62, 75]
request_ids = list(range(len(sequence_lengths)))
token_nums = (torch.tensor(past_seen_tokens) + torch.tensor(sequence_lengths)).tolist()
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
num_blocks = 100
tokens_per_block = 128
head_dim = mistral.config.head_dim
num_layers = mistral.config.num_hidden_layers
num_kv_heads = mistral.config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block
batch_size = len(context_sequence_lengths) + 2
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype")
mapping = mapping_lib.Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = executor_lib.KvCacheConfig(max_tokens=num_blocks * tokens_per_block)
kv_cache_manager = resource_manager.KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
with kv_cache_manager_context(kv_cache_manager):
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
metadata_cls = attention_utils.get_attention_backend(model_config.attn_backend).Metadata
attn_metadata = metadata_cls(
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
num_contexts=len(context_sequence_lengths),
kv_cache_params=metadata_lib.KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=past_seen_tokens,
),
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
max_num_requests=len(context_sequence_lengths) + 2,
max_num_tokens=8192,
)
position_ids = []
for i, tokens in enumerate(past_seen_tokens):
seq_len = context_sequence_lengths[i] if i < len(context_sequence_lengths) else 1
position_id = torch.arange(tokens, tokens + seq_len, device=input_ids.device)
position_ids.append(position_id)
position_ids = torch.cat(position_ids).unsqueeze(0)
with torch.inference_mode():
attn_metadata.prepare()
logits = mistral.forward(
input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata
)
assert len(past_seen_tokens) == logits.shape[0]
with torch.inference_mode():
attn_metadata.prepare()
logits = mistral.forward(
input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata,
return_context_logits=True,
)
assert input_ids.shape == logits.shape[:-1]
@pytest.mark.parametrize(
"backend, use_cuda_graph",
[
("VANILLA", False),
("FLASHINFER", False),
("FLASHINFER", True),
("TRTLLM", False),
("TRTLLM", True),
],
)
@torch.no_grad()
def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use_cuda_graph):
metadata_cls = attention_utils.get_attention_backend(backend).Metadata
torch.random.manual_seed(0)
config_dict = mistral_small_3_1_24b_config
# 24B * sizeof(float16) plus some extra for activations
# times 2, since we'll need 2 of these
mem_for_full_model = int(2.1 * 24 * 2 ** (30) * 2)
reduce_mistral_config(mem_for_full_model, config_dict)
if config_dict["text_config"]["num_hidden_layers"] <= 0:
pytest.skip("Insufficient memory for a single Mistral layer")
mistral_config = transformers.Mistral3Config.from_dict(config_dict)
dtype = mistral_config.torch_dtype
device = torch.device("cuda")
hf_mistral = init_hf_model(
cls=transformers.Mistral3ForConditionalGeneration,
config=mistral_config,
dtype=dtype,
device=device,
)
model_config = model_config_lib.ModelConfig(
pretrained_config=mistral_config,
attn_backend=backend,
)
mistral = modeling_mistral.Mistral3VLM(model_config).to(dtype).to(device)
mistral.load_weights(hf_mistral.state_dict())
num_blocks = 1
tokens_per_block = 128
head_dim = mistral.config.head_dim
num_layers = mistral.config.num_hidden_layers
num_kv_heads = mistral.config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block
batch_size = 1
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype")
mapping = mapping_lib.Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = executor_lib.KvCacheConfig(max_tokens=num_blocks * tokens_per_block)
kv_cache_manager = resource_manager.KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
with kv_cache_manager_context(kv_cache_manager):
# context
input_ids = torch.tensor(
[100, 200, 300, 100, 200, 100, 400, 500], dtype=torch.int, device=device
)
num_cached_tokens_per_seq = [0]
request_ids = [1]
token_nums = [input_ids.size(-1)]
prompt_lens = [input_ids.size(-1)]
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
attn_metadata = metadata_cls(
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
num_contexts=1,
kv_cache_params=metadata_lib.KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
# Note: no CUDA graphs for prefill, the graph runner is built for
# decoding only.
position_ids = [torch.arange(0, input_ids.size(-1))]
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
with torch.inference_mode():
attn_metadata.prepare()
logits = mistral.forward(
input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata
)
ref = hf_mistral.forward(
input_ids=input_ids.unsqueeze(0), position_ids=position_ids, use_cache=True
)
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.4, rtol=0.4)
# gen
gen_input_ids = torch.tensor([600], dtype=torch.int, device=device)
num_cached_tokens_per_seq = [input_ids.size(-1)]
attn_metadata = metadata_cls(
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
num_contexts=0,
kv_cache_params=metadata_lib.KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
gen_position_ids = [
torch.arange(input_ids.size(-1), input_ids.size(-1) + gen_input_ids.size(-1))
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not use_cuda_graph:
return mistral.forward(
input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata
)
else:
graph_runner = cuda_graph_runner.DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata
)
graph_runner.capture(lambda inputs: mistral.forward(**inputs))
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run(
{
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
)
return logits
if use_cuda_graph:
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
with torch.inference_mode():
logits = run_forward(
input_ids=gen_input_ids, position_ids=gen_position_ids, attn_metadata=attn_metadata
)
ref = hf_mistral.forward(
input_ids=gen_input_ids.unsqueeze(0),
position_ids=gen_position_ids,
past_key_values=ref.past_key_values,
use_cache=True,
)
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.4, rtol=0.4)

View File

@ -0,0 +1,104 @@
import pytest
import torch
import transformers
from transformers.models.pixtral import modeling_pixtral as hf_modeling_pixtral
from tensorrt_llm import mapping as mapping_lib
from tensorrt_llm._torch import model_config as model_config_lib
from tensorrt_llm._torch.models import modeling_pixtral
@pytest.fixture
def pixtral_vision_config():
# Values taken from:
# https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/config.json
return model_config_lib.ModelConfig(
pretrained_config=transformers.PixtralVisionConfig(
hidden_size=1024,
num_attention_heads=16,
torch_dtype=torch.bfloat16,
hidden_act="silu",
),
)
@pytest.fixture
def set_seed():
torch.manual_seed(322)
def init_hf_model(cls, config, dtype, device):
"""Helper function for initializing a model from `transformers`.
The reason this function exists is: by default, instantiating a `transformers` model also
eagerly initializes the model's weights on the CPU, which takes an absurdly long time to
complete.
Instead, we lazily instantiate the model, and initialize the weights only after moving it to
the requested `device`.
"""
from transformers import modeling_utils as t_modeling_utils
with t_modeling_utils.no_init_weights():
model = cls(config).eval()
model.to(device=device)
model.init_weights()
model.to(dtype=dtype)
return model
@pytest.mark.parametrize(
"mapping",
[
mapping_lib.Mapping(world_size=2, tp_size=2),
mapping_lib.Mapping(world_size=3, tp_size=3),
mapping_lib.Mapping(world_size=4, tp_size=2, pp_size=2),
mapping_lib.Mapping(world_size=8, tp_size=2, pp_size=2, cp_size=2),
],
)
def test_pixtral_vision_model_rejects_tp_size_greater_than_one(pixtral_vision_config, mapping):
pixtral_vision_config.mapping = mapping
with pytest.raises(NotImplementedError, match="tp_size > 1"):
modeling_pixtral.PixtralVisionModel(model_config=pixtral_vision_config)
@torch.no_grad()
@pytest.mark.usefixtures("set_seed")
def test_pixtral_vision_model_vs_hf(pixtral_vision_config):
dtype = torch.bfloat16
device = torch.device("cuda")
pretrained_config = pixtral_vision_config.pretrained_config
pixtral_model = (
modeling_pixtral.PixtralVisionModel(model_config=pixtral_vision_config).eval().to(device)
)
hf_pixtral_model = init_hf_model(
cls=hf_modeling_pixtral.PixtralVisionModel,
config=pretrained_config,
dtype=dtype,
device=device,
)
# Make sure both models have the same weights.
pixtral_model.load_weights(hf_pixtral_model.state_dict())
batch_size = 1
height, width, channels = 123, 456, 3
pixel_values = torch.randn(batch_size, channels, height, width, device=device, dtype=dtype)
image_sizes = torch.tensor([[height, width]])
out = pixtral_model(
pixel_values=pixel_values,
image_sizes=image_sizes,
)
with torch.autocast(device_type="cuda", dtype=dtype):
hf_out = (
hf_pixtral_model(
pixel_values=pixel_values,
image_sizes=image_sizes,
)
.last_hidden_state.squeeze(0)
.to(dtype=dtype)
)
torch.testing.assert_close(out, hf_out, atol=0.2, rtol=0.2)