Qwen3: Fix eagle hidden states (#6199)

Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
Izzy Putterman 2025-08-06 14:05:18 -07:00 committed by GitHub
parent a16ba6445c
commit 7e0158b583
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -214,7 +214,9 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
self.fusion_config.POST_MOE_FUSION = False
# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
@ -257,9 +259,6 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
if self.fusion_config.POST_MOE_FUSION:
if do_finalize:
if spec_metadata:
spec_metadata.maybe_capture_hidden_states(
self.layer_idx, hidden_states, residual)
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
@ -289,12 +288,8 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
hidden_states, residual = self.moe_allreduce(
fc2_output, all_reduce_params=moe_all_reduce_params)
if spec_metadata:
spec_metadata.maybe_capture_hidden_states(
self.layer_idx, hidden_states, residual)
else:
if spec_metadata:
if spec_metadata and spec_metadata.is_layer_capture(self.layer_idx):
spec_metadata.maybe_capture_hidden_states(
self.layer_idx, hidden_states, residual)
if self.next_layer_layernorm is not None: