Replace memset with data initialization within kernels (#4851)

Signed-off-by: Christina Zhang <christinaz@nvidia.com>
This commit is contained in:
ChristinaZ 2025-06-04 08:56:46 +08:00 committed by GitHub
parent 73389d6531
commit d64af85e8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2992,6 +2992,19 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
cudaGridDependencySynchronize();
}
// initialize the mPtrPermutedIdxToTokenIdx
if (params.mPtrPermutedIdxToTokenIdx != nullptr)
{
int32_t permIdxToTokenIdxNum
= (params.mNumTokens * NumTopExperts + (params.mNumExperts << params.mPaddingLog2) - params.mNumExperts);
for (int32_t i = clusterThreadIdx; i < permIdxToTokenIdxNum; i += NumThreadsPerCluster)
{
params.mPtrPermutedIdxToTokenIdx[i] = -1;
}
// A cluster synchronization is performed prior to setting mPtrPermutedIdxToTokenIdx at the end of the kernel.
// Don't need to use __threadfence() here.
}
if (params.mPtrScores != nullptr)
{
// in this case, each warp represents a token
@ -3251,14 +3264,45 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WarpSize>(block);
// Wait on primary grid and trigger secondary kernel.
// Wait on primary grid.
if constexpr (KernelParams::UsePdl)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
}
// initialize the mPtrPermutedIdxToTokenIdx
int32_t globalThreadIdx = globalWarpIdx * WarpSize + laneIdx;
int32_t globalThreadStride = globalWarpStride * WarpSize;
if (params.mPtrPermutedIdxToTokenIdx != nullptr)
{
int32_t permIdxToTokenIdxNum
= (params.mNumTokens * NumTopExperts + (params.mNumExperts << params.mPaddingLog2) - params.mNumExperts);
for (int32_t i = globalThreadIdx; i < permIdxToTokenIdxNum; i += globalThreadStride)
{
params.mPtrPermutedIdxToTokenIdx[i] = -1;
}
}
// initialize the mPtrExpertCounts
if (params.mPtrExpertCounts != nullptr)
{
int32_t expertCountsNum = 2 * params.mNumExperts;
for (int32_t i = globalThreadIdx; i < expertCountsNum; i += globalThreadStride)
{
params.mPtrExpertCounts[i] = 0;
}
}
// Trigger secondary kernel.
if constexpr (KernelParams::UsePdl)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
}
// in this case, each warp represents a token, and we use a grid-stride loop
// over all warps/tokens
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride)
@ -3698,17 +3742,6 @@ void run(Data const& data, void* stream)
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
if (data.mPtrPermutedIdxToTokenIdx != nullptr)
{
// need to set all values to -1 before running the kernel
auto maxPermutedSize
= data.mNumTokens * data.mTopK + (data.mNumExperts << data.mPaddingLog2) - data.mNumExperts;
// note that a value of -1 per byte works for any size of signed integer
// to set each full value to the logical value -1
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrPermutedIdxToTokenIdx, -1,
static_cast<size_t>(maxPermutedSize) * sizeof(int32_t), (cudaStream_t) stream));
}
bool const useSingleCluster
= data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster);
if (!useSingleCluster)
@ -3717,9 +3750,6 @@ void run(Data const& data, void* stream)
data.mPtrExpertIdx != nullptr, "When #tokens is large, `mPtrExpertIdx` is a required input.");
TLLM_CHECK_WITH_INFO(
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
// Reset the global histograms (not used in single-cluster code path).
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrExpertCounts, 0,
static_cast<size_t>(2 * data.mNumExperts) * sizeof(int32_t), (cudaStream_t) stream));
}
if (useSingleCluster)