[Feature] Add support for JetBrains' Mellum v2 code generation model (#43992)

Signed-off-by: Madeesh Kannan <madeeswaran.kannan@jetbrains.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Madeesh Kannan
2026-06-01 14:11:35 +00:00
committed by khluu
parent 682ffebfef
commit 932dfd5276
7 changed files with 266 additions and 0 deletions
+1
View File
@@ -438,6 +438,7 @@ th {
| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ | ✅︎ |
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ |
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ |
| `MellumForCausalLM` | Mellum 2 | `JetBrains/Mellum2-12B-A2.5B-Base`, etc. | | ✅︎ |
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ |
| `MiMoV2FlashForCausalLM` | MiMoV2Flash | `XiaomiMiMo/MiMo-V2-Flash`, etc. | | ✅︎ |
| `MiMoV2ForCausalLM` | MiMoV2Pro | `XiaomiMiMo/MiMo-V2.5-Pro`, etc. | | ✅︎ |
+1
View File
@@ -523,6 +523,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"MellumForCausalLM": _HfExamplesInfo("JetBrains/Mellum2-12B-A2.5B-Base"),
"Qwen3NextForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
extras={"tiny-random": "tiny-random/qwen3-next-moe"},
+253
View File
@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from .qwen3_moe import (
Qwen3MoeAttention,
Qwen3MoeDecoderLayer,
Qwen3MoeForCausalLM,
Qwen3MoeMLP,
Qwen3MoeModel,
Qwen3MoeSparseMoeBlock,
)
from .utils import PPMissingLayer, extract_layer_index, maybe_prefix
class MellumAttention(Qwen3MoeAttention):
"""
Differences from `Qwen3MoeAttention`:
- Supports `per_layer_sliding_window` for `Attention`.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_parameters: dict[str, Any],
max_position_embeddings: int = 8192,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
cache_config: Any | None = None,
quant_config: Any | None = None,
prefix: str = "",
dual_chunk_attention_config: dict[str, Any] | None = None,
per_layer_sliding_window: int | None = None,
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{prefix}.attn",
**(
{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
}
if dual_chunk_attention_config
else {}
),
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
class MellumDecoderLayer(Qwen3MoeDecoderLayer):
"""
Differences from `Qwen3MoeDecoderLayer`:
- Supports interleaved SWA and per-layer RoPE scaling.
"""
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
nn.Module.__init__(self)
config = vllm_config.model_config.hf_text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
layer_idx = extract_layer_index(prefix)
layer_type = config.layer_types[layer_idx]
if layer_type == "sliding_attention":
sliding_window = getattr(config, "sliding_window", None)
else:
sliding_window = None
rope_parameters = config.rope_parameters[layer_type]
self.self_attn = MellumAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_parameters=rope_parameters,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=getattr(config, "head_dim", None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
dual_chunk_attention_config=dual_chunk_attention_config,
per_layer_sliding_window=sliding_window,
)
if config.mlp_layer_types[layer_idx] == "sparse":
self.mlp = Qwen3MoeSparseMoeBlock(
vllm_config=vllm_config, prefix=f"{prefix}.mlp"
)
else:
self.mlp = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
@support_torch_compile
class MellumModel(Qwen3MoeModel):
"""
Differences from `Qwen3MoeModel`:
- Uses `MellumDecoderLayer`.
"""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=MellumDecoderLayer,
)
class MellumForCausalLM(Qwen3MoeForCausalLM):
"""
Differences from `Qwen3MoeForCausalLM`:
- Uses `MellumModel`.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_text_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
if "dense" in getattr(config, "mlp_layer_types", []):
self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
self.model = MellumModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
self.expert_weights = []
self.moe_layers = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Qwen3MoeDecoderLayer)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No MoE layer found in the model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
+1
View File
@@ -160,6 +160,7 @@ _TEXT_GENERATION_MODELS = {
"LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
"MellumForCausalLM": ("mellum", "MellumForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
+1
View File
@@ -114,6 +114,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
jais="JAISConfig",
mlp_speculator="MLPSpeculatorConfig",
medusa="MedusaConfig",
mellum="MellumConfig",
midashenglm="MiDashengLMConfig",
moondream3="Moondream3Config",
eagle="EAGLEConfig",
@@ -49,6 +49,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"LagunaConfig": "vllm.transformers_utils.configs.laguna",
"Lfm2MoeConfig": "vllm.transformers_utils.configs.lfm2_moe",
"MedusaConfig": "vllm.transformers_utils.configs.medusa",
"MellumConfig": "vllm.transformers_utils.configs.mellum",
"MiDashengLMConfig": "vllm.transformers_utils.configs.midashenglm",
"MLPSpeculatorConfig": "vllm.transformers_utils.configs.mlp_speculator",
"Moondream3Config": "vllm.transformers_utils.configs.moondream3",
@@ -117,6 +118,7 @@ __all__ = [
"LagunaConfig",
"Lfm2MoeConfig",
"MedusaConfig",
"MellumConfig",
"MiDashengLMConfig",
"MLPSpeculatorConfig",
"Moondream3Config",
@@ -0,0 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers import Qwen3MoeConfig
class MellumConfig(Qwen3MoeConfig):
model_type = "mellum"