TensorRT-LLMs/cpp/tensorrt_llm/kernels/moePrepareKernels.h
WeiHaocheng fddb7f1141
feat: moe prepare support topk % 4 != 0 (#5742)
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
2025-07-22 10:42:46 +08:00

130 lines
4.7 KiB
C++

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <map>
#include "tensorrt_llm/common/cudaUtils.h"
#define DEBUG_PIPELINE 0
namespace tensorrt_llm::kernels
{
namespace moe_prepare
{
#define STEP_DEPTH 2
#define THREADS_PER_UNIT 1
#define UNIT_PER_PIPELINE 128
#define PIPELINE_PER_CTA 4
#define EXPERT_BYTES_PER_UNIT 32
#define SCALE_BYTES_PER_UNIT 32
#define UNIT_COUNT_PER_PACKET 1024
#define BYTES_COUNTER 8
#define CUMSUM_THREADS_PER_BLOCK 128
#define UNIT_PER_ITER 256
#define STATIC_COPY_PER_ITER 128
static constexpr int THREADS_PER_PIPELINE = THREADS_PER_UNIT * UNIT_PER_PIPELINE;
static constexpr int THREADS_PER_CTA = THREADS_PER_PIPELINE * PIPELINE_PER_CTA;
template <int UNIT_SIZE_INPUT, int PACKET_PER_STEP_INPUT>
struct PipelineConfig
{
static constexpr int UNIT_SIZE = UNIT_SIZE_INPUT;
static constexpr int PACKET_PER_STEP = PACKET_PER_STEP_INPUT;
static constexpr int UNIT_BYTES_SIZE = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float));
static constexpr int SCALE_OFFSET = UNIT_SIZE * UNIT_PER_ITER * sizeof(int);
static constexpr int STATIC_COPY_OFFSET = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float));
static constexpr int PACKET_SIZE = UNIT_BYTES_SIZE + STATIC_COPY_PER_ITER * 4 * sizeof(int);
static constexpr int PACKET_SIZE_IN_U64 = (PACKET_SIZE / 8);
};
// 1MB FIFO size
static constexpr int FIFO_SIZE_IN_U64 = 1024 * 1024 / 8;
#ifdef __CUDACC__
#define ALIGN_256 __align__(256)
#else
#define ALIGN_256 alignas(256)
#endif
struct ALIGN_256 MoeCommFifoConnInfo
{
volatile uint64_t head; // write position
volatile uint64_t tail; // read position
volatile uint64_t count; // for counter
};
struct MoeCommWorkspace
{
uint64_t* workspacePtr;
size_t rankStrideInU64;
#ifdef __CUDACC__
__inline__ __device__ uint64_t* getFifoBasePtr(
bool isSender, int epRank, int peerRank, int channel, int channelCount) const
{
// fifo itself is in receiver's side.
if (isSender)
{
return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * FIFO_SIZE_IN_U64;
}
else
{
return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * FIFO_SIZE_IN_U64;
}
}
__inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo(
bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const
{
// fifoInfo is in sender's side.
uint64_t* fifoInfoPtrU64 = workspacePtr + FIFO_SIZE_IN_U64 * channelCount * epSize;
int strideIndice = isSender ? epRank : peerRank;
int fifoInfoIndice = isSender ? peerRank : epRank;
fifoInfoPtrU64 += strideIndice * rankStrideInU64;
MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*) fifoInfoPtrU64;
MoeCommFifoConnInfo* result = fifoInfoPtr + fifoInfoIndice * channelCount + channel;
return result;
}
#endif
};
void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace,
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount,
int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream);
void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream);
void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, int* gatherSendIndice,
int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount,
int maxTokenCountPerRank, cudaStream_t stream);
void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics,
int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice,
int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount,
int slotCount, int rankId, int rankCount, cudaStream_t stream);
size_t getMoePrepareWorkspaceSize(int epSize);
} // namespace moe_prepare
} // namespace tensorrt_llm::kernels