"""Multimodal utilities for handling images and other media types in TensorRT-LLM.""" from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import PIL import torch from blake3 import blake3 from torchvision.transforms import ToPILImage # Default hasher default_hasher = blake3 @dataclass class MultimodalInput: multimodal_hashes: List[List[int]] """Hash values for multimodal data items (e.g., images). Each element is a list of 8 integers representing the hash digest of a multimodal item. """ multimodal_positions: List[int] """Starting positions of each multimodal chunk in the token sequence. Contains only the start position of each chunk, not all positions of multimodal tokens. This is different from mm_positions elsewhere which contains all positions. """ multimodal_lengths: List[int] """Length (number of tokens) of each multimodal item. Combined with multimodal_positions, this defines the token spans for each multimodal item. """ def __post_init__(self): """Validate input data structure and consistency.""" # Validate multimodal_hashes if not isinstance(self.multimodal_hashes, list): raise TypeError("multimodal_hashes must be a list") # Check that hashes are lists of consistent length containing integers if not all(isinstance(h, list) for h in self.multimodal_hashes): raise TypeError("Each element in multimodal_hashes must be a list") # Check consistent length of hash arrays hash_lengths = [len(h) for h in self.multimodal_hashes] if min(hash_lengths) != max(hash_lengths): raise ValueError( f"All hash arrays must have the same length, got lengths: {hash_lengths}" ) # Check that positions and lengths are valid if not all(isinstance(x, int) for x in self.multimodal_positions): raise TypeError("multimodal_positions must contain only integers") if not all(isinstance(x, int) for x in self.multimodal_lengths): raise TypeError("multimodal_lengths must contain only integers") # Check position and length arrays match in size if len(self.multimodal_positions) != len(self.multimodal_lengths): raise ValueError( f"Position and length arrays must match in size: " f"positions={len(self.multimodal_positions)}, lengths={len(self.multimodal_lengths)}" ) @classmethod def from_components(cls, mm_hashes: List[List[int]], mm_positions: List[int], mm_lengths: List[int]) -> 'MultimodalInput': return cls(multimodal_hashes=mm_hashes, multimodal_positions=mm_positions, multimodal_lengths=mm_lengths) def to_tensor(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Convert data to tensors""" return ( # int32 to match the type in TRTLLM SizeType32 torch.tensor(self.multimodal_hashes, dtype=torch.int32), torch.tensor(self.multimodal_positions, dtype=torch.int32), torch.tensor(self.multimodal_lengths, dtype=torch.int32)) @dataclass class MultimodalParams: """Unified container for multimodal parameters. This class encapsulates all multimodal-related data that flows through the system, providing a clean interface for handling multimodal inputs across different models. Attributes: multimodal_input: Multimodal input data with hashing information. multimodal_data: Processed multimodal data containing embeddings, configurations, and modality-specific data organized by type. Structure of multimodal_data: { "mrope_config": { "mrope_rotary_cos_sin": torch.Tensor, # Rotary embeddings (Qwen2/2.5-VL) "mrope_position_deltas": torch.Tensor, # Position deltas (Qwen2/2.5-VL) }, "multimodal_embedding": torch.Tensor, # Pre-computed vision embeddings "image": { "pixel_values": torch.Tensor, "image_height": torch.Tensor | List[int], "image_width": torch.Tensor | List[int], }, "video": { "pixel_values": torch.Tensor, "video_height": torch.Tensor | List[int], "video_width": torch.Tensor | List[int], }, # ... other modalities } """ multimodal_input: Optional[MultimodalInput] = None multimodal_data: Optional[Dict[str, Any]] = field(default_factory=dict) def __post_init__(self): """Ensure default values are properly set.""" if self.multimodal_data is None: self.multimodal_data = {} def to_device(self, element: str, device: str, pin_memory: bool = False) -> None: """Move specified multimodal data element to target device. Args: element: Element to move ("multimodal_data" or "multimodal_input") device: Target device (e.g., "cuda", "cpu") pin_memory: Whether to pin memory for faster transfers """ def _to_device( input_tensor: Union[torch.Tensor, List, dict, None], pin_memory: bool = False, ) -> Union[torch.Tensor, List, dict, None]: if input_tensor is None: return None elif isinstance(input_tensor, list): return [_to_device(item, pin_memory) for item in input_tensor] elif isinstance(input_tensor, dict): return { key: _to_device(value, pin_memory) for key, value in input_tensor.items() } elif isinstance(input_tensor, torch.Tensor): if pin_memory and input_tensor.device.type == 'cpu': return input_tensor.pin_memory().to(device, non_blocking=True) else: return input_tensor.to(device, non_blocking=True) else: return input_tensor if element == "multimodal_data": self.multimodal_data = _to_device(self.multimodal_data, pin_memory) elif element == "multimodal_input": self.multimodal_input = _to_device(self.multimodal_input, pin_memory) else: print( f"MultimodalParams: Unsupported element '{element}' to move to device. " f"Supported elements: 'multimodal_data', 'multimodal_input'") def strip_for_context(self) -> None: """Strip multimodal data for context processing. Removes only mrope_position_deltas while keeping all other multimodal data (embeddings, images, etc.) needed for context phase processing. """ if not (self.multimodal_data and 'mrope_config' in self.multimodal_data): return mrope_config = self.multimodal_data['mrope_config'] if 'mrope_position_deltas' in mrope_config: del mrope_config['mrope_position_deltas'] # Clean up empty mrope_config if not mrope_config: del self.multimodal_data['mrope_config'] def strip_for_generation(self) -> None: """Strip multimodal data for generation processing. Keeps only mrope_position_deltas and removes all other multimodal data (embeddings, images, etc.) as they're not needed during generation. """ if not self.multimodal_data: return # Extract mrope_position_deltas before clearing mrope_position_deltas = None if 'mrope_config' in self.multimodal_data: mrope_config = self.multimodal_data['mrope_config'] if isinstance(mrope_config, dict) and 'mrope_position_deltas' in mrope_config: mrope_position_deltas = mrope_config['mrope_position_deltas'] # Clear all data and restore only position deltas if they exist self.multimodal_data = {} if mrope_position_deltas is not None: self.multimodal_data['mrope_config'] = { 'mrope_position_deltas': mrope_position_deltas } def has_content(self) -> bool: """Check if this object contains any multimodal data.""" return bool(self.multimodal_input or self.multimodal_data) # adopt from vllm : https://github.com/vllm-project/vllm/blob/main/vllm/vllm/multimodal/hash.py def serialize_item(obj: object) -> bytes: # Simple cases if isinstance(obj, str): return obj.encode("utf-8") if isinstance(obj, bytes): return obj if isinstance(obj, (int, float)): return np.array(obj).tobytes() if isinstance(obj, PIL.Image.Image): return np.array(obj.convert("RGBA")).tobytes() if isinstance(obj, torch.Tensor): return obj.numpy().tobytes() if isinstance(obj, np.ndarray): return obj.tobytes() raise ValueError(f"Unsupported object type: {type(obj)}") def apply_mm_hashes(mm_data: Dict[str, Any], hash_lib=default_hasher) -> Dict[str, List[str]]: """Apply hashing to multimodal data items.""" def _hash_image(image): # only support single modality w/ PIL.Image.Image for now # TODO: possible hash collision w/ this simplified version (vllm/PR/17378) hasher = hash_lib() if isinstance(image, torch.Tensor): # TODO: Device tensor hashing is an open issue. Limited hashing to CPU for now. image = image.cpu() hasher.update(serialize_item(image)) return hasher.hexdigest() mm_items = { modality: items if isinstance(items, list) else [items] for modality, items in mm_data.items() } # TODO: need to hash both modality and item to distinguish modality (vllm/PR) mm_hashes = { modality: [_hash_image(item) for item in items] for modality, items in mm_items.items() } return mm_hashes def hexdigest_to_int32(hex_digest: str) -> List[int]: """Convert a 256-bit hexadecimal digest to 8 int32 values.""" if len(hex_digest) != 64: raise ValueError( f"Expected 64 character hexadecimal string, got {len(hex_digest)}") result = [] for i in range(0, 64, 8): hex_chunk = hex_digest[i:i + 8] value = int(hex_chunk, 16) if value > 0x7FFFFFFF: # Check if the highest bit is set (value > 2^31-1) value = value - 0x100000000 # Convert to signed by subtracting 2^32 result.append(value) return result def find_mm_token_lengths(mm_data: Dict[str, Any], input_processor: Any) -> List[int]: """Get multimodal token lengths from multimodal data items. """ mm_items = { modality: items if isinstance(items, list) else [items] for modality, items in mm_data.items() } num_mm_tokens = {} for modality, items in mm_items.items(): if modality != "image": #TODO: support other modalities raise ValueError( f"Unsupported modality: {modality}. Only 'image' modality is currently supported for hashing." ) if not hasattr(input_processor, "get_num_tokens_per_image"): #TODO: backward compatibility for models that don't yet have get_num_tokens_per_image implemented #TODO: only support qwen2_vl for now raise AttributeError( f"Input processor {type(input_processor).__name__} does not have 'get_num_tokens_per_image' method required for multimodal hashing." ) modality_token_lengths = [] for item in items: if isinstance(item, torch.Tensor): item = ToPILImage()(item) num_tokens = input_processor.get_num_tokens_per_image( image_width=item.width, image_height=item.height, ) modality_token_lengths.append(num_tokens) num_mm_tokens[modality] = modality_token_lengths return num_mm_tokens['image'] # flatten all mm instances to a single list def find_mm_token_positions(input_ids: Union[torch.Tensor, List[int], np.ndarray], num_mm_tokens: List[int], vocab_size: int, mm_token_ids: torch.Tensor = None) -> List[int]: """Get multimodal token positions using IDs > vocab_size and known lengths. This function finds multimodal tokens (with IDs > vocab_size) and uses the provided lengths in num_mm_tokens to identify where each chunk starts. This works even when there are no gaps between different image sequences (e.g., when all images use the same token IDs). Args: input_ids: Token sequence (tensor, list, or numpy array) num_mm_tokens: List of lengths for each multimodal token chunk vocab_size: Size of the model's vocabulary mm_token_ids (optional): possible token ids for multimodal tokens Returns: List of starting positions for each multimodal token chunk """ # Convert input_ids to tensor if needed if not isinstance(input_ids, torch.Tensor): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids) elif isinstance(input_ids, np.ndarray): input_ids = torch.from_numpy(input_ids) # Create mask for multimodal tokens if mm_token_ids is None: mm_mask = input_ids >= vocab_size else: mm_mask = torch.isin(input_ids, mm_token_ids) # If no multimodal tokens found, return empty list if not torch.any(mm_mask): return [] # Get positions of all multimodal tokens mm_positions = torch.where(mm_mask)[0].tolist() assert len(mm_positions) == sum( num_mm_tokens ), f"Number of multimodal tokens does not match sum of all lengths" # Use num_mm_tokens to find the starting position of each chunk start_positions = [] current_position = 0 # Process each expected length for length in num_mm_tokens: if current_position < len(mm_positions): # Add the starting position of this chunk start_positions.append(mm_positions[current_position]) # Move to the next chunk current_position += length return start_positions def validate_mm_inputs(prompt_token_ids: Union[torch.Tensor, List[int], np.ndarray], mm_hashes: List[List[int]], start_positions: List[int], num_mm_tokens: List[int]) -> None: """Validates multimodal inputs for consistency and correctness.""" # Validate number of hashes matches number of chunks if len(mm_hashes) != len(num_mm_tokens): raise AssertionError( f"Number of hashes ({len(mm_hashes)}) does not match " f"number of multimodal chunks ({len(num_mm_tokens)})") # Validate number of start positions matches number of chunks if len(start_positions) != len(num_mm_tokens): raise AssertionError( f"Number of start positions ({len(start_positions)}) does not match " f"number of multimodal chunks ({len(num_mm_tokens)})") # Validate each chunk's position and length prompt_len = len(prompt_token_ids) # Verify start_positions are sorted if not all(start_positions[i] < start_positions[i + 1] for i in range(len(start_positions) - 1)): raise AssertionError( "start_positions must be sorted in ascending order") for chunk_idx, (start_pos, chunk_len) in enumerate(zip(start_positions, num_mm_tokens)): if start_pos < 0: raise AssertionError( f"Invalid negative start position {start_pos} for chunk {chunk_idx}" ) if start_pos + chunk_len > prompt_len: raise AssertionError( f"Multimodal chunk {chunk_idx} at position {start_pos} with length {chunk_len} " f"exceeds input sequence length {prompt_len}") # Check for overlap with next chunk if chunk_idx < len(start_positions) - 1: next_start = start_positions[chunk_idx + 1] if start_pos + chunk_len > next_start: raise AssertionError( f"Multimodal chunk {chunk_idx} at position {start_pos} with length {chunk_len} " f"overlaps with chunk {chunk_idx + 1} at position {next_start}" )