mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Feat: Add vectorized loading for finalize kernel in MoE Trtllm backend (#5919)
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
This commit is contained in:
parent
4c364b9a73
commit
7e033c392e
@ -16,10 +16,21 @@
|
||||
|
||||
#include "DevKernel.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include <cub/cub.cuh>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Helper function for array conversion
|
||||
template <class T, class U>
|
||||
__host__ __device__ constexpr static U arrayConvert(T const& input)
|
||||
{
|
||||
cutlass::NumericArrayConverter<typename U::Element, typename T::Element, U::kElements> converter;
|
||||
return converter(input);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -518,6 +529,83 @@ __global__ void finalizeKernel(KernelParams params)
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
|
||||
|
||||
__device__ float4 vectorizedLoadPtx(float4 const* ptr)
|
||||
{
|
||||
float4 ret;
|
||||
asm volatile("ld.global.v4.f32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w)
|
||||
: "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Final kernel to unpermute and scale
|
||||
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection.
|
||||
|
||||
template <typename KernelParams>
|
||||
__global__ void finalizeKernelVecLoad(KernelParams params)
|
||||
{
|
||||
using Type = typename KernelParams::Type;
|
||||
using TypeExpW = typename KernelParams::TypeExpW;
|
||||
|
||||
int const hiddenDimBits = params.hiddenDim * cutlass::sizeof_bits<Type>::value;
|
||||
assert(hiddenDimBits % 128 == 0);
|
||||
|
||||
// Load 128-bits per thread, according to the smallest data type we read/write
|
||||
constexpr int64_t FINALIZE_ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<Type>::value;
|
||||
using InputElem = cutlass::Array<Type, FINALIZE_ELEM_PER_THREAD>;
|
||||
using OutputElem = cutlass::Array<Type, FINALIZE_ELEM_PER_THREAD>;
|
||||
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
|
||||
|
||||
int64_t const tokenIdx = blockIdx.x;
|
||||
int64_t const startOffset = threadIdx.x;
|
||||
int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
|
||||
int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD;
|
||||
|
||||
auto const offset = tokenIdx * params.hiddenDim;
|
||||
Type* outputPtr = params.outPtr + offset;
|
||||
auto* outElemPtr = reinterpret_cast<OutputElem*>(outputPtr);
|
||||
auto const* inElemPtr = reinterpret_cast<InputElem const*>(params.inPtr);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
// wait on primary kernel when using PDL
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
#endif
|
||||
|
||||
for (int elemIndex = startOffset; elemIndex < numElemsInCol; elemIndex += stride)
|
||||
{
|
||||
ComputeElem threadOutput;
|
||||
threadOutput.fill(0);
|
||||
for (int k = 0; k < params.topK; ++k)
|
||||
{
|
||||
int const expandedIdx = tokenIdx * params.topK + k;
|
||||
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
|
||||
if (permutedIdx == -1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
float const scale
|
||||
= (params.expertWeightsPtr != nullptr) ? static_cast<float>(params.expertWeightsPtr[expandedIdx]) : 1.f;
|
||||
|
||||
auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInCol;
|
||||
|
||||
float4 input = vectorizedLoadPtx(reinterpret_cast<float4 const*>(&inputPermutedPtr[elemIndex]));
|
||||
InputElem inputPermutedElem = *reinterpret_cast<InputElem const*>(&input);
|
||||
ComputeElem expertResult = arrayConvert<InputElem, ComputeElem>(inputPermutedElem);
|
||||
|
||||
threadOutput = threadOutput + scale * expertResult;
|
||||
}
|
||||
|
||||
OutputElem outputElem = arrayConvert<ComputeElem, OutputElem>(threadOutput);
|
||||
outElemPtr[elemIndex] = outputElem;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename KernelParams>
|
||||
@ -552,7 +640,9 @@ __global__ void finalizeDeepSeekKernel(KernelParams params)
|
||||
int const expandedIdx = tokenIdx * params.topK + k;
|
||||
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
|
||||
if (permutedIdx == -1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
int const totalNumPaddedTokens = params.totalNumPaddedTokens[0];
|
||||
int const scaleIdx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
|
||||
float const blockScale = params.inDqSfsPtr ? params.inDqSfsPtr[scaleIdx] : 1;
|
||||
@ -591,7 +681,6 @@ __global__ void finalizeDeepSeekKernel(KernelParams params)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run(Data const& data, void* stream)
|
||||
{
|
||||
if (data.mUseDeepSeekFp8)
|
||||
@ -610,9 +699,22 @@ void run(Data const& data, void* stream)
|
||||
int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads;
|
||||
// Capped at rather arbitrary 8192 to avoid gridDim exceeding 65535 specified by CUDA.
|
||||
int const numBlocksY = std::min(8192, data.numTokens);
|
||||
dim3 numBlocks(numBlocksX, numBlocksY);
|
||||
|
||||
LAUNCH_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream);
|
||||
if (numBlocksX * numBlocksY < 1184)
|
||||
{
|
||||
// The number 1184 comes from 148 * 8, where 148 is the number of SMs (Streaming Multiprocessors) in the
|
||||
// Blackwell architecture,
|
||||
// and the value 8 means that each Streaming Multiprocessor (SM) can hold up to 8 blocks for this kernel.
|
||||
// This limitation is intended to ensure that when the number of waves is greater than 1, we choose to use
|
||||
// the kernel with vectorized loading.
|
||||
dim3 numBlocks(numBlocksX, numBlocksY);
|
||||
LAUNCH_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
LAUNCH_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens,
|
||||
/*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user