Qwen cherrypick1

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
Faraz Khoubsirat 2025-09-10 02:18:49 +00:00
parent 1010e15bb5
commit 2c41fc54d2
No known key found for this signature in database
GPG Key ID: 15733A5323348457

View File

@ -7,7 +7,10 @@ import torch.nn as nn
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
PreTrainedModel, Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import \
Qwen2_5_VisionTransformerPretrainedModel
from transformers.models.qwen2_vl.modeling_qwen2_vl import \
Qwen2VisionTransformerPretrainedModel
from tensorrt_llm.inputs.multimodal import MultimodalParams
@ -358,14 +361,21 @@ class Qwen2VisionModelBase(nn.Module):
# TODO: Change the model class to TRT-LLM's Qwen2VisionModel
# Currently, copying vision encoder on all devices.
# NOTE: Using attn_implementation='flash_attention_2' to avoid the issue of vision model's GPU OOM.
model = model_class.from_pretrained(
model_path,
torch_dtype=pretrained_config.torch_dtype,
attn_implementation='flash_attention_2').eval()
hf_model_config = AutoConfig.from_pretrained(model_path)
vision_model = model_class(config=hf_model_config.vision_config,
torch_dtype=pretrained_config.torch_dtype,
attn_implementation='flash_attention_2')
# TODO: Make vision model compatible with meta init mode and load_weights at the same place
self.visual = model.visual.to(self.device)
self.visual = vision_model.to(self.device)
self.post_config()
def load_weights(self, weights):
filtered_weights = {
k.replace('visual.', ''): v
for k, v in weights.items() if k.startswith('visual.')
}
self.visual.load_state_dict(filtered_weights)
def post_config(self):
self.config = self.visual.config
@ -503,6 +513,7 @@ class Qwen2VLModelBase(PreTrainedModel):
def load_weights(self, weights):
self.llm.load_weights(weights)
self.mm_encoder.load_weights(weights)
self.init_rotary_cos_sin_ori()
def infer_max_seq_len(self) -> int:
@ -675,7 +686,7 @@ class Qwen2VLModel(Qwen2VLModelBase):
super().__init__(model_config, *args, **kwargs)
if not DISAGG:
self.mm_encoder = Qwen2VisionModelBase(
model_config, Qwen2VLForConditionalGeneration)
model_config, Qwen2VisionTransformerPretrainedModel)
@register_vision_encoder(Qwen2VisionModelBase,
@ -696,4 +707,4 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):
super().__init__(model_config, *args, **kwargs)
if not DISAGG:
self.mm_encoder = Qwen2VisionModelBase(
model_config, Qwen2_5_VLForConditionalGeneration)
model_config, Qwen2_5_VisionTransformerPretrainedModel)