mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* add MNNVL memory mapping support Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> * add more MPI environment for trtllm-llmapi-launch Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> * add MoE communication and prepare kernels Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> * add MNNVL AlltoAll support for DeepSeekV3 Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> * add output dump for throughput benchmark Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> * support dynamic kernel launch grid Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> * address review comments Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> * address review comments #2 Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --------- Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
269 lines
9.8 KiB
C++
269 lines
9.8 KiB
C++
/*
|
|
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <map>
|
|
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
|
|
namespace tensorrt_llm::kernels
|
|
{
|
|
|
|
#ifdef __CUDACC__
|
|
#define ALIGN_256 __align__(256)
|
|
#else
|
|
#define ALIGN_256 alignas(256)
|
|
#endif
|
|
|
|
struct ALIGN_256 MoeCommFifoConnInfo
|
|
{
|
|
volatile uint64_t head; // write position
|
|
volatile uint64_t tail; // read position
|
|
};
|
|
|
|
constexpr int WARP_SIZE = 32;
|
|
constexpr uint32_t WARP_MASK = 0xffffffff;
|
|
|
|
constexpr int RECV_FIFO_DEPTH = 8;
|
|
constexpr int RECV_FIFO_ENTRY_BYTES = 256 * 1024;
|
|
constexpr int RECV_FIFO_ENTRY_U64 = RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t);
|
|
constexpr int RECV_FIFO_TOTAL_BYTES = RECV_FIFO_DEPTH * RECV_FIFO_ENTRY_BYTES;
|
|
constexpr int RECV_FIFO_TOTAL_U64 = RECV_FIFO_TOTAL_BYTES / sizeof(uint64_t);
|
|
|
|
class AllToAllChannelCommunicatorBase
|
|
{
|
|
public:
|
|
static constexpr int GROUP_COUNT_PER_BLOCK = 8;
|
|
static_assert(GROUP_COUNT_PER_BLOCK <= 8, "GROUP_COUNT_PER_BLOCK must be less than or equal to 8");
|
|
static constexpr int WARP_PER_GROUP = 2;
|
|
static constexpr int U64_DATA_REG_PER_THREAD = 8;
|
|
// A packet is a warp-sized chunk of data that is sent or received in one go,
|
|
// but may be split into multiple 64-bit registers, the number of which is U64_DATA_REG_PER_THREAD.
|
|
static constexpr int PACKET_SIZE_IN_U64 = WARP_SIZE * U64_DATA_REG_PER_THREAD;
|
|
static constexpr int PACKET_SIZE_IN_BYTES = PACKET_SIZE_IN_U64 * sizeof(uint64_t);
|
|
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 = (WARP_SIZE - 2) * U64_DATA_REG_PER_THREAD;
|
|
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET = DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 * sizeof(uint64_t);
|
|
static constexpr int U64_ELT_COUNT_PER_PACKET = PACKET_SIZE_IN_BYTES / sizeof(uint64_t);
|
|
|
|
static constexpr int PACKET_COUNT_PER_FIFO_ENTRY = RECV_FIFO_ENTRY_BYTES / PACKET_SIZE_IN_BYTES;
|
|
|
|
static constexpr int GROUP_MAX_INDICE_COUNT
|
|
= RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t) / (WARP_SIZE * U64_DATA_REG_PER_THREAD);
|
|
|
|
struct GroupSharedBuffer
|
|
{
|
|
int groupIndiceBuffer[GROUP_MAX_INDICE_COUNT];
|
|
int groupStartIndice;
|
|
int groupEndIndice;
|
|
};
|
|
|
|
static void setMaxUsableSmCount(int maxUsableSmCount)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(AllToAllChannelCommunicatorBase::maxSmCountUsed == false,
|
|
"setMaxUsableSmCount can be called only before it is used");
|
|
int smCount = tensorrt_llm::common::getMultiProcessorCount();
|
|
if (maxUsableSmCount > smCount)
|
|
{
|
|
TLLM_LOG_WARNING("setMaxUsableSmCount, maxUsableSmCount=%d, larger than smCount=%d, using smCount instead",
|
|
maxUsableSmCount, smCount);
|
|
maxUsableSmCount = smCount;
|
|
}
|
|
AllToAllChannelCommunicatorBase::maxSmCount = maxUsableSmCount;
|
|
}
|
|
|
|
static int getMaxUsableSmCount()
|
|
{
|
|
AllToAllChannelCommunicatorBase::maxSmCountUsed = true;
|
|
if (AllToAllChannelCommunicatorBase::maxSmCount == -1)
|
|
{
|
|
int smCount = tensorrt_llm::common::getMultiProcessorCount();
|
|
AllToAllChannelCommunicatorBase::maxSmCount = smCount;
|
|
}
|
|
return AllToAllChannelCommunicatorBase::maxSmCount;
|
|
}
|
|
|
|
static int computeMoeCommChannelCount(int epSize)
|
|
{
|
|
int smCount = getMaxUsableSmCount();
|
|
int blockCountPerChannel = (epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK;
|
|
blockCountPerChannel *= 2; // for send and recv
|
|
TLLM_CHECK_WITH_INFO(
|
|
blockCountPerChannel <= smCount, "GPU should support at lease one channel, usableSmCount=%d", smCount);
|
|
int perferredChannel = smCount / 2 / blockCountPerChannel; // use half SMs for communication
|
|
int channelCount = std::max(perferredChannel, 1); // at lease one channel
|
|
return channelCount;
|
|
}
|
|
|
|
static int getMoeCommChannelCount(int epSize)
|
|
{
|
|
static std::map<int, int> channelCountMap{};
|
|
auto iter = channelCountMap.find(epSize);
|
|
if (iter == channelCountMap.end())
|
|
{
|
|
auto channelCount = AllToAllChannelCommunicatorBase::computeMoeCommChannelCount(epSize);
|
|
channelCountMap[epSize] = channelCount;
|
|
return channelCount;
|
|
}
|
|
return iter->second;
|
|
}
|
|
|
|
static dim3 getLaunchBlockDim()
|
|
{
|
|
return dim3(WARP_SIZE * WARP_PER_GROUP, GROUP_COUNT_PER_BLOCK);
|
|
}
|
|
|
|
static dim3 getLaunchGridDim(int epSize)
|
|
{
|
|
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
|
|
return dim3((epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK, channelCount, 2);
|
|
}
|
|
|
|
protected:
|
|
static int maxSmCount;
|
|
static bool maxSmCountUsed;
|
|
};
|
|
|
|
inline size_t getMoeCommWorkspaceSize(int epSize)
|
|
{
|
|
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
|
|
return RECV_FIFO_TOTAL_BYTES * epSize * channelCount + sizeof(MoeCommFifoConnInfo) * epSize * channelCount;
|
|
}
|
|
|
|
struct MoeEpWorldInfo
|
|
{
|
|
int epSize;
|
|
int epRank;
|
|
};
|
|
|
|
struct MoeExpertParallelInfo
|
|
{
|
|
int expertCount = -1;
|
|
int topK = 1;
|
|
};
|
|
|
|
struct SendRecvDataInfo
|
|
{
|
|
int vectorSizeInU64;
|
|
// pre-computed at host side for GPU kernel
|
|
int dataPacketCountPerVector;
|
|
int vectorCountPerFifoEntry;
|
|
|
|
void ComputeDataPacketCountPerVector()
|
|
{
|
|
dataPacketCountPerVector
|
|
= (vectorSizeInU64 * sizeof(uint64_t) + AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET - 1)
|
|
/ AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET;
|
|
}
|
|
|
|
void ComputeVectorCountPerFifoEntry()
|
|
{
|
|
ComputeDataPacketCountPerVector();
|
|
vectorCountPerFifoEntry
|
|
= AllToAllChannelCommunicatorBase::PACKET_COUNT_PER_FIFO_ENTRY / dataPacketCountPerVector;
|
|
}
|
|
|
|
void DoPreCompute()
|
|
{
|
|
ComputeDataPacketCountPerVector();
|
|
ComputeVectorCountPerFifoEntry();
|
|
assert(vectorCountPerFifoEntry <= AllToAllChannelCommunicatorBase::GROUP_MAX_INDICE_COUNT);
|
|
}
|
|
};
|
|
|
|
// struct holding Send/Recv data pointer and its displacement information.
|
|
struct SendRecvDispls
|
|
{
|
|
uint64_t* dataPtr;
|
|
int const* rankCountCumSum; // length = epSize
|
|
int const* rankLocalIndices; // length = rankCountCumSum[epRank] - rankCountCumSum[epRank - 1] if epRank > 0 else
|
|
// rankCountCumSum[epRank]
|
|
int vectorStrideInU64;
|
|
|
|
#ifdef __CUDACC__
|
|
__inline__ __device__ int getCount(int rank) const
|
|
{
|
|
return rank == 0 ? rankCountCumSum[rank] : rankCountCumSum[rank] - rankCountCumSum[rank - 1];
|
|
}
|
|
|
|
__inline__ __device__ int getRankStart(int rank) const
|
|
{
|
|
return rank == 0 ? 0 : rankCountCumSum[rank - 1];
|
|
}
|
|
|
|
__inline__ __device__ int getRealVectorIndice(int globalVectorIndex) const
|
|
{
|
|
return rankLocalIndices[globalVectorIndex];
|
|
}
|
|
|
|
__inline__ __device__ uint64_t* getVectorDataPtr(int realVectorIndex) const
|
|
{
|
|
return dataPtr + realVectorIndex * vectorStrideInU64;
|
|
}
|
|
#endif
|
|
};
|
|
|
|
struct MoeCommWorkspace
|
|
{
|
|
uint64_t* workspacePtr;
|
|
size_t rankStrideInU64;
|
|
#ifdef __CUDACC__
|
|
__inline__ __device__ uint64_t* getFifoBasePtr(
|
|
bool isSender, int epRank, int peerRank, int channel, int channelCount) const
|
|
{
|
|
// fifo itself is in receiver's side.
|
|
if (isSender)
|
|
{
|
|
return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
|
|
}
|
|
else
|
|
{
|
|
return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
|
|
}
|
|
}
|
|
|
|
__inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo(
|
|
bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const
|
|
{
|
|
// fifoInfo is in sender's side.
|
|
uint64_t* fifoInfoPtrU64 = workspacePtr + RECV_FIFO_TOTAL_U64 * channelCount * epSize;
|
|
int strideIndice = isSender ? epRank : peerRank;
|
|
int fifoInfoIndice = isSender ? peerRank : epRank;
|
|
fifoInfoPtrU64 += strideIndice * rankStrideInU64;
|
|
MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*) fifoInfoPtrU64;
|
|
return fifoInfoPtr + fifoInfoIndice * channelCount + channel;
|
|
}
|
|
#endif
|
|
};
|
|
|
|
void setMaxUsableSmCount(int smCount);
|
|
|
|
void moeAllToAll(MoeEpWorldInfo worldInfo, SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls,
|
|
SendRecvDispls recvDispls, MoeCommWorkspace workspace, cudaStream_t stream);
|
|
|
|
void moeAllToAllPrepareIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
|
|
int maxTokenCountPerRank, int const* gatheredTargetRankIds, int const* realRankTokenCountCumSum,
|
|
int* localGatheredIndices, // indices of gatheredTargetRankIds that has the local rank in topK
|
|
int* sendRankCountCumSum, int* sendRankLocalIndices, int* recvRankCountCumSum, int* recvRankLocalIndices,
|
|
// the rankCountCumSum of combineRecv should be the same as sendRankCountCumSum
|
|
int* backwardRecvRankLocalIndices, cudaStream_t stream);
|
|
|
|
void moeLocalGather(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
|
|
int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices, int const* gatheredExpertIds,
|
|
float const* gatheredScales, int* localExpertIds, float* localScales, cudaStream_t stream);
|
|
|
|
} // namespace tensorrt_llm::kernels
|