mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: add Pytorch support of Vision Encoder for multimodal models (#3791)
* feat: Add rename_weights_with_regex function for dynamic weight key renaming Introduced a new utility function to rename weight keys in a dictionary based on regex pattern matching. This allows for flexible mapping of keys from Hugging Face naming conventions to TRT-LLM naming conventions, enhancing model compatibility and usability. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> * feat: Implement SiglipVisionModel and related components Added the SiglipVisionModel along with its associated classes, including SiglipAttention, SiglipEncoderLayer, and SiglipEncoder. Additionally, a new test suite for the SiglipVisionModel has been created to ensure compatibility with Hugging Face outputs. Currently SiglipVisionModel support batch size larger than one. Also, inputs and outputs shape are same with the HF for compatibility. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> * feat: Add CLIPVisionModel and associated components Introduced the CLIPVisionModel along with its related classes, including CLIPAttention, CLIPEncoderLayer, CLIPEncoder, and CLIPVisionTransformer. This implementation aligns with Hugging Face's CLIP architecture, ensuring compatibility in input and output shapes. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> * feat: Enhance CLIPVisionModel with attention metadata preparation and unit tests Updated the CLIPVisionModel to include a method for preparing attention metadata, simplifying the model's usage. Additionally, added a comprehensive unit test suite for the CLIPVisionModel, ensuring compatibility with Hugging Face outputs and validating model performance across various scenarios. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> * feat: Refactor SiglipVisionModel with attention metadata preparation and update unit tests Enhanced the SiglipVisionModel by adding a method to prepare attention metadata, streamlining its usage. Updated unit tests to validate model performance and compatibility with Hugging Face outputs, including adjustments to the configuration and test scenarios. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> * refactor: Remove unused rotary_emb parameter from CLIP and Siglip attention classes Eliminated the rotary_emb parameter from the CLIPAttention and SiglipAttention classes to streamline the code. Updated unit tests to reflect changes in the model configurations, including clarifications in the default configurations sourced from Hugging Face. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> * feat: Integrate CLIPVisionModel into LlavaNextInputProcessor and enhance weight loading Added CLIPVisionModel to the LlavaNextInputProcessor for improved vision processing. Updated the model loading mechanism to ensure compatibility with the new vision model and added attention metadata preparation. Removed debug print statements from weight renaming function for cleaner code. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> * refactor: Remove unused max_position_embeddings from CLIPAttention and update Siglip classes to use CLIP components Removed the unused max_position_embeddings variable from the CLIPAttention class. Updated the Siglip classes to utilize CLIP components, specifically replacing SiglipEncoder and SiglipAttention with their CLIP counterparts, streamlining the codebase and enhancing consistency across models. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> * refactor: Consolidate weight loading logic into a shared implementation Refactored the weight loading process across CLIP and Siglip models by using a new utility function, _load_weights_impl, to streamline the loading mechanism. This change enhances code maintainability and reduces redundancy in weight handling, ensuring consistent behavior across different model architectures. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> * refactor: Simplify output handling in CLIP and Siglip models by removing output_hidden_states parameter Removed the output_hidden_states parameter from the CLIPEncoder and SiglipVisionTransformer classes, streamlining the output handling process. Updated the corresponding unit tests to reflect these changes and ensure compatibility with the new output structure. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> * feat: Enhance LlavaNextInputProcessor with dynamic model loading and memory optimization Updated the LlavaNextInputProcessor to support dynamic model loading from local paths or Hugging Face, improving memory efficiency by partially loading the model components. Integrated the LlavaNextMultiModalProjector and adjusted weight loading to ensure compatibility with the new architecture. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> --------- Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Co-authored-by: Haohang Huang <31998628+symphonylyh@users.noreply.github.com>
This commit is contained in:
parent
906cddffb0
commit
bf4f7ad744
@ -2,6 +2,7 @@ import transformers
|
||||
|
||||
from .modeling_auto import AutoModelForCausalLM
|
||||
from .modeling_bert import BertForSequenceClassification
|
||||
from .modeling_clip import CLIPVisionModel
|
||||
from .modeling_deepseekv3 import DeepseekV3ForCausalLM
|
||||
from .modeling_llama import LlamaForCausalLM
|
||||
from .modeling_llava_next import LlavaNextModel
|
||||
@ -16,6 +17,7 @@ from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
|
||||
from .modeling_qwen3 import Qwen3ForCausalLM
|
||||
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
|
||||
from .modeling_qwen_moe import Qwen2MoeForCausalLM
|
||||
from .modeling_siglip import SiglipVisionModel
|
||||
from .modeling_utils import get_model_architecture
|
||||
from .modeling_vila import VilaModel
|
||||
|
||||
@ -23,6 +25,7 @@ from .modeling_vila import VilaModel
|
||||
__all__ = [
|
||||
"AutoModelForCausalLM",
|
||||
"BertForSequenceClassification",
|
||||
"CLIPVisionModel",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
"LlavaNextModel",
|
||||
@ -35,6 +38,7 @@ __all__ = [
|
||||
"Qwen2ForProcessRewardModel",
|
||||
"Qwen2ForRewardModel",
|
||||
"Qwen2MoeForCausalLM",
|
||||
"SiglipVisionModel",
|
||||
"get_model_architecture",
|
||||
"VilaModel",
|
||||
"Qwen2VLModel",
|
||||
|
||||
235
tensorrt_llm/_torch/models/modeling_clip.py
Normal file
235
tensorrt_llm/_torch/models/modeling_clip.py
Normal file
@ -0,0 +1,235 @@
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
from transformers.modeling_utils import (get_parameter_device,
|
||||
get_parameter_dtype)
|
||||
from transformers.models.clip.configuration_clip import CLIPVisionConfig
|
||||
from transformers.models.clip.modeling_clip import CLIPVisionEmbeddings
|
||||
|
||||
from ..attention_backend.interface import (AttentionMetadata,
|
||||
PredefinedAttentionMask)
|
||||
from ..attention_backend.utils import get_attention_backend
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.mlp import MLP
|
||||
from .modeling_utils import _load_weights_impl, register_auto_model
|
||||
|
||||
|
||||
class CLIPAttention(Attention):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[CLIPVisionConfig],
|
||||
layer_idx: int):
|
||||
config = model_config.pretrained_config
|
||||
pos_embd_params = None
|
||||
max_position_embeddings = None
|
||||
|
||||
# CLIP uses bias in attention QKV projections
|
||||
bias = True
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_attention_heads, # CLIP uses MHA
|
||||
max_position_embeddings=
|
||||
max_position_embeddings, # does not matter for CLIP
|
||||
bias=bias,
|
||||
pos_embd_params=pos_embd_params,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype
|
||||
if hasattr(config, 'torch_dtype') else torch.float32,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[CLIPVisionConfig],
|
||||
layer_idx: int):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = CLIPAttention(model_config=model_config,
|
||||
layer_idx=layer_idx)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
activation=ACT2FN[config.hidden_act],
|
||||
bias=True, # CLIP MLP bias=True
|
||||
config=model_config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
position_ids=None, # CLIP doesn't use explicit position_ids here
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
attention_mask=PredefinedAttentionMask.
|
||||
FULL # Always FULL for Vision
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states, )
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[CLIPVisionConfig]):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.config = config # Keep HF config accessible
|
||||
self.model_config = model_config # Keep TRT-LLM config accessible
|
||||
self.layers = nn.ModuleList([
|
||||
CLIPEncoderLayer(model_config, layer_idx)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
|
||||
encoder_states = ()
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
# hidden_states is (batch_size * seq_len, embed_dim) because TRT-LLM Attention is applied to flattened tokens
|
||||
# we want the output shape align with HF output shape (batch_size, seq_len, embed_dim)
|
||||
encoder_states = encoder_states + (hidden_states.view(
|
||||
attn_metadata.seq_lens.shape[0], attn_metadata.seq_lens[0],
|
||||
-1), )
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
# hidden_states is (batch_size * seq_len, embed_dim) because TRT-LLM Attention is applied to flattened tokens
|
||||
# we want the output shape align with HF output shape (batch_size, seq_len, embed_dim)
|
||||
encoder_states = encoder_states + (hidden_states.view(
|
||||
attn_metadata.seq_lens.shape[0], attn_metadata.seq_lens[0], -1), )
|
||||
|
||||
return encoder_states
|
||||
|
||||
|
||||
class CLIPVisionTransformer(nn.Module):
|
||||
"""
|
||||
This CLIPVisionTransformer is tailored for multimodal models that use CLIP as the vision encoder.
|
||||
For example, it is different from the regular CLIPVisionTransformer in the sense that it does not return a pooled output.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig[CLIPVisionConfig]):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
# Use HF Embeddings
|
||||
self.embeddings = CLIPVisionEmbeddings(config)
|
||||
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.encoder = CLIPEncoder(model_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
attn_metadata: AttentionMetadata,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
|
||||
# Reshape for TRT-LLM Attention: (batch * seq_len, hidden)
|
||||
hidden_states = hidden_states.reshape(
|
||||
hidden_states.shape[0] * hidden_states.shape[1],
|
||||
hidden_states.shape[2])
|
||||
encoder_outputs: Tuple[torch.Tensor] = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
@register_auto_model("CLIPVisionModel")
|
||||
class CLIPVisionModel(nn.Module):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[CLIPVisionConfig]):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
self.config = self.model_config.pretrained_config # HF Vision Config
|
||||
self.vision_model = CLIPVisionTransformer(self.model_config)
|
||||
self.metadata_cls = get_attention_backend(
|
||||
model_config.attn_backend).Metadata
|
||||
|
||||
def prepare_attn_metadata(self, batch_size):
|
||||
"""
|
||||
To simplify the usage of the model, this function aims to fill the metadata for Attention
|
||||
Call this function before forward pass
|
||||
"""
|
||||
seq_len = (self.config.image_size // self.config.patch_size)**2 + 1
|
||||
request_ids = list(range(1, batch_size + 1))
|
||||
prompt_lens = [seq_len] * batch_size
|
||||
attn_metadata = self.metadata_cls(
|
||||
seq_lens=torch.tensor([seq_len] * batch_size, dtype=torch.int),
|
||||
num_contexts=batch_size,
|
||||
max_num_requests=batch_size,
|
||||
max_num_tokens=seq_len * batch_size,
|
||||
kv_cache_manager=None,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
attn_metadata.max_seq_len = seq_len * batch_size
|
||||
attn_metadata.prepare()
|
||||
return attn_metadata
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return get_parameter_dtype(self)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return get_parameter_device(self)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self,
|
||||
pixel_values,
|
||||
attn_metadata: AttentionMetadata,
|
||||
interpolate_pos_encoding: Optional[bool] = False):
|
||||
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
attn_metadata=attn_metadata,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
# Pattern mapping for CLIP based on Siglip's example and CLIP HF names
|
||||
pattern_mapping = {
|
||||
r'(.*?)self_attn\.out_proj(.*)': r'\1self_attn.o_proj\2',
|
||||
r'(.*?)mlp\.fc1(.*)': r'\1mlp.up_proj\2',
|
||||
r'(.*?)mlp\.fc2(.*)': r'\1mlp.down_proj\2',
|
||||
}
|
||||
_load_weights_impl(self, weights, pattern_mapping)
|
||||
@ -1,18 +1,25 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import (AutoModel, AutoProcessor, LlavaNextConfig,
|
||||
LlavaNextForConditionalGeneration, PretrainedConfig,
|
||||
PreTrainedModel)
|
||||
import torch.nn as nn
|
||||
from transformers import (AutoConfig, AutoModel, AutoProcessor, LlavaNextConfig,
|
||||
PretrainedConfig, PreTrainedModel)
|
||||
from transformers.modeling_utils import load_sharded_checkpoint
|
||||
from transformers.models.llava_next.modeling_llava_next import \
|
||||
LlavaNextMultiModalProjector
|
||||
|
||||
from ..._utils import nvtx_range
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...llmapi.utils import download_hf_model
|
||||
from ...logger import logger
|
||||
from ...sampling_params import SamplingParams
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..model_config import ModelConfig
|
||||
from .modeling_auto import AutoModelForCausalLM
|
||||
from .modeling_clip import CLIPVisionModel
|
||||
from .modeling_multimodal_utils import fuse_input_embeds
|
||||
from .modeling_utils import ModelConfig, register_auto_model
|
||||
|
||||
@ -25,11 +32,40 @@ class LlavaNextInputProcessor(InputProcessor):
|
||||
use_fast=True)
|
||||
self.model_config = model_config
|
||||
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(
|
||||
model_path, torch_dtype=model_config.text_config.torch_dtype)
|
||||
self.device = 'cuda'
|
||||
self.vision_tower = model.vision_tower.vision_model.to(self.device)
|
||||
self.mm_projector = model.multi_modal_projector.to(self.device)
|
||||
|
||||
# Determine the actual local path for model files
|
||||
if os.path.isdir(model_path):
|
||||
local_model_path = model_path
|
||||
else:
|
||||
local_model_path = download_hf_model(model_path)
|
||||
|
||||
# Partially load the model to reduce memory usage(Vision tower and multi-modal projector)
|
||||
hf_model_config = AutoConfig.from_pretrained(local_model_path)
|
||||
self.dtype = hf_model_config.text_config.torch_dtype
|
||||
module_dict = nn.ModuleDict({
|
||||
"vision_tower":
|
||||
AutoModel.from_config(hf_model_config.vision_config),
|
||||
"multi_modal_projector":
|
||||
LlavaNextMultiModalProjector(hf_model_config)
|
||||
})
|
||||
missing_keys, _ = load_sharded_checkpoint(module_dict,
|
||||
local_model_path,
|
||||
strict=False)
|
||||
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
|
||||
hf_vision_tower = module_dict["vision_tower"].to(self.dtype)
|
||||
hf_mm_projector = module_dict["multi_modal_projector"].to(
|
||||
self.dtype).to(self.device)
|
||||
|
||||
# Use TRTLLM vision tower(CLIPVisionModel)
|
||||
vision_model_config = ModelConfig(
|
||||
pretrained_config=model_config.vision_config, attn_backend="TRTLLM")
|
||||
self.vision_tower = CLIPVisionModel(vision_model_config).to(
|
||||
self.device).to(self.dtype)
|
||||
self.vision_tower.load_weights(hf_vision_tower.state_dict())
|
||||
|
||||
# Use HF multi-modal projector
|
||||
self.mm_projector = hf_mm_projector
|
||||
|
||||
@nvtx_range("[Vision] preprocess")
|
||||
def _preprocess(self, images):
|
||||
@ -44,9 +80,13 @@ class LlavaNextInputProcessor(InputProcessor):
|
||||
|
||||
@nvtx_range("[Vision] process")
|
||||
def _process(self, pixel_values):
|
||||
image_features = self.vision_tower(pixel_values,
|
||||
output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[-2][:, 1:]
|
||||
attn_metadata = self.vision_tower.prepare_attn_metadata(
|
||||
pixel_values.shape[0])
|
||||
image_features: Tuple[torch.Tensor] = self.vision_tower(
|
||||
pixel_values,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
selected_image_feature = image_features[-2][:, 1:]
|
||||
image_features = self.mm_projector(selected_image_feature)
|
||||
return image_features.reshape(-1, image_features.shape[-1])
|
||||
|
||||
|
||||
111
tensorrt_llm/_torch/models/modeling_siglip.py
Normal file
111
tensorrt_llm/_torch/models/modeling_siglip.py
Normal file
@ -0,0 +1,111 @@
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.modeling_utils import (get_parameter_device,
|
||||
get_parameter_dtype)
|
||||
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
|
||||
from transformers.models.siglip.modeling_siglip import (SiglipVisionConfig,
|
||||
SiglipVisionEmbeddings)
|
||||
|
||||
from ..attention_backend.interface import AttentionMetadata
|
||||
from ..attention_backend.utils import get_attention_backend
|
||||
from ..model_config import ModelConfig
|
||||
from .modeling_clip import CLIPEncoder
|
||||
from .modeling_utils import _load_weights_impl, register_auto_model
|
||||
|
||||
SiglipEncoder = CLIPEncoder
|
||||
|
||||
|
||||
class SiglipVisionTransformer(nn.Module):
|
||||
"""
|
||||
This SiglipVisionTransformer is tailored for multimodal models that use Siglip as the vision encoder.
|
||||
For example, it is different from the regular SiglipVisionTransformer in the sense that it does not return a pooled output.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig[SiglipVisionConfig]):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.config = config
|
||||
|
||||
self.embeddings = SiglipVisionEmbeddings(config)
|
||||
self.encoder = SiglipEncoder(model_config)
|
||||
if hasattr(config, "vision_use_head"):
|
||||
assert not config.vision_use_head, "Currently, we only support vision_use_head = False"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
attn_metadata: AttentionMetadata,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
# reshape hidden_states to (batch_size * seq_len, embed_dim)
|
||||
hidden_states = hidden_states.reshape(
|
||||
hidden_states.shape[0] * hidden_states.shape[1],
|
||||
hidden_states.shape[2])
|
||||
|
||||
encoder_outputs: Tuple[torch.Tensor] = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
@register_auto_model("SiglipVisionModel")
|
||||
class SiglipVisionModel(nn.Module):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[SiglipVisionConfig]):
|
||||
super().__init__()
|
||||
self.config = model_config.pretrained_config
|
||||
self.vision_model = SiglipVisionTransformer(model_config)
|
||||
self.model_config = model_config
|
||||
self.metadata_cls = get_attention_backend(
|
||||
model_config.attn_backend).Metadata
|
||||
|
||||
def prepare_attn_metadata(self, batch_size):
|
||||
"""
|
||||
To simplify the usage of the model, this function aims to fill the metadata for Attention
|
||||
Call this function before forward pass
|
||||
"""
|
||||
seq_len = (self.config.image_size // self.config.patch_size)**2
|
||||
request_ids = list(range(1, batch_size + 1))
|
||||
prompt_lens = [seq_len] * batch_size
|
||||
attn_metadata = self.metadata_cls(
|
||||
seq_lens=torch.tensor([seq_len] * batch_size, dtype=torch.int),
|
||||
num_contexts=batch_size,
|
||||
max_num_requests=batch_size,
|
||||
max_num_tokens=seq_len * batch_size,
|
||||
kv_cache_manager=None,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
attn_metadata.max_seq_len = seq_len * batch_size
|
||||
attn_metadata.prepare()
|
||||
return attn_metadata
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return get_parameter_dtype(self)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return get_parameter_device(self)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, pixel_values, attn_metadata: AttentionMetadata):
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
pattern_mapping = {
|
||||
r'(.*?)out_proj(.*)': r'\1o_proj\2',
|
||||
r'(.*?)fc1(.*)': r'\1up_proj\2',
|
||||
r'(.*?)fc2(.*)': r'\1down_proj\2',
|
||||
}
|
||||
_load_weights_impl(self, weights, pattern_mapping)
|
||||
@ -510,73 +510,7 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
tp_size = self.model_config.mapping.tp_size
|
||||
head_dim = getattr(
|
||||
self.config, "head_dim",
|
||||
self.config.hidden_size // self.config.num_attention_heads)
|
||||
|
||||
def filter_weights(prefix, weights: Dict):
|
||||
result = {}
|
||||
for k, v in weights.items():
|
||||
if k.startswith(prefix):
|
||||
new_k = k[len(prefix) + 1:]
|
||||
result[new_k] = v
|
||||
return result
|
||||
|
||||
params_map = {
|
||||
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
|
||||
'gate_up_proj': ['gate_proj', 'up_proj']
|
||||
}
|
||||
|
||||
for name, module in tqdm(list(self.named_modules()),
|
||||
desc="Loading weights"):
|
||||
if len(module._parameters) > 0:
|
||||
# skip load weights if tie word embeddings is enabled and layer is lm_head
|
||||
if self.config.tie_word_embeddings and name.startswith(
|
||||
"lm_head"):
|
||||
continue
|
||||
|
||||
# Skip loading weights for embedding and lm_head if LoRA is enabled
|
||||
if hasattr(self.model_config, 'lora_config'
|
||||
) and self.model_config.lora_config is not None and (
|
||||
name == "model.embed_tokens"
|
||||
or name == "lm_head"):
|
||||
continue
|
||||
|
||||
# Skip if parameter belongs to a missing layer
|
||||
if missing_layer_parameter(name, self):
|
||||
continue
|
||||
|
||||
names = name.split('.')
|
||||
# WAR: better solution is that llama has its own load_weights function.
|
||||
if names[-1] == 'next_layer_layernorm':
|
||||
continue
|
||||
if names[-1] in params_map:
|
||||
module_weights = []
|
||||
for new_name in params_map[names[-1]]:
|
||||
fw = filter_weights('.'.join(names[:-1] + [new_name]),
|
||||
weights)
|
||||
if new_name in ['k_proj', 'v_proj']:
|
||||
fw = {
|
||||
k:
|
||||
duplicate_kv_weight(
|
||||
weight=v[:],
|
||||
head_dim=head_dim,
|
||||
tensor_parallel_size=tp_size)
|
||||
if k in ["weight", "bias"] else v
|
||||
for k, v in fw.items()
|
||||
}
|
||||
|
||||
module_weights.append(fw)
|
||||
module.load_weights(weights=module_weights)
|
||||
else:
|
||||
module_weights = filter_weights(name, weights)
|
||||
if hasattr(module, 'load_weights'):
|
||||
module.load_weights(weights=[module_weights])
|
||||
else:
|
||||
for n, p in module._parameters.items():
|
||||
if p is not None:
|
||||
p.data.copy_(module_weights[n][:])
|
||||
_load_weights_impl(self, weights)
|
||||
|
||||
def infer_max_seq_len(self) -> int:
|
||||
# Modified from tensorrt_llm/builder.py _init_max_seq_len
|
||||
@ -634,3 +568,129 @@ def get_model_architecture(
|
||||
def support_pp(cls: Type) -> Type:
|
||||
cls._supports_pp = True
|
||||
return cls
|
||||
|
||||
|
||||
def rename_weights_with_regex(pattern_mapping: Dict[str, str], weights: Dict):
|
||||
"""
|
||||
Rename weight keys according to regex pattern matching.
|
||||
|
||||
Args:
|
||||
pattern_mapping: A dictionary mapping regex patterns to replacement strings. The key is HF name pattern, and the value is corresponding TRT-LLM name pattern.
|
||||
The patterns will be used to match keys in the weights dict and replace
|
||||
them according to the replacement string, which can use regex backreferences.
|
||||
Example:
|
||||
HF name: vision_model.encoder.layers.1.self_attn.out_proj.{weight,bias}
|
||||
TRT-LLM name: vision_model.encoder.layers.1.self_attn.o_proj.{weight,bias}
|
||||
Then the pattern_mapping could be:
|
||||
pattern_mapping = {
|
||||
r'(.*?)out_proj(.*)': r'\1o_proj\2'
|
||||
}
|
||||
weights: A dictionary of weights
|
||||
|
||||
Returns:
|
||||
A dictionary of weights with renamed keys
|
||||
"""
|
||||
import re
|
||||
|
||||
# Create a new dictionary to store the renamed weights
|
||||
renamed_weights = {}
|
||||
|
||||
# Keep track of keys that have been matched by a pattern
|
||||
matched_keys = set()
|
||||
|
||||
# Process each key in the weights dictionary
|
||||
for key in list(weights.keys()):
|
||||
# Check each pattern for a match
|
||||
for pattern, replacement in pattern_mapping.items():
|
||||
if re.match(pattern, key):
|
||||
# Create the new key by applying the regex replacement
|
||||
new_key = re.sub(pattern, replacement, key)
|
||||
# Store the weight with the new key
|
||||
renamed_weights[new_key] = weights[key]
|
||||
matched_keys.add(key)
|
||||
break
|
||||
|
||||
# If the key wasn't matched by any pattern, keep it as is
|
||||
if key not in matched_keys:
|
||||
renamed_weights[key] = weights[key]
|
||||
|
||||
return renamed_weights
|
||||
|
||||
|
||||
def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
weights: Dict,
|
||||
params_map: Optional[Dict[str, str]] = None):
|
||||
if not hasattr(model, 'model_config') or not isinstance(
|
||||
model.model_config, ModelConfig):
|
||||
raise ValueError("model must have a model_config attribute")
|
||||
if not hasattr(model, 'config'):
|
||||
raise ValueError("model must have a config attribute")
|
||||
|
||||
if params_map is not None:
|
||||
weights = rename_weights_with_regex(params_map, weights)
|
||||
logger.info(f"Renamed weights with params_map: {params_map}")
|
||||
|
||||
tp_size = model.model_config.mapping.tp_size
|
||||
head_dim = getattr(
|
||||
model.config, "head_dim",
|
||||
model.config.hidden_size // model.config.num_attention_heads)
|
||||
|
||||
def filter_weights(prefix, weights: Dict):
|
||||
result = {}
|
||||
for k, v in weights.items():
|
||||
if k.startswith(prefix):
|
||||
new_k = k[len(prefix) + 1:]
|
||||
result[new_k] = v
|
||||
return result
|
||||
|
||||
params_map = {
|
||||
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
|
||||
'gate_up_proj': ['gate_proj', 'up_proj']
|
||||
}
|
||||
|
||||
for name, module in tqdm(list(model.named_modules()),
|
||||
desc="Loading weights"):
|
||||
if len(module._parameters) > 0:
|
||||
# skip load weights if tie word embeddings is enabled and layer is lm_head
|
||||
if model.config.tie_word_embeddings and name.startswith("lm_head"):
|
||||
continue
|
||||
|
||||
# Skip loading weights for embedding and lm_head if LoRA is enabled
|
||||
if hasattr(model.model_config, 'lora_config'
|
||||
) and model.model_config.lora_config is not None and (
|
||||
name == "model.embed_tokens" or name == "lm_head"):
|
||||
continue
|
||||
|
||||
# Skip if parameter belongs to a missing layer
|
||||
if missing_layer_parameter(name, model):
|
||||
continue
|
||||
|
||||
names = name.split('.')
|
||||
# WAR: better solution is that llama has its own load_weights function.
|
||||
if names[-1] == 'next_layer_layernorm':
|
||||
continue
|
||||
if names[-1] in params_map:
|
||||
module_weights = []
|
||||
for new_name in params_map[names[-1]]:
|
||||
fw = filter_weights('.'.join(names[:-1] + [new_name]),
|
||||
weights)
|
||||
if new_name in ['k_proj', 'v_proj']:
|
||||
fw = {
|
||||
k:
|
||||
duplicate_kv_weight(weight=v[:],
|
||||
head_dim=head_dim,
|
||||
tensor_parallel_size=tp_size)
|
||||
if k in ["weight", "bias"] else v
|
||||
for k, v in fw.items()
|
||||
}
|
||||
|
||||
module_weights.append(fw)
|
||||
module.load_weights(weights=module_weights)
|
||||
else:
|
||||
module_weights = filter_weights(name, weights)
|
||||
if hasattr(module, 'load_weights'):
|
||||
module.load_weights(weights=[module_weights])
|
||||
else:
|
||||
for n, p in module._parameters.items():
|
||||
if p is not None:
|
||||
p.data.copy_(module_weights[n][:])
|
||||
|
||||
129
tests/unittest/_torch/modeling/test_modeling_clip.py
Normal file
129
tests/unittest/_torch/modeling/test_modeling_clip.py
Normal file
@ -0,0 +1,129 @@
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
# Import CLIP specific classes from HF
|
||||
from transformers import CLIPVisionConfig
|
||||
from transformers import CLIPVisionModel as HFCLIPVisionModel
|
||||
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
# Import TRT-LLM CLIP model
|
||||
from tensorrt_llm._torch.models.modeling_clip import CLIPVisionModel
|
||||
|
||||
# Default CLIP config from HF (https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/clip/configuration_clip.py#L144-L172)
|
||||
CLIP_CONFIG = {
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_hidden_layers": 12,
|
||||
"num_attention_heads": 12,
|
||||
"num_channels": 3,
|
||||
"image_size": 224,
|
||||
"patch_size": 32,
|
||||
"hidden_act": "quick_gelu",
|
||||
"layer_norm_eps": 1e-5,
|
||||
"attention_dropout": 0.0,
|
||||
"initializer_range": 0.02,
|
||||
"initializer_factor": 1.0,
|
||||
}
|
||||
|
||||
ACCURACY_CONFIG = {
|
||||
torch.float16: (1e-2, 1e-2),
|
||||
}
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class Scenario:
|
||||
backend: str
|
||||
num_images: int
|
||||
dtype: torch.dtype # Add dtype to scenario
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"backend:{self.backend.lower()}_num_images:{self.num_images}_dtype:{self.dtype}"
|
||||
|
||||
|
||||
class TestCLIPVisionModel(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch.random.manual_seed(720)
|
||||
|
||||
@parameterized.expand([
|
||||
Scenario(backend="VANILLA", num_images=2, dtype=torch.float16),
|
||||
Scenario(backend="TRTLLM", num_images=2, dtype=torch.float16),
|
||||
Scenario(backend="TRTLLM", num_images=21, dtype=torch.float16),
|
||||
], lambda testcase_func, param_num, param:
|
||||
f"{testcase_func.__name__}[{param.args[0]}]")
|
||||
def test_clip_vision_allclose_to_hf(self, scenario: Scenario):
|
||||
"""Compare output to HF"""
|
||||
backend = scenario.backend
|
||||
num_images = scenario.num_images
|
||||
dtype = scenario.dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
# Create configs
|
||||
config_dict = deepcopy(CLIP_CONFIG)
|
||||
hf_config = CLIPVisionConfig.from_dict(config_dict)
|
||||
|
||||
# Prepare HF model
|
||||
hf_model = HFCLIPVisionModel(hf_config).to(dtype).to(device).eval()
|
||||
|
||||
# Prepare tllm pytorch model
|
||||
model_config = ModelConfig(
|
||||
pretrained_config=hf_config,
|
||||
attn_backend=backend,
|
||||
)
|
||||
|
||||
tllm_model = CLIPVisionModel(model_config).to(dtype).to(device)
|
||||
# Use the load_weights method we are testing
|
||||
tllm_model.load_weights(hf_model.state_dict())
|
||||
|
||||
# Prepare inputs - create random pixel values for images
|
||||
batch_size = num_images
|
||||
pixel_values = torch.rand(batch_size,
|
||||
hf_config.num_channels,
|
||||
hf_config.image_size,
|
||||
hf_config.image_size,
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
|
||||
# Run HF inference
|
||||
with torch.inference_mode():
|
||||
hf_outputs = hf_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=False,
|
||||
output_hidden_states=True) # Get hidden states for comparison
|
||||
|
||||
# Run TRT-LLM inference
|
||||
attn_metadata = tllm_model.prepare_attn_metadata(batch_size)
|
||||
tllm_outputs = tllm_model(
|
||||
pixel_values=pixel_values,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
rtol, atol = ACCURACY_CONFIG[dtype]
|
||||
|
||||
# Compare all hidden states
|
||||
|
||||
for i, (hf_hs, tllm_hs) in enumerate(
|
||||
zip(hf_outputs.hidden_states,
|
||||
tllm_outputs)): # Iterate through tllm_outputs directly
|
||||
self.assertEqual(hf_hs.shape, tllm_hs.shape,
|
||||
f"Shape mismatch for hidden state {i}")
|
||||
torch.testing.assert_close(
|
||||
hf_hs.float(),
|
||||
tllm_hs.float(),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=
|
||||
f"FAILED: TRT-LLM and HF hidden_states mismatch for {dtype} with {num_images} images at layer {i}"
|
||||
)
|
||||
print(
|
||||
f"PASSED: TRT-LLM and HF hidden_states match for {dtype} with {num_images} images at layer {i}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
125
tests/unittest/_torch/modeling/test_modeling_siglip.py
Normal file
125
tests/unittest/_torch/modeling/test_modeling_siglip.py
Normal file
@ -0,0 +1,125 @@
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import SiglipVisionConfig
|
||||
from transformers import SiglipVisionModel as HFSiglipVisionModel
|
||||
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_siglip import SiglipVisionModel
|
||||
|
||||
# use the default config from HF (https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/configuration_siglip.py#L126-L147)
|
||||
SIGLIP_CONFIG = {
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_hidden_layers": 12,
|
||||
"num_attention_heads": 12,
|
||||
"image_size": 224,
|
||||
"patch_size": 16,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"layer_norm_eps": 1e-6,
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"num_channels": 3,
|
||||
"vision_use_head": False,
|
||||
}
|
||||
|
||||
ACCURACY_CONFIG = {
|
||||
torch.float16: (2e-2, 5e-2),
|
||||
}
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class Scenario:
|
||||
backend: str
|
||||
num_images: int
|
||||
dtype: torch.dtype
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"backend:{self.backend.lower()}_num_images:{self.num_images}_dtype:{self.dtype}"
|
||||
|
||||
|
||||
class TestSiglipVisionModel(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch.random.manual_seed(1234)
|
||||
|
||||
@parameterized.expand([
|
||||
Scenario(backend="VANILLA", num_images=2, dtype=torch.float16),
|
||||
Scenario(backend="TRTLLM", num_images=2, dtype=torch.float16),
|
||||
Scenario(backend="TRTLLM", num_images=21, dtype=torch.float16),
|
||||
], lambda testcase_func, param_num, param:
|
||||
f"{testcase_func.__name__}[{param.args[0]}]")
|
||||
def test_siglip_vision_allclose_to_hf(self, scenario: Scenario):
|
||||
"""Compare output to HF"""
|
||||
backend = scenario.backend
|
||||
num_images = scenario.num_images
|
||||
dtype = scenario.dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
# Create configs
|
||||
config_dict = deepcopy(SIGLIP_CONFIG)
|
||||
hf_config = SiglipVisionConfig.from_dict(config_dict)
|
||||
|
||||
# Prepare HF model
|
||||
hf_model = HFSiglipVisionModel(hf_config).to(dtype).to(device).eval()
|
||||
|
||||
# Prepare tllm pytorch model
|
||||
model_config = ModelConfig(
|
||||
pretrained_config=hf_config,
|
||||
attn_backend=backend,
|
||||
)
|
||||
|
||||
tllm_model = SiglipVisionModel(model_config).to(dtype).to(device)
|
||||
tllm_model.load_weights(hf_model.state_dict())
|
||||
|
||||
# Prepare inputs - create random pixel values for images
|
||||
batch_size = num_images
|
||||
pixel_values = torch.rand(batch_size,
|
||||
hf_config.num_channels,
|
||||
hf_config.image_size,
|
||||
hf_config.image_size,
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
|
||||
# Run HF inference
|
||||
with torch.inference_mode():
|
||||
# HF model forward
|
||||
hf_outputs = hf_model(pixel_values=pixel_values,
|
||||
output_attentions=False,
|
||||
output_hidden_states=True)
|
||||
|
||||
# Fill the metadata for tllm attn
|
||||
attn_metadata = tllm_model.prepare_attn_metadata(batch_size)
|
||||
|
||||
# TRT-LLM model forward
|
||||
tllm_outputs = tllm_model(
|
||||
pixel_values=pixel_values,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Compare all hidden states
|
||||
|
||||
for i, (hf_hs, tllm_hs) in enumerate(
|
||||
zip(hf_outputs.hidden_states, tllm_outputs)):
|
||||
self.assertEqual(hf_hs.shape, tllm_hs.shape,
|
||||
f"Shape mismatch for hidden state {i}")
|
||||
|
||||
torch.testing.assert_close(
|
||||
hf_hs.float(),
|
||||
tllm_hs.float(),
|
||||
rtol=ACCURACY_CONFIG[dtype][0],
|
||||
atol=ACCURACY_CONFIG[dtype][1],
|
||||
msg=
|
||||
f"FAILED: TRT-LLM and HF hidden_states mismatch for {dtype} with {num_images} images at layer {i}, the mean value of this layer is {hf_hs.mean()}"
|
||||
)
|
||||
print(
|
||||
f"PASSED: TRT-LLM and HF hidden_states match for {dtype} with {num_images} images at layer {i}, the mean value of this layer is {hf_hs.mean()}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user