Merge commit '31979aefacbf80d2742c98ef30385db162788c84' into feat/b300_cu13

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-08-26 10:31:35 +08:00
commit ab7febd4d8
78 changed files with 4469 additions and 2538 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,562 @@
/*
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <map>
#include <cuda_runtime_api.h>
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/moeCommKernelsCommon.h"
namespace tensorrt_llm
{
namespace kernels
{
struct ALIGN_256 SenderSideFifoInfo
{
volatile uint64_t head; // write position
volatile uint64_t tail; // read position
};
struct ALIGN_256 ReceiverSideFifoInfo
{
volatile uint64_t head; // write position do we use this?
volatile uint64_t tail; // read position
};
// struct holding Send/Recv data pointer and its displacement information.
struct SendRecvIndices
{
int const* rankCountCumSum; // length = epSize
int* rankLocalIndices; // length = rankCountCumSum[epRank] - rankCountCumSum[epRank - 1] if epRank > 0 else
// rankCountCumSum[epRank]
#ifdef __CUDACC__
__inline__ __device__ int getCount(int rank) const
{
return rank == 0 ? rankCountCumSum[rank] : rankCountCumSum[rank] - rankCountCumSum[rank - 1];
}
__inline__ __device__ int getRankStart(int rank) const
{
return rank == 0 ? 0 : rankCountCumSum[rank - 1];
}
__inline__ __device__ int* getGroupStart(int rank, int& tokenCount) const
{
tokenCount = getCount(rank);
int rankStart = getRankStart(rank);
return rankLocalIndices + rankStart;
}
#endif
};
struct MoeCommFieldInfo
{
uint8_t* dataPtrBase;
uint8_t alignedUnitBit; // 0, 1, 2, 3, 4 (for 1, 2, 4, 8, 16 Bytes), smallest aligned unit.
uint16_t alignedUnitCount; // data count in aligned unit
uint16_t alignedUnitStride; // data stride in aligned unit
uint8_t unalignedFieldIndex; // the index of unaligned Field, no decrease with field index
uint16_t compact16BOffset; // aligned to 16 Bytes, offset is count of 16 Byte
static constexpr uint64_t kAlign16BytePtrMask = (1ULL << 4) - 1;
static constexpr uint32_t kAligned16BMask = (1 << 4) - 1;
// Constants for memory alignment and access
static constexpr int BYTES_PER_128B_BLOCK = 128;
static constexpr int INTS_PER_128B_BLOCK = BYTES_PER_128B_BLOCK / sizeof(int);
static constexpr int UINT64_PER_128B_BLOCK = BYTES_PER_128B_BLOCK / sizeof(uint64_t);
static constexpr int BYTES_PER_16B_BLOCK = 16;
// Will pad one 16 byte for each unaligned field, then head and tail 16 byte might not be aligned
// Fill single field info, the fields that need global info is not filled here.
__host__ void fillFieldInfo(uint8_t* dataPtr, size_t elementSize, int vectorSize, int stride);
__host__ void setUnused()
{
dataPtrBase = nullptr;
alignedUnitBit = 4;
alignedUnitCount = 0;
alignedUnitStride = 0;
unalignedFieldIndex = 0;
compact16BOffset = 0;
}
template <typename T>
__host__ void fillFieldInfo(T* dataPtr, int vectorSize, int stride)
{
size_t elementSize = sizeof(T);
fillFieldInfo(reinterpret_cast<uint8_t*>(dataPtr), elementSize, vectorSize, stride);
}
__device__ __host__ __forceinline__ int getFieldUncompactSize() const
{
int alignedUnitBytes = 1 << alignedUnitBit;
int currentFieldSize = alignedUnitCount * alignedUnitBytes;
if (alignedUnitBytes != 16)
{
constexpr int alignedUnitBytes = BYTES_PER_16B_BLOCK;
currentFieldSize = currentFieldSize / alignedUnitBytes * alignedUnitBytes;
currentFieldSize += alignedUnitBytes * 2;
}
return currentFieldSize;
}
__device__ __host__ __forceinline__ int getFieldCompactSize() const
{
int alignedUnitBytes = 1 << alignedUnitBit;
int currentFieldSize = alignedUnitCount * alignedUnitBytes;
// Align to 16 bytes for compact size
return (currentFieldSize + BYTES_PER_16B_BLOCK - 1) / BYTES_PER_16B_BLOCK * BYTES_PER_16B_BLOCK;
}
__device__ __forceinline__ int getCompactShmOffset() const
{
return compact16BOffset * BYTES_PER_16B_BLOCK;
}
__device__ __forceinline__ int getUncompactShmOffset() const
{
// each unaligned field need 16 byte head and 16 byte tail
return compact16BOffset * BYTES_PER_16B_BLOCK + unalignedFieldIndex * BYTES_PER_16B_BLOCK;
}
__device__ __forceinline__ int getMemmoveOffsets(int index) const
{
int alignedBytes = 1 << alignedUnitBit;
uint8_t* dataPtr = dataPtrBase + index * alignedBytes * alignedUnitStride;
int offset = reinterpret_cast<uint64_t>(dataPtr) & kAlign16BytePtrMask;
return offset + unalignedFieldIndex * BYTES_PER_16B_BLOCK;
}
__device__ __forceinline__ uint8_t* getRawPtr(int index, int* rawSize) const
{
int alignedBytes = 1 << alignedUnitBit;
uint8_t* dataPtr = dataPtrBase + static_cast<size_t>(index) * alignedBytes * alignedUnitStride;
if (rawSize != nullptr)
{
*rawSize = alignedUnitCount * alignedBytes;
}
return dataPtr;
}
__device__ __forceinline__ uint8_t* get16BAlignedLoadCopyRange(int index, int* copyByteCount) const
{
int rawSize;
uint8_t* rawDataPtr = getRawPtr(index, &rawSize);
uint8_t* rawEndPtr = rawDataPtr + rawSize;
uint8_t* alignedDataPtr
= reinterpret_cast<uint8_t*>(reinterpret_cast<uint64_t>(rawDataPtr) & (~kAlign16BytePtrMask));
uint32_t copySize = rawEndPtr - alignedDataPtr;
*copyByteCount
= (copySize & kAligned16BMask) != 0 ? (copySize & (~kAligned16BMask)) + BYTES_PER_16B_BLOCK : copySize;
return alignedDataPtr;
}
__device__ __forceinline__ uint8_t* get16BAlignedStoreCopyRange(
int index, int* copyByteCount, int laneId, int* headTailShmIdx, int* headTailGlobalIdx) const
{
int rawSize;
uint8_t* rawDataPtr = getRawPtr(index, &rawSize);
uint8_t* rawEndPtr = rawDataPtr + rawSize;
int offset = reinterpret_cast<uint64_t>(rawDataPtr) & kAlign16BytePtrMask;
uint8_t* alignedDataPtr
= reinterpret_cast<uint8_t*>(reinterpret_cast<uint64_t>(rawDataPtr) + BYTES_PER_16B_BLOCK - offset);
uint8_t* alignedEndPtr
= reinterpret_cast<uint8_t*>(reinterpret_cast<uint64_t>(rawEndPtr) & (~kAlign16BytePtrMask));
int alignedCopyBytes = alignedEndPtr - alignedDataPtr;
if (alignedCopyBytes < 0)
{
alignedCopyBytes = 0;
}
*copyByteCount = alignedCopyBytes;
if (laneId < BYTES_PER_16B_BLOCK)
{
*headTailShmIdx = laneId;
}
else
{
*headTailShmIdx = laneId + alignedCopyBytes;
}
*headTailGlobalIdx = *headTailShmIdx - offset;
if (*headTailGlobalIdx < 0 || *headTailGlobalIdx >= rawSize)
{
*headTailGlobalIdx = -1;
*headTailShmIdx = -1;
}
return alignedDataPtr;
}
};
// Maximum number of field supported, except tokenSelectedExpert and expertScales
static constexpr int MOE_COMM_FIELD_MAX_COUNT = 8;
struct MoeSingleCommMeta
{
int singleTransferAlignedSize; // transfer size aligned to 128 bytes.
int singleCompactAlignedSize; // compact buffer is always aligned to 128 bytes
int singleUncompactAlignedSize; // uncompact shared memory size, aligned to 128 bytes, might be larger than compact
// buffer if unaligned field exist.
// TODO: Do we need reduce shared memory usage, make it able to be smaller, and enable multiple wave?
__device__ __host__ __forceinline__ int getTransfer128ByteCount() const
{
return singleTransferAlignedSize / MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
}
__device__ __host__ __forceinline__ int getCompactData128ByteCount() const
{
return singleCompactAlignedSize / MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
}
__device__ __host__ __forceinline__ int getSingleShmSize() const
{
return std::max(singleUncompactAlignedSize, singleTransferAlignedSize);
}
};
struct FusedMoeWorldInfo
{
MoeEpWorldInfo epInfo;
};
struct FusedMoePairInfo
{
int senderRank;
int receiverRank;
int channel;
int runChannelCount;
};
class FusedMoeCommunicator
{
public:
static constexpr int FIFO_DEPTH = 4;
static constexpr int FIFO_ENTRY_BYTES = 256 * 1024;
static constexpr int FIFO_ENTRY_128_BYTE_COUNT = FIFO_ENTRY_BYTES / 128;
static constexpr int FIFO_TOTAL_BYTES = FIFO_ENTRY_BYTES * FIFO_DEPTH;
static constexpr int FIFO_TOTAL_U64 = FIFO_TOTAL_BYTES / sizeof(uint64_t);
static constexpr int MAX_GROUP_COUNT_PER_BLOCK = 8;
static constexpr int WARP_SIZE = 32;
static int maxSmCount;
static bool maxSmCountUsed;
static void setMaxUsableSmCount(int maxUsableSmCount)
{
TLLM_CHECK_WITH_INFO(
FusedMoeCommunicator::maxSmCountUsed == false, "setMaxUsableSmCount can be called only before it is used");
int smCount = tensorrt_llm::common::getMultiProcessorCount();
if (maxUsableSmCount > smCount)
{
TLLM_LOG_WARNING("setMaxUsableSmCount, maxUsableSmCount=%d, larger than smCount=%d, using smCount instead",
maxUsableSmCount, smCount);
maxUsableSmCount = smCount;
}
FusedMoeCommunicator::maxSmCount = maxUsableSmCount;
}
static int getMaxUsableSmCount()
{
FusedMoeCommunicator::maxSmCountUsed = true;
if (FusedMoeCommunicator::maxSmCount == -1)
{
int smCount = tensorrt_llm::common::getMultiProcessorCount();
FusedMoeCommunicator::maxSmCount = smCount;
}
return FusedMoeCommunicator::maxSmCount;
}
static int computeMoeCommChannelCount(int epSize)
{
int smCount = getMaxUsableSmCount();
int blockCountPerChannel = (epSize + MAX_GROUP_COUNT_PER_BLOCK - 1) / MAX_GROUP_COUNT_PER_BLOCK;
blockCountPerChannel *= 2; // for send and recv
TLLM_CHECK_WITH_INFO(
blockCountPerChannel <= smCount, "GPU should support at lease one channel, usableSmCount=%d", smCount);
int perferredChannel = smCount / 2 / blockCountPerChannel; // use half SMs for communication
int channelCount = std::max(perferredChannel, 1); // at lease one channel
return channelCount;
}
static int getMoeCommChannelCount(int epSize)
{
static std::map<int, int> channelCountMap{};
auto iter = channelCountMap.find(epSize);
if (iter == channelCountMap.end())
{
auto channelCount = FusedMoeCommunicator::computeMoeCommChannelCount(epSize);
channelCountMap[epSize] = channelCount;
return channelCount;
}
return iter->second;
}
static dim3 getLaunchBlockDim(int groupCountPerCta)
{
return dim3(WARP_SIZE, groupCountPerCta);
}
static dim3 getLaunchGridDim(int epSize, int groupCountPerCta)
{
int maxChannelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize);
int targetCtaCount = (epSize + MAX_GROUP_COUNT_PER_BLOCK - 1) / MAX_GROUP_COUNT_PER_BLOCK * maxChannelCount * 2;
int ctaPerChannel = (epSize + groupCountPerCta - 1) / groupCountPerCta;
int ctaLimitedChannelCount = targetCtaCount / 2 / ctaPerChannel;
ctaLimitedChannelCount = std::max(1, ctaLimitedChannelCount);
int channelCount = std::min(ctaLimitedChannelCount, maxChannelCount);
return dim3(ctaPerChannel, channelCount, 2);
}
};
size_t getFusedMoeCommWorkspaceSize(int epSize);
struct FusedMoeFieldInfo
{
int8_t isBasicInterleaved; // using tokenSelectedSlots and expertScales interleaving?
int32_t* tokenSelectedSlots;
float* expertScales; // can be nullptr if no scale is used(all 1.0), if so, interleaved should all be 0
int fieldCount;
MoeCommFieldInfo fieldsInfo[MOE_COMM_FIELD_MAX_COUNT];
__host__ int computeSingleCompactSize(int topK, bool hasScales, bool hasBasicFields) const
{
int basicFieldSize = 0;
if (hasBasicFields)
{
basicFieldSize = topK * sizeof(int) + (hasScales ? topK * sizeof(float) : 0);
// align to 16 bytes
basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1)
/ MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
}
int otherFieldSize = 0;
for (int i = 0; i < fieldCount; i++)
{
MoeCommFieldInfo const& fieldInfo = fieldsInfo[i];
otherFieldSize += fieldInfo.getFieldCompactSize();
}
int totalSize = basicFieldSize + otherFieldSize;
constexpr int totalSizeAlignment = MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
totalSize = (totalSize + totalSizeAlignment - 1) / totalSizeAlignment * totalSizeAlignment;
return totalSize;
}
__host__ int computeSingleUncompactSize(int topK, bool hasScales, bool hasBasicFields) const
{
int basicFieldSize = 0;
if (hasBasicFields)
{
basicFieldSize = topK * sizeof(int) + (hasScales ? topK * sizeof(float) : 0);
// align to 16 bytes
basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1)
/ MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
}
int otherFieldSize = 0;
for (int i = 0; i < fieldCount; i++)
{
MoeCommFieldInfo const& fieldInfo = fieldsInfo[i];
otherFieldSize += fieldInfo.getFieldUncompactSize();
}
int totalSize = basicFieldSize + otherFieldSize;
constexpr int totalSizeAlignment = MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
totalSize = (totalSize + totalSizeAlignment - 1) / totalSizeAlignment * totalSizeAlignment;
return totalSize;
}
template <typename T = int, bool IS_SLOTS = true>
__device__ __forceinline__ T* getBasicFieldPtr(int tokenIndex, int selectedIndex, int topK) const
{
T* fieldPtr = nullptr;
fieldPtr = IS_SLOTS ? reinterpret_cast<T*>(tokenSelectedSlots) : reinterpret_cast<T*>(expertScales);
if (fieldPtr == nullptr || selectedIndex >= topK)
{
return nullptr;
}
int tokenStride = isBasicInterleaved ? topK * 2 : topK;
int elementStride = isBasicInterleaved ? 2 : 1;
return fieldPtr + tokenIndex * tokenStride + selectedIndex * elementStride;
}
__device__ __forceinline__ int* getTokenSelectedSlotsPtr(int tokenIndex, int selectedIndex, int topK) const
{
return getBasicFieldPtr<int, true>(tokenIndex, selectedIndex, topK);
}
__device__ __forceinline__ float* getScalePtr(int tokenIndex, int selectedIndex, int topK) const
{
return getBasicFieldPtr<float, false>(tokenIndex, selectedIndex, topK);
}
void fillMetaInfo(MoeSingleCommMeta* singleCommMeta, int topK, bool hasScales, bool hasBasicFields) const;
void fillFieldPlacementInfo(int topK, bool hasBasicFields);
};
struct FusedMoeCommKernelParam
{
FusedMoeWorldInfo worldInfo;
MoeExpertParallelInfo expertParallelInfo; // expertCount inside should be slotCount if using redundant experts.
MoeSingleCommMeta sendCommMeta;
MoeSingleCommMeta recvCommMeta;
SendRecvIndices sendIndices;
SendRecvIndices recvIndices;
FusedMoeFieldInfo sendFieldInfo;
FusedMoeFieldInfo recvFieldInfo;
};
/*
* Workspace Layout:
* Ri: Rank i
* N: Number of GPUs, e.g. EpSize or WorldSize, n = N - 1
* Ci: Channel i
* M: Number of Channels, m = M - 1
* MMr: Memory Mapped from Rank r, physically located at rank r, and mapped to all ranks.
*
* Whole workspace memory space:
* ---------------------------------------------------------------------------------------------------
* |<-- MM0 --> |<-- MM1 --> |<-- MM2 --> | ...... |<-- MMn --> |
* ^ ^ ^ ^ ^ ^
* 0 rankStrideInU64 2*rankStrideInU64 3*rankStrideInU64 n*rankStrideInU64 N*rankStrideInU64
*
* For each MMr, the layout is:
* -------------------------------------------------------------------------------------------------
* |<--- FIFO memory --->|<--- SenderSideFifoInfo memory --->|<--- ReceiverSideFifoInfo memory --->|
* -------------------------------------------------------------------------------------------------
*
* For each FIFO memory, it is physically placed at the receiver rank.
* To find the FIFO whose receiver is rank r, we need to find that in the FIFO memory of MMr.
* The layout of FIFO memory of each MMR is(here rank is the sender rank):
* -------------------------------------------------------------------------------------------------
* | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm |
* |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->|
* -------------------------------------------------------------------------------------------------
* Each R*C* has length of FIFO_TOTAL_U64 in uint64_t, which is internally divided into FIFO_DEPTH entries of
* size FIFO_ENTRY_BYTES each.
*
* For each SenderSideFifoInfo memory, it is physically placed at the sender rank.
* To find the SenderSideFifoInfo whose sender is rank r, we need to find that in the FIFO memory of MMr.
* The layout of SenderSideFifoInfo memory of each MMR is(here rank is the receiver rank):
* -------------------------------------------------------------------------------------------------
* | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm |
* |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->|
* -------------------------------------------------------------------------------------------------
* Each R*C* is one struct of SenderSideFifoInfo. There are total M * N SenderSideFifoInfo in each MMR.
*
* For each ReceiverSideFifoInfo memory, it is physically placed at the receiver rank.
* To find the ReceiverSideFifoInfo whose receiver is rank r, we need to find that in the FIFO memory of MMr.
* The layout of ReceiverSideFifoInfo memory of each MMR is(here rank is the sender rank):
* -------------------------------------------------------------------------------------------------
* | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm |
* |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->|
* -------------------------------------------------------------------------------------------------
* Each R*C* is one struct of ReceiverSideFifoInfo. There are total M * N ReceiverSideFifoInfo in each MMR.
*/
struct FusedMoeWorkspace
{
uint64_t* workspacePtr;
size_t rankStrideInU64;
int channelCount;
template <bool isSenderSideBuffer>
__device__ __forceinline__ uint8_t* commonGetPtrBase(
FusedMoePairInfo const& pairInfo, size_t fieldOffset, int fieldSingleSize) const
{
int mappedMemoryrank = isSenderSideBuffer ? pairInfo.senderRank : pairInfo.receiverRank;
int rankInsideMappedMemory = isSenderSideBuffer ? pairInfo.receiverRank : pairInfo.senderRank;
auto* mappedMemory = reinterpret_cast<uint8_t*>(workspacePtr + mappedMemoryrank * rankStrideInU64);
mappedMemory += fieldOffset;
mappedMemory += rankInsideMappedMemory * channelCount * fieldSingleSize;
mappedMemory += pairInfo.channel * fieldSingleSize;
return mappedMemory;
}
__device__ __forceinline__ uint64_t* getFifoBasePtr(
FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const
{
constexpr int fieldSingleSize = FusedMoeCommunicator::FIFO_TOTAL_BYTES;
return reinterpret_cast<uint64_t*>(commonGetPtrBase<false>(pairInfo, 0, fieldSingleSize));
}
__device__ __forceinline__ SenderSideFifoInfo* getSenderSideFifoInfo(
FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const
{
constexpr int fieldSingleSize = sizeof(SenderSideFifoInfo);
size_t fieldOffset
= static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * worldInfo.epInfo.epSize * channelCount;
return reinterpret_cast<SenderSideFifoInfo*>(commonGetPtrBase<true>(pairInfo, fieldOffset, fieldSingleSize));
}
__device__ __forceinline__ ReceiverSideFifoInfo* getReceiverSideFifoInfo(
FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const
{
constexpr int fieldSingleSize = sizeof(ReceiverSideFifoInfo);
size_t fieldOffset
= static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * worldInfo.epInfo.epSize * channelCount
+ sizeof(SenderSideFifoInfo) * worldInfo.epInfo.epSize * channelCount;
return reinterpret_cast<ReceiverSideFifoInfo*>(commonGetPtrBase<false>(pairInfo, fieldOffset, fieldSingleSize));
}
static size_t computeWorkspaceSizePreRank(int epSize, int channelCount)
{
size_t fifoSize = static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * epSize * channelCount;
size_t senderSideInfoSize = sizeof(SenderSideFifoInfo) * epSize * channelCount;
size_t receiverSideInfoSize = sizeof(ReceiverSideFifoInfo) * epSize * channelCount;
return fifoSize + senderSideInfoSize + receiverSideInfoSize;
}
void initializeLocalWorkspace(FusedMoeWorldInfo const& worldInfo);
};
void setMaxUsableSmCount(int smCount);
void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cudaStream_t stream);
void constructWorkspace(FusedMoeWorkspace* workspace, uint64_t* workspacePtr, size_t rankStrideInU64, int epSize);
void initializeFusedMoeLocalWorkspace(FusedMoeWorkspace* workspace, FusedMoeWorldInfo const& worldInfo);
namespace fused_moe_comm_tests
{
// Functions for testing
void launchSingleG2S(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
int tokenCount, int* shmDump, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream);
void launchSingleS2G(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
int tokenCount, int* shmPreload, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream);
void launchLoopback(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
MoeExpertParallelInfo const& expertParallelInfo, int* recvIndexMapping, int tokenCount, int warpsPerBlock,
bool hasBasicFields, cudaStream_t stream);
void launchLocalFifoSendRecv(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
MoeExpertParallelInfo const& expertParallelInfo, int* sendIndexMapping, int* recvIndexMapping,
FusedMoeWorkspace fusedMoeWorkspace, int tokenCount, int warpsPerBlock, int blockChannelCount, bool hasBasicFields,
cudaStream_t stream);
} // namespace fused_moe_comm_tests
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -1,804 +0,0 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "moeCommKernels.h"
#include <stdio.h>
#include <cooperative_groups.h>
#include <cub/cub.cuh>
namespace cg = cooperative_groups;
namespace tensorrt_llm::kernels
{
__device__ inline void barrier_sync(int name, int nThreads)
{
asm volatile("barrier.sync.aligned %0, %1;" ::"r"(name), "r"(nThreads) : "memory");
}
inline __device__ void load128(uint64_t const* ptr, uint64_t& v0, uint64_t& v1)
{
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v0), "=l"(v1) : "l"(ptr) : "memory");
}
inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1)
{
asm volatile("st.volatile.global.v2.u64 [%2], {%0,%1};" ::"l"(v0), "l"(v1), "l"(ptr) : "memory");
}
template <bool isSender>
class AllToAllChannelCommunicator : public AllToAllChannelCommunicatorBase
{
private:
int const tid; // thread index in primitives group
int const nthreads; // number of threads in primitives group
int const wid; // lane index in warp
int const warp; // warp index in primitives group
const MoeEpWorldInfo worldInfo;
const MoeCommWorkspace workspace;
const SendRecvDataInfo sendRecvDataInfo;
const SendRecvDispls dataDispls;
int peerRank; // peer rank index
bool const flagThread;
int const group; // primitives group index
int const channel; // channel index
int const channelCount; // count of channels
MoeCommFifoConnInfo* fifoConnInfoPtr;
uint64_t* fifoBasePtr; // pointer to fifo base address
uint64_t step;
uint64_t tailStepCache;
uint64_t regs[U64_DATA_REG_PER_THREAD];
GroupSharedBuffer* groupSharedBuffer;
int groupStartIndice;
int groupEndIndice;
int sliceStartIndice;
int sliceEndIndice;
uint64_t* stepFifoEntryPtr;
public:
__inline__ __device__ uint64_t getFlag()
{
return step + 1;
}
__inline__ __device__ AllToAllChannelCommunicator(MoeEpWorldInfo const& worldInfo, MoeCommWorkspace workspace,
SendRecvDataInfo sendRecvDataInfo, SendRecvDispls dataDispls, GroupSharedBuffer* groupSharedBuffer,
int channelCount)
: worldInfo(worldInfo)
, nthreads(blockDim.x)
, tid(threadIdx.x)
, workspace(workspace)
, sendRecvDataInfo(sendRecvDataInfo)
, dataDispls(dataDispls)
, wid(threadIdx.x % WARP_SIZE)
, warp(threadIdx.x / WARP_SIZE)
, peerRank(blockIdx.x * GROUP_COUNT_PER_BLOCK + threadIdx.y)
, group(threadIdx.y)
, channel(blockIdx.y)
, flagThread(threadIdx.x % 8 == 7)
, fifoConnInfoPtr(nullptr)
, fifoBasePtr(nullptr)
, step(0)
, tailStepCache(0)
, groupSharedBuffer(groupSharedBuffer)
, channelCount(channelCount)
{
}
__inline__ __device__ void init()
{
fifoBasePtr = workspace.getFifoBasePtr(isSender, worldInfo.epRank, peerRank, channel, channelCount);
fifoConnInfoPtr
= workspace.getFifoConnInfo(isSender, worldInfo.epRank, peerRank, channel, worldInfo.epSize, channelCount);
step = isSender ? fifoConnInfoPtr->head : fifoConnInfoPtr->tail;
tailStepCache = isSender ? fifoConnInfoPtr->tail : 0;
}
__inline__ __device__ void computeGroupTransferRange()
{
if (tid == 0)
{
int rankCount = dataDispls.getCount(peerRank);
int rankStart = dataDispls.getRankStart(peerRank);
int countPerChannel = (rankCount + channelCount - 1) / channelCount;
int groupEnd = min(rankStart + (channel + 1) * countPerChannel, rankStart + rankCount);
int groupStart = min(rankStart + channel * countPerChannel, rankStart + rankCount);
groupSharedBuffer->groupStartIndice = groupStart;
groupSharedBuffer->groupEndIndice = groupEnd;
}
barrier();
groupStartIndice = groupSharedBuffer->groupStartIndice;
groupEndIndice = groupSharedBuffer->groupEndIndice;
}
__inline__ __device__ void loadTransferIndices()
{
sliceStartIndice = groupStartIndice;
sliceEndIndice = min(groupStartIndice + sendRecvDataInfo.vectorCountPerFifoEntry, groupEndIndice);
for (int i = groupStartIndice + tid; i < sliceEndIndice; i += WARP_SIZE * WARP_PER_GROUP)
{
groupSharedBuffer->groupIndiceBuffer[i - groupStartIndice] = dataDispls.getRealVectorIndice(i);
}
groupStartIndice = sliceEndIndice;
barrier();
}
__inline__ __device__ void computeSlicePtr()
{
stepFifoEntryPtr = fifoBasePtr + RECV_FIFO_ENTRY_U64 * (step % RECV_FIFO_DEPTH);
}
__inline__ __device__ void sendSlice()
{
waitSend();
int EltPer16B = 2;
int eltN = sendRecvDataInfo.vectorSizeInU64;
for (int vecId = warp + sliceStartIndice; vecId < sliceEndIndice; vecId += WARP_PER_GROUP)
{
int idxInSlice = vecId - sliceStartIndice;
int vecRealIdx = groupSharedBuffer->groupIndiceBuffer[idxInSlice];
uint64_t* src = dataDispls.getVectorDataPtr(vecRealIdx);
uint64_t* slicePtr = stepFifoEntryPtr
+ idxInSlice * sendRecvDataInfo.dataPacketCountPerVector * PACKET_SIZE_IN_U64 + 2 * wid;
for (int packetId = 0; packetId < sendRecvDataInfo.dataPacketCountPerVector; packetId++)
{
int vecOff = packetId * DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64;
#pragma unroll
for (int g = 0; g < U64_DATA_REG_PER_THREAD / 2; g++)
{
int ix = g * WARP_SIZE - 4 * (g / 2) + wid - (g % 2) * (wid / 8);
__syncwarp();
if (!flagThread || g % 2 == 0)
{
if (ix * EltPer16B + vecOff < eltN)
{
load128((uint64_t*) (src + ix * EltPer16B + vecOff), regs[2 * g + 0], regs[2 * g + 1]);
}
}
__syncwarp();
}
#pragma unroll
for (int g = 1; g < U64_DATA_REG_PER_THREAD / 2; g += 2)
{
if (flagThread)
regs[2 * g] = regs[2 * g - 1];
}
uint64_t flag = getFlag();
uint64_t* packetPtr = slicePtr + packetId * PACKET_SIZE_IN_U64;
__syncwarp();
#pragma unroll
for (int u = 0; u < U64_DATA_REG_PER_THREAD; u += 2)
{
store128(packetPtr + u * WARP_SIZE, regs[u], flagThread ? flag : regs[u + 1]);
}
}
}
updateSend();
}
__inline__ __device__ void recvSlice()
{
// receiver don't need to wait since we have flag.
int EltPer16B = 2;
int eltN = sendRecvDataInfo.vectorSizeInU64;
for (int vecId = warp + sliceStartIndice; vecId < sliceEndIndice; vecId += WARP_PER_GROUP)
{
int idxInSlice = vecId - sliceStartIndice;
int vecRealIdx = groupSharedBuffer->groupIndiceBuffer[idxInSlice];
uint64_t* dst = dataDispls.getVectorDataPtr(vecRealIdx);
uint64_t* slicePtr = stepFifoEntryPtr
+ idxInSlice * sendRecvDataInfo.dataPacketCountPerVector * PACKET_SIZE_IN_U64 + 2 * wid;
for (int packetId = 0; packetId < sendRecvDataInfo.dataPacketCountPerVector; packetId++)
{
uint64_t* packetPtr = slicePtr + packetId * PACKET_SIZE_IN_U64;
int vecOff = packetId * DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64;
bool needReload;
uint64_t flag = getFlag();
__syncwarp();
do
{
needReload = false;
#pragma unroll
for (int u = 0; u < U64_DATA_REG_PER_THREAD; u += 2)
{
load128(packetPtr + u * WARP_SIZE, regs[u], regs[u + 1]);
needReload |= flagThread && (regs[u + 1] != flag);
}
} while (__any_sync(WARP_MASK, needReload));
#pragma unroll
for (int g = 1; g < U64_DATA_REG_PER_THREAD / 2; g += 2)
{
if (flagThread)
regs[2 * g - 1] = regs[2 * g];
}
#pragma unroll
for (int g = 0; g < U64_DATA_REG_PER_THREAD / 2; g++)
{
int ix = g * WARP_SIZE - 4 * (g / 2) + wid - (g % 2) * (wid / 8);
__syncwarp();
if (!flagThread || g % 2 == 0)
{
if (ix * EltPer16B + vecOff < eltN)
{
store128((uint64_t*) (dst + ix * EltPer16B + vecOff), regs[2 * g + 0], regs[2 * g + 1]);
}
}
__syncwarp();
}
}
}
updateRecv();
}
__inline__ __device__ void run()
{
if (peerRank >= worldInfo.epSize)
{
return;
}
init();
computeGroupTransferRange();
while (groupStartIndice < groupEndIndice)
{
loadTransferIndices();
computeSlicePtr();
if (isSender)
{
sendSlice();
}
else
{
recvSlice();
}
}
}
__inline__ __device__ ~AllToAllChannelCommunicator() {}
__inline__ __device__ void barrier()
{
barrier_sync(15 - group, nthreads);
}
__inline__ __device__ void waitSend()
{
barrier();
while (tailStepCache + RECV_FIFO_DEPTH < step + 1)
{
tailStepCache = fifoConnInfoPtr->tail;
}
barrier();
}
__inline__ __device__ void updateSend()
{
barrier();
if (tid == 0)
{
atomicAdd_system((unsigned long long*) &fifoConnInfoPtr->head, 1);
}
barrier();
step++;
}
__inline__ __device__ void updateRecv()
{
barrier();
if (tid == 0)
{
atomicAdd_system((unsigned long long*) &fifoConnInfoPtr->tail, 1);
}
barrier();
step++;
}
};
__global__ void moeAllToAllKernel(MoeEpWorldInfo worldInfo, MoeCommWorkspace workspace,
SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls, SendRecvDispls recvDispls)
{
__shared__ AllToAllChannelCommunicatorBase::GroupSharedBuffer
allGroupSharedBuffer[AllToAllChannelCommunicatorBase::GROUP_COUNT_PER_BLOCK];
bool isSender = blockIdx.z == 0;
int channelCount = gridDim.y;
int group = threadIdx.y;
SendRecvDispls dataDispls = isSender ? sendDispls : recvDispls;
AllToAllChannelCommunicatorBase::GroupSharedBuffer* groupSharedBuffer = &allGroupSharedBuffer[group];
if (isSender)
{
AllToAllChannelCommunicator<true> comm(
worldInfo, workspace, sendRecvDataInfo, dataDispls, groupSharedBuffer, channelCount);
comm.run();
}
else
{
AllToAllChannelCommunicator<false> comm(
worldInfo, workspace, sendRecvDataInfo, dataDispls, groupSharedBuffer, channelCount);
comm.run();
}
}
void moeAllToAll(MoeEpWorldInfo worldInfo, SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls,
SendRecvDispls recvDispls, MoeCommWorkspace workspace, cudaStream_t stream)
{
sendRecvDataInfo.DoPreCompute();
TLLM_CHECK_WITH_INFO(
reinterpret_cast<uintptr_t>(sendDispls.dataPtr) % 16 == 0, "sendDispls.dataPtr must be 16-byte aligned");
TLLM_CHECK_WITH_INFO(
reinterpret_cast<uintptr_t>(recvDispls.dataPtr) % 16 == 0, "recvDispls.dataPtr must be 16-byte aligned");
dim3 block = AllToAllChannelCommunicatorBase::getLaunchBlockDim();
dim3 grid = AllToAllChannelCommunicatorBase::getLaunchGridDim(worldInfo.epSize);
moeAllToAllKernel<<<grid, block, 0, stream>>>(worldInfo, workspace, sendRecvDataInfo, sendDispls, recvDispls);
}
template <bool isSend, int kThreadsGroupSize>
__inline__ __device__ void computeSendRecvRankCountDevice(MoeEpWorldInfo worldInfo,
MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank, int const* realRankTokenCountCumSum,
int const* gatheredTargetRankIds, int* sharedSendRecvRankCount, int* sendRecvRankCount)
{
cg::thread_block_tile<kThreadsGroupSize> tile = cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
int laneInTile = tile.thread_rank();
int tileId = threadIdx.x / kThreadsGroupSize;
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
int topK = expertParallelInfo.topK;
int epRank = worldInfo.epRank;
int epSize = worldInfo.epSize;
if (threadIdx.x == 0)
{
*sharedSendRecvRankCount = 0;
}
__syncthreads();
int readRank = isSend ? epRank : blockIdx.x;
int compareRankId = isSend ? blockIdx.x : epRank;
int const* readRankTargetRankIds = gatheredTargetRankIds + readRank * maxTokenCountPerRank * topK;
int readRankTokenCount = maxTokenCountPerRank;
if (realRankTokenCountCumSum != nullptr)
{
int readRankStart = readRank == 0 ? 0 : realRankTokenCountCumSum[readRank - 1];
readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK;
readRankTokenCount = realRankTokenCountCumSum[readRank] - readRankStart;
}
for (int i = tileId + blockIdx.z * tileCountPerBlock; i < readRankTokenCount; i += tileCountPerBlock * gridDim.z)
{
int targetRankId = laneInTile < topK ? readRankTargetRankIds[i * topK + laneInTile] : epSize;
bool rankMatched = (targetRankId == compareRankId);
bool hasRankMatched = tile.any(rankMatched);
if (hasRankMatched && laneInTile == 0)
{
atomicAdd_block(sharedSendRecvRankCount, 1);
}
tile.sync();
}
__syncthreads();
if (threadIdx.x == 0)
{
atomicAdd_system(sendRecvRankCount + blockIdx.x, *sharedSendRecvRankCount);
}
}
template <int kThreadsGroupSize>
__global__ void computeSendRecvRankCountKernel(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds, int* sendRankCount,
int* recvRankCount)
{
static_assert(kThreadsGroupSize == 1 || kThreadsGroupSize == 2 || kThreadsGroupSize == 4 || kThreadsGroupSize == 8
|| kThreadsGroupSize == 16 || kThreadsGroupSize == 32,
"Only 1, 2, 4, 8, 16, 32 threads group size supported now.");
__shared__ int sharedSendRecvRankCount;
if (blockIdx.y == 0)
{
// compute send rank count
computeSendRecvRankCountDevice<true, kThreadsGroupSize>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
realRankTokenCountCumSum, gatheredTargetRankIds, &sharedSendRecvRankCount, sendRankCount);
}
else
{
// compute recv rank count
computeSendRecvRankCountDevice<false, kThreadsGroupSize>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
realRankTokenCountCumSum, gatheredTargetRankIds, &sharedSendRecvRankCount, recvRankCount);
}
}
void computeSendRecvRankCount(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds, int* sendRankCount,
int* recvRankCount, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(expertParallelInfo.topK <= 32, "Only topK less than or equal to 32 supported now.");
int threadsPerBlock = 1024;
auto* kernelPtr = computeSendRecvRankCountKernel<32>;
if (expertParallelInfo.topK <= 1)
{
kernelPtr = computeSendRecvRankCountKernel<1>;
}
else if (expertParallelInfo.topK <= 2)
{
kernelPtr = computeSendRecvRankCountKernel<2>;
}
else if (expertParallelInfo.topK <= 4)
{
kernelPtr = computeSendRecvRankCountKernel<4>;
}
else if (expertParallelInfo.topK <= 8)
{
kernelPtr = computeSendRecvRankCountKernel<8>;
}
else if (expertParallelInfo.topK <= 16)
{
kernelPtr = computeSendRecvRankCountKernel<16>;
}
dim3 block(worldInfo.epSize, 2, 1);
kernelPtr<<<block, threadsPerBlock, 0, stream>>>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
realRankTokenCountCumSum, gatheredTargetRankIds, sendRankCount, recvRankCount);
}
template <int kThreadsPerBlock>
__global__ void inplaceSendRecvRankCumSumKernel(MoeEpWorldInfo worldInfo, int* sendRankCount, int* recvRankCount)
{
int* inputOutputPtr = blockIdx.x == 0 ? sendRankCount : recvRankCount;
typedef cub::BlockScan<int, kThreadsPerBlock> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
int tid = threadIdx.x;
int threadData = tid < worldInfo.epSize ? inputOutputPtr[tid] : 0;
BlockScan(temp_storage).InclusiveSum(threadData, threadData);
if (tid < worldInfo.epSize)
{
inputOutputPtr[tid] = threadData;
}
}
void inplaceSendRecvRankCumSum(MoeEpWorldInfo worldInfo, int* sendRankCount, int* recvRankCount, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(worldInfo.epSize <= 1024, "Only worldInfo.epSize less than or equal to 1024 supported now.");
auto* kernelPtr = inplaceSendRecvRankCumSumKernel<1024>;
int blockSize = 1024;
if (worldInfo.epSize <= 32)
{
kernelPtr = inplaceSendRecvRankCumSumKernel<32>;
blockSize = 32;
}
else if (worldInfo.epSize <= 64)
{
kernelPtr = inplaceSendRecvRankCumSumKernel<64>;
blockSize = 64;
}
else if (worldInfo.epSize <= 128)
{
kernelPtr = inplaceSendRecvRankCumSumKernel<128>;
blockSize = 128;
}
else if (worldInfo.epSize <= 256)
{
kernelPtr = inplaceSendRecvRankCumSumKernel<256>;
blockSize = 256;
}
else if (worldInfo.epSize <= 512)
{
kernelPtr = inplaceSendRecvRankCumSumKernel<512>;
blockSize = 512;
}
kernelPtr<<<2, blockSize, 0, stream>>>(worldInfo, sendRankCount, recvRankCount);
}
template <bool isSend, int kThreadsGroupSize, int kThreadsPerBlock>
__inline__ __device__ void computeSendRecvIndicesDevice(MoeEpWorldInfo worldInfo,
MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank, int const* realRankTokenCountCumSum,
int const* gatheredTargetRankIds, int const* sendRecvCumSum,
int* sendRecvIndices, // send or receive
int* localGatherIndices, // receive only
int* backwardRecvRankLocalIndices, // send only
int* sharedSendRecvRankStart, typename cub::BlockScan<int, kThreadsPerBlock>::TempStorage& tempStorage)
{
cg::thread_block_tile<kThreadsGroupSize> tile = cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
int laneInTile = tile.thread_rank();
int tileId = threadIdx.x / kThreadsGroupSize;
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
int topK = expertParallelInfo.topK;
int epRank = worldInfo.epRank;
int epSize = worldInfo.epSize;
if (threadIdx.x == 0)
{
*sharedSendRecvRankStart = blockIdx.x == 0 ? 0 : sendRecvCumSum[blockIdx.x - 1];
}
__syncthreads();
int readRank = isSend ? epRank : blockIdx.x;
int compareRankId = isSend ? blockIdx.x : epRank;
int readRankStart = readRank * maxTokenCountPerRank;
int const* readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK;
int readRankTokenCount = maxTokenCountPerRank;
if (realRankTokenCountCumSum != nullptr)
{
readRankStart = readRank == 0 ? 0 : realRankTokenCountCumSum[readRank - 1];
readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK;
readRankTokenCount = realRankTokenCountCumSum[readRank] - readRankStart;
}
for (int blockStartId = blockIdx.z * tileCountPerBlock; blockStartId < readRankTokenCount;
blockStartId += tileCountPerBlock * gridDim.z)
{
int stepStartIndice = *sharedSendRecvRankStart;
int i = blockStartId + tileId;
int targetRankId
= (laneInTile < topK && i < readRankTokenCount) ? readRankTargetRankIds[i * topK + laneInTile] : epSize;
bool rankMatched = (targetRankId == compareRankId);
bool hasRankMatched = tile.any(rankMatched);
unsigned int laneMask = tile.ballot(rankMatched);
int lowestLane = __ffs(laneMask) - 1;
int isMatchedLane = (hasRankMatched && laneInTile == lowestLane) ? 1 : 0;
int indice;
typedef cub::BlockScan<int, kThreadsPerBlock> BlockScan;
BlockScan(tempStorage).ExclusiveSum(isMatchedLane, indice);
indice += stepStartIndice;
__syncthreads();
if (isMatchedLane == 1)
{
atomicAdd_block(sharedSendRecvRankStart, 1);
if (isSend)
{
sendRecvIndices[indice] = i;
backwardRecvRankLocalIndices[indice] = i * topK + lowestLane;
}
else
{
sendRecvIndices[indice] = indice;
localGatherIndices[indice] = readRankStart + i;
}
}
__syncthreads();
}
}
template <int kThreadsGroupSize, int kThreadsPerBlock>
__global__ void computeSendRecvIndicesKernel(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds,
int const* sendRankCountCumSum, int const* recvRankCountCumSum, int* localGatherIndices, int* sendRankLocalIndices,
int* recvRankLocalIndices, int* backwardRecvRankLocalIndices)
{
static_assert(kThreadsGroupSize == 1 || kThreadsGroupSize == 2 || kThreadsGroupSize == 4 || kThreadsGroupSize == 8
|| kThreadsGroupSize == 16 || kThreadsGroupSize == 32,
"Only 1, 2, 4, 8, 16, 32 threads group size supported now.");
__shared__ int sharedSendRecvRankStart;
__shared__ typename cub::BlockScan<int, kThreadsPerBlock>::TempStorage tempStorage;
if (blockIdx.y == 0)
{
// compute send rank count
computeSendRecvIndicesDevice<true, kThreadsGroupSize, kThreadsPerBlock>(worldInfo, expertParallelInfo,
maxTokenCountPerRank, realRankTokenCountCumSum, gatheredTargetRankIds, sendRankCountCumSum,
sendRankLocalIndices, localGatherIndices, backwardRecvRankLocalIndices, &sharedSendRecvRankStart,
tempStorage);
}
else
{
// compute recv rank count
computeSendRecvIndicesDevice<false, kThreadsGroupSize, kThreadsPerBlock>(worldInfo, expertParallelInfo,
maxTokenCountPerRank, realRankTokenCountCumSum, gatheredTargetRankIds, recvRankCountCumSum,
recvRankLocalIndices, localGatherIndices, backwardRecvRankLocalIndices, &sharedSendRecvRankStart,
tempStorage);
}
}
void computeSendRecvIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds,
int const* sendRankCountCumSum, int const* recvRankCountCumSum, int* localGatherIndices, int* sendRankLocalIndices,
int* recvRankLocalIndices, int* backwardRecvRankLocalIndices, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(expertParallelInfo.topK <= 32, "Only topK less than or equal to 32 supported now.");
int threadsPerBlock = 1024;
auto* kernelPtr = computeSendRecvIndicesKernel<32, 1024>;
if (expertParallelInfo.topK <= 1)
{
kernelPtr = computeSendRecvIndicesKernel<1, 1024>;
}
else if (expertParallelInfo.topK <= 2)
{
kernelPtr = computeSendRecvIndicesKernel<2, 1024>;
}
else if (expertParallelInfo.topK <= 4)
{
kernelPtr = computeSendRecvIndicesKernel<4, 1024>;
}
else if (expertParallelInfo.topK <= 8)
{
kernelPtr = computeSendRecvIndicesKernel<8, 1024>;
}
else if (expertParallelInfo.topK <= 16)
{
kernelPtr = computeSendRecvIndicesKernel<16, 1024>;
}
else if (expertParallelInfo.topK <= 32)
{
kernelPtr = computeSendRecvIndicesKernel<32, 1024>;
}
dim3 block(worldInfo.epSize, 2, 1);
kernelPtr<<<block, threadsPerBlock, 0, stream>>>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
realRankTokenCountCumSum, gatheredTargetRankIds, sendRankCountCumSum, recvRankCountCumSum, localGatherIndices,
sendRankLocalIndices, recvRankLocalIndices, backwardRecvRankLocalIndices);
}
__global__ void moeAllToAllMemsetKernel(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank, int* sendRankCountCumSum, int* recvRankCountCumSum, int* localGatherIndices,
int* sendRankLocalIndices, int* recvRankLocalIndices, int* backwardRecvRankLocalIndices)
{
int maxSendRanksPerToken = std::max(worldInfo.epSize, expertParallelInfo.topK);
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int maxRankRecvTokenCount = maxTokenCountPerRank * worldInfo.epSize;
int maxRankSendTokenCount = maxTokenCountPerRank * maxSendRanksPerToken;
if (idx < worldInfo.epSize)
{
sendRankCountCumSum[idx] = 0;
recvRankCountCumSum[idx] = 0;
}
if (idx < maxRankRecvTokenCount)
{
localGatherIndices[idx] = -1;
recvRankLocalIndices[idx] = -1;
}
if (idx < maxRankSendTokenCount)
{
sendRankLocalIndices[idx] = -1;
backwardRecvRankLocalIndices[idx] = -1;
}
}
void moeAllToAllMemset(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
int* sendRankCountCumSum, int* recvRankCountCumSum, int* localGatherIndices, int* sendRankLocalIndices,
int* recvRankLocalIndices, int* backwardRecvRankLocalIndices, cudaStream_t stream)
{
int maxSendRanksPerToken = std::max(worldInfo.epSize, expertParallelInfo.topK);
int maxRankRecvTokenCount = maxTokenCountPerRank * worldInfo.epSize;
int maxRankSendTokenCount = maxTokenCountPerRank * maxSendRanksPerToken;
int maxEltCount = std::max<int>(maxRankRecvTokenCount, maxRankSendTokenCount);
maxEltCount = std::max<int>(maxEltCount, worldInfo.epSize);
static constexpr int kBlockSize = 256;
int blockCount = (maxEltCount + kBlockSize - 1) / kBlockSize;
dim3 grid(blockCount, 1);
moeAllToAllMemsetKernel<<<grid, kBlockSize, 0, stream>>>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
sendRankCountCumSum, recvRankCountCumSum, localGatherIndices, sendRankLocalIndices, recvRankLocalIndices,
backwardRecvRankLocalIndices);
}
void moeAllToAllPrepareIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank, int const* gatheredTargetRankIds, int const* realRankTokenCountCumSum,
// indices of gatheredTargetRankIds that has the local rank in topK
int* localGatherIndices, // max length = maxTokenCountPerRank * worldInfo.epSize when all ranks send to current
// rank
int* sendRankCountCumSum, // max length = worldInfo.epSize
int* sendRankLocalIndices, // max length = maxTokenCountPerRank * expertParallelInfo.expertCount when current rank
// has maxTokenCountPerRank tokens to send and all has expertCount dest
int* recvRankCountCumSum, // max length = worldInfo.epSize
int* recvRankLocalIndices, // max length = maxTokenCountPerRank * worldInfo.epSize when all ranks send to current
// rank
// the rankCountCumSum of combineRecv should be the same as sendRankCountCumSum
int*
backwardRecvRankLocalIndices, // max length = maxTokenCountPerRank * expertParallelInfo.expertCount when current
// rank has maxTokenCountPerRank tokens to send and all has expertCount dest
cudaStream_t stream)
{
moeAllToAllMemset(worldInfo, expertParallelInfo, maxTokenCountPerRank, sendRankCountCumSum, recvRankCountCumSum,
localGatherIndices, sendRankLocalIndices, recvRankLocalIndices, backwardRecvRankLocalIndices, stream);
TLLM_CHECK_WITH_INFO(worldInfo.epSize <= 1024, "Only worldInfo.epSize less than or equal to 1024 supported now.");
computeSendRecvRankCount(worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum,
gatheredTargetRankIds, sendRankCountCumSum, recvRankCountCumSum, stream);
inplaceSendRecvRankCumSum(worldInfo, sendRankCountCumSum, recvRankCountCumSum, stream);
computeSendRecvIndices(worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum,
gatheredTargetRankIds, sendRankCountCumSum, recvRankCountCumSum, localGatherIndices, sendRankLocalIndices,
recvRankLocalIndices, backwardRecvRankLocalIndices, stream);
}
template <int kThreadsGroupSize>
__global__ void moeLocalGatherDevice(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank, int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices,
int const* gatheredExpertIds, float const* gatheredScales, int* localExpertIds, float* localScales)
{
cg::thread_block_tile<kThreadsGroupSize> tile = cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
int laneInTile = tile.thread_rank();
int tileId = threadIdx.x / kThreadsGroupSize;
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
int epSize = worldInfo.epSize;
int rankTokenCount = recvRankCountCumSum[epSize - 1];
if (laneInTile >= expertParallelInfo.topK)
{
return;
}
for (int index = tileId + blockIdx.x * tileCountPerBlock; index < localMaxTokenCount;
index += tileCountPerBlock * gridDim.x)
{
int localTokenIndice = localGatherIndices[index];
int expertId = index < rankTokenCount
? gatheredExpertIds[localTokenIndice * expertParallelInfo.topK + laneInTile]
: expertParallelInfo.expertCount;
localExpertIds[index * expertParallelInfo.topK + laneInTile] = expertId;
if (gatheredScales)
{
float scale = index < rankTokenCount
? gatheredScales[localTokenIndice * expertParallelInfo.topK + laneInTile]
: 0.0f;
localScales[index * expertParallelInfo.topK + laneInTile] = scale;
}
}
}
void moeLocalGather(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices, int const* gatheredExpertIds,
float const* gatheredScales, int* localExpertIds, float* localScales, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(expertParallelInfo.topK <= 32, "Only topK less than or equal to 32 supported now.");
auto* kernelPtr = moeLocalGatherDevice<32>;
int paddedTopK = 32;
if (expertParallelInfo.topK <= 1)
{
paddedTopK = 1;
kernelPtr = moeLocalGatherDevice<1>;
}
else if (expertParallelInfo.topK <= 2)
{
paddedTopK = 2;
kernelPtr = moeLocalGatherDevice<2>;
}
else if (expertParallelInfo.topK <= 4)
{
paddedTopK = 4;
kernelPtr = moeLocalGatherDevice<4>;
}
else if (expertParallelInfo.topK <= 8)
{
paddedTopK = 8;
kernelPtr = moeLocalGatherDevice<8>;
}
else if (expertParallelInfo.topK <= 16)
{
paddedTopK = 16;
kernelPtr = moeLocalGatherDevice<16>;
}
int threadsPerBlock = 512;
int tokenPerBlock = threadsPerBlock / paddedTopK;
int blockCount = (localMaxTokenCount + tokenPerBlock - 1) / tokenPerBlock * 2;
kernelPtr<<<blockCount, threadsPerBlock, 0, stream>>>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
localMaxTokenCount, recvRankCountCumSum, localGatherIndices, gatheredExpertIds, gatheredScales, localExpertIds,
localScales);
}
int AllToAllChannelCommunicatorBase::maxSmCount = -1;
bool AllToAllChannelCommunicatorBase::maxSmCountUsed = false;
void setMaxUsableSmCount(int smCount)
{
AllToAllChannelCommunicatorBase::setMaxUsableSmCount(smCount);
}
} // namespace tensorrt_llm::kernels

View File

@ -1,268 +0,0 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <map>
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm::kernels
{
#ifdef __CUDACC__
#define ALIGN_256 __align__(256)
#else
#define ALIGN_256 alignas(256)
#endif
struct ALIGN_256 MoeCommFifoConnInfo
{
volatile uint64_t head; // write position
volatile uint64_t tail; // read position
};
constexpr int WARP_SIZE = 32;
constexpr uint32_t WARP_MASK = 0xffffffff;
constexpr int RECV_FIFO_DEPTH = 8;
constexpr int RECV_FIFO_ENTRY_BYTES = 256 * 1024;
constexpr int RECV_FIFO_ENTRY_U64 = RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t);
constexpr int RECV_FIFO_TOTAL_BYTES = RECV_FIFO_DEPTH * RECV_FIFO_ENTRY_BYTES;
constexpr int RECV_FIFO_TOTAL_U64 = RECV_FIFO_TOTAL_BYTES / sizeof(uint64_t);
class AllToAllChannelCommunicatorBase
{
public:
static constexpr int GROUP_COUNT_PER_BLOCK = 8;
static_assert(GROUP_COUNT_PER_BLOCK <= 8, "GROUP_COUNT_PER_BLOCK must be less than or equal to 8");
static constexpr int WARP_PER_GROUP = 2;
static constexpr int U64_DATA_REG_PER_THREAD = 8;
// A packet is a warp-sized chunk of data that is sent or received in one go,
// but may be split into multiple 64-bit registers, the number of which is U64_DATA_REG_PER_THREAD.
static constexpr int PACKET_SIZE_IN_U64 = WARP_SIZE * U64_DATA_REG_PER_THREAD;
static constexpr int PACKET_SIZE_IN_BYTES = PACKET_SIZE_IN_U64 * sizeof(uint64_t);
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 = (WARP_SIZE - 2) * U64_DATA_REG_PER_THREAD;
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET = DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 * sizeof(uint64_t);
static constexpr int U64_ELT_COUNT_PER_PACKET = PACKET_SIZE_IN_BYTES / sizeof(uint64_t);
static constexpr int PACKET_COUNT_PER_FIFO_ENTRY = RECV_FIFO_ENTRY_BYTES / PACKET_SIZE_IN_BYTES;
static constexpr int GROUP_MAX_INDICE_COUNT
= RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t) / (WARP_SIZE * U64_DATA_REG_PER_THREAD);
struct GroupSharedBuffer
{
int groupIndiceBuffer[GROUP_MAX_INDICE_COUNT];
int groupStartIndice;
int groupEndIndice;
};
static void setMaxUsableSmCount(int maxUsableSmCount)
{
TLLM_CHECK_WITH_INFO(AllToAllChannelCommunicatorBase::maxSmCountUsed == false,
"setMaxUsableSmCount can be called only before it is used");
int smCount = tensorrt_llm::common::getMultiProcessorCount();
if (maxUsableSmCount > smCount)
{
TLLM_LOG_WARNING("setMaxUsableSmCount, maxUsableSmCount=%d, larger than smCount=%d, using smCount instead",
maxUsableSmCount, smCount);
maxUsableSmCount = smCount;
}
AllToAllChannelCommunicatorBase::maxSmCount = maxUsableSmCount;
}
static int getMaxUsableSmCount()
{
AllToAllChannelCommunicatorBase::maxSmCountUsed = true;
if (AllToAllChannelCommunicatorBase::maxSmCount == -1)
{
int smCount = tensorrt_llm::common::getMultiProcessorCount();
AllToAllChannelCommunicatorBase::maxSmCount = smCount;
}
return AllToAllChannelCommunicatorBase::maxSmCount;
}
static int computeMoeCommChannelCount(int epSize)
{
int smCount = getMaxUsableSmCount();
int blockCountPerChannel = (epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK;
blockCountPerChannel *= 2; // for send and recv
TLLM_CHECK_WITH_INFO(
blockCountPerChannel <= smCount, "GPU should support at lease one channel, usableSmCount=%d", smCount);
int perferredChannel = smCount / 2 / blockCountPerChannel; // use half SMs for communication
int channelCount = std::max(perferredChannel, 1); // at lease one channel
return channelCount;
}
static int getMoeCommChannelCount(int epSize)
{
static std::map<int, int> channelCountMap{};
auto iter = channelCountMap.find(epSize);
if (iter == channelCountMap.end())
{
auto channelCount = AllToAllChannelCommunicatorBase::computeMoeCommChannelCount(epSize);
channelCountMap[epSize] = channelCount;
return channelCount;
}
return iter->second;
}
static dim3 getLaunchBlockDim()
{
return dim3(WARP_SIZE * WARP_PER_GROUP, GROUP_COUNT_PER_BLOCK);
}
static dim3 getLaunchGridDim(int epSize)
{
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
return dim3((epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK, channelCount, 2);
}
protected:
static int maxSmCount;
static bool maxSmCountUsed;
};
inline size_t getMoeCommWorkspaceSize(int epSize)
{
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
return RECV_FIFO_TOTAL_BYTES * epSize * channelCount + sizeof(MoeCommFifoConnInfo) * epSize * channelCount;
}
struct MoeEpWorldInfo
{
int epSize;
int epRank;
};
struct MoeExpertParallelInfo
{
int expertCount = -1;
int topK = 1;
};
struct SendRecvDataInfo
{
int vectorSizeInU64;
// pre-computed at host side for GPU kernel
int dataPacketCountPerVector;
int vectorCountPerFifoEntry;
void ComputeDataPacketCountPerVector()
{
dataPacketCountPerVector
= (vectorSizeInU64 * sizeof(uint64_t) + AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET - 1)
/ AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET;
}
void ComputeVectorCountPerFifoEntry()
{
ComputeDataPacketCountPerVector();
vectorCountPerFifoEntry
= AllToAllChannelCommunicatorBase::PACKET_COUNT_PER_FIFO_ENTRY / dataPacketCountPerVector;
}
void DoPreCompute()
{
ComputeDataPacketCountPerVector();
ComputeVectorCountPerFifoEntry();
assert(vectorCountPerFifoEntry <= AllToAllChannelCommunicatorBase::GROUP_MAX_INDICE_COUNT);
}
};
// struct holding Send/Recv data pointer and its displacement information.
struct SendRecvDispls
{
uint64_t* dataPtr;
int const* rankCountCumSum; // length = epSize
int const* rankLocalIndices; // length = rankCountCumSum[epRank] - rankCountCumSum[epRank - 1] if epRank > 0 else
// rankCountCumSum[epRank]
int vectorStrideInU64;
#ifdef __CUDACC__
__inline__ __device__ int getCount(int rank) const
{
return rank == 0 ? rankCountCumSum[rank] : rankCountCumSum[rank] - rankCountCumSum[rank - 1];
}
__inline__ __device__ int getRankStart(int rank) const
{
return rank == 0 ? 0 : rankCountCumSum[rank - 1];
}
__inline__ __device__ int getRealVectorIndice(int globalVectorIndex) const
{
return rankLocalIndices[globalVectorIndex];
}
__inline__ __device__ uint64_t* getVectorDataPtr(int realVectorIndex) const
{
return dataPtr + realVectorIndex * vectorStrideInU64;
}
#endif
};
struct MoeCommWorkspace
{
uint64_t* workspacePtr;
size_t rankStrideInU64;
#ifdef __CUDACC__
__inline__ __device__ uint64_t* getFifoBasePtr(
bool isSender, int epRank, int peerRank, int channel, int channelCount) const
{
// fifo itself is in receiver's side.
if (isSender)
{
return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
}
else
{
return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
}
}
__inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo(
bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const
{
// fifoInfo is in sender's side.
uint64_t* fifoInfoPtrU64 = workspacePtr + RECV_FIFO_TOTAL_U64 * channelCount * epSize;
int strideIndice = isSender ? epRank : peerRank;
int fifoInfoIndice = isSender ? peerRank : epRank;
fifoInfoPtrU64 += strideIndice * rankStrideInU64;
MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*) fifoInfoPtrU64;
return fifoInfoPtr + fifoInfoIndice * channelCount + channel;
}
#endif
};
void setMaxUsableSmCount(int smCount);
void moeAllToAll(MoeEpWorldInfo worldInfo, SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls,
SendRecvDispls recvDispls, MoeCommWorkspace workspace, cudaStream_t stream);
void moeAllToAllPrepareIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank, int const* gatheredTargetRankIds, int const* realRankTokenCountCumSum,
int* localGatheredIndices, // indices of gatheredTargetRankIds that has the local rank in topK
int* sendRankCountCumSum, int* sendRankLocalIndices, int* recvRankCountCumSum, int* recvRankLocalIndices,
// the rankCountCumSum of combineRecv should be the same as sendRankCountCumSum
int* backwardRecvRankLocalIndices, cudaStream_t stream);
void moeLocalGather(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices, int const* gatheredExpertIds,
float const* gatheredScales, int* localExpertIds, float* localScales, cudaStream_t stream);
} // namespace tensorrt_llm::kernels

View File

@ -0,0 +1,47 @@
/*
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <stdint.h>
namespace tensorrt_llm
{
namespace kernels
{
#ifdef __CUDACC__
#define ALIGN_256 __align__(256)
#else
#define ALIGN_256 alignas(256)
#endif
constexpr int WARP_SIZE = 32;
constexpr uint32_t WARP_MASK = 0xffffffff;
struct MoeEpWorldInfo
{
int epSize;
int epRank;
};
struct MoeExpertParallelInfo
{
int expertCount = -1;
int topK = 1;
};
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -49,86 +49,6 @@ __device__ __forceinline__ int ld_acquire_sys_global_int(int volatile* ptr)
return ret;
}
class StepCommunicatorBase
{
public:
static constexpr int META_SIZE = sizeof(MoeCommFifoConnInfo);
__device__ __inline__ StepCommunicatorBase(MoeCommFifoConnInfo* fifoConnInfo)
: fifoConnInfo(fifoConnInfo)
, localCachedHead(0)
, localCachedTail(0)
{
}
__forceinline__ __device__ void reset()
{
fifoConnInfo->head = 0;
fifoConnInfo->tail = 0;
}
__forceinline__ __device__ void releaseSendStep()
{
localCachedHead += 1;
st_release_sys_global(&(fifoConnInfo->head), uint64_t(localCachedHead));
}
__forceinline__ __device__ void releaseRecvStep()
{
localCachedTail += 1;
st_release_sys_global(&(fifoConnInfo->tail), uint64_t(localCachedTail));
}
__forceinline__ __device__ uint64_t acquireTail()
{
uint64_t tail = ld_acquire_sys_global(&(fifoConnInfo->tail));
localCachedTail = tail;
return tail;
}
__forceinline__ __device__ uint64_t acquireHead()
{
uint64_t head = ld_acquire_sys_global(&(fifoConnInfo->head));
localCachedHead = head;
return head;
}
__forceinline__ __device__ int acquireNewSendStep()
{
int64_t tail;
do
{
tail = acquireTail();
} while (localCachedHead >= tail + STEP_DEPTH);
// depth = 2, head = 1, tail = 0 , ok
// depth = 2, head = 2, tail = 0, should wait
return localCachedHead % STEP_DEPTH;
}
__forceinline__ __device__ int acquireNewRecvStep()
{
int64_t head = 0;
do
{
head = acquireHead();
} while (localCachedTail >= head);
return localCachedTail % STEP_DEPTH;
}
public:
MoeCommFifoConnInfo* fifoConnInfo;
uint64_t localCachedHead;
uint64_t localCachedTail;
int rank;
int targetRank;
};
// Use MoeCommFifoConnInfo as media to transfer a counter number.
// Use the "head" field as flag.
// Use the "tail" field to transfer the counter number.
class CounterCommunicator
{
public:
@ -137,23 +57,23 @@ public:
{
}
__forceinline__ __device__ void releaseValue(uint64_t value)
__forceinline__ __device__ void releaseValue(uint64_t value, int index)
{
// Avoid block on 0
st_release_sys_global(&(fifoConnInfo->count), value + 1);
fifoConnInfo->values[index] = value + 1;
}
__forceinline__ __device__ uint64_t acquireValue()
__forceinline__ __device__ uint64_t acquireValue(int index)
{
uint64_t localCount = 0;
uint64_t localValue = 0;
do
{
localCount = ld_acquire_sys_global(&(fifoConnInfo->count));
} while (localCount == 0);
localValue = fifoConnInfo->values[index];
} while (localValue == 0);
fifoConnInfo->count = 0; // reset the count
fifoConnInfo->values[index] = 0; // reset the value
return localCount - 1;
return localValue - 1;
}
protected:
@ -161,15 +81,16 @@ protected:
};
template <int kThreadsGroupSize>
__device__ __forceinline__ void computeCountAndSend(int* experts, int tokenCount, int* sharedSendRecvRankCount,
int* sendCounts, int* sendIndiceWorkspace, int* backwardIndiceWorkspace, MoeCommWorkspace workspace,
int maxTokenCountPerRank, int expertCount, int topK, int epRank, int epSize)
__device__ __forceinline__ void computeCountAndSendStatics(int* experts, int tokenCount, int* sharedSendRecvRankCount,
int* sendCounts, int* sendIndiceWorkspace, int* backwardIndiceWorkspace, int* expertStatics,
MoeCommWorkspace workspace, int maxTokenCountPerRank, int slotCount, int expertCount, int topK, int epRank,
int epSize)
{
cg::thread_block_tile<kThreadsGroupSize> tile = cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
int laneInTile = tile.thread_rank();
int tileId = threadIdx.x / kThreadsGroupSize;
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
int expertCountPerRank = expertCount / epSize;
int expertCountPerRank = slotCount / epSize;
if (threadIdx.x == 0)
{
*sharedSendRecvRankCount = 0;
@ -201,18 +122,24 @@ __device__ __forceinline__ void computeCountAndSend(int* experts, int tokenCount
tile.sync();
}
__syncthreads();
if (threadIdx.x == 0)
CounterCommunicator counter(workspace.getFifoConnInfo(true, epRank, targetRankId, 0, epSize, 1));
int communicationCount = expertStatics == nullptr ? 1 : expertCount + 1;
for (int i = threadIdx.x; i < communicationCount; i += blockDim.x)
{
CounterCommunicator counter(workspace.getFifoConnInfo(true, epRank, targetRankId, 0, epSize, 1));
int count = *(sharedSendRecvRankCount);
// printf("sendRecvCount: %d, rankId: %d, targetRankId: %d\n", count, rankId, targetRankId);
counter.releaseValue(uint64_t(count));
*(sendCounts + targetRankId) = count;
int value = i == 0 ? *(sharedSendRecvRankCount) : *(expertStatics + i - 1);
counter.releaseValue(value, i);
if (i == 0)
{
*(sendCounts + targetRankId) = value;
}
}
}
__device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCounts, int* sharedCountsBase,
MoeCommWorkspace workspace, int maxTokenCountPerRank, int rankId, int rankCount)
__device__ __forceinline__ void recvCountAndStatics(int* recvIndiceWorkspace, int* recvCounts, int* sharedCountsBase,
int* gatheredExpertStatics, MoeCommWorkspace workspace, int expertCount, int maxTokenCountPerRank, int rankId,
int rankCount)
{
int rankOffset = threadIdx.x / THREADS_PER_PIPELINE;
if (rankOffset >= PIPELINE_PER_CTA)
@ -229,18 +156,25 @@ __device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCou
cg::thread_block_tile<THREADS_PER_PIPELINE> rankTile
= cg::tiled_partition<THREADS_PER_PIPELINE>(cg::this_thread_block());
int* localRecvIndice = recvIndiceWorkspace + targetRankId * maxTokenCountPerRank;
int rankRecvCount;
if (rankTile.thread_rank() == 0)
CounterCommunicator counter(workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1));
int communicationCount = gatheredExpertStatics == nullptr ? 1 : expertCount + 1;
for (int i = rankTile.thread_rank(); i < communicationCount; i += THREADS_PER_PIPELINE)
{
CounterCommunicator counter(workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1));
rankRecvCount = int(counter.acquireValue());
// printf("rankRecvCount: %d, rankId: %d, targetRankId: %d\n", rankRecvCount, rankId, targetRankId);
*(recvCounts + targetRankId) = rankRecvCount;
*(sharedCountsThisRank) = rankRecvCount;
int recvValue = counter.acquireValue(i);
if (i == 0)
{
*(recvCounts + targetRankId) = recvValue;
*(sharedCountsThisRank) = recvValue;
}
else
{
*(gatheredExpertStatics + targetRankId * expertCount + i - 1) = recvValue;
}
}
rankTile.sync();
rankRecvCount = *(sharedCountsThisRank);
int rankRecvCount = *(sharedCountsThisRank);
for (int tokenId = unitId; tokenId < rankRecvCount; tokenId += UNIT_PER_PIPELINE)
{
*(localRecvIndice + tokenId) = tokenId;
@ -249,20 +183,22 @@ __device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCou
template <int kThreadsGroupSize>
__global__ void computeCountAndIndiceDevice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace,
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount,
int maxTokenCountPerRank, int topK, int expertCount, int rankId, int rankCount)
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, int* expertStatics, int* gatheredExpertStatics,
MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank, int topK, int slotCount, int expertCount,
int rankId, int rankCount)
{
__shared__ int sharedCounts[PIPELINE_PER_CTA];
bool isSender = blockIdx.x < rankCount;
if (isSender)
{
computeCountAndSend<kThreadsGroupSize>(experts, tokenCount, &sharedCounts[0], sendCounts, sendIndiceWorkspace,
backwardIndiceWorkspace, workspace, maxTokenCountPerRank, expertCount, topK, rankId, rankCount);
computeCountAndSendStatics<kThreadsGroupSize>(experts, tokenCount, &sharedCounts[0], sendCounts,
sendIndiceWorkspace, backwardIndiceWorkspace, expertStatics, workspace, maxTokenCountPerRank, slotCount,
expertCount, topK, rankId, rankCount);
}
else
{
recvCount(
recvIndiceWorkspace, recvCounts, &sharedCounts[0], workspace, maxTokenCountPerRank, rankId, rankCount);
recvCountAndStatics(recvIndiceWorkspace, recvCounts, &sharedCounts[0], gatheredExpertStatics, workspace,
expertCount, maxTokenCountPerRank, rankId, rankCount);
}
}
@ -307,259 +243,12 @@ __global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum
int tid = threadIdx.x;
int threadData = tid < rankCount ? inputOutputPtr[tid] : 0;
int count = threadData;
__syncthreads();
BlockScan(temp_storage).InclusiveSum(threadData, threadData);
if (tid < rankCount)
{
inputOutputPtr[tid] = threadData;
// printf("cumsum, send? : %d, rankId:%d, tid:%d, threadData:%d, count:%d\n", blockIdx.x == 0, rankId, tid,
// threadData, count);
}
}
template <typename PipelineConfig>
class PacketPipeline
{
public:
__device__ __inline__ PacketPipeline(
void* bufferBase, StepCommunicatorBase* stepCommunicator, int* sharedNewStepPtr, bool isSender)
: bufferBase(bufferBase)
, stepCommunicator(stepCommunicator)
, shared_new_step(sharedNewStepPtr)
{
step = 0;
needRelease = false;
packetId = isSender ? 0 : PipelineConfig::PACKET_PER_STEP - 1;
}
__device__ __forceinline__ void* getFirstSendPacket()
{
return bufferBase;
}
__device__ __inline__ void* finishSendPacket(bool acquireNewStep)
{
packetId++;
if (packetId < PipelineConfig::PACKET_PER_STEP)
{
return acquireNewStep ? bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
+ packetId * PipelineConfig::PACKET_SIZE
: nullptr;
}
__syncthreads();
if (threadIdx.x == 0)
{
stepCommunicator->releaseSendStep();
if (acquireNewStep)
{
step = stepCommunicator->acquireNewSendStep();
*(shared_new_step) = step;
}
}
__syncthreads();
if (acquireNewStep)
{
step = *(shared_new_step);
packetId = 0;
return bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP;
}
return nullptr;
}
__device__ __forceinline__ void* sendFinalize()
{
if (packetId > 0 && threadIdx.x == 0)
{
stepCommunicator->releaseSendStep();
}
}
__device__ __inline__ void* getNewRecvPacket()
{
packetId++;
if (packetId < PipelineConfig::PACKET_PER_STEP)
{
return bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
+ packetId * PipelineConfig::PACKET_SIZE;
}
__syncthreads();
if (threadIdx.x == 0)
{
if (needRelease)
{
stepCommunicator->releaseRecvStep();
}
step = stepCommunicator->acquireNewRecvStep();
needRelease = true;
*(shared_new_step) = step;
}
__syncthreads();
packetId = 0;
step = *(shared_new_step);
void* packetPtr = bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP;
return packetPtr;
}
__device__ __forceinline__ void reset()
{
if (threadIdx.x == 0)
{
stepCommunicator->reset();
}
}
void* bufferBase;
StepCommunicatorBase* stepCommunicator;
int step;
int packetId;
bool needRelease;
int* shared_new_step;
};
template <typename PipelineConfig, typename ExpertType, typename ScaleType>
__global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales,
int* localExpertStatics, int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum,
int* localSendIndice, int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank,
int topK, int expertCount, int slotCount, int rankId, int rankCount)
{
bool isSender = (blockIdx.y == 0);
int targetRankId = blockIdx.x;
int slotCountPerRank = slotCount / rankCount;
int groupSize = topK / PipelineConfig::UNIT_SIZE;
__shared__ int sharedNewStep;
__align__(16) int experts[PipelineConfig::UNIT_SIZE];
__align__(16) float scales[PipelineConfig::UNIT_SIZE];
uint8_t* bufferBase = (uint8_t*) (workspace.getFifoBasePtr(isSender, rankId, targetRankId, 0, 1));
StepCommunicatorBase stepCommunicator(workspace.getFifoConnInfo(isSender, rankId, targetRankId, 0, rankCount, 1));
PacketPipeline<PipelineConfig> pipeline(bufferBase, &stepCommunicator, &sharedNewStep, isSender);
if (isSender)
{
int baseCumsum = targetRankId == 0 ? 0 : *(sendCountsCumsum + targetRankId - 1);
int sendTokenCount = *(sendCountsCumsum + targetRankId) - baseCumsum;
int unitCount = sendTokenCount * topK / PipelineConfig::UNIT_SIZE;
void* packPtr = pipeline.getFirstSendPacket();
int indexBase = 0;
int staticCopyBase = 0;
bool acquireNewStep = unitCount > 0 || (localExpertStatics != nullptr && expertCount > 0);
while (acquireNewStep)
{
if (threadIdx.x < UNIT_PER_ITER)
{
int index = indexBase + threadIdx.x;
int groupId = index % groupSize;
if (index < unitCount)
{
int tokenId = *(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize));
*((ExpertType*) (experts))
= *(ExpertType*) (sendExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
#pragma unroll
for (int j = 0; j < PipelineConfig::UNIT_SIZE; j++)
{
int expertId = experts[j];
if (expertId / slotCountPerRank != targetRankId)
{
experts[j] = slotCount;
}
}
int* expertsPtr = (int*) (packPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
*((ExpertType*) (expertsPtr)) = *((ExpertType*) (experts));
if (sendScales != nullptr)
{
*((ScaleType*) (scales))
= *(ScaleType*) (sendScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
float* scaleBasePtr = (float*) (packPtr + PipelineConfig::SCALE_OFFSET);
float* scalesPtr = (float*) (scaleBasePtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
*((ScaleType*) (scalesPtr)) = *((ScaleType*) (scales));
}
}
}
else if (localExpertStatics != nullptr)
{
int staticCopyIdx = threadIdx.x - UNIT_PER_ITER;
if (staticCopyBase + staticCopyIdx * 4 < expertCount)
{
int4* staticBasePtr = (int4*) (packPtr + PipelineConfig::STATIC_COPY_OFFSET);
int4 staticData = *(int4*) (localExpertStatics + staticCopyBase + staticCopyIdx * 4);
*(staticBasePtr + staticCopyIdx) = staticData;
}
}
indexBase += UNIT_PER_ITER;
staticCopyBase += STATIC_COPY_PER_ITER * 4;
acquireNewStep = indexBase < unitCount || staticCopyBase < expertCount;
packPtr = pipeline.finishSendPacket(acquireNewStep);
}
pipeline.sendFinalize();
}
else
{
int baseCumsum = targetRankId == 0 ? 0 : *(recvCountsCumsum + targetRankId - 1);
int recvTokenCount = *(recvCountsCumsum + targetRankId) - baseCumsum;
int recvUnitCount = recvTokenCount * groupSize;
int unitIdBase = 0;
int staticCopyBase = 0;
while (unitIdBase < recvUnitCount || (localExpertStatics != nullptr && staticCopyBase < expertCount))
{
void* packetPtr = pipeline.getNewRecvPacket();
int packetUnitCount
= unitIdBase + UNIT_PER_ITER < recvUnitCount ? UNIT_PER_ITER : recvUnitCount - unitIdBase;
packetUnitCount = max(packetUnitCount, 0);
if (threadIdx.x < UNIT_PER_ITER)
{
if (threadIdx.x < packetUnitCount)
{
int tokenId = baseCumsum + (unitIdBase + threadIdx.x) / groupSize;
int groupId = (unitIdBase + threadIdx.x) % groupSize;
int* expertsPtr = (int*) (packetPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
*((ExpertType*) (experts)) = *((ExpertType*) (expertsPtr));
ExpertType* dstExpertsPtr
= (ExpertType*) (recvExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
*dstExpertsPtr = *((ExpertType*) (experts));
if (recvScales != nullptr)
{
float* scaleBasePtr = (float*) (packetPtr + PipelineConfig::SCALE_OFFSET);
float* scalesPtr = scaleBasePtr + threadIdx.x * PipelineConfig::UNIT_SIZE;
*((ScaleType*) (scales)) = *((ScaleType*) (scalesPtr));
ScaleType* dstScalesPtr
= (ScaleType*) (recvScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
*dstScalesPtr = *((ScaleType*) (scales));
}
}
}
else if (localExpertStatics != nullptr)
{
int staticCopyIdx = threadIdx.x - UNIT_PER_ITER;
if (staticCopyBase + staticCopyIdx * 4 < expertCount)
{
int4* staticBasePtr = (int4*) (packetPtr + PipelineConfig::STATIC_COPY_OFFSET);
int4 staticData = *(staticBasePtr + staticCopyIdx);
*(int4*) (gatheredExpertStatics + targetRankId * expertCount + staticCopyBase + staticCopyIdx * 4)
= staticData;
}
}
unitIdBase += packetUnitCount;
staticCopyBase += STATIC_COPY_PER_ITER * 4;
}
pipeline.reset();
}
}
@ -576,8 +265,9 @@ __global__ void memsetExpertIdsDevice(
}
void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace,
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount,
int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream)
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, int* expertStatics, int* gatheredExpertStatics,
MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank, int topK, int slotCount, int expertCount,
int rankId, int rankCount, cudaStream_t stream)
{
// first rankCount CTAs for count and send, then rankCount / PIPELINE_PER_CTA CTAs only for receive
int grid_x = rankCount + (rankCount + PIPELINE_PER_CTA - 1) / PIPELINE_PER_CTA;
@ -607,7 +297,8 @@ void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int*
kernelFn = computeCountAndIndiceDevice<2>;
}
kernelFn<<<grid, block, 0, stream>>>(experts, sendCounts, recvCounts, sendIndiceWorkspace, backwardIndiceWorkspace,
recvIndiceWorkspace, workspace, tokenCount, maxTokenCountPerRank, topK, expert_count, rankId, rankCount);
recvIndiceWorkspace, expertStatics, gatheredExpertStatics, workspace, tokenCount, maxTokenCountPerRank, topK,
slotCount, expertCount, rankId, rankCount);
}
void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream)
@ -628,46 +319,18 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i
backwardIndice, gatherBackwardIndice, recvIndice, gatherRecvIndice, maxTokenCountPerRank);
}
void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics,
int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice,
int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount,
int slotCount, int rankId, int rankCount, cudaStream_t stream)
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount,
int rankCount, cudaStream_t stream)
{
int block_size = localExpertStatics == nullptr ? UNIT_PER_ITER : UNIT_PER_ITER + STATIC_COPY_PER_ITER;
dim3 block(block_size);
dim3 grid(rankCount, 2);
if (topK % 4 == 0)
{
using PipelineConfig = PipelineConfig<4, 16>;
static_assert(
PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
"FIFO size is too small");
allToAllMetadataDevice<PipelineConfig, int4, float4><<<grid, block, 0, stream>>>(sendExperts, recvExperts,
sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
slotCount, rankId, rankCount);
}
else
{
using PipelineConfig = PipelineConfig<1, 64>;
static_assert(
PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
"FIFO size is too small");
allToAllMetadataDevice<PipelineConfig, int, float><<<grid, block, 0, stream>>>(sendExperts, recvExperts,
sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
slotCount, rankId, rankCount);
}
int smCount = tensorrt_llm::common::getMultiProcessorCount();
memsetExpertIdsDevice<<<smCount, 256, 0, stream>>>(
recvExperts, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
int block_size = 256;
memsetExpertIdsDevice<<<smCount, block_size, 0, stream>>>(
expertIds, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
}
size_t getMoePrepareWorkspaceSize(int epSize)
{
return (FIFO_SIZE_IN_U64 * 8 + StepCommunicatorBase::META_SIZE) * epSize;
return sizeof(MoeCommFifoConnInfo) * epSize;
}
} // namespace moe_prepare

View File

@ -28,36 +28,11 @@ namespace tensorrt_llm::kernels
namespace moe_prepare
{
#define STEP_DEPTH 2
#define THREADS_PER_UNIT 1
#define UNIT_PER_PIPELINE 128
#define PIPELINE_PER_CTA 4
#define EXPERT_BYTES_PER_UNIT 32
#define SCALE_BYTES_PER_UNIT 32
#define UNIT_COUNT_PER_PACKET 1024
#define BYTES_COUNTER 8
#define CUMSUM_THREADS_PER_BLOCK 128
#define UNIT_PER_ITER 256
#define STATIC_COPY_PER_ITER 128
static constexpr int THREADS_PER_PIPELINE = THREADS_PER_UNIT * UNIT_PER_PIPELINE;
static constexpr int THREADS_PER_CTA = THREADS_PER_PIPELINE * PIPELINE_PER_CTA;
template <int UNIT_SIZE_INPUT, int PACKET_PER_STEP_INPUT>
struct PipelineConfig
{
static constexpr int UNIT_SIZE = UNIT_SIZE_INPUT;
static constexpr int PACKET_PER_STEP = PACKET_PER_STEP_INPUT;
static constexpr int UNIT_BYTES_SIZE = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float));
static constexpr int SCALE_OFFSET = UNIT_SIZE * UNIT_PER_ITER * sizeof(int);
static constexpr int STATIC_COPY_OFFSET = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float));
static constexpr int PACKET_SIZE = UNIT_BYTES_SIZE + STATIC_COPY_PER_ITER * 4 * sizeof(int);
static constexpr int PACKET_SIZE_IN_U64 = (PACKET_SIZE / 8);
};
// 1MB FIFO size
static constexpr int FIFO_SIZE_IN_U64 = 1024 * 1024 / 8;
static constexpr int THREADS_PER_PIPELINE = UNIT_PER_PIPELINE;
#ifdef __CUDACC__
#define ALIGN_256 __align__(256)
@ -67,9 +42,9 @@ static constexpr int FIFO_SIZE_IN_U64 = 1024 * 1024 / 8;
struct ALIGN_256 MoeCommFifoConnInfo
{
volatile uint64_t head; // write position
volatile uint64_t tail; // read position
volatile uint64_t count; // for counter
volatile uint64_t head; // write position
volatile uint64_t tail; // read position
int volatile values[512]; // for values
};
struct MoeCommWorkspace
@ -77,25 +52,11 @@ struct MoeCommWorkspace
uint64_t* workspacePtr;
size_t rankStrideInU64;
#ifdef __CUDACC__
__inline__ __device__ uint64_t* getFifoBasePtr(
bool isSender, int epRank, int peerRank, int channel, int channelCount) const
{
// fifo itself is in receiver's side.
if (isSender)
{
return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * FIFO_SIZE_IN_U64;
}
else
{
return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * FIFO_SIZE_IN_U64;
}
}
__inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo(
bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const
{
// fifoInfo is in sender's side.
uint64_t* fifoInfoPtrU64 = workspacePtr + FIFO_SIZE_IN_U64 * channelCount * epSize;
uint64_t* fifoInfoPtrU64 = workspacePtr;
int strideIndice = isSender ? epRank : peerRank;
int fifoInfoIndice = isSender ? peerRank : epRank;
fifoInfoPtrU64 += strideIndice * rankStrideInU64;
@ -108,8 +69,9 @@ struct MoeCommWorkspace
};
void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace,
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount,
int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream);
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, int* expertStatics, int* gatheredExpertStatics,
MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank, int topK, int slotCount, int expertCount,
int rankId, int rankCount, cudaStream_t stream);
void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream);
@ -117,10 +79,8 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i
int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount,
int maxTokenCountPerRank, cudaStream_t stream);
void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics,
int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice,
int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount,
int slotCount, int rankId, int rankCount, cudaStream_t stream);
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount,
int epSize, cudaStream_t stream);
size_t getMoePrepareWorkspaceSize(int epSize);

View File

@ -16,7 +16,7 @@
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/moeCommKernels.h"
#include "tensorrt_llm/kernels/fusedMoeCommKernels.h"
#include "tensorrt_llm/kernels/moePrepareKernels.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
@ -28,180 +28,104 @@
namespace torch_ext
{
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
moeCommPrepareIndicesOp(torch::Tensor gatheredTargetRankIds, c10::optional<torch::Tensor> realRankTokenCountCumSum,
int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize)
void setMoeCommFieldInfo(tensorrt_llm::kernels::MoeCommFieldInfo& fieldInfo, torch::Tensor const& tensor)
{
CHECK_INPUT(gatheredTargetRankIds, torch::kInt32);
TORCH_CHECK(gatheredTargetRankIds.dim() == 2, "gatheredTargetRankIds must be a 2D tensor");
TORCH_CHECK(gatheredTargetRankIds.size(1) == topK, "gatheredTargetRankIds must have topK columns");
int const* realRankTokenCountCumSumPtr = nullptr;
if (realRankTokenCountCumSum.has_value())
{
TORCH_CHECK(realRankTokenCountCumSum.value().dim() == 1, "realRankTokenCountCumSum must be a 1D tensor");
TORCH_CHECK(realRankTokenCountCumSum.value().dtype() == torch::kInt32,
"realRankTokenCountCumSum must be a int32 tensor");
TORCH_CHECK(
realRankTokenCountCumSum.value().size(0) == epSize, "realRankTokenCountCumSum must have epSize elements");
realRankTokenCountCumSumPtr = realRankTokenCountCumSum.value().data_ptr<int>();
}
else
{
TORCH_CHECK(gatheredTargetRankIds.size(0) == epSize * maxTokenCountPerRank,
"gatheredTargetRankIds should have shape (epSize * maxTokenCountPerRank, topK)");
}
TORCH_CHECK(maxTokenCountPerRank > 0, "maxTokenCountPerRank must be greater than 0");
TORCH_CHECK(expertCount > 0, "expertCount must be greater than 0");
TORCH_CHECK(topK > 0, "topK must be greater than 0");
TORCH_CHECK(topK <= expertCount, "topK must be less than or equal to expertCount");
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
auto stream = at::cuda::getCurrentCUDAStream();
int maxSendRanksPerToken = std::max(epSize, topK);
torch::Tensor localGatherIndices
= torch::empty({maxTokenCountPerRank * epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
torch::Tensor sendRankCountCumSum = torch::empty({epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
torch::Tensor sendRankLocalIndices = torch::empty(
{maxTokenCountPerRank * maxSendRanksPerToken}, gatheredTargetRankIds.options().dtype(torch::kInt32));
torch::Tensor recvRankCountCumSum = torch::empty({epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
torch::Tensor recvRankLocalIndices
= torch::empty({maxTokenCountPerRank * epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
torch::Tensor backwardRecvRankLocalIndices = torch::empty(
{maxTokenCountPerRank * maxSendRanksPerToken}, gatheredTargetRankIds.options().dtype(torch::kInt32));
tensorrt_llm::kernels::MoeExpertParallelInfo expertParallelInfo;
expertParallelInfo.expertCount = expertCount;
expertParallelInfo.topK = topK;
tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
tensorrt_llm::kernels::moeAllToAllPrepareIndices(worldInfo, expertParallelInfo, maxTokenCountPerRank,
gatheredTargetRankIds.data_ptr<int>(), realRankTokenCountCumSumPtr, localGatherIndices.data_ptr<int>(),
sendRankCountCumSum.data_ptr<int>(), sendRankLocalIndices.data_ptr<int>(), recvRankCountCumSum.data_ptr<int>(),
recvRankLocalIndices.data_ptr<int>(), backwardRecvRankLocalIndices.data_ptr<int>(), stream);
return std::make_tuple(localGatherIndices, sendRankCountCumSum, sendRankLocalIndices, recvRankCountCumSum,
recvRankLocalIndices, backwardRecvRankLocalIndices);
TORCH_CHECK(tensor.dim() == 2, "tensor must be a 2D tensor");
int eltSize = tensor.dtype().itemsize();
fieldInfo.fillFieldInfo(static_cast<uint8_t*>(tensor.data_ptr()), eltSize, tensor.size(1), tensor.stride(0));
}
void moeLocalGatherOp(torch::Tensor recvRankCumSum, torch::Tensor localGatherIndices, torch::Tensor gatheredExpertIds,
c10::optional<torch::Tensor> gatheredScales, torch::Tensor localExpertIds, c10::optional<torch::Tensor> localScales,
int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize)
{
CHECK_INPUT(recvRankCumSum, torch::kInt32);
CHECK_INPUT(localGatherIndices, torch::kInt32);
CHECK_INPUT(gatheredExpertIds, torch::kInt32);
CHECK_INPUT(localExpertIds, torch::kInt32);
TORCH_CHECK(maxTokenCountPerRank > 0, "maxTokenCountPerRank must be greater than 0");
TORCH_CHECK(expertCount > 0, "expertCount must be greater than 0");
TORCH_CHECK(topK > 0, "topK must be greater than 0");
TORCH_CHECK(topK <= expertCount, "topK must be less than or equal to expertCount");
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
TORCH_CHECK(recvRankCumSum.dim() == 1, "recvRankCumSum must be a 1D tensor");
TORCH_CHECK(recvRankCumSum.size(0) == epSize, "recvRankCumSum must have epSize elements");
TORCH_CHECK(localGatherIndices.dim() == 1, "localGatherIndices must be a 1D tensor");
TORCH_CHECK(gatheredExpertIds.dim() == 2, "gatheredExpertIds must be a 2D tensor");
TORCH_CHECK(localExpertIds.dim() == 2, "localExpertIds must be a 2D tensor");
TORCH_CHECK(gatheredExpertIds.size(1) == topK, "gatheredExpertIds must have topK columns");
TORCH_CHECK(localExpertIds.size(1) == topK, "localExpertIds must have topK columns");
int localMaxTokenCount = static_cast<int>(localGatherIndices.size(0));
TORCH_CHECK(localExpertIds.size(0) == localMaxTokenCount, "localExpertIds must have localMaxTokenCount rows");
TORCH_CHECK(gatheredScales.has_value() == localScales.has_value(),
"gatheredScales and localScales must be both valid or both invalid");
float const* gatheredScalesPtr = nullptr;
float* localScalesPtr = nullptr;
if (gatheredScales.has_value())
{
CHECK_INPUT(gatheredScales.value(), torch::kFloat32);
CHECK_INPUT(localScales.value(), torch::kFloat32);
TORCH_CHECK(gatheredScales->dim() == 2, "gatheredScales must be a 2D tensor");
TORCH_CHECK(gatheredScales->size(1) == topK, "gatheredScales must have topK columns");
TORCH_CHECK(localScales->dim() == 2, "localScales must be a 2D tensor");
TORCH_CHECK(localScales->size(1) == topK, "localScales must have topK columns");
TORCH_CHECK(localScales->size(0) == localMaxTokenCount, "localScales must have localMaxTokenCount rows");
gatheredScalesPtr = gatheredScales->data_ptr<float>();
localScalesPtr = localScales->data_ptr<float>();
}
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::MoeExpertParallelInfo expertParallelInfo;
expertParallelInfo.expertCount = expertCount;
expertParallelInfo.topK = topK;
tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
tensorrt_llm::kernels::moeLocalGather(worldInfo, expertParallelInfo, maxTokenCountPerRank, localMaxTokenCount,
recvRankCumSum.data_ptr<int>(), localGatherIndices.data_ptr<int>(), gatheredExpertIds.data_ptr<int>(),
gatheredScalesPtr, localExpertIds.data_ptr<int>(), localScalesPtr, stream);
}
void moeCommOp(torch::Tensor input, torch::Tensor sendRankCumSum, torch::Tensor sendIndices, torch::Tensor output,
torch::Tensor recvRankCumSum, torch::Tensor recvIndices, torch::Tensor allWorkspaces, int64_t epRank,
int64_t epSize)
c10::List<torch::Tensor> moeCommOp(c10::List<torch::Tensor> inputs, torch::Tensor sendRankCumSum,
torch::Tensor sendIndiceTensor, torch::Tensor recvRankCumSum, torch::Tensor recvIndiceTensor,
torch::Tensor allWorkspaces, int64_t outputAllocationCount, int64_t epRank, int64_t epSize,
std::optional<c10::List<bool>> needZeroOutput = std::nullopt)
{
CHECK_INPUT(sendRankCumSum, torch::kInt32);
CHECK_INPUT(sendIndices, torch::kInt32);
CHECK_INPUT(sendIndiceTensor, torch::kInt32);
CHECK_INPUT(recvRankCumSum, torch::kInt32);
CHECK_INPUT(recvIndices, torch::kInt32);
// allWorkspaces is a uint64 tensor, but may not be contiguous
TORCH_CHECK(allWorkspaces.dtype() == torch::kUInt64, "allWorkspaces must be a uint64 tensor");
CHECK_INPUT(recvIndiceTensor, torch::kInt32);
TORCH_CHECK(input.dim() == 2, "input must be a 2D tensor");
TORCH_CHECK(output.dim() == 2, "output must be a 2D tensor");
TORCH_CHECK(sendRankCumSum.dim() == 1, "sendRankCumSum must be a 1D tensor");
TORCH_CHECK(sendIndices.dim() == 1, "sendIndices must be a 1D tensor");
TORCH_CHECK(sendIndiceTensor.dim() == 1, "sendIndices must be a 1D tensor");
TORCH_CHECK(recvRankCumSum.dim() == 1, "recvRankCumSum must be a 1D tensor");
TORCH_CHECK(recvIndices.dim() == 1, "recvIndices must be a 1D tensor");
TORCH_CHECK(recvIndiceTensor.dim() == 1, "recvIndices must be a 1D tensor");
TORCH_CHECK(allWorkspaces.dim() == 2, "allWorkspaces must be a 2D tensor");
TORCH_CHECK(input.size(1) == output.size(1), "input and output must have the same second dimension");
TORCH_CHECK(sendRankCumSum.size(0) == epSize, "sendRankCumSum must have epSize elements");
TORCH_CHECK(recvRankCumSum.size(0) == epSize, "recvRankCumSum must have epSize elements");
TORCH_CHECK(allWorkspaces.size(0) == epSize, "allWorkspaces must have epSize elements");
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
TORCH_CHECK(!needZeroOutput.has_value() || needZeroOutput.value().size() == inputs.size(),
"needZeroOutput should have same length as inputs");
c10::List<torch::Tensor> outputs;
tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
tensorrt_llm::kernels::SendRecvDataInfo sendRecvDataInfo;
tensorrt_llm::kernels::MoeEpWorldInfo epWorldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
tensorrt_llm::kernels::FusedMoeWorldInfo worldInfo = {epWorldInfo};
size_t eltSize = input.dtype().itemsize();
size_t eltCountPerU64 = sizeof(uint64_t) / eltSize;
TORCH_CHECK(input.size(1) % (eltCountPerU64 * 2) == 0, "input.size(1) must be aligned to 16 bytes");
sendRecvDataInfo.vectorSizeInU64 = input.size(1) / eltCountPerU64;
sendRecvDataInfo.DoPreCompute();
tensorrt_llm::kernels::SendRecvIndices sendIndices, recvIndices;
sendIndices.rankCountCumSum = sendRankCumSum.data_ptr<int>();
sendIndices.rankLocalIndices = sendIndiceTensor.data_ptr<int>();
tensorrt_llm::kernels::SendRecvDispls sendDispls, recvDispls;
sendDispls.dataPtr = static_cast<uint64_t*>(input.data_ptr());
sendDispls.rankCountCumSum = sendRankCumSum.data_ptr<int>();
sendDispls.rankLocalIndices = sendIndices.data_ptr<int>();
sendDispls.vectorStrideInU64 = input.stride(0) / eltCountPerU64;
recvIndices.rankCountCumSum = recvRankCumSum.data_ptr<int>();
recvIndices.rankLocalIndices = recvIndiceTensor.data_ptr<int>();
recvDispls.dataPtr = static_cast<uint64_t*>(output.data_ptr());
recvDispls.rankCountCumSum = recvRankCumSum.data_ptr<int>();
recvDispls.rankLocalIndices = recvIndices.data_ptr<int>();
recvDispls.vectorStrideInU64 = output.stride(0) / eltCountPerU64;
int fieldCount = inputs.size();
TORCH_CHECK(fieldCount <= tensorrt_llm::kernels::MOE_COMM_FIELD_MAX_COUNT, "Number of fields (", fieldCount,
") exceeds maximum allowed (", tensorrt_llm::kernels::MOE_COMM_FIELD_MAX_COUNT, ")");
tensorrt_llm::kernels::FusedMoeFieldInfo sendFieldInfo, recvFieldInfo;
sendFieldInfo.isBasicInterleaved = false;
recvFieldInfo.isBasicInterleaved = false;
sendFieldInfo.fieldCount = fieldCount;
recvFieldInfo.fieldCount = fieldCount;
sendFieldInfo.expertScales = nullptr;
recvFieldInfo.expertScales = nullptr;
sendFieldInfo.tokenSelectedSlots = nullptr;
recvFieldInfo.tokenSelectedSlots = nullptr;
tensorrt_llm::kernels::MoeCommWorkspace workspace;
workspace.workspacePtr = allWorkspaces.data_ptr<uint64_t>();
workspace.rankStrideInU64 = allWorkspaces.stride(0);
for (int i = 0; i < fieldCount; i++)
{
torch::Tensor const& t = inputs[i];
setMoeCommFieldInfo(sendFieldInfo.fieldsInfo[i], t);
if (needZeroOutput.has_value() && needZeroOutput.value()[i])
{
outputs.push_back(torch::zeros({outputAllocationCount, t.size(1)}, t.options()));
}
else
{
outputs.push_back(torch::empty({outputAllocationCount, t.size(1)}, t.options()));
}
setMoeCommFieldInfo(recvFieldInfo.fieldsInfo[i], outputs[i]);
}
sendFieldInfo.fillFieldPlacementInfo(0, false);
recvFieldInfo.fillFieldPlacementInfo(0, false);
tensorrt_llm::kernels::FusedMoeCommKernelParam params;
params.worldInfo = worldInfo;
params.sendIndices = sendIndices;
params.recvIndices = recvIndices;
params.sendFieldInfo = sendFieldInfo;
params.recvFieldInfo = recvFieldInfo;
// Do not need expertParallelInfo for fused moe comm now
params.sendFieldInfo.fillMetaInfo(&(params.sendCommMeta), params.expertParallelInfo.topK, false, false);
params.recvFieldInfo.fillMetaInfo(&(params.recvCommMeta), params.expertParallelInfo.topK, false, false);
tensorrt_llm::kernels::FusedMoeWorkspace fusedMoeWorkspace;
tensorrt_llm::kernels::constructWorkspace(
&fusedMoeWorkspace, allWorkspaces.data_ptr<uint64_t>(), allWorkspaces.stride(0), epSize);
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::moeAllToAll(worldInfo, sendRecvDataInfo, sendDispls, recvDispls, workspace, stream);
tensorrt_llm::kernels::moeAllToAll(params, fusedMoeWorkspace, stream);
return outputs;
}
int64_t getWorkspaceSizePerRank(int64_t epSize)
{
int epSize32 = static_cast<int>(epSize);
return tensorrt_llm::kernels::getMoeCommWorkspaceSize(epSize32);
return tensorrt_llm::kernels::getFusedMoeCommWorkspaceSize(epSize32);
}
void setMaxUsableSmCount(int64_t maxSmCount)
@ -215,15 +139,29 @@ int64_t getPrepareWorkspaceSizePerRank(int64_t epSize)
return tensorrt_llm::kernels::moe_prepare::getMoePrepareWorkspaceSize(epSize32);
}
std::tuple<torch::Tensor, c10::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor, c10::optional<torch::Tensor>>
moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> scales, c10::optional<torch::Tensor> expertsStatics,
torch::Tensor allWorkspaces, int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount,
int64_t slotCount, int64_t topK)
void initializeMoeWorkspace(torch::Tensor allWorkspaces, int64_t epRank, int64_t epSize)
{
TORCH_CHECK(allWorkspaces.dim() == 2, "allWorkspaces must be a 2D tensor");
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
tensorrt_llm::kernels::MoeEpWorldInfo epWorldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
tensorrt_llm::kernels::FusedMoeWorldInfo worldInfo = {epWorldInfo};
tensorrt_llm::kernels::FusedMoeWorkspace fusedMoeWorkspace;
tensorrt_llm::kernels::constructWorkspace(
&fusedMoeWorkspace, allWorkspaces.data_ptr<uint64_t>(), allWorkspaces.stride(0), epSize);
tensorrt_llm::kernels::initializeFusedMoeLocalWorkspace(&fusedMoeWorkspace, worldInfo);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, c10::optional<torch::Tensor>>
moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> expertsStatics, torch::Tensor allWorkspaces,
int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount, int64_t slotCount, int64_t topK)
{
CHECK_INPUT(expertsIds, torch::kInt32);
TORCH_CHECK(expertCount % 4 == 0, "expertCount must be divisible by 4");
TORCH_CHECK(slotCount % 4 == 0, "slotCount must be divisible by 4");
TORCH_CHECK(expertCount + 1 <= 512, "expertCount + 1 is larger than 512");
int64_t maxSendRanksPerToken = std::max(epSize, topK);
int64_t tokenCount = expertsIds.size(0);
@ -249,18 +187,6 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> scales, c10:
torch::Tensor sendRankIndices
= torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32));
c10::optional<torch::Tensor> preparedLocalScales;
float* scalesPtr = nullptr;
float* preparedLocalScalesPtr = nullptr;
if (scales.has_value())
{
CHECK_INPUT(scales.value(), torch::kFloat32);
scalesPtr = scales->data_ptr<float>();
preparedLocalScales
= torch::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(torch::kFloat32));
preparedLocalScalesPtr = preparedLocalScales->data_ptr<float>();
}
int* localExpertStaticsPtr = nullptr;
int* gatheredExpertStaticsPtr = nullptr;
c10::optional<torch::Tensor> gatheredExpertStatics;
@ -279,8 +205,9 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> scales, c10:
tensorrt_llm::kernels::moe_prepare::computeCountAndIndice(expertsIds.data_ptr<int>(),
sendRankCountCumSum.data_ptr<int>(), RecvRankCountCumSum.data_ptr<int>(), sendRankIndices.data_ptr<int>(),
backwardRecvRankIndices.data_ptr<int>(), recvRankIndices.data_ptr<int>(), workspace, tokenCount,
maxTokenCountPerRank, topK, slotCount, epRank, epSize, stream);
backwardRecvRankIndices.data_ptr<int>(), recvRankIndices.data_ptr<int>(), localExpertStaticsPtr,
gatheredExpertStaticsPtr, workspace, tokenCount, maxTokenCountPerRank, topK, slotCount, expertCount, epRank,
epSize, stream);
tensorrt_llm::kernels::moe_prepare::computeCumsum(
sendRankCountCumSum.data_ptr<int>(), RecvRankCountCumSum.data_ptr<int>(), epRank, epSize, stream);
@ -291,14 +218,28 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> scales, c10:
recvRankIndices.data_ptr<int>(), gatherRecvRankIndices.data_ptr<int>(), epRank, epSize, maxTokenCountPerRank,
stream);
tensorrt_llm::kernels::moe_prepare::allToAllMetadata(expertsIds.data_ptr<int>(),
preparedLocalExpertIds.data_ptr<int>(), scalesPtr, preparedLocalScalesPtr, localExpertStaticsPtr,
gatheredExpertStaticsPtr, workspace, sendRankCountCumSum.data_ptr<int>(), sendRankIndices.data_ptr<int>(),
RecvRankCountCumSum.data_ptr<int>(), recvRankIndices.data_ptr<int>(), tokenCount, maxTokenCountPerRank, topK,
expertCount, slotCount, epRank, epSize, stream);
return std::make_tuple(sendRankCountCumSum, gatherSendRankIndices, RecvRankCountCumSum, gatherRecvRankIndices,
gatherBackwardRecvRankIndices, gatheredExpertStatics);
}
return std::make_tuple(preparedLocalExpertIds, preparedLocalScales, sendRankCountCumSum, gatherSendRankIndices,
RecvRankCountCumSum, gatherRecvRankIndices, gatherBackwardRecvRankIndices, gatheredExpertStatics);
void memsetExpertIds(torch::Tensor expertsIds, torch::Tensor recvRankCountCumSum, int64_t maxTokenCountPerRank,
int64_t topK, int64_t slotCount, int64_t epSize)
{
CHECK_INPUT(expertsIds, torch::kInt32);
TORCH_CHECK(expertsIds.dim() == 2, "expertsIds must be a 1D tensor");
TORCH_CHECK(
expertsIds.size(0) == maxTokenCountPerRank * epSize, "expertsIds must have maxTokenCountPerRank * epSize rows");
TORCH_CHECK(expertsIds.size(1) == topK, "expertsIds must have topK columns");
CHECK_INPUT(recvRankCountCumSum, torch::kInt32);
TORCH_CHECK(recvRankCountCumSum.dim() == 1, "recvRankCountCumSum must be a 1D tensor");
TORCH_CHECK(recvRankCountCumSum.size(0) == epSize, "recvRankCountCumSum must have epSize elements");
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::moe_prepare::memsetExpertIds(expertsIds.data_ptr<int>(), recvRankCountCumSum.data_ptr<int>(),
static_cast<int>(maxTokenCountPerRank), static_cast<int>(topK), static_cast<int>(slotCount),
static_cast<int>(epSize), stream);
}
} // namespace torch_ext
@ -306,34 +247,9 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> scales, c10:
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_comm_prepare_indices(Tensor gathered_target_rank_ids, Tensor? real_rank_token_count_cum_sum, int "
"max_token_count_per_rank, int expert_count, int top_k, int ep_rank, int ep_size) -> (Tensor, Tensor, Tensor, "
"Tensor, Tensor, Tensor)");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_comm_prepare_indices", &torch_ext::moeCommPrepareIndicesOp);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_local_gather(Tensor recv_rank_cum_sum, Tensor local_gather_indices, Tensor gathered_expert_ids, Tensor? "
"gathered_scales, Tensor local_expert_ids, Tensor? local_scales, int max_token_count_per_rank, int "
"expert_count, int top_k, int ep_rank, int ep_size) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_local_gather", &torch_ext::moeLocalGatherOp);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_comm(Tensor input, Tensor send_rank_cum_sum, Tensor send_indices, Tensor output, Tensor "
"recv_rank_cum_sum, Tensor recv_indices, Tensor all_workspaces, int ep_rank, int ep_size) -> ()");
"moe_comm(Tensor[] inputs, Tensor send_rank_cum_sum, Tensor send_indices, Tensor "
"recv_rank_cum_sum, Tensor recv_indices, Tensor all_workspaces, int output_allocation_count, int ep_rank, int "
"ep_size, bool[]? need_zero_output=None) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
@ -341,6 +257,16 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
m.impl("moe_comm", &torch_ext::moeCommOp);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("moe_initialize_workspace(Tensor(a!) all_workspaces, int ep_rank, int ep_size) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_initialize_workspace", &torch_ext::initializeMoeWorkspace);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("get_moe_commworkspace_size_per_rank(int ep_size) -> int");
@ -364,9 +290,9 @@ TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m)
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor? scales, Tensor? experts_statics, "
"mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor? experts_statics, "
"Tensor allWorkspace, int max_token_count_per_rank, int ep_rank, int ep_size, int expert_count, int "
"slot_count, int top_k) -> (Tensor, Tensor?, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?)");
"slot_count, int top_k) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?)");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
@ -374,6 +300,19 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
m.impl("mnnvl_moe_alltoallv_prepare_without_allgather", &torch_ext::moePrepareOp);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"memset_expert_ids(Tensor(a!) experts_ids, Tensor recv_rank_count_cumsum, int max_token_count_per_rank, int "
"top_k, "
"int slot_count, int ep_size) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("memset_expert_ids", &torch_ext::memsetExpertIds);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("get_moe_prepare_workspace_size_per_rank(int ep_size) -> int");

View File

@ -16,7 +16,6 @@
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/moeCommKernels.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/thop/thUtils.h"

View File

@ -69,32 +69,6 @@ function(add_gtest test_name test_src)
add_dependencies(google-tests ${test_name})
endfunction()
add_subdirectory(unit_tests)
add_gtest(mpiUtilsTest runtime/mpiUtilsTest.cpp)
add_gtest(gptDecoderTest runtime/gptDecoderTest.cpp)
add_gtest(gptDecoderBatchedTest runtime/gptDecoderBatchedTest.cpp)
add_gtest(medusaModuleTest runtime/medusaModuleTest.cpp)
add_gtest(moeLoadBalancerTest runtime/moeLoadBalancerTest.cpp)
add_gtest(sanitizerTest runtime/sanitizerTest.cpp)
add_gtest(eaglePackDataTest kernels/eaglePackDataTest.cpp)
add_gtest(medusaDecodeLayerTest layers/medusaDecodeLayerTest.cpp)
add_gtest(moeLoadBalanceKernelTest kernels/moeLoadBalanceKernelTest.cpp)
add_gtest(eagleLayerTest layers/eagleLayerTest.cpp)
add_subdirectory(utils)
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/batch_manager)
add_subdirectory(batch_manager)
endif()
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/executor)
add_subdirectory(executor)
endif()
add_subdirectory(unit_tests)
add_subdirectory(e2e_tests)

View File

@ -0,0 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: NVIDIA TensorRT
# Source Code License Agreement
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this material and related documentation without an express
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
# prohibited.
add_subdirectory(batch_manager)
add_subdirectory(executor)

View File

@ -9,13 +9,10 @@
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
# prohibited.
add_gtest(cacheTransceiverTest cacheTransceiverTest.cpp)
# guidedDecoderTest requires model tokenizer info, so it's easier to run it with
# e2e tests instead of unit tests.
add_gtest(guidedDecoderTest guidedDecoderTest.cpp)
add_gtest(trtEncoderModelTest trtEncoderModelTest.cpp)
add_gtest(trtGptModelTest trtGptModelTest.cpp)
add_gtest(trtGptModelRealDecoderTest trtGptModelRealDecoderTest.cpp)
target_link_libraries(trtGptModelRealDecoderTest PRIVATE testingUtils)
add_gtest(peftCacheManagerTest peftCacheManagerTest.cpp)
add_gtest(trtEncoderModelTest trtEncoderModelTest.cpp)
add_gtest(guidedDecoderTest guidedDecoderTest.cpp)
add_gtest(blockKeyTest blockKeyTest.cpp)

View File

@ -1405,7 +1405,7 @@ INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen1TP2PP2DisaaggOrchestrator, DisaggOrches
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP2Gen2TP1DisaaggSpawnOrchestrator, DisaggOrchestratorParamsTest,
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP2Gen2TP1DisaggSpawnOrchestrator, DisaggOrchestratorParamsTest,
testing::Combine( //
testing::Values(1), // processNum
testing::Values(
@ -1418,7 +1418,7 @@ INSTANTIATE_TEST_SUITE_P(LlamaCon2TP2Gen2TP1DisaaggSpawnOrchestrator, DisaggOrch
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen2PP2DisaaggSpawnOrchestrator, DisaggOrchestratorParamsTest,
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen2PP2DisaggSpawnOrchestrator, DisaggOrchestratorParamsTest,
testing::Combine( //
testing::Values(1), // processNum
testing::Values(

View File

@ -19,6 +19,7 @@ endif()
add_subdirectory(common)
add_subdirectory(kernels)
add_subdirectory(multi_gpu)
add_subdirectory(layers)
add_subdirectory(runtime)
add_subdirectory(thop)

View File

@ -9,6 +9,8 @@
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
# prohibited.
add_gtest(blockKeyTest blockKeyTest.cpp)
add_gtest(cacheTransBufferTest cacheTransBufferTest.cpp)
add_gtest(capacitySchedulerTest capacitySchedulerTest.cpp)
add_gtest(contextProgressTest contextProgressTest.cu)
add_gtest(evictionPolicyTest evictionPolicyTest.cpp)
@ -16,5 +18,5 @@ add_gtest(kvCacheManagerTest kvCacheManagerTest.cpp)
add_gtest(kvCacheUtilsTest kvCacheUtilsTest.cpp)
add_gtest(llmRequestTest llmRequestTest.cpp)
add_gtest(microBatchSchedulerTest microBatchSchedulerTest.cpp)
add_gtest(peftCacheManagerTest peftCacheManagerTest.cpp)
add_gtest(staticThreadPoolTest staticThreadPoolTest.cpp)
add_gtest(cacheTransBufferTest cacheTransBufferTest.cpp)

View File

@ -45,16 +45,7 @@ add_gtest(cudaCoreGemmKernelTest cudaCoreGemm/cudaCoreGemmKernelTest.cpp)
add_gtest(mlaChunkedPrefillTest mlaChunkedPrefillTest.cu)
if(NOT ENABLE_MULTI_DEVICE EQUAL 0)
add_gtest(allReduceKernelTest allReduce/allReduceKernelTest.cu)
add_gtest(allReduceFusionTest allReduce/allReduceFusionTest.cu)
add_gtest(gemmAllReduceTest allReduce/gemmAllReduceTest.cu)
if(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
target_link_libraries(gemmAllReduceTest PRIVATE ar_gemm_src)
target_compile_definitions(gemmAllReduceTest
PRIVATE USING_OSS_CUTLASS_ALLREDUCE_GEMM)
endif()
endif()
add_gtest(fusedMoeCommKernelTest fusedMoeCommKernelTest.cpp)
add_gtest(
gemmSwigluRunnerTest
@ -89,11 +80,13 @@ set(SAMPLING_KERNEL_TEST_SRC
sampling/samplingTest.cpp sampling/samplingTopKTest.cpp
sampling/samplingTopPTest.cpp sampling/samplingAirTopPTest.cpp
sampling/samplingPenaltyTest.cpp sampling/samplingUtilsTest.cu)
add_gtest(samplingKernelsTest "${SAMPLING_KERNEL_TEST_SRC}")
set(ROUTING_KERNEL_TEST_SRC
routing/routingTest.cpp routing/routingLlama4Test.cpp
routing/routingRenormalizeTest.cpp routing/routingDeepSeekTest.cpp)
add_gtest(routingKernelsTest "${ROUTING_KERNEL_TEST_SRC}")
add_gtest(moeLoadBalanceKernelTest moeLoadBalanceKernelTest.cpp)
add_gtest(eaglePackDataTest eaglePackDataTest.cpp)

File diff suppressed because it is too large Load Diff

View File

@ -31,5 +31,7 @@ set(LOOKAHEAD_DECODING_TEST_SRC randomLlm.cpp lookaheadDecodingLayerTest.cpp)
add_gtest(lookaheadDecodingLayerTest "${LOOKAHEAD_DECODING_TEST_SRC}")
add_gtest(dynamicDecodeLayerTest dynamicDecodeLayerTest.cpp)
add_gtest(eagleLayerTest eagleLayerTest.cpp)
add_gtest(explicitDraftTokensLayerTest explicitDraftTokensLayerTest.cpp)
add_gtest(layerUtilsTest layerUtilsTest.cpp)
add_gtest(medusaDecodeLayerTest medusaDecodeLayerTest.cpp)

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "tests/layers/eagleLayerTest.h"
#include "eagleLayerTest.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "tests/layers/medusaDecodeLayerTest.h"
#include "medusaDecodeLayerTest.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/runtime/medusaModule.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"

View File

@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: NVIDIA TensorRT
# Source Code License Agreement
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this material and related documentation without an express
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
# prohibited.
add_subdirectory(kernels)
add_gtest(cacheTransceiverTest cacheTransceiverTest.cpp)
add_gtest(mpiUtilsTest mpiUtilsTest.cpp)
add_gtest(userBufferTest userBufferTest.cpp)

View File

@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: NVIDIA TensorRT
# Source Code License Agreement
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this material and related documentation without an express
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
# prohibited.
if(NOT ENABLE_MULTI_DEVICE EQUAL 0)
add_gtest(allReduceKernelTest allReduce/allReduceKernelTest.cu)
add_gtest(allReduceFusionTest allReduce/allReduceFusionTest.cu)
add_gtest(gemmAllReduceTest allReduce/gemmAllReduceTest.cu)
if(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
target_link_libraries(gemmAllReduceTest PRIVATE ar_gemm_src)
target_compile_definitions(gemmAllReduceTest
PRIVATE USING_OSS_CUTLASS_ALLREDUCE_GEMM)
endif()
endif()

View File

@ -501,11 +501,8 @@ TEST(Kernel_AllReduceFusion, AllReduceAccuracyRandomTokenNum)
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
if (world_size % 2)
{
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
return;
}
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
int iter = 100;
std::vector<int> candidate_hidden_dim{1024, 2048, 4096, 7168, 8192};
int min_token_num = 1;
@ -537,11 +534,8 @@ TEST(Kernel_AllReduceFusion, AllReduceAccuracyFixedTokenNum)
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
if (world_size % 2)
{
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
return;
}
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
int iter = 10;
std::vector<int> candidate_hidden_dim{1024, 2048, 4096, 7168, 8192};
int min_token_num = 1;
@ -603,11 +597,8 @@ TEST(Kernel_AllReduceFusion, AllReduceFusionAccuracyDifferentHiddenDim)
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
if (world_size % 2)
{
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
return;
}
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
int const arch = tensorrt_llm::common::getSMVersion();
if (arch >= 100)
{
@ -647,11 +638,8 @@ TEST(Kernel_AllReduceFusion, AllReduceFusionAccuracyDifferentDType)
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
if (world_size % 2)
{
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
return;
}
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
std::vector<int> candidate_hidden_dim{1024, 2048, 4096, 7168, 8192};
int min_token_num = 1;
int max_token_num = 2048;
@ -683,53 +671,52 @@ TEST(Kernel_AllReduceFusion, AllReduceFusionAccuracyDifferentDType)
TEST(Kernel_AllReduceFusion, Perf)
{
int const arch = tensorrt_llm::common::getSMVersion();
if (arch >= 100)
if (arch < 100)
{
using Runner = TestRunner<half, ar_fusion::AllReduceFusionPattern::kARResidualRMSNormFP4Quant>;
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
if (world_size % 2)
GTEST_SKIP() << "Skipping test for SM < 100";
}
using Runner = TestRunner<half, ar_fusion::AllReduceFusionPattern::kARResidualRMSNormFP4Quant>;
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
int warmup = 100, iter = 300;
int hidden_dim = 7168;
std::vector<int> candidate_token_num{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048};
int max_token_num = 2048;
Runner runner(max_token_num, hidden_dim);
for (auto token_num : candidate_token_num)
{
auto latency = runner.benchmark(&Runner::run_kernel, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
return;
TLLM_LOG_INFO(
"token_num %-4d, hidden_dim %-4d, fusion kernel latency %4.4fus", token_num, hidden_dim, latency);
}
int warmup = 100, iter = 300;
int hidden_dim = 7168;
std::vector<int> candidate_token_num{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048};
int max_token_num = 2048;
Runner runner(max_token_num, hidden_dim);
for (auto token_num : candidate_token_num)
auto nccl_latency = runner.benchmark(&Runner::run_nccl_allreduce, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
auto latency = runner.benchmark(&Runner::run_kernel, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO(
"token_num %-4d, hidden_dim %-4d, fusion kernel latency %4.4fus", token_num, hidden_dim, latency);
}
auto nccl_latency = runner.benchmark(&Runner::run_nccl_allreduce, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("nccl allreduce latency %4.4fus", nccl_latency);
}
auto residual_latency = runner.benchmark(&Runner::run_residual_add, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("residual add latency %4.4fus", residual_latency);
}
auto rms_latency = runner.benchmark(&Runner::run_rms_norm, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("rms norm latency %4.4fus", rms_latency);
}
auto quant_latency = runner.benchmark(&Runner::run_fp4_quant, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("fp4 quant latency %4.4fus", quant_latency);
auto tot_latency = nccl_latency + residual_latency + rms_latency + quant_latency;
TLLM_LOG_INFO("fusion kernel latency %4.4fus, nccl + ops latency %4.4fus, total speedup %2.4fx",
latency, tot_latency, tot_latency / latency);
}
TLLM_LOG_INFO("nccl allreduce latency %4.4fus", nccl_latency);
}
auto residual_latency = runner.benchmark(&Runner::run_residual_add, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("residual add latency %4.4fus", residual_latency);
}
auto rms_latency = runner.benchmark(&Runner::run_rms_norm, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("rms norm latency %4.4fus", rms_latency);
}
auto quant_latency = runner.benchmark(&Runner::run_fp4_quant, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("fp4 quant latency %4.4fus", quant_latency);
auto tot_latency = nccl_latency + residual_latency + rms_latency + quant_latency;
TLLM_LOG_INFO("fusion kernel latency %4.4fus, nccl + ops latency %4.4fus, total speedup %2.4fx", latency,
tot_latency, tot_latency / latency);
}
}
}

View File

@ -573,8 +573,7 @@ TEST(Kernel, AllReduce)
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
if (world_size % 2)
return;
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
int warmup = 100, iter = 100;
// clang-format off
@ -645,8 +644,7 @@ TEST(Kernel, AllReduceOneShot)
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
if (world_size % 2)
return;
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
int warmup = 100, iter = 100;
std::vector<int> candidate_bs{1, 2, 4, 8, 16};
@ -673,19 +671,14 @@ TEST(Kernel, AllReduceOneShotPreNorm)
char const* value = "1";
int overwrite = 1; // Set to 1 to overwrite existing values, 0 to preserve
if (setenv(varName, value, overwrite) != 0)
{
perror("Error setting environment variable");
return;
}
ASSERT_EQ(setenv(varName, value, overwrite), 0) << "Error setting environment variable";
std::cout << varName << " set to " << getenv(varName) << std::endl;
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
if (world_size % 2)
return;
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
int warmup = 100, iter = 100;
std::vector<int> candidate_bs{1, 2, 4, 8, 16};

View File

@ -628,11 +628,8 @@ TEST(Kernel, MoEReduceAddARFuse)
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
if (world_size % 2)
{
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
return;
}
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
int warmup = 100, iter = 100;
int hidden_dim = 7168;
std::vector<int> candidate_token_num{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048};

View File

@ -26,15 +26,13 @@ TEST(UserBuffer, basic)
{
if (!tr::ub::ub_supported())
{
return;
GTEST_SKIP() << "UserBuffer is not supported";
}
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
if (world_size % 2)
{
return;
}
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
tr::ub::ub_initialize(world_size);
EXPECT_EQ(tr::ub::ub_is_initialized(), true);
EXPECT_NE(tr::ub::ub_comm(), nullptr);

View File

@ -13,6 +13,8 @@ add_gtest(bufferManagerTest bufferManagerTest.cpp)
add_gtest(cudaMemPoolTest cudaMemPoolTest.cpp)
add_gtest(decodingLayerWorkspaceTest decodingLayerWorkspaceTest.cpp)
add_gtest(gdrcopyTest gdrcopyTest.cpp)
add_gtest(gptDecoderBatchedTest gptDecoderBatchedTest.cpp)
add_gtest(gptDecoderTest gptDecoderTest.cpp)
add_gtest(hostAccessibleDeviceAllocatorTest
hostAccessibleDeviceAllocatorTest.cu)
add_gtest(iBufferTest iBufferTest.cpp)
@ -20,13 +22,15 @@ add_gtest(iTensorTest iTensorTest.cpp)
add_gtest(loraCacheTest loraCacheTest.cpp)
add_gtest(loraManagerTest loraManagerTest.cpp)
add_gtest(loraUtilsTest loraUtilsTest.cpp)
add_gtest(medusaModuleTest medusaModuleTest.cpp)
add_gtest(moeLoadBalancerTest moeLoadBalancerTest.cpp)
add_gtest(runtimeKernelTest runtimeKernelTest.cpp)
add_gtest(samplingConfigTest samplingConfigTest.cpp)
add_gtest(samplingTest samplingTest.cpp)
add_gtest(sanitizerTest sanitizerTest.cpp)
add_gtest(tllmBuffersTest tllmBuffersTest.cpp)
add_gtest(tllmRuntimeTest tllmRuntimeTest.cpp)
add_gtest(transposeKVKernelTest transposeKVKernelTest.cpp)
add_gtest(userBufferTest userBufferTest.cpp)
add_gtest(utilsTest utilsTest.cpp)
add_gtest(virtualMemoryTest virtualMemoryTest.cpp)
add_gtest(workerPoolTest workerPoolTest.cpp)

View File

@ -69,9 +69,11 @@ cat << EOF > ${EXTRA_LLM_API_FILE}
enable_attention_dp: false
cuda_graph_config:
enable_padding: true
max_batch_size: 128
max_batch_size: 720
moe_config:
backend: TRTLLM
stream_interval: 10
num_postprocess_workers: 4
EOF
```
@ -84,9 +86,11 @@ cat << EOF > ${EXTRA_LLM_API_FILE}
enable_attention_dp: true
cuda_graph_config:
enable_padding: true
max_batch_size: 128
max_batch_size: 720
moe_config:
backend: CUTLASS
stream_interval: 10
num_postprocess_workers: 4
EOF
```
@ -99,9 +103,8 @@ trtllm-serve openai/gpt-oss-120b \
--host 0.0.0.0 \
--port 8000 \
--backend pytorch \
--max_batch_size 128 \
--max_batch_size 720 \
--max_num_tokens 16384 \
--max_seq_len 2048 \
--kv_cache_free_gpu_memory_fraction 0.9 \
--tp_size 8 \
--ep_size 8 \
@ -134,7 +137,7 @@ These options are used directly on the command line when you start the `trtllm-s
#### `--max_batch_size`
* **Description:** The maximum number of user requests that can be grouped into a single batch for processing.
* **Description:** The maximum number of user requests that can be grouped into a single batch for processing. The actual max batch size that can be achieved depends on total sequence length (input + output).
#### `--max_num_tokens`
@ -142,7 +145,7 @@ These options are used directly on the command line when you start the `trtllm-s
#### `--max_seq_len`
* **Description:** The maximum possible sequence length for a single request, including both input and generated output tokens.
* **Description:** The maximum possible sequence length for a single request, including both input and generated output tokens. We won't specifically set it. It will be inferred from model config.
#### `--trust_remote_code`
@ -229,8 +232,7 @@ TODO: Use Chat Compeletions API / Responses API as the example after the PR is m
### Running Evaluations to Verify Accuracy (Optional)
We use OpenAI's official evaluation tool to test the model's accuracy. For more information see [https://github.com/openai/gpt-oss/tree/main/gpt_oss/evals](gpt-oss-eval).
TODO(@Binghan Chen): Add instructions for running gpt-oss-eval.
With the added support of Chat Completions and Responses API in `trtllm-serve,` `gpt_oss.evals` works directly without any modifications.
## Benchmarking Performance
@ -267,6 +269,8 @@ EOF
chmod +x bench.sh
```
To achieve max through-put, with attention DP on, one needs to sweep up to `concurrency = max_batch_size * num_gpus`.
If you want to save the results to a file add the following options.
```shell

View File

@ -690,9 +690,9 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars)
"cpp/tensorrt_llm/thop/allgatherOp.cpp",
"cpp/tensorrt_llm/thop/allreduceOp.cpp",
"cpp/tensorrt_llm/thop/reducescatterOp.cpp",
"cpp/tests/executor/",
"cpp/tests/kernels/allReduce/",
"cpp/tests/runtime/mpiUtilsTest.cpp",
"cpp/tests/e2e_tests/batch_manager/",
"cpp/tests/e2e_tests/executor/",
"cpp/tests/unit_tests/multi_gpu/",
"jenkins/L0_Test.groovy",
"tensorrt_llm/_ipc_utils.py",
"tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py",

View File

@ -7,7 +7,6 @@ import groovy.json.JsonOutput
import com.nvidia.bloom.KubernetesManager
import com.nvidia.bloom.Constants
import com.nvidia.bloom.CloudManager
import com.nvidia.bloom.KubernetesManager
import com.nvidia.bloom.SlurmConfig
import com.nvidia.bloom.SlurmCluster
import com.nvidia.bloom.SlurmPartition
@ -230,8 +229,11 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
SlurmPartition partition = SlurmConfig.partitionConfig[platform] as SlurmPartition
SlurmCluster cluster = SlurmConfig.clusterConfig[partition.clusterName]
def nodeName = "${cluster.host}-test-${UUID.randomUUID().toString()}"
def nodeSecret = CloudManager.createNode(nodeName)
// Create a unique suffix for the node name and workspace
String customSuffix = "${env.BUILD_TAG}-${UUID.randomUUID().toString().replaceAll("-", "").substring(0, 6)}".toLowerCase()
def nodeName = "${cluster.host}-test-${customSuffix}"
def customWorkspace = "/tmp/${nodeName}"
def nodeSecret = CloudManager.createNode(nodeName, customWorkspace)
try {
// Run ssh command to start node in desired cluster via SLURM
@ -274,12 +276,30 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
}
if (CloudManager.isNodeOnline(nodeName)) {
def dockerArgs = "--gpus ${gpuCount} --cap-add=SYS_ADMIN --ipc=host --security-opt seccomp=unconfined -u root:root -v /home/scratch.trt_llm_data:/scratch.trt_llm_data:ro -v /tmp/ccache:${CCACHE_DIR}:rw -v /tmp/pipcache/http-v2:/root/.cache/pip/http-v2:rw --cap-add syslog"
node(nodeName) {
sh """
env | sort
pwd && ls -alh
ls -alh ${env.WORKSPACE}
ls -alh ${env.WORKSPACE_TMP}
"""
}
def dockerArgs = "--gpus ${gpuCount} " +
"--cap-add=SYS_ADMIN " +
"--ipc=host " +
"--security-opt seccomp=unconfined " +
"-u root:root " +
"-v /home/scratch.trt_llm_data:/scratch.trt_llm_data:ro " +
"-v /tmp/ccache:${CCACHE_DIR}:rw " +
"-v /tmp/pipcache/http-v2:/root/.cache/pip/http-v2:rw " +
"--cap-add syslog"
if (partition.clusterName == "dlcluster") {
dockerArgs += " -e NVIDIA_IMEX_CHANNELS=0"
}
slurmRunner = runInDockerOnNodeMultiStage(LLM_DOCKER_IMAGE, nodeName, dockerArgs, false)
slurmRunner = runInDockerOnNodeMultiStage(LLM_DOCKER_IMAGE, nodeName, dockerArgs, true)
executeLLMTestOnSlurm(pipeline, platform, testList, config, perfMode, stageName, splitId, splits, skipInstallWheel, cpver, slurmRunner)
} else {
echo "The node does not come online in 2 hours, terminating the job"
@ -571,6 +591,13 @@ def cacheErrorAndUploadResult(stageName, taskRunner, finallyRunner, noResultIfSu
"${UPLOAD_PATH}/test-results/"
)
junit(testResults: "${stageName}/results*.xml")
// Clean up the workspace
sh """
env | sort
pwd && ls -alh
rm -rf ./*
"""
}
}
}
@ -807,7 +834,7 @@ def echoNodeAndGpuInfo(pipeline, stageName)
def runLLMDocBuild(pipeline, config)
{
// Step 1: cloning tekit source code
// Step 1: cloning source code
sh "pwd && ls -alh"
sh "env | sort"
// allow to checkout from forked repo, svc_tensorrt needs to have access to the repo, otherwise clone will fail
@ -1252,13 +1279,16 @@ def rerunFailedTests(stageName, llmSrc, testCmdLine) {
def runLLMTestlistOnPlatformImpl(pipeline, platform, testList, config=VANILLA_CONFIG, perfMode=false, stageName="Undefined", splitId=1, splits=1, skipInstallWheel=false, cpver="cp312")
{
// Step 1: create LLM_ROOT dir
sh "pwd && ls -alh"
// TODO: proper way to clean workspace, maybe save in a folder named with BUILD_ID.
// So that it can work with multiple job running in same node
sh "rm -rf ./*"
// Step 1: create LLM_ROOT dir and clean up the workspace
def llmRootConfig = "${LLM_ROOT}${config}"
sh "mkdir ${llmRootConfig}"
sh """
env | sort
pwd && ls -alh
rm -rf ./*
mkdir ${llmRootConfig}
ls -alh ${env.WORKSPACE}
ls -alh ${env.WORKSPACE_TMP}
"""
def llmPath = sh (script: "realpath ${llmRootConfig}", returnStdout: true).trim()
def llmSrc = "${llmPath}/TensorRT-LLM/src"
@ -1779,7 +1809,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
"DGX_H100-4_GPUs-PyTorch-DeepSeek-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 2, 4],
"DGX_H100-4_GPUs-PyTorch-DeepSeek-2": ["dgx-h100-x4", "l0_dgx_h100", 2, 2, 4],
"DGX_H100-4_GPUs-PyTorch-Others-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
"DGX_H100-4_GPUs-Triton-Post-Merge-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
"DGX_H100-4_GPUs-CPP-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
"A10-PyTorch-1": ["a10", "l0_a10", 1, 1],
"A10-CPP-1": ["a10", "l0_a10", 1, 1],
@ -1852,6 +1881,7 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
"B200_PCIe-TensorRT-Post-Merge-2": ["b100-ts2", "l0_b200", 2, 2],
"H100_PCIe-TensorRT-Perf-1": ["h100-cr", "l0_perf", 1, 1],
"H100_PCIe-PyTorch-Perf-1": ["h100-cr", "l0_perf", 1, 1],
"DGX_H200-4_GPUs-Triton-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 1, 4],
"DGX_H200-8_GPUs-PyTorch-Post-Merge-1": ["dgx-h200-x8", "l0_dgx_h200", 1, 1, 8],
"DGX_H200-4_GPUs-PyTorch-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 1, 4],
"DGX_H200-4_GPUs-TensorRT-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 3, 4],
@ -1910,8 +1940,10 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
fullSet += SBSATestConfigs.keySet()
SBSASlurmTestConfigs = [
"GB200-4_GPUs-PyTorch-1": ["gb200-x4", "l0_gb200", 1, 1, 4],
"GB200-4_GPUs-PyTorch-Post-Merge-1": ["gb200-x4", "l0_gb200", 1, 1, 4],
// Not used in the pipeline now
// "GB200-PyTorch-1": ["gb200-single", "l0_gb200", 1, 3],
"GB200-4_GPUs-PyTorch-1": ["gb200-x4", "l0_gb200_multi_gpus", 1, 1, 4],
"GB200-4_GPUs-PyTorch-Post-Merge-1": ["gb200-x4", "l0_gb200_multi_gpus", 1, 1, 4],
]
fullSet += SBSASlurmTestConfigs.keySet()
@ -1923,7 +1955,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
"GB200-8_GPUs-2_Nodes-PyTorch-Post-Merge-4": ["gb200-multi-node", "l0_gb200_multi_nodes", 4, 7, 8, 2],
"GB200-8_GPUs-2_Nodes-PyTorch-Post-Merge-5": ["gb200-multi-node", "l0_gb200_multi_nodes", 5, 7, 8, 2],
"GB200-8_GPUs-2_Nodes-PyTorch-Post-Merge-6": ["gb200-multi-node", "l0_gb200_multi_nodes", 6, 7, 8, 2],
"GB200-8_GPUs-2_Nodes-PyTorch-Post-Merge-7": ["gb200-multi-node", "l0_gb200_multi_nodes", 7, 7, 8, 2],
]
fullSet += multiNodesSBSAConfigs.keySet()
@ -2145,7 +2176,9 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
echo "###### Check pip install Start ######"
withEnv(libEnv) {
sh "env | sort"
checkPipInstall(pipeline, "${cpu_arch}/${wheelPath}")
timeout(time: 1, unit: 'HOURS') {
checkPipInstall(pipeline, "${cpu_arch}/${wheelPath}")
}
}
echo "###### Run LLMAPI tests Start ######"
def config = VANILLA_CONFIG_CU12
@ -2481,7 +2514,7 @@ pipeline {
def testPhase2StageName = env.testPhase2StageName
if (testPhase2StageName) {
def dgxSigns = ["DGX_H100", "DGX_H200", "GB200", "DGX_B200", "RTXPro6000-4_GPUs"]
def dgxSigns = ["2_GPUs", "4_GPUs", "8_GPUs"]
singleGpuJobs = parallelJobs.findAll{!dgxSigns.any{sign -> it.key.contains(sign)}}
dgxJobs = parallelJobs.findAll{dgxSigns.any{sign -> it.key.contains(sign)}}
}

View File

@ -34,7 +34,7 @@ else
done
fi
testList="$testList_$splitId"
export CPP_TEST_TIMEOUT_OVERRIDDEN=7200
export CPP_TEST_TIMEOUT_OVERRIDDEN=$pytestTestTimeout
export LLM_ROOT=$llmSrcNode
export LLM_MODELS_ROOT=$MODEL_CACHE_DIR
export UCX_TLS=^gdr_copy
@ -43,6 +43,7 @@ testCmdLines=(
"$llmSrcNode/tensorrt_llm/llmapi/trtllm-llmapi-launch"
"pytest"
"-v"
"--timeout-method=thread"
"--timeout=$pytestTestTimeout"
"--test-list=$testListPathNode"
"--waives-file=$waivesListPathNode"

View File

@ -17,7 +17,7 @@ import os
import platform
import sys
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional, Union
import pynvml
import torch
@ -366,6 +366,10 @@ class MnnvlMoe:
)
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(torch.uint64)
torch.ops.trtllm.moe_initialize_workspace(
MnnvlMoe.moe_workspace_tensor, mapping.tp_rank, mapping.tp_size
)
MnnvlMoe.moe_workspace.comm.barrier()
return MnnvlMoe.moe_workspace_tensor
@staticmethod
@ -394,7 +398,6 @@ class MnnvlMoe:
@staticmethod
def mnnvl_moe_alltoallv_prepare_without_allgather(
expert_ids: torch.Tensor,
scales: torch.Tensor,
expert_statics: Optional[torch.Tensor],
workspace: torch.Tensor,
max_token_count_per_rank: int,
@ -405,8 +408,6 @@ class MnnvlMoe:
top_k: int,
):
(
prepared_local_experts,
prepared_local_scales,
local_send_rank_count_cumsum,
local_send_rank_indices,
local_recv_rank_count_cumsum,
@ -415,7 +416,6 @@ class MnnvlMoe:
gathered_expert_statics,
) = torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(
expert_ids,
scales,
expert_statics,
workspace,
max_token_count_per_rank,
@ -440,7 +440,7 @@ class MnnvlMoe:
local_token_allocation_count,
)
return alltoall_info, prepared_local_experts, prepared_local_scales, gathered_expert_statics
return alltoall_info, gathered_expert_statics
@staticmethod
def mnnvl_moe_expert_static_allgather(
@ -526,31 +526,67 @@ class MnnvlMoe:
@staticmethod
def mnnvl_moe_alltoallv(
x: torch.Tensor,
x: Union[torch.Tensor, List[Optional[torch.Tensor]]],
alltoall_info: MoEAlltoallInfo,
workspace: torch.Tensor,
ep_rank: int,
ep_size: int,
):
assert x.dim() == 2, "only 2D tensor supported, please reshape."
output_tensor = torch.empty(
alltoall_info.local_token_allocation_count,
x.shape[1],
dtype=x.dtype,
device=torch.device("cuda"),
)
torch.ops.trtllm.moe_comm(
x,
alltoall_info.send_rank_count_cumsum,
alltoall_info.send_rank_local_indices,
output_tensor,
alltoall_info.recv_rank_count_cumsum,
alltoall_info.recv_rank_local_indices,
workspace,
ep_rank,
ep_size,
)
return output_tensor
) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]:
# Convert single tensor to list for unified handling
is_single_tensor = not isinstance(x, list)
if is_single_tensor:
assert x.dim() == 2, "only 2D tensor supported, please reshape."
x = [x]
assert len(x) > 0, "Empty tensor list not supported"
# Filter out None values
valid_list = [tensor is not None for tensor in x]
valid_tensors = [tensor for tensor in x if tensor is not None]
if len(valid_tensors) == 0:
# All tensors are None, return list of None
result = [None] * len(x)
else:
first_dim = None
for tensor in valid_tensors:
# Validate dimensions of valid tensors
assert tensor.dim() == 2, "only 2D tensor supported, please reshape."
if first_dim is None:
first_dim = tensor.shape[0]
else:
assert tensor.shape[0] == first_dim, (
f"All tensors must have the same first dimension, got {tensor.shape[0]} vs {first_dim}"
)
# Process only valid tensors
output_tensors = torch.ops.trtllm.moe_comm(
valid_tensors,
alltoall_info.send_rank_count_cumsum,
alltoall_info.send_rank_local_indices,
alltoall_info.recv_rank_count_cumsum,
alltoall_info.recv_rank_local_indices,
workspace,
alltoall_info.local_token_allocation_count,
ep_rank,
ep_size,
)
# Restore None positions in output
idx = 0
result = []
for is_valid in valid_list:
if is_valid:
result.append(output_tensors[idx])
idx += 1
else:
result.append(None)
# If input was a single tensor, return a single tensor
if is_single_tensor:
result = result[0]
return result
@staticmethod
def mnnvl_moe_alltoallv_combine(
@ -563,20 +599,19 @@ class MnnvlMoe:
token_count: int,
):
assert x.dim() == 2, "2D tensor supported, please reshape."
output_tensor = torch.zeros(
token_count * top_k, x.shape[1], dtype=x.dtype, device=torch.device("cuda")
)
torch.ops.trtllm.moe_comm(
x,
output_tensors = torch.ops.trtllm.moe_comm(
[x],
alltoall_info.recv_rank_count_cumsum,
alltoall_info.recv_rank_local_indices,
output_tensor,
alltoall_info.send_rank_count_cumsum,
alltoall_info.backward_recv_rank_local_indices,
workspace,
token_count * top_k,
ep_rank,
ep_size,
[True],
)
output_tensor = output_tensors[0]
return torch.sum(
output_tensor.reshape(token_count, top_k, x.shape[1]), dim=1, keepdim=False
)

View File

@ -476,9 +476,6 @@ class SequenceInfo:
idx = self.previous_batch_indices_cuda[: len(previous_batch_indices)]
idx.copy_(host_idx, non_blocking=True)
# sort them so that masked_scatter_ lines up correctly
idx, _ = idx.sort()
# gather the exact values you want to write
src = new_tokens[0, idx, 0]

View File

@ -179,71 +179,27 @@ def _register_fake():
return (input.new_empty(output_shape, dtype=torch.uint8),
global_scale.new_empty(scale_shape, dtype=torch.uint8))
@torch.library.register_fake("trtllm::moe_comm_prepare_indices")
def _(
gathered_target_rank_ids: torch.Tensor,
real_rank_token_count_cum_sum: Optional[torch.Tensor],
max_token_count_per_rank: int,
expert_count: int,
top_k: int,
ep_rank: int,
ep_size: int,
):
max_send_ranks_per_token = max(ep_size, top_k)
local_gather_indices_shape = (max_token_count_per_rank * ep_size, )
rank_count_cum_sum_shape = (ep_size, )
send_rank_local_indices_shape = (max_token_count_per_rank *
max_send_ranks_per_token, )
recv_rank_local_indices_shape = (max_token_count_per_rank * ep_size, )
backward_recv_rank_local_indices_shape = (max_token_count_per_rank *
max_send_ranks_per_token, )
local_gather_indices = gathered_target_rank_ids.new_empty(
local_gather_indices_shape, dtype=torch.int32)
send_rank_count_cum_sum = gathered_target_rank_ids.new_empty(
rank_count_cum_sum_shape, dtype=torch.int32)
send_rank_local_indices = gathered_target_rank_ids.new_empty(
send_rank_local_indices_shape, dtype=torch.int32)
recv_rank_count_cum_sum = gathered_target_rank_ids.new_empty(
rank_count_cum_sum_shape, dtype=torch.int32)
recv_rank_local_indices = gathered_target_rank_ids.new_empty(
recv_rank_local_indices_shape, dtype=torch.int32)
backward_recv_rank_local_indices = gathered_target_rank_ids.new_empty(
backward_recv_rank_local_indices_shape, dtype=torch.int32)
return (local_gather_indices, send_rank_count_cum_sum,
send_rank_local_indices, recv_rank_count_cum_sum,
recv_rank_local_indices, backward_recv_rank_local_indices)
@torch.library.register_fake("trtllm::moe_local_gather")
def _(
recv_rank_cum_sum: torch.Tensor,
local_gather_indices: torch.Tensor,
gathered_expert_ids: torch.Tensor,
gathered_scales: Optional[torch.Tensor],
local_expert_ids: torch.Tensor,
local_scales: Optional[torch.Tensor],
max_token_count_per_rank: int,
expert_count: int,
top_k: int,
ep_rank: int,
ep_size: int,
):
pass
@torch.library.register_fake("trtllm::moe_comm")
def _(
input: torch.Tensor,
inputs: List[torch.Tensor],
send_rank_cum_sum: torch.Tensor,
send_indices: torch.Tensor,
output: torch.Tensor,
recv_rank_cum_sum: torch.Tensor,
recv_indices: torch.Tensor,
all_workspaces: torch.Tensor,
output_allocation_count: int,
ep_rank: int,
ep_size: int,
need_zero_output: Optional[List[bool]],
):
pass
outputs = []
for input_tensor in inputs:
output_tensor = torch.empty(
(output_allocation_count, input_tensor.shape[1]),
dtype=input_tensor.dtype,
device=input_tensor.device)
outputs.append(output_tensor)
return outputs
@torch.library.register_fake("trtllm::get_moe_commworkspace_size_per_rank")
def _(ep_size: int):
@ -287,6 +243,12 @@ def _register_fake():
token_selected_experts: torch.Tensor, offset_by_ep_rank: bool):
return torch.empty_like(token_selected_experts)
@torch.library.register_fake("trtllm::memset_expert_ids")
def _(experts_ids: torch.Tensor, recv_rank_count_cumsum: torch.Tensor,
max_token_count_per_rank: int, top_k: int, slot_count: int,
ep_size: int):
pass
@torch.library.custom_op("trtllm::group_rms_norm_base",
mutates_args=("outputs", ))
def group_rms_norm_base(

View File

@ -58,8 +58,7 @@ from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
MoEWeightLoadingMode, TRTLLMGenFusedMoE,
create_moe,
moe_load_balancer_set_repeated_for_next_layer)
create_moe)
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@ -1159,7 +1158,7 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
self.num_hidden_layers = self.config.num_hidden_layers
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
moe_load_balancer_set_repeated_for_next_layer(model_nextn)
pass
else:
# modify the QuantConfig to support duplicated mtp layers
if model_config.quant_config.exclude_modules is not None:

View File

@ -14,6 +14,7 @@ from ..model_config import ModelConfig, TConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import moe_load_balancer_set_repeated_for_next_layer
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import (Linear, TensorParallelMode, WeightMode,
WeightsLoadingConfig)
@ -340,6 +341,9 @@ class MTPForCausalLM(nn.Module):
mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle(
) else model_config.spec_config.num_nextn_predict_layers
moe_load_balancer_set_repeated_for_next_layer(
model_config.spec_config.num_nextn_predict_layers // mtp_num_layers)
self.mtp_layers = nn.ModuleList([
DeepseekV3MTP(model_config, layer_idx + start_layer_idx,
model.aux_stream_dict)

View File

@ -5,7 +5,6 @@ from typing import Dict, List, Optional, Union
import torch
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
from tensorrt_llm.math_utils import pad_up
from ...distributed import allgather
from ...model_config import ModelConfig
@ -190,8 +189,6 @@ class CutlassFusedMoE(MoE):
@cached_property
def enable_alltoall(self):
return (self.mapping.moe_ep_size > self.routing_method.experts_per_token
and self.routing_method.experts_per_token % 4 ==
0 # alltoall without allgather only supports top_k % 4 == 0
and self.mapping.enable_attention_dp
and self.mapping.tp_size > 1
and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1"
@ -353,39 +350,28 @@ class CutlassFusedMoE(MoE):
token_final_scales = torch.ones_like(token_selected_experts,
dtype=torch.float32)
# TODO: support alltoall without allgather for top_k % 4 != 0
assert top_k % 4 == 0, "alltoall without allgather only supports top_k % 4 == 0"
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
alltoall_info, token_selected_experts, token_final_scales, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
token_selected_experts, token_final_scales, None,
self.alltoall_prepare_workspace, max_num_token, self.ep_rank,
self.ep_size, self.num_experts, self.num_experts, top_k)
alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
token_selected_experts, None, self.alltoall_prepare_workspace,
max_num_token, self.ep_rank, self.ep_size, self.num_experts,
self.num_experts, top_k)
# Dispatch alltoall (common for both paths)
x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info,
self.alltoall_workspace,
self.ep_rank, self.ep_size)
if x_sf is not None:
x_sf = x_sf.view(x_row, ceil_div(x_col,
self.scaling_vector_size))
# Pad dim[1] to 16 bytes alignment for alltoall
# TODO: Remove this padding if possible
sf_per_16bytes = 16 // x_sf.element_size()
x_sf_col_orig = x_sf.shape[1]
x_sf_col = pad_up(x_sf_col_orig, sf_per_16bytes)
if x_sf_col > x_sf_col_orig:
x_sf = torch.nn.functional.pad(
x_sf, (0, x_sf_col - x_sf_col_orig))
# Dispatch x, x_sf, token_selected_experts, token_final_scales in one alltoall kernel
x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv(
[x, x_sf, token_selected_experts, token_final_scales],
alltoall_info, self.alltoall_workspace, self.ep_rank,
self.ep_size)
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(x_sf, alltoall_info,
self.alltoall_workspace,
self.ep_rank, self.ep_size)
torch.ops.trtllm.memset_expert_ids(
token_selected_experts, alltoall_info.recv_rank_count_cumsum,
max_num_token, top_k, self.num_experts, self.ep_size)
if x_sf is not None:
x_row = x_sf.shape[0]
# TODO: Remove this slicing required by padding if possible
x_sf = x_sf[:, :x_sf_col_orig].contiguous()
x_sf = swizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size)
elif run_post_quant_allgather:

View File

@ -192,18 +192,13 @@ class WideEPMoE(MoE):
self.use_low_precision_combine = (os.environ.get(
"TRTLLM_MOE_USE_LOW_PRECISION_COMBINE", "0")
== "1") and qm.has_nvfp4()
# TODO: support alltoall without allgather for top_k % 4 != 0
self.enable_alltoall_without_allgather = (
os.environ.get("TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER",
"1") == "1"
) and self.alltoall_method_type == AlltoallMethodType.MNNVL and routing_method.experts_per_token % 4 == 0
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
if self.enable_alltoall_without_allgather:
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
self.deep_ep_buffer = buffer_pool.get_buffer(
model_config.mapping)
@ -301,6 +296,9 @@ class WideEPMoE(MoE):
1) // self.moe_max_num_tokens
def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens):
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
return True
# Disable alltoall when chunking is used
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
return False
@ -458,24 +456,23 @@ class WideEPMoE(MoE):
else:
tuner_num_tokens = None
tuner_top_k = None
alltoall_info = None
if use_all_to_all:
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
if self.enable_dummy_allreduce:
self.dummy_allreduce()
token_count = x.shape[0]
alltoall_info = None
if is_last_call:
if is_last_call and self.layer_load_balancer is not None and not self.layer_load_balancer.is_static_routing(
):
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
)
else:
loadbalancer_local_statistic_info = None
x, token_selected_slots, token_final_scales, gathered_loadbalancer_local_statistic_info, alltoall_info = \
self.alltoall_prepare_maybe_dispatch(all_rank_max_num_tokens,
x,
token_selected_slots,
token_final_scales,
use_postquant_alltoall,
loadbalancer_local_statistic_info)
token_selected_slots, gathered_loadbalancer_local_statistic_info, alltoall_info = \
self.alltoall_prepare(all_rank_max_num_tokens,
token_selected_slots,
loadbalancer_local_statistic_info)
if gathered_loadbalancer_local_statistic_info is not None:
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
(self.mapping.moe_ep_size, self.num_experts))
@ -580,10 +577,15 @@ class WideEPMoE(MoE):
cluster_rank = self.cluster_rank
quant_scales = self.quant_scales
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
top_k = self.routing_method.experts_per_token
x, x_sf, token_selected_slots, token_final_scales = self.alltoall_dispatch(
x, x_sf, token_selected_slots, token_final_scales,
all_rank_max_num_tokens, top_k, alltoall_info)
if use_postquant_alltoall:
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
x, x_sf = self.alltoall_postquant_dispatch(
x, x_sf, alltoall_info)
pass
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
if x_sf is not None:
# Adapter between `x_sf` and DeepEP
@ -862,77 +864,34 @@ class WideEPMoE(MoE):
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
return outputs
def alltoall_prepare_maybe_dispatch(
self, all_rank_max_num_tokens: int, x: torch.Tensor,
token_selected_slots: torch.Tensor,
token_final_scales: torch.Tensor, use_postquant_alltoall: bool,
local_statistic_tensor: Optional[torch.Tensor]):
def alltoall_prepare(self, all_rank_max_num_tokens: int,
token_selected_slots: torch.Tensor,
local_statistic_tensor: Optional[torch.Tensor]):
top_k = self.routing_method.experts_per_token
if self.enable_alltoall_without_allgather:
alltoall_info, token_selected_slots, token_final_scales, gathered_local_statistic_tensor = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
token_selected_slots, token_final_scales,
local_statistic_tensor, self.alltoall_prepare_workspace,
all_rank_max_num_tokens, self.ep_rank, self.ep_size,
self.num_experts, self.num_slots, top_k)
else:
if all_rank_max_num_tokens > token_selected_slots.shape[0]:
token_selected_slots = torch.nn.functional.pad(
token_selected_slots,
(0, 0, 0,
all_rank_max_num_tokens - token_selected_slots.shape[0]),
'constant', self.num_slots)
if token_final_scales is not None and all_rank_max_num_tokens > token_final_scales.shape[
0]:
token_final_scales = torch.nn.functional.pad(
token_final_scales,
(0, 0, 0,
all_rank_max_num_tokens - token_final_scales.shape[0]))
gathered_token_selected_slots, gathered_token_final_scales, gathered_local_statistic_tensor = allgather(
[
token_selected_slots, token_final_scales,
local_statistic_tensor
],
self.mapping,
dim=0)
gathered_token_selected_slots = torch.flatten(
gathered_token_selected_slots.contiguous(),
start_dim=0,
end_dim=-2)
if gathered_token_final_scales is not None:
gathered_token_final_scales = torch.flatten(
gathered_token_final_scales.contiguous(),
start_dim=0,
end_dim=-2)
gathered_target_rank_ids = MnnvlMoe.compute_target_rank_id(
gathered_token_selected_slots, self.num_slots, self.ep_size)
alltoall_info, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv_prepare(
gathered_target_rank_ids, None, gathered_token_selected_slots,
gathered_token_final_scales, all_rank_max_num_tokens,
self.num_slots, top_k, self.ep_rank, self.ep_size)
alltoall_info, gathered_local_statistic_tensor = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
token_selected_slots, local_statistic_tensor,
self.alltoall_prepare_workspace, all_rank_max_num_tokens,
self.ep_rank, self.ep_size, self.num_experts, self.num_slots, top_k)
if not use_postquant_alltoall:
assert not isinstance(
x, Fp4QuantizedTensor
), "pre-quant alltoall doesn't support fp4 tensor"
x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info,
self.alltoall_workspace,
self.ep_rank, self.ep_size)
return token_selected_slots, gathered_local_statistic_tensor, alltoall_info
return x, token_selected_slots, token_final_scales, gathered_local_statistic_tensor, alltoall_info
def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
token_selected_slots: torch.Tensor,
token_final_scales: Optional[torch.Tensor],
all_rank_max_num_tokens: int, top_k: int,
alltoall_info: MoEAlltoallInfo):
def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor,
alltoall_info: MoEAlltoallInfo):
x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info,
self.alltoall_workspace, self.ep_rank,
self.ep_size)
x, x_sf, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv(
[x, x_sf, token_selected_slots, token_final_scales], alltoall_info,
self.alltoall_workspace, self.ep_rank, self.ep_size)
if x_sf is not None:
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(x_sf, alltoall_info,
self.alltoall_workspace,
self.ep_rank, self.ep_size)
torch.ops.trtllm.memset_expert_ids(token_selected_slots,
alltoall_info.recv_rank_count_cumsum,
all_rank_max_num_tokens, top_k,
self.num_slots, self.ep_size)
return x, x_sf
return x, x_sf, token_selected_slots, token_final_scales
def alltoall_combine(self, final_hidden_states: torch.Tensor,
alltoall_info: MoEAlltoallInfo, token_count: int):

View File

@ -423,7 +423,8 @@ def load_torch_hf_lora(lora_config: LoraConfig):
pivot model config is the transformer's one.
"""
# TODO smor- need to comibe with load_hf_lora
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules()
if not lora_config.trtllm_modules_to_hf_modules:
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules()
assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir"
lora_loader = HfLoraLoader(lora_config.lora_dir)

View File

@ -109,7 +109,6 @@
"examples/test_whisper.py::test_llm_whisper_general[large-v3-disable_gemm_plugin-disable_attention_plugin-disable_weight_only-float16-nb:1-use_python_runtime]": 327.95307156071067,
"test_e2e.py::test_build_time_benchmark_sanity": 165.71592589840293,
"test_unittests.py::test_unittests_v2[unittest/trt/attention/test_bert_attention.py]": 99.96196278184652,
"cpp/test_e2e.py::test_benchmarks[gpt-80]": 1376.0404928650241,
"accuracy/test_cli_flow.py::TestPhi3Small128kInstruct::test_auto_dtype": 512.450893450994,
"accuracy/test_llm_api.py::TestLlama3_1_8B::test_fp8_rowwise": 361.5573864541948,
"examples/test_llama.py::test_llm_llama_v3_dora_1gpu[commonsense-llama-v3-8b-dora-r32-llama-v3-8b-hf-base_fp16]": 517.2770831151865,
@ -150,10 +149,6 @@
"test_unittests.py::test_unittests_v2[unittest/_torch/thop]": 852.56,
"test_unittests.py::test_unittests_v2[unittest/_torch/modeling -k \"modeling_mixtral\"]": 208.1838396479725,
"test_unittests.py::test_unittests_v2[unittest/_torch/multi_gpu_modeling -k \"deepseek\"]": 393.0210295501165,
"cpp/test_e2e.py::test_model[-gpt_executor-80]": 4016.7569622844458,
"cpp/test_e2e.py::test_model[-gpt_tests-80]": 1817.8153839111328,
"cpp/test_unit_tests.py::test_unit_tests[executor-80]": 339.0683519244194,
"cpp/test_unit_tests.py::test_unit_tests[kernels-80]": 846.0403860099614,
"test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]": 21.019993914989755,
"test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]": 18.753523574909195,
"test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-video-False]": 278.4781197870616,
@ -254,35 +249,37 @@
"accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only[int8]": 159.531545445323,
"accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only_int8_kv_cache[int8]": 184.35870655626059,
"test_unittests.py::test_unittests_v2[unittest/llmapi/test_llm_models.py -m \"not (part0 or part1)\"]": 825.9972547292709,
"cpp/test_e2e.py::test_benchmarks[bart-90]": 271.95234084688127,
"cpp/test_e2e.py::test_model[-bart-90]": 391.84748707409017,
"cpp/test_e2e.py::test_benchmarks[gpt-80]": 1376.0404928650241,
"cpp/test_e2e.py::test_model[-gpt_executor-80]": 1495.14,
"cpp/test_e2e.py::test_model[-gpt_tests-80]": 1206.79,
"cpp/test_e2e.py::test_model[-gpt-80]": 1568.98,
"cpp/test_unit_tests.py::test_unit_tests[batch_manager-80]": 1005.24,
"cpp/test_unit_tests.py::test_unit_tests[common-80]": 38.98,
"cpp/test_unit_tests.py::test_unit_tests[executor-80]": 425.16,
"cpp/test_unit_tests.py::test_unit_tests[kernels-80]": 2009.96,
"cpp/test_unit_tests.py::test_unit_tests[layers-80]": 2209.11,
"cpp/test_unit_tests.py::test_unit_tests[runtime-80]": 1671.42,
"cpp/test_unit_tests.py::test_unit_tests[thop-80]": 6.76,
"cpp/test_unit_tests.py::test_unit_tests[utils-80]": 8.53,
"cpp/test_e2e.py::test_model[-eagle-86]": 850.5158995762467,
"cpp/test_e2e.py::test_model[-mamba-86]": 893.8684413917363,
"cpp/test_e2e.py::test_model[-medusa-86]": 577.0913726426661,
"cpp/test_e2e.py::test_model[-redrafter-86]": 356.56682327389717,
"cpp/test_e2e.py::test_benchmarks[t5-90]": 244.83684724476188,
"cpp/test_e2e.py::test_model[-enc_dec_language_adapter-90]": 356.5558080910705,
"cpp/test_e2e.py::test_model[-t5-90]": 167.93334361724555,
"cpp/test_e2e.py::test_model[fp8-llama-90]": 810.9923318810761,
"cpp/test_unit_tests.py::test_unit_tests[batch_manager-90]": 230.49758478673175,
"cpp/test_unit_tests.py::test_unit_tests[common-90]": 12.953252204693854,
"cpp/test_unit_tests.py::test_unit_tests[executor-90]": 317.7621980938129,
"cpp/test_unit_tests.py::test_unit_tests[kernels-90]": 554.3154804841615,
"cpp/test_unit_tests.py::test_unit_tests[layers-90]": 1288.0563295381144,
"cpp/test_unit_tests.py::test_unit_tests[runtime-90]": 876.2420815587975,
"cpp/test_unit_tests.py::test_unit_tests[thop-90]": 2.2652571727521718,
"cpp/test_unit_tests.py::test_unit_tests[utils-90]": 3.6415831856429577,
"cpp/test_e2e.py::test_benchmarks[bart-90]": 271.95234084688127,
"cpp/test_e2e.py::test_benchmarks[t5-90]": 523.07,
"cpp/test_e2e.py::test_model[-bart-90]": 391.84748707409017,
"cpp/test_e2e.py::test_model[-enc_dec_language_adapter-90]": 416.06,
"cpp/test_e2e.py::test_model[-t5-90]": 170.26,
"cpp/test_e2e.py::test_model[fp8-llama-90]": 385.98,
"cpp/test_unit_tests.py::test_unit_tests[common-90]": 25.06,
"cpp/test_unit_tests.py::test_unit_tests[kernels-90]": 1333.18,
"cpp/test_unit_tests.py::test_unit_tests[layers-90]": 1627.07,
"cpp/test_unit_tests.py::test_unit_tests[thop-90]": 4.16,
"cpp/test_unit_tests.py::test_unit_tests[utils-90]": 5.28,
"accuracy/test_cli_flow.py::TestLlama3_1_8B::test_fp8_rowwise_meta_recipe": 634.7149123200215,
"examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_fp16]": 895.7611340929288,
"examples/test_nemotron_nas.py::test_nemotron_nas_summary_1gpu[DeciLM-7B]": 335.41048416192643,
"examples/test_phi.py::test_llm_phi_lora_1gpu[Phi-3-mini-4k-instruct-ru-lora-Phi-3-mini-4k-instruct-lora_fp16-base_fp16]": 217.61977925198153,
"cpp/test_e2e.py::test_model[-gpt-80]": 2498.94351779297,
"cpp/test_unit_tests.py::test_unit_tests[batch_manager-80]": 380.8567730002105,
"cpp/test_unit_tests.py::test_unit_tests[common-80]": 20.237869411706924,
"cpp/test_unit_tests.py::test_unit_tests[layers-80]": 2141.2598778679967,
"cpp/test_unit_tests.py::test_unit_tests[runtime-80]": 1491.7047495394945,
"cpp/test_unit_tests.py::test_unit_tests[thop-80]": 3.3458465598523617,
"cpp/test_unit_tests.py::test_unit_tests[utils-80]": 5.461210697889328,
"accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_gemm_plugin": 593.3573900908232,
"examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it-other-bfloat16-8]": 195.3050664511975,
"test_e2e.py::test_llmapi_quickstart_atexit": 110.45052940770984,

View File

@ -1,4 +1,3 @@
import glob
import logging as _logger
import os as _os
import pathlib as _pl
@ -19,8 +18,16 @@ from defs.conftest import llm_models_root
@pytest.fixture(scope="session")
def build_dir():
return _cpp.find_build_dir()
def build_type():
"""CMake build type for C++ builds."""
# For debugging purposes, we can use the RelWithDebInfo build type.
return _os.environ.get("TLLM_BUILD_TYPE", "Release")
@pytest.fixture(scope="session")
def build_dir(build_type):
"""Resolved build directory for the current build_type."""
return _cpp.find_build_dir(build_type)
@pytest.fixture(scope="session")
@ -148,44 +155,35 @@ def install_additional_requirements(python_exe, root_dir):
@pytest.fixture(scope="session")
def build_google_tests(request, build_dir):
def build_google_tests(request, build_type):
cuda_arch = f"{request.param}-real"
print(f"Using CUDA arch: {cuda_arch}")
_logger.info(f"Using CUDA arch: {cuda_arch}")
build_trt_llm(cuda_architectures=cuda_arch,
job_count=12,
use_ccache=True,
clean=True,
generator="Ninja",
trt_root="/usr/local/tensorrt",
nixl_root="/opt/nvidia/nvda_nixl",
skip_building_wheel=True)
make_google_tests = [
"cmake",
"--build",
".",
"--config",
"Release",
"-j",
"--target",
"google-tests",
]
_cpp.run_command(make_google_tests, cwd=build_dir, timeout=300)
build_trt_llm(
build_type=build_type,
cuda_architectures=cuda_arch,
job_count=12,
use_ccache=True,
clean=True,
generator="Ninja",
trt_root="/usr/local/tensorrt",
nixl_root="/opt/nvidia/nvda_nixl",
skip_building_wheel=True,
extra_make_targets=["google-tests"],
)
@pytest.fixture(scope="session")
def build_benchmarks(build_google_tests, build_dir):
def build_benchmarks(build_google_tests, build_dir, build_type):
make_benchmarks = [
"cmake",
"--build",
".",
"--config",
"Release",
build_type,
"-j",
"--target",
"benchmarks",
@ -224,16 +222,18 @@ def prepare_model(
@pytest.fixture(scope="function", autouse=True)
def keep_log_files(llm_root):
"Backup previous cpp test results when run multiple ctest"
results_dir = f"{llm_root}/cpp/build"
def keep_log_files(build_dir):
"""Backup previous cpp test results when run multiple ctest invocations."""
results_dir = build_dir
yield
backup_dir = f"{llm_root}/cpp/build_backup"
_os.makedirs(backup_dir, exist_ok=True)
# Copy XML files to backup directory
xml_files = glob.glob(f"{results_dir}/*.xml")
build_parent_dir = build_dir.parent
backup_dir_name = build_dir.name + "_backup"
backup_dir = build_parent_dir / backup_dir_name
backup_dir.mkdir(parents=True, exist_ok=True)
# Copy XML files from all subdirectories to backup directory
xml_files = list(results_dir.rglob("*.xml"))
if xml_files:
for xml_file in xml_files:
try:

View File

@ -56,8 +56,6 @@ def generate_result_file_name(test_list: List[str],
def generate_excluded_test_list(test_list):
if "gpt" in test_list:
if "gpt_session" not in test_list:
yield "GptSession"
if "gpt_executor" not in test_list:
yield "GptExecutor"
if "gpt_tests" not in test_list:
@ -84,11 +82,18 @@ def find_root_dir(start_dir: Optional[_pl.Path] = None) -> _pl.Path:
return find_dir_containing(("scripts", "examples", "cpp"), start_dir)
def find_build_dir():
root_dir = find_root_dir()
dir = get_trt_llm_build_dir(None, "Release")
def find_build_dir(build_type: str) -> _pl.Path:
"""Resolve the TRT-LLM C++ build directory for the given CMake build type.
return dir if dir.is_absolute() else root_dir / dir
Args:
build_type: CMake build type (e.g., "Release", "RelWithDebInfo", "Debug").
Returns:
Absolute path to the C++ build directory.
"""
root_dir = find_root_dir()
build_dir = get_trt_llm_build_dir(None, build_type)
return build_dir if build_dir.is_absolute() else root_dir / build_dir
def run_command(command: Sequence[str],

View File

@ -14,6 +14,7 @@ def run_single_gpu_tests(build_dir: _pl.Path,
timeout=3600):
cpp_env = {**_os.environ}
tests_dir = build_dir / "tests" / "e2e_tests"
included_tests = list(_cpp.generate_included_model_tests(test_list))
@ -38,7 +39,7 @@ def run_single_gpu_tests(build_dir: _pl.Path,
parallel = int(parallel_override)
_cpp.parallel_run_ctest(ctest,
cwd=build_dir,
cwd=tests_dir,
env=cpp_env,
timeout=timeout,
parallel=parallel)
@ -50,12 +51,12 @@ def run_single_gpu_tests(build_dir: _pl.Path,
global_commands=["mpirun", "--allow-run-as-root"],
nranks=2,
local_commands=[
"tests/executor/disaggExecutorTest",
"executor/disaggExecutorTest",
"--gtest_filter=*GptSingleDeviceDisaggSymmetricExecutorTest*"
],
leader_commands=[f"--gtest_output=xml:{xml_output_file}"])
_cpp.run_command(trt_model_test,
cwd=build_dir,
cwd=tests_dir,
env=new_env,
timeout=timeout)
@ -192,7 +193,7 @@ def run_benchmarks(
def run_spec_dec_tests(build_dir: _pl.Path):
xml_output_file = build_dir / "results-spec-dec-fast-logits.xml"
cpp_env = {**_os.environ}
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "e2e_tests"
trt_model_test = _cpp.produce_mpirun_command(
global_commands=["mpirun", "--allow-run-as-root"],
nranks=3,

View File

@ -50,7 +50,7 @@ def get_multi_gpu_env(kv_cache_type=KVCacheType.NONE, llama_multi_gpu=False):
def run_mpi_utils_tests(build_dir, timeout=300):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "unit_tests" / "multi_gpu"
mgpu_env = get_multi_gpu_env()
mpi_utils_test = [
@ -68,7 +68,7 @@ def run_mpi_utils_tests(build_dir, timeout=300):
def run_gemm_allreduce_tests(build_dir, nprocs, timeout=300):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "unit_tests" / "multi_gpu"
mgpu_env = get_multi_gpu_env()
gemm_allreduce_test = [
@ -76,7 +76,7 @@ def run_gemm_allreduce_tests(build_dir, nprocs, timeout=300):
"-n",
f"{nprocs}",
"--allow-run-as-root",
"unit_tests/kernels/gemmAllReduceTest",
"kernels/gemmAllReduceTest",
"--m=2032",
"--n=8200",
"--k=1024",
@ -93,7 +93,7 @@ def run_cache_transceiver_tests(build_dir: _pl.Path,
kv_cache_type=KVCacheType.MPI,
timeout=600):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "unit_tests" / "multi_gpu"
mgpu_env = get_multi_gpu_env(kv_cache_type=kv_cache_type)
cache_trans_test = [
@ -101,7 +101,7 @@ def run_cache_transceiver_tests(build_dir: _pl.Path,
"-n",
f"{nprocs}",
"--allow-run-as-root",
"batch_manager/cacheTransceiverTest",
"cacheTransceiverTest",
]
_cpp.run_command(cache_trans_test,
cwd=tests_dir,
@ -117,7 +117,7 @@ def run_cache_transceiver_tests(build_dir: _pl.Path,
"-n",
"8",
"--allow-run-as-root",
"batch_manager/cacheTransceiverTest",
"cacheTransceiverTest",
]
_cpp.run_command(cache_trans_test_8_proc,
cwd=tests_dir,
@ -125,8 +125,26 @@ def run_cache_transceiver_tests(build_dir: _pl.Path,
timeout=600)
def run_user_buffer_tests(build_dir: _pl.Path, nprocs=2, timeout=300):
tests_dir = build_dir / "tests" / "unit_tests" / "multi_gpu"
mgpu_env = get_multi_gpu_env()
user_buffer_test = [
"mpirun",
"-n",
f"{nprocs}",
"--allow-run-as-root",
"userBufferTest",
]
_cpp.run_command(user_buffer_test,
cwd=tests_dir,
env=mgpu_env,
timeout=timeout)
def run_llama_executor_leader_tests(build_dir: _pl.Path, timeout=1500):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "e2e_tests"
mgpu_env = get_multi_gpu_env(llama_multi_gpu=True)
@ -145,7 +163,7 @@ def run_llama_executor_leader_tests(build_dir: _pl.Path, timeout=1500):
def run_llama_executor_orchestrator_tests(build_dir: _pl.Path, timeout=1500):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "e2e_tests"
mgpu_env = get_multi_gpu_env(llama_multi_gpu=True)
@ -160,7 +178,7 @@ def run_llama_executor_orchestrator_tests(build_dir: _pl.Path, timeout=1500):
def run_llama_executor_logits_proc_tests(build_dir: _pl.Path, timeout=1500):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "e2e_tests"
mgpu_env = get_multi_gpu_env(llama_multi_gpu=True)
@ -187,7 +205,7 @@ def run_llama_executor_logits_proc_tests(build_dir: _pl.Path, timeout=1500):
def run_llama_executor_guided_decoding_tests(build_dir: _pl.Path, timeout=1500):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "e2e_tests"
mgpu_env = get_multi_gpu_env(llama_multi_gpu=True)
@ -214,7 +232,7 @@ def run_llama_executor_guided_decoding_tests(build_dir: _pl.Path, timeout=1500):
def run_enc_dec_multi_gpu_tests(build_dir: _pl.Path, timeout=1500):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "e2e_tests"
cpp_env = {**_os.environ}
#EncDec test in leader mode
@ -233,7 +251,7 @@ def run_enc_dec_multi_gpu_tests(build_dir: _pl.Path, timeout=1500):
def run_trt_gpt_model_real_decoder_multi_gpu_tests(build_dir: _pl.Path,
timeout=1500):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "e2e_tests"
cpp_env = {**_os.environ}
xml_output_file = build_dir / "results-multi-gpu-real-decoder.xml"
@ -256,7 +274,7 @@ def run_disagg_symmetric_executor_tests(build_dir: _pl.Path,
nprocs=2,
kvcache_type=KVCacheType.MPI,
timeout=1500):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "e2e_tests"
prefix = get_model_test_filter_prefix(model)
@ -285,7 +303,7 @@ def run_disagg_asymmetric_executor_tests(build_dir: _pl.Path,
kvcache_type=KVCacheType.MPI,
timeout=1500):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "e2e_tests"
prefix = get_model_test_filter_prefix(model)
@ -314,7 +332,7 @@ def run_disagg_orchestrator_params_tests(build_dir: _pl.Path,
kvcache_type=KVCacheType.MPI,
timeout=1500):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "e2e_tests"
prefix = get_model_test_filter_prefix(model)
@ -341,7 +359,7 @@ def run_disagg_spawn_orchestrator_tests(build_dir: _pl.Path,
kvcache_type=False,
timeout=1500):
tests_dir = build_dir / "tests"
tests_dir = build_dir / "tests" / "e2e_tests"
prefix = get_model_test_filter_prefix(model)
@ -352,7 +370,7 @@ def run_disagg_spawn_orchestrator_tests(build_dir: _pl.Path,
comms = [
"executor/disaggExecutorTest",
f"--gtest_filter=*{prefix}*DisaaggSpawnOrchestrator*",
f"--gtest_filter=*{prefix}*DisaggSpawnOrchestrator*",
f"--gtest_output=xml:{xml_output_file}"
]
_cpp.run_command(comms, cwd=tests_dir, env=mgpu_env, timeout=timeout)
@ -488,13 +506,21 @@ def test_fused_gemm_allreduce(build_google_tests, nprocs, build_dir):
def test_cache_transceiver(build_google_tests, nprocs, kvcache_type, build_dir):
if platform.system() != "Windows":
run_cache_transceiver_tests(build_dir=build_dir,
nprocs=nprocs,
kv_cache_type=kvcache_type,
timeout=600)
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
indirect=True)
@pytest.mark.parametrize("nprocs", [2, 8], ids=["2proc", "8proc"])
def test_user_buffer(build_google_tests, nprocs, build_dir):
if platform.system() != "Windows":
run_user_buffer_tests(build_dir=build_dir, nprocs=nprocs, timeout=300)
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
indirect=True)
@pytest.mark.parametrize("multi_gpu_model", ["t5"], indirect=True)

View File

@ -38,13 +38,13 @@ l0_a30:
- cpp/test_unit_tests.py::test_unit_tests[common-80]
- cpp/test_unit_tests.py::test_unit_tests[executor-80]
- cpp/test_unit_tests.py::test_unit_tests[kernels-80]
- cpp/test_unit_tests.py::test_unit_tests[layers-80] TIMEOUT (90)
- cpp/test_unit_tests.py::test_unit_tests[layers-80]
- cpp/test_unit_tests.py::test_unit_tests[runtime-80]
- cpp/test_unit_tests.py::test_unit_tests[thop-80]
- cpp/test_unit_tests.py::test_unit_tests[utils-80]
- cpp/test_e2e.py::test_model[-gpt-80] TIMEOUT (90)
- cpp/test_e2e.py::test_model[-gpt_executor-80] TIMEOUT (90)
- cpp/test_e2e.py::test_model[-gpt_executor-80]
- cpp/test_e2e.py::test_model[-gpt_tests-80]
- cpp/test_e2e.py::test_model[-gpt-80]
- condition:
ranges:
system_gpu_count:

View File

@ -69,6 +69,8 @@ l0_b200:
- unittest/_torch/modeling -k "modeling_deepseek"
- unittest/_torch/modeling -k "modeling_gpt_oss"
- unittest/_torch/auto_deploy/unit/singlegpu -k "not test_trtllm_bench_backend_comparison"
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
- condition:
ranges:
system_gpu_count:

View File

@ -89,3 +89,5 @@ l0_dgx_b200:
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-fp8]
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype

View File

@ -61,6 +61,8 @@ l0_dgx_h100:
- test_e2e.py::test_ptp_quickstart_advanced_bs1
- test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8]
- unittest/_torch/modeling/test_modeling_pixtral.py::test_tensor_parallelism
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
- condition:
ranges:
system_gpu_count:
@ -177,6 +179,7 @@ l0_dgx_h100:
- cpp/test_multi_gpu.py::test_cache_transceiver[2proc-ucx_kvcache-90]
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mpi_kvcache-90]
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-ucx_kvcache-90]
- cpp/test_multi_gpu.py::test_user_buffer[2proc-90]
- cpp/test_multi_gpu.py::test_enc_dec[t5-90]
- cpp/test_multi_gpu.py::test_llama_executor[llama-orchestrator-90]
- cpp/test_multi_gpu.py::test_llama_executor[llama-leader-90]
@ -209,18 +212,3 @@ l0_dgx_h100:
- cpp/test_multi_gpu.py::TestDisagg::test_spawn_orchestrator[llama-ucx_kvcache-90]
- cpp/test_multi_gpu.py::TestDisagg::test_orchestrator_params[llama-nixl_kvcache-90] TIMEOUT (90)
- cpp/test_multi_gpu.py::TestDisagg::test_spawn_orchestrator[llama-nixl_kvcache-90]
- condition:
ranges:
system_gpu_count:
gte: 4
lte: 4
wildcards:
gpu:
- '*h100*'
linux_distribution_name: ubuntu*
terms:
stage: post_merge
backend: triton
auto_trigger: others
tests:
- triton_server/test_triton_llm.py::test_llmapi_backend[4-0-disableDecoupleMode-tensorrt_llm]

View File

@ -34,6 +34,8 @@ l0_dgx_h200:
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout]
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout]
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
- condition:
ranges:
system_gpu_count:
@ -166,3 +168,19 @@ l0_dgx_h200:
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float16-enable_gemm_plugin-enable_attention_plugin-disable_paged_kv_cache-tp:2-pp:2-nb:1-disable_fp8]
- examples/test_gpt.py::test_llm_gpt2_next_prompt_tuning[use_py_session-tp2]
- unittest/llmapi/apps/_test_openai_multi_gpu.py -m "part0"
- condition:
ranges:
system_gpu_count:
gte: 4
lte: 4
wildcards:
gpu:
- '*h200*'
linux_distribution_name: ubuntu*
cpu: x86_64
terms:
stage: post_merge
backend: triton
tests:
# ------------- Triton tests ---------------
- triton_server/test_triton_llm.py::test_llmapi_backend[4-0-disableDecoupleMode-tensorrt_llm]

View File

@ -1,10 +1,12 @@
# Don't add any tests here.
# Copied from l0_b200.yml but not used in the pipeline now
version: 0.0.1
l0_gb200:
- condition:
ranges:
system_gpu_count:
gte: 4
lte: 4
gte: 1
lte: 1
wildcards:
gpu:
- '*gb200*'
@ -14,33 +16,111 @@ l0_gb200:
stage: pre_merge
backend: pytorch
tests:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
# ------------- PyTorch tests ---------------
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_4]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_64]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=nvfp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=True-fp8kv=False-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=False-fp8kv=False-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=nvfp4-kv_cache_reuse=True-fp8kv=False-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=nvfp4-kv_cache_reuse=True-fp8kv=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[mxfp8-latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-CUTLASS]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRITON]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRTLLM]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-TRTLLM]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton]
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]
- test_e2e.py::test_ptp_quickstart_advanced_mtp[DeepSeek-V3-Lite-BF16-DeepSeek-V3-Lite/bf16]
- test_e2e.py::test_ptp_quickstart_advanced_mixed_precision
- test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
- test_e2e.py::test_ptp_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
- unittest/_torch/attention
- unittest/_torch/compilation
- unittest/_torch/debugger
- unittest/_torch/executor
- unittest/_torch/misc
- unittest/_torch/modules
- unittest/_torch/multimodal
- unittest/_torch/sampler
- unittest/_torch/speculative
- unittest/_torch/thop
- unittest/_torch/modeling -k "modeling_llama"
- unittest/_torch/modeling -k "modeling_mixtral"
- unittest/_torch/modeling -k "modeling_deepseek"
- unittest/_torch/modeling -k "modeling_gpt_oss"
- unittest/_torch/auto_deploy/unit/singlegpu -k "not test_trtllm_bench_backend_comparison"
- condition:
ranges:
system_gpu_count:
gte: 4
lte: 4
gte: 1
lte: 1
wildcards:
gpu:
- '*gb200*'
linux_distribution_name: ubuntu*
cpu: aarch64
terms:
stage: post_merge
backend: tensorrt
tests:
# ------------- TRT tests ---------------
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-disable_fused_quant]
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-enable_fused_quant]
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-disable_fused_quant]
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-enable_fused_quant]
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_auto_dtype
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_fp8
- accuracy/test_cli_flow.py::TestMixtral8x7B::test_nvfp4_prequantized
- unittest/trt/attention/test_gpt_attention.py -k "trtllm_gen"
- unittest/llmapi/test_llm_quant.py
- unittest/trt/functional/test_fp4_gemm.py
- condition:
ranges:
system_gpu_count:
gte: 1
lte: 1
wildcards:
gpu:
- '*gb200*'
linux_distribution_name: ubuntu*
cpu: aarch64
terms:
stage: post_merge
backend: triton
tests:
# ------------- Triton tests ---------------
- triton_server/test_triton.py::test_llava[llava]
- triton_server/test_triton.py::test_gpt_ib_ptuning[gpt-ib-ptuning]
- triton_server/test_triton.py::test_gpt_2b_ib_lora[gpt-2b-ib-lora]
- condition:
ranges:
system_gpu_count:
gte: 1
lte: 1
wildcards:
gpu:
- '*gb200*'
@ -50,23 +130,16 @@ l0_gb200:
stage: post_merge
backend: pytorch
tests:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90)
# ------------- PyTorch tests ---------------
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]

View File

@ -0,0 +1,69 @@
version: 0.0.1
l0_gb200_multi_gpus:
- condition:
ranges:
system_gpu_count:
gte: 4
lte: 4
wildcards:
gpu:
- '*gb200*'
linux_distribution_name: ubuntu*
cpu: aarch64
terms:
stage: pre_merge
backend: pytorch
tests:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- condition:
ranges:
system_gpu_count:
gte: 4
lte: 4
wildcards:
gpu:
- '*gb200*'
linux_distribution_name: ubuntu*
cpu: aarch64
terms:
stage: post_merge
backend: pytorch
tests:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90)

View File

@ -102,6 +102,8 @@ l0_h100:
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
- condition:
ranges:
system_gpu_count:

View File

@ -269,8 +269,6 @@ test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mis
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True] SKIP (https://nvbugs/5433545)
examples/test_nemotron_nas.py::test_nemotron_nas_summary_1gpu[DeciLM-7B] SKIP (https://nvbugs/5444636)
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_multiblock_aggressive SKIP (https://nvbugs/5444627)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] SKIP (https://nvbugs/5444687)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] SKIP (https://nvbugs/5444687)
examples/test_qwen2audio.py::test_llm_qwen2audio_single_gpu[qwen2_audio_7b_instruct] SKIP (https://nvbugs/5447530)
examples/test_nemotron_nas.py::test_nemotron_nas_summary_2gpu[DeciLM-7B] SKIP (https://nvbugs/5444636)
examples/test_multimodal.py::test_llm_multimodal_general[Qwen2-VL-7B-Instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:4] SKIP (https://nvbugs/5453709)
@ -325,3 +323,4 @@ full:L40S/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] SKIP (https://nvbugs/5471106)
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108)
test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] SKIP (https://nvbugs/5473781)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5476580)

View File

@ -52,10 +52,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
vector_dim,
dtype=dtype,
device=torch.device('cuda'))
output_tensor = torch.zeros(output_entry_count,
vector_dim,
dtype=dtype,
device=torch.device('cuda'))
send_cumsum = torch.ones(
(1, ), dtype=torch.int32,
@ -78,13 +74,18 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(1)
all_workspaces = torch.zeros(1,
workspace_size,
workspace_size // 8,
dtype=torch.uint64,
device=torch.device('cuda'))
torch.ops.trtllm.moe_initialize_workspace(all_workspaces, 0, 1)
torch.ops.trtllm.moe_comm(input_tensor, send_cumsum, send_indices,
output_tensor, recv_cumsum, recv_indices,
all_workspaces, 0, 1)
output_tensors = torch.ops.trtllm.moe_comm([input_tensor], send_cumsum,
send_indices, recv_cumsum,
recv_indices, all_workspaces,
output_entry_count, 0, 1,
[True])
output_tensor = output_tensors[0]
torch.testing.assert_close(output_tensor,
ref_output_tensor,
@ -103,40 +104,43 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
send_indices = torch.zeros(1,
dtype=torch.int32,
device=torch.device('cuda'))
output_tensor = torch.zeros(1,
8,
dtype=torch.float16,
device=torch.device('cuda'))
recv_cumsum = torch.ones(1,
dtype=torch.int32,
device=torch.device('cuda'))
recv_indices = torch.zeros(1,
dtype=torch.int32,
device=torch.device('cuda'))
input_tensors = [input_tensor]
workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(1)
all_workspaces = torch.zeros(1,
workspace_size,
workspace_size // 8,
dtype=torch.uint64,
device=torch.device('cuda'))
torch.ops.trtllm.moe_comm(input_tensor, send_cumsum, send_indices,
output_tensor, recv_cumsum, recv_indices,
all_workspaces, 0, 1)
_ = torch.ops.trtllm.moe_comm(input_tensors, send_cumsum, send_indices,
recv_cumsum, recv_indices, all_workspaces,
1, 0, 1, [True])
torch.cuda.synchronize()
@parameterized.expand([
(2, 5, 8, torch.float16), # small input as smoke test
(2, 1, 8, torch.float16), # some ranks have no data to send/recv
(4, 5, 8, torch.float16), # small input with larger world size
(4, 901, 32768, torch.bfloat16), # large input that reuses workspace
(8, 901, 32768,
(2, 5, [4, 4], torch.float16), # small input as smoke test
(2, 1, [8], torch.float16), # some ranks have no data to send/recv
(4, 5, [8], torch.float16), # small input with larger world size
(4, 901, [1472, 46, 4,
4], torch.float16), # large input that reuses workspace
(4, 5, [2944], torch.bfloat16), # large input that reuses workspace
(8, 901, [
32768,
],
torch.float16), # large input that reuses workspace, larger world size
(
8, 16384, 128, torch.float16
8, 16384, [
128,
], torch.float16
), # large input count with small vector dim that requires more indices per fifo
])
def test_moe_alltoall_multi_rank_single_gpu(self, world_size,
input_entry_per_rank,
vector_dim, dtype):
vector_dims, dtype):
torch.cuda.set_device(0)
max_world_size = 8
assert world_size <= max_world_size, f"should run with world_size at most {max_world_size}"
@ -148,27 +152,32 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
torch.ops.trtllm.set_moe_max_usable_sm_count(max_sm_count)
has_setup_max_sm_count = True
# Create a random input tensor
input_tensor = torch.randn(input_entry_per_rank * world_size,
vector_dim,
dtype=dtype,
device=torch.device('cuda'))
output_tensor = torch.zeros(input_entry_per_rank * world_size,
vector_dim,
dtype=dtype,
device=torch.device('cuda'))
ref_output_tensor = torch.zeros(input_entry_per_rank * world_size,
vector_dim,
dtype=dtype,
device=torch.device('cuda'))
tensor_count = len(vector_dims)
input_tensors = []
ref_output_tensors = []
for vector_dim in vector_dims:
input_tensors.append(
torch.randn(input_entry_per_rank * world_size,
vector_dim,
dtype=dtype,
device=torch.device('cuda')))
ref_output_tensors.append(
torch.zeros(input_entry_per_rank * world_size,
vector_dim,
dtype=dtype,
device=torch.device('cuda')))
target_rank_ids = torch.randint(0,
world_size,
(input_entry_per_rank * world_size, ),
dtype=torch.int32,
device=torch.device('cuda'))
input_tensors_all_ranks = list(
torch.split(input_tensor, input_entry_per_rank))
input_tensors_all_ranks = []
for i in range(tensor_count):
input_tensors_all_ranks.append(
list(torch.split(input_tensors[i], input_entry_per_rank)))
target_rank_ids_all_ranks = list(
torch.split(target_rank_ids, input_entry_per_rank))
@ -210,12 +219,9 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
recv_ids_all_ranks = []
recv_cumsum_all_ranks = []
output_tensors_all_ranks = []
total_recv_all_ranks_cpu = []
output_indice_offset = 0
output_start_current_rank = 0
# each rank do compute based on other ranks' send counts to get how to receive data from other ranks.
for rank in range(world_size):
local_recv_counts = torch.zeros(world_size,
@ -227,18 +233,15 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
local_recv_count_pair = local_recv_counts[other_rank].cpu(
).item()
send_rank_start_end = send_start_end_all_ranks[other_rank][rank]
ref_output_tensor[output_indice_offset:output_indice_offset + local_recv_count_pair] = \
input_tensors_all_ranks[other_rank][send_ids_all_ranks[other_rank][send_rank_start_end[0]:send_rank_start_end[1]]]
for i in range(tensor_count):
ref_output_tensors[i][output_indice_offset:output_indice_offset + local_recv_count_pair] = \
input_tensors_all_ranks[i][other_rank][send_ids_all_ranks[other_rank][send_rank_start_end[0]:send_rank_start_end[1]]]
output_indice_offset += local_recv_count_pair
local_recv_cumsum = torch.cumsum(local_recv_counts,
dim=0).to(torch.int32)
recv_cumsum_all_ranks.append(local_recv_cumsum)
total_recv_count = local_recv_cumsum[-1].cpu()
total_recv_all_ranks_cpu.append(total_recv_count)
output_tensors_all_ranks.append(output_tensor[
output_start_current_rank:output_start_current_rank +
total_recv_count])
output_start_current_rank += total_recv_count
local_recv_ids = torch.arange(total_recv_count,
dtype=torch.int32,
device=torch.device('cuda'))
@ -251,9 +254,12 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(
world_size)
all_workspaces = torch.zeros(world_size,
workspace_size,
workspace_size // 8,
dtype=torch.uint64,
device=torch.device('cuda'))
for i in range(world_size):
torch.ops.trtllm.moe_initialize_workspace(all_workspaces, i,
world_size)
# do one warmup for each rank to avoid possible synchronization at first launch.
for rank in range(world_size):
@ -262,212 +268,141 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
torch.cuda.synchronize()
# Store output tensors from each rank
output_tensors_all_ranks = []
# do alltoall in parallel
for rank in range(world_size):
input_tensors_this_rank = [
input_tensors_all_ranks[i][rank] for i in range(tensor_count)
]
with torch.cuda.stream(cuda_streams_all_ranks[rank]):
torch.ops.trtllm.moe_comm(
input_tensors_all_ranks[rank], send_cumsum_all_ranks[rank],
send_ids_all_ranks[rank], output_tensors_all_ranks[rank],
recv_cumsum_all_ranks[rank], recv_ids_all_ranks[rank],
all_workspaces, rank, world_size)
output_tensors_this_rank = torch.ops.trtllm.moe_comm(
input_tensors_this_rank, send_cumsum_all_ranks[rank],
send_ids_all_ranks[rank], recv_cumsum_all_ranks[rank],
recv_ids_all_ranks[rank], all_workspaces,
input_entry_per_rank * world_size, rank, world_size)
output_tensors_all_ranks.append(output_tensors_this_rank)
for rank in range(world_size):
cuda_streams_all_ranks[rank].synchronize()
torch.testing.assert_close(output_tensor,
ref_output_tensor,
atol=1e-5,
rtol=1e-5)
# Reconstruct the full output tensors by concatenating results from all ranks
for i in range(tensor_count):
# Collect the actual received data from each rank (trim to actual recv count)
actual_output_parts = []
for rank in range(world_size):
total_recv_count = total_recv_all_ranks_cpu[rank].item()
# Each rank returns tensor with size [input_entry_per_rank * world_size, vector_dim]
# but only the first total_recv_count entries are valid
actual_output_parts.append(
output_tensors_all_ranks[rank][i][:total_recv_count])
@parameterized.expand([
(0, 8, 256, 4, 3, False),
(0, 8, 256, 4, 3, True),
(1, 8, 256, 4, 3, False),
(1, 8, 256, 4, 3, True),
(1, 4, 256, 8, 3, False),
(1, 4, 256, 8, 3, True),
(7, 8, 256, 8, 1025, False),
(7, 8, 256, 8, 1025, True),
(7, 64, 1024, 32, 1029, False),
(7, 64, 1024, 32, 1029, True),
])
def test_moe_alltoall_prepare_indices(
self, ep_rank: int, ep_size: int, expert_count: int, top_k: int,
max_token_count_per_rank: int,
use_real_rank_token_count_cumsum: bool):
# Concatenate all ranks' outputs to form the complete result
actual_output = torch.cat(actual_output_parts, dim=0)
torch.testing.assert_close(actual_output,
ref_output_tensors[i],
atol=1e-5,
rtol=1e-5)
class TestMoeAlltoAllFP8SingleGPU(unittest.TestCase):
def setUp(self):
torch.manual_seed(0x1234)
tllm.logger.set_level('error')
def test_moe_alltoall_fp8_with_indices(self):
"""Test fp8 alltoall with properly constructed indices"""
torch.cuda.set_device(0)
gathered_target_rank_ids = torch.randint(
0,
ep_size, (ep_size * max_token_count_per_rank, top_k),
dtype=torch.int32,
device=torch.device('cuda'))
real_rank_token_count_cumsum = None
if use_real_rank_token_count_cumsum:
real_rank_token_count_cumsum = torch.randint(
0,
max_token_count_per_rank + 1, (ep_size, ),
dtype=torch.int32,
device=torch.device('cuda'))
real_rank_token_count_cumsum = torch.cumsum(
real_rank_token_count_cumsum, dim=0).to(torch.int32)
def generate_references():
gathered_target_rank_ids_cpu_lists = gathered_target_rank_ids.cpu(
).tolist()
if use_real_rank_token_count_cumsum:
real_rank_token_count_cumsum_cpu_lists = real_rank_token_count_cumsum.cpu(
).tolist()
else:
real_rank_token_count_cumsum_cpu_lists = [
(i + 1) * max_token_count_per_rank for i in range(ep_size)
]
rank_token_start = 0
ref_local_gather_indices_cpu_lists = []
ref_recv_rank_count_cumsum_cpu_lists = [0] * ep_size
ref_recv_rank_local_indices_cpu_lists = []
ref_send_rank_count_cumsum_cpu_lists = [0] * ep_size
ref_send_rank_local_indices_cpu_lists = []
ref_backward_recv_rank_local_indices_cpu_lists = []
total_recv_count = 0
for rank in range(ep_size):
rank_token_end = real_rank_token_count_cumsum_cpu_lists[rank]
for token_id in range(rank_token_start, rank_token_end):
if ep_rank in gathered_target_rank_ids_cpu_lists[token_id]:
ref_local_gather_indices_cpu_lists.append(token_id)
ref_recv_rank_local_indices_cpu_lists.append(
total_recv_count)
total_recv_count += 1
ref_recv_rank_count_cumsum_cpu_lists[rank] = total_recv_count
if rank == ep_rank:
total_send_count = 0
for target_rank in range(ep_size):
for token_id in range(rank_token_start, rank_token_end):
local_token_id = token_id - rank_token_start
if target_rank in gathered_target_rank_ids_cpu_lists[
token_id]:
pos = gathered_target_rank_ids_cpu_lists[
token_id].index(target_rank)
ref_send_rank_local_indices_cpu_lists.append(
local_token_id)
ref_backward_recv_rank_local_indices_cpu_lists.append(
local_token_id * top_k + pos)
total_send_count += 1
ref_send_rank_count_cumsum_cpu_lists[
target_rank] = total_send_count
rank_token_start = rank_token_end
ref_local_gather_indices = torch.IntTensor(
ref_local_gather_indices_cpu_lists).cuda()
ref_send_rank_count_cumsum = torch.IntTensor(
ref_send_rank_count_cumsum_cpu_lists).cuda()
ref_send_rank_local_indices = torch.IntTensor(
ref_send_rank_local_indices_cpu_lists).cuda()
ref_recv_rank_count_cumsum = torch.IntTensor(
ref_recv_rank_count_cumsum_cpu_lists).cuda()
ref_recv_rank_local_indices = torch.IntTensor(
ref_recv_rank_local_indices_cpu_lists).cuda()
ref_backward_recv_rank_local_indices = torch.IntTensor(
ref_backward_recv_rank_local_indices_cpu_lists).cuda()
return ref_local_gather_indices, ref_send_rank_count_cumsum, ref_send_rank_local_indices, ref_recv_rank_count_cumsum, ref_recv_rank_local_indices, ref_backward_recv_rank_local_indices
# Match dimensions from the error
input_entry_count = 16384
output_entry_count = 16384
vector_dim = 2944
sf_vector_dim = 92 # Scaling factor dimension from error
send_recv_count = 1000 # Number of entries to send/receive
ref_local_gather_indices, ref_send_rank_count_cumsum, ref_send_rank_local_indices, ref_recv_rank_count_cumsum, ref_recv_rank_local_indices, ref_backward_recv_rank_local_indices = generate_references(
)
# Create input tensors - first as float16, then convert
input_tensor_fp16 = torch.randn(input_entry_count,
vector_dim,
dtype=torch.float16,
device='cuda')
input_tensor_fp8 = input_tensor_fp16.to(torch.float8_e4m3fn)
local_gather_indices, send_rank_count_cumsum, send_rank_local_indices, recv_rank_count_cumsum, recv_rank_local_indices, backward_recv_rank_local_indices = \
torch.ops.trtllm.moe_comm_prepare_indices(gathered_target_rank_ids, real_rank_token_count_cumsum, max_token_count_per_rank, expert_count, top_k, ep_rank, ep_size)
# Scaling factor tensor
input_sf_tensor = torch.randint(1,
255, (input_entry_count, sf_vector_dim),
dtype=torch.uint8,
device='cuda')
assert torch.equal(
local_gather_indices[:torch.numel(ref_local_gather_indices)],
ref_local_gather_indices)
assert torch.equal(
send_rank_count_cumsum[:torch.numel(ref_send_rank_count_cumsum)],
ref_send_rank_count_cumsum)
assert torch.equal(
send_rank_local_indices[:torch.numel(ref_send_rank_local_indices)],
ref_send_rank_local_indices)
assert torch.equal(
recv_rank_count_cumsum[:torch.numel(ref_recv_rank_count_cumsum)],
ref_recv_rank_count_cumsum)
assert torch.equal(
recv_rank_local_indices[:torch.numel(ref_recv_rank_local_indices)],
ref_recv_rank_local_indices)
assert torch.equal(
backward_recv_rank_local_indices[:torch.numel(
ref_backward_recv_rank_local_indices)],
ref_backward_recv_rank_local_indices)
# Expert selection tensors
input_experts = torch.randint(0,
64, (input_entry_count, 4),
dtype=torch.int32,
device='cuda')
input_scales = torch.rand(input_entry_count,
4,
dtype=torch.float32,
device='cuda')
@parameterized.expand([
(0, 8, 256, 4, 3),
(1, 8, 256, 4, 3),
(7, 8, 256, 4, 3),
(7, 8, 256, 8, 32),
(7, 8, 256, 32, 10),
(7, 8, 1024, 32, 127),
(7, 64, 1024, 32, 1029),
(9, 64, 1024, 3, 1029),
])
def test_moe_local_gather(self, ep_rank: int, ep_size: int,
expert_count: int, top_k: int,
max_token_count_per_rank: int):
torch.cuda.set_device(0)
rank_token_count_cumsum = torch.randint(0,
max_token_count_per_rank + 1,
(ep_size, ),
dtype=torch.int32,
device=torch.device('cuda'))
rank_token_count_cumsum = torch.cumsum(rank_token_count_cumsum,
dim=0).to(torch.int32)
local_token_count = rank_token_count_cumsum[ep_size - 1].cpu().item()
local_max_token_count = max_token_count_per_rank * ep_size
local_gather_indices = torch.randint(0,
max_token_count_per_rank * ep_size,
(local_max_token_count, ),
dtype=torch.int32,
device=torch.device('cuda'))
# Construct send/recv indices
send_cumsum = torch.tensor([send_recv_count],
dtype=torch.int32,
device='cuda')
recv_cumsum = torch.tensor([send_recv_count],
dtype=torch.int32,
device='cuda')
gathered_expert_ids = torch.randint(
0,
expert_count, (max_token_count_per_rank * ep_size, top_k),
dtype=torch.int32,
device=torch.device('cuda'))
gathered_scales = torch.rand(
(max_token_count_per_rank * ep_size, top_k),
dtype=torch.float32,
device=torch.device('cuda'))
# Random indices for sending
send_indices = torch.randperm(input_entry_count,
dtype=torch.int32,
device='cuda')[:send_recv_count]
recv_indices = torch.randperm(output_entry_count,
dtype=torch.int32,
device='cuda')[:send_recv_count]
ref_local_expert_ids = torch.zeros(local_max_token_count,
top_k,
dtype=torch.int32,
device=torch.device('cuda'))
ref_local_scales = torch.zeros(local_max_token_count,
top_k,
dtype=torch.float32,
device=torch.device('cuda'))
# Create workspace
workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(1)
all_workspaces = torch.zeros(1,
workspace_size // 8,
dtype=torch.uint64,
device='cuda')
torch.ops.trtllm.moe_initialize_workspace(all_workspaces, 0, 1)
# compute reference
ref_local_expert_ids += expert_count
valid_local_gather_indices = local_gather_indices[:local_token_count]
ref_local_expert_ids[:local_token_count] = gathered_expert_ids[
valid_local_gather_indices]
ref_local_scales[:local_token_count] = gathered_scales[
valid_local_gather_indices]
print(f"Test configuration:")
print(f" Input entries: {input_entry_count}")
print(f" Vector dim: {vector_dim}")
print(f" SF vector dim: {sf_vector_dim}")
print(f" Send/recv count: {send_recv_count}")
print(f" FP8 tensor shape: {input_tensor_fp8.shape}")
print(f" SF tensor shape: {input_sf_tensor.shape}")
local_expert_ids = torch.empty(local_max_token_count,
top_k,
dtype=torch.int32,
device=torch.device('cuda'))
local_scales = torch.empty(local_max_token_count,
top_k,
dtype=torch.float32,
device=torch.device('cuda'))
try:
# Test with all 4 tensors
output_tensor_fp8, output_sf_tensor, output_experts, output_scales = \
torch.ops.trtllm.moe_comm([
input_tensor_fp8, input_sf_tensor, input_experts, input_scales
], send_cumsum, send_indices, recv_cumsum, recv_indices, all_workspaces, output_entry_count, 0, 1)
torch.ops.trtllm.moe_local_gather(rank_token_count_cumsum,
local_gather_indices,
gathered_expert_ids, gathered_scales,
local_expert_ids, local_scales,
max_token_count_per_rank,
expert_count, top_k, ep_rank, ep_size)
torch.cuda.synchronize()
print("FP8 alltoall test PASSED!")
assert torch.equal(local_expert_ids, ref_local_expert_ids)
assert torch.equal(local_scales, ref_local_scales)
# Verify outputs
print(f"\nOutput verification:")
print(f" Output FP8 shape: {output_tensor_fp8.shape}")
print(f" Output SF shape: {output_sf_tensor.shape}")
print(
f" Non-zero FP8 elements: {(output_tensor_fp8 != 0).sum().item()}"
)
print(
f" Non-zero SF elements: {(output_sf_tensor != 0).sum().item()}"
)
except Exception as e:
print(f"FP8 alltoall test FAILED: {e}")
print(f"Error type: {type(e)}")
raise
@parameterized.expand([
(0, 2, 16, 20, 8, 512),
@ -489,7 +424,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
cpu_expert_ids_all_ranks_lists = []
cpu_token_count_lists = []
cpu_scales_all_ranks_lists = []
for _ in range(ep_size):
token_count = torch.randint(max_token_count_per_rank // 2,
max_token_count_per_rank + 1, (1, ),
@ -505,12 +439,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
dtype=torch.int32,
device=torch.device('cpu')))
cpu_scales_all_ranks_lists.append(
torch.zeros(token_count,
top_k,
dtype=torch.float32,
device=torch.device('cpu')) + 0.5)
cpu_token_count_lists.append(token_count)
def compute_target_rank(expert_id):
@ -519,7 +447,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
def generate_references():
ref_prepared_local_expert_ids = []
ref_prepared_local_scales = []
ref_local_send_rank_count_cumsum = [0] * ep_size
ref_local_recv_rank_count_cumsum = [0] * ep_size
ref_local_recv_rank_indices = []
@ -580,16 +507,13 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
for pos in range(top_k):
expert_id = int(
cpu_expert_ids_all_ranks_lists[rank][token_id][pos])
sf = cpu_scales_all_ranks_lists[rank][token_id][pos]
target_rank_id = compute_target_rank(expert_id)
if target_rank_id == ep_rank:
if not token_is_received:
token_is_received = True
ref_prepared_local_expert_ids.append(
[slot_count] * top_k)
ref_prepared_local_scales.append([0.0] * top_k)
ref_prepared_local_expert_ids[-1][pos] = expert_id
ref_prepared_local_scales[-1][pos] = sf
if token_is_received:
ref_local_recv_rank_indices.append(
total_recv_token_count)
@ -599,9 +523,9 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
rank] = current_recv_token_count if rank == 0 else ref_local_recv_rank_count_cumsum[
rank - 1] + current_recv_token_count
return ref_prepared_local_expert_ids, ref_prepared_local_scales, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count
return ref_prepared_local_expert_ids, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count
ref_prepared_local_expert_ids, ref_prepared_local_scales, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count = generate_references(
ref_prepared_local_expert_ids, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count = generate_references(
)
cpu_experter_count_lists = []
@ -615,10 +539,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
expert_ids_all_ranks = [
cpu_expert_ids_all_ranks_lists[i].cuda() for i in range(ep_size)
]
#scales_all_ranks = torch.FloatTensor(cpu_scales_all_ranks_lists).cuda()
scales_all_ranks = [
cpu_scales_all_ranks_lists[i].cuda() for i in range(ep_size)
]
experter_count_lists = [
cpu_experter_count_lists[i].cuda() for i in range(ep_size)
@ -637,30 +557,18 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(
expert_ids_all_ranks[0], scales_all_ranks[0],
experter_count_lists[0], all_workspaces,
max_token_count_per_rank, 0, 1, expert_count, slot_count, top_k)
expert_ids_all_ranks[0], experter_count_lists[0],
all_workspaces, max_token_count_per_rank, 0, 1, expert_count,
slot_count, top_k)
stream.wait_stream(torch.cuda.current_stream())
# Make torch alloc tensor to avoid cuda sync
prepared_local_experts = []
prepared_local_scales = []
local_send_rank_count_cumsum = []
local_send_rank_indices = []
local_recv_rank_count_cumsum = []
local_recv_rank_indices = []
backward_local_recv_rank_indices = []
for _ in range(ep_size):
prepared_local_experts.append(
torch.empty(max_token_count_per_rank * ep_size,
top_k,
dtype=torch.int32,
device=torch.device('cuda')))
prepared_local_scales.append(
torch.empty(max_token_count_per_rank * ep_size,
top_k,
dtype=torch.float32,
device=torch.device('cuda')))
local_send_rank_count_cumsum.append(
torch.empty(ep_size,
dtype=torch.int32,
@ -676,8 +584,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
backward_local_recv_rank_indices.append(
torch.empty(0, dtype=torch.int32, device=torch.device('cuda')))
prepared_local_experts = []
prepared_local_scales = []
local_send_rank_count_cumsum = []
local_send_rank_indices = []
local_recv_rank_count_cumsum = []
@ -694,35 +600,19 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
for rank in range(ep_size):
with torch.cuda.stream(cuda_streams_all_ranks[rank]):
if rank == ep_rank:
prepared_local_experts, prepared_local_scales, local_send_rank_count_cumsum, \
local_send_rank_count_cumsum, \
local_send_rank_indices, local_recv_rank_count_cumsum, local_recv_rank_indices, \
backward_local_recv_rank_indices, gathered_expert_statics\
= torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(expert_ids_all_ranks[rank], scales_all_ranks[rank], experter_count_lists[rank], all_workspaces, max_token_count_per_rank,
= torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(expert_ids_all_ranks[rank], experter_count_lists[rank], all_workspaces, max_token_count_per_rank,
rank, ep_size, expert_count, slot_count, top_k)
else:
torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(
expert_ids_all_ranks[rank], scales_all_ranks[rank],
experter_count_lists[rank], all_workspaces,
max_token_count_per_rank, rank, ep_size, expert_count,
slot_count, top_k)
expert_ids_all_ranks[rank], experter_count_lists[rank],
all_workspaces, max_token_count_per_rank, rank, ep_size,
expert_count, slot_count, top_k)
for rank in range(ep_size):
cuda_streams_all_ranks[rank].synchronize()
prepared_local_experts_cpu = prepared_local_experts[:
total_recv_token_count].cpu(
)
prepared_local_scales_cpu = prepared_local_scales[:
total_recv_token_count].cpu(
)
for i in range(total_recv_token_count):
for j in range(top_k):
expert_id = int(prepared_local_experts_cpu[i][j])
assert 0 <= expert_id and expert_id <= slot_count
if expert_id < slot_count:
assert compute_target_rank(expert_id) == ep_rank
scale = float(prepared_local_scales_cpu[i][j])
assert scale > 1e-6
gathered_expert_statics_cpu = gathered_expert_statics.cpu()
for rank in range(ep_size):
for i in range(expert_count):