mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[chore] Make llama4 MoE use maybe_execute_in_parallel (#3779)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
parent
19da82d68f
commit
e6f7ff3a46
@ -310,21 +310,11 @@ class Llama4MoE(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
# Only enable multi-stream for cuda graph since switch stream has extra host overhead
|
||||
# This design is mainly for low latency use case. Need to improve for max throughput use case.
|
||||
do_multi_stream = torch.cuda.is_current_stream_capturing()
|
||||
if do_multi_stream:
|
||||
self.moe_event[0].record()
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
if do_multi_stream:
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.moe_event[0].wait()
|
||||
routed_output = self.compute_routed_output(
|
||||
hidden_states, all_rank_num_tokens, min_latency_mode)
|
||||
self.moe_event[1].record()
|
||||
self.moe_event[1].wait()
|
||||
else:
|
||||
routed_output = self.compute_routed_output(hidden_states,
|
||||
all_rank_num_tokens,
|
||||
min_latency_mode)
|
||||
fn0 = lambda: self.shared_expert(hidden_states)
|
||||
fn1 = lambda: self.compute_routed_output(
|
||||
hidden_states, all_rank_num_tokens, min_latency_mode)
|
||||
shared_output, routed_output = maybe_execute_in_parallel(
|
||||
fn0, fn1, self.moe_event[0], self.moe_event[1], self.aux_stream)
|
||||
if min_latency_mode:
|
||||
return [shared_output, *routed_output]
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ def maybe_execute_in_parallel(
|
||||
event0 (torch.cuda.Event): cuda event for fn0
|
||||
event1 (torch.cuda.Event): cuda event for fn1
|
||||
aux_stream (Optional[torch.cuda.Stream]): the second cuda stream for fn1.
|
||||
Mutil-stream is disabled when aux_stream is None.
|
||||
Multi-stream is disabled when aux_stream is None.
|
||||
|
||||
Returns:
|
||||
tuple[Any, Any]: the return values of fn0() and fn1()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user