mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Replace memset with data initialization within kernels (#4851)
Signed-off-by: Christina Zhang <christinaz@nvidia.com>
This commit is contained in:
parent
73389d6531
commit
d64af85e8c
@ -2992,6 +2992,19 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
|
||||
// initialize the mPtrPermutedIdxToTokenIdx
|
||||
if (params.mPtrPermutedIdxToTokenIdx != nullptr)
|
||||
{
|
||||
int32_t permIdxToTokenIdxNum
|
||||
= (params.mNumTokens * NumTopExperts + (params.mNumExperts << params.mPaddingLog2) - params.mNumExperts);
|
||||
for (int32_t i = clusterThreadIdx; i < permIdxToTokenIdxNum; i += NumThreadsPerCluster)
|
||||
{
|
||||
params.mPtrPermutedIdxToTokenIdx[i] = -1;
|
||||
}
|
||||
// A cluster synchronization is performed prior to setting mPtrPermutedIdxToTokenIdx at the end of the kernel.
|
||||
// Don't need to use __threadfence() here.
|
||||
}
|
||||
|
||||
if (params.mPtrScores != nullptr)
|
||||
{
|
||||
// in this case, each warp represents a token
|
||||
@ -3251,14 +3264,45 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||||
|
||||
// Wait on primary grid and trigger secondary kernel.
|
||||
// Wait on primary grid.
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
}
|
||||
|
||||
// initialize the mPtrPermutedIdxToTokenIdx
|
||||
int32_t globalThreadIdx = globalWarpIdx * WarpSize + laneIdx;
|
||||
int32_t globalThreadStride = globalWarpStride * WarpSize;
|
||||
if (params.mPtrPermutedIdxToTokenIdx != nullptr)
|
||||
{
|
||||
int32_t permIdxToTokenIdxNum
|
||||
= (params.mNumTokens * NumTopExperts + (params.mNumExperts << params.mPaddingLog2) - params.mNumExperts);
|
||||
for (int32_t i = globalThreadIdx; i < permIdxToTokenIdxNum; i += globalThreadStride)
|
||||
{
|
||||
params.mPtrPermutedIdxToTokenIdx[i] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
// initialize the mPtrExpertCounts
|
||||
if (params.mPtrExpertCounts != nullptr)
|
||||
{
|
||||
int32_t expertCountsNum = 2 * params.mNumExperts;
|
||||
for (int32_t i = globalThreadIdx; i < expertCountsNum; i += globalThreadStride)
|
||||
{
|
||||
params.mPtrExpertCounts[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger secondary kernel.
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
}
|
||||
|
||||
// in this case, each warp represents a token, and we use a grid-stride loop
|
||||
// over all warps/tokens
|
||||
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride)
|
||||
@ -3698,17 +3742,6 @@ void run(Data const& data, void* stream)
|
||||
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
|
||||
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
|
||||
|
||||
if (data.mPtrPermutedIdxToTokenIdx != nullptr)
|
||||
{
|
||||
// need to set all values to -1 before running the kernel
|
||||
auto maxPermutedSize
|
||||
= data.mNumTokens * data.mTopK + (data.mNumExperts << data.mPaddingLog2) - data.mNumExperts;
|
||||
// note that a value of -1 per byte works for any size of signed integer
|
||||
// to set each full value to the logical value -1
|
||||
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrPermutedIdxToTokenIdx, -1,
|
||||
static_cast<size_t>(maxPermutedSize) * sizeof(int32_t), (cudaStream_t) stream));
|
||||
}
|
||||
|
||||
bool const useSingleCluster
|
||||
= data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster);
|
||||
if (!useSingleCluster)
|
||||
@ -3717,9 +3750,6 @@ void run(Data const& data, void* stream)
|
||||
data.mPtrExpertIdx != nullptr, "When #tokens is large, `mPtrExpertIdx` is a required input.");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
|
||||
// Reset the global histograms (not used in single-cluster code path).
|
||||
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrExpertCounts, 0,
|
||||
static_cast<size_t>(2 * data.mNumExperts) * sizeof(int32_t), (cudaStream_t) stream));
|
||||
}
|
||||
|
||||
if (useSingleCluster)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user