mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-15 23:44:02 +08:00
Qwen cherrypick2
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
parent
2c41fc54d2
commit
3fcfab0883
@ -4,8 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
|
||||
PreTrainedModel, Qwen2_5_VLForConditionalGeneration,
|
||||
from transformers import (AutoConfig, AutoProcessor, AutoTokenizer,
|
||||
PretrainedConfig, PreTrainedModel,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration)
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import \
|
||||
Qwen2_5_VisionTransformerPretrainedModel
|
||||
@ -362,9 +363,10 @@ class Qwen2VisionModelBase(nn.Module):
|
||||
# Currently, copying vision encoder on all devices.
|
||||
# NOTE: Using attn_implementation='flash_attention_2' to avoid the issue of vision model's GPU OOM.
|
||||
hf_model_config = AutoConfig.from_pretrained(model_path)
|
||||
vision_model = model_class(config=hf_model_config.vision_config,
|
||||
torch_dtype=pretrained_config.torch_dtype,
|
||||
attn_implementation='flash_attention_2')
|
||||
vision_model = model_class._from_config(
|
||||
hf_model_config.vision_config,
|
||||
torch_dtype=pretrained_config.torch_dtype,
|
||||
attn_implementation='flash_attention_2')
|
||||
# TODO: Make vision model compatible with meta init mode and load_weights at the same place
|
||||
self.visual = vision_model.to(self.device)
|
||||
self.post_config()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user