mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Feature] Add support for JetBrains' Mellum v2 code generation model (#43992)
Signed-off-by: Madeesh Kannan <madeeswaran.kannan@jetbrains.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -437,6 +437,7 @@ th {
|
||||
| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ | ✅︎ |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ |
|
||||
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ |
|
||||
| `MellumForCausalLM` | Mellum 2 | `JetBrains/Mellum2-12B-A2.5B-Base`, etc. | | ✅︎ |
|
||||
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ |
|
||||
| `MiMoV2FlashForCausalLM` | MiMoV2Flash | `XiaomiMiMo/MiMo-V2-Flash`, etc. | | ✅︎ |
|
||||
| `MiMoV2ForCausalLM` | MiMoV2Pro | `XiaomiMiMo/MiMo-V2.5-Pro`, etc. | | ✅︎ |
|
||||
|
||||
@@ -522,6 +522,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
|
||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||
"MellumForCausalLM": _HfExamplesInfo("JetBrains/Mellum2-12B-A2.5B-Base"),
|
||||
"Qwen3NextForCausalLM": _HfExamplesInfo(
|
||||
"Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
extras={"tiny-random": "tiny-random/qwen3-next-moe"},
|
||||
|
||||
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
from .qwen3_moe import (
|
||||
Qwen3MoeAttention,
|
||||
Qwen3MoeDecoderLayer,
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoeMLP,
|
||||
Qwen3MoeModel,
|
||||
Qwen3MoeSparseMoeBlock,
|
||||
)
|
||||
from .utils import PPMissingLayer, extract_layer_index, maybe_prefix
|
||||
|
||||
|
||||
class MellumAttention(Qwen3MoeAttention):
|
||||
"""
|
||||
Differences from `Qwen3MoeAttention`:
|
||||
- Supports `per_layer_sliding_window` for `Attention`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_parameters: dict[str, Any],
|
||||
max_position_embeddings: int = 8192,
|
||||
head_dim: int | None = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
qkv_bias: bool = False,
|
||||
cache_config: Any | None = None,
|
||||
quant_config: Any | None = None,
|
||||
prefix: str = "",
|
||||
dual_chunk_attention_config: dict[str, Any] | None = None,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.dual_chunk_attention_config = dual_chunk_attention_config
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=per_layer_sliding_window,
|
||||
prefix=f"{prefix}.attn",
|
||||
**(
|
||||
{
|
||||
"layer_idx": extract_layer_index(prefix),
|
||||
"dual_chunk_attention_config": dual_chunk_attention_config,
|
||||
}
|
||||
if dual_chunk_attention_config
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
|
||||
class MellumDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
"""
|
||||
Differences from `Qwen3MoeDecoderLayer`:
|
||||
- Supports interleaved SWA and per-layer RoPE scaling.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
nn.Module.__init__(self)
|
||||
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
dual_chunk_attention_config = getattr(
|
||||
config, "dual_chunk_attention_config", None
|
||||
)
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
layer_type = config.layer_types[layer_idx]
|
||||
if layer_type == "sliding_attention":
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
else:
|
||||
sliding_window = None
|
||||
rope_parameters = config.rope_parameters[layer_type]
|
||||
|
||||
self.self_attn = MellumAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_parameters=rope_parameters,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, "attention_bias", False),
|
||||
head_dim=getattr(config, "head_dim", None),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
)
|
||||
|
||||
if config.mlp_layer_types[layer_idx] == "sparse":
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(
|
||||
vllm_config=vllm_config, prefix=f"{prefix}.mlp"
|
||||
)
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MellumModel(Qwen3MoeModel):
|
||||
"""
|
||||
Differences from `Qwen3MoeModel`:
|
||||
- Uses `MellumDecoderLayer`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
decoder_layer_type=MellumDecoderLayer,
|
||||
)
|
||||
|
||||
|
||||
class MellumForCausalLM(Qwen3MoeForCausalLM):
|
||||
"""
|
||||
Differences from `Qwen3MoeForCausalLM`:
|
||||
- Uses `MellumModel`.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
if "dense" in getattr(config, "mlp_layer_types", []):
|
||||
self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
|
||||
self.model = MellumModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
self.expert_weights = []
|
||||
|
||||
self.moe_layers = []
|
||||
example_layer = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Qwen3MoeDecoderLayer)
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
example_layer = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_layer is None:
|
||||
raise RuntimeError("No MoE layer found in the model.layers.")
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
self.num_logical_experts = example_layer.n_logical_experts
|
||||
self.num_physical_experts = example_layer.n_physical_experts
|
||||
self.num_local_physical_experts = example_layer.n_local_physical_experts
|
||||
self.num_routed_experts = example_layer.n_routed_experts
|
||||
self.num_redundant_experts = example_layer.n_redundant_experts
|
||||
@@ -159,6 +159,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
|
||||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
|
||||
"MellumForCausalLM": ("mellum", "MellumForCausalLM"),
|
||||
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
||||
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
||||
"MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
||||
|
||||
@@ -116,6 +116,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct)
|
||||
mlp_speculator="MLPSpeculatorConfig",
|
||||
medusa="MedusaConfig",
|
||||
mellum="MellumConfig",
|
||||
midashenglm="MiDashengLMConfig",
|
||||
moondream3="Moondream3Config",
|
||||
eagle="EAGLEConfig",
|
||||
|
||||
@@ -49,6 +49,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
|
||||
"LagunaConfig": "vllm.transformers_utils.configs.laguna",
|
||||
"Lfm2MoeConfig": "vllm.transformers_utils.configs.lfm2_moe",
|
||||
"MedusaConfig": "vllm.transformers_utils.configs.medusa",
|
||||
"MellumConfig": "vllm.transformers_utils.configs.mellum",
|
||||
"MiDashengLMConfig": "vllm.transformers_utils.configs.midashenglm",
|
||||
"MLPSpeculatorConfig": "vllm.transformers_utils.configs.mlp_speculator",
|
||||
"Moondream3Config": "vllm.transformers_utils.configs.moondream3",
|
||||
@@ -117,6 +118,7 @@ __all__ = [
|
||||
"LagunaConfig",
|
||||
"Lfm2MoeConfig",
|
||||
"MedusaConfig",
|
||||
"MellumConfig",
|
||||
"MiDashengLMConfig",
|
||||
"MLPSpeculatorConfig",
|
||||
"Moondream3Config",
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from transformers import Qwen3MoeConfig
|
||||
|
||||
|
||||
class MellumConfig(Qwen3MoeConfig):
|
||||
model_type = "mellum"
|
||||
Reference in New Issue
Block a user