/* * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/kernels/fusedMoeCommKernels.h" #include #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" namespace tensorrt_llm { namespace kernels { static __device__ __forceinline__ uint32_t __as_ptr_smem(void const* __ptr) { // Consider adding debug asserts here. return static_cast(__cvta_generic_to_shared(__ptr)); } static __device__ __forceinline__ uint64_t __as_ptr_gmem(void const* __ptr) { // Consider adding debug asserts here. return static_cast(__cvta_generic_to_global(__ptr)); } __device__ __forceinline__ void fence_release_sys() { asm volatile("fence.release.sys;" : : : "memory"); } __device__ __forceinline__ void mbarrier_init(uint64_t* addr, uint32_t const& count) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 800 asm("mbarrier.init.shared.b64 [%0], %1;" : : "r"(__as_ptr_smem(addr)), "r"(count) : "memory"); #endif } __device__ __forceinline__ void mbarrier_expect_tx(uint64_t* addr, const uint32_t txCount) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 asm("mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" : : "r"(__as_ptr_smem(addr)), "r"(txCount) : "memory"); #endif } __device__ __forceinline__ uint64_t mbarrier_arrive(uint64_t* addr) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 800 uint64_t state; asm("mbarrier.arrive.shared.b64 %0, [%1];" : "=l"(state) : "r"(__as_ptr_smem(addr)) : "memory"); return state; #else return 0; #endif } __device__ __forceinline__ uint64_t mbarrier_arrive_expect_tx(uint64_t* addr, const uint32_t txCount) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 uint64_t state; asm("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 %0, [%1], %2;" : "=l"(state) : "r"(__as_ptr_smem(addr)), "r"(txCount) : "memory"); return state; #else return 0; #endif } __device__ __forceinline__ bool mbarrier_try_wait_parity(uint64_t* addr, uint32_t const& phaseParity) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 uint32_t waitComplete; asm("{\n\t .reg .pred P_OUT; \n\t" "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2;\n\t" "selp.b32 %0, 1, 0, P_OUT; \n" "}" : "=r"(waitComplete) : "r"(__as_ptr_smem(addr)), "r"(phaseParity) : "memory"); return static_cast(waitComplete); #else return false; #endif } template __device__ __forceinline__ void ldgsts(int* dstShm, int const* srcMem, bool predGuard) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 800 asm volatile( "{\n" " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" " @p cp.async.ca.shared.global [%1], [%2], %3;\n" "}\n" ::"r"((int) predGuard), "r"(__as_ptr_smem(dstShm)), "l"(__as_ptr_gmem(srcMem)), "n"(COPY_SIZE)); #endif } __device__ __forceinline__ void cp_async_commit_group() { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 800 asm volatile("cp.async.commit_group;" : : :); #endif } template __device__ __forceinline__ void cp_async_wait_group() { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 800 asm volatile("cp.async.wait_group %0;" : : "n"(N) : "memory"); #endif } __device__ __forceinline__ void cp_async_bulk_g2s(void* dstMem, void const* srcMem, int copySize, uint64_t* smemBar) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 asm("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" : : "r"(__as_ptr_smem(dstMem)), "l"(__as_ptr_gmem(srcMem)), "r"(copySize), "r"(__as_ptr_smem(smemBar)) : "memory"); #endif } __device__ __forceinline__ void cp_async_bulk_s2g(void* dstMem, void const* srcMem, int copySize) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 asm("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" : : "l"(__as_ptr_gmem(dstMem)), "r"(__as_ptr_smem(srcMem)), "r"(copySize) : "memory"); #endif } __device__ __forceinline__ void cp_async_bulk_commit_group() { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 asm volatile("cp.async.bulk.commit_group;" : : :); #endif } template __device__ __forceinline__ void cp_async_bulk_wait_group() { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 asm volatile("cp.async.bulk.wait_group %0;" : : "n"(N) : "memory"); #endif } template __device__ __forceinline__ void cp_async_bulk_wait_group_read() { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(N) : "memory"); #endif } __host__ void MoeCommFieldInfo::fillFieldInfo(uint8_t* dataPtr, size_t elementSize, int vectorSize, int stride) { TLLM_CHECK(elementSize == 1 || elementSize == 2 || elementSize == 4 || elementSize == 8 || elementSize == 16); dataPtrBase = dataPtr; uint64_t dataPtrU64 = reinterpret_cast(dataPtr); while (elementSize < 16 && dataPtrU64 % (elementSize * 2) == 0 && vectorSize % 2 == 0 && stride % 2 == 0) { elementSize *= 2; vectorSize /= 2; stride /= 2; } if (elementSize == 16) { alignedUnitBit = 4; } else if (elementSize == 8) { alignedUnitBit = 3; } else if (elementSize == 4) { alignedUnitBit = 2; } else if (elementSize == 2) { alignedUnitBit = 1; } else { alignedUnitBit = 0; } alignedUnitCount = vectorSize; alignedUnitStride = stride; } class Ll128Proto { public: static constexpr uint32_t INITIALIZED_VALUE = 0xFFFFFFFFU; template static __device__ __forceinline__ int checkDataReceivedInShm(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes, int fifoEntry128ByteIndexBase, int loaded128ByteCount, int warpId, int laneId) { // return value should be how many package already been received. // 0 means no data received, -1 means has received finish package(should be the very first 128 Byte). uint64_t* aligned128BytesShm = reinterpret_cast(sharedMemoryBase); int totalValidCount = 0; for (int idxBase = loaded128ByteCount; idxBase < countIn128Bytes; idxBase += WARP_SIZE) { int idx = idxBase + laneId; bool valid = false; bool finish = false; if (idx < countIn128Bytes) { int indexInFifoEntry = fifoEntry128ByteIndexBase + idx; uint64_t value = aligned128BytesShm[idx * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + indexInFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK]; if (USE_FINISH) { finish = (value == (step & (1ULL << 63ULL))); valid = (value == step) || finish; } else { valid = (value == step); } } __syncwarp(); unsigned validMask = __ballot_sync(WARP_MASK, valid); // here we check valid in order, if previous valid is not true, we ignore the current valid. int validCount = (validMask == WARP_MASK) ? WARP_SIZE : (__ffs(~validMask) - 1); if (USE_FINISH) { unsigned finishedMask = __ballot_sync(WARP_MASK, finish); // finish should be the very first 128 Byte. if (finishedMask & 0x1) { return -1; } } totalValidCount += validCount; if (validCount != WARP_SIZE) { break; } } return totalValidCount; } static __device__ __forceinline__ void protoPack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes, int fifoEntry128ByteIndexBase, int warpId, int laneId) { uint64_t* aligned128BytesShm = reinterpret_cast(sharedMemoryBase); int halfLaneId = laneId % 16; int halfIndex = laneId / 16; int tailOffsetIn128Bytes = countIn128Bytes + halfIndex; // for LL128 15 * 128 Bytes will be packed to 16 * 128 Bytes, each 16 threads is used for one 15 * 128 bytes. for (int idxIn128BytesBase = halfIndex * 15; idxIn128BytesBase < countIn128Bytes; idxIn128BytesBase += 30) { int tailFlagIndexFromFifoEntry = fifoEntry128ByteIndexBase + tailOffsetIn128Bytes; int tailFlagInnerIndex = tailFlagIndexFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK; int idxIn128Bytes = idxIn128BytesBase + halfLaneId; int idxFromFifoEntry = fifoEntry128ByteIndexBase + idxIn128Bytes; uint64_t tailValue = step; uint64_t tailInnerIndex = (halfLaneId >= tailFlagInnerIndex) ? halfLaneId + 1 : halfLaneId; if (halfLaneId == 15) { tailInnerIndex = tailFlagInnerIndex; } int targetTailIndex = tailOffsetIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + tailInnerIndex; if (idxIn128Bytes < countIn128Bytes && halfLaneId < 15) { int flagIndex = idxIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + idxFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK; tailValue = aligned128BytesShm[flagIndex]; aligned128BytesShm[flagIndex] = step; } aligned128BytesShm[targetTailIndex] = tailValue; tailOffsetIn128Bytes += 2; } __syncwarp(); } static __device__ __forceinline__ void protoUnpack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes, int fifoEntry128ByteIndexBase, int loaded128ByteCount, int warpId, int laneId) { uint64_t* aligned128BytesShm = reinterpret_cast(sharedMemoryBase); int halfLaneId = laneId % 16; int halfIndex = laneId / 16; int tailOffsetIn128Bytes = countIn128Bytes + halfIndex; for (int idxIn128BytesBase = halfIndex * 15; idxIn128BytesBase < countIn128Bytes; idxIn128BytesBase += 30) { int tailFlagIndexFromFifoEntry = fifoEntry128ByteIndexBase + tailOffsetIn128Bytes; int tailFlagInnerIndex = tailFlagIndexFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK; int idxIn128Bytes = idxIn128BytesBase + halfLaneId; int idxFromFifoEntry = fifoEntry128ByteIndexBase + idxIn128Bytes; uint64_t tailValue = 0; int tailInnerIndex = (halfLaneId >= tailFlagInnerIndex) ? halfLaneId + 1 : halfLaneId; int targetTailIndex = tailOffsetIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + tailInnerIndex; if (halfLaneId < 15) { tailValue = aligned128BytesShm[targetTailIndex]; } if (idxIn128Bytes < countIn128Bytes && halfLaneId < 15) { int flagIndex = idxIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + idxFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK; aligned128BytesShm[flagIndex] = tailValue; } tailOffsetIn128Bytes += 2; } __syncwarp(); } static __device__ __forceinline__ void rearm( uint32_t* u32FifoPtr, uint64_t step, int countIn128Bytes, int fifoEntry128ByteIndexBase, int warpId, int laneId) { // LL128 don't need rearm } static __device__ __host__ __forceinline__ int computeProtoTransfer128ByteAlignedSize( int compact128ByteSizeBeforeProto) { // each 15 * 128 byte need one tail 128 byte int tail128ByteSize = (compact128ByteSizeBeforeProto + 15 * 128 - 1) / (15 * 128) * 128; return compact128ByteSizeBeforeProto + tail128ByteSize; } }; using FusedMoeProto = Ll128Proto; // using FusedMoeProto = LamportProto; namespace fused_moe_impl { // returns copy size for txCount __device__ __forceinline__ int startFieldG2S(MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId, uint64_t* smemBar) { // we can copy more data than needed, just align to 16 bytes. int alignedShmLoadOffset = fieldInfo.getUncompactShmOffset(); uint8_t* sharedMemoryLoadPtr = sharedMemoryBase + alignedShmLoadOffset; int copyByteCount = 0; uint8_t* loadPtr = fieldInfo.get16BAlignedLoadCopyRange(dataIndex, ©ByteCount); if (laneId == 0 && copyByteCount > 0) { cp_async_bulk_g2s(sharedMemoryLoadPtr, loadPtr, copyByteCount, smemBar); } return copyByteCount; } __device__ __forceinline__ void startFieldS2G( MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId) { int alignedShmStoreOffset = fieldInfo.getUncompactShmOffset(); uint8_t* sharedMemoryStorePtr = sharedMemoryBase + alignedShmStoreOffset; int copyByteCount = 0; int headTailShmIdx; int headTailGlobalIdx; uint8_t* storePtr = fieldInfo.get16BAlignedStoreCopyRange(dataIndex, ©ByteCount, laneId, &headTailShmIdx, &headTailGlobalIdx); if (copyByteCount > 0 && laneId == 0) { cp_async_bulk_s2g(storePtr, sharedMemoryStorePtr + MoeCommFieldInfo::BYTES_PER_16B_BLOCK, copyByteCount); } if (headTailGlobalIdx >= 0) { // copy head and tail fieldInfo.getRawPtr(dataIndex, nullptr)[headTailGlobalIdx] = sharedMemoryStorePtr[headTailShmIdx]; } __syncwarp(); } // SRC_AFTER_DST is true, if src > dst, pack will use this, // SRC_AFTER_DST is false, if src < dst, unpack will use this template __device__ __forceinline__ void memmoveSharedMemory(uint8_t* dst, uint8_t const* src, int copySize, int laneId) { int count = (copySize + sizeof(T) - 1) / sizeof(T); int warpLoopStart = SRC_AFTER_DST ? 0 : (count + WARP_SIZE - 1) / WARP_SIZE - 1; int warpLoopEnd = SRC_AFTER_DST ? (count + WARP_SIZE - 1) / WARP_SIZE : -1; int warpLoopUpdate = SRC_AFTER_DST ? 1 : -1; for (int i = warpLoopStart; i != warpLoopEnd; i += warpLoopUpdate) { int idx = laneId + i * WARP_SIZE; T data = T{}; if (idx < count) { data = reinterpret_cast(src)[idx]; } __syncwarp(); if (idx < count) { reinterpret_cast(dst)[idx] = data; } __syncwarp(); } } template __device__ __forceinline__ void memmoveFieldOnSharedMemory( MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) { int movOffset = fieldInfo.getMemmoveOffsets(dataIndex); if (movOffset == 0) { // if movOffset is 0, src and dst are the same, don't need memmove. return; } int alignedBytes = 1 << fieldInfo.alignedUnitBit; int copySize = fieldInfo.alignedUnitCount * alignedBytes; uint8_t* sharedMemoryCompact = sharedMemoryBase + fieldInfo.getCompactShmOffset(); uint8_t* sharedMemoryUncompact = sharedMemoryCompact + movOffset; uint8_t* sharedMemoryDst = IS_PACK ? sharedMemoryCompact : sharedMemoryUncompact; uint8_t* sharedMemorySrc = IS_PACK ? sharedMemoryUncompact : sharedMemoryCompact; if (movOffset % 16 == 0) { memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); } else if (movOffset % 8 == 0) { memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); } else if (movOffset % 4 == 0) { memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); } else if (movOffset % 2 == 0) { memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); } else { memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); } } template __device__ __forceinline__ void packAllFields( FusedMoeFieldInfo const& sendFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) { #pragma unroll for (int i = 0; i < FIELD_COUNT; i++) { memmoveFieldOnSharedMemory(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId); } __syncwarp(); } template __device__ __forceinline__ void unpackAllFields( FusedMoeFieldInfo const& recvFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) { #pragma unroll for (int i = FIELD_COUNT - 1; i >= 0; i--) { memmoveFieldOnSharedMemory(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId); } __syncwarp(); } __device__ __forceinline__ void initSmemBar(uint64_t* smemBar, int laneId) { if (laneId == 0) { mbarrier_init(smemBar, WARP_SIZE); } __syncwarp(); } __device__ __forceinline__ void smemBarWait(uint64_t* smemBar, uint32_t* phaseParity) { while (!mbarrier_try_wait_parity(smemBar, *phaseParity)) { } *phaseParity = 1 - *phaseParity; } __device__ __forceinline__ void startWorkspaceS2G( uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int warpId, int laneId) { int copyByteCount = send128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK; if (laneId == 0) { cp_async_bulk_s2g(fifoEntry + fifo128ByteOffset * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t), sharedMemoryBase, copyByteCount); } __syncwarp(); cp_async_bulk_commit_group(); } __device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* sharedMemoryBase, uint64_t* fifoEntry, int allLoad128ByteCount, int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int warpId, int laneId) { int copyByteCount = (allLoad128ByteCount - loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK; if (laneId == 0) { cp_async_bulk_g2s(sharedMemoryBase + loaded128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK, fifoEntry + (fifo128ByteOffset + loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t), copyByteCount, smemBar); } return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? copyByteCount : 0); } __device__ __forceinline__ void g2sBasicFields(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) { int topK = expertParallelInfo.topK; int* tokenSelectedSlotsPtr = sendFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK); float* scalePtr = sendFieldInfo.getScalePtr(dataIndex, laneId, topK); ldgsts<4>(reinterpret_cast(sharedMemoryBase) + laneId, tokenSelectedSlotsPtr, laneId < topK); ldgsts<4>(reinterpret_cast(sharedMemoryBase) + laneId + topK, reinterpret_cast(scalePtr), laneId < topK && sendFieldInfo.expertScales != nullptr); } // May commit 1 group for basic fields(tokenSelectedSlots and scales) if HAS_BASIC_FIELDS is true // For other fields, use smemBar. template __device__ __forceinline__ uint64_t g2sAllFields(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId, uint64_t* smemBar) { if (HAS_BASIC_FIELDS) { g2sBasicFields(sendFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, laneId); cp_async_commit_group(); } int asyncLoadSize = 0; #pragma unroll for (int i = 0; i < FIELD_COUNT; i++) { asyncLoadSize += startFieldG2S(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId, smemBar); } return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? asyncLoadSize : 0); } template __device__ __forceinline__ void waitG2SBasicFields() { if (HAS_BASIC_FIELDS) { cp_async_wait_group<0>(); __syncwarp(); } } __device__ __forceinline__ void waitG2SOtherFields(uint64_t* memBar, uint32_t* phaseParity) { tensorrt_llm::kernels::fused_moe_impl::smemBarWait(memBar, phaseParity); } template __device__ __forceinline__ void waitG2SAllFields(uint64_t* memBar, uint32_t* phaseParity) { waitG2SBasicFields(); waitG2SOtherFields(memBar, phaseParity); } __device__ __forceinline__ void waitS2GBulkRead() { cp_async_bulk_wait_group_read<0>(); __syncwarp(); } __device__ __forceinline__ void s2gBasicFields(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId) { int topK = expertParallelInfo.topK; int* tokenSelectedSlotsPtr = recvFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK); float* scalePtr = recvFieldInfo.getScalePtr(dataIndex, laneId, topK); if (laneId < topK) { int selectedSlot = reinterpret_cast(sharedMemoryBase)[laneId]; *tokenSelectedSlotsPtr = selectedSlot; if (recvFieldInfo.expertScales != nullptr) { float scale = reinterpret_cast(sharedMemoryBase)[laneId + topK]; *scalePtr = scale; } } } // Will commit 1 group, for all non-basic fields template __device__ __forceinline__ void s2gAllFields(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId) { if (HAS_BASIC_FIELDS) { s2gBasicFields(recvFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, warpId, laneId); __syncwarp(); } #pragma unroll for (int i = 0; i < FIELD_COUNT; i++) { startFieldS2G(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId); } cp_async_bulk_commit_group(); } template class SingleChannelCommunicator { public: __device__ __forceinline__ SingleChannelCommunicator(FusedMoeFieldInfo const& fieldInfo, MoeExpertParallelInfo const& expertParallelInfo, MoeSingleCommMeta const& commMeta, FusedMoeWorkspace const& workspace, FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo, uint64_t* smemBar, uint8_t* shmemBase) : mFieldInfo(fieldInfo) , mExpertParallelInfo(expertParallelInfo) , mCommMeta(commMeta) , mWorkspace(workspace) , mWorldInfo(worldInfo) , mPairInfo(pairInfo) , mSmemBar(smemBar) , mShmemBase(shmemBase) { mWarpId = threadIdx.x / WARP_SIZE; mLaneId = threadIdx.x % WARP_SIZE; mFifoBasePtr = mWorkspace.getFifoBasePtr(mWorldInfo, mPairInfo); mSenderSideFifoInfo = mWorkspace.getSenderSideFifoInfo(mWorldInfo, mPairInfo); mReceiverSideFifoInfo = mWorkspace.getReceiverSideFifoInfo(mWorldInfo, mPairInfo); mSingleTransfer128ByteCount = mCommMeta.getTransfer128ByteCount(); mSingleCompactData128ByteCount = mCommMeta.getCompactData128ByteCount(); // initialize as need new Entry first mFifoEntry128ByteIndexBase = kFifoEntry128ByteCount; mFifoEntryIndex = -1; tensorrt_llm::kernels::fused_moe_impl::initSmemBar(mSmemBar, mLaneId); } __device__ __forceinline__ uint64_t* getFifoEntryPtr() const { return mFifoBasePtr + mFifoEntryIndex * kFifoEntrySizeInU64; } __device__ __forceinline__ bool needNewEntry() const { return mFifoEntry128ByteIndexBase + mSingleTransfer128ByteCount > kFifoEntry128ByteCount; } __device__ __forceinline__ void nextToken() { mFifoEntry128ByteIndexBase += mSingleTransfer128ByteCount; } __device__ __forceinline__ void senderInitFifo() { mHead = mSenderSideFifoInfo->head; mTail = mSenderSideFifoInfo->tail; } __device__ __forceinline__ void receiverInitFifo() { mHead = mReceiverSideFifoInfo->head; mTail = mReceiverSideFifoInfo->tail; } /* * Head | 0 | 1 | 2 | 3 | 4 | 4 | 4 | 4 | 4 | 5 | * Tail | 0 | 0 | 0 | 0 | 0 | 1 | 2 | 3 | 4 | 4 | * Writable | Y | Y | Y | Y | N | Y | Y | Y | Y | Y | * Readable | N | Y | Y | Y | Y | Y | Y | Y | N | Y | */ __device__ __forceinline__ void waitEntryWritable() { while (mTail + kFifoDepth <= mHead) { mTail = mSenderSideFifoInfo->tail; } } __device__ __forceinline__ void updateWriteEntry() { __syncwarp(); mSenderSideFifoInfo->head = mHead; } __device__ __forceinline__ void waitEntryReadable() { // always readable as long as flag matches. } __device__ __forceinline__ void updateReadEntry() { mReceiverSideFifoInfo->tail = mTail; mSenderSideFifoInfo->tail = mTail; } __device__ __forceinline__ void newSendEntry() { mFifoEntryIndex = mHead % kFifoDepth; mFifoEntry128ByteIndexBase = 0; waitEntryWritable(); __syncwarp(); } __device__ __forceinline__ void newReceiveEntry() { mFifoEntryIndex = mTail % kFifoDepth; mFifoEntry128ByteIndexBase = 0; waitEntryReadable(); __syncwarp(); } __device__ __forceinline__ void doSend(int tokenCount, int* sendIndexMapping) { senderInitFifo(); int sendIndex = mPairInfo.channel; uint32_t phaseParity = 0; for (; sendIndex < tokenCount; sendIndex += mPairInfo.runChannelCount) { int tokenIndex = sendIndexMapping == nullptr ? sendIndex : sendIndexMapping[sendIndex]; tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId, mSmemBar); if (needNewEntry()) { if (mFifoEntryIndex >= 0) { // not first entry, update FIFO info from last entry. mHead++; updateWriteEntry(); } newSendEntry(); } tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(mSmemBar, &phaseParity); tensorrt_llm::kernels::fused_moe_impl::packAllFields( mFieldInfo, tokenIndex, mShmemBase, mLaneId); FusedMoeProto::protoPack( mShmemBase, mHead, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId); tensorrt_llm::kernels::fused_moe_impl::startWorkspaceS2G(getFifoEntryPtr(), mShmemBase, mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId); tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead(); nextToken(); } if (mFifoEntry128ByteIndexBase > 0) { mHead++; updateWriteEntry(); } } __device__ __forceinline__ void rearmFifoBuffer() { constexpr int kUint32CountPer128Byte = 128 / sizeof(uint32_t); uint32_t* fifoPtr = reinterpret_cast(getFifoEntryPtr()); fifoPtr += mFifoEntry128ByteIndexBase * kUint32CountPer128Byte; FusedMoeProto::rearm(fifoPtr, mTail, mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId); __syncwarp(); } __device__ __forceinline__ void doReceive(int tokenCount, int* recvIndexMapping) { receiverInitFifo(); int recvIndex = mPairInfo.channel; uint32_t phaseParity = 0; bool needRelease = false; for (; recvIndex < tokenCount; recvIndex += mPairInfo.runChannelCount) { int tokenIndex = recvIndexMapping == nullptr ? recvIndex : recvIndexMapping[recvIndex]; int loaded128ByteCount = 0; if (needNewEntry()) { if (mFifoEntryIndex >= 0) { // not first entry, update FIFO info from last entry. mTail++; needRelease = true; } newReceiveEntry(); } while (loaded128ByteCount < mSingleTransfer128ByteCount) { tensorrt_llm::kernels::fused_moe_impl::startWorkspaceG2S(mShmemBase, getFifoEntryPtr(), mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mSmemBar, mWarpId, mLaneId); if (needRelease) { updateReadEntry(); needRelease = false; } tensorrt_llm::kernels::fused_moe_impl::smemBarWait(mSmemBar, &phaseParity); loaded128ByteCount += FusedMoeProto::template checkDataReceivedInShm(mShmemBase, mTail, mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mWarpId, mLaneId); } FusedMoeProto::protoUnpack(mShmemBase, mTail, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mWarpId, mLaneId); tensorrt_llm::kernels::fused_moe_impl::unpackAllFields( mFieldInfo, tokenIndex, mShmemBase, mLaneId); tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId); tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead(); rearmFifoBuffer(); nextToken(); } if (mFifoEntry128ByteIndexBase > 0) { mTail++; updateReadEntry(); } } private: static constexpr int kFifoEntrySizeInU64 = FusedMoeCommunicator::FIFO_ENTRY_BYTES / sizeof(uint64_t); static constexpr int kFifoEntry128ByteCount = FusedMoeCommunicator::FIFO_ENTRY_128_BYTE_COUNT; static constexpr int kFifoDepth = FusedMoeCommunicator::FIFO_DEPTH; FusedMoeFieldInfo mFieldInfo; MoeExpertParallelInfo mExpertParallelInfo; MoeSingleCommMeta mCommMeta; FusedMoeWorkspace mWorkspace; FusedMoeWorldInfo mWorldInfo; FusedMoePairInfo mPairInfo; uint64_t* mSmemBar; uint8_t* mShmemBase; int mLaneId; int mWarpId; uint64_t* mFifoBasePtr; SenderSideFifoInfo* mSenderSideFifoInfo; ReceiverSideFifoInfo* mReceiverSideFifoInfo; int64_t mHead; int64_t mTail; int mSingleTransfer128ByteCount; int mSingleCompactData128ByteCount; int mFifoEntry128ByteIndexBase; int mFifoEntryIndex; }; template __global__ void moeAllToAllKernel(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, bool hasBasicFields) { __shared__ uint64_t allWarpSmemBar[32]; extern __shared__ int4 allWarpShm[]; bool isSender = blockIdx.z == 0; int runChannelCount = gridDim.y; int group = threadIdx.y; SendRecvIndices dataIndices = isSender ? params.sendIndices : params.recvIndices; FusedMoePairInfo pairInfo; int peerRank = blockIdx.x * blockDim.y + group; if (peerRank >= params.worldInfo.epInfo.epSize) { return; } int tokenCount; int* groupStartPtr = dataIndices.getGroupStart(peerRank, tokenCount); if (tokenCount == 0) { return; } pairInfo.channel = blockIdx.y; pairInfo.runChannelCount = runChannelCount; pairInfo.senderRank = isSender ? params.worldInfo.epInfo.epRank : peerRank; pairInfo.receiverRank = isSender ? peerRank : params.worldInfo.epInfo.epRank; if (isSender) { int singleShmSize = params.sendCommMeta.getSingleShmSize(); if (hasBasicFields) { SingleChannelCommunicator comm(params.sendFieldInfo, params.expertParallelInfo, params.sendCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, reinterpret_cast(allWarpShm) + singleShmSize * group); comm.doSend(tokenCount, groupStartPtr); } else { SingleChannelCommunicator comm(params.sendFieldInfo, params.expertParallelInfo, params.sendCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, reinterpret_cast(allWarpShm) + singleShmSize * group); comm.doSend(tokenCount, groupStartPtr); } } else { int singleShmSize = params.recvCommMeta.getSingleShmSize(); if (hasBasicFields) { SingleChannelCommunicator comm(params.recvFieldInfo, params.expertParallelInfo, params.recvCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, reinterpret_cast(allWarpShm) + singleShmSize * group); comm.doReceive(tokenCount, groupStartPtr); } else { SingleChannelCommunicator comm(params.recvFieldInfo, params.expertParallelInfo, params.recvCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, reinterpret_cast(allWarpShm) + singleShmSize * group); comm.doReceive(tokenCount, groupStartPtr); } } } int computeMoeAlltoallMaxDynamicSharedMemorySize() { int devId = -1; TLLM_CUDA_CHECK(cudaGetDevice(&devId)); cudaFuncAttributes attr{}; TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, (void const*) moeAllToAllKernel<1>)); int staticSmem = static_cast(attr.sharedSizeBytes); int maxPerBlockShmOptin = 0; TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxPerBlockShmOptin, cudaDevAttrMaxSharedMemoryPerBlockOptin, devId)); return maxPerBlockShmOptin - staticSmem; } } // namespace fused_moe_impl void FusedMoeFieldInfo::fillMetaInfo( MoeSingleCommMeta* singleCommMeta, int topK, bool hasScales, bool hasBasicFields) const { singleCommMeta->singleCompactAlignedSize = computeSingleCompactSize(topK, hasScales, hasBasicFields); singleCommMeta->singleUncompactAlignedSize = computeSingleUncompactSize(topK, hasScales, hasBasicFields); singleCommMeta->singleTransferAlignedSize = FusedMoeProto::computeProtoTransfer128ByteAlignedSize(singleCommMeta->singleCompactAlignedSize); } void FusedMoeFieldInfo::fillFieldPlacementInfo(int topK, bool hasBasicFields) { int basicFieldSize = 0; if (hasBasicFields) { basicFieldSize = topK * sizeof(int) + (expertScales != nullptr ? topK * sizeof(float) : 0); // align to 16 bytes basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1) / MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK; } int offset = basicFieldSize; int unalignedFieldIndex = 0; for (int i = 0; i < fieldCount; i++) { fieldsInfo[i].compact16BOffset = offset / MoeCommFieldInfo::BYTES_PER_16B_BLOCK; offset += fieldsInfo[i].getFieldCompactSize(); fieldsInfo[i].unalignedFieldIndex = unalignedFieldIndex; if (fieldsInfo[i].alignedUnitBit < 4) { unalignedFieldIndex++; } } for (int i = fieldCount; i < MOE_COMM_FIELD_MAX_COUNT; i++) { fieldsInfo[i].setUnused(); } } void FusedMoeWorkspace::initializeLocalWorkspace(FusedMoeWorldInfo const& worldInfo) { int epSize = worldInfo.epInfo.epSize; int epRank = worldInfo.epInfo.epRank; size_t fifoSize = static_cast(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * epSize * channelCount; size_t senderSideInfoSize = sizeof(SenderSideFifoInfo) * epSize * channelCount; size_t receiverSideInfoSize = sizeof(ReceiverSideFifoInfo) * epSize * channelCount; uint64_t* localWorkspacePtr = workspacePtr + epRank * rankStrideInU64; TLLM_CU_CHECK(cuMemsetD32(reinterpret_cast(localWorkspacePtr), FusedMoeProto::INITIALIZED_VALUE, fifoSize / sizeof(uint32_t))); TLLM_CUDA_CHECK(cudaMemset( reinterpret_cast(localWorkspacePtr) + fifoSize, 0, senderSideInfoSize + receiverSideInfoSize)); } void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cudaStream_t stream) { bool hasBasicFields = params.sendFieldInfo.tokenSelectedSlots != nullptr; int warpSendShmSize = params.sendCommMeta.getSingleShmSize(); int warpRecvShmSize = params.recvCommMeta.getSingleShmSize(); int warpShmSize = warpSendShmSize; int epSize = params.worldInfo.epInfo.epSize; TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)", warpSendShmSize, warpRecvShmSize); int maxGroupCountPerCta = std::min(params.worldInfo.epInfo.epSize, FusedMoeCommunicator::MAX_GROUP_COUNT_PER_BLOCK); static int maxDynamicShmSize = fused_moe_impl::computeMoeAlltoallMaxDynamicSharedMemorySize(); int groupCountPerCta = std::min(maxGroupCountPerCta, maxDynamicShmSize / warpShmSize); int maxFieldCount = std::max(params.sendFieldInfo.fieldCount, params.recvFieldInfo.fieldCount); auto getFunc = [](int fieldCount) { switch (fieldCount) { case 1: return fused_moe_impl::moeAllToAllKernel<1>; case 2: return fused_moe_impl::moeAllToAllKernel<2>; case 3: return fused_moe_impl::moeAllToAllKernel<3>; case 4: return fused_moe_impl::moeAllToAllKernel<4>; case 5: return fused_moe_impl::moeAllToAllKernel<5>; case 6: return fused_moe_impl::moeAllToAllKernel<6>; case 7: return fused_moe_impl::moeAllToAllKernel<7>; case 8: return fused_moe_impl::moeAllToAllKernel<8>; default: return fused_moe_impl::moeAllToAllKernel<8>; } return fused_moe_impl::moeAllToAllKernel<8>; }; auto* kernelFn = getFunc(maxFieldCount); if (groupCountPerCta * warpShmSize > 48 * 1024) { TLLM_CUDA_CHECK(cudaFuncSetAttribute( kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, groupCountPerCta * warpShmSize)); } for (; groupCountPerCta > 0; groupCountPerCta--) { int dynamicShmSize = groupCountPerCta * warpShmSize; int numBlocks = 0; if (cudaOccupancyMaxActiveBlocksPerMultiprocessor( &numBlocks, kernelFn, WARP_SIZE * groupCountPerCta, dynamicShmSize) != cudaSuccess) { continue; } if (numBlocks >= 1) { break; } } TLLM_CHECK_WITH_INFO( groupCountPerCta >= 1, "computed groupCount=%d, warpShmSize=%d", groupCountPerCta, warpShmSize); int ctaPerChannel = (epSize + groupCountPerCta - 1) / groupCountPerCta; groupCountPerCta = (epSize + ctaPerChannel - 1) / ctaPerChannel; int totalDynamicShmSize = warpShmSize * groupCountPerCta; dim3 block = FusedMoeCommunicator::getLaunchBlockDim(groupCountPerCta); dim3 grid = FusedMoeCommunicator::getLaunchGridDim(params.worldInfo.epInfo.epSize, groupCountPerCta); kernelFn<<>>(params, workspace, hasBasicFields); TLLM_CUDA_CHECK(cudaGetLastError()); } int FusedMoeCommunicator::maxSmCount = -1; bool FusedMoeCommunicator::maxSmCountUsed = false; void setMaxUsableSmCount(int smCount) { FusedMoeCommunicator::setMaxUsableSmCount(smCount); } size_t getFusedMoeCommWorkspaceSize(int epSize) { int channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize); size_t workspaceSize = FusedMoeWorkspace::computeWorkspaceSizePreRank(epSize, channelCount); return workspaceSize; } void constructWorkspace(FusedMoeWorkspace* workspace, uint64_t* workspacePtr, size_t rankStrideInU64, int epSize) { workspace->workspacePtr = workspacePtr; workspace->rankStrideInU64 = rankStrideInU64; workspace->channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize); } void initializeFusedMoeLocalWorkspace(FusedMoeWorkspace* workspace, FusedMoeWorldInfo const& worldInfo) { workspace->initializeLocalWorkspace(worldInfo); } namespace fused_moe_comm_tests { __global__ void g2sKernel(FusedMoeFieldInfo allFieldInfo, MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmDump, bool hasBasicFields) { __shared__ uint64_t allWarpSmemBar[32]; extern __shared__ int4 allWarpShm[]; int laneId = threadIdx.x % WARP_SIZE; int warpId = threadIdx.x / WARP_SIZE; int warpCount = blockDim.x / WARP_SIZE; int tokenIndex = warpId + blockIdx.x * warpCount; if (tokenIndex >= tokenCount) { return; } int singleShmSize = singleCommMeta.singleUncompactAlignedSize; tensorrt_llm::kernels::fused_moe_impl::initSmemBar(&allWarpSmemBar[warpId], laneId); uint32_t phaseParity = 0; uint8_t* sharedMemoryBase = reinterpret_cast(allWarpShm) + singleShmSize * warpId; if (hasBasicFields) { tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); } else { tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); } for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE) { shmDump[tokenIndex * singleShmSize / sizeof(int) + offset] = reinterpret_cast(sharedMemoryBase)[offset]; } } void launchSingleG2S(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int tokenCount, int* shmDump, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream) { int warpShmSize = sendFieldInfo.computeSingleUncompactSize( expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields); dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1); MoeSingleCommMeta singleCommMeta; sendFieldInfo.fillMetaInfo( &singleCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields); TLLM_CUDA_CHECK( cudaFuncSetAttribute(g2sKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); g2sKernel<<>>( sendFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmDump, hasBasicFields); TLLM_CUDA_CHECK(cudaGetLastError()); } __global__ void s2gKernel(FusedMoeFieldInfo recvFieldInfo, MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmPreload, bool hasBasicFields) { extern __shared__ int4 allWarpShm[]; int laneId = threadIdx.x % WARP_SIZE; int warpId = threadIdx.x / WARP_SIZE; int warpCount = blockDim.x / WARP_SIZE; int tokenIndex = warpId + blockIdx.x * warpCount; if (tokenIndex >= tokenCount) { return; } int singleShmSize = singleCommMeta.singleUncompactAlignedSize; uint8_t* sharedMemoryBase = reinterpret_cast(allWarpShm) + singleShmSize * warpId; for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE) { reinterpret_cast(sharedMemoryBase)[offset] = shmPreload[tokenIndex * singleShmSize / sizeof(int) + offset]; } __syncwarp(); if (hasBasicFields) { tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); } else { tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); } tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead(); } void launchSingleS2G(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int tokenCount, int* shmPreload, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream) { int warpShmSize = recvFieldInfo.computeSingleUncompactSize( expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields); dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1); MoeSingleCommMeta singleCommMeta; recvFieldInfo.fillMetaInfo( &singleCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields); TLLM_CUDA_CHECK( cudaFuncSetAttribute(s2gKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); s2gKernel<<>>( recvFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmPreload, hasBasicFields); TLLM_CUDA_CHECK(cudaGetLastError()); } __global__ void loopbackKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo, MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta, int* recvIndexMapping, int tokenCount, bool hasBasicFields) { __shared__ uint64_t allWarpSmemBar[32]; extern __shared__ int4 allWarpShm[]; int laneId = threadIdx.x % WARP_SIZE; int warpId = threadIdx.x / WARP_SIZE; int warpCount = blockDim.x / WARP_SIZE; int tokenIndex = warpId + blockIdx.x * warpCount; if (tokenIndex >= tokenCount) { return; } int recvTokenIndex = recvIndexMapping[tokenIndex]; tensorrt_llm::kernels::fused_moe_impl::initSmemBar(&allWarpSmemBar[warpId], laneId); uint32_t phaseParity = 0; int singleShmSize = sendCommMeta.getSingleShmSize(); uint8_t* sharedMemoryBase = reinterpret_cast(allWarpShm) + singleShmSize * warpId; if (hasBasicFields) { tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); } else { tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); } if (hasBasicFields) { tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); } else { tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); } tensorrt_llm::kernels::fused_moe_impl::packAllFields(sendFieldInfo, tokenIndex, sharedMemoryBase, laneId); tokenIndex = recvTokenIndex; // switch to recvTokenIndex; tensorrt_llm::kernels::fused_moe_impl::unpackAllFields(recvFieldInfo, tokenIndex, sharedMemoryBase, laneId); if (hasBasicFields) { tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); } else { tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); } cp_async_bulk_wait_group_read<0>(); __syncwarp(); } // G2S -> Pack -> Unpack -> S2G void launchLoopback(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int* recvIndexMapping, int tokenCount, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream) { MoeSingleCommMeta sendCommMeta, recvCommMeta; sendFieldInfo.fillMetaInfo( &sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields); recvFieldInfo.fillMetaInfo( &recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields); int warpSendShmSize = sendCommMeta.getSingleShmSize(); int warpRecvShmSize = recvCommMeta.getSingleShmSize(); int warpShmSize = warpSendShmSize; TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)", warpSendShmSize, warpRecvShmSize); dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1); TLLM_CUDA_CHECK( cudaFuncSetAttribute(loopbackKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); loopbackKernel<<>>(sendFieldInfo, recvFieldInfo, expertParallelInfo, sendCommMeta, recvCommMeta, recvIndexMapping, tokenCount, hasBasicFields); TLLM_CUDA_CHECK(cudaGetLastError()); } template __global__ void localFifoSendRecvKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo, MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta, FusedMoeWorkspace fusedMoeWorkspace, int* sendIndexMapping, int* recvIndexMapping, int tokenCount) { __shared__ uint64_t allWarpSmemBar[32]; extern __shared__ int4 allWarpShm[]; FusedMoeWorldInfo worldInfo; worldInfo.epInfo.epRank = 0; worldInfo.epInfo.epSize = 1; int warpId = threadIdx.x / WARP_SIZE; int warpCount = blockDim.x / WARP_SIZE; FusedMoePairInfo pairInfo; pairInfo.senderRank = 0; pairInfo.receiverRank = 0; pairInfo.channel = blockIdx.z * warpCount + warpId; pairInfo.runChannelCount = gridDim.z * warpCount; if (blockIdx.y == 0) { tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator senderComm(sendFieldInfo, expertParallelInfo, sendCommMeta, fusedMoeWorkspace, worldInfo, pairInfo, &allWarpSmemBar[warpId], reinterpret_cast(&allWarpShm[0]) + warpId * sendCommMeta.getSingleShmSize()); senderComm.doSend(tokenCount, sendIndexMapping); } else { tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator recverComm(recvFieldInfo, expertParallelInfo, recvCommMeta, fusedMoeWorkspace, worldInfo, pairInfo, &allWarpSmemBar[warpId], reinterpret_cast(&allWarpShm[0]) + warpId * recvCommMeta.getSingleShmSize()); recverComm.doReceive(tokenCount, recvIndexMapping); } } void launchLocalFifoSendRecv(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int* sendIndexMapping, int* recvIndexMapping, FusedMoeWorkspace fusedMoeWorkspace, int tokenCount, int warpsPerBlock, int blockChannelCount, bool hasBasicFields, cudaStream_t stream) { MoeSingleCommMeta sendCommMeta, recvCommMeta; sendFieldInfo.fillMetaInfo( &sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields); recvFieldInfo.fillMetaInfo( &recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields); int warpSendShmSize = sendCommMeta.getSingleShmSize(); int warpRecvShmSize = recvCommMeta.getSingleShmSize(); int warpShmSize = warpSendShmSize; TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)", warpSendShmSize, warpRecvShmSize); dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); dim3 gridDim(1, 2, blockChannelCount); auto* kernelFn = localFifoSendRecvKernel<>; if (hasBasicFields) { kernelFn = localFifoSendRecvKernel; } else { kernelFn = localFifoSendRecvKernel; } TLLM_CUDA_CHECK( cudaFuncSetAttribute(kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); kernelFn<<>>(sendFieldInfo, recvFieldInfo, expertParallelInfo, sendCommMeta, recvCommMeta, fusedMoeWorkspace, sendIndexMapping, recvIndexMapping, tokenCount); TLLM_CUDA_CHECK(cudaGetLastError()); } } // namespace fused_moe_comm_tests } // namespace kernels } // namespace tensorrt_llm