[Bugfix] Convert Gemma4-MM ViT linear layers to vllm native impl (#43798)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: ZiTian Zhao <zitian.zhao@tencentmusic.com>
Co-authored-by: B-201 <Joy25810@foxmail.com>
This commit is contained in:
Isotr0py
2026-06-02 12:41:16 +08:00
committed by GitHub
parent a045c7425f
commit f91fb2fcf3
4 changed files with 112 additions and 12 deletions
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cached_property
from typing import Any, Union
import torch
@@ -168,6 +169,12 @@ class BitsAndBytesConfig(QuantizationConfig):
return None
class BitsAndBytesWeightParameter(torch.nn.Parameter):
@cached_property
def dtype(self) -> torch.dtype:
return torch.get_default_dtype()
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
# Split the prefix into its dot-separated components
components = prefix.split(".")
@@ -246,7 +253,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized weight shape."
)
qweight = torch.nn.Parameter(
qweight = BitsAndBytesWeightParameter(
torch.empty(total_size // quant_ratio, 1, dtype=torch.uint8),
requires_grad=False,
)
@@ -744,6 +744,29 @@ class BitsAndBytesModelLoader(BaseModelLoader):
stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[
non_stacked_param_name
]
# repeat k_proj for v_proj for k_eq_v models (e.g. Gemma4)
config = getattr(model, "config", None)
if config is not None:
text_config = config.get_text_config()
if getattr(text_config, "attention_k_eq_v", False):
shard_packed = {
name
for name, subs in self.modules_mapping.packed_mapping.items()
if len(subs) == 3
}
for param_name, shards in stacked_quant_state_dict.items():
is_target = (
isinstance(shards, dict)
and len(shards) == 2
and any(
param_name.endswith(f"{p}.weight") for p in shard_packed
)
)
if is_target:
assert 1 in shards and 2 not in shards
shards[2] = shards[1]
return stacked_quant_state_dict
def _bind_quant_states_to_params(
+52 -11
View File
@@ -16,7 +16,7 @@ reason about temporal order.
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal
from typing import TYPE_CHECKING, Annotated, Any, Literal
import numpy as np
import torch
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.transformers.utils import recursive_replace_linear
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalFieldConfig,
@@ -71,6 +72,7 @@ from .interfaces import (
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
)
from .utils import (
AutoWeightsLoader,
@@ -79,6 +81,9 @@ from .utils import (
maybe_prefix,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationConfig
logger = init_logger(__name__)
# Video constants — match transformers Gemma4VideoProcessor defaults.
@@ -872,6 +877,9 @@ class Gemma4MultimodalEmbedder(nn.Module):
self,
multimodal_config: Gemma4VisionConfig | Gemma4AudioConfig,
text_config: Gemma4TextConfig,
*,
quant_config: "QuantizationConfig | None" = None,
prefix: str = "",
):
super().__init__()
@@ -895,6 +903,8 @@ class Gemma4MultimodalEmbedder(nn.Module):
embedding_dim,
self.text_hidden_size,
bias=False,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "embedding_projection"),
)
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
@@ -917,6 +927,7 @@ class Gemma4MultimodalEmbedder(nn.Module):
class Gemma4ForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsQuant,
SupportsPP,
SupportsLoRA,
SupportsEagle3,
@@ -936,11 +947,14 @@ class Gemma4ForConditionalGeneration(
# Maps checkpoint prefixes to vLLM module paths.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.embed_audio.": "embed_audio.",
"model.embed_vision.": "embed_vision.",
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
# vision tower
"model.vision_tower": "vision_tower",
"model.embed_vision": "embed_vision",
# audio tower
"model.audio_tower.": "audio_tower.",
"model.embed_audio.": "embed_audio.",
# backbone
"model.language_model.": "language_model.model.",
"lm_head.": "language_model.lm_head.",
"model": "language_model.model",
}
@@ -959,7 +973,15 @@ class Gemma4ForConditionalGeneration(
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.embed_vision = Gemma4MultimodalEmbedder(
config.vision_config, config.text_config
config.vision_config,
config.text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "embed_vision"),
)
recursive_replace_linear(
self.vision_tower,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
# ---- Audio tower (variants with audio_config) ----
@@ -972,7 +994,15 @@ class Gemma4ForConditionalGeneration(
# position embeddings, softcap, gradient_clipping).
self.audio_tower.post_init()
self.embed_audio = Gemma4MultimodalEmbedder(
config.audio_config, config.text_config
config.audio_config,
config.text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "embed_audio"),
)
recursive_replace_linear(
self.audio_tower,
quant_config,
prefix=maybe_prefix(prefix, "audio_tower"),
)
else:
self.audio_tower = None
@@ -1153,6 +1183,7 @@ class Gemma4ForConditionalGeneration(
vt = self.vision_tower
vision_cfg = self.config.vision_config
pooling_k2 = vision_cfg.pooling_kernel_size**2
target_dtype = self.language_model.model.embed_tokens.weight.dtype
# Concurrent requests with different image resolutions may
# arrive as a list of per-image tensors, while same-resolution
@@ -1193,7 +1224,11 @@ class Gemma4ForConditionalGeneration(
)
pad_tensor = (pp_tensor == -1).all(dim=-1)
inputs_embeds = vt.patch_embedder(pv_tensor, pp_tensor, pad_tensor)
inputs_embeds = vt.patch_embedder(
pv_tensor,
pp_tensor,
pad_tensor,
).to(target_dtype)
encoder_outputs = vt.encoder(
inputs_embeds=inputs_embeds,
attention_mask=~pad_tensor,
@@ -1230,7 +1265,9 @@ class Gemma4ForConditionalGeneration(
all_valid_states[orig_idx] = valid_states
valid_lens[orig_idx] = valid_states.shape[0]
target_dtype = self.embed_vision.embedding_projection.weight.dtype
# Use embed_tokens dtype as compute dtype; embedding_projection.weight
# may be uint8 under BnB 4-bit, which would corrupt the cast.
target_dtype = self.language_model.model.embed_tokens.weight.dtype
# Project all images in a single batched call.
flat_valid_states = torch.cat(all_valid_states, dim=0).to(target_dtype)
@@ -1273,7 +1310,7 @@ class Gemma4ForConditionalGeneration(
vt = self.vision_tower
vision_cfg = self.config.vision_config
pooling_k2 = vision_cfg.pooling_kernel_size**2
target_dtype = self.embed_vision.embedding_projection.weight.dtype
target_dtype = self.language_model.model.embed_tokens.weight.dtype
if isinstance(frame_counts, torch.Tensor):
fc_list = frame_counts.tolist()
@@ -1301,7 +1338,11 @@ class Gemma4ForConditionalGeneration(
pp_chunk = pixel_position_ids[i : i + max_batch_size]
pad_chunk = padding_positions[i : i + max_batch_size]
inputs_embeds = vt.patch_embedder(pv_chunk, pp_chunk, pad_chunk)
inputs_embeds = vt.patch_embedder(
pv_chunk,
pp_chunk,
pad_chunk,
).to(target_dtype)
encoder_outputs = vt.encoder(
inputs_embeds=inputs_embeds,
attention_mask=~pad_chunk,
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.transformers_utils.config import is_rope_parameters_nested
if TYPE_CHECKING:
@@ -227,6 +228,34 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
return RMSNorm(**kwargs)
def recursive_replace_linear(
model: nn.Module,
quant_config: "QuantizationConfig | None",
prefix: str = "",
):
"""Recursively replace linear modules in the model as needed."""
def _recursive_replace(module: nn.Module, prefix: str):
for child_name, child_module in module.named_children():
new_module = child_module
qual_name = maybe_prefix(prefix, child_name)
# Replace modules as needed
if isinstance(child_module, nn.Linear):
style = "replicate"
new_module = replace_linear_class(
child_module,
style,
quant_config,
prefix=qual_name,
)
else:
_recursive_replace(child_module, prefix=qual_name)
if new_module is not child_module:
setattr(module, child_name, new_module)
_recursive_replace(model, prefix=prefix)
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)