[feat] Add EAGLE3 support for Qwen3 (#5206)

Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com>
This commit is contained in:
Yilin Fan 2025-06-17 02:07:06 -07:00 committed by GitHub
parent 517c1ecf72
commit 498fadceb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,11 +18,13 @@ from ..modules.fused_moe import (BaseMoeRoutingMethod, CutlassFusedMoE, MoE,
RoutingMethodType, create_moe)
from ..modules.linear import TensorParallelMode
from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
from ..utils import disable_fp4_allgather
from .modeling_qwen3 import Qwen3Attention
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
EagerFusionConfig, duplicate_kv_weight,
filter_weights, register_auto_model)
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import (DecoderModel, EagerFusionConfig,
duplicate_kv_weight, filter_weights,
register_auto_model)
class Qwen3Gate(nn.Module):
@ -203,6 +205,7 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
@ -240,6 +243,9 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)))
if spec_metadata:
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)
if self.fusion_config.POST_MOE_FUSION:
hidden_states, residual = self.allreduce(
hidden_states,
@ -300,6 +306,7 @@ class Qwen3MoEModel(DecoderModel):
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
@ -317,13 +324,14 @@ class Qwen3MoEModel(DecoderModel):
hidden_states, residual = decoder_layer(position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual)
residual=residual,
spec_metadata=spec_metadata)
return hidden_states
@register_auto_model("Qwen3MoeForCausalLM")
class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
Qwen3MoeConfig]):
class Qwen3MoeForCausalLM(SpecDecOneEngineForCausalLM[Qwen3MoEModel,
Qwen3MoeConfig]):
def __init__(
self,
@ -331,9 +339,7 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
):
super().__init__(
Qwen3MoEModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
model_config,
)
def load_weights(self, weights: Dict):
@ -353,7 +359,7 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
if len(module._parameters) > 0:
# skip load weights if tie word embeddings is enabled and layer is lm_head
if self.config.tie_word_embeddings and name.startswith(
"lm_head"):
"lm_head") or name.startswith("draft_model"):
continue
names = name.split(".")