Add PDL support for moeAlltoAllKernels

Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2026-01-11 19:18:04 -08:00
parent 38296a472b
commit 9520144116

View File

@ -334,6 +334,11 @@ __device__ void vectorized_dispatch(uint8_t const* src_ptr, int bytes_per_token,
__global__ void moeA2APrepareDispatchKernel(
int* send_counters, int* local_token_counter, int ep_size, uint32_t* flag_val_ptr)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// PDL: Wait for any dependent kernels to complete before starting
cudaGridDependencySynchronize();
#endif
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Zero send_counters
if (idx < ep_size)
@ -347,6 +352,11 @@ __global__ void moeA2APrepareDispatchKernel(
// Increment flag_val for this dispatch round
*flag_val_ptr = *flag_val_ptr + 1;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// PDL: Signal that this kernel's main work is complete and dependent kernels can launch
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
// ============================================================================
@ -364,6 +374,10 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
int max_tokens_per_rank, // Maximum tokens per rank
int local_num_tokens, int rank_id, int ep_size, int num_experts_per_rank)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// PDL: Wait for any dependent kernels to complete before starting
cudaGridDependencySynchronize();
#endif
int thread_idx = ThreadingPolicy::offset();
int local_token_idx = ThreadingPolicy::token_idx();
@ -529,12 +543,30 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
#endif
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// PDL: Signal that this kernel's main work is complete and dependent kernels can launch
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params)
{
moeA2APrepareDispatchKernel<<<1, params.ep_size, 0, params.stream>>>(
params.send_counters, params.local_token_counter, params.ep_size, params.flag_val);
// Setup PDL launch configuration
cudaLaunchConfig_t config;
config.gridDim = 1;
config.blockDim = params.ep_size;
config.dynamicSmemBytes = 0;
config.stream = params.stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.attrs = attrs;
config.numAttrs = 1;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, moeA2APrepareDispatchKernel, params.send_counters,
params.local_token_counter, params.ep_size, params.flag_val));
}
// ============================================================================
@ -587,6 +619,17 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
constexpr int kWarpSize = 32;
int const kWarpsPerBlock = kBlockSize / kWarpSize;
// Setup PDL launch configuration
cudaLaunchConfig_t config;
config.blockDim = kBlockSize;
config.stream = params.stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.attrs = attrs;
config.numAttrs = 1;
// Configure kernel launch
if (params.one_block_per_token)
{
@ -597,10 +640,15 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
grid_size = 1;
}
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<BlockPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
params.token_selected_experts, kernel_ptrs, params.num_payloads, params.max_tokens_per_rank,
params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts_per_rank))
config.gridDim = grid_size;
config.dynamicSmemBytes = shared_bytes;
SWITCH_TOP_K(params.top_k, TOP_K, {
auto kernel_fn = moeA2ADispatchKernel<BlockPolicy, TOP_K>;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_fn, params.token_selected_experts, kernel_ptrs,
params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank,
params.ep_size, params.num_experts_per_rank));
})
}
else
{
@ -611,10 +659,15 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
grid_size = 1;
}
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<WarpPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
params.token_selected_experts, kernel_ptrs, params.num_payloads, params.max_tokens_per_rank,
params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts_per_rank))
config.gridDim = grid_size;
config.dynamicSmemBytes = shared_bytes;
SWITCH_TOP_K(params.top_k, TOP_K, {
auto kernel_fn = moeA2ADispatchKernel<WarpPolicy, TOP_K>;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_fn, params.token_selected_experts, kernel_ptrs,
params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank,
params.ep_size, params.num_experts_per_rank));
})
}
}
@ -947,6 +1000,11 @@ template <typename ThreadingPolicy>
__global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t const* payload_bytes,
int bytes_per_token, int ep_size, int max_tokens_per_rank, uint32_t* flag_val_ptr, int const* recv_counters)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// PDL: Wait for any dependent kernels to complete before starting
cudaGridDependencySynchronize();
#endif
if (blockIdx.x == 0 && threadIdx.x == 0)
{
// Increment flag_val for this combine round
@ -954,13 +1012,25 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c
}
if (payload_bytes == nullptr)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// PDL: Signal completion even if no payload
cudaTriggerProgrammaticLaunchCompletion();
#endif
return;
}
int slot_idx = ThreadingPolicy::token_idx();
int total_slots = ep_size * max_tokens_per_rank;
if (slot_idx >= total_slots)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// PDL: Signal completion before early return
cudaTriggerProgrammaticLaunchCompletion();
#endif
return;
}
// Map global token to (source_rank, token_idx)
int source_rank = slot_idx / max_tokens_per_rank;
@ -968,7 +1038,13 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c
// Skip invalid tokens beyond per-source recv count
if (token_idx >= recv_counters[source_rank])
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// PDL: Signal completion before early return
cudaTriggerProgrammaticLaunchCompletion();
#endif
return;
}
// Calculate source and destination pointers for this token
size_t slot_offset = static_cast<size_t>(slot_idx) * bytes_per_token;
@ -977,6 +1053,11 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c
// Copy one token's data using vectorized copy with policy
vectorized_copy<ThreadingPolicy>(dst_ptr, src_ptr, bytes_per_token);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// PDL: Signal that this kernel's main work is complete and dependent kernels can launch
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
// ============================================================================
@ -988,6 +1069,11 @@ __global__ void moeA2ACombineKernel(
const CombineKernelPointers ptrs, // Combine-specific struct, src_data_ptrs[0] is output
int max_tokens_per_rank, int elements_per_token, int local_num_tokens, int rank_id, int ep_size)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// PDL: Wait for any dependent kernels to complete before starting
cudaGridDependencySynchronize();
#endif
int local_token_idx = ThreadingPolicy::token_idx();
int const size_per_token = elements_per_token * sizeof(T);
@ -1063,6 +1149,11 @@ __global__ void moeA2ACombineKernel(
// Accumulate across ranks in registers, then store once per segment
vectorized_combine<TOP_K, ThreadingPolicy, T>(token_output, size_per_token, rank_id, max_tokens_per_rank, ptrs);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// PDL: Signal that this kernel's main work is complete and dependent kernels can launch
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params)
@ -1085,19 +1176,37 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params)
int grid_size_warp = ceilDiv(total_slots, kWarpsPerBlock);
int grid_size_block = total_slots; // one block per token
int grid = params.one_block_per_token ? grid_size_block : grid_size_warp;
// Setup PDL launch configuration
cudaLaunchConfig_t config;
config.gridDim = grid;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = params.stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.attrs = attrs;
config.numAttrs = 1;
uint8_t* recv_buffer_bytes = static_cast<uint8_t*>(const_cast<void*>(params.recv_buffers[params.ep_rank]));
uint8_t const* payload_bytes = static_cast<uint8_t const*>(params.prepare_payload);
if (params.one_block_per_token)
{
moeA2APrepareCombineKernel<BlockPolicy><<<grid_size_block, kBlockSize, 0, params.stream>>>(
static_cast<uint8_t*>(const_cast<void*>(params.recv_buffers[params.ep_rank])),
static_cast<uint8_t const*>(params.prepare_payload), bytes_per_token, params.ep_size,
params.max_tokens_per_rank, params.flag_val, params.recv_counters);
auto kernel_fn = moeA2APrepareCombineKernel<BlockPolicy>;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
&config, kernel_fn, recv_buffer_bytes, payload_bytes, bytes_per_token, params.ep_size,
params.max_tokens_per_rank, params.flag_val, params.recv_counters));
}
else
{
moeA2APrepareCombineKernel<WarpPolicy><<<grid_size_warp, kBlockSize, 0, params.stream>>>(
static_cast<uint8_t*>(const_cast<void*>(params.recv_buffers[params.ep_rank])),
static_cast<uint8_t const*>(params.prepare_payload), bytes_per_token, params.ep_size,
params.max_tokens_per_rank, params.flag_val, params.recv_counters);
auto kernel_fn = moeA2APrepareCombineKernel<WarpPolicy>;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
&config, kernel_fn, recv_buffer_bytes, payload_bytes, bytes_per_token, params.ep_size,
params.max_tokens_per_rank, params.flag_val, params.recv_counters));
}
}
@ -1151,19 +1260,28 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
kernel_ptrs.topk_target_ranks = params.topk_target_ranks;
kernel_ptrs.topk_send_indices = params.topk_send_indices;
int grid = params.one_block_per_token ? grid_size_block : grid_size_warp;
// Setup PDL launch configuration
cudaLaunchConfig_t config;
config.gridDim = grid;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = params.stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.attrs = attrs;
config.numAttrs = 1;
// Launch appropriate kernel with compact macros
SWITCH_DTYPE(params.dtype, TKernelType, {
SWITCH_POLICY(params.one_block_per_token, Policy, {
SWITCH_TOP_K(params.top_k, TOP_K, {
auto launch = [&](int grid_blocks, int block_threads)
{
moeA2ACombineKernel<TKernelType, Policy, TOP_K>
<<<grid_blocks, block_threads, 0, params.stream>>>(kernel_ptrs, params.max_tokens_per_rank,
params.elements_per_token, params.local_num_tokens, params.ep_rank, params.ep_size);
};
int grid = params.one_block_per_token ? grid_size_block : grid_size_warp;
int cta = kBlockSize;
launch(grid, cta);
auto kernel_fn = moeA2ACombineKernel<TKernelType, Policy, TOP_K>;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_fn, kernel_ptrs, params.max_tokens_per_rank,
params.elements_per_token, params.local_num_tokens, params.ep_rank, params.ep_size));
});
});
});