From 9909dca6fad58479b4b62febab1cc9e90bd48241 Mon Sep 17 00:00:00 2001 From: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Date: Mon, 2 Feb 2026 13:23:37 +0800 Subject: [PATCH] [None] [feat] Add PDL support for moeAlltoAllKernels (#10591) Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Signed-off-by: Zhenhuan Chen Co-authored-by: Zhenhuan Chen --- .../moeAlltoAllKernels.cu | 107 +++++++++++------- 1 file changed, 67 insertions(+), 40 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index da1aed6a37..99a5e8c413 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -29,6 +29,8 @@ TRTLLM_NAMESPACE_BEGIN namespace kernels::moe_comm { +using tensorrt_llm::common::launchWithPdlWhenEnabled; + #define ENABLE_DEBUG_PRINT 0 #define DISABLE_SYNC_FOR_PROFILING 0 @@ -345,6 +347,10 @@ __device__ void vectorized_dispatch(uint8_t const* src_ptr, int bytes_per_token, __global__ void moeA2APrepareDispatchKernel( int* send_counters, int* local_token_counter, int ep_size, uint32_t* flag_val_ptr) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); + cudaTriggerProgrammaticLaunchCompletion(); +#endif int idx = blockIdx.x * blockDim.x + threadIdx.x; // Zero send_counters if (idx < ep_size) @@ -371,7 +377,6 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ int max_tokens_per_rank, // Maximum tokens per rank int local_num_tokens, int rank_id, int ep_size, int num_experts, int eplb_stats_num_experts) { - int thread_idx = ThreadingPolicy::offset(); int local_token_idx = ThreadingPolicy::token_idx(); @@ -382,6 +387,9 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ // Other threads should return. if (local_token_idx > 0) return; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } else { @@ -408,6 +416,9 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ uint64_t already_copied = 0; int num_experts_per_rank = num_experts / ep_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int k = 0; k < TOP_K; k++) { int expert_id = token_selected_experts[local_token_idx * TOP_K + k]; @@ -467,6 +478,9 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ ThreadingPolicy::sync(); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif bool is_first_warp = threadIdx.x / warpSize == 0; if (is_first_warp) @@ -569,8 +583,8 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params) { - moeA2APrepareDispatchKernel<<<1, params.ep_size, 0, params.stream>>>( - params.send_counters, params.local_token_counter, params.ep_size, params.flag_val); + launchWithPdlWhenEnabled("moeA2APrepareDispatchKernel", moeA2APrepareDispatchKernel, 1, params.ep_size, 0, + params.stream, params.send_counters, params.local_token_counter, params.ep_size, params.flag_val); } // ============================================================================ @@ -635,12 +649,13 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) grid_size = 1; } int shared_bytes = 2 * params.top_k * (int) sizeof(int); - SWITCH_BOOL(params.enable_eplb, EPLB_STATS, - SWITCH_TOP_K(params.top_k, TOP_K, - moeA2ADispatchKernel - <<>>(params.token_selected_experts, kernel_ptrs, - params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, - params.ep_size, params.num_experts, params.eplb_stats_num_experts))) + SWITCH_BOOL(params.enable_eplb, EPLB_STATS, SWITCH_TOP_K(params.top_k, TOP_K, { + auto kernel_fn = moeA2ADispatchKernel; + launchWithPdlWhenEnabled("moeA2ADispatchKernel", kernel_fn, grid_size, kBlockSize, shared_bytes, + params.stream, params.token_selected_experts, kernel_ptrs, params.num_payloads, + params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts, + params.eplb_stats_num_experts); + })) } else { @@ -651,12 +666,13 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) grid_size = 1; } int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int); - SWITCH_BOOL(params.enable_eplb, EPLB_STATS, - SWITCH_TOP_K(params.top_k, TOP_K, - moeA2ADispatchKernel - <<>>(params.token_selected_experts, kernel_ptrs, - params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, - params.ep_size, params.num_experts, params.eplb_stats_num_experts))) + SWITCH_BOOL(params.enable_eplb, EPLB_STATS, SWITCH_TOP_K(params.top_k, TOP_K, { + auto kernel_fn = moeA2ADispatchKernel; + launchWithPdlWhenEnabled("moeA2ADispatchKernel", kernel_fn, grid_size, kBlockSize, shared_bytes, + params.stream, params.token_selected_experts, kernel_ptrs, params.num_payloads, + params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts, + params.eplb_stats_num_experts); + })) } } @@ -989,6 +1005,11 @@ template __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t const* payload_bytes, int bytes_per_token, int ep_size, int max_tokens_per_rank, uint32_t* flag_val_ptr, int const* recv_counters) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); + cudaTriggerProgrammaticLaunchCompletion(); +#endif + if (blockIdx.x == 0 && threadIdx.x == 0) { // Increment flag_val for this combine round @@ -996,7 +1017,9 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c } if (payload_bytes == nullptr) + { return; + } int global_token_idx = ThreadingPolicy::token_idx(); @@ -1048,6 +1071,11 @@ __global__ void moeA2ACombineKernel( return; } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); + cudaTriggerProgrammaticLaunchCompletion(); +#endif + #if !DISABLE_SYNC_FOR_PROFILING // In-kernel readiness synchronization at start of combine: // - One warp signals readiness to all peers with current flag_val. @@ -1118,6 +1146,9 @@ __global__ void moeA2ACombineKernel( // Accumulate across ranks in registers, then store once per segment vectorized_combine(token_output, size_per_token, rank_id, max_tokens_per_rank, ptrs); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params) @@ -1139,21 +1170,16 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params) int global_token_num = params.prepare_payload == nullptr ? 1 : params.ep_size * params.max_tokens_per_rank; int grid_size_warp = ceilDiv(global_token_num, kWarpsPerBlock); int grid_size_block = global_token_num; // one block per token + int grid = params.one_block_per_token ? grid_size_block : grid_size_warp; - if (params.one_block_per_token) - { - moeA2APrepareCombineKernel<<>>( - static_cast(const_cast(params.recv_buffers[params.ep_rank])), - static_cast(params.prepare_payload), bytes_per_token, params.ep_size, - params.max_tokens_per_rank, params.flag_val, params.recv_counters); - } - else - { - moeA2APrepareCombineKernel<<>>( - static_cast(const_cast(params.recv_buffers[params.ep_rank])), - static_cast(params.prepare_payload), bytes_per_token, params.ep_size, - params.max_tokens_per_rank, params.flag_val, params.recv_counters); - } + uint8_t* recv_buffer_bytes = static_cast(const_cast(params.recv_buffers[params.ep_rank])); + uint8_t const* payload_bytes = static_cast(params.prepare_payload); + + auto kernel_fn + = params.one_block_per_token ? moeA2APrepareCombineKernel : moeA2APrepareCombineKernel; + launchWithPdlWhenEnabled("moeA2APrepareCombineKernel", kernel_fn, grid, kBlockSize, 0, params.stream, + recv_buffer_bytes, payload_bytes, bytes_per_token, params.ep_size, params.max_tokens_per_rank, params.flag_val, + params.recv_counters); } // ============================================================================ @@ -1206,19 +1232,16 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params) kernel_ptrs.topk_target_ranks = params.topk_target_ranks; kernel_ptrs.topk_send_indices = params.topk_send_indices; + int grid = params.one_block_per_token ? grid_size_block : grid_size_warp; + // Launch appropriate kernel with compact macros SWITCH_DTYPE(params.dtype, TKernelType, { SWITCH_POLICY(params.one_block_per_token, Policy, { SWITCH_TOP_K(params.top_k, TOP_K, { - auto launch = [&](int grid_blocks, int block_threads) - { - moeA2ACombineKernel - <<>>(kernel_ptrs, params.max_tokens_per_rank, - params.elements_per_token, params.local_num_tokens, params.ep_rank, params.ep_size); - }; - int grid = params.one_block_per_token ? grid_size_block : grid_size_warp; - int cta = kBlockSize; - launch(grid, cta); + auto kernel_fn = moeA2ACombineKernel; + launchWithPdlWhenEnabled("moeA2ACombineKernel", kernel_fn, grid, kBlockSize, 0, params.stream, + kernel_ptrs, params.max_tokens_per_rank, params.elements_per_token, params.local_num_tokens, + params.ep_rank, params.ep_size); }); }); }); @@ -1236,6 +1259,10 @@ __global__ void moeA2ASanitizeExpertIdsKernel(int32_t* expert_ids_ptr, int32_t c int source_rank = tid / max_tokens_per_rank; int token_idx = tid % max_tokens_per_rank; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); + cudaTriggerProgrammaticLaunchCompletion(); +#endif if (token_idx >= recv_counters_ptr[source_rank]) { int32_t* token_expert_ids = expert_ids_ptr + tid * top_k; @@ -1252,8 +1279,8 @@ void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv constexpr int kBlockSize = 256; int total_tokens = ep_size * max_tokens_per_rank; int grid = ceilDiv(total_tokens, kBlockSize); - moeA2ASanitizeExpertIdsKernel<<>>( - expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id); + launchWithPdlWhenEnabled("moeA2ASanitizeExpertIdsKernel", moeA2ASanitizeExpertIdsKernel, grid, kBlockSize, 0, + stream, expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id); } } // namespace kernels::moe_comm