[chore] Make llama4 MoE use maybe_execute_in_parallel (#3779)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
Mike Iovine 2025-04-28 07:58:03 -07:00 committed by GitHub
parent 19da82d68f
commit e6f7ff3a46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 16 deletions

View File

@ -310,21 +310,11 @@ class Llama4MoE(nn.Module):
) -> torch.Tensor:
# Only enable multi-stream for cuda graph since switch stream has extra host overhead
# This design is mainly for low latency use case. Need to improve for max throughput use case.
do_multi_stream = torch.cuda.is_current_stream_capturing()
if do_multi_stream:
self.moe_event[0].record()
shared_output = self.shared_expert(hidden_states)
if do_multi_stream:
with torch.cuda.stream(self.aux_stream):
self.moe_event[0].wait()
routed_output = self.compute_routed_output(
hidden_states, all_rank_num_tokens, min_latency_mode)
self.moe_event[1].record()
self.moe_event[1].wait()
else:
routed_output = self.compute_routed_output(hidden_states,
all_rank_num_tokens,
min_latency_mode)
fn0 = lambda: self.shared_expert(hidden_states)
fn1 = lambda: self.compute_routed_output(
hidden_states, all_rank_num_tokens, min_latency_mode)
shared_output, routed_output = maybe_execute_in_parallel(
fn0, fn1, self.moe_event[0], self.moe_event[1], self.aux_stream)
if min_latency_mode:
return [shared_output, *routed_output]

View File

@ -24,7 +24,7 @@ def maybe_execute_in_parallel(
event0 (torch.cuda.Event): cuda event for fn0
event1 (torch.cuda.Event): cuda event for fn1
aux_stream (Optional[torch.cuda.Stream]): the second cuda stream for fn1.
Mutil-stream is disabled when aux_stream is None.
Multi-stream is disabled when aux_stream is None.
Returns:
tuple[Any, Any]: the return values of fn0() and fn1()