diff --git a/tests/models/multimodal/processing/test_gemma4.py b/tests/models/multimodal/processing/test_gemma4.py index 24a30cae9d4..8541701ae10 100644 --- a/tests/models/multimodal/processing/test_gemma4.py +++ b/tests/models/multimodal/processing/test_gemma4.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Mapping + import pytest from PIL import Image as PILImage @@ -102,6 +104,39 @@ def test_get_mm_max_tokens_per_item_respects_configured_max_soft_tokens( assert tokens["video"] == 32 * (70 + 2 + 6) +@pytest.mark.parametrize( + ("limit_mm_per_prompt", "expected_video_tokens"), + [ + ({"video": 1}, 32 * (70 + 2 + 6)), + ({"video": {"count": 1}}, 32 * (70 + 2 + 6)), + ({"video": {"count": 1, "num_frames": 1}}, 1 * (70 + 2 + 6)), + ({"video": {"count": 1, "num_frames": 8}}, 8 * (70 + 2 + 6)), + ({"video": {"count": 1, "num_frames": 32}}, 32 * (70 + 2 + 6)), + ({"video": {"count": 1, "num_frames": 40}}, 32 * (70 + 2 + 6)), + ], +) +@pytest.mark.parametrize("model_id", [GEMMA4_MODEL_ID]) +def test_get_mm_max_tokens_per_item_respects_configured_video_num_frames( + model_id: str, + limit_mm_per_prompt: Mapping[str, int | Mapping[str, int]], + expected_video_tokens: int, +): + ctx = build_model_context( + model_id, + limit_mm_per_prompt=limit_mm_per_prompt, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + + tokens = processor.info.get_mm_max_tokens_per_item( + seq_len=ctx.model_config.max_model_len, + mm_counts={"video": 1}, + ) + + assert tokens is not None + assert tokens["image"] == 280 + assert tokens["video"] == expected_video_tokens + + @pytest.mark.parametrize("model_id", [GEMMA4_MODEL_ID]) def test_get_prompt_updates_respects_nested_max_soft_tokens(model_id: str): ctx = build_model_context( diff --git a/tests/models/utils.py b/tests/models/utils.py index b12ab72d77c..a5d1844a307 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import warnings -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass from typing import Any @@ -277,7 +277,7 @@ def build_model_context( dtype: ModelDType = "auto", model_config_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None, - limit_mm_per_prompt: dict[str, int] | None = None, + limit_mm_per_prompt: Mapping[str, int | Mapping[str, int]] | None = None, mm_processor_cache_gb: int = 0, ): """Creates an InputProcessingContext for a given model. @@ -300,7 +300,10 @@ def build_model_context( ) model_config_kwargs = model_config_kwargs or {} - limit_mm_per_prompt = limit_mm_per_prompt or {} + limit_mm_per_prompt = { + modality: dict(limit) if isinstance(limit, Mapping) else limit + for modality, limit in (limit_mm_per_prompt or {}).items() + } model_config = ModelConfig( model_id, runner=runner, diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py index 029f4ff9bf5..73a5e701e7a 100644 --- a/vllm/model_executor/models/gemma4_mm.py +++ b/vllm/model_executor/models/gemma4_mm.py @@ -246,7 +246,15 @@ class Gemma4ProcessingInfo(BaseProcessingInfo): processor = self.get_hf_processor() tokens["audio"] = processor.audio_seq_length # Video: each frame ≤ 70 soft tokens + boi + eoi + ~6 ts tokens. - tokens["video"] = _VIDEO_MAX_FRAMES * (_VIDEO_MAX_SOFT_TOKENS + 2 + 6) + num_frames = _VIDEO_MAX_FRAMES + mm_config = self.ctx.model_config.get_multimodal_config() + video_opts = mm_config.limit_per_prompt.get("video") + if ( + isinstance(video_opts, VideoDummyOptions) + and video_opts.num_frames is not None + ): + num_frames = min(num_frames, video_opts.num_frames) + tokens["video"] = num_frames * (_VIDEO_MAX_SOFT_TOKENS + 2 + 6) return tokens def get_data_parser(self) -> MultiModalDataParser: