mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Extend VLM factory and add Mistral3 factory (#7583)
This commit: * extends existing factory interfaces to enable Mistral3 in AutoDeploy. * adds a Mistral3 VLM factory. * adds various model patches for pixtral (the vision model) and mistral3 to make the VLM export compliant. * adjusts checkpoint loading code to take possible parameter name conversions into account. * fixes a sampling bug (the `end_id` needs to be take into account when sampling, but it is not included in the stop words' token IDs). Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
parent
6ba1c8421c
commit
c53d1814a7
@ -211,6 +211,17 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
|||||||
self.attn_page_size = self.max_seq_len
|
self.attn_page_size = self.max_seq_len
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@field_validator("model_factory", mode="after")
|
||||||
|
@classmethod
|
||||||
|
def model_factory_exists(cls, value: str) -> str:
|
||||||
|
if not ModelFactoryRegistry.has(value):
|
||||||
|
raise ValueError(
|
||||||
|
f"'{value}' does not exist in the model factory registry. Available values: "
|
||||||
|
f"{ModelFactoryRegistry.entries()}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
### UTILITY METHODS ############################################################################
|
### UTILITY METHODS ############################################################################
|
||||||
def create_factory(self) -> ModelFactory:
|
def create_factory(self) -> ModelFactory:
|
||||||
"""Create a model factory from the arguments."""
|
"""Create a model factory from the arguments."""
|
||||||
|
|||||||
@ -1,2 +1,2 @@
|
|||||||
from . import hf, patches
|
from . import hf, mistral3, patches
|
||||||
from .factory import *
|
from .factory import *
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -282,3 +282,7 @@ class ModelFactoryRegistry:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def has(cls, model_factory_cls: str) -> bool:
|
def has(cls, model_factory_cls: str) -> bool:
|
||||||
return model_factory_cls in cls._registry
|
return model_factory_cls in cls._registry
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def entries(cls) -> List[str]:
|
||||||
|
return list(cls._registry.keys())
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""Interface to initialize and load HF models."""
|
"""Interface to initialize and load HF models."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import types
|
import types
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
@ -99,6 +100,11 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
|||||||
# set sharding config source to huggingface
|
# set sharding config source to huggingface
|
||||||
self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE
|
self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE
|
||||||
|
|
||||||
|
# Some models' transformers implementation has changed in between when safetensors were produced
|
||||||
|
# and / or uploaded to HuggingFace hub. When building the model, we will try to determine whether
|
||||||
|
# a mapping of the parameter names exists and hold that information in this attribute.
|
||||||
|
self._checkpoint_conversion_mapping: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def autoconfig_from_pretrained(self):
|
def autoconfig_from_pretrained(self):
|
||||||
return AutoConfig.from_pretrained
|
return AutoConfig.from_pretrained
|
||||||
@ -168,6 +174,7 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
|||||||
|
|
||||||
# if present, initialize sharding config. We need head_dim for colwise sharding.
|
# if present, initialize sharding config. We need head_dim for colwise sharding.
|
||||||
self._set_sharding_config(model.config)
|
self._set_sharding_config(model.config)
|
||||||
|
self._checkpoint_conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
|
||||||
|
|
||||||
# patch forward method
|
# patch forward method
|
||||||
model.forward = types.MethodType(self._simple_forward, model)
|
model.forward = types.MethodType(self._simple_forward, model)
|
||||||
@ -326,15 +333,30 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
|||||||
"""Load the checkpoint into the model."""
|
"""Load the checkpoint into the model."""
|
||||||
# identify the most relevant checkpoint file
|
# identify the most relevant checkpoint file
|
||||||
ckpt_file = self._get_checkpoint_file(self.model)
|
ckpt_file = self._get_checkpoint_file(self.model)
|
||||||
|
|
||||||
|
load_handle = model.register_load_state_dict_pre_hook(self._remap_param_names_load_hook)
|
||||||
|
# Ensure it's the first one.
|
||||||
|
model._load_state_dict_pre_hooks.move_to_end(key=load_handle.id, last=False)
|
||||||
|
|
||||||
|
get_handle = model.register_state_dict_post_hook(
|
||||||
|
_StateDictParamNameConverter(self._checkpoint_conversion_mapping)
|
||||||
|
)
|
||||||
|
# Ensure it's the first one.
|
||||||
|
model._state_dict_hooks.move_to_end(key=get_handle.id, last=False)
|
||||||
|
|
||||||
# reuse the load checkpoint utility from accelerate
|
# reuse the load checkpoint utility from accelerate
|
||||||
with hf_load_state_dict_with_device(device):
|
try:
|
||||||
# Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic.
|
with hf_load_state_dict_with_device(device):
|
||||||
# Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict,
|
# Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic.
|
||||||
# which collects local model params, syncs weights from checkpoint, and applies them via
|
# Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict,
|
||||||
# model.load_state_dict.
|
# which collects local model params, syncs weights from checkpoint, and applies them via
|
||||||
# This sync step can interfere with load_hooks by mixing raw checkpoint weights and
|
# model.load_state_dict.
|
||||||
# model-transformed weights,leading to unexpected key mismatches or format issues.
|
# This sync step can interfere with load_hooks by mixing raw checkpoint weights and
|
||||||
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
|
# model-transformed weights,leading to unexpected key mismatches or format issues.
|
||||||
|
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
|
||||||
|
finally:
|
||||||
|
load_handle.remove()
|
||||||
|
get_handle.remove()
|
||||||
|
|
||||||
def _load_quantization_config(self, fetched_dir: str):
|
def _load_quantization_config(self, fetched_dir: str):
|
||||||
"""Load the quantization config from the model directory if not done already."""
|
"""Load the quantization config from the model directory if not done already."""
|
||||||
@ -351,6 +373,63 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
|||||||
self._quant_config_reader = reader
|
self._quant_config_reader = reader
|
||||||
self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs)
|
self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs)
|
||||||
|
|
||||||
|
def _remap_param_names_load_hook(self, model, state_dict, *args, **kwargs) -> None:
|
||||||
|
"""Hook to handle potential param name conversions.
|
||||||
|
|
||||||
|
Some models' transformers implementation can change in between when safetensors were produced
|
||||||
|
and / or uploaded to HuggingFace hub. This hook applies the mapping (when present) to reflect
|
||||||
|
these differences.
|
||||||
|
"""
|
||||||
|
conversion_mapping = self._checkpoint_conversion_mapping
|
||||||
|
if conversion_mapping:
|
||||||
|
keys_to_process = list(state_dict.keys())
|
||||||
|
for key in keys_to_process:
|
||||||
|
new_key = key
|
||||||
|
for pattern, replacement in conversion_mapping.items():
|
||||||
|
new_key = re.sub(pattern, replacement, new_key)
|
||||||
|
|
||||||
|
if new_key != key:
|
||||||
|
state_dict[new_key] = state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
|
class _StateDictParamNameConverter:
|
||||||
|
"""Helper class for applying param name conversions to a state dict.
|
||||||
|
|
||||||
|
The reason this is a class instead of a method of factory like `_remap_param_names_load_hook`
|
||||||
|
is because PyTorch tries to set an `_from_public_api` attribute on hooks, and bound instance
|
||||||
|
methods cannot have attributes set on them without major hacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, conversion_mapping: Optional[Dict[str, str]]):
|
||||||
|
conversion_mapping = conversion_mapping or {}
|
||||||
|
|
||||||
|
# NOTE: most of the code in this class is forked from `PreTrainedModel.save_pretrained`.
|
||||||
|
reverse_key_mapping = {v: k for k, v in conversion_mapping.items()}
|
||||||
|
self._mapping = reverse_key_mapping
|
||||||
|
|
||||||
|
def __call__(self, module, state_dict, *args, **kwargs) -> None:
|
||||||
|
"""Hook to handle potential param name conversions.
|
||||||
|
|
||||||
|
For the same reasons as the `load` hook, we define one to for `state_dict`. This is to silence
|
||||||
|
potentially misleading warnings about certain parameter names not being used, because the
|
||||||
|
`accelerate` library's logic for determining which keys are unexpected bases it on the keys
|
||||||
|
in the `module.state_dict()` return value, not on what `module.load_state_dict()` returns.
|
||||||
|
"""
|
||||||
|
if self._mapping:
|
||||||
|
keys_to_process = list(state_dict.keys())
|
||||||
|
for key in keys_to_process:
|
||||||
|
new_key = key
|
||||||
|
for pattern, replacement in self._mapping.items():
|
||||||
|
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
|
||||||
|
replacement = re.sub(r"\(.*\)", "", replacement)
|
||||||
|
new_key, n_replace = re.subn(pattern, replacement, key)
|
||||||
|
# Early exit of the loop
|
||||||
|
if n_replace > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if new_key != key:
|
||||||
|
state_dict[new_key] = state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
@ModelFactoryRegistry.register("AutoModelForImageTextToText")
|
@ModelFactoryRegistry.register("AutoModelForImageTextToText")
|
||||||
class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
|
class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
|
||||||
@ -426,17 +505,19 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create a batch of conversations (batch_size = 2)
|
# Create a batch of conversations (batch_size = 2).
|
||||||
|
# Note that we explicitly use 2 images in the examples to avoid potential shape specialization(s)
|
||||||
|
# in `torch.compile` / `torch.export`.
|
||||||
batch_messages = [
|
batch_messages = [
|
||||||
_prep_seq(
|
_prep_seq(
|
||||||
"Describe what you see in the two images and their differences.",
|
"Describe what you see in the two images and their differences.",
|
||||||
Image.new("RGB", (16, 16), color=(128, 128, 128)),
|
Image.new("RGB", self._example_image_dims, color=(128, 128, 128)),
|
||||||
Image.new("RGB", (16, 16), color=(64, 64, 64)),
|
Image.new("RGB", self._example_image_dims, color=(64, 64, 64)),
|
||||||
),
|
),
|
||||||
_prep_seq(
|
_prep_seq(
|
||||||
"What are the main differences between these two images?",
|
"What are the main differences between these two images?",
|
||||||
Image.new("RGB", (16, 16), color=(255, 0, 0)),
|
Image.new("RGB", self._example_image_dims, color=(255, 0, 0)),
|
||||||
Image.new("RGB", (16, 16), color=(0, 255, 0)),
|
Image.new("RGB", self._example_image_dims, color=(0, 255, 0)),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -451,10 +532,15 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
|
|||||||
return_attention_mask=False,
|
return_attention_mask=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
# We should have no need for the attention mask, and it can actually cause issues in
|
||||||
"input_ids": inputs["input_ids"],
|
# downstream code.
|
||||||
"pixel_values": inputs["pixel_values"],
|
inputs.pop("attention_mask", None)
|
||||||
}
|
|
||||||
|
# NOTES:
|
||||||
|
# 1. `inputs` is dict-like, but not a dict (hence the dict unpacking below).
|
||||||
|
# 2. Although `get_extra_inputs` allows implementations to specify "extra inputs", the example
|
||||||
|
# values still need to be returned by `get_example_inputs`.
|
||||||
|
return {**inputs}
|
||||||
|
|
||||||
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]:
|
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]:
|
||||||
"""Return a dictionary of extra inputs for the model.
|
"""Return a dictionary of extra inputs for the model.
|
||||||
@ -476,3 +562,10 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
|
|||||||
|
|
||||||
none_pixel_values = torch.zeros(0, 3, 336, 336)
|
none_pixel_values = torch.zeros(0, 3, 336, 336)
|
||||||
return {"pixel_values": (none_pixel_values, _get_dynamic_shape)}
|
return {"pixel_values": (none_pixel_values, _get_dynamic_shape)}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _example_image_dims(self) -> Tuple[int, int]:
|
||||||
|
# Some specializations (children) of this class may override this if their models have
|
||||||
|
# assumptions on the image dimensions. For example, they may have a lower bound due to
|
||||||
|
# the patch size they use.
|
||||||
|
return (16, 16)
|
||||||
|
|||||||
56
tensorrt_llm/_torch/auto_deploy/models/mistral3.py
Normal file
56
tensorrt_llm/_torch/auto_deploy/models/mistral3.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
"""Auto-deploy model factory for Mistral3 models."""
|
||||||
|
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tensorrt_llm._torch.auto_deploy.custom_ops import attention_interface
|
||||||
|
from tensorrt_llm._torch.auto_deploy.models import factory, hf
|
||||||
|
|
||||||
|
|
||||||
|
@factory.ModelFactoryRegistry.register("Mistral3VLM")
|
||||||
|
class Mistral3VLM(hf.AutoModelForImageTextToTextFactory):
|
||||||
|
def get_extra_inputs(
|
||||||
|
self,
|
||||||
|
) -> Dict[str, Tuple[torch.Tensor, attention_interface.DynamicShapeCallback]]:
|
||||||
|
"""Return a dictionary of extra inputs for the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of extra inputs for the model where the key corresponds to the argument
|
||||||
|
name and the value corresponds to a tuple of (example_input, dynamic_shape_callback).
|
||||||
|
The dynamic shape callback is a function that returns the dynamic shape of the extra
|
||||||
|
input.
|
||||||
|
"""
|
||||||
|
extra_inputs = super().get_extra_inputs()
|
||||||
|
# Reuse the same dynamic batch dimension for `image_sizes`.
|
||||||
|
batch_dim = extra_inputs["pixel_values"][1]()[0]
|
||||||
|
extra_inputs["image_sizes"] = (torch.zeros(0, 2, dtype=torch.long), lambda: {0: batch_dim})
|
||||||
|
|
||||||
|
return extra_inputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _simple_forward(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
image_sizes: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""A simple forward pass for the model to functionalize the args.
|
||||||
|
|
||||||
|
This follows the standard function signature as expected by factory.py.
|
||||||
|
"""
|
||||||
|
return type(model).forward(
|
||||||
|
model,
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _example_image_dims(self) -> Tuple[int, int]:
|
||||||
|
# The pixtral processor requires a minimum image size, which is larger than the default (16, 16)
|
||||||
|
# in the parent class.
|
||||||
|
# TODO: figure this out on the model config somehow (patch size value, etc.).
|
||||||
|
return (64, 64)
|
||||||
179
tensorrt_llm/_torch/auto_deploy/models/patches/mistral3.py
Normal file
179
tensorrt_llm/_torch/auto_deploy/models/patches/mistral3.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
"""A patch for the Mistral3Model to make it compatible with torch.export."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.models.mistral3.modeling_mistral3 import (
|
||||||
|
Mistral3Model,
|
||||||
|
Mistral3ModelOutputWithPast,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...export.interface import BaseExportPatch, ExportPatchRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def _get_image_features_flat(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
image_sizes: torch.Tensor,
|
||||||
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer
|
||||||
|
if vision_feature_layer is not None
|
||||||
|
else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
image_outputs = self.vision_tower(
|
||||||
|
pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(vision_feature_layer, int):
|
||||||
|
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||||
|
else:
|
||||||
|
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||||
|
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||||
|
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
||||||
|
image_features = image_features.squeeze(0)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: the main reason for this patch's existence is the `torch.cond` branching logic to handle the
|
||||||
|
# presence / absence of image features in a `torch.export`-compatible way.
|
||||||
|
def _mistral_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[tuple, Mistral3ModelOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer
|
||||||
|
if vision_feature_layer is not None
|
||||||
|
else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if pixel_values is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
def _no_vision_branch(
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
inputs_embeds: torch.FloatTensor,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
image_sizes: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def _vision_branch(
|
||||||
|
# ! The type annotations in the original transformers code are all wrong.
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
inputs_embeds: torch.FloatTensor,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
image_sizes: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
pixel_values = pixel_values.to(torch.bfloat16)
|
||||||
|
image_features = self.get_image_features(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
vision_feature_layer=vision_feature_layer,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
)
|
||||||
|
# HF returns a list of tensors; our patch may already return a single tensor.
|
||||||
|
# Only concatenate when a list/tuple is returned.
|
||||||
|
if isinstance(image_features, (list, tuple)):
|
||||||
|
image_features = torch.cat(image_features, dim=0)
|
||||||
|
|
||||||
|
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||||
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
# Decide by whether there is any non-zero pixel_values.
|
||||||
|
has_image: torch.Tensor = (pixel_values is not None) and torch.any(pixel_values != 0)
|
||||||
|
|
||||||
|
# `torch.cond` serves 2 purposes here:
|
||||||
|
# 1. It lets the export stage know that there could be both image and no-image branches.
|
||||||
|
# Without this, the export stage would just assume that whatever the example input contains
|
||||||
|
# is representative of _all_ inputs at runtime. This means that, if we export it with images
|
||||||
|
# in the inputs, it would crash when called without images (i.e. in text-only mode).
|
||||||
|
# 2. It introduces a subgraph, which the pattern matcher will ignore. This is important as we
|
||||||
|
# do not want the vision model's attention ops to be converted by the pattern matcher to have
|
||||||
|
# KV cache enabled on them, as it would be both unnecessary to do so and potentially bad for
|
||||||
|
# performance.
|
||||||
|
inputs_embeds = torch.cond(
|
||||||
|
has_image,
|
||||||
|
_vision_branch,
|
||||||
|
_no_vision_branch,
|
||||||
|
(input_ids, inputs_embeds, pixel_values, image_sizes),
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.language_model(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=True,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Mistral3ModelOutputWithPast(
|
||||||
|
last_hidden_state=outputs.last_hidden_state,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
# NOTE: this is hardcoded since we make no use of this.
|
||||||
|
image_hidden_states=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ExportPatchRegistry.register("hf_mistral3")
|
||||||
|
class Mistral3ModelPatch(BaseExportPatch):
|
||||||
|
"""Patch for `Mistral3Model`."""
|
||||||
|
|
||||||
|
def _apply_patch(self):
|
||||||
|
"""Apply the Mistral3Model patch."""
|
||||||
|
self.original_values["Mistral3Model.forward"] = Mistral3Model.forward
|
||||||
|
self.original_values["Mistral3Model.get_image_features"] = Mistral3Model.get_image_features
|
||||||
|
|
||||||
|
Mistral3Model.forward = _mistral_forward
|
||||||
|
Mistral3Model.get_image_features = _get_image_features_flat
|
||||||
|
|
||||||
|
def _revert_patch(self):
|
||||||
|
"""Revert the Mistral3Model patch."""
|
||||||
|
# Restore original forward method.
|
||||||
|
Mistral3Model.forward = self.original_values["Mistral3Model.forward"]
|
||||||
|
Mistral3Model.get_image_features = self.original_values["Mistral3Model.get_image_features"]
|
||||||
231
tensorrt_llm/_torch/auto_deploy/models/patches/pixtral.py
Normal file
231
tensorrt_llm/_torch/auto_deploy/models/patches/pixtral.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
"""Patches for the PixtralVisionModel to make it compatible with `torch.export`.
|
||||||
|
|
||||||
|
On top of the patching, `custom_op`s are registered to replace specific parts of the Pixtral model's
|
||||||
|
forward pass that are not compatible with `torch.export`. Note that the `register_fake` portion of
|
||||||
|
the ops needs to return the shape (and dtype) of the output tensor(s) without accessing the values in
|
||||||
|
the input tensors, which is where things get tricky, and why so many custom ops / patches are needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.models.mistral3.modeling_mistral3 import Mistral3PatchMerger
|
||||||
|
from transformers.models.pixtral.modeling_pixtral import (
|
||||||
|
PixtralRMSNorm,
|
||||||
|
PixtralVisionModel,
|
||||||
|
position_ids_in_meshgrid,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...export.interface import BaseExportPatch, ExportPatchRegistry
|
||||||
|
|
||||||
|
# NOTES:
|
||||||
|
# 1. Everything decorated by a `custom_op` must be type annotated.
|
||||||
|
# 2. The annotations must be one of the internally supported param types. As such, `self: PixtralVisionModel`
|
||||||
|
# is a no-go.
|
||||||
|
# 3. This means that pretty much only free-standing functions with tensor inputs are supported - instance
|
||||||
|
# methods cannot be decorated.
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("auto_deploy::pixtral_process_patch_embeds", mutates_args={})
|
||||||
|
def _process_patch_embeds(
|
||||||
|
patch_embeds: torch.Tensor,
|
||||||
|
image_sizes: torch.Tensor,
|
||||||
|
patch_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
max_width: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
patch_embeds_list = []
|
||||||
|
for embed, size in zip(patch_embeds, image_sizes):
|
||||||
|
# size is a 1-D tensor [H, W]; convert to Python ints for indexing.
|
||||||
|
h = int((size[0] // patch_size).item())
|
||||||
|
w = int((size[1] // patch_size).item())
|
||||||
|
patch_embeds_list.append(embed[..., :h, :w])
|
||||||
|
|
||||||
|
# flatten to a single sequence
|
||||||
|
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
|
||||||
|
|
||||||
|
position_ids = position_ids_in_meshgrid(patch_embeds_list, max_width=max_width)
|
||||||
|
|
||||||
|
return patch_embeds, position_ids
|
||||||
|
|
||||||
|
|
||||||
|
@_process_patch_embeds.register_fake
|
||||||
|
def _process_patch_embeds_meta(
|
||||||
|
patch_embeds: torch.Tensor,
|
||||||
|
image_sizes: torch.Tensor,
|
||||||
|
patch_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
max_widht: int,
|
||||||
|
):
|
||||||
|
B = (image_sizes // patch_size).prod(dim=1).sum()
|
||||||
|
device = patch_embeds.device
|
||||||
|
return (
|
||||||
|
# Leading 1 = `unsqueeze(0)` after concatenating the `patch_embeds_list`.
|
||||||
|
torch.empty(1, B, hidden_size, device=device),
|
||||||
|
torch.empty(B, device=device, dtype=torch.int64),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _pixtral_forward(
|
||||||
|
self: PixtralVisionModel,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
image_sizes: torch.Tensor | None,
|
||||||
|
output_hidden_states: bool | None = None,
|
||||||
|
output_attentions: bool | None = None,
|
||||||
|
return_dict: bool | None = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if image_sizes is None:
|
||||||
|
batch_size, _, height, width = pixel_values.shape
|
||||||
|
image_sizes = torch.tensor([(height, width)] * batch_size, device=pixel_values.device)
|
||||||
|
|
||||||
|
# pass images through initial convolution independently
|
||||||
|
patch_embeds = self.patch_conv(pixel_values)
|
||||||
|
patch_embeds, position_ids = torch.ops.auto_deploy.pixtral_process_patch_embeds(
|
||||||
|
patch_embeds=patch_embeds,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
hidden_size=self.config.hidden_size,
|
||||||
|
max_width=self.config.image_size // self.config.patch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_embeds = self.ln_pre(patch_embeds)
|
||||||
|
|
||||||
|
# Constrain sequence length to be size-like and > 1 for export guards.
|
||||||
|
_seq_len = patch_embeds.shape[1]
|
||||||
|
torch._check_is_size(_seq_len)
|
||||||
|
torch._check(_seq_len > 1)
|
||||||
|
|
||||||
|
position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids)
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
# We only rely on position_ids when using flash_attention_2
|
||||||
|
attention_mask = None
|
||||||
|
else:
|
||||||
|
attention_mask = generate_block_attention_mask(
|
||||||
|
(image_sizes // self.config.patch_size).prod(dim=1),
|
||||||
|
patch_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.transformer(
|
||||||
|
patch_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
return_dict=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def generate_block_attention_mask(num_ids_per_image, tensor):
|
||||||
|
dtype = tensor.dtype
|
||||||
|
device = tensor.device
|
||||||
|
|
||||||
|
if not isinstance(num_ids_per_image, torch.Tensor):
|
||||||
|
num_ids_per_image = torch.as_tensor(num_ids_per_image, device=device, dtype=torch.long)
|
||||||
|
else:
|
||||||
|
num_ids_per_image = num_ids_per_image.to(device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
# Build per-token block ids: [0 repeated n0, 1 repeated n1, ...].
|
||||||
|
block_ids = torch.repeat_interleave(
|
||||||
|
torch.arange(num_ids_per_image.numel(), device=device), num_ids_per_image
|
||||||
|
)
|
||||||
|
# same_block[i, j] is True if tokens i and j belong to the same image block.
|
||||||
|
same_block = block_ids[:, None] == block_ids[None, :]
|
||||||
|
|
||||||
|
# Mask: 0 inside blocks, 1 outside blocks (match previous function's output), tensor-only.
|
||||||
|
mask = (~same_block).to(dtype)
|
||||||
|
d_min = torch.finfo(dtype).min
|
||||||
|
mask *= d_min
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("auto_deploy::pixtral_unfold_to_2d_grid", mutates_args={})
|
||||||
|
def _unfold_to_2d_grid(
|
||||||
|
image_features: torch.Tensor,
|
||||||
|
image_sizes: torch.Tensor,
|
||||||
|
patch_size: int,
|
||||||
|
spatial_merge_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
image_sizes = [
|
||||||
|
(image_size[0] // patch_size, image_size[1] // patch_size) for image_size in image_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
tokens_per_image = [h * w for h, w in image_sizes]
|
||||||
|
d = image_features.shape[-1]
|
||||||
|
|
||||||
|
permuted_tensor = []
|
||||||
|
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
|
||||||
|
# Reshape image_tokens into a 2D grid
|
||||||
|
h, w = image_sizes[image_index]
|
||||||
|
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
|
||||||
|
grid = torch.nn.functional.unfold(
|
||||||
|
image_grid, kernel_size=spatial_merge_size, stride=spatial_merge_size
|
||||||
|
)
|
||||||
|
grid = grid.view(d * spatial_merge_size**2, -1).t()
|
||||||
|
permuted_tensor.append(grid)
|
||||||
|
|
||||||
|
image_features = torch.cat(permuted_tensor, dim=0)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
|
||||||
|
@_unfold_to_2d_grid.register_fake
|
||||||
|
def _unfold_to_2d_grid_meta(
|
||||||
|
image_features: torch.Tensor,
|
||||||
|
image_sizes: torch.Tensor,
|
||||||
|
patch_size: int,
|
||||||
|
spatial_merge_size: int,
|
||||||
|
):
|
||||||
|
embedding_sizes = (image_sizes // patch_size).prod(dim=1)
|
||||||
|
spatial_factor = spatial_merge_size * spatial_merge_size
|
||||||
|
grid_sizes = embedding_sizes // spatial_factor
|
||||||
|
total_size = grid_sizes.sum()
|
||||||
|
|
||||||
|
return image_features.new_empty(total_size, image_features.shape[-1] * spatial_factor)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_merger_forward(
|
||||||
|
self, image_features: torch.Tensor, image_sizes: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
unfolded_features = torch.ops.auto_deploy.pixtral_unfold_to_2d_grid(
|
||||||
|
image_features=image_features,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
spatial_merge_size=self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
image_features = self.merging_layer(unfolded_features)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
|
||||||
|
# Somehow there are dtype mismatches at runtime between bfloat16 and float32 without this.
|
||||||
|
def _pixtral_rms_norm_forward(self, hidden_states):
|
||||||
|
input_dtype = torch.bfloat16
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@ExportPatchRegistry.register("hf_pixtral_vit")
|
||||||
|
class PixtralVisionModelPatch(BaseExportPatch):
|
||||||
|
"""Patch for `PixtralVisionModel`."""
|
||||||
|
|
||||||
|
def _apply_patch(self):
|
||||||
|
"""Apply the PixtralVisionModel patch."""
|
||||||
|
self.original_values["PixtralVisionModel.forward"] = PixtralVisionModel.forward
|
||||||
|
self.original_values["Mistral3PatchMerger.forward"] = Mistral3PatchMerger.forward
|
||||||
|
self.original_values["PixtralRMSNorm.forward"] = PixtralRMSNorm.forward
|
||||||
|
|
||||||
|
PixtralVisionModel.forward = _pixtral_forward
|
||||||
|
Mistral3PatchMerger.forward = _patch_merger_forward
|
||||||
|
PixtralRMSNorm.forward = _pixtral_rms_norm_forward
|
||||||
|
|
||||||
|
def _revert_patch(self):
|
||||||
|
"""Revert the PixtralVisionModel patch."""
|
||||||
|
PixtralVisionModel.forward = self.original_values["PixtralVisionModel.forward"]
|
||||||
|
Mistral3PatchMerger.forward = self.original_values["Mistral3PatchMerger.forward"]
|
||||||
|
PixtralRMSNorm.forward = self.original_values["PixtralRMSNorm.forward"]
|
||||||
@ -121,6 +121,11 @@ class DemoEngine(ADEngine):
|
|||||||
batch_size = sequence_info.num_sequences
|
batch_size = sequence_info.num_sequences
|
||||||
new_tokens = [[] for _ in range(batch_size)] # [batch_size][max_seq_len]
|
new_tokens = [[] for _ in range(batch_size)] # [batch_size][max_seq_len]
|
||||||
stop_tokens = sampling_params._get_stop_words()
|
stop_tokens = sampling_params._get_stop_words()
|
||||||
|
# NOTE: TRTLLM has made the intentional choice to separate `end_id` from `stop_words`, and not
|
||||||
|
# include the former in the latter's corresponding stop IDs. From a UX perspective, `stop_words`
|
||||||
|
# are optional, and can be customized per user requests, whereas `end_id` is static per model,
|
||||||
|
# and should always be used outside of benchmarking.
|
||||||
|
stop_tokens.append([sampling_params.end_id])
|
||||||
idxs_stop = [sampling_params.max_tokens - 1] * batch_size
|
idxs_stop = [sampling_params.max_tokens - 1] * batch_size
|
||||||
gen_logits = [] if sampling_params.return_generation_logits else None
|
gen_logits = [] if sampling_params.return_generation_logits else None
|
||||||
context_logits: Optional[List[torch.Tensor]] = None
|
context_logits: Optional[List[torch.Tensor]] = None
|
||||||
|
|||||||
@ -434,6 +434,15 @@ _SMALL_MODEL_CONFIGS = {
|
|||||||
"num_hidden_layers": 2,
|
"num_hidden_layers": 2,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": {
|
||||||
|
"model": f"{llm_models_root()}/Mistral-Small-3.1-24B-Instruct-2503",
|
||||||
|
"model_factory": "Mistral3VLM",
|
||||||
|
"compile_backend": "torch-simple",
|
||||||
|
"model_kwargs": {
|
||||||
|
"text_config": {"num_hidden_layers": 2},
|
||||||
|
"vision_config": {"num_hidden_layers": 2},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,15 @@
|
|||||||
|
from tensorrt_llm._torch.auto_deploy.models import mistral3
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_extra_inputs_includes_image_sizes():
|
||||||
|
factory = mistral3.Mistral3VLM(model="test-model")
|
||||||
|
extra_inputs = factory.get_extra_inputs()
|
||||||
|
|
||||||
|
pixel_values = extra_inputs["pixel_values"]
|
||||||
|
image_sizes = extra_inputs["image_sizes"]
|
||||||
|
|
||||||
|
pixel_values_dynamic_shape = pixel_values[1]()
|
||||||
|
image_sizes_dynamic_shape = image_sizes[1]()
|
||||||
|
|
||||||
|
# Unfortunately, direct object comparisons do not work.
|
||||||
|
assert pixel_values_dynamic_shape[0].__dict__ == image_sizes_dynamic_shape[0].__dict__
|
||||||
@ -0,0 +1,90 @@
|
|||||||
|
import torch
|
||||||
|
from _model_test_utils import get_small_model_config
|
||||||
|
from build_and_run_ad import ExperimentConfig
|
||||||
|
|
||||||
|
from tensorrt_llm._torch.auto_deploy import LlmArgs
|
||||||
|
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
|
||||||
|
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_run_mistral3_vlm():
|
||||||
|
experiment_config = get_small_model_config("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||||
|
experiment_config = ExperimentConfig(**experiment_config)
|
||||||
|
llm_args: LlmArgs = experiment_config.args
|
||||||
|
|
||||||
|
factory = llm_args.create_factory()
|
||||||
|
model = factory.build_model("cuda")
|
||||||
|
|
||||||
|
inputs = factory.get_example_inputs()
|
||||||
|
for key, value in inputs.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
dtype = torch.bfloat16 if isinstance(value, torch.FloatTensor) else None
|
||||||
|
inputs[key] = value.to(device=model.device, dtype=dtype)
|
||||||
|
|
||||||
|
# get relevant inputs
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).repeat(
|
||||||
|
input_ids.shape[0], 1
|
||||||
|
)
|
||||||
|
pixel_values = inputs["pixel_values"]
|
||||||
|
image_sizes = inputs["image_sizes"]
|
||||||
|
|
||||||
|
def _run_with_and_without_image(model, use_patch=True):
|
||||||
|
with apply_export_patches(
|
||||||
|
patch_list=["hf_mistral3", "hf_pixtral_vit"] if use_patch else []
|
||||||
|
):
|
||||||
|
with torch.inference_mode():
|
||||||
|
out_no_images = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
pixel_values=torch.zeros_like(pixel_values) if use_patch else None,
|
||||||
|
image_sizes=image_sizes if use_patch else None,
|
||||||
|
)
|
||||||
|
out_with_images = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
)
|
||||||
|
return {"no_images": out_no_images.logits, "with_images": out_with_images.logits}
|
||||||
|
|
||||||
|
# Get output pre-patch.
|
||||||
|
out_original = _run_with_and_without_image(model, use_patch=False)
|
||||||
|
|
||||||
|
# Get output post-patch.
|
||||||
|
outputs_for_comparison = {}
|
||||||
|
# TODO(2ez4bz): Figure out why the patches do not work outside of `torch_export_to_gm`.
|
||||||
|
# outputs_for_comparison["model_with_patch"] = _run_with_and_without_image(model)
|
||||||
|
|
||||||
|
gm = torch_export_to_gm(
|
||||||
|
model,
|
||||||
|
args=(),
|
||||||
|
kwargs={
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"image_sizes": image_sizes,
|
||||||
|
},
|
||||||
|
patch_list=[
|
||||||
|
"transformers_sdpa_mask",
|
||||||
|
"autocast_noop",
|
||||||
|
"torch_where",
|
||||||
|
"tensor_meta_device",
|
||||||
|
"sdpa_kernel_noop",
|
||||||
|
"sdpa",
|
||||||
|
"hf_mistral3",
|
||||||
|
"hf_pixtral_vit",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
move_to_device(gm, model.device)
|
||||||
|
|
||||||
|
outputs_for_comparison["gm"] = _run_with_and_without_image(gm)
|
||||||
|
|
||||||
|
atol, rtol = 1e-3, 1e-3
|
||||||
|
for comp, outs in outputs_for_comparison.items():
|
||||||
|
torch.testing.assert_close(
|
||||||
|
outs,
|
||||||
|
out_original,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
)
|
||||||
@ -1,5 +1,6 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pydantic
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM, LlmArgs
|
from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM, LlmArgs
|
||||||
@ -147,6 +148,21 @@ def test_config_flow(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_factory",
|
||||||
|
[
|
||||||
|
"Foo",
|
||||||
|
# typo.
|
||||||
|
"AutomodelForCausalLMFactory",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_non_registered_model_factory(model_factory: str):
|
||||||
|
with pytest.raises(
|
||||||
|
pydantic.ValidationError, match="does not exist in the model factory registry"
|
||||||
|
):
|
||||||
|
LlmArgs(model="test-model", model_factory=model_factory)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"parallel_field,invalid_value",
|
"parallel_field,invalid_value",
|
||||||
[
|
[
|
||||||
|
|||||||
@ -84,6 +84,15 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
|
|||||||
attn_backend="triton",
|
attn_backend="triton",
|
||||||
compile_backend="torch-compile",
|
compile_backend="torch-compile",
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
get_small_model_config(
|
||||||
|
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||||
|
attn_backend="flashinfer",
|
||||||
|
compile_backend="torch-simple",
|
||||||
|
),
|
||||||
|
# Human readable name for readability / easier selection with `-k`.
|
||||||
|
id="mistral-small-3.1-24b",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_build_ad(experiment_config: Dict):
|
def test_build_ad(experiment_config: Dict):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user