From d45236b253970d6270281bb37163823d3c915a42 Mon Sep 17 00:00:00 2001 From: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> Date: Fri, 8 Aug 2025 20:09:26 +0800 Subject: [PATCH] [TRTLLM-6308][feat] Support Aggregate mode for phi4-mm (#6184) Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_phi4mm.py | 486 +++++++++++++++--- 1 file changed, 421 insertions(+), 65 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_phi4mm.py b/tensorrt_llm/_torch/models/modeling_phi4mm.py index b5ad4f4520..b54b64fc69 100644 --- a/tensorrt_llm/_torch/models/modeling_phi4mm.py +++ b/tensorrt_llm/_torch/models/modeling_phi4mm.py @@ -1,14 +1,23 @@ # Plan for phi4-mm model support. # (done) step 1: support legacy inference pipeline for phi4-mm model. -# (todo) step 2: refactor the inference pipeline to use AGGREGATE mode (https://github.com/NVIDIA/TensorRT-LLM/pull/5522). +# (done) step 2: refactor the inference pipeline to use AGGREGATE mode (https://github.com/NVIDIA/TensorRT-LLM/pull/5522). +# (todo) step 3: optimization +# * use TRTLLM-attention to replace original pytorch attention in vision/audio encoders. +# * use data parallel to accelerate inference. import copy +import importlib +import os +import sys +from pathlib import Path from typing import List, Optional, Tuple import torch import transformers from PIL import Image +from tensorrt_llm.inputs.multimodal import MultimodalParams + from ...executor.request import LoRARequest from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, register_input_processor) @@ -21,16 +30,361 @@ from .modeling_auto import AutoModelForCausalLM from .modeling_multimodal_utils import fuse_input_embeds from .modeling_utils import register_auto_model -# Special tokens -_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>' -_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>' +# Special token ids from the original Phi-4-multimodal-instruct implementation +_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>' from HF `modeling_phi4mm.py` +_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>' from HF `modeling_phi4mm.py` +_PAD_TOKEN_ID = 199999 # '<|endoftext|>' from HF `special_tokens_map.json` +_COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE = [-9999, + -1] # from HF `modeling_phi4mm.py` +_COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE = [float('-inf'), -10000 + ] # from HF `modeling_phi4mm.py` + +# Below classes will be loaded from HuggingFace codes, rather than using transformers version, +# since transformers version is not compatible with checkpoints and configs from `microsoft/Phi-4-multimodal-instruct`. +Phi4MMAudioEmbedding = None +Phi4MMImageEmbedding = None +Phi4MMConfig = None -# Create a PreTrainedModel class for transformers=4.53.1 upgrade. -# Core idea is to provide `prepare_inputs_for_generation` method from `GenerationMixin`. -class NewPreTrainedModel(transformers.modeling_utils.PreTrainedModel, - transformers.generation.GenerationMixin): - pass +# Make this a runtime lookup rather than a module-wide constant for easier unit testing. +def _is_disagg() -> bool: + return os.getenv("TLLM_MULTIMODAL_DISAGGREGATED", "0") == "1" + + +# Load the Phi4MM classes from HuggingFace Phi-4-multimodal-instruct repo. +# Remove this function by using the transformers version of Phi4Multimodal when weights/configs are converted to transformers format. +def _load_phi4mm_classes(local_path): + """Load Phi4MM classes from the specified local path.""" + global Phi4MMAudioEmbedding, Phi4MMImageEmbedding, Phi4MMConfig + if Phi4MMAudioEmbedding is not None and Phi4MMImageEmbedding is not None and Phi4MMConfig is not None: + return + + # Add parent folder to sys.path to enable relative import. + original_sys_path = sys.path.copy() + package_folder = Path(local_path) + parent_folder = str(package_folder.parent) + if parent_folder not in sys.path: + sys.path.insert(0, parent_folder) + + try: + # Import Phi4MMConfig from configuration_phi4mm.py. + config_path = os.path.join(local_path, 'configuration_phi4mm.py') + if not os.path.exists(config_path): + raise FileNotFoundError( + f"configuration_phi4mm.py not found at {local_path}.") + spec = importlib.util.spec_from_file_location("hf_config", config_path) + hf_config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(hf_config) + Phi4MMConfig = hf_config.Phi4MMConfig + + # Import Phi4MMAudioEmbedding and Phi4MMImageEmbedding from modeling_phi4mm.py. + modeling_phi4mm_path = os.path.join(local_path, 'modeling_phi4mm.py') + if not os.path.exists(modeling_phi4mm_path): + raise FileNotFoundError( + f"modeling_phi4mm.py not found at {local_path}.") + # `Phi-4-multimodal-instruct` as the package name to avoid relative import errors. + # `hf_modeling_phi4mm` as the module name to avoid name conflicts. + spec = importlib.util.spec_from_file_location( + "Phi-4-multimodal-instruct.hf_modeling_phi4mm", + modeling_phi4mm_path) + hf_modeling_phi4mm = importlib.util.module_from_spec(spec) + spec.loader.exec_module(hf_modeling_phi4mm) + Phi4MMAudioEmbedding = hf_modeling_phi4mm.Phi4MMAudioEmbedding + Phi4MMImageEmbedding = hf_modeling_phi4mm.Phi4MMImageEmbedding + finally: + sys.path = original_sys_path + + +class HFPhi4MultimodalEncoder(transformers.PreTrainedModel, + transformers.generation.GenerationMixin): + + # Copy and modify from https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py::Phi4MMImageAudioEmbedding + # Note: the HF implementation here will cause duplicated encoders on all GPUs for TP>1 scenario. + # TODO: use TRTLLM-attention to replace original pytorch Flash_attn_2 in HFPhi4MultimodalEncoder. + config_class = Phi4MMConfig + base_model_prefix = "model" + _tied_weights_keys = ["lm_head.weight"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config: transformers.PretrainedConfig, **kwargs): + super().__init__(config, **kwargs) + self.padding_idx = config.pad_token_id + + self.embed_tokens = torch.nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx) + + self._attn_implementation = config._attn_implementation + + self.vocab_size = config.vocab_size + + embedding_config = { + 'embedding_cls': config.embd_layer['embedding_cls'], + **config.embd_layer + } + # The default values are from HuggingFace Phi-4-multimodal-instruct codes. + self.image_input_id = embedding_config.get('image_input_id', -1) + self.audio_input_id = embedding_config.get('audio_input_id', -10000) + if self.image_input_id == self.audio_input_id: + raise ValueError( + 'image_input_id and audio_input_id should be different') + + self.image_embd_layer_kwargs = embedding_config['image_embd_layer'] + self.image_embed = Phi4MMImageEmbedding(config, + **self.image_embd_layer_kwargs) + + self.audio_embd_layer_kwargs = embedding_config['audio_embd_layer'] + self.audio_embed = Phi4MMAudioEmbedding(config, + **self.audio_embd_layer_kwargs) + + def _replace_special_token_ids(self, + input_ids: torch.Tensor) -> torch.Tensor: + # Inplace-replacement for special token ids. + torch.where( + (input_ids >= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[0]) + & (input_ids <= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[1]), + torch.tensor(_IMAGE_SPECIAL_TOKEN_ID), + input_ids, + out=input_ids, + ) + torch.where( + (input_ids >= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[0]) + & (input_ids <= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[1]), + torch.tensor(_AUDIO_SPECIAL_TOKEN_ID), + input_ids, + out=input_ids, + ) + return input_ids + + def _batch_infer_image_embeds( + self, batched_input_ids: torch.Tensor, + multimodal_params: List[MultimodalParams]) -> torch.Tensor: + # Batch image inputs and attention mask with padding along dim=1 (patch num). + input_image_embeds_list, input_image_attn_mask_list, input_image_sizes_list = [], [], [] + for mm_param in multimodal_params: + mm_data = mm_param.multimodal_data + input_image_embeds = mm_data["input_image_embeds"] + if input_image_embeds is not None and input_image_embeds.numel( + ) > 0: + input_image_embeds_list.append(input_image_embeds) + input_image_attn_mask_list.append( + mm_data["image_attention_mask"]) + input_image_sizes_list.append(mm_data["image_sizes"]) + batched_image_hidden_states = None + if len(input_image_embeds_list) > 0: + # Padding image embeds/attn_masks along dim=1 (patch dimension). + b_list = [x.shape[0] for x in input_image_embeds_list] + p_list = [x.shape[1] for x in input_image_embeds_list] + c_i, h_i, w_i = input_image_embeds_list[0].shape[2:5] + h_i_attn, w_i_attn = input_image_attn_mask_list[0].shape[2:4] + total_b = sum(b_list) + max_p = max(p_list) + batched_image_embeds = torch.zeros( + (total_b, max_p, c_i, h_i, w_i), + dtype=input_image_embeds_list[0].dtype, + device=input_image_embeds_list[0].device) + batched_image_attn_mask = torch.zeros( + (total_b, max_p, h_i_attn, w_i_attn), + dtype=input_image_embeds_list[0].dtype, + device=input_image_embeds_list[0].device) + b_offset = 0 + for i, tensor in enumerate(input_image_embeds_list): + b, p = tensor.shape[:2] + batched_image_embeds[b_offset:b_offset + b, :p] = tensor + if input_image_attn_mask_list[i] is not None: + batched_image_attn_mask[ + b_offset:b_offset + + b, :p] = input_image_attn_mask_list[i] + else: + batched_image_attn_mask[b_offset:b_offset + b, :p] = 1 + b_offset += b + batched_image_sizes = torch.cat(input_image_sizes_list, dim=0) + # Forward image encoder with batched image embeds. + batched_image_hidden_states = self.image_embed( + input_ids=batched_input_ids, + input_embeds=batched_image_embeds, + image_sizes=batched_image_sizes, + image_attention_mask=batched_image_attn_mask, + wte=self.embed_tokens, + ) + return batched_image_hidden_states + + def _batch_infer_audio_embeds( + self, batched_input_ids: torch.Tensor, + multimodal_params: List[MultimodalParams]) -> torch.Tensor: + # Batch audio inputs and attention mask with padding along dim=1 (patch num). + input_audio_embeds_list, input_audio_attn_mask_list, input_audio_sizes_list = [], [], [] + for mm_param in multimodal_params: + mm_data = mm_param.multimodal_data + input_audio_embeds = mm_data["input_audio_embeds"] + if input_audio_embeds is not None and input_audio_embeds.numel( + ) > 0: + input_audio_embeds_list.append(input_audio_embeds) + input_audio_attn_mask_list.append( + mm_data["audio_attention_mask"]) + input_audio_sizes_list.append(mm_data["audio_embed_sizes"]) + batched_audio_hidden_states = None + if len(input_audio_embeds_list) > 0: + b_list = [x.shape[0] for x in input_audio_embeds_list] + p_list = [x.shape[1] for x in input_audio_embeds_list] + d_a = input_audio_embeds_list[0].shape[2] + total_b = sum(b_list) + max_p = max(p_list) + batched_audio_embeds = torch.zeros( + (total_b, max_p, d_a), + dtype=input_audio_embeds_list[0].dtype, + device=input_audio_embeds_list[0].device) + batched_audio_attn_mask = torch.zeros( + (total_b, max_p), + dtype=input_audio_embeds_list[0].dtype, + device=input_audio_embeds_list[0].device) + b_offset = 0 + for i, tensor in enumerate(input_audio_embeds_list): + b, p = tensor.shape[:2] + batched_audio_embeds[b_offset:b_offset + b, :p] = tensor + if input_audio_attn_mask_list[i] is not None: + batched_audio_attn_mask[ + b_offset:b_offset + + b, :p] = input_audio_attn_mask_list[i] + else: + batched_audio_attn_mask[b_offset:b_offset + b, :p] = 1 + b_offset += b + batched_audio_sizes = torch.cat(input_audio_sizes_list, dim=0) + # Forward audio encoder with batched audio embeds. + batched_audio_hidden_states = self.audio_embed( + input_ids=batched_input_ids, + input_embeds=batched_audio_embeds, + audio_embed_sizes=batched_audio_sizes, + audio_attention_mask=batched_audio_attn_mask, + wte=self.embed_tokens, + ) + return batched_audio_hidden_states + + def _encoding_per_request( + self, multimodal_params: List[MultimodalParams], + mm_token_ids: torch.Tensor) -> List[torch.FloatTensor]: + # Loop implementation. + mm_embeddings = [] + for i in range(len(multimodal_params)): + input_ids = multimodal_params[i].multimodal_data["input_ids"] + input_image_embeds = multimodal_params[i].multimodal_data[ + "input_image_embeds"] + input_audio_embeds = multimodal_params[i].multimodal_data[ + "input_audio_embeds"] + image_sizes = multimodal_params[i].multimodal_data["image_sizes"] + image_attention_mask = multimodal_params[i].multimodal_data[ + "image_attention_mask"] + audio_embed_sizes = multimodal_params[i].multimodal_data[ + "audio_embed_sizes"] + audio_attention_mask = multimodal_params[i].multimodal_data[ + "audio_attention_mask"] + audio_projection_mode = multimodal_params[i].multimodal_data[ + "audio_projection_mode"] + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + input_ids = self._replace_special_token_ids(input_ids) + image_position_mask = input_ids == _IMAGE_SPECIAL_TOKEN_ID + non_image_position_mask = ~image_position_mask + + image_hidden_states = None + if input_image_embeds is not None: + image_hidden_states = self.image_embed( + input_ids=input_ids, + input_embeds=input_image_embeds, + image_sizes=image_sizes, + wte=self.embed_tokens, + image_attention_mask=image_attention_mask, + ) + audio_hidden_states = None + if input_audio_embeds is not None: + audio_hidden_states = self.audio_embed( + input_ids=input_ids, + input_embeds=input_audio_embeds, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + wte=self.embed_tokens, + audio_projection_mode=audio_projection_mode, + ) + + if input_image_embeds is not None and input_audio_embeds is not None: + dtype = image_hidden_states.dtype + hidden_states = image_hidden_states * image_position_mask.to( + dtype).unsqueeze( + -1) + audio_hidden_states * non_image_position_mask.to( + dtype).unsqueeze(-1) + elif input_image_embeds is not None: + hidden_states = image_hidden_states + elif input_audio_embeds is not None: + hidden_states = audio_hidden_states + else: + hidden_states = self.embed_tokens(input_ids) + + # Postprocessing to get multimodal-only embeddings. + mm_token_mask = torch.isin(input_ids, mm_token_ids) + hidden_states = hidden_states[mm_token_mask] + + mm_embeddings.append(hidden_states) + return mm_embeddings + + def _encoding_batch_request( + self, multimodal_params: List[MultimodalParams], + mm_token_ids: torch.Tensor) -> List[torch.FloatTensor]: + # Batch input_ids. + input_ids_list = [ + multimodal_params[i].multimodal_data["input_ids"] + for i in range(len(multimodal_params)) + ] + max_input_ids_len = max( + [input_ids.shape[1] for input_ids in input_ids_list]) + batched_input_ids = torch.full( + (len(multimodal_params), max_input_ids_len), + _PAD_TOKEN_ID, + device=input_ids_list[0].device) + for i, input_ids in enumerate(input_ids_list): + batched_input_ids[i, :input_ids.shape[1]] = input_ids + batched_input_ids = batched_input_ids.view(-1, max_input_ids_len) + batched_input_ids = self._replace_special_token_ids(batched_input_ids) + image_position_mask = batched_input_ids == _IMAGE_SPECIAL_TOKEN_ID + non_image_position_mask = ~image_position_mask + + # Batch inference for image and audio embeds. + batched_image_hidden_states = self._batch_infer_image_embeds( + batched_input_ids, multimodal_params) + batched_audio_hidden_states = self._batch_infer_audio_embeds( + batched_input_ids, multimodal_params) + + # Combine different modalities into one. + if batched_image_hidden_states is not None and batched_audio_hidden_states is not None: + batched_hidden_states = batched_image_hidden_states * image_position_mask.unsqueeze( + -1 + ) + batched_audio_hidden_states * non_image_position_mask.unsqueeze( + -1) + elif batched_image_hidden_states is not None: + batched_hidden_states = batched_image_hidden_states + elif batched_audio_hidden_states is not None: + batched_hidden_states = batched_audio_hidden_states + else: + batched_hidden_states = self.embed_tokens(batched_input_ids) + + # Postprocessing to get multimodal-only embeddings. + mm_token_mask = torch.isin(batched_input_ids, mm_token_ids) + batched_hidden_states = batched_hidden_states[mm_token_mask] + batched_hidden_states = [batched_hidden_states] + return batched_hidden_states + + @torch.inference_mode() + def forward(self, multimodal_params: List[MultimodalParams], + mm_token_ids: torch.Tensor) -> List[torch.FloatTensor]: + if os.getenv("PHI4_MM_PER_REQUEST_INFER", "0") == "1": + # Reference code path to check correctness of batch inference and further dev. + # (TODO) Remove this path after accuracy bench and data parallelism are supported. + return self._encoding_per_request(multimodal_params, mm_token_ids) + else: + # Batch inference as default path. + return self._encoding_batch_request(multimodal_params, mm_token_ids) class Phi4MMInputProcessor(InputProcessor): @@ -40,10 +394,11 @@ class Phi4MMInputProcessor(InputProcessor): model_config: transformers.PretrainedConfig, tokenizer: transformers.AutoTokenizer, trust_remote_code: bool = True): - assert trust_remote_code, "trust_remote_code must be True for Phi4MM" + if not trust_remote_code: + raise ValueError("trust_remote_code must be True for Phi4MM") self.model_config = model_config - self.device = 'cuda' + self.device = 'cpu' self.tokenizer = tokenizer self.use_fast = True @@ -58,37 +413,18 @@ class Phi4MMInputProcessor(InputProcessor): trust_remote_code=trust_remote_code, use_fast=self.use_fast) - # Build pure-pytorch model architecture for multimodal encoder. - # Model weights are also loaded here. - OldPreTrainedModel = transformers.modeling_utils.PreTrainedModel - transformers.modeling_utils.PreTrainedModel = NewPreTrainedModel - # TODO: Make separate Phi4VisionEncoder and Phi4AudioEncoder, and move them to LLM-side. - ref_phi4mm_model = transformers.AutoModelForCausalLM.from_pretrained( - model_path, - trust_remote_code=True, - # Flash_attn_2 only supports bf16 or fp16 and set in HF config. - torch_dtype='auto', - _attn_implementation='flash_attention_2', - ).eval() - transformers.modeling_utils.PreTrainedModel = OldPreTrainedModel - self.phi4mm_modal_encoder = ref_phi4mm_model.model.embed_tokens_extend.to( - self.device) - # Required by Phi4MMImageAudioEmbedding. - # See link: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L701 - self.phi4mm_wte = ref_phi4mm_model.model.embed_tokens.to(self.device) - @torch.inference_mode() def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: - text_prompt, mm_data, mm_processor_kwargs = inputs.get("prompt"), \ - inputs.get("multi_modal_data", {}), inputs.get("mm_processor_kwargs", {}) + text_prompt, mm_data = inputs.get("prompt"), inputs.get( + "multi_modal_data", {}) images = mm_data.get("image", None) audios = mm_data.get("audio", None) if images is not None: if isinstance(images[0], torch.Tensor): - # Convert normalized tensors (0-1) to PIL images (0-255). + # HF Phi4MM can only support PIL images. Convert normalized tensors (0-1) to PIL images (0-255). images = [ Image.fromarray((image.permute(1, 2, 0) * 255).to( torch.uint8).cpu().numpy()) for image in images @@ -109,29 +445,16 @@ class Phi4MMInputProcessor(InputProcessor): else: audio_projection_mode = 'speech' - # Processing with Phi4MMImageAudioEmbedding. - mm_features = self.phi4mm_modal_encoder( - input_ids=inputs['input_ids'], - input_embeds=None, - input_image_embeds=inputs['input_image_embeds'], - input_audio_embeds=inputs['input_audio_embeds'], - image_sizes=inputs['image_sizes'], - image_attention_mask=inputs['image_attention_mask'], - audio_embed_sizes=inputs['audio_embed_sizes'], - audio_attention_mask=inputs['audio_attention_mask'], - audio_projection_mode=audio_projection_mode, - wte=self.phi4mm_wte, - ) - - # Postprocessing to get multimodal-only embeddings. - image_token_mask = inputs['input_ids'] == _IMAGE_SPECIAL_TOKEN_ID - audio_token_mask = inputs['input_ids'] == _AUDIO_SPECIAL_TOKEN_ID - mm_token_mask = image_token_mask | audio_token_mask - mm_features = mm_features[mm_token_mask] - + # Will package inputs for language model forward in AGGREGATE mode. multimodal_data = {} - multimodal_data["multimodal_embedding"] = mm_features - + multimodal_data['input_ids'] = inputs['input_ids'] + multimodal_data['input_image_embeds'] = inputs['input_image_embeds'] + multimodal_data['image_sizes'] = inputs['image_sizes'] + multimodal_data['image_attention_mask'] = inputs['image_attention_mask'] + multimodal_data['input_audio_embeds'] = inputs['input_audio_embeds'] + multimodal_data['audio_embed_sizes'] = inputs['audio_embed_sizes'] + multimodal_data['audio_attention_mask'] = inputs['audio_attention_mask'] + multimodal_data['audio_projection_mode'] = audio_projection_mode return inputs['input_ids'][0].to(torch.int32).tolist(), { "multimodal_data": multimodal_data, } @@ -142,10 +465,11 @@ class Phi4MMInputProcessor(InputProcessor): class Phi4MMForCausalLM(transformers.PreTrainedModel): _supports_flash_attn_2 = True - MM_TOKEN_IDS = torch.tensor( - [_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID]) def __init__(self, model_config: ModelConfig): + if _is_disagg(): + raise ValueError( + "Phi4MM does not support disaggregated inference yet.") config = model_config.pretrained_config super().__init__(config) @@ -154,6 +478,15 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): if hasattr(self, "llm"): return + if not _is_disagg(): + _load_phi4mm_classes(config._name_or_path) + + # Setup HFPhi4MultimodalEncoder in AGGREGATE mode. + self.hf_phi4mm_model = HFPhi4MultimodalEncoder(config).eval() + self.hf_phi4mm_model.to(config.torch_dtype) + # Required by HFPhi4MultimodalEncoder. + self.phi4mm_wte = self.hf_phi4mm_model.embed_tokens + # We use Phi3ForCausalLM as the language model. llm_model_config = copy.deepcopy(model_config) llm_model_config.pretrained_config.architectures = ["Phi3ForCausalLM"] @@ -167,6 +500,18 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): self.is_loaded = True def load_weights(self, weights): + # Load weights into HFPhi4MultimodalEncoder. + if not _is_disagg(): + filtered_weights = {} + for k, v in weights.items(): + if k.startswith("model.embed_tokens."): + new_k = k.replace("model.embed_tokens.", "embed_tokens.") + filtered_weights[new_k] = v + elif k.startswith("model.embed_tokens_extend."): + new_k = k.replace("model.embed_tokens_extend.", "") + filtered_weights[new_k] = v + self.hf_phi4mm_model.load_state_dict(filtered_weights, strict=True) + # Filter out non-language model weights. weights = { k: v @@ -185,9 +530,13 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): else: updated_weights[k] = weights[k] weights = updated_weights - self.llm.load_weights(weights) + # Move mm_token_ids to the correct device. + self.mm_token_ids = torch.tensor( + [_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID], + device=self.device) + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() @@ -215,17 +564,24 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): ) multimodal_params = kwargs.get("multimodal_params", []) - mm_embeds = [] + mm_embedding = [] if len(multimodal_params) > 0: - mm_embeds = [ - multimodal_param.multimodal_data["multimodal_embedding"] - for multimodal_param in multimodal_params - ] + if not _is_disagg(): + # Forward the multimodal data to HFPhi4MultimodalEncoder in AGGREGATE mode. + mm_embedding = self.hf_phi4mm_model(multimodal_params, + self.mm_token_ids) + else: + # Directly fetch the multimodal embedding for DISAGG mode. + # This path is not functional now. `multimodal_params` will be prepared in PyExecutor. + mm_embedding = [ + multimodal_param.multimodal_data["multimodal_embedding"] + for multimodal_param in multimodal_params + ] input_ids, input_embeds = fuse_input_embeds( self.llm.model.embed_tokens, input_ids, - mm_embeds, - mm_token_ids=self.MM_TOKEN_IDS, + mm_embedding, + mm_token_ids=self.mm_token_ids, ) output_prob = self.llm.forward(