Feat: Add vectorized loading for finalize kernel in MoE Trtllm backend (#5919)

Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
This commit is contained in:
ChristinaZ 2025-07-17 12:38:29 +08:00 committed by GitHub
parent 4c364b9a73
commit 7e033c392e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,10 +16,21 @@
#include "DevKernel.h"
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include <cub/cub.cuh>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cub/cub.cuh>
////////////////////////////////////////////////////////////////////////////////////////////////////
// Helper function for array conversion
template <class T, class U>
__host__ __device__ constexpr static U arrayConvert(T const& input)
{
cutlass::NumericArrayConverter<typename U::Element, typename T::Element, U::kElements> converter;
return converter(input);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -518,6 +529,83 @@ __global__ void finalizeKernel(KernelParams params)
}
}
constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
__device__ float4 vectorizedLoadPtx(float4 const* ptr)
{
float4 ret;
asm volatile("ld.global.v4.f32 {%0, %1, %2, %3}, [%4];"
: "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w)
: "l"(ptr));
return ret;
}
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection.
template <typename KernelParams>
__global__ void finalizeKernelVecLoad(KernelParams params)
{
using Type = typename KernelParams::Type;
using TypeExpW = typename KernelParams::TypeExpW;
int const hiddenDimBits = params.hiddenDim * cutlass::sizeof_bits<Type>::value;
assert(hiddenDimBits % 128 == 0);
// Load 128-bits per thread, according to the smallest data type we read/write
constexpr int64_t FINALIZE_ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<Type>::value;
using InputElem = cutlass::Array<Type, FINALIZE_ELEM_PER_THREAD>;
using OutputElem = cutlass::Array<Type, FINALIZE_ELEM_PER_THREAD>;
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
int64_t const tokenIdx = blockIdx.x;
int64_t const startOffset = threadIdx.x;
int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD;
auto const offset = tokenIdx * params.hiddenDim;
Type* outputPtr = params.outPtr + offset;
auto* outElemPtr = reinterpret_cast<OutputElem*>(outputPtr);
auto const* inElemPtr = reinterpret_cast<InputElem const*>(params.inPtr);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// wait on primary kernel when using PDL
if constexpr (KernelParams::UsePdl)
{
cudaGridDependencySynchronize();
}
#endif
for (int elemIndex = startOffset; elemIndex < numElemsInCol; elemIndex += stride)
{
ComputeElem threadOutput;
threadOutput.fill(0);
for (int k = 0; k < params.topK; ++k)
{
int const expandedIdx = tokenIdx * params.topK + k;
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
if (permutedIdx == -1)
{
continue;
}
float const scale
= (params.expertWeightsPtr != nullptr) ? static_cast<float>(params.expertWeightsPtr[expandedIdx]) : 1.f;
auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInCol;
float4 input = vectorizedLoadPtx(reinterpret_cast<float4 const*>(&inputPermutedPtr[elemIndex]));
InputElem inputPermutedElem = *reinterpret_cast<InputElem const*>(&input);
ComputeElem expertResult = arrayConvert<InputElem, ComputeElem>(inputPermutedElem);
threadOutput = threadOutput + scale * expertResult;
}
OutputElem outputElem = arrayConvert<ComputeElem, OutputElem>(threadOutput);
outElemPtr[elemIndex] = outputElem;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename KernelParams>
@ -552,7 +640,9 @@ __global__ void finalizeDeepSeekKernel(KernelParams params)
int const expandedIdx = tokenIdx * params.topK + k;
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
if (permutedIdx == -1)
{
continue;
}
int const totalNumPaddedTokens = params.totalNumPaddedTokens[0];
int const scaleIdx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
float const blockScale = params.inDqSfsPtr ? params.inDqSfsPtr[scaleIdx] : 1;
@ -591,7 +681,6 @@ __global__ void finalizeDeepSeekKernel(KernelParams params)
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void run(Data const& data, void* stream)
{
if (data.mUseDeepSeekFp8)
@ -610,9 +699,22 @@ void run(Data const& data, void* stream)
int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads;
// Capped at rather arbitrary 8192 to avoid gridDim exceeding 65535 specified by CUDA.
int const numBlocksY = std::min(8192, data.numTokens);
dim3 numBlocks(numBlocksX, numBlocksY);
LAUNCH_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream);
if (numBlocksX * numBlocksY < 1184)
{
// The number 1184 comes from 148 * 8, where 148 is the number of SMs (Streaming Multiprocessors) in the
// Blackwell architecture,
// and the value 8 means that each Streaming Multiprocessor (SM) can hold up to 8 blocks for this kernel.
// This limitation is intended to ensure that when the number of waves is greater than 1, we choose to use
// the kernel with vectorized loading.
dim3 numBlocks(numBlocksX, numBlocksY);
LAUNCH_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream);
}
else
{
LAUNCH_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens,
/*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream);
}
}
}