From 73ba4fc320575a7057c5514a95e72b4a9e376cc1 Mon Sep 17 00:00:00 2001 From: bhsueh_NV <11360707+byshiue@users.noreply.github.com> Date: Wed, 25 Jun 2025 09:20:23 +0800 Subject: [PATCH] fix: fix bug of qwen3 + eagle3 + finalize_moe_fusion (#5369) Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_qwen3_moe.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index cd94ec494f..db126b1341 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -263,11 +263,11 @@ class Qwen3MoEDecoderLayer(DecoderLayer): do_finalize=do_finalize, ) - if spec_metadata: - spec_metadata.maybe_capture_hidden_states(self.layer_idx, - hidden_states, residual) if self.fusion_config.POST_MOE_FUSION: if do_finalize: + if spec_metadata: + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) hidden_states, residual = self.allreduce( hidden_states, all_reduce_params=AllReduceParams( @@ -296,7 +296,15 @@ class Qwen3MoEDecoderLayer(DecoderLayer): ) hidden_states, residual = self.moe_allreduce( fc2_output, all_reduce_params=moe_all_reduce_params) + + if spec_metadata: + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) + else: + if spec_metadata: + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) if self.next_layer_layernorm is not None: hidden_states, residual = self.next_layer_layernorm( hidden_states, residual)