[TRTLLM-7918][feat] Support kvcache reuse for phi4mm (#7563)

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
Wanli Jiang 2025-09-15 15:47:00 +08:00 committed by GitHub
parent 335c007df8
commit fc9f4c9295
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 7 deletions

View File

@ -45,13 +45,13 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
| Model Architecture/Feature | Overlap Scheduler | CUDA Graph | Chunked Prefill | Torch Sampler | TLLM C++ Sampler | KV Cache Reuse | Logits Post Processor | EPD Disaggregated Serving | Modality |
| ---------------------------------- | ----------------- | ---------- | --------------- | ------------- | ---------------- | -------------- | --------------------- | ------------------------- | -------- |
| Gemma3ForConditionalGeneration | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No | L + I |
| Gemma3ForConditionalGeneration | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No | L + I |
| HCXVisionForCausalLM | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
| LlavaLlamaModel (VILA) | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + V |
| LlavaNextForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
| Llama4ForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
| Mistral3ForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
| Phi4MMForCausalLM | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + A |
| Phi4MMForCausalLM | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + A |
| Qwen2VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
| Qwen2_5_VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |

View File

@ -8,6 +8,6 @@
| LLaVA-NeXT | Yes | Yes | Yes | Yes |
| Llama 4 | Yes | Yes | No | No |
| Mistral-Small-3.1 | Yes | Yes | No | No |
| Phi-4-multimodal | Yes | Yes | No | No |
| Phi-4-multimodal | Yes | Yes | Yes | No |
| Qwen2-VL | Yes | Yes | Yes | Yes |
| Qwen2.5-VL | Yes | Yes | Yes | Yes |

View File

@ -19,8 +19,8 @@ from PIL import Image
from tensorrt_llm.inputs.multimodal import MultimodalParams
from ...executor.request import LoRARequest
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
InputProcessor, MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
@ -29,7 +29,8 @@ from ...sampling_params import SamplingParams
from ..attention_backend import AttentionMetadata
from ..model_config import ModelConfig
from .modeling_auto import AutoModelForCausalLM
from .modeling_multimodal_utils import fuse_input_embeds
from .modeling_multimodal_utils import (find_uncached_mm_embeds,
fuse_input_embeds)
from .modeling_utils import register_auto_model
# Special token ids from the original Phi-4-multimodal-instruct implementation
@ -389,7 +390,7 @@ class HFPhi4MultimodalEncoder(transformers.PreTrainedModel,
return self._encoding_batch_request(multimodal_params, mm_token_ids)
class Phi4MMInputProcessor(InputProcessor):
class Phi4MMInputProcessor(BaseMultimodalInputProcessor, InputProcessor):
def __init__(self,
model_path: str,
@ -415,6 +416,20 @@ class Phi4MMInputProcessor(InputProcessor):
trust_remote_code=trust_remote_code,
use_fast=self.use_fast)
def get_mm_token_ids(self) -> Optional[torch.Tensor]:
return torch.tensor([_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID],
dtype=torch.int32,
device=self.device)
def get_num_tokens_per_image(
self,
*,
image: Image.Image,
**kwargs,
):
data = self.processor.image_processor.preprocess(image)
return data["num_img_tokens"][0]
@torch.inference_mode()
def __call__(
self, inputs: TextPrompt, sampling_params: SamplingParams
@ -589,6 +604,9 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]
mm_embedding = find_uncached_mm_embeds(
mm_embedding, multimodal_params[:num_context_requests])
input_ids, input_embeds = fuse_input_embeds(
self.llm.model.embed_tokens,
input_ids,