mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
676 lines
25 KiB
Plaintext
676 lines
25 KiB
Plaintext
/*
|
|
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "moePrepareKernels.h"
|
|
|
|
#include <stdio.h>
|
|
|
|
#include <cooperative_groups.h>
|
|
#include <cooperative_groups/reduce.h>
|
|
#include <cub/cub.cuh>
|
|
|
|
namespace cg = cooperative_groups;
|
|
|
|
namespace tensorrt_llm::kernels
|
|
{
|
|
|
|
namespace moe_prepare
|
|
{
|
|
|
|
__device__ __forceinline__ void st_release_sys_global(uint64_t volatile* ptr, uint64_t val)
|
|
{
|
|
asm volatile("st.release.sys.global.u64 [%0], %1;" ::"l"(ptr), "l"(val) : "memory");
|
|
}
|
|
|
|
__device__ __forceinline__ uint64_t ld_acquire_sys_global(uint64_t volatile* ptr)
|
|
{
|
|
uint64_t ret;
|
|
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
|
return ret;
|
|
}
|
|
|
|
__device__ __forceinline__ int ld_acquire_sys_global_int(int volatile* ptr)
|
|
{
|
|
int ret;
|
|
asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
|
return ret;
|
|
}
|
|
|
|
class StepCommunicatorBase
|
|
{
|
|
public:
|
|
static constexpr int META_SIZE = sizeof(MoeCommFifoConnInfo);
|
|
|
|
__device__ __inline__ StepCommunicatorBase(MoeCommFifoConnInfo* fifoConnInfo)
|
|
: fifoConnInfo(fifoConnInfo)
|
|
, localCachedHead(0)
|
|
, localCachedTail(0)
|
|
{
|
|
}
|
|
|
|
__forceinline__ __device__ void reset()
|
|
{
|
|
fifoConnInfo->head = 0;
|
|
fifoConnInfo->tail = 0;
|
|
}
|
|
|
|
__forceinline__ __device__ void releaseSendStep()
|
|
{
|
|
localCachedHead += 1;
|
|
st_release_sys_global(&(fifoConnInfo->head), uint64_t(localCachedHead));
|
|
}
|
|
|
|
__forceinline__ __device__ void releaseRecvStep()
|
|
{
|
|
localCachedTail += 1;
|
|
st_release_sys_global(&(fifoConnInfo->tail), uint64_t(localCachedTail));
|
|
}
|
|
|
|
__forceinline__ __device__ uint64_t acquireTail()
|
|
{
|
|
uint64_t tail = ld_acquire_sys_global(&(fifoConnInfo->tail));
|
|
localCachedTail = tail;
|
|
return tail;
|
|
}
|
|
|
|
__forceinline__ __device__ uint64_t acquireHead()
|
|
{
|
|
uint64_t head = ld_acquire_sys_global(&(fifoConnInfo->head));
|
|
localCachedHead = head;
|
|
return head;
|
|
}
|
|
|
|
__forceinline__ __device__ int acquireNewSendStep()
|
|
{
|
|
|
|
int64_t tail;
|
|
do
|
|
{
|
|
tail = acquireTail();
|
|
} while (localCachedHead >= tail + STEP_DEPTH);
|
|
// depth = 2, head = 1, tail = 0 , ok
|
|
// depth = 2, head = 2, tail = 0, should wait
|
|
|
|
return localCachedHead % STEP_DEPTH;
|
|
}
|
|
|
|
__forceinline__ __device__ int acquireNewRecvStep()
|
|
{
|
|
int64_t head = 0;
|
|
do
|
|
{
|
|
head = acquireHead();
|
|
} while (localCachedTail >= head);
|
|
|
|
return localCachedTail % STEP_DEPTH;
|
|
}
|
|
|
|
public:
|
|
MoeCommFifoConnInfo* fifoConnInfo;
|
|
uint64_t localCachedHead;
|
|
uint64_t localCachedTail;
|
|
int rank;
|
|
int targetRank;
|
|
};
|
|
|
|
// Use MoeCommFifoConnInfo as media to transfer a counter number.
|
|
// Use the "head" field as flag.
|
|
// Use the "tail" field to transfer the counter number.
|
|
class CounterCommunicator
|
|
{
|
|
public:
|
|
__device__ __inline__ CounterCommunicator(MoeCommFifoConnInfo* fifoConnInfo)
|
|
: fifoConnInfo(fifoConnInfo)
|
|
{
|
|
}
|
|
|
|
__forceinline__ __device__ void releaseValue(uint64_t value)
|
|
{
|
|
// Avoid block on 0
|
|
st_release_sys_global(&(fifoConnInfo->count), value + 1);
|
|
}
|
|
|
|
__forceinline__ __device__ uint64_t acquireValue()
|
|
{
|
|
uint64_t localCount = 0;
|
|
do
|
|
{
|
|
localCount = ld_acquire_sys_global(&(fifoConnInfo->count));
|
|
} while (localCount == 0);
|
|
|
|
fifoConnInfo->count = 0; // reset the count
|
|
|
|
return localCount - 1;
|
|
}
|
|
|
|
protected:
|
|
MoeCommFifoConnInfo* fifoConnInfo;
|
|
};
|
|
|
|
template <int kThreadsGroupSize>
|
|
__device__ __forceinline__ void computeCountAndSend(int* experts, int tokenCount, int* sharedSendRecvRankCount,
|
|
int* sendCounts, int* sendIndiceWorkspace, int* backwardIndiceWorkspace, MoeCommWorkspace workspace,
|
|
int maxTokenCountPerRank, int expertCount, int topK, int epRank, int epSize)
|
|
{
|
|
cg::thread_block_tile<kThreadsGroupSize> tile = cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
|
|
int laneInTile = tile.thread_rank();
|
|
int tileId = threadIdx.x / kThreadsGroupSize;
|
|
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
|
|
int expertCountPerRank = expertCount / epSize;
|
|
if (threadIdx.x == 0)
|
|
{
|
|
*sharedSendRecvRankCount = 0;
|
|
}
|
|
__syncthreads();
|
|
int targetRankId = blockIdx.x;
|
|
int readRankTokenCount = tokenCount;
|
|
if (targetRankId >= epSize)
|
|
{
|
|
return;
|
|
}
|
|
|
|
int* localSendIndice = sendIndiceWorkspace + targetRankId * maxTokenCountPerRank;
|
|
int* localBackwardIndice = backwardIndiceWorkspace + targetRankId * maxTokenCountPerRank;
|
|
|
|
for (int i = tileId; i < readRankTokenCount; i += tileCountPerBlock)
|
|
{
|
|
int expertRankId = laneInTile < topK ? experts[i * topK + laneInTile] / expertCountPerRank : epSize;
|
|
bool rankMatched = (expertRankId == targetRankId);
|
|
bool hasRankMatched = tile.any(rankMatched);
|
|
int mask = tile.ballot(rankMatched);
|
|
int firstMatchLane = __ffs(mask) - 1; // only valid if hasRankMatched is true
|
|
if (hasRankMatched && laneInTile == 0)
|
|
{
|
|
int index = atomicAdd_block(sharedSendRecvRankCount, 1);
|
|
localSendIndice[index] = i;
|
|
localBackwardIndice[index] = i * topK + firstMatchLane;
|
|
}
|
|
tile.sync();
|
|
}
|
|
__syncthreads();
|
|
if (threadIdx.x == 0)
|
|
{
|
|
CounterCommunicator counter(workspace.getFifoConnInfo(true, epRank, targetRankId, 0, epSize, 1));
|
|
int count = *(sharedSendRecvRankCount);
|
|
// printf("sendRecvCount: %d, rankId: %d, targetRankId: %d\n", count, rankId, targetRankId);
|
|
counter.releaseValue(uint64_t(count));
|
|
*(sendCounts + targetRankId) = count;
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCounts, int* sharedCountsBase,
|
|
MoeCommWorkspace workspace, int maxTokenCountPerRank, int rankId, int rankCount)
|
|
{
|
|
int rankOffset = threadIdx.x / THREADS_PER_PIPELINE;
|
|
if (rankOffset >= PIPELINE_PER_CTA)
|
|
{
|
|
return;
|
|
}
|
|
int* sharedCountsThisRank = sharedCountsBase + rankOffset;
|
|
int targetRankId = (blockIdx.x - rankCount) * PIPELINE_PER_CTA + rankOffset;
|
|
if (targetRankId >= rankCount)
|
|
{
|
|
return;
|
|
}
|
|
int unitId = threadIdx.x % UNIT_PER_PIPELINE;
|
|
cg::thread_block_tile<THREADS_PER_PIPELINE> rankTile
|
|
= cg::tiled_partition<THREADS_PER_PIPELINE>(cg::this_thread_block());
|
|
int* localRecvIndice = recvIndiceWorkspace + targetRankId * maxTokenCountPerRank;
|
|
int rankRecvCount;
|
|
if (rankTile.thread_rank() == 0)
|
|
{
|
|
CounterCommunicator counter(workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1));
|
|
rankRecvCount = int(counter.acquireValue());
|
|
// printf("rankRecvCount: %d, rankId: %d, targetRankId: %d\n", rankRecvCount, rankId, targetRankId);
|
|
*(recvCounts + targetRankId) = rankRecvCount;
|
|
*(sharedCountsThisRank) = rankRecvCount;
|
|
}
|
|
rankTile.sync();
|
|
|
|
rankRecvCount = *(sharedCountsThisRank);
|
|
for (int tokenId = unitId; tokenId < rankRecvCount; tokenId += UNIT_PER_PIPELINE)
|
|
{
|
|
*(localRecvIndice + tokenId) = tokenId;
|
|
}
|
|
}
|
|
|
|
template <int kThreadsGroupSize>
|
|
__global__ void computeCountAndIndiceDevice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace,
|
|
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount,
|
|
int maxTokenCountPerRank, int topK, int expertCount, int rankId, int rankCount)
|
|
{
|
|
__shared__ int sharedCounts[PIPELINE_PER_CTA];
|
|
bool isSender = blockIdx.x < rankCount;
|
|
if (isSender)
|
|
{
|
|
computeCountAndSend<kThreadsGroupSize>(experts, tokenCount, &sharedCounts[0], sendCounts, sendIndiceWorkspace,
|
|
backwardIndiceWorkspace, workspace, maxTokenCountPerRank, expertCount, topK, rankId, rankCount);
|
|
}
|
|
else
|
|
{
|
|
recvCount(
|
|
recvIndiceWorkspace, recvCounts, &sharedCounts[0], workspace, maxTokenCountPerRank, rankId, rankCount);
|
|
}
|
|
}
|
|
|
|
__global__ void moveIndiceDevice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, int* gatherSendIndice,
|
|
int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int maxTokenCountPerRank)
|
|
{
|
|
int targetRankId = blockIdx.x;
|
|
if (blockIdx.y == 0)
|
|
{
|
|
// sendIndice and backwardIndice CTA
|
|
int startIndex = targetRankId == 0 ? 0 : sendCountsCumsum[targetRankId - 1];
|
|
int endIndex = sendCountsCumsum[targetRankId];
|
|
int count = endIndex - startIndex;
|
|
int* localSendIndice = sendIndice + targetRankId * maxTokenCountPerRank;
|
|
int* localBackwardIndice = backwardIndice + targetRankId * maxTokenCountPerRank;
|
|
for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x)
|
|
{
|
|
gatherSendIndice[startIndex + localIdx] = localSendIndice[localIdx];
|
|
gatherBackwardIndice[startIndex + localIdx] = localBackwardIndice[localIdx];
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// recvIndice CTA
|
|
int startIndex = targetRankId == 0 ? 0 : recvCountsCumsum[targetRankId - 1];
|
|
int endIndex = recvCountsCumsum[targetRankId];
|
|
int count = endIndex - startIndex;
|
|
for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x)
|
|
{
|
|
gatherRecvIndice[startIndex + localIdx] = startIndex + localIdx;
|
|
}
|
|
}
|
|
}
|
|
|
|
__global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount)
|
|
{
|
|
int* inputOutputPtr = blockIdx.x == 0 ? sendCountsCumsum : recvCountsCumsum;
|
|
|
|
// Use 2 block to comuteCumsum
|
|
typedef cub::BlockScan<int, CUMSUM_THREADS_PER_BLOCK> BlockScan;
|
|
__shared__ typename BlockScan::TempStorage temp_storage;
|
|
|
|
int tid = threadIdx.x;
|
|
int threadData = tid < rankCount ? inputOutputPtr[tid] : 0;
|
|
int count = threadData;
|
|
__syncthreads();
|
|
|
|
BlockScan(temp_storage).InclusiveSum(threadData, threadData);
|
|
if (tid < rankCount)
|
|
{
|
|
inputOutputPtr[tid] = threadData;
|
|
// printf("cumsum, send? : %d, rankId:%d, tid:%d, threadData:%d, count:%d\n", blockIdx.x == 0, rankId, tid,
|
|
// threadData, count);
|
|
}
|
|
}
|
|
|
|
template <typename PipelineConfig>
|
|
class PacketPipeline
|
|
{
|
|
public:
|
|
__device__ __inline__ PacketPipeline(
|
|
void* bufferBase, StepCommunicatorBase* stepCommunicator, int* sharedNewStepPtr, bool isSender)
|
|
: bufferBase(bufferBase)
|
|
, stepCommunicator(stepCommunicator)
|
|
, shared_new_step(sharedNewStepPtr)
|
|
{
|
|
step = 0;
|
|
needRelease = false;
|
|
packetId = isSender ? 0 : PipelineConfig::PACKET_PER_STEP - 1;
|
|
}
|
|
|
|
__device__ __forceinline__ void* getFirstSendPacket()
|
|
{
|
|
return bufferBase;
|
|
}
|
|
|
|
__device__ __inline__ void* finishSendPacket(bool acquireNewStep)
|
|
{
|
|
|
|
packetId++;
|
|
if (packetId < PipelineConfig::PACKET_PER_STEP)
|
|
{
|
|
return acquireNewStep ? bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
|
|
+ packetId * PipelineConfig::PACKET_SIZE
|
|
: nullptr;
|
|
}
|
|
|
|
__syncthreads();
|
|
if (threadIdx.x == 0)
|
|
{
|
|
stepCommunicator->releaseSendStep();
|
|
if (acquireNewStep)
|
|
{
|
|
step = stepCommunicator->acquireNewSendStep();
|
|
*(shared_new_step) = step;
|
|
}
|
|
}
|
|
__syncthreads();
|
|
|
|
if (acquireNewStep)
|
|
{
|
|
step = *(shared_new_step);
|
|
packetId = 0;
|
|
return bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP;
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
__device__ __forceinline__ void* sendFinalize()
|
|
{
|
|
if (packetId > 0 && threadIdx.x == 0)
|
|
{
|
|
stepCommunicator->releaseSendStep();
|
|
}
|
|
}
|
|
|
|
__device__ __inline__ void* getNewRecvPacket()
|
|
{
|
|
packetId++;
|
|
if (packetId < PipelineConfig::PACKET_PER_STEP)
|
|
{
|
|
return bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
|
|
+ packetId * PipelineConfig::PACKET_SIZE;
|
|
}
|
|
|
|
__syncthreads();
|
|
if (threadIdx.x == 0)
|
|
{
|
|
if (needRelease)
|
|
{
|
|
stepCommunicator->releaseRecvStep();
|
|
}
|
|
step = stepCommunicator->acquireNewRecvStep();
|
|
needRelease = true;
|
|
*(shared_new_step) = step;
|
|
}
|
|
__syncthreads();
|
|
packetId = 0;
|
|
step = *(shared_new_step);
|
|
void* packetPtr = bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP;
|
|
|
|
return packetPtr;
|
|
}
|
|
|
|
__device__ __forceinline__ void reset()
|
|
{
|
|
if (threadIdx.x == 0)
|
|
{
|
|
stepCommunicator->reset();
|
|
}
|
|
}
|
|
|
|
void* bufferBase;
|
|
StepCommunicatorBase* stepCommunicator;
|
|
int step;
|
|
int packetId;
|
|
bool needRelease;
|
|
int* shared_new_step;
|
|
};
|
|
|
|
template <typename PipelineConfig, typename ExpertType, typename ScaleType>
|
|
__global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales,
|
|
int* localExpertStatics, int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum,
|
|
int* localSendIndice, int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank,
|
|
int topK, int expertCount, int slotCount, int rankId, int rankCount)
|
|
{
|
|
bool isSender = (blockIdx.y == 0);
|
|
int targetRankId = blockIdx.x;
|
|
int slotCountPerRank = slotCount / rankCount;
|
|
int groupSize = topK / PipelineConfig::UNIT_SIZE;
|
|
|
|
__shared__ int sharedNewStep;
|
|
__align__(16) int experts[PipelineConfig::UNIT_SIZE];
|
|
__align__(16) float scales[PipelineConfig::UNIT_SIZE];
|
|
|
|
uint8_t* bufferBase = (uint8_t*) (workspace.getFifoBasePtr(isSender, rankId, targetRankId, 0, 1));
|
|
StepCommunicatorBase stepCommunicator(workspace.getFifoConnInfo(isSender, rankId, targetRankId, 0, rankCount, 1));
|
|
PacketPipeline<PipelineConfig> pipeline(bufferBase, &stepCommunicator, &sharedNewStep, isSender);
|
|
|
|
if (isSender)
|
|
{
|
|
int baseCumsum = targetRankId == 0 ? 0 : *(sendCountsCumsum + targetRankId - 1);
|
|
int sendTokenCount = *(sendCountsCumsum + targetRankId) - baseCumsum;
|
|
int unitCount = sendTokenCount * topK / PipelineConfig::UNIT_SIZE;
|
|
|
|
void* packPtr = pipeline.getFirstSendPacket();
|
|
int indexBase = 0;
|
|
int staticCopyBase = 0;
|
|
bool acquireNewStep = unitCount > 0 || (localExpertStatics != nullptr && expertCount > 0);
|
|
while (acquireNewStep)
|
|
{
|
|
if (threadIdx.x < UNIT_PER_ITER)
|
|
{
|
|
int index = indexBase + threadIdx.x;
|
|
int groupId = index % groupSize;
|
|
if (index < unitCount)
|
|
{
|
|
int tokenId = *(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize));
|
|
*((ExpertType*) (experts))
|
|
= *(ExpertType*) (sendExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < PipelineConfig::UNIT_SIZE; j++)
|
|
{
|
|
int expertId = experts[j];
|
|
if (expertId / slotCountPerRank != targetRankId)
|
|
{
|
|
experts[j] = slotCount;
|
|
}
|
|
}
|
|
|
|
int* expertsPtr = (int*) (packPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
|
|
*((ExpertType*) (expertsPtr)) = *((ExpertType*) (experts));
|
|
if (sendScales != nullptr)
|
|
{
|
|
*((ScaleType*) (scales))
|
|
= *(ScaleType*) (sendScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
|
|
float* scaleBasePtr = (float*) (packPtr + PipelineConfig::SCALE_OFFSET);
|
|
float* scalesPtr = (float*) (scaleBasePtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
|
|
*((ScaleType*) (scalesPtr)) = *((ScaleType*) (scales));
|
|
}
|
|
}
|
|
}
|
|
else if (localExpertStatics != nullptr)
|
|
{
|
|
int staticCopyIdx = threadIdx.x - UNIT_PER_ITER;
|
|
if (staticCopyBase + staticCopyIdx * 4 < expertCount)
|
|
{
|
|
int4* staticBasePtr = (int4*) (packPtr + PipelineConfig::STATIC_COPY_OFFSET);
|
|
int4 staticData = *(int4*) (localExpertStatics + staticCopyBase + staticCopyIdx * 4);
|
|
*(staticBasePtr + staticCopyIdx) = staticData;
|
|
}
|
|
}
|
|
|
|
indexBase += UNIT_PER_ITER;
|
|
staticCopyBase += STATIC_COPY_PER_ITER * 4;
|
|
acquireNewStep = indexBase < unitCount || staticCopyBase < expertCount;
|
|
packPtr = pipeline.finishSendPacket(acquireNewStep);
|
|
}
|
|
|
|
pipeline.sendFinalize();
|
|
}
|
|
else
|
|
{
|
|
int baseCumsum = targetRankId == 0 ? 0 : *(recvCountsCumsum + targetRankId - 1);
|
|
int recvTokenCount = *(recvCountsCumsum + targetRankId) - baseCumsum;
|
|
int recvUnitCount = recvTokenCount * groupSize;
|
|
|
|
int unitIdBase = 0;
|
|
int staticCopyBase = 0;
|
|
while (unitIdBase < recvUnitCount || (localExpertStatics != nullptr && staticCopyBase < expertCount))
|
|
{
|
|
void* packetPtr = pipeline.getNewRecvPacket();
|
|
int packetUnitCount
|
|
= unitIdBase + UNIT_PER_ITER < recvUnitCount ? UNIT_PER_ITER : recvUnitCount - unitIdBase;
|
|
packetUnitCount = max(packetUnitCount, 0);
|
|
if (threadIdx.x < UNIT_PER_ITER)
|
|
{
|
|
if (threadIdx.x < packetUnitCount)
|
|
{
|
|
int tokenId = baseCumsum + (unitIdBase + threadIdx.x) / groupSize;
|
|
int groupId = (unitIdBase + threadIdx.x) % groupSize;
|
|
int* expertsPtr = (int*) (packetPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
|
|
*((ExpertType*) (experts)) = *((ExpertType*) (expertsPtr));
|
|
ExpertType* dstExpertsPtr
|
|
= (ExpertType*) (recvExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
|
|
*dstExpertsPtr = *((ExpertType*) (experts));
|
|
|
|
if (recvScales != nullptr)
|
|
{
|
|
float* scaleBasePtr = (float*) (packetPtr + PipelineConfig::SCALE_OFFSET);
|
|
float* scalesPtr = scaleBasePtr + threadIdx.x * PipelineConfig::UNIT_SIZE;
|
|
*((ScaleType*) (scales)) = *((ScaleType*) (scalesPtr));
|
|
ScaleType* dstScalesPtr
|
|
= (ScaleType*) (recvScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
|
|
*dstScalesPtr = *((ScaleType*) (scales));
|
|
}
|
|
}
|
|
}
|
|
else if (localExpertStatics != nullptr)
|
|
{
|
|
int staticCopyIdx = threadIdx.x - UNIT_PER_ITER;
|
|
if (staticCopyBase + staticCopyIdx * 4 < expertCount)
|
|
{
|
|
int4* staticBasePtr = (int4*) (packetPtr + PipelineConfig::STATIC_COPY_OFFSET);
|
|
int4 staticData = *(staticBasePtr + staticCopyIdx);
|
|
*(int4*) (gatheredExpertStatics + targetRankId * expertCount + staticCopyBase + staticCopyIdx * 4)
|
|
= staticData;
|
|
}
|
|
}
|
|
|
|
unitIdBase += packetUnitCount;
|
|
staticCopyBase += STATIC_COPY_PER_ITER * 4;
|
|
}
|
|
|
|
pipeline.reset();
|
|
}
|
|
}
|
|
|
|
__global__ void memsetExpertIdsDevice(
|
|
int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount, int rankCount)
|
|
{
|
|
int maxTokenCount = maxTokenCountPerRank * rankCount;
|
|
int totalRecvTokenCount = *(recvCountsCumsum + rankCount - 1);
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i + totalRecvTokenCount * topK < maxTokenCount * topK;
|
|
i += gridDim.x * blockDim.x)
|
|
{
|
|
*(expertIds + i + totalRecvTokenCount * topK) = slotCount;
|
|
}
|
|
}
|
|
|
|
void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace,
|
|
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount,
|
|
int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream)
|
|
{
|
|
// first rankCount CTAs for count and send, then rankCount / PIPELINE_PER_CTA CTAs only for receive
|
|
int grid_x = rankCount + (rankCount + PIPELINE_PER_CTA - 1) / PIPELINE_PER_CTA;
|
|
int block_size = 1024;
|
|
dim3 block(block_size);
|
|
dim3 grid(grid_x);
|
|
TLLM_CHECK_WITH_INFO(topK >= 1 && topK <= 32, "Only 1 <= topK <= 32 is supported now.");
|
|
auto* kernelFn = computeCountAndIndiceDevice<1>;
|
|
if (topK > 16)
|
|
{
|
|
kernelFn = computeCountAndIndiceDevice<32>;
|
|
}
|
|
else if (topK > 8)
|
|
{
|
|
kernelFn = computeCountAndIndiceDevice<16>;
|
|
}
|
|
else if (topK > 4)
|
|
{
|
|
kernelFn = computeCountAndIndiceDevice<8>;
|
|
}
|
|
else if (topK > 2)
|
|
{
|
|
kernelFn = computeCountAndIndiceDevice<4>;
|
|
}
|
|
else if (topK > 1)
|
|
{
|
|
kernelFn = computeCountAndIndiceDevice<2>;
|
|
}
|
|
kernelFn<<<grid, block, 0, stream>>>(experts, sendCounts, recvCounts, sendIndiceWorkspace, backwardIndiceWorkspace,
|
|
recvIndiceWorkspace, workspace, tokenCount, maxTokenCountPerRank, topK, expert_count, rankId, rankCount);
|
|
}
|
|
|
|
void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream)
|
|
{
|
|
int block_size = CUMSUM_THREADS_PER_BLOCK;
|
|
dim3 block(block_size);
|
|
dim3 grid(2);
|
|
computeCumsumDevice<<<grid, block, 0, stream>>>(sendCountsCumsum, recvCountsCumsum, rankId, rankCount);
|
|
}
|
|
|
|
void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, int* gatherSendIndice,
|
|
int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount,
|
|
int maxTokenCountPerRank, cudaStream_t stream)
|
|
{
|
|
dim3 block(512);
|
|
dim3 grid(rankCount, 2);
|
|
moveIndiceDevice<<<grid, block, 0, stream>>>(sendCountsCumsum, recvCountsCumsum, sendIndice, gatherSendIndice,
|
|
backwardIndice, gatherBackwardIndice, recvIndice, gatherRecvIndice, maxTokenCountPerRank);
|
|
}
|
|
|
|
void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics,
|
|
int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice,
|
|
int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount,
|
|
int slotCount, int rankId, int rankCount, cudaStream_t stream)
|
|
{
|
|
int block_size = localExpertStatics == nullptr ? UNIT_PER_ITER : UNIT_PER_ITER + STATIC_COPY_PER_ITER;
|
|
dim3 block(block_size);
|
|
dim3 grid(rankCount, 2);
|
|
|
|
if (topK % 4 == 0)
|
|
{
|
|
using PipelineConfig = PipelineConfig<4, 16>;
|
|
static_assert(
|
|
PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
|
|
"FIFO size is too small");
|
|
allToAllMetadataDevice<PipelineConfig, int4, float4><<<grid, block, 0, stream>>>(sendExperts, recvExperts,
|
|
sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
|
|
localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
|
|
slotCount, rankId, rankCount);
|
|
}
|
|
else
|
|
{
|
|
using PipelineConfig = PipelineConfig<1, 64>;
|
|
static_assert(
|
|
PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
|
|
"FIFO size is too small");
|
|
allToAllMetadataDevice<PipelineConfig, int, float><<<grid, block, 0, stream>>>(sendExperts, recvExperts,
|
|
sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
|
|
localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
|
|
slotCount, rankId, rankCount);
|
|
}
|
|
|
|
int smCount = tensorrt_llm::common::getMultiProcessorCount();
|
|
memsetExpertIdsDevice<<<smCount, 256, 0, stream>>>(
|
|
recvExperts, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
|
|
}
|
|
|
|
size_t getMoePrepareWorkspaceSize(int epSize)
|
|
{
|
|
return (FIFO_SIZE_IN_U64 * 8 + StepCommunicatorBase::META_SIZE) * epSize;
|
|
}
|
|
|
|
} // namespace moe_prepare
|
|
|
|
} // namespace tensorrt_llm::kernels
|