[Model] Add Gemma4 Unified (encoder-free) support (#44429)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
This commit is contained in:
Luciano Martins
2026-06-03 16:01:39 -03:00
committed by GitHub
parent 271328e256
commit a248b45d05
14 changed files with 791 additions and 31 deletions
+5 -4
View File
@@ -24,10 +24,11 @@ vllm serve google/gemma-4-E2B-it \
--speculative-config '{"method":"mtp","model":"gg-hf-am/gemma-4-E2B-it-assistant","num_speculative_tokens":1}'
```
The E2B, E4B, 26B-A4B, and 31B Gemma 4 IT assistant checkpoints are supported
when their configuration uses `model_type: gemma4_assistant`. vLLM maps those
checkpoints to `Gemma4MTPModel` internally and wires the assistant layers to
share KV cache with the target model.
The E2B, E4B, 12B, 26B-A4B, and 31B Gemma 4 IT assistant checkpoints are supported.
Tower-based variants use `model_type: gemma4_assistant` and the encoder-free
Gemma 4 Unified variant (12B) uses `model_type: gemma4_unified_assistant`.
vLLM maps both to `Gemma4MTPModel` internally and wires the assistant layers
to share KV cache with the target model.
If an older vLLM release logs `SpeculativeConfig(method='draft_model', ...)`
for a Gemma 4 assistant checkpoint, that release is treating the assistant as a
+8 -1
View File
@@ -562,6 +562,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>E+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `Gemma4ForConditionalGeneration` | Gemma 4 | T + I<sup>+</sup> + V + A<sup>*</sup> | `google/gemma-4-E2B-it`, etc. | | ✅︎ |
| `Gemma4UnifiedForConditionalGeneration` | Gemma 4 Unified | T + I<sup>+</sup> + V + A | `google/gemma-4-12B-it`, etc. | | ✅︎ |
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
@@ -664,10 +665,16 @@ Some models are supported only via the [Transformers modeling backend](#transfor
For `Gemma4ForConditionalGeneration`:
- audio input is only supported by the `gemma-4-E2B` and `gemma-4-E4B` variants.
- The model does not ingest videos directly. However, vLLMs Gemma 4 implementation supports video inputs by handling video processing internally. Users can send videos directly in the message structure to vLLM, where they are converted into text and image frames before being passed to the model.
- Gemma 4 assistant checkpoints for speculative decoding use vLLM's Gemma
- Gemma 4 assistant checkpoints for speculative decoding use vLLMs Gemma
4 MTP path, not generic draft-model speculative decoding. See the
[Gemma 4 assistant model MTP example](../features/speculative_decoding/mtp.md#gemma-4-assistant-models).
!!! note
For `Gemma4UnifiedForConditionalGeneration`:
- This is the encoder-free Gemma 4 variant (e.g. `gemma-4-12B-it`). Unlike the tower-based `Gemma4ForConditionalGeneration`, it has **no SigLIP vision encoder** and **no audio encoder**. Raw pixel patches are projected directly into LM space via a Dense+LayerNorm pipeline with factorized positional embeddings, and raw audio waveform frames are projected directly through a multimodal embedder.
- All modalities (image, video, audio) are supported.
- Gemma 4 Unified assistant checkpoints (`model_type: gemma4_unified_assistant`) use the same MTP path as the tower-based variant. See the [Gemma 4 assistant model MTP example](../features/speculative_decoding/mtp.md#gemma-4-assistant-models).
!!! note
For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc.), InternVL3 and InternVL3.5 have video inputs support currently.
@@ -0,0 +1,205 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
import pytest
import torch
from PIL import Image as PILImage
from vllm.model_executor.models.gemma4_mm import Gemma4ImagePixelInputs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from ....conftest import ImageTestAssets
from ...utils import build_model_context
# The Unified model ID for testing purposes
GEMMA4_UNIFIED_MODEL_ID = "google/gemma-4-12B-it"
def test_gemma4_unified_image_schema_accepts_variable_patch_counts():
Gemma4ImagePixelInputs(
pixel_values=[
torch.randn(10080, 768),
torch.randn(2520, 768),
],
pixel_position_ids=[
torch.zeros(10080, 2, dtype=torch.long),
torch.zeros(2520, 2, dtype=torch.long),
],
)
def test_gemma4_unified_image_batching_keeps_variable_patch_counts_unstacked():
field = MultiModalFieldConfig.batched("image").field
elems = field.build_elems(
"image",
"pixel_values",
[torch.randn(10080, 768), torch.randn(2520, 768)],
)
reduced = field.reduce_data(list(elems))
assert isinstance(reduced, list)
assert [tensor.shape for tensor in reduced] == [
torch.Size([10080, 768]),
torch.Size([2520, 768]),
]
@pytest.mark.parametrize(
"image_width,image_height,max_soft_tokens",
[
(900, 3, 280),
(3, 900, 280),
(900, 3, 70),
(4000, 2, 1120),
],
)
@pytest.mark.parametrize("model_id", [GEMMA4_UNIFIED_MODEL_ID])
def test_compute_num_soft_tokens_does_not_exceed_max_soft_tokens(
model_id: str,
image_width: int,
image_height: int,
max_soft_tokens: int,
):
"""Verify ``_compute_num_soft_tokens`` caps output at ``max_soft_tokens``."""
ctx = build_model_context(
model_id,
mm_processor_kwargs={"do_pan_and_scan": True},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
num_soft_tokens = processor.info._compute_num_soft_tokens(
image_width=image_width,
image_height=image_height,
max_soft_tokens=max_soft_tokens,
)
assert num_soft_tokens <= max_soft_tokens, (
f"_compute_num_soft_tokens returned {num_soft_tokens} for "
f"image_width={image_width}, image_height={image_height}, "
f"max_soft_tokens={max_soft_tokens} — exceeds the cap."
)
@pytest.mark.parametrize(
("mm_processor_kwargs", "expected_image_tokens"),
[
({}, 280),
({"max_soft_tokens": 70}, 70),
({"max_soft_tokens": 280}, 280),
({"max_soft_tokens": 1120}, 1120),
({"images_kwargs": {"max_soft_tokens": 560}}, 560),
({"images_kwargs": None}, 280),
({"images_kwargs": "not-a-dict"}, 280),
],
)
@pytest.mark.parametrize("model_id", [GEMMA4_UNIFIED_MODEL_ID])
def test_get_mm_max_tokens_per_item_respects_configured_max_soft_tokens(
model_id: str,
mm_processor_kwargs: dict[str, object],
expected_image_tokens: int,
):
ctx = build_model_context(
model_id,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1, "video": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
tokens = processor.info.get_mm_max_tokens_per_item(
seq_len=ctx.model_config.max_model_len,
mm_counts={"image": 1, "video": 1},
)
assert tokens is not None
assert tokens["image"] == expected_image_tokens
assert tokens["video"] == 32 * (70 + 2 + 6)
@pytest.mark.parametrize(
("limit_mm_per_prompt", "expected_video_tokens"),
[
({"video": 1}, 32 * (70 + 2 + 6)),
({"video": {"count": 1}}, 32 * (70 + 2 + 6)),
({"video": {"count": 1, "num_frames": 1}}, 1 * (70 + 2 + 6)),
({"video": {"count": 1, "num_frames": 8}}, 8 * (70 + 2 + 6)),
({"video": {"count": 1, "num_frames": 32}}, 32 * (70 + 2 + 6)),
({"video": {"count": 1, "num_frames": 40}}, 32 * (70 + 2 + 6)),
],
)
@pytest.mark.parametrize("model_id", [GEMMA4_UNIFIED_MODEL_ID])
def test_get_mm_max_tokens_per_item_respects_configured_video_num_frames(
model_id: str,
limit_mm_per_prompt: Mapping[str, int | Mapping[str, int]],
expected_video_tokens: int,
):
ctx = build_model_context(
model_id,
limit_mm_per_prompt=limit_mm_per_prompt,
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
tokens = processor.info.get_mm_max_tokens_per_item(
seq_len=ctx.model_config.max_model_len,
mm_counts={"video": 1},
)
assert tokens is not None
assert tokens["image"] == 280
assert tokens["video"] == expected_video_tokens
@pytest.mark.parametrize("model_id", [GEMMA4_UNIFIED_MODEL_ID])
def test_get_prompt_updates_respects_nested_max_soft_tokens(model_id: str):
ctx = build_model_context(
model_id,
mm_processor_kwargs={"images_kwargs": {"max_soft_tokens": 560}},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
image = PILImage.new("RGB", (1000, 1000), color="white")
image_size = image.size
mm_items = processor.info.parse_mm_data({"image": image})
prompt_update = processor._get_prompt_updates(mm_items, {}, {})[0]
replacement = prompt_update.resolve(0).content.full
expected = processor.info.get_image_repl(
image_width=image_size[0],
image_height=image_size[1],
processor=processor.info.get_hf_processor(),
max_soft_tokens=560,
).full
assert replacement == expected
@pytest.mark.parametrize("model_id", [GEMMA4_UNIFIED_MODEL_ID])
def test_limit_mm_per_prompt(
image_assets: ImageTestAssets,
model_id: str,
):
"""Test that limit_mm_per_prompt restricts multiple images correctly."""
ctx = build_model_context(
model_id,
mm_processor_kwargs={},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
prompt = "<image><image>"
images = [asset.pil_image for asset in image_assets][:2]
if len(images) < 2:
images = [images[0], images[0]]
mm_data = {"image": images}
with pytest.raises(ValueError, match="At most 1 image"):
processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
)
+7
View File
@@ -917,6 +917,13 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"google/gemma-4-E2B-it",
min_transformers_version="5.5.0",
),
# TODO: update min_transformers_version when Gemma4 Unified lands in
# a stable transformers release.
"Gemma4UnifiedForConditionalGeneration": _HfExamplesInfo(
"google/gemma-4-12B-it",
min_transformers_version="5.8.0",
is_available_online=False,
),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmAsrForConditionalGeneration": _HfExamplesInfo(
"zai-org/GLM-ASR-Nano-2512",
+1 -1
View File
@@ -509,7 +509,7 @@ class SpeculativeConfig:
{"n_predict": n_predict, "architectures": ["HYV3MTPModel"]}
)
if hf_config.model_type == "gemma4_assistant":
if hf_config.model_type in ("gemma4_assistant", "gemma4_unified_assistant"):
hf_config.model_type = "gemma4_mtp"
text_config = getattr(hf_config, "text_config", hf_config)
# The assistant runs all decoder layers in a single forward
+1
View File
@@ -597,6 +597,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Gemma3TextModel": Gemma3TextModelConfig,
"Gemma4ForCausalLM": Gemma4Config,
"Gemma4ForConditionalGeneration": Gemma4Config,
"Gemma4UnifiedForConditionalGeneration": Gemma4Config,
"GptOssForCausalLM": GptOssForCausalLMConfig,
"GteModel": SnowflakeGteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
+6 -3
View File
@@ -1051,11 +1051,14 @@ class Gemma4Model(nn.Module, EagleModelMixin):
# Final norm: output = norm(x) * weight
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Embedding scale = sqrt(hidden_size)
# Downcast to model dtype (bfloat16 etc.) for numerical parity
# Embedding scale = sqrt(hidden_size), cast to model dtype to avoid
# mixed-precision drift from bf16 * fp32 across deep stacks.
self.register_buffer(
"normalizer",
torch.tensor(config.hidden_size**0.5),
torch.tensor(
config.hidden_size**0.5,
dtype=self.embed_tokens.weight.dtype,
),
persistent=False,
)
+64 -19
View File
@@ -121,7 +121,7 @@ class Gemma4ImagePixelInputs(TensorSchema):
- np: Number of patches (max_patches = max_soft_tokens * pooling_kernel_size²)
- pp: Patch pixels (patch_size² * 3)
The HF Gemma4ImageProcessor outputs pixel_values as
The Gemma4 image processor outputs pixel_values as
(batch, max_patches, patch_pixels) — already patchified with
zero-padding for patches beyond the real image content.
pixel_position_ids provides (x, y) coordinates per patch,
@@ -341,6 +341,29 @@ class Gemma4ProcessingInfo(BaseProcessingInfo):
)
return PromptUpdateDetails.select_token_id(token_ids, processor.image_token_id)
@staticmethod
def _compute_audio_num_tokens(
num_samples: int, sampling_rate: int, audio_seq_length: int
) -> int:
"""Replicate the audio encoder's sequence-length arithmetic.
Mirrors: mel framing (_unfold in Gemma4AudioFeatureExtractor)
followed by two Conv2d subsampling layers (kernel=3, stride=2,
semicausal padding top=1, bottom=1), capped at audio_seq_length.
"""
frame_length = int(round(sampling_rate * 20.0 / 1000.0))
hop_length = int(round(sampling_rate * 10.0 / 1000.0))
frame_size_for_unfold = frame_length + 1
pad_left = frame_length // 2
padded_samples = num_samples + pad_left
num_mel_frames = (padded_samples - frame_size_for_unfold) // hop_length + 1
if num_mel_frames <= 0:
return 0
t = num_mel_frames
for _ in range(2):
t = (t + 2 - 3) // 2 + 1
return min(t, audio_seq_length)
def get_audio_repl(
self,
*,
@@ -350,20 +373,21 @@ class Gemma4ProcessingInfo(BaseProcessingInfo):
"""Return the dynamic audio token sequence for this audio.
Computes the number of soft tokens from the audio waveform
length using ``ceil(duration_ms / audio_ms_per_token)``.
length by replicating the audio encoder's sequence-length
arithmetic (mel framing + two Conv2d subsampling layers).
"""
if processor is None:
processor = self.get_hf_processor()
sampling_rate = processor.feature_extractor.sampling_rate
num_tokens = processor._compute_audio_num_tokens(
torch.zeros(audio_len), sampling_rate
num_tokens = self._compute_audio_num_tokens(
audio_len, sampling_rate, processor.audio_seq_length
)
config = self.get_hf_config()
token_ids = (
[config.boa_token_id]
+ [processor.audio_token_id] * num_tokens
+ [config.eoa_token_id]
+ [getattr(config, "eoa_token_id", config.eoa_token_index)]
)
return PromptUpdateDetails.select_token_id(token_ids, processor.audio_token_id)
@@ -988,18 +1012,35 @@ class Gemma4ForConditionalGeneration(
self.quant_config = quant_config
self.multimodal_config = multimodal_config
# Only quantize towers when the quant method supports their
# dimensions. BNB/torchao handle arbitrary sizes; other methods
# (Marlin, FP8, …) require dimensions divisible by 64, which
# the vision tower (intermediate_size=4304) does not satisfy.
if quant_config and quant_config.get_name() in [
"bitsandbytes",
"torchao",
]:
tower_quant = quant_config
else:
vision_cfg = config.vision_config
quantizable = (
vision_cfg.hidden_size % 64 == 0
and vision_cfg.intermediate_size % 64 == 0
)
tower_quant = quant_config if quantizable else None
# ---- Vision tower (shared by image and video) ----
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.embed_vision = Gemma4MultimodalEmbedder(
config.vision_config,
config.text_config,
quant_config=quant_config,
quant_config=tower_quant,
prefix=maybe_prefix(prefix, "embed_vision"),
)
recursive_replace_linear(
self.vision_tower,
quant_config,
tower_quant,
prefix=maybe_prefix(prefix, "vision_tower"),
)
@@ -1015,12 +1056,12 @@ class Gemma4ForConditionalGeneration(
self.embed_audio = Gemma4MultimodalEmbedder(
config.audio_config,
config.text_config,
quant_config=quant_config,
quant_config=tower_quant,
prefix=maybe_prefix(prefix, "embed_audio"),
)
recursive_replace_linear(
self.audio_tower,
quant_config,
tower_quant,
prefix=maybe_prefix(prefix, "audio_tower"),
)
else:
@@ -1039,13 +1080,13 @@ class Gemma4ForConditionalGeneration(
# Pre-allocate PLE buffer for CUDA graph compatibility.
# Some variants have hidden_size_per_layer_input=None (no PLE).
ple_dim = config.text_config.hidden_size_per_layer_input
if ple_dim is not None:
if ple_dim is not None and ple_dim > 0:
self.per_layer_embeddings = torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.num_hidden_layers,
ple_dim,
device=(self.language_model.model.embed_tokens.weight.device),
dtype=(self.language_model.model.embed_tokens.weight.dtype),
device=self.language_model.model.embed_tokens.weight.device,
dtype=self.language_model.model.embed_tokens.weight.dtype,
)
else:
self.per_layer_embeddings = None
@@ -1076,6 +1117,9 @@ class Gemma4ForConditionalGeneration(
self.num_shared_experts = self.language_model.num_shared_experts
self.num_redundant_experts = self.language_model.num_redundant_experts
gen_cfg = vllm_config.model_config.try_get_generation_config()
self._suppress_token_ids = gen_cfg.get("suppress_tokens") if gen_cfg else None
# ------------------------------------------------------------------ #
# Input parsing
# ------------------------------------------------------------------ #
@@ -1424,8 +1468,7 @@ class Gemma4ForConditionalGeneration(
input_features = audio_input["input_features_padded"].squeeze(1)
input_features_mask = audio_input["input_features_mask"].squeeze(1)
# Run audio tower — mask uses standard HF convention
# (True=valid, False=padding).
# Run audio tower — mask convention: True=valid, False=padding.
audio_outputs = self.audio_tower(input_features, input_features_mask)
if isinstance(audio_outputs, tuple):
audio_encodings, audio_mask = audio_outputs
@@ -1436,8 +1479,8 @@ class Gemma4ForConditionalGeneration(
# Project into LM embedding space.
audio_features = self.embed_audio(inputs_embeds=audio_encodings)
# Strip padding per-batch element: only keep real (non-padding)
# tokens. audio_mask is True for valid positions (HF convention).
# Strip padding per-batch element: only keep valid (non-padding)
# tokens.
per_audio = []
for enc, mask in zip(audio_features, audio_mask, strict=True):
per_audio.append(enc[mask]) # [num_real, hidden_size]
@@ -1559,7 +1602,10 @@ class Gemma4ForConditionalGeneration(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
logits = self.language_model.compute_logits(hidden_states)
if logits is not None and self._suppress_token_ids:
logits[:, self._suppress_token_ids] = -float("inf")
return logits
# ------------------------------------------------------------------ #
# Bidirectional attention helpers
@@ -1617,8 +1663,7 @@ class Gemma4ForConditionalGeneration(
"embed_vision.embedding.",
"embed_audio.embedding.",
]
# Models without audio tower should skip
# audio weights entirely.
# Models without audio tower should skip audio weights entirely.
if self.audio_tower is None:
ignore_prefixes.extend(
[
+19 -3
View File
@@ -279,11 +279,19 @@ class Gemma4MTPDecoderLayer(nn.Module):
else config.head_dim
)
use_k_eq_v = is_full_attention and getattr(config, "attention_k_eq_v", False)
if use_k_eq_v:
num_kv_heads = getattr(
config, "num_global_key_value_heads", config.num_key_value_heads
)
else:
num_kv_heads = config.num_key_value_heads
self.self_attn = Gemma4MTPAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
max_position_embeddings=config.max_position_embeddings,
cache_config=cache_config,
@@ -545,6 +553,10 @@ class Gemma4MTP(nn.Module):
else:
self.masked_embedding = None
draft_cfg = vllm_config.speculative_config.draft_model_config
gen_cfg = draft_cfg.try_get_generation_config()
self._suppress_token_ids = gen_cfg.get("suppress_tokens") if gen_cfg else None
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
@@ -589,11 +601,15 @@ class Gemma4MTP(nn.Module):
spec_step_idx: int = 0,
) -> torch.Tensor | None:
if self.masked_embedding is not None:
return self.masked_embedding(
logits = self.masked_embedding(
hidden_states,
self._get_full_lm_head_weight(),
)
return self.logits_processor(self.lm_head, hidden_states)
else:
logits = self.logits_processor(self.lm_head, hidden_states)
if logits is not None and self._suppress_token_ids:
logits[:, self._suppress_token_ids] = -float("inf")
return logits
def get_top_tokens(
self,
@@ -0,0 +1,466 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma 4 Unified multimodal model (encoder-free image + audio + video).
The Unified Gemma4 variant has no SigLIP vision tower and no audio tower.
Raw pixel patches are projected directly to LM space via a Dense+LayerNorm
pipeline with factorized 2D positional embeddings (Gemma4UnifiedVisionEmbedder),
then routed through the same Gemma4MultimodalEmbedder used by the tower-based
variant. Audio inputs are raw waveform frames projected directly through the
multimodal embedder.
This module subclasses Gemma4ForConditionalGeneration from gemma4_mm rather
than reimplementing it from scratch. Only the multimodal pipeline differs;
the language model, MTP integration, bidirectional attention helpers,
embedding/forward path, and LoRA support are all inherited unchanged.
"""
import math
from collections.abc import Iterable, Mapping
import torch
from torch import nn
from transformers.models.gemma4_unified.configuration_gemma4_unified import (
Gemma4UnifiedConfig,
)
from transformers.models.gemma4_unified.processing_gemma4_unified import (
Gemma4UnifiedProcessor,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import VideoDummyOptions
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM
from vllm.model_executor.models.gemma4_mm import (
_SUPPORTED_SOFT_TOKENS,
_VIDEO_MAX_FRAMES,
_VIDEO_MAX_SOFT_TOKENS,
Gemma4AudioInputs,
Gemma4DummyInputsBuilder,
Gemma4ForConditionalGeneration,
Gemma4ImageInputs,
Gemma4ImagePixelInputs,
Gemma4MultimodalEmbedder,
Gemma4MultiModalProcessor,
Gemma4ProcessingInfo,
_get_max_soft_tokens,
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
# Re-export so tests/code targeting the unified variant can import from here
# rather than reaching into gemma4_mm.
__all__ = [
"Gemma4ImagePixelInputs",
"Gemma4UnifiedVisionEmbedder",
"Gemma4UnifiedProcessingInfo",
"Gemma4UnifiedForConditionalGeneration",
]
# ---------------------------------------------------------------------------
# Encoder-free vision embedder
# ---------------------------------------------------------------------------
class Gemma4UnifiedVisionEmbedder(nn.Module):
"""Encoder-free vision embedder for Gemma4 Unified variants.
Projects raw pixel patches to LM space via dense projection and
factorized 2D positional embeddings. Replaces the SigLIP vision
tower used by the tower-based Gemma4 variant.
Pipeline: raw patches → LN₁ → Dense → LN₂ → +factorized_posemb → LN₃.
"""
def __init__(self, config, quant_config=None):
super().__init__()
patch_dim = config.model_patch_size**2 * 3
mm_embed_dim = config.mm_embed_dim
self.patch_ln1 = nn.LayerNorm(patch_dim)
self.patch_dense = ColumnParallelLinear(
patch_dim,
mm_embed_dim,
bias=True,
quant_config=quant_config,
gather_output=True,
)
self.patch_ln2 = nn.LayerNorm(mm_embed_dim)
self.pos_embedding = nn.Parameter(
torch.zeros(config.mm_posemb_size, 2, mm_embed_dim)
)
self.pos_norm = nn.LayerNorm(mm_embed_dim)
def _factorized_posemb(self, positions_xy: torch.Tensor) -> torch.Tensor:
clamped_pos = positions_xy.clamp(min=0).long()
valid_mask = positions_xy != -1
pos_embs = torch.zeros(
*positions_xy.shape[:-1],
self.pos_embedding.shape[-1],
device=positions_xy.device,
dtype=self.pos_embedding.dtype,
)
for i in range(2):
axis_pe = self.pos_embedding[:, i, :][clamped_pos[..., i]]
mask = valid_mask[..., i].unsqueeze(-1).to(axis_pe.dtype)
pos_embs = pos_embs + (axis_pe * mask)
return pos_embs
def forward(
self,
pixel_values: torch.Tensor,
pixel_position_ids: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.patch_ln1(pixel_values.to(self.pos_embedding.dtype))
hidden_states, _ = self.patch_dense(hidden_states)
hidden_states = self.patch_ln2(hidden_states)
pos_embs = self._factorized_posemb(pixel_position_ids)
hidden_states = hidden_states + pos_embs
hidden_states = self.pos_norm(hidden_states)
return hidden_states
# ---------------------------------------------------------------------------
# Processing info
# ---------------------------------------------------------------------------
class Gemma4UnifiedProcessingInfo(Gemma4ProcessingInfo):
"""ProcessingInfo for the Gemma4 Unified variant.
Two field-name differences from the tower-based parent:
* config → ``Gemma4UnifiedConfig`` (not ``Gemma4Config``)
* vision_config.``num_soft_tokens`` (not ``default_output_length``)
Everything else (token sequencing, audio limits, video frame budget,
parser construction) is inherited unchanged.
"""
def get_hf_config(self):
return self.ctx.get_hf_config(Gemma4UnifiedConfig)
def get_hf_processor(self, **kwargs: object) -> Gemma4UnifiedProcessor:
return self.ctx.get_hf_processor(
Gemma4UnifiedProcessor,
**kwargs,
)
def get_mm_max_tokens_per_item(
self, seq_len: int, mm_counts: Mapping[str, int]
) -> Mapping[str, int] | None:
config = self.get_hf_config()
# Unified field is `num_soft_tokens`. Tower-based parent uses
# `default_output_length`, hence the override.
tokens_per_image = config.vision_config.num_soft_tokens
merged_kwargs = self.ctx.get_merged_mm_kwargs({})
val, _ = _get_max_soft_tokens(merged_kwargs)
if isinstance(val, int) and val in _SUPPORTED_SOFT_TOKENS:
tokens_per_image = val
tokens: dict[str, int] = {"image": tokens_per_image}
if config.audio_config is not None:
processor = self.get_hf_processor()
tokens["audio"] = processor.audio_seq_length
num_frames = _VIDEO_MAX_FRAMES
mm_config = self.ctx.model_config.get_multimodal_config()
video_opts = mm_config.limit_per_prompt.get("video")
if (
isinstance(video_opts, VideoDummyOptions)
and video_opts.num_frames is not None
):
num_frames = min(num_frames, video_opts.num_frames)
tokens["video"] = num_frames * (_VIDEO_MAX_SOFT_TOKENS + 2 + 6)
return tokens
def _compute_num_soft_tokens(
self,
image_width: int,
image_height: int,
max_soft_tokens: int | None = None,
) -> int:
vision_cfg = self.get_hf_config().vision_config
patch_size = vision_cfg.patch_size
pooling_kernel_size = vision_cfg.pooling_kernel_size
if max_soft_tokens is None:
max_soft_tokens = vision_cfg.num_soft_tokens
unit = patch_size * pooling_kernel_size
max_patches = max_soft_tokens * pooling_kernel_size**2
num_patches_orig = (image_height / patch_size) * (image_width / patch_size)
scale = math.sqrt(max_patches / num_patches_orig)
target_h = max(unit, int(math.floor(image_height * scale / unit)) * unit)
target_w = max(unit, int(math.floor(image_width * scale / unit)) * unit)
num_patches = (target_h // patch_size) * (target_w // patch_size)
num_soft_tokens = num_patches // (pooling_kernel_size**2)
return min(num_soft_tokens, max_soft_tokens)
# ---------------------------------------------------------------------------
# Main model
# ---------------------------------------------------------------------------
@MULTIMODAL_REGISTRY.register_processor(
Gemma4MultiModalProcessor,
info=Gemma4UnifiedProcessingInfo,
dummy_inputs=Gemma4DummyInputsBuilder,
)
class Gemma4UnifiedForConditionalGeneration(Gemma4ForConditionalGeneration):
"""Encoder-free Gemma4 (Unified) for conditional generation.
Inherits multimodal embedding routing, PLE handling, bidirectional
attention helpers, language-model forward, LoRA, and pipeline-parallel
support from :class:`Gemma4ForConditionalGeneration`. Overrides only:
* ``__init__`` — builds the encoder-free vision embedder instead of
SigLIP/audio towers (LightOnOCR-style: ``nn.Module.__init__`` +
full rebuild, no ``super().__init__()``).
* ``hf_to_vllm_mapper`` — adds the ``model.vision_embedder.`` prefix.
* ``_process_image_input`` / ``_process_video_input`` /
``_process_audio_input`` — encoder-free projection paths.
* ``load_weights`` — ignore-prefix list excludes the absent towers.
* ``get_mm_mapping`` — no tower entries.
"""
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.embed_audio.": "embed_audio.",
"model.embed_vision.": "embed_vision.",
"model.language_model.": "language_model.model.",
"model.vision_embedder.": "vision_embedder.",
"lm_head.": "language_model.lm_head.",
"model": "language_model.model",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# LightOnOCR-style rebuild: do NOT call super().__init__ — that
# would build a SigLIP vision tower and an audio tower we don't
# need. Initialize nn.Module directly and assemble the
# encoder-free pipeline below.
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
# No towers — set to None so inherited load_weights / get_mm_mapping
# and any tower-aware logic short-circuits.
self.vision_tower = None
self.audio_tower = None
# ---- Encoder-free vision embedder ----
self.vision_embedder = (
Gemma4UnifiedVisionEmbedder(
config.vision_config,
quant_config=quant_config,
)
if config.vision_config is not None
else None
)
self.embed_vision = (
Gemma4MultimodalEmbedder(
config.vision_config,
config.text_config,
)
if config.vision_config is not None
else None
)
# ---- Encoder-free audio embedder ----
self.embed_audio = (
Gemma4MultimodalEmbedder(
config.audio_config,
config.text_config,
)
if config.audio_config is not None
else None
)
# ---- Language model (vLLM optimised) ----
with self._mark_language_model(vllm_config):
self.language_model: Gemma4ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Gemma4ForCausalLM"],
)
# PLE is disabled for the unified variant (text config defaults
# hidden_size_per_layer_input to 0). Skip the buffer.
ple_dim = getattr(
config.text_config,
"hidden_size_per_layer_input",
None,
)
if ple_dim is not None and ple_dim > 0:
self.per_layer_embeddings = torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.num_hidden_layers,
ple_dim,
device=self.language_model.model.embed_tokens.weight.device,
dtype=self.language_model.model.embed_tokens.weight.dtype,
)
else:
self.per_layer_embeddings = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
# --- Precompute full-attention layer indices for bidi clearing ---
self._full_attn_layer_idxs: frozenset[int] = frozenset()
text_config = config.text_config
if getattr(text_config, "use_bidirectional_attention", None) == "vision":
layer_types = getattr(text_config, "layer_types", None)
if layer_types:
self._full_attn_layer_idxs = frozenset(
i for i, lt in enumerate(layer_types) if lt != "sliding_attention"
)
# --- MixtureOfExperts delegation to language_model ---
self.expert_weights = self.language_model.expert_weights
self.moe_layers = self.language_model.moe_layers
self.num_moe_layers = self.language_model.num_moe_layers
self.num_logical_experts = self.language_model.num_logical_experts
self.num_physical_experts = self.language_model.num_physical_experts
self.num_local_physical_experts = self.language_model.num_local_physical_experts
self.num_routed_experts = self.language_model.num_routed_experts
self.num_expert_groups = self.language_model.num_expert_groups
self.num_shared_experts = self.language_model.num_shared_experts
self.num_redundant_experts = self.language_model.num_redundant_experts
gen_cfg = vllm_config.model_config.try_get_generation_config()
self._suppress_token_ids = gen_cfg.get("suppress_tokens") if gen_cfg else None
# ------------------------------------------------------------------ #
# Multimodal processing (encoder-free overrides)
# ------------------------------------------------------------------ #
def _process_image_input(
self,
image_input: Gemma4ImageInputs,
) -> list[torch.Tensor]:
"""Project raw image patches directly to LM space.
No vision tower: each image's pre-patchified pixel values are
embedded via Gemma4UnifiedVisionEmbedder, projected through
Gemma4MultimodalEmbedder, and padding patches (pp == -1) are
stripped per image.
"""
pixel_values = image_input["pixel_values"]
pixel_position_ids = image_input["pixel_position_ids"]
target_dtype = self.embed_vision.embedding_projection.weight.dtype
per_image_features: list[torch.Tensor] = []
for pv, pp in zip(pixel_values, pixel_position_ids, strict=True):
pv = pv.unsqueeze(0)
pp = pp.unsqueeze(0)
embedded = self.vision_embedder(pv, pp)
projected = self.embed_vision(embedded.to(target_dtype))
padding_mask = (pp.squeeze(0) == -1).all(dim=-1)
valid_features = projected.squeeze(0)[~padding_mask]
per_image_features.append(valid_features)
return per_image_features
def _process_video_input(
self,
video_input: dict[str, torch.Tensor],
) -> list[torch.Tensor]:
"""Project video frames to LM space, one frame at a time.
Frames are split per video, each frame is embedded + projected,
and per-frame valid embeddings are concatenated per video.
"""
pixel_values = video_input["pixel_values_videos"]
pixel_position_ids = video_input["pixel_position_ids_videos"]
frame_counts = video_input["video_frame_counts"]
target_dtype = self.embed_vision.embedding_projection.weight.dtype
if isinstance(frame_counts, torch.Tensor):
fc_list = frame_counts.tolist()
else:
fc_list = list(frame_counts)
pv_per_video = torch.split(pixel_values, fc_list, dim=0)
pp_per_video = torch.split(pixel_position_ids, fc_list, dim=0)
per_video_embeddings: list[torch.Tensor] = []
for pv_chunk, pp_chunk in zip(pv_per_video, pp_per_video):
frame_embs: list[torch.Tensor] = []
for i in range(pv_chunk.shape[0]):
pv = pv_chunk[i].unsqueeze(0)
pp = pp_chunk[i].unsqueeze(0)
embedded = self.vision_embedder(pv, pp)
projected = self.embed_vision(embedded.to(target_dtype))
padding_mask = (pp.squeeze(0) == -1).all(dim=-1)
frame_embs.append(projected.squeeze(0)[~padding_mask])
per_video_embeddings.append(torch.cat(frame_embs, dim=0))
return per_video_embeddings
def _process_audio_input(
self,
audio_input: Gemma4AudioInputs,
) -> list[torch.Tensor]:
"""Project raw waveform-frame features directly to LM space.
No audio tower: the per-frame raw features are passed straight
through the multimodal embedder, then padding is stripped.
"""
input_features = audio_input["input_features_padded"].squeeze(1)
input_features_mask = audio_input["input_features_mask"].squeeze(1)
target_dtype = self.embed_audio.embedding_projection.weight.dtype
audio_features = self.embed_audio(input_features.to(target_dtype))
per_audio: list[torch.Tensor] = []
for enc, mask in zip(audio_features, input_features_mask, strict=True):
per_audio.append(enc[mask])
return per_audio
# ------------------------------------------------------------------ #
# Weight loading
# ------------------------------------------------------------------ #
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
ignore_prefixes = [
# Vestigial Gemma3n-style embedding tables not used by
# Gemma4MultimodalEmbedder (which has only projection + norm).
"embed_vision.embedding.",
"embed_audio.embedding.",
]
if self.embed_audio is None:
ignore_prefixes.append("embed_audio.")
loader = AutoWeightsLoader(
self,
ignore_unexpected_prefixes=ignore_prefixes,
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
# ------------------------------------------------------------------ #
# LoRA / multimodal mapping
# ------------------------------------------------------------------ #
def get_mm_mapping(self) -> MultiModelKeys:
"""Module prefix mapping for the encoder-free model (no towers)."""
connectors = ["embed_vision"]
if self.embed_audio is not None:
connectors.append("embed_audio")
return MultiModelKeys.from_string_field(
language_model="language_model",
connector=connectors,
tower_model=[],
)
+4
View File
@@ -406,6 +406,10 @@ _MULTIMODAL_MODELS = {
"Gemma3nForConditionalGeneration",
),
"Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
"Gemma4UnifiedForConditionalGeneration": (
"gemma4_unified",
"Gemma4UnifiedForConditionalGeneration",
),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
@@ -556,6 +556,8 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"gemma4": Gemma4ModelArchConfigConvertor,
"gemma4_mtp": Gemma4MTPModelArchConfigConvertor,
"gemma4_text": Gemma4ModelArchConfigConvertor,
"gemma4_unified": Gemma4ModelArchConfigConvertor,
"gemma4_unified_text": Gemma4ModelArchConfigConvertor,
"glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor,
"glm_ocr_mtp": GLM4MoeMTPModelArchConfigConvertor,
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
+1
View File
@@ -1232,6 +1232,7 @@ class SpecDecodeBaseProposer:
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"Gemma4ForConditionalGeneration",
"Gemma4UnifiedForConditionalGeneration",
"Step3p7ForConditionalGeneration",
]:
self.model.config.image_token_index = target_model.config.image_token_id
+2
View File
@@ -2467,6 +2467,8 @@ class GPUModelRunner(
image_doc_ranges = []
req_state = self.requests[req_id]
for mm_feature in req_state.mm_features:
if mm_feature.modality == "audio":
continue
pos_info = mm_feature.mm_position
img_doc_range = pos_info.extract_embeds_range()
for r in img_doc_range: