diff --git a/docs/source/models/supported-models.md b/docs/source/models/supported-models.md index 9db83e1bd7..577586b5de 100644 --- a/docs/source/models/supported-models.md +++ b/docs/source/models/supported-models.md @@ -45,13 +45,13 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl | Model Architecture/Feature | Overlap Scheduler | CUDA Graph | Chunked Prefill | Torch Sampler | TLLM C++ Sampler | KV Cache Reuse | Logits Post Processor | EPD Disaggregated Serving | Modality | | ---------------------------------- | ----------------- | ---------- | --------------- | ------------- | ---------------- | -------------- | --------------------- | ------------------------- | -------- | -| Gemma3ForConditionalGeneration | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No | L + I | +| Gemma3ForConditionalGeneration | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No | L + I | | HCXVisionForCausalLM | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I | | LlavaLlamaModel (VILA) | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + V | | LlavaNextForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I | | Llama4ForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I | | Mistral3ForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I | -| Phi4MMForCausalLM | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + A | +| Phi4MMForCausalLM | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + A | | Qwen2VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V | | Qwen2_5_VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V | diff --git a/docs/source/reference/multimodal-feature-support-matrix.md b/docs/source/reference/multimodal-feature-support-matrix.md index d0cf237268..269aca60a4 100644 --- a/docs/source/reference/multimodal-feature-support-matrix.md +++ b/docs/source/reference/multimodal-feature-support-matrix.md @@ -8,6 +8,6 @@ | LLaVA-NeXT | Yes | Yes | Yes | Yes | | Llama 4 | Yes | Yes | No | No | | Mistral-Small-3.1 | Yes | Yes | No | No | -| Phi-4-multimodal | Yes | Yes | No | No | +| Phi-4-multimodal | Yes | Yes | Yes | No | | Qwen2-VL | Yes | Yes | Yes | Yes | | Qwen2.5-VL | Yes | Yes | Yes | Yes | diff --git a/tensorrt_llm/_torch/models/modeling_phi4mm.py b/tensorrt_llm/_torch/models/modeling_phi4mm.py index 38ee1eb110..e69a0c5d50 100644 --- a/tensorrt_llm/_torch/models/modeling_phi4mm.py +++ b/tensorrt_llm/_torch/models/modeling_phi4mm.py @@ -19,8 +19,8 @@ from PIL import Image from tensorrt_llm.inputs.multimodal import MultimodalParams from ...executor.request import LoRARequest -from ...inputs import (ExtraProcessedInputs, InputProcessor, - MultimodalPlaceholderMetadata, +from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs, + InputProcessor, MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) from ...logger import logger @@ -29,7 +29,8 @@ from ...sampling_params import SamplingParams from ..attention_backend import AttentionMetadata from ..model_config import ModelConfig from .modeling_auto import AutoModelForCausalLM -from .modeling_multimodal_utils import fuse_input_embeds +from .modeling_multimodal_utils import (find_uncached_mm_embeds, + fuse_input_embeds) from .modeling_utils import register_auto_model # Special token ids from the original Phi-4-multimodal-instruct implementation @@ -389,7 +390,7 @@ class HFPhi4MultimodalEncoder(transformers.PreTrainedModel, return self._encoding_batch_request(multimodal_params, mm_token_ids) -class Phi4MMInputProcessor(InputProcessor): +class Phi4MMInputProcessor(BaseMultimodalInputProcessor, InputProcessor): def __init__(self, model_path: str, @@ -415,6 +416,20 @@ class Phi4MMInputProcessor(InputProcessor): trust_remote_code=trust_remote_code, use_fast=self.use_fast) + def get_mm_token_ids(self) -> Optional[torch.Tensor]: + return torch.tensor([_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID], + dtype=torch.int32, + device=self.device) + + def get_num_tokens_per_image( + self, + *, + image: Image.Image, + **kwargs, + ): + data = self.processor.image_processor.preprocess(image) + return data["num_img_tokens"][0] + @torch.inference_mode() def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams @@ -589,6 +604,9 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): multimodal_param.multimodal_data["multimodal_embedding"] for multimodal_param in multimodal_params ] + mm_embedding = find_uncached_mm_embeds( + mm_embedding, multimodal_params[:num_context_requests]) + input_ids, input_embeds = fuse_input_embeds( self.llm.model.embed_tokens, input_ids,