mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Bugfix] Convert Gemma4-MM ViT linear layers to vllm native impl (#43798)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: ZiTian Zhao <zitian.zhao@tencentmusic.com> Co-authored-by: B-201 <Joy25810@foxmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
@@ -168,6 +169,12 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
class BitsAndBytesWeightParameter(torch.nn.Parameter):
|
||||
@cached_property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return torch.get_default_dtype()
|
||||
|
||||
|
||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
|
||||
# Split the prefix into its dot-separated components
|
||||
components = prefix.split(".")
|
||||
@@ -246,7 +253,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
"The input size is not aligned with the quantized weight shape."
|
||||
)
|
||||
|
||||
qweight = torch.nn.Parameter(
|
||||
qweight = BitsAndBytesWeightParameter(
|
||||
torch.empty(total_size // quant_ratio, 1, dtype=torch.uint8),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
@@ -744,6 +744,29 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[
|
||||
non_stacked_param_name
|
||||
]
|
||||
|
||||
# repeat k_proj for v_proj for k_eq_v models (e.g. Gemma4)
|
||||
config = getattr(model, "config", None)
|
||||
if config is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "attention_k_eq_v", False):
|
||||
shard_packed = {
|
||||
name
|
||||
for name, subs in self.modules_mapping.packed_mapping.items()
|
||||
if len(subs) == 3
|
||||
}
|
||||
for param_name, shards in stacked_quant_state_dict.items():
|
||||
is_target = (
|
||||
isinstance(shards, dict)
|
||||
and len(shards) == 2
|
||||
and any(
|
||||
param_name.endswith(f"{p}.weight") for p in shard_packed
|
||||
)
|
||||
)
|
||||
if is_target:
|
||||
assert 1 in shards and 2 not in shards
|
||||
shards[2] = shards[1]
|
||||
|
||||
return stacked_quant_state_dict
|
||||
|
||||
def _bind_quant_states_to_params(
|
||||
|
||||
@@ -16,7 +16,7 @@ reason about temporal order.
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.transformers.utils import recursive_replace_linear
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalFieldConfig,
|
||||
@@ -71,6 +72,7 @@ from .interfaces import (
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
@@ -79,6 +81,9 @@ from .utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Video constants — match transformers Gemma4VideoProcessor defaults.
|
||||
@@ -872,6 +877,9 @@ class Gemma4MultimodalEmbedder(nn.Module):
|
||||
self,
|
||||
multimodal_config: Gemma4VisionConfig | Gemma4AudioConfig,
|
||||
text_config: Gemma4TextConfig,
|
||||
*,
|
||||
quant_config: "QuantizationConfig | None" = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -895,6 +903,8 @@ class Gemma4MultimodalEmbedder(nn.Module):
|
||||
embedding_dim,
|
||||
self.text_hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "embedding_projection"),
|
||||
)
|
||||
|
||||
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
|
||||
@@ -917,6 +927,7 @@ class Gemma4MultimodalEmbedder(nn.Module):
|
||||
class Gemma4ForConditionalGeneration(
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsQuant,
|
||||
SupportsPP,
|
||||
SupportsLoRA,
|
||||
SupportsEagle3,
|
||||
@@ -936,11 +947,14 @@ class Gemma4ForConditionalGeneration(
|
||||
# Maps checkpoint prefixes to vLLM module paths.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.embed_audio.": "embed_audio.",
|
||||
"model.embed_vision.": "embed_vision.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
# vision tower
|
||||
"model.vision_tower": "vision_tower",
|
||||
"model.embed_vision": "embed_vision",
|
||||
# audio tower
|
||||
"model.audio_tower.": "audio_tower.",
|
||||
"model.embed_audio.": "embed_audio.",
|
||||
# backbone
|
||||
"model.language_model.": "language_model.model.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model": "language_model.model",
|
||||
}
|
||||
@@ -959,7 +973,15 @@ class Gemma4ForConditionalGeneration(
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.embed_vision = Gemma4MultimodalEmbedder(
|
||||
config.vision_config, config.text_config
|
||||
config.vision_config,
|
||||
config.text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "embed_vision"),
|
||||
)
|
||||
recursive_replace_linear(
|
||||
self.vision_tower,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
# ---- Audio tower (variants with audio_config) ----
|
||||
@@ -972,7 +994,15 @@ class Gemma4ForConditionalGeneration(
|
||||
# position embeddings, softcap, gradient_clipping).
|
||||
self.audio_tower.post_init()
|
||||
self.embed_audio = Gemma4MultimodalEmbedder(
|
||||
config.audio_config, config.text_config
|
||||
config.audio_config,
|
||||
config.text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "embed_audio"),
|
||||
)
|
||||
recursive_replace_linear(
|
||||
self.audio_tower,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||
)
|
||||
else:
|
||||
self.audio_tower = None
|
||||
@@ -1153,6 +1183,7 @@ class Gemma4ForConditionalGeneration(
|
||||
vt = self.vision_tower
|
||||
vision_cfg = self.config.vision_config
|
||||
pooling_k2 = vision_cfg.pooling_kernel_size**2
|
||||
target_dtype = self.language_model.model.embed_tokens.weight.dtype
|
||||
|
||||
# Concurrent requests with different image resolutions may
|
||||
# arrive as a list of per-image tensors, while same-resolution
|
||||
@@ -1193,7 +1224,11 @@ class Gemma4ForConditionalGeneration(
|
||||
)
|
||||
pad_tensor = (pp_tensor == -1).all(dim=-1)
|
||||
|
||||
inputs_embeds = vt.patch_embedder(pv_tensor, pp_tensor, pad_tensor)
|
||||
inputs_embeds = vt.patch_embedder(
|
||||
pv_tensor,
|
||||
pp_tensor,
|
||||
pad_tensor,
|
||||
).to(target_dtype)
|
||||
encoder_outputs = vt.encoder(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=~pad_tensor,
|
||||
@@ -1230,7 +1265,9 @@ class Gemma4ForConditionalGeneration(
|
||||
all_valid_states[orig_idx] = valid_states
|
||||
valid_lens[orig_idx] = valid_states.shape[0]
|
||||
|
||||
target_dtype = self.embed_vision.embedding_projection.weight.dtype
|
||||
# Use embed_tokens dtype as compute dtype; embedding_projection.weight
|
||||
# may be uint8 under BnB 4-bit, which would corrupt the cast.
|
||||
target_dtype = self.language_model.model.embed_tokens.weight.dtype
|
||||
|
||||
# Project all images in a single batched call.
|
||||
flat_valid_states = torch.cat(all_valid_states, dim=0).to(target_dtype)
|
||||
@@ -1273,7 +1310,7 @@ class Gemma4ForConditionalGeneration(
|
||||
vt = self.vision_tower
|
||||
vision_cfg = self.config.vision_config
|
||||
pooling_k2 = vision_cfg.pooling_kernel_size**2
|
||||
target_dtype = self.embed_vision.embedding_projection.weight.dtype
|
||||
target_dtype = self.language_model.model.embed_tokens.weight.dtype
|
||||
|
||||
if isinstance(frame_counts, torch.Tensor):
|
||||
fc_list = frame_counts.tolist()
|
||||
@@ -1301,7 +1338,11 @@ class Gemma4ForConditionalGeneration(
|
||||
pp_chunk = pixel_position_ids[i : i + max_batch_size]
|
||||
pad_chunk = padding_positions[i : i + max_batch_size]
|
||||
|
||||
inputs_embeds = vt.patch_embedder(pv_chunk, pp_chunk, pad_chunk)
|
||||
inputs_embeds = vt.patch_embedder(
|
||||
pv_chunk,
|
||||
pp_chunk,
|
||||
pad_chunk,
|
||||
).to(target_dtype)
|
||||
encoder_outputs = vt.encoder(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=~pad_chunk,
|
||||
|
||||
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.transformers_utils.config import is_rope_parameters_nested
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -227,6 +228,34 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
|
||||
return RMSNorm(**kwargs)
|
||||
|
||||
|
||||
def recursive_replace_linear(
|
||||
model: nn.Module,
|
||||
quant_config: "QuantizationConfig | None",
|
||||
prefix: str = "",
|
||||
):
|
||||
"""Recursively replace linear modules in the model as needed."""
|
||||
|
||||
def _recursive_replace(module: nn.Module, prefix: str):
|
||||
for child_name, child_module in module.named_children():
|
||||
new_module = child_module
|
||||
qual_name = maybe_prefix(prefix, child_name)
|
||||
# Replace modules as needed
|
||||
if isinstance(child_module, nn.Linear):
|
||||
style = "replicate"
|
||||
new_module = replace_linear_class(
|
||||
child_module,
|
||||
style,
|
||||
quant_config,
|
||||
prefix=qual_name,
|
||||
)
|
||||
else:
|
||||
_recursive_replace(child_module, prefix=qual_name)
|
||||
if new_module is not child_module:
|
||||
setattr(module, child_name, new_module)
|
||||
|
||||
_recursive_replace(model, prefix=prefix)
|
||||
|
||||
|
||||
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
||||
logger.debug("%s: %s -> %s", name, old_module, new_module)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user