[None][feat] Extend VLM factory and add Mistral3 factory (#7583)

This commit:

* extends existing factory interfaces to enable Mistral3 in AutoDeploy.
* adds a Mistral3 VLM factory.
* adds various model patches for pixtral (the vision model) and mistral3
  to make the VLM export compliant.
* adjusts checkpoint loading code to take possible parameter name
  conversions into account.
* fixes a sampling bug (the `end_id` needs to be take into account when
  sampling, but it is not included in the stop words' token IDs).

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
William Zhang 2025-09-08 23:47:18 -07:00 committed by GitHub
parent 6ba1c8421c
commit c53d1814a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 737 additions and 19 deletions

View File

@ -211,6 +211,17 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
self.attn_page_size = self.max_seq_len self.attn_page_size = self.max_seq_len
return self return self
@field_validator("model_factory", mode="after")
@classmethod
def model_factory_exists(cls, value: str) -> str:
if not ModelFactoryRegistry.has(value):
raise ValueError(
f"'{value}' does not exist in the model factory registry. Available values: "
f"{ModelFactoryRegistry.entries()}."
)
return value
### UTILITY METHODS ############################################################################ ### UTILITY METHODS ############################################################################
def create_factory(self) -> ModelFactory: def create_factory(self) -> ModelFactory:
"""Create a model factory from the arguments.""" """Create a model factory from the arguments."""

View File

@ -1,2 +1,2 @@
from . import hf, patches from . import hf, mistral3, patches
from .factory import * from .factory import *

View File

@ -3,7 +3,7 @@
import copy import copy
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, Optional, Tuple, Type from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -282,3 +282,7 @@ class ModelFactoryRegistry:
@classmethod @classmethod
def has(cls, model_factory_cls: str) -> bool: def has(cls, model_factory_cls: str) -> bool:
return model_factory_cls in cls._registry return model_factory_cls in cls._registry
@classmethod
def entries(cls) -> List[str]:
return list(cls._registry.keys())

View File

@ -1,6 +1,7 @@
"""Interface to initialize and load HF models.""" """Interface to initialize and load HF models."""
import os import os
import re
import types import types
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@ -99,6 +100,11 @@ class AutoModelForCausalLMFactory(ModelFactory):
# set sharding config source to huggingface # set sharding config source to huggingface
self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE
# Some models' transformers implementation has changed in between when safetensors were produced
# and / or uploaded to HuggingFace hub. When building the model, we will try to determine whether
# a mapping of the parameter names exists and hold that information in this attribute.
self._checkpoint_conversion_mapping: Optional[Dict[str, str]] = None
@property @property
def autoconfig_from_pretrained(self): def autoconfig_from_pretrained(self):
return AutoConfig.from_pretrained return AutoConfig.from_pretrained
@ -168,6 +174,7 @@ class AutoModelForCausalLMFactory(ModelFactory):
# if present, initialize sharding config. We need head_dim for colwise sharding. # if present, initialize sharding config. We need head_dim for colwise sharding.
self._set_sharding_config(model.config) self._set_sharding_config(model.config)
self._checkpoint_conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
# patch forward method # patch forward method
model.forward = types.MethodType(self._simple_forward, model) model.forward = types.MethodType(self._simple_forward, model)
@ -326,15 +333,30 @@ class AutoModelForCausalLMFactory(ModelFactory):
"""Load the checkpoint into the model.""" """Load the checkpoint into the model."""
# identify the most relevant checkpoint file # identify the most relevant checkpoint file
ckpt_file = self._get_checkpoint_file(self.model) ckpt_file = self._get_checkpoint_file(self.model)
load_handle = model.register_load_state_dict_pre_hook(self._remap_param_names_load_hook)
# Ensure it's the first one.
model._load_state_dict_pre_hooks.move_to_end(key=load_handle.id, last=False)
get_handle = model.register_state_dict_post_hook(
_StateDictParamNameConverter(self._checkpoint_conversion_mapping)
)
# Ensure it's the first one.
model._state_dict_hooks.move_to_end(key=get_handle.id, last=False)
# reuse the load checkpoint utility from accelerate # reuse the load checkpoint utility from accelerate
with hf_load_state_dict_with_device(device): try:
# Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic. with hf_load_state_dict_with_device(device):
# Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict, # Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic.
# which collects local model params, syncs weights from checkpoint, and applies them via # Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict,
# model.load_state_dict. # which collects local model params, syncs weights from checkpoint, and applies them via
# This sync step can interfere with load_hooks by mixing raw checkpoint weights and # model.load_state_dict.
# model-transformed weights,leading to unexpected key mismatches or format issues. # This sync step can interfere with load_hooks by mixing raw checkpoint weights and
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False) # model-transformed weights,leading to unexpected key mismatches or format issues.
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
finally:
load_handle.remove()
get_handle.remove()
def _load_quantization_config(self, fetched_dir: str): def _load_quantization_config(self, fetched_dir: str):
"""Load the quantization config from the model directory if not done already.""" """Load the quantization config from the model directory if not done already."""
@ -351,6 +373,63 @@ class AutoModelForCausalLMFactory(ModelFactory):
self._quant_config_reader = reader self._quant_config_reader = reader
self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs) self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs)
def _remap_param_names_load_hook(self, model, state_dict, *args, **kwargs) -> None:
"""Hook to handle potential param name conversions.
Some models' transformers implementation can change in between when safetensors were produced
and / or uploaded to HuggingFace hub. This hook applies the mapping (when present) to reflect
these differences.
"""
conversion_mapping = self._checkpoint_conversion_mapping
if conversion_mapping:
keys_to_process = list(state_dict.keys())
for key in keys_to_process:
new_key = key
for pattern, replacement in conversion_mapping.items():
new_key = re.sub(pattern, replacement, new_key)
if new_key != key:
state_dict[new_key] = state_dict.pop(key)
class _StateDictParamNameConverter:
"""Helper class for applying param name conversions to a state dict.
The reason this is a class instead of a method of factory like `_remap_param_names_load_hook`
is because PyTorch tries to set an `_from_public_api` attribute on hooks, and bound instance
methods cannot have attributes set on them without major hacks.
"""
def __init__(self, conversion_mapping: Optional[Dict[str, str]]):
conversion_mapping = conversion_mapping or {}
# NOTE: most of the code in this class is forked from `PreTrainedModel.save_pretrained`.
reverse_key_mapping = {v: k for k, v in conversion_mapping.items()}
self._mapping = reverse_key_mapping
def __call__(self, module, state_dict, *args, **kwargs) -> None:
"""Hook to handle potential param name conversions.
For the same reasons as the `load` hook, we define one to for `state_dict`. This is to silence
potentially misleading warnings about certain parameter names not being used, because the
`accelerate` library's logic for determining which keys are unexpected bases it on the keys
in the `module.state_dict()` return value, not on what `module.load_state_dict()` returns.
"""
if self._mapping:
keys_to_process = list(state_dict.keys())
for key in keys_to_process:
new_key = key
for pattern, replacement in self._mapping.items():
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*\)", "", replacement)
new_key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break
if new_key != key:
state_dict[new_key] = state_dict.pop(key)
@ModelFactoryRegistry.register("AutoModelForImageTextToText") @ModelFactoryRegistry.register("AutoModelForImageTextToText")
class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory): class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
@ -426,17 +505,19 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
} }
] ]
# Create a batch of conversations (batch_size = 2) # Create a batch of conversations (batch_size = 2).
# Note that we explicitly use 2 images in the examples to avoid potential shape specialization(s)
# in `torch.compile` / `torch.export`.
batch_messages = [ batch_messages = [
_prep_seq( _prep_seq(
"Describe what you see in the two images and their differences.", "Describe what you see in the two images and their differences.",
Image.new("RGB", (16, 16), color=(128, 128, 128)), Image.new("RGB", self._example_image_dims, color=(128, 128, 128)),
Image.new("RGB", (16, 16), color=(64, 64, 64)), Image.new("RGB", self._example_image_dims, color=(64, 64, 64)),
), ),
_prep_seq( _prep_seq(
"What are the main differences between these two images?", "What are the main differences between these two images?",
Image.new("RGB", (16, 16), color=(255, 0, 0)), Image.new("RGB", self._example_image_dims, color=(255, 0, 0)),
Image.new("RGB", (16, 16), color=(0, 255, 0)), Image.new("RGB", self._example_image_dims, color=(0, 255, 0)),
), ),
] ]
@ -451,10 +532,15 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
return_attention_mask=False, return_attention_mask=False,
) )
return { # We should have no need for the attention mask, and it can actually cause issues in
"input_ids": inputs["input_ids"], # downstream code.
"pixel_values": inputs["pixel_values"], inputs.pop("attention_mask", None)
}
# NOTES:
# 1. `inputs` is dict-like, but not a dict (hence the dict unpacking below).
# 2. Although `get_extra_inputs` allows implementations to specify "extra inputs", the example
# values still need to be returned by `get_example_inputs`.
return {**inputs}
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]: def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]:
"""Return a dictionary of extra inputs for the model. """Return a dictionary of extra inputs for the model.
@ -476,3 +562,10 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
none_pixel_values = torch.zeros(0, 3, 336, 336) none_pixel_values = torch.zeros(0, 3, 336, 336)
return {"pixel_values": (none_pixel_values, _get_dynamic_shape)} return {"pixel_values": (none_pixel_values, _get_dynamic_shape)}
@property
def _example_image_dims(self) -> Tuple[int, int]:
# Some specializations (children) of this class may override this if their models have
# assumptions on the image dimensions. For example, they may have a lower bound due to
# the patch size they use.
return (16, 16)

View File

@ -0,0 +1,56 @@
"""Auto-deploy model factory for Mistral3 models."""
from typing import Dict, Tuple
import torch
from tensorrt_llm._torch.auto_deploy.custom_ops import attention_interface
from tensorrt_llm._torch.auto_deploy.models import factory, hf
@factory.ModelFactoryRegistry.register("Mistral3VLM")
class Mistral3VLM(hf.AutoModelForImageTextToTextFactory):
def get_extra_inputs(
self,
) -> Dict[str, Tuple[torch.Tensor, attention_interface.DynamicShapeCallback]]:
"""Return a dictionary of extra inputs for the model.
Returns:
A dictionary of extra inputs for the model where the key corresponds to the argument
name and the value corresponds to a tuple of (example_input, dynamic_shape_callback).
The dynamic shape callback is a function that returns the dynamic shape of the extra
input.
"""
extra_inputs = super().get_extra_inputs()
# Reuse the same dynamic batch dimension for `image_sizes`.
batch_dim = extra_inputs["pixel_values"][1]()[0]
extra_inputs["image_sizes"] = (torch.zeros(0, 2, dtype=torch.long), lambda: {0: batch_dim})
return extra_inputs
@staticmethod
def _simple_forward(
model: torch.nn.Module,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
pixel_values: torch.Tensor,
image_sizes: torch.Tensor,
):
"""A simple forward pass for the model to functionalize the args.
This follows the standard function signature as expected by factory.py.
"""
return type(model).forward(
model,
input_ids=input_ids,
position_ids=position_ids,
pixel_values=pixel_values,
image_sizes=image_sizes,
)
@property
def _example_image_dims(self) -> Tuple[int, int]:
# The pixtral processor requires a minimum image size, which is larger than the default (16, 16)
# in the parent class.
# TODO: figure this out on the model config somehow (patch size value, etc.).
return (64, 64)

View File

@ -0,0 +1,179 @@
"""A patch for the Mistral3Model to make it compatible with torch.export."""
from typing import List, Optional, Union
import torch
from transformers.models.mistral3.modeling_mistral3 import (
Mistral3Model,
Mistral3ModelOutputWithPast,
)
from ...export.interface import BaseExportPatch, ExportPatchRegistry
def _get_image_features_flat(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
**kwargs,
):
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
image_outputs = self.vision_tower(
pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs
)
if isinstance(vision_feature_layer, int):
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
else:
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
image_features = image_features.squeeze(0)
return image_features
# NOTE: the main reason for this patch's existence is the `torch.cond` branching logic to handle the
# presence / absence of image features in a `torch.export`-compatible way.
def _mistral_forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[tuple, Mistral3ModelOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
def _no_vision_branch(
input_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
pixel_values: torch.Tensor,
image_sizes: Optional[torch.Tensor],
):
return inputs_embeds
def _vision_branch(
# ! The type annotations in the original transformers code are all wrong.
input_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
pixel_values: torch.Tensor,
image_sizes: Optional[torch.Tensor],
):
pixel_values = pixel_values.to(torch.bfloat16)
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
image_sizes=image_sizes,
)
# HF returns a list of tensors; our patch may already return a single tensor.
# Only concatenate when a list/tuple is returned.
if isinstance(image_features, (list, tuple)):
image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
return inputs_embeds
# Decide by whether there is any non-zero pixel_values.
has_image: torch.Tensor = (pixel_values is not None) and torch.any(pixel_values != 0)
# `torch.cond` serves 2 purposes here:
# 1. It lets the export stage know that there could be both image and no-image branches.
# Without this, the export stage would just assume that whatever the example input contains
# is representative of _all_ inputs at runtime. This means that, if we export it with images
# in the inputs, it would crash when called without images (i.e. in text-only mode).
# 2. It introduces a subgraph, which the pattern matcher will ignore. This is important as we
# do not want the vision model's attention ops to be converted by the pattern matcher to have
# KV cache enabled on them, as it would be both unnecessary to do so and potentially bad for
# performance.
inputs_embeds = torch.cond(
has_image,
_vision_branch,
_no_vision_branch,
(input_ids, inputs_embeds, pixel_values, image_sizes),
)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
return Mistral3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
# NOTE: this is hardcoded since we make no use of this.
image_hidden_states=None,
)
@ExportPatchRegistry.register("hf_mistral3")
class Mistral3ModelPatch(BaseExportPatch):
"""Patch for `Mistral3Model`."""
def _apply_patch(self):
"""Apply the Mistral3Model patch."""
self.original_values["Mistral3Model.forward"] = Mistral3Model.forward
self.original_values["Mistral3Model.get_image_features"] = Mistral3Model.get_image_features
Mistral3Model.forward = _mistral_forward
Mistral3Model.get_image_features = _get_image_features_flat
def _revert_patch(self):
"""Revert the Mistral3Model patch."""
# Restore original forward method.
Mistral3Model.forward = self.original_values["Mistral3Model.forward"]
Mistral3Model.get_image_features = self.original_values["Mistral3Model.get_image_features"]

View File

@ -0,0 +1,231 @@
"""Patches for the PixtralVisionModel to make it compatible with `torch.export`.
On top of the patching, `custom_op`s are registered to replace specific parts of the Pixtral model's
forward pass that are not compatible with `torch.export`. Note that the `register_fake` portion of
the ops needs to return the shape (and dtype) of the output tensor(s) without accessing the values in
the input tensors, which is where things get tricky, and why so many custom ops / patches are needed.
"""
import torch
from transformers.models.mistral3.modeling_mistral3 import Mistral3PatchMerger
from transformers.models.pixtral.modeling_pixtral import (
PixtralRMSNorm,
PixtralVisionModel,
position_ids_in_meshgrid,
)
from ...export.interface import BaseExportPatch, ExportPatchRegistry
# NOTES:
# 1. Everything decorated by a `custom_op` must be type annotated.
# 2. The annotations must be one of the internally supported param types. As such, `self: PixtralVisionModel`
# is a no-go.
# 3. This means that pretty much only free-standing functions with tensor inputs are supported - instance
# methods cannot be decorated.
@torch.library.custom_op("auto_deploy::pixtral_process_patch_embeds", mutates_args={})
def _process_patch_embeds(
patch_embeds: torch.Tensor,
image_sizes: torch.Tensor,
patch_size: int,
hidden_size: int,
max_width: int,
) -> tuple[torch.Tensor, torch.Tensor]:
patch_embeds_list = []
for embed, size in zip(patch_embeds, image_sizes):
# size is a 1-D tensor [H, W]; convert to Python ints for indexing.
h = int((size[0] // patch_size).item())
w = int((size[1] // patch_size).item())
patch_embeds_list.append(embed[..., :h, :w])
# flatten to a single sequence
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
position_ids = position_ids_in_meshgrid(patch_embeds_list, max_width=max_width)
return patch_embeds, position_ids
@_process_patch_embeds.register_fake
def _process_patch_embeds_meta(
patch_embeds: torch.Tensor,
image_sizes: torch.Tensor,
patch_size: int,
hidden_size: int,
max_widht: int,
):
B = (image_sizes // patch_size).prod(dim=1).sum()
device = patch_embeds.device
return (
# Leading 1 = `unsqueeze(0)` after concatenating the `patch_embeds_list`.
torch.empty(1, B, hidden_size, device=device),
torch.empty(B, device=device, dtype=torch.int64),
)
def _pixtral_forward(
self: PixtralVisionModel,
pixel_values: torch.Tensor,
image_sizes: torch.Tensor | None,
output_hidden_states: bool | None = None,
output_attentions: bool | None = None,
return_dict: bool | None = None,
*args,
**kwargs,
):
if image_sizes is None:
batch_size, _, height, width = pixel_values.shape
image_sizes = torch.tensor([(height, width)] * batch_size, device=pixel_values.device)
# pass images through initial convolution independently
patch_embeds = self.patch_conv(pixel_values)
patch_embeds, position_ids = torch.ops.auto_deploy.pixtral_process_patch_embeds(
patch_embeds=patch_embeds,
image_sizes=image_sizes,
patch_size=self.patch_size,
hidden_size=self.config.hidden_size,
max_width=self.config.image_size // self.config.patch_size,
)
patch_embeds = self.ln_pre(patch_embeds)
# Constrain sequence length to be size-like and > 1 for export guards.
_seq_len = patch_embeds.shape[1]
torch._check_is_size(_seq_len)
torch._check(_seq_len > 1)
position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids)
if self.config._attn_implementation == "flash_attention_2":
# We only rely on position_ids when using flash_attention_2
attention_mask = None
else:
attention_mask = generate_block_attention_mask(
(image_sizes // self.config.patch_size).prod(dim=1),
patch_embeds,
)
out = self.transformer(
patch_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
position_embeddings=position_embeddings,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=True,
**kwargs,
)
return out
def generate_block_attention_mask(num_ids_per_image, tensor):
dtype = tensor.dtype
device = tensor.device
if not isinstance(num_ids_per_image, torch.Tensor):
num_ids_per_image = torch.as_tensor(num_ids_per_image, device=device, dtype=torch.long)
else:
num_ids_per_image = num_ids_per_image.to(device=device, dtype=torch.long)
# Build per-token block ids: [0 repeated n0, 1 repeated n1, ...].
block_ids = torch.repeat_interleave(
torch.arange(num_ids_per_image.numel(), device=device), num_ids_per_image
)
# same_block[i, j] is True if tokens i and j belong to the same image block.
same_block = block_ids[:, None] == block_ids[None, :]
# Mask: 0 inside blocks, 1 outside blocks (match previous function's output), tensor-only.
mask = (~same_block).to(dtype)
d_min = torch.finfo(dtype).min
mask *= d_min
return mask
@torch.library.custom_op("auto_deploy::pixtral_unfold_to_2d_grid", mutates_args={})
def _unfold_to_2d_grid(
image_features: torch.Tensor,
image_sizes: torch.Tensor,
patch_size: int,
spatial_merge_size: int,
) -> torch.Tensor:
image_sizes = [
(image_size[0] // patch_size, image_size[1] // patch_size) for image_size in image_sizes
]
tokens_per_image = [h * w for h, w in image_sizes]
d = image_features.shape[-1]
permuted_tensor = []
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
# Reshape image_tokens into a 2D grid
h, w = image_sizes[image_index]
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
grid = torch.nn.functional.unfold(
image_grid, kernel_size=spatial_merge_size, stride=spatial_merge_size
)
grid = grid.view(d * spatial_merge_size**2, -1).t()
permuted_tensor.append(grid)
image_features = torch.cat(permuted_tensor, dim=0)
return image_features
@_unfold_to_2d_grid.register_fake
def _unfold_to_2d_grid_meta(
image_features: torch.Tensor,
image_sizes: torch.Tensor,
patch_size: int,
spatial_merge_size: int,
):
embedding_sizes = (image_sizes // patch_size).prod(dim=1)
spatial_factor = spatial_merge_size * spatial_merge_size
grid_sizes = embedding_sizes // spatial_factor
total_size = grid_sizes.sum()
return image_features.new_empty(total_size, image_features.shape[-1] * spatial_factor)
def _patch_merger_forward(
self, image_features: torch.Tensor, image_sizes: torch.Tensor
) -> torch.Tensor:
unfolded_features = torch.ops.auto_deploy.pixtral_unfold_to_2d_grid(
image_features=image_features,
image_sizes=image_sizes,
patch_size=self.patch_size,
spatial_merge_size=self.spatial_merge_size,
)
image_features = self.merging_layer(unfolded_features)
return image_features
# Somehow there are dtype mismatches at runtime between bfloat16 and float32 without this.
def _pixtral_rms_norm_forward(self, hidden_states):
input_dtype = torch.bfloat16
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
@ExportPatchRegistry.register("hf_pixtral_vit")
class PixtralVisionModelPatch(BaseExportPatch):
"""Patch for `PixtralVisionModel`."""
def _apply_patch(self):
"""Apply the PixtralVisionModel patch."""
self.original_values["PixtralVisionModel.forward"] = PixtralVisionModel.forward
self.original_values["Mistral3PatchMerger.forward"] = Mistral3PatchMerger.forward
self.original_values["PixtralRMSNorm.forward"] = PixtralRMSNorm.forward
PixtralVisionModel.forward = _pixtral_forward
Mistral3PatchMerger.forward = _patch_merger_forward
PixtralRMSNorm.forward = _pixtral_rms_norm_forward
def _revert_patch(self):
"""Revert the PixtralVisionModel patch."""
PixtralVisionModel.forward = self.original_values["PixtralVisionModel.forward"]
Mistral3PatchMerger.forward = self.original_values["Mistral3PatchMerger.forward"]
PixtralRMSNorm.forward = self.original_values["PixtralRMSNorm.forward"]

View File

@ -121,6 +121,11 @@ class DemoEngine(ADEngine):
batch_size = sequence_info.num_sequences batch_size = sequence_info.num_sequences
new_tokens = [[] for _ in range(batch_size)] # [batch_size][max_seq_len] new_tokens = [[] for _ in range(batch_size)] # [batch_size][max_seq_len]
stop_tokens = sampling_params._get_stop_words() stop_tokens = sampling_params._get_stop_words()
# NOTE: TRTLLM has made the intentional choice to separate `end_id` from `stop_words`, and not
# include the former in the latter's corresponding stop IDs. From a UX perspective, `stop_words`
# are optional, and can be customized per user requests, whereas `end_id` is static per model,
# and should always be used outside of benchmarking.
stop_tokens.append([sampling_params.end_id])
idxs_stop = [sampling_params.max_tokens - 1] * batch_size idxs_stop = [sampling_params.max_tokens - 1] * batch_size
gen_logits = [] if sampling_params.return_generation_logits else None gen_logits = [] if sampling_params.return_generation_logits else None
context_logits: Optional[List[torch.Tensor]] = None context_logits: Optional[List[torch.Tensor]] = None

View File

@ -434,6 +434,15 @@ _SMALL_MODEL_CONFIGS = {
"num_hidden_layers": 2, "num_hidden_layers": 2,
}, },
}, },
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": {
"model": f"{llm_models_root()}/Mistral-Small-3.1-24B-Instruct-2503",
"model_factory": "Mistral3VLM",
"compile_backend": "torch-simple",
"model_kwargs": {
"text_config": {"num_hidden_layers": 2},
"vision_config": {"num_hidden_layers": 2},
},
},
} }

View File

@ -0,0 +1,15 @@
from tensorrt_llm._torch.auto_deploy.models import mistral3
def test_get_extra_inputs_includes_image_sizes():
factory = mistral3.Mistral3VLM(model="test-model")
extra_inputs = factory.get_extra_inputs()
pixel_values = extra_inputs["pixel_values"]
image_sizes = extra_inputs["image_sizes"]
pixel_values_dynamic_shape = pixel_values[1]()
image_sizes_dynamic_shape = image_sizes[1]()
# Unfortunately, direct object comparisons do not work.
assert pixel_values_dynamic_shape[0].__dict__ == image_sizes_dynamic_shape[0].__dict__

View File

@ -0,0 +1,90 @@
import torch
from _model_test_utils import get_small_model_config
from build_and_run_ad import ExperimentConfig
from tensorrt_llm._torch.auto_deploy import LlmArgs
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
def test_build_run_mistral3_vlm():
experiment_config = get_small_model_config("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
experiment_config = ExperimentConfig(**experiment_config)
llm_args: LlmArgs = experiment_config.args
factory = llm_args.create_factory()
model = factory.build_model("cuda")
inputs = factory.get_example_inputs()
for key, value in inputs.items():
if isinstance(value, torch.Tensor):
dtype = torch.bfloat16 if isinstance(value, torch.FloatTensor) else None
inputs[key] = value.to(device=model.device, dtype=dtype)
# get relevant inputs
input_ids = inputs["input_ids"]
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).repeat(
input_ids.shape[0], 1
)
pixel_values = inputs["pixel_values"]
image_sizes = inputs["image_sizes"]
def _run_with_and_without_image(model, use_patch=True):
with apply_export_patches(
patch_list=["hf_mistral3", "hf_pixtral_vit"] if use_patch else []
):
with torch.inference_mode():
out_no_images = model(
input_ids=input_ids,
position_ids=position_ids,
pixel_values=torch.zeros_like(pixel_values) if use_patch else None,
image_sizes=image_sizes if use_patch else None,
)
out_with_images = model(
input_ids=input_ids,
position_ids=position_ids,
pixel_values=pixel_values,
image_sizes=image_sizes,
)
return {"no_images": out_no_images.logits, "with_images": out_with_images.logits}
# Get output pre-patch.
out_original = _run_with_and_without_image(model, use_patch=False)
# Get output post-patch.
outputs_for_comparison = {}
# TODO(2ez4bz): Figure out why the patches do not work outside of `torch_export_to_gm`.
# outputs_for_comparison["model_with_patch"] = _run_with_and_without_image(model)
gm = torch_export_to_gm(
model,
args=(),
kwargs={
"input_ids": input_ids,
"position_ids": position_ids,
"pixel_values": pixel_values,
"image_sizes": image_sizes,
},
patch_list=[
"transformers_sdpa_mask",
"autocast_noop",
"torch_where",
"tensor_meta_device",
"sdpa_kernel_noop",
"sdpa",
"hf_mistral3",
"hf_pixtral_vit",
],
)
move_to_device(gm, model.device)
outputs_for_comparison["gm"] = _run_with_and_without_image(gm)
atol, rtol = 1e-3, 1e-3
for comp, outs in outputs_for_comparison.items():
torch.testing.assert_close(
outs,
out_original,
rtol=rtol,
atol=atol,
)

View File

@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pydantic
import pytest import pytest
from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM, LlmArgs from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM, LlmArgs
@ -147,6 +148,21 @@ def test_config_flow(
pass pass
@pytest.mark.parametrize(
"model_factory",
[
"Foo",
# typo.
"AutomodelForCausalLMFactory",
],
)
def test_non_registered_model_factory(model_factory: str):
with pytest.raises(
pydantic.ValidationError, match="does not exist in the model factory registry"
):
LlmArgs(model="test-model", model_factory=model_factory)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"parallel_field,invalid_value", "parallel_field,invalid_value",
[ [

View File

@ -84,6 +84,15 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
attn_backend="triton", attn_backend="triton",
compile_backend="torch-compile", compile_backend="torch-compile",
), ),
pytest.param(
get_small_model_config(
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
attn_backend="flashinfer",
compile_backend="torch-simple",
),
# Human readable name for readability / easier selection with `-k`.
id="mistral-small-3.1-24b",
),
], ],
) )
def test_build_ad(experiment_config: Dict): def test_build_ad(experiment_config: Dict):