mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Refactor Llava-Next (#6478)
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
This commit is contained in:
parent
f92397493e
commit
c17f4984e2
@ -1,16 +1,19 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (AutoConfig, AutoModel, AutoProcessor, AutoTokenizer,
|
||||
LlavaNextConfig, PretrainedConfig, PreTrainedModel)
|
||||
from transformers.modeling_utils import load_sharded_checkpoint
|
||||
from transformers.models.llava_next.modeling_llava_next import \
|
||||
LlavaNextMultiModalProjector
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
LlavaNextMultiModalProjector, get_anyres_image_grid_shape,
|
||||
image_size_to_num_patches, unpad_image)
|
||||
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
|
||||
from ..._utils import nvtx_range
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...llmapi.utils import download_hf_model
|
||||
@ -23,13 +26,15 @@ from .modeling_clip import CLIPVisionModel
|
||||
from .modeling_multimodal_utils import fuse_input_embeds
|
||||
from .modeling_utils import ModelConfig, filter_weights, register_auto_model
|
||||
|
||||
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
|
||||
|
||||
|
||||
class LlavaNextInputProcessor(InputProcessor):
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
model_config,
|
||||
tokenizer,
|
||||
model_path: str,
|
||||
model_config: PretrainedConfig,
|
||||
tokenizer: AutoTokenizer,
|
||||
trust_remote_code: bool = True):
|
||||
self.tokenizer = tokenizer
|
||||
self.use_fast = True
|
||||
@ -44,7 +49,50 @@ class LlavaNextInputProcessor(InputProcessor):
|
||||
use_fast=self.use_fast)
|
||||
self.model_config = model_config
|
||||
|
||||
self.device = 'cuda'
|
||||
self.image_token_index = model_config.image_token_index
|
||||
self.vocab_size = model_config.vocab_size
|
||||
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self, inputs: TextPrompt, sampling_params: SamplingParams
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
text_prompt, mm_data = inputs.get("prompt"), inputs.get(
|
||||
"multi_modal_data", {})
|
||||
# Preprocess
|
||||
images = mm_data.get('image', [])
|
||||
if not images:
|
||||
return self.processor.tokenizer(
|
||||
text_prompt,
|
||||
return_tensors="pt").input_ids[0].to(torch.int32).tolist(), {}
|
||||
|
||||
processed_values = self.processor(
|
||||
text=text_prompt,
|
||||
images=images,
|
||||
do_rescale=not (images and isinstance(images[0], torch.Tensor)),
|
||||
return_tensors="pt")
|
||||
# Postprocess
|
||||
fused_input_ids = processed_values['input_ids'][0]
|
||||
fused_input_ids[fused_input_ids ==
|
||||
self.image_token_index] = self.vocab_size + 1
|
||||
|
||||
multimodal_data = {}
|
||||
multimodal_data["image"] = {
|
||||
"pixel_values": processed_values['pixel_values'],
|
||||
"image_sizes": processed_values['image_sizes'],
|
||||
}
|
||||
return fused_input_ids.to(torch.int32).tolist(), {
|
||||
"multimodal_data": multimodal_data
|
||||
}
|
||||
|
||||
|
||||
class LlavaNextVisionModel(nn.Module):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.pretrained_config = model_config.pretrained_config
|
||||
self.device = f"cuda:{model_config.mapping.rank}"
|
||||
model_path = self.pretrained_config._name_or_path
|
||||
|
||||
# Determine the actual local path for model files
|
||||
if os.path.isdir(model_path):
|
||||
@ -61,6 +109,10 @@ class LlavaNextInputProcessor(InputProcessor):
|
||||
"multi_modal_projector":
|
||||
LlavaNextMultiModalProjector(hf_model_config)
|
||||
})
|
||||
module_dict.register_parameter(
|
||||
"image_newline",
|
||||
nn.Parameter(torch.empty(hf_model_config.text_config.hidden_size)))
|
||||
|
||||
missing_keys, _ = load_sharded_checkpoint(module_dict,
|
||||
local_model_path,
|
||||
strict=False)
|
||||
@ -68,6 +120,8 @@ class LlavaNextInputProcessor(InputProcessor):
|
||||
hf_vision_tower = module_dict["vision_tower"].to(self.dtype)
|
||||
hf_mm_projector = module_dict["multi_modal_projector"].to(
|
||||
self.dtype).to(self.device)
|
||||
hf_image_newline = module_dict.image_newline.to(self.dtype).to(
|
||||
self.device)
|
||||
|
||||
# For A100 GPU, fallback to HF vision tower due to accuracy issue in TRT-LLM CLIPAttention
|
||||
# Otherwise, use TRTLLM vision tower(CLIPVisionModel)
|
||||
@ -78,7 +132,7 @@ class LlavaNextInputProcessor(InputProcessor):
|
||||
self.vision_tower = hf_vision_tower.to(self.device)
|
||||
else:
|
||||
vision_model_config = ModelConfig(
|
||||
pretrained_config=model_config.vision_config,
|
||||
pretrained_config=model_config.pretrained_config.vision_config,
|
||||
attn_backend="TRTLLM")
|
||||
self.vision_tower = CLIPVisionModel(vision_model_config).to(
|
||||
self.device).to(self.dtype)
|
||||
@ -86,179 +140,126 @@ class LlavaNextInputProcessor(InputProcessor):
|
||||
|
||||
# Use HF multi-modal projector
|
||||
self.mm_projector = hf_mm_projector
|
||||
self.image_newline = hf_image_newline
|
||||
self.vision_feature_select_strategy = getattr(
|
||||
model_config.pretrained_config, "vision_feature_select_strategy",
|
||||
"default")
|
||||
|
||||
@nvtx_range("[Vision] preprocess")
|
||||
def _preprocess(self, images):
|
||||
return [
|
||||
self.processor(text="dummy",
|
||||
images=image,
|
||||
do_rescale=not isinstance(images[0], torch.Tensor),
|
||||
return_tensors="pt",
|
||||
device=self.device)['pixel_values'][0].to(
|
||||
self.device) for image in images
|
||||
# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L284
|
||||
def pack_image_features(self,
|
||||
image_features,
|
||||
image_sizes,
|
||||
vision_feature_select_strategy,
|
||||
image_newline=None):
|
||||
new_image_features = []
|
||||
feature_lens = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = self.pretrained_config.vision_config.image_size // self.pretrained_config.vision_config.patch_size
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.pretrained_config.image_grid_pinpoints,
|
||||
self.pretrained_config.vision_config.image_size,
|
||||
)
|
||||
|
||||
if (np.prod(image_feature.shape) %
|
||||
(num_patch_height * num_patch_width * height * width) != 0
|
||||
and vision_feature_select_strategy == "default"):
|
||||
logger.warning_once(
|
||||
"Image feature shape does not line up with the provided patch size. "
|
||||
"You may be using the `default` vision_feature_select_strategy with a"
|
||||
" visual encoder that does not have CLS.")
|
||||
|
||||
image_feature = image_feature.view(num_patch_height,
|
||||
num_patch_width, height,
|
||||
width, -1)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1,
|
||||
3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature,
|
||||
image_sizes[image_idx])
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1).to(
|
||||
image_feature.device, image_feature.dtype),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature),
|
||||
dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat(
|
||||
(image_feature, image_newline[None].to(image_feature)),
|
||||
dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
feature_lens.append(image_feature.size(0))
|
||||
feature_lens = torch.tensor(feature_lens,
|
||||
dtype=torch.long,
|
||||
device=image_features[0].device)
|
||||
return new_image_features, feature_lens
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, multimodal_params: List[MultimodalParams]):
|
||||
pixel_values = [
|
||||
multimodal_param.multimodal_data["image"]["pixel_values"]
|
||||
for multimodal_param in multimodal_params
|
||||
]
|
||||
image_sizes = [
|
||||
multimodal_param.multimodal_data["image"]["image_sizes"]
|
||||
for multimodal_param in multimodal_params
|
||||
]
|
||||
pixel_values = torch.cat(pixel_values, dim=0)
|
||||
image_sizes = torch.cat(image_sizes, dim=0)
|
||||
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.pretrained_config.image_grid_pinpoints,
|
||||
patch_size=self.pretrained_config.vision_config.image_size,
|
||||
) for imsize in image_sizes
|
||||
]
|
||||
|
||||
@nvtx_range("[Vision] process")
|
||||
def _process(self, pixel_values):
|
||||
if pixel_values.dim() == 5:
|
||||
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [
|
||||
pix_val[:num_patch]
|
||||
for pix_val, num_patch in zip(pixel_values, image_num_patches)
|
||||
]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
|
||||
if self.use_hf_vision_tower:
|
||||
image_features = self.vision_tower(
|
||||
pixel_values, output_hidden_states=True).hidden_states
|
||||
else:
|
||||
attn_metadata = self.vision_tower.prepare_attn_metadata(
|
||||
pixel_values.shape[0])
|
||||
image_features: Tuple[torch.Tensor] = self.vision_tower(
|
||||
image_features = self.vision_tower(
|
||||
pixel_values,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
selected_image_feature = image_features[-2][:, 1:]
|
||||
image_features = self.mm_projector(selected_image_feature)
|
||||
return image_features.reshape(-1, image_features.shape[-1])
|
||||
|
||||
@nvtx_range("[Vision] postprocess")
|
||||
def _postprocess(self, input_ids, mm_features):
|
||||
# Define model specific variables here before shared logic
|
||||
mm_tokens = torch.tensor([self.model_config.image_token_index
|
||||
]).to(input_ids.device)
|
||||
model_hidden_size = self.model_config.text_config.hidden_size
|
||||
vocab_size = self.model_config.text_config.vocab_size
|
||||
start_len = end_len = 0 # for llava, need not append start/end token around each image token
|
||||
# End model specific variables
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
|
||||
## find mm token positions in input_ids
|
||||
mm_token_positions = torch.where(torch.isin(input_ids, mm_tokens))[0]
|
||||
num_medias = num_mm_tokens = len(mm_token_positions)
|
||||
if num_medias > 1 and isinstance(mm_features, torch.Tensor):
|
||||
mm_features = list(
|
||||
mm_features.split(mm_features.shape[0] // num_medias))
|
||||
|
||||
if isinstance(mm_features, torch.Tensor):
|
||||
# 1 prompt + 1 media
|
||||
# "split" means what a single mm_token in the input_ids should represent
|
||||
# image: one split --> one frame
|
||||
# video: one split --> N frames
|
||||
num_frames, mm_feature_length, mm_hidden_dim = mm_features.shape
|
||||
mm_lengths_per_split = [mm_feature_length * num_frames]
|
||||
mm_lengths_per_frame = [mm_feature_length]
|
||||
elif isinstance(mm_features, list):
|
||||
# 1 prompt + N media
|
||||
num_frames = len(mm_features) if mm_features[0].dim() == 2 else sum(
|
||||
[f.shape[0] for f in mm_features])
|
||||
mm_lengths_per_split = [
|
||||
f.shape[0] if f.dim() == 2 else f.shape[0] * f.shape[1]
|
||||
for f in mm_features
|
||||
]
|
||||
mm_lengths_per_frame = [
|
||||
f.shape[0] if f.dim() == 2 else f.shape[1] for f in mm_features
|
||||
]
|
||||
mm_hidden_dim = mm_features[0].shape[-1]
|
||||
mm_features = torch.cat(mm_features, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid multimodal features type: {type(mm_features)}")
|
||||
mm_total_length = sum(mm_lengths_per_split)
|
||||
assert mm_hidden_dim == model_hidden_size, "Multimodal embedding_dim must match model hidden_size"
|
||||
|
||||
## split input_ids into segments by isolating mm tokens
|
||||
mm_split_positions = torch.cat(
|
||||
[mm_token_positions, mm_token_positions + 1]).unique()
|
||||
input_ids_splits = list(input_ids.tensor_split(mm_split_positions.cpu(
|
||||
))) # len(input_ids_splits) = num_segments after mm tokens are isolated
|
||||
mm_ids_splits = list(
|
||||
torch.arange(vocab_size,
|
||||
vocab_size + mm_total_length,
|
||||
device=input_ids.device).split(mm_lengths_per_split)
|
||||
) # len(mm_ids_splits) = num_mm_segments
|
||||
|
||||
for i, mm_ids in enumerate(mm_ids_splits):
|
||||
mm_ids = mm_ids.reshape(-1, mm_lengths_per_frame[i])
|
||||
mm_ids_splits[i] = mm_ids.flatten()
|
||||
|
||||
## replace mm token ids with the expanded out-of-vocab ids
|
||||
mm_split_idx = 0
|
||||
for i, split in enumerate(input_ids_splits):
|
||||
if torch.isin(split, mm_tokens).any().item():
|
||||
input_ids_splits[i] = mm_ids_splits[mm_split_idx]
|
||||
mm_split_idx += 1
|
||||
assert mm_split_idx == len(
|
||||
mm_ids_splits), "All mm_ids_splits should be consumed"
|
||||
|
||||
## concat text & mm input_ids, wrap mm feature in prompt tuning config
|
||||
fused_input_ids = torch.cat(input_ids_splits).to(
|
||||
device=input_ids.device)
|
||||
fused_length = len(input_ids) + mm_total_length + num_frames * (
|
||||
start_len + end_len) - num_medias
|
||||
assert len(
|
||||
fused_input_ids
|
||||
) == fused_length, f"Fused input_ids length {len(fused_input_ids)} should match the sum of text and multimodal embedding lengths {fused_length}"
|
||||
|
||||
# [num_frames, feature_length, hidden_dim] -> [num_frames * feature_length, hidden_dim]
|
||||
mm_features = mm_features.view(-1, mm_features.shape[-1])
|
||||
return fused_input_ids, mm_features
|
||||
|
||||
def attach_multimodal_embeddings(
|
||||
self, inputs: TextPrompt,
|
||||
multimodal_embedding: Dict[str, List[torch.Tensor]],
|
||||
sampling_params: SamplingParams
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
"""
|
||||
Attach pre-processed multimodal embeddings into text token stream for LlavaNext model.
|
||||
|
||||
This method skips vision processing and works with externally provided embeddings.
|
||||
It replaces/expands image placeholders in the text with appropriate tokens and prepares
|
||||
the embeddings for model forward pass.
|
||||
|
||||
Args:
|
||||
inputs: Text prompt containing image placeholders
|
||||
multimodal_embedding: Dictionary containing pre-processed image embedding data
|
||||
Returns:
|
||||
Tuple of (token_ids, extra_processed_inputs) where:
|
||||
- token_ids: List of processed token IDs with image placeholders
|
||||
- extra_processed_inputs: Optional dictionary containing multimodal embeddings
|
||||
"""
|
||||
text_prompt = inputs.get("prompt")
|
||||
if not text_prompt:
|
||||
raise ValueError("Text prompt is required but not provided")
|
||||
|
||||
if not isinstance(multimodal_embedding, dict):
|
||||
raise ValueError("multimodal_embedding must be a dictionary")
|
||||
|
||||
if 'image' not in multimodal_embedding:
|
||||
raise ValueError(
|
||||
"Only image modality is supported for external multimodal embedding"
|
||||
)
|
||||
|
||||
input_ids = self.tokenizer(
|
||||
text_prompt, return_tensors="pt").input_ids[0].to(self.device)
|
||||
mm_features = torch.stack(multimodal_embedding['image'])
|
||||
fused_input_ids, mm_features = self._postprocess(input_ids, mm_features)
|
||||
multimodal_data = {}
|
||||
multimodal_data["multimodal_embedding"] = mm_features
|
||||
return fused_input_ids.to(torch.int32).tolist(), {
|
||||
"multimodal_data": multimodal_data
|
||||
}
|
||||
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self, inputs: TextPrompt, sampling_params: SamplingParams
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
text_prompt, mm_data = inputs.get("prompt"), inputs.get(
|
||||
"multi_modal_data", {})
|
||||
|
||||
input_ids = self.tokenizer(
|
||||
text_prompt, return_tensors="pt").input_ids[0].to(self.device)
|
||||
|
||||
if not mm_data:
|
||||
return input_ids.to(torch.int32).tolist(), {}
|
||||
|
||||
mm_tensor = self._preprocess(mm_data['image'])
|
||||
mm_features = torch.stack(
|
||||
[self._process(tensor) for tensor in mm_tensor])
|
||||
fused_input_ids, mm_features = self._postprocess(input_ids, mm_features)
|
||||
multimodal_data = {}
|
||||
multimodal_data["multimodal_embedding"] = mm_features
|
||||
return fused_input_ids.to(torch.int32).tolist(), {
|
||||
"multimodal_data": multimodal_data
|
||||
}
|
||||
# NOTE: 'pack_image_features' is directly copied from the HF's code
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
return [image_features]
|
||||
|
||||
|
||||
@register_auto_model("LlavaNextForConditionalGeneration")
|
||||
@ -272,6 +273,8 @@ class LlavaNextModel(PreTrainedModel):
|
||||
super().__init__(config)
|
||||
if hasattr(self, "llm"):
|
||||
return
|
||||
if not DISAGG:
|
||||
self.mm_encoder = LlavaNextVisionModel(model_config)
|
||||
|
||||
llm_model_config = copy.deepcopy(model_config)
|
||||
llm_model_config.pretrained_config = model_config.pretrained_config.text_config
|
||||
@ -293,7 +296,6 @@ class LlavaNextModel(PreTrainedModel):
|
||||
self.is_loaded = True
|
||||
|
||||
def load_weights(self, weights):
|
||||
|
||||
weights = filter_weights("language_model", weights)
|
||||
self.llm.load_weights(weights)
|
||||
|
||||
@ -320,13 +322,16 @@ class LlavaNextModel(PreTrainedModel):
|
||||
multimodal_params = kwargs.get("multimodal_params", [])
|
||||
mm_embeds = []
|
||||
if len(multimodal_params) > 0:
|
||||
mm_embeds = [
|
||||
multimodal_param.multimodal_data["multimodal_embedding"]
|
||||
for multimodal_param in multimodal_params
|
||||
]
|
||||
|
||||
if not DISAGG:
|
||||
mm_embeds = self.mm_encoder.forward(multimodal_params)
|
||||
else:
|
||||
mm_embeds = [
|
||||
multimodal_param.multimodal_data["multimodal_embedding"]
|
||||
for multimodal_param in multimodal_params
|
||||
]
|
||||
input_ids, inputs_embeds = fuse_input_embeds(
|
||||
self.llm.model.embed_tokens, input_ids, mm_embeds)
|
||||
|
||||
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
|
||||
inputs_embeds, return_context_logits)
|
||||
return logits
|
||||
|
||||
@ -2150,10 +2150,7 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
|
||||
"llava-v1.6-mistral-7b": {
|
||||
"image": [
|
||||
["ocean", "sky", "large", "waves", "shore", "blue"],
|
||||
[
|
||||
"landscape", "rock", "landmark", "formation", "smooth",
|
||||
"mountain"
|
||||
],
|
||||
['mountain', 'flat', 'dome', 'formation', 'sky'],
|
||||
["highway", "vehicles", "traffic", "bus", "suburban"],
|
||||
],
|
||||
},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user