mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Bugfix][Pooling] Fix silent weight corruption with buffer-reusing iterators (#39650)
Signed-off-by: Pedram Razavi <pedram.razavi@gmail.com>
This commit is contained in:
@@ -0,0 +1,148 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for model adapter weight loading (adapters.py)."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.adapters import _create_pooling_model_cls
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader, StageMissingLayer
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
class SimpleInnerModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed = torch.nn.Linear(4, 8, bias=False)
|
||||
self.layer0 = torch.nn.Linear(8, 8, bias=False)
|
||||
self.layer1 = torch.nn.Linear(8, 8, bias=False)
|
||||
self.norm = torch.nn.Linear(8, 4, bias=False)
|
||||
|
||||
def load_weights(self, weights):
|
||||
params = dict(self.named_parameters())
|
||||
loaded = set()
|
||||
for name, tensor in weights:
|
||||
if name in params:
|
||||
params[name].data.copy_(tensor)
|
||||
loaded.add(name)
|
||||
return loaded
|
||||
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = SimpleInnerModel()
|
||||
self.lm_head = torch.nn.Linear(8, 16, bias=False)
|
||||
|
||||
def load_weights(self, weights):
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class PackedWeightInnerModel(torch.nn.Module):
|
||||
"""Remaps q_proj/k_proj into a fused qkv_proj (Qwen2/Llama pattern)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.qkv_proj = torch.nn.Linear(4, 16, bias=False)
|
||||
self.out = torch.nn.Linear(8, 4, bias=False)
|
||||
|
||||
def load_weights(self, weights):
|
||||
params = dict(self.named_parameters())
|
||||
loaded = set()
|
||||
for name, tensor in weights:
|
||||
if name == "q_proj.weight":
|
||||
params["qkv_proj.weight"].data[:8].copy_(tensor)
|
||||
loaded.add("qkv_proj.weight")
|
||||
elif name == "k_proj.weight":
|
||||
params["qkv_proj.weight"].data[8:].copy_(tensor)
|
||||
loaded.add("qkv_proj.weight")
|
||||
elif name in params:
|
||||
params[name].data.copy_(tensor)
|
||||
loaded.add(name)
|
||||
return loaded
|
||||
|
||||
|
||||
class PackedWeightModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = PackedWeightInnerModel()
|
||||
self.lm_head = torch.nn.Linear(4, 8, bias=False)
|
||||
|
||||
def load_weights(self, weights):
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
def _buffer_reusing_iterator(weight_dict):
|
||||
"""Yield weights through a shared buffer overwritten each step.
|
||||
|
||||
Mimics ``runai_model_streamer`` with ``RUNAI_STREAMER_MEMORY_LIMIT=0``.
|
||||
"""
|
||||
buf = None
|
||||
for name, tensor in weight_dict.items():
|
||||
if buf is None or buf.numel() < tensor.numel():
|
||||
buf = torch.empty(tensor.numel(), dtype=tensor.dtype)
|
||||
view = buf[: tensor.numel()].view(tensor.shape)
|
||||
view.copy_(tensor)
|
||||
yield name, view
|
||||
|
||||
|
||||
def _make_pooling_model(base_cls=SimpleModel):
|
||||
PoolingModel = _create_pooling_model_cls(base_cls)
|
||||
model = base_cls()
|
||||
model.__class__ = PoolingModel
|
||||
model.lm_head = StageMissingLayer("output", model.lm_head)
|
||||
return model
|
||||
|
||||
|
||||
def _make_reference_weights():
|
||||
torch.manual_seed(42)
|
||||
return {
|
||||
"model.embed.weight": torch.randn(8, 4),
|
||||
"model.layer0.weight": torch.randn(8, 8),
|
||||
"model.layer1.weight": torch.randn(8, 8),
|
||||
"model.norm.weight": torch.randn(4, 8),
|
||||
"lm_head.weight": torch.randn(16, 8),
|
||||
}
|
||||
|
||||
|
||||
def _make_packed_reference_weights():
|
||||
torch.manual_seed(42)
|
||||
return {
|
||||
"model.q_proj.weight": torch.randn(8, 4),
|
||||
"model.k_proj.weight": torch.randn(8, 4),
|
||||
"model.out.weight": torch.randn(4, 8),
|
||||
"lm_head.weight": torch.randn(8, 4),
|
||||
}
|
||||
|
||||
|
||||
def _load_and_compare(model, ref, expected):
|
||||
for p in model.parameters():
|
||||
p.data.zero_()
|
||||
model.load_weights(_buffer_reusing_iterator(ref))
|
||||
for name, param in model.named_parameters():
|
||||
assert torch.equal(param.data, expected[name]), name
|
||||
|
||||
|
||||
def test_pooling_load_weights_with_buffer_reuse():
|
||||
"""Ensure ModelForPooling.load_weights works with buffer-reusing iterators."""
|
||||
ref = _make_reference_weights()
|
||||
|
||||
ground_truth = SimpleModel()
|
||||
ground_truth.load_weights(ref.items())
|
||||
expected = {n: p.data.clone() for n, p in ground_truth.named_parameters()}
|
||||
|
||||
_load_and_compare(_make_pooling_model(), ref, expected)
|
||||
|
||||
|
||||
def test_pooling_load_weights_clones_probed_weights():
|
||||
"""Ensure probed weights survive buffer reuse during packed remapping."""
|
||||
ref = _make_packed_reference_weights()
|
||||
|
||||
ground_truth = PackedWeightModel()
|
||||
ground_truth.load_weights(ref.items())
|
||||
expected = {n: p.data.clone() for n, p in ground_truth.named_parameters()}
|
||||
|
||||
_load_and_compare(_make_pooling_model(PackedWeightModel), ref, expected)
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
@@ -181,7 +182,8 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
|
||||
seen_weights = list[tuple[str, torch.Tensor]]()
|
||||
for name, loaded_weight in weights:
|
||||
seen_weights.append((name, loaded_weight))
|
||||
# Clone because the iterator may reuse the tensor buffer
|
||||
seen_weights.append((name, loaded_weight.clone()))
|
||||
|
||||
try:
|
||||
target_prefix = next(
|
||||
@@ -208,9 +210,11 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
self._get_name(),
|
||||
)
|
||||
|
||||
# Lazy chain so buffer-reusing weight iterators (e.g.
|
||||
# runai_streamer) are consumed one tensor at a time.
|
||||
mapped_weights = (
|
||||
(target_prefix + name, weight)
|
||||
for name, weight in (*seen_weights, *weights)
|
||||
for name, weight in itertools.chain(seen_weights, weights)
|
||||
)
|
||||
|
||||
def default_load_weights(weights):
|
||||
|
||||
Reference in New Issue
Block a user