mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[BUGFIX] Fix Pixtral consolidated format vision weight loading (#39916)
Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -25,6 +25,7 @@ if TYPE_CHECKING:
|
||||
|
||||
PIXTRAL_ID = "mistralai/Pixtral-12B-2409"
|
||||
MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
MINISTRAL_3B_ID = "mistralai/Ministral-3-3B-Instruct-2512"
|
||||
|
||||
MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID]
|
||||
|
||||
@@ -116,6 +117,7 @@ assert FIXTURES_PATH.exists()
|
||||
FIXTURE_LOGPROBS_CHAT = {
|
||||
PIXTRAL_ID: FIXTURES_PATH / "pixtral_chat.json",
|
||||
MISTRAL_SMALL_3_1_ID: FIXTURES_PATH / "mistral_small_3_chat.json",
|
||||
MINISTRAL_3B_ID: FIXTURES_PATH / "ministral_3b_chat.json",
|
||||
}
|
||||
|
||||
OutputsLogprobs = list[tuple[list[int], str, SampleLogprobs | None]]
|
||||
@@ -209,3 +211,41 @@ def test_chat(
|
||||
name_0="h100_ref",
|
||||
name_1="output",
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=16)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
def test_chat_consolidated(vllm_runner, dtype: str, local_asset_server) -> None:
|
||||
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(
|
||||
FIXTURE_LOGPROBS_CHAT[MINISTRAL_3B_ID]
|
||||
)
|
||||
with vllm_runner(
|
||||
MINISTRAL_3B_ID,
|
||||
dtype=dtype,
|
||||
tokenizer_mode="mistral",
|
||||
load_format="mistral",
|
||||
config_format="mistral",
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||
) as vllm_model:
|
||||
outputs = []
|
||||
urls_all = [local_asset_server.url_for(u) for u in IMG_URLS]
|
||||
msgs = [
|
||||
_create_msg_format(urls_all[:1]),
|
||||
_create_msg_format(urls_all[:2]),
|
||||
_create_msg_format(urls_all),
|
||||
]
|
||||
for msg in msgs:
|
||||
output = vllm_model.llm.chat(msg, sampling_params=SAMPLING_PARAMS)
|
||||
outputs.extend(output)
|
||||
|
||||
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||
for i in range(len(logprobs)):
|
||||
assert logprobs[i][-1] is None
|
||||
logprobs[i] = logprobs[i][:-1]
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
|
||||
outputs_1_lst=logprobs,
|
||||
name_0="h100_ref",
|
||||
name_1="output",
|
||||
)
|
||||
|
||||
@@ -458,13 +458,27 @@ class PixtralForConditionalGeneration(
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
_vision_encoder_stacked_params = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
# HF format
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
# Mistral native (consolidated) format
|
||||
(".qkv_proj", ".wq", "q"),
|
||||
(".qkv_proj", ".wk", "k"),
|
||||
(".qkv_proj", ".wv", "v"),
|
||||
(".gate_up_proj", ".w1", 0),
|
||||
(".gate_up_proj", ".w3", 1),
|
||||
]
|
||||
|
||||
# Remap Mistral native names to HF-style names
|
||||
# used by the vLLM vision encoder modules.
|
||||
_vision_encoder_name_remap = {
|
||||
".wo.": ".o_proj.",
|
||||
".w2.": ".down_proj.",
|
||||
}
|
||||
|
||||
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
|
||||
return weight[0].startswith(("vision_encoder", "vision_tower"))
|
||||
|
||||
@@ -518,6 +532,11 @@ class PixtralForConditionalGeneration(
|
||||
weight_loader(param, w, shard_id)
|
||||
break
|
||||
else:
|
||||
for old, new in _vision_encoder_name_remap.items():
|
||||
if old in trimmed_name:
|
||||
trimmed_name = trimmed_name.replace(old, new)
|
||||
break
|
||||
|
||||
param = vision_encoder_dict.get(trimmed_name)
|
||||
if param is not None:
|
||||
weight_loader = getattr(
|
||||
|
||||
Reference in New Issue
Block a user