mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: llama4: add an option apply_router_weight_on_input for in FusedMoE (#3492)
* apply a tenative fix to moe bypass kernel update * Pass none to disable final stage in moe Co-authored-by: hlu1 <14827759+hlu1@users.noreply.github.com> Signed-off-by: Chang Liu <lc9114@gmail.com> --------- Signed-off-by: Chang Liu <lc9114@gmail.com> Co-authored-by: hlu1 <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
parent
b286b51118
commit
1902d73eb5
@ -119,7 +119,8 @@ class Llama4MoE(nn.Module):
|
||||
reduce_results=
|
||||
False, # In both low latency and max-throughput scenarios, FusedMoE needs not to do allreduce inside op.
|
||||
weight_loading_mode=MoEWeightLoadingMode.FUSED_GATE_UP_PROJ,
|
||||
model_config=model_config)
|
||||
model_config=model_config,
|
||||
apply_router_weight_on_input=True)
|
||||
|
||||
self.shared_expert = GatedMLP(
|
||||
hidden_size=hidden_size,
|
||||
|
||||
@ -240,6 +240,7 @@ class FusedMoE(nn.Module):
|
||||
aux_stream: torch.cuda.Stream = torch.cuda.Stream(),
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
||||
VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
):
|
||||
from ..distributed import AllReduce
|
||||
|
||||
@ -302,6 +303,9 @@ class FusedMoE(nn.Module):
|
||||
if not model_config.skip_create_weights:
|
||||
self.create_weights()
|
||||
|
||||
# If True, the router weight will be multiplied on the input rather than at the end of FC2
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
|
||||
def setup_quant_scales(self):
|
||||
self.quant_scales = None
|
||||
if not self.has_any_quant:
|
||||
@ -558,6 +562,12 @@ class FusedMoE(nn.Module):
|
||||
assert token_final_scales.dtype == torch.float32
|
||||
assert token_selected_experts.dtype == torch.int32
|
||||
|
||||
if self.apply_router_weight_on_input:
|
||||
assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing"
|
||||
x = x * token_final_scales.to(x.dtype)
|
||||
# TODO: remove this once we have correct fusedmoe kernel ready
|
||||
token_final_scales = None
|
||||
|
||||
x_sf = None
|
||||
if self.has_any_quant:
|
||||
if self.has_fp8_qdq:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user