mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9409][feat] Pass MRoPE tensors for EPD disagg (#9758)
* Why? Certain VLMs like the Qwen family need more than just the multimodal embeddings in the language model, and need MRoPE position IDs and deltas. Prior to this commit, only the embeddings could be communicated from the encoder worker to the prefill worker. * What? This commit extends the `DisaggregatedParams` to include the MRoPE information. It also adjusts several pieces of code required to communicate that between E, P and D workers. Closes TRTLLM-9409. Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
parent
472fe497dc
commit
a6a88985cf
@ -527,6 +527,8 @@ class LlavaNextModel(PreTrainedModel):
|
||||
return
|
||||
if not DISAGG:
|
||||
self.mm_encoder = LlavaNextVisionModel(model_config)
|
||||
else:
|
||||
self.mm_encoder = None
|
||||
|
||||
llm_model_config = copy.deepcopy(model_config)
|
||||
llm_model_config.pretrained_config = model_config.pretrained_config.text_config
|
||||
@ -545,7 +547,8 @@ class LlavaNextModel(PreTrainedModel):
|
||||
if isinstance(weight_mapper, LlavaNextHfWeightMapper):
|
||||
weights = weight_mapper.preprocess_weights(weights)
|
||||
|
||||
self.mm_encoder.load_weights(weights)
|
||||
if self.mm_encoder is not None:
|
||||
self.mm_encoder.load_weights(weights)
|
||||
|
||||
def filter_weights(weights: Dict):
|
||||
transformed_weights = {}
|
||||
|
||||
@ -32,7 +32,8 @@ from ...inputs import (BaseMultimodalDummyInputsBuilder,
|
||||
BaseMultimodalInputProcessor, ExtraProcessedInputs,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
register_input_processor,
|
||||
support_multimodal_disaggregated)
|
||||
from ...logger import logger
|
||||
from ...sampling_params import SamplingParams
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -865,6 +866,8 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
mm_encoder_config = copy.deepcopy(model_config)
|
||||
self.mm_encoder = Qwen2VisionModelBase(
|
||||
mm_encoder_config, kwargs.get('vision_model_class', None))
|
||||
else:
|
||||
self.mm_encoder = None
|
||||
|
||||
def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]):
|
||||
config = model_config.pretrained_config
|
||||
@ -953,24 +956,21 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
"""
|
||||
VLM forward logic with inflight batching support.
|
||||
"""
|
||||
num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
|
||||
num_context_requests = attn_metadata.num_contexts
|
||||
|
||||
multimodal_params = kwargs.get("multimodal_params", [])
|
||||
mm_embeds = []
|
||||
mrope_config = {}
|
||||
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate the mm_multimodal_params from the text-only prompts.
|
||||
mm_multimodal_params = [
|
||||
multimodal_param for multimodal_param in multimodal_params
|
||||
if multimodal_param.multimodal_data.get("image", {}).get(
|
||||
"pixel_values") is not None or multimodal_param.multimodal_data.
|
||||
get("video", {}).get("pixel_values_videos") is not None
|
||||
]
|
||||
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate
|
||||
# the entries that do have multimodal data from those that correspond to text-only prompts.
|
||||
mm_multimodal_params = self._get_requests_with_mm_data(
|
||||
multimodal_params)
|
||||
if len(mm_multimodal_params) > 0:
|
||||
if not _is_disagg():
|
||||
mm_embeds = get_multimodal_embeddings(
|
||||
encoder_forward_fn=self.mm_encoder.forward,
|
||||
multimodal_params=mm_multimodal_params)
|
||||
else:
|
||||
elif not getattr(self, "support_mm_disagg", False):
|
||||
raise NotImplementedError(
|
||||
"Qwen2VLModel does not support disaggregated inference yet. Please unset "
|
||||
f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
|
||||
@ -995,6 +995,21 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
logger.debug(f'output shape: {output_prob.shape}')
|
||||
return output_prob
|
||||
|
||||
def _get_requests_with_mm_data(self, multimodal_params):
|
||||
mm_multimodal_params = []
|
||||
for multimodal_param in multimodal_params:
|
||||
data = multimodal_param.multimodal_data
|
||||
if (
|
||||
# The first 2 conditions check whether there is input on which inference should be run.
|
||||
data.get("image", {}).get("pixel_values") is not None or
|
||||
data.get("video", {}).get("pixel_values_videos") is not None
|
||||
# This condition corresponds to when the embeddings are already populated, as is e.g.
|
||||
# the case in EPD disagg in the prefill worker.
|
||||
or data.get("multimodal_embedding")):
|
||||
mm_multimodal_params.append(multimodal_param)
|
||||
|
||||
return mm_multimodal_params
|
||||
|
||||
|
||||
@register_vision_encoder(Qwen2VisionModelBase,
|
||||
vlm_base_model=Qwen2VisionTransformerPretrainedModel)
|
||||
@ -1032,11 +1047,89 @@ class Qwen2VLModel(Qwen2VLModelBase):
|
||||
self.llm.load_weights(weights, weight_mapper)
|
||||
|
||||
|
||||
class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase):
|
||||
|
||||
def get_prompt_token_ids(
|
||||
self, inputs: TextPrompt,
|
||||
mm_handles: List[Dict[str,
|
||||
Any]]) -> Tuple[List[int], List[int], List[int]]:
|
||||
"""
|
||||
Build input token ids with multimodal placeholders expanded to the number of MM tokens.
|
||||
|
||||
Args:
|
||||
inputs: Text prompt input container. Must contain a non-empty prompt string.
|
||||
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
|
||||
|
||||
Returns:
|
||||
Tuple[List[int], List[int], List[int]]:
|
||||
- expanded_ids: token ids with each image token expanded to a placeholder repeated per MM token
|
||||
- mm_token_length: per-image MM token lengths
|
||||
- mm_token_offsets: start offsets (positions) for each image's MM tokens within expanded_ids
|
||||
"""
|
||||
# TODO: Move this function to the base input processor class when extending for more models
|
||||
text_prompt = inputs.get("prompt")
|
||||
if not text_prompt:
|
||||
raise ValueError("Text prompt is required but not provided")
|
||||
|
||||
if not isinstance(mm_handles, list):
|
||||
raise TypeError("mm_handles must be a list")
|
||||
|
||||
if len(mm_handles) != 1:
|
||||
# TODO: only support single multimodal item within a request for now
|
||||
raise NotImplementedError(
|
||||
"Only one mm_handle is supported for Qwen2.5 VL for now")
|
||||
hidden_size = mm_handles[0]['tensor_size'][1]
|
||||
assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size"
|
||||
input_ids = self.tokenizer(text_prompt,
|
||||
return_tensors="pt").input_ids[0]
|
||||
|
||||
image_token_index = self.config.image_token_id
|
||||
|
||||
image_mask = input_ids == image_token_index
|
||||
image_positions = torch.where(image_mask)[0]
|
||||
num_images = len(image_positions)
|
||||
assert num_images == len(
|
||||
mm_handles), "Number of images must match number of mm_handles"
|
||||
total_mm_tokens = sum(mm_handle["tensor_size"][0]
|
||||
for mm_handle in mm_handles)
|
||||
final_length = len(input_ids) - num_images + total_mm_tokens
|
||||
# Create output tensor
|
||||
expanded_ids = torch.empty(final_length, dtype=input_ids.dtype)
|
||||
placeholder_id = self.tllm_multimodal_token_id
|
||||
|
||||
# Fill the expanded sequence
|
||||
write_pos = 0
|
||||
image_cnt = 0
|
||||
mm_token_length = []
|
||||
mm_token_offsets = []
|
||||
for read_pos in range(len(input_ids)):
|
||||
if input_ids[read_pos] == image_token_index:
|
||||
# Replace with placeholder id
|
||||
mm_token_num = mm_handles[image_cnt]["tensor_size"][0]
|
||||
expanded_ids[write_pos:write_pos + mm_token_num] = \
|
||||
placeholder_id
|
||||
mm_token_offsets.append(write_pos)
|
||||
mm_token_length.append(mm_token_num)
|
||||
write_pos += mm_token_num
|
||||
image_cnt += 1
|
||||
else:
|
||||
# Copy text token as-is
|
||||
expanded_ids[write_pos] = input_ids[read_pos]
|
||||
write_pos += 1
|
||||
|
||||
assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}"
|
||||
assert mm_token_length[-1] + mm_token_offsets[
|
||||
-1] <= final_length, f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less than or equal to final_length ({final_length})"
|
||||
return expanded_ids.to(
|
||||
torch.int32).tolist(), mm_token_length, mm_token_offsets
|
||||
|
||||
|
||||
@support_multimodal_disaggregated
|
||||
@register_vision_encoder(Qwen2VisionModelBase,
|
||||
vlm_base_model=Qwen2_5_VisionModel)
|
||||
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
|
||||
@register_input_processor(
|
||||
Qwen2VLInputProcessorBase,
|
||||
Qwen2_5VLInputProcessorBase,
|
||||
model_type="qwen2_5_vl",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
|
||||
@ -262,6 +262,8 @@ class PyResult:
|
||||
chunk_size=self._chunk_size) if return_generation_logits else None
|
||||
self._log_probs = LogProbStorage() if return_log_probs else None
|
||||
self._mm_embeddings = None
|
||||
self._mrope_position_ids = None
|
||||
self._mrope_position_deltas = None
|
||||
self._additional_context_outputs = {
|
||||
name: []
|
||||
for name in additional_outputs
|
||||
@ -293,6 +295,16 @@ class PyResult:
|
||||
self._mm_embeddings = SharedTensorContainer.from_tensor(
|
||||
mm_embeddings).dump_to_dict()
|
||||
|
||||
def set_mrope_position(
|
||||
self,
|
||||
mrope_position_ids: torch.Tensor,
|
||||
mrope_position_deltas: torch.Tensor,
|
||||
):
|
||||
self._mrope_position_ids = (SharedTensorContainer.from_tensor(
|
||||
mrope_position_ids).dump_to_dict())
|
||||
self._mrope_position_deltas = (SharedTensorContainer.from_tensor(
|
||||
mrope_position_deltas).dump_to_dict())
|
||||
|
||||
def transfer_remaining_device_logits(self):
|
||||
"""Finalize any remaining generation logits transfers (for chunked mode)"""
|
||||
if self._generation_logits:
|
||||
@ -352,6 +364,18 @@ class PyResult:
|
||||
def mm_embedding_handle(self) -> Dict[str, Any] | None:
|
||||
return self._mm_embeddings
|
||||
|
||||
@property
|
||||
def mrope_position_ids_handle(self) -> Dict[str, Any] | None:
|
||||
# NOTE: when populated, the returned `dict` contains the information necessary to rebuild
|
||||
# the `SharedTensorContainer` using the `from_dict` class method.
|
||||
return self._mrope_position_ids
|
||||
|
||||
@property
|
||||
def mrope_position_deltas_handle(self) -> Dict[str, Any] | None:
|
||||
# NOTE: when populated, the returned `dict` contains the information necessary to rebuild
|
||||
# the `SharedTensorContainer` using the `from_dict` class method.
|
||||
return self._mrope_position_deltas
|
||||
|
||||
@property
|
||||
def additional_context_outputs(self) -> Dict[str, torch.Tensor] | None:
|
||||
if self._additional_context_outputs is None:
|
||||
@ -382,7 +406,8 @@ class LlmResult:
|
||||
py_result_properties = frozenset(
|
||||
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs',
|
||||
'mm_embedding_handle', 'additional_context_outputs',
|
||||
'additional_generation_outputs'))
|
||||
'additional_generation_outputs', 'mrope_position_ids_handle',
|
||||
'mrope_position_deltas_handle'))
|
||||
|
||||
def __init__(self,
|
||||
result: Union[bytes, tensorrt_llm.bindings.executor.Result],
|
||||
|
||||
@ -2213,13 +2213,14 @@ class PyTorchModelEngine(ModelEngine):
|
||||
mrope_position_deltas).expand(
|
||||
3, 1, 1)
|
||||
mrope_position_ids.append(gen_mrope_position_ids)
|
||||
multimodal_params.to_device(
|
||||
"multimodal_data",
|
||||
"cuda",
|
||||
pin_memory=True,
|
||||
target_keywords=[
|
||||
"mrope_config.mrope_position_deltas"
|
||||
])
|
||||
if mrope_position_deltas.device.type == "cpu":
|
||||
multimodal_params.to_device(
|
||||
"multimodal_data",
|
||||
"cuda",
|
||||
pin_memory=True,
|
||||
target_keywords=[
|
||||
"mrope_config.mrope_position_deltas"
|
||||
])
|
||||
multimodal_params_list.append(multimodal_params)
|
||||
|
||||
request.py_batch_idx = request.py_seq_slot
|
||||
@ -2448,8 +2449,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# NOTE: self.use_mrope is enough for differentiating whether to use mrope_position_ids but
|
||||
# `_create_dummy_context_requests` from `kv_cache_creater` makes an exception that I can not add multimodal_data to the dummy_request
|
||||
# so that we only replace position_ids with mrope_position_ids when it is not a dummy request and for models who is using mrope.
|
||||
mrope_position_ids = torch.cat(mrope_position_ids,
|
||||
dim=-1).pin_memory()
|
||||
mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
|
||||
if mrope_position_ids.device.type == "cpu":
|
||||
mrope_position_ids = mrope_position_ids.pin_memory()
|
||||
self.mrope_position_ids_cuda[:, :, :total_num_tokens].copy_(
|
||||
mrope_position_ids[:, :, :total_num_tokens], non_blocking=True)
|
||||
final_position_ids = self.mrope_position_ids_cuda[:, :, :
|
||||
@ -3362,7 +3364,26 @@ class PyTorchModelEngine(ModelEngine):
|
||||
mm_embeddings = list(
|
||||
torch.split(mm_embeddings[0], multimodal_chunks, dim=0))
|
||||
|
||||
return {'mm_embeddings': mm_embeddings, 'logits': None}
|
||||
# Extract mrope position data from multimodal_params if available
|
||||
mrope_position_ids_list = []
|
||||
mrope_position_deltas_list = []
|
||||
for multimodal_param in multimodal_params:
|
||||
mrope_config = multimodal_param.multimodal_data.get(
|
||||
'mrope_config', {})
|
||||
mrope_position_ids = mrope_config.get('mrope_position_ids')
|
||||
mrope_position_deltas = mrope_config.get('mrope_position_deltas')
|
||||
if mrope_position_ids is not None:
|
||||
mrope_position_ids_list.append(mrope_position_ids)
|
||||
if mrope_position_deltas is not None:
|
||||
mrope_position_deltas_list.append(mrope_position_deltas)
|
||||
|
||||
result = {'mm_embeddings': mm_embeddings, 'logits': None}
|
||||
if mrope_position_ids_list:
|
||||
result['mrope_position_ids'] = mrope_position_ids_list
|
||||
if mrope_position_deltas_list:
|
||||
result['mrope_position_deltas'] = mrope_position_deltas_list
|
||||
|
||||
return result
|
||||
|
||||
def _init_userbuffers(self, hidden_size):
|
||||
if self.mapping.tp_size <= 1 or self.mapping.pp_size > 1:
|
||||
|
||||
@ -21,7 +21,7 @@ from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from itertools import repeat
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, cast
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -199,6 +199,8 @@ class EarlyStopSampler(Sampler):
|
||||
@dataclass(kw_only=True)
|
||||
class MultimodalResult:
|
||||
mm_embeddings: List[torch.Tensor]
|
||||
# Can be used to include e.g. `mrope_position_ids`, etc.
|
||||
extra_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
def values(self):
|
||||
return vars(self).values()
|
||||
@ -262,7 +264,10 @@ class EarlyStopWithMMResult(Sampler):
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> SampleStateWithMMResult:
|
||||
# from model_outputs to MultimodalResult
|
||||
data = MultimodalResult(mm_embeddings=model_outputs["mm_embeddings"])
|
||||
data = MultimodalResult(
|
||||
mm_embeddings=model_outputs.pop("mm_embeddings"),
|
||||
extra_data={**model_outputs},
|
||||
)
|
||||
return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data)
|
||||
|
||||
@override
|
||||
@ -276,7 +281,12 @@ class EarlyStopWithMMResult(Sampler):
|
||||
scheduled_requests = state.scheduled_requests
|
||||
assert not scheduled_requests.generation_requests
|
||||
mm_embeddings = state.data.mm_embeddings
|
||||
for request, mm_embedding in zip(scheduled_requests.context_requests, mm_embeddings):
|
||||
extra_data = state.data.extra_data or {}
|
||||
mrope_position_ids = extra_data.get("mrope_position_ids", None)
|
||||
mrope_position_deltas = extra_data.get("mrope_position_deltas", None)
|
||||
for i, (request, mm_embedding) in enumerate(
|
||||
zip(scheduled_requests.context_requests, mm_embeddings)
|
||||
):
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
# NOTE: This is a hack: set finish reason manually and set the beam 0
|
||||
request.set_finished_reason(FinishReason.LENGTH, 0)
|
||||
@ -287,6 +297,12 @@ class EarlyStopWithMMResult(Sampler):
|
||||
|
||||
request.py_result.append_mm_embeddings(mm_embedding)
|
||||
|
||||
# Store mrope data if available
|
||||
if mrope_position_ids is not None and mrope_position_deltas is not None:
|
||||
request.py_result.set_mrope_position(
|
||||
mrope_position_ids[i], mrope_position_deltas[i]
|
||||
)
|
||||
|
||||
@override
|
||||
def is_generation_model(self) -> bool:
|
||||
return False
|
||||
|
||||
@ -40,6 +40,8 @@ class DisaggregatedParams:
|
||||
multimodal_hashes: Optional[List[List[int]]] = (
|
||||
None # user provided mm hashes should be a list of 8 integers
|
||||
)
|
||||
mrope_position_ids_handle: Optional[Dict[str, Any]] = None
|
||||
mrope_position_deltas_handle: Optional[Dict[str, Any]] = None
|
||||
|
||||
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
|
||||
return tllme.ContextPhaseParams(
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
import weakref
|
||||
@ -415,12 +416,19 @@ class GenerationResultBase:
|
||||
self.cached_tokens = getattr(response_result, 'cached_tokens', 0)
|
||||
self.avg_decoded_tokens_per_iter = response_result.avg_decoded_tokens_per_iter
|
||||
if context_phase_params is not None:
|
||||
self.disaggregated_params = DisaggregatedParams(
|
||||
existing_disagg_params = self.disaggregated_params
|
||||
# Use `replace` to preserve things like `mrope_position_ids_handle` and
|
||||
# `mrope_position_deltas_handle`. However, explicitly set
|
||||
# `multimodal_embedding_handles=None` since they should no longer be needed.
|
||||
self.disaggregated_params = dataclasses.replace(
|
||||
existing_disagg_params or DisaggregatedParams(),
|
||||
request_type="context_only",
|
||||
first_gen_tokens=context_phase_params.first_gen_tokens,
|
||||
ctx_request_id=context_phase_params.req_id,
|
||||
opaque_state=context_phase_params.opaque_state,
|
||||
draft_tokens=context_phase_params.draft_tokens)
|
||||
draft_tokens=context_phase_params.draft_tokens,
|
||||
multimodal_embedding_handles=None,
|
||||
)
|
||||
|
||||
finish_reasons = response_result.finish_reasons
|
||||
# output_token_ids = (beams, tokens)
|
||||
@ -440,6 +448,8 @@ class GenerationResultBase:
|
||||
if hasattr(response_result, 'mm_embedding_handle'
|
||||
) and response_result.mm_embedding_handle is not None:
|
||||
self._mm_embedding_handle = response_result.mm_embedding_handle
|
||||
mrope_position_ids_handle = response_result.mrope_position_ids_handle
|
||||
mrope_position_deltas_handle = response_result.mrope_position_deltas_handle
|
||||
if self.disaggregated_params is not None:
|
||||
self.disaggregated_params.multimodal_embedding_handles = [
|
||||
response_result.mm_embedding_handle
|
||||
@ -451,6 +461,8 @@ class GenerationResultBase:
|
||||
response_result.mm_embedding_handle
|
||||
],
|
||||
multimodal_hashes=self._multimodal_hashes)
|
||||
self.disaggregated_params.mrope_position_ids_handle = mrope_position_ids_handle
|
||||
self.disaggregated_params.mrope_position_deltas_handle = mrope_position_deltas_handle
|
||||
|
||||
# Processing background errors here ASAF during generation.
|
||||
if self._background_error_handler and (
|
||||
@ -811,8 +823,12 @@ class GenerationResult(GenerationResultBase):
|
||||
|
||||
def _repr_fields(self):
|
||||
return [
|
||||
'request_id', 'prompt_token_ids', 'outputs', 'finished',
|
||||
"context_logits", "mm_embedding_handle"
|
||||
'request_id',
|
||||
'prompt_token_ids',
|
||||
'outputs',
|
||||
'finished',
|
||||
"context_logits",
|
||||
"mm_embedding_handle",
|
||||
]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@ -89,8 +89,12 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
|
||||
|
||||
def _repr_fields(self):
|
||||
return [
|
||||
"request_id", "prompt", "prompt_token_ids", "outputs", "finished",
|
||||
"mm_embedding_handle"
|
||||
"request_id",
|
||||
"prompt",
|
||||
"prompt_token_ids",
|
||||
"outputs",
|
||||
"finished",
|
||||
"mm_embedding_handle",
|
||||
]
|
||||
|
||||
|
||||
@ -419,7 +423,7 @@ class BaseLLM:
|
||||
multimodal_params = None
|
||||
|
||||
if is_mm_disagg:
|
||||
if not self.input_processor.support_mm_disagg:
|
||||
if not getattr(self.input_processor, "support_mm_disagg", False):
|
||||
raise ValueError(
|
||||
"Multimodal disaggregated inference is not supported for this model"
|
||||
)
|
||||
@ -436,14 +440,42 @@ class BaseLLM:
|
||||
mm_hashes = disaggregated_params.multimodal_hashes
|
||||
multimodal_input = MultimodalInput.from_components(
|
||||
mm_hashes, mm_token_positions, mm_token_length)
|
||||
multimodal_data = {"multimodal_embedding": mm_handles}
|
||||
if disaggregated_params.mrope_position_ids_handle is not None:
|
||||
# NOTE: `PyTorchModelEngine` assumes both are present when using mrope.
|
||||
assert disaggregated_params.mrope_position_deltas_handle is not None
|
||||
mrope_config = {}
|
||||
mrope_config[
|
||||
"mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle
|
||||
mrope_config[
|
||||
"mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle
|
||||
multimodal_data["mrope_config"] = mrope_config
|
||||
multimodal_params = MultimodalParams(
|
||||
multimodal_input=multimodal_input,
|
||||
multimodal_data={"multimodal_embedding": mm_handles})
|
||||
multimodal_data=multimodal_data,
|
||||
)
|
||||
|
||||
elif "prompt_token_ids" in inputs:
|
||||
prompt_token_ids = inputs['prompt_token_ids']
|
||||
prompt = None
|
||||
query_token_ids = inputs.get("query_token_ids", None)
|
||||
multimodal_data = {}
|
||||
# NOTE: when running in `generation_only` for disagg, this is the code path we expect to hit.
|
||||
if disaggregated_params is not None and disaggregated_params.mrope_position_ids_handle is not None:
|
||||
# It looks like `PyTorchModelEngine` assumes both are present when using mrope?
|
||||
if disaggregated_params.mrope_position_deltas_handle is None:
|
||||
raise RuntimeError(
|
||||
"`mrope_position_ids_handle` and `mrope_position_deltas_handle` must both "
|
||||
"be provided, or both `None`.")
|
||||
mrope_config = {}
|
||||
mrope_config[
|
||||
"mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle
|
||||
mrope_config[
|
||||
"mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle
|
||||
multimodal_data["mrope_config"] = mrope_config
|
||||
if multimodal_data:
|
||||
multimodal_params = MultimodalParams(
|
||||
multimodal_data=multimodal_data)
|
||||
elif "prompt" in inputs:
|
||||
if 'multi_modal_data' in inputs:
|
||||
# TODO: The current design uses a wrapper for existing input processor (input_processor_with_hash)
|
||||
|
||||
@ -101,14 +101,8 @@ class MultimodalEncoder(_TorchLLM):
|
||||
|
||||
inputs = [prompt_inputs(i) for i in inputs]
|
||||
|
||||
def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any:
|
||||
if isinstance(maybe_batched, list):
|
||||
return maybe_batched[pos]
|
||||
else:
|
||||
return maybe_batched
|
||||
|
||||
futures = []
|
||||
for i, request_inputs in enumerate(inputs):
|
||||
for request_inputs in inputs:
|
||||
future = self.generate_async(request_inputs)
|
||||
futures.append(future)
|
||||
|
||||
|
||||
@ -19,49 +19,23 @@ example_images = [
|
||||
str(test_data_root / "61.jpg"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def multimodal_model_config():
|
||||
"""Get multimodal model configuration similar to integration tests"""
|
||||
# You can extend this to support multiple models or get from environment
|
||||
model_configs = {
|
||||
'llava-v1.6-mistral-7b-hf': {
|
||||
'model_name':
|
||||
'llava-v1.6-mistral-7b-hf',
|
||||
'hf_model_dir':
|
||||
'llava-hf/llava-v1.6-mistral-7b-hf',
|
||||
'model_dir':
|
||||
llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf",
|
||||
}
|
||||
}
|
||||
|
||||
return model_configs['llava-v1.6-mistral-7b-hf']
|
||||
_LLAVA_DIR = llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf"
|
||||
_QWEN_2_5_VL_DIR = llm_models_root() / "Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
|
||||
# TODO: Add multi-image in single chat test
|
||||
@pytest.mark.parametrize("model_key", [
|
||||
"llava-v1.6-mistral-7b-hf",
|
||||
])
|
||||
@pytest.mark.parametrize("model_dir", [_LLAVA_DIR, _QWEN_2_5_VL_DIR])
|
||||
@pytest.mark.parametrize("pd_disagg", [False, True])
|
||||
def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
def test_single_image_chat(model_dir, pd_disagg):
|
||||
"""Test processing single image using encoder (pass mm_embeddings) + LLM API.
|
||||
|
||||
This test verifies that encoder (pass mm_embeddings) + LLM API produces identical
|
||||
results to standard llm generation (pass raw image) by comparing outputs.
|
||||
"""
|
||||
# Get model configuration
|
||||
if model_key != "llava-v1.6-mistral-7b-hf":
|
||||
#TODO: add more model tests progressively here
|
||||
pytest.skip(
|
||||
f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now"
|
||||
)
|
||||
|
||||
# Extract model information from config
|
||||
encoder_model_dir = multimodal_model_config['model_dir']
|
||||
|
||||
# Test configuration
|
||||
max_tokens = 64
|
||||
free_gpu_memory_fraction = 0.6 if not pd_disagg else 0.2
|
||||
free_gpu_memory_fraction = 0.2
|
||||
max_batch_size = 1
|
||||
|
||||
# Test data - OpenAI chat completion format
|
||||
@ -76,15 +50,14 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
)
|
||||
|
||||
# Process multimodal data using encoder (pass mm_embeddings)
|
||||
encoder = MultimodalEncoder(model=encoder_model_dir,
|
||||
max_batch_size=max_batch_size)
|
||||
encoder = MultimodalEncoder(model=model_dir, max_batch_size=max_batch_size)
|
||||
|
||||
cache_transceiver_cfg = CacheTransceiverConfig(
|
||||
backend="DEFAULT") if pd_disagg else None
|
||||
|
||||
disable_overlap_scheduler = pd_disagg
|
||||
|
||||
llm = LLM(model=encoder_model_dir,
|
||||
llm = LLM(model=model_dir,
|
||||
backend='pytorch',
|
||||
kv_cache_config=kv_cache_config,
|
||||
trust_remote_code=True,
|
||||
@ -93,7 +66,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
|
||||
llm_decode = None
|
||||
if pd_disagg:
|
||||
llm_decode = LLM(model=encoder_model_dir,
|
||||
llm_decode = LLM(model=model_dir,
|
||||
backend='pytorch',
|
||||
kv_cache_config=kv_cache_config,
|
||||
trust_remote_code=True,
|
||||
@ -141,6 +114,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
|
||||
assert ep_disaggregated_params is not None, "Encoder output disaggregated params is None"
|
||||
ep_disaggregated_params.request_type = "context_and_generation" if not pd_disagg else "context_only"
|
||||
|
||||
outputs = llm.generate(inputs,
|
||||
sampling_params=sampling_params,
|
||||
disaggregated_params=ep_disaggregated_params)
|
||||
@ -151,10 +125,10 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
pd_disaggregated_params = outputs[0].disaggregated_params
|
||||
pd_disaggregated_params.request_type = "generation_only"
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
inputs[0][
|
||||
'multi_modal_data'] = None # remove multimodal data from input as decoder worker doesn't need it
|
||||
inputs[0]['prompt_token_ids'] = outputs[
|
||||
0].prompt_token_ids # use prompt token ids from encoder output
|
||||
# remove multimodal data from input as decoder worker doesn't need it
|
||||
inputs[0]['multi_modal_data'] = None
|
||||
# use prompt token ids from encoder output
|
||||
inputs[0]['prompt_token_ids'] = outputs[0].prompt_token_ids
|
||||
|
||||
outputs = llm_decode.generate(
|
||||
inputs,
|
||||
@ -199,24 +173,23 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
f"Log probabilities don't match for output {i}, generation {j}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_key", [
|
||||
"llava-v1.6-mistral-7b-hf",
|
||||
])
|
||||
def test_multi_request_batch_chat(model_key, multimodal_model_config):
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir, encoder_max_batch_size",
|
||||
[
|
||||
(_LLAVA_DIR, 3),
|
||||
# Qwen2.5 VL's vision encoder seems to output different embeddings based on this value.
|
||||
# The test only passes with this set to 1.
|
||||
(_QWEN_2_5_VL_DIR, 1),
|
||||
],
|
||||
)
|
||||
def test_multi_request_batch_chat(model_dir, encoder_max_batch_size):
|
||||
"""Test batching multiple multimodal requests and verify encoder path matches raw path.
|
||||
|
||||
This mirrors test_single_image_chat but with a batch of size 3.
|
||||
"""
|
||||
if model_key != "llava-v1.6-mistral-7b-hf":
|
||||
pytest.skip(
|
||||
f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now"
|
||||
)
|
||||
|
||||
encoder_model_dir = multimodal_model_config['model_dir']
|
||||
|
||||
max_tokens = 64
|
||||
free_gpu_memory_fraction = 0.6
|
||||
max_batch_size = 3
|
||||
|
||||
prompts = [
|
||||
"Describe the natural environment in the image.",
|
||||
@ -232,10 +205,10 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config):
|
||||
free_gpu_memory_fraction=free_gpu_memory_fraction,
|
||||
)
|
||||
|
||||
encoder = MultimodalEncoder(model=encoder_model_dir,
|
||||
max_batch_size=max_batch_size)
|
||||
encoder = MultimodalEncoder(model=model_dir,
|
||||
max_batch_size=encoder_max_batch_size)
|
||||
llm = LLM(
|
||||
model=encoder_model_dir,
|
||||
model=model_dir,
|
||||
backend='pytorch',
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_batch_size=1, # fix batch size to reduce non-determinism in tests
|
||||
@ -305,8 +278,7 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config):
|
||||
"Describe the weather in the image.",
|
||||
], 2),
|
||||
])
|
||||
def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates,
|
||||
multimodal_model_config):
|
||||
def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates):
|
||||
"""Test mm_keys in KV cache events with cache reuse scenarios.
|
||||
|
||||
This test verifies:
|
||||
@ -316,7 +288,7 @@ def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates,
|
||||
- Same media + same prompts: full reuse (0 duplicate offsets)
|
||||
- Same media + different prompts: partial reuse (prefix blocks reused)
|
||||
"""
|
||||
encoder_model_dir = multimodal_model_config['model_dir']
|
||||
encoder_model_dir = _LLAVA_DIR
|
||||
|
||||
max_tokens = 16
|
||||
free_gpu_memory_fraction = 0.6
|
||||
|
||||
Loading…
Reference in New Issue
Block a user