feat: add Pytorch support of Vision Encoder for multimodal models (#3791)

* feat: Add rename_weights_with_regex function for dynamic weight key renaming

Introduced a new utility function to rename weight keys in a dictionary based on regex pattern matching. This allows for flexible mapping of keys from Hugging Face naming conventions to TRT-LLM naming conventions, enhancing model compatibility and usability.

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>

* feat: Implement SiglipVisionModel and related components

Added the SiglipVisionModel along with its associated classes, including SiglipAttention, SiglipEncoderLayer, and SiglipEncoder.
Additionally, a new test suite for the SiglipVisionModel has been created to ensure compatibility with Hugging Face outputs.

Currently SiglipVisionModel support batch size larger than one. Also, inputs and outputs shape are same with the HF for compatibility.

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>

* feat: Add CLIPVisionModel and associated components

Introduced the CLIPVisionModel along with its related classes, including CLIPAttention, CLIPEncoderLayer, CLIPEncoder, and CLIPVisionTransformer. This implementation aligns with Hugging Face's CLIP architecture, ensuring compatibility in input and output shapes.

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>

* feat: Enhance CLIPVisionModel with attention metadata preparation and unit tests

Updated the CLIPVisionModel to include a method for preparing attention metadata, simplifying the model's usage. Additionally, added a comprehensive unit test suite for the CLIPVisionModel, ensuring compatibility with Hugging Face outputs and validating model performance across various scenarios.

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>

* feat: Refactor SiglipVisionModel with attention metadata preparation and update unit tests

Enhanced the SiglipVisionModel by adding a method to prepare attention metadata, streamlining its usage. Updated unit tests to validate model performance and compatibility with Hugging Face outputs, including adjustments to the configuration and test scenarios.

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>

* refactor: Remove unused rotary_emb parameter from CLIP and Siglip attention classes

Eliminated the rotary_emb parameter from the CLIPAttention and SiglipAttention classes to streamline the code. Updated unit tests to reflect changes in the model configurations, including clarifications in the default configurations sourced from Hugging Face.

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>

* feat: Integrate CLIPVisionModel into LlavaNextInputProcessor and enhance weight loading

Added CLIPVisionModel to the LlavaNextInputProcessor for improved vision processing. Updated the model loading mechanism to ensure compatibility with the new vision model and added attention metadata preparation. Removed debug print statements from weight renaming function for cleaner code.

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>

* refactor: Remove unused max_position_embeddings from CLIPAttention and update Siglip classes to use CLIP components

Removed the unused max_position_embeddings variable from the CLIPAttention class. Updated the Siglip classes to utilize CLIP components, specifically replacing SiglipEncoder and SiglipAttention with their CLIP counterparts, streamlining the codebase and enhancing consistency across models.

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>

* refactor: Consolidate weight loading logic into a shared implementation

Refactored the weight loading process across CLIP and Siglip models by using a new utility function, _load_weights_impl, to streamline the loading mechanism. This change enhances code maintainability and reduces redundancy in weight handling, ensuring consistent behavior across different model architectures.

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>

* refactor: Simplify output handling in CLIP and Siglip models by removing output_hidden_states parameter

Removed the output_hidden_states parameter from the CLIPEncoder and SiglipVisionTransformer classes, streamlining the output handling process. Updated the corresponding unit tests to reflect these changes and ensure compatibility with the new output structure.

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>

* feat: Enhance LlavaNextInputProcessor with dynamic model loading and memory optimization

Updated the LlavaNextInputProcessor to support dynamic model loading from local paths or Hugging Face, improving memory efficiency by partially loading the model components. Integrated the LlavaNextMultiModalProjector and adjusted weight loading to ensure compatibility with the new architecture.

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>

---------

Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com>
Co-authored-by: Haohang Huang <31998628+symphonylyh@users.noreply.github.com>
This commit is contained in:
qixiang-99 2025-05-02 14:13:47 -07:00 committed by GitHub
parent 906cddffb0
commit bf4f7ad744
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 781 additions and 77 deletions

View File

@ -2,6 +2,7 @@ import transformers
from .modeling_auto import AutoModelForCausalLM
from .modeling_bert import BertForSequenceClassification
from .modeling_clip import CLIPVisionModel
from .modeling_deepseekv3 import DeepseekV3ForCausalLM
from .modeling_llama import LlamaForCausalLM
from .modeling_llava_next import LlavaNextModel
@ -16,6 +17,7 @@ from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
from .modeling_qwen3 import Qwen3ForCausalLM
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
from .modeling_qwen_moe import Qwen2MoeForCausalLM
from .modeling_siglip import SiglipVisionModel
from .modeling_utils import get_model_architecture
from .modeling_vila import VilaModel
@ -23,6 +25,7 @@ from .modeling_vila import VilaModel
__all__ = [
"AutoModelForCausalLM",
"BertForSequenceClassification",
"CLIPVisionModel",
"DeepseekV3ForCausalLM",
"LlamaForCausalLM",
"LlavaNextModel",
@ -35,6 +38,7 @@ __all__ = [
"Qwen2ForProcessRewardModel",
"Qwen2ForRewardModel",
"Qwen2MoeForCausalLM",
"SiglipVisionModel",
"get_model_architecture",
"VilaModel",
"Qwen2VLModel",

View File

@ -0,0 +1,235 @@
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import (get_parameter_device,
get_parameter_dtype)
from transformers.models.clip.configuration_clip import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPVisionEmbeddings
from ..attention_backend.interface import (AttentionMetadata,
PredefinedAttentionMask)
from ..attention_backend.utils import get_attention_backend
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.mlp import MLP
from .modeling_utils import _load_weights_impl, register_auto_model
class CLIPAttention(Attention):
def __init__(self, model_config: ModelConfig[CLIPVisionConfig],
layer_idx: int):
config = model_config.pretrained_config
pos_embd_params = None
max_position_embeddings = None
# CLIP uses bias in attention QKV projections
bias = True
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_attention_heads, # CLIP uses MHA
max_position_embeddings=
max_position_embeddings, # does not matter for CLIP
bias=bias,
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=config.torch_dtype
if hasattr(config, 'torch_dtype') else torch.float32,
config=model_config,
)
class CLIPEncoderLayer(nn.Module):
def __init__(self, model_config: ModelConfig[CLIPVisionConfig],
layer_idx: int):
super().__init__()
config = model_config.pretrained_config
self.embed_dim = config.hidden_size
self.self_attn = CLIPAttention(model_config=model_config,
layer_idx=layer_idx)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
activation=ACT2FN[config.hidden_act],
bias=True, # CLIP MLP bias=True
config=model_config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.FloatTensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
position_ids=None, # CLIP doesn't use explicit position_ids here
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask=PredefinedAttentionMask.
FULL # Always FULL for Vision
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, )
return outputs
class CLIPEncoder(nn.Module):
def __init__(self, model_config: ModelConfig[CLIPVisionConfig]):
super().__init__()
config = model_config.pretrained_config
self.config = config # Keep HF config accessible
self.model_config = model_config # Keep TRT-LLM config accessible
self.layers = nn.ModuleList([
CLIPEncoderLayer(model_config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
inputs_embeds,
attn_metadata: AttentionMetadata,
) -> Union[Tuple, BaseModelOutput]:
encoder_states = ()
hidden_states = inputs_embeds
for encoder_layer in self.layers:
# hidden_states is (batch_size * seq_len, embed_dim) because TRT-LLM Attention is applied to flattened tokens
# we want the output shape align with HF output shape (batch_size, seq_len, embed_dim)
encoder_states = encoder_states + (hidden_states.view(
attn_metadata.seq_lens.shape[0], attn_metadata.seq_lens[0],
-1), )
layer_outputs = encoder_layer(
hidden_states,
attn_metadata=attn_metadata,
)
hidden_states = layer_outputs[0]
# hidden_states is (batch_size * seq_len, embed_dim) because TRT-LLM Attention is applied to flattened tokens
# we want the output shape align with HF output shape (batch_size, seq_len, embed_dim)
encoder_states = encoder_states + (hidden_states.view(
attn_metadata.seq_lens.shape[0], attn_metadata.seq_lens[0], -1), )
return encoder_states
class CLIPVisionTransformer(nn.Module):
"""
This CLIPVisionTransformer is tailored for multimodal models that use CLIP as the vision encoder.
For example, it is different from the regular CLIPVisionTransformer in the sense that it does not return a pooled output.
"""
def __init__(self, model_config: ModelConfig[CLIPVisionConfig]):
super().__init__()
config = model_config.pretrained_config
self.config = config
embed_dim = config.hidden_size
# Use HF Embeddings
self.embeddings = CLIPVisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(model_config)
def forward(
self,
pixel_values,
attn_metadata: AttentionMetadata,
interpolate_pos_encoding: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
hidden_states = self.embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)
# Reshape for TRT-LLM Attention: (batch * seq_len, hidden)
hidden_states = hidden_states.reshape(
hidden_states.shape[0] * hidden_states.shape[1],
hidden_states.shape[2])
encoder_outputs: Tuple[torch.Tensor] = self.encoder(
inputs_embeds=hidden_states,
attn_metadata=attn_metadata,
)
return encoder_outputs
@register_auto_model("CLIPVisionModel")
class CLIPVisionModel(nn.Module):
def __init__(self, model_config: ModelConfig[CLIPVisionConfig]):
super().__init__()
self.model_config = model_config
self.config = self.model_config.pretrained_config # HF Vision Config
self.vision_model = CLIPVisionTransformer(self.model_config)
self.metadata_cls = get_attention_backend(
model_config.attn_backend).Metadata
def prepare_attn_metadata(self, batch_size):
"""
To simplify the usage of the model, this function aims to fill the metadata for Attention
Call this function before forward pass
"""
seq_len = (self.config.image_size // self.config.patch_size)**2 + 1
request_ids = list(range(1, batch_size + 1))
prompt_lens = [seq_len] * batch_size
attn_metadata = self.metadata_cls(
seq_lens=torch.tensor([seq_len] * batch_size, dtype=torch.int),
num_contexts=batch_size,
max_num_requests=batch_size,
max_num_tokens=seq_len * batch_size,
kv_cache_manager=None,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
attn_metadata.max_seq_len = seq_len * batch_size
attn_metadata.prepare()
return attn_metadata
@property
def dtype(self):
return get_parameter_dtype(self)
@property
def device(self):
return get_parameter_device(self)
@torch.inference_mode()
def forward(self,
pixel_values,
attn_metadata: AttentionMetadata,
interpolate_pos_encoding: Optional[bool] = False):
return self.vision_model(
pixel_values=pixel_values,
attn_metadata=attn_metadata,
interpolate_pos_encoding=interpolate_pos_encoding)
def load_weights(self, weights: Dict):
# Pattern mapping for CLIP based on Siglip's example and CLIP HF names
pattern_mapping = {
r'(.*?)self_attn\.out_proj(.*)': r'\1self_attn.o_proj\2',
r'(.*?)mlp\.fc1(.*)': r'\1mlp.up_proj\2',
r'(.*?)mlp\.fc2(.*)': r'\1mlp.down_proj\2',
}
_load_weights_impl(self, weights, pattern_mapping)

View File

@ -1,18 +1,25 @@
import copy
import os
from typing import Dict, List, Optional, Tuple
import torch
from transformers import (AutoModel, AutoProcessor, LlavaNextConfig,
LlavaNextForConditionalGeneration, PretrainedConfig,
PreTrainedModel)
import torch.nn as nn
from transformers import (AutoConfig, AutoModel, AutoProcessor, LlavaNextConfig,
PretrainedConfig, PreTrainedModel)
from transformers.modeling_utils import load_sharded_checkpoint
from transformers.models.llava_next.modeling_llava_next import \
LlavaNextMultiModalProjector
from ..._utils import nvtx_range
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
register_input_processor)
from ...llmapi.utils import download_hf_model
from ...logger import logger
from ...sampling_params import SamplingParams
from ..attention_backend import AttentionMetadata
from ..model_config import ModelConfig
from .modeling_auto import AutoModelForCausalLM
from .modeling_clip import CLIPVisionModel
from .modeling_multimodal_utils import fuse_input_embeds
from .modeling_utils import ModelConfig, register_auto_model
@ -25,11 +32,40 @@ class LlavaNextInputProcessor(InputProcessor):
use_fast=True)
self.model_config = model_config
model = LlavaNextForConditionalGeneration.from_pretrained(
model_path, torch_dtype=model_config.text_config.torch_dtype)
self.device = 'cuda'
self.vision_tower = model.vision_tower.vision_model.to(self.device)
self.mm_projector = model.multi_modal_projector.to(self.device)
# Determine the actual local path for model files
if os.path.isdir(model_path):
local_model_path = model_path
else:
local_model_path = download_hf_model(model_path)
# Partially load the model to reduce memory usage(Vision tower and multi-modal projector)
hf_model_config = AutoConfig.from_pretrained(local_model_path)
self.dtype = hf_model_config.text_config.torch_dtype
module_dict = nn.ModuleDict({
"vision_tower":
AutoModel.from_config(hf_model_config.vision_config),
"multi_modal_projector":
LlavaNextMultiModalProjector(hf_model_config)
})
missing_keys, _ = load_sharded_checkpoint(module_dict,
local_model_path,
strict=False)
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
hf_vision_tower = module_dict["vision_tower"].to(self.dtype)
hf_mm_projector = module_dict["multi_modal_projector"].to(
self.dtype).to(self.device)
# Use TRTLLM vision tower(CLIPVisionModel)
vision_model_config = ModelConfig(
pretrained_config=model_config.vision_config, attn_backend="TRTLLM")
self.vision_tower = CLIPVisionModel(vision_model_config).to(
self.device).to(self.dtype)
self.vision_tower.load_weights(hf_vision_tower.state_dict())
# Use HF multi-modal projector
self.mm_projector = hf_mm_projector
@nvtx_range("[Vision] preprocess")
def _preprocess(self, images):
@ -44,9 +80,13 @@ class LlavaNextInputProcessor(InputProcessor):
@nvtx_range("[Vision] process")
def _process(self, pixel_values):
image_features = self.vision_tower(pixel_values,
output_hidden_states=True)
selected_image_feature = image_features.hidden_states[-2][:, 1:]
attn_metadata = self.vision_tower.prepare_attn_metadata(
pixel_values.shape[0])
image_features: Tuple[torch.Tensor] = self.vision_tower(
pixel_values,
attn_metadata=attn_metadata,
)
selected_image_feature = image_features[-2][:, 1:]
image_features = self.mm_projector(selected_image_feature)
return image_features.reshape(-1, image_features.shape[-1])

View File

@ -0,0 +1,111 @@
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from transformers.modeling_utils import (get_parameter_device,
get_parameter_dtype)
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import (SiglipVisionConfig,
SiglipVisionEmbeddings)
from ..attention_backend.interface import AttentionMetadata
from ..attention_backend.utils import get_attention_backend
from ..model_config import ModelConfig
from .modeling_clip import CLIPEncoder
from .modeling_utils import _load_weights_impl, register_auto_model
SiglipEncoder = CLIPEncoder
class SiglipVisionTransformer(nn.Module):
"""
This SiglipVisionTransformer is tailored for multimodal models that use Siglip as the vision encoder.
For example, it is different from the regular SiglipVisionTransformer in the sense that it does not return a pooled output.
"""
def __init__(self, model_config: ModelConfig[SiglipVisionConfig]):
super().__init__()
config = model_config.pretrained_config
self.config = config
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(model_config)
if hasattr(config, "vision_use_head"):
assert not config.vision_use_head, "Currently, we only support vision_use_head = False"
def forward(
self,
pixel_values,
attn_metadata: AttentionMetadata,
interpolate_pos_encoding: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
hidden_states = self.embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
# reshape hidden_states to (batch_size * seq_len, embed_dim)
hidden_states = hidden_states.reshape(
hidden_states.shape[0] * hidden_states.shape[1],
hidden_states.shape[2])
encoder_outputs: Tuple[torch.Tensor] = self.encoder(
inputs_embeds=hidden_states,
attn_metadata=attn_metadata,
)
return encoder_outputs
@register_auto_model("SiglipVisionModel")
class SiglipVisionModel(nn.Module):
def __init__(self, model_config: ModelConfig[SiglipVisionConfig]):
super().__init__()
self.config = model_config.pretrained_config
self.vision_model = SiglipVisionTransformer(model_config)
self.model_config = model_config
self.metadata_cls = get_attention_backend(
model_config.attn_backend).Metadata
def prepare_attn_metadata(self, batch_size):
"""
To simplify the usage of the model, this function aims to fill the metadata for Attention
Call this function before forward pass
"""
seq_len = (self.config.image_size // self.config.patch_size)**2
request_ids = list(range(1, batch_size + 1))
prompt_lens = [seq_len] * batch_size
attn_metadata = self.metadata_cls(
seq_lens=torch.tensor([seq_len] * batch_size, dtype=torch.int),
num_contexts=batch_size,
max_num_requests=batch_size,
max_num_tokens=seq_len * batch_size,
kv_cache_manager=None,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
attn_metadata.max_seq_len = seq_len * batch_size
attn_metadata.prepare()
return attn_metadata
@property
def dtype(self):
return get_parameter_dtype(self)
@property
def device(self):
return get_parameter_device(self)
@torch.inference_mode()
def forward(self, pixel_values, attn_metadata: AttentionMetadata):
return self.vision_model(
pixel_values=pixel_values,
attn_metadata=attn_metadata,
)
def load_weights(self, weights: Dict):
pattern_mapping = {
r'(.*?)out_proj(.*)': r'\1o_proj\2',
r'(.*?)fc1(.*)': r'\1up_proj\2',
r'(.*?)fc2(.*)': r'\1down_proj\2',
}
_load_weights_impl(self, weights, pattern_mapping)

View File

@ -510,73 +510,7 @@ class DecoderModelForCausalLM(nn.Module,
)
def load_weights(self, weights: Dict):
tp_size = self.model_config.mapping.tp_size
head_dim = getattr(
self.config, "head_dim",
self.config.hidden_size // self.config.num_attention_heads)
def filter_weights(prefix, weights: Dict):
result = {}
for k, v in weights.items():
if k.startswith(prefix):
new_k = k[len(prefix) + 1:]
result[new_k] = v
return result
params_map = {
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
'gate_up_proj': ['gate_proj', 'up_proj']
}
for name, module in tqdm(list(self.named_modules()),
desc="Loading weights"):
if len(module._parameters) > 0:
# skip load weights if tie word embeddings is enabled and layer is lm_head
if self.config.tie_word_embeddings and name.startswith(
"lm_head"):
continue
# Skip loading weights for embedding and lm_head if LoRA is enabled
if hasattr(self.model_config, 'lora_config'
) and self.model_config.lora_config is not None and (
name == "model.embed_tokens"
or name == "lm_head"):
continue
# Skip if parameter belongs to a missing layer
if missing_layer_parameter(name, self):
continue
names = name.split('.')
# WAR: better solution is that llama has its own load_weights function.
if names[-1] == 'next_layer_layernorm':
continue
if names[-1] in params_map:
module_weights = []
for new_name in params_map[names[-1]]:
fw = filter_weights('.'.join(names[:-1] + [new_name]),
weights)
if new_name in ['k_proj', 'v_proj']:
fw = {
k:
duplicate_kv_weight(
weight=v[:],
head_dim=head_dim,
tensor_parallel_size=tp_size)
if k in ["weight", "bias"] else v
for k, v in fw.items()
}
module_weights.append(fw)
module.load_weights(weights=module_weights)
else:
module_weights = filter_weights(name, weights)
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])
_load_weights_impl(self, weights)
def infer_max_seq_len(self) -> int:
# Modified from tensorrt_llm/builder.py _init_max_seq_len
@ -634,3 +568,129 @@ def get_model_architecture(
def support_pp(cls: Type) -> Type:
cls._supports_pp = True
return cls
def rename_weights_with_regex(pattern_mapping: Dict[str, str], weights: Dict):
"""
Rename weight keys according to regex pattern matching.
Args:
pattern_mapping: A dictionary mapping regex patterns to replacement strings. The key is HF name pattern, and the value is corresponding TRT-LLM name pattern.
The patterns will be used to match keys in the weights dict and replace
them according to the replacement string, which can use regex backreferences.
Example:
HF name: vision_model.encoder.layers.1.self_attn.out_proj.{weight,bias}
TRT-LLM name: vision_model.encoder.layers.1.self_attn.o_proj.{weight,bias}
Then the pattern_mapping could be:
pattern_mapping = {
r'(.*?)out_proj(.*)': r'\1o_proj\2'
}
weights: A dictionary of weights
Returns:
A dictionary of weights with renamed keys
"""
import re
# Create a new dictionary to store the renamed weights
renamed_weights = {}
# Keep track of keys that have been matched by a pattern
matched_keys = set()
# Process each key in the weights dictionary
for key in list(weights.keys()):
# Check each pattern for a match
for pattern, replacement in pattern_mapping.items():
if re.match(pattern, key):
# Create the new key by applying the regex replacement
new_key = re.sub(pattern, replacement, key)
# Store the weight with the new key
renamed_weights[new_key] = weights[key]
matched_keys.add(key)
break
# If the key wasn't matched by any pattern, keep it as is
if key not in matched_keys:
renamed_weights[key] = weights[key]
return renamed_weights
def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
weights: Dict,
params_map: Optional[Dict[str, str]] = None):
if not hasattr(model, 'model_config') or not isinstance(
model.model_config, ModelConfig):
raise ValueError("model must have a model_config attribute")
if not hasattr(model, 'config'):
raise ValueError("model must have a config attribute")
if params_map is not None:
weights = rename_weights_with_regex(params_map, weights)
logger.info(f"Renamed weights with params_map: {params_map}")
tp_size = model.model_config.mapping.tp_size
head_dim = getattr(
model.config, "head_dim",
model.config.hidden_size // model.config.num_attention_heads)
def filter_weights(prefix, weights: Dict):
result = {}
for k, v in weights.items():
if k.startswith(prefix):
new_k = k[len(prefix) + 1:]
result[new_k] = v
return result
params_map = {
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
'gate_up_proj': ['gate_proj', 'up_proj']
}
for name, module in tqdm(list(model.named_modules()),
desc="Loading weights"):
if len(module._parameters) > 0:
# skip load weights if tie word embeddings is enabled and layer is lm_head
if model.config.tie_word_embeddings and name.startswith("lm_head"):
continue
# Skip loading weights for embedding and lm_head if LoRA is enabled
if hasattr(model.model_config, 'lora_config'
) and model.model_config.lora_config is not None and (
name == "model.embed_tokens" or name == "lm_head"):
continue
# Skip if parameter belongs to a missing layer
if missing_layer_parameter(name, model):
continue
names = name.split('.')
# WAR: better solution is that llama has its own load_weights function.
if names[-1] == 'next_layer_layernorm':
continue
if names[-1] in params_map:
module_weights = []
for new_name in params_map[names[-1]]:
fw = filter_weights('.'.join(names[:-1] + [new_name]),
weights)
if new_name in ['k_proj', 'v_proj']:
fw = {
k:
duplicate_kv_weight(weight=v[:],
head_dim=head_dim,
tensor_parallel_size=tp_size)
if k in ["weight", "bias"] else v
for k, v in fw.items()
}
module_weights.append(fw)
module.load_weights(weights=module_weights)
else:
module_weights = filter_weights(name, weights)
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])

View File

@ -0,0 +1,129 @@
import unittest
from copy import deepcopy
from dataclasses import dataclass
import torch
from parameterized import parameterized
# Import CLIP specific classes from HF
from transformers import CLIPVisionConfig
from transformers import CLIPVisionModel as HFCLIPVisionModel
from tensorrt_llm._torch.model_config import ModelConfig
# Import TRT-LLM CLIP model
from tensorrt_llm._torch.models.modeling_clip import CLIPVisionModel
# Default CLIP config from HF (https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/clip/configuration_clip.py#L144-L172)
CLIP_CONFIG = {
"hidden_size": 768,
"intermediate_size": 3072,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"num_channels": 3,
"image_size": 224,
"patch_size": 32,
"hidden_act": "quick_gelu",
"layer_norm_eps": 1e-5,
"attention_dropout": 0.0,
"initializer_range": 0.02,
"initializer_factor": 1.0,
}
ACCURACY_CONFIG = {
torch.float16: (1e-2, 1e-2),
}
@dataclass(repr=False)
class Scenario:
backend: str
num_images: int
dtype: torch.dtype # Add dtype to scenario
def __repr__(self) -> str:
return f"backend:{self.backend.lower()}_num_images:{self.num_images}_dtype:{self.dtype}"
class TestCLIPVisionModel(unittest.TestCase):
def setUp(self):
super().setUp()
torch.random.manual_seed(720)
@parameterized.expand([
Scenario(backend="VANILLA", num_images=2, dtype=torch.float16),
Scenario(backend="TRTLLM", num_images=2, dtype=torch.float16),
Scenario(backend="TRTLLM", num_images=21, dtype=torch.float16),
], lambda testcase_func, param_num, param:
f"{testcase_func.__name__}[{param.args[0]}]")
def test_clip_vision_allclose_to_hf(self, scenario: Scenario):
"""Compare output to HF"""
backend = scenario.backend
num_images = scenario.num_images
dtype = scenario.dtype
device = torch.device('cuda')
# Create configs
config_dict = deepcopy(CLIP_CONFIG)
hf_config = CLIPVisionConfig.from_dict(config_dict)
# Prepare HF model
hf_model = HFCLIPVisionModel(hf_config).to(dtype).to(device).eval()
# Prepare tllm pytorch model
model_config = ModelConfig(
pretrained_config=hf_config,
attn_backend=backend,
)
tllm_model = CLIPVisionModel(model_config).to(dtype).to(device)
# Use the load_weights method we are testing
tllm_model.load_weights(hf_model.state_dict())
# Prepare inputs - create random pixel values for images
batch_size = num_images
pixel_values = torch.rand(batch_size,
hf_config.num_channels,
hf_config.image_size,
hf_config.image_size,
device=device,
dtype=dtype)
# Run HF inference
with torch.inference_mode():
hf_outputs = hf_model(
pixel_values=pixel_values,
output_attentions=False,
output_hidden_states=True) # Get hidden states for comparison
# Run TRT-LLM inference
attn_metadata = tllm_model.prepare_attn_metadata(batch_size)
tllm_outputs = tllm_model(
pixel_values=pixel_values,
attn_metadata=attn_metadata,
)
# Compare outputs
rtol, atol = ACCURACY_CONFIG[dtype]
# Compare all hidden states
for i, (hf_hs, tllm_hs) in enumerate(
zip(hf_outputs.hidden_states,
tllm_outputs)): # Iterate through tllm_outputs directly
self.assertEqual(hf_hs.shape, tllm_hs.shape,
f"Shape mismatch for hidden state {i}")
torch.testing.assert_close(
hf_hs.float(),
tllm_hs.float(),
rtol=rtol,
atol=atol,
msg=
f"FAILED: TRT-LLM and HF hidden_states mismatch for {dtype} with {num_images} images at layer {i}"
)
print(
f"PASSED: TRT-LLM and HF hidden_states match for {dtype} with {num_images} images at layer {i}"
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,125 @@
import unittest
from copy import deepcopy
from dataclasses import dataclass
import torch
from parameterized import parameterized
from transformers import SiglipVisionConfig
from transformers import SiglipVisionModel as HFSiglipVisionModel
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_siglip import SiglipVisionModel
# use the default config from HF (https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/configuration_siglip.py#L126-L147)
SIGLIP_CONFIG = {
"hidden_size": 768,
"intermediate_size": 3072,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"image_size": 224,
"patch_size": 16,
"hidden_act": "gelu_pytorch_tanh",
"layer_norm_eps": 1e-6,
"hidden_dropout_prob": 0.0,
"attention_probs_dropout_prob": 0.0,
"num_channels": 3,
"vision_use_head": False,
}
ACCURACY_CONFIG = {
torch.float16: (2e-2, 5e-2),
}
@dataclass(repr=False)
class Scenario:
backend: str
num_images: int
dtype: torch.dtype
def __repr__(self) -> str:
return f"backend:{self.backend.lower()}_num_images:{self.num_images}_dtype:{self.dtype}"
class TestSiglipVisionModel(unittest.TestCase):
def setUp(self):
super().setUp()
torch.random.manual_seed(1234)
@parameterized.expand([
Scenario(backend="VANILLA", num_images=2, dtype=torch.float16),
Scenario(backend="TRTLLM", num_images=2, dtype=torch.float16),
Scenario(backend="TRTLLM", num_images=21, dtype=torch.float16),
], lambda testcase_func, param_num, param:
f"{testcase_func.__name__}[{param.args[0]}]")
def test_siglip_vision_allclose_to_hf(self, scenario: Scenario):
"""Compare output to HF"""
backend = scenario.backend
num_images = scenario.num_images
dtype = scenario.dtype
device = torch.device('cuda')
# Create configs
config_dict = deepcopy(SIGLIP_CONFIG)
hf_config = SiglipVisionConfig.from_dict(config_dict)
# Prepare HF model
hf_model = HFSiglipVisionModel(hf_config).to(dtype).to(device).eval()
# Prepare tllm pytorch model
model_config = ModelConfig(
pretrained_config=hf_config,
attn_backend=backend,
)
tllm_model = SiglipVisionModel(model_config).to(dtype).to(device)
tllm_model.load_weights(hf_model.state_dict())
# Prepare inputs - create random pixel values for images
batch_size = num_images
pixel_values = torch.rand(batch_size,
hf_config.num_channels,
hf_config.image_size,
hf_config.image_size,
device=device,
dtype=dtype)
# Run HF inference
with torch.inference_mode():
# HF model forward
hf_outputs = hf_model(pixel_values=pixel_values,
output_attentions=False,
output_hidden_states=True)
# Fill the metadata for tllm attn
attn_metadata = tllm_model.prepare_attn_metadata(batch_size)
# TRT-LLM model forward
tllm_outputs = tllm_model(
pixel_values=pixel_values,
attn_metadata=attn_metadata,
)
# Compare all hidden states
for i, (hf_hs, tllm_hs) in enumerate(
zip(hf_outputs.hidden_states, tllm_outputs)):
self.assertEqual(hf_hs.shape, tllm_hs.shape,
f"Shape mismatch for hidden state {i}")
torch.testing.assert_close(
hf_hs.float(),
tllm_hs.float(),
rtol=ACCURACY_CONFIG[dtype][0],
atol=ACCURACY_CONFIG[dtype][1],
msg=
f"FAILED: TRT-LLM and HF hidden_states mismatch for {dtype} with {num_images} images at layer {i}, the mean value of this layer is {hf_hs.mean()}"
)
print(
f"PASSED: TRT-LLM and HF hidden_states match for {dtype} with {num_images} images at layer {i}, the mean value of this layer is {hf_hs.mean()}"
)
if __name__ == "__main__":
unittest.main()