[TRTLLM-9409][feat] Pass MRoPE tensors for EPD disagg (#9758)

* Why?

Certain VLMs like the Qwen family need more than just the multimodal
embeddings in the language model, and need MRoPE position IDs and
deltas. Prior to this commit, only the embeddings could be communicated
from the encoder worker to the prefill worker.

* What?

This commit extends the `DisaggregatedParams` to include the MRoPE
information. It also adjusts several pieces of code required to
communicate that between E, P and D workers.

Closes TRTLLM-9409.

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
William Zhang 2025-12-22 03:32:49 -08:00 committed by GitHub
parent 472fe497dc
commit a6a88985cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 271 additions and 97 deletions

View File

@ -527,6 +527,8 @@ class LlavaNextModel(PreTrainedModel):
return
if not DISAGG:
self.mm_encoder = LlavaNextVisionModel(model_config)
else:
self.mm_encoder = None
llm_model_config = copy.deepcopy(model_config)
llm_model_config.pretrained_config = model_config.pretrained_config.text_config
@ -545,7 +547,8 @@ class LlavaNextModel(PreTrainedModel):
if isinstance(weight_mapper, LlavaNextHfWeightMapper):
weights = weight_mapper.preprocess_weights(weights)
self.mm_encoder.load_weights(weights)
if self.mm_encoder is not None:
self.mm_encoder.load_weights(weights)
def filter_weights(weights: Dict):
transformed_weights = {}

View File

@ -32,7 +32,8 @@ from ...inputs import (BaseMultimodalDummyInputsBuilder,
BaseMultimodalInputProcessor, ExtraProcessedInputs,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
register_input_processor,
support_multimodal_disaggregated)
from ...logger import logger
from ...sampling_params import SamplingParams
from ..attention_backend import AttentionMetadata
@ -865,6 +866,8 @@ class Qwen2VLModelBase(PreTrainedModel):
mm_encoder_config = copy.deepcopy(model_config)
self.mm_encoder = Qwen2VisionModelBase(
mm_encoder_config, kwargs.get('vision_model_class', None))
else:
self.mm_encoder = None
def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]):
config = model_config.pretrained_config
@ -953,24 +956,21 @@ class Qwen2VLModelBase(PreTrainedModel):
"""
VLM forward logic with inflight batching support.
"""
num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
num_context_requests = attn_metadata.num_contexts
multimodal_params = kwargs.get("multimodal_params", [])
mm_embeds = []
mrope_config = {}
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate the mm_multimodal_params from the text-only prompts.
mm_multimodal_params = [
multimodal_param for multimodal_param in multimodal_params
if multimodal_param.multimodal_data.get("image", {}).get(
"pixel_values") is not None or multimodal_param.multimodal_data.
get("video", {}).get("pixel_values_videos") is not None
]
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate
# the entries that do have multimodal data from those that correspond to text-only prompts.
mm_multimodal_params = self._get_requests_with_mm_data(
multimodal_params)
if len(mm_multimodal_params) > 0:
if not _is_disagg():
mm_embeds = get_multimodal_embeddings(
encoder_forward_fn=self.mm_encoder.forward,
multimodal_params=mm_multimodal_params)
else:
elif not getattr(self, "support_mm_disagg", False):
raise NotImplementedError(
"Qwen2VLModel does not support disaggregated inference yet. Please unset "
f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
@ -995,6 +995,21 @@ class Qwen2VLModelBase(PreTrainedModel):
logger.debug(f'output shape: {output_prob.shape}')
return output_prob
def _get_requests_with_mm_data(self, multimodal_params):
mm_multimodal_params = []
for multimodal_param in multimodal_params:
data = multimodal_param.multimodal_data
if (
# The first 2 conditions check whether there is input on which inference should be run.
data.get("image", {}).get("pixel_values") is not None or
data.get("video", {}).get("pixel_values_videos") is not None
# This condition corresponds to when the embeddings are already populated, as is e.g.
# the case in EPD disagg in the prefill worker.
or data.get("multimodal_embedding")):
mm_multimodal_params.append(multimodal_param)
return mm_multimodal_params
@register_vision_encoder(Qwen2VisionModelBase,
vlm_base_model=Qwen2VisionTransformerPretrainedModel)
@ -1032,11 +1047,89 @@ class Qwen2VLModel(Qwen2VLModelBase):
self.llm.load_weights(weights, weight_mapper)
class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase):
def get_prompt_token_ids(
self, inputs: TextPrompt,
mm_handles: List[Dict[str,
Any]]) -> Tuple[List[int], List[int], List[int]]:
"""
Build input token ids with multimodal placeholders expanded to the number of MM tokens.
Args:
inputs: Text prompt input container. Must contain a non-empty prompt string.
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
Returns:
Tuple[List[int], List[int], List[int]]:
- expanded_ids: token ids with each image token expanded to a placeholder repeated per MM token
- mm_token_length: per-image MM token lengths
- mm_token_offsets: start offsets (positions) for each image's MM tokens within expanded_ids
"""
# TODO: Move this function to the base input processor class when extending for more models
text_prompt = inputs.get("prompt")
if not text_prompt:
raise ValueError("Text prompt is required but not provided")
if not isinstance(mm_handles, list):
raise TypeError("mm_handles must be a list")
if len(mm_handles) != 1:
# TODO: only support single multimodal item within a request for now
raise NotImplementedError(
"Only one mm_handle is supported for Qwen2.5 VL for now")
hidden_size = mm_handles[0]['tensor_size'][1]
assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size"
input_ids = self.tokenizer(text_prompt,
return_tensors="pt").input_ids[0]
image_token_index = self.config.image_token_id
image_mask = input_ids == image_token_index
image_positions = torch.where(image_mask)[0]
num_images = len(image_positions)
assert num_images == len(
mm_handles), "Number of images must match number of mm_handles"
total_mm_tokens = sum(mm_handle["tensor_size"][0]
for mm_handle in mm_handles)
final_length = len(input_ids) - num_images + total_mm_tokens
# Create output tensor
expanded_ids = torch.empty(final_length, dtype=input_ids.dtype)
placeholder_id = self.tllm_multimodal_token_id
# Fill the expanded sequence
write_pos = 0
image_cnt = 0
mm_token_length = []
mm_token_offsets = []
for read_pos in range(len(input_ids)):
if input_ids[read_pos] == image_token_index:
# Replace with placeholder id
mm_token_num = mm_handles[image_cnt]["tensor_size"][0]
expanded_ids[write_pos:write_pos + mm_token_num] = \
placeholder_id
mm_token_offsets.append(write_pos)
mm_token_length.append(mm_token_num)
write_pos += mm_token_num
image_cnt += 1
else:
# Copy text token as-is
expanded_ids[write_pos] = input_ids[read_pos]
write_pos += 1
assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}"
assert mm_token_length[-1] + mm_token_offsets[
-1] <= final_length, f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less than or equal to final_length ({final_length})"
return expanded_ids.to(
torch.int32).tolist(), mm_token_length, mm_token_offsets
@support_multimodal_disaggregated
@register_vision_encoder(Qwen2VisionModelBase,
vlm_base_model=Qwen2_5_VisionModel)
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
@register_input_processor(
Qwen2VLInputProcessorBase,
Qwen2_5VLInputProcessorBase,
model_type="qwen2_5_vl",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={

View File

@ -262,6 +262,8 @@ class PyResult:
chunk_size=self._chunk_size) if return_generation_logits else None
self._log_probs = LogProbStorage() if return_log_probs else None
self._mm_embeddings = None
self._mrope_position_ids = None
self._mrope_position_deltas = None
self._additional_context_outputs = {
name: []
for name in additional_outputs
@ -293,6 +295,16 @@ class PyResult:
self._mm_embeddings = SharedTensorContainer.from_tensor(
mm_embeddings).dump_to_dict()
def set_mrope_position(
self,
mrope_position_ids: torch.Tensor,
mrope_position_deltas: torch.Tensor,
):
self._mrope_position_ids = (SharedTensorContainer.from_tensor(
mrope_position_ids).dump_to_dict())
self._mrope_position_deltas = (SharedTensorContainer.from_tensor(
mrope_position_deltas).dump_to_dict())
def transfer_remaining_device_logits(self):
"""Finalize any remaining generation logits transfers (for chunked mode)"""
if self._generation_logits:
@ -352,6 +364,18 @@ class PyResult:
def mm_embedding_handle(self) -> Dict[str, Any] | None:
return self._mm_embeddings
@property
def mrope_position_ids_handle(self) -> Dict[str, Any] | None:
# NOTE: when populated, the returned `dict` contains the information necessary to rebuild
# the `SharedTensorContainer` using the `from_dict` class method.
return self._mrope_position_ids
@property
def mrope_position_deltas_handle(self) -> Dict[str, Any] | None:
# NOTE: when populated, the returned `dict` contains the information necessary to rebuild
# the `SharedTensorContainer` using the `from_dict` class method.
return self._mrope_position_deltas
@property
def additional_context_outputs(self) -> Dict[str, torch.Tensor] | None:
if self._additional_context_outputs is None:
@ -382,7 +406,8 @@ class LlmResult:
py_result_properties = frozenset(
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs',
'mm_embedding_handle', 'additional_context_outputs',
'additional_generation_outputs'))
'additional_generation_outputs', 'mrope_position_ids_handle',
'mrope_position_deltas_handle'))
def __init__(self,
result: Union[bytes, tensorrt_llm.bindings.executor.Result],

View File

@ -2213,13 +2213,14 @@ class PyTorchModelEngine(ModelEngine):
mrope_position_deltas).expand(
3, 1, 1)
mrope_position_ids.append(gen_mrope_position_ids)
multimodal_params.to_device(
"multimodal_data",
"cuda",
pin_memory=True,
target_keywords=[
"mrope_config.mrope_position_deltas"
])
if mrope_position_deltas.device.type == "cpu":
multimodal_params.to_device(
"multimodal_data",
"cuda",
pin_memory=True,
target_keywords=[
"mrope_config.mrope_position_deltas"
])
multimodal_params_list.append(multimodal_params)
request.py_batch_idx = request.py_seq_slot
@ -2448,8 +2449,9 @@ class PyTorchModelEngine(ModelEngine):
# NOTE: self.use_mrope is enough for differentiating whether to use mrope_position_ids but
# `_create_dummy_context_requests` from `kv_cache_creater` makes an exception that I can not add multimodal_data to the dummy_request
# so that we only replace position_ids with mrope_position_ids when it is not a dummy request and for models who is using mrope.
mrope_position_ids = torch.cat(mrope_position_ids,
dim=-1).pin_memory()
mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
if mrope_position_ids.device.type == "cpu":
mrope_position_ids = mrope_position_ids.pin_memory()
self.mrope_position_ids_cuda[:, :, :total_num_tokens].copy_(
mrope_position_ids[:, :, :total_num_tokens], non_blocking=True)
final_position_ids = self.mrope_position_ids_cuda[:, :, :
@ -3362,7 +3364,26 @@ class PyTorchModelEngine(ModelEngine):
mm_embeddings = list(
torch.split(mm_embeddings[0], multimodal_chunks, dim=0))
return {'mm_embeddings': mm_embeddings, 'logits': None}
# Extract mrope position data from multimodal_params if available
mrope_position_ids_list = []
mrope_position_deltas_list = []
for multimodal_param in multimodal_params:
mrope_config = multimodal_param.multimodal_data.get(
'mrope_config', {})
mrope_position_ids = mrope_config.get('mrope_position_ids')
mrope_position_deltas = mrope_config.get('mrope_position_deltas')
if mrope_position_ids is not None:
mrope_position_ids_list.append(mrope_position_ids)
if mrope_position_deltas is not None:
mrope_position_deltas_list.append(mrope_position_deltas)
result = {'mm_embeddings': mm_embeddings, 'logits': None}
if mrope_position_ids_list:
result['mrope_position_ids'] = mrope_position_ids_list
if mrope_position_deltas_list:
result['mrope_position_deltas'] = mrope_position_deltas_list
return result
def _init_userbuffers(self, hidden_size):
if self.mapping.tp_size <= 1 or self.mapping.pp_size > 1:

View File

@ -21,7 +21,7 @@ from concurrent import futures
from dataclasses import dataclass
from functools import cached_property
from itertools import repeat
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, cast
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, cast
import numpy as np
import torch
@ -199,6 +199,8 @@ class EarlyStopSampler(Sampler):
@dataclass(kw_only=True)
class MultimodalResult:
mm_embeddings: List[torch.Tensor]
# Can be used to include e.g. `mrope_position_ids`, etc.
extra_data: Optional[Dict[str, Any]] = None
def values(self):
return vars(self).values()
@ -262,7 +264,10 @@ class EarlyStopWithMMResult(Sampler):
resource_manager: Optional[ResourceManager] = None,
) -> SampleStateWithMMResult:
# from model_outputs to MultimodalResult
data = MultimodalResult(mm_embeddings=model_outputs["mm_embeddings"])
data = MultimodalResult(
mm_embeddings=model_outputs.pop("mm_embeddings"),
extra_data={**model_outputs},
)
return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data)
@override
@ -276,7 +281,12 @@ class EarlyStopWithMMResult(Sampler):
scheduled_requests = state.scheduled_requests
assert not scheduled_requests.generation_requests
mm_embeddings = state.data.mm_embeddings
for request, mm_embedding in zip(scheduled_requests.context_requests, mm_embeddings):
extra_data = state.data.extra_data or {}
mrope_position_ids = extra_data.get("mrope_position_ids", None)
mrope_position_deltas = extra_data.get("mrope_position_deltas", None)
for i, (request, mm_embedding) in enumerate(
zip(scheduled_requests.context_requests, mm_embeddings)
):
request.state = LlmRequestState.GENERATION_COMPLETE
# NOTE: This is a hack: set finish reason manually and set the beam 0
request.set_finished_reason(FinishReason.LENGTH, 0)
@ -287,6 +297,12 @@ class EarlyStopWithMMResult(Sampler):
request.py_result.append_mm_embeddings(mm_embedding)
# Store mrope data if available
if mrope_position_ids is not None and mrope_position_deltas is not None:
request.py_result.set_mrope_position(
mrope_position_ids[i], mrope_position_deltas[i]
)
@override
def is_generation_model(self) -> bool:
return False

View File

@ -40,6 +40,8 @@ class DisaggregatedParams:
multimodal_hashes: Optional[List[List[int]]] = (
None # user provided mm hashes should be a list of 8 integers
)
mrope_position_ids_handle: Optional[Dict[str, Any]] = None
mrope_position_deltas_handle: Optional[Dict[str, Any]] = None
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
return tllme.ContextPhaseParams(

View File

@ -1,4 +1,5 @@
import asyncio
import dataclasses
import json
import time
import weakref
@ -415,12 +416,19 @@ class GenerationResultBase:
self.cached_tokens = getattr(response_result, 'cached_tokens', 0)
self.avg_decoded_tokens_per_iter = response_result.avg_decoded_tokens_per_iter
if context_phase_params is not None:
self.disaggregated_params = DisaggregatedParams(
existing_disagg_params = self.disaggregated_params
# Use `replace` to preserve things like `mrope_position_ids_handle` and
# `mrope_position_deltas_handle`. However, explicitly set
# `multimodal_embedding_handles=None` since they should no longer be needed.
self.disaggregated_params = dataclasses.replace(
existing_disagg_params or DisaggregatedParams(),
request_type="context_only",
first_gen_tokens=context_phase_params.first_gen_tokens,
ctx_request_id=context_phase_params.req_id,
opaque_state=context_phase_params.opaque_state,
draft_tokens=context_phase_params.draft_tokens)
draft_tokens=context_phase_params.draft_tokens,
multimodal_embedding_handles=None,
)
finish_reasons = response_result.finish_reasons
# output_token_ids = (beams, tokens)
@ -440,6 +448,8 @@ class GenerationResultBase:
if hasattr(response_result, 'mm_embedding_handle'
) and response_result.mm_embedding_handle is not None:
self._mm_embedding_handle = response_result.mm_embedding_handle
mrope_position_ids_handle = response_result.mrope_position_ids_handle
mrope_position_deltas_handle = response_result.mrope_position_deltas_handle
if self.disaggregated_params is not None:
self.disaggregated_params.multimodal_embedding_handles = [
response_result.mm_embedding_handle
@ -451,6 +461,8 @@ class GenerationResultBase:
response_result.mm_embedding_handle
],
multimodal_hashes=self._multimodal_hashes)
self.disaggregated_params.mrope_position_ids_handle = mrope_position_ids_handle
self.disaggregated_params.mrope_position_deltas_handle = mrope_position_deltas_handle
# Processing background errors here ASAF during generation.
if self._background_error_handler and (
@ -811,8 +823,12 @@ class GenerationResult(GenerationResultBase):
def _repr_fields(self):
return [
'request_id', 'prompt_token_ids', 'outputs', 'finished',
"context_logits", "mm_embedding_handle"
'request_id',
'prompt_token_ids',
'outputs',
'finished',
"context_logits",
"mm_embedding_handle",
]
def __repr__(self) -> str:

View File

@ -89,8 +89,12 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
def _repr_fields(self):
return [
"request_id", "prompt", "prompt_token_ids", "outputs", "finished",
"mm_embedding_handle"
"request_id",
"prompt",
"prompt_token_ids",
"outputs",
"finished",
"mm_embedding_handle",
]
@ -419,7 +423,7 @@ class BaseLLM:
multimodal_params = None
if is_mm_disagg:
if not self.input_processor.support_mm_disagg:
if not getattr(self.input_processor, "support_mm_disagg", False):
raise ValueError(
"Multimodal disaggregated inference is not supported for this model"
)
@ -436,14 +440,42 @@ class BaseLLM:
mm_hashes = disaggregated_params.multimodal_hashes
multimodal_input = MultimodalInput.from_components(
mm_hashes, mm_token_positions, mm_token_length)
multimodal_data = {"multimodal_embedding": mm_handles}
if disaggregated_params.mrope_position_ids_handle is not None:
# NOTE: `PyTorchModelEngine` assumes both are present when using mrope.
assert disaggregated_params.mrope_position_deltas_handle is not None
mrope_config = {}
mrope_config[
"mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle
mrope_config[
"mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle
multimodal_data["mrope_config"] = mrope_config
multimodal_params = MultimodalParams(
multimodal_input=multimodal_input,
multimodal_data={"multimodal_embedding": mm_handles})
multimodal_data=multimodal_data,
)
elif "prompt_token_ids" in inputs:
prompt_token_ids = inputs['prompt_token_ids']
prompt = None
query_token_ids = inputs.get("query_token_ids", None)
multimodal_data = {}
# NOTE: when running in `generation_only` for disagg, this is the code path we expect to hit.
if disaggregated_params is not None and disaggregated_params.mrope_position_ids_handle is not None:
# It looks like `PyTorchModelEngine` assumes both are present when using mrope?
if disaggregated_params.mrope_position_deltas_handle is None:
raise RuntimeError(
"`mrope_position_ids_handle` and `mrope_position_deltas_handle` must both "
"be provided, or both `None`.")
mrope_config = {}
mrope_config[
"mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle
mrope_config[
"mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle
multimodal_data["mrope_config"] = mrope_config
if multimodal_data:
multimodal_params = MultimodalParams(
multimodal_data=multimodal_data)
elif "prompt" in inputs:
if 'multi_modal_data' in inputs:
# TODO: The current design uses a wrapper for existing input processor (input_processor_with_hash)

View File

@ -101,14 +101,8 @@ class MultimodalEncoder(_TorchLLM):
inputs = [prompt_inputs(i) for i in inputs]
def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any:
if isinstance(maybe_batched, list):
return maybe_batched[pos]
else:
return maybe_batched
futures = []
for i, request_inputs in enumerate(inputs):
for request_inputs in inputs:
future = self.generate_async(request_inputs)
futures.append(future)

View File

@ -19,49 +19,23 @@ example_images = [
str(test_data_root / "61.jpg"),
]
@pytest.fixture(scope="function")
def multimodal_model_config():
"""Get multimodal model configuration similar to integration tests"""
# You can extend this to support multiple models or get from environment
model_configs = {
'llava-v1.6-mistral-7b-hf': {
'model_name':
'llava-v1.6-mistral-7b-hf',
'hf_model_dir':
'llava-hf/llava-v1.6-mistral-7b-hf',
'model_dir':
llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf",
}
}
return model_configs['llava-v1.6-mistral-7b-hf']
_LLAVA_DIR = llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf"
_QWEN_2_5_VL_DIR = llm_models_root() / "Qwen2.5-VL-3B-Instruct"
# TODO: Add multi-image in single chat test
@pytest.mark.parametrize("model_key", [
"llava-v1.6-mistral-7b-hf",
])
@pytest.mark.parametrize("model_dir", [_LLAVA_DIR, _QWEN_2_5_VL_DIR])
@pytest.mark.parametrize("pd_disagg", [False, True])
def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
def test_single_image_chat(model_dir, pd_disagg):
"""Test processing single image using encoder (pass mm_embeddings) + LLM API.
This test verifies that encoder (pass mm_embeddings) + LLM API produces identical
results to standard llm generation (pass raw image) by comparing outputs.
"""
# Get model configuration
if model_key != "llava-v1.6-mistral-7b-hf":
#TODO: add more model tests progressively here
pytest.skip(
f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now"
)
# Extract model information from config
encoder_model_dir = multimodal_model_config['model_dir']
# Test configuration
max_tokens = 64
free_gpu_memory_fraction = 0.6 if not pd_disagg else 0.2
free_gpu_memory_fraction = 0.2
max_batch_size = 1
# Test data - OpenAI chat completion format
@ -76,15 +50,14 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
)
# Process multimodal data using encoder (pass mm_embeddings)
encoder = MultimodalEncoder(model=encoder_model_dir,
max_batch_size=max_batch_size)
encoder = MultimodalEncoder(model=model_dir, max_batch_size=max_batch_size)
cache_transceiver_cfg = CacheTransceiverConfig(
backend="DEFAULT") if pd_disagg else None
disable_overlap_scheduler = pd_disagg
llm = LLM(model=encoder_model_dir,
llm = LLM(model=model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
trust_remote_code=True,
@ -93,7 +66,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
llm_decode = None
if pd_disagg:
llm_decode = LLM(model=encoder_model_dir,
llm_decode = LLM(model=model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
trust_remote_code=True,
@ -141,6 +114,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
assert ep_disaggregated_params is not None, "Encoder output disaggregated params is None"
ep_disaggregated_params.request_type = "context_and_generation" if not pd_disagg else "context_only"
outputs = llm.generate(inputs,
sampling_params=sampling_params,
disaggregated_params=ep_disaggregated_params)
@ -151,10 +125,10 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
pd_disaggregated_params = outputs[0].disaggregated_params
pd_disaggregated_params.request_type = "generation_only"
sampling_params = SamplingParams(max_tokens=max_tokens)
inputs[0][
'multi_modal_data'] = None # remove multimodal data from input as decoder worker doesn't need it
inputs[0]['prompt_token_ids'] = outputs[
0].prompt_token_ids # use prompt token ids from encoder output
# remove multimodal data from input as decoder worker doesn't need it
inputs[0]['multi_modal_data'] = None
# use prompt token ids from encoder output
inputs[0]['prompt_token_ids'] = outputs[0].prompt_token_ids
outputs = llm_decode.generate(
inputs,
@ -199,24 +173,23 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
f"Log probabilities don't match for output {i}, generation {j}"
@pytest.mark.parametrize("model_key", [
"llava-v1.6-mistral-7b-hf",
])
def test_multi_request_batch_chat(model_key, multimodal_model_config):
@pytest.mark.parametrize(
"model_dir, encoder_max_batch_size",
[
(_LLAVA_DIR, 3),
# Qwen2.5 VL's vision encoder seems to output different embeddings based on this value.
# The test only passes with this set to 1.
(_QWEN_2_5_VL_DIR, 1),
],
)
def test_multi_request_batch_chat(model_dir, encoder_max_batch_size):
"""Test batching multiple multimodal requests and verify encoder path matches raw path.
This mirrors test_single_image_chat but with a batch of size 3.
"""
if model_key != "llava-v1.6-mistral-7b-hf":
pytest.skip(
f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now"
)
encoder_model_dir = multimodal_model_config['model_dir']
max_tokens = 64
free_gpu_memory_fraction = 0.6
max_batch_size = 3
prompts = [
"Describe the natural environment in the image.",
@ -232,10 +205,10 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config):
free_gpu_memory_fraction=free_gpu_memory_fraction,
)
encoder = MultimodalEncoder(model=encoder_model_dir,
max_batch_size=max_batch_size)
encoder = MultimodalEncoder(model=model_dir,
max_batch_size=encoder_max_batch_size)
llm = LLM(
model=encoder_model_dir,
model=model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
max_batch_size=1, # fix batch size to reduce non-determinism in tests
@ -305,8 +278,7 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config):
"Describe the weather in the image.",
], 2),
])
def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates,
multimodal_model_config):
def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates):
"""Test mm_keys in KV cache events with cache reuse scenarios.
This test verifies:
@ -316,7 +288,7 @@ def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates,
- Same media + same prompts: full reuse (0 duplicate offsets)
- Same media + different prompts: partial reuse (prefix blocks reused)
"""
encoder_model_dir = multimodal_model_config['model_dir']
encoder_model_dir = _LLAVA_DIR
max_tokens = 16
free_gpu_memory_fraction = 0.6