mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
Qwen cherrypick1
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
parent
1010e15bb5
commit
2c41fc54d2
@ -7,7 +7,10 @@ import torch.nn as nn
|
||||
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
|
||||
PreTrainedModel, Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration)
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import \
|
||||
Qwen2_5_VisionTransformerPretrainedModel
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import \
|
||||
Qwen2VisionTransformerPretrainedModel
|
||||
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
|
||||
@ -358,14 +361,21 @@ class Qwen2VisionModelBase(nn.Module):
|
||||
# TODO: Change the model class to TRT-LLM's Qwen2VisionModel
|
||||
# Currently, copying vision encoder on all devices.
|
||||
# NOTE: Using attn_implementation='flash_attention_2' to avoid the issue of vision model's GPU OOM.
|
||||
model = model_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=pretrained_config.torch_dtype,
|
||||
attn_implementation='flash_attention_2').eval()
|
||||
hf_model_config = AutoConfig.from_pretrained(model_path)
|
||||
vision_model = model_class(config=hf_model_config.vision_config,
|
||||
torch_dtype=pretrained_config.torch_dtype,
|
||||
attn_implementation='flash_attention_2')
|
||||
# TODO: Make vision model compatible with meta init mode and load_weights at the same place
|
||||
self.visual = model.visual.to(self.device)
|
||||
self.visual = vision_model.to(self.device)
|
||||
self.post_config()
|
||||
|
||||
def load_weights(self, weights):
|
||||
filtered_weights = {
|
||||
k.replace('visual.', ''): v
|
||||
for k, v in weights.items() if k.startswith('visual.')
|
||||
}
|
||||
self.visual.load_state_dict(filtered_weights)
|
||||
|
||||
def post_config(self):
|
||||
self.config = self.visual.config
|
||||
|
||||
@ -503,6 +513,7 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
|
||||
def load_weights(self, weights):
|
||||
self.llm.load_weights(weights)
|
||||
self.mm_encoder.load_weights(weights)
|
||||
self.init_rotary_cos_sin_ori()
|
||||
|
||||
def infer_max_seq_len(self) -> int:
|
||||
@ -675,7 +686,7 @@ class Qwen2VLModel(Qwen2VLModelBase):
|
||||
super().__init__(model_config, *args, **kwargs)
|
||||
if not DISAGG:
|
||||
self.mm_encoder = Qwen2VisionModelBase(
|
||||
model_config, Qwen2VLForConditionalGeneration)
|
||||
model_config, Qwen2VisionTransformerPretrainedModel)
|
||||
|
||||
|
||||
@register_vision_encoder(Qwen2VisionModelBase,
|
||||
@ -696,4 +707,4 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):
|
||||
super().__init__(model_config, *args, **kwargs)
|
||||
if not DISAGG:
|
||||
self.mm_encoder = Qwen2VisionModelBase(
|
||||
model_config, Qwen2_5_VLForConditionalGeneration)
|
||||
model_config, Qwen2_5_VisionTransformerPretrainedModel)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user