[#2511][fix] eagle: qwen2 capture hidden states (#10091)

Signed-off-by: SpicyNoodle <522169030@qq.com>
This commit is contained in:
Xiao Xuan 2026-01-06 10:46:41 +08:00 committed by GitHub
parent 9cae7277ea
commit 46f035befe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,6 +16,7 @@ from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode
from ..modules.qk_norm_attention import QKNormRoPEAttention
from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model)
@ -148,6 +149,7 @@ class QwenDecoderLayer(DecoderLayer):
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
@ -170,6 +172,10 @@ class QwenDecoderLayer(DecoderLayer):
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if spec_metadata is not None:
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)
return hidden_states, residual
@ -204,6 +210,7 @@ class QwenModel(DecoderModel):
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
@ -222,7 +229,8 @@ class QwenModel(DecoderModel):
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
mrope_config=mrope_config)
mrope_config=mrope_config,
spec_metadata=spec_metadata)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@ -249,15 +257,15 @@ class Qwen2ForCausalLM(DecoderModelForCausalLM[QwenModel, Qwen2Config]):
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: bool = False,
mrope_config: Optional[dict] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
output = self.model(
input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
mrope_config=mrope_config,
)
output = self.model(input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
mrope_config=mrope_config,
spec_metadata=spec_metadata)
return self.logits_processor.forward(
output,