diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index b0b6c080ac4..e4d1d9c1c13 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -550,6 +550,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `DeepseekOCR2ForCausalLM` | DeepSeek-OCR-2 | T + I+ | `deepseek-ai/DeepSeek-OCR-2`, etc. | ✅︎ | ✅︎ | | `Eagle2_5_VLForConditionalGeneration` | Eagle2.5-VL | T + IE+ | `nvidia/Eagle2.5-8B`, etc. | ✅︎ | ✅︎ | | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | +| `Exaone4_5_ForConditionalGeneration` | EXAONE-4.5 | T + IE+ | `LGAI-EXAONE/EXAONE-4.5-33B`, etc. | ✅︎ | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + IE+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | | `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index d62f25c2285..2c0bd52c0e3 100755 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -421,6 +421,43 @@ def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# EXAONE-4.5 +def run_exaone4_5(questions: list[str], modality: str) -> ModelRequestData: + model_name = "LGAI-EXAONE/EXAONE-4.5-33B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + "<|system|>\nYou are a helpful assistant.<|endofturn|>\n" + f"<|user|>\n{placeholder}" + f"{question}<|endofturn|>\n" + "<|assistant|>\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Fuyu def run_fuyu(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -2199,6 +2236,7 @@ model_example_map = { "dots_ocr": run_dots_ocr, "eagle2_5": run_eagle2_5, "ernie45_vl": run_ernie45_vl, + "exaone4_5": run_exaone4_5, "fuyu": run_fuyu, "gemma3": run_gemma3, "gemma3n": run_gemma3n, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 1963ffff791..1a3b7fd954c 100755 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -241,6 +241,41 @@ def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData: ) +# exaone4_5 +def load_exaone4_5(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "LGAI-EXAONE/EXAONE-4.5-33B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "google/gemma-3-4b-it" @@ -1450,6 +1485,7 @@ model_example_map = { "command_a_vision": load_command_a_vision, "deepseek_vl_v2": load_deepseek_vl2, "deepseek_ocr": load_deepseek_ocr, + "exaone4_5": load_exaone4_5, "gemma3": load_gemma3, "h2ovl_chat": load_h2ovl, "hunyuan_vl": load_hunyuan_vl, diff --git a/tests/models/registry.py b/tests/models/registry.py index 9f2bf5f1a2a..95753b2a60e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -813,6 +813,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { trust_remote_code=True, revision="refs/pr/17", ), + "Exaone4_5_ForConditionalGeneration": _HfExamplesInfo( + "LGAI-EXAONE/EXAONE-4.5-33B", + min_transformers_version="5.6.0", + ), "FireRedASR2ForConditionalGeneration": _HfExamplesInfo( "allendou/FireRedASR2-LLM-vllm", ), @@ -1306,6 +1310,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { min_transformers_version="5.1.0", enable_prefix_caching=False, ), + "Exaone4_5_MTP": _HfExamplesInfo( + "LGAI-EXAONE/EXAONE-4.5-33B", + speculative_model="LGAI-EXAONE/EXAONE-4.5-33B", + min_transformers_version="5.6.0", + ), "ExtractHiddenStatesModel": _HfExamplesInfo( "Qwen/Qwen3-8B", speculative_method="extract_hidden_states", diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 0e74501dd9a..8f9b9dcc811 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -40,6 +40,7 @@ MTPModelTypes = Literal[ "ernie_mtp", "nemotron_h_mtp", "exaone_moe_mtp", + "exaone4_5_mtp", "qwen3_next_mtp", "qwen3_5_mtp", "longcat_flash_mtp", @@ -327,7 +328,13 @@ class SpeculativeConfig: hf_config.update( {"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]} ) - + if "exaone4_5" in hf_config.model_type: + hf_config.model_type = "exaone4_5_mtp" + if hf_config.model_type == "exaone4_5_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update( + {"n_predict": n_predict, "architectures": ["Exaone4_5_MTP"]} + ) if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"): is_moe = hf_config.model_type == "qwen3_5_moe" hf_config.model_type = "qwen3_5_mtp" diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 485b145b9cd..04708de93d3 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -75,6 +75,7 @@ class Exaone4GatedMLP(nn.Module): reduce_results: bool = True, bias: bool = False, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -83,6 +84,7 @@ class Exaone4GatedMLP(nn.Module): bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel, ) self.down_proj = RowParallelLinear( input_size=intermediate_size, @@ -91,6 +93,7 @@ class Exaone4GatedMLP(nn.Module): quant_config=quant_config, reduce_results=reduce_results, prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, ) if hidden_act != "silu": raise ValueError( diff --git a/vllm/model_executor/models/exaone4_5.py b/vllm/model_executor/models/exaone4_5.py new file mode 100644 index 00000000000..1eac43ccb0c --- /dev/null +++ b/vllm/model_executor/models/exaone4_5.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only EXAONE-4.5 model compatible with HuggingFace weights.""" + +from collections.abc import Callable, Iterable +from functools import partial + +import einops +import torch +import torch.nn as nn +from transformers.models.exaone4_5 import ( + Exaone4_5_Config, + Exaone4_5_ImageProcessor, + Exaone4_5_Processor, +) +from transformers.models.exaone4_5.configuration_exaone4_5 import Exaone4_5_VisionConfig + +from vllm.compilation.decorators import ( + should_torch_compile_mm_encoder, + support_torch_compile, +) +from vllm.config import VllmConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) +from vllm.model_executor.models.exaone4 import Exaone4GatedMLP as Exaone4_5_VisionMLP +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VisionTransformer, + Qwen2_5_VLForConditionalGeneration, + Qwen2VLProcessingInfo, +) +from vllm.multimodal import MULTIMODAL_REGISTRY + +from .qwen2_vl import Qwen2VLDummyInputsBuilder as Exaone4_5_DummyInputsBuilder +from .qwen2_vl import Qwen2VLMultiModalProcessor as Exaone4_5_MultiModalProcessor +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix + +logger = init_logger(__name__) + + +# === Vision Encoder === # + + +class EXAONE4_5_VisionAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + projection_size: int, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads + ) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size + ) + + self.total_num_heads = num_heads + self.total_num_kv_heads = num_kv_heads + self.num_heads = num_heads // self.tp_size + self.num_kv_heads = max(1, num_kv_heads // self.tp_size) + + self.head_dim = embed_dim // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel, + ) + + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) + + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, + head_size=self.hidden_size_per_attention_head, + num_kv_heads=self.num_kv_heads, + scale=self.hidden_size_per_attention_head**-0.5, + prefix=f"{prefix}.attn", + ) + + self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # qkv: [s, b, (h + 2*hk) * d] + s, b, _ = qkv.shape + h = self.num_heads + hk = self.num_kv_heads + d = self.head_dim + + qkv = qkv.view(s, b, h + 2 * hk, d) + + q = qkv[:, :, :h, :] + k = qkv[:, :, h : h + hk, :] + v = qkv[:, :, h + hk :, :] + + # [s, b, h, d] -> [b, s, h, d] + return ( + q.permute(1, 0, 2, 3).contiguous(), + k.permute(1, 0, 2, 3).contiguous(), + v.permute(1, 0, 2, 3).contiguous(), + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + max_seqlen: int | None = None, + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + seq_len, batch_size, _ = x.shape + + q, k, v = self.split_qkv(x) + q = self.apply_rotary_emb( + q, + rotary_pos_emb_cos, + rotary_pos_emb_sin, + ) + + k = self.apply_rotary_emb( + k, + rotary_pos_emb_cos, + rotary_pos_emb_sin, + ) + + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + context_layer = einops.rearrange( + context_layer, "b s h d -> s b (h d)", b=batch_size + ).contiguous() + + output, _ = self.proj(context_layer) + return output + + +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + "cu_seqlens": 0, + "rotary_pos_emb_cos": 0, + "rotary_pos_emb_sin": 0, + }, + enable_if=should_torch_compile_mm_encoder, + is_encoder=True, +) +class Exaone4_5_VisionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_hidden_dim: int, + hidden_act: str = "silu", + norm_layer: Callable[[int], nn.Module] | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = EXAONE4_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, + ) + self.mlp = Exaone4_5_VisionMLP( + dim, + mlp_hidden_dim, + hidden_act=hidden_act, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers + ) -> torch.Tensor: + x_attn = self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + ) + x_fused_norm, residual = self.norm2(x, residual=x_attn) + x = residual + self.mlp(x_fused_norm) + return x + + +class EXAONE4_5_VisionTransformer(Qwen2_5_VisionTransformer): + def __init__( + self, + vision_config: Exaone4_5_VisionConfig, + norm_eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__( + vision_config=vision_config, + norm_eps=norm_eps, + quant_config=quant_config, + prefix=prefix, + ) + depth = vision_config.depth + self.num_kv_heads = vision_config.num_key_value_heads + + norm_layer = partial(RMSNorm, eps=norm_eps) + + self.blocks = nn.ModuleList( + [ + Exaone4_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + mlp_hidden_dim=vision_config.intermediate_size, + hidden_act=vision_config.hidden_act, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(depth) + ] + ) + + +class Exaone4_5_ProcessingInfo(Qwen2VLProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(Exaone4_5_Config) + + def get_hf_processor(self, **kwargs: object) -> Exaone4_5_Processor: + return self.ctx.get_hf_processor( + Exaone4_5_Processor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + + def get_image_processor(self, **kwargs: object) -> Exaone4_5_ImageProcessor: + return Exaone4_5_ImageProcessor(**kwargs) + + +@MULTIMODAL_REGISTRY.register_processor( + Exaone4_5_MultiModalProcessor, + info=Exaone4_5_ProcessingInfo, + dummy_inputs=Exaone4_5_DummyInputsBuilder, +) +class Exaone4_5_ForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + + config: Exaone4_5_Config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.config = config + self.multimodal_config = multimodal_config + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) + + with self._mark_tower_model(vllm_config, {"image", "video"}): + self.visual = EXAONE4_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) + + with self._mark_language_model(vllm_config): + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + hf_config=config.get_text_config(), + architectures=["Exaone4ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["mtp."]), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "<|image_pad|>" + if modality.startswith("video"): + return "<|video_pad|>" + + raise ValueError("Only image or video modality is supported") diff --git a/vllm/model_executor/models/exaone4_5_mtp.py b/vllm/model_executor/models/exaone4_5_mtp.py new file mode 100644 index 00000000000..5252aa7a72d --- /dev/null +++ b/vllm/model_executor/models/exaone4_5_mtp.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only EXAONE-4_5 MTP model.""" + +from collections.abc import Iterable + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.models.exaone4 import Exaone4DecoderLayer +from vllm.model_executor.models.exaone_moe_mtp import ( + ExaoneMoeMTP, + ExaoneMoeMultiTokenPredictor, +) + +from .utils import ( + AutoWeightsLoader, + maybe_prefix, +) + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +@support_torch_compile +class Exaone4_5MultiTokenPredictor(ExaoneMoeMultiTokenPredictor): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + config = model_config.hf_config + + self.config = config + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.fc = ColumnParallelLinear( + self.config.hidden_size * 2, + self.config.hidden_size, + gather_output=True, + bias=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fc", + ) + self.layers = nn.ModuleList( + Exaone4DecoderLayer( + vllm_config.model_config.hf_config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}", + ) + for idx in range(self.num_mtp_layers) + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_fc_norm_hidden = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_fc_norm_embedding = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + +@support_torch_compile +class Exaone4_5_MTP(ExaoneMoeMTP): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.quant_config = vllm_config.quant_config + + nn.Module.__init__(self) + self.config = config + self.model = Exaone4_5MultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp") + ) + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + shared_weight_names = ["embed_tokens", "lm_head"] + + def remap_weight_names(weights): + for name, weight in weights: + if name.startswith("mtp."): + name = name.replace("mtp.", "model.") + elif any(key in name for key in shared_weight_names): + if "embed_tokens" in name: + name = name.replace("language_model.", "") + else: + continue + yield name, weight + + loader = AutoWeightsLoader(self) + return loader.load_weights(remap_weight_names(weights)) diff --git a/vllm/model_executor/models/exaone_moe_mtp.py b/vllm/model_executor/models/exaone_moe_mtp.py index b3c71e6aef6..b3f8552aac5 100644 --- a/vllm/model_executor/models/exaone_moe_mtp.py +++ b/vllm/model_executor/models/exaone_moe_mtp.py @@ -184,11 +184,6 @@ class ExaoneMoeMTP(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config - cache_config = vllm_config.cache_config - assert not cache_config.enable_prefix_caching, ( - "ExaoneMoeMTP currently does not support prefix caching" - ) - self.quant_config = vllm_config.quant_config super().__init__() diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index fa129bfb42c..4def68d2b11 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -371,6 +371,10 @@ _MULTIMODAL_MODELS = { "ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration", ), + "Exaone4_5_ForConditionalGeneration": ( + "exaone4_5", + "Exaone4_5_ForConditionalGeneration", + ), # noqa: E501 "FireRedASR2ForConditionalGeneration": ( "fireredasr2", "FireRedASR2ForConditionalGeneration", @@ -569,6 +573,7 @@ _SPECULATIVE_DECODING_MODELS = { "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"), + "Exaone4_5_MTP": ("exaone4_5_mtp", "Exaone4_5_MTP"), "NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"), "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 77e6f64bab8..323b16931fb 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1321,13 +1321,14 @@ class SpecDecodeBaseProposer: # handle multimodality assert hasattr(target_model, "config") if self.get_model_name(target_model) in [ - "Qwen2_5_VLForConditionalGeneration", - "Qwen3VLForConditionalGeneration", - "Qwen3VLMoeForConditionalGeneration", - "HunYuanVLForConditionalGeneration", + "Exaone4_5_ForConditionalGeneration", "GlmOcrForConditionalGeneration", + "HunYuanVLForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration", + "Qwen3VLForConditionalGeneration", + "Qwen3VLMoeForConditionalGeneration", ]: self.model.config.image_token_index = target_model.config.image_token_id elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":