TensorRT-LLMs/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu
dongxuy04 19a0ea363b
[TRTLLM-6743][feat] Optimize and refactor alltoall in WideEP (#6973)
Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
Signed-off-by: Dongxu Yang <dongxuy@nvidia.com>
Co-authored-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
2025-08-24 08:15:29 -04:00

1373 lines
52 KiB
Plaintext

/*
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/fusedMoeCommKernels.h"
#include <type_traits>
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
namespace tensorrt_llm
{
namespace kernels
{
static __device__ __forceinline__ uint32_t __as_ptr_smem(void const* __ptr)
{
// Consider adding debug asserts here.
return static_cast<uint32_t>(__cvta_generic_to_shared(__ptr));
}
static __device__ __forceinline__ uint64_t __as_ptr_gmem(void const* __ptr)
{
// Consider adding debug asserts here.
return static_cast<uint64_t>(__cvta_generic_to_global(__ptr));
}
__device__ __forceinline__ void fence_release_sys()
{
asm volatile("fence.release.sys;" : : : "memory");
}
__device__ __forceinline__ void mbarrier_init(uint64_t* addr, uint32_t const& count)
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
asm("mbarrier.init.shared.b64 [%0], %1;" : : "r"(__as_ptr_smem(addr)), "r"(count) : "memory");
#endif
}
__device__ __forceinline__ void mbarrier_expect_tx(uint64_t* addr, const uint32_t txCount)
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
asm("mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;"
:
: "r"(__as_ptr_smem(addr)), "r"(txCount)
: "memory");
#endif
}
__device__ __forceinline__ uint64_t mbarrier_arrive(uint64_t* addr)
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
uint64_t state;
asm("mbarrier.arrive.shared.b64 %0, [%1];" : "=l"(state) : "r"(__as_ptr_smem(addr)) : "memory");
return state;
#else
return 0;
#endif
}
__device__ __forceinline__ uint64_t mbarrier_arrive_expect_tx(uint64_t* addr, const uint32_t txCount)
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
uint64_t state;
asm("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 %0, [%1], %2;"
: "=l"(state)
: "r"(__as_ptr_smem(addr)), "r"(txCount)
: "memory");
return state;
#else
return 0;
#endif
}
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint64_t* addr, uint32_t const& phaseParity)
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
uint32_t waitComplete;
asm("{\n\t .reg .pred P_OUT; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2;\n\t"
"selp.b32 %0, 1, 0, P_OUT; \n"
"}"
: "=r"(waitComplete)
: "r"(__as_ptr_smem(addr)), "r"(phaseParity)
: "memory");
return static_cast<bool>(waitComplete);
#else
return false;
#endif
}
template <int COPY_SIZE = 4>
__device__ __forceinline__ void ldgsts(int* dstShm, int const* srcMem, bool predGuard)
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int) predGuard),
"r"(__as_ptr_smem(dstShm)), "l"(__as_ptr_gmem(srcMem)), "n"(COPY_SIZE));
#endif
}
__device__ __forceinline__ void cp_async_commit_group()
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
asm volatile("cp.async.commit_group;" : : :);
#endif
}
template <int N = 0>
__device__ __forceinline__ void cp_async_wait_group()
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
asm volatile("cp.async.wait_group %0;" : : "n"(N) : "memory");
#endif
}
__device__ __forceinline__ void cp_async_bulk_g2s(void* dstMem, void const* srcMem, int copySize, uint64_t* smemBar)
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
asm("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];"
:
: "r"(__as_ptr_smem(dstMem)), "l"(__as_ptr_gmem(srcMem)), "r"(copySize), "r"(__as_ptr_smem(smemBar))
: "memory");
#endif
}
__device__ __forceinline__ void cp_async_bulk_s2g(void* dstMem, void const* srcMem, int copySize)
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
asm("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;"
:
: "l"(__as_ptr_gmem(dstMem)), "r"(__as_ptr_smem(srcMem)), "r"(copySize)
: "memory");
#endif
}
__device__ __forceinline__ void cp_async_bulk_commit_group()
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
asm volatile("cp.async.bulk.commit_group;" : : :);
#endif
}
template <int N = 0>
__device__ __forceinline__ void cp_async_bulk_wait_group()
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
asm volatile("cp.async.bulk.wait_group %0;" : : "n"(N) : "memory");
#endif
}
template <int N = 0>
__device__ __forceinline__ void cp_async_bulk_wait_group_read()
{
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(N) : "memory");
#endif
}
__host__ void MoeCommFieldInfo::fillFieldInfo(uint8_t* dataPtr, size_t elementSize, int vectorSize, int stride)
{
TLLM_CHECK(elementSize == 1 || elementSize == 2 || elementSize == 4 || elementSize == 8 || elementSize == 16);
dataPtrBase = dataPtr;
uint64_t dataPtrU64 = reinterpret_cast<uint64_t>(dataPtr);
while (elementSize < 16 && dataPtrU64 % (elementSize * 2) == 0 && vectorSize % 2 == 0 && stride % 2 == 0)
{
elementSize *= 2;
vectorSize /= 2;
stride /= 2;
}
if (elementSize == 16)
{
alignedUnitBit = 4;
}
else if (elementSize == 8)
{
alignedUnitBit = 3;
}
else if (elementSize == 4)
{
alignedUnitBit = 2;
}
else if (elementSize == 2)
{
alignedUnitBit = 1;
}
else
{
alignedUnitBit = 0;
}
alignedUnitCount = vectorSize;
alignedUnitStride = stride;
}
class Ll128Proto
{
public:
static constexpr uint32_t INITIALIZED_VALUE = 0xFFFFFFFFU;
template <bool USE_FINISH>
static __device__ __forceinline__ int checkDataReceivedInShm(uint8_t* sharedMemoryBase, uint64_t step,
int countIn128Bytes, int fifoEntry128ByteIndexBase, int loaded128ByteCount, int warpId, int laneId)
{
// return value should be how many package already been received.
// 0 means no data received, -1 means has received finish package(should be the very first 128 Byte).
uint64_t* aligned128BytesShm = reinterpret_cast<uint64_t*>(sharedMemoryBase);
int totalValidCount = 0;
for (int idxBase = loaded128ByteCount; idxBase < countIn128Bytes; idxBase += WARP_SIZE)
{
int idx = idxBase + laneId;
bool valid = false;
bool finish = false;
if (idx < countIn128Bytes)
{
int indexInFifoEntry = fifoEntry128ByteIndexBase + idx;
uint64_t value = aligned128BytesShm[idx * MoeCommFieldInfo::UINT64_PER_128B_BLOCK
+ indexInFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK];
if (USE_FINISH)
{
finish = (value == (step & (1ULL << 63ULL)));
valid = (value == step) || finish;
}
else
{
valid = (value == step);
}
}
__syncwarp();
unsigned validMask = __ballot_sync(WARP_MASK, valid);
// here we check valid in order, if previous valid is not true, we ignore the current valid.
int validCount = (validMask == WARP_MASK) ? WARP_SIZE : (__ffs(~validMask) - 1);
if (USE_FINISH)
{
unsigned finishedMask = __ballot_sync(WARP_MASK, finish);
// finish should be the very first 128 Byte.
if (finishedMask & 0x1)
{
return -1;
}
}
totalValidCount += validCount;
if (validCount != WARP_SIZE)
{
break;
}
}
return totalValidCount;
}
static __device__ __forceinline__ void protoPack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes,
int fifoEntry128ByteIndexBase, int warpId, int laneId)
{
uint64_t* aligned128BytesShm = reinterpret_cast<uint64_t*>(sharedMemoryBase);
int halfLaneId = laneId % 16;
int halfIndex = laneId / 16;
int tailOffsetIn128Bytes = countIn128Bytes + halfIndex;
// for LL128 15 * 128 Bytes will be packed to 16 * 128 Bytes, each 16 threads is used for one 15 * 128 bytes.
for (int idxIn128BytesBase = halfIndex * 15; idxIn128BytesBase < countIn128Bytes; idxIn128BytesBase += 30)
{
int tailFlagIndexFromFifoEntry = fifoEntry128ByteIndexBase + tailOffsetIn128Bytes;
int tailFlagInnerIndex = tailFlagIndexFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK;
int idxIn128Bytes = idxIn128BytesBase + halfLaneId;
int idxFromFifoEntry = fifoEntry128ByteIndexBase + idxIn128Bytes;
uint64_t tailValue = step;
uint64_t tailInnerIndex = (halfLaneId >= tailFlagInnerIndex) ? halfLaneId + 1 : halfLaneId;
if (halfLaneId == 15)
{
tailInnerIndex = tailFlagInnerIndex;
}
int targetTailIndex = tailOffsetIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + tailInnerIndex;
if (idxIn128Bytes < countIn128Bytes && halfLaneId < 15)
{
int flagIndex = idxIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK
+ idxFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK;
tailValue = aligned128BytesShm[flagIndex];
aligned128BytesShm[flagIndex] = step;
}
aligned128BytesShm[targetTailIndex] = tailValue;
tailOffsetIn128Bytes += 2;
}
__syncwarp();
}
static __device__ __forceinline__ void protoUnpack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes,
int fifoEntry128ByteIndexBase, int loaded128ByteCount, int warpId, int laneId)
{
uint64_t* aligned128BytesShm = reinterpret_cast<uint64_t*>(sharedMemoryBase);
int halfLaneId = laneId % 16;
int halfIndex = laneId / 16;
int tailOffsetIn128Bytes = countIn128Bytes + halfIndex;
for (int idxIn128BytesBase = halfIndex * 15; idxIn128BytesBase < countIn128Bytes; idxIn128BytesBase += 30)
{
int tailFlagIndexFromFifoEntry = fifoEntry128ByteIndexBase + tailOffsetIn128Bytes;
int tailFlagInnerIndex = tailFlagIndexFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK;
int idxIn128Bytes = idxIn128BytesBase + halfLaneId;
int idxFromFifoEntry = fifoEntry128ByteIndexBase + idxIn128Bytes;
uint64_t tailValue = 0;
int tailInnerIndex = (halfLaneId >= tailFlagInnerIndex) ? halfLaneId + 1 : halfLaneId;
int targetTailIndex = tailOffsetIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + tailInnerIndex;
if (halfLaneId < 15)
{
tailValue = aligned128BytesShm[targetTailIndex];
}
if (idxIn128Bytes < countIn128Bytes && halfLaneId < 15)
{
int flagIndex = idxIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK
+ idxFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK;
aligned128BytesShm[flagIndex] = tailValue;
}
tailOffsetIn128Bytes += 2;
}
__syncwarp();
}
static __device__ __forceinline__ void rearm(
uint32_t* u32FifoPtr, uint64_t step, int countIn128Bytes, int fifoEntry128ByteIndexBase, int warpId, int laneId)
{
// LL128 don't need rearm
}
static __device__ __host__ __forceinline__ int computeProtoTransfer128ByteAlignedSize(
int compact128ByteSizeBeforeProto)
{
// each 15 * 128 byte need one tail 128 byte
int tail128ByteSize = (compact128ByteSizeBeforeProto + 15 * 128 - 1) / (15 * 128) * 128;
return compact128ByteSizeBeforeProto + tail128ByteSize;
}
};
using FusedMoeProto = Ll128Proto;
// using FusedMoeProto = LamportProto;
namespace fused_moe_impl
{
// returns copy size for txCount
__device__ __forceinline__ int startFieldG2S(MoeCommFieldInfo const& fieldInfo, int dataIndex,
uint8_t* sharedMemoryBase, int warpId, int laneId, uint64_t* smemBar)
{
// we can copy more data than needed, just align to 16 bytes.
int alignedShmLoadOffset = fieldInfo.getUncompactShmOffset();
uint8_t* sharedMemoryLoadPtr = sharedMemoryBase + alignedShmLoadOffset;
int copyByteCount = 0;
uint8_t* loadPtr = fieldInfo.get16BAlignedLoadCopyRange(dataIndex, &copyByteCount);
if (laneId == 0 && copyByteCount > 0)
{
cp_async_bulk_g2s(sharedMemoryLoadPtr, loadPtr, copyByteCount, smemBar);
}
return copyByteCount;
}
__device__ __forceinline__ void startFieldS2G(
MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId)
{
int alignedShmStoreOffset = fieldInfo.getUncompactShmOffset();
uint8_t* sharedMemoryStorePtr = sharedMemoryBase + alignedShmStoreOffset;
int copyByteCount = 0;
int headTailShmIdx;
int headTailGlobalIdx;
uint8_t* storePtr
= fieldInfo.get16BAlignedStoreCopyRange(dataIndex, &copyByteCount, laneId, &headTailShmIdx, &headTailGlobalIdx);
if (copyByteCount > 0 && laneId == 0)
{
cp_async_bulk_s2g(storePtr, sharedMemoryStorePtr + MoeCommFieldInfo::BYTES_PER_16B_BLOCK, copyByteCount);
}
if (headTailGlobalIdx >= 0)
{
// copy head and tail
fieldInfo.getRawPtr(dataIndex, nullptr)[headTailGlobalIdx] = sharedMemoryStorePtr[headTailShmIdx];
}
__syncwarp();
}
// SRC_AFTER_DST is true, if src > dst, pack will use this,
// SRC_AFTER_DST is false, if src < dst, unpack will use this
template <typename T, bool SRC_AFTER_DST = true>
__device__ __forceinline__ void memmoveSharedMemory(uint8_t* dst, uint8_t const* src, int copySize, int laneId)
{
int count = (copySize + sizeof(T) - 1) / sizeof(T);
int warpLoopStart = SRC_AFTER_DST ? 0 : (count + WARP_SIZE - 1) / WARP_SIZE - 1;
int warpLoopEnd = SRC_AFTER_DST ? (count + WARP_SIZE - 1) / WARP_SIZE : -1;
int warpLoopUpdate = SRC_AFTER_DST ? 1 : -1;
for (int i = warpLoopStart; i != warpLoopEnd; i += warpLoopUpdate)
{
int idx = laneId + i * WARP_SIZE;
T data = T{};
if (idx < count)
{
data = reinterpret_cast<T const*>(src)[idx];
}
__syncwarp();
if (idx < count)
{
reinterpret_cast<T*>(dst)[idx] = data;
}
__syncwarp();
}
}
template <bool IS_PACK = true>
__device__ __forceinline__ void memmoveFieldOnSharedMemory(
MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
{
int movOffset = fieldInfo.getMemmoveOffsets(dataIndex);
if (movOffset == 0)
{
// if movOffset is 0, src and dst are the same, don't need memmove.
return;
}
int alignedBytes = 1 << fieldInfo.alignedUnitBit;
int copySize = fieldInfo.alignedUnitCount * alignedBytes;
uint8_t* sharedMemoryCompact = sharedMemoryBase + fieldInfo.getCompactShmOffset();
uint8_t* sharedMemoryUncompact = sharedMemoryCompact + movOffset;
uint8_t* sharedMemoryDst = IS_PACK ? sharedMemoryCompact : sharedMemoryUncompact;
uint8_t* sharedMemorySrc = IS_PACK ? sharedMemoryUncompact : sharedMemoryCompact;
if (movOffset % 16 == 0)
{
memmoveSharedMemory<int4, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
}
else if (movOffset % 8 == 0)
{
memmoveSharedMemory<int64_t, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
}
else if (movOffset % 4 == 0)
{
memmoveSharedMemory<int, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
}
else if (movOffset % 2 == 0)
{
memmoveSharedMemory<int16_t, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
}
else
{
memmoveSharedMemory<int8_t, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
}
}
template <int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
__device__ __forceinline__ void packAllFields(
FusedMoeFieldInfo const& sendFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
{
#pragma unroll
for (int i = 0; i < FIELD_COUNT; i++)
{
memmoveFieldOnSharedMemory<true>(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId);
}
__syncwarp();
}
template <int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
__device__ __forceinline__ void unpackAllFields(
FusedMoeFieldInfo const& recvFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
{
#pragma unroll
for (int i = FIELD_COUNT - 1; i >= 0; i--)
{
memmoveFieldOnSharedMemory<false>(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId);
}
__syncwarp();
}
__device__ __forceinline__ void initSmemBar(uint64_t* smemBar, int laneId)
{
if (laneId == 0)
{
mbarrier_init(smemBar, WARP_SIZE);
}
__syncwarp();
}
__device__ __forceinline__ void smemBarWait(uint64_t* smemBar, uint32_t* phaseParity)
{
while (!mbarrier_try_wait_parity(smemBar, *phaseParity))
{
}
*phaseParity = 1 - *phaseParity;
}
__device__ __forceinline__ void startWorkspaceS2G(
uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int warpId, int laneId)
{
int copyByteCount = send128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
if (laneId == 0)
{
cp_async_bulk_s2g(fifoEntry + fifo128ByteOffset * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t),
sharedMemoryBase, copyByteCount);
}
__syncwarp();
cp_async_bulk_commit_group();
}
__device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* sharedMemoryBase, uint64_t* fifoEntry,
int allLoad128ByteCount, int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int warpId, int laneId)
{
int copyByteCount = (allLoad128ByteCount - loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
if (laneId == 0)
{
cp_async_bulk_g2s(sharedMemoryBase + loaded128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK,
fifoEntry
+ (fifo128ByteOffset + loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t),
copyByteCount, smemBar);
}
return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? copyByteCount : 0);
}
__device__ __forceinline__ void g2sBasicFields(FusedMoeFieldInfo const& sendFieldInfo,
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
{
int topK = expertParallelInfo.topK;
int* tokenSelectedSlotsPtr = sendFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK);
float* scalePtr = sendFieldInfo.getScalePtr(dataIndex, laneId, topK);
ldgsts<4>(reinterpret_cast<int*>(sharedMemoryBase) + laneId, tokenSelectedSlotsPtr, laneId < topK);
ldgsts<4>(reinterpret_cast<int*>(sharedMemoryBase) + laneId + topK, reinterpret_cast<int*>(scalePtr),
laneId < topK && sendFieldInfo.expertScales != nullptr);
}
// May commit 1 group for basic fields(tokenSelectedSlots and scales) if HAS_BASIC_FIELDS is true
// For other fields, use smemBar.
template <bool HAS_BASIC_FIELDS = true, int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
__device__ __forceinline__ uint64_t g2sAllFields(FusedMoeFieldInfo const& sendFieldInfo,
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId,
uint64_t* smemBar)
{
if (HAS_BASIC_FIELDS)
{
g2sBasicFields(sendFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, laneId);
cp_async_commit_group();
}
int asyncLoadSize = 0;
#pragma unroll
for (int i = 0; i < FIELD_COUNT; i++)
{
asyncLoadSize
+= startFieldG2S(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId, smemBar);
}
return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? asyncLoadSize : 0);
}
template <bool HAS_BASIC_FIELDS = true>
__device__ __forceinline__ void waitG2SBasicFields()
{
if (HAS_BASIC_FIELDS)
{
cp_async_wait_group<0>();
__syncwarp();
}
}
__device__ __forceinline__ void waitG2SOtherFields(uint64_t* memBar, uint32_t* phaseParity)
{
tensorrt_llm::kernels::fused_moe_impl::smemBarWait(memBar, phaseParity);
}
template <bool HAS_BASIC_FIELDS = true>
__device__ __forceinline__ void waitG2SAllFields(uint64_t* memBar, uint32_t* phaseParity)
{
waitG2SBasicFields<HAS_BASIC_FIELDS>();
waitG2SOtherFields(memBar, phaseParity);
}
__device__ __forceinline__ void waitS2GBulkRead()
{
cp_async_bulk_wait_group_read<0>();
__syncwarp();
}
__device__ __forceinline__ void s2gBasicFields(FusedMoeFieldInfo const& recvFieldInfo,
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId)
{
int topK = expertParallelInfo.topK;
int* tokenSelectedSlotsPtr = recvFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK);
float* scalePtr = recvFieldInfo.getScalePtr(dataIndex, laneId, topK);
if (laneId < topK)
{
int selectedSlot = reinterpret_cast<int*>(sharedMemoryBase)[laneId];
*tokenSelectedSlotsPtr = selectedSlot;
if (recvFieldInfo.expertScales != nullptr)
{
float scale = reinterpret_cast<float*>(sharedMemoryBase)[laneId + topK];
*scalePtr = scale;
}
}
}
// Will commit 1 group, for all non-basic fields
template <bool HAS_BASIC_FIELDS = true, int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
__device__ __forceinline__ void s2gAllFields(FusedMoeFieldInfo const& recvFieldInfo,
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId)
{
if (HAS_BASIC_FIELDS)
{
s2gBasicFields(recvFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, warpId, laneId);
__syncwarp();
}
#pragma unroll
for (int i = 0; i < FIELD_COUNT; i++)
{
startFieldS2G(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId);
}
cp_async_bulk_commit_group();
}
template <int FIELD_COUNT, bool HAS_BASIC_FIELD = true>
class SingleChannelCommunicator
{
public:
__device__ __forceinline__ SingleChannelCommunicator(FusedMoeFieldInfo const& fieldInfo,
MoeExpertParallelInfo const& expertParallelInfo, MoeSingleCommMeta const& commMeta,
FusedMoeWorkspace const& workspace, FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo,
uint64_t* smemBar, uint8_t* shmemBase)
: mFieldInfo(fieldInfo)
, mExpertParallelInfo(expertParallelInfo)
, mCommMeta(commMeta)
, mWorkspace(workspace)
, mWorldInfo(worldInfo)
, mPairInfo(pairInfo)
, mSmemBar(smemBar)
, mShmemBase(shmemBase)
{
mWarpId = threadIdx.x / WARP_SIZE;
mLaneId = threadIdx.x % WARP_SIZE;
mFifoBasePtr = mWorkspace.getFifoBasePtr(mWorldInfo, mPairInfo);
mSenderSideFifoInfo = mWorkspace.getSenderSideFifoInfo(mWorldInfo, mPairInfo);
mReceiverSideFifoInfo = mWorkspace.getReceiverSideFifoInfo(mWorldInfo, mPairInfo);
mSingleTransfer128ByteCount = mCommMeta.getTransfer128ByteCount();
mSingleCompactData128ByteCount = mCommMeta.getCompactData128ByteCount();
// initialize as need new Entry first
mFifoEntry128ByteIndexBase = kFifoEntry128ByteCount;
mFifoEntryIndex = -1;
tensorrt_llm::kernels::fused_moe_impl::initSmemBar(mSmemBar, mLaneId);
}
__device__ __forceinline__ uint64_t* getFifoEntryPtr() const
{
return mFifoBasePtr + mFifoEntryIndex * kFifoEntrySizeInU64;
}
__device__ __forceinline__ bool needNewEntry() const
{
return mFifoEntry128ByteIndexBase + mSingleTransfer128ByteCount > kFifoEntry128ByteCount;
}
__device__ __forceinline__ void nextToken()
{
mFifoEntry128ByteIndexBase += mSingleTransfer128ByteCount;
}
__device__ __forceinline__ void senderInitFifo()
{
mHead = mSenderSideFifoInfo->head;
mTail = mSenderSideFifoInfo->tail;
}
__device__ __forceinline__ void receiverInitFifo()
{
mHead = mReceiverSideFifoInfo->head;
mTail = mReceiverSideFifoInfo->tail;
}
/*
* Head | 0 | 1 | 2 | 3 | 4 | 4 | 4 | 4 | 4 | 5 |
* Tail | 0 | 0 | 0 | 0 | 0 | 1 | 2 | 3 | 4 | 4 |
* Writable | Y | Y | Y | Y | N | Y | Y | Y | Y | Y |
* Readable | N | Y | Y | Y | Y | Y | Y | Y | N | Y |
*/
__device__ __forceinline__ void waitEntryWritable()
{
while (mTail + kFifoDepth <= mHead)
{
mTail = mSenderSideFifoInfo->tail;
}
}
__device__ __forceinline__ void updateWriteEntry()
{
__syncwarp();
mSenderSideFifoInfo->head = mHead;
}
__device__ __forceinline__ void waitEntryReadable()
{
// always readable as long as flag matches.
}
__device__ __forceinline__ void updateReadEntry()
{
mReceiverSideFifoInfo->tail = mTail;
mSenderSideFifoInfo->tail = mTail;
}
__device__ __forceinline__ void newSendEntry()
{
mFifoEntryIndex = mHead % kFifoDepth;
mFifoEntry128ByteIndexBase = 0;
waitEntryWritable();
__syncwarp();
}
__device__ __forceinline__ void newReceiveEntry()
{
mFifoEntryIndex = mTail % kFifoDepth;
mFifoEntry128ByteIndexBase = 0;
waitEntryReadable();
__syncwarp();
}
__device__ __forceinline__ void doSend(int tokenCount, int* sendIndexMapping)
{
senderInitFifo();
int sendIndex = mPairInfo.channel;
uint32_t phaseParity = 0;
for (; sendIndex < tokenCount; sendIndex += mPairInfo.runChannelCount)
{
int tokenIndex = sendIndexMapping == nullptr ? sendIndex : sendIndexMapping[sendIndex];
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<HAS_BASIC_FIELD, FIELD_COUNT>(
mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId, mSmemBar);
if (needNewEntry())
{
if (mFifoEntryIndex >= 0)
{
// not first entry, update FIFO info from last entry.
mHead++;
updateWriteEntry();
}
newSendEntry();
}
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<HAS_BASIC_FIELD>(mSmemBar, &phaseParity);
tensorrt_llm::kernels::fused_moe_impl::packAllFields<FIELD_COUNT>(
mFieldInfo, tokenIndex, mShmemBase, mLaneId);
FusedMoeProto::protoPack(
mShmemBase, mHead, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId);
tensorrt_llm::kernels::fused_moe_impl::startWorkspaceS2G(getFifoEntryPtr(), mShmemBase,
mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId);
tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead();
nextToken();
}
if (mFifoEntry128ByteIndexBase > 0)
{
mHead++;
updateWriteEntry();
}
}
__device__ __forceinline__ void rearmFifoBuffer()
{
constexpr int kUint32CountPer128Byte = 128 / sizeof(uint32_t);
uint32_t* fifoPtr = reinterpret_cast<uint32_t*>(getFifoEntryPtr());
fifoPtr += mFifoEntry128ByteIndexBase * kUint32CountPer128Byte;
FusedMoeProto::rearm(fifoPtr, mTail, mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId);
__syncwarp();
}
__device__ __forceinline__ void doReceive(int tokenCount, int* recvIndexMapping)
{
receiverInitFifo();
int recvIndex = mPairInfo.channel;
uint32_t phaseParity = 0;
bool needRelease = false;
for (; recvIndex < tokenCount; recvIndex += mPairInfo.runChannelCount)
{
int tokenIndex = recvIndexMapping == nullptr ? recvIndex : recvIndexMapping[recvIndex];
int loaded128ByteCount = 0;
if (needNewEntry())
{
if (mFifoEntryIndex >= 0)
{
// not first entry, update FIFO info from last entry.
mTail++;
needRelease = true;
}
newReceiveEntry();
}
while (loaded128ByteCount < mSingleTransfer128ByteCount)
{
tensorrt_llm::kernels::fused_moe_impl::startWorkspaceG2S(mShmemBase, getFifoEntryPtr(),
mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mSmemBar, mWarpId,
mLaneId);
if (needRelease)
{
updateReadEntry();
needRelease = false;
}
tensorrt_llm::kernels::fused_moe_impl::smemBarWait(mSmemBar, &phaseParity);
loaded128ByteCount += FusedMoeProto::template checkDataReceivedInShm<false>(mShmemBase, mTail,
mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mWarpId, mLaneId);
}
FusedMoeProto::protoUnpack(mShmemBase, mTail, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase,
loaded128ByteCount, mWarpId, mLaneId);
tensorrt_llm::kernels::fused_moe_impl::unpackAllFields<FIELD_COUNT>(
mFieldInfo, tokenIndex, mShmemBase, mLaneId);
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<HAS_BASIC_FIELD, FIELD_COUNT>(
mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId);
tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead();
rearmFifoBuffer();
nextToken();
}
if (mFifoEntry128ByteIndexBase > 0)
{
mTail++;
updateReadEntry();
}
}
private:
static constexpr int kFifoEntrySizeInU64 = FusedMoeCommunicator::FIFO_ENTRY_BYTES / sizeof(uint64_t);
static constexpr int kFifoEntry128ByteCount = FusedMoeCommunicator::FIFO_ENTRY_128_BYTE_COUNT;
static constexpr int kFifoDepth = FusedMoeCommunicator::FIFO_DEPTH;
FusedMoeFieldInfo mFieldInfo;
MoeExpertParallelInfo mExpertParallelInfo;
MoeSingleCommMeta mCommMeta;
FusedMoeWorkspace mWorkspace;
FusedMoeWorldInfo mWorldInfo;
FusedMoePairInfo mPairInfo;
uint64_t* mSmemBar;
uint8_t* mShmemBase;
int mLaneId;
int mWarpId;
uint64_t* mFifoBasePtr;
SenderSideFifoInfo* mSenderSideFifoInfo;
ReceiverSideFifoInfo* mReceiverSideFifoInfo;
int64_t mHead;
int64_t mTail;
int mSingleTransfer128ByteCount;
int mSingleCompactData128ByteCount;
int mFifoEntry128ByteIndexBase;
int mFifoEntryIndex;
};
template <int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
__global__ void moeAllToAllKernel(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, bool hasBasicFields)
{
__shared__ uint64_t allWarpSmemBar[32];
extern __shared__ int4 allWarpShm[];
bool isSender = blockIdx.z == 0;
int runChannelCount = gridDim.y;
int group = threadIdx.y;
SendRecvIndices dataIndices = isSender ? params.sendIndices : params.recvIndices;
FusedMoePairInfo pairInfo;
int peerRank = blockIdx.x * blockDim.y + group;
if (peerRank >= params.worldInfo.epInfo.epSize)
{
return;
}
int tokenCount;
int* groupStartPtr = dataIndices.getGroupStart(peerRank, tokenCount);
if (tokenCount == 0)
{
return;
}
pairInfo.channel = blockIdx.y;
pairInfo.runChannelCount = runChannelCount;
pairInfo.senderRank = isSender ? params.worldInfo.epInfo.epRank : peerRank;
pairInfo.receiverRank = isSender ? peerRank : params.worldInfo.epInfo.epRank;
if (isSender)
{
int singleShmSize = params.sendCommMeta.getSingleShmSize();
if (hasBasicFields)
{
SingleChannelCommunicator<FIELD_COUNT, true> comm(params.sendFieldInfo, params.expertParallelInfo,
params.sendCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group,
reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
comm.doSend(tokenCount, groupStartPtr);
}
else
{
SingleChannelCommunicator<FIELD_COUNT, false> comm(params.sendFieldInfo, params.expertParallelInfo,
params.sendCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group,
reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
comm.doSend(tokenCount, groupStartPtr);
}
}
else
{
int singleShmSize = params.recvCommMeta.getSingleShmSize();
if (hasBasicFields)
{
SingleChannelCommunicator<FIELD_COUNT, true> comm(params.recvFieldInfo, params.expertParallelInfo,
params.recvCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group,
reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
comm.doReceive(tokenCount, groupStartPtr);
}
else
{
SingleChannelCommunicator<FIELD_COUNT, false> comm(params.recvFieldInfo, params.expertParallelInfo,
params.recvCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group,
reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
comm.doReceive(tokenCount, groupStartPtr);
}
}
}
int computeMoeAlltoallMaxDynamicSharedMemorySize()
{
int devId = -1;
TLLM_CUDA_CHECK(cudaGetDevice(&devId));
cudaFuncAttributes attr{};
TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, (void const*) moeAllToAllKernel<1>));
int staticSmem = static_cast<int>(attr.sharedSizeBytes);
int maxPerBlockShmOptin = 0;
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxPerBlockShmOptin, cudaDevAttrMaxSharedMemoryPerBlockOptin, devId));
return maxPerBlockShmOptin - staticSmem;
}
} // namespace fused_moe_impl
void FusedMoeFieldInfo::fillMetaInfo(
MoeSingleCommMeta* singleCommMeta, int topK, bool hasScales, bool hasBasicFields) const
{
singleCommMeta->singleCompactAlignedSize = computeSingleCompactSize(topK, hasScales, hasBasicFields);
singleCommMeta->singleUncompactAlignedSize = computeSingleUncompactSize(topK, hasScales, hasBasicFields);
singleCommMeta->singleTransferAlignedSize
= FusedMoeProto::computeProtoTransfer128ByteAlignedSize(singleCommMeta->singleCompactAlignedSize);
}
void FusedMoeFieldInfo::fillFieldPlacementInfo(int topK, bool hasBasicFields)
{
int basicFieldSize = 0;
if (hasBasicFields)
{
basicFieldSize = topK * sizeof(int) + (expertScales != nullptr ? topK * sizeof(float) : 0);
// align to 16 bytes
basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1)
/ MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
}
int offset = basicFieldSize;
int unalignedFieldIndex = 0;
for (int i = 0; i < fieldCount; i++)
{
fieldsInfo[i].compact16BOffset = offset / MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
offset += fieldsInfo[i].getFieldCompactSize();
fieldsInfo[i].unalignedFieldIndex = unalignedFieldIndex;
if (fieldsInfo[i].alignedUnitBit < 4)
{
unalignedFieldIndex++;
}
}
for (int i = fieldCount; i < MOE_COMM_FIELD_MAX_COUNT; i++)
{
fieldsInfo[i].setUnused();
}
}
void FusedMoeWorkspace::initializeLocalWorkspace(FusedMoeWorldInfo const& worldInfo)
{
int epSize = worldInfo.epInfo.epSize;
int epRank = worldInfo.epInfo.epRank;
size_t fifoSize = static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * epSize * channelCount;
size_t senderSideInfoSize = sizeof(SenderSideFifoInfo) * epSize * channelCount;
size_t receiverSideInfoSize = sizeof(ReceiverSideFifoInfo) * epSize * channelCount;
uint64_t* localWorkspacePtr = workspacePtr + epRank * rankStrideInU64;
TLLM_CU_CHECK(cuMemsetD32(reinterpret_cast<CUdeviceptr>(localWorkspacePtr), FusedMoeProto::INITIALIZED_VALUE,
fifoSize / sizeof(uint32_t)));
TLLM_CUDA_CHECK(cudaMemset(
reinterpret_cast<uint8_t*>(localWorkspacePtr) + fifoSize, 0, senderSideInfoSize + receiverSideInfoSize));
}
void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cudaStream_t stream)
{
bool hasBasicFields = params.sendFieldInfo.tokenSelectedSlots != nullptr;
int warpSendShmSize = params.sendCommMeta.getSingleShmSize();
int warpRecvShmSize = params.recvCommMeta.getSingleShmSize();
int warpShmSize = warpSendShmSize;
int epSize = params.worldInfo.epInfo.epSize;
TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)",
warpSendShmSize, warpRecvShmSize);
int maxGroupCountPerCta = std::min(params.worldInfo.epInfo.epSize, FusedMoeCommunicator::MAX_GROUP_COUNT_PER_BLOCK);
static int maxDynamicShmSize = fused_moe_impl::computeMoeAlltoallMaxDynamicSharedMemorySize();
int groupCountPerCta = std::min(maxGroupCountPerCta, maxDynamicShmSize / warpShmSize);
int maxFieldCount = std::max(params.sendFieldInfo.fieldCount, params.recvFieldInfo.fieldCount);
auto getFunc = [](int fieldCount)
{
switch (fieldCount)
{
case 1: return fused_moe_impl::moeAllToAllKernel<1>;
case 2: return fused_moe_impl::moeAllToAllKernel<2>;
case 3: return fused_moe_impl::moeAllToAllKernel<3>;
case 4: return fused_moe_impl::moeAllToAllKernel<4>;
case 5: return fused_moe_impl::moeAllToAllKernel<5>;
case 6: return fused_moe_impl::moeAllToAllKernel<6>;
case 7: return fused_moe_impl::moeAllToAllKernel<7>;
case 8: return fused_moe_impl::moeAllToAllKernel<8>;
default: return fused_moe_impl::moeAllToAllKernel<8>;
}
return fused_moe_impl::moeAllToAllKernel<8>;
};
auto* kernelFn = getFunc(maxFieldCount);
if (groupCountPerCta * warpShmSize > 48 * 1024)
{
TLLM_CUDA_CHECK(cudaFuncSetAttribute(
kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, groupCountPerCta * warpShmSize));
}
for (; groupCountPerCta > 0; groupCountPerCta--)
{
int dynamicShmSize = groupCountPerCta * warpShmSize;
int numBlocks = 0;
if (cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&numBlocks, kernelFn, WARP_SIZE * groupCountPerCta, dynamicShmSize)
!= cudaSuccess)
{
continue;
}
if (numBlocks >= 1)
{
break;
}
}
TLLM_CHECK_WITH_INFO(
groupCountPerCta >= 1, "computed groupCount=%d, warpShmSize=%d", groupCountPerCta, warpShmSize);
int ctaPerChannel = (epSize + groupCountPerCta - 1) / groupCountPerCta;
groupCountPerCta = (epSize + ctaPerChannel - 1) / ctaPerChannel;
int totalDynamicShmSize = warpShmSize * groupCountPerCta;
dim3 block = FusedMoeCommunicator::getLaunchBlockDim(groupCountPerCta);
dim3 grid = FusedMoeCommunicator::getLaunchGridDim(params.worldInfo.epInfo.epSize, groupCountPerCta);
kernelFn<<<grid, block, totalDynamicShmSize, stream>>>(params, workspace, hasBasicFields);
TLLM_CUDA_CHECK(cudaGetLastError());
}
int FusedMoeCommunicator::maxSmCount = -1;
bool FusedMoeCommunicator::maxSmCountUsed = false;
void setMaxUsableSmCount(int smCount)
{
FusedMoeCommunicator::setMaxUsableSmCount(smCount);
}
size_t getFusedMoeCommWorkspaceSize(int epSize)
{
int channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize);
size_t workspaceSize = FusedMoeWorkspace::computeWorkspaceSizePreRank(epSize, channelCount);
return workspaceSize;
}
void constructWorkspace(FusedMoeWorkspace* workspace, uint64_t* workspacePtr, size_t rankStrideInU64, int epSize)
{
workspace->workspacePtr = workspacePtr;
workspace->rankStrideInU64 = rankStrideInU64;
workspace->channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize);
}
void initializeFusedMoeLocalWorkspace(FusedMoeWorkspace* workspace, FusedMoeWorldInfo const& worldInfo)
{
workspace->initializeLocalWorkspace(worldInfo);
}
namespace fused_moe_comm_tests
{
__global__ void g2sKernel(FusedMoeFieldInfo allFieldInfo, MoeExpertParallelInfo expertParallelInfo,
MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmDump, bool hasBasicFields)
{
__shared__ uint64_t allWarpSmemBar[32];
extern __shared__ int4 allWarpShm[];
int laneId = threadIdx.x % WARP_SIZE;
int warpId = threadIdx.x / WARP_SIZE;
int warpCount = blockDim.x / WARP_SIZE;
int tokenIndex = warpId + blockIdx.x * warpCount;
if (tokenIndex >= tokenCount)
{
return;
}
int singleShmSize = singleCommMeta.singleUncompactAlignedSize;
tensorrt_llm::kernels::fused_moe_impl::initSmemBar(&allWarpSmemBar[warpId], laneId);
uint32_t phaseParity = 0;
uint8_t* sharedMemoryBase = reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * warpId;
if (hasBasicFields)
{
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<true>(
allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<true>(&allWarpSmemBar[warpId], &phaseParity);
}
else
{
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<false>(
allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<false>(&allWarpSmemBar[warpId], &phaseParity);
}
for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE)
{
shmDump[tokenIndex * singleShmSize / sizeof(int) + offset] = reinterpret_cast<int*>(sharedMemoryBase)[offset];
}
}
void launchSingleG2S(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
int tokenCount, int* shmDump, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream)
{
int warpShmSize = sendFieldInfo.computeSingleUncompactSize(
expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields);
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1);
MoeSingleCommMeta singleCommMeta;
sendFieldInfo.fillMetaInfo(
&singleCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields);
TLLM_CUDA_CHECK(
cudaFuncSetAttribute(g2sKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
g2sKernel<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(
sendFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmDump, hasBasicFields);
TLLM_CUDA_CHECK(cudaGetLastError());
}
__global__ void s2gKernel(FusedMoeFieldInfo recvFieldInfo, MoeExpertParallelInfo expertParallelInfo,
MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmPreload, bool hasBasicFields)
{
extern __shared__ int4 allWarpShm[];
int laneId = threadIdx.x % WARP_SIZE;
int warpId = threadIdx.x / WARP_SIZE;
int warpCount = blockDim.x / WARP_SIZE;
int tokenIndex = warpId + blockIdx.x * warpCount;
if (tokenIndex >= tokenCount)
{
return;
}
int singleShmSize = singleCommMeta.singleUncompactAlignedSize;
uint8_t* sharedMemoryBase = reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * warpId;
for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE)
{
reinterpret_cast<int*>(sharedMemoryBase)[offset]
= shmPreload[tokenIndex * singleShmSize / sizeof(int) + offset];
}
__syncwarp();
if (hasBasicFields)
{
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<true>(
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
}
else
{
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<false>(
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
}
tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead();
}
void launchSingleS2G(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
int tokenCount, int* shmPreload, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream)
{
int warpShmSize = recvFieldInfo.computeSingleUncompactSize(
expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields);
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1);
MoeSingleCommMeta singleCommMeta;
recvFieldInfo.fillMetaInfo(
&singleCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields);
TLLM_CUDA_CHECK(
cudaFuncSetAttribute(s2gKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
s2gKernel<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(
recvFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmPreload, hasBasicFields);
TLLM_CUDA_CHECK(cudaGetLastError());
}
__global__ void loopbackKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo,
MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta,
int* recvIndexMapping, int tokenCount, bool hasBasicFields)
{
__shared__ uint64_t allWarpSmemBar[32];
extern __shared__ int4 allWarpShm[];
int laneId = threadIdx.x % WARP_SIZE;
int warpId = threadIdx.x / WARP_SIZE;
int warpCount = blockDim.x / WARP_SIZE;
int tokenIndex = warpId + blockIdx.x * warpCount;
if (tokenIndex >= tokenCount)
{
return;
}
int recvTokenIndex = recvIndexMapping[tokenIndex];
tensorrt_llm::kernels::fused_moe_impl::initSmemBar(&allWarpSmemBar[warpId], laneId);
uint32_t phaseParity = 0;
int singleShmSize = sendCommMeta.getSingleShmSize();
uint8_t* sharedMemoryBase = reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * warpId;
if (hasBasicFields)
{
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<true>(
sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
}
else
{
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<false>(
sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
}
if (hasBasicFields)
{
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<true>(&allWarpSmemBar[warpId], &phaseParity);
}
else
{
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<false>(&allWarpSmemBar[warpId], &phaseParity);
}
tensorrt_llm::kernels::fused_moe_impl::packAllFields(sendFieldInfo, tokenIndex, sharedMemoryBase, laneId);
tokenIndex = recvTokenIndex; // switch to recvTokenIndex;
tensorrt_llm::kernels::fused_moe_impl::unpackAllFields(recvFieldInfo, tokenIndex, sharedMemoryBase, laneId);
if (hasBasicFields)
{
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<true>(
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
}
else
{
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<false>(
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
}
cp_async_bulk_wait_group_read<0>();
__syncwarp();
}
// G2S -> Pack -> Unpack -> S2G
void launchLoopback(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
MoeExpertParallelInfo const& expertParallelInfo, int* recvIndexMapping, int tokenCount, int warpsPerBlock,
bool hasBasicFields, cudaStream_t stream)
{
MoeSingleCommMeta sendCommMeta, recvCommMeta;
sendFieldInfo.fillMetaInfo(
&sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields);
recvFieldInfo.fillMetaInfo(
&recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields);
int warpSendShmSize = sendCommMeta.getSingleShmSize();
int warpRecvShmSize = recvCommMeta.getSingleShmSize();
int warpShmSize = warpSendShmSize;
TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)",
warpSendShmSize, warpRecvShmSize);
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1);
TLLM_CUDA_CHECK(
cudaFuncSetAttribute(loopbackKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
loopbackKernel<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(sendFieldInfo, recvFieldInfo,
expertParallelInfo, sendCommMeta, recvCommMeta, recvIndexMapping, tokenCount, hasBasicFields);
TLLM_CUDA_CHECK(cudaGetLastError());
}
template <bool HAS_BASIC_FIELD = true>
__global__ void localFifoSendRecvKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo,
MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta,
FusedMoeWorkspace fusedMoeWorkspace, int* sendIndexMapping, int* recvIndexMapping, int tokenCount)
{
__shared__ uint64_t allWarpSmemBar[32];
extern __shared__ int4 allWarpShm[];
FusedMoeWorldInfo worldInfo;
worldInfo.epInfo.epRank = 0;
worldInfo.epInfo.epSize = 1;
int warpId = threadIdx.x / WARP_SIZE;
int warpCount = blockDim.x / WARP_SIZE;
FusedMoePairInfo pairInfo;
pairInfo.senderRank = 0;
pairInfo.receiverRank = 0;
pairInfo.channel = blockIdx.z * warpCount + warpId;
pairInfo.runChannelCount = gridDim.z * warpCount;
if (blockIdx.y == 0)
{
tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator<MOE_COMM_FIELD_MAX_COUNT, HAS_BASIC_FIELD>
senderComm(sendFieldInfo, expertParallelInfo, sendCommMeta, fusedMoeWorkspace, worldInfo, pairInfo,
&allWarpSmemBar[warpId],
reinterpret_cast<uint8_t*>(&allWarpShm[0]) + warpId * sendCommMeta.getSingleShmSize());
senderComm.doSend(tokenCount, sendIndexMapping);
}
else
{
tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator<MOE_COMM_FIELD_MAX_COUNT, HAS_BASIC_FIELD>
recverComm(recvFieldInfo, expertParallelInfo, recvCommMeta, fusedMoeWorkspace, worldInfo, pairInfo,
&allWarpSmemBar[warpId],
reinterpret_cast<uint8_t*>(&allWarpShm[0]) + warpId * recvCommMeta.getSingleShmSize());
recverComm.doReceive(tokenCount, recvIndexMapping);
}
}
void launchLocalFifoSendRecv(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
MoeExpertParallelInfo const& expertParallelInfo, int* sendIndexMapping, int* recvIndexMapping,
FusedMoeWorkspace fusedMoeWorkspace, int tokenCount, int warpsPerBlock, int blockChannelCount, bool hasBasicFields,
cudaStream_t stream)
{
MoeSingleCommMeta sendCommMeta, recvCommMeta;
sendFieldInfo.fillMetaInfo(
&sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields);
recvFieldInfo.fillMetaInfo(
&recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields);
int warpSendShmSize = sendCommMeta.getSingleShmSize();
int warpRecvShmSize = recvCommMeta.getSingleShmSize();
int warpShmSize = warpSendShmSize;
TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)",
warpSendShmSize, warpRecvShmSize);
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
dim3 gridDim(1, 2, blockChannelCount);
auto* kernelFn = localFifoSendRecvKernel<>;
if (hasBasicFields)
{
kernelFn = localFifoSendRecvKernel<true>;
}
else
{
kernelFn = localFifoSendRecvKernel<false>;
}
TLLM_CUDA_CHECK(
cudaFuncSetAttribute(kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
kernelFn<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(sendFieldInfo, recvFieldInfo,
expertParallelInfo, sendCommMeta, recvCommMeta, fusedMoeWorkspace, sendIndexMapping, recvIndexMapping,
tokenCount);
TLLM_CUDA_CHECK(cudaGetLastError());
}
} // namespace fused_moe_comm_tests
} // namespace kernels
} // namespace tensorrt_llm