mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com> Signed-off-by: Dongxu Yang <dongxuy@nvidia.com> Co-authored-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
1373 lines
52 KiB
Plaintext
1373 lines
52 KiB
Plaintext
/*
|
|
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/kernels/fusedMoeCommKernels.h"
|
|
|
|
#include <type_traits>
|
|
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/common/logger.h"
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
|
|
static __device__ __forceinline__ uint32_t __as_ptr_smem(void const* __ptr)
|
|
{
|
|
// Consider adding debug asserts here.
|
|
return static_cast<uint32_t>(__cvta_generic_to_shared(__ptr));
|
|
}
|
|
|
|
static __device__ __forceinline__ uint64_t __as_ptr_gmem(void const* __ptr)
|
|
{
|
|
// Consider adding debug asserts here.
|
|
return static_cast<uint64_t>(__cvta_generic_to_global(__ptr));
|
|
}
|
|
|
|
__device__ __forceinline__ void fence_release_sys()
|
|
{
|
|
asm volatile("fence.release.sys;" : : : "memory");
|
|
}
|
|
|
|
__device__ __forceinline__ void mbarrier_init(uint64_t* addr, uint32_t const& count)
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
|
|
asm("mbarrier.init.shared.b64 [%0], %1;" : : "r"(__as_ptr_smem(addr)), "r"(count) : "memory");
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ void mbarrier_expect_tx(uint64_t* addr, const uint32_t txCount)
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
|
|
asm("mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;"
|
|
:
|
|
: "r"(__as_ptr_smem(addr)), "r"(txCount)
|
|
: "memory");
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ uint64_t mbarrier_arrive(uint64_t* addr)
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
|
|
uint64_t state;
|
|
asm("mbarrier.arrive.shared.b64 %0, [%1];" : "=l"(state) : "r"(__as_ptr_smem(addr)) : "memory");
|
|
return state;
|
|
#else
|
|
return 0;
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ uint64_t mbarrier_arrive_expect_tx(uint64_t* addr, const uint32_t txCount)
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
|
|
uint64_t state;
|
|
asm("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 %0, [%1], %2;"
|
|
: "=l"(state)
|
|
: "r"(__as_ptr_smem(addr)), "r"(txCount)
|
|
: "memory");
|
|
return state;
|
|
#else
|
|
return 0;
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint64_t* addr, uint32_t const& phaseParity)
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
|
|
uint32_t waitComplete;
|
|
asm("{\n\t .reg .pred P_OUT; \n\t"
|
|
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2;\n\t"
|
|
"selp.b32 %0, 1, 0, P_OUT; \n"
|
|
"}"
|
|
: "=r"(waitComplete)
|
|
: "r"(__as_ptr_smem(addr)), "r"(phaseParity)
|
|
: "memory");
|
|
return static_cast<bool>(waitComplete);
|
|
#else
|
|
return false;
|
|
#endif
|
|
}
|
|
|
|
template <int COPY_SIZE = 4>
|
|
__device__ __forceinline__ void ldgsts(int* dstShm, int const* srcMem, bool predGuard)
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
|
|
asm volatile(
|
|
"{\n"
|
|
" .reg .pred p;\n"
|
|
" setp.ne.b32 p, %0, 0;\n"
|
|
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
|
|
"}\n" ::"r"((int) predGuard),
|
|
"r"(__as_ptr_smem(dstShm)), "l"(__as_ptr_gmem(srcMem)), "n"(COPY_SIZE));
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ void cp_async_commit_group()
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
|
|
asm volatile("cp.async.commit_group;" : : :);
|
|
#endif
|
|
}
|
|
|
|
template <int N = 0>
|
|
__device__ __forceinline__ void cp_async_wait_group()
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
|
|
asm volatile("cp.async.wait_group %0;" : : "n"(N) : "memory");
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ void cp_async_bulk_g2s(void* dstMem, void const* srcMem, int copySize, uint64_t* smemBar)
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
|
|
asm("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];"
|
|
:
|
|
: "r"(__as_ptr_smem(dstMem)), "l"(__as_ptr_gmem(srcMem)), "r"(copySize), "r"(__as_ptr_smem(smemBar))
|
|
: "memory");
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ void cp_async_bulk_s2g(void* dstMem, void const* srcMem, int copySize)
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
|
|
asm("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;"
|
|
:
|
|
: "l"(__as_ptr_gmem(dstMem)), "r"(__as_ptr_smem(srcMem)), "r"(copySize)
|
|
: "memory");
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ void cp_async_bulk_commit_group()
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
|
|
asm volatile("cp.async.bulk.commit_group;" : : :);
|
|
#endif
|
|
}
|
|
|
|
template <int N = 0>
|
|
__device__ __forceinline__ void cp_async_bulk_wait_group()
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
|
|
asm volatile("cp.async.bulk.wait_group %0;" : : "n"(N) : "memory");
|
|
#endif
|
|
}
|
|
|
|
template <int N = 0>
|
|
__device__ __forceinline__ void cp_async_bulk_wait_group_read()
|
|
{
|
|
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
|
|
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(N) : "memory");
|
|
#endif
|
|
}
|
|
|
|
__host__ void MoeCommFieldInfo::fillFieldInfo(uint8_t* dataPtr, size_t elementSize, int vectorSize, int stride)
|
|
{
|
|
TLLM_CHECK(elementSize == 1 || elementSize == 2 || elementSize == 4 || elementSize == 8 || elementSize == 16);
|
|
|
|
dataPtrBase = dataPtr;
|
|
|
|
uint64_t dataPtrU64 = reinterpret_cast<uint64_t>(dataPtr);
|
|
|
|
while (elementSize < 16 && dataPtrU64 % (elementSize * 2) == 0 && vectorSize % 2 == 0 && stride % 2 == 0)
|
|
{
|
|
elementSize *= 2;
|
|
vectorSize /= 2;
|
|
stride /= 2;
|
|
}
|
|
|
|
if (elementSize == 16)
|
|
{
|
|
alignedUnitBit = 4;
|
|
}
|
|
else if (elementSize == 8)
|
|
{
|
|
alignedUnitBit = 3;
|
|
}
|
|
else if (elementSize == 4)
|
|
{
|
|
alignedUnitBit = 2;
|
|
}
|
|
else if (elementSize == 2)
|
|
{
|
|
alignedUnitBit = 1;
|
|
}
|
|
else
|
|
{
|
|
alignedUnitBit = 0;
|
|
}
|
|
|
|
alignedUnitCount = vectorSize;
|
|
alignedUnitStride = stride;
|
|
}
|
|
|
|
class Ll128Proto
|
|
{
|
|
public:
|
|
static constexpr uint32_t INITIALIZED_VALUE = 0xFFFFFFFFU;
|
|
|
|
template <bool USE_FINISH>
|
|
static __device__ __forceinline__ int checkDataReceivedInShm(uint8_t* sharedMemoryBase, uint64_t step,
|
|
int countIn128Bytes, int fifoEntry128ByteIndexBase, int loaded128ByteCount, int warpId, int laneId)
|
|
{
|
|
// return value should be how many package already been received.
|
|
// 0 means no data received, -1 means has received finish package(should be the very first 128 Byte).
|
|
uint64_t* aligned128BytesShm = reinterpret_cast<uint64_t*>(sharedMemoryBase);
|
|
int totalValidCount = 0;
|
|
for (int idxBase = loaded128ByteCount; idxBase < countIn128Bytes; idxBase += WARP_SIZE)
|
|
{
|
|
int idx = idxBase + laneId;
|
|
bool valid = false;
|
|
bool finish = false;
|
|
if (idx < countIn128Bytes)
|
|
{
|
|
int indexInFifoEntry = fifoEntry128ByteIndexBase + idx;
|
|
uint64_t value = aligned128BytesShm[idx * MoeCommFieldInfo::UINT64_PER_128B_BLOCK
|
|
+ indexInFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK];
|
|
if (USE_FINISH)
|
|
{
|
|
finish = (value == (step & (1ULL << 63ULL)));
|
|
valid = (value == step) || finish;
|
|
}
|
|
else
|
|
{
|
|
valid = (value == step);
|
|
}
|
|
}
|
|
__syncwarp();
|
|
unsigned validMask = __ballot_sync(WARP_MASK, valid);
|
|
// here we check valid in order, if previous valid is not true, we ignore the current valid.
|
|
int validCount = (validMask == WARP_MASK) ? WARP_SIZE : (__ffs(~validMask) - 1);
|
|
if (USE_FINISH)
|
|
{
|
|
unsigned finishedMask = __ballot_sync(WARP_MASK, finish);
|
|
// finish should be the very first 128 Byte.
|
|
if (finishedMask & 0x1)
|
|
{
|
|
return -1;
|
|
}
|
|
}
|
|
totalValidCount += validCount;
|
|
|
|
if (validCount != WARP_SIZE)
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
return totalValidCount;
|
|
}
|
|
|
|
static __device__ __forceinline__ void protoPack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes,
|
|
int fifoEntry128ByteIndexBase, int warpId, int laneId)
|
|
{
|
|
uint64_t* aligned128BytesShm = reinterpret_cast<uint64_t*>(sharedMemoryBase);
|
|
int halfLaneId = laneId % 16;
|
|
int halfIndex = laneId / 16;
|
|
int tailOffsetIn128Bytes = countIn128Bytes + halfIndex;
|
|
// for LL128 15 * 128 Bytes will be packed to 16 * 128 Bytes, each 16 threads is used for one 15 * 128 bytes.
|
|
for (int idxIn128BytesBase = halfIndex * 15; idxIn128BytesBase < countIn128Bytes; idxIn128BytesBase += 30)
|
|
{
|
|
int tailFlagIndexFromFifoEntry = fifoEntry128ByteIndexBase + tailOffsetIn128Bytes;
|
|
int tailFlagInnerIndex = tailFlagIndexFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK;
|
|
int idxIn128Bytes = idxIn128BytesBase + halfLaneId;
|
|
int idxFromFifoEntry = fifoEntry128ByteIndexBase + idxIn128Bytes;
|
|
uint64_t tailValue = step;
|
|
uint64_t tailInnerIndex = (halfLaneId >= tailFlagInnerIndex) ? halfLaneId + 1 : halfLaneId;
|
|
if (halfLaneId == 15)
|
|
{
|
|
tailInnerIndex = tailFlagInnerIndex;
|
|
}
|
|
int targetTailIndex = tailOffsetIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + tailInnerIndex;
|
|
if (idxIn128Bytes < countIn128Bytes && halfLaneId < 15)
|
|
{
|
|
int flagIndex = idxIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK
|
|
+ idxFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK;
|
|
tailValue = aligned128BytesShm[flagIndex];
|
|
aligned128BytesShm[flagIndex] = step;
|
|
}
|
|
aligned128BytesShm[targetTailIndex] = tailValue;
|
|
tailOffsetIn128Bytes += 2;
|
|
}
|
|
__syncwarp();
|
|
}
|
|
|
|
static __device__ __forceinline__ void protoUnpack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes,
|
|
int fifoEntry128ByteIndexBase, int loaded128ByteCount, int warpId, int laneId)
|
|
{
|
|
uint64_t* aligned128BytesShm = reinterpret_cast<uint64_t*>(sharedMemoryBase);
|
|
int halfLaneId = laneId % 16;
|
|
int halfIndex = laneId / 16;
|
|
int tailOffsetIn128Bytes = countIn128Bytes + halfIndex;
|
|
for (int idxIn128BytesBase = halfIndex * 15; idxIn128BytesBase < countIn128Bytes; idxIn128BytesBase += 30)
|
|
{
|
|
int tailFlagIndexFromFifoEntry = fifoEntry128ByteIndexBase + tailOffsetIn128Bytes;
|
|
int tailFlagInnerIndex = tailFlagIndexFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK;
|
|
int idxIn128Bytes = idxIn128BytesBase + halfLaneId;
|
|
int idxFromFifoEntry = fifoEntry128ByteIndexBase + idxIn128Bytes;
|
|
uint64_t tailValue = 0;
|
|
int tailInnerIndex = (halfLaneId >= tailFlagInnerIndex) ? halfLaneId + 1 : halfLaneId;
|
|
int targetTailIndex = tailOffsetIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + tailInnerIndex;
|
|
if (halfLaneId < 15)
|
|
{
|
|
tailValue = aligned128BytesShm[targetTailIndex];
|
|
}
|
|
if (idxIn128Bytes < countIn128Bytes && halfLaneId < 15)
|
|
{
|
|
int flagIndex = idxIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK
|
|
+ idxFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK;
|
|
aligned128BytesShm[flagIndex] = tailValue;
|
|
}
|
|
tailOffsetIn128Bytes += 2;
|
|
}
|
|
__syncwarp();
|
|
}
|
|
|
|
static __device__ __forceinline__ void rearm(
|
|
uint32_t* u32FifoPtr, uint64_t step, int countIn128Bytes, int fifoEntry128ByteIndexBase, int warpId, int laneId)
|
|
{
|
|
// LL128 don't need rearm
|
|
}
|
|
|
|
static __device__ __host__ __forceinline__ int computeProtoTransfer128ByteAlignedSize(
|
|
int compact128ByteSizeBeforeProto)
|
|
{
|
|
// each 15 * 128 byte need one tail 128 byte
|
|
int tail128ByteSize = (compact128ByteSizeBeforeProto + 15 * 128 - 1) / (15 * 128) * 128;
|
|
return compact128ByteSizeBeforeProto + tail128ByteSize;
|
|
}
|
|
};
|
|
|
|
using FusedMoeProto = Ll128Proto;
|
|
|
|
// using FusedMoeProto = LamportProto;
|
|
|
|
namespace fused_moe_impl
|
|
{
|
|
|
|
// returns copy size for txCount
|
|
__device__ __forceinline__ int startFieldG2S(MoeCommFieldInfo const& fieldInfo, int dataIndex,
|
|
uint8_t* sharedMemoryBase, int warpId, int laneId, uint64_t* smemBar)
|
|
{
|
|
// we can copy more data than needed, just align to 16 bytes.
|
|
int alignedShmLoadOffset = fieldInfo.getUncompactShmOffset();
|
|
uint8_t* sharedMemoryLoadPtr = sharedMemoryBase + alignedShmLoadOffset;
|
|
int copyByteCount = 0;
|
|
uint8_t* loadPtr = fieldInfo.get16BAlignedLoadCopyRange(dataIndex, ©ByteCount);
|
|
if (laneId == 0 && copyByteCount > 0)
|
|
{
|
|
cp_async_bulk_g2s(sharedMemoryLoadPtr, loadPtr, copyByteCount, smemBar);
|
|
}
|
|
return copyByteCount;
|
|
}
|
|
|
|
__device__ __forceinline__ void startFieldS2G(
|
|
MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId)
|
|
{
|
|
int alignedShmStoreOffset = fieldInfo.getUncompactShmOffset();
|
|
uint8_t* sharedMemoryStorePtr = sharedMemoryBase + alignedShmStoreOffset;
|
|
int copyByteCount = 0;
|
|
int headTailShmIdx;
|
|
int headTailGlobalIdx;
|
|
uint8_t* storePtr
|
|
= fieldInfo.get16BAlignedStoreCopyRange(dataIndex, ©ByteCount, laneId, &headTailShmIdx, &headTailGlobalIdx);
|
|
if (copyByteCount > 0 && laneId == 0)
|
|
{
|
|
cp_async_bulk_s2g(storePtr, sharedMemoryStorePtr + MoeCommFieldInfo::BYTES_PER_16B_BLOCK, copyByteCount);
|
|
}
|
|
if (headTailGlobalIdx >= 0)
|
|
{
|
|
// copy head and tail
|
|
fieldInfo.getRawPtr(dataIndex, nullptr)[headTailGlobalIdx] = sharedMemoryStorePtr[headTailShmIdx];
|
|
}
|
|
__syncwarp();
|
|
}
|
|
|
|
// SRC_AFTER_DST is true, if src > dst, pack will use this,
|
|
// SRC_AFTER_DST is false, if src < dst, unpack will use this
|
|
template <typename T, bool SRC_AFTER_DST = true>
|
|
__device__ __forceinline__ void memmoveSharedMemory(uint8_t* dst, uint8_t const* src, int copySize, int laneId)
|
|
{
|
|
int count = (copySize + sizeof(T) - 1) / sizeof(T);
|
|
int warpLoopStart = SRC_AFTER_DST ? 0 : (count + WARP_SIZE - 1) / WARP_SIZE - 1;
|
|
int warpLoopEnd = SRC_AFTER_DST ? (count + WARP_SIZE - 1) / WARP_SIZE : -1;
|
|
int warpLoopUpdate = SRC_AFTER_DST ? 1 : -1;
|
|
for (int i = warpLoopStart; i != warpLoopEnd; i += warpLoopUpdate)
|
|
{
|
|
int idx = laneId + i * WARP_SIZE;
|
|
T data = T{};
|
|
if (idx < count)
|
|
{
|
|
data = reinterpret_cast<T const*>(src)[idx];
|
|
}
|
|
__syncwarp();
|
|
if (idx < count)
|
|
{
|
|
reinterpret_cast<T*>(dst)[idx] = data;
|
|
}
|
|
__syncwarp();
|
|
}
|
|
}
|
|
|
|
template <bool IS_PACK = true>
|
|
__device__ __forceinline__ void memmoveFieldOnSharedMemory(
|
|
MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
|
|
{
|
|
int movOffset = fieldInfo.getMemmoveOffsets(dataIndex);
|
|
if (movOffset == 0)
|
|
{
|
|
// if movOffset is 0, src and dst are the same, don't need memmove.
|
|
return;
|
|
}
|
|
int alignedBytes = 1 << fieldInfo.alignedUnitBit;
|
|
int copySize = fieldInfo.alignedUnitCount * alignedBytes;
|
|
uint8_t* sharedMemoryCompact = sharedMemoryBase + fieldInfo.getCompactShmOffset();
|
|
uint8_t* sharedMemoryUncompact = sharedMemoryCompact + movOffset;
|
|
uint8_t* sharedMemoryDst = IS_PACK ? sharedMemoryCompact : sharedMemoryUncompact;
|
|
uint8_t* sharedMemorySrc = IS_PACK ? sharedMemoryUncompact : sharedMemoryCompact;
|
|
|
|
if (movOffset % 16 == 0)
|
|
{
|
|
memmoveSharedMemory<int4, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
|
|
}
|
|
else if (movOffset % 8 == 0)
|
|
{
|
|
memmoveSharedMemory<int64_t, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
|
|
}
|
|
else if (movOffset % 4 == 0)
|
|
{
|
|
memmoveSharedMemory<int, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
|
|
}
|
|
else if (movOffset % 2 == 0)
|
|
{
|
|
memmoveSharedMemory<int16_t, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
|
|
}
|
|
else
|
|
{
|
|
memmoveSharedMemory<int8_t, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
|
|
}
|
|
}
|
|
|
|
template <int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
|
|
__device__ __forceinline__ void packAllFields(
|
|
FusedMoeFieldInfo const& sendFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
|
|
{
|
|
#pragma unroll
|
|
for (int i = 0; i < FIELD_COUNT; i++)
|
|
{
|
|
memmoveFieldOnSharedMemory<true>(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId);
|
|
}
|
|
__syncwarp();
|
|
}
|
|
|
|
template <int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
|
|
__device__ __forceinline__ void unpackAllFields(
|
|
FusedMoeFieldInfo const& recvFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
|
|
{
|
|
#pragma unroll
|
|
for (int i = FIELD_COUNT - 1; i >= 0; i--)
|
|
{
|
|
memmoveFieldOnSharedMemory<false>(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId);
|
|
}
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ void initSmemBar(uint64_t* smemBar, int laneId)
|
|
{
|
|
if (laneId == 0)
|
|
{
|
|
mbarrier_init(smemBar, WARP_SIZE);
|
|
}
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ void smemBarWait(uint64_t* smemBar, uint32_t* phaseParity)
|
|
{
|
|
while (!mbarrier_try_wait_parity(smemBar, *phaseParity))
|
|
{
|
|
}
|
|
*phaseParity = 1 - *phaseParity;
|
|
}
|
|
|
|
__device__ __forceinline__ void startWorkspaceS2G(
|
|
uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int warpId, int laneId)
|
|
{
|
|
int copyByteCount = send128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
|
|
if (laneId == 0)
|
|
{
|
|
cp_async_bulk_s2g(fifoEntry + fifo128ByteOffset * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t),
|
|
sharedMemoryBase, copyByteCount);
|
|
}
|
|
__syncwarp();
|
|
cp_async_bulk_commit_group();
|
|
}
|
|
|
|
__device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* sharedMemoryBase, uint64_t* fifoEntry,
|
|
int allLoad128ByteCount, int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int warpId, int laneId)
|
|
{
|
|
int copyByteCount = (allLoad128ByteCount - loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
|
|
if (laneId == 0)
|
|
{
|
|
cp_async_bulk_g2s(sharedMemoryBase + loaded128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK,
|
|
fifoEntry
|
|
+ (fifo128ByteOffset + loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t),
|
|
copyByteCount, smemBar);
|
|
}
|
|
return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? copyByteCount : 0);
|
|
}
|
|
|
|
__device__ __forceinline__ void g2sBasicFields(FusedMoeFieldInfo const& sendFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
|
|
{
|
|
int topK = expertParallelInfo.topK;
|
|
int* tokenSelectedSlotsPtr = sendFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK);
|
|
float* scalePtr = sendFieldInfo.getScalePtr(dataIndex, laneId, topK);
|
|
ldgsts<4>(reinterpret_cast<int*>(sharedMemoryBase) + laneId, tokenSelectedSlotsPtr, laneId < topK);
|
|
ldgsts<4>(reinterpret_cast<int*>(sharedMemoryBase) + laneId + topK, reinterpret_cast<int*>(scalePtr),
|
|
laneId < topK && sendFieldInfo.expertScales != nullptr);
|
|
}
|
|
|
|
// May commit 1 group for basic fields(tokenSelectedSlots and scales) if HAS_BASIC_FIELDS is true
|
|
// For other fields, use smemBar.
|
|
template <bool HAS_BASIC_FIELDS = true, int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
|
|
__device__ __forceinline__ uint64_t g2sAllFields(FusedMoeFieldInfo const& sendFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId,
|
|
uint64_t* smemBar)
|
|
{
|
|
if (HAS_BASIC_FIELDS)
|
|
{
|
|
g2sBasicFields(sendFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, laneId);
|
|
cp_async_commit_group();
|
|
}
|
|
int asyncLoadSize = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < FIELD_COUNT; i++)
|
|
{
|
|
asyncLoadSize
|
|
+= startFieldG2S(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId, smemBar);
|
|
}
|
|
return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? asyncLoadSize : 0);
|
|
}
|
|
|
|
template <bool HAS_BASIC_FIELDS = true>
|
|
__device__ __forceinline__ void waitG2SBasicFields()
|
|
{
|
|
if (HAS_BASIC_FIELDS)
|
|
{
|
|
cp_async_wait_group<0>();
|
|
__syncwarp();
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void waitG2SOtherFields(uint64_t* memBar, uint32_t* phaseParity)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::smemBarWait(memBar, phaseParity);
|
|
}
|
|
|
|
template <bool HAS_BASIC_FIELDS = true>
|
|
__device__ __forceinline__ void waitG2SAllFields(uint64_t* memBar, uint32_t* phaseParity)
|
|
{
|
|
waitG2SBasicFields<HAS_BASIC_FIELDS>();
|
|
waitG2SOtherFields(memBar, phaseParity);
|
|
}
|
|
|
|
__device__ __forceinline__ void waitS2GBulkRead()
|
|
{
|
|
cp_async_bulk_wait_group_read<0>();
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ void s2gBasicFields(FusedMoeFieldInfo const& recvFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId)
|
|
{
|
|
int topK = expertParallelInfo.topK;
|
|
int* tokenSelectedSlotsPtr = recvFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK);
|
|
float* scalePtr = recvFieldInfo.getScalePtr(dataIndex, laneId, topK);
|
|
if (laneId < topK)
|
|
{
|
|
int selectedSlot = reinterpret_cast<int*>(sharedMemoryBase)[laneId];
|
|
*tokenSelectedSlotsPtr = selectedSlot;
|
|
if (recvFieldInfo.expertScales != nullptr)
|
|
{
|
|
float scale = reinterpret_cast<float*>(sharedMemoryBase)[laneId + topK];
|
|
*scalePtr = scale;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Will commit 1 group, for all non-basic fields
|
|
template <bool HAS_BASIC_FIELDS = true, int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
|
|
__device__ __forceinline__ void s2gAllFields(FusedMoeFieldInfo const& recvFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId)
|
|
{
|
|
if (HAS_BASIC_FIELDS)
|
|
{
|
|
s2gBasicFields(recvFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, warpId, laneId);
|
|
__syncwarp();
|
|
}
|
|
#pragma unroll
|
|
for (int i = 0; i < FIELD_COUNT; i++)
|
|
{
|
|
startFieldS2G(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId);
|
|
}
|
|
cp_async_bulk_commit_group();
|
|
}
|
|
|
|
template <int FIELD_COUNT, bool HAS_BASIC_FIELD = true>
|
|
class SingleChannelCommunicator
|
|
{
|
|
public:
|
|
__device__ __forceinline__ SingleChannelCommunicator(FusedMoeFieldInfo const& fieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, MoeSingleCommMeta const& commMeta,
|
|
FusedMoeWorkspace const& workspace, FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo,
|
|
uint64_t* smemBar, uint8_t* shmemBase)
|
|
: mFieldInfo(fieldInfo)
|
|
, mExpertParallelInfo(expertParallelInfo)
|
|
, mCommMeta(commMeta)
|
|
, mWorkspace(workspace)
|
|
, mWorldInfo(worldInfo)
|
|
, mPairInfo(pairInfo)
|
|
, mSmemBar(smemBar)
|
|
, mShmemBase(shmemBase)
|
|
{
|
|
mWarpId = threadIdx.x / WARP_SIZE;
|
|
mLaneId = threadIdx.x % WARP_SIZE;
|
|
|
|
mFifoBasePtr = mWorkspace.getFifoBasePtr(mWorldInfo, mPairInfo);
|
|
mSenderSideFifoInfo = mWorkspace.getSenderSideFifoInfo(mWorldInfo, mPairInfo);
|
|
mReceiverSideFifoInfo = mWorkspace.getReceiverSideFifoInfo(mWorldInfo, mPairInfo);
|
|
|
|
mSingleTransfer128ByteCount = mCommMeta.getTransfer128ByteCount();
|
|
mSingleCompactData128ByteCount = mCommMeta.getCompactData128ByteCount();
|
|
// initialize as need new Entry first
|
|
mFifoEntry128ByteIndexBase = kFifoEntry128ByteCount;
|
|
mFifoEntryIndex = -1;
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::initSmemBar(mSmemBar, mLaneId);
|
|
}
|
|
|
|
__device__ __forceinline__ uint64_t* getFifoEntryPtr() const
|
|
{
|
|
return mFifoBasePtr + mFifoEntryIndex * kFifoEntrySizeInU64;
|
|
}
|
|
|
|
__device__ __forceinline__ bool needNewEntry() const
|
|
{
|
|
return mFifoEntry128ByteIndexBase + mSingleTransfer128ByteCount > kFifoEntry128ByteCount;
|
|
}
|
|
|
|
__device__ __forceinline__ void nextToken()
|
|
{
|
|
mFifoEntry128ByteIndexBase += mSingleTransfer128ByteCount;
|
|
}
|
|
|
|
__device__ __forceinline__ void senderInitFifo()
|
|
{
|
|
mHead = mSenderSideFifoInfo->head;
|
|
mTail = mSenderSideFifoInfo->tail;
|
|
}
|
|
|
|
__device__ __forceinline__ void receiverInitFifo()
|
|
{
|
|
mHead = mReceiverSideFifoInfo->head;
|
|
mTail = mReceiverSideFifoInfo->tail;
|
|
}
|
|
|
|
/*
|
|
* Head | 0 | 1 | 2 | 3 | 4 | 4 | 4 | 4 | 4 | 5 |
|
|
* Tail | 0 | 0 | 0 | 0 | 0 | 1 | 2 | 3 | 4 | 4 |
|
|
* Writable | Y | Y | Y | Y | N | Y | Y | Y | Y | Y |
|
|
* Readable | N | Y | Y | Y | Y | Y | Y | Y | N | Y |
|
|
*/
|
|
|
|
__device__ __forceinline__ void waitEntryWritable()
|
|
{
|
|
while (mTail + kFifoDepth <= mHead)
|
|
{
|
|
mTail = mSenderSideFifoInfo->tail;
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void updateWriteEntry()
|
|
{
|
|
__syncwarp();
|
|
mSenderSideFifoInfo->head = mHead;
|
|
}
|
|
|
|
__device__ __forceinline__ void waitEntryReadable()
|
|
{
|
|
// always readable as long as flag matches.
|
|
}
|
|
|
|
__device__ __forceinline__ void updateReadEntry()
|
|
{
|
|
mReceiverSideFifoInfo->tail = mTail;
|
|
mSenderSideFifoInfo->tail = mTail;
|
|
}
|
|
|
|
__device__ __forceinline__ void newSendEntry()
|
|
{
|
|
mFifoEntryIndex = mHead % kFifoDepth;
|
|
mFifoEntry128ByteIndexBase = 0;
|
|
waitEntryWritable();
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ void newReceiveEntry()
|
|
{
|
|
mFifoEntryIndex = mTail % kFifoDepth;
|
|
mFifoEntry128ByteIndexBase = 0;
|
|
waitEntryReadable();
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ void doSend(int tokenCount, int* sendIndexMapping)
|
|
{
|
|
senderInitFifo();
|
|
|
|
int sendIndex = mPairInfo.channel;
|
|
uint32_t phaseParity = 0;
|
|
for (; sendIndex < tokenCount; sendIndex += mPairInfo.runChannelCount)
|
|
{
|
|
int tokenIndex = sendIndexMapping == nullptr ? sendIndex : sendIndexMapping[sendIndex];
|
|
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<HAS_BASIC_FIELD, FIELD_COUNT>(
|
|
mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId, mSmemBar);
|
|
if (needNewEntry())
|
|
{
|
|
if (mFifoEntryIndex >= 0)
|
|
{
|
|
// not first entry, update FIFO info from last entry.
|
|
mHead++;
|
|
updateWriteEntry();
|
|
}
|
|
newSendEntry();
|
|
}
|
|
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<HAS_BASIC_FIELD>(mSmemBar, &phaseParity);
|
|
tensorrt_llm::kernels::fused_moe_impl::packAllFields<FIELD_COUNT>(
|
|
mFieldInfo, tokenIndex, mShmemBase, mLaneId);
|
|
|
|
FusedMoeProto::protoPack(
|
|
mShmemBase, mHead, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId);
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::startWorkspaceS2G(getFifoEntryPtr(), mShmemBase,
|
|
mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId);
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead();
|
|
|
|
nextToken();
|
|
}
|
|
if (mFifoEntry128ByteIndexBase > 0)
|
|
{
|
|
mHead++;
|
|
updateWriteEntry();
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void rearmFifoBuffer()
|
|
{
|
|
constexpr int kUint32CountPer128Byte = 128 / sizeof(uint32_t);
|
|
uint32_t* fifoPtr = reinterpret_cast<uint32_t*>(getFifoEntryPtr());
|
|
fifoPtr += mFifoEntry128ByteIndexBase * kUint32CountPer128Byte;
|
|
|
|
FusedMoeProto::rearm(fifoPtr, mTail, mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId);
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ void doReceive(int tokenCount, int* recvIndexMapping)
|
|
{
|
|
receiverInitFifo();
|
|
int recvIndex = mPairInfo.channel;
|
|
uint32_t phaseParity = 0;
|
|
bool needRelease = false;
|
|
for (; recvIndex < tokenCount; recvIndex += mPairInfo.runChannelCount)
|
|
{
|
|
int tokenIndex = recvIndexMapping == nullptr ? recvIndex : recvIndexMapping[recvIndex];
|
|
int loaded128ByteCount = 0;
|
|
if (needNewEntry())
|
|
{
|
|
if (mFifoEntryIndex >= 0)
|
|
{
|
|
// not first entry, update FIFO info from last entry.
|
|
mTail++;
|
|
needRelease = true;
|
|
}
|
|
newReceiveEntry();
|
|
}
|
|
while (loaded128ByteCount < mSingleTransfer128ByteCount)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::startWorkspaceG2S(mShmemBase, getFifoEntryPtr(),
|
|
mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mSmemBar, mWarpId,
|
|
mLaneId);
|
|
if (needRelease)
|
|
{
|
|
updateReadEntry();
|
|
needRelease = false;
|
|
}
|
|
tensorrt_llm::kernels::fused_moe_impl::smemBarWait(mSmemBar, &phaseParity);
|
|
loaded128ByteCount += FusedMoeProto::template checkDataReceivedInShm<false>(mShmemBase, mTail,
|
|
mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mWarpId, mLaneId);
|
|
}
|
|
|
|
FusedMoeProto::protoUnpack(mShmemBase, mTail, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase,
|
|
loaded128ByteCount, mWarpId, mLaneId);
|
|
tensorrt_llm::kernels::fused_moe_impl::unpackAllFields<FIELD_COUNT>(
|
|
mFieldInfo, tokenIndex, mShmemBase, mLaneId);
|
|
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<HAS_BASIC_FIELD, FIELD_COUNT>(
|
|
mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId);
|
|
tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead();
|
|
|
|
rearmFifoBuffer();
|
|
nextToken();
|
|
}
|
|
if (mFifoEntry128ByteIndexBase > 0)
|
|
{
|
|
mTail++;
|
|
updateReadEntry();
|
|
}
|
|
}
|
|
|
|
private:
|
|
static constexpr int kFifoEntrySizeInU64 = FusedMoeCommunicator::FIFO_ENTRY_BYTES / sizeof(uint64_t);
|
|
static constexpr int kFifoEntry128ByteCount = FusedMoeCommunicator::FIFO_ENTRY_128_BYTE_COUNT;
|
|
static constexpr int kFifoDepth = FusedMoeCommunicator::FIFO_DEPTH;
|
|
|
|
FusedMoeFieldInfo mFieldInfo;
|
|
MoeExpertParallelInfo mExpertParallelInfo;
|
|
MoeSingleCommMeta mCommMeta;
|
|
FusedMoeWorkspace mWorkspace;
|
|
FusedMoeWorldInfo mWorldInfo;
|
|
FusedMoePairInfo mPairInfo;
|
|
uint64_t* mSmemBar;
|
|
uint8_t* mShmemBase;
|
|
|
|
int mLaneId;
|
|
int mWarpId;
|
|
|
|
uint64_t* mFifoBasePtr;
|
|
SenderSideFifoInfo* mSenderSideFifoInfo;
|
|
ReceiverSideFifoInfo* mReceiverSideFifoInfo;
|
|
|
|
int64_t mHead;
|
|
int64_t mTail;
|
|
|
|
int mSingleTransfer128ByteCount;
|
|
int mSingleCompactData128ByteCount;
|
|
int mFifoEntry128ByteIndexBase;
|
|
int mFifoEntryIndex;
|
|
};
|
|
|
|
template <int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
|
|
__global__ void moeAllToAllKernel(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, bool hasBasicFields)
|
|
{
|
|
__shared__ uint64_t allWarpSmemBar[32];
|
|
extern __shared__ int4 allWarpShm[];
|
|
|
|
bool isSender = blockIdx.z == 0;
|
|
int runChannelCount = gridDim.y;
|
|
int group = threadIdx.y;
|
|
SendRecvIndices dataIndices = isSender ? params.sendIndices : params.recvIndices;
|
|
|
|
FusedMoePairInfo pairInfo;
|
|
int peerRank = blockIdx.x * blockDim.y + group;
|
|
if (peerRank >= params.worldInfo.epInfo.epSize)
|
|
{
|
|
return;
|
|
}
|
|
int tokenCount;
|
|
int* groupStartPtr = dataIndices.getGroupStart(peerRank, tokenCount);
|
|
if (tokenCount == 0)
|
|
{
|
|
return;
|
|
}
|
|
|
|
pairInfo.channel = blockIdx.y;
|
|
pairInfo.runChannelCount = runChannelCount;
|
|
pairInfo.senderRank = isSender ? params.worldInfo.epInfo.epRank : peerRank;
|
|
pairInfo.receiverRank = isSender ? peerRank : params.worldInfo.epInfo.epRank;
|
|
|
|
if (isSender)
|
|
{
|
|
int singleShmSize = params.sendCommMeta.getSingleShmSize();
|
|
if (hasBasicFields)
|
|
{
|
|
SingleChannelCommunicator<FIELD_COUNT, true> comm(params.sendFieldInfo, params.expertParallelInfo,
|
|
params.sendCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group,
|
|
reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
|
|
comm.doSend(tokenCount, groupStartPtr);
|
|
}
|
|
else
|
|
{
|
|
SingleChannelCommunicator<FIELD_COUNT, false> comm(params.sendFieldInfo, params.expertParallelInfo,
|
|
params.sendCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group,
|
|
reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
|
|
comm.doSend(tokenCount, groupStartPtr);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
int singleShmSize = params.recvCommMeta.getSingleShmSize();
|
|
if (hasBasicFields)
|
|
{
|
|
SingleChannelCommunicator<FIELD_COUNT, true> comm(params.recvFieldInfo, params.expertParallelInfo,
|
|
params.recvCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group,
|
|
reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
|
|
comm.doReceive(tokenCount, groupStartPtr);
|
|
}
|
|
else
|
|
{
|
|
SingleChannelCommunicator<FIELD_COUNT, false> comm(params.recvFieldInfo, params.expertParallelInfo,
|
|
params.recvCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group,
|
|
reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
|
|
comm.doReceive(tokenCount, groupStartPtr);
|
|
}
|
|
}
|
|
}
|
|
|
|
int computeMoeAlltoallMaxDynamicSharedMemorySize()
|
|
{
|
|
int devId = -1;
|
|
TLLM_CUDA_CHECK(cudaGetDevice(&devId));
|
|
cudaFuncAttributes attr{};
|
|
TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, (void const*) moeAllToAllKernel<1>));
|
|
int staticSmem = static_cast<int>(attr.sharedSizeBytes);
|
|
int maxPerBlockShmOptin = 0;
|
|
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxPerBlockShmOptin, cudaDevAttrMaxSharedMemoryPerBlockOptin, devId));
|
|
return maxPerBlockShmOptin - staticSmem;
|
|
}
|
|
|
|
} // namespace fused_moe_impl
|
|
|
|
void FusedMoeFieldInfo::fillMetaInfo(
|
|
MoeSingleCommMeta* singleCommMeta, int topK, bool hasScales, bool hasBasicFields) const
|
|
{
|
|
singleCommMeta->singleCompactAlignedSize = computeSingleCompactSize(topK, hasScales, hasBasicFields);
|
|
singleCommMeta->singleUncompactAlignedSize = computeSingleUncompactSize(topK, hasScales, hasBasicFields);
|
|
singleCommMeta->singleTransferAlignedSize
|
|
= FusedMoeProto::computeProtoTransfer128ByteAlignedSize(singleCommMeta->singleCompactAlignedSize);
|
|
}
|
|
|
|
void FusedMoeFieldInfo::fillFieldPlacementInfo(int topK, bool hasBasicFields)
|
|
{
|
|
int basicFieldSize = 0;
|
|
if (hasBasicFields)
|
|
{
|
|
basicFieldSize = topK * sizeof(int) + (expertScales != nullptr ? topK * sizeof(float) : 0);
|
|
// align to 16 bytes
|
|
basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1)
|
|
/ MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
|
|
}
|
|
int offset = basicFieldSize;
|
|
int unalignedFieldIndex = 0;
|
|
for (int i = 0; i < fieldCount; i++)
|
|
{
|
|
fieldsInfo[i].compact16BOffset = offset / MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
|
|
offset += fieldsInfo[i].getFieldCompactSize();
|
|
fieldsInfo[i].unalignedFieldIndex = unalignedFieldIndex;
|
|
if (fieldsInfo[i].alignedUnitBit < 4)
|
|
{
|
|
unalignedFieldIndex++;
|
|
}
|
|
}
|
|
for (int i = fieldCount; i < MOE_COMM_FIELD_MAX_COUNT; i++)
|
|
{
|
|
fieldsInfo[i].setUnused();
|
|
}
|
|
}
|
|
|
|
void FusedMoeWorkspace::initializeLocalWorkspace(FusedMoeWorldInfo const& worldInfo)
|
|
{
|
|
int epSize = worldInfo.epInfo.epSize;
|
|
int epRank = worldInfo.epInfo.epRank;
|
|
size_t fifoSize = static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * epSize * channelCount;
|
|
size_t senderSideInfoSize = sizeof(SenderSideFifoInfo) * epSize * channelCount;
|
|
size_t receiverSideInfoSize = sizeof(ReceiverSideFifoInfo) * epSize * channelCount;
|
|
uint64_t* localWorkspacePtr = workspacePtr + epRank * rankStrideInU64;
|
|
TLLM_CU_CHECK(cuMemsetD32(reinterpret_cast<CUdeviceptr>(localWorkspacePtr), FusedMoeProto::INITIALIZED_VALUE,
|
|
fifoSize / sizeof(uint32_t)));
|
|
TLLM_CUDA_CHECK(cudaMemset(
|
|
reinterpret_cast<uint8_t*>(localWorkspacePtr) + fifoSize, 0, senderSideInfoSize + receiverSideInfoSize));
|
|
}
|
|
|
|
void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cudaStream_t stream)
|
|
{
|
|
bool hasBasicFields = params.sendFieldInfo.tokenSelectedSlots != nullptr;
|
|
int warpSendShmSize = params.sendCommMeta.getSingleShmSize();
|
|
int warpRecvShmSize = params.recvCommMeta.getSingleShmSize();
|
|
int warpShmSize = warpSendShmSize;
|
|
int epSize = params.worldInfo.epInfo.epSize;
|
|
TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)",
|
|
warpSendShmSize, warpRecvShmSize);
|
|
int maxGroupCountPerCta = std::min(params.worldInfo.epInfo.epSize, FusedMoeCommunicator::MAX_GROUP_COUNT_PER_BLOCK);
|
|
static int maxDynamicShmSize = fused_moe_impl::computeMoeAlltoallMaxDynamicSharedMemorySize();
|
|
int groupCountPerCta = std::min(maxGroupCountPerCta, maxDynamicShmSize / warpShmSize);
|
|
|
|
int maxFieldCount = std::max(params.sendFieldInfo.fieldCount, params.recvFieldInfo.fieldCount);
|
|
auto getFunc = [](int fieldCount)
|
|
{
|
|
switch (fieldCount)
|
|
{
|
|
case 1: return fused_moe_impl::moeAllToAllKernel<1>;
|
|
case 2: return fused_moe_impl::moeAllToAllKernel<2>;
|
|
case 3: return fused_moe_impl::moeAllToAllKernel<3>;
|
|
case 4: return fused_moe_impl::moeAllToAllKernel<4>;
|
|
case 5: return fused_moe_impl::moeAllToAllKernel<5>;
|
|
case 6: return fused_moe_impl::moeAllToAllKernel<6>;
|
|
case 7: return fused_moe_impl::moeAllToAllKernel<7>;
|
|
case 8: return fused_moe_impl::moeAllToAllKernel<8>;
|
|
default: return fused_moe_impl::moeAllToAllKernel<8>;
|
|
}
|
|
return fused_moe_impl::moeAllToAllKernel<8>;
|
|
};
|
|
auto* kernelFn = getFunc(maxFieldCount);
|
|
|
|
if (groupCountPerCta * warpShmSize > 48 * 1024)
|
|
{
|
|
TLLM_CUDA_CHECK(cudaFuncSetAttribute(
|
|
kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, groupCountPerCta * warpShmSize));
|
|
}
|
|
for (; groupCountPerCta > 0; groupCountPerCta--)
|
|
{
|
|
int dynamicShmSize = groupCountPerCta * warpShmSize;
|
|
int numBlocks = 0;
|
|
if (cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
|
&numBlocks, kernelFn, WARP_SIZE * groupCountPerCta, dynamicShmSize)
|
|
!= cudaSuccess)
|
|
{
|
|
continue;
|
|
}
|
|
if (numBlocks >= 1)
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
TLLM_CHECK_WITH_INFO(
|
|
groupCountPerCta >= 1, "computed groupCount=%d, warpShmSize=%d", groupCountPerCta, warpShmSize);
|
|
int ctaPerChannel = (epSize + groupCountPerCta - 1) / groupCountPerCta;
|
|
groupCountPerCta = (epSize + ctaPerChannel - 1) / ctaPerChannel;
|
|
int totalDynamicShmSize = warpShmSize * groupCountPerCta;
|
|
|
|
dim3 block = FusedMoeCommunicator::getLaunchBlockDim(groupCountPerCta);
|
|
dim3 grid = FusedMoeCommunicator::getLaunchGridDim(params.worldInfo.epInfo.epSize, groupCountPerCta);
|
|
kernelFn<<<grid, block, totalDynamicShmSize, stream>>>(params, workspace, hasBasicFields);
|
|
TLLM_CUDA_CHECK(cudaGetLastError());
|
|
}
|
|
|
|
int FusedMoeCommunicator::maxSmCount = -1;
|
|
bool FusedMoeCommunicator::maxSmCountUsed = false;
|
|
|
|
void setMaxUsableSmCount(int smCount)
|
|
{
|
|
FusedMoeCommunicator::setMaxUsableSmCount(smCount);
|
|
}
|
|
|
|
size_t getFusedMoeCommWorkspaceSize(int epSize)
|
|
{
|
|
int channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize);
|
|
size_t workspaceSize = FusedMoeWorkspace::computeWorkspaceSizePreRank(epSize, channelCount);
|
|
return workspaceSize;
|
|
}
|
|
|
|
void constructWorkspace(FusedMoeWorkspace* workspace, uint64_t* workspacePtr, size_t rankStrideInU64, int epSize)
|
|
{
|
|
workspace->workspacePtr = workspacePtr;
|
|
workspace->rankStrideInU64 = rankStrideInU64;
|
|
workspace->channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize);
|
|
}
|
|
|
|
void initializeFusedMoeLocalWorkspace(FusedMoeWorkspace* workspace, FusedMoeWorldInfo const& worldInfo)
|
|
{
|
|
workspace->initializeLocalWorkspace(worldInfo);
|
|
}
|
|
|
|
namespace fused_moe_comm_tests
|
|
{
|
|
|
|
__global__ void g2sKernel(FusedMoeFieldInfo allFieldInfo, MoeExpertParallelInfo expertParallelInfo,
|
|
MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmDump, bool hasBasicFields)
|
|
{
|
|
__shared__ uint64_t allWarpSmemBar[32];
|
|
extern __shared__ int4 allWarpShm[];
|
|
int laneId = threadIdx.x % WARP_SIZE;
|
|
int warpId = threadIdx.x / WARP_SIZE;
|
|
int warpCount = blockDim.x / WARP_SIZE;
|
|
int tokenIndex = warpId + blockIdx.x * warpCount;
|
|
if (tokenIndex >= tokenCount)
|
|
{
|
|
return;
|
|
}
|
|
|
|
int singleShmSize = singleCommMeta.singleUncompactAlignedSize;
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::initSmemBar(&allWarpSmemBar[warpId], laneId);
|
|
uint32_t phaseParity = 0;
|
|
|
|
uint8_t* sharedMemoryBase = reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * warpId;
|
|
|
|
if (hasBasicFields)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<true>(
|
|
allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
|
|
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<true>(&allWarpSmemBar[warpId], &phaseParity);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<false>(
|
|
allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
|
|
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<false>(&allWarpSmemBar[warpId], &phaseParity);
|
|
}
|
|
|
|
for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE)
|
|
{
|
|
shmDump[tokenIndex * singleShmSize / sizeof(int) + offset] = reinterpret_cast<int*>(sharedMemoryBase)[offset];
|
|
}
|
|
}
|
|
|
|
void launchSingleG2S(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
|
|
int tokenCount, int* shmDump, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream)
|
|
{
|
|
int warpShmSize = sendFieldInfo.computeSingleUncompactSize(
|
|
expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields);
|
|
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
|
|
dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1);
|
|
MoeSingleCommMeta singleCommMeta;
|
|
sendFieldInfo.fillMetaInfo(
|
|
&singleCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields);
|
|
TLLM_CUDA_CHECK(
|
|
cudaFuncSetAttribute(g2sKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
|
|
g2sKernel<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(
|
|
sendFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmDump, hasBasicFields);
|
|
TLLM_CUDA_CHECK(cudaGetLastError());
|
|
}
|
|
|
|
__global__ void s2gKernel(FusedMoeFieldInfo recvFieldInfo, MoeExpertParallelInfo expertParallelInfo,
|
|
MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmPreload, bool hasBasicFields)
|
|
{
|
|
extern __shared__ int4 allWarpShm[];
|
|
int laneId = threadIdx.x % WARP_SIZE;
|
|
int warpId = threadIdx.x / WARP_SIZE;
|
|
int warpCount = blockDim.x / WARP_SIZE;
|
|
int tokenIndex = warpId + blockIdx.x * warpCount;
|
|
if (tokenIndex >= tokenCount)
|
|
{
|
|
return;
|
|
}
|
|
int singleShmSize = singleCommMeta.singleUncompactAlignedSize;
|
|
uint8_t* sharedMemoryBase = reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * warpId;
|
|
|
|
for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE)
|
|
{
|
|
reinterpret_cast<int*>(sharedMemoryBase)[offset]
|
|
= shmPreload[tokenIndex * singleShmSize / sizeof(int) + offset];
|
|
}
|
|
__syncwarp();
|
|
|
|
if (hasBasicFields)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<true>(
|
|
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<false>(
|
|
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
|
|
}
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead();
|
|
}
|
|
|
|
void launchSingleS2G(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
|
|
int tokenCount, int* shmPreload, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream)
|
|
{
|
|
int warpShmSize = recvFieldInfo.computeSingleUncompactSize(
|
|
expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields);
|
|
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
|
|
dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1);
|
|
MoeSingleCommMeta singleCommMeta;
|
|
recvFieldInfo.fillMetaInfo(
|
|
&singleCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields);
|
|
TLLM_CUDA_CHECK(
|
|
cudaFuncSetAttribute(s2gKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
|
|
s2gKernel<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(
|
|
recvFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmPreload, hasBasicFields);
|
|
TLLM_CUDA_CHECK(cudaGetLastError());
|
|
}
|
|
|
|
__global__ void loopbackKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo,
|
|
MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta,
|
|
int* recvIndexMapping, int tokenCount, bool hasBasicFields)
|
|
{
|
|
__shared__ uint64_t allWarpSmemBar[32];
|
|
extern __shared__ int4 allWarpShm[];
|
|
int laneId = threadIdx.x % WARP_SIZE;
|
|
int warpId = threadIdx.x / WARP_SIZE;
|
|
int warpCount = blockDim.x / WARP_SIZE;
|
|
int tokenIndex = warpId + blockIdx.x * warpCount;
|
|
if (tokenIndex >= tokenCount)
|
|
{
|
|
return;
|
|
}
|
|
|
|
int recvTokenIndex = recvIndexMapping[tokenIndex];
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::initSmemBar(&allWarpSmemBar[warpId], laneId);
|
|
uint32_t phaseParity = 0;
|
|
|
|
int singleShmSize = sendCommMeta.getSingleShmSize();
|
|
|
|
uint8_t* sharedMemoryBase = reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * warpId;
|
|
|
|
if (hasBasicFields)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<true>(
|
|
sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<false>(
|
|
sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
|
|
}
|
|
|
|
if (hasBasicFields)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<true>(&allWarpSmemBar[warpId], &phaseParity);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<false>(&allWarpSmemBar[warpId], &phaseParity);
|
|
}
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::packAllFields(sendFieldInfo, tokenIndex, sharedMemoryBase, laneId);
|
|
|
|
tokenIndex = recvTokenIndex; // switch to recvTokenIndex;
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::unpackAllFields(recvFieldInfo, tokenIndex, sharedMemoryBase, laneId);
|
|
|
|
if (hasBasicFields)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<true>(
|
|
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<false>(
|
|
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
|
|
}
|
|
|
|
cp_async_bulk_wait_group_read<0>();
|
|
__syncwarp();
|
|
}
|
|
|
|
// G2S -> Pack -> Unpack -> S2G
|
|
void launchLoopback(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int* recvIndexMapping, int tokenCount, int warpsPerBlock,
|
|
bool hasBasicFields, cudaStream_t stream)
|
|
{
|
|
MoeSingleCommMeta sendCommMeta, recvCommMeta;
|
|
sendFieldInfo.fillMetaInfo(
|
|
&sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields);
|
|
recvFieldInfo.fillMetaInfo(
|
|
&recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields);
|
|
int warpSendShmSize = sendCommMeta.getSingleShmSize();
|
|
int warpRecvShmSize = recvCommMeta.getSingleShmSize();
|
|
int warpShmSize = warpSendShmSize;
|
|
TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)",
|
|
warpSendShmSize, warpRecvShmSize);
|
|
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
|
|
dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1);
|
|
TLLM_CUDA_CHECK(
|
|
cudaFuncSetAttribute(loopbackKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
|
|
loopbackKernel<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(sendFieldInfo, recvFieldInfo,
|
|
expertParallelInfo, sendCommMeta, recvCommMeta, recvIndexMapping, tokenCount, hasBasicFields);
|
|
TLLM_CUDA_CHECK(cudaGetLastError());
|
|
}
|
|
|
|
template <bool HAS_BASIC_FIELD = true>
|
|
__global__ void localFifoSendRecvKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo,
|
|
MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta,
|
|
FusedMoeWorkspace fusedMoeWorkspace, int* sendIndexMapping, int* recvIndexMapping, int tokenCount)
|
|
{
|
|
__shared__ uint64_t allWarpSmemBar[32];
|
|
extern __shared__ int4 allWarpShm[];
|
|
|
|
FusedMoeWorldInfo worldInfo;
|
|
worldInfo.epInfo.epRank = 0;
|
|
worldInfo.epInfo.epSize = 1;
|
|
|
|
int warpId = threadIdx.x / WARP_SIZE;
|
|
int warpCount = blockDim.x / WARP_SIZE;
|
|
|
|
FusedMoePairInfo pairInfo;
|
|
pairInfo.senderRank = 0;
|
|
pairInfo.receiverRank = 0;
|
|
pairInfo.channel = blockIdx.z * warpCount + warpId;
|
|
pairInfo.runChannelCount = gridDim.z * warpCount;
|
|
|
|
if (blockIdx.y == 0)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator<MOE_COMM_FIELD_MAX_COUNT, HAS_BASIC_FIELD>
|
|
senderComm(sendFieldInfo, expertParallelInfo, sendCommMeta, fusedMoeWorkspace, worldInfo, pairInfo,
|
|
&allWarpSmemBar[warpId],
|
|
reinterpret_cast<uint8_t*>(&allWarpShm[0]) + warpId * sendCommMeta.getSingleShmSize());
|
|
senderComm.doSend(tokenCount, sendIndexMapping);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator<MOE_COMM_FIELD_MAX_COUNT, HAS_BASIC_FIELD>
|
|
recverComm(recvFieldInfo, expertParallelInfo, recvCommMeta, fusedMoeWorkspace, worldInfo, pairInfo,
|
|
&allWarpSmemBar[warpId],
|
|
reinterpret_cast<uint8_t*>(&allWarpShm[0]) + warpId * recvCommMeta.getSingleShmSize());
|
|
recverComm.doReceive(tokenCount, recvIndexMapping);
|
|
}
|
|
}
|
|
|
|
void launchLocalFifoSendRecv(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int* sendIndexMapping, int* recvIndexMapping,
|
|
FusedMoeWorkspace fusedMoeWorkspace, int tokenCount, int warpsPerBlock, int blockChannelCount, bool hasBasicFields,
|
|
cudaStream_t stream)
|
|
{
|
|
MoeSingleCommMeta sendCommMeta, recvCommMeta;
|
|
sendFieldInfo.fillMetaInfo(
|
|
&sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields);
|
|
recvFieldInfo.fillMetaInfo(
|
|
&recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields);
|
|
int warpSendShmSize = sendCommMeta.getSingleShmSize();
|
|
int warpRecvShmSize = recvCommMeta.getSingleShmSize();
|
|
int warpShmSize = warpSendShmSize;
|
|
TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)",
|
|
warpSendShmSize, warpRecvShmSize);
|
|
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
|
|
dim3 gridDim(1, 2, blockChannelCount);
|
|
auto* kernelFn = localFifoSendRecvKernel<>;
|
|
if (hasBasicFields)
|
|
{
|
|
kernelFn = localFifoSendRecvKernel<true>;
|
|
}
|
|
else
|
|
{
|
|
kernelFn = localFifoSendRecvKernel<false>;
|
|
}
|
|
TLLM_CUDA_CHECK(
|
|
cudaFuncSetAttribute(kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
|
|
kernelFn<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(sendFieldInfo, recvFieldInfo,
|
|
expertParallelInfo, sendCommMeta, recvCommMeta, fusedMoeWorkspace, sendIndexMapping, recvIndexMapping,
|
|
tokenCount);
|
|
TLLM_CUDA_CHECK(cudaGetLastError());
|
|
}
|
|
|
|
} // namespace fused_moe_comm_tests
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|