mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Use launchWithPdlWhenEnabled
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
parent
9520144116
commit
36edeadfc1
@ -29,6 +29,8 @@ TRTLLM_NAMESPACE_BEGIN
|
||||
namespace kernels::moe_comm
|
||||
{
|
||||
|
||||
using tensorrt_llm::common::launchWithPdlWhenEnabled;
|
||||
|
||||
#define ENABLE_DEBUG_PRINT 0
|
||||
#define DISABLE_SYNC_FOR_PROFILING 0
|
||||
|
||||
@ -335,7 +337,6 @@ __global__ void moeA2APrepareDispatchKernel(
|
||||
int* send_counters, int* local_token_counter, int ep_size, uint32_t* flag_val_ptr)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// PDL: Wait for any dependent kernels to complete before starting
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
@ -354,7 +355,6 @@ __global__ void moeA2APrepareDispatchKernel(
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// PDL: Signal that this kernel's main work is complete and dependent kernels can launch
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
@ -375,7 +375,6 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
|
||||
int local_num_tokens, int rank_id, int ep_size, int num_experts_per_rank)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// PDL: Wait for any dependent kernels to complete before starting
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
@ -545,28 +544,14 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// PDL: Signal that this kernel's main work is complete and dependent kernels can launch
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
{
|
||||
// Setup PDL launch configuration
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = 1;
|
||||
config.blockDim = params.ep_size;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = params.stream;
|
||||
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.attrs = attrs;
|
||||
config.numAttrs = 1;
|
||||
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, moeA2APrepareDispatchKernel, params.send_counters,
|
||||
params.local_token_counter, params.ep_size, params.flag_val));
|
||||
launchWithPdlWhenEnabled("moeA2APrepareDispatchKernel", moeA2APrepareDispatchKernel, 1, params.ep_size, 0,
|
||||
params.stream, params.send_counters, params.local_token_counter, params.flag_val);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@ -619,17 +604,6 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
constexpr int kWarpSize = 32;
|
||||
int const kWarpsPerBlock = kBlockSize / kWarpSize;
|
||||
|
||||
// Setup PDL launch configuration
|
||||
cudaLaunchConfig_t config;
|
||||
config.blockDim = kBlockSize;
|
||||
config.stream = params.stream;
|
||||
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.attrs = attrs;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Configure kernel launch
|
||||
if (params.one_block_per_token)
|
||||
{
|
||||
@ -640,14 +614,13 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
grid_size = 1;
|
||||
}
|
||||
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
|
||||
config.gridDim = grid_size;
|
||||
config.dynamicSmemBytes = shared_bytes;
|
||||
|
||||
SWITCH_TOP_K(params.top_k, TOP_K, {
|
||||
auto kernel_fn = moeA2ADispatchKernel<BlockPolicy, TOP_K>;
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_fn, params.token_selected_experts, kernel_ptrs,
|
||||
params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank,
|
||||
params.ep_size, params.num_experts_per_rank));
|
||||
launchWithPdlWhenEnabled("moeA2ADispatchKernel", kernel_fn, grid_size, kBlockSize, shared_bytes,
|
||||
params.stream, params.token_selected_experts, kernel_ptrs, params.num_payloads,
|
||||
params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, params.ep_size,
|
||||
params.num_experts_per_rank);
|
||||
})
|
||||
}
|
||||
else
|
||||
@ -659,14 +632,13 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
grid_size = 1;
|
||||
}
|
||||
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
|
||||
config.gridDim = grid_size;
|
||||
config.dynamicSmemBytes = shared_bytes;
|
||||
|
||||
SWITCH_TOP_K(params.top_k, TOP_K, {
|
||||
auto kernel_fn = moeA2ADispatchKernel<WarpPolicy, TOP_K>;
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_fn, params.token_selected_experts, kernel_ptrs,
|
||||
params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank,
|
||||
params.ep_size, params.num_experts_per_rank));
|
||||
launchWithPdlWhenEnabled("moeA2ADispatchKernel", kernel_fn, grid_size, kBlockSize, shared_bytes,
|
||||
params.stream, params.token_selected_experts, kernel_ptrs, params.num_payloads,
|
||||
params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, params.ep_size,
|
||||
params.num_experts_per_rank);
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1001,7 +973,6 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c
|
||||
int bytes_per_token, int ep_size, int max_tokens_per_rank, uint32_t* flag_val_ptr, int const* recv_counters)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// PDL: Wait for any dependent kernels to complete before starting
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
@ -1014,7 +985,6 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c
|
||||
if (payload_bytes == nullptr)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// PDL: Signal completion even if no payload
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
return;
|
||||
@ -1026,7 +996,6 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c
|
||||
if (slot_idx >= total_slots)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// PDL: Signal completion before early return
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
return;
|
||||
@ -1040,7 +1009,6 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c
|
||||
if (token_idx >= recv_counters[source_rank])
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// PDL: Signal completion before early return
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
return;
|
||||
@ -1055,7 +1023,6 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c
|
||||
vectorized_copy<ThreadingPolicy>(dst_ptr, src_ptr, bytes_per_token);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// PDL: Signal that this kernel's main work is complete and dependent kernels can launch
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
@ -1070,7 +1037,6 @@ __global__ void moeA2ACombineKernel(
|
||||
int max_tokens_per_rank, int elements_per_token, int local_num_tokens, int rank_id, int ep_size)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// PDL: Wait for any dependent kernels to complete before starting
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
@ -1151,7 +1117,6 @@ __global__ void moeA2ACombineKernel(
|
||||
vectorized_combine<TOP_K, ThreadingPolicy, T>(token_output, size_per_token, rank_id, max_tokens_per_rank, ptrs);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// PDL: Signal that this kernel's main work is complete and dependent kernels can launch
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
@ -1175,39 +1140,16 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params)
|
||||
int total_slots = params.prepare_payload == nullptr ? 1 : params.ep_size * params.max_tokens_per_rank;
|
||||
int grid_size_warp = ceilDiv(total_slots, kWarpsPerBlock);
|
||||
int grid_size_block = total_slots; // one block per token
|
||||
|
||||
int grid = params.one_block_per_token ? grid_size_block : grid_size_warp;
|
||||
|
||||
// Setup PDL launch configuration
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = grid;
|
||||
config.blockDim = kBlockSize;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = params.stream;
|
||||
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.attrs = attrs;
|
||||
config.numAttrs = 1;
|
||||
|
||||
uint8_t* recv_buffer_bytes = static_cast<uint8_t*>(const_cast<void*>(params.recv_buffers[params.ep_rank]));
|
||||
uint8_t const* payload_bytes = static_cast<uint8_t const*>(params.prepare_payload);
|
||||
|
||||
if (params.one_block_per_token)
|
||||
{
|
||||
auto kernel_fn = moeA2APrepareCombineKernel<BlockPolicy>;
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
|
||||
&config, kernel_fn, recv_buffer_bytes, payload_bytes, bytes_per_token, params.ep_size,
|
||||
params.max_tokens_per_rank, params.flag_val, params.recv_counters));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto kernel_fn = moeA2APrepareCombineKernel<WarpPolicy>;
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
|
||||
&config, kernel_fn, recv_buffer_bytes, payload_bytes, bytes_per_token, params.ep_size,
|
||||
params.max_tokens_per_rank, params.flag_val, params.recv_counters));
|
||||
}
|
||||
auto kernel_fn
|
||||
= params.one_block_per_token ? moeA2APrepareCombineKernel<BlockPolicy> : moeA2APrepareCombineKernel<WarpPolicy>;
|
||||
launchWithPdlWhenEnabled("moeA2APrepareCombineKernel", kernel_fn, grid, kBlockSize, 0, params.stream,
|
||||
recv_buffer_bytes, payload_bytes, bytes_per_token, params.ep_size, params.max_tokens_per_rank, params.flag_val,
|
||||
params.recv_counters);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@ -1262,26 +1204,14 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
|
||||
|
||||
int grid = params.one_block_per_token ? grid_size_block : grid_size_warp;
|
||||
|
||||
// Setup PDL launch configuration
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = grid;
|
||||
config.blockDim = kBlockSize;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = params.stream;
|
||||
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.attrs = attrs;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Launch appropriate kernel with compact macros
|
||||
SWITCH_DTYPE(params.dtype, TKernelType, {
|
||||
SWITCH_POLICY(params.one_block_per_token, Policy, {
|
||||
SWITCH_TOP_K(params.top_k, TOP_K, {
|
||||
auto kernel_fn = moeA2ACombineKernel<TKernelType, Policy, TOP_K>;
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_fn, kernel_ptrs, params.max_tokens_per_rank,
|
||||
params.elements_per_token, params.local_num_tokens, params.ep_rank, params.ep_size));
|
||||
launchWithPdlWhenEnabled("moeA2ACombineKernel", kernel_fn, grid, kBlockSize, 0, params.stream,
|
||||
kernel_ptrs, params.max_tokens_per_rank, params.elements_per_token, params.local_num_tokens,
|
||||
params.ep_rank, params.ep_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Loading…
Reference in New Issue
Block a user