mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-7918][feat] Support kvcache reuse for phi4mm (#7563)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
335c007df8
commit
fc9f4c9295
@ -45,13 +45,13 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
|
||||
|
||||
| Model Architecture/Feature | Overlap Scheduler | CUDA Graph | Chunked Prefill | Torch Sampler | TLLM C++ Sampler | KV Cache Reuse | Logits Post Processor | EPD Disaggregated Serving | Modality |
|
||||
| ---------------------------------- | ----------------- | ---------- | --------------- | ------------- | ---------------- | -------------- | --------------------- | ------------------------- | -------- |
|
||||
| Gemma3ForConditionalGeneration | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No | L + I |
|
||||
| Gemma3ForConditionalGeneration | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No | L + I |
|
||||
| HCXVisionForCausalLM | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
|
||||
| LlavaLlamaModel (VILA) | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + V |
|
||||
| LlavaNextForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
|
||||
| Llama4ForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
|
||||
| Mistral3ForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
|
||||
| Phi4MMForCausalLM | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + A |
|
||||
| Phi4MMForCausalLM | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + A |
|
||||
| Qwen2VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
|
||||
| Qwen2_5_VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
|
||||
|
||||
|
||||
@ -8,6 +8,6 @@
|
||||
| LLaVA-NeXT | Yes | Yes | Yes | Yes |
|
||||
| Llama 4 | Yes | Yes | No | No |
|
||||
| Mistral-Small-3.1 | Yes | Yes | No | No |
|
||||
| Phi-4-multimodal | Yes | Yes | No | No |
|
||||
| Phi-4-multimodal | Yes | Yes | Yes | No |
|
||||
| Qwen2-VL | Yes | Yes | Yes | Yes |
|
||||
| Qwen2.5-VL | Yes | Yes | Yes | Yes |
|
||||
|
||||
@ -19,8 +19,8 @@ from PIL import Image
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
|
||||
from ...executor.request import LoRARequest
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor,
|
||||
MultimodalPlaceholderMetadata,
|
||||
from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
|
||||
InputProcessor, MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...logger import logger
|
||||
@ -29,7 +29,8 @@ from ...sampling_params import SamplingParams
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..model_config import ModelConfig
|
||||
from .modeling_auto import AutoModelForCausalLM
|
||||
from .modeling_multimodal_utils import fuse_input_embeds
|
||||
from .modeling_multimodal_utils import (find_uncached_mm_embeds,
|
||||
fuse_input_embeds)
|
||||
from .modeling_utils import register_auto_model
|
||||
|
||||
# Special token ids from the original Phi-4-multimodal-instruct implementation
|
||||
@ -389,7 +390,7 @@ class HFPhi4MultimodalEncoder(transformers.PreTrainedModel,
|
||||
return self._encoding_batch_request(multimodal_params, mm_token_ids)
|
||||
|
||||
|
||||
class Phi4MMInputProcessor(InputProcessor):
|
||||
class Phi4MMInputProcessor(BaseMultimodalInputProcessor, InputProcessor):
|
||||
|
||||
def __init__(self,
|
||||
model_path: str,
|
||||
@ -415,6 +416,20 @@ class Phi4MMInputProcessor(InputProcessor):
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=self.use_fast)
|
||||
|
||||
def get_mm_token_ids(self) -> Optional[torch.Tensor]:
|
||||
return torch.tensor([_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
def get_num_tokens_per_image(
|
||||
self,
|
||||
*,
|
||||
image: Image.Image,
|
||||
**kwargs,
|
||||
):
|
||||
data = self.processor.image_processor.preprocess(image)
|
||||
return data["num_img_tokens"][0]
|
||||
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self, inputs: TextPrompt, sampling_params: SamplingParams
|
||||
@ -589,6 +604,9 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
multimodal_param.multimodal_data["multimodal_embedding"]
|
||||
for multimodal_param in multimodal_params
|
||||
]
|
||||
mm_embedding = find_uncached_mm_embeds(
|
||||
mm_embedding, multimodal_params[:num_context_requests])
|
||||
|
||||
input_ids, input_embeds = fuse_input_embeds(
|
||||
self.llm.model.embed_tokens,
|
||||
input_ids,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user