diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index cd94ec494f..db126b1341 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -263,11 +263,11 @@ class Qwen3MoEDecoderLayer(DecoderLayer): do_finalize=do_finalize, ) - if spec_metadata: - spec_metadata.maybe_capture_hidden_states(self.layer_idx, - hidden_states, residual) if self.fusion_config.POST_MOE_FUSION: if do_finalize: + if spec_metadata: + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) hidden_states, residual = self.allreduce( hidden_states, all_reduce_params=AllReduceParams( @@ -296,7 +296,15 @@ class Qwen3MoEDecoderLayer(DecoderLayer): ) hidden_states, residual = self.moe_allreduce( fc2_output, all_reduce_params=moe_all_reduce_params) + + if spec_metadata: + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) + else: + if spec_metadata: + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) if self.next_layer_layernorm is not None: hidden_states, residual = self.next_layer_layernorm( hidden_states, residual)