From 1902d73eb5ba979e6a8dd0575b21201a44937ae5 Mon Sep 17 00:00:00 2001 From: Chang Liu Date: Mon, 14 Apr 2025 11:56:42 -0700 Subject: [PATCH] fix: llama4: add an option `apply_router_weight_on_input` for in FusedMoE (#3492) * apply a tenative fix to moe bypass kernel update * Pass none to disable final stage in moe Co-authored-by: hlu1 <14827759+hlu1@users.noreply.github.com> Signed-off-by: Chang Liu --------- Signed-off-by: Chang Liu Co-authored-by: hlu1 <14827759+hlu1@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_llama.py | 3 ++- tensorrt_llm/_torch/modules/fused_moe.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 6fc99754b7..dffe91ecbc 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -119,7 +119,8 @@ class Llama4MoE(nn.Module): reduce_results= False, # In both low latency and max-throughput scenarios, FusedMoE needs not to do allreduce inside op. weight_loading_mode=MoEWeightLoadingMode.FUSED_GATE_UP_PROJ, - model_config=model_config) + model_config=model_config, + apply_router_weight_on_input=True) self.shared_expert = GatedMLP( hidden_size=hidden_size, diff --git a/tensorrt_llm/_torch/modules/fused_moe.py b/tensorrt_llm/_torch/modules/fused_moe.py index f6ef183e57..81ee17ad78 100755 --- a/tensorrt_llm/_torch/modules/fused_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe.py @@ -240,6 +240,7 @@ class FusedMoE(nn.Module): aux_stream: torch.cuda.Stream = torch.cuda.Stream(), weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. VANILLA, + apply_router_weight_on_input: bool = False, ): from ..distributed import AllReduce @@ -302,6 +303,9 @@ class FusedMoE(nn.Module): if not model_config.skip_create_weights: self.create_weights() + # If True, the router weight will be multiplied on the input rather than at the end of FC2 + self.apply_router_weight_on_input = apply_router_weight_on_input + def setup_quant_scales(self): self.quant_scales = None if not self.has_any_quant: @@ -558,6 +562,12 @@ class FusedMoE(nn.Module): assert token_final_scales.dtype == torch.float32 assert token_selected_experts.dtype == torch.int32 + if self.apply_router_weight_on_input: + assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing" + x = x * token_final_scales.to(x.dtype) + # TODO: remove this once we have correct fusedmoe kernel ready + token_final_scales = None + x_sf = None if self.has_any_quant: if self.has_fp8_qdq: