Qwen cherrypick2

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
Faraz Khoubsirat 2025-09-10 02:19:06 +00:00
parent 2c41fc54d2
commit 3fcfab0883
No known key found for this signature in database
GPG Key ID: 15733A5323348457

View File

@ -4,8 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
PreTrainedModel, Qwen2_5_VLForConditionalGeneration,
from transformers import (AutoConfig, AutoProcessor, AutoTokenizer,
PretrainedConfig, PreTrainedModel,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import \
Qwen2_5_VisionTransformerPretrainedModel
@ -362,9 +363,10 @@ class Qwen2VisionModelBase(nn.Module):
# Currently, copying vision encoder on all devices.
# NOTE: Using attn_implementation='flash_attention_2' to avoid the issue of vision model's GPU OOM.
hf_model_config = AutoConfig.from_pretrained(model_path)
vision_model = model_class(config=hf_model_config.vision_config,
torch_dtype=pretrained_config.torch_dtype,
attn_implementation='flash_attention_2')
vision_model = model_class._from_config(
hf_model_config.vision_config,
torch_dtype=pretrained_config.torch_dtype,
attn_implementation='flash_attention_2')
# TODO: Make vision model compatible with meta init mode and load_weights at the same place
self.visual = vision_model.to(self.device)
self.post_config()