fix: llama4: add an option apply_router_weight_on_input for in FusedMoE (#3492)

* apply a tenative fix to moe bypass kernel update

* Pass none to disable final stage in moe

Co-authored-by: hlu1 <14827759+hlu1@users.noreply.github.com>
Signed-off-by: Chang Liu <lc9114@gmail.com>

---------

Signed-off-by: Chang Liu <lc9114@gmail.com>
Co-authored-by: hlu1 <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
Chang Liu 2025-04-14 11:56:42 -07:00 committed by GitHub
parent b286b51118
commit 1902d73eb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 1 deletions

View File

@ -119,7 +119,8 @@ class Llama4MoE(nn.Module):
reduce_results=
False, # In both low latency and max-throughput scenarios, FusedMoE needs not to do allreduce inside op.
weight_loading_mode=MoEWeightLoadingMode.FUSED_GATE_UP_PROJ,
model_config=model_config)
model_config=model_config,
apply_router_weight_on_input=True)
self.shared_expert = GatedMLP(
hidden_size=hidden_size,

View File

@ -240,6 +240,7 @@ class FusedMoE(nn.Module):
aux_stream: torch.cuda.Stream = torch.cuda.Stream(),
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
VANILLA,
apply_router_weight_on_input: bool = False,
):
from ..distributed import AllReduce
@ -302,6 +303,9 @@ class FusedMoE(nn.Module):
if not model_config.skip_create_weights:
self.create_weights()
# If True, the router weight will be multiplied on the input rather than at the end of FC2
self.apply_router_weight_on_input = apply_router_weight_on_input
def setup_quant_scales(self):
self.quant_scales = None
if not self.has_any_quant:
@ -558,6 +562,12 @@ class FusedMoE(nn.Module):
assert token_final_scales.dtype == torch.float32
assert token_selected_experts.dtype == torch.int32
if self.apply_router_weight_on_input:
assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing"
x = x * token_final_scales.to(x.dtype)
# TODO: remove this once we have correct fusedmoe kernel ready
token_final_scales = None
x_sf = None
if self.has_any_quant:
if self.has_fp8_qdq: