From e6f7ff3a46a0cc45333fe87fcde0b222dba3397a Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Mon, 28 Apr 2025 07:58:03 -0700 Subject: [PATCH] [chore] Make llama4 MoE use maybe_execute_in_parallel (#3779) Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_llama.py | 20 +++++-------------- .../_torch/modules/multi_stream_utils.py | 2 +- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index eaf5afa3c1..076c3a0278 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -310,21 +310,11 @@ class Llama4MoE(nn.Module): ) -> torch.Tensor: # Only enable multi-stream for cuda graph since switch stream has extra host overhead # This design is mainly for low latency use case. Need to improve for max throughput use case. - do_multi_stream = torch.cuda.is_current_stream_capturing() - if do_multi_stream: - self.moe_event[0].record() - shared_output = self.shared_expert(hidden_states) - if do_multi_stream: - with torch.cuda.stream(self.aux_stream): - self.moe_event[0].wait() - routed_output = self.compute_routed_output( - hidden_states, all_rank_num_tokens, min_latency_mode) - self.moe_event[1].record() - self.moe_event[1].wait() - else: - routed_output = self.compute_routed_output(hidden_states, - all_rank_num_tokens, - min_latency_mode) + fn0 = lambda: self.shared_expert(hidden_states) + fn1 = lambda: self.compute_routed_output( + hidden_states, all_rank_num_tokens, min_latency_mode) + shared_output, routed_output = maybe_execute_in_parallel( + fn0, fn1, self.moe_event[0], self.moe_event[1], self.aux_stream) if min_latency_mode: return [shared_output, *routed_output] diff --git a/tensorrt_llm/_torch/modules/multi_stream_utils.py b/tensorrt_llm/_torch/modules/multi_stream_utils.py index 87a79eef4e..e91b7eac24 100644 --- a/tensorrt_llm/_torch/modules/multi_stream_utils.py +++ b/tensorrt_llm/_torch/modules/multi_stream_utils.py @@ -24,7 +24,7 @@ def maybe_execute_in_parallel( event0 (torch.cuda.Event): cuda event for fn0 event1 (torch.cuda.Event): cuda event for fn1 aux_stream (Optional[torch.cuda.Stream]): the second cuda stream for fn1. - Mutil-stream is disabled when aux_stream is None. + Multi-stream is disabled when aux_stream is None. Returns: tuple[Any, Any]: the return values of fn0() and fn1()