mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Model] Add Gemma4 Unified (encoder-free) support (#44429)
Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
This commit is contained in:
@@ -24,10 +24,11 @@ vllm serve google/gemma-4-E2B-it \
|
||||
--speculative-config '{"method":"mtp","model":"gg-hf-am/gemma-4-E2B-it-assistant","num_speculative_tokens":1}'
|
||||
```
|
||||
|
||||
The E2B, E4B, 26B-A4B, and 31B Gemma 4 IT assistant checkpoints are supported
|
||||
when their configuration uses `model_type: gemma4_assistant`. vLLM maps those
|
||||
checkpoints to `Gemma4MTPModel` internally and wires the assistant layers to
|
||||
share KV cache with the target model.
|
||||
The E2B, E4B, 12B, 26B-A4B, and 31B Gemma 4 IT assistant checkpoints are supported.
|
||||
Tower-based variants use `model_type: gemma4_assistant` and the encoder-free
|
||||
Gemma 4 Unified variant (12B) uses `model_type: gemma4_unified_assistant`.
|
||||
vLLM maps both to `Gemma4MTPModel` internally and wires the assistant layers
|
||||
to share KV cache with the target model.
|
||||
|
||||
If an older vLLM release logs `SpeculativeConfig(method='draft_model', ...)`
|
||||
for a Gemma 4 assistant checkpoint, that release is treating the assistant as a
|
||||
|
||||
@@ -562,6 +562,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>E+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
||||
| `Gemma4ForConditionalGeneration` | Gemma 4 | T + I<sup>+</sup> + V + A<sup>*</sup> | `google/gemma-4-E2B-it`, etc. | | ✅︎ |
|
||||
| `Gemma4UnifiedForConditionalGeneration` | Gemma 4 Unified | T + I<sup>+</sup> + V + A | `google/gemma-4-12B-it`, etc. | | ✅︎ |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
|
||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
|
||||
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
|
||||
@@ -664,10 +665,16 @@ Some models are supported only via the [Transformers modeling backend](#transfor
|
||||
For `Gemma4ForConditionalGeneration`:
|
||||
- audio input is only supported by the `gemma-4-E2B` and `gemma-4-E4B` variants.
|
||||
- The model does not ingest videos directly. However, vLLM’s Gemma 4 implementation supports video inputs by handling video processing internally. Users can send videos directly in the message structure to vLLM, where they are converted into text and image frames before being passed to the model.
|
||||
- Gemma 4 assistant checkpoints for speculative decoding use vLLM's Gemma
|
||||
- Gemma 4 assistant checkpoints for speculative decoding use vLLM’s Gemma
|
||||
4 MTP path, not generic draft-model speculative decoding. See the
|
||||
[Gemma 4 assistant model MTP example](../features/speculative_decoding/mtp.md#gemma-4-assistant-models).
|
||||
|
||||
!!! note
|
||||
For `Gemma4UnifiedForConditionalGeneration`:
|
||||
- This is the encoder-free Gemma 4 variant (e.g. `gemma-4-12B-it`). Unlike the tower-based `Gemma4ForConditionalGeneration`, it has **no SigLIP vision encoder** and **no audio encoder**. Raw pixel patches are projected directly into LM space via a Dense+LayerNorm pipeline with factorized positional embeddings, and raw audio waveform frames are projected directly through a multimodal embedder.
|
||||
- All modalities (image, video, audio) are supported.
|
||||
- Gemma 4 Unified assistant checkpoints (`model_type: gemma4_unified_assistant`) use the same MTP path as the tower-based variant. See the [Gemma 4 assistant model MTP example](../features/speculative_decoding/mtp.md#gemma-4-assistant-models).
|
||||
|
||||
!!! note
|
||||
For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc.), InternVL3 and InternVL3.5 have video inputs support currently.
|
||||
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from vllm.model_executor.models.gemma4_mm import Gemma4ImagePixelInputs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
|
||||
from ....conftest import ImageTestAssets
|
||||
from ...utils import build_model_context
|
||||
|
||||
# The Unified model ID for testing purposes
|
||||
GEMMA4_UNIFIED_MODEL_ID = "google/gemma-4-12B-it"
|
||||
|
||||
|
||||
def test_gemma4_unified_image_schema_accepts_variable_patch_counts():
|
||||
Gemma4ImagePixelInputs(
|
||||
pixel_values=[
|
||||
torch.randn(10080, 768),
|
||||
torch.randn(2520, 768),
|
||||
],
|
||||
pixel_position_ids=[
|
||||
torch.zeros(10080, 2, dtype=torch.long),
|
||||
torch.zeros(2520, 2, dtype=torch.long),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_gemma4_unified_image_batching_keeps_variable_patch_counts_unstacked():
|
||||
field = MultiModalFieldConfig.batched("image").field
|
||||
elems = field.build_elems(
|
||||
"image",
|
||||
"pixel_values",
|
||||
[torch.randn(10080, 768), torch.randn(2520, 768)],
|
||||
)
|
||||
|
||||
reduced = field.reduce_data(list(elems))
|
||||
|
||||
assert isinstance(reduced, list)
|
||||
assert [tensor.shape for tensor in reduced] == [
|
||||
torch.Size([10080, 768]),
|
||||
torch.Size([2520, 768]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"image_width,image_height,max_soft_tokens",
|
||||
[
|
||||
(900, 3, 280),
|
||||
(3, 900, 280),
|
||||
(900, 3, 70),
|
||||
(4000, 2, 1120),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model_id", [GEMMA4_UNIFIED_MODEL_ID])
|
||||
def test_compute_num_soft_tokens_does_not_exceed_max_soft_tokens(
|
||||
model_id: str,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
max_soft_tokens: int,
|
||||
):
|
||||
"""Verify ``_compute_num_soft_tokens`` caps output at ``max_soft_tokens``."""
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
num_soft_tokens = processor.info._compute_num_soft_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
max_soft_tokens=max_soft_tokens,
|
||||
)
|
||||
|
||||
assert num_soft_tokens <= max_soft_tokens, (
|
||||
f"_compute_num_soft_tokens returned {num_soft_tokens} for "
|
||||
f"image_width={image_width}, image_height={image_height}, "
|
||||
f"max_soft_tokens={max_soft_tokens} — exceeds the cap."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mm_processor_kwargs", "expected_image_tokens"),
|
||||
[
|
||||
({}, 280),
|
||||
({"max_soft_tokens": 70}, 70),
|
||||
({"max_soft_tokens": 280}, 280),
|
||||
({"max_soft_tokens": 1120}, 1120),
|
||||
({"images_kwargs": {"max_soft_tokens": 560}}, 560),
|
||||
({"images_kwargs": None}, 280),
|
||||
({"images_kwargs": "not-a-dict"}, 280),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model_id", [GEMMA4_UNIFIED_MODEL_ID])
|
||||
def test_get_mm_max_tokens_per_item_respects_configured_max_soft_tokens(
|
||||
model_id: str,
|
||||
mm_processor_kwargs: dict[str, object],
|
||||
expected_image_tokens: int,
|
||||
):
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
limit_mm_per_prompt={"image": 1, "video": 1},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
tokens = processor.info.get_mm_max_tokens_per_item(
|
||||
seq_len=ctx.model_config.max_model_len,
|
||||
mm_counts={"image": 1, "video": 1},
|
||||
)
|
||||
|
||||
assert tokens is not None
|
||||
assert tokens["image"] == expected_image_tokens
|
||||
assert tokens["video"] == 32 * (70 + 2 + 6)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("limit_mm_per_prompt", "expected_video_tokens"),
|
||||
[
|
||||
({"video": 1}, 32 * (70 + 2 + 6)),
|
||||
({"video": {"count": 1}}, 32 * (70 + 2 + 6)),
|
||||
({"video": {"count": 1, "num_frames": 1}}, 1 * (70 + 2 + 6)),
|
||||
({"video": {"count": 1, "num_frames": 8}}, 8 * (70 + 2 + 6)),
|
||||
({"video": {"count": 1, "num_frames": 32}}, 32 * (70 + 2 + 6)),
|
||||
({"video": {"count": 1, "num_frames": 40}}, 32 * (70 + 2 + 6)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model_id", [GEMMA4_UNIFIED_MODEL_ID])
|
||||
def test_get_mm_max_tokens_per_item_respects_configured_video_num_frames(
|
||||
model_id: str,
|
||||
limit_mm_per_prompt: Mapping[str, int | Mapping[str, int]],
|
||||
expected_video_tokens: int,
|
||||
):
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
tokens = processor.info.get_mm_max_tokens_per_item(
|
||||
seq_len=ctx.model_config.max_model_len,
|
||||
mm_counts={"video": 1},
|
||||
)
|
||||
|
||||
assert tokens is not None
|
||||
assert tokens["image"] == 280
|
||||
assert tokens["video"] == expected_video_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", [GEMMA4_UNIFIED_MODEL_ID])
|
||||
def test_get_prompt_updates_respects_nested_max_soft_tokens(model_id: str):
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs={"images_kwargs": {"max_soft_tokens": 560}},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
image = PILImage.new("RGB", (1000, 1000), color="white")
|
||||
image_size = image.size
|
||||
mm_items = processor.info.parse_mm_data({"image": image})
|
||||
|
||||
prompt_update = processor._get_prompt_updates(mm_items, {}, {})[0]
|
||||
replacement = prompt_update.resolve(0).content.full
|
||||
expected = processor.info.get_image_repl(
|
||||
image_width=image_size[0],
|
||||
image_height=image_size[1],
|
||||
processor=processor.info.get_hf_processor(),
|
||||
max_soft_tokens=560,
|
||||
).full
|
||||
|
||||
assert replacement == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", [GEMMA4_UNIFIED_MODEL_ID])
|
||||
def test_limit_mm_per_prompt(
|
||||
image_assets: ImageTestAssets,
|
||||
model_id: str,
|
||||
):
|
||||
"""Test that limit_mm_per_prompt restricts multiple images correctly."""
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs={},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
prompt = "<image><image>"
|
||||
images = [asset.pil_image for asset in image_assets][:2]
|
||||
if len(images) < 2:
|
||||
images = [images[0], images[0]]
|
||||
|
||||
mm_data = {"image": images}
|
||||
|
||||
with pytest.raises(ValueError, match="At most 1 image"):
|
||||
processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
@@ -917,6 +917,13 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"google/gemma-4-E2B-it",
|
||||
min_transformers_version="5.5.0",
|
||||
),
|
||||
# TODO: update min_transformers_version when Gemma4 Unified lands in
|
||||
# a stable transformers release.
|
||||
"Gemma4UnifiedForConditionalGeneration": _HfExamplesInfo(
|
||||
"google/gemma-4-12B-it",
|
||||
min_transformers_version="5.8.0",
|
||||
is_available_online=False,
|
||||
),
|
||||
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
|
||||
"GlmAsrForConditionalGeneration": _HfExamplesInfo(
|
||||
"zai-org/GLM-ASR-Nano-2512",
|
||||
|
||||
@@ -509,7 +509,7 @@ class SpeculativeConfig:
|
||||
{"n_predict": n_predict, "architectures": ["HYV3MTPModel"]}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "gemma4_assistant":
|
||||
if hf_config.model_type in ("gemma4_assistant", "gemma4_unified_assistant"):
|
||||
hf_config.model_type = "gemma4_mtp"
|
||||
text_config = getattr(hf_config, "text_config", hf_config)
|
||||
# The assistant runs all decoder layers in a single forward
|
||||
|
||||
@@ -597,6 +597,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"Gemma3TextModel": Gemma3TextModelConfig,
|
||||
"Gemma4ForCausalLM": Gemma4Config,
|
||||
"Gemma4ForConditionalGeneration": Gemma4Config,
|
||||
"Gemma4UnifiedForConditionalGeneration": Gemma4Config,
|
||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||
"GteModel": SnowflakeGteNewModelConfig,
|
||||
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||
|
||||
@@ -1051,11 +1051,14 @@ class Gemma4Model(nn.Module, EagleModelMixin):
|
||||
# Final norm: output = norm(x) * weight
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Embedding scale = sqrt(hidden_size)
|
||||
# Downcast to model dtype (bfloat16 etc.) for numerical parity
|
||||
# Embedding scale = sqrt(hidden_size), cast to model dtype to avoid
|
||||
# mixed-precision drift from bf16 * fp32 across deep stacks.
|
||||
self.register_buffer(
|
||||
"normalizer",
|
||||
torch.tensor(config.hidden_size**0.5),
|
||||
torch.tensor(
|
||||
config.hidden_size**0.5,
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -121,7 +121,7 @@ class Gemma4ImagePixelInputs(TensorSchema):
|
||||
- np: Number of patches (max_patches = max_soft_tokens * pooling_kernel_size²)
|
||||
- pp: Patch pixels (patch_size² * 3)
|
||||
|
||||
The HF Gemma4ImageProcessor outputs pixel_values as
|
||||
The Gemma4 image processor outputs pixel_values as
|
||||
(batch, max_patches, patch_pixels) — already patchified with
|
||||
zero-padding for patches beyond the real image content.
|
||||
pixel_position_ids provides (x, y) coordinates per patch,
|
||||
@@ -341,6 +341,29 @@ class Gemma4ProcessingInfo(BaseProcessingInfo):
|
||||
)
|
||||
return PromptUpdateDetails.select_token_id(token_ids, processor.image_token_id)
|
||||
|
||||
@staticmethod
|
||||
def _compute_audio_num_tokens(
|
||||
num_samples: int, sampling_rate: int, audio_seq_length: int
|
||||
) -> int:
|
||||
"""Replicate the audio encoder's sequence-length arithmetic.
|
||||
|
||||
Mirrors: mel framing (_unfold in Gemma4AudioFeatureExtractor)
|
||||
followed by two Conv2d subsampling layers (kernel=3, stride=2,
|
||||
semicausal padding top=1, bottom=1), capped at audio_seq_length.
|
||||
"""
|
||||
frame_length = int(round(sampling_rate * 20.0 / 1000.0))
|
||||
hop_length = int(round(sampling_rate * 10.0 / 1000.0))
|
||||
frame_size_for_unfold = frame_length + 1
|
||||
pad_left = frame_length // 2
|
||||
padded_samples = num_samples + pad_left
|
||||
num_mel_frames = (padded_samples - frame_size_for_unfold) // hop_length + 1
|
||||
if num_mel_frames <= 0:
|
||||
return 0
|
||||
t = num_mel_frames
|
||||
for _ in range(2):
|
||||
t = (t + 2 - 3) // 2 + 1
|
||||
return min(t, audio_seq_length)
|
||||
|
||||
def get_audio_repl(
|
||||
self,
|
||||
*,
|
||||
@@ -350,20 +373,21 @@ class Gemma4ProcessingInfo(BaseProcessingInfo):
|
||||
"""Return the dynamic audio token sequence for this audio.
|
||||
|
||||
Computes the number of soft tokens from the audio waveform
|
||||
length using ``ceil(duration_ms / audio_ms_per_token)``.
|
||||
length by replicating the audio encoder's sequence-length
|
||||
arithmetic (mel framing + two Conv2d subsampling layers).
|
||||
"""
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
sampling_rate = processor.feature_extractor.sampling_rate
|
||||
num_tokens = processor._compute_audio_num_tokens(
|
||||
torch.zeros(audio_len), sampling_rate
|
||||
num_tokens = self._compute_audio_num_tokens(
|
||||
audio_len, sampling_rate, processor.audio_seq_length
|
||||
)
|
||||
config = self.get_hf_config()
|
||||
token_ids = (
|
||||
[config.boa_token_id]
|
||||
+ [processor.audio_token_id] * num_tokens
|
||||
+ [config.eoa_token_id]
|
||||
+ [getattr(config, "eoa_token_id", config.eoa_token_index)]
|
||||
)
|
||||
return PromptUpdateDetails.select_token_id(token_ids, processor.audio_token_id)
|
||||
|
||||
@@ -988,18 +1012,35 @@ class Gemma4ForConditionalGeneration(
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# Only quantize towers when the quant method supports their
|
||||
# dimensions. BNB/torchao handle arbitrary sizes; other methods
|
||||
# (Marlin, FP8, …) require dimensions divisible by 64, which
|
||||
# the vision tower (intermediate_size=4304) does not satisfy.
|
||||
if quant_config and quant_config.get_name() in [
|
||||
"bitsandbytes",
|
||||
"torchao",
|
||||
]:
|
||||
tower_quant = quant_config
|
||||
else:
|
||||
vision_cfg = config.vision_config
|
||||
quantizable = (
|
||||
vision_cfg.hidden_size % 64 == 0
|
||||
and vision_cfg.intermediate_size % 64 == 0
|
||||
)
|
||||
tower_quant = quant_config if quantizable else None
|
||||
|
||||
# ---- Vision tower (shared by image and video) ----
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.embed_vision = Gemma4MultimodalEmbedder(
|
||||
config.vision_config,
|
||||
config.text_config,
|
||||
quant_config=quant_config,
|
||||
quant_config=tower_quant,
|
||||
prefix=maybe_prefix(prefix, "embed_vision"),
|
||||
)
|
||||
recursive_replace_linear(
|
||||
self.vision_tower,
|
||||
quant_config,
|
||||
tower_quant,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
@@ -1015,12 +1056,12 @@ class Gemma4ForConditionalGeneration(
|
||||
self.embed_audio = Gemma4MultimodalEmbedder(
|
||||
config.audio_config,
|
||||
config.text_config,
|
||||
quant_config=quant_config,
|
||||
quant_config=tower_quant,
|
||||
prefix=maybe_prefix(prefix, "embed_audio"),
|
||||
)
|
||||
recursive_replace_linear(
|
||||
self.audio_tower,
|
||||
quant_config,
|
||||
tower_quant,
|
||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||
)
|
||||
else:
|
||||
@@ -1039,13 +1080,13 @@ class Gemma4ForConditionalGeneration(
|
||||
# Pre-allocate PLE buffer for CUDA graph compatibility.
|
||||
# Some variants have hidden_size_per_layer_input=None (no PLE).
|
||||
ple_dim = config.text_config.hidden_size_per_layer_input
|
||||
if ple_dim is not None:
|
||||
if ple_dim is not None and ple_dim > 0:
|
||||
self.per_layer_embeddings = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.num_hidden_layers,
|
||||
ple_dim,
|
||||
device=(self.language_model.model.embed_tokens.weight.device),
|
||||
dtype=(self.language_model.model.embed_tokens.weight.dtype),
|
||||
device=self.language_model.model.embed_tokens.weight.device,
|
||||
dtype=self.language_model.model.embed_tokens.weight.dtype,
|
||||
)
|
||||
else:
|
||||
self.per_layer_embeddings = None
|
||||
@@ -1076,6 +1117,9 @@ class Gemma4ForConditionalGeneration(
|
||||
self.num_shared_experts = self.language_model.num_shared_experts
|
||||
self.num_redundant_experts = self.language_model.num_redundant_experts
|
||||
|
||||
gen_cfg = vllm_config.model_config.try_get_generation_config()
|
||||
self._suppress_token_ids = gen_cfg.get("suppress_tokens") if gen_cfg else None
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Input parsing
|
||||
# ------------------------------------------------------------------ #
|
||||
@@ -1424,8 +1468,7 @@ class Gemma4ForConditionalGeneration(
|
||||
input_features = audio_input["input_features_padded"].squeeze(1)
|
||||
input_features_mask = audio_input["input_features_mask"].squeeze(1)
|
||||
|
||||
# Run audio tower — mask uses standard HF convention
|
||||
# (True=valid, False=padding).
|
||||
# Run audio tower — mask convention: True=valid, False=padding.
|
||||
audio_outputs = self.audio_tower(input_features, input_features_mask)
|
||||
if isinstance(audio_outputs, tuple):
|
||||
audio_encodings, audio_mask = audio_outputs
|
||||
@@ -1436,8 +1479,8 @@ class Gemma4ForConditionalGeneration(
|
||||
# Project into LM embedding space.
|
||||
audio_features = self.embed_audio(inputs_embeds=audio_encodings)
|
||||
|
||||
# Strip padding per-batch element: only keep real (non-padding)
|
||||
# tokens. audio_mask is True for valid positions (HF convention).
|
||||
# Strip padding per-batch element: only keep valid (non-padding)
|
||||
# tokens.
|
||||
per_audio = []
|
||||
for enc, mask in zip(audio_features, audio_mask, strict=True):
|
||||
per_audio.append(enc[mask]) # [num_real, hidden_size]
|
||||
@@ -1559,7 +1602,10 @@ class Gemma4ForConditionalGeneration(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
logits = self.language_model.compute_logits(hidden_states)
|
||||
if logits is not None and self._suppress_token_ids:
|
||||
logits[:, self._suppress_token_ids] = -float("inf")
|
||||
return logits
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Bidirectional attention helpers
|
||||
@@ -1617,8 +1663,7 @@ class Gemma4ForConditionalGeneration(
|
||||
"embed_vision.embedding.",
|
||||
"embed_audio.embedding.",
|
||||
]
|
||||
# Models without audio tower should skip
|
||||
# audio weights entirely.
|
||||
# Models without audio tower should skip audio weights entirely.
|
||||
if self.audio_tower is None:
|
||||
ignore_prefixes.extend(
|
||||
[
|
||||
|
||||
@@ -279,11 +279,19 @@ class Gemma4MTPDecoderLayer(nn.Module):
|
||||
else config.head_dim
|
||||
)
|
||||
|
||||
use_k_eq_v = is_full_attention and getattr(config, "attention_k_eq_v", False)
|
||||
if use_k_eq_v:
|
||||
num_kv_heads = getattr(
|
||||
config, "num_global_key_value_heads", config.num_key_value_heads
|
||||
)
|
||||
else:
|
||||
num_kv_heads = config.num_key_value_heads
|
||||
|
||||
self.self_attn = Gemma4MTPAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
@@ -545,6 +553,10 @@ class Gemma4MTP(nn.Module):
|
||||
else:
|
||||
self.masked_embedding = None
|
||||
|
||||
draft_cfg = vllm_config.speculative_config.draft_model_config
|
||||
gen_cfg = draft_cfg.try_get_generation_config()
|
||||
self._suppress_token_ids = gen_cfg.get("suppress_tokens") if gen_cfg else None
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
@@ -589,11 +601,15 @@ class Gemma4MTP(nn.Module):
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor | None:
|
||||
if self.masked_embedding is not None:
|
||||
return self.masked_embedding(
|
||||
logits = self.masked_embedding(
|
||||
hidden_states,
|
||||
self._get_full_lm_head_weight(),
|
||||
)
|
||||
return self.logits_processor(self.lm_head, hidden_states)
|
||||
else:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
if logits is not None and self._suppress_token_ids:
|
||||
logits[:, self._suppress_token_ids] = -float("inf")
|
||||
return logits
|
||||
|
||||
def get_top_tokens(
|
||||
self,
|
||||
|
||||
@@ -0,0 +1,466 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Gemma 4 Unified multimodal model (encoder-free image + audio + video).
|
||||
|
||||
The Unified Gemma4 variant has no SigLIP vision tower and no audio tower.
|
||||
Raw pixel patches are projected directly to LM space via a Dense+LayerNorm
|
||||
pipeline with factorized 2D positional embeddings (Gemma4UnifiedVisionEmbedder),
|
||||
then routed through the same Gemma4MultimodalEmbedder used by the tower-based
|
||||
variant. Audio inputs are raw waveform frames projected directly through the
|
||||
multimodal embedder.
|
||||
|
||||
This module subclasses Gemma4ForConditionalGeneration from gemma4_mm rather
|
||||
than reimplementing it from scratch. Only the multimodal pipeline differs;
|
||||
the language model, MTP integration, bidirectional attention helpers,
|
||||
embedding/forward path, and LoRA support are all inherited unchanged.
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.gemma4_unified.configuration_gemma4_unified import (
|
||||
Gemma4UnifiedConfig,
|
||||
)
|
||||
from transformers.models.gemma4_unified.processing_gemma4_unified import (
|
||||
Gemma4UnifiedProcessor,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import VideoDummyOptions
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM
|
||||
from vllm.model_executor.models.gemma4_mm import (
|
||||
_SUPPORTED_SOFT_TOKENS,
|
||||
_VIDEO_MAX_FRAMES,
|
||||
_VIDEO_MAX_SOFT_TOKENS,
|
||||
Gemma4AudioInputs,
|
||||
Gemma4DummyInputsBuilder,
|
||||
Gemma4ForConditionalGeneration,
|
||||
Gemma4ImageInputs,
|
||||
Gemma4ImagePixelInputs,
|
||||
Gemma4MultimodalEmbedder,
|
||||
Gemma4MultiModalProcessor,
|
||||
Gemma4ProcessingInfo,
|
||||
_get_max_soft_tokens,
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
# Re-export so tests/code targeting the unified variant can import from here
|
||||
# rather than reaching into gemma4_mm.
|
||||
__all__ = [
|
||||
"Gemma4ImagePixelInputs",
|
||||
"Gemma4UnifiedVisionEmbedder",
|
||||
"Gemma4UnifiedProcessingInfo",
|
||||
"Gemma4UnifiedForConditionalGeneration",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder-free vision embedder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Gemma4UnifiedVisionEmbedder(nn.Module):
|
||||
"""Encoder-free vision embedder for Gemma4 Unified variants.
|
||||
|
||||
Projects raw pixel patches to LM space via dense projection and
|
||||
factorized 2D positional embeddings. Replaces the SigLIP vision
|
||||
tower used by the tower-based Gemma4 variant.
|
||||
|
||||
Pipeline: raw patches → LN₁ → Dense → LN₂ → +factorized_posemb → LN₃.
|
||||
"""
|
||||
|
||||
def __init__(self, config, quant_config=None):
|
||||
super().__init__()
|
||||
patch_dim = config.model_patch_size**2 * 3
|
||||
mm_embed_dim = config.mm_embed_dim
|
||||
|
||||
self.patch_ln1 = nn.LayerNorm(patch_dim)
|
||||
self.patch_dense = ColumnParallelLinear(
|
||||
patch_dim,
|
||||
mm_embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
gather_output=True,
|
||||
)
|
||||
self.patch_ln2 = nn.LayerNorm(mm_embed_dim)
|
||||
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.zeros(config.mm_posemb_size, 2, mm_embed_dim)
|
||||
)
|
||||
self.pos_norm = nn.LayerNorm(mm_embed_dim)
|
||||
|
||||
def _factorized_posemb(self, positions_xy: torch.Tensor) -> torch.Tensor:
|
||||
clamped_pos = positions_xy.clamp(min=0).long()
|
||||
valid_mask = positions_xy != -1
|
||||
|
||||
pos_embs = torch.zeros(
|
||||
*positions_xy.shape[:-1],
|
||||
self.pos_embedding.shape[-1],
|
||||
device=positions_xy.device,
|
||||
dtype=self.pos_embedding.dtype,
|
||||
)
|
||||
for i in range(2):
|
||||
axis_pe = self.pos_embedding[:, i, :][clamped_pos[..., i]]
|
||||
mask = valid_mask[..., i].unsqueeze(-1).to(axis_pe.dtype)
|
||||
pos_embs = pos_embs + (axis_pe * mask)
|
||||
return pos_embs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
pixel_position_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.patch_ln1(pixel_values.to(self.pos_embedding.dtype))
|
||||
hidden_states, _ = self.patch_dense(hidden_states)
|
||||
hidden_states = self.patch_ln2(hidden_states)
|
||||
|
||||
pos_embs = self._factorized_posemb(pixel_position_ids)
|
||||
hidden_states = hidden_states + pos_embs
|
||||
hidden_states = self.pos_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Processing info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Gemma4UnifiedProcessingInfo(Gemma4ProcessingInfo):
|
||||
"""ProcessingInfo for the Gemma4 Unified variant.
|
||||
|
||||
Two field-name differences from the tower-based parent:
|
||||
* config → ``Gemma4UnifiedConfig`` (not ``Gemma4Config``)
|
||||
* vision_config.``num_soft_tokens`` (not ``default_output_length``)
|
||||
|
||||
Everything else (token sequencing, audio limits, video frame budget,
|
||||
parser construction) is inherited unchanged.
|
||||
"""
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Gemma4UnifiedConfig)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> Gemma4UnifiedProcessor:
|
||||
return self.ctx.get_hf_processor(
|
||||
Gemma4UnifiedProcessor,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self, seq_len: int, mm_counts: Mapping[str, int]
|
||||
) -> Mapping[str, int] | None:
|
||||
config = self.get_hf_config()
|
||||
# Unified field is `num_soft_tokens`. Tower-based parent uses
|
||||
# `default_output_length`, hence the override.
|
||||
tokens_per_image = config.vision_config.num_soft_tokens
|
||||
merged_kwargs = self.ctx.get_merged_mm_kwargs({})
|
||||
val, _ = _get_max_soft_tokens(merged_kwargs)
|
||||
if isinstance(val, int) and val in _SUPPORTED_SOFT_TOKENS:
|
||||
tokens_per_image = val
|
||||
tokens: dict[str, int] = {"image": tokens_per_image}
|
||||
if config.audio_config is not None:
|
||||
processor = self.get_hf_processor()
|
||||
tokens["audio"] = processor.audio_seq_length
|
||||
num_frames = _VIDEO_MAX_FRAMES
|
||||
mm_config = self.ctx.model_config.get_multimodal_config()
|
||||
video_opts = mm_config.limit_per_prompt.get("video")
|
||||
if (
|
||||
isinstance(video_opts, VideoDummyOptions)
|
||||
and video_opts.num_frames is not None
|
||||
):
|
||||
num_frames = min(num_frames, video_opts.num_frames)
|
||||
tokens["video"] = num_frames * (_VIDEO_MAX_SOFT_TOKENS + 2 + 6)
|
||||
return tokens
|
||||
|
||||
def _compute_num_soft_tokens(
|
||||
self,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
max_soft_tokens: int | None = None,
|
||||
) -> int:
|
||||
vision_cfg = self.get_hf_config().vision_config
|
||||
patch_size = vision_cfg.patch_size
|
||||
pooling_kernel_size = vision_cfg.pooling_kernel_size
|
||||
|
||||
if max_soft_tokens is None:
|
||||
max_soft_tokens = vision_cfg.num_soft_tokens
|
||||
|
||||
unit = patch_size * pooling_kernel_size
|
||||
max_patches = max_soft_tokens * pooling_kernel_size**2
|
||||
num_patches_orig = (image_height / patch_size) * (image_width / patch_size)
|
||||
scale = math.sqrt(max_patches / num_patches_orig)
|
||||
target_h = max(unit, int(math.floor(image_height * scale / unit)) * unit)
|
||||
target_w = max(unit, int(math.floor(image_width * scale / unit)) * unit)
|
||||
num_patches = (target_h // patch_size) * (target_w // patch_size)
|
||||
num_soft_tokens = num_patches // (pooling_kernel_size**2)
|
||||
return min(num_soft_tokens, max_soft_tokens)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Gemma4MultiModalProcessor,
|
||||
info=Gemma4UnifiedProcessingInfo,
|
||||
dummy_inputs=Gemma4DummyInputsBuilder,
|
||||
)
|
||||
class Gemma4UnifiedForConditionalGeneration(Gemma4ForConditionalGeneration):
|
||||
"""Encoder-free Gemma4 (Unified) for conditional generation.
|
||||
|
||||
Inherits multimodal embedding routing, PLE handling, bidirectional
|
||||
attention helpers, language-model forward, LoRA, and pipeline-parallel
|
||||
support from :class:`Gemma4ForConditionalGeneration`. Overrides only:
|
||||
|
||||
* ``__init__`` — builds the encoder-free vision embedder instead of
|
||||
SigLIP/audio towers (LightOnOCR-style: ``nn.Module.__init__`` +
|
||||
full rebuild, no ``super().__init__()``).
|
||||
* ``hf_to_vllm_mapper`` — adds the ``model.vision_embedder.`` prefix.
|
||||
* ``_process_image_input`` / ``_process_video_input`` /
|
||||
``_process_audio_input`` — encoder-free projection paths.
|
||||
* ``load_weights`` — ignore-prefix list excludes the absent towers.
|
||||
* ``get_mm_mapping`` — no tower entries.
|
||||
"""
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.embed_audio.": "embed_audio.",
|
||||
"model.embed_vision.": "embed_vision.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_embedder.": "vision_embedder.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model": "language_model.model",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# LightOnOCR-style rebuild: do NOT call super().__init__ — that
|
||||
# would build a SigLIP vision tower and an audio tower we don't
|
||||
# need. Initialize nn.Module directly and assemble the
|
||||
# encoder-free pipeline below.
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# No towers — set to None so inherited load_weights / get_mm_mapping
|
||||
# and any tower-aware logic short-circuits.
|
||||
self.vision_tower = None
|
||||
self.audio_tower = None
|
||||
|
||||
# ---- Encoder-free vision embedder ----
|
||||
self.vision_embedder = (
|
||||
Gemma4UnifiedVisionEmbedder(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if config.vision_config is not None
|
||||
else None
|
||||
)
|
||||
self.embed_vision = (
|
||||
Gemma4MultimodalEmbedder(
|
||||
config.vision_config,
|
||||
config.text_config,
|
||||
)
|
||||
if config.vision_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# ---- Encoder-free audio embedder ----
|
||||
self.embed_audio = (
|
||||
Gemma4MultimodalEmbedder(
|
||||
config.audio_config,
|
||||
config.text_config,
|
||||
)
|
||||
if config.audio_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# ---- Language model (vLLM optimised) ----
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model: Gemma4ForCausalLM = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
architectures=["Gemma4ForCausalLM"],
|
||||
)
|
||||
|
||||
# PLE is disabled for the unified variant (text config defaults
|
||||
# hidden_size_per_layer_input to 0). Skip the buffer.
|
||||
ple_dim = getattr(
|
||||
config.text_config,
|
||||
"hidden_size_per_layer_input",
|
||||
None,
|
||||
)
|
||||
if ple_dim is not None and ple_dim > 0:
|
||||
self.per_layer_embeddings = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.num_hidden_layers,
|
||||
ple_dim,
|
||||
device=self.language_model.model.embed_tokens.weight.device,
|
||||
dtype=self.language_model.model.embed_tokens.weight.dtype,
|
||||
)
|
||||
else:
|
||||
self.per_layer_embeddings = None
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
# --- Precompute full-attention layer indices for bidi clearing ---
|
||||
self._full_attn_layer_idxs: frozenset[int] = frozenset()
|
||||
text_config = config.text_config
|
||||
if getattr(text_config, "use_bidirectional_attention", None) == "vision":
|
||||
layer_types = getattr(text_config, "layer_types", None)
|
||||
if layer_types:
|
||||
self._full_attn_layer_idxs = frozenset(
|
||||
i for i, lt in enumerate(layer_types) if lt != "sliding_attention"
|
||||
)
|
||||
|
||||
# --- MixtureOfExperts delegation to language_model ---
|
||||
self.expert_weights = self.language_model.expert_weights
|
||||
self.moe_layers = self.language_model.moe_layers
|
||||
self.num_moe_layers = self.language_model.num_moe_layers
|
||||
self.num_logical_experts = self.language_model.num_logical_experts
|
||||
self.num_physical_experts = self.language_model.num_physical_experts
|
||||
self.num_local_physical_experts = self.language_model.num_local_physical_experts
|
||||
self.num_routed_experts = self.language_model.num_routed_experts
|
||||
self.num_expert_groups = self.language_model.num_expert_groups
|
||||
self.num_shared_experts = self.language_model.num_shared_experts
|
||||
self.num_redundant_experts = self.language_model.num_redundant_experts
|
||||
|
||||
gen_cfg = vllm_config.model_config.try_get_generation_config()
|
||||
self._suppress_token_ids = gen_cfg.get("suppress_tokens") if gen_cfg else None
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Multimodal processing (encoder-free overrides)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Gemma4ImageInputs,
|
||||
) -> list[torch.Tensor]:
|
||||
"""Project raw image patches directly to LM space.
|
||||
|
||||
No vision tower: each image's pre-patchified pixel values are
|
||||
embedded via Gemma4UnifiedVisionEmbedder, projected through
|
||||
Gemma4MultimodalEmbedder, and padding patches (pp == -1) are
|
||||
stripped per image.
|
||||
"""
|
||||
pixel_values = image_input["pixel_values"]
|
||||
pixel_position_ids = image_input["pixel_position_ids"]
|
||||
target_dtype = self.embed_vision.embedding_projection.weight.dtype
|
||||
|
||||
per_image_features: list[torch.Tensor] = []
|
||||
for pv, pp in zip(pixel_values, pixel_position_ids, strict=True):
|
||||
pv = pv.unsqueeze(0)
|
||||
pp = pp.unsqueeze(0)
|
||||
embedded = self.vision_embedder(pv, pp)
|
||||
projected = self.embed_vision(embedded.to(target_dtype))
|
||||
padding_mask = (pp.squeeze(0) == -1).all(dim=-1)
|
||||
valid_features = projected.squeeze(0)[~padding_mask]
|
||||
per_image_features.append(valid_features)
|
||||
return per_image_features
|
||||
|
||||
def _process_video_input(
|
||||
self,
|
||||
video_input: dict[str, torch.Tensor],
|
||||
) -> list[torch.Tensor]:
|
||||
"""Project video frames to LM space, one frame at a time.
|
||||
|
||||
Frames are split per video, each frame is embedded + projected,
|
||||
and per-frame valid embeddings are concatenated per video.
|
||||
"""
|
||||
pixel_values = video_input["pixel_values_videos"]
|
||||
pixel_position_ids = video_input["pixel_position_ids_videos"]
|
||||
frame_counts = video_input["video_frame_counts"]
|
||||
target_dtype = self.embed_vision.embedding_projection.weight.dtype
|
||||
|
||||
if isinstance(frame_counts, torch.Tensor):
|
||||
fc_list = frame_counts.tolist()
|
||||
else:
|
||||
fc_list = list(frame_counts)
|
||||
|
||||
pv_per_video = torch.split(pixel_values, fc_list, dim=0)
|
||||
pp_per_video = torch.split(pixel_position_ids, fc_list, dim=0)
|
||||
|
||||
per_video_embeddings: list[torch.Tensor] = []
|
||||
for pv_chunk, pp_chunk in zip(pv_per_video, pp_per_video):
|
||||
frame_embs: list[torch.Tensor] = []
|
||||
for i in range(pv_chunk.shape[0]):
|
||||
pv = pv_chunk[i].unsqueeze(0)
|
||||
pp = pp_chunk[i].unsqueeze(0)
|
||||
embedded = self.vision_embedder(pv, pp)
|
||||
projected = self.embed_vision(embedded.to(target_dtype))
|
||||
padding_mask = (pp.squeeze(0) == -1).all(dim=-1)
|
||||
frame_embs.append(projected.squeeze(0)[~padding_mask])
|
||||
per_video_embeddings.append(torch.cat(frame_embs, dim=0))
|
||||
return per_video_embeddings
|
||||
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: Gemma4AudioInputs,
|
||||
) -> list[torch.Tensor]:
|
||||
"""Project raw waveform-frame features directly to LM space.
|
||||
|
||||
No audio tower: the per-frame raw features are passed straight
|
||||
through the multimodal embedder, then padding is stripped.
|
||||
"""
|
||||
input_features = audio_input["input_features_padded"].squeeze(1)
|
||||
input_features_mask = audio_input["input_features_mask"].squeeze(1)
|
||||
|
||||
target_dtype = self.embed_audio.embedding_projection.weight.dtype
|
||||
audio_features = self.embed_audio(input_features.to(target_dtype))
|
||||
per_audio: list[torch.Tensor] = []
|
||||
for enc, mask in zip(audio_features, input_features_mask, strict=True):
|
||||
per_audio.append(enc[mask])
|
||||
return per_audio
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Weight loading
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
ignore_prefixes = [
|
||||
# Vestigial Gemma3n-style embedding tables not used by
|
||||
# Gemma4MultimodalEmbedder (which has only projection + norm).
|
||||
"embed_vision.embedding.",
|
||||
"embed_audio.embedding.",
|
||||
]
|
||||
if self.embed_audio is None:
|
||||
ignore_prefixes.append("embed_audio.")
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
ignore_unexpected_prefixes=ignore_prefixes,
|
||||
)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# LoRA / multimodal mapping
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""Module prefix mapping for the encoder-free model (no towers)."""
|
||||
connectors = ["embed_vision"]
|
||||
if self.embed_audio is not None:
|
||||
connectors.append("embed_audio")
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector=connectors,
|
||||
tower_model=[],
|
||||
)
|
||||
@@ -406,6 +406,10 @@ _MULTIMODAL_MODELS = {
|
||||
"Gemma3nForConditionalGeneration",
|
||||
),
|
||||
"Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
|
||||
"Gemma4UnifiedForConditionalGeneration": (
|
||||
"gemma4_unified",
|
||||
"Gemma4UnifiedForConditionalGeneration",
|
||||
),
|
||||
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
|
||||
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
||||
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
|
||||
|
||||
@@ -556,6 +556,8 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
|
||||
"gemma4": Gemma4ModelArchConfigConvertor,
|
||||
"gemma4_mtp": Gemma4MTPModelArchConfigConvertor,
|
||||
"gemma4_text": Gemma4ModelArchConfigConvertor,
|
||||
"gemma4_unified": Gemma4ModelArchConfigConvertor,
|
||||
"gemma4_unified_text": Gemma4ModelArchConfigConvertor,
|
||||
"glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor,
|
||||
"glm_ocr_mtp": GLM4MoeMTPModelArchConfigConvertor,
|
||||
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
|
||||
|
||||
@@ -1232,6 +1232,7 @@ class SpecDecodeBaseProposer:
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
"Qwen3VLMoeForConditionalGeneration",
|
||||
"Gemma4ForConditionalGeneration",
|
||||
"Gemma4UnifiedForConditionalGeneration",
|
||||
"Step3p7ForConditionalGeneration",
|
||||
]:
|
||||
self.model.config.image_token_index = target_model.config.image_token_id
|
||||
|
||||
@@ -2467,6 +2467,8 @@ class GPUModelRunner(
|
||||
image_doc_ranges = []
|
||||
req_state = self.requests[req_id]
|
||||
for mm_feature in req_state.mm_features:
|
||||
if mm_feature.modality == "audio":
|
||||
continue
|
||||
pos_info = mm_feature.mm_position
|
||||
img_doc_range = pos_info.extract_embeds_range()
|
||||
for r in img_doc_range:
|
||||
|
||||
Reference in New Issue
Block a user