mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[MM][Gemma4] Use video profiling hints in encoder budget (#41837)
Signed-off-by: lesj0610 <lesj0610@users.noreply.github.com> Co-authored-by: lesj0610 <lesj0610@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
import pytest
|
||||
from PIL import Image as PILImage
|
||||
|
||||
@@ -102,6 +104,39 @@ def test_get_mm_max_tokens_per_item_respects_configured_max_soft_tokens(
|
||||
assert tokens["video"] == 32 * (70 + 2 + 6)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("limit_mm_per_prompt", "expected_video_tokens"),
|
||||
[
|
||||
({"video": 1}, 32 * (70 + 2 + 6)),
|
||||
({"video": {"count": 1}}, 32 * (70 + 2 + 6)),
|
||||
({"video": {"count": 1, "num_frames": 1}}, 1 * (70 + 2 + 6)),
|
||||
({"video": {"count": 1, "num_frames": 8}}, 8 * (70 + 2 + 6)),
|
||||
({"video": {"count": 1, "num_frames": 32}}, 32 * (70 + 2 + 6)),
|
||||
({"video": {"count": 1, "num_frames": 40}}, 32 * (70 + 2 + 6)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model_id", [GEMMA4_MODEL_ID])
|
||||
def test_get_mm_max_tokens_per_item_respects_configured_video_num_frames(
|
||||
model_id: str,
|
||||
limit_mm_per_prompt: Mapping[str, int | Mapping[str, int]],
|
||||
expected_video_tokens: int,
|
||||
):
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
tokens = processor.info.get_mm_max_tokens_per_item(
|
||||
seq_len=ctx.model_config.max_model_len,
|
||||
mm_counts={"video": 1},
|
||||
)
|
||||
|
||||
assert tokens is not None
|
||||
assert tokens["image"] == 280
|
||||
assert tokens["video"] == expected_video_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", [GEMMA4_MODEL_ID])
|
||||
def test_get_prompt_updates_respects_nested_max_soft_tokens(model_id: str):
|
||||
ctx = build_model_context(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@@ -277,7 +277,7 @@ def build_model_context(
|
||||
dtype: ModelDType = "auto",
|
||||
model_config_kwargs: dict[str, Any] | None = None,
|
||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||
limit_mm_per_prompt: dict[str, int] | None = None,
|
||||
limit_mm_per_prompt: Mapping[str, int | Mapping[str, int]] | None = None,
|
||||
mm_processor_cache_gb: int = 0,
|
||||
):
|
||||
"""Creates an InputProcessingContext for a given model.
|
||||
@@ -300,7 +300,10 @@ def build_model_context(
|
||||
)
|
||||
|
||||
model_config_kwargs = model_config_kwargs or {}
|
||||
limit_mm_per_prompt = limit_mm_per_prompt or {}
|
||||
limit_mm_per_prompt = {
|
||||
modality: dict(limit) if isinstance(limit, Mapping) else limit
|
||||
for modality, limit in (limit_mm_per_prompt or {}).items()
|
||||
}
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
runner=runner,
|
||||
|
||||
@@ -246,7 +246,15 @@ class Gemma4ProcessingInfo(BaseProcessingInfo):
|
||||
processor = self.get_hf_processor()
|
||||
tokens["audio"] = processor.audio_seq_length
|
||||
# Video: each frame ≤ 70 soft tokens + boi + eoi + ~6 ts tokens.
|
||||
tokens["video"] = _VIDEO_MAX_FRAMES * (_VIDEO_MAX_SOFT_TOKENS + 2 + 6)
|
||||
num_frames = _VIDEO_MAX_FRAMES
|
||||
mm_config = self.ctx.model_config.get_multimodal_config()
|
||||
video_opts = mm_config.limit_per_prompt.get("video")
|
||||
if (
|
||||
isinstance(video_opts, VideoDummyOptions)
|
||||
and video_opts.num_frames is not None
|
||||
):
|
||||
num_frames = min(num_frames, video_opts.num_frames)
|
||||
tokens["video"] = num_frames * (_VIDEO_MAX_SOFT_TOKENS + 2 + 6)
|
||||
return tokens
|
||||
|
||||
def get_data_parser(self) -> MultiModalDataParser:
|
||||
|
||||
Reference in New Issue
Block a user