[None] [feat] Add PDL support for moeAlltoAllKernels (#10591)

Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Co-authored-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
This commit is contained in:
Kaiyu Xie 2026-02-02 13:23:37 +08:00 committed by GitHub
parent 77afcbddae
commit 9909dca6fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -29,6 +29,8 @@ TRTLLM_NAMESPACE_BEGIN
namespace kernels::moe_comm
{
using tensorrt_llm::common::launchWithPdlWhenEnabled;
#define ENABLE_DEBUG_PRINT 0
#define DISABLE_SYNC_FOR_PROFILING 0
@ -345,6 +347,10 @@ __device__ void vectorized_dispatch(uint8_t const* src_ptr, int bytes_per_token,
__global__ void moeA2APrepareDispatchKernel(
int* send_counters, int* local_token_counter, int ep_size, uint32_t* flag_val_ptr)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Zero send_counters
if (idx < ep_size)
@ -371,7 +377,6 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
int max_tokens_per_rank, // Maximum tokens per rank
int local_num_tokens, int rank_id, int ep_size, int num_experts, int eplb_stats_num_experts)
{
int thread_idx = ThreadingPolicy::offset();
int local_token_idx = ThreadingPolicy::token_idx();
@ -382,6 +387,9 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
// Other threads should return.
if (local_token_idx > 0)
return;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
}
else
{
@ -408,6 +416,9 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
uint64_t already_copied = 0;
int num_experts_per_rank = num_experts / ep_size;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
for (int k = 0; k < TOP_K; k++)
{
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
@ -467,6 +478,9 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
ThreadingPolicy::sync();
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
bool is_first_warp = threadIdx.x / warpSize == 0;
if (is_first_warp)
@ -569,8 +583,8 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params)
{
moeA2APrepareDispatchKernel<<<1, params.ep_size, 0, params.stream>>>(
params.send_counters, params.local_token_counter, params.ep_size, params.flag_val);
launchWithPdlWhenEnabled("moeA2APrepareDispatchKernel", moeA2APrepareDispatchKernel, 1, params.ep_size, 0,
params.stream, params.send_counters, params.local_token_counter, params.ep_size, params.flag_val);
}
// ============================================================================
@ -635,12 +649,13 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
grid_size = 1;
}
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
SWITCH_BOOL(params.enable_eplb, EPLB_STATS,
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<BlockPolicy, TOP_K, EPLB_STATS>
<<<grid_size, kBlockSize, shared_bytes, params.stream>>>(params.token_selected_experts, kernel_ptrs,
params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank,
params.ep_size, params.num_experts, params.eplb_stats_num_experts)))
SWITCH_BOOL(params.enable_eplb, EPLB_STATS, SWITCH_TOP_K(params.top_k, TOP_K, {
auto kernel_fn = moeA2ADispatchKernel<BlockPolicy, TOP_K, EPLB_STATS>;
launchWithPdlWhenEnabled("moeA2ADispatchKernel", kernel_fn, grid_size, kBlockSize, shared_bytes,
params.stream, params.token_selected_experts, kernel_ptrs, params.num_payloads,
params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts,
params.eplb_stats_num_experts);
}))
}
else
{
@ -651,12 +666,13 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
grid_size = 1;
}
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
SWITCH_BOOL(params.enable_eplb, EPLB_STATS,
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<WarpPolicy, TOP_K, EPLB_STATS>
<<<grid_size, kBlockSize, shared_bytes, params.stream>>>(params.token_selected_experts, kernel_ptrs,
params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank,
params.ep_size, params.num_experts, params.eplb_stats_num_experts)))
SWITCH_BOOL(params.enable_eplb, EPLB_STATS, SWITCH_TOP_K(params.top_k, TOP_K, {
auto kernel_fn = moeA2ADispatchKernel<WarpPolicy, TOP_K, EPLB_STATS>;
launchWithPdlWhenEnabled("moeA2ADispatchKernel", kernel_fn, grid_size, kBlockSize, shared_bytes,
params.stream, params.token_selected_experts, kernel_ptrs, params.num_payloads,
params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts,
params.eplb_stats_num_experts);
}))
}
}
@ -989,6 +1005,11 @@ template <typename ThreadingPolicy>
__global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t const* payload_bytes,
int bytes_per_token, int ep_size, int max_tokens_per_rank, uint32_t* flag_val_ptr, int const* recv_counters)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
if (blockIdx.x == 0 && threadIdx.x == 0)
{
// Increment flag_val for this combine round
@ -996,7 +1017,9 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c
}
if (payload_bytes == nullptr)
{
return;
}
int global_token_idx = ThreadingPolicy::token_idx();
@ -1048,6 +1071,11 @@ __global__ void moeA2ACombineKernel(
return;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
#if !DISABLE_SYNC_FOR_PROFILING
// In-kernel readiness synchronization at start of combine:
// - One warp signals readiness to all peers with current flag_val.
@ -1118,6 +1146,9 @@ __global__ void moeA2ACombineKernel(
// Accumulate across ranks in registers, then store once per segment
vectorized_combine<TOP_K, ThreadingPolicy, T>(token_output, size_per_token, rank_id, max_tokens_per_rank, ptrs);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params)
@ -1139,21 +1170,16 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params)
int global_token_num = params.prepare_payload == nullptr ? 1 : params.ep_size * params.max_tokens_per_rank;
int grid_size_warp = ceilDiv(global_token_num, kWarpsPerBlock);
int grid_size_block = global_token_num; // one block per token
int grid = params.one_block_per_token ? grid_size_block : grid_size_warp;
if (params.one_block_per_token)
{
moeA2APrepareCombineKernel<BlockPolicy><<<grid_size_block, kBlockSize, 0, params.stream>>>(
static_cast<uint8_t*>(const_cast<void*>(params.recv_buffers[params.ep_rank])),
static_cast<uint8_t const*>(params.prepare_payload), bytes_per_token, params.ep_size,
params.max_tokens_per_rank, params.flag_val, params.recv_counters);
}
else
{
moeA2APrepareCombineKernel<WarpPolicy><<<grid_size_warp, kBlockSize, 0, params.stream>>>(
static_cast<uint8_t*>(const_cast<void*>(params.recv_buffers[params.ep_rank])),
static_cast<uint8_t const*>(params.prepare_payload), bytes_per_token, params.ep_size,
params.max_tokens_per_rank, params.flag_val, params.recv_counters);
}
uint8_t* recv_buffer_bytes = static_cast<uint8_t*>(const_cast<void*>(params.recv_buffers[params.ep_rank]));
uint8_t const* payload_bytes = static_cast<uint8_t const*>(params.prepare_payload);
auto kernel_fn
= params.one_block_per_token ? moeA2APrepareCombineKernel<BlockPolicy> : moeA2APrepareCombineKernel<WarpPolicy>;
launchWithPdlWhenEnabled("moeA2APrepareCombineKernel", kernel_fn, grid, kBlockSize, 0, params.stream,
recv_buffer_bytes, payload_bytes, bytes_per_token, params.ep_size, params.max_tokens_per_rank, params.flag_val,
params.recv_counters);
}
// ============================================================================
@ -1206,19 +1232,16 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
kernel_ptrs.topk_target_ranks = params.topk_target_ranks;
kernel_ptrs.topk_send_indices = params.topk_send_indices;
int grid = params.one_block_per_token ? grid_size_block : grid_size_warp;
// Launch appropriate kernel with compact macros
SWITCH_DTYPE(params.dtype, TKernelType, {
SWITCH_POLICY(params.one_block_per_token, Policy, {
SWITCH_TOP_K(params.top_k, TOP_K, {
auto launch = [&](int grid_blocks, int block_threads)
{
moeA2ACombineKernel<TKernelType, Policy, TOP_K>
<<<grid_blocks, block_threads, 0, params.stream>>>(kernel_ptrs, params.max_tokens_per_rank,
params.elements_per_token, params.local_num_tokens, params.ep_rank, params.ep_size);
};
int grid = params.one_block_per_token ? grid_size_block : grid_size_warp;
int cta = kBlockSize;
launch(grid, cta);
auto kernel_fn = moeA2ACombineKernel<TKernelType, Policy, TOP_K>;
launchWithPdlWhenEnabled("moeA2ACombineKernel", kernel_fn, grid, kBlockSize, 0, params.stream,
kernel_ptrs, params.max_tokens_per_rank, params.elements_per_token, params.local_num_tokens,
params.ep_rank, params.ep_size);
});
});
});
@ -1236,6 +1259,10 @@ __global__ void moeA2ASanitizeExpertIdsKernel(int32_t* expert_ids_ptr, int32_t c
int source_rank = tid / max_tokens_per_rank;
int token_idx = tid % max_tokens_per_rank;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
if (token_idx >= recv_counters_ptr[source_rank])
{
int32_t* token_expert_ids = expert_ids_ptr + tid * top_k;
@ -1252,8 +1279,8 @@ void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv
constexpr int kBlockSize = 256;
int total_tokens = ep_size * max_tokens_per_rank;
int grid = ceilDiv(total_tokens, kBlockSize);
moeA2ASanitizeExpertIdsKernel<<<grid, kBlockSize, 0, stream>>>(
expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id);
launchWithPdlWhenEnabled("moeA2ASanitizeExpertIdsKernel", moeA2ASanitizeExpertIdsKernel, grid, kBlockSize, 0,
stream, expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id);
}
} // namespace kernels::moe_comm