TensorRT-LLMs/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.h
Balaram Buddharaju 9a1750c8f9
[TRTLLM-9493][noop] Refactor fusedMoeCommKernels to enable code sharing (#9922)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-14 11:29:30 -08:00

566 lines
24 KiB
C++

/*
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <map>
#include <cuda_runtime_api.h>
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/moeCommKernelsCommon.h"
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
struct ALIGN_256 SenderSideFifoInfo
{
volatile uint64_t head; // write position
volatile uint64_t tail; // read position
};
struct ALIGN_256 ReceiverSideFifoInfo
{
volatile uint64_t head; // write position do we use this?
volatile uint64_t tail; // read position
};
// struct holding Send/Recv data pointer and its displacement information.
struct SendRecvIndices
{
int const* rankCountCumSum; // length = epSize
int* rankLocalIndices; // length = rankCountCumSum[epRank] - rankCountCumSum[epRank - 1] if epRank > 0 else
// rankCountCumSum[epRank]
#ifdef __CUDACC__
__inline__ __device__ int getCount(int rank) const
{
return rank == 0 ? rankCountCumSum[rank] : rankCountCumSum[rank] - rankCountCumSum[rank - 1];
}
__inline__ __device__ int getRankStart(int rank) const
{
return rank == 0 ? 0 : rankCountCumSum[rank - 1];
}
__inline__ __device__ int* getGroupStart(int rank, int& tokenCount) const
{
tokenCount = getCount(rank);
int rankStart = getRankStart(rank);
return rankLocalIndices + rankStart;
}
#endif
};
struct MoeCommFieldInfo
{
uint8_t* dataPtrBase;
uint8_t alignedUnitBit; // 0, 1, 2, 3, 4 (for 1, 2, 4, 8, 16 Bytes), smallest aligned unit.
uint16_t alignedUnitCount; // data count in aligned unit
uint16_t alignedUnitStride; // data stride in aligned unit
uint8_t unalignedFieldIndex; // the index of unaligned Field, no decrease with field index
uint16_t compact16BOffset; // aligned to 16 Bytes, offset is count of 16 Byte
cudaDataType_t originalDataType; // original data type, used for low precision alltoall.
static constexpr uint64_t kAlign16BytePtrMask = (1ULL << 4) - 1;
static constexpr uint32_t kAligned16BMask = (1 << 4) - 1;
// Constants for memory alignment and access (reference common constants for consistency)
static constexpr int BYTES_PER_128B_BLOCK = tensorrt_llm::kernels::BYTES_PER_128B_BLOCK;
static constexpr int INTS_PER_128B_BLOCK = tensorrt_llm::kernels::INTS_PER_128B_BLOCK;
static constexpr int UINT64_PER_128B_BLOCK = tensorrt_llm::kernels::UINT64_PER_128B_BLOCK;
static constexpr int BYTES_PER_16B_BLOCK = tensorrt_llm::kernels::BYTES_PER_16B_BLOCK;
// Will pad one 16 byte for each unaligned field, then head and tail 16 byte might not be aligned
// Fill single field info, the fields that need global info is not filled here.
__host__ void fillFieldInfo(
uint8_t* dataPtr, size_t elementSize, int vectorSize, int stride, cudaDataType_t originalDataType);
__host__ void setUnused()
{
dataPtrBase = nullptr;
alignedUnitBit = 4;
alignedUnitCount = 0;
alignedUnitStride = 0;
unalignedFieldIndex = 0;
compact16BOffset = 0;
}
__device__ __host__ __forceinline__ int getFieldUncompactSize() const
{
int alignedUnitBytes = 1 << alignedUnitBit;
int currentFieldSize = alignedUnitCount * alignedUnitBytes;
if (alignedUnitBytes != 16)
{
constexpr int alignedUnitBytes = BYTES_PER_16B_BLOCK;
currentFieldSize = currentFieldSize / alignedUnitBytes * alignedUnitBytes;
currentFieldSize += alignedUnitBytes * 2;
}
return currentFieldSize;
}
__device__ __host__ __forceinline__ int getFieldCompactSize() const
{
int alignedUnitBytes = 1 << alignedUnitBit;
int currentFieldSize = alignedUnitCount * alignedUnitBytes;
// Align to 16 bytes for compact size
return (currentFieldSize + BYTES_PER_16B_BLOCK - 1) / BYTES_PER_16B_BLOCK * BYTES_PER_16B_BLOCK;
}
__device__ __forceinline__ int getCompactShmOffset() const
{
return compact16BOffset * BYTES_PER_16B_BLOCK;
}
__device__ __forceinline__ int getUncompactShmOffset() const
{
// each unaligned field need 16 byte head and 16 byte tail
return compact16BOffset * BYTES_PER_16B_BLOCK + unalignedFieldIndex * BYTES_PER_16B_BLOCK;
}
__device__ __forceinline__ int getMemmoveOffsets(int index) const
{
int alignedBytes = 1 << alignedUnitBit;
uint8_t* dataPtr = dataPtrBase + index * alignedBytes * alignedUnitStride;
int offset = reinterpret_cast<uint64_t>(dataPtr) & kAlign16BytePtrMask;
return offset + unalignedFieldIndex * BYTES_PER_16B_BLOCK;
}
__device__ __forceinline__ uint8_t* getRawPtr(int index, int* rawSize) const
{
int alignedBytes = 1 << alignedUnitBit;
uint8_t* dataPtr = dataPtrBase + static_cast<size_t>(index) * alignedBytes * alignedUnitStride;
if (rawSize != nullptr)
{
*rawSize = alignedUnitCount * alignedBytes;
}
return dataPtr;
}
__device__ __forceinline__ uint8_t* get16BAlignedLoadCopyRange(int index, int* copyByteCount) const
{
int rawSize;
uint8_t* rawDataPtr = getRawPtr(index, &rawSize);
uint8_t* rawEndPtr = rawDataPtr + rawSize;
uint8_t* alignedDataPtr
= reinterpret_cast<uint8_t*>(reinterpret_cast<uint64_t>(rawDataPtr) & (~kAlign16BytePtrMask));
uint32_t copySize = rawEndPtr - alignedDataPtr;
*copyByteCount
= (copySize & kAligned16BMask) != 0 ? (copySize & (~kAligned16BMask)) + BYTES_PER_16B_BLOCK : copySize;
return alignedDataPtr;
}
__device__ __forceinline__ uint8_t* get16BAlignedStoreCopyRange(
int index, int* copyByteCount, int laneId, int* headTailShmIdx, int* headTailGlobalIdx) const
{
int rawSize;
uint8_t* rawDataPtr = getRawPtr(index, &rawSize);
uint8_t* rawEndPtr = rawDataPtr + rawSize;
int offset = reinterpret_cast<uint64_t>(rawDataPtr) & kAlign16BytePtrMask;
uint8_t* alignedDataPtr
= reinterpret_cast<uint8_t*>(reinterpret_cast<uint64_t>(rawDataPtr) + BYTES_PER_16B_BLOCK - offset);
uint8_t* alignedEndPtr
= reinterpret_cast<uint8_t*>(reinterpret_cast<uint64_t>(rawEndPtr) & (~kAlign16BytePtrMask));
int alignedCopyBytes = alignedEndPtr - alignedDataPtr;
if (alignedCopyBytes < 0)
{
alignedCopyBytes = 0;
}
*copyByteCount = alignedCopyBytes;
if (laneId < BYTES_PER_16B_BLOCK)
{
*headTailShmIdx = laneId;
}
else
{
*headTailShmIdx = laneId + alignedCopyBytes;
}
*headTailGlobalIdx = *headTailShmIdx - offset;
if (*headTailGlobalIdx < 0 || *headTailGlobalIdx >= rawSize)
{
*headTailGlobalIdx = -1;
*headTailShmIdx = -1;
}
return alignedDataPtr;
}
};
// Maximum number of field supported, except tokenSelectedExpert and expertScales
static constexpr int MOE_COMM_FIELD_MAX_COUNT = 8;
struct MoeSingleCommMeta
{
int singleTransferAlignedSize; // transfer size aligned to 128 bytes.
int singleCompactAlignedSize; // compact buffer is always aligned to 128 bytes
int singleUncompactAlignedSize; // uncompact shared memory size, aligned to 128 bytes, might be larger than compact
// buffer if unaligned field exist.
// TODO: Do we need reduce shared memory usage, make it able to be smaller, and enable multiple wave?
__device__ __host__ __forceinline__ int getTransfer128ByteCount() const
{
return singleTransferAlignedSize / MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
}
__device__ __host__ __forceinline__ int getCompactData128ByteCount() const
{
return singleCompactAlignedSize / MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
}
__device__ __host__ __forceinline__ int getSingleShmSize() const
{
return std::max(singleUncompactAlignedSize, singleTransferAlignedSize);
}
};
struct FusedMoeWorldInfo
{
MoeEpWorldInfo epInfo;
};
struct FusedMoePairInfo
{
int senderRank;
int receiverRank;
int channel;
int runChannelCount;
};
class FusedMoeCommunicator
{
public:
static constexpr int FIFO_DEPTH = 4;
static constexpr int FIFO_ENTRY_BYTES = 256 * 1024;
static constexpr int FIFO_ENTRY_128_BYTE_COUNT = FIFO_ENTRY_BYTES / 128;
static constexpr int FIFO_TOTAL_BYTES = FIFO_ENTRY_BYTES * FIFO_DEPTH;
static constexpr int FIFO_TOTAL_U64 = FIFO_TOTAL_BYTES / sizeof(uint64_t);
// Reference common constant for consistency
static constexpr int MAX_GROUP_COUNT_PER_BLOCK = tensorrt_llm::kernels::MAX_GROUP_COUNT_PER_BLOCK;
// Reference common constant for consistency
static constexpr int WARP_SIZE = tensorrt_llm::kernels::WARP_SIZE;
static int maxSmCount;
static bool maxSmCountUsed;
static void setMaxUsableSmCount(int maxUsableSmCount)
{
TLLM_CHECK_WITH_INFO(
FusedMoeCommunicator::maxSmCountUsed == false, "setMaxUsableSmCount can be called only before it is used");
int smCount = tensorrt_llm::common::getMultiProcessorCount();
if (maxUsableSmCount > smCount)
{
TLLM_LOG_WARNING("setMaxUsableSmCount, maxUsableSmCount=%d, larger than smCount=%d, using smCount instead",
maxUsableSmCount, smCount);
maxUsableSmCount = smCount;
}
FusedMoeCommunicator::maxSmCount = maxUsableSmCount;
}
static int getMaxUsableSmCount()
{
FusedMoeCommunicator::maxSmCountUsed = true;
if (FusedMoeCommunicator::maxSmCount == -1)
{
int smCount = tensorrt_llm::common::getMultiProcessorCount();
FusedMoeCommunicator::maxSmCount = smCount;
}
return FusedMoeCommunicator::maxSmCount;
}
static int computeMoeCommChannelCount(int epSize)
{
int smCount = getMaxUsableSmCount();
int blockCountPerChannel = (epSize + MAX_GROUP_COUNT_PER_BLOCK - 1) / MAX_GROUP_COUNT_PER_BLOCK;
blockCountPerChannel *= 2; // for send and recv
TLLM_CHECK_WITH_INFO(
blockCountPerChannel <= smCount, "GPU should support at lease one channel, usableSmCount=%d", smCount);
int perferredChannel = smCount / 2 / blockCountPerChannel; // use half SMs for communication
int channelCount = std::max(perferredChannel, 1); // at lease one channel
return channelCount;
}
static int getMoeCommChannelCount(int epSize)
{
static std::map<int, int> channelCountMap{};
auto iter = channelCountMap.find(epSize);
if (iter == channelCountMap.end())
{
auto channelCount = FusedMoeCommunicator::computeMoeCommChannelCount(epSize);
channelCountMap[epSize] = channelCount;
return channelCount;
}
return iter->second;
}
static dim3 getLaunchBlockDim(int groupCountPerCta)
{
return dim3(WARP_SIZE, groupCountPerCta);
}
static dim3 getLaunchGridDim(int epSize, int groupCountPerCta)
{
int maxChannelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize);
int targetCtaCount = (epSize + MAX_GROUP_COUNT_PER_BLOCK - 1) / MAX_GROUP_COUNT_PER_BLOCK * maxChannelCount * 2;
int ctaPerChannel = (epSize + groupCountPerCta - 1) / groupCountPerCta;
int ctaLimitedChannelCount = targetCtaCount / 2 / ctaPerChannel;
ctaLimitedChannelCount = std::max(1, ctaLimitedChannelCount);
int channelCount = std::min(ctaLimitedChannelCount, maxChannelCount);
return dim3(ctaPerChannel, channelCount, 2);
}
};
size_t getFusedMoeCommWorkspaceSize(int epSize);
struct FusedMoeFieldInfo
{
int8_t isBasicInterleaved; // using tokenSelectedSlots and expertScales interleaving?
int32_t* tokenSelectedSlots;
float* expertScales; // can be nullptr if no scale is used(all 1.0), if so, interleaved should all be 0
int fieldCount;
MoeCommFieldInfo fieldsInfo[MOE_COMM_FIELD_MAX_COUNT];
__host__ int computeSingleCompactSize(int topK, bool hasScales, bool hasBasicFields) const
{
int basicFieldSize = 0;
if (hasBasicFields)
{
basicFieldSize = topK * sizeof(int) + (hasScales ? topK * sizeof(float) : 0);
// align to 16 bytes
basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1)
/ MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
}
int otherFieldSize = 0;
for (int i = 0; i < fieldCount; i++)
{
MoeCommFieldInfo const& fieldInfo = fieldsInfo[i];
otherFieldSize += fieldInfo.getFieldCompactSize();
}
int totalSize = basicFieldSize + otherFieldSize;
constexpr int totalSizeAlignment = MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
totalSize = (totalSize + totalSizeAlignment - 1) / totalSizeAlignment * totalSizeAlignment;
return totalSize;
}
__host__ int computeSingleUncompactSize(int topK, bool hasScales, bool hasBasicFields) const
{
int basicFieldSize = 0;
if (hasBasicFields)
{
basicFieldSize = topK * sizeof(int) + (hasScales ? topK * sizeof(float) : 0);
// align to 16 bytes
basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1)
/ MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
}
int otherFieldSize = 0;
for (int i = 0; i < fieldCount; i++)
{
MoeCommFieldInfo const& fieldInfo = fieldsInfo[i];
otherFieldSize += fieldInfo.getFieldUncompactSize();
}
int totalSize = basicFieldSize + otherFieldSize;
constexpr int totalSizeAlignment = MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
totalSize = (totalSize + totalSizeAlignment - 1) / totalSizeAlignment * totalSizeAlignment;
return totalSize;
}
template <typename T = int, bool IS_SLOTS = true>
__device__ __forceinline__ T* getBasicFieldPtr(int tokenIndex, int selectedIndex, int topK) const
{
T* fieldPtr = nullptr;
fieldPtr = IS_SLOTS ? reinterpret_cast<T*>(tokenSelectedSlots) : reinterpret_cast<T*>(expertScales);
if (fieldPtr == nullptr || selectedIndex >= topK)
{
return nullptr;
}
int tokenStride = isBasicInterleaved ? topK * 2 : topK;
int elementStride = isBasicInterleaved ? 2 : 1;
return fieldPtr + tokenIndex * tokenStride + selectedIndex * elementStride;
}
__device__ __forceinline__ int* getTokenSelectedSlotsPtr(int tokenIndex, int selectedIndex, int topK) const
{
return getBasicFieldPtr<int, true>(tokenIndex, selectedIndex, topK);
}
__device__ __forceinline__ float* getScalePtr(int tokenIndex, int selectedIndex, int topK) const
{
return getBasicFieldPtr<float, false>(tokenIndex, selectedIndex, topK);
}
void fillMetaInfo(
MoeSingleCommMeta* singleCommMeta, int topK, bool hasScales, bool hasBasicFields, bool isLowPrecision) const;
void fillFieldPlacementInfo(int topK, bool hasBasicFields);
};
struct FusedMoeCommKernelParam
{
FusedMoeWorldInfo worldInfo;
MoeExpertParallelInfo expertParallelInfo; // expertCount inside should be slotCount if using redundant experts.
MoeSingleCommMeta sendCommMeta;
MoeSingleCommMeta recvCommMeta;
SendRecvIndices sendIndices;
SendRecvIndices recvIndices;
FusedMoeFieldInfo sendFieldInfo;
FusedMoeFieldInfo recvFieldInfo;
bool isLowPrecision;
};
/*
* Workspace Layout:
* Ri: Rank i
* N: Number of GPUs, e.g. EpSize or WorldSize, n = N - 1
* Ci: Channel i
* M: Number of Channels, m = M - 1
* MMr: Memory Mapped from Rank r, physically located at rank r, and mapped to all ranks.
*
* Whole workspace memory space:
* ---------------------------------------------------------------------------------------------------
* |<-- MM0 --> |<-- MM1 --> |<-- MM2 --> | ...... |<-- MMn --> |
* ^ ^ ^ ^ ^ ^
* 0 rankStrideInU64 2*rankStrideInU64 3*rankStrideInU64 n*rankStrideInU64 N*rankStrideInU64
*
* For each MMr, the layout is:
* -------------------------------------------------------------------------------------------------
* |<--- FIFO memory --->|<--- SenderSideFifoInfo memory --->|<--- ReceiverSideFifoInfo memory --->|
* -------------------------------------------------------------------------------------------------
*
* For each FIFO memory, it is physically placed at the receiver rank.
* To find the FIFO whose receiver is rank r, we need to find that in the FIFO memory of MMr.
* The layout of FIFO memory of each MMR is(here rank is the sender rank):
* -------------------------------------------------------------------------------------------------
* | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm |
* |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->|
* -------------------------------------------------------------------------------------------------
* Each R*C* has length of FIFO_TOTAL_U64 in uint64_t, which is internally divided into FIFO_DEPTH entries of
* size FIFO_ENTRY_BYTES each.
*
* For each SenderSideFifoInfo memory, it is physically placed at the sender rank.
* To find the SenderSideFifoInfo whose sender is rank r, we need to find that in the FIFO memory of MMr.
* The layout of SenderSideFifoInfo memory of each MMR is(here rank is the receiver rank):
* -------------------------------------------------------------------------------------------------
* | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm |
* |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->|
* -------------------------------------------------------------------------------------------------
* Each R*C* is one struct of SenderSideFifoInfo. There are total M * N SenderSideFifoInfo in each MMR.
*
* For each ReceiverSideFifoInfo memory, it is physically placed at the receiver rank.
* To find the ReceiverSideFifoInfo whose receiver is rank r, we need to find that in the FIFO memory of MMr.
* The layout of ReceiverSideFifoInfo memory of each MMR is(here rank is the sender rank):
* -------------------------------------------------------------------------------------------------
* | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm |
* |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->|
* -------------------------------------------------------------------------------------------------
* Each R*C* is one struct of ReceiverSideFifoInfo. There are total M * N ReceiverSideFifoInfo in each MMR.
*/
struct FusedMoeWorkspace
{
uint64_t* workspacePtr;
size_t rankStrideInU64;
int channelCount;
template <bool isSenderSideBuffer>
__device__ __forceinline__ uint8_t* commonGetPtrBase(
FusedMoePairInfo const& pairInfo, size_t fieldOffset, int fieldSingleSize) const
{
int mappedMemoryrank = isSenderSideBuffer ? pairInfo.senderRank : pairInfo.receiverRank;
int rankInsideMappedMemory = isSenderSideBuffer ? pairInfo.receiverRank : pairInfo.senderRank;
auto* mappedMemory = reinterpret_cast<uint8_t*>(workspacePtr + mappedMemoryrank * rankStrideInU64);
mappedMemory += fieldOffset;
mappedMemory += rankInsideMappedMemory * channelCount * fieldSingleSize;
mappedMemory += pairInfo.channel * fieldSingleSize;
return mappedMemory;
}
__device__ __forceinline__ uint64_t* getFifoBasePtr(
FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const
{
constexpr int fieldSingleSize = FusedMoeCommunicator::FIFO_TOTAL_BYTES;
return reinterpret_cast<uint64_t*>(commonGetPtrBase<false>(pairInfo, 0, fieldSingleSize));
}
__device__ __forceinline__ SenderSideFifoInfo* getSenderSideFifoInfo(
FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const
{
constexpr int fieldSingleSize = sizeof(SenderSideFifoInfo);
size_t fieldOffset
= static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * worldInfo.epInfo.epSize * channelCount;
return reinterpret_cast<SenderSideFifoInfo*>(commonGetPtrBase<true>(pairInfo, fieldOffset, fieldSingleSize));
}
__device__ __forceinline__ ReceiverSideFifoInfo* getReceiverSideFifoInfo(
FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const
{
constexpr int fieldSingleSize = sizeof(ReceiverSideFifoInfo);
size_t fieldOffset
= static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * worldInfo.epInfo.epSize * channelCount
+ sizeof(SenderSideFifoInfo) * worldInfo.epInfo.epSize * channelCount;
return reinterpret_cast<ReceiverSideFifoInfo*>(commonGetPtrBase<false>(pairInfo, fieldOffset, fieldSingleSize));
}
static size_t computeWorkspaceSizePreRank(int epSize, int channelCount)
{
size_t fifoSize = static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * epSize * channelCount;
size_t senderSideInfoSize = sizeof(SenderSideFifoInfo) * epSize * channelCount;
size_t receiverSideInfoSize = sizeof(ReceiverSideFifoInfo) * epSize * channelCount;
return fifoSize + senderSideInfoSize + receiverSideInfoSize;
}
void initializeLocalWorkspace(FusedMoeWorldInfo const& worldInfo);
};
void setMaxUsableSmCount(int smCount);
void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cudaStream_t stream);
void constructWorkspace(FusedMoeWorkspace* workspace, uint64_t* workspacePtr, size_t rankStrideInU64, int epSize);
void initializeFusedMoeLocalWorkspace(FusedMoeWorkspace* workspace, FusedMoeWorldInfo const& worldInfo);
namespace fused_moe_comm_tests
{
// Functions for testing
void launchSingleG2S(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
int tokenCount, int* shmDump, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream);
void launchSingleS2G(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
int tokenCount, int* shmPreload, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream);
void launchLoopback(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
MoeExpertParallelInfo const& expertParallelInfo, int* recvIndexMapping, int tokenCount, int warpsPerBlock,
bool hasBasicFields, cudaStream_t stream);
void launchLocalFifoSendRecv(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
MoeExpertParallelInfo const& expertParallelInfo, int* sendIndexMapping, int* recvIndexMapping,
FusedMoeWorkspace fusedMoeWorkspace, int tokenCount, int warpsPerBlock, int blockChannelCount, bool hasBasicFields,
cudaStream_t stream);
} // namespace fused_moe_comm_tests
} // namespace kernels
TRTLLM_NAMESPACE_END