mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
feat(models): Mistral3.1 VLM pytorch backend support (#5529)
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
parent
b61a717275
commit
87fe44fd29
@ -18,6 +18,7 @@ extend_skip_glob = [
|
||||
"tensorrt_llm/_dlpack_utils.py",
|
||||
"tensorrt_llm/_ipc_utils.py",
|
||||
"tensorrt_llm/_mnnvl_utils.py",
|
||||
"tensorrt_llm/_torch/models/modeling_pixtral.py",
|
||||
"tensorrt_llm/disaggregated_params.py",
|
||||
"tensorrt_llm/engine.py",
|
||||
"tensorrt_llm/graph_rewriting.py",
|
||||
@ -30,6 +31,8 @@ extend_skip_glob = [
|
||||
"tensorrt_llm/python_plugin.py",
|
||||
"tensorrt_llm/sampling_params.py",
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
|
||||
]
|
||||
|
||||
[tool.yapf]
|
||||
@ -45,6 +48,7 @@ ignore_patterns = [
|
||||
"tensorrt_llm/_dlpack_utils.py",
|
||||
"tensorrt_llm/_ipc_utils.py",
|
||||
"tensorrt_llm/_mnnvl_utils.py",
|
||||
"tensorrt_llm/_torch/models/modeling_pixtral.py",
|
||||
"tensorrt_llm/disaggregated_params.py",
|
||||
"tensorrt_llm/engine.py",
|
||||
"tensorrt_llm/graph_rewriting.py",
|
||||
@ -57,6 +61,8 @@ ignore_patterns = [
|
||||
"tensorrt_llm/python_plugin.py",
|
||||
"tensorrt_llm/sampling_params.py",
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
@ -76,6 +82,7 @@ exclude = [
|
||||
"tensorrt_llm/_dlpack_utils.py",
|
||||
"tensorrt_llm/_ipc_utils.py",
|
||||
"tensorrt_llm/_mnnvl_utils.py",
|
||||
"tensorrt_llm/_torch/models/modeling_pixtral.py",
|
||||
"tensorrt_llm/disaggregated_params.py",
|
||||
"tensorrt_llm/engine.py",
|
||||
"tensorrt_llm/graph_rewriting.py",
|
||||
@ -88,6 +95,8 @@ exclude = [
|
||||
"tensorrt_llm/python_plugin.py",
|
||||
"tensorrt_llm/sampling_params.py",
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
|
||||
]
|
||||
|
||||
|
||||
@ -116,6 +125,7 @@ include = [
|
||||
"tensorrt_llm/_dlpack_utils.py",
|
||||
"tensorrt_llm/_ipc_utils.py",
|
||||
"tensorrt_llm/_mnnvl_utils.py",
|
||||
"tensorrt_llm/_torch/models/modeling_pixtral.py",
|
||||
"tensorrt_llm/disaggregated_params.py",
|
||||
"tensorrt_llm/engine.py",
|
||||
"tensorrt_llm/graph_rewriting.py",
|
||||
@ -128,6 +138,8 @@ include = [
|
||||
"tensorrt_llm/python_plugin.py",
|
||||
"tensorrt_llm/sampling_params.py",
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
|
||||
]
|
||||
exclude = [
|
||||
"**3rdparty/**",
|
||||
|
||||
@ -1,24 +1,43 @@
|
||||
from typing import Any, Optional, Tuple
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import MistralConfig
|
||||
from transformers import (AutoProcessor, AutoTokenizer, Mistral3Config,
|
||||
MistralConfig, PretrainedConfig, PreTrainedModel)
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from tensorrt_llm._torch.attention_backend import AttentionMetadata
|
||||
from tensorrt_llm._torch.attention_backend.interface import (
|
||||
PositionalEmbeddingParams, RopeParams)
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models import modeling_pixtral
|
||||
from tensorrt_llm._torch.models.modeling_multimodal_utils import \
|
||||
fuse_input_embeds
|
||||
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
|
||||
DecoderModelForCausalLM,
|
||||
_load_weights_impl,
|
||||
register_auto_model)
|
||||
from tensorrt_llm._torch.modules.attention import Attention
|
||||
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
|
||||
from tensorrt_llm._torch.modules.embedding import Embedding
|
||||
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
||||
from tensorrt_llm._torch.modules.linear import TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm._torch.speculative import SpecMetadata
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.inputs import (ExtraProcessedInputs, InputProcessor,
|
||||
TextPrompt, register_input_processor)
|
||||
from tensorrt_llm.llmapi import SamplingParams
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
_MULTIMODAL_ENV_NAME = "TLLM_MULTIMODAL_DISAGGREGATED"
|
||||
|
||||
|
||||
# Make this a runtime lookup rather than a module-wide constant for easier unit testing.
|
||||
def _is_disagg() -> bool:
|
||||
return os.getenv(_MULTIMODAL_ENV_NAME, "0") == "1"
|
||||
|
||||
|
||||
class MistralAttention(Attention):
|
||||
@ -187,3 +206,337 @@ class MistralForCausalLM(DecoderModelForCausalLM[MistralModel, MistralConfig]):
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size,
|
||||
)
|
||||
|
||||
|
||||
class Mistral3InputProcessor(InputProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
model_config: PretrainedConfig,
|
||||
tokenizer: Optional[AutoTokenizer],
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if tokenizer is None:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path,
|
||||
use_fast=False)
|
||||
|
||||
# To abide by the `InputProcessor` interface.
|
||||
self.model_path = model_path
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self._device = "cuda"
|
||||
self._processor = AutoProcessor.from_pretrained(model_path,
|
||||
use_fast=False)
|
||||
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self, inputs: TextPrompt, sampling_params: SamplingParams
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
images = inputs.get("multi_modal_data", {}).get("image")
|
||||
do_rescale = self._processor.image_processor.do_rescale
|
||||
if images is not None and isinstance(images[0], torch.Tensor):
|
||||
# The default multimodal input loader will normalize images to [0, 1] when the requested
|
||||
# format is "pt" (pytorch tensors), but not for "pil" (PIL images).
|
||||
do_rescale = False
|
||||
|
||||
processed = self._processor(
|
||||
text=inputs["prompt"],
|
||||
images=images,
|
||||
do_rescale=do_rescale,
|
||||
)
|
||||
input_ids = processed.pop("input_ids").tolist()[0]
|
||||
# Remaining in `processed`:
|
||||
# * "attention_mask": [B, num_input_tokens]
|
||||
# * "pixel_values": [B, C, H, W]
|
||||
# * "image_sizes": [B, 2]
|
||||
extra_processed_inputs = None
|
||||
pixel_values = processed.get("pixel_values")
|
||||
if pixel_values is not None:
|
||||
# We have no use for the `attention_mask`.
|
||||
processed.pop("attention_mask")
|
||||
processed = processed.to(self._device)
|
||||
# NOTE: `processed` is a dict-like object, but not actually a dict.
|
||||
extra_processed_inputs = {
|
||||
"multimodal_data": {
|
||||
"image": {
|
||||
**processed
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return input_ids, extra_processed_inputs
|
||||
|
||||
|
||||
@register_auto_model("Mistral3ForConditionalGeneration")
|
||||
# The below informs the registry which input registry to create for this in `tensorrt_llm/llmapi/llm.py`.
|
||||
@register_input_processor(Mistral3InputProcessor, model_type="mistral3")
|
||||
class Mistral3VLM(PreTrainedModel):
|
||||
"""Mistral3VLM implementation for TRTLLM.
|
||||
|
||||
NOTE: for the time being, image tokens are only placed after the text (see
|
||||
`tensorrt_llm/inputs/utils.py`).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[Mistral3Config],
|
||||
):
|
||||
if _is_disagg():
|
||||
raise NotImplementedError(
|
||||
"Mistral3VLM does not support disaggregated inference yet. Please unset "
|
||||
f"the {_MULTIMODAL_ENV_NAME} environment variable, or set it to '0'."
|
||||
)
|
||||
|
||||
config = model_config.pretrained_config
|
||||
super().__init__(config)
|
||||
|
||||
self.model_config = model_config
|
||||
|
||||
llm_model_config = self._get_sub_model_config(model_config,
|
||||
"text_config")
|
||||
self.llm = MistralForCausalLM(llm_model_config)
|
||||
|
||||
self._device = "cuda"
|
||||
vision_model_config = self._get_sub_model_config(
|
||||
model_config, "vision_config")
|
||||
self._vision_tower = modeling_pixtral.PixtralVisionModel(
|
||||
vision_model_config)
|
||||
self._multi_modal_projector = Mistral3MultiModalProjector(model_config)
|
||||
vision_feature_layer = config.vision_feature_layer
|
||||
if vision_feature_layer != -1:
|
||||
raise ValueError(
|
||||
f"Using intermediate layers ({vision_feature_layer}) in the `PixtralVisionModel` "
|
||||
f"is not supported. Please use `vision_feature_layer=-1`.")
|
||||
|
||||
self.model_dtype = getattr(config, "torch_dtype", torch.bfloat16)
|
||||
|
||||
self._image_token_ids = torch.tensor([config.image_token_index],
|
||||
dtype=torch.int32,
|
||||
device=self._device)
|
||||
self._post_config()
|
||||
|
||||
# This is necessary because the executor looks at
|
||||
# `model.model_config.pretrained_config.vocab_size`.
|
||||
def _post_config(self):
|
||||
self.config = self.llm.config
|
||||
self.model_config.pretrained_config = self.llm.config
|
||||
|
||||
def load_weights(self, weights: Dict, *args, **kwargs):
|
||||
llm_weights = _filter_weights(weights, "language_model.")
|
||||
self.llm.load_weights(llm_weights, *args, **kwargs)
|
||||
|
||||
vit_weights = _filter_weights(weights, "vision_tower.")
|
||||
self._vision_tower.load_weights(vit_weights, *args, **kwargs)
|
||||
|
||||
mm_projector_weights = _filter_weights(weights,
|
||||
"multi_modal_projector.")
|
||||
# `_load_weights_impl` assumes `config.hidden_size` exists, which is not the case for the
|
||||
# top-level `Mistral3Config`.
|
||||
self._multi_modal_projector.load_state_dict(mm_projector_weights)
|
||||
|
||||
def infer_max_seq_len(self) -> int:
|
||||
return self.llm.infer_max_seq_len()
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
return_context_logits: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Forward method."""
|
||||
num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
|
||||
logger.debug(f"{num_context_requests=}, {num_generation_requests=}")
|
||||
|
||||
multimodal_params = kwargs.get("multimodal_params", [])
|
||||
image_features = []
|
||||
multimodal_params_len = len(multimodal_params)
|
||||
if multimodal_params_len > 0:
|
||||
if multimodal_params_len != num_context_requests:
|
||||
raise RuntimeError(
|
||||
f"Number of multimodal tensors ({multimodal_params_len}) should be equal to number of "
|
||||
f"context requests ({num_context_requests}) in the batch.")
|
||||
# NOTES:
|
||||
# 1. the pixel values in `multimodal_data["image"]` might vary in (height, width) between
|
||||
# images, making them unsafe to batch in general. The input processor also cannot produce
|
||||
# them in a batch, since it is always called with a single input - otherwise, we would
|
||||
# have been able to naturally leverage the padding / resizing capabilities of the underlying
|
||||
# `PixtralProcessor`.
|
||||
# 2. After each `pixel_values` tensor has gone through the vision tower's `patch_conv` layer,
|
||||
# they are divided into patches that are then concatenated in order to treat them as a
|
||||
# single "sequence" in the vision tower's attention layers, so some form of batching still
|
||||
# happens in the vision tower.
|
||||
image_features = [
|
||||
self._get_image_features(**x.multimodal_data["image"])
|
||||
for x in multimodal_params
|
||||
]
|
||||
|
||||
input_ids, inputs_embeds = fuse_input_embeds(
|
||||
embedding_layer=self.llm.model.embed_tokens,
|
||||
input_ids=input_ids,
|
||||
mm_embeds=image_features,
|
||||
mm_token_ids=self._image_token_ids,
|
||||
)
|
||||
|
||||
return self.llm.forward(
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
return_context_logits=return_context_logits,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_sub_model_config(
|
||||
model_config: ModelConfig[MistralConfig],
|
||||
name: str,
|
||||
) -> ModelConfig:
|
||||
# Extract the subconfig from the `transformers` config and shove it into our own
|
||||
# `ModelConfig` class.
|
||||
sub_model_config: ModelConfig[MistralConfig] = dataclasses.replace(
|
||||
model_config,
|
||||
pretrained_config=getattr(model_config.pretrained_config, name),
|
||||
)
|
||||
# Make sure some fields that are not explicitly included in the sub config, but present
|
||||
# in the top-level config, are replicated.
|
||||
if (hasattr(sub_model_config.pretrained_config, "torch_dtype")
|
||||
and sub_model_config.pretrained_config.torch_dtype is None):
|
||||
sub_model_config.pretrained_config.torch_dtype = model_config.pretrained_config.torch_dtype
|
||||
|
||||
return sub_model_config
|
||||
|
||||
# Original implementation:
|
||||
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/
|
||||
# modeling_mistral3.py#L341
|
||||
def _get_image_features(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
image_sizes: torch.Tensor,
|
||||
):
|
||||
image_outputs = self._vision_tower(
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
image_features = self._multi_modal_projector(image_outputs.squeeze(0),
|
||||
image_sizes)
|
||||
return image_features
|
||||
|
||||
|
||||
# Original implementation:
|
||||
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/modeling_mistral3.py#L66
|
||||
# NOTE: the main difference is the usage of TRTLLM's own `Linear` layer over pytorch's built-in layer.
|
||||
class Mistral3PatchMerger(torch.nn.Module):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[Mistral3Config]):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
# Both the below are needed in order to use `_load_weights_impl`.
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
|
||||
hidden_size = config.vision_config.hidden_size
|
||||
self._spatial_merge_size = config.spatial_merge_size
|
||||
self._patch_size = config.vision_config.patch_size
|
||||
self.merging_layer = Linear(
|
||||
in_features=hidden_size * self._spatial_merge_size**2,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, image_features: torch.Tensor,
|
||||
image_sizes: torch.Tensor) -> torch.Tensor:
|
||||
image_sizes = [(image_size[0] // self._patch_size,
|
||||
image_size[1] // self._patch_size)
|
||||
for image_size in image_sizes]
|
||||
|
||||
tokens_per_image = [h * w for h, w in image_sizes]
|
||||
d = image_features.shape[-1]
|
||||
|
||||
permuted_tensor = []
|
||||
for image_index, image_tokens in enumerate(
|
||||
image_features.split(tokens_per_image)):
|
||||
# Reshape image_tokens into a 2D grid
|
||||
h, w = image_sizes[image_index]
|
||||
image_grid = image_tokens.view(h, w, d).permute(2, 0,
|
||||
1).unsqueeze(0)
|
||||
grid = torch.nn.functional.unfold(
|
||||
image_grid,
|
||||
kernel_size=self._spatial_merge_size,
|
||||
stride=self._spatial_merge_size)
|
||||
grid = grid.view(d * self._spatial_merge_size**2, -1).t()
|
||||
permuted_tensor.append(grid)
|
||||
|
||||
image_features = torch.cat(permuted_tensor, dim=0)
|
||||
image_features = self.merging_layer(image_features)
|
||||
return image_features
|
||||
|
||||
def load_weights(self, weights):
|
||||
_load_weights_impl(self, weights)
|
||||
|
||||
|
||||
# Original implementation:
|
||||
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/
|
||||
# modeling_mistral3.py#L104C1-L127C29
|
||||
class Mistral3MultiModalProjector(torch.nn.Module):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[Mistral3Config]):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
# Both the below are needed in order to use `_load_weights_impl`.
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
|
||||
dtype = config.torch_dtype
|
||||
self.norm = RMSNorm(
|
||||
hidden_size=config.vision_config.hidden_size,
|
||||
# NOTE: the original implementation actually does not look at the config for this value.
|
||||
# We therefore hardcode the default value `1e-6` from `Mistral3RMSNorm`.
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.patch_merger = Mistral3PatchMerger(model_config)
|
||||
# We have hidden_size * the number of vision feature layers
|
||||
num_feature_layers = 1 if isinstance(config.vision_feature_layer,
|
||||
int) else len(
|
||||
config.vision_feature_layer)
|
||||
self.linear_1 = Linear(
|
||||
in_features=config.vision_config.hidden_size * num_feature_layers,
|
||||
out_features=config.text_config.hidden_size,
|
||||
bias=config.multimodal_projector_bias,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = Linear(
|
||||
in_features=config.text_config.hidden_size,
|
||||
out_features=config.text_config.hidden_size,
|
||||
bias=config.multimodal_projector_bias,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
|
||||
image_features = self.norm(image_features)
|
||||
image_features = self.patch_merger(image_features, image_sizes)
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights):
|
||||
_load_weights_impl(self, weights)
|
||||
|
||||
|
||||
def _filter_weights(weights: Dict[str, torch.Tensor],
|
||||
prefix: str) -> Dict[str, torch.Tensor]:
|
||||
return {
|
||||
name[len(prefix):]: weight
|
||||
for name, weight in weights.items() if name.startswith(prefix)
|
||||
}
|
||||
|
||||
311
tensorrt_llm/_torch/models/modeling_pixtral.py
Normal file
311
tensorrt_llm/_torch/models/modeling_pixtral.py
Normal file
@ -0,0 +1,311 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from tensorrt_llm._torch import model_config as model_config_lib
|
||||
from tensorrt_llm._torch.attention_backend import interface as attention_interface
|
||||
from tensorrt_llm._torch.attention_backend import utils as attention_utils
|
||||
from tensorrt_llm._torch.models import modeling_utils
|
||||
from tensorrt_llm._torch.modules import attention as trtllm_attention
|
||||
from tensorrt_llm._torch.modules import gated_mlp as trtllm_gated_mlp
|
||||
from tensorrt_llm._torch.modules import rms_norm as trtllm_rmsnorm
|
||||
|
||||
|
||||
class PixtralAttention(trtllm_attention.Attention):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig],
|
||||
layer_idx: int,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
pos_embd_params = None
|
||||
max_position_embeddings = None
|
||||
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_attention_heads,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
bias=False,
|
||||
pos_embd_params=pos_embd_params,
|
||||
layer_idx=layer_idx,
|
||||
dtype=getattr(config, "torch_dtype", torch.float32),
|
||||
config=model_config,
|
||||
# Pixtral first needs to compute positional embeddings using its own
|
||||
# `PixtralRotaryEmbedding`.
|
||||
rope_fusion=False,
|
||||
)
|
||||
|
||||
|
||||
class PixtralAttentionLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig],
|
||||
layer_idx: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.pretrained_config.hidden_size
|
||||
dtype = config.pretrained_config.torch_dtype
|
||||
self.attention_norm = trtllm_rmsnorm.RMSNorm(
|
||||
hidden_size=hidden_size,
|
||||
eps=1e-5,
|
||||
dtype=dtype,
|
||||
)
|
||||
pretrained_config = config.pretrained_config
|
||||
|
||||
if pretrained_config.hidden_act != "silu":
|
||||
raise ValueError(
|
||||
"Only 'silu' is accepted as the activation function for the MLP in "
|
||||
f"{self.__class__.__name__}. Got: {pretrained_config.hidden_act}."
|
||||
)
|
||||
self.feed_forward = trtllm_gated_mlp.GatedMLP(
|
||||
hidden_size=pretrained_config.hidden_size,
|
||||
intermediate_size=pretrained_config.intermediate_size,
|
||||
bias=False,
|
||||
activation=torch.nn.functional.silu,
|
||||
dtype=pretrained_config.torch_dtype,
|
||||
config=config,
|
||||
)
|
||||
self.attention = PixtralAttention(config, layer_idx)
|
||||
self.ffn_norm = trtllm_rmsnorm.RMSNorm(
|
||||
hidden_size=hidden_size,
|
||||
eps=1e-5,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: attention_interface.AttentionMetadata,
|
||||
position_ids: torch.Tensor,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.attention_norm(hidden_states)
|
||||
hidden_states = self.attention(
|
||||
# NOTE: although we do not need the `position_ids` to compute ROPE (since it has already
|
||||
# been pre-computed), internally, the cos / sin vectors will not be applied to the
|
||||
# query / key tensors if `position_ids=None` in `RotaryEmbedding`.
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
attention_mask=attention_interface.PredefinedAttentionMask.FULL,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
hidden_states = self.ffn_norm(hidden_states)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Original implementation:
|
||||
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/pixtral/modeling_pixtral.py#L279
|
||||
class PixtralTransformer(torch.nn.Module):
|
||||
def __init__(self, config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig]):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for i in range(config.pretrained_config.num_hidden_layers):
|
||||
self.layers.append(PixtralAttentionLayer(config=config, layer_idx=i))
|
||||
self._head_dim = config.pretrained_config.head_dim
|
||||
self._num_heads = config.pretrained_config.num_attention_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attn_metadata: attention_interface.AttentionMetadata,
|
||||
):
|
||||
if inputs_embeds.ndim == 3:
|
||||
batch_size, patches, _ = inputs_embeds.shape
|
||||
elif inputs_embeds.ndim == 2:
|
||||
batch_size = 1
|
||||
patches = inputs_embeds.size(0)
|
||||
rope_function = _RopeFunction(
|
||||
batch_size=batch_size,
|
||||
patches=patches,
|
||||
num_heads=self._num_heads,
|
||||
head_dim=self._head_dim,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
# The way pixtral applies rope is by:
|
||||
# 1. Computing the `position_ids` using the `patch_embeds` (which are essentially a
|
||||
# sliced + concat'ed output of the conv2d layer), using their positions in the
|
||||
# a meshgrid.
|
||||
# 2. Computing `position_embeddings` once for the entire transformer portion.
|
||||
# See: https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/
|
||||
# models/pixtral/modeling_pixtral.py#L494
|
||||
# 3. These `position_embeddings` are then the ones used to apply rope in each attention
|
||||
# layer.
|
||||
# By substituting the `encoder_layer.attention.rotary_emb` to use `_RopeFunction`, which
|
||||
# has these `position_embeddings` as an attribute, we can reuse the embeddings + application
|
||||
# logic for each encoder layer.
|
||||
encoder_layer.attention.rotary_emb = rope_function
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Original implementation:
|
||||
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/pixtral/modeling_pixtral.py#L440
|
||||
@modeling_utils.register_auto_model("PixtralVisionModel")
|
||||
class PixtralVisionModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig]
|
||||
):
|
||||
super().__init__()
|
||||
tp_size = model_config.mapping.tp_size
|
||||
# TODO: implement support for `tp_size > 1`.
|
||||
if tp_size > 1:
|
||||
raise NotImplementedError(
|
||||
f"Mistral3VLM does not support `mapping.tp_size > 1` yet (got {tp_size})."
|
||||
)
|
||||
# Both the below are needed in order to use `_load_weights_impl`.
|
||||
self.model_config = model_config
|
||||
self.config: transformers.PixtralVisionConfig = model_config.pretrained_config
|
||||
self.patch_conv = torch.nn.Conv2d(
|
||||
in_channels=self.config.num_channels,
|
||||
out_channels=self.config.hidden_size,
|
||||
kernel_size=self.config.patch_size,
|
||||
stride=self.config.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
self._patch_size = self.config.patch_size
|
||||
self.ln_pre = trtllm_rmsnorm.RMSNorm(
|
||||
hidden_size=self.config.hidden_size,
|
||||
eps=1e-5,
|
||||
dtype=self.config.torch_dtype,
|
||||
)
|
||||
self.transformer = PixtralTransformer(model_config)
|
||||
self._patch_positional_embedding = (
|
||||
transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(self.config)
|
||||
)
|
||||
|
||||
self._metadata_cls = attention_utils.get_attention_backend(
|
||||
model_config.attn_backend
|
||||
).Metadata
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
image_sizes: torch.Tensor,
|
||||
):
|
||||
with torch.autocast(device_type="cuda", dtype=self.config.torch_dtype):
|
||||
patch_embeds = self.patch_conv(pixel_values)
|
||||
patch_embeds_list = [
|
||||
embed[..., : (size[0] // self._patch_size), : (size[1] // self._patch_size)]
|
||||
for embed, size in zip(patch_embeds, image_sizes)
|
||||
]
|
||||
|
||||
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0)
|
||||
patch_embeds = self.ln_pre(patch_embeds)
|
||||
|
||||
position_ids = transformers.models.pixtral.modeling_pixtral.position_ids_in_meshgrid(
|
||||
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
|
||||
)
|
||||
position_embeddings = self._patch_positional_embedding(patch_embeds, position_ids)
|
||||
|
||||
attn_metadata = self._prepare_attn_metadata(
|
||||
# The `torch.cat` that creates the `patch_embeds` flattens the conv features from multiple
|
||||
# images into a single sequence - hence why we hardcode the batch size to 1 here.
|
||||
batch_size=1,
|
||||
seq_len=position_ids.size(0),
|
||||
)
|
||||
out = self.transformer(
|
||||
patch_embeds,
|
||||
position_embeddings=position_embeddings,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def load_weights(self, weights):
|
||||
modeling_utils._load_weights_impl(self, weights)
|
||||
|
||||
def _prepare_attn_metadata(self, batch_size: int, seq_len: int):
|
||||
request_ids = list(range(1, batch_size + 1))
|
||||
prompt_lens = [seq_len] * batch_size
|
||||
attn_metadata = self._metadata_cls(
|
||||
seq_lens=torch.tensor([seq_len] * batch_size, dtype=torch.int),
|
||||
num_contexts=batch_size,
|
||||
max_num_requests=batch_size,
|
||||
max_num_tokens=seq_len * batch_size,
|
||||
kv_cache_manager=None,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
attn_metadata.max_seq_len = seq_len * batch_size
|
||||
attn_metadata.prepare()
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class _RopeFunction:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
patches: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
position_embeddings: torch.Tensor,
|
||||
):
|
||||
self._batch_size = batch_size
|
||||
self._patches = patches
|
||||
self._num_heads = num_heads
|
||||
self._head_dim = head_dim
|
||||
self._cos, self._sin = position_embeddings
|
||||
|
||||
# This signature matches that of
|
||||
# `tensorrt_llm/_torch/modules/rotary_embedding.py::RotaryEmbedding.forward` so that we are
|
||||
# able to override the `PixtralAttentionLayer.rotary_embed` attribute.
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Unused.
|
||||
position_ids: torch.Tensor,
|
||||
# Assumed to be in the order `[q, k]`.
|
||||
targets: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
if len(targets) != 2:
|
||||
raise ValueError("Expected exactly two targets [q, k].")
|
||||
|
||||
# TODO: see if we can reuse `RotaryEmbedding.apply_rotary_pos_emb`.
|
||||
orig_shape = targets[0].shape
|
||||
q_embed, k_embed = transformers.models.pixtral.modeling_pixtral.apply_rotary_pos_emb(
|
||||
q=targets[0]
|
||||
.view(
|
||||
self._batch_size,
|
||||
self._patches,
|
||||
self._num_heads,
|
||||
self._head_dim,
|
||||
)
|
||||
.transpose(1, 2),
|
||||
k=targets[1]
|
||||
.view(
|
||||
self._batch_size,
|
||||
self._patches,
|
||||
self._num_heads,
|
||||
self._head_dim,
|
||||
)
|
||||
.transpose(1, 2),
|
||||
cos=self._cos,
|
||||
sin=self._sin,
|
||||
unsqueeze_dim=0,
|
||||
)
|
||||
|
||||
q_embed = q_embed.transpose(2, 1).reshape(orig_shape)
|
||||
k_embed = k_embed.transpose(2, 1).reshape(orig_shape)
|
||||
|
||||
return [q_embed, k_embed]
|
||||
@ -184,13 +184,15 @@ SUPPORTED_GEMMA_MODEL_GROUP = ["gemma3"]
|
||||
SUPPORTED_LLAMA_MODEL_GROUP = ["mllama", "llama4"]
|
||||
SUPPORTED_LLAVA_IMAGE_MODEL_GROUP = ["llava_llama", "llava_next"]
|
||||
SUPPORTED_LLAVA_VIDEO_MODEL_GROUP = ["llava_llama"]
|
||||
SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP = ["mistral3"]
|
||||
SUPPORTED_HYPERCLOVAX_MODEL_GROUP = ["hyperclovax_vlm"]
|
||||
|
||||
ALL_SUPPORTED_IMAGE_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
|
||||
+ SUPPORTED_LLAMA_MODEL_GROUP \
|
||||
+ SUPPORTED_LLAVA_IMAGE_MODEL_GROUP \
|
||||
+ SUPPORTED_HYPERCLOVAX_MODEL_GROUP \
|
||||
+ SUPPORTED_GEMMA_MODEL_GROUP
|
||||
+ SUPPORTED_GEMMA_MODEL_GROUP \
|
||||
+ SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP
|
||||
|
||||
ALL_SUPPORTED_VIDEO_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
|
||||
+ SUPPORTED_LLAVA_VIDEO_MODEL_GROUP
|
||||
@ -217,6 +219,10 @@ PLACEHOLDER_PLACEMENT_MAP = {
|
||||
"mllama": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
"hyperclovax_vlm": MultimodalPlaceholderPlacement.AFTER_TEXT,
|
||||
"gemma3": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
# NOTE: for mistral3 multimodal models, it does not strictly have to be after the text.
|
||||
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
|
||||
# src/mistral_common/tokens/tokenizers/base.py#L326
|
||||
"mistral3": MultimodalPlaceholderPlacement.AFTER_TEXT,
|
||||
}
|
||||
assert len(PLACEHOLDER_PLACEMENT_MAP) == len(ALL_SUPPORTED_MULTIMODAL_MODELS)
|
||||
|
||||
@ -247,6 +253,10 @@ def retrieve_multimodal_placeholder(model_type: str, modality: str,
|
||||
'<|im_start|>user (vector)\n<|dummy3|><|im_end|>\n' + \
|
||||
'<|im_start|>image/aux\n다음 중 ocr은 사진에서 검출된 글자이고, lens_keyword는 사진에서 추출된 keyword와 bbox 위치입니다.' + \
|
||||
'bbox는 0~1 사이로 정규화된 [x1, y1, x2, y2]의 형태입니다. 참고하여 답변하세요. {"ocr": "", "lens_keywords": "", "lens_local_keywords": ""}'
|
||||
elif model_type in SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP:
|
||||
# Ref: https://github.com/mistralai/mistral-common/blob/26a6bb3a07ee0b78a3808f2797f23e1d28514b93/
|
||||
# src/mistral_common/tokens/tokenizers/base.py#L60
|
||||
return "[IMG]"
|
||||
raise TypeError(
|
||||
f"For image modality, only {ALL_SUPPORTED_IMAGE_MODELS} are supported but got {model_type}"
|
||||
)
|
||||
|
||||
@ -1963,9 +1963,12 @@ def test_ptp_quickstart_advanced_mixed_precision(llm_root, llm_venv):
|
||||
("llava-v1.6-mistral-7b", "llava-v1.6-mistral-7b-hf"),
|
||||
("qwen2-vl-7b-instruct", "Qwen2-VL-7B-Instruct"),
|
||||
("qwen2.5-vl-7b-instruct", "Qwen2.5-VL-7B-Instruct"),
|
||||
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
|
||||
])
|
||||
def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
|
||||
modality, use_cuda_graph):
|
||||
# NOTE: individual tests need to be enabled in
|
||||
# tests/integration/test_lists/qa/examples_test_list.txt
|
||||
llm_venv.run_cmd(
|
||||
['-m', 'pip', 'install', 'flash-attn==2.7.3', '--no-build-isolation'])
|
||||
|
||||
@ -2051,6 +2054,16 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
|
||||
["earth", "rotating", "night", "lights", "cities"],
|
||||
],
|
||||
},
|
||||
"mistral-small-3.1-24b-instruct": {
|
||||
"image": [
|
||||
[
|
||||
"dramatic", "seascape", "stormy", "turbulent", "waves",
|
||||
"rough"
|
||||
],
|
||||
["scenic", "rock", "landscape", "snow", "formation"],
|
||||
["highway", "traffic", "directions", "lanes", "Jurong"],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
cmd = [
|
||||
|
||||
@ -523,6 +523,8 @@ test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-image-True]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-False]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-True]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False]
|
||||
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
|
||||
415
tests/unittest/_torch/modeling/test_modeling_mistral.py
Normal file
415
tests/unittest/_torch/modeling/test_modeling_mistral.py
Normal file
@ -0,0 +1,415 @@
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils.util import getSMVersion
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import mapping as mapping_lib
|
||||
from tensorrt_llm._torch import metadata as metadata_lib
|
||||
from tensorrt_llm._torch import model_config as model_config_lib
|
||||
from tensorrt_llm._torch.attention_backend import utils as attention_utils
|
||||
from tensorrt_llm._torch.models import modeling_mistral
|
||||
from tensorrt_llm._torch.pyexecutor import cuda_graph_runner, resource_manager
|
||||
from tensorrt_llm.bindings import executor as executor_lib
|
||||
from tensorrt_llm.models import modeling_utils
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mistral_small_3_1_24b_config():
|
||||
return {
|
||||
"architectures": ["Mistral3ForConditionalGeneration"],
|
||||
"image_token_index": 10,
|
||||
"model_type": "mistral3",
|
||||
"multimodal_projector_bias": False,
|
||||
"projector_hidden_act": "gelu",
|
||||
"spatial_merge_size": 2,
|
||||
"text_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 5120,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 32768,
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "mistral",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 40,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_theta": 1000000000.0,
|
||||
"sliding_window": None,
|
||||
"use_cache": True,
|
||||
"vocab_size": 131072,
|
||||
},
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.50.0.dev0",
|
||||
"vision_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"head_dim": 64,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 1540,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"model_type": "pixtral",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"rope_theta": 10000.0,
|
||||
},
|
||||
"vision_feature_layer": -1,
|
||||
}
|
||||
|
||||
|
||||
def reduce_mistral_config(
|
||||
mem_for_full_model: int, config_dict: Dict[str, Any], default_num_layers: int = 32
|
||||
):
|
||||
_, total_mem = torch.cuda.mem_get_info()
|
||||
if "text_config" in config_dict:
|
||||
config_dict = config_dict["text_config"]
|
||||
# scale model down if gpu memory is low
|
||||
if total_mem < mem_for_full_model:
|
||||
model_fraction = total_mem / mem_for_full_model
|
||||
num_layers = int(config_dict["num_hidden_layers"] * model_fraction)
|
||||
num_layers = min(num_layers, default_num_layers)
|
||||
config_dict["num_hidden_layers"] = num_layers
|
||||
|
||||
|
||||
def init_hf_model(cls, config, dtype, device):
|
||||
"""Helper function for initializing a model from `transformers`.
|
||||
|
||||
The reason this function exists is: by default, instantiating a `transformers` model also
|
||||
eagerly initializes the model's weights on the CPU, which takes an absurdly long time to
|
||||
complete.
|
||||
|
||||
Instead, we lazily instantiate the model, and initialize the weights only after moving it to
|
||||
the requested `device`.
|
||||
"""
|
||||
from transformers import modeling_utils as t_modeling_utils
|
||||
|
||||
with t_modeling_utils.no_init_weights():
|
||||
model = cls(config).eval()
|
||||
|
||||
model.to(device=device)
|
||||
model.init_weights()
|
||||
model.to(dtype=dtype)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def kv_cache_manager_context(kv_cache_manager):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
|
||||
def test_mistral_3_vlm_rejects_disagg(mistral_small_3_1_24b_config):
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"TLLM_MULTIMODAL_DISAGGREGATED": "1"}),
|
||||
pytest.raises(NotImplementedError, match="disaggregated inference"),
|
||||
):
|
||||
modeling_mistral.Mistral3VLM(
|
||||
model_config=model_config_lib.ModelConfig(
|
||||
pretrained_config=transformers.Mistral3Config.from_dict(
|
||||
mistral_small_3_1_24b_config
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("quant_algo", [None, "FP8"])
|
||||
def test_mistral_3_vlm_sanity(mistral_small_3_1_24b_config, quant_algo):
|
||||
if quant_algo == "FP8" and getSMVersion() < 89:
|
||||
pytest.skip("This test is not supported in pre-Ada architecture")
|
||||
|
||||
config_dict = mistral_small_3_1_24b_config
|
||||
# 24B * sizeof(float16) plus some extra for activations
|
||||
mem_for_full_model = int(2.1 * 24 * 2 ** (30))
|
||||
reduce_mistral_config(mem_for_full_model, config_dict)
|
||||
|
||||
if config_dict["text_config"]["num_hidden_layers"] <= 0:
|
||||
pytest.skip("Insufficient memory for a single Mistral layer")
|
||||
|
||||
mistral_3_config = transformers.Mistral3Config.from_dict(config_dict)
|
||||
if quant_algo:
|
||||
quant_config = modeling_utils.QuantConfig(quant_algo=quant_algo)
|
||||
else:
|
||||
quant_config = None
|
||||
|
||||
dtype = mistral_3_config.torch_dtype
|
||||
device = torch.device("cuda")
|
||||
|
||||
model_config = model_config_lib.ModelConfig(
|
||||
pretrained_config=mistral_3_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mistral = modeling_mistral.Mistral3VLM(model_config).to(device)
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[100, 200, 300, 100, 200, 100, 400, 500], dtype=torch.int, device=device
|
||||
)
|
||||
|
||||
context_sequence_lengths = [3, 2, 1]
|
||||
sequence_lengths = context_sequence_lengths + [1, 1]
|
||||
past_seen_tokens = [0, 0, 0, 62, 75]
|
||||
request_ids = list(range(len(sequence_lengths)))
|
||||
token_nums = (torch.tensor(past_seen_tokens) + torch.tensor(sequence_lengths)).tolist()
|
||||
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
|
||||
|
||||
num_blocks = 100
|
||||
tokens_per_block = 128
|
||||
head_dim = mistral.config.head_dim
|
||||
num_layers = mistral.config.num_hidden_layers
|
||||
num_kv_heads = mistral.config.num_key_value_heads
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
batch_size = len(context_sequence_lengths) + 2
|
||||
|
||||
if dtype == torch.half:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
|
||||
mapping = mapping_lib.Mapping(world_size=1, tp_size=1, rank=0)
|
||||
kv_cache_config = executor_lib.KvCacheConfig(max_tokens=num_blocks * tokens_per_block)
|
||||
kv_cache_manager = resource_manager.KVCacheManager(
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
with kv_cache_manager_context(kv_cache_manager):
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
metadata_cls = attention_utils.get_attention_backend(model_config.attn_backend).Metadata
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
|
||||
num_contexts=len(context_sequence_lengths),
|
||||
kv_cache_params=metadata_lib.KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=past_seen_tokens,
|
||||
),
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
max_num_requests=len(context_sequence_lengths) + 2,
|
||||
max_num_tokens=8192,
|
||||
)
|
||||
|
||||
position_ids = []
|
||||
for i, tokens in enumerate(past_seen_tokens):
|
||||
seq_len = context_sequence_lengths[i] if i < len(context_sequence_lengths) else 1
|
||||
position_id = torch.arange(tokens, tokens + seq_len, device=input_ids.device)
|
||||
position_ids.append(position_id)
|
||||
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0)
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = mistral.forward(
|
||||
input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata
|
||||
)
|
||||
|
||||
assert len(past_seen_tokens) == logits.shape[0]
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = mistral.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
return_context_logits=True,
|
||||
)
|
||||
assert input_ids.shape == logits.shape[:-1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend, use_cuda_graph",
|
||||
[
|
||||
("VANILLA", False),
|
||||
("FLASHINFER", False),
|
||||
("FLASHINFER", True),
|
||||
("TRTLLM", False),
|
||||
("TRTLLM", True),
|
||||
],
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use_cuda_graph):
|
||||
metadata_cls = attention_utils.get_attention_backend(backend).Metadata
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
config_dict = mistral_small_3_1_24b_config
|
||||
# 24B * sizeof(float16) plus some extra for activations
|
||||
# times 2, since we'll need 2 of these
|
||||
mem_for_full_model = int(2.1 * 24 * 2 ** (30) * 2)
|
||||
reduce_mistral_config(mem_for_full_model, config_dict)
|
||||
if config_dict["text_config"]["num_hidden_layers"] <= 0:
|
||||
pytest.skip("Insufficient memory for a single Mistral layer")
|
||||
mistral_config = transformers.Mistral3Config.from_dict(config_dict)
|
||||
dtype = mistral_config.torch_dtype
|
||||
device = torch.device("cuda")
|
||||
|
||||
hf_mistral = init_hf_model(
|
||||
cls=transformers.Mistral3ForConditionalGeneration,
|
||||
config=mistral_config,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = model_config_lib.ModelConfig(
|
||||
pretrained_config=mistral_config,
|
||||
attn_backend=backend,
|
||||
)
|
||||
mistral = modeling_mistral.Mistral3VLM(model_config).to(dtype).to(device)
|
||||
mistral.load_weights(hf_mistral.state_dict())
|
||||
|
||||
num_blocks = 1
|
||||
tokens_per_block = 128
|
||||
head_dim = mistral.config.head_dim
|
||||
num_layers = mistral.config.num_hidden_layers
|
||||
num_kv_heads = mistral.config.num_key_value_heads
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
batch_size = 1
|
||||
|
||||
if dtype == torch.half:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
|
||||
mapping = mapping_lib.Mapping(world_size=1, tp_size=1, rank=0)
|
||||
kv_cache_config = executor_lib.KvCacheConfig(max_tokens=num_blocks * tokens_per_block)
|
||||
kv_cache_manager = resource_manager.KVCacheManager(
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
with kv_cache_manager_context(kv_cache_manager):
|
||||
# context
|
||||
input_ids = torch.tensor(
|
||||
[100, 200, 300, 100, 200, 100, 400, 500], dtype=torch.int, device=device
|
||||
)
|
||||
|
||||
num_cached_tokens_per_seq = [0]
|
||||
request_ids = [1]
|
||||
token_nums = [input_ids.size(-1)]
|
||||
prompt_lens = [input_ids.size(-1)]
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
|
||||
num_contexts=1,
|
||||
kv_cache_params=metadata_lib.KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
max_num_requests=1,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
# Note: no CUDA graphs for prefill, the graph runner is built for
|
||||
# decoding only.
|
||||
position_ids = [torch.arange(0, input_ids.size(-1))]
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = mistral.forward(
|
||||
input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata
|
||||
)
|
||||
ref = hf_mistral.forward(
|
||||
input_ids=input_ids.unsqueeze(0), position_ids=position_ids, use_cache=True
|
||||
)
|
||||
|
||||
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.4, rtol=0.4)
|
||||
|
||||
# gen
|
||||
gen_input_ids = torch.tensor([600], dtype=torch.int, device=device)
|
||||
|
||||
num_cached_tokens_per_seq = [input_ids.size(-1)]
|
||||
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
|
||||
num_contexts=0,
|
||||
kv_cache_params=metadata_lib.KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
max_num_requests=1,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
gen_position_ids = [
|
||||
torch.arange(input_ids.size(-1), input_ids.size(-1) + gen_input_ids.size(-1))
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not use_cuda_graph:
|
||||
return mistral.forward(
|
||||
input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata
|
||||
)
|
||||
else:
|
||||
graph_runner = cuda_graph_runner.DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata
|
||||
)
|
||||
graph_runner.capture(lambda inputs: mistral.forward(**inputs))
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run(
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
)
|
||||
return logits
|
||||
|
||||
if use_cuda_graph:
|
||||
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits = run_forward(
|
||||
input_ids=gen_input_ids, position_ids=gen_position_ids, attn_metadata=attn_metadata
|
||||
)
|
||||
ref = hf_mistral.forward(
|
||||
input_ids=gen_input_ids.unsqueeze(0),
|
||||
position_ids=gen_position_ids,
|
||||
past_key_values=ref.past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.4, rtol=0.4)
|
||||
104
tests/unittest/_torch/modeling/test_modeling_pixtral.py
Normal file
104
tests/unittest/_torch/modeling/test_modeling_pixtral.py
Normal file
@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from transformers.models.pixtral import modeling_pixtral as hf_modeling_pixtral
|
||||
|
||||
from tensorrt_llm import mapping as mapping_lib
|
||||
from tensorrt_llm._torch import model_config as model_config_lib
|
||||
from tensorrt_llm._torch.models import modeling_pixtral
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pixtral_vision_config():
|
||||
# Values taken from:
|
||||
# https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/config.json
|
||||
return model_config_lib.ModelConfig(
|
||||
pretrained_config=transformers.PixtralVisionConfig(
|
||||
hidden_size=1024,
|
||||
num_attention_heads=16,
|
||||
torch_dtype=torch.bfloat16,
|
||||
hidden_act="silu",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_seed():
|
||||
torch.manual_seed(322)
|
||||
|
||||
|
||||
def init_hf_model(cls, config, dtype, device):
|
||||
"""Helper function for initializing a model from `transformers`.
|
||||
|
||||
The reason this function exists is: by default, instantiating a `transformers` model also
|
||||
eagerly initializes the model's weights on the CPU, which takes an absurdly long time to
|
||||
complete.
|
||||
|
||||
Instead, we lazily instantiate the model, and initialize the weights only after moving it to
|
||||
the requested `device`.
|
||||
"""
|
||||
from transformers import modeling_utils as t_modeling_utils
|
||||
|
||||
with t_modeling_utils.no_init_weights():
|
||||
model = cls(config).eval()
|
||||
|
||||
model.to(device=device)
|
||||
model.init_weights()
|
||||
model.to(dtype=dtype)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mapping",
|
||||
[
|
||||
mapping_lib.Mapping(world_size=2, tp_size=2),
|
||||
mapping_lib.Mapping(world_size=3, tp_size=3),
|
||||
mapping_lib.Mapping(world_size=4, tp_size=2, pp_size=2),
|
||||
mapping_lib.Mapping(world_size=8, tp_size=2, pp_size=2, cp_size=2),
|
||||
],
|
||||
)
|
||||
def test_pixtral_vision_model_rejects_tp_size_greater_than_one(pixtral_vision_config, mapping):
|
||||
pixtral_vision_config.mapping = mapping
|
||||
with pytest.raises(NotImplementedError, match="tp_size > 1"):
|
||||
modeling_pixtral.PixtralVisionModel(model_config=pixtral_vision_config)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.usefixtures("set_seed")
|
||||
def test_pixtral_vision_model_vs_hf(pixtral_vision_config):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
pretrained_config = pixtral_vision_config.pretrained_config
|
||||
|
||||
pixtral_model = (
|
||||
modeling_pixtral.PixtralVisionModel(model_config=pixtral_vision_config).eval().to(device)
|
||||
)
|
||||
hf_pixtral_model = init_hf_model(
|
||||
cls=hf_modeling_pixtral.PixtralVisionModel,
|
||||
config=pretrained_config,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
# Make sure both models have the same weights.
|
||||
pixtral_model.load_weights(hf_pixtral_model.state_dict())
|
||||
|
||||
batch_size = 1
|
||||
height, width, channels = 123, 456, 3
|
||||
pixel_values = torch.randn(batch_size, channels, height, width, device=device, dtype=dtype)
|
||||
image_sizes = torch.tensor([[height, width]])
|
||||
out = pixtral_model(
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
with torch.autocast(device_type="cuda", dtype=dtype):
|
||||
hf_out = (
|
||||
hf_pixtral_model(
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
.last_hidden_state.squeeze(0)
|
||||
.to(dtype=dtype)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, hf_out, atol=0.2, rtol=0.2)
|
||||
Loading…
Reference in New Issue
Block a user