mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
1038 lines
39 KiB
Plaintext
1038 lines
39 KiB
Plaintext
/*
|
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include "mnnvlAllreduceKernels.h"
|
|
#include "tensorrt_llm/common/config.h"
|
|
#include <cooperative_groups.h>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <cuda/atomic>
|
|
#include <cuda_bf16.h>
|
|
#include <cuda_pipeline.h>
|
|
#include <tuple>
|
|
#include <type_traits>
|
|
|
|
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/common/dataType.h"
|
|
#include "tensorrt_llm/common/envUtils.h"
|
|
#include "tensorrt_llm/common/lamportUtils.cuh"
|
|
#include "tensorrt_llm/common/logger.h"
|
|
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
|
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
|
|
namespace kernels::mnnvl
|
|
{
|
|
|
|
using tensorrt_llm::common::isNegZero;
|
|
using tensorrt_llm::common::LamportFlags;
|
|
using tensorrt_llm::common::cuda_cast;
|
|
using tensorrt_llm::common::getMultiProcessorCount;
|
|
using tensorrt_llm::common::getDTypeSize;
|
|
|
|
// Guard the helper function used for this kernel.
|
|
namespace detail
|
|
{
|
|
template <typename PackedType, typename T>
|
|
union PackedVec
|
|
{
|
|
PackedType packed;
|
|
T elements[sizeof(PackedType) / sizeof(T)];
|
|
|
|
__device__ PackedVec& operator+=(PackedVec& other)
|
|
{
|
|
#pragma unroll
|
|
for (int i = 0; i < sizeof(PackedType) / sizeof(T); i++)
|
|
{
|
|
elements[i] += other.elements[i];
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
__device__ PackedVec operator+(PackedVec& other)
|
|
{
|
|
PackedVec result;
|
|
#pragma unroll
|
|
for (int i = 0; i < sizeof(PackedType) / sizeof(T); i++)
|
|
{
|
|
result.elements[i] = elements[i] + other.elements[i];
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
template <typename PackedType, typename T>
|
|
inline __device__ PackedType loadPacked(T* ptr)
|
|
{
|
|
return *reinterpret_cast<PackedType*>(ptr);
|
|
}
|
|
|
|
template <typename PackedType, typename T>
|
|
inline __device__ const PackedType loadPacked(T const* ptr)
|
|
{
|
|
return *reinterpret_cast<PackedType const*>(ptr);
|
|
}
|
|
|
|
template <typename PackedType>
|
|
inline __device__ PackedType loadPackedVolatile(void const* ptr)
|
|
{
|
|
static_assert(sizeof(PackedType) == 0, "Not implemented");
|
|
return PackedType{};
|
|
}
|
|
|
|
template <>
|
|
inline __device__ float4 loadPackedVolatile<float4>(void const* ptr)
|
|
{
|
|
float4 returnValue;
|
|
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
|
|
: "=f"(returnValue.x), "=f"(returnValue.y), "=f"(returnValue.z), "=f"(returnValue.w)
|
|
: "l"(ptr));
|
|
return returnValue;
|
|
}
|
|
|
|
template <>
|
|
inline __device__ float2 loadPackedVolatile<float2>(void const* ptr)
|
|
{
|
|
float2 returnValue;
|
|
asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" : "=f"(returnValue.x), "=f"(returnValue.y) : "l"(ptr));
|
|
return returnValue;
|
|
}
|
|
|
|
template <typename T_IN>
|
|
inline __device__ void copyF4(T_IN* dst, T_IN const* src)
|
|
{
|
|
float4* dst4 = reinterpret_cast<float4*>(dst);
|
|
float4 const* src4 = reinterpret_cast<float4 const*>(src);
|
|
__pipeline_memcpy_async(dst4, src4, sizeof(float4));
|
|
}
|
|
|
|
uint32_t constexpr kWARP_SIZE = 32U;
|
|
uint32_t constexpr kLOG2_WARP_SIZE = 5U;
|
|
uint32_t constexpr kLANE_ID_MASK = 0x1f;
|
|
|
|
template <typename T>
|
|
inline __device__ T warpReduceSumPartial(T val)
|
|
{
|
|
int laneId = threadIdx.x & kLANE_ID_MASK;
|
|
// We make sure only the last warp will call this function
|
|
int warpSize = blockDim.x - (threadIdx.x & ~(kWARP_SIZE - 1));
|
|
unsigned int active_mask = (1U << warpSize) - 1;
|
|
|
|
#pragma unroll
|
|
for (int mask = 16; mask > 0; mask >>= 1)
|
|
{
|
|
int targetLane = laneId ^ mask;
|
|
auto tmp = __shfl_xor_sync(active_mask, val, mask, kWARP_SIZE);
|
|
val += targetLane < warpSize ? tmp : 0;
|
|
}
|
|
return val;
|
|
}
|
|
|
|
// SYNC:
|
|
// - True: share the sume across all threads
|
|
// - False: only thread 0 get the sum; Other thread's value is undefined.
|
|
template <typename T, bool SYNC = false>
|
|
inline __device__ T blockReduceSumPartial(T val)
|
|
{
|
|
__shared__ T smem[kWARP_SIZE];
|
|
int laneId = threadIdx.x & kLANE_ID_MASK;
|
|
int warpId = threadIdx.x >> kLOG2_WARP_SIZE;
|
|
int warpNum = (blockDim.x + kWARP_SIZE - 1) >> kLOG2_WARP_SIZE; // Ceiling division to include partial warps
|
|
|
|
val = (warpId == warpNum - 1) ? warpReduceSumPartial(val) : tensorrt_llm::common::warpReduceSum(val);
|
|
if (laneId == 0)
|
|
{
|
|
smem[warpId] = val;
|
|
}
|
|
__syncthreads();
|
|
|
|
if (warpId == 0)
|
|
{
|
|
val = (laneId < warpNum) ? smem[laneId] : (T) 0.f;
|
|
// Need to consider the corner case where we only have one warp and it is partial
|
|
val = (warpNum == 1) ? warpReduceSumPartial(val) : tensorrt_llm::common::warpReduceSum(val);
|
|
|
|
if constexpr (SYNC)
|
|
{
|
|
if (laneId == 0)
|
|
{
|
|
smem[warpId] = val;
|
|
}
|
|
}
|
|
}
|
|
if constexpr (SYNC)
|
|
{
|
|
__syncthreads();
|
|
val = smem[0];
|
|
}
|
|
return val;
|
|
}
|
|
|
|
// blockReduceSum in reduceKernelUtils.cuh returns result only on warp0
|
|
// So we need a duplicate implementation here where all threads get the result
|
|
template <typename T>
|
|
inline __device__ T blockReduceSumFull(T val)
|
|
{
|
|
__shared__ T smem[kWARP_SIZE];
|
|
int lane_id = threadIdx.x & kLANE_ID_MASK;
|
|
int warp_id = threadIdx.x >> kLOG2_WARP_SIZE;
|
|
int warp_num = blockDim.x >> kLOG2_WARP_SIZE;
|
|
|
|
val = tensorrt_llm::common::warpReduceSum(val);
|
|
if (lane_id == 0)
|
|
{
|
|
smem[warp_id] = val;
|
|
}
|
|
__syncthreads();
|
|
|
|
val = (lane_id < warp_num) ? smem[lane_id] : (T) 0.f;
|
|
val = tensorrt_llm::common::warpReduceSum(val);
|
|
|
|
return val;
|
|
}
|
|
|
|
template <typename T, bool SYNC = false>
|
|
inline __device__ T blockReduceSum(T val)
|
|
{
|
|
bool hasPartialWarp = (blockDim.x & kLANE_ID_MASK) != 0;
|
|
if (hasPartialWarp)
|
|
{
|
|
return blockReduceSumPartial<T, SYNC>(val);
|
|
}
|
|
else
|
|
{
|
|
return blockReduceSumFull<T>(val);
|
|
}
|
|
}
|
|
|
|
// We have to define this again since the one in mathUtils.h is shadowed by the one from cudaUtils.h, which is a
|
|
// host-only function!
|
|
template <typename T>
|
|
inline __device__ __host__ T divUp(T m, T n)
|
|
{
|
|
return (m + n - 1) / n;
|
|
}
|
|
|
|
// A helper function to tune the grid configuration for fused oneshot and rmsnorm kernels
|
|
// Return (block_size, cluster_size, loads_per_thread)
|
|
std::tuple<int, int, int> adjustGridConfig(int numTokens, int dim, int eltsPerThread)
|
|
{
|
|
// Start with preferred block_size and cluster_size
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
int clusterSize = 8;
|
|
#else
|
|
int clusterSize = 1;
|
|
#endif
|
|
int blockSize = 128;
|
|
// ========================== Adjust the grid configuration ==========================
|
|
int threadsNeeded = divUp(dim, eltsPerThread);
|
|
int loadsPerThread = 1;
|
|
|
|
blockSize = divUp(threadsNeeded, clusterSize);
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
while (threadsNeeded % clusterSize != 0 && clusterSize > 1)
|
|
{
|
|
clusterSize /= 2;
|
|
}
|
|
blockSize = divUp(threadsNeeded, clusterSize);
|
|
while (blockSize < 128 && clusterSize >= 2)
|
|
{
|
|
blockSize *= 2;
|
|
clusterSize /= 2;
|
|
}
|
|
int smCount = getMultiProcessorCount();
|
|
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512)
|
|
{
|
|
blockSize *= 2;
|
|
clusterSize /= 2;
|
|
}
|
|
#endif
|
|
|
|
// Trying to scale up use multiple loads or CGA
|
|
while (blockSize > 1024)
|
|
{
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
if (clusterSize < 8)
|
|
{
|
|
clusterSize = clusterSize << 1;
|
|
}
|
|
else
|
|
{
|
|
break;
|
|
}
|
|
#else
|
|
if (loadsPerThread < 8)
|
|
{
|
|
loadsPerThread += 1;
|
|
}
|
|
else
|
|
{
|
|
break;
|
|
}
|
|
#endif
|
|
blockSize = divUp(threadsNeeded, clusterSize * loadsPerThread);
|
|
}
|
|
return {blockSize, clusterSize, loadsPerThread};
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
using detail::PackedVec;
|
|
using detail::loadPacked;
|
|
using detail::loadPackedVolatile;
|
|
using detail::blockReduceSum;
|
|
using detail::divUp;
|
|
using detail::copyF4;
|
|
|
|
template <uint8_t WorldSize, typename T, bool RMSNormFusion = false, typename PackedType = float4>
|
|
__global__ void __launch_bounds__(1024) oneshotAllreduceFusionKernel(T* outputPtr, T* prenormedPtr, T const* shardPtr,
|
|
T const* residualInPtr, T const* gammaPtr, T** inputPtrs, T* mcastPtr, int const numTokens, int const tokenDim,
|
|
float epsilon, int const rank, uint32_t* bufferFlags)
|
|
{
|
|
constexpr int kELTS_PER_THREAD = sizeof(PackedType) / sizeof(T);
|
|
constexpr int kLAMPORT_ELTS_PER_PACKED = sizeof(PackedType) / sizeof(float);
|
|
constexpr uint32_t kELT_SIZE = sizeof(T);
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
namespace cg = cooperative_groups;
|
|
cg::cluster_group cluster = cg::this_cluster();
|
|
int packedIdx = cluster.thread_rank();
|
|
int token = blockIdx.x;
|
|
int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD;
|
|
|
|
cudaGridDependencySynchronize();
|
|
#else
|
|
int packedIdx = blockIdx.y * blockDim.x + threadIdx.x;
|
|
int token = blockIdx.x;
|
|
// Offset w.r.t. the input shard
|
|
int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD;
|
|
#endif
|
|
|
|
// We only use 1 stage for the oneshot allreduce
|
|
LamportFlags<PackedType> flag(bufferFlags, 1);
|
|
T* stagePtrMcast = reinterpret_cast<T*>(flag.getCurLamportBuf(mcastPtr, 0));
|
|
T* stagePtrLocal = reinterpret_cast<T*>(flag.getCurLamportBuf(inputPtrs[rank], 0));
|
|
|
|
if (packedIdx * kELTS_PER_THREAD >= tokenDim)
|
|
{
|
|
flag.clearDirtyLamportBuf(inputPtrs[rank], -1);
|
|
return;
|
|
}
|
|
|
|
// ==================== Broadcast tokens to each rank =============================
|
|
PackedVec<PackedType, T> val;
|
|
val.packed = loadPacked<PackedType>(&shardPtr[threadOffset]);
|
|
#pragma unroll
|
|
for (int i = 0; i < kELTS_PER_THREAD; i++)
|
|
{
|
|
if (isNegZero(val.elements[i]))
|
|
val.elements[i] = cuda_cast<T, float>(0.f);
|
|
}
|
|
|
|
reinterpret_cast<PackedType*>(&stagePtrMcast[token * tokenDim * WorldSize + rank * tokenDim])[packedIdx]
|
|
= val.packed;
|
|
|
|
flag.ctaArrive();
|
|
// ======================= Lamport Sync and clear the output buffer from previous iteration
|
|
// =============================
|
|
flag.clearDirtyLamportBuf(inputPtrs[rank], -1);
|
|
|
|
PackedVec<PackedType, float> valuesLamport[WorldSize];
|
|
while (1)
|
|
{
|
|
bool valid = true;
|
|
#pragma unroll
|
|
for (int r = 0; r < WorldSize; r++)
|
|
{
|
|
valuesLamport[r].packed = loadPackedVolatile<PackedType>(
|
|
&stagePtrLocal[token * tokenDim * WorldSize + r * tokenDim + packedIdx * kELTS_PER_THREAD]);
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++)
|
|
{
|
|
valid &= !isNegZero(valuesLamport[r].elements[i]);
|
|
}
|
|
}
|
|
if (valid)
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
|
|
auto values = reinterpret_cast<PackedVec<PackedType, T>*>(valuesLamport);
|
|
// ======================= Reduction =============================
|
|
float accum[kELTS_PER_THREAD];
|
|
PackedVec<PackedType, T> packedAccum;
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < kELTS_PER_THREAD; i++)
|
|
{
|
|
accum[i] = cuda_cast<float, T>(values[0].elements[i]);
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int r = 1; r < WorldSize; r++)
|
|
{
|
|
#pragma unroll
|
|
for (int i = 0; i < kELTS_PER_THREAD; i++)
|
|
{
|
|
accum[i] += cuda_cast<float, T>(values[r].elements[i]);
|
|
}
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < kELTS_PER_THREAD; i++)
|
|
{
|
|
packedAccum.elements[i] = cuda_cast<T, float>(accum[i]);
|
|
}
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cudaTriggerProgrammaticLaunchCompletion();
|
|
#endif
|
|
if constexpr (RMSNormFusion)
|
|
{
|
|
// =============================== Residual ===============================
|
|
PackedVec<PackedType, T> residualIn;
|
|
residualIn.packed = *reinterpret_cast<PackedType const*>(&residualInPtr[threadOffset]);
|
|
packedAccum += residualIn;
|
|
*reinterpret_cast<PackedType*>(&prenormedPtr[threadOffset]) = packedAccum.packed;
|
|
// =============================== Rmsnorm ================================
|
|
PackedVec<PackedType, T> gamma;
|
|
gamma.packed = *reinterpret_cast<PackedType const*>(&gammaPtr[packedIdx * kELTS_PER_THREAD]);
|
|
|
|
float threadSum = 0.F;
|
|
#pragma unroll
|
|
for (int i = 0; i < kELTS_PER_THREAD; i++)
|
|
{
|
|
// FIXME: Use float square if accuracy issue
|
|
threadSum += cuda_cast<float, T>(packedAccum.elements[i] * packedAccum.elements[i]);
|
|
}
|
|
float blockSum = blockReduceSum<float, true>(threadSum);
|
|
|
|
__shared__ float sharedVal[8]; // Temporary variable to share the sum within block
|
|
float fullSum = blockSum;
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
namespace cg = cooperative_groups;
|
|
cg::cluster_group cluster = cg::this_cluster();
|
|
int const numBlocks = cluster.num_blocks();
|
|
if (numBlocks > 1)
|
|
{
|
|
fullSum = 0.F;
|
|
// Need to reduce over the entire cluster
|
|
int const blockRank = cluster.block_rank();
|
|
if (threadIdx.x < numBlocks)
|
|
{
|
|
cluster.map_shared_rank(&sharedVal[0], threadIdx.x)[blockRank] = blockSum;
|
|
}
|
|
// cluster.sync();
|
|
cluster.barrier_wait(cluster.barrier_arrive());
|
|
for (int i = 0; i < numBlocks; ++i)
|
|
{
|
|
fullSum += sharedVal[i];
|
|
}
|
|
}
|
|
#endif
|
|
float rcpRms = rsqrtf(fullSum / tokenDim + epsilon);
|
|
#pragma unroll
|
|
for (int i = 0; i < kELTS_PER_THREAD; i++)
|
|
{
|
|
packedAccum.elements[i] = cuda_cast<T, float>(
|
|
cuda_cast<float, T>(packedAccum.elements[i]) * rcpRms * cuda_cast<float, T>(gamma.elements[i]));
|
|
}
|
|
}
|
|
reinterpret_cast<PackedType*>(&outputPtr[threadOffset])[0] = packedAccum.packed;
|
|
flag.waitAndUpdate({static_cast<uint32_t>(numTokens * tokenDim * WorldSize * kELT_SIZE), 0, 0, 0});
|
|
}
|
|
|
|
using detail::adjustGridConfig;
|
|
|
|
void oneshotAllreduceFusionOp(AllReduceFusionParams const& params)
|
|
{
|
|
int const numTokens = params.numTokens;
|
|
int const tokenDim = params.tokenDim;
|
|
int const eltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
|
|
|
|
auto [blockSize, clusterSize, loadsPerThread] = adjustGridConfig(numTokens, tokenDim, eltsPerThread);
|
|
dim3 grid(numTokens, clusterSize, 1);
|
|
|
|
TLLM_CHECK_WITH_INFO(blockSize <= 1024 && loadsPerThread == 1,
|
|
"Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", tokenDim,
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
1024 * 8 * eltsPerThread);
|
|
#else
|
|
1024 * eltsPerThread);
|
|
#endif
|
|
|
|
TLLM_LOG_DEBUG(
|
|
"[MNNVL AllReduceOneShot] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: %d, "
|
|
"loads_per_thread: %d, "
|
|
"threads_needed: %d",
|
|
numTokens, clusterSize, blockSize, clusterSize, loadsPerThread, divUp(tokenDim, eltsPerThread));
|
|
|
|
cudaLaunchAttribute attrs[2];
|
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
|
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
attrs[1].id = cudaLaunchAttributeClusterDimension;
|
|
attrs[1].val.clusterDim.x = 1;
|
|
attrs[1].val.clusterDim.y = clusterSize;
|
|
attrs[1].val.clusterDim.z = 1;
|
|
#endif
|
|
|
|
cudaLaunchConfig_t config
|
|
{
|
|
.gridDim = grid, .blockDim = blockSize, .dynamicSmemBytes = 0, .stream = params.stream, .attrs = attrs,
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
.numAttrs = 2,
|
|
#else
|
|
.numAttrs = 1,
|
|
#endif
|
|
};
|
|
|
|
#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, T, RMSNORM) \
|
|
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &oneshotAllreduceFusionKernel<WORLD_SIZE, T, RMSNORM>, output, \
|
|
residualOut, input, residualIn, gamma, ucPtrs, mcPtr, numTokens, tokenDim, static_cast<float>(params.epsilon), \
|
|
params.rank, params.bufferFlags));
|
|
#define DISPATCH_ALLREDUCE_KERNEL(WORLD_SIZE, T) \
|
|
if (params.rmsNormFusion) \
|
|
{ \
|
|
LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, T, true); \
|
|
} \
|
|
else \
|
|
{ \
|
|
LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, T, false); \
|
|
}
|
|
// C++17 compatible alternative using a template function
|
|
auto dispatchImpl = [&](auto* type_ptr) -> bool
|
|
{
|
|
using T = std::remove_pointer_t<decltype(type_ptr)>;
|
|
T** ucPtrs = reinterpret_cast<T**>(params.bufferPtrsDev);
|
|
T* mcPtr = reinterpret_cast<T*>(params.multicastPtr);
|
|
T* output = reinterpret_cast<T*>(params.output);
|
|
T* residualOut = reinterpret_cast<T*>(params.residualOut);
|
|
T const* input = reinterpret_cast<T const*>(params.input);
|
|
T const* residualIn = reinterpret_cast<T const*>(params.residualIn);
|
|
T const* gamma = reinterpret_cast<T const*>(params.gamma);
|
|
|
|
switch (params.nRanks)
|
|
{
|
|
// FIXME: Do we need other world sizes?
|
|
case 2: DISPATCH_ALLREDUCE_KERNEL(2, T); return true;
|
|
case 4: DISPATCH_ALLREDUCE_KERNEL(4, T); return true;
|
|
case 8: DISPATCH_ALLREDUCE_KERNEL(8, T); return true;
|
|
case 16: DISPATCH_ALLREDUCE_KERNEL(16, T); return true;
|
|
case 32: DISPATCH_ALLREDUCE_KERNEL(32, T); return true;
|
|
case 64: DISPATCH_ALLREDUCE_KERNEL(64, T); return true;
|
|
}
|
|
return false;
|
|
};
|
|
#undef LAUNCH_ALLREDUCE_KERNEL
|
|
#undef DISPATCH_ALLREDUCE_KERNEL
|
|
bool launched = (params.dType == nvinfer1::DataType::kBF16 && dispatchImpl((__nv_bfloat16*) nullptr))
|
|
|| (params.dType == nvinfer1::DataType::kFLOAT && dispatchImpl((float*) nullptr))
|
|
|| (params.dType == nvinfer1::DataType::kHALF && dispatchImpl((__nv_half*) nullptr));
|
|
if (!launched)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(false, "Failed to dispatch MNNVL AllReduceOneShot kernel.");
|
|
}
|
|
}
|
|
|
|
enum MNNVLTwoShotStage : uint8_t
|
|
{
|
|
SCATTER = 0,
|
|
BROADCAST = 1,
|
|
NUM_STAGES = 2,
|
|
};
|
|
|
|
template <uint8_t WorldSize, typename T, typename PackedType = float4>
|
|
__global__ __launch_bounds__(128) void twoshotAllreduceKernel(T* outputPtr, T const* shardPtr, T** inputPtrs,
|
|
T* mcastPtr, uint32_t const numTokens, uint32_t const tokenDim, uint32_t const rank, uint32_t* bufferFlags,
|
|
bool const wait_for_results)
|
|
{
|
|
constexpr int kELTS_PER_THREAD = sizeof(PackedType) / sizeof(T);
|
|
constexpr int kLAMPORT_ELTS_PER_PACKED = sizeof(PackedType) / sizeof(float);
|
|
constexpr uint32_t kELT_SIZE = sizeof(T);
|
|
|
|
int packedIdx = blockIdx.y * blockDim.x + threadIdx.x;
|
|
int token = blockIdx.x;
|
|
// Offset w.r.t. the input shard
|
|
int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD;
|
|
|
|
int destRank = token % WorldSize;
|
|
int destTokenOffset = token / WorldSize;
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cudaGridDependencySynchronize();
|
|
#endif
|
|
LamportFlags<PackedType> flag(bufferFlags, MNNVLTwoShotStage::NUM_STAGES);
|
|
|
|
T* scatterBufLocal = reinterpret_cast<T*>(flag.getCurLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER));
|
|
T* scatterBufDest = reinterpret_cast<T*>(flag.getCurLamportBuf(inputPtrs[destRank], MNNVLTwoShotStage::SCATTER));
|
|
T* broadcastBufW = reinterpret_cast<T*>(flag.getCurLamportBuf(mcastPtr, MNNVLTwoShotStage::BROADCAST));
|
|
T* broadcastBufR = reinterpret_cast<T*>(flag.getCurLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::BROADCAST));
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cudaTriggerProgrammaticLaunchCompletion();
|
|
#endif
|
|
// Make sure the clear function is called before OOB thread exits
|
|
if (packedIdx * kELTS_PER_THREAD >= tokenDim)
|
|
{
|
|
flag.clearDirtyLamportBuf(inputPtrs[rank], -1);
|
|
return;
|
|
}
|
|
|
|
// =============================== Scatter ===============================
|
|
|
|
// Load vectorized data
|
|
PackedVec<PackedType, T> val;
|
|
val.packed = loadPacked<PackedType>(&shardPtr[threadOffset]);
|
|
#pragma unroll
|
|
for (int i = 0; i < kELTS_PER_THREAD; i++)
|
|
{
|
|
if (isNegZero(val.elements[i]))
|
|
{
|
|
val.elements[i] = cuda_cast<T, float>(0.F);
|
|
}
|
|
}
|
|
|
|
// Store vectorized data
|
|
reinterpret_cast<PackedType*>(&scatterBufDest[destTokenOffset * tokenDim * WorldSize + rank * tokenDim])[packedIdx]
|
|
= val.packed;
|
|
|
|
flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER);
|
|
|
|
// =============================== Reduction and Broadcast ===============================
|
|
|
|
if ((token % WorldSize) == rank)
|
|
{
|
|
int localToken = token / WorldSize;
|
|
float accum[kELTS_PER_THREAD] = {0.F};
|
|
|
|
// Use float as we only check each float value for validity
|
|
PackedVec<PackedType, float> valuesLamport[WorldSize];
|
|
while (1)
|
|
{
|
|
bool valid = true;
|
|
#pragma unroll
|
|
for (int r = 0; r < WorldSize; r++)
|
|
{
|
|
valuesLamport[r].packed = loadPackedVolatile<PackedType>(
|
|
&scatterBufLocal[localToken * tokenDim * WorldSize + r * tokenDim + packedIdx * kELTS_PER_THREAD]);
|
|
|
|
// Check validity across all elements
|
|
#pragma unroll
|
|
for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++)
|
|
{
|
|
valid &= !isNegZero(valuesLamport[r].elements[i]);
|
|
}
|
|
}
|
|
if (valid)
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Now we view it as the value for reduction
|
|
auto values = reinterpret_cast<PackedVec<PackedType, T>*>(valuesLamport);
|
|
#pragma unroll
|
|
for (int r = 0; r < WorldSize; r++)
|
|
{
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < kELTS_PER_THREAD; i++)
|
|
{
|
|
accum[i] += cuda_cast<float, T>(values[r].elements[i]);
|
|
}
|
|
}
|
|
|
|
// Store vectorized result
|
|
PackedVec<PackedType, T> packedAccum;
|
|
#pragma unroll
|
|
for (int i = 0; i < kELTS_PER_THREAD; i++)
|
|
{
|
|
packedAccum.elements[i] = cuda_cast<T, float>(accum[i]);
|
|
}
|
|
reinterpret_cast<PackedType*>(&broadcastBufW[token * tokenDim])[packedIdx] = packedAccum.packed;
|
|
}
|
|
|
|
flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::BROADCAST);
|
|
|
|
// Optionally wait for results if the next layer isn't doing the Lamport check
|
|
if (wait_for_results)
|
|
{
|
|
// Update the atomic counter to indicate the block has read the offsets
|
|
flag.ctaArrive();
|
|
|
|
PackedVec<PackedType, float> valLamport;
|
|
valLamport.packed = loadPackedVolatile<PackedType>(&broadcastBufR[threadOffset]);
|
|
while (isNegZero(valLamport.elements[0]))
|
|
{
|
|
valLamport.packed = loadPackedVolatile<PackedType>(&broadcastBufR[threadOffset]);
|
|
}
|
|
if (outputPtr)
|
|
{
|
|
reinterpret_cast<PackedType*>(&outputPtr[threadOffset])[0] = valLamport.packed;
|
|
}
|
|
|
|
// Update the buffer flags
|
|
flag.waitAndUpdate({static_cast<uint32_t>(divUp<uint32_t>(numTokens, WorldSize) * WorldSize * tokenDim
|
|
* kELT_SIZE), // Clear Size for scatter stage
|
|
static_cast<uint32_t>(numTokens * tokenDim * kELT_SIZE), // Clear Size for broadcast stage
|
|
0, 0});
|
|
// If not wait for results, we will rely on the following kernel to update the buffer
|
|
}
|
|
}
|
|
|
|
// This kernel works performant when loads_per_thread is 1.
|
|
// For this mode, we are able to support up to 1024 (threads) x 8 (elements) = 8192 hidden dimension.
|
|
// There are two options for further scaling up:
|
|
// 1. Use CGA if supported. It expands the hidden dimension to 8k x 8 = 64k.
|
|
// 2. Set loads_per_thread >1. Which can be used if CGA is not supported. Note that this will be limited by the
|
|
// shared memory size and register count.
|
|
template <typename T_IN, typename T_OUT, int LoadsPerThread = 1>
|
|
__global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OUT* outputNorm, T_IN* bufferInput,
|
|
T_IN const* gamma, float epsilon, T_IN const* residual, uint32_t numTokens, uint32_t dim, uint32_t worldSize,
|
|
uint32_t* bufferFlags)
|
|
{
|
|
static_assert(std::is_same_v<T_IN, T_OUT>, "T_IN and T_OUT must be the same type");
|
|
static int const kELTS_PER_LOAD = sizeof(float4) / sizeof(T_IN);
|
|
|
|
uint32_t const token = blockIdx.x;
|
|
uint32_t const blockSize = blockDim.x;
|
|
uint32_t const threadOffset = threadIdx.x;
|
|
|
|
uint32_t numThreads = blockSize;
|
|
uint32_t clusterSize = 1;
|
|
uint32_t blockOffset = 0;
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
namespace cg = cooperative_groups;
|
|
cg::cluster_group cluster = cg::this_cluster();
|
|
numThreads = cluster.num_threads();
|
|
clusterSize = cluster.num_blocks();
|
|
blockOffset = cluster.block_rank();
|
|
#endif
|
|
uint32_t const dimPadded = divUp(dim, kELTS_PER_LOAD * numThreads) * kELTS_PER_LOAD * numThreads;
|
|
uint32_t const elemsPerThread = dimPadded / numThreads;
|
|
uint32_t const loadStride = blockSize;
|
|
|
|
extern __shared__ uint8_t smem[];
|
|
float rInput[LoadsPerThread * kELTS_PER_LOAD];
|
|
uint32_t offsets[LoadsPerThread * kELTS_PER_LOAD];
|
|
|
|
uint32_t const smemBufferSize = blockSize * elemsPerThread * sizeof(T_IN);
|
|
T_IN* smemInput = (T_IN*) &smem[0];
|
|
T_IN* smemResidual = (T_IN*) &smem[smemBufferSize];
|
|
T_IN* smemGamma = (T_IN*) &smem[2 * smemBufferSize];
|
|
|
|
LamportFlags<float4> flag(bufferFlags, MNNVLTwoShotStage::NUM_STAGES);
|
|
T_IN* input = reinterpret_cast<T_IN*>(
|
|
flag.getCurLamportBuf(reinterpret_cast<void*>(bufferInput), MNNVLTwoShotStage::BROADCAST));
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cudaTriggerProgrammaticLaunchCompletion();
|
|
#endif
|
|
// The offset that current thread should load from. Note that the hidden dimension is split by CGA size and each
|
|
// block loads a contiguous chunk;
|
|
// The size of chunk that each block processes
|
|
uint32_t const blockChunkSize = divUp(dim, clusterSize * kELTS_PER_LOAD) * kELTS_PER_LOAD;
|
|
uint32_t const blockLoadOffset = token * dim + blockOffset * blockChunkSize;
|
|
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < LoadsPerThread; i++)
|
|
{
|
|
// Each block load a contiguous chunk of tokens
|
|
uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD;
|
|
offsets[i] = blockLoadOffset + threadLoadOffset;
|
|
}
|
|
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < LoadsPerThread; i++)
|
|
{
|
|
uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD;
|
|
if (blockOffset * blockChunkSize + threadLoadOffset < dim)
|
|
{
|
|
copyF4(&smemResidual[threadLoadOffset], &residual[blockLoadOffset + threadLoadOffset]);
|
|
}
|
|
}
|
|
__pipeline_commit();
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < LoadsPerThread; i++)
|
|
{
|
|
uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD;
|
|
if (blockOffset * blockChunkSize + threadLoadOffset < dim)
|
|
{
|
|
copyF4(&smemGamma[threadLoadOffset], &gamma[blockOffset * blockChunkSize + threadLoadOffset]);
|
|
}
|
|
}
|
|
__pipeline_commit();
|
|
|
|
flag.ctaArrive();
|
|
bool valid = false;
|
|
// ACQBLK if not lamport
|
|
while (!valid)
|
|
{
|
|
valid = true;
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < LoadsPerThread; i++)
|
|
{
|
|
uint32_t threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD;
|
|
|
|
if (blockOffset * blockChunkSize + threadLoadOffset < dim)
|
|
{
|
|
|
|
float4* dst4 = reinterpret_cast<float4*>(&smemInput[threadLoadOffset]);
|
|
float4 const* src4 = reinterpret_cast<float4 const*>(&input[offsets[i]]);
|
|
|
|
float4 value = loadPackedVolatile<float4>(src4);
|
|
// Assume that the 16B were written atomically, so we only need to check one value
|
|
valid &= !isNegZero(value.x);
|
|
*dst4 = value;
|
|
}
|
|
}
|
|
}
|
|
|
|
__pipeline_wait_prior(1);
|
|
__syncthreads();
|
|
|
|
float threadSum = 0.f;
|
|
#pragma unroll
|
|
for (int i = 0; i < LoadsPerThread; i++)
|
|
{
|
|
int threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD;
|
|
if (blockOffset * blockChunkSize + threadLoadOffset < dim)
|
|
{
|
|
PackedVec<float4, T_IN> inp{.packed = loadPacked<float4>(&smemInput[threadLoadOffset])};
|
|
PackedVec<float4, T_IN> res{.packed = loadPacked<float4>(&smemResidual[threadLoadOffset])};
|
|
|
|
PackedVec<float4, T_IN> inp_plus_res = inp + res;
|
|
#pragma unroll
|
|
for (int j = 0; j < kELTS_PER_LOAD; j++)
|
|
{
|
|
rInput[i * kELTS_PER_LOAD + j] = cuda_cast<float, T_IN>(inp_plus_res.elements[j]);
|
|
threadSum += cuda_cast<float, T_IN>(inp_plus_res.elements[j] * inp_plus_res.elements[j]);
|
|
}
|
|
|
|
*reinterpret_cast<float4*>(&outputPreNorm[blockLoadOffset + threadLoadOffset]) = inp_plus_res.packed;
|
|
}
|
|
}
|
|
|
|
__pipeline_wait_prior(0);
|
|
|
|
float blockSum = blockReduceSum<float, true>(threadSum);
|
|
|
|
float fullSum = blockSum;
|
|
__shared__ float sharedVal[8];
|
|
// Use CGA Reduction if supported
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
int const numBlocks = cluster.num_blocks();
|
|
if (numBlocks > 1)
|
|
{
|
|
fullSum = 0.F;
|
|
// Need to reduce over the entire cluster
|
|
int const blockRank = cluster.block_rank();
|
|
if (threadIdx.x < numBlocks)
|
|
{
|
|
cluster.map_shared_rank(&sharedVal[0], threadIdx.x)[blockRank] = blockSum;
|
|
}
|
|
// cluster.sync();
|
|
cluster.barrier_wait(cluster.barrier_arrive());
|
|
for (int i = 0; i < numBlocks; ++i)
|
|
{
|
|
fullSum += sharedVal[i];
|
|
}
|
|
}
|
|
#endif
|
|
|
|
float rcpRms = rsqrtf(fullSum / dim + epsilon);
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < LoadsPerThread; i++)
|
|
{
|
|
PackedVec<float4, T_OUT> r_out;
|
|
uint32_t threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD;
|
|
if (blockOffset * blockChunkSize + threadLoadOffset < dim)
|
|
{
|
|
PackedVec<float4, T_IN> gamma = {.packed = loadPacked<float4>(&smemGamma[threadLoadOffset])};
|
|
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < kELTS_PER_LOAD; j++)
|
|
{
|
|
r_out.elements[j] = cuda_cast<T_OUT, float>(
|
|
cuda_cast<float, T_IN>(gamma.elements[j]) * rInput[i * kELTS_PER_LOAD + j] * rcpRms);
|
|
}
|
|
|
|
*reinterpret_cast<float4*>(&outputNorm[blockLoadOffset + threadLoadOffset]) = r_out.packed;
|
|
}
|
|
}
|
|
constexpr int kELTS_SIZE = sizeof(T_IN);
|
|
|
|
// Update the buffer pointers
|
|
flag.waitAndUpdate({static_cast<uint32_t>(divUp<uint32_t>(numTokens, worldSize) * worldSize * dim * kELTS_SIZE),
|
|
static_cast<uint32_t>(numTokens * dim * kELTS_SIZE), 0, 0});
|
|
}
|
|
|
|
void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
|
|
{
|
|
int const numTokens = params.numTokens;
|
|
int const tokenDim = params.tokenDim;
|
|
int const numEltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
|
|
TLLM_CHECK_WITH_INFO(tokenDim % numEltsPerThread == 0, "[MNNVL AllReduceTwoShot] token_dim must be divisible by %d",
|
|
numEltsPerThread);
|
|
|
|
int const arNumThreads = divUp(tokenDim, numEltsPerThread);
|
|
int const arNumBlocksPerToken = divUp(arNumThreads, 128);
|
|
|
|
dim3 arGrid(numTokens, arNumBlocksPerToken);
|
|
|
|
cudaLaunchAttribute arAttrs[1];
|
|
arAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
|
arAttrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
|
|
|
|
cudaLaunchConfig_t arConfig{
|
|
.gridDim = arGrid,
|
|
.blockDim = 128,
|
|
.dynamicSmemBytes = 0,
|
|
.stream = params.stream,
|
|
.attrs = arAttrs,
|
|
.numAttrs = 1,
|
|
};
|
|
|
|
TLLM_LOG_DEBUG(
|
|
"[MNNVL AllReduceTwoShot] Dispatch: grid size: (%d, %d, 1), block_size: 128", numTokens, arNumBlocksPerToken);
|
|
|
|
#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, T) \
|
|
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&arConfig, &twoshotAllreduceKernel<WORLD_SIZE, T>, output, input, ucPtrs, \
|
|
mcastPtr, numTokens, tokenDim, params.rank, params.bufferFlags, (!params.rmsNormFusion)));
|
|
auto dispatchAR = [&](auto* type_ptr) -> bool
|
|
{
|
|
using T = std::remove_pointer_t<decltype(type_ptr)>;
|
|
T** ucPtrs = reinterpret_cast<T**>(params.bufferPtrsDev);
|
|
T* mcastPtr = reinterpret_cast<T*>(params.multicastPtr);
|
|
T* output = reinterpret_cast<T*>(params.output);
|
|
T const* input = reinterpret_cast<T const*>(params.input);
|
|
switch (params.nRanks)
|
|
{
|
|
case 2: LAUNCH_ALLREDUCE_KERNEL(2, T); return true;
|
|
case 4: LAUNCH_ALLREDUCE_KERNEL(4, T); return true;
|
|
case 8: LAUNCH_ALLREDUCE_KERNEL(8, T); return true;
|
|
case 16: LAUNCH_ALLREDUCE_KERNEL(16, T); return true;
|
|
case 32: LAUNCH_ALLREDUCE_KERNEL(32, T); return true;
|
|
case 64: LAUNCH_ALLREDUCE_KERNEL(64, T); return true;
|
|
}
|
|
return false;
|
|
};
|
|
|
|
#undef LAUNCH_ALLREDUCE_KERNEL
|
|
|
|
bool launched = (params.dType == nvinfer1::DataType::kFLOAT && dispatchAR((float*) nullptr))
|
|
|| (params.dType == nvinfer1::DataType::kBF16 && dispatchAR((__nv_bfloat16*) nullptr))
|
|
|| (params.dType == nvinfer1::DataType::kHALF && dispatchAR((__nv_half*) nullptr));
|
|
if (!launched)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(false, "[MNNVL AllReduceTwoShot] Failed to dispatch twoshotAllreduce kernel.");
|
|
}
|
|
// Launch the rmsnorm lamport kernel if fusion is enabled
|
|
if (params.rmsNormFusion)
|
|
{
|
|
auto gridConfig = adjustGridConfig(numTokens, tokenDim, numEltsPerThread);
|
|
int rnBlockSize = std::get<0>(gridConfig);
|
|
int rnClusterSize = std::get<1>(gridConfig);
|
|
int rnLoadsPerThread = std::get<2>(gridConfig);
|
|
|
|
int rnNumThreads = rnClusterSize * rnBlockSize;
|
|
dim3 rnGrid(numTokens, rnClusterSize, 1);
|
|
cudaLaunchConfig_t rnConfig;
|
|
cudaLaunchAttribute rnAttrs[2];
|
|
rnConfig.stream = params.stream;
|
|
rnConfig.gridDim = rnGrid;
|
|
rnConfig.blockDim = rnBlockSize;
|
|
rnConfig.attrs = rnAttrs;
|
|
rnAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
|
rnAttrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
|
|
#ifndef DISABLE_CGA
|
|
rnAttrs[1].id = cudaLaunchAttributeClusterDimension;
|
|
rnAttrs[1].val.clusterDim.x = 1;
|
|
rnAttrs[1].val.clusterDim.y = rnClusterSize;
|
|
rnAttrs[1].val.clusterDim.z = 1;
|
|
rnConfig.numAttrs = 2;
|
|
#else
|
|
rnConfig.numAttrs = 1;
|
|
#endif
|
|
|
|
bool const rnUseCGA = rnClusterSize > 1;
|
|
int const dimPadded = divUp(tokenDim, numEltsPerThread * rnNumThreads) * numEltsPerThread * rnNumThreads;
|
|
int const iters = dimPadded / rnNumThreads;
|
|
|
|
size_t const smemSize = 3 * rnBlockSize * iters * getDTypeSize(params.dType);
|
|
|
|
TLLM_LOG_DEBUG(
|
|
"[MNNVL AllReduceTwoShotRMSNorm] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: %d, "
|
|
"loads_per_thread: %d, "
|
|
"threads_needed: %d",
|
|
numTokens, rnClusterSize, rnBlockSize, rnClusterSize, rnLoadsPerThread, divUp(tokenDim, numEltsPerThread));
|
|
|
|
#define RUN_RMSNORM_KERNEL(T_IN, T_OUT, LOADS_PER_THREAD) \
|
|
TLLM_CUDA_CHECK(cudaFuncSetAttribute( \
|
|
&rmsNormLamport<T_IN, T_OUT, LOADS_PER_THREAD>, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \
|
|
rnConfig.dynamicSmemBytes = smemSize; \
|
|
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&rnConfig, &rmsNormLamport<T_IN, T_OUT, LOADS_PER_THREAD>, residualOut, output, \
|
|
bufferInput, gamma, static_cast<float>(params.epsilon), residualIn, numTokens, tokenDim, params.nRanks, \
|
|
params.bufferFlags));
|
|
|
|
// C++ 17 does not support capturing structured bindings
|
|
auto dispatchRN = [&, rnLoadsPerThread](auto* type_ptr)
|
|
{
|
|
using T_IN = std::remove_pointer_t<decltype(type_ptr)>;
|
|
using T_OUT = T_IN;
|
|
T_OUT* residualOut = reinterpret_cast<T_OUT*>(params.residualOut);
|
|
T_OUT* output = reinterpret_cast<T_OUT*>(params.output);
|
|
T_IN* bufferInput = reinterpret_cast<T_IN*>(params.bufferPtrLocal);
|
|
T_IN const* gamma = reinterpret_cast<T_IN const*>(params.gamma);
|
|
T_IN const* residualIn = reinterpret_cast<T_IN const*>(params.residualIn);
|
|
if (rnUseCGA)
|
|
{
|
|
RUN_RMSNORM_KERNEL(T_IN, T_OUT, 1);
|
|
}
|
|
else
|
|
{
|
|
switch (rnLoadsPerThread)
|
|
{
|
|
case 1: RUN_RMSNORM_KERNEL(T_IN, T_OUT, 1); break;
|
|
case 2: RUN_RMSNORM_KERNEL(T_IN, T_OUT, 2); break;
|
|
case 3: RUN_RMSNORM_KERNEL(T_IN, T_OUT, 3); break;
|
|
case 4: RUN_RMSNORM_KERNEL(T_IN, T_OUT, 4); break;
|
|
case 5: RUN_RMSNORM_KERNEL(T_IN, T_OUT, 5); break;
|
|
case 6: RUN_RMSNORM_KERNEL(T_IN, T_OUT, 6); break;
|
|
case 7: RUN_RMSNORM_KERNEL(T_IN, T_OUT, 7); break;
|
|
case 8: RUN_RMSNORM_KERNEL(T_IN, T_OUT, 8); break;
|
|
default: return false;
|
|
}
|
|
}
|
|
return true;
|
|
};
|
|
|
|
launched = (params.dType == nvinfer1::DataType::kFLOAT && dispatchRN((float*) nullptr))
|
|
|| (params.dType == nvinfer1::DataType::kBF16 && dispatchRN((__nv_bfloat16*) nullptr))
|
|
|| (params.dType == nvinfer1::DataType::kHALF && dispatchRN((__nv_half*) nullptr));
|
|
if (!launched)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(false, "[MNNVL AllReduceTwoShot] Failed to dispatch rmsnorm lamport kernel.");
|
|
}
|
|
#undef RUN_RMSNORM_KERNEL
|
|
}
|
|
}
|
|
|
|
} // namespace kernels::mnnvl
|
|
|
|
TRTLLM_NAMESPACE_END
|