[None][feat] Eagle: MLA Based Eagle (#9677)

Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
Izzy Putterman 2026-01-02 10:45:07 -08:00 committed by GitHub
parent f3dd6da080
commit bdf6953ddc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 322 additions and 69 deletions

View File

@ -1377,13 +1377,13 @@ class DeepseekV3DecoderLayer(DecoderLayer):
hidden_states, residual = self.moe_allreduce(
fc2_output, all_reduce_params=moe_all_reduce_params)
else:
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
spec_metadata.maybe_capture_hidden_states(
self.layer_idx, hidden_states, None)
self.layer_idx, hidden_states, residual)
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)
return hidden_states, residual

View File

@ -8,7 +8,7 @@ from ...functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..model_config import ModelConfig, TConfig
from ..modules.attention import Attention
from ..modules.attention import MLA, Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import moe_load_balancer_set_repeated_for_next_layer
@ -76,14 +76,81 @@ class Eagle3Attention(Attention):
)
class Eagle3DecoderLayer(DecoderLayer):
class Eagle3MLAttention(MLA):
"""
MLA (Multi-head Latent Attention) for Eagle3 draft model (e.g., DeepSeekV3).
The first layer takes concatenated [embeds, hidden_states] as input (2x hidden_size),
while subsequent layers take regular hidden_states (1x hidden_size).
"""
def __init__(
self,
model_config: LlamaConfig,
model_config: ModelConfig[PretrainedConfig],
layer_idx: Optional[int] = None,
aux_stream: Optional[torch.cuda.Stream] = None,
next_layer_regular: bool = False,
):
config = model_config.pretrained_config
self._next_layer_regular = next_layer_regular
predicted_tokens_per_seq = (
model_config.spec_config.max_total_draft_tokens +
1 if model_config.spec_config is not None else 1)
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
qk_rope_head_dim=config.qk_rope_head_dim,
qk_nope_head_dim=config.qk_nope_head_dim,
q_lora_rank=config.q_lora_rank,
kv_lora_rank=config.kv_lora_rank,
v_head_dim=config.v_head_dim,
predicted_tokens_per_seq=predicted_tokens_per_seq,
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=PositionalEmbeddingParams(
type=PositionEmbeddingType.yarn,
rope=RopeParams.from_config(config),
is_neox=False,
),
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
aux_stream=aux_stream,
)
# Override the kv_a_proj_with_mqa projection for first layer.
# The number of input features is twice as big for EAGLE3 draft models.
if not self._next_layer_regular:
quant_config = model_config.get_quant_config()
# For Eagle3, first layer takes [embeds, hidden_states] concatenated
self.kv_a_proj_with_mqa = Linear(
2 * config.hidden_size, # Double input size
self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
dtype=config.torch_dtype,
quant_config=quant_config,
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
use_custom_cublas_mm=True,
)
class Eagle3DecoderLayer(DecoderLayer):
"""
Unified decoder layer for Eagle3 speculative decoding.
Supports both standard attention (Llama-style) and MLA (DeepSeekV3-style).
"""
def __init__(
self,
model_config: ModelConfig[PretrainedConfig],
layer_idx: int = 0,
is_first_layer: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
use_mla: bool = False,
aux_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
config = model_config.pretrained_config
eagle_config = config.eagle_config if hasattr(config,
@ -92,8 +159,18 @@ class Eagle3DecoderLayer(DecoderLayer):
self._next_layer_regular = (eagle_config.get("next_layer_regular", True)
and not is_first_layer) or eagle_config.get(
"eh_proj_before_attn", False)
self.self_attn = Eagle3Attention(model_config, layer_idx,
self._next_layer_regular)
# Select attention type based on config
if use_mla:
self.self_attn = Eagle3MLAttention(
model_config,
layer_idx,
aux_stream=aux_stream,
next_layer_regular=self._next_layer_regular,
)
else:
self.self_attn = Eagle3Attention(model_config, layer_idx,
self._next_layer_regular)
if config.model_type == "llama4_text":
inter_size = config.intermediate_size_mlp
@ -109,18 +186,25 @@ class Eagle3DecoderLayer(DecoderLayer):
overridden_tp_size=1
if model_config.mapping.enable_attention_dp else None,
)
if not self._next_layer_regular:
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.input_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
self.hidden_norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.hidden_norm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.post_attention_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
def forward(
self,
@ -157,11 +241,16 @@ class Eagle3DecoderLayer(DecoderLayer):
class Eagle3DraftModel(DecoderModel):
"""
Unified Eagle3 draft model supporting both standard attention (Llama-style)
and MLA attention (DeepSeekV3-style).
"""
def __init__(
self,
model_config: LlamaConfig,
model_config: ModelConfig[PretrainedConfig],
start_layer_idx: int = 0,
use_mla: bool = False,
) -> None:
super().__init__(model_config)
@ -175,6 +264,7 @@ class Eagle3DraftModel(DecoderModel):
self.num_layers = model_config.pretrained_config.num_hidden_layers
self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn",
False)
self._use_mla = use_mla
if hasattr(config, "target_hidden_size"):
self.hidden_size_in = config.target_hidden_size
@ -184,40 +274,60 @@ class Eagle3DraftModel(DecoderModel):
self._return_hidden_post_norm = eagle_config.get(
"return_hidden_post_norm", False)
# Create auxiliary CUDA stream for MLA operations (only needed for MLA)
self.aux_stream = torch.cuda.Stream() if use_mla else None
if self.spec_config.num_capture_layers > 1:
self.fc = Linear(self.hidden_size_in *
self.spec_config.num_capture_layers,
config.hidden_size,
bias=getattr(config, "bias", False),
dtype=config.torch_dtype)
self.fc = Linear(
self.hidden_size_in * self.spec_config.num_capture_layers,
config.hidden_size,
bias=getattr(config, "bias", False),
dtype=config.torch_dtype,
)
if self.num_layers > 1:
self.midlayer = nn.ModuleList([
Eagle3DecoderLayer(model_config,
start_layer_idx + i,
is_first_layer=(i == 0))
for i in range(self.num_layers)
Eagle3DecoderLayer(
model_config,
start_layer_idx + i,
is_first_layer=(i == 0),
use_mla=use_mla,
aux_stream=self.aux_stream,
) for i in range(self.num_layers)
])
else:
self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx)
self.midlayer = Eagle3DecoderLayer(
model_config,
start_layer_idx,
use_mla=use_mla,
aux_stream=self.aux_stream,
)
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.norm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
if (config.draft_vocab_size is not None
and config.vocab_size != config.draft_vocab_size):
self.d2t = nn.Parameter(
torch.empty((config.draft_vocab_size, ), dtype=torch.int32),
requires_grad=False,
)
if config.draft_vocab_size is not None and config.vocab_size != config.draft_vocab_size:
self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ),
dtype=torch.int32),
requires_grad=False)
if self._eh_proj_before_attn:
self.enorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=eagle_config.get(
"eh_proj_bias", False),
dtype=config.torch_dtype)
self.enorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
self.eh_proj = nn.Linear(
config.hidden_size * 2,
config.hidden_size,
bias=eagle_config.get("eh_proj_bias", False),
dtype=config.torch_dtype,
)
if self.hidden_size_in != config.hidden_size:
if model_config.mapping.enable_attention_dp:
@ -261,7 +371,7 @@ class Eagle3DraftModel(DecoderModel):
assert hidden_states is not None
# NOTE: If hidden states from the target model have to be concatenated,
# ideally, we expect that to happen outside the model definition. This
# helps usavoid data-dependent control flow and gives us better CUDA
# helps us avoid data-dependent control flow and gives us better CUDA
# graph coverage.
if self._eh_proj_before_attn:
input_embeds = self.enorm(inputs_embeds)
@ -273,17 +383,21 @@ class Eagle3DraftModel(DecoderModel):
for layer in self.midlayer:
if residual is not None:
hidden_states = hidden_states + residual
hidden_states, residual = layer(position_ids=position_ids,
embeds=inputs_embeds,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
hidden_states, residual = layer(
position_ids=position_ids,
embeds=inputs_embeds,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
)
else:
hidden_states, residual = self.midlayer(position_ids=position_ids,
embeds=inputs_embeds,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
hidden_states, residual = self.midlayer(
position_ids=position_ids,
embeds=inputs_embeds,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
)
hidden_states, hidden_states_to_save = self.norm(
hidden_states, residual)
@ -294,20 +408,36 @@ class Eagle3DraftModel(DecoderModel):
# We use Llama3 as the base architecture for EAGLE3 draft layers
@register_auto_model("EAGLE3LlamaForCausalLM")
class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, LlamaConfig]):
@register_auto_model("Eagle3DeepSeekV3ForCausalLM")
class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel,
PretrainedConfig]):
def __init__(
self,
model_config: LlamaConfig,
model_config: ModelConfig[PretrainedConfig],
start_layer_idx: int = 0,
):
draft_vocab_size = model_config.pretrained_config.vocab_size
if model_config.pretrained_config.draft_vocab_size is not None:
draft_vocab_size = model_config.pretrained_config.draft_vocab_size
super().__init__(Eagle3DraftModel(model_config, start_layer_idx),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=draft_vocab_size)
# Determine if we should use MLA attention based on config
# MLA is used for DeepSeekV3-style models that have kv_lora_rank
config = model_config.pretrained_config
self._use_mla = hasattr(config, 'kv_lora_rank') and config.kv_lora_rank
draft_model = Eagle3DraftModel(
model_config,
start_layer_idx,
use_mla=self._use_mla,
)
super().__init__(
draft_model,
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=draft_vocab_size,
)
self.load_lm_head_from_target = True
def forward(
@ -339,6 +469,7 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, LlamaConfig]):
)
def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
new_weights = {}
for k, v in weights.items():
if 'lm_head' not in k:
@ -347,13 +478,24 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, LlamaConfig]):
self.load_lm_head_from_target = False
new_k = k
new_weights[new_k] = v
if self.load_lm_head_from_target:
super().load_weights(weights=new_weights,
weight_mapper=weight_mapper,
skip_modules=['lm_head'])
if self._use_mla:
# Use DeepseekV3WeightLoader for proper MLA weight handling
from .modeling_deepseekv3 import DeepseekV3WeightLoader
weight_loader = DeepseekV3WeightLoader(self, is_draft_model=False)
if self.load_lm_head_from_target:
weight_loader.load_weights(new_weights,
skip_modules=['lm_head'])
else:
weight_loader.load_weights(new_weights)
else:
super().load_weights(weights=new_weights,
weight_mapper=weight_mapper)
if self.load_lm_head_from_target:
super().load_weights(weights=new_weights,
weight_mapper=weight_mapper,
skip_modules=['lm_head'])
else:
super().load_weights(weights=new_weights,
weight_mapper=weight_mapper)
def load_weights_from_target_model(self,
target_model: torch.nn.Module) -> None:
@ -704,6 +846,7 @@ def get_draft_model(model_config, draft_config, lm_head, model):
spec_dec_mode = model_config.spec_config.spec_dec_mode
if spec_dec_mode.is_eagle3_one_model():
if model_config.spec_config.eagle3_model_arch == "llama3":
# Eagle3ForCausalLM handles both Llama3 and DeepSeekV3 architectures
return Eagle3ForCausalLM(
draft_config, model_config.pretrained_config.num_hidden_layers)
elif model_config.spec_config.eagle3_model_arch == "mistral_large3":
@ -714,6 +857,7 @@ def get_draft_model(model_config, draft_config, lm_head, model):
raise ValueError(
f"Unsupported eagle3 model architecture: {spec_dec_mode.eagle3_model_arch}"
)
elif spec_dec_mode.is_mtp_one_model():
return MTPForCausalLM(model_config,
model_config.pretrained_config.num_hidden_layers,

View File

@ -33,6 +33,7 @@ class LazyConfigDict(dict):
_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict(
deepseek_v32="DeepseekV3Config",
kimi_k2="DeepseekV3Config",
) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class

View File

@ -133,7 +133,7 @@ class Eagle3SpecMetadata(SpecMetadata):
self.layers_to_capture = (self.num_layers - 1, )
elif self.layers_to_capture is None:
if self.num_layers == 1 or self.is_mtp_eagle:
self.layers_to_capture = (self.num_layers - 1, )
self.layers_to_capture = (-1, )
else:
if self.num_layers <= 5:
raise ValueError(

View File

@ -462,6 +462,114 @@ def test_deepseek_eagle3():
pass
def test_deepseek_mla_eagle3():
use_cuda_graph = True
attn_backend = "TRTLLM"
disable_overlap_scheduler = False
enable_block_reuse = False
use_one_model = True
enable_chunked_prefill = False
# Eagle3 one model works with overlap scheduler and block reuse.
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 150:
pytest.skip("Not enough memory to load target + draft model")
models_path = llm_models_root()
eagle_config = {
"architectures": ["Eagle3DeepseekV3ForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"first_k_dense_replace": 1,
"hidden_act": "silu",
"hidden_size": 2560,
"intermediate_size": 8192,
"kv_lora_rank": 512,
"max_position_embeddings": 4096,
"model_type": "kimi_k2",
"num_attention_heads": 32,
"num_hidden_layers": 1,
"num_key_value_heads": 32,
"num_nextn_predict_layers": 0,
"q_lora_rank": 1536,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"beta_fast": 1.0,
"beta_slow": 1.0,
"factor": 64.0,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn"
},
"rope_theta": 50000.0,
"routed_scaling_factor": 2.827,
"scoring_func": "sigmoid",
"seq_aux": True,
"topk_group": 1,
"topk_method": "noaux_tc",
"torch_dtype": "bfloat16",
"torchscript": False,
"transformers_version": "4.51.3",
"use_bfloat16": False,
"use_cache": True,
"v_head_dim": 128,
"vocab_size": 129280,
"draft_vocab_size": 129280,
"eagle_config": {
"use_aux_hidden_state": True,
"use_input_layernorm_in_first_layer": True,
"use_last_layernorm": True,
"use_mtp_layernorm": False
}
}
with tempfile.TemporaryDirectory() as temp_dir:
eagle_model_dir = Path(temp_dir)
config_path = eagle_model_dir / "config.json"
with config_path.open("w") as f:
json.dump(eagle_config, f, indent=2)
target_model_dir = f"{models_path}/DeepSeek-V3-Lite/nvfp4_moe_only"
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
# that ref and spec does not match 100%
max_batch_size = 16
max_draft_len = 3
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
max_tokens=8192)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1]) if use_cuda_graph else None
llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
max_num_tokens=4096,
max_seq_len=4096,
kv_cache_config=kv_cache_config,
enable_chunked_prefill=enable_chunked_prefill,
load_format="dummy",
)
spec_config = EagleDecodingConfig(max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=use_one_model,
load_format="dummy")
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
sampling_params = SamplingParams(max_tokens=32, temperature=0)
for output in llm_spec.generate_async(tok_ids,
sampling_params,
streaming=True):
pass
@pytest.mark.parametrize("use_one_model", [True, False])
def test_multi_eagle3(use_one_model: bool):
use_cuda_graph = True