TensorRT-LLMs/cpp/tensorrt_llm/kernels/moeCommKernels.h
dongxuy04 16535991b2
feat: Add MNNVL MoE A2A support (#3504)
* add MNNVL memory mapping support

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>

* add more MPI environment for trtllm-llmapi-launch

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>

* add MoE communication and prepare kernels

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>

* add MNNVL AlltoAll support for DeepSeekV3

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>

* add output dump for throughput benchmark

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>

* support dynamic kernel launch grid

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>

* address review comments

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>

* address review comments #2

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>

---------

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
2025-04-25 17:29:08 +08:00

269 lines
9.8 KiB
C++

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <map>
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm::kernels
{
#ifdef __CUDACC__
#define ALIGN_256 __align__(256)
#else
#define ALIGN_256 alignas(256)
#endif
struct ALIGN_256 MoeCommFifoConnInfo
{
volatile uint64_t head; // write position
volatile uint64_t tail; // read position
};
constexpr int WARP_SIZE = 32;
constexpr uint32_t WARP_MASK = 0xffffffff;
constexpr int RECV_FIFO_DEPTH = 8;
constexpr int RECV_FIFO_ENTRY_BYTES = 256 * 1024;
constexpr int RECV_FIFO_ENTRY_U64 = RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t);
constexpr int RECV_FIFO_TOTAL_BYTES = RECV_FIFO_DEPTH * RECV_FIFO_ENTRY_BYTES;
constexpr int RECV_FIFO_TOTAL_U64 = RECV_FIFO_TOTAL_BYTES / sizeof(uint64_t);
class AllToAllChannelCommunicatorBase
{
public:
static constexpr int GROUP_COUNT_PER_BLOCK = 8;
static_assert(GROUP_COUNT_PER_BLOCK <= 8, "GROUP_COUNT_PER_BLOCK must be less than or equal to 8");
static constexpr int WARP_PER_GROUP = 2;
static constexpr int U64_DATA_REG_PER_THREAD = 8;
// A packet is a warp-sized chunk of data that is sent or received in one go,
// but may be split into multiple 64-bit registers, the number of which is U64_DATA_REG_PER_THREAD.
static constexpr int PACKET_SIZE_IN_U64 = WARP_SIZE * U64_DATA_REG_PER_THREAD;
static constexpr int PACKET_SIZE_IN_BYTES = PACKET_SIZE_IN_U64 * sizeof(uint64_t);
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 = (WARP_SIZE - 2) * U64_DATA_REG_PER_THREAD;
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET = DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 * sizeof(uint64_t);
static constexpr int U64_ELT_COUNT_PER_PACKET = PACKET_SIZE_IN_BYTES / sizeof(uint64_t);
static constexpr int PACKET_COUNT_PER_FIFO_ENTRY = RECV_FIFO_ENTRY_BYTES / PACKET_SIZE_IN_BYTES;
static constexpr int GROUP_MAX_INDICE_COUNT
= RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t) / (WARP_SIZE * U64_DATA_REG_PER_THREAD);
struct GroupSharedBuffer
{
int groupIndiceBuffer[GROUP_MAX_INDICE_COUNT];
int groupStartIndice;
int groupEndIndice;
};
static void setMaxUsableSmCount(int maxUsableSmCount)
{
TLLM_CHECK_WITH_INFO(AllToAllChannelCommunicatorBase::maxSmCountUsed == false,
"setMaxUsableSmCount can be called only before it is used");
int smCount = tensorrt_llm::common::getMultiProcessorCount();
if (maxUsableSmCount > smCount)
{
TLLM_LOG_WARNING("setMaxUsableSmCount, maxUsableSmCount=%d, larger than smCount=%d, using smCount instead",
maxUsableSmCount, smCount);
maxUsableSmCount = smCount;
}
AllToAllChannelCommunicatorBase::maxSmCount = maxUsableSmCount;
}
static int getMaxUsableSmCount()
{
AllToAllChannelCommunicatorBase::maxSmCountUsed = true;
if (AllToAllChannelCommunicatorBase::maxSmCount == -1)
{
int smCount = tensorrt_llm::common::getMultiProcessorCount();
AllToAllChannelCommunicatorBase::maxSmCount = smCount;
}
return AllToAllChannelCommunicatorBase::maxSmCount;
}
static int computeMoeCommChannelCount(int epSize)
{
int smCount = getMaxUsableSmCount();
int blockCountPerChannel = (epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK;
blockCountPerChannel *= 2; // for send and recv
TLLM_CHECK_WITH_INFO(
blockCountPerChannel <= smCount, "GPU should support at lease one channel, usableSmCount=%d", smCount);
int perferredChannel = smCount / 2 / blockCountPerChannel; // use half SMs for communication
int channelCount = std::max(perferredChannel, 1); // at lease one channel
return channelCount;
}
static int getMoeCommChannelCount(int epSize)
{
static std::map<int, int> channelCountMap{};
auto iter = channelCountMap.find(epSize);
if (iter == channelCountMap.end())
{
auto channelCount = AllToAllChannelCommunicatorBase::computeMoeCommChannelCount(epSize);
channelCountMap[epSize] = channelCount;
return channelCount;
}
return iter->second;
}
static dim3 getLaunchBlockDim()
{
return dim3(WARP_SIZE * WARP_PER_GROUP, GROUP_COUNT_PER_BLOCK);
}
static dim3 getLaunchGridDim(int epSize)
{
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
return dim3((epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK, channelCount, 2);
}
protected:
static int maxSmCount;
static bool maxSmCountUsed;
};
inline size_t getMoeCommWorkspaceSize(int epSize)
{
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
return RECV_FIFO_TOTAL_BYTES * epSize * channelCount + sizeof(MoeCommFifoConnInfo) * epSize * channelCount;
}
struct MoeEpWorldInfo
{
int epSize;
int epRank;
};
struct MoeExpertParallelInfo
{
int expertCount = -1;
int topK = 1;
};
struct SendRecvDataInfo
{
int vectorSizeInU64;
// pre-computed at host side for GPU kernel
int dataPacketCountPerVector;
int vectorCountPerFifoEntry;
void ComputeDataPacketCountPerVector()
{
dataPacketCountPerVector
= (vectorSizeInU64 * sizeof(uint64_t) + AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET - 1)
/ AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET;
}
void ComputeVectorCountPerFifoEntry()
{
ComputeDataPacketCountPerVector();
vectorCountPerFifoEntry
= AllToAllChannelCommunicatorBase::PACKET_COUNT_PER_FIFO_ENTRY / dataPacketCountPerVector;
}
void DoPreCompute()
{
ComputeDataPacketCountPerVector();
ComputeVectorCountPerFifoEntry();
assert(vectorCountPerFifoEntry <= AllToAllChannelCommunicatorBase::GROUP_MAX_INDICE_COUNT);
}
};
// struct holding Send/Recv data pointer and its displacement information.
struct SendRecvDispls
{
uint64_t* dataPtr;
int const* rankCountCumSum; // length = epSize
int const* rankLocalIndices; // length = rankCountCumSum[epRank] - rankCountCumSum[epRank - 1] if epRank > 0 else
// rankCountCumSum[epRank]
int vectorStrideInU64;
#ifdef __CUDACC__
__inline__ __device__ int getCount(int rank) const
{
return rank == 0 ? rankCountCumSum[rank] : rankCountCumSum[rank] - rankCountCumSum[rank - 1];
}
__inline__ __device__ int getRankStart(int rank) const
{
return rank == 0 ? 0 : rankCountCumSum[rank - 1];
}
__inline__ __device__ int getRealVectorIndice(int globalVectorIndex) const
{
return rankLocalIndices[globalVectorIndex];
}
__inline__ __device__ uint64_t* getVectorDataPtr(int realVectorIndex) const
{
return dataPtr + realVectorIndex * vectorStrideInU64;
}
#endif
};
struct MoeCommWorkspace
{
uint64_t* workspacePtr;
size_t rankStrideInU64;
#ifdef __CUDACC__
__inline__ __device__ uint64_t* getFifoBasePtr(
bool isSender, int epRank, int peerRank, int channel, int channelCount) const
{
// fifo itself is in receiver's side.
if (isSender)
{
return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
}
else
{
return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
}
}
__inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo(
bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const
{
// fifoInfo is in sender's side.
uint64_t* fifoInfoPtrU64 = workspacePtr + RECV_FIFO_TOTAL_U64 * channelCount * epSize;
int strideIndice = isSender ? epRank : peerRank;
int fifoInfoIndice = isSender ? peerRank : epRank;
fifoInfoPtrU64 += strideIndice * rankStrideInU64;
MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*) fifoInfoPtrU64;
return fifoInfoPtr + fifoInfoIndice * channelCount + channel;
}
#endif
};
void setMaxUsableSmCount(int smCount);
void moeAllToAll(MoeEpWorldInfo worldInfo, SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls,
SendRecvDispls recvDispls, MoeCommWorkspace workspace, cudaStream_t stream);
void moeAllToAllPrepareIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank, int const* gatheredTargetRankIds, int const* realRankTokenCountCumSum,
int* localGatheredIndices, // indices of gatheredTargetRankIds that has the local rank in topK
int* sendRankCountCumSum, int* sendRankLocalIndices, int* recvRankCountCumSum, int* recvRankLocalIndices,
// the rankCountCumSum of combineRecv should be the same as sendRankCountCumSum
int* backwardRecvRankLocalIndices, cudaStream_t stream);
void moeLocalGather(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices, int const* gatheredExpertIds,
float const* gatheredScales, int* localExpertIds, float* localScales, cudaStream_t stream);
} // namespace tensorrt_llm::kernels