mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Merge commit '31979aefacbf80d2742c98ef30385db162788c84' into feat/b300_cu13
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
commit
ab7febd4d8
1372
cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu
Normal file
1372
cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu
Normal file
File diff suppressed because it is too large
Load Diff
562
cpp/tensorrt_llm/kernels/fusedMoeCommKernels.h
Normal file
562
cpp/tensorrt_llm/kernels/fusedMoeCommKernels.h
Normal file
@ -0,0 +1,562 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/moeCommKernelsCommon.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
struct ALIGN_256 SenderSideFifoInfo
|
||||
{
|
||||
volatile uint64_t head; // write position
|
||||
volatile uint64_t tail; // read position
|
||||
};
|
||||
|
||||
struct ALIGN_256 ReceiverSideFifoInfo
|
||||
{
|
||||
volatile uint64_t head; // write position do we use this?
|
||||
volatile uint64_t tail; // read position
|
||||
};
|
||||
|
||||
// struct holding Send/Recv data pointer and its displacement information.
|
||||
struct SendRecvIndices
|
||||
{
|
||||
int const* rankCountCumSum; // length = epSize
|
||||
int* rankLocalIndices; // length = rankCountCumSum[epRank] - rankCountCumSum[epRank - 1] if epRank > 0 else
|
||||
// rankCountCumSum[epRank]
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__inline__ __device__ int getCount(int rank) const
|
||||
{
|
||||
return rank == 0 ? rankCountCumSum[rank] : rankCountCumSum[rank] - rankCountCumSum[rank - 1];
|
||||
}
|
||||
|
||||
__inline__ __device__ int getRankStart(int rank) const
|
||||
{
|
||||
return rank == 0 ? 0 : rankCountCumSum[rank - 1];
|
||||
}
|
||||
|
||||
__inline__ __device__ int* getGroupStart(int rank, int& tokenCount) const
|
||||
{
|
||||
tokenCount = getCount(rank);
|
||||
int rankStart = getRankStart(rank);
|
||||
return rankLocalIndices + rankStart;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
struct MoeCommFieldInfo
|
||||
{
|
||||
uint8_t* dataPtrBase;
|
||||
uint8_t alignedUnitBit; // 0, 1, 2, 3, 4 (for 1, 2, 4, 8, 16 Bytes), smallest aligned unit.
|
||||
uint16_t alignedUnitCount; // data count in aligned unit
|
||||
uint16_t alignedUnitStride; // data stride in aligned unit
|
||||
|
||||
uint8_t unalignedFieldIndex; // the index of unaligned Field, no decrease with field index
|
||||
uint16_t compact16BOffset; // aligned to 16 Bytes, offset is count of 16 Byte
|
||||
|
||||
static constexpr uint64_t kAlign16BytePtrMask = (1ULL << 4) - 1;
|
||||
static constexpr uint32_t kAligned16BMask = (1 << 4) - 1;
|
||||
|
||||
// Constants for memory alignment and access
|
||||
static constexpr int BYTES_PER_128B_BLOCK = 128;
|
||||
static constexpr int INTS_PER_128B_BLOCK = BYTES_PER_128B_BLOCK / sizeof(int);
|
||||
static constexpr int UINT64_PER_128B_BLOCK = BYTES_PER_128B_BLOCK / sizeof(uint64_t);
|
||||
static constexpr int BYTES_PER_16B_BLOCK = 16;
|
||||
// Will pad one 16 byte for each unaligned field, then head and tail 16 byte might not be aligned
|
||||
|
||||
// Fill single field info, the fields that need global info is not filled here.
|
||||
__host__ void fillFieldInfo(uint8_t* dataPtr, size_t elementSize, int vectorSize, int stride);
|
||||
|
||||
__host__ void setUnused()
|
||||
{
|
||||
dataPtrBase = nullptr;
|
||||
alignedUnitBit = 4;
|
||||
alignedUnitCount = 0;
|
||||
alignedUnitStride = 0;
|
||||
unalignedFieldIndex = 0;
|
||||
compact16BOffset = 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ void fillFieldInfo(T* dataPtr, int vectorSize, int stride)
|
||||
{
|
||||
size_t elementSize = sizeof(T);
|
||||
fillFieldInfo(reinterpret_cast<uint8_t*>(dataPtr), elementSize, vectorSize, stride);
|
||||
}
|
||||
|
||||
__device__ __host__ __forceinline__ int getFieldUncompactSize() const
|
||||
{
|
||||
int alignedUnitBytes = 1 << alignedUnitBit;
|
||||
int currentFieldSize = alignedUnitCount * alignedUnitBytes;
|
||||
if (alignedUnitBytes != 16)
|
||||
{
|
||||
constexpr int alignedUnitBytes = BYTES_PER_16B_BLOCK;
|
||||
currentFieldSize = currentFieldSize / alignedUnitBytes * alignedUnitBytes;
|
||||
currentFieldSize += alignedUnitBytes * 2;
|
||||
}
|
||||
return currentFieldSize;
|
||||
}
|
||||
|
||||
__device__ __host__ __forceinline__ int getFieldCompactSize() const
|
||||
{
|
||||
int alignedUnitBytes = 1 << alignedUnitBit;
|
||||
int currentFieldSize = alignedUnitCount * alignedUnitBytes;
|
||||
// Align to 16 bytes for compact size
|
||||
return (currentFieldSize + BYTES_PER_16B_BLOCK - 1) / BYTES_PER_16B_BLOCK * BYTES_PER_16B_BLOCK;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int getCompactShmOffset() const
|
||||
{
|
||||
return compact16BOffset * BYTES_PER_16B_BLOCK;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int getUncompactShmOffset() const
|
||||
{
|
||||
// each unaligned field need 16 byte head and 16 byte tail
|
||||
return compact16BOffset * BYTES_PER_16B_BLOCK + unalignedFieldIndex * BYTES_PER_16B_BLOCK;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int getMemmoveOffsets(int index) const
|
||||
{
|
||||
int alignedBytes = 1 << alignedUnitBit;
|
||||
uint8_t* dataPtr = dataPtrBase + index * alignedBytes * alignedUnitStride;
|
||||
int offset = reinterpret_cast<uint64_t>(dataPtr) & kAlign16BytePtrMask;
|
||||
return offset + unalignedFieldIndex * BYTES_PER_16B_BLOCK;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t* getRawPtr(int index, int* rawSize) const
|
||||
{
|
||||
int alignedBytes = 1 << alignedUnitBit;
|
||||
uint8_t* dataPtr = dataPtrBase + static_cast<size_t>(index) * alignedBytes * alignedUnitStride;
|
||||
if (rawSize != nullptr)
|
||||
{
|
||||
*rawSize = alignedUnitCount * alignedBytes;
|
||||
}
|
||||
return dataPtr;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t* get16BAlignedLoadCopyRange(int index, int* copyByteCount) const
|
||||
{
|
||||
int rawSize;
|
||||
uint8_t* rawDataPtr = getRawPtr(index, &rawSize);
|
||||
uint8_t* rawEndPtr = rawDataPtr + rawSize;
|
||||
uint8_t* alignedDataPtr
|
||||
= reinterpret_cast<uint8_t*>(reinterpret_cast<uint64_t>(rawDataPtr) & (~kAlign16BytePtrMask));
|
||||
uint32_t copySize = rawEndPtr - alignedDataPtr;
|
||||
*copyByteCount
|
||||
= (copySize & kAligned16BMask) != 0 ? (copySize & (~kAligned16BMask)) + BYTES_PER_16B_BLOCK : copySize;
|
||||
return alignedDataPtr;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t* get16BAlignedStoreCopyRange(
|
||||
int index, int* copyByteCount, int laneId, int* headTailShmIdx, int* headTailGlobalIdx) const
|
||||
{
|
||||
int rawSize;
|
||||
uint8_t* rawDataPtr = getRawPtr(index, &rawSize);
|
||||
uint8_t* rawEndPtr = rawDataPtr + rawSize;
|
||||
int offset = reinterpret_cast<uint64_t>(rawDataPtr) & kAlign16BytePtrMask;
|
||||
uint8_t* alignedDataPtr
|
||||
= reinterpret_cast<uint8_t*>(reinterpret_cast<uint64_t>(rawDataPtr) + BYTES_PER_16B_BLOCK - offset);
|
||||
uint8_t* alignedEndPtr
|
||||
= reinterpret_cast<uint8_t*>(reinterpret_cast<uint64_t>(rawEndPtr) & (~kAlign16BytePtrMask));
|
||||
int alignedCopyBytes = alignedEndPtr - alignedDataPtr;
|
||||
if (alignedCopyBytes < 0)
|
||||
{
|
||||
alignedCopyBytes = 0;
|
||||
}
|
||||
*copyByteCount = alignedCopyBytes;
|
||||
|
||||
if (laneId < BYTES_PER_16B_BLOCK)
|
||||
{
|
||||
*headTailShmIdx = laneId;
|
||||
}
|
||||
else
|
||||
{
|
||||
*headTailShmIdx = laneId + alignedCopyBytes;
|
||||
}
|
||||
*headTailGlobalIdx = *headTailShmIdx - offset;
|
||||
if (*headTailGlobalIdx < 0 || *headTailGlobalIdx >= rawSize)
|
||||
{
|
||||
*headTailGlobalIdx = -1;
|
||||
*headTailShmIdx = -1;
|
||||
}
|
||||
return alignedDataPtr;
|
||||
}
|
||||
};
|
||||
|
||||
// Maximum number of field supported, except tokenSelectedExpert and expertScales
|
||||
static constexpr int MOE_COMM_FIELD_MAX_COUNT = 8;
|
||||
|
||||
struct MoeSingleCommMeta
|
||||
{
|
||||
int singleTransferAlignedSize; // transfer size aligned to 128 bytes.
|
||||
int singleCompactAlignedSize; // compact buffer is always aligned to 128 bytes
|
||||
int singleUncompactAlignedSize; // uncompact shared memory size, aligned to 128 bytes, might be larger than compact
|
||||
// buffer if unaligned field exist.
|
||||
|
||||
// TODO: Do we need reduce shared memory usage, make it able to be smaller, and enable multiple wave?
|
||||
|
||||
__device__ __host__ __forceinline__ int getTransfer128ByteCount() const
|
||||
{
|
||||
return singleTransferAlignedSize / MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
|
||||
}
|
||||
|
||||
__device__ __host__ __forceinline__ int getCompactData128ByteCount() const
|
||||
{
|
||||
return singleCompactAlignedSize / MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
|
||||
}
|
||||
|
||||
__device__ __host__ __forceinline__ int getSingleShmSize() const
|
||||
{
|
||||
return std::max(singleUncompactAlignedSize, singleTransferAlignedSize);
|
||||
}
|
||||
};
|
||||
|
||||
struct FusedMoeWorldInfo
|
||||
{
|
||||
MoeEpWorldInfo epInfo;
|
||||
};
|
||||
|
||||
struct FusedMoePairInfo
|
||||
{
|
||||
int senderRank;
|
||||
int receiverRank;
|
||||
int channel;
|
||||
int runChannelCount;
|
||||
};
|
||||
|
||||
class FusedMoeCommunicator
|
||||
{
|
||||
public:
|
||||
static constexpr int FIFO_DEPTH = 4;
|
||||
static constexpr int FIFO_ENTRY_BYTES = 256 * 1024;
|
||||
static constexpr int FIFO_ENTRY_128_BYTE_COUNT = FIFO_ENTRY_BYTES / 128;
|
||||
static constexpr int FIFO_TOTAL_BYTES = FIFO_ENTRY_BYTES * FIFO_DEPTH;
|
||||
static constexpr int FIFO_TOTAL_U64 = FIFO_TOTAL_BYTES / sizeof(uint64_t);
|
||||
static constexpr int MAX_GROUP_COUNT_PER_BLOCK = 8;
|
||||
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
|
||||
static int maxSmCount;
|
||||
static bool maxSmCountUsed;
|
||||
|
||||
static void setMaxUsableSmCount(int maxUsableSmCount)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
FusedMoeCommunicator::maxSmCountUsed == false, "setMaxUsableSmCount can be called only before it is used");
|
||||
int smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
if (maxUsableSmCount > smCount)
|
||||
{
|
||||
TLLM_LOG_WARNING("setMaxUsableSmCount, maxUsableSmCount=%d, larger than smCount=%d, using smCount instead",
|
||||
maxUsableSmCount, smCount);
|
||||
maxUsableSmCount = smCount;
|
||||
}
|
||||
FusedMoeCommunicator::maxSmCount = maxUsableSmCount;
|
||||
}
|
||||
|
||||
static int getMaxUsableSmCount()
|
||||
{
|
||||
FusedMoeCommunicator::maxSmCountUsed = true;
|
||||
if (FusedMoeCommunicator::maxSmCount == -1)
|
||||
{
|
||||
int smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
FusedMoeCommunicator::maxSmCount = smCount;
|
||||
}
|
||||
return FusedMoeCommunicator::maxSmCount;
|
||||
}
|
||||
|
||||
static int computeMoeCommChannelCount(int epSize)
|
||||
{
|
||||
int smCount = getMaxUsableSmCount();
|
||||
int blockCountPerChannel = (epSize + MAX_GROUP_COUNT_PER_BLOCK - 1) / MAX_GROUP_COUNT_PER_BLOCK;
|
||||
blockCountPerChannel *= 2; // for send and recv
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
blockCountPerChannel <= smCount, "GPU should support at lease one channel, usableSmCount=%d", smCount);
|
||||
int perferredChannel = smCount / 2 / blockCountPerChannel; // use half SMs for communication
|
||||
int channelCount = std::max(perferredChannel, 1); // at lease one channel
|
||||
return channelCount;
|
||||
}
|
||||
|
||||
static int getMoeCommChannelCount(int epSize)
|
||||
{
|
||||
static std::map<int, int> channelCountMap{};
|
||||
auto iter = channelCountMap.find(epSize);
|
||||
if (iter == channelCountMap.end())
|
||||
{
|
||||
auto channelCount = FusedMoeCommunicator::computeMoeCommChannelCount(epSize);
|
||||
channelCountMap[epSize] = channelCount;
|
||||
return channelCount;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
static dim3 getLaunchBlockDim(int groupCountPerCta)
|
||||
{
|
||||
return dim3(WARP_SIZE, groupCountPerCta);
|
||||
}
|
||||
|
||||
static dim3 getLaunchGridDim(int epSize, int groupCountPerCta)
|
||||
{
|
||||
int maxChannelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize);
|
||||
int targetCtaCount = (epSize + MAX_GROUP_COUNT_PER_BLOCK - 1) / MAX_GROUP_COUNT_PER_BLOCK * maxChannelCount * 2;
|
||||
int ctaPerChannel = (epSize + groupCountPerCta - 1) / groupCountPerCta;
|
||||
int ctaLimitedChannelCount = targetCtaCount / 2 / ctaPerChannel;
|
||||
ctaLimitedChannelCount = std::max(1, ctaLimitedChannelCount);
|
||||
int channelCount = std::min(ctaLimitedChannelCount, maxChannelCount);
|
||||
return dim3(ctaPerChannel, channelCount, 2);
|
||||
}
|
||||
};
|
||||
|
||||
size_t getFusedMoeCommWorkspaceSize(int epSize);
|
||||
|
||||
struct FusedMoeFieldInfo
|
||||
{
|
||||
int8_t isBasicInterleaved; // using tokenSelectedSlots and expertScales interleaving?
|
||||
int32_t* tokenSelectedSlots;
|
||||
float* expertScales; // can be nullptr if no scale is used(all 1.0), if so, interleaved should all be 0
|
||||
int fieldCount;
|
||||
MoeCommFieldInfo fieldsInfo[MOE_COMM_FIELD_MAX_COUNT];
|
||||
|
||||
__host__ int computeSingleCompactSize(int topK, bool hasScales, bool hasBasicFields) const
|
||||
{
|
||||
int basicFieldSize = 0;
|
||||
if (hasBasicFields)
|
||||
{
|
||||
basicFieldSize = topK * sizeof(int) + (hasScales ? topK * sizeof(float) : 0);
|
||||
// align to 16 bytes
|
||||
basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1)
|
||||
/ MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
|
||||
}
|
||||
int otherFieldSize = 0;
|
||||
for (int i = 0; i < fieldCount; i++)
|
||||
{
|
||||
MoeCommFieldInfo const& fieldInfo = fieldsInfo[i];
|
||||
otherFieldSize += fieldInfo.getFieldCompactSize();
|
||||
}
|
||||
int totalSize = basicFieldSize + otherFieldSize;
|
||||
constexpr int totalSizeAlignment = MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
|
||||
totalSize = (totalSize + totalSizeAlignment - 1) / totalSizeAlignment * totalSizeAlignment;
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
__host__ int computeSingleUncompactSize(int topK, bool hasScales, bool hasBasicFields) const
|
||||
{
|
||||
int basicFieldSize = 0;
|
||||
if (hasBasicFields)
|
||||
{
|
||||
basicFieldSize = topK * sizeof(int) + (hasScales ? topK * sizeof(float) : 0);
|
||||
// align to 16 bytes
|
||||
basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1)
|
||||
/ MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
|
||||
}
|
||||
int otherFieldSize = 0;
|
||||
for (int i = 0; i < fieldCount; i++)
|
||||
{
|
||||
MoeCommFieldInfo const& fieldInfo = fieldsInfo[i];
|
||||
otherFieldSize += fieldInfo.getFieldUncompactSize();
|
||||
}
|
||||
int totalSize = basicFieldSize + otherFieldSize;
|
||||
constexpr int totalSizeAlignment = MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
|
||||
totalSize = (totalSize + totalSizeAlignment - 1) / totalSizeAlignment * totalSizeAlignment;
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
template <typename T = int, bool IS_SLOTS = true>
|
||||
__device__ __forceinline__ T* getBasicFieldPtr(int tokenIndex, int selectedIndex, int topK) const
|
||||
{
|
||||
T* fieldPtr = nullptr;
|
||||
fieldPtr = IS_SLOTS ? reinterpret_cast<T*>(tokenSelectedSlots) : reinterpret_cast<T*>(expertScales);
|
||||
if (fieldPtr == nullptr || selectedIndex >= topK)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
int tokenStride = isBasicInterleaved ? topK * 2 : topK;
|
||||
int elementStride = isBasicInterleaved ? 2 : 1;
|
||||
return fieldPtr + tokenIndex * tokenStride + selectedIndex * elementStride;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int* getTokenSelectedSlotsPtr(int tokenIndex, int selectedIndex, int topK) const
|
||||
{
|
||||
return getBasicFieldPtr<int, true>(tokenIndex, selectedIndex, topK);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float* getScalePtr(int tokenIndex, int selectedIndex, int topK) const
|
||||
{
|
||||
return getBasicFieldPtr<float, false>(tokenIndex, selectedIndex, topK);
|
||||
}
|
||||
|
||||
void fillMetaInfo(MoeSingleCommMeta* singleCommMeta, int topK, bool hasScales, bool hasBasicFields) const;
|
||||
|
||||
void fillFieldPlacementInfo(int topK, bool hasBasicFields);
|
||||
};
|
||||
|
||||
struct FusedMoeCommKernelParam
|
||||
{
|
||||
FusedMoeWorldInfo worldInfo;
|
||||
MoeExpertParallelInfo expertParallelInfo; // expertCount inside should be slotCount if using redundant experts.
|
||||
MoeSingleCommMeta sendCommMeta;
|
||||
MoeSingleCommMeta recvCommMeta;
|
||||
SendRecvIndices sendIndices;
|
||||
SendRecvIndices recvIndices;
|
||||
FusedMoeFieldInfo sendFieldInfo;
|
||||
FusedMoeFieldInfo recvFieldInfo;
|
||||
};
|
||||
|
||||
/*
|
||||
* Workspace Layout:
|
||||
* Ri: Rank i
|
||||
* N: Number of GPUs, e.g. EpSize or WorldSize, n = N - 1
|
||||
* Ci: Channel i
|
||||
* M: Number of Channels, m = M - 1
|
||||
* MMr: Memory Mapped from Rank r, physically located at rank r, and mapped to all ranks.
|
||||
*
|
||||
* Whole workspace memory space:
|
||||
* ---------------------------------------------------------------------------------------------------
|
||||
* |<-- MM0 --> |<-- MM1 --> |<-- MM2 --> | ...... |<-- MMn --> |
|
||||
* ^ ^ ^ ^ ^ ^
|
||||
* 0 rankStrideInU64 2*rankStrideInU64 3*rankStrideInU64 n*rankStrideInU64 N*rankStrideInU64
|
||||
*
|
||||
* For each MMr, the layout is:
|
||||
* -------------------------------------------------------------------------------------------------
|
||||
* |<--- FIFO memory --->|<--- SenderSideFifoInfo memory --->|<--- ReceiverSideFifoInfo memory --->|
|
||||
* -------------------------------------------------------------------------------------------------
|
||||
*
|
||||
* For each FIFO memory, it is physically placed at the receiver rank.
|
||||
* To find the FIFO whose receiver is rank r, we need to find that in the FIFO memory of MMr.
|
||||
* The layout of FIFO memory of each MMR is(here rank is the sender rank):
|
||||
* -------------------------------------------------------------------------------------------------
|
||||
* | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm |
|
||||
* |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->|
|
||||
* -------------------------------------------------------------------------------------------------
|
||||
* Each R*C* has length of FIFO_TOTAL_U64 in uint64_t, which is internally divided into FIFO_DEPTH entries of
|
||||
* size FIFO_ENTRY_BYTES each.
|
||||
*
|
||||
* For each SenderSideFifoInfo memory, it is physically placed at the sender rank.
|
||||
* To find the SenderSideFifoInfo whose sender is rank r, we need to find that in the FIFO memory of MMr.
|
||||
* The layout of SenderSideFifoInfo memory of each MMR is(here rank is the receiver rank):
|
||||
* -------------------------------------------------------------------------------------------------
|
||||
* | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm |
|
||||
* |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->|
|
||||
* -------------------------------------------------------------------------------------------------
|
||||
* Each R*C* is one struct of SenderSideFifoInfo. There are total M * N SenderSideFifoInfo in each MMR.
|
||||
*
|
||||
* For each ReceiverSideFifoInfo memory, it is physically placed at the receiver rank.
|
||||
* To find the ReceiverSideFifoInfo whose receiver is rank r, we need to find that in the FIFO memory of MMr.
|
||||
* The layout of ReceiverSideFifoInfo memory of each MMR is(here rank is the sender rank):
|
||||
* -------------------------------------------------------------------------------------------------
|
||||
* | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm |
|
||||
* |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->|
|
||||
* -------------------------------------------------------------------------------------------------
|
||||
* Each R*C* is one struct of ReceiverSideFifoInfo. There are total M * N ReceiverSideFifoInfo in each MMR.
|
||||
*/
|
||||
|
||||
struct FusedMoeWorkspace
|
||||
{
|
||||
uint64_t* workspacePtr;
|
||||
size_t rankStrideInU64;
|
||||
int channelCount;
|
||||
|
||||
template <bool isSenderSideBuffer>
|
||||
__device__ __forceinline__ uint8_t* commonGetPtrBase(
|
||||
FusedMoePairInfo const& pairInfo, size_t fieldOffset, int fieldSingleSize) const
|
||||
{
|
||||
int mappedMemoryrank = isSenderSideBuffer ? pairInfo.senderRank : pairInfo.receiverRank;
|
||||
int rankInsideMappedMemory = isSenderSideBuffer ? pairInfo.receiverRank : pairInfo.senderRank;
|
||||
auto* mappedMemory = reinterpret_cast<uint8_t*>(workspacePtr + mappedMemoryrank * rankStrideInU64);
|
||||
mappedMemory += fieldOffset;
|
||||
mappedMemory += rankInsideMappedMemory * channelCount * fieldSingleSize;
|
||||
mappedMemory += pairInfo.channel * fieldSingleSize;
|
||||
return mappedMemory;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint64_t* getFifoBasePtr(
|
||||
FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const
|
||||
{
|
||||
constexpr int fieldSingleSize = FusedMoeCommunicator::FIFO_TOTAL_BYTES;
|
||||
return reinterpret_cast<uint64_t*>(commonGetPtrBase<false>(pairInfo, 0, fieldSingleSize));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ SenderSideFifoInfo* getSenderSideFifoInfo(
|
||||
FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const
|
||||
{
|
||||
constexpr int fieldSingleSize = sizeof(SenderSideFifoInfo);
|
||||
size_t fieldOffset
|
||||
= static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * worldInfo.epInfo.epSize * channelCount;
|
||||
return reinterpret_cast<SenderSideFifoInfo*>(commonGetPtrBase<true>(pairInfo, fieldOffset, fieldSingleSize));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ ReceiverSideFifoInfo* getReceiverSideFifoInfo(
|
||||
FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const
|
||||
{
|
||||
constexpr int fieldSingleSize = sizeof(ReceiverSideFifoInfo);
|
||||
size_t fieldOffset
|
||||
= static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * worldInfo.epInfo.epSize * channelCount
|
||||
+ sizeof(SenderSideFifoInfo) * worldInfo.epInfo.epSize * channelCount;
|
||||
return reinterpret_cast<ReceiverSideFifoInfo*>(commonGetPtrBase<false>(pairInfo, fieldOffset, fieldSingleSize));
|
||||
}
|
||||
|
||||
static size_t computeWorkspaceSizePreRank(int epSize, int channelCount)
|
||||
{
|
||||
size_t fifoSize = static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * epSize * channelCount;
|
||||
size_t senderSideInfoSize = sizeof(SenderSideFifoInfo) * epSize * channelCount;
|
||||
size_t receiverSideInfoSize = sizeof(ReceiverSideFifoInfo) * epSize * channelCount;
|
||||
return fifoSize + senderSideInfoSize + receiverSideInfoSize;
|
||||
}
|
||||
|
||||
void initializeLocalWorkspace(FusedMoeWorldInfo const& worldInfo);
|
||||
};
|
||||
|
||||
void setMaxUsableSmCount(int smCount);
|
||||
|
||||
void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cudaStream_t stream);
|
||||
|
||||
void constructWorkspace(FusedMoeWorkspace* workspace, uint64_t* workspacePtr, size_t rankStrideInU64, int epSize);
|
||||
|
||||
void initializeFusedMoeLocalWorkspace(FusedMoeWorkspace* workspace, FusedMoeWorldInfo const& worldInfo);
|
||||
|
||||
namespace fused_moe_comm_tests
|
||||
{
|
||||
|
||||
// Functions for testing
|
||||
|
||||
void launchSingleG2S(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
|
||||
int tokenCount, int* shmDump, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream);
|
||||
|
||||
void launchSingleS2G(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
|
||||
int tokenCount, int* shmPreload, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream);
|
||||
|
||||
void launchLoopback(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
|
||||
MoeExpertParallelInfo const& expertParallelInfo, int* recvIndexMapping, int tokenCount, int warpsPerBlock,
|
||||
bool hasBasicFields, cudaStream_t stream);
|
||||
|
||||
void launchLocalFifoSendRecv(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
|
||||
MoeExpertParallelInfo const& expertParallelInfo, int* sendIndexMapping, int* recvIndexMapping,
|
||||
FusedMoeWorkspace fusedMoeWorkspace, int tokenCount, int warpsPerBlock, int blockChannelCount, bool hasBasicFields,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace fused_moe_comm_tests
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -1,804 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "moeCommKernels.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
__device__ inline void barrier_sync(int name, int nThreads)
|
||||
{
|
||||
asm volatile("barrier.sync.aligned %0, %1;" ::"r"(name), "r"(nThreads) : "memory");
|
||||
}
|
||||
|
||||
inline __device__ void load128(uint64_t const* ptr, uint64_t& v0, uint64_t& v1)
|
||||
{
|
||||
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v0), "=l"(v1) : "l"(ptr) : "memory");
|
||||
}
|
||||
|
||||
inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1)
|
||||
{
|
||||
asm volatile("st.volatile.global.v2.u64 [%2], {%0,%1};" ::"l"(v0), "l"(v1), "l"(ptr) : "memory");
|
||||
}
|
||||
|
||||
template <bool isSender>
|
||||
class AllToAllChannelCommunicator : public AllToAllChannelCommunicatorBase
|
||||
{
|
||||
private:
|
||||
int const tid; // thread index in primitives group
|
||||
int const nthreads; // number of threads in primitives group
|
||||
int const wid; // lane index in warp
|
||||
int const warp; // warp index in primitives group
|
||||
const MoeEpWorldInfo worldInfo;
|
||||
const MoeCommWorkspace workspace;
|
||||
const SendRecvDataInfo sendRecvDataInfo;
|
||||
const SendRecvDispls dataDispls;
|
||||
int peerRank; // peer rank index
|
||||
bool const flagThread;
|
||||
int const group; // primitives group index
|
||||
int const channel; // channel index
|
||||
int const channelCount; // count of channels
|
||||
|
||||
MoeCommFifoConnInfo* fifoConnInfoPtr;
|
||||
uint64_t* fifoBasePtr; // pointer to fifo base address
|
||||
uint64_t step;
|
||||
uint64_t tailStepCache;
|
||||
uint64_t regs[U64_DATA_REG_PER_THREAD];
|
||||
GroupSharedBuffer* groupSharedBuffer;
|
||||
|
||||
int groupStartIndice;
|
||||
int groupEndIndice;
|
||||
|
||||
int sliceStartIndice;
|
||||
int sliceEndIndice;
|
||||
|
||||
uint64_t* stepFifoEntryPtr;
|
||||
|
||||
public:
|
||||
__inline__ __device__ uint64_t getFlag()
|
||||
{
|
||||
return step + 1;
|
||||
}
|
||||
|
||||
__inline__ __device__ AllToAllChannelCommunicator(MoeEpWorldInfo const& worldInfo, MoeCommWorkspace workspace,
|
||||
SendRecvDataInfo sendRecvDataInfo, SendRecvDispls dataDispls, GroupSharedBuffer* groupSharedBuffer,
|
||||
int channelCount)
|
||||
: worldInfo(worldInfo)
|
||||
, nthreads(blockDim.x)
|
||||
, tid(threadIdx.x)
|
||||
, workspace(workspace)
|
||||
, sendRecvDataInfo(sendRecvDataInfo)
|
||||
, dataDispls(dataDispls)
|
||||
, wid(threadIdx.x % WARP_SIZE)
|
||||
, warp(threadIdx.x / WARP_SIZE)
|
||||
, peerRank(blockIdx.x * GROUP_COUNT_PER_BLOCK + threadIdx.y)
|
||||
, group(threadIdx.y)
|
||||
, channel(blockIdx.y)
|
||||
, flagThread(threadIdx.x % 8 == 7)
|
||||
, fifoConnInfoPtr(nullptr)
|
||||
, fifoBasePtr(nullptr)
|
||||
, step(0)
|
||||
, tailStepCache(0)
|
||||
, groupSharedBuffer(groupSharedBuffer)
|
||||
, channelCount(channelCount)
|
||||
{
|
||||
}
|
||||
|
||||
__inline__ __device__ void init()
|
||||
{
|
||||
fifoBasePtr = workspace.getFifoBasePtr(isSender, worldInfo.epRank, peerRank, channel, channelCount);
|
||||
fifoConnInfoPtr
|
||||
= workspace.getFifoConnInfo(isSender, worldInfo.epRank, peerRank, channel, worldInfo.epSize, channelCount);
|
||||
step = isSender ? fifoConnInfoPtr->head : fifoConnInfoPtr->tail;
|
||||
tailStepCache = isSender ? fifoConnInfoPtr->tail : 0;
|
||||
}
|
||||
|
||||
__inline__ __device__ void computeGroupTransferRange()
|
||||
{
|
||||
if (tid == 0)
|
||||
{
|
||||
int rankCount = dataDispls.getCount(peerRank);
|
||||
int rankStart = dataDispls.getRankStart(peerRank);
|
||||
int countPerChannel = (rankCount + channelCount - 1) / channelCount;
|
||||
int groupEnd = min(rankStart + (channel + 1) * countPerChannel, rankStart + rankCount);
|
||||
int groupStart = min(rankStart + channel * countPerChannel, rankStart + rankCount);
|
||||
groupSharedBuffer->groupStartIndice = groupStart;
|
||||
groupSharedBuffer->groupEndIndice = groupEnd;
|
||||
}
|
||||
barrier();
|
||||
groupStartIndice = groupSharedBuffer->groupStartIndice;
|
||||
groupEndIndice = groupSharedBuffer->groupEndIndice;
|
||||
}
|
||||
|
||||
__inline__ __device__ void loadTransferIndices()
|
||||
{
|
||||
sliceStartIndice = groupStartIndice;
|
||||
sliceEndIndice = min(groupStartIndice + sendRecvDataInfo.vectorCountPerFifoEntry, groupEndIndice);
|
||||
for (int i = groupStartIndice + tid; i < sliceEndIndice; i += WARP_SIZE * WARP_PER_GROUP)
|
||||
{
|
||||
groupSharedBuffer->groupIndiceBuffer[i - groupStartIndice] = dataDispls.getRealVectorIndice(i);
|
||||
}
|
||||
groupStartIndice = sliceEndIndice;
|
||||
barrier();
|
||||
}
|
||||
|
||||
__inline__ __device__ void computeSlicePtr()
|
||||
{
|
||||
stepFifoEntryPtr = fifoBasePtr + RECV_FIFO_ENTRY_U64 * (step % RECV_FIFO_DEPTH);
|
||||
}
|
||||
|
||||
__inline__ __device__ void sendSlice()
|
||||
{
|
||||
waitSend();
|
||||
int EltPer16B = 2;
|
||||
int eltN = sendRecvDataInfo.vectorSizeInU64;
|
||||
for (int vecId = warp + sliceStartIndice; vecId < sliceEndIndice; vecId += WARP_PER_GROUP)
|
||||
{
|
||||
int idxInSlice = vecId - sliceStartIndice;
|
||||
int vecRealIdx = groupSharedBuffer->groupIndiceBuffer[idxInSlice];
|
||||
uint64_t* src = dataDispls.getVectorDataPtr(vecRealIdx);
|
||||
uint64_t* slicePtr = stepFifoEntryPtr
|
||||
+ idxInSlice * sendRecvDataInfo.dataPacketCountPerVector * PACKET_SIZE_IN_U64 + 2 * wid;
|
||||
for (int packetId = 0; packetId < sendRecvDataInfo.dataPacketCountPerVector; packetId++)
|
||||
{
|
||||
int vecOff = packetId * DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64;
|
||||
#pragma unroll
|
||||
for (int g = 0; g < U64_DATA_REG_PER_THREAD / 2; g++)
|
||||
{
|
||||
int ix = g * WARP_SIZE - 4 * (g / 2) + wid - (g % 2) * (wid / 8);
|
||||
__syncwarp();
|
||||
if (!flagThread || g % 2 == 0)
|
||||
{
|
||||
if (ix * EltPer16B + vecOff < eltN)
|
||||
{
|
||||
load128((uint64_t*) (src + ix * EltPer16B + vecOff), regs[2 * g + 0], regs[2 * g + 1]);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
#pragma unroll
|
||||
for (int g = 1; g < U64_DATA_REG_PER_THREAD / 2; g += 2)
|
||||
{
|
||||
if (flagThread)
|
||||
regs[2 * g] = regs[2 * g - 1];
|
||||
}
|
||||
|
||||
uint64_t flag = getFlag();
|
||||
uint64_t* packetPtr = slicePtr + packetId * PACKET_SIZE_IN_U64;
|
||||
__syncwarp();
|
||||
#pragma unroll
|
||||
for (int u = 0; u < U64_DATA_REG_PER_THREAD; u += 2)
|
||||
{
|
||||
store128(packetPtr + u * WARP_SIZE, regs[u], flagThread ? flag : regs[u + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
updateSend();
|
||||
}
|
||||
|
||||
__inline__ __device__ void recvSlice()
|
||||
{
|
||||
// receiver don't need to wait since we have flag.
|
||||
int EltPer16B = 2;
|
||||
int eltN = sendRecvDataInfo.vectorSizeInU64;
|
||||
for (int vecId = warp + sliceStartIndice; vecId < sliceEndIndice; vecId += WARP_PER_GROUP)
|
||||
{
|
||||
int idxInSlice = vecId - sliceStartIndice;
|
||||
int vecRealIdx = groupSharedBuffer->groupIndiceBuffer[idxInSlice];
|
||||
|
||||
uint64_t* dst = dataDispls.getVectorDataPtr(vecRealIdx);
|
||||
uint64_t* slicePtr = stepFifoEntryPtr
|
||||
+ idxInSlice * sendRecvDataInfo.dataPacketCountPerVector * PACKET_SIZE_IN_U64 + 2 * wid;
|
||||
for (int packetId = 0; packetId < sendRecvDataInfo.dataPacketCountPerVector; packetId++)
|
||||
{
|
||||
uint64_t* packetPtr = slicePtr + packetId * PACKET_SIZE_IN_U64;
|
||||
int vecOff = packetId * DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64;
|
||||
|
||||
bool needReload;
|
||||
uint64_t flag = getFlag();
|
||||
__syncwarp();
|
||||
do
|
||||
{
|
||||
needReload = false;
|
||||
#pragma unroll
|
||||
for (int u = 0; u < U64_DATA_REG_PER_THREAD; u += 2)
|
||||
{
|
||||
load128(packetPtr + u * WARP_SIZE, regs[u], regs[u + 1]);
|
||||
needReload |= flagThread && (regs[u + 1] != flag);
|
||||
}
|
||||
} while (__any_sync(WARP_MASK, needReload));
|
||||
#pragma unroll
|
||||
for (int g = 1; g < U64_DATA_REG_PER_THREAD / 2; g += 2)
|
||||
{
|
||||
if (flagThread)
|
||||
regs[2 * g - 1] = regs[2 * g];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int g = 0; g < U64_DATA_REG_PER_THREAD / 2; g++)
|
||||
{
|
||||
int ix = g * WARP_SIZE - 4 * (g / 2) + wid - (g % 2) * (wid / 8);
|
||||
__syncwarp();
|
||||
if (!flagThread || g % 2 == 0)
|
||||
{
|
||||
if (ix * EltPer16B + vecOff < eltN)
|
||||
{
|
||||
store128((uint64_t*) (dst + ix * EltPer16B + vecOff), regs[2 * g + 0], regs[2 * g + 1]);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
}
|
||||
updateRecv();
|
||||
}
|
||||
|
||||
__inline__ __device__ void run()
|
||||
{
|
||||
if (peerRank >= worldInfo.epSize)
|
||||
{
|
||||
return;
|
||||
}
|
||||
init();
|
||||
computeGroupTransferRange();
|
||||
while (groupStartIndice < groupEndIndice)
|
||||
{
|
||||
loadTransferIndices();
|
||||
computeSlicePtr();
|
||||
if (isSender)
|
||||
{
|
||||
sendSlice();
|
||||
}
|
||||
else
|
||||
{
|
||||
recvSlice();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__inline__ __device__ ~AllToAllChannelCommunicator() {}
|
||||
|
||||
__inline__ __device__ void barrier()
|
||||
{
|
||||
barrier_sync(15 - group, nthreads);
|
||||
}
|
||||
|
||||
__inline__ __device__ void waitSend()
|
||||
{
|
||||
barrier();
|
||||
while (tailStepCache + RECV_FIFO_DEPTH < step + 1)
|
||||
{
|
||||
tailStepCache = fifoConnInfoPtr->tail;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
__inline__ __device__ void updateSend()
|
||||
{
|
||||
barrier();
|
||||
if (tid == 0)
|
||||
{
|
||||
atomicAdd_system((unsigned long long*) &fifoConnInfoPtr->head, 1);
|
||||
}
|
||||
barrier();
|
||||
step++;
|
||||
}
|
||||
|
||||
__inline__ __device__ void updateRecv()
|
||||
{
|
||||
barrier();
|
||||
if (tid == 0)
|
||||
{
|
||||
atomicAdd_system((unsigned long long*) &fifoConnInfoPtr->tail, 1);
|
||||
}
|
||||
barrier();
|
||||
step++;
|
||||
}
|
||||
};
|
||||
|
||||
__global__ void moeAllToAllKernel(MoeEpWorldInfo worldInfo, MoeCommWorkspace workspace,
|
||||
SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls, SendRecvDispls recvDispls)
|
||||
{
|
||||
__shared__ AllToAllChannelCommunicatorBase::GroupSharedBuffer
|
||||
allGroupSharedBuffer[AllToAllChannelCommunicatorBase::GROUP_COUNT_PER_BLOCK];
|
||||
bool isSender = blockIdx.z == 0;
|
||||
int channelCount = gridDim.y;
|
||||
int group = threadIdx.y;
|
||||
SendRecvDispls dataDispls = isSender ? sendDispls : recvDispls;
|
||||
AllToAllChannelCommunicatorBase::GroupSharedBuffer* groupSharedBuffer = &allGroupSharedBuffer[group];
|
||||
if (isSender)
|
||||
{
|
||||
AllToAllChannelCommunicator<true> comm(
|
||||
worldInfo, workspace, sendRecvDataInfo, dataDispls, groupSharedBuffer, channelCount);
|
||||
comm.run();
|
||||
}
|
||||
else
|
||||
{
|
||||
AllToAllChannelCommunicator<false> comm(
|
||||
worldInfo, workspace, sendRecvDataInfo, dataDispls, groupSharedBuffer, channelCount);
|
||||
comm.run();
|
||||
}
|
||||
}
|
||||
|
||||
void moeAllToAll(MoeEpWorldInfo worldInfo, SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls,
|
||||
SendRecvDispls recvDispls, MoeCommWorkspace workspace, cudaStream_t stream)
|
||||
{
|
||||
sendRecvDataInfo.DoPreCompute();
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
reinterpret_cast<uintptr_t>(sendDispls.dataPtr) % 16 == 0, "sendDispls.dataPtr must be 16-byte aligned");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
reinterpret_cast<uintptr_t>(recvDispls.dataPtr) % 16 == 0, "recvDispls.dataPtr must be 16-byte aligned");
|
||||
dim3 block = AllToAllChannelCommunicatorBase::getLaunchBlockDim();
|
||||
dim3 grid = AllToAllChannelCommunicatorBase::getLaunchGridDim(worldInfo.epSize);
|
||||
moeAllToAllKernel<<<grid, block, 0, stream>>>(worldInfo, workspace, sendRecvDataInfo, sendDispls, recvDispls);
|
||||
}
|
||||
|
||||
template <bool isSend, int kThreadsGroupSize>
|
||||
__inline__ __device__ void computeSendRecvRankCountDevice(MoeEpWorldInfo worldInfo,
|
||||
MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank, int const* realRankTokenCountCumSum,
|
||||
int const* gatheredTargetRankIds, int* sharedSendRecvRankCount, int* sendRecvRankCount)
|
||||
{
|
||||
cg::thread_block_tile<kThreadsGroupSize> tile = cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
|
||||
int laneInTile = tile.thread_rank();
|
||||
int tileId = threadIdx.x / kThreadsGroupSize;
|
||||
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
|
||||
|
||||
int topK = expertParallelInfo.topK;
|
||||
int epRank = worldInfo.epRank;
|
||||
int epSize = worldInfo.epSize;
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
*sharedSendRecvRankCount = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
int readRank = isSend ? epRank : blockIdx.x;
|
||||
int compareRankId = isSend ? blockIdx.x : epRank;
|
||||
int const* readRankTargetRankIds = gatheredTargetRankIds + readRank * maxTokenCountPerRank * topK;
|
||||
int readRankTokenCount = maxTokenCountPerRank;
|
||||
if (realRankTokenCountCumSum != nullptr)
|
||||
{
|
||||
int readRankStart = readRank == 0 ? 0 : realRankTokenCountCumSum[readRank - 1];
|
||||
readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK;
|
||||
readRankTokenCount = realRankTokenCountCumSum[readRank] - readRankStart;
|
||||
}
|
||||
|
||||
for (int i = tileId + blockIdx.z * tileCountPerBlock; i < readRankTokenCount; i += tileCountPerBlock * gridDim.z)
|
||||
{
|
||||
int targetRankId = laneInTile < topK ? readRankTargetRankIds[i * topK + laneInTile] : epSize;
|
||||
bool rankMatched = (targetRankId == compareRankId);
|
||||
bool hasRankMatched = tile.any(rankMatched);
|
||||
if (hasRankMatched && laneInTile == 0)
|
||||
{
|
||||
atomicAdd_block(sharedSendRecvRankCount, 1);
|
||||
}
|
||||
tile.sync();
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd_system(sendRecvRankCount + blockIdx.x, *sharedSendRecvRankCount);
|
||||
}
|
||||
}
|
||||
|
||||
template <int kThreadsGroupSize>
|
||||
__global__ void computeSendRecvRankCountKernel(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
|
||||
int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds, int* sendRankCount,
|
||||
int* recvRankCount)
|
||||
{
|
||||
static_assert(kThreadsGroupSize == 1 || kThreadsGroupSize == 2 || kThreadsGroupSize == 4 || kThreadsGroupSize == 8
|
||||
|| kThreadsGroupSize == 16 || kThreadsGroupSize == 32,
|
||||
"Only 1, 2, 4, 8, 16, 32 threads group size supported now.");
|
||||
__shared__ int sharedSendRecvRankCount;
|
||||
if (blockIdx.y == 0)
|
||||
{
|
||||
// compute send rank count
|
||||
computeSendRecvRankCountDevice<true, kThreadsGroupSize>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
|
||||
realRankTokenCountCumSum, gatheredTargetRankIds, &sharedSendRecvRankCount, sendRankCount);
|
||||
}
|
||||
else
|
||||
{
|
||||
// compute recv rank count
|
||||
computeSendRecvRankCountDevice<false, kThreadsGroupSize>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
|
||||
realRankTokenCountCumSum, gatheredTargetRankIds, &sharedSendRecvRankCount, recvRankCount);
|
||||
}
|
||||
}
|
||||
|
||||
void computeSendRecvRankCount(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
|
||||
int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds, int* sendRankCount,
|
||||
int* recvRankCount, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(expertParallelInfo.topK <= 32, "Only topK less than or equal to 32 supported now.");
|
||||
int threadsPerBlock = 1024;
|
||||
auto* kernelPtr = computeSendRecvRankCountKernel<32>;
|
||||
if (expertParallelInfo.topK <= 1)
|
||||
{
|
||||
kernelPtr = computeSendRecvRankCountKernel<1>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 2)
|
||||
{
|
||||
kernelPtr = computeSendRecvRankCountKernel<2>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 4)
|
||||
{
|
||||
kernelPtr = computeSendRecvRankCountKernel<4>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 8)
|
||||
{
|
||||
kernelPtr = computeSendRecvRankCountKernel<8>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 16)
|
||||
{
|
||||
kernelPtr = computeSendRecvRankCountKernel<16>;
|
||||
}
|
||||
dim3 block(worldInfo.epSize, 2, 1);
|
||||
kernelPtr<<<block, threadsPerBlock, 0, stream>>>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
|
||||
realRankTokenCountCumSum, gatheredTargetRankIds, sendRankCount, recvRankCount);
|
||||
}
|
||||
|
||||
template <int kThreadsPerBlock>
|
||||
__global__ void inplaceSendRecvRankCumSumKernel(MoeEpWorldInfo worldInfo, int* sendRankCount, int* recvRankCount)
|
||||
{
|
||||
int* inputOutputPtr = blockIdx.x == 0 ? sendRankCount : recvRankCount;
|
||||
typedef cub::BlockScan<int, kThreadsPerBlock> BlockScan;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int threadData = tid < worldInfo.epSize ? inputOutputPtr[tid] : 0;
|
||||
|
||||
BlockScan(temp_storage).InclusiveSum(threadData, threadData);
|
||||
if (tid < worldInfo.epSize)
|
||||
{
|
||||
inputOutputPtr[tid] = threadData;
|
||||
}
|
||||
}
|
||||
|
||||
void inplaceSendRecvRankCumSum(MoeEpWorldInfo worldInfo, int* sendRankCount, int* recvRankCount, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(worldInfo.epSize <= 1024, "Only worldInfo.epSize less than or equal to 1024 supported now.");
|
||||
auto* kernelPtr = inplaceSendRecvRankCumSumKernel<1024>;
|
||||
int blockSize = 1024;
|
||||
if (worldInfo.epSize <= 32)
|
||||
{
|
||||
kernelPtr = inplaceSendRecvRankCumSumKernel<32>;
|
||||
blockSize = 32;
|
||||
}
|
||||
else if (worldInfo.epSize <= 64)
|
||||
{
|
||||
kernelPtr = inplaceSendRecvRankCumSumKernel<64>;
|
||||
blockSize = 64;
|
||||
}
|
||||
else if (worldInfo.epSize <= 128)
|
||||
{
|
||||
kernelPtr = inplaceSendRecvRankCumSumKernel<128>;
|
||||
blockSize = 128;
|
||||
}
|
||||
else if (worldInfo.epSize <= 256)
|
||||
{
|
||||
kernelPtr = inplaceSendRecvRankCumSumKernel<256>;
|
||||
blockSize = 256;
|
||||
}
|
||||
else if (worldInfo.epSize <= 512)
|
||||
{
|
||||
kernelPtr = inplaceSendRecvRankCumSumKernel<512>;
|
||||
blockSize = 512;
|
||||
}
|
||||
kernelPtr<<<2, blockSize, 0, stream>>>(worldInfo, sendRankCount, recvRankCount);
|
||||
}
|
||||
|
||||
template <bool isSend, int kThreadsGroupSize, int kThreadsPerBlock>
|
||||
__inline__ __device__ void computeSendRecvIndicesDevice(MoeEpWorldInfo worldInfo,
|
||||
MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank, int const* realRankTokenCountCumSum,
|
||||
int const* gatheredTargetRankIds, int const* sendRecvCumSum,
|
||||
int* sendRecvIndices, // send or receive
|
||||
int* localGatherIndices, // receive only
|
||||
int* backwardRecvRankLocalIndices, // send only
|
||||
int* sharedSendRecvRankStart, typename cub::BlockScan<int, kThreadsPerBlock>::TempStorage& tempStorage)
|
||||
{
|
||||
cg::thread_block_tile<kThreadsGroupSize> tile = cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
|
||||
int laneInTile = tile.thread_rank();
|
||||
int tileId = threadIdx.x / kThreadsGroupSize;
|
||||
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
|
||||
|
||||
int topK = expertParallelInfo.topK;
|
||||
int epRank = worldInfo.epRank;
|
||||
int epSize = worldInfo.epSize;
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
*sharedSendRecvRankStart = blockIdx.x == 0 ? 0 : sendRecvCumSum[blockIdx.x - 1];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
int readRank = isSend ? epRank : blockIdx.x;
|
||||
int compareRankId = isSend ? blockIdx.x : epRank;
|
||||
int readRankStart = readRank * maxTokenCountPerRank;
|
||||
int const* readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK;
|
||||
int readRankTokenCount = maxTokenCountPerRank;
|
||||
if (realRankTokenCountCumSum != nullptr)
|
||||
{
|
||||
readRankStart = readRank == 0 ? 0 : realRankTokenCountCumSum[readRank - 1];
|
||||
readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK;
|
||||
readRankTokenCount = realRankTokenCountCumSum[readRank] - readRankStart;
|
||||
}
|
||||
|
||||
for (int blockStartId = blockIdx.z * tileCountPerBlock; blockStartId < readRankTokenCount;
|
||||
blockStartId += tileCountPerBlock * gridDim.z)
|
||||
{
|
||||
int stepStartIndice = *sharedSendRecvRankStart;
|
||||
int i = blockStartId + tileId;
|
||||
int targetRankId
|
||||
= (laneInTile < topK && i < readRankTokenCount) ? readRankTargetRankIds[i * topK + laneInTile] : epSize;
|
||||
bool rankMatched = (targetRankId == compareRankId);
|
||||
bool hasRankMatched = tile.any(rankMatched);
|
||||
unsigned int laneMask = tile.ballot(rankMatched);
|
||||
int lowestLane = __ffs(laneMask) - 1;
|
||||
int isMatchedLane = (hasRankMatched && laneInTile == lowestLane) ? 1 : 0;
|
||||
int indice;
|
||||
typedef cub::BlockScan<int, kThreadsPerBlock> BlockScan;
|
||||
BlockScan(tempStorage).ExclusiveSum(isMatchedLane, indice);
|
||||
indice += stepStartIndice;
|
||||
__syncthreads();
|
||||
|
||||
if (isMatchedLane == 1)
|
||||
{
|
||||
atomicAdd_block(sharedSendRecvRankStart, 1);
|
||||
if (isSend)
|
||||
{
|
||||
sendRecvIndices[indice] = i;
|
||||
backwardRecvRankLocalIndices[indice] = i * topK + lowestLane;
|
||||
}
|
||||
else
|
||||
{
|
||||
sendRecvIndices[indice] = indice;
|
||||
localGatherIndices[indice] = readRankStart + i;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <int kThreadsGroupSize, int kThreadsPerBlock>
|
||||
__global__ void computeSendRecvIndicesKernel(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
|
||||
int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds,
|
||||
int const* sendRankCountCumSum, int const* recvRankCountCumSum, int* localGatherIndices, int* sendRankLocalIndices,
|
||||
int* recvRankLocalIndices, int* backwardRecvRankLocalIndices)
|
||||
{
|
||||
static_assert(kThreadsGroupSize == 1 || kThreadsGroupSize == 2 || kThreadsGroupSize == 4 || kThreadsGroupSize == 8
|
||||
|| kThreadsGroupSize == 16 || kThreadsGroupSize == 32,
|
||||
"Only 1, 2, 4, 8, 16, 32 threads group size supported now.");
|
||||
__shared__ int sharedSendRecvRankStart;
|
||||
__shared__ typename cub::BlockScan<int, kThreadsPerBlock>::TempStorage tempStorage;
|
||||
if (blockIdx.y == 0)
|
||||
{
|
||||
// compute send rank count
|
||||
computeSendRecvIndicesDevice<true, kThreadsGroupSize, kThreadsPerBlock>(worldInfo, expertParallelInfo,
|
||||
maxTokenCountPerRank, realRankTokenCountCumSum, gatheredTargetRankIds, sendRankCountCumSum,
|
||||
sendRankLocalIndices, localGatherIndices, backwardRecvRankLocalIndices, &sharedSendRecvRankStart,
|
||||
tempStorage);
|
||||
}
|
||||
else
|
||||
{
|
||||
// compute recv rank count
|
||||
computeSendRecvIndicesDevice<false, kThreadsGroupSize, kThreadsPerBlock>(worldInfo, expertParallelInfo,
|
||||
maxTokenCountPerRank, realRankTokenCountCumSum, gatheredTargetRankIds, recvRankCountCumSum,
|
||||
recvRankLocalIndices, localGatherIndices, backwardRecvRankLocalIndices, &sharedSendRecvRankStart,
|
||||
tempStorage);
|
||||
}
|
||||
}
|
||||
|
||||
void computeSendRecvIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
|
||||
int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds,
|
||||
int const* sendRankCountCumSum, int const* recvRankCountCumSum, int* localGatherIndices, int* sendRankLocalIndices,
|
||||
int* recvRankLocalIndices, int* backwardRecvRankLocalIndices, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(expertParallelInfo.topK <= 32, "Only topK less than or equal to 32 supported now.");
|
||||
int threadsPerBlock = 1024;
|
||||
auto* kernelPtr = computeSendRecvIndicesKernel<32, 1024>;
|
||||
if (expertParallelInfo.topK <= 1)
|
||||
{
|
||||
kernelPtr = computeSendRecvIndicesKernel<1, 1024>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 2)
|
||||
{
|
||||
kernelPtr = computeSendRecvIndicesKernel<2, 1024>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 4)
|
||||
{
|
||||
kernelPtr = computeSendRecvIndicesKernel<4, 1024>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 8)
|
||||
{
|
||||
kernelPtr = computeSendRecvIndicesKernel<8, 1024>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 16)
|
||||
{
|
||||
kernelPtr = computeSendRecvIndicesKernel<16, 1024>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 32)
|
||||
{
|
||||
kernelPtr = computeSendRecvIndicesKernel<32, 1024>;
|
||||
}
|
||||
dim3 block(worldInfo.epSize, 2, 1);
|
||||
kernelPtr<<<block, threadsPerBlock, 0, stream>>>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
|
||||
realRankTokenCountCumSum, gatheredTargetRankIds, sendRankCountCumSum, recvRankCountCumSum, localGatherIndices,
|
||||
sendRankLocalIndices, recvRankLocalIndices, backwardRecvRankLocalIndices);
|
||||
}
|
||||
|
||||
__global__ void moeAllToAllMemsetKernel(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
|
||||
int maxTokenCountPerRank, int* sendRankCountCumSum, int* recvRankCountCumSum, int* localGatherIndices,
|
||||
int* sendRankLocalIndices, int* recvRankLocalIndices, int* backwardRecvRankLocalIndices)
|
||||
{
|
||||
int maxSendRanksPerToken = std::max(worldInfo.epSize, expertParallelInfo.topK);
|
||||
int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int maxRankRecvTokenCount = maxTokenCountPerRank * worldInfo.epSize;
|
||||
int maxRankSendTokenCount = maxTokenCountPerRank * maxSendRanksPerToken;
|
||||
if (idx < worldInfo.epSize)
|
||||
{
|
||||
sendRankCountCumSum[idx] = 0;
|
||||
recvRankCountCumSum[idx] = 0;
|
||||
}
|
||||
if (idx < maxRankRecvTokenCount)
|
||||
{
|
||||
localGatherIndices[idx] = -1;
|
||||
recvRankLocalIndices[idx] = -1;
|
||||
}
|
||||
if (idx < maxRankSendTokenCount)
|
||||
{
|
||||
sendRankLocalIndices[idx] = -1;
|
||||
backwardRecvRankLocalIndices[idx] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
void moeAllToAllMemset(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
|
||||
int* sendRankCountCumSum, int* recvRankCountCumSum, int* localGatherIndices, int* sendRankLocalIndices,
|
||||
int* recvRankLocalIndices, int* backwardRecvRankLocalIndices, cudaStream_t stream)
|
||||
{
|
||||
int maxSendRanksPerToken = std::max(worldInfo.epSize, expertParallelInfo.topK);
|
||||
int maxRankRecvTokenCount = maxTokenCountPerRank * worldInfo.epSize;
|
||||
int maxRankSendTokenCount = maxTokenCountPerRank * maxSendRanksPerToken;
|
||||
int maxEltCount = std::max<int>(maxRankRecvTokenCount, maxRankSendTokenCount);
|
||||
maxEltCount = std::max<int>(maxEltCount, worldInfo.epSize);
|
||||
static constexpr int kBlockSize = 256;
|
||||
int blockCount = (maxEltCount + kBlockSize - 1) / kBlockSize;
|
||||
dim3 grid(blockCount, 1);
|
||||
moeAllToAllMemsetKernel<<<grid, kBlockSize, 0, stream>>>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
|
||||
sendRankCountCumSum, recvRankCountCumSum, localGatherIndices, sendRankLocalIndices, recvRankLocalIndices,
|
||||
backwardRecvRankLocalIndices);
|
||||
}
|
||||
|
||||
void moeAllToAllPrepareIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
|
||||
int maxTokenCountPerRank, int const* gatheredTargetRankIds, int const* realRankTokenCountCumSum,
|
||||
// indices of gatheredTargetRankIds that has the local rank in topK
|
||||
int* localGatherIndices, // max length = maxTokenCountPerRank * worldInfo.epSize when all ranks send to current
|
||||
// rank
|
||||
int* sendRankCountCumSum, // max length = worldInfo.epSize
|
||||
int* sendRankLocalIndices, // max length = maxTokenCountPerRank * expertParallelInfo.expertCount when current rank
|
||||
// has maxTokenCountPerRank tokens to send and all has expertCount dest
|
||||
int* recvRankCountCumSum, // max length = worldInfo.epSize
|
||||
int* recvRankLocalIndices, // max length = maxTokenCountPerRank * worldInfo.epSize when all ranks send to current
|
||||
// rank
|
||||
// the rankCountCumSum of combineRecv should be the same as sendRankCountCumSum
|
||||
int*
|
||||
backwardRecvRankLocalIndices, // max length = maxTokenCountPerRank * expertParallelInfo.expertCount when current
|
||||
// rank has maxTokenCountPerRank tokens to send and all has expertCount dest
|
||||
cudaStream_t stream)
|
||||
{
|
||||
moeAllToAllMemset(worldInfo, expertParallelInfo, maxTokenCountPerRank, sendRankCountCumSum, recvRankCountCumSum,
|
||||
localGatherIndices, sendRankLocalIndices, recvRankLocalIndices, backwardRecvRankLocalIndices, stream);
|
||||
TLLM_CHECK_WITH_INFO(worldInfo.epSize <= 1024, "Only worldInfo.epSize less than or equal to 1024 supported now.");
|
||||
computeSendRecvRankCount(worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum,
|
||||
gatheredTargetRankIds, sendRankCountCumSum, recvRankCountCumSum, stream);
|
||||
inplaceSendRecvRankCumSum(worldInfo, sendRankCountCumSum, recvRankCountCumSum, stream);
|
||||
computeSendRecvIndices(worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum,
|
||||
gatheredTargetRankIds, sendRankCountCumSum, recvRankCountCumSum, localGatherIndices, sendRankLocalIndices,
|
||||
recvRankLocalIndices, backwardRecvRankLocalIndices, stream);
|
||||
}
|
||||
|
||||
template <int kThreadsGroupSize>
|
||||
__global__ void moeLocalGatherDevice(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
|
||||
int maxTokenCountPerRank, int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices,
|
||||
int const* gatheredExpertIds, float const* gatheredScales, int* localExpertIds, float* localScales)
|
||||
{
|
||||
cg::thread_block_tile<kThreadsGroupSize> tile = cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
|
||||
int laneInTile = tile.thread_rank();
|
||||
int tileId = threadIdx.x / kThreadsGroupSize;
|
||||
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
|
||||
|
||||
int epSize = worldInfo.epSize;
|
||||
int rankTokenCount = recvRankCountCumSum[epSize - 1];
|
||||
if (laneInTile >= expertParallelInfo.topK)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for (int index = tileId + blockIdx.x * tileCountPerBlock; index < localMaxTokenCount;
|
||||
index += tileCountPerBlock * gridDim.x)
|
||||
{
|
||||
int localTokenIndice = localGatherIndices[index];
|
||||
int expertId = index < rankTokenCount
|
||||
? gatheredExpertIds[localTokenIndice * expertParallelInfo.topK + laneInTile]
|
||||
: expertParallelInfo.expertCount;
|
||||
localExpertIds[index * expertParallelInfo.topK + laneInTile] = expertId;
|
||||
if (gatheredScales)
|
||||
{
|
||||
float scale = index < rankTokenCount
|
||||
? gatheredScales[localTokenIndice * expertParallelInfo.topK + laneInTile]
|
||||
: 0.0f;
|
||||
localScales[index * expertParallelInfo.topK + laneInTile] = scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void moeLocalGather(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
|
||||
int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices, int const* gatheredExpertIds,
|
||||
float const* gatheredScales, int* localExpertIds, float* localScales, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(expertParallelInfo.topK <= 32, "Only topK less than or equal to 32 supported now.");
|
||||
auto* kernelPtr = moeLocalGatherDevice<32>;
|
||||
int paddedTopK = 32;
|
||||
if (expertParallelInfo.topK <= 1)
|
||||
{
|
||||
paddedTopK = 1;
|
||||
kernelPtr = moeLocalGatherDevice<1>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 2)
|
||||
{
|
||||
paddedTopK = 2;
|
||||
kernelPtr = moeLocalGatherDevice<2>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 4)
|
||||
{
|
||||
paddedTopK = 4;
|
||||
kernelPtr = moeLocalGatherDevice<4>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 8)
|
||||
{
|
||||
paddedTopK = 8;
|
||||
kernelPtr = moeLocalGatherDevice<8>;
|
||||
}
|
||||
else if (expertParallelInfo.topK <= 16)
|
||||
{
|
||||
paddedTopK = 16;
|
||||
kernelPtr = moeLocalGatherDevice<16>;
|
||||
}
|
||||
|
||||
int threadsPerBlock = 512;
|
||||
int tokenPerBlock = threadsPerBlock / paddedTopK;
|
||||
int blockCount = (localMaxTokenCount + tokenPerBlock - 1) / tokenPerBlock * 2;
|
||||
|
||||
kernelPtr<<<blockCount, threadsPerBlock, 0, stream>>>(worldInfo, expertParallelInfo, maxTokenCountPerRank,
|
||||
localMaxTokenCount, recvRankCountCumSum, localGatherIndices, gatheredExpertIds, gatheredScales, localExpertIds,
|
||||
localScales);
|
||||
}
|
||||
|
||||
int AllToAllChannelCommunicatorBase::maxSmCount = -1;
|
||||
bool AllToAllChannelCommunicatorBase::maxSmCountUsed = false;
|
||||
|
||||
void setMaxUsableSmCount(int smCount)
|
||||
{
|
||||
AllToAllChannelCommunicatorBase::setMaxUsableSmCount(smCount);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels
|
||||
@ -1,268 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#define ALIGN_256 __align__(256)
|
||||
#else
|
||||
#define ALIGN_256 alignas(256)
|
||||
#endif
|
||||
|
||||
struct ALIGN_256 MoeCommFifoConnInfo
|
||||
{
|
||||
volatile uint64_t head; // write position
|
||||
volatile uint64_t tail; // read position
|
||||
};
|
||||
|
||||
constexpr int WARP_SIZE = 32;
|
||||
constexpr uint32_t WARP_MASK = 0xffffffff;
|
||||
|
||||
constexpr int RECV_FIFO_DEPTH = 8;
|
||||
constexpr int RECV_FIFO_ENTRY_BYTES = 256 * 1024;
|
||||
constexpr int RECV_FIFO_ENTRY_U64 = RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t);
|
||||
constexpr int RECV_FIFO_TOTAL_BYTES = RECV_FIFO_DEPTH * RECV_FIFO_ENTRY_BYTES;
|
||||
constexpr int RECV_FIFO_TOTAL_U64 = RECV_FIFO_TOTAL_BYTES / sizeof(uint64_t);
|
||||
|
||||
class AllToAllChannelCommunicatorBase
|
||||
{
|
||||
public:
|
||||
static constexpr int GROUP_COUNT_PER_BLOCK = 8;
|
||||
static_assert(GROUP_COUNT_PER_BLOCK <= 8, "GROUP_COUNT_PER_BLOCK must be less than or equal to 8");
|
||||
static constexpr int WARP_PER_GROUP = 2;
|
||||
static constexpr int U64_DATA_REG_PER_THREAD = 8;
|
||||
// A packet is a warp-sized chunk of data that is sent or received in one go,
|
||||
// but may be split into multiple 64-bit registers, the number of which is U64_DATA_REG_PER_THREAD.
|
||||
static constexpr int PACKET_SIZE_IN_U64 = WARP_SIZE * U64_DATA_REG_PER_THREAD;
|
||||
static constexpr int PACKET_SIZE_IN_BYTES = PACKET_SIZE_IN_U64 * sizeof(uint64_t);
|
||||
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 = (WARP_SIZE - 2) * U64_DATA_REG_PER_THREAD;
|
||||
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET = DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 * sizeof(uint64_t);
|
||||
static constexpr int U64_ELT_COUNT_PER_PACKET = PACKET_SIZE_IN_BYTES / sizeof(uint64_t);
|
||||
|
||||
static constexpr int PACKET_COUNT_PER_FIFO_ENTRY = RECV_FIFO_ENTRY_BYTES / PACKET_SIZE_IN_BYTES;
|
||||
|
||||
static constexpr int GROUP_MAX_INDICE_COUNT
|
||||
= RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t) / (WARP_SIZE * U64_DATA_REG_PER_THREAD);
|
||||
|
||||
struct GroupSharedBuffer
|
||||
{
|
||||
int groupIndiceBuffer[GROUP_MAX_INDICE_COUNT];
|
||||
int groupStartIndice;
|
||||
int groupEndIndice;
|
||||
};
|
||||
|
||||
static void setMaxUsableSmCount(int maxUsableSmCount)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(AllToAllChannelCommunicatorBase::maxSmCountUsed == false,
|
||||
"setMaxUsableSmCount can be called only before it is used");
|
||||
int smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
if (maxUsableSmCount > smCount)
|
||||
{
|
||||
TLLM_LOG_WARNING("setMaxUsableSmCount, maxUsableSmCount=%d, larger than smCount=%d, using smCount instead",
|
||||
maxUsableSmCount, smCount);
|
||||
maxUsableSmCount = smCount;
|
||||
}
|
||||
AllToAllChannelCommunicatorBase::maxSmCount = maxUsableSmCount;
|
||||
}
|
||||
|
||||
static int getMaxUsableSmCount()
|
||||
{
|
||||
AllToAllChannelCommunicatorBase::maxSmCountUsed = true;
|
||||
if (AllToAllChannelCommunicatorBase::maxSmCount == -1)
|
||||
{
|
||||
int smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
AllToAllChannelCommunicatorBase::maxSmCount = smCount;
|
||||
}
|
||||
return AllToAllChannelCommunicatorBase::maxSmCount;
|
||||
}
|
||||
|
||||
static int computeMoeCommChannelCount(int epSize)
|
||||
{
|
||||
int smCount = getMaxUsableSmCount();
|
||||
int blockCountPerChannel = (epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK;
|
||||
blockCountPerChannel *= 2; // for send and recv
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
blockCountPerChannel <= smCount, "GPU should support at lease one channel, usableSmCount=%d", smCount);
|
||||
int perferredChannel = smCount / 2 / blockCountPerChannel; // use half SMs for communication
|
||||
int channelCount = std::max(perferredChannel, 1); // at lease one channel
|
||||
return channelCount;
|
||||
}
|
||||
|
||||
static int getMoeCommChannelCount(int epSize)
|
||||
{
|
||||
static std::map<int, int> channelCountMap{};
|
||||
auto iter = channelCountMap.find(epSize);
|
||||
if (iter == channelCountMap.end())
|
||||
{
|
||||
auto channelCount = AllToAllChannelCommunicatorBase::computeMoeCommChannelCount(epSize);
|
||||
channelCountMap[epSize] = channelCount;
|
||||
return channelCount;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
static dim3 getLaunchBlockDim()
|
||||
{
|
||||
return dim3(WARP_SIZE * WARP_PER_GROUP, GROUP_COUNT_PER_BLOCK);
|
||||
}
|
||||
|
||||
static dim3 getLaunchGridDim(int epSize)
|
||||
{
|
||||
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
|
||||
return dim3((epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK, channelCount, 2);
|
||||
}
|
||||
|
||||
protected:
|
||||
static int maxSmCount;
|
||||
static bool maxSmCountUsed;
|
||||
};
|
||||
|
||||
inline size_t getMoeCommWorkspaceSize(int epSize)
|
||||
{
|
||||
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
|
||||
return RECV_FIFO_TOTAL_BYTES * epSize * channelCount + sizeof(MoeCommFifoConnInfo) * epSize * channelCount;
|
||||
}
|
||||
|
||||
struct MoeEpWorldInfo
|
||||
{
|
||||
int epSize;
|
||||
int epRank;
|
||||
};
|
||||
|
||||
struct MoeExpertParallelInfo
|
||||
{
|
||||
int expertCount = -1;
|
||||
int topK = 1;
|
||||
};
|
||||
|
||||
struct SendRecvDataInfo
|
||||
{
|
||||
int vectorSizeInU64;
|
||||
// pre-computed at host side for GPU kernel
|
||||
int dataPacketCountPerVector;
|
||||
int vectorCountPerFifoEntry;
|
||||
|
||||
void ComputeDataPacketCountPerVector()
|
||||
{
|
||||
dataPacketCountPerVector
|
||||
= (vectorSizeInU64 * sizeof(uint64_t) + AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET - 1)
|
||||
/ AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET;
|
||||
}
|
||||
|
||||
void ComputeVectorCountPerFifoEntry()
|
||||
{
|
||||
ComputeDataPacketCountPerVector();
|
||||
vectorCountPerFifoEntry
|
||||
= AllToAllChannelCommunicatorBase::PACKET_COUNT_PER_FIFO_ENTRY / dataPacketCountPerVector;
|
||||
}
|
||||
|
||||
void DoPreCompute()
|
||||
{
|
||||
ComputeDataPacketCountPerVector();
|
||||
ComputeVectorCountPerFifoEntry();
|
||||
assert(vectorCountPerFifoEntry <= AllToAllChannelCommunicatorBase::GROUP_MAX_INDICE_COUNT);
|
||||
}
|
||||
};
|
||||
|
||||
// struct holding Send/Recv data pointer and its displacement information.
|
||||
struct SendRecvDispls
|
||||
{
|
||||
uint64_t* dataPtr;
|
||||
int const* rankCountCumSum; // length = epSize
|
||||
int const* rankLocalIndices; // length = rankCountCumSum[epRank] - rankCountCumSum[epRank - 1] if epRank > 0 else
|
||||
// rankCountCumSum[epRank]
|
||||
int vectorStrideInU64;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__inline__ __device__ int getCount(int rank) const
|
||||
{
|
||||
return rank == 0 ? rankCountCumSum[rank] : rankCountCumSum[rank] - rankCountCumSum[rank - 1];
|
||||
}
|
||||
|
||||
__inline__ __device__ int getRankStart(int rank) const
|
||||
{
|
||||
return rank == 0 ? 0 : rankCountCumSum[rank - 1];
|
||||
}
|
||||
|
||||
__inline__ __device__ int getRealVectorIndice(int globalVectorIndex) const
|
||||
{
|
||||
return rankLocalIndices[globalVectorIndex];
|
||||
}
|
||||
|
||||
__inline__ __device__ uint64_t* getVectorDataPtr(int realVectorIndex) const
|
||||
{
|
||||
return dataPtr + realVectorIndex * vectorStrideInU64;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
struct MoeCommWorkspace
|
||||
{
|
||||
uint64_t* workspacePtr;
|
||||
size_t rankStrideInU64;
|
||||
#ifdef __CUDACC__
|
||||
__inline__ __device__ uint64_t* getFifoBasePtr(
|
||||
bool isSender, int epRank, int peerRank, int channel, int channelCount) const
|
||||
{
|
||||
// fifo itself is in receiver's side.
|
||||
if (isSender)
|
||||
{
|
||||
return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
|
||||
}
|
||||
else
|
||||
{
|
||||
return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
|
||||
}
|
||||
}
|
||||
|
||||
__inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo(
|
||||
bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const
|
||||
{
|
||||
// fifoInfo is in sender's side.
|
||||
uint64_t* fifoInfoPtrU64 = workspacePtr + RECV_FIFO_TOTAL_U64 * channelCount * epSize;
|
||||
int strideIndice = isSender ? epRank : peerRank;
|
||||
int fifoInfoIndice = isSender ? peerRank : epRank;
|
||||
fifoInfoPtrU64 += strideIndice * rankStrideInU64;
|
||||
MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*) fifoInfoPtrU64;
|
||||
return fifoInfoPtr + fifoInfoIndice * channelCount + channel;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
void setMaxUsableSmCount(int smCount);
|
||||
|
||||
void moeAllToAll(MoeEpWorldInfo worldInfo, SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls,
|
||||
SendRecvDispls recvDispls, MoeCommWorkspace workspace, cudaStream_t stream);
|
||||
|
||||
void moeAllToAllPrepareIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
|
||||
int maxTokenCountPerRank, int const* gatheredTargetRankIds, int const* realRankTokenCountCumSum,
|
||||
int* localGatheredIndices, // indices of gatheredTargetRankIds that has the local rank in topK
|
||||
int* sendRankCountCumSum, int* sendRankLocalIndices, int* recvRankCountCumSum, int* recvRankLocalIndices,
|
||||
// the rankCountCumSum of combineRecv should be the same as sendRankCountCumSum
|
||||
int* backwardRecvRankLocalIndices, cudaStream_t stream);
|
||||
|
||||
void moeLocalGather(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
|
||||
int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices, int const* gatheredExpertIds,
|
||||
float const* gatheredScales, int* localExpertIds, float* localScales, cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels
|
||||
47
cpp/tensorrt_llm/kernels/moeCommKernelsCommon.h
Normal file
47
cpp/tensorrt_llm/kernels/moeCommKernelsCommon.h
Normal file
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#define ALIGN_256 __align__(256)
|
||||
#else
|
||||
#define ALIGN_256 alignas(256)
|
||||
#endif
|
||||
|
||||
constexpr int WARP_SIZE = 32;
|
||||
constexpr uint32_t WARP_MASK = 0xffffffff;
|
||||
|
||||
struct MoeEpWorldInfo
|
||||
{
|
||||
int epSize;
|
||||
int epRank;
|
||||
};
|
||||
|
||||
struct MoeExpertParallelInfo
|
||||
{
|
||||
int expertCount = -1;
|
||||
int topK = 1;
|
||||
};
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -49,86 +49,6 @@ __device__ __forceinline__ int ld_acquire_sys_global_int(int volatile* ptr)
|
||||
return ret;
|
||||
}
|
||||
|
||||
class StepCommunicatorBase
|
||||
{
|
||||
public:
|
||||
static constexpr int META_SIZE = sizeof(MoeCommFifoConnInfo);
|
||||
|
||||
__device__ __inline__ StepCommunicatorBase(MoeCommFifoConnInfo* fifoConnInfo)
|
||||
: fifoConnInfo(fifoConnInfo)
|
||||
, localCachedHead(0)
|
||||
, localCachedTail(0)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void reset()
|
||||
{
|
||||
fifoConnInfo->head = 0;
|
||||
fifoConnInfo->tail = 0;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void releaseSendStep()
|
||||
{
|
||||
localCachedHead += 1;
|
||||
st_release_sys_global(&(fifoConnInfo->head), uint64_t(localCachedHead));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void releaseRecvStep()
|
||||
{
|
||||
localCachedTail += 1;
|
||||
st_release_sys_global(&(fifoConnInfo->tail), uint64_t(localCachedTail));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint64_t acquireTail()
|
||||
{
|
||||
uint64_t tail = ld_acquire_sys_global(&(fifoConnInfo->tail));
|
||||
localCachedTail = tail;
|
||||
return tail;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint64_t acquireHead()
|
||||
{
|
||||
uint64_t head = ld_acquire_sys_global(&(fifoConnInfo->head));
|
||||
localCachedHead = head;
|
||||
return head;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int acquireNewSendStep()
|
||||
{
|
||||
|
||||
int64_t tail;
|
||||
do
|
||||
{
|
||||
tail = acquireTail();
|
||||
} while (localCachedHead >= tail + STEP_DEPTH);
|
||||
// depth = 2, head = 1, tail = 0 , ok
|
||||
// depth = 2, head = 2, tail = 0, should wait
|
||||
|
||||
return localCachedHead % STEP_DEPTH;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int acquireNewRecvStep()
|
||||
{
|
||||
int64_t head = 0;
|
||||
do
|
||||
{
|
||||
head = acquireHead();
|
||||
} while (localCachedTail >= head);
|
||||
|
||||
return localCachedTail % STEP_DEPTH;
|
||||
}
|
||||
|
||||
public:
|
||||
MoeCommFifoConnInfo* fifoConnInfo;
|
||||
uint64_t localCachedHead;
|
||||
uint64_t localCachedTail;
|
||||
int rank;
|
||||
int targetRank;
|
||||
};
|
||||
|
||||
// Use MoeCommFifoConnInfo as media to transfer a counter number.
|
||||
// Use the "head" field as flag.
|
||||
// Use the "tail" field to transfer the counter number.
|
||||
class CounterCommunicator
|
||||
{
|
||||
public:
|
||||
@ -137,23 +57,23 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void releaseValue(uint64_t value)
|
||||
__forceinline__ __device__ void releaseValue(uint64_t value, int index)
|
||||
{
|
||||
// Avoid block on 0
|
||||
st_release_sys_global(&(fifoConnInfo->count), value + 1);
|
||||
fifoConnInfo->values[index] = value + 1;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint64_t acquireValue()
|
||||
__forceinline__ __device__ uint64_t acquireValue(int index)
|
||||
{
|
||||
uint64_t localCount = 0;
|
||||
uint64_t localValue = 0;
|
||||
do
|
||||
{
|
||||
localCount = ld_acquire_sys_global(&(fifoConnInfo->count));
|
||||
} while (localCount == 0);
|
||||
localValue = fifoConnInfo->values[index];
|
||||
} while (localValue == 0);
|
||||
|
||||
fifoConnInfo->count = 0; // reset the count
|
||||
fifoConnInfo->values[index] = 0; // reset the value
|
||||
|
||||
return localCount - 1;
|
||||
return localValue - 1;
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -161,15 +81,16 @@ protected:
|
||||
};
|
||||
|
||||
template <int kThreadsGroupSize>
|
||||
__device__ __forceinline__ void computeCountAndSend(int* experts, int tokenCount, int* sharedSendRecvRankCount,
|
||||
int* sendCounts, int* sendIndiceWorkspace, int* backwardIndiceWorkspace, MoeCommWorkspace workspace,
|
||||
int maxTokenCountPerRank, int expertCount, int topK, int epRank, int epSize)
|
||||
__device__ __forceinline__ void computeCountAndSendStatics(int* experts, int tokenCount, int* sharedSendRecvRankCount,
|
||||
int* sendCounts, int* sendIndiceWorkspace, int* backwardIndiceWorkspace, int* expertStatics,
|
||||
MoeCommWorkspace workspace, int maxTokenCountPerRank, int slotCount, int expertCount, int topK, int epRank,
|
||||
int epSize)
|
||||
{
|
||||
cg::thread_block_tile<kThreadsGroupSize> tile = cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
|
||||
int laneInTile = tile.thread_rank();
|
||||
int tileId = threadIdx.x / kThreadsGroupSize;
|
||||
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
|
||||
int expertCountPerRank = expertCount / epSize;
|
||||
int expertCountPerRank = slotCount / epSize;
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
*sharedSendRecvRankCount = 0;
|
||||
@ -201,18 +122,24 @@ __device__ __forceinline__ void computeCountAndSend(int* experts, int tokenCount
|
||||
tile.sync();
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0)
|
||||
|
||||
CounterCommunicator counter(workspace.getFifoConnInfo(true, epRank, targetRankId, 0, epSize, 1));
|
||||
|
||||
int communicationCount = expertStatics == nullptr ? 1 : expertCount + 1;
|
||||
for (int i = threadIdx.x; i < communicationCount; i += blockDim.x)
|
||||
{
|
||||
CounterCommunicator counter(workspace.getFifoConnInfo(true, epRank, targetRankId, 0, epSize, 1));
|
||||
int count = *(sharedSendRecvRankCount);
|
||||
// printf("sendRecvCount: %d, rankId: %d, targetRankId: %d\n", count, rankId, targetRankId);
|
||||
counter.releaseValue(uint64_t(count));
|
||||
*(sendCounts + targetRankId) = count;
|
||||
int value = i == 0 ? *(sharedSendRecvRankCount) : *(expertStatics + i - 1);
|
||||
counter.releaseValue(value, i);
|
||||
if (i == 0)
|
||||
{
|
||||
*(sendCounts + targetRankId) = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCounts, int* sharedCountsBase,
|
||||
MoeCommWorkspace workspace, int maxTokenCountPerRank, int rankId, int rankCount)
|
||||
__device__ __forceinline__ void recvCountAndStatics(int* recvIndiceWorkspace, int* recvCounts, int* sharedCountsBase,
|
||||
int* gatheredExpertStatics, MoeCommWorkspace workspace, int expertCount, int maxTokenCountPerRank, int rankId,
|
||||
int rankCount)
|
||||
{
|
||||
int rankOffset = threadIdx.x / THREADS_PER_PIPELINE;
|
||||
if (rankOffset >= PIPELINE_PER_CTA)
|
||||
@ -229,18 +156,25 @@ __device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCou
|
||||
cg::thread_block_tile<THREADS_PER_PIPELINE> rankTile
|
||||
= cg::tiled_partition<THREADS_PER_PIPELINE>(cg::this_thread_block());
|
||||
int* localRecvIndice = recvIndiceWorkspace + targetRankId * maxTokenCountPerRank;
|
||||
int rankRecvCount;
|
||||
if (rankTile.thread_rank() == 0)
|
||||
|
||||
CounterCommunicator counter(workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1));
|
||||
int communicationCount = gatheredExpertStatics == nullptr ? 1 : expertCount + 1;
|
||||
for (int i = rankTile.thread_rank(); i < communicationCount; i += THREADS_PER_PIPELINE)
|
||||
{
|
||||
CounterCommunicator counter(workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1));
|
||||
rankRecvCount = int(counter.acquireValue());
|
||||
// printf("rankRecvCount: %d, rankId: %d, targetRankId: %d\n", rankRecvCount, rankId, targetRankId);
|
||||
*(recvCounts + targetRankId) = rankRecvCount;
|
||||
*(sharedCountsThisRank) = rankRecvCount;
|
||||
int recvValue = counter.acquireValue(i);
|
||||
if (i == 0)
|
||||
{
|
||||
*(recvCounts + targetRankId) = recvValue;
|
||||
*(sharedCountsThisRank) = recvValue;
|
||||
}
|
||||
else
|
||||
{
|
||||
*(gatheredExpertStatics + targetRankId * expertCount + i - 1) = recvValue;
|
||||
}
|
||||
}
|
||||
rankTile.sync();
|
||||
|
||||
rankRecvCount = *(sharedCountsThisRank);
|
||||
int rankRecvCount = *(sharedCountsThisRank);
|
||||
for (int tokenId = unitId; tokenId < rankRecvCount; tokenId += UNIT_PER_PIPELINE)
|
||||
{
|
||||
*(localRecvIndice + tokenId) = tokenId;
|
||||
@ -249,20 +183,22 @@ __device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCou
|
||||
|
||||
template <int kThreadsGroupSize>
|
||||
__global__ void computeCountAndIndiceDevice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace,
|
||||
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount,
|
||||
int maxTokenCountPerRank, int topK, int expertCount, int rankId, int rankCount)
|
||||
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, int* expertStatics, int* gatheredExpertStatics,
|
||||
MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank, int topK, int slotCount, int expertCount,
|
||||
int rankId, int rankCount)
|
||||
{
|
||||
__shared__ int sharedCounts[PIPELINE_PER_CTA];
|
||||
bool isSender = blockIdx.x < rankCount;
|
||||
if (isSender)
|
||||
{
|
||||
computeCountAndSend<kThreadsGroupSize>(experts, tokenCount, &sharedCounts[0], sendCounts, sendIndiceWorkspace,
|
||||
backwardIndiceWorkspace, workspace, maxTokenCountPerRank, expertCount, topK, rankId, rankCount);
|
||||
computeCountAndSendStatics<kThreadsGroupSize>(experts, tokenCount, &sharedCounts[0], sendCounts,
|
||||
sendIndiceWorkspace, backwardIndiceWorkspace, expertStatics, workspace, maxTokenCountPerRank, slotCount,
|
||||
expertCount, topK, rankId, rankCount);
|
||||
}
|
||||
else
|
||||
{
|
||||
recvCount(
|
||||
recvIndiceWorkspace, recvCounts, &sharedCounts[0], workspace, maxTokenCountPerRank, rankId, rankCount);
|
||||
recvCountAndStatics(recvIndiceWorkspace, recvCounts, &sharedCounts[0], gatheredExpertStatics, workspace,
|
||||
expertCount, maxTokenCountPerRank, rankId, rankCount);
|
||||
}
|
||||
}
|
||||
|
||||
@ -307,259 +243,12 @@ __global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int threadData = tid < rankCount ? inputOutputPtr[tid] : 0;
|
||||
int count = threadData;
|
||||
__syncthreads();
|
||||
|
||||
BlockScan(temp_storage).InclusiveSum(threadData, threadData);
|
||||
if (tid < rankCount)
|
||||
{
|
||||
inputOutputPtr[tid] = threadData;
|
||||
// printf("cumsum, send? : %d, rankId:%d, tid:%d, threadData:%d, count:%d\n", blockIdx.x == 0, rankId, tid,
|
||||
// threadData, count);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PipelineConfig>
|
||||
class PacketPipeline
|
||||
{
|
||||
public:
|
||||
__device__ __inline__ PacketPipeline(
|
||||
void* bufferBase, StepCommunicatorBase* stepCommunicator, int* sharedNewStepPtr, bool isSender)
|
||||
: bufferBase(bufferBase)
|
||||
, stepCommunicator(stepCommunicator)
|
||||
, shared_new_step(sharedNewStepPtr)
|
||||
{
|
||||
step = 0;
|
||||
needRelease = false;
|
||||
packetId = isSender ? 0 : PipelineConfig::PACKET_PER_STEP - 1;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void* getFirstSendPacket()
|
||||
{
|
||||
return bufferBase;
|
||||
}
|
||||
|
||||
__device__ __inline__ void* finishSendPacket(bool acquireNewStep)
|
||||
{
|
||||
|
||||
packetId++;
|
||||
if (packetId < PipelineConfig::PACKET_PER_STEP)
|
||||
{
|
||||
return acquireNewStep ? bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
|
||||
+ packetId * PipelineConfig::PACKET_SIZE
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
stepCommunicator->releaseSendStep();
|
||||
if (acquireNewStep)
|
||||
{
|
||||
step = stepCommunicator->acquireNewSendStep();
|
||||
*(shared_new_step) = step;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (acquireNewStep)
|
||||
{
|
||||
step = *(shared_new_step);
|
||||
packetId = 0;
|
||||
return bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void* sendFinalize()
|
||||
{
|
||||
if (packetId > 0 && threadIdx.x == 0)
|
||||
{
|
||||
stepCommunicator->releaseSendStep();
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __inline__ void* getNewRecvPacket()
|
||||
{
|
||||
packetId++;
|
||||
if (packetId < PipelineConfig::PACKET_PER_STEP)
|
||||
{
|
||||
return bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
|
||||
+ packetId * PipelineConfig::PACKET_SIZE;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
if (needRelease)
|
||||
{
|
||||
stepCommunicator->releaseRecvStep();
|
||||
}
|
||||
step = stepCommunicator->acquireNewRecvStep();
|
||||
needRelease = true;
|
||||
*(shared_new_step) = step;
|
||||
}
|
||||
__syncthreads();
|
||||
packetId = 0;
|
||||
step = *(shared_new_step);
|
||||
void* packetPtr = bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP;
|
||||
|
||||
return packetPtr;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void reset()
|
||||
{
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
stepCommunicator->reset();
|
||||
}
|
||||
}
|
||||
|
||||
void* bufferBase;
|
||||
StepCommunicatorBase* stepCommunicator;
|
||||
int step;
|
||||
int packetId;
|
||||
bool needRelease;
|
||||
int* shared_new_step;
|
||||
};
|
||||
|
||||
template <typename PipelineConfig, typename ExpertType, typename ScaleType>
|
||||
__global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales,
|
||||
int* localExpertStatics, int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum,
|
||||
int* localSendIndice, int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank,
|
||||
int topK, int expertCount, int slotCount, int rankId, int rankCount)
|
||||
{
|
||||
bool isSender = (blockIdx.y == 0);
|
||||
int targetRankId = blockIdx.x;
|
||||
int slotCountPerRank = slotCount / rankCount;
|
||||
int groupSize = topK / PipelineConfig::UNIT_SIZE;
|
||||
|
||||
__shared__ int sharedNewStep;
|
||||
__align__(16) int experts[PipelineConfig::UNIT_SIZE];
|
||||
__align__(16) float scales[PipelineConfig::UNIT_SIZE];
|
||||
|
||||
uint8_t* bufferBase = (uint8_t*) (workspace.getFifoBasePtr(isSender, rankId, targetRankId, 0, 1));
|
||||
StepCommunicatorBase stepCommunicator(workspace.getFifoConnInfo(isSender, rankId, targetRankId, 0, rankCount, 1));
|
||||
PacketPipeline<PipelineConfig> pipeline(bufferBase, &stepCommunicator, &sharedNewStep, isSender);
|
||||
|
||||
if (isSender)
|
||||
{
|
||||
int baseCumsum = targetRankId == 0 ? 0 : *(sendCountsCumsum + targetRankId - 1);
|
||||
int sendTokenCount = *(sendCountsCumsum + targetRankId) - baseCumsum;
|
||||
int unitCount = sendTokenCount * topK / PipelineConfig::UNIT_SIZE;
|
||||
|
||||
void* packPtr = pipeline.getFirstSendPacket();
|
||||
int indexBase = 0;
|
||||
int staticCopyBase = 0;
|
||||
bool acquireNewStep = unitCount > 0 || (localExpertStatics != nullptr && expertCount > 0);
|
||||
while (acquireNewStep)
|
||||
{
|
||||
if (threadIdx.x < UNIT_PER_ITER)
|
||||
{
|
||||
int index = indexBase + threadIdx.x;
|
||||
int groupId = index % groupSize;
|
||||
if (index < unitCount)
|
||||
{
|
||||
int tokenId = *(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize));
|
||||
*((ExpertType*) (experts))
|
||||
= *(ExpertType*) (sendExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < PipelineConfig::UNIT_SIZE; j++)
|
||||
{
|
||||
int expertId = experts[j];
|
||||
if (expertId / slotCountPerRank != targetRankId)
|
||||
{
|
||||
experts[j] = slotCount;
|
||||
}
|
||||
}
|
||||
|
||||
int* expertsPtr = (int*) (packPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
|
||||
*((ExpertType*) (expertsPtr)) = *((ExpertType*) (experts));
|
||||
if (sendScales != nullptr)
|
||||
{
|
||||
*((ScaleType*) (scales))
|
||||
= *(ScaleType*) (sendScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
|
||||
float* scaleBasePtr = (float*) (packPtr + PipelineConfig::SCALE_OFFSET);
|
||||
float* scalesPtr = (float*) (scaleBasePtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
|
||||
*((ScaleType*) (scalesPtr)) = *((ScaleType*) (scales));
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (localExpertStatics != nullptr)
|
||||
{
|
||||
int staticCopyIdx = threadIdx.x - UNIT_PER_ITER;
|
||||
if (staticCopyBase + staticCopyIdx * 4 < expertCount)
|
||||
{
|
||||
int4* staticBasePtr = (int4*) (packPtr + PipelineConfig::STATIC_COPY_OFFSET);
|
||||
int4 staticData = *(int4*) (localExpertStatics + staticCopyBase + staticCopyIdx * 4);
|
||||
*(staticBasePtr + staticCopyIdx) = staticData;
|
||||
}
|
||||
}
|
||||
|
||||
indexBase += UNIT_PER_ITER;
|
||||
staticCopyBase += STATIC_COPY_PER_ITER * 4;
|
||||
acquireNewStep = indexBase < unitCount || staticCopyBase < expertCount;
|
||||
packPtr = pipeline.finishSendPacket(acquireNewStep);
|
||||
}
|
||||
|
||||
pipeline.sendFinalize();
|
||||
}
|
||||
else
|
||||
{
|
||||
int baseCumsum = targetRankId == 0 ? 0 : *(recvCountsCumsum + targetRankId - 1);
|
||||
int recvTokenCount = *(recvCountsCumsum + targetRankId) - baseCumsum;
|
||||
int recvUnitCount = recvTokenCount * groupSize;
|
||||
|
||||
int unitIdBase = 0;
|
||||
int staticCopyBase = 0;
|
||||
while (unitIdBase < recvUnitCount || (localExpertStatics != nullptr && staticCopyBase < expertCount))
|
||||
{
|
||||
void* packetPtr = pipeline.getNewRecvPacket();
|
||||
int packetUnitCount
|
||||
= unitIdBase + UNIT_PER_ITER < recvUnitCount ? UNIT_PER_ITER : recvUnitCount - unitIdBase;
|
||||
packetUnitCount = max(packetUnitCount, 0);
|
||||
if (threadIdx.x < UNIT_PER_ITER)
|
||||
{
|
||||
if (threadIdx.x < packetUnitCount)
|
||||
{
|
||||
int tokenId = baseCumsum + (unitIdBase + threadIdx.x) / groupSize;
|
||||
int groupId = (unitIdBase + threadIdx.x) % groupSize;
|
||||
int* expertsPtr = (int*) (packetPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
|
||||
*((ExpertType*) (experts)) = *((ExpertType*) (expertsPtr));
|
||||
ExpertType* dstExpertsPtr
|
||||
= (ExpertType*) (recvExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
|
||||
*dstExpertsPtr = *((ExpertType*) (experts));
|
||||
|
||||
if (recvScales != nullptr)
|
||||
{
|
||||
float* scaleBasePtr = (float*) (packetPtr + PipelineConfig::SCALE_OFFSET);
|
||||
float* scalesPtr = scaleBasePtr + threadIdx.x * PipelineConfig::UNIT_SIZE;
|
||||
*((ScaleType*) (scales)) = *((ScaleType*) (scalesPtr));
|
||||
ScaleType* dstScalesPtr
|
||||
= (ScaleType*) (recvScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
|
||||
*dstScalesPtr = *((ScaleType*) (scales));
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (localExpertStatics != nullptr)
|
||||
{
|
||||
int staticCopyIdx = threadIdx.x - UNIT_PER_ITER;
|
||||
if (staticCopyBase + staticCopyIdx * 4 < expertCount)
|
||||
{
|
||||
int4* staticBasePtr = (int4*) (packetPtr + PipelineConfig::STATIC_COPY_OFFSET);
|
||||
int4 staticData = *(staticBasePtr + staticCopyIdx);
|
||||
*(int4*) (gatheredExpertStatics + targetRankId * expertCount + staticCopyBase + staticCopyIdx * 4)
|
||||
= staticData;
|
||||
}
|
||||
}
|
||||
|
||||
unitIdBase += packetUnitCount;
|
||||
staticCopyBase += STATIC_COPY_PER_ITER * 4;
|
||||
}
|
||||
|
||||
pipeline.reset();
|
||||
}
|
||||
}
|
||||
|
||||
@ -576,8 +265,9 @@ __global__ void memsetExpertIdsDevice(
|
||||
}
|
||||
|
||||
void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace,
|
||||
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount,
|
||||
int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream)
|
||||
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, int* expertStatics, int* gatheredExpertStatics,
|
||||
MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank, int topK, int slotCount, int expertCount,
|
||||
int rankId, int rankCount, cudaStream_t stream)
|
||||
{
|
||||
// first rankCount CTAs for count and send, then rankCount / PIPELINE_PER_CTA CTAs only for receive
|
||||
int grid_x = rankCount + (rankCount + PIPELINE_PER_CTA - 1) / PIPELINE_PER_CTA;
|
||||
@ -607,7 +297,8 @@ void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int*
|
||||
kernelFn = computeCountAndIndiceDevice<2>;
|
||||
}
|
||||
kernelFn<<<grid, block, 0, stream>>>(experts, sendCounts, recvCounts, sendIndiceWorkspace, backwardIndiceWorkspace,
|
||||
recvIndiceWorkspace, workspace, tokenCount, maxTokenCountPerRank, topK, expert_count, rankId, rankCount);
|
||||
recvIndiceWorkspace, expertStatics, gatheredExpertStatics, workspace, tokenCount, maxTokenCountPerRank, topK,
|
||||
slotCount, expertCount, rankId, rankCount);
|
||||
}
|
||||
|
||||
void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream)
|
||||
@ -628,46 +319,18 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i
|
||||
backwardIndice, gatherBackwardIndice, recvIndice, gatherRecvIndice, maxTokenCountPerRank);
|
||||
}
|
||||
|
||||
void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics,
|
||||
int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice,
|
||||
int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount,
|
||||
int slotCount, int rankId, int rankCount, cudaStream_t stream)
|
||||
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount,
|
||||
int rankCount, cudaStream_t stream)
|
||||
{
|
||||
int block_size = localExpertStatics == nullptr ? UNIT_PER_ITER : UNIT_PER_ITER + STATIC_COPY_PER_ITER;
|
||||
dim3 block(block_size);
|
||||
dim3 grid(rankCount, 2);
|
||||
|
||||
if (topK % 4 == 0)
|
||||
{
|
||||
using PipelineConfig = PipelineConfig<4, 16>;
|
||||
static_assert(
|
||||
PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
|
||||
"FIFO size is too small");
|
||||
allToAllMetadataDevice<PipelineConfig, int4, float4><<<grid, block, 0, stream>>>(sendExperts, recvExperts,
|
||||
sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
|
||||
localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
|
||||
slotCount, rankId, rankCount);
|
||||
}
|
||||
else
|
||||
{
|
||||
using PipelineConfig = PipelineConfig<1, 64>;
|
||||
static_assert(
|
||||
PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
|
||||
"FIFO size is too small");
|
||||
allToAllMetadataDevice<PipelineConfig, int, float><<<grid, block, 0, stream>>>(sendExperts, recvExperts,
|
||||
sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
|
||||
localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
|
||||
slotCount, rankId, rankCount);
|
||||
}
|
||||
|
||||
int smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
memsetExpertIdsDevice<<<smCount, 256, 0, stream>>>(
|
||||
recvExperts, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
|
||||
int block_size = 256;
|
||||
memsetExpertIdsDevice<<<smCount, block_size, 0, stream>>>(
|
||||
expertIds, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
|
||||
}
|
||||
|
||||
size_t getMoePrepareWorkspaceSize(int epSize)
|
||||
{
|
||||
return (FIFO_SIZE_IN_U64 * 8 + StepCommunicatorBase::META_SIZE) * epSize;
|
||||
return sizeof(MoeCommFifoConnInfo) * epSize;
|
||||
}
|
||||
|
||||
} // namespace moe_prepare
|
||||
|
||||
@ -28,36 +28,11 @@ namespace tensorrt_llm::kernels
|
||||
namespace moe_prepare
|
||||
{
|
||||
|
||||
#define STEP_DEPTH 2
|
||||
#define THREADS_PER_UNIT 1
|
||||
#define UNIT_PER_PIPELINE 128
|
||||
#define PIPELINE_PER_CTA 4
|
||||
#define EXPERT_BYTES_PER_UNIT 32
|
||||
#define SCALE_BYTES_PER_UNIT 32
|
||||
#define UNIT_COUNT_PER_PACKET 1024
|
||||
#define BYTES_COUNTER 8
|
||||
#define CUMSUM_THREADS_PER_BLOCK 128
|
||||
|
||||
#define UNIT_PER_ITER 256
|
||||
#define STATIC_COPY_PER_ITER 128
|
||||
|
||||
static constexpr int THREADS_PER_PIPELINE = THREADS_PER_UNIT * UNIT_PER_PIPELINE;
|
||||
static constexpr int THREADS_PER_CTA = THREADS_PER_PIPELINE * PIPELINE_PER_CTA;
|
||||
|
||||
template <int UNIT_SIZE_INPUT, int PACKET_PER_STEP_INPUT>
|
||||
struct PipelineConfig
|
||||
{
|
||||
static constexpr int UNIT_SIZE = UNIT_SIZE_INPUT;
|
||||
static constexpr int PACKET_PER_STEP = PACKET_PER_STEP_INPUT;
|
||||
static constexpr int UNIT_BYTES_SIZE = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float));
|
||||
static constexpr int SCALE_OFFSET = UNIT_SIZE * UNIT_PER_ITER * sizeof(int);
|
||||
static constexpr int STATIC_COPY_OFFSET = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float));
|
||||
static constexpr int PACKET_SIZE = UNIT_BYTES_SIZE + STATIC_COPY_PER_ITER * 4 * sizeof(int);
|
||||
static constexpr int PACKET_SIZE_IN_U64 = (PACKET_SIZE / 8);
|
||||
};
|
||||
|
||||
// 1MB FIFO size
|
||||
static constexpr int FIFO_SIZE_IN_U64 = 1024 * 1024 / 8;
|
||||
static constexpr int THREADS_PER_PIPELINE = UNIT_PER_PIPELINE;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#define ALIGN_256 __align__(256)
|
||||
@ -67,9 +42,9 @@ static constexpr int FIFO_SIZE_IN_U64 = 1024 * 1024 / 8;
|
||||
|
||||
struct ALIGN_256 MoeCommFifoConnInfo
|
||||
{
|
||||
volatile uint64_t head; // write position
|
||||
volatile uint64_t tail; // read position
|
||||
volatile uint64_t count; // for counter
|
||||
volatile uint64_t head; // write position
|
||||
volatile uint64_t tail; // read position
|
||||
int volatile values[512]; // for values
|
||||
};
|
||||
|
||||
struct MoeCommWorkspace
|
||||
@ -77,25 +52,11 @@ struct MoeCommWorkspace
|
||||
uint64_t* workspacePtr;
|
||||
size_t rankStrideInU64;
|
||||
#ifdef __CUDACC__
|
||||
__inline__ __device__ uint64_t* getFifoBasePtr(
|
||||
bool isSender, int epRank, int peerRank, int channel, int channelCount) const
|
||||
{
|
||||
// fifo itself is in receiver's side.
|
||||
if (isSender)
|
||||
{
|
||||
return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * FIFO_SIZE_IN_U64;
|
||||
}
|
||||
else
|
||||
{
|
||||
return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * FIFO_SIZE_IN_U64;
|
||||
}
|
||||
}
|
||||
|
||||
__inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo(
|
||||
bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const
|
||||
{
|
||||
// fifoInfo is in sender's side.
|
||||
uint64_t* fifoInfoPtrU64 = workspacePtr + FIFO_SIZE_IN_U64 * channelCount * epSize;
|
||||
uint64_t* fifoInfoPtrU64 = workspacePtr;
|
||||
int strideIndice = isSender ? epRank : peerRank;
|
||||
int fifoInfoIndice = isSender ? peerRank : epRank;
|
||||
fifoInfoPtrU64 += strideIndice * rankStrideInU64;
|
||||
@ -108,8 +69,9 @@ struct MoeCommWorkspace
|
||||
};
|
||||
|
||||
void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace,
|
||||
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount,
|
||||
int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream);
|
||||
int* backwardIndiceWorkspace, int* recvIndiceWorkspace, int* expertStatics, int* gatheredExpertStatics,
|
||||
MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank, int topK, int slotCount, int expertCount,
|
||||
int rankId, int rankCount, cudaStream_t stream);
|
||||
|
||||
void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream);
|
||||
|
||||
@ -117,10 +79,8 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i
|
||||
int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount,
|
||||
int maxTokenCountPerRank, cudaStream_t stream);
|
||||
|
||||
void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics,
|
||||
int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice,
|
||||
int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount,
|
||||
int slotCount, int rankId, int rankCount, cudaStream_t stream);
|
||||
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount,
|
||||
int epSize, cudaStream_t stream);
|
||||
|
||||
size_t getMoePrepareWorkspaceSize(int epSize);
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include "tensorrt_llm/kernels/moeCommKernels.h"
|
||||
#include "tensorrt_llm/kernels/fusedMoeCommKernels.h"
|
||||
#include "tensorrt_llm/kernels/moePrepareKernels.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
@ -28,180 +28,104 @@
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
moeCommPrepareIndicesOp(torch::Tensor gatheredTargetRankIds, c10::optional<torch::Tensor> realRankTokenCountCumSum,
|
||||
int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize)
|
||||
void setMoeCommFieldInfo(tensorrt_llm::kernels::MoeCommFieldInfo& fieldInfo, torch::Tensor const& tensor)
|
||||
{
|
||||
CHECK_INPUT(gatheredTargetRankIds, torch::kInt32);
|
||||
TORCH_CHECK(gatheredTargetRankIds.dim() == 2, "gatheredTargetRankIds must be a 2D tensor");
|
||||
TORCH_CHECK(gatheredTargetRankIds.size(1) == topK, "gatheredTargetRankIds must have topK columns");
|
||||
|
||||
int const* realRankTokenCountCumSumPtr = nullptr;
|
||||
if (realRankTokenCountCumSum.has_value())
|
||||
{
|
||||
TORCH_CHECK(realRankTokenCountCumSum.value().dim() == 1, "realRankTokenCountCumSum must be a 1D tensor");
|
||||
TORCH_CHECK(realRankTokenCountCumSum.value().dtype() == torch::kInt32,
|
||||
"realRankTokenCountCumSum must be a int32 tensor");
|
||||
TORCH_CHECK(
|
||||
realRankTokenCountCumSum.value().size(0) == epSize, "realRankTokenCountCumSum must have epSize elements");
|
||||
realRankTokenCountCumSumPtr = realRankTokenCountCumSum.value().data_ptr<int>();
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK(gatheredTargetRankIds.size(0) == epSize * maxTokenCountPerRank,
|
||||
"gatheredTargetRankIds should have shape (epSize * maxTokenCountPerRank, topK)");
|
||||
}
|
||||
TORCH_CHECK(maxTokenCountPerRank > 0, "maxTokenCountPerRank must be greater than 0");
|
||||
TORCH_CHECK(expertCount > 0, "expertCount must be greater than 0");
|
||||
TORCH_CHECK(topK > 0, "topK must be greater than 0");
|
||||
TORCH_CHECK(topK <= expertCount, "topK must be less than or equal to expertCount");
|
||||
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int maxSendRanksPerToken = std::max(epSize, topK);
|
||||
|
||||
torch::Tensor localGatherIndices
|
||||
= torch::empty({maxTokenCountPerRank * epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
|
||||
torch::Tensor sendRankCountCumSum = torch::empty({epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
|
||||
torch::Tensor sendRankLocalIndices = torch::empty(
|
||||
{maxTokenCountPerRank * maxSendRanksPerToken}, gatheredTargetRankIds.options().dtype(torch::kInt32));
|
||||
torch::Tensor recvRankCountCumSum = torch::empty({epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
|
||||
torch::Tensor recvRankLocalIndices
|
||||
= torch::empty({maxTokenCountPerRank * epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
|
||||
torch::Tensor backwardRecvRankLocalIndices = torch::empty(
|
||||
{maxTokenCountPerRank * maxSendRanksPerToken}, gatheredTargetRankIds.options().dtype(torch::kInt32));
|
||||
|
||||
tensorrt_llm::kernels::MoeExpertParallelInfo expertParallelInfo;
|
||||
expertParallelInfo.expertCount = expertCount;
|
||||
expertParallelInfo.topK = topK;
|
||||
|
||||
tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
|
||||
tensorrt_llm::kernels::moeAllToAllPrepareIndices(worldInfo, expertParallelInfo, maxTokenCountPerRank,
|
||||
gatheredTargetRankIds.data_ptr<int>(), realRankTokenCountCumSumPtr, localGatherIndices.data_ptr<int>(),
|
||||
sendRankCountCumSum.data_ptr<int>(), sendRankLocalIndices.data_ptr<int>(), recvRankCountCumSum.data_ptr<int>(),
|
||||
recvRankLocalIndices.data_ptr<int>(), backwardRecvRankLocalIndices.data_ptr<int>(), stream);
|
||||
|
||||
return std::make_tuple(localGatherIndices, sendRankCountCumSum, sendRankLocalIndices, recvRankCountCumSum,
|
||||
recvRankLocalIndices, backwardRecvRankLocalIndices);
|
||||
TORCH_CHECK(tensor.dim() == 2, "tensor must be a 2D tensor");
|
||||
int eltSize = tensor.dtype().itemsize();
|
||||
fieldInfo.fillFieldInfo(static_cast<uint8_t*>(tensor.data_ptr()), eltSize, tensor.size(1), tensor.stride(0));
|
||||
}
|
||||
|
||||
void moeLocalGatherOp(torch::Tensor recvRankCumSum, torch::Tensor localGatherIndices, torch::Tensor gatheredExpertIds,
|
||||
c10::optional<torch::Tensor> gatheredScales, torch::Tensor localExpertIds, c10::optional<torch::Tensor> localScales,
|
||||
int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize)
|
||||
{
|
||||
CHECK_INPUT(recvRankCumSum, torch::kInt32);
|
||||
CHECK_INPUT(localGatherIndices, torch::kInt32);
|
||||
CHECK_INPUT(gatheredExpertIds, torch::kInt32);
|
||||
CHECK_INPUT(localExpertIds, torch::kInt32);
|
||||
|
||||
TORCH_CHECK(maxTokenCountPerRank > 0, "maxTokenCountPerRank must be greater than 0");
|
||||
TORCH_CHECK(expertCount > 0, "expertCount must be greater than 0");
|
||||
TORCH_CHECK(topK > 0, "topK must be greater than 0");
|
||||
TORCH_CHECK(topK <= expertCount, "topK must be less than or equal to expertCount");
|
||||
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
|
||||
|
||||
TORCH_CHECK(recvRankCumSum.dim() == 1, "recvRankCumSum must be a 1D tensor");
|
||||
TORCH_CHECK(recvRankCumSum.size(0) == epSize, "recvRankCumSum must have epSize elements");
|
||||
TORCH_CHECK(localGatherIndices.dim() == 1, "localGatherIndices must be a 1D tensor");
|
||||
TORCH_CHECK(gatheredExpertIds.dim() == 2, "gatheredExpertIds must be a 2D tensor");
|
||||
TORCH_CHECK(localExpertIds.dim() == 2, "localExpertIds must be a 2D tensor");
|
||||
TORCH_CHECK(gatheredExpertIds.size(1) == topK, "gatheredExpertIds must have topK columns");
|
||||
TORCH_CHECK(localExpertIds.size(1) == topK, "localExpertIds must have topK columns");
|
||||
|
||||
int localMaxTokenCount = static_cast<int>(localGatherIndices.size(0));
|
||||
TORCH_CHECK(localExpertIds.size(0) == localMaxTokenCount, "localExpertIds must have localMaxTokenCount rows");
|
||||
|
||||
TORCH_CHECK(gatheredScales.has_value() == localScales.has_value(),
|
||||
"gatheredScales and localScales must be both valid or both invalid");
|
||||
float const* gatheredScalesPtr = nullptr;
|
||||
float* localScalesPtr = nullptr;
|
||||
if (gatheredScales.has_value())
|
||||
{
|
||||
CHECK_INPUT(gatheredScales.value(), torch::kFloat32);
|
||||
CHECK_INPUT(localScales.value(), torch::kFloat32);
|
||||
|
||||
TORCH_CHECK(gatheredScales->dim() == 2, "gatheredScales must be a 2D tensor");
|
||||
TORCH_CHECK(gatheredScales->size(1) == topK, "gatheredScales must have topK columns");
|
||||
TORCH_CHECK(localScales->dim() == 2, "localScales must be a 2D tensor");
|
||||
TORCH_CHECK(localScales->size(1) == topK, "localScales must have topK columns");
|
||||
TORCH_CHECK(localScales->size(0) == localMaxTokenCount, "localScales must have localMaxTokenCount rows");
|
||||
|
||||
gatheredScalesPtr = gatheredScales->data_ptr<float>();
|
||||
localScalesPtr = localScales->data_ptr<float>();
|
||||
}
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tensorrt_llm::kernels::MoeExpertParallelInfo expertParallelInfo;
|
||||
expertParallelInfo.expertCount = expertCount;
|
||||
expertParallelInfo.topK = topK;
|
||||
|
||||
tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
|
||||
tensorrt_llm::kernels::moeLocalGather(worldInfo, expertParallelInfo, maxTokenCountPerRank, localMaxTokenCount,
|
||||
recvRankCumSum.data_ptr<int>(), localGatherIndices.data_ptr<int>(), gatheredExpertIds.data_ptr<int>(),
|
||||
gatheredScalesPtr, localExpertIds.data_ptr<int>(), localScalesPtr, stream);
|
||||
}
|
||||
|
||||
void moeCommOp(torch::Tensor input, torch::Tensor sendRankCumSum, torch::Tensor sendIndices, torch::Tensor output,
|
||||
torch::Tensor recvRankCumSum, torch::Tensor recvIndices, torch::Tensor allWorkspaces, int64_t epRank,
|
||||
int64_t epSize)
|
||||
c10::List<torch::Tensor> moeCommOp(c10::List<torch::Tensor> inputs, torch::Tensor sendRankCumSum,
|
||||
torch::Tensor sendIndiceTensor, torch::Tensor recvRankCumSum, torch::Tensor recvIndiceTensor,
|
||||
torch::Tensor allWorkspaces, int64_t outputAllocationCount, int64_t epRank, int64_t epSize,
|
||||
std::optional<c10::List<bool>> needZeroOutput = std::nullopt)
|
||||
{
|
||||
CHECK_INPUT(sendRankCumSum, torch::kInt32);
|
||||
CHECK_INPUT(sendIndices, torch::kInt32);
|
||||
CHECK_INPUT(sendIndiceTensor, torch::kInt32);
|
||||
CHECK_INPUT(recvRankCumSum, torch::kInt32);
|
||||
CHECK_INPUT(recvIndices, torch::kInt32);
|
||||
// allWorkspaces is a uint64 tensor, but may not be contiguous
|
||||
TORCH_CHECK(allWorkspaces.dtype() == torch::kUInt64, "allWorkspaces must be a uint64 tensor");
|
||||
CHECK_INPUT(recvIndiceTensor, torch::kInt32);
|
||||
|
||||
TORCH_CHECK(input.dim() == 2, "input must be a 2D tensor");
|
||||
TORCH_CHECK(output.dim() == 2, "output must be a 2D tensor");
|
||||
TORCH_CHECK(sendRankCumSum.dim() == 1, "sendRankCumSum must be a 1D tensor");
|
||||
TORCH_CHECK(sendIndices.dim() == 1, "sendIndices must be a 1D tensor");
|
||||
TORCH_CHECK(sendIndiceTensor.dim() == 1, "sendIndices must be a 1D tensor");
|
||||
TORCH_CHECK(recvRankCumSum.dim() == 1, "recvRankCumSum must be a 1D tensor");
|
||||
TORCH_CHECK(recvIndices.dim() == 1, "recvIndices must be a 1D tensor");
|
||||
TORCH_CHECK(recvIndiceTensor.dim() == 1, "recvIndices must be a 1D tensor");
|
||||
TORCH_CHECK(allWorkspaces.dim() == 2, "allWorkspaces must be a 2D tensor");
|
||||
|
||||
TORCH_CHECK(input.size(1) == output.size(1), "input and output must have the same second dimension");
|
||||
TORCH_CHECK(sendRankCumSum.size(0) == epSize, "sendRankCumSum must have epSize elements");
|
||||
TORCH_CHECK(recvRankCumSum.size(0) == epSize, "recvRankCumSum must have epSize elements");
|
||||
TORCH_CHECK(allWorkspaces.size(0) == epSize, "allWorkspaces must have epSize elements");
|
||||
|
||||
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
|
||||
TORCH_CHECK(!needZeroOutput.has_value() || needZeroOutput.value().size() == inputs.size(),
|
||||
"needZeroOutput should have same length as inputs");
|
||||
c10::List<torch::Tensor> outputs;
|
||||
|
||||
tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
|
||||
tensorrt_llm::kernels::SendRecvDataInfo sendRecvDataInfo;
|
||||
tensorrt_llm::kernels::MoeEpWorldInfo epWorldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
|
||||
tensorrt_llm::kernels::FusedMoeWorldInfo worldInfo = {epWorldInfo};
|
||||
|
||||
size_t eltSize = input.dtype().itemsize();
|
||||
size_t eltCountPerU64 = sizeof(uint64_t) / eltSize;
|
||||
TORCH_CHECK(input.size(1) % (eltCountPerU64 * 2) == 0, "input.size(1) must be aligned to 16 bytes");
|
||||
sendRecvDataInfo.vectorSizeInU64 = input.size(1) / eltCountPerU64;
|
||||
sendRecvDataInfo.DoPreCompute();
|
||||
tensorrt_llm::kernels::SendRecvIndices sendIndices, recvIndices;
|
||||
sendIndices.rankCountCumSum = sendRankCumSum.data_ptr<int>();
|
||||
sendIndices.rankLocalIndices = sendIndiceTensor.data_ptr<int>();
|
||||
|
||||
tensorrt_llm::kernels::SendRecvDispls sendDispls, recvDispls;
|
||||
sendDispls.dataPtr = static_cast<uint64_t*>(input.data_ptr());
|
||||
sendDispls.rankCountCumSum = sendRankCumSum.data_ptr<int>();
|
||||
sendDispls.rankLocalIndices = sendIndices.data_ptr<int>();
|
||||
sendDispls.vectorStrideInU64 = input.stride(0) / eltCountPerU64;
|
||||
recvIndices.rankCountCumSum = recvRankCumSum.data_ptr<int>();
|
||||
recvIndices.rankLocalIndices = recvIndiceTensor.data_ptr<int>();
|
||||
|
||||
recvDispls.dataPtr = static_cast<uint64_t*>(output.data_ptr());
|
||||
recvDispls.rankCountCumSum = recvRankCumSum.data_ptr<int>();
|
||||
recvDispls.rankLocalIndices = recvIndices.data_ptr<int>();
|
||||
recvDispls.vectorStrideInU64 = output.stride(0) / eltCountPerU64;
|
||||
int fieldCount = inputs.size();
|
||||
TORCH_CHECK(fieldCount <= tensorrt_llm::kernels::MOE_COMM_FIELD_MAX_COUNT, "Number of fields (", fieldCount,
|
||||
") exceeds maximum allowed (", tensorrt_llm::kernels::MOE_COMM_FIELD_MAX_COUNT, ")");
|
||||
tensorrt_llm::kernels::FusedMoeFieldInfo sendFieldInfo, recvFieldInfo;
|
||||
sendFieldInfo.isBasicInterleaved = false;
|
||||
recvFieldInfo.isBasicInterleaved = false;
|
||||
sendFieldInfo.fieldCount = fieldCount;
|
||||
recvFieldInfo.fieldCount = fieldCount;
|
||||
sendFieldInfo.expertScales = nullptr;
|
||||
recvFieldInfo.expertScales = nullptr;
|
||||
sendFieldInfo.tokenSelectedSlots = nullptr;
|
||||
recvFieldInfo.tokenSelectedSlots = nullptr;
|
||||
|
||||
tensorrt_llm::kernels::MoeCommWorkspace workspace;
|
||||
workspace.workspacePtr = allWorkspaces.data_ptr<uint64_t>();
|
||||
workspace.rankStrideInU64 = allWorkspaces.stride(0);
|
||||
for (int i = 0; i < fieldCount; i++)
|
||||
{
|
||||
torch::Tensor const& t = inputs[i];
|
||||
setMoeCommFieldInfo(sendFieldInfo.fieldsInfo[i], t);
|
||||
if (needZeroOutput.has_value() && needZeroOutput.value()[i])
|
||||
{
|
||||
outputs.push_back(torch::zeros({outputAllocationCount, t.size(1)}, t.options()));
|
||||
}
|
||||
else
|
||||
{
|
||||
outputs.push_back(torch::empty({outputAllocationCount, t.size(1)}, t.options()));
|
||||
}
|
||||
setMoeCommFieldInfo(recvFieldInfo.fieldsInfo[i], outputs[i]);
|
||||
}
|
||||
sendFieldInfo.fillFieldPlacementInfo(0, false);
|
||||
recvFieldInfo.fillFieldPlacementInfo(0, false);
|
||||
|
||||
tensorrt_llm::kernels::FusedMoeCommKernelParam params;
|
||||
params.worldInfo = worldInfo;
|
||||
params.sendIndices = sendIndices;
|
||||
params.recvIndices = recvIndices;
|
||||
params.sendFieldInfo = sendFieldInfo;
|
||||
params.recvFieldInfo = recvFieldInfo;
|
||||
// Do not need expertParallelInfo for fused moe comm now
|
||||
|
||||
params.sendFieldInfo.fillMetaInfo(&(params.sendCommMeta), params.expertParallelInfo.topK, false, false);
|
||||
params.recvFieldInfo.fillMetaInfo(&(params.recvCommMeta), params.expertParallelInfo.topK, false, false);
|
||||
|
||||
tensorrt_llm::kernels::FusedMoeWorkspace fusedMoeWorkspace;
|
||||
tensorrt_llm::kernels::constructWorkspace(
|
||||
&fusedMoeWorkspace, allWorkspaces.data_ptr<uint64_t>(), allWorkspaces.stride(0), epSize);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tensorrt_llm::kernels::moeAllToAll(worldInfo, sendRecvDataInfo, sendDispls, recvDispls, workspace, stream);
|
||||
tensorrt_llm::kernels::moeAllToAll(params, fusedMoeWorkspace, stream);
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
int64_t getWorkspaceSizePerRank(int64_t epSize)
|
||||
{
|
||||
int epSize32 = static_cast<int>(epSize);
|
||||
return tensorrt_llm::kernels::getMoeCommWorkspaceSize(epSize32);
|
||||
return tensorrt_llm::kernels::getFusedMoeCommWorkspaceSize(epSize32);
|
||||
}
|
||||
|
||||
void setMaxUsableSmCount(int64_t maxSmCount)
|
||||
@ -215,15 +139,29 @@ int64_t getPrepareWorkspaceSizePerRank(int64_t epSize)
|
||||
return tensorrt_llm::kernels::moe_prepare::getMoePrepareWorkspaceSize(epSize32);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, c10::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
||||
torch::Tensor, c10::optional<torch::Tensor>>
|
||||
moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> scales, c10::optional<torch::Tensor> expertsStatics,
|
||||
torch::Tensor allWorkspaces, int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount,
|
||||
int64_t slotCount, int64_t topK)
|
||||
void initializeMoeWorkspace(torch::Tensor allWorkspaces, int64_t epRank, int64_t epSize)
|
||||
{
|
||||
TORCH_CHECK(allWorkspaces.dim() == 2, "allWorkspaces must be a 2D tensor");
|
||||
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
|
||||
|
||||
tensorrt_llm::kernels::MoeEpWorldInfo epWorldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
|
||||
tensorrt_llm::kernels::FusedMoeWorldInfo worldInfo = {epWorldInfo};
|
||||
|
||||
tensorrt_llm::kernels::FusedMoeWorkspace fusedMoeWorkspace;
|
||||
tensorrt_llm::kernels::constructWorkspace(
|
||||
&fusedMoeWorkspace, allWorkspaces.data_ptr<uint64_t>(), allWorkspaces.stride(0), epSize);
|
||||
|
||||
tensorrt_llm::kernels::initializeFusedMoeLocalWorkspace(&fusedMoeWorkspace, worldInfo);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, c10::optional<torch::Tensor>>
|
||||
moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> expertsStatics, torch::Tensor allWorkspaces,
|
||||
int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount, int64_t slotCount, int64_t topK)
|
||||
{
|
||||
CHECK_INPUT(expertsIds, torch::kInt32);
|
||||
TORCH_CHECK(expertCount % 4 == 0, "expertCount must be divisible by 4");
|
||||
TORCH_CHECK(slotCount % 4 == 0, "slotCount must be divisible by 4");
|
||||
TORCH_CHECK(expertCount + 1 <= 512, "expertCount + 1 is larger than 512");
|
||||
|
||||
int64_t maxSendRanksPerToken = std::max(epSize, topK);
|
||||
int64_t tokenCount = expertsIds.size(0);
|
||||
@ -249,18 +187,6 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> scales, c10:
|
||||
torch::Tensor sendRankIndices
|
||||
= torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32));
|
||||
|
||||
c10::optional<torch::Tensor> preparedLocalScales;
|
||||
float* scalesPtr = nullptr;
|
||||
float* preparedLocalScalesPtr = nullptr;
|
||||
if (scales.has_value())
|
||||
{
|
||||
CHECK_INPUT(scales.value(), torch::kFloat32);
|
||||
scalesPtr = scales->data_ptr<float>();
|
||||
preparedLocalScales
|
||||
= torch::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(torch::kFloat32));
|
||||
preparedLocalScalesPtr = preparedLocalScales->data_ptr<float>();
|
||||
}
|
||||
|
||||
int* localExpertStaticsPtr = nullptr;
|
||||
int* gatheredExpertStaticsPtr = nullptr;
|
||||
c10::optional<torch::Tensor> gatheredExpertStatics;
|
||||
@ -279,8 +205,9 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> scales, c10:
|
||||
|
||||
tensorrt_llm::kernels::moe_prepare::computeCountAndIndice(expertsIds.data_ptr<int>(),
|
||||
sendRankCountCumSum.data_ptr<int>(), RecvRankCountCumSum.data_ptr<int>(), sendRankIndices.data_ptr<int>(),
|
||||
backwardRecvRankIndices.data_ptr<int>(), recvRankIndices.data_ptr<int>(), workspace, tokenCount,
|
||||
maxTokenCountPerRank, topK, slotCount, epRank, epSize, stream);
|
||||
backwardRecvRankIndices.data_ptr<int>(), recvRankIndices.data_ptr<int>(), localExpertStaticsPtr,
|
||||
gatheredExpertStaticsPtr, workspace, tokenCount, maxTokenCountPerRank, topK, slotCount, expertCount, epRank,
|
||||
epSize, stream);
|
||||
|
||||
tensorrt_llm::kernels::moe_prepare::computeCumsum(
|
||||
sendRankCountCumSum.data_ptr<int>(), RecvRankCountCumSum.data_ptr<int>(), epRank, epSize, stream);
|
||||
@ -291,14 +218,28 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> scales, c10:
|
||||
recvRankIndices.data_ptr<int>(), gatherRecvRankIndices.data_ptr<int>(), epRank, epSize, maxTokenCountPerRank,
|
||||
stream);
|
||||
|
||||
tensorrt_llm::kernels::moe_prepare::allToAllMetadata(expertsIds.data_ptr<int>(),
|
||||
preparedLocalExpertIds.data_ptr<int>(), scalesPtr, preparedLocalScalesPtr, localExpertStaticsPtr,
|
||||
gatheredExpertStaticsPtr, workspace, sendRankCountCumSum.data_ptr<int>(), sendRankIndices.data_ptr<int>(),
|
||||
RecvRankCountCumSum.data_ptr<int>(), recvRankIndices.data_ptr<int>(), tokenCount, maxTokenCountPerRank, topK,
|
||||
expertCount, slotCount, epRank, epSize, stream);
|
||||
return std::make_tuple(sendRankCountCumSum, gatherSendRankIndices, RecvRankCountCumSum, gatherRecvRankIndices,
|
||||
gatherBackwardRecvRankIndices, gatheredExpertStatics);
|
||||
}
|
||||
|
||||
return std::make_tuple(preparedLocalExpertIds, preparedLocalScales, sendRankCountCumSum, gatherSendRankIndices,
|
||||
RecvRankCountCumSum, gatherRecvRankIndices, gatherBackwardRecvRankIndices, gatheredExpertStatics);
|
||||
void memsetExpertIds(torch::Tensor expertsIds, torch::Tensor recvRankCountCumSum, int64_t maxTokenCountPerRank,
|
||||
int64_t topK, int64_t slotCount, int64_t epSize)
|
||||
{
|
||||
CHECK_INPUT(expertsIds, torch::kInt32);
|
||||
TORCH_CHECK(expertsIds.dim() == 2, "expertsIds must be a 1D tensor");
|
||||
TORCH_CHECK(
|
||||
expertsIds.size(0) == maxTokenCountPerRank * epSize, "expertsIds must have maxTokenCountPerRank * epSize rows");
|
||||
TORCH_CHECK(expertsIds.size(1) == topK, "expertsIds must have topK columns");
|
||||
|
||||
CHECK_INPUT(recvRankCountCumSum, torch::kInt32);
|
||||
TORCH_CHECK(recvRankCountCumSum.dim() == 1, "recvRankCountCumSum must be a 1D tensor");
|
||||
TORCH_CHECK(recvRankCountCumSum.size(0) == epSize, "recvRankCountCumSum must have epSize elements");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tensorrt_llm::kernels::moe_prepare::memsetExpertIds(expertsIds.data_ptr<int>(), recvRankCountCumSum.data_ptr<int>(),
|
||||
static_cast<int>(maxTokenCountPerRank), static_cast<int>(topK), static_cast<int>(slotCount),
|
||||
static_cast<int>(epSize), stream);
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
@ -306,34 +247,9 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> scales, c10:
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"moe_comm_prepare_indices(Tensor gathered_target_rank_ids, Tensor? real_rank_token_count_cum_sum, int "
|
||||
"max_token_count_per_rank, int expert_count, int top_k, int ep_rank, int ep_size) -> (Tensor, Tensor, Tensor, "
|
||||
"Tensor, Tensor, Tensor)");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("moe_comm_prepare_indices", &torch_ext::moeCommPrepareIndicesOp);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"moe_local_gather(Tensor recv_rank_cum_sum, Tensor local_gather_indices, Tensor gathered_expert_ids, Tensor? "
|
||||
"gathered_scales, Tensor local_expert_ids, Tensor? local_scales, int max_token_count_per_rank, int "
|
||||
"expert_count, int top_k, int ep_rank, int ep_size) -> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("moe_local_gather", &torch_ext::moeLocalGatherOp);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"moe_comm(Tensor input, Tensor send_rank_cum_sum, Tensor send_indices, Tensor output, Tensor "
|
||||
"recv_rank_cum_sum, Tensor recv_indices, Tensor all_workspaces, int ep_rank, int ep_size) -> ()");
|
||||
"moe_comm(Tensor[] inputs, Tensor send_rank_cum_sum, Tensor send_indices, Tensor "
|
||||
"recv_rank_cum_sum, Tensor recv_indices, Tensor all_workspaces, int output_allocation_count, int ep_rank, int "
|
||||
"ep_size, bool[]? need_zero_output=None) -> Tensor[]");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
@ -341,6 +257,16 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
m.impl("moe_comm", &torch_ext::moeCommOp);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("moe_initialize_workspace(Tensor(a!) all_workspaces, int ep_rank, int ep_size) -> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("moe_initialize_workspace", &torch_ext::initializeMoeWorkspace);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("get_moe_commworkspace_size_per_rank(int ep_size) -> int");
|
||||
@ -364,9 +290,9 @@ TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m)
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor? scales, Tensor? experts_statics, "
|
||||
"mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor? experts_statics, "
|
||||
"Tensor allWorkspace, int max_token_count_per_rank, int ep_rank, int ep_size, int expert_count, int "
|
||||
"slot_count, int top_k) -> (Tensor, Tensor?, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?)");
|
||||
"slot_count, int top_k) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?)");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
@ -374,6 +300,19 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
m.impl("mnnvl_moe_alltoallv_prepare_without_allgather", &torch_ext::moePrepareOp);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"memset_expert_ids(Tensor(a!) experts_ids, Tensor recv_rank_count_cumsum, int max_token_count_per_rank, int "
|
||||
"top_k, "
|
||||
"int slot_count, int ep_size) -> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("memset_expert_ids", &torch_ext::memsetExpertIds);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("get_moe_prepare_workspace_size_per_rank(int ep_size) -> int");
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include "tensorrt_llm/kernels/moeCommKernels.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
|
||||
@ -69,32 +69,6 @@ function(add_gtest test_name test_src)
|
||||
add_dependencies(google-tests ${test_name})
|
||||
endfunction()
|
||||
|
||||
add_subdirectory(unit_tests)
|
||||
|
||||
add_gtest(mpiUtilsTest runtime/mpiUtilsTest.cpp)
|
||||
|
||||
add_gtest(gptDecoderTest runtime/gptDecoderTest.cpp)
|
||||
add_gtest(gptDecoderBatchedTest runtime/gptDecoderBatchedTest.cpp)
|
||||
add_gtest(medusaModuleTest runtime/medusaModuleTest.cpp)
|
||||
|
||||
add_gtest(moeLoadBalancerTest runtime/moeLoadBalancerTest.cpp)
|
||||
|
||||
add_gtest(sanitizerTest runtime/sanitizerTest.cpp)
|
||||
|
||||
add_gtest(eaglePackDataTest kernels/eaglePackDataTest.cpp)
|
||||
|
||||
add_gtest(medusaDecodeLayerTest layers/medusaDecodeLayerTest.cpp)
|
||||
|
||||
add_gtest(moeLoadBalanceKernelTest kernels/moeLoadBalanceKernelTest.cpp)
|
||||
|
||||
add_gtest(eagleLayerTest layers/eagleLayerTest.cpp)
|
||||
|
||||
add_subdirectory(utils)
|
||||
|
||||
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/batch_manager)
|
||||
add_subdirectory(batch_manager)
|
||||
endif()
|
||||
|
||||
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/executor)
|
||||
add_subdirectory(executor)
|
||||
endif()
|
||||
add_subdirectory(unit_tests)
|
||||
add_subdirectory(e2e_tests)
|
||||
|
||||
13
cpp/tests/e2e_tests/CMakeLists.txt
Normal file
13
cpp/tests/e2e_tests/CMakeLists.txt
Normal file
@ -0,0 +1,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: NVIDIA TensorRT
|
||||
# Source Code License Agreement
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this material and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
|
||||
# prohibited.
|
||||
|
||||
add_subdirectory(batch_manager)
|
||||
add_subdirectory(executor)
|
||||
@ -9,13 +9,10 @@
|
||||
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
|
||||
# prohibited.
|
||||
|
||||
add_gtest(cacheTransceiverTest cacheTransceiverTest.cpp)
|
||||
|
||||
# guidedDecoderTest requires model tokenizer info, so it's easier to run it with
|
||||
# e2e tests instead of unit tests.
|
||||
add_gtest(guidedDecoderTest guidedDecoderTest.cpp)
|
||||
add_gtest(trtEncoderModelTest trtEncoderModelTest.cpp)
|
||||
add_gtest(trtGptModelTest trtGptModelTest.cpp)
|
||||
add_gtest(trtGptModelRealDecoderTest trtGptModelRealDecoderTest.cpp)
|
||||
target_link_libraries(trtGptModelRealDecoderTest PRIVATE testingUtils)
|
||||
|
||||
add_gtest(peftCacheManagerTest peftCacheManagerTest.cpp)
|
||||
add_gtest(trtEncoderModelTest trtEncoderModelTest.cpp)
|
||||
add_gtest(guidedDecoderTest guidedDecoderTest.cpp)
|
||||
add_gtest(blockKeyTest blockKeyTest.cpp)
|
||||
@ -1405,7 +1405,7 @@ INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen1TP2PP2DisaaggOrchestrator, DisaggOrches
|
||||
),
|
||||
generateTestNameDisaggParams);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP2Gen2TP1DisaaggSpawnOrchestrator, DisaggOrchestratorParamsTest,
|
||||
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP2Gen2TP1DisaggSpawnOrchestrator, DisaggOrchestratorParamsTest,
|
||||
testing::Combine( //
|
||||
testing::Values(1), // processNum
|
||||
testing::Values(
|
||||
@ -1418,7 +1418,7 @@ INSTANTIATE_TEST_SUITE_P(LlamaCon2TP2Gen2TP1DisaaggSpawnOrchestrator, DisaggOrch
|
||||
),
|
||||
generateTestNameDisaggParams);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen2PP2DisaaggSpawnOrchestrator, DisaggOrchestratorParamsTest,
|
||||
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen2PP2DisaggSpawnOrchestrator, DisaggOrchestratorParamsTest,
|
||||
testing::Combine( //
|
||||
testing::Values(1), // processNum
|
||||
testing::Values(
|
||||
@ -19,6 +19,7 @@ endif()
|
||||
|
||||
add_subdirectory(common)
|
||||
add_subdirectory(kernels)
|
||||
add_subdirectory(multi_gpu)
|
||||
add_subdirectory(layers)
|
||||
add_subdirectory(runtime)
|
||||
add_subdirectory(thop)
|
||||
|
||||
@ -9,6 +9,8 @@
|
||||
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
|
||||
# prohibited.
|
||||
|
||||
add_gtest(blockKeyTest blockKeyTest.cpp)
|
||||
add_gtest(cacheTransBufferTest cacheTransBufferTest.cpp)
|
||||
add_gtest(capacitySchedulerTest capacitySchedulerTest.cpp)
|
||||
add_gtest(contextProgressTest contextProgressTest.cu)
|
||||
add_gtest(evictionPolicyTest evictionPolicyTest.cpp)
|
||||
@ -16,5 +18,5 @@ add_gtest(kvCacheManagerTest kvCacheManagerTest.cpp)
|
||||
add_gtest(kvCacheUtilsTest kvCacheUtilsTest.cpp)
|
||||
add_gtest(llmRequestTest llmRequestTest.cpp)
|
||||
add_gtest(microBatchSchedulerTest microBatchSchedulerTest.cpp)
|
||||
add_gtest(peftCacheManagerTest peftCacheManagerTest.cpp)
|
||||
add_gtest(staticThreadPoolTest staticThreadPoolTest.cpp)
|
||||
add_gtest(cacheTransBufferTest cacheTransBufferTest.cpp)
|
||||
|
||||
@ -45,16 +45,7 @@ add_gtest(cudaCoreGemmKernelTest cudaCoreGemm/cudaCoreGemmKernelTest.cpp)
|
||||
|
||||
add_gtest(mlaChunkedPrefillTest mlaChunkedPrefillTest.cu)
|
||||
|
||||
if(NOT ENABLE_MULTI_DEVICE EQUAL 0)
|
||||
add_gtest(allReduceKernelTest allReduce/allReduceKernelTest.cu)
|
||||
add_gtest(allReduceFusionTest allReduce/allReduceFusionTest.cu)
|
||||
add_gtest(gemmAllReduceTest allReduce/gemmAllReduceTest.cu)
|
||||
if(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
|
||||
target_link_libraries(gemmAllReduceTest PRIVATE ar_gemm_src)
|
||||
target_compile_definitions(gemmAllReduceTest
|
||||
PRIVATE USING_OSS_CUTLASS_ALLREDUCE_GEMM)
|
||||
endif()
|
||||
endif()
|
||||
add_gtest(fusedMoeCommKernelTest fusedMoeCommKernelTest.cpp)
|
||||
|
||||
add_gtest(
|
||||
gemmSwigluRunnerTest
|
||||
@ -89,11 +80,13 @@ set(SAMPLING_KERNEL_TEST_SRC
|
||||
sampling/samplingTest.cpp sampling/samplingTopKTest.cpp
|
||||
sampling/samplingTopPTest.cpp sampling/samplingAirTopPTest.cpp
|
||||
sampling/samplingPenaltyTest.cpp sampling/samplingUtilsTest.cu)
|
||||
|
||||
add_gtest(samplingKernelsTest "${SAMPLING_KERNEL_TEST_SRC}")
|
||||
|
||||
set(ROUTING_KERNEL_TEST_SRC
|
||||
routing/routingTest.cpp routing/routingLlama4Test.cpp
|
||||
routing/routingRenormalizeTest.cpp routing/routingDeepSeekTest.cpp)
|
||||
|
||||
add_gtest(routingKernelsTest "${ROUTING_KERNEL_TEST_SRC}")
|
||||
|
||||
add_gtest(moeLoadBalanceKernelTest moeLoadBalanceKernelTest.cpp)
|
||||
|
||||
add_gtest(eaglePackDataTest eaglePackDataTest.cpp)
|
||||
|
||||
1410
cpp/tests/unit_tests/kernels/fusedMoeCommKernelTest.cpp
Normal file
1410
cpp/tests/unit_tests/kernels/fusedMoeCommKernelTest.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@ -31,5 +31,7 @@ set(LOOKAHEAD_DECODING_TEST_SRC randomLlm.cpp lookaheadDecodingLayerTest.cpp)
|
||||
add_gtest(lookaheadDecodingLayerTest "${LOOKAHEAD_DECODING_TEST_SRC}")
|
||||
|
||||
add_gtest(dynamicDecodeLayerTest dynamicDecodeLayerTest.cpp)
|
||||
add_gtest(eagleLayerTest eagleLayerTest.cpp)
|
||||
add_gtest(explicitDraftTokensLayerTest explicitDraftTokensLayerTest.cpp)
|
||||
add_gtest(layerUtilsTest layerUtilsTest.cpp)
|
||||
add_gtest(medusaDecodeLayerTest medusaDecodeLayerTest.cpp)
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tests/layers/eagleLayerTest.h"
|
||||
#include "eagleLayerTest.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include "tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h"
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tests/layers/medusaDecodeLayerTest.h"
|
||||
#include "medusaDecodeLayerTest.h"
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include "tensorrt_llm/runtime/medusaModule.h"
|
||||
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
||||
16
cpp/tests/unit_tests/multi_gpu/CMakeLists.txt
Normal file
16
cpp/tests/unit_tests/multi_gpu/CMakeLists.txt
Normal file
@ -0,0 +1,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: NVIDIA TensorRT
|
||||
# Source Code License Agreement
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this material and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
|
||||
# prohibited.
|
||||
|
||||
add_subdirectory(kernels)
|
||||
|
||||
add_gtest(cacheTransceiverTest cacheTransceiverTest.cpp)
|
||||
add_gtest(mpiUtilsTest mpiUtilsTest.cpp)
|
||||
add_gtest(userBufferTest userBufferTest.cpp)
|
||||
21
cpp/tests/unit_tests/multi_gpu/kernels/CMakeLists.txt
Normal file
21
cpp/tests/unit_tests/multi_gpu/kernels/CMakeLists.txt
Normal file
@ -0,0 +1,21 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: NVIDIA TensorRT
|
||||
# Source Code License Agreement
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this material and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
|
||||
# prohibited.
|
||||
|
||||
if(NOT ENABLE_MULTI_DEVICE EQUAL 0)
|
||||
add_gtest(allReduceKernelTest allReduce/allReduceKernelTest.cu)
|
||||
add_gtest(allReduceFusionTest allReduce/allReduceFusionTest.cu)
|
||||
add_gtest(gemmAllReduceTest allReduce/gemmAllReduceTest.cu)
|
||||
if(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
|
||||
target_link_libraries(gemmAllReduceTest PRIVATE ar_gemm_src)
|
||||
target_compile_definitions(gemmAllReduceTest
|
||||
PRIVATE USING_OSS_CUTLASS_ALLREDUCE_GEMM)
|
||||
endif()
|
||||
endif()
|
||||
@ -501,11 +501,8 @@ TEST(Kernel_AllReduceFusion, AllReduceAccuracyRandomTokenNum)
|
||||
auto& comm = mpi::MpiComm::world();
|
||||
auto world_size = comm.getSize();
|
||||
auto rank = comm.getRank();
|
||||
if (world_size % 2)
|
||||
{
|
||||
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
|
||||
return;
|
||||
}
|
||||
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
|
||||
|
||||
int iter = 100;
|
||||
std::vector<int> candidate_hidden_dim{1024, 2048, 4096, 7168, 8192};
|
||||
int min_token_num = 1;
|
||||
@ -537,11 +534,8 @@ TEST(Kernel_AllReduceFusion, AllReduceAccuracyFixedTokenNum)
|
||||
auto& comm = mpi::MpiComm::world();
|
||||
auto world_size = comm.getSize();
|
||||
auto rank = comm.getRank();
|
||||
if (world_size % 2)
|
||||
{
|
||||
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
|
||||
return;
|
||||
}
|
||||
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
|
||||
|
||||
int iter = 10;
|
||||
std::vector<int> candidate_hidden_dim{1024, 2048, 4096, 7168, 8192};
|
||||
int min_token_num = 1;
|
||||
@ -603,11 +597,8 @@ TEST(Kernel_AllReduceFusion, AllReduceFusionAccuracyDifferentHiddenDim)
|
||||
auto& comm = mpi::MpiComm::world();
|
||||
auto world_size = comm.getSize();
|
||||
auto rank = comm.getRank();
|
||||
if (world_size % 2)
|
||||
{
|
||||
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
|
||||
return;
|
||||
}
|
||||
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
|
||||
|
||||
int const arch = tensorrt_llm::common::getSMVersion();
|
||||
if (arch >= 100)
|
||||
{
|
||||
@ -647,11 +638,8 @@ TEST(Kernel_AllReduceFusion, AllReduceFusionAccuracyDifferentDType)
|
||||
auto& comm = mpi::MpiComm::world();
|
||||
auto world_size = comm.getSize();
|
||||
auto rank = comm.getRank();
|
||||
if (world_size % 2)
|
||||
{
|
||||
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
|
||||
return;
|
||||
}
|
||||
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
|
||||
|
||||
std::vector<int> candidate_hidden_dim{1024, 2048, 4096, 7168, 8192};
|
||||
int min_token_num = 1;
|
||||
int max_token_num = 2048;
|
||||
@ -683,53 +671,52 @@ TEST(Kernel_AllReduceFusion, AllReduceFusionAccuracyDifferentDType)
|
||||
TEST(Kernel_AllReduceFusion, Perf)
|
||||
{
|
||||
int const arch = tensorrt_llm::common::getSMVersion();
|
||||
if (arch >= 100)
|
||||
if (arch < 100)
|
||||
{
|
||||
using Runner = TestRunner<half, ar_fusion::AllReduceFusionPattern::kARResidualRMSNormFP4Quant>;
|
||||
auto& comm = mpi::MpiComm::world();
|
||||
auto world_size = comm.getSize();
|
||||
auto rank = comm.getRank();
|
||||
if (world_size % 2)
|
||||
GTEST_SKIP() << "Skipping test for SM < 100";
|
||||
}
|
||||
|
||||
using Runner = TestRunner<half, ar_fusion::AllReduceFusionPattern::kARResidualRMSNormFP4Quant>;
|
||||
auto& comm = mpi::MpiComm::world();
|
||||
auto world_size = comm.getSize();
|
||||
auto rank = comm.getRank();
|
||||
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
|
||||
|
||||
int warmup = 100, iter = 300;
|
||||
int hidden_dim = 7168;
|
||||
std::vector<int> candidate_token_num{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048};
|
||||
int max_token_num = 2048;
|
||||
Runner runner(max_token_num, hidden_dim);
|
||||
for (auto token_num : candidate_token_num)
|
||||
{
|
||||
auto latency = runner.benchmark(&Runner::run_kernel, warmup, iter, token_num, hidden_dim);
|
||||
if (rank == 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
|
||||
return;
|
||||
TLLM_LOG_INFO(
|
||||
"token_num %-4d, hidden_dim %-4d, fusion kernel latency %4.4fus", token_num, hidden_dim, latency);
|
||||
}
|
||||
int warmup = 100, iter = 300;
|
||||
int hidden_dim = 7168;
|
||||
std::vector<int> candidate_token_num{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048};
|
||||
int max_token_num = 2048;
|
||||
Runner runner(max_token_num, hidden_dim);
|
||||
for (auto token_num : candidate_token_num)
|
||||
auto nccl_latency = runner.benchmark(&Runner::run_nccl_allreduce, warmup, iter, token_num, hidden_dim);
|
||||
if (rank == 0)
|
||||
{
|
||||
auto latency = runner.benchmark(&Runner::run_kernel, warmup, iter, token_num, hidden_dim);
|
||||
if (rank == 0)
|
||||
{
|
||||
TLLM_LOG_INFO(
|
||||
"token_num %-4d, hidden_dim %-4d, fusion kernel latency %4.4fus", token_num, hidden_dim, latency);
|
||||
}
|
||||
auto nccl_latency = runner.benchmark(&Runner::run_nccl_allreduce, warmup, iter, token_num, hidden_dim);
|
||||
if (rank == 0)
|
||||
{
|
||||
TLLM_LOG_INFO("nccl allreduce latency %4.4fus", nccl_latency);
|
||||
}
|
||||
auto residual_latency = runner.benchmark(&Runner::run_residual_add, warmup, iter, token_num, hidden_dim);
|
||||
if (rank == 0)
|
||||
{
|
||||
TLLM_LOG_INFO("residual add latency %4.4fus", residual_latency);
|
||||
}
|
||||
auto rms_latency = runner.benchmark(&Runner::run_rms_norm, warmup, iter, token_num, hidden_dim);
|
||||
if (rank == 0)
|
||||
{
|
||||
TLLM_LOG_INFO("rms norm latency %4.4fus", rms_latency);
|
||||
}
|
||||
auto quant_latency = runner.benchmark(&Runner::run_fp4_quant, warmup, iter, token_num, hidden_dim);
|
||||
if (rank == 0)
|
||||
{
|
||||
TLLM_LOG_INFO("fp4 quant latency %4.4fus", quant_latency);
|
||||
auto tot_latency = nccl_latency + residual_latency + rms_latency + quant_latency;
|
||||
TLLM_LOG_INFO("fusion kernel latency %4.4fus, nccl + ops latency %4.4fus, total speedup %2.4fx",
|
||||
latency, tot_latency, tot_latency / latency);
|
||||
}
|
||||
TLLM_LOG_INFO("nccl allreduce latency %4.4fus", nccl_latency);
|
||||
}
|
||||
auto residual_latency = runner.benchmark(&Runner::run_residual_add, warmup, iter, token_num, hidden_dim);
|
||||
if (rank == 0)
|
||||
{
|
||||
TLLM_LOG_INFO("residual add latency %4.4fus", residual_latency);
|
||||
}
|
||||
auto rms_latency = runner.benchmark(&Runner::run_rms_norm, warmup, iter, token_num, hidden_dim);
|
||||
if (rank == 0)
|
||||
{
|
||||
TLLM_LOG_INFO("rms norm latency %4.4fus", rms_latency);
|
||||
}
|
||||
auto quant_latency = runner.benchmark(&Runner::run_fp4_quant, warmup, iter, token_num, hidden_dim);
|
||||
if (rank == 0)
|
||||
{
|
||||
TLLM_LOG_INFO("fp4 quant latency %4.4fus", quant_latency);
|
||||
auto tot_latency = nccl_latency + residual_latency + rms_latency + quant_latency;
|
||||
TLLM_LOG_INFO("fusion kernel latency %4.4fus, nccl + ops latency %4.4fus, total speedup %2.4fx", latency,
|
||||
tot_latency, tot_latency / latency);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -573,8 +573,7 @@ TEST(Kernel, AllReduce)
|
||||
auto& comm = mpi::MpiComm::world();
|
||||
auto world_size = comm.getSize();
|
||||
auto rank = comm.getRank();
|
||||
if (world_size % 2)
|
||||
return;
|
||||
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
|
||||
|
||||
int warmup = 100, iter = 100;
|
||||
// clang-format off
|
||||
@ -645,8 +644,7 @@ TEST(Kernel, AllReduceOneShot)
|
||||
auto& comm = mpi::MpiComm::world();
|
||||
auto world_size = comm.getSize();
|
||||
auto rank = comm.getRank();
|
||||
if (world_size % 2)
|
||||
return;
|
||||
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
|
||||
|
||||
int warmup = 100, iter = 100;
|
||||
std::vector<int> candidate_bs{1, 2, 4, 8, 16};
|
||||
@ -673,19 +671,14 @@ TEST(Kernel, AllReduceOneShotPreNorm)
|
||||
char const* value = "1";
|
||||
int overwrite = 1; // Set to 1 to overwrite existing values, 0 to preserve
|
||||
|
||||
if (setenv(varName, value, overwrite) != 0)
|
||||
{
|
||||
perror("Error setting environment variable");
|
||||
return;
|
||||
}
|
||||
ASSERT_EQ(setenv(varName, value, overwrite), 0) << "Error setting environment variable";
|
||||
|
||||
std::cout << varName << " set to " << getenv(varName) << std::endl;
|
||||
|
||||
auto& comm = mpi::MpiComm::world();
|
||||
auto world_size = comm.getSize();
|
||||
auto rank = comm.getRank();
|
||||
if (world_size % 2)
|
||||
return;
|
||||
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
|
||||
|
||||
int warmup = 100, iter = 100;
|
||||
std::vector<int> candidate_bs{1, 2, 4, 8, 16};
|
||||
@ -628,11 +628,8 @@ TEST(Kernel, MoEReduceAddARFuse)
|
||||
auto& comm = mpi::MpiComm::world();
|
||||
auto world_size = comm.getSize();
|
||||
auto rank = comm.getRank();
|
||||
if (world_size % 2)
|
||||
{
|
||||
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
|
||||
return;
|
||||
}
|
||||
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
|
||||
|
||||
int warmup = 100, iter = 100;
|
||||
int hidden_dim = 7168;
|
||||
std::vector<int> candidate_token_num{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048};
|
||||
@ -26,15 +26,13 @@ TEST(UserBuffer, basic)
|
||||
{
|
||||
if (!tr::ub::ub_supported())
|
||||
{
|
||||
return;
|
||||
GTEST_SKIP() << "UserBuffer is not supported";
|
||||
}
|
||||
auto& comm = mpi::MpiComm::world();
|
||||
auto world_size = comm.getSize();
|
||||
auto rank = comm.getRank();
|
||||
if (world_size % 2)
|
||||
{
|
||||
return;
|
||||
}
|
||||
ASSERT_EQ(world_size % 2, 0) << "Requires even world size (got " << world_size << ")";
|
||||
|
||||
tr::ub::ub_initialize(world_size);
|
||||
EXPECT_EQ(tr::ub::ub_is_initialized(), true);
|
||||
EXPECT_NE(tr::ub::ub_comm(), nullptr);
|
||||
@ -13,6 +13,8 @@ add_gtest(bufferManagerTest bufferManagerTest.cpp)
|
||||
add_gtest(cudaMemPoolTest cudaMemPoolTest.cpp)
|
||||
add_gtest(decodingLayerWorkspaceTest decodingLayerWorkspaceTest.cpp)
|
||||
add_gtest(gdrcopyTest gdrcopyTest.cpp)
|
||||
add_gtest(gptDecoderBatchedTest gptDecoderBatchedTest.cpp)
|
||||
add_gtest(gptDecoderTest gptDecoderTest.cpp)
|
||||
add_gtest(hostAccessibleDeviceAllocatorTest
|
||||
hostAccessibleDeviceAllocatorTest.cu)
|
||||
add_gtest(iBufferTest iBufferTest.cpp)
|
||||
@ -20,13 +22,15 @@ add_gtest(iTensorTest iTensorTest.cpp)
|
||||
add_gtest(loraCacheTest loraCacheTest.cpp)
|
||||
add_gtest(loraManagerTest loraManagerTest.cpp)
|
||||
add_gtest(loraUtilsTest loraUtilsTest.cpp)
|
||||
add_gtest(medusaModuleTest medusaModuleTest.cpp)
|
||||
add_gtest(moeLoadBalancerTest moeLoadBalancerTest.cpp)
|
||||
add_gtest(runtimeKernelTest runtimeKernelTest.cpp)
|
||||
add_gtest(samplingConfigTest samplingConfigTest.cpp)
|
||||
add_gtest(samplingTest samplingTest.cpp)
|
||||
add_gtest(sanitizerTest sanitizerTest.cpp)
|
||||
add_gtest(tllmBuffersTest tllmBuffersTest.cpp)
|
||||
add_gtest(tllmRuntimeTest tllmRuntimeTest.cpp)
|
||||
add_gtest(transposeKVKernelTest transposeKVKernelTest.cpp)
|
||||
add_gtest(userBufferTest userBufferTest.cpp)
|
||||
add_gtest(utilsTest utilsTest.cpp)
|
||||
add_gtest(virtualMemoryTest virtualMemoryTest.cpp)
|
||||
add_gtest(workerPoolTest workerPoolTest.cpp)
|
||||
|
||||
@ -69,9 +69,11 @@ cat << EOF > ${EXTRA_LLM_API_FILE}
|
||||
enable_attention_dp: false
|
||||
cuda_graph_config:
|
||||
enable_padding: true
|
||||
max_batch_size: 128
|
||||
max_batch_size: 720
|
||||
moe_config:
|
||||
backend: TRTLLM
|
||||
stream_interval: 10
|
||||
num_postprocess_workers: 4
|
||||
EOF
|
||||
```
|
||||
|
||||
@ -84,9 +86,11 @@ cat << EOF > ${EXTRA_LLM_API_FILE}
|
||||
enable_attention_dp: true
|
||||
cuda_graph_config:
|
||||
enable_padding: true
|
||||
max_batch_size: 128
|
||||
max_batch_size: 720
|
||||
moe_config:
|
||||
backend: CUTLASS
|
||||
stream_interval: 10
|
||||
num_postprocess_workers: 4
|
||||
EOF
|
||||
```
|
||||
|
||||
@ -99,9 +103,8 @@ trtllm-serve openai/gpt-oss-120b \
|
||||
--host 0.0.0.0 \
|
||||
--port 8000 \
|
||||
--backend pytorch \
|
||||
--max_batch_size 128 \
|
||||
--max_batch_size 720 \
|
||||
--max_num_tokens 16384 \
|
||||
--max_seq_len 2048 \
|
||||
--kv_cache_free_gpu_memory_fraction 0.9 \
|
||||
--tp_size 8 \
|
||||
--ep_size 8 \
|
||||
@ -134,7 +137,7 @@ These options are used directly on the command line when you start the `trtllm-s
|
||||
|
||||
#### `--max_batch_size`
|
||||
|
||||
* **Description:** The maximum number of user requests that can be grouped into a single batch for processing.
|
||||
* **Description:** The maximum number of user requests that can be grouped into a single batch for processing. The actual max batch size that can be achieved depends on total sequence length (input + output).
|
||||
|
||||
#### `--max_num_tokens`
|
||||
|
||||
@ -142,7 +145,7 @@ These options are used directly on the command line when you start the `trtllm-s
|
||||
|
||||
#### `--max_seq_len`
|
||||
|
||||
* **Description:** The maximum possible sequence length for a single request, including both input and generated output tokens.
|
||||
* **Description:** The maximum possible sequence length for a single request, including both input and generated output tokens. We won't specifically set it. It will be inferred from model config.
|
||||
|
||||
#### `--trust_remote_code`
|
||||
|
||||
@ -229,8 +232,7 @@ TODO: Use Chat Compeletions API / Responses API as the example after the PR is m
|
||||
### Running Evaluations to Verify Accuracy (Optional)
|
||||
|
||||
We use OpenAI's official evaluation tool to test the model's accuracy. For more information see [https://github.com/openai/gpt-oss/tree/main/gpt_oss/evals](gpt-oss-eval).
|
||||
|
||||
TODO(@Binghan Chen): Add instructions for running gpt-oss-eval.
|
||||
With the added support of Chat Completions and Responses API in `trtllm-serve,` `gpt_oss.evals` works directly without any modifications.
|
||||
|
||||
## Benchmarking Performance
|
||||
|
||||
@ -267,6 +269,8 @@ EOF
|
||||
chmod +x bench.sh
|
||||
```
|
||||
|
||||
To achieve max through-put, with attention DP on, one needs to sweep up to `concurrency = max_batch_size * num_gpus`.
|
||||
|
||||
If you want to save the results to a file add the following options.
|
||||
|
||||
```shell
|
||||
|
||||
@ -690,9 +690,9 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars)
|
||||
"cpp/tensorrt_llm/thop/allgatherOp.cpp",
|
||||
"cpp/tensorrt_llm/thop/allreduceOp.cpp",
|
||||
"cpp/tensorrt_llm/thop/reducescatterOp.cpp",
|
||||
"cpp/tests/executor/",
|
||||
"cpp/tests/kernels/allReduce/",
|
||||
"cpp/tests/runtime/mpiUtilsTest.cpp",
|
||||
"cpp/tests/e2e_tests/batch_manager/",
|
||||
"cpp/tests/e2e_tests/executor/",
|
||||
"cpp/tests/unit_tests/multi_gpu/",
|
||||
"jenkins/L0_Test.groovy",
|
||||
"tensorrt_llm/_ipc_utils.py",
|
||||
"tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py",
|
||||
|
||||
@ -7,7 +7,6 @@ import groovy.json.JsonOutput
|
||||
import com.nvidia.bloom.KubernetesManager
|
||||
import com.nvidia.bloom.Constants
|
||||
import com.nvidia.bloom.CloudManager
|
||||
import com.nvidia.bloom.KubernetesManager
|
||||
import com.nvidia.bloom.SlurmConfig
|
||||
import com.nvidia.bloom.SlurmCluster
|
||||
import com.nvidia.bloom.SlurmPartition
|
||||
@ -230,8 +229,11 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
|
||||
SlurmPartition partition = SlurmConfig.partitionConfig[platform] as SlurmPartition
|
||||
SlurmCluster cluster = SlurmConfig.clusterConfig[partition.clusterName]
|
||||
|
||||
def nodeName = "${cluster.host}-test-${UUID.randomUUID().toString()}"
|
||||
def nodeSecret = CloudManager.createNode(nodeName)
|
||||
// Create a unique suffix for the node name and workspace
|
||||
String customSuffix = "${env.BUILD_TAG}-${UUID.randomUUID().toString().replaceAll("-", "").substring(0, 6)}".toLowerCase()
|
||||
def nodeName = "${cluster.host}-test-${customSuffix}"
|
||||
def customWorkspace = "/tmp/${nodeName}"
|
||||
def nodeSecret = CloudManager.createNode(nodeName, customWorkspace)
|
||||
|
||||
try {
|
||||
// Run ssh command to start node in desired cluster via SLURM
|
||||
@ -274,12 +276,30 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
|
||||
}
|
||||
|
||||
if (CloudManager.isNodeOnline(nodeName)) {
|
||||
def dockerArgs = "--gpus ${gpuCount} --cap-add=SYS_ADMIN --ipc=host --security-opt seccomp=unconfined -u root:root -v /home/scratch.trt_llm_data:/scratch.trt_llm_data:ro -v /tmp/ccache:${CCACHE_DIR}:rw -v /tmp/pipcache/http-v2:/root/.cache/pip/http-v2:rw --cap-add syslog"
|
||||
node(nodeName) {
|
||||
sh """
|
||||
env | sort
|
||||
pwd && ls -alh
|
||||
ls -alh ${env.WORKSPACE}
|
||||
ls -alh ${env.WORKSPACE_TMP}
|
||||
"""
|
||||
}
|
||||
|
||||
def dockerArgs = "--gpus ${gpuCount} " +
|
||||
"--cap-add=SYS_ADMIN " +
|
||||
"--ipc=host " +
|
||||
"--security-opt seccomp=unconfined " +
|
||||
"-u root:root " +
|
||||
"-v /home/scratch.trt_llm_data:/scratch.trt_llm_data:ro " +
|
||||
"-v /tmp/ccache:${CCACHE_DIR}:rw " +
|
||||
"-v /tmp/pipcache/http-v2:/root/.cache/pip/http-v2:rw " +
|
||||
"--cap-add syslog"
|
||||
|
||||
if (partition.clusterName == "dlcluster") {
|
||||
dockerArgs += " -e NVIDIA_IMEX_CHANNELS=0"
|
||||
}
|
||||
slurmRunner = runInDockerOnNodeMultiStage(LLM_DOCKER_IMAGE, nodeName, dockerArgs, false)
|
||||
|
||||
slurmRunner = runInDockerOnNodeMultiStage(LLM_DOCKER_IMAGE, nodeName, dockerArgs, true)
|
||||
executeLLMTestOnSlurm(pipeline, platform, testList, config, perfMode, stageName, splitId, splits, skipInstallWheel, cpver, slurmRunner)
|
||||
} else {
|
||||
echo "The node does not come online in 2 hours, terminating the job"
|
||||
@ -571,6 +591,13 @@ def cacheErrorAndUploadResult(stageName, taskRunner, finallyRunner, noResultIfSu
|
||||
"${UPLOAD_PATH}/test-results/"
|
||||
)
|
||||
junit(testResults: "${stageName}/results*.xml")
|
||||
|
||||
// Clean up the workspace
|
||||
sh """
|
||||
env | sort
|
||||
pwd && ls -alh
|
||||
rm -rf ./*
|
||||
"""
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -807,7 +834,7 @@ def echoNodeAndGpuInfo(pipeline, stageName)
|
||||
|
||||
def runLLMDocBuild(pipeline, config)
|
||||
{
|
||||
// Step 1: cloning tekit source code
|
||||
// Step 1: cloning source code
|
||||
sh "pwd && ls -alh"
|
||||
sh "env | sort"
|
||||
// allow to checkout from forked repo, svc_tensorrt needs to have access to the repo, otherwise clone will fail
|
||||
@ -1252,13 +1279,16 @@ def rerunFailedTests(stageName, llmSrc, testCmdLine) {
|
||||
|
||||
def runLLMTestlistOnPlatformImpl(pipeline, platform, testList, config=VANILLA_CONFIG, perfMode=false, stageName="Undefined", splitId=1, splits=1, skipInstallWheel=false, cpver="cp312")
|
||||
{
|
||||
// Step 1: create LLM_ROOT dir
|
||||
sh "pwd && ls -alh"
|
||||
// TODO: proper way to clean workspace, maybe save in a folder named with BUILD_ID.
|
||||
// So that it can work with multiple job running in same node
|
||||
sh "rm -rf ./*"
|
||||
// Step 1: create LLM_ROOT dir and clean up the workspace
|
||||
def llmRootConfig = "${LLM_ROOT}${config}"
|
||||
sh "mkdir ${llmRootConfig}"
|
||||
sh """
|
||||
env | sort
|
||||
pwd && ls -alh
|
||||
rm -rf ./*
|
||||
mkdir ${llmRootConfig}
|
||||
ls -alh ${env.WORKSPACE}
|
||||
ls -alh ${env.WORKSPACE_TMP}
|
||||
"""
|
||||
|
||||
def llmPath = sh (script: "realpath ${llmRootConfig}", returnStdout: true).trim()
|
||||
def llmSrc = "${llmPath}/TensorRT-LLM/src"
|
||||
@ -1779,7 +1809,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
"DGX_H100-4_GPUs-PyTorch-DeepSeek-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 2, 4],
|
||||
"DGX_H100-4_GPUs-PyTorch-DeepSeek-2": ["dgx-h100-x4", "l0_dgx_h100", 2, 2, 4],
|
||||
"DGX_H100-4_GPUs-PyTorch-Others-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
|
||||
"DGX_H100-4_GPUs-Triton-Post-Merge-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
|
||||
"DGX_H100-4_GPUs-CPP-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
|
||||
"A10-PyTorch-1": ["a10", "l0_a10", 1, 1],
|
||||
"A10-CPP-1": ["a10", "l0_a10", 1, 1],
|
||||
@ -1852,6 +1881,7 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
"B200_PCIe-TensorRT-Post-Merge-2": ["b100-ts2", "l0_b200", 2, 2],
|
||||
"H100_PCIe-TensorRT-Perf-1": ["h100-cr", "l0_perf", 1, 1],
|
||||
"H100_PCIe-PyTorch-Perf-1": ["h100-cr", "l0_perf", 1, 1],
|
||||
"DGX_H200-4_GPUs-Triton-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 1, 4],
|
||||
"DGX_H200-8_GPUs-PyTorch-Post-Merge-1": ["dgx-h200-x8", "l0_dgx_h200", 1, 1, 8],
|
||||
"DGX_H200-4_GPUs-PyTorch-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 1, 4],
|
||||
"DGX_H200-4_GPUs-TensorRT-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 3, 4],
|
||||
@ -1910,8 +1940,10 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
fullSet += SBSATestConfigs.keySet()
|
||||
|
||||
SBSASlurmTestConfigs = [
|
||||
"GB200-4_GPUs-PyTorch-1": ["gb200-x4", "l0_gb200", 1, 1, 4],
|
||||
"GB200-4_GPUs-PyTorch-Post-Merge-1": ["gb200-x4", "l0_gb200", 1, 1, 4],
|
||||
// Not used in the pipeline now
|
||||
// "GB200-PyTorch-1": ["gb200-single", "l0_gb200", 1, 3],
|
||||
"GB200-4_GPUs-PyTorch-1": ["gb200-x4", "l0_gb200_multi_gpus", 1, 1, 4],
|
||||
"GB200-4_GPUs-PyTorch-Post-Merge-1": ["gb200-x4", "l0_gb200_multi_gpus", 1, 1, 4],
|
||||
]
|
||||
fullSet += SBSASlurmTestConfigs.keySet()
|
||||
|
||||
@ -1923,7 +1955,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
"GB200-8_GPUs-2_Nodes-PyTorch-Post-Merge-4": ["gb200-multi-node", "l0_gb200_multi_nodes", 4, 7, 8, 2],
|
||||
"GB200-8_GPUs-2_Nodes-PyTorch-Post-Merge-5": ["gb200-multi-node", "l0_gb200_multi_nodes", 5, 7, 8, 2],
|
||||
"GB200-8_GPUs-2_Nodes-PyTorch-Post-Merge-6": ["gb200-multi-node", "l0_gb200_multi_nodes", 6, 7, 8, 2],
|
||||
"GB200-8_GPUs-2_Nodes-PyTorch-Post-Merge-7": ["gb200-multi-node", "l0_gb200_multi_nodes", 7, 7, 8, 2],
|
||||
]
|
||||
fullSet += multiNodesSBSAConfigs.keySet()
|
||||
|
||||
@ -2145,7 +2176,9 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
echo "###### Check pip install Start ######"
|
||||
withEnv(libEnv) {
|
||||
sh "env | sort"
|
||||
checkPipInstall(pipeline, "${cpu_arch}/${wheelPath}")
|
||||
timeout(time: 1, unit: 'HOURS') {
|
||||
checkPipInstall(pipeline, "${cpu_arch}/${wheelPath}")
|
||||
}
|
||||
}
|
||||
echo "###### Run LLMAPI tests Start ######"
|
||||
def config = VANILLA_CONFIG_CU12
|
||||
@ -2481,7 +2514,7 @@ pipeline {
|
||||
|
||||
def testPhase2StageName = env.testPhase2StageName
|
||||
if (testPhase2StageName) {
|
||||
def dgxSigns = ["DGX_H100", "DGX_H200", "GB200", "DGX_B200", "RTXPro6000-4_GPUs"]
|
||||
def dgxSigns = ["2_GPUs", "4_GPUs", "8_GPUs"]
|
||||
singleGpuJobs = parallelJobs.findAll{!dgxSigns.any{sign -> it.key.contains(sign)}}
|
||||
dgxJobs = parallelJobs.findAll{dgxSigns.any{sign -> it.key.contains(sign)}}
|
||||
}
|
||||
|
||||
@ -34,7 +34,7 @@ else
|
||||
done
|
||||
fi
|
||||
testList="$testList_$splitId"
|
||||
export CPP_TEST_TIMEOUT_OVERRIDDEN=7200
|
||||
export CPP_TEST_TIMEOUT_OVERRIDDEN=$pytestTestTimeout
|
||||
export LLM_ROOT=$llmSrcNode
|
||||
export LLM_MODELS_ROOT=$MODEL_CACHE_DIR
|
||||
export UCX_TLS=^gdr_copy
|
||||
@ -43,6 +43,7 @@ testCmdLines=(
|
||||
"$llmSrcNode/tensorrt_llm/llmapi/trtllm-llmapi-launch"
|
||||
"pytest"
|
||||
"-v"
|
||||
"--timeout-method=thread"
|
||||
"--timeout=$pytestTestTimeout"
|
||||
"--test-list=$testListPathNode"
|
||||
"--waives-file=$waivesListPathNode"
|
||||
|
||||
@ -17,7 +17,7 @@ import os
|
||||
import platform
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
@ -366,6 +366,10 @@ class MnnvlMoe:
|
||||
)
|
||||
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
|
||||
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(torch.uint64)
|
||||
torch.ops.trtllm.moe_initialize_workspace(
|
||||
MnnvlMoe.moe_workspace_tensor, mapping.tp_rank, mapping.tp_size
|
||||
)
|
||||
MnnvlMoe.moe_workspace.comm.barrier()
|
||||
return MnnvlMoe.moe_workspace_tensor
|
||||
|
||||
@staticmethod
|
||||
@ -394,7 +398,6 @@ class MnnvlMoe:
|
||||
@staticmethod
|
||||
def mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
expert_ids: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
expert_statics: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
max_token_count_per_rank: int,
|
||||
@ -405,8 +408,6 @@ class MnnvlMoe:
|
||||
top_k: int,
|
||||
):
|
||||
(
|
||||
prepared_local_experts,
|
||||
prepared_local_scales,
|
||||
local_send_rank_count_cumsum,
|
||||
local_send_rank_indices,
|
||||
local_recv_rank_count_cumsum,
|
||||
@ -415,7 +416,6 @@ class MnnvlMoe:
|
||||
gathered_expert_statics,
|
||||
) = torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
expert_ids,
|
||||
scales,
|
||||
expert_statics,
|
||||
workspace,
|
||||
max_token_count_per_rank,
|
||||
@ -440,7 +440,7 @@ class MnnvlMoe:
|
||||
local_token_allocation_count,
|
||||
)
|
||||
|
||||
return alltoall_info, prepared_local_experts, prepared_local_scales, gathered_expert_statics
|
||||
return alltoall_info, gathered_expert_statics
|
||||
|
||||
@staticmethod
|
||||
def mnnvl_moe_expert_static_allgather(
|
||||
@ -526,31 +526,67 @@ class MnnvlMoe:
|
||||
|
||||
@staticmethod
|
||||
def mnnvl_moe_alltoallv(
|
||||
x: torch.Tensor,
|
||||
x: Union[torch.Tensor, List[Optional[torch.Tensor]]],
|
||||
alltoall_info: MoEAlltoallInfo,
|
||||
workspace: torch.Tensor,
|
||||
ep_rank: int,
|
||||
ep_size: int,
|
||||
):
|
||||
assert x.dim() == 2, "only 2D tensor supported, please reshape."
|
||||
output_tensor = torch.empty(
|
||||
alltoall_info.local_token_allocation_count,
|
||||
x.shape[1],
|
||||
dtype=x.dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
torch.ops.trtllm.moe_comm(
|
||||
x,
|
||||
alltoall_info.send_rank_count_cumsum,
|
||||
alltoall_info.send_rank_local_indices,
|
||||
output_tensor,
|
||||
alltoall_info.recv_rank_count_cumsum,
|
||||
alltoall_info.recv_rank_local_indices,
|
||||
workspace,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
return output_tensor
|
||||
) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]:
|
||||
# Convert single tensor to list for unified handling
|
||||
is_single_tensor = not isinstance(x, list)
|
||||
if is_single_tensor:
|
||||
assert x.dim() == 2, "only 2D tensor supported, please reshape."
|
||||
x = [x]
|
||||
|
||||
assert len(x) > 0, "Empty tensor list not supported"
|
||||
|
||||
# Filter out None values
|
||||
valid_list = [tensor is not None for tensor in x]
|
||||
valid_tensors = [tensor for tensor in x if tensor is not None]
|
||||
|
||||
if len(valid_tensors) == 0:
|
||||
# All tensors are None, return list of None
|
||||
result = [None] * len(x)
|
||||
else:
|
||||
first_dim = None
|
||||
for tensor in valid_tensors:
|
||||
# Validate dimensions of valid tensors
|
||||
assert tensor.dim() == 2, "only 2D tensor supported, please reshape."
|
||||
if first_dim is None:
|
||||
first_dim = tensor.shape[0]
|
||||
else:
|
||||
assert tensor.shape[0] == first_dim, (
|
||||
f"All tensors must have the same first dimension, got {tensor.shape[0]} vs {first_dim}"
|
||||
)
|
||||
|
||||
# Process only valid tensors
|
||||
output_tensors = torch.ops.trtllm.moe_comm(
|
||||
valid_tensors,
|
||||
alltoall_info.send_rank_count_cumsum,
|
||||
alltoall_info.send_rank_local_indices,
|
||||
alltoall_info.recv_rank_count_cumsum,
|
||||
alltoall_info.recv_rank_local_indices,
|
||||
workspace,
|
||||
alltoall_info.local_token_allocation_count,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
# Restore None positions in output
|
||||
idx = 0
|
||||
result = []
|
||||
for is_valid in valid_list:
|
||||
if is_valid:
|
||||
result.append(output_tensors[idx])
|
||||
idx += 1
|
||||
else:
|
||||
result.append(None)
|
||||
|
||||
# If input was a single tensor, return a single tensor
|
||||
if is_single_tensor:
|
||||
result = result[0]
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def mnnvl_moe_alltoallv_combine(
|
||||
@ -563,20 +599,19 @@ class MnnvlMoe:
|
||||
token_count: int,
|
||||
):
|
||||
assert x.dim() == 2, "2D tensor supported, please reshape."
|
||||
output_tensor = torch.zeros(
|
||||
token_count * top_k, x.shape[1], dtype=x.dtype, device=torch.device("cuda")
|
||||
)
|
||||
torch.ops.trtllm.moe_comm(
|
||||
x,
|
||||
output_tensors = torch.ops.trtllm.moe_comm(
|
||||
[x],
|
||||
alltoall_info.recv_rank_count_cumsum,
|
||||
alltoall_info.recv_rank_local_indices,
|
||||
output_tensor,
|
||||
alltoall_info.send_rank_count_cumsum,
|
||||
alltoall_info.backward_recv_rank_local_indices,
|
||||
workspace,
|
||||
token_count * top_k,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
[True],
|
||||
)
|
||||
output_tensor = output_tensors[0]
|
||||
return torch.sum(
|
||||
output_tensor.reshape(token_count, top_k, x.shape[1]), dim=1, keepdim=False
|
||||
)
|
||||
|
||||
@ -476,9 +476,6 @@ class SequenceInfo:
|
||||
idx = self.previous_batch_indices_cuda[: len(previous_batch_indices)]
|
||||
idx.copy_(host_idx, non_blocking=True)
|
||||
|
||||
# sort them so that masked_scatter_ lines up correctly
|
||||
idx, _ = idx.sort()
|
||||
|
||||
# gather the exact values you want to write
|
||||
src = new_tokens[0, idx, 0]
|
||||
|
||||
|
||||
@ -179,71 +179,27 @@ def _register_fake():
|
||||
return (input.new_empty(output_shape, dtype=torch.uint8),
|
||||
global_scale.new_empty(scale_shape, dtype=torch.uint8))
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_comm_prepare_indices")
|
||||
def _(
|
||||
gathered_target_rank_ids: torch.Tensor,
|
||||
real_rank_token_count_cum_sum: Optional[torch.Tensor],
|
||||
max_token_count_per_rank: int,
|
||||
expert_count: int,
|
||||
top_k: int,
|
||||
ep_rank: int,
|
||||
ep_size: int,
|
||||
):
|
||||
max_send_ranks_per_token = max(ep_size, top_k)
|
||||
local_gather_indices_shape = (max_token_count_per_rank * ep_size, )
|
||||
rank_count_cum_sum_shape = (ep_size, )
|
||||
send_rank_local_indices_shape = (max_token_count_per_rank *
|
||||
max_send_ranks_per_token, )
|
||||
recv_rank_local_indices_shape = (max_token_count_per_rank * ep_size, )
|
||||
backward_recv_rank_local_indices_shape = (max_token_count_per_rank *
|
||||
max_send_ranks_per_token, )
|
||||
|
||||
local_gather_indices = gathered_target_rank_ids.new_empty(
|
||||
local_gather_indices_shape, dtype=torch.int32)
|
||||
send_rank_count_cum_sum = gathered_target_rank_ids.new_empty(
|
||||
rank_count_cum_sum_shape, dtype=torch.int32)
|
||||
send_rank_local_indices = gathered_target_rank_ids.new_empty(
|
||||
send_rank_local_indices_shape, dtype=torch.int32)
|
||||
recv_rank_count_cum_sum = gathered_target_rank_ids.new_empty(
|
||||
rank_count_cum_sum_shape, dtype=torch.int32)
|
||||
recv_rank_local_indices = gathered_target_rank_ids.new_empty(
|
||||
recv_rank_local_indices_shape, dtype=torch.int32)
|
||||
backward_recv_rank_local_indices = gathered_target_rank_ids.new_empty(
|
||||
backward_recv_rank_local_indices_shape, dtype=torch.int32)
|
||||
|
||||
return (local_gather_indices, send_rank_count_cum_sum,
|
||||
send_rank_local_indices, recv_rank_count_cum_sum,
|
||||
recv_rank_local_indices, backward_recv_rank_local_indices)
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_local_gather")
|
||||
def _(
|
||||
recv_rank_cum_sum: torch.Tensor,
|
||||
local_gather_indices: torch.Tensor,
|
||||
gathered_expert_ids: torch.Tensor,
|
||||
gathered_scales: Optional[torch.Tensor],
|
||||
local_expert_ids: torch.Tensor,
|
||||
local_scales: Optional[torch.Tensor],
|
||||
max_token_count_per_rank: int,
|
||||
expert_count: int,
|
||||
top_k: int,
|
||||
ep_rank: int,
|
||||
ep_size: int,
|
||||
):
|
||||
pass
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_comm")
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
inputs: List[torch.Tensor],
|
||||
send_rank_cum_sum: torch.Tensor,
|
||||
send_indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
recv_rank_cum_sum: torch.Tensor,
|
||||
recv_indices: torch.Tensor,
|
||||
all_workspaces: torch.Tensor,
|
||||
output_allocation_count: int,
|
||||
ep_rank: int,
|
||||
ep_size: int,
|
||||
need_zero_output: Optional[List[bool]],
|
||||
):
|
||||
pass
|
||||
outputs = []
|
||||
for input_tensor in inputs:
|
||||
output_tensor = torch.empty(
|
||||
(output_allocation_count, input_tensor.shape[1]),
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
outputs.append(output_tensor)
|
||||
return outputs
|
||||
|
||||
@torch.library.register_fake("trtllm::get_moe_commworkspace_size_per_rank")
|
||||
def _(ep_size: int):
|
||||
@ -287,6 +243,12 @@ def _register_fake():
|
||||
token_selected_experts: torch.Tensor, offset_by_ep_rank: bool):
|
||||
return torch.empty_like(token_selected_experts)
|
||||
|
||||
@torch.library.register_fake("trtllm::memset_expert_ids")
|
||||
def _(experts_ids: torch.Tensor, recv_rank_count_cumsum: torch.Tensor,
|
||||
max_token_count_per_rank: int, top_k: int, slot_count: int,
|
||||
ep_size: int):
|
||||
pass
|
||||
|
||||
@torch.library.custom_op("trtllm::group_rms_norm_base",
|
||||
mutates_args=("outputs", ))
|
||||
def group_rms_norm_base(
|
||||
|
||||
@ -58,8 +58,7 @@ from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
|
||||
MoEWeightLoadingMode, TRTLLMGenFusedMoE,
|
||||
create_moe,
|
||||
moe_load_balancer_set_repeated_for_next_layer)
|
||||
create_moe)
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
@ -1159,7 +1158,7 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
|
||||
self.num_hidden_layers = self.config.num_hidden_layers
|
||||
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
|
||||
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
|
||||
moe_load_balancer_set_repeated_for_next_layer(model_nextn)
|
||||
pass
|
||||
else:
|
||||
# modify the QuantConfig to support duplicated mtp layers
|
||||
if model_config.quant_config.exclude_modules is not None:
|
||||
|
||||
@ -14,6 +14,7 @@ from ..model_config import ModelConfig, TConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import moe_load_balancer_set_repeated_for_next_layer
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import (Linear, TensorParallelMode, WeightMode,
|
||||
WeightsLoadingConfig)
|
||||
@ -340,6 +341,9 @@ class MTPForCausalLM(nn.Module):
|
||||
mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle(
|
||||
) else model_config.spec_config.num_nextn_predict_layers
|
||||
|
||||
moe_load_balancer_set_repeated_for_next_layer(
|
||||
model_config.spec_config.num_nextn_predict_layers // mtp_num_layers)
|
||||
|
||||
self.mtp_layers = nn.ModuleList([
|
||||
DeepseekV3MTP(model_config, layer_idx + start_layer_idx,
|
||||
model.aux_stream_dict)
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
|
||||
from tensorrt_llm.math_utils import pad_up
|
||||
|
||||
from ...distributed import allgather
|
||||
from ...model_config import ModelConfig
|
||||
@ -190,8 +189,6 @@ class CutlassFusedMoE(MoE):
|
||||
@cached_property
|
||||
def enable_alltoall(self):
|
||||
return (self.mapping.moe_ep_size > self.routing_method.experts_per_token
|
||||
and self.routing_method.experts_per_token % 4 ==
|
||||
0 # alltoall without allgather only supports top_k % 4 == 0
|
||||
and self.mapping.enable_attention_dp
|
||||
and self.mapping.tp_size > 1
|
||||
and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1"
|
||||
@ -353,39 +350,28 @@ class CutlassFusedMoE(MoE):
|
||||
token_final_scales = torch.ones_like(token_selected_experts,
|
||||
dtype=torch.float32)
|
||||
|
||||
# TODO: support alltoall without allgather for top_k % 4 != 0
|
||||
assert top_k % 4 == 0, "alltoall without allgather only supports top_k % 4 == 0"
|
||||
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
|
||||
alltoall_info, token_selected_experts, token_final_scales, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
token_selected_experts, token_final_scales, None,
|
||||
self.alltoall_prepare_workspace, max_num_token, self.ep_rank,
|
||||
self.ep_size, self.num_experts, self.num_experts, top_k)
|
||||
alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
token_selected_experts, None, self.alltoall_prepare_workspace,
|
||||
max_num_token, self.ep_rank, self.ep_size, self.num_experts,
|
||||
self.num_experts, top_k)
|
||||
|
||||
# Dispatch alltoall (common for both paths)
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info,
|
||||
self.alltoall_workspace,
|
||||
self.ep_rank, self.ep_size)
|
||||
if x_sf is not None:
|
||||
x_sf = x_sf.view(x_row, ceil_div(x_col,
|
||||
self.scaling_vector_size))
|
||||
|
||||
# Pad dim[1] to 16 bytes alignment for alltoall
|
||||
# TODO: Remove this padding if possible
|
||||
sf_per_16bytes = 16 // x_sf.element_size()
|
||||
x_sf_col_orig = x_sf.shape[1]
|
||||
x_sf_col = pad_up(x_sf_col_orig, sf_per_16bytes)
|
||||
if x_sf_col > x_sf_col_orig:
|
||||
x_sf = torch.nn.functional.pad(
|
||||
x_sf, (0, x_sf_col - x_sf_col_orig))
|
||||
# Dispatch x, x_sf, token_selected_experts, token_final_scales in one alltoall kernel
|
||||
x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
[x, x_sf, token_selected_experts, token_final_scales],
|
||||
alltoall_info, self.alltoall_workspace, self.ep_rank,
|
||||
self.ep_size)
|
||||
|
||||
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(x_sf, alltoall_info,
|
||||
self.alltoall_workspace,
|
||||
self.ep_rank, self.ep_size)
|
||||
torch.ops.trtllm.memset_expert_ids(
|
||||
token_selected_experts, alltoall_info.recv_rank_count_cumsum,
|
||||
max_num_token, top_k, self.num_experts, self.ep_size)
|
||||
|
||||
if x_sf is not None:
|
||||
x_row = x_sf.shape[0]
|
||||
|
||||
# TODO: Remove this slicing required by padding if possible
|
||||
x_sf = x_sf[:, :x_sf_col_orig].contiguous()
|
||||
|
||||
x_sf = swizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size)
|
||||
|
||||
elif run_post_quant_allgather:
|
||||
|
||||
@ -192,18 +192,13 @@ class WideEPMoE(MoE):
|
||||
self.use_low_precision_combine = (os.environ.get(
|
||||
"TRTLLM_MOE_USE_LOW_PRECISION_COMBINE", "0")
|
||||
== "1") and qm.has_nvfp4()
|
||||
# TODO: support alltoall without allgather for top_k % 4 != 0
|
||||
self.enable_alltoall_without_allgather = (
|
||||
os.environ.get("TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER",
|
||||
"1") == "1"
|
||||
) and self.alltoall_method_type == AlltoallMethodType.MNNVL and routing_method.experts_per_token % 4 == 0
|
||||
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
MnnvlMemory.initialize()
|
||||
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
|
||||
model_config.mapping)
|
||||
if self.enable_alltoall_without_allgather:
|
||||
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
|
||||
model_config.mapping)
|
||||
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
|
||||
model_config.mapping)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||
self.deep_ep_buffer = buffer_pool.get_buffer(
|
||||
model_config.mapping)
|
||||
@ -301,6 +296,9 @@ class WideEPMoE(MoE):
|
||||
1) // self.moe_max_num_tokens
|
||||
|
||||
def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens):
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
return True
|
||||
|
||||
# Disable alltoall when chunking is used
|
||||
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
|
||||
return False
|
||||
@ -458,24 +456,23 @@ class WideEPMoE(MoE):
|
||||
else:
|
||||
tuner_num_tokens = None
|
||||
tuner_top_k = None
|
||||
alltoall_info = None
|
||||
if use_all_to_all:
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
if self.enable_dummy_allreduce:
|
||||
self.dummy_allreduce()
|
||||
token_count = x.shape[0]
|
||||
alltoall_info = None
|
||||
if is_last_call:
|
||||
if is_last_call and self.layer_load_balancer is not None and not self.layer_load_balancer.is_static_routing(
|
||||
):
|
||||
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
|
||||
)
|
||||
else:
|
||||
loadbalancer_local_statistic_info = None
|
||||
x, token_selected_slots, token_final_scales, gathered_loadbalancer_local_statistic_info, alltoall_info = \
|
||||
self.alltoall_prepare_maybe_dispatch(all_rank_max_num_tokens,
|
||||
x,
|
||||
token_selected_slots,
|
||||
token_final_scales,
|
||||
use_postquant_alltoall,
|
||||
loadbalancer_local_statistic_info)
|
||||
token_selected_slots, gathered_loadbalancer_local_statistic_info, alltoall_info = \
|
||||
self.alltoall_prepare(all_rank_max_num_tokens,
|
||||
token_selected_slots,
|
||||
loadbalancer_local_statistic_info)
|
||||
|
||||
if gathered_loadbalancer_local_statistic_info is not None:
|
||||
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
|
||||
(self.mapping.moe_ep_size, self.num_experts))
|
||||
@ -580,10 +577,15 @@ class WideEPMoE(MoE):
|
||||
cluster_rank = self.cluster_rank
|
||||
quant_scales = self.quant_scales
|
||||
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
top_k = self.routing_method.experts_per_token
|
||||
x, x_sf, token_selected_slots, token_final_scales = self.alltoall_dispatch(
|
||||
x, x_sf, token_selected_slots, token_final_scales,
|
||||
all_rank_max_num_tokens, top_k, alltoall_info)
|
||||
|
||||
if use_postquant_alltoall:
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
x, x_sf = self.alltoall_postquant_dispatch(
|
||||
x, x_sf, alltoall_info)
|
||||
pass
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||
if x_sf is not None:
|
||||
# Adapter between `x_sf` and DeepEP
|
||||
@ -862,77 +864,34 @@ class WideEPMoE(MoE):
|
||||
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
|
||||
return outputs
|
||||
|
||||
def alltoall_prepare_maybe_dispatch(
|
||||
self, all_rank_max_num_tokens: int, x: torch.Tensor,
|
||||
token_selected_slots: torch.Tensor,
|
||||
token_final_scales: torch.Tensor, use_postquant_alltoall: bool,
|
||||
local_statistic_tensor: Optional[torch.Tensor]):
|
||||
def alltoall_prepare(self, all_rank_max_num_tokens: int,
|
||||
token_selected_slots: torch.Tensor,
|
||||
local_statistic_tensor: Optional[torch.Tensor]):
|
||||
top_k = self.routing_method.experts_per_token
|
||||
|
||||
if self.enable_alltoall_without_allgather:
|
||||
alltoall_info, token_selected_slots, token_final_scales, gathered_local_statistic_tensor = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
token_selected_slots, token_final_scales,
|
||||
local_statistic_tensor, self.alltoall_prepare_workspace,
|
||||
all_rank_max_num_tokens, self.ep_rank, self.ep_size,
|
||||
self.num_experts, self.num_slots, top_k)
|
||||
else:
|
||||
if all_rank_max_num_tokens > token_selected_slots.shape[0]:
|
||||
token_selected_slots = torch.nn.functional.pad(
|
||||
token_selected_slots,
|
||||
(0, 0, 0,
|
||||
all_rank_max_num_tokens - token_selected_slots.shape[0]),
|
||||
'constant', self.num_slots)
|
||||
if token_final_scales is not None and all_rank_max_num_tokens > token_final_scales.shape[
|
||||
0]:
|
||||
token_final_scales = torch.nn.functional.pad(
|
||||
token_final_scales,
|
||||
(0, 0, 0,
|
||||
all_rank_max_num_tokens - token_final_scales.shape[0]))
|
||||
gathered_token_selected_slots, gathered_token_final_scales, gathered_local_statistic_tensor = allgather(
|
||||
[
|
||||
token_selected_slots, token_final_scales,
|
||||
local_statistic_tensor
|
||||
],
|
||||
self.mapping,
|
||||
dim=0)
|
||||
gathered_token_selected_slots = torch.flatten(
|
||||
gathered_token_selected_slots.contiguous(),
|
||||
start_dim=0,
|
||||
end_dim=-2)
|
||||
if gathered_token_final_scales is not None:
|
||||
gathered_token_final_scales = torch.flatten(
|
||||
gathered_token_final_scales.contiguous(),
|
||||
start_dim=0,
|
||||
end_dim=-2)
|
||||
gathered_target_rank_ids = MnnvlMoe.compute_target_rank_id(
|
||||
gathered_token_selected_slots, self.num_slots, self.ep_size)
|
||||
alltoall_info, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv_prepare(
|
||||
gathered_target_rank_ids, None, gathered_token_selected_slots,
|
||||
gathered_token_final_scales, all_rank_max_num_tokens,
|
||||
self.num_slots, top_k, self.ep_rank, self.ep_size)
|
||||
alltoall_info, gathered_local_statistic_tensor = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
token_selected_slots, local_statistic_tensor,
|
||||
self.alltoall_prepare_workspace, all_rank_max_num_tokens,
|
||||
self.ep_rank, self.ep_size, self.num_experts, self.num_slots, top_k)
|
||||
|
||||
if not use_postquant_alltoall:
|
||||
assert not isinstance(
|
||||
x, Fp4QuantizedTensor
|
||||
), "pre-quant alltoall doesn't support fp4 tensor"
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info,
|
||||
self.alltoall_workspace,
|
||||
self.ep_rank, self.ep_size)
|
||||
return token_selected_slots, gathered_local_statistic_tensor, alltoall_info
|
||||
|
||||
return x, token_selected_slots, token_final_scales, gathered_local_statistic_tensor, alltoall_info
|
||||
def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
|
||||
token_selected_slots: torch.Tensor,
|
||||
token_final_scales: Optional[torch.Tensor],
|
||||
all_rank_max_num_tokens: int, top_k: int,
|
||||
alltoall_info: MoEAlltoallInfo):
|
||||
|
||||
def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor,
|
||||
alltoall_info: MoEAlltoallInfo):
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info,
|
||||
self.alltoall_workspace, self.ep_rank,
|
||||
self.ep_size)
|
||||
x, x_sf, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
[x, x_sf, token_selected_slots, token_final_scales], alltoall_info,
|
||||
self.alltoall_workspace, self.ep_rank, self.ep_size)
|
||||
|
||||
if x_sf is not None:
|
||||
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(x_sf, alltoall_info,
|
||||
self.alltoall_workspace,
|
||||
self.ep_rank, self.ep_size)
|
||||
torch.ops.trtllm.memset_expert_ids(token_selected_slots,
|
||||
alltoall_info.recv_rank_count_cumsum,
|
||||
all_rank_max_num_tokens, top_k,
|
||||
self.num_slots, self.ep_size)
|
||||
|
||||
return x, x_sf
|
||||
return x, x_sf, token_selected_slots, token_final_scales
|
||||
|
||||
def alltoall_combine(self, final_hidden_states: torch.Tensor,
|
||||
alltoall_info: MoEAlltoallInfo, token_count: int):
|
||||
|
||||
@ -423,7 +423,8 @@ def load_torch_hf_lora(lora_config: LoraConfig):
|
||||
pivot model config is the transformer's one.
|
||||
"""
|
||||
# TODO smor- need to comibe with load_hf_lora
|
||||
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules()
|
||||
if not lora_config.trtllm_modules_to_hf_modules:
|
||||
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules()
|
||||
|
||||
assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir"
|
||||
lora_loader = HfLoraLoader(lora_config.lora_dir)
|
||||
|
||||
@ -109,7 +109,6 @@
|
||||
"examples/test_whisper.py::test_llm_whisper_general[large-v3-disable_gemm_plugin-disable_attention_plugin-disable_weight_only-float16-nb:1-use_python_runtime]": 327.95307156071067,
|
||||
"test_e2e.py::test_build_time_benchmark_sanity": 165.71592589840293,
|
||||
"test_unittests.py::test_unittests_v2[unittest/trt/attention/test_bert_attention.py]": 99.96196278184652,
|
||||
"cpp/test_e2e.py::test_benchmarks[gpt-80]": 1376.0404928650241,
|
||||
"accuracy/test_cli_flow.py::TestPhi3Small128kInstruct::test_auto_dtype": 512.450893450994,
|
||||
"accuracy/test_llm_api.py::TestLlama3_1_8B::test_fp8_rowwise": 361.5573864541948,
|
||||
"examples/test_llama.py::test_llm_llama_v3_dora_1gpu[commonsense-llama-v3-8b-dora-r32-llama-v3-8b-hf-base_fp16]": 517.2770831151865,
|
||||
@ -150,10 +149,6 @@
|
||||
"test_unittests.py::test_unittests_v2[unittest/_torch/thop]": 852.56,
|
||||
"test_unittests.py::test_unittests_v2[unittest/_torch/modeling -k \"modeling_mixtral\"]": 208.1838396479725,
|
||||
"test_unittests.py::test_unittests_v2[unittest/_torch/multi_gpu_modeling -k \"deepseek\"]": 393.0210295501165,
|
||||
"cpp/test_e2e.py::test_model[-gpt_executor-80]": 4016.7569622844458,
|
||||
"cpp/test_e2e.py::test_model[-gpt_tests-80]": 1817.8153839111328,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[executor-80]": 339.0683519244194,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[kernels-80]": 846.0403860099614,
|
||||
"test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]": 21.019993914989755,
|
||||
"test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]": 18.753523574909195,
|
||||
"test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-video-False]": 278.4781197870616,
|
||||
@ -254,35 +249,37 @@
|
||||
"accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only[int8]": 159.531545445323,
|
||||
"accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only_int8_kv_cache[int8]": 184.35870655626059,
|
||||
"test_unittests.py::test_unittests_v2[unittest/llmapi/test_llm_models.py -m \"not (part0 or part1)\"]": 825.9972547292709,
|
||||
"cpp/test_e2e.py::test_benchmarks[bart-90]": 271.95234084688127,
|
||||
"cpp/test_e2e.py::test_model[-bart-90]": 391.84748707409017,
|
||||
"cpp/test_e2e.py::test_benchmarks[gpt-80]": 1376.0404928650241,
|
||||
"cpp/test_e2e.py::test_model[-gpt_executor-80]": 1495.14,
|
||||
"cpp/test_e2e.py::test_model[-gpt_tests-80]": 1206.79,
|
||||
"cpp/test_e2e.py::test_model[-gpt-80]": 1568.98,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[batch_manager-80]": 1005.24,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[common-80]": 38.98,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[executor-80]": 425.16,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[kernels-80]": 2009.96,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[layers-80]": 2209.11,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[runtime-80]": 1671.42,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[thop-80]": 6.76,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[utils-80]": 8.53,
|
||||
"cpp/test_e2e.py::test_model[-eagle-86]": 850.5158995762467,
|
||||
"cpp/test_e2e.py::test_model[-mamba-86]": 893.8684413917363,
|
||||
"cpp/test_e2e.py::test_model[-medusa-86]": 577.0913726426661,
|
||||
"cpp/test_e2e.py::test_model[-redrafter-86]": 356.56682327389717,
|
||||
"cpp/test_e2e.py::test_benchmarks[t5-90]": 244.83684724476188,
|
||||
"cpp/test_e2e.py::test_model[-enc_dec_language_adapter-90]": 356.5558080910705,
|
||||
"cpp/test_e2e.py::test_model[-t5-90]": 167.93334361724555,
|
||||
"cpp/test_e2e.py::test_model[fp8-llama-90]": 810.9923318810761,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[batch_manager-90]": 230.49758478673175,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[common-90]": 12.953252204693854,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[executor-90]": 317.7621980938129,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[kernels-90]": 554.3154804841615,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[layers-90]": 1288.0563295381144,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[runtime-90]": 876.2420815587975,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[thop-90]": 2.2652571727521718,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[utils-90]": 3.6415831856429577,
|
||||
"cpp/test_e2e.py::test_benchmarks[bart-90]": 271.95234084688127,
|
||||
"cpp/test_e2e.py::test_benchmarks[t5-90]": 523.07,
|
||||
"cpp/test_e2e.py::test_model[-bart-90]": 391.84748707409017,
|
||||
"cpp/test_e2e.py::test_model[-enc_dec_language_adapter-90]": 416.06,
|
||||
"cpp/test_e2e.py::test_model[-t5-90]": 170.26,
|
||||
"cpp/test_e2e.py::test_model[fp8-llama-90]": 385.98,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[common-90]": 25.06,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[kernels-90]": 1333.18,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[layers-90]": 1627.07,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[thop-90]": 4.16,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[utils-90]": 5.28,
|
||||
"accuracy/test_cli_flow.py::TestLlama3_1_8B::test_fp8_rowwise_meta_recipe": 634.7149123200215,
|
||||
"examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_fp16]": 895.7611340929288,
|
||||
"examples/test_nemotron_nas.py::test_nemotron_nas_summary_1gpu[DeciLM-7B]": 335.41048416192643,
|
||||
"examples/test_phi.py::test_llm_phi_lora_1gpu[Phi-3-mini-4k-instruct-ru-lora-Phi-3-mini-4k-instruct-lora_fp16-base_fp16]": 217.61977925198153,
|
||||
"cpp/test_e2e.py::test_model[-gpt-80]": 2498.94351779297,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[batch_manager-80]": 380.8567730002105,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[common-80]": 20.237869411706924,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[layers-80]": 2141.2598778679967,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[runtime-80]": 1491.7047495394945,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[thop-80]": 3.3458465598523617,
|
||||
"cpp/test_unit_tests.py::test_unit_tests[utils-80]": 5.461210697889328,
|
||||
"accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_gemm_plugin": 593.3573900908232,
|
||||
"examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it-other-bfloat16-8]": 195.3050664511975,
|
||||
"test_e2e.py::test_llmapi_quickstart_atexit": 110.45052940770984,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import glob
|
||||
import logging as _logger
|
||||
import os as _os
|
||||
import pathlib as _pl
|
||||
@ -19,8 +18,16 @@ from defs.conftest import llm_models_root
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def build_dir():
|
||||
return _cpp.find_build_dir()
|
||||
def build_type():
|
||||
"""CMake build type for C++ builds."""
|
||||
# For debugging purposes, we can use the RelWithDebInfo build type.
|
||||
return _os.environ.get("TLLM_BUILD_TYPE", "Release")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def build_dir(build_type):
|
||||
"""Resolved build directory for the current build_type."""
|
||||
return _cpp.find_build_dir(build_type)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -148,44 +155,35 @@ def install_additional_requirements(python_exe, root_dir):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def build_google_tests(request, build_dir):
|
||||
def build_google_tests(request, build_type):
|
||||
|
||||
cuda_arch = f"{request.param}-real"
|
||||
|
||||
print(f"Using CUDA arch: {cuda_arch}")
|
||||
_logger.info(f"Using CUDA arch: {cuda_arch}")
|
||||
|
||||
build_trt_llm(cuda_architectures=cuda_arch,
|
||||
job_count=12,
|
||||
use_ccache=True,
|
||||
clean=True,
|
||||
generator="Ninja",
|
||||
trt_root="/usr/local/tensorrt",
|
||||
nixl_root="/opt/nvidia/nvda_nixl",
|
||||
skip_building_wheel=True)
|
||||
|
||||
make_google_tests = [
|
||||
"cmake",
|
||||
"--build",
|
||||
".",
|
||||
"--config",
|
||||
"Release",
|
||||
"-j",
|
||||
"--target",
|
||||
"google-tests",
|
||||
]
|
||||
|
||||
_cpp.run_command(make_google_tests, cwd=build_dir, timeout=300)
|
||||
build_trt_llm(
|
||||
build_type=build_type,
|
||||
cuda_architectures=cuda_arch,
|
||||
job_count=12,
|
||||
use_ccache=True,
|
||||
clean=True,
|
||||
generator="Ninja",
|
||||
trt_root="/usr/local/tensorrt",
|
||||
nixl_root="/opt/nvidia/nvda_nixl",
|
||||
skip_building_wheel=True,
|
||||
extra_make_targets=["google-tests"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def build_benchmarks(build_google_tests, build_dir):
|
||||
def build_benchmarks(build_google_tests, build_dir, build_type):
|
||||
|
||||
make_benchmarks = [
|
||||
"cmake",
|
||||
"--build",
|
||||
".",
|
||||
"--config",
|
||||
"Release",
|
||||
build_type,
|
||||
"-j",
|
||||
"--target",
|
||||
"benchmarks",
|
||||
@ -224,16 +222,18 @@ def prepare_model(
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def keep_log_files(llm_root):
|
||||
"Backup previous cpp test results when run multiple ctest"
|
||||
results_dir = f"{llm_root}/cpp/build"
|
||||
def keep_log_files(build_dir):
|
||||
"""Backup previous cpp test results when run multiple ctest invocations."""
|
||||
results_dir = build_dir
|
||||
|
||||
yield
|
||||
|
||||
backup_dir = f"{llm_root}/cpp/build_backup"
|
||||
_os.makedirs(backup_dir, exist_ok=True)
|
||||
# Copy XML files to backup directory
|
||||
xml_files = glob.glob(f"{results_dir}/*.xml")
|
||||
build_parent_dir = build_dir.parent
|
||||
backup_dir_name = build_dir.name + "_backup"
|
||||
backup_dir = build_parent_dir / backup_dir_name
|
||||
backup_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Copy XML files from all subdirectories to backup directory
|
||||
xml_files = list(results_dir.rglob("*.xml"))
|
||||
if xml_files:
|
||||
for xml_file in xml_files:
|
||||
try:
|
||||
|
||||
@ -56,8 +56,6 @@ def generate_result_file_name(test_list: List[str],
|
||||
|
||||
def generate_excluded_test_list(test_list):
|
||||
if "gpt" in test_list:
|
||||
if "gpt_session" not in test_list:
|
||||
yield "GptSession"
|
||||
if "gpt_executor" not in test_list:
|
||||
yield "GptExecutor"
|
||||
if "gpt_tests" not in test_list:
|
||||
@ -84,11 +82,18 @@ def find_root_dir(start_dir: Optional[_pl.Path] = None) -> _pl.Path:
|
||||
return find_dir_containing(("scripts", "examples", "cpp"), start_dir)
|
||||
|
||||
|
||||
def find_build_dir():
|
||||
root_dir = find_root_dir()
|
||||
dir = get_trt_llm_build_dir(None, "Release")
|
||||
def find_build_dir(build_type: str) -> _pl.Path:
|
||||
"""Resolve the TRT-LLM C++ build directory for the given CMake build type.
|
||||
|
||||
return dir if dir.is_absolute() else root_dir / dir
|
||||
Args:
|
||||
build_type: CMake build type (e.g., "Release", "RelWithDebInfo", "Debug").
|
||||
|
||||
Returns:
|
||||
Absolute path to the C++ build directory.
|
||||
"""
|
||||
root_dir = find_root_dir()
|
||||
build_dir = get_trt_llm_build_dir(None, build_type)
|
||||
return build_dir if build_dir.is_absolute() else root_dir / build_dir
|
||||
|
||||
|
||||
def run_command(command: Sequence[str],
|
||||
|
||||
@ -14,6 +14,7 @@ def run_single_gpu_tests(build_dir: _pl.Path,
|
||||
timeout=3600):
|
||||
|
||||
cpp_env = {**_os.environ}
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
|
||||
included_tests = list(_cpp.generate_included_model_tests(test_list))
|
||||
|
||||
@ -38,7 +39,7 @@ def run_single_gpu_tests(build_dir: _pl.Path,
|
||||
parallel = int(parallel_override)
|
||||
|
||||
_cpp.parallel_run_ctest(ctest,
|
||||
cwd=build_dir,
|
||||
cwd=tests_dir,
|
||||
env=cpp_env,
|
||||
timeout=timeout,
|
||||
parallel=parallel)
|
||||
@ -50,12 +51,12 @@ def run_single_gpu_tests(build_dir: _pl.Path,
|
||||
global_commands=["mpirun", "--allow-run-as-root"],
|
||||
nranks=2,
|
||||
local_commands=[
|
||||
"tests/executor/disaggExecutorTest",
|
||||
"executor/disaggExecutorTest",
|
||||
"--gtest_filter=*GptSingleDeviceDisaggSymmetricExecutorTest*"
|
||||
],
|
||||
leader_commands=[f"--gtest_output=xml:{xml_output_file}"])
|
||||
_cpp.run_command(trt_model_test,
|
||||
cwd=build_dir,
|
||||
cwd=tests_dir,
|
||||
env=new_env,
|
||||
timeout=timeout)
|
||||
|
||||
@ -192,7 +193,7 @@ def run_benchmarks(
|
||||
def run_spec_dec_tests(build_dir: _pl.Path):
|
||||
xml_output_file = build_dir / "results-spec-dec-fast-logits.xml"
|
||||
cpp_env = {**_os.environ}
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
trt_model_test = _cpp.produce_mpirun_command(
|
||||
global_commands=["mpirun", "--allow-run-as-root"],
|
||||
nranks=3,
|
||||
|
||||
@ -50,7 +50,7 @@ def get_multi_gpu_env(kv_cache_type=KVCacheType.NONE, llama_multi_gpu=False):
|
||||
|
||||
def run_mpi_utils_tests(build_dir, timeout=300):
|
||||
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "unit_tests" / "multi_gpu"
|
||||
mgpu_env = get_multi_gpu_env()
|
||||
|
||||
mpi_utils_test = [
|
||||
@ -68,7 +68,7 @@ def run_mpi_utils_tests(build_dir, timeout=300):
|
||||
|
||||
def run_gemm_allreduce_tests(build_dir, nprocs, timeout=300):
|
||||
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "unit_tests" / "multi_gpu"
|
||||
mgpu_env = get_multi_gpu_env()
|
||||
|
||||
gemm_allreduce_test = [
|
||||
@ -76,7 +76,7 @@ def run_gemm_allreduce_tests(build_dir, nprocs, timeout=300):
|
||||
"-n",
|
||||
f"{nprocs}",
|
||||
"--allow-run-as-root",
|
||||
"unit_tests/kernels/gemmAllReduceTest",
|
||||
"kernels/gemmAllReduceTest",
|
||||
"--m=2032",
|
||||
"--n=8200",
|
||||
"--k=1024",
|
||||
@ -93,7 +93,7 @@ def run_cache_transceiver_tests(build_dir: _pl.Path,
|
||||
kv_cache_type=KVCacheType.MPI,
|
||||
timeout=600):
|
||||
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "unit_tests" / "multi_gpu"
|
||||
mgpu_env = get_multi_gpu_env(kv_cache_type=kv_cache_type)
|
||||
|
||||
cache_trans_test = [
|
||||
@ -101,7 +101,7 @@ def run_cache_transceiver_tests(build_dir: _pl.Path,
|
||||
"-n",
|
||||
f"{nprocs}",
|
||||
"--allow-run-as-root",
|
||||
"batch_manager/cacheTransceiverTest",
|
||||
"cacheTransceiverTest",
|
||||
]
|
||||
_cpp.run_command(cache_trans_test,
|
||||
cwd=tests_dir,
|
||||
@ -117,7 +117,7 @@ def run_cache_transceiver_tests(build_dir: _pl.Path,
|
||||
"-n",
|
||||
"8",
|
||||
"--allow-run-as-root",
|
||||
"batch_manager/cacheTransceiverTest",
|
||||
"cacheTransceiverTest",
|
||||
]
|
||||
_cpp.run_command(cache_trans_test_8_proc,
|
||||
cwd=tests_dir,
|
||||
@ -125,8 +125,26 @@ def run_cache_transceiver_tests(build_dir: _pl.Path,
|
||||
timeout=600)
|
||||
|
||||
|
||||
def run_user_buffer_tests(build_dir: _pl.Path, nprocs=2, timeout=300):
|
||||
tests_dir = build_dir / "tests" / "unit_tests" / "multi_gpu"
|
||||
mgpu_env = get_multi_gpu_env()
|
||||
|
||||
user_buffer_test = [
|
||||
"mpirun",
|
||||
"-n",
|
||||
f"{nprocs}",
|
||||
"--allow-run-as-root",
|
||||
"userBufferTest",
|
||||
]
|
||||
|
||||
_cpp.run_command(user_buffer_test,
|
||||
cwd=tests_dir,
|
||||
env=mgpu_env,
|
||||
timeout=timeout)
|
||||
|
||||
|
||||
def run_llama_executor_leader_tests(build_dir: _pl.Path, timeout=1500):
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
|
||||
mgpu_env = get_multi_gpu_env(llama_multi_gpu=True)
|
||||
|
||||
@ -145,7 +163,7 @@ def run_llama_executor_leader_tests(build_dir: _pl.Path, timeout=1500):
|
||||
|
||||
|
||||
def run_llama_executor_orchestrator_tests(build_dir: _pl.Path, timeout=1500):
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
|
||||
mgpu_env = get_multi_gpu_env(llama_multi_gpu=True)
|
||||
|
||||
@ -160,7 +178,7 @@ def run_llama_executor_orchestrator_tests(build_dir: _pl.Path, timeout=1500):
|
||||
|
||||
|
||||
def run_llama_executor_logits_proc_tests(build_dir: _pl.Path, timeout=1500):
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
|
||||
mgpu_env = get_multi_gpu_env(llama_multi_gpu=True)
|
||||
|
||||
@ -187,7 +205,7 @@ def run_llama_executor_logits_proc_tests(build_dir: _pl.Path, timeout=1500):
|
||||
|
||||
|
||||
def run_llama_executor_guided_decoding_tests(build_dir: _pl.Path, timeout=1500):
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
|
||||
mgpu_env = get_multi_gpu_env(llama_multi_gpu=True)
|
||||
|
||||
@ -214,7 +232,7 @@ def run_llama_executor_guided_decoding_tests(build_dir: _pl.Path, timeout=1500):
|
||||
|
||||
|
||||
def run_enc_dec_multi_gpu_tests(build_dir: _pl.Path, timeout=1500):
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
cpp_env = {**_os.environ}
|
||||
|
||||
#EncDec test in leader mode
|
||||
@ -233,7 +251,7 @@ def run_enc_dec_multi_gpu_tests(build_dir: _pl.Path, timeout=1500):
|
||||
|
||||
def run_trt_gpt_model_real_decoder_multi_gpu_tests(build_dir: _pl.Path,
|
||||
timeout=1500):
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
cpp_env = {**_os.environ}
|
||||
|
||||
xml_output_file = build_dir / "results-multi-gpu-real-decoder.xml"
|
||||
@ -256,7 +274,7 @@ def run_disagg_symmetric_executor_tests(build_dir: _pl.Path,
|
||||
nprocs=2,
|
||||
kvcache_type=KVCacheType.MPI,
|
||||
timeout=1500):
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
|
||||
prefix = get_model_test_filter_prefix(model)
|
||||
|
||||
@ -285,7 +303,7 @@ def run_disagg_asymmetric_executor_tests(build_dir: _pl.Path,
|
||||
kvcache_type=KVCacheType.MPI,
|
||||
timeout=1500):
|
||||
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
|
||||
prefix = get_model_test_filter_prefix(model)
|
||||
|
||||
@ -314,7 +332,7 @@ def run_disagg_orchestrator_params_tests(build_dir: _pl.Path,
|
||||
kvcache_type=KVCacheType.MPI,
|
||||
timeout=1500):
|
||||
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
|
||||
prefix = get_model_test_filter_prefix(model)
|
||||
|
||||
@ -341,7 +359,7 @@ def run_disagg_spawn_orchestrator_tests(build_dir: _pl.Path,
|
||||
kvcache_type=False,
|
||||
timeout=1500):
|
||||
|
||||
tests_dir = build_dir / "tests"
|
||||
tests_dir = build_dir / "tests" / "e2e_tests"
|
||||
|
||||
prefix = get_model_test_filter_prefix(model)
|
||||
|
||||
@ -352,7 +370,7 @@ def run_disagg_spawn_orchestrator_tests(build_dir: _pl.Path,
|
||||
|
||||
comms = [
|
||||
"executor/disaggExecutorTest",
|
||||
f"--gtest_filter=*{prefix}*DisaaggSpawnOrchestrator*",
|
||||
f"--gtest_filter=*{prefix}*DisaggSpawnOrchestrator*",
|
||||
f"--gtest_output=xml:{xml_output_file}"
|
||||
]
|
||||
_cpp.run_command(comms, cwd=tests_dir, env=mgpu_env, timeout=timeout)
|
||||
@ -488,13 +506,21 @@ def test_fused_gemm_allreduce(build_google_tests, nprocs, build_dir):
|
||||
def test_cache_transceiver(build_google_tests, nprocs, kvcache_type, build_dir):
|
||||
|
||||
if platform.system() != "Windows":
|
||||
|
||||
run_cache_transceiver_tests(build_dir=build_dir,
|
||||
nprocs=nprocs,
|
||||
kv_cache_type=kvcache_type,
|
||||
timeout=600)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
|
||||
indirect=True)
|
||||
@pytest.mark.parametrize("nprocs", [2, 8], ids=["2proc", "8proc"])
|
||||
def test_user_buffer(build_google_tests, nprocs, build_dir):
|
||||
|
||||
if platform.system() != "Windows":
|
||||
run_user_buffer_tests(build_dir=build_dir, nprocs=nprocs, timeout=300)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
|
||||
indirect=True)
|
||||
@pytest.mark.parametrize("multi_gpu_model", ["t5"], indirect=True)
|
||||
|
||||
@ -38,13 +38,13 @@ l0_a30:
|
||||
- cpp/test_unit_tests.py::test_unit_tests[common-80]
|
||||
- cpp/test_unit_tests.py::test_unit_tests[executor-80]
|
||||
- cpp/test_unit_tests.py::test_unit_tests[kernels-80]
|
||||
- cpp/test_unit_tests.py::test_unit_tests[layers-80] TIMEOUT (90)
|
||||
- cpp/test_unit_tests.py::test_unit_tests[layers-80]
|
||||
- cpp/test_unit_tests.py::test_unit_tests[runtime-80]
|
||||
- cpp/test_unit_tests.py::test_unit_tests[thop-80]
|
||||
- cpp/test_unit_tests.py::test_unit_tests[utils-80]
|
||||
- cpp/test_e2e.py::test_model[-gpt-80] TIMEOUT (90)
|
||||
- cpp/test_e2e.py::test_model[-gpt_executor-80] TIMEOUT (90)
|
||||
- cpp/test_e2e.py::test_model[-gpt_executor-80]
|
||||
- cpp/test_e2e.py::test_model[-gpt_tests-80]
|
||||
- cpp/test_e2e.py::test_model[-gpt-80]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
@ -69,6 +69,8 @@ l0_b200:
|
||||
- unittest/_torch/modeling -k "modeling_deepseek"
|
||||
- unittest/_torch/modeling -k "modeling_gpt_oss"
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu -k "not test_trtllm_bench_backend_comparison"
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
@ -89,3 +89,5 @@ l0_dgx_b200:
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-fp8]
|
||||
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
||||
|
||||
@ -61,6 +61,8 @@ l0_dgx_h100:
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_bs1
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8]
|
||||
- unittest/_torch/modeling/test_modeling_pixtral.py::test_tensor_parallelism
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
@ -177,6 +179,7 @@ l0_dgx_h100:
|
||||
- cpp/test_multi_gpu.py::test_cache_transceiver[2proc-ucx_kvcache-90]
|
||||
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mpi_kvcache-90]
|
||||
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-ucx_kvcache-90]
|
||||
- cpp/test_multi_gpu.py::test_user_buffer[2proc-90]
|
||||
- cpp/test_multi_gpu.py::test_enc_dec[t5-90]
|
||||
- cpp/test_multi_gpu.py::test_llama_executor[llama-orchestrator-90]
|
||||
- cpp/test_multi_gpu.py::test_llama_executor[llama-leader-90]
|
||||
@ -209,18 +212,3 @@ l0_dgx_h100:
|
||||
- cpp/test_multi_gpu.py::TestDisagg::test_spawn_orchestrator[llama-ucx_kvcache-90]
|
||||
- cpp/test_multi_gpu.py::TestDisagg::test_orchestrator_params[llama-nixl_kvcache-90] TIMEOUT (90)
|
||||
- cpp/test_multi_gpu.py::TestDisagg::test_spawn_orchestrator[llama-nixl_kvcache-90]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 4
|
||||
lte: 4
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*h100*'
|
||||
linux_distribution_name: ubuntu*
|
||||
terms:
|
||||
stage: post_merge
|
||||
backend: triton
|
||||
auto_trigger: others
|
||||
tests:
|
||||
- triton_server/test_triton_llm.py::test_llmapi_backend[4-0-disableDecoupleMode-tensorrt_llm]
|
||||
|
||||
@ -34,6 +34,8 @@ l0_dgx_h200:
|
||||
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout]
|
||||
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout]
|
||||
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
@ -166,3 +168,19 @@ l0_dgx_h200:
|
||||
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float16-enable_gemm_plugin-enable_attention_plugin-disable_paged_kv_cache-tp:2-pp:2-nb:1-disable_fp8]
|
||||
- examples/test_gpt.py::test_llm_gpt2_next_prompt_tuning[use_py_session-tp2]
|
||||
- unittest/llmapi/apps/_test_openai_multi_gpu.py -m "part0"
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 4
|
||||
lte: 4
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*h200*'
|
||||
linux_distribution_name: ubuntu*
|
||||
cpu: x86_64
|
||||
terms:
|
||||
stage: post_merge
|
||||
backend: triton
|
||||
tests:
|
||||
# ------------- Triton tests ---------------
|
||||
- triton_server/test_triton_llm.py::test_llmapi_backend[4-0-disableDecoupleMode-tensorrt_llm]
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
# Don't add any tests here.
|
||||
# Copied from l0_b200.yml but not used in the pipeline now
|
||||
version: 0.0.1
|
||||
l0_gb200:
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 4
|
||||
lte: 4
|
||||
gte: 1
|
||||
lte: 1
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*gb200*'
|
||||
@ -14,33 +16,111 @@ l0_gb200:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
tests:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
# ------------- PyTorch tests ---------------
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_4]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_64]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=nvfp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=True-fp8kv=False-overlap_scheduler=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=False-fp8kv=False-overlap_scheduler=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=nvfp4-kv_cache_reuse=True-fp8kv=False-overlap_scheduler=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=nvfp4-kv_cache_reuse=True-fp8kv=True-overlap_scheduler=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[mxfp8-latency]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-CUTLASS]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRITON]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRTLLM]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-TRTLLM]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton]
|
||||
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551
|
||||
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_mtp[DeepSeek-V3-Lite-BF16-DeepSeek-V3-Lite/bf16]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_mixed_precision
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct]
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
|
||||
- unittest/_torch/attention
|
||||
- unittest/_torch/compilation
|
||||
- unittest/_torch/debugger
|
||||
- unittest/_torch/executor
|
||||
- unittest/_torch/misc
|
||||
- unittest/_torch/modules
|
||||
- unittest/_torch/multimodal
|
||||
- unittest/_torch/sampler
|
||||
- unittest/_torch/speculative
|
||||
- unittest/_torch/thop
|
||||
- unittest/_torch/modeling -k "modeling_llama"
|
||||
- unittest/_torch/modeling -k "modeling_mixtral"
|
||||
- unittest/_torch/modeling -k "modeling_deepseek"
|
||||
- unittest/_torch/modeling -k "modeling_gpt_oss"
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu -k "not test_trtllm_bench_backend_comparison"
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 4
|
||||
lte: 4
|
||||
gte: 1
|
||||
lte: 1
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*gb200*'
|
||||
linux_distribution_name: ubuntu*
|
||||
cpu: aarch64
|
||||
terms:
|
||||
stage: post_merge
|
||||
backend: tensorrt
|
||||
tests:
|
||||
# ------------- TRT tests ---------------
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-disable_fused_quant]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-enable_fused_quant]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-disable_fused_quant]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-enable_fused_quant]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_fp8
|
||||
- accuracy/test_cli_flow.py::TestMixtral8x7B::test_nvfp4_prequantized
|
||||
- unittest/trt/attention/test_gpt_attention.py -k "trtllm_gen"
|
||||
- unittest/llmapi/test_llm_quant.py
|
||||
- unittest/trt/functional/test_fp4_gemm.py
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 1
|
||||
lte: 1
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*gb200*'
|
||||
linux_distribution_name: ubuntu*
|
||||
cpu: aarch64
|
||||
terms:
|
||||
stage: post_merge
|
||||
backend: triton
|
||||
tests:
|
||||
# ------------- Triton tests ---------------
|
||||
- triton_server/test_triton.py::test_llava[llava]
|
||||
- triton_server/test_triton.py::test_gpt_ib_ptuning[gpt-ib-ptuning]
|
||||
- triton_server/test_triton.py::test_gpt_2b_ib_lora[gpt-2b-ib-lora]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 1
|
||||
lte: 1
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*gb200*'
|
||||
@ -50,23 +130,16 @@ l0_gb200:
|
||||
stage: post_merge
|
||||
backend: pytorch
|
||||
tests:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90)
|
||||
# ------------- PyTorch tests ---------------
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
|
||||
69
tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml
Normal file
69
tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml
Normal file
@ -0,0 +1,69 @@
|
||||
version: 0.0.1
|
||||
l0_gb200_multi_gpus:
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 4
|
||||
lte: 4
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*gb200*'
|
||||
linux_distribution_name: ubuntu*
|
||||
cpu: aarch64
|
||||
terms:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
tests:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 4
|
||||
lte: 4
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*gb200*'
|
||||
linux_distribution_name: ubuntu*
|
||||
cpu: aarch64
|
||||
terms:
|
||||
stage: post_merge
|
||||
backend: pytorch
|
||||
tests:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90)
|
||||
@ -102,6 +102,8 @@ l0_h100:
|
||||
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test
|
||||
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
@ -269,8 +269,6 @@ test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mis
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True] SKIP (https://nvbugs/5433545)
|
||||
examples/test_nemotron_nas.py::test_nemotron_nas_summary_1gpu[DeciLM-7B] SKIP (https://nvbugs/5444636)
|
||||
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_multiblock_aggressive SKIP (https://nvbugs/5444627)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] SKIP (https://nvbugs/5444687)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] SKIP (https://nvbugs/5444687)
|
||||
examples/test_qwen2audio.py::test_llm_qwen2audio_single_gpu[qwen2_audio_7b_instruct] SKIP (https://nvbugs/5447530)
|
||||
examples/test_nemotron_nas.py::test_nemotron_nas_summary_2gpu[DeciLM-7B] SKIP (https://nvbugs/5444636)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[Qwen2-VL-7B-Instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:4] SKIP (https://nvbugs/5453709)
|
||||
@ -325,3 +323,4 @@ full:L40S/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] SKIP (https://nvbugs/5471106)
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108)
|
||||
test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] SKIP (https://nvbugs/5473781)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5476580)
|
||||
|
||||
@ -52,10 +52,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
vector_dim,
|
||||
dtype=dtype,
|
||||
device=torch.device('cuda'))
|
||||
output_tensor = torch.zeros(output_entry_count,
|
||||
vector_dim,
|
||||
dtype=dtype,
|
||||
device=torch.device('cuda'))
|
||||
|
||||
send_cumsum = torch.ones(
|
||||
(1, ), dtype=torch.int32,
|
||||
@ -78,13 +74,18 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
|
||||
workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(1)
|
||||
all_workspaces = torch.zeros(1,
|
||||
workspace_size,
|
||||
workspace_size // 8,
|
||||
dtype=torch.uint64,
|
||||
device=torch.device('cuda'))
|
||||
torch.ops.trtllm.moe_initialize_workspace(all_workspaces, 0, 1)
|
||||
|
||||
torch.ops.trtllm.moe_comm(input_tensor, send_cumsum, send_indices,
|
||||
output_tensor, recv_cumsum, recv_indices,
|
||||
all_workspaces, 0, 1)
|
||||
output_tensors = torch.ops.trtllm.moe_comm([input_tensor], send_cumsum,
|
||||
send_indices, recv_cumsum,
|
||||
recv_indices, all_workspaces,
|
||||
output_entry_count, 0, 1,
|
||||
[True])
|
||||
|
||||
output_tensor = output_tensors[0]
|
||||
|
||||
torch.testing.assert_close(output_tensor,
|
||||
ref_output_tensor,
|
||||
@ -103,40 +104,43 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
send_indices = torch.zeros(1,
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
output_tensor = torch.zeros(1,
|
||||
8,
|
||||
dtype=torch.float16,
|
||||
device=torch.device('cuda'))
|
||||
recv_cumsum = torch.ones(1,
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
recv_indices = torch.zeros(1,
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
input_tensors = [input_tensor]
|
||||
workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(1)
|
||||
all_workspaces = torch.zeros(1,
|
||||
workspace_size,
|
||||
workspace_size // 8,
|
||||
dtype=torch.uint64,
|
||||
device=torch.device('cuda'))
|
||||
torch.ops.trtllm.moe_comm(input_tensor, send_cumsum, send_indices,
|
||||
output_tensor, recv_cumsum, recv_indices,
|
||||
all_workspaces, 0, 1)
|
||||
_ = torch.ops.trtllm.moe_comm(input_tensors, send_cumsum, send_indices,
|
||||
recv_cumsum, recv_indices, all_workspaces,
|
||||
1, 0, 1, [True])
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@parameterized.expand([
|
||||
(2, 5, 8, torch.float16), # small input as smoke test
|
||||
(2, 1, 8, torch.float16), # some ranks have no data to send/recv
|
||||
(4, 5, 8, torch.float16), # small input with larger world size
|
||||
(4, 901, 32768, torch.bfloat16), # large input that reuses workspace
|
||||
(8, 901, 32768,
|
||||
(2, 5, [4, 4], torch.float16), # small input as smoke test
|
||||
(2, 1, [8], torch.float16), # some ranks have no data to send/recv
|
||||
(4, 5, [8], torch.float16), # small input with larger world size
|
||||
(4, 901, [1472, 46, 4,
|
||||
4], torch.float16), # large input that reuses workspace
|
||||
(4, 5, [2944], torch.bfloat16), # large input that reuses workspace
|
||||
(8, 901, [
|
||||
32768,
|
||||
],
|
||||
torch.float16), # large input that reuses workspace, larger world size
|
||||
(
|
||||
8, 16384, 128, torch.float16
|
||||
8, 16384, [
|
||||
128,
|
||||
], torch.float16
|
||||
), # large input count with small vector dim that requires more indices per fifo
|
||||
])
|
||||
def test_moe_alltoall_multi_rank_single_gpu(self, world_size,
|
||||
input_entry_per_rank,
|
||||
vector_dim, dtype):
|
||||
vector_dims, dtype):
|
||||
torch.cuda.set_device(0)
|
||||
max_world_size = 8
|
||||
assert world_size <= max_world_size, f"should run with world_size at most {max_world_size}"
|
||||
@ -148,27 +152,32 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
torch.ops.trtllm.set_moe_max_usable_sm_count(max_sm_count)
|
||||
has_setup_max_sm_count = True
|
||||
|
||||
# Create a random input tensor
|
||||
input_tensor = torch.randn(input_entry_per_rank * world_size,
|
||||
vector_dim,
|
||||
dtype=dtype,
|
||||
device=torch.device('cuda'))
|
||||
output_tensor = torch.zeros(input_entry_per_rank * world_size,
|
||||
vector_dim,
|
||||
dtype=dtype,
|
||||
device=torch.device('cuda'))
|
||||
ref_output_tensor = torch.zeros(input_entry_per_rank * world_size,
|
||||
vector_dim,
|
||||
dtype=dtype,
|
||||
device=torch.device('cuda'))
|
||||
tensor_count = len(vector_dims)
|
||||
input_tensors = []
|
||||
ref_output_tensors = []
|
||||
for vector_dim in vector_dims:
|
||||
input_tensors.append(
|
||||
torch.randn(input_entry_per_rank * world_size,
|
||||
vector_dim,
|
||||
dtype=dtype,
|
||||
device=torch.device('cuda')))
|
||||
ref_output_tensors.append(
|
||||
torch.zeros(input_entry_per_rank * world_size,
|
||||
vector_dim,
|
||||
dtype=dtype,
|
||||
device=torch.device('cuda')))
|
||||
|
||||
target_rank_ids = torch.randint(0,
|
||||
world_size,
|
||||
(input_entry_per_rank * world_size, ),
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
|
||||
input_tensors_all_ranks = list(
|
||||
torch.split(input_tensor, input_entry_per_rank))
|
||||
input_tensors_all_ranks = []
|
||||
for i in range(tensor_count):
|
||||
input_tensors_all_ranks.append(
|
||||
list(torch.split(input_tensors[i], input_entry_per_rank)))
|
||||
|
||||
target_rank_ids_all_ranks = list(
|
||||
torch.split(target_rank_ids, input_entry_per_rank))
|
||||
|
||||
@ -210,12 +219,9 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
recv_ids_all_ranks = []
|
||||
recv_cumsum_all_ranks = []
|
||||
|
||||
output_tensors_all_ranks = []
|
||||
|
||||
total_recv_all_ranks_cpu = []
|
||||
output_indice_offset = 0
|
||||
|
||||
output_start_current_rank = 0
|
||||
# each rank do compute based on other ranks' send counts to get how to receive data from other ranks.
|
||||
for rank in range(world_size):
|
||||
local_recv_counts = torch.zeros(world_size,
|
||||
@ -227,18 +233,15 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
local_recv_count_pair = local_recv_counts[other_rank].cpu(
|
||||
).item()
|
||||
send_rank_start_end = send_start_end_all_ranks[other_rank][rank]
|
||||
ref_output_tensor[output_indice_offset:output_indice_offset + local_recv_count_pair] = \
|
||||
input_tensors_all_ranks[other_rank][send_ids_all_ranks[other_rank][send_rank_start_end[0]:send_rank_start_end[1]]]
|
||||
for i in range(tensor_count):
|
||||
ref_output_tensors[i][output_indice_offset:output_indice_offset + local_recv_count_pair] = \
|
||||
input_tensors_all_ranks[i][other_rank][send_ids_all_ranks[other_rank][send_rank_start_end[0]:send_rank_start_end[1]]]
|
||||
output_indice_offset += local_recv_count_pair
|
||||
local_recv_cumsum = torch.cumsum(local_recv_counts,
|
||||
dim=0).to(torch.int32)
|
||||
recv_cumsum_all_ranks.append(local_recv_cumsum)
|
||||
total_recv_count = local_recv_cumsum[-1].cpu()
|
||||
total_recv_all_ranks_cpu.append(total_recv_count)
|
||||
output_tensors_all_ranks.append(output_tensor[
|
||||
output_start_current_rank:output_start_current_rank +
|
||||
total_recv_count])
|
||||
output_start_current_rank += total_recv_count
|
||||
local_recv_ids = torch.arange(total_recv_count,
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
@ -251,9 +254,12 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(
|
||||
world_size)
|
||||
all_workspaces = torch.zeros(world_size,
|
||||
workspace_size,
|
||||
workspace_size // 8,
|
||||
dtype=torch.uint64,
|
||||
device=torch.device('cuda'))
|
||||
for i in range(world_size):
|
||||
torch.ops.trtllm.moe_initialize_workspace(all_workspaces, i,
|
||||
world_size)
|
||||
|
||||
# do one warmup for each rank to avoid possible synchronization at first launch.
|
||||
for rank in range(world_size):
|
||||
@ -262,212 +268,141 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Store output tensors from each rank
|
||||
output_tensors_all_ranks = []
|
||||
|
||||
# do alltoall in parallel
|
||||
for rank in range(world_size):
|
||||
input_tensors_this_rank = [
|
||||
input_tensors_all_ranks[i][rank] for i in range(tensor_count)
|
||||
]
|
||||
with torch.cuda.stream(cuda_streams_all_ranks[rank]):
|
||||
torch.ops.trtllm.moe_comm(
|
||||
input_tensors_all_ranks[rank], send_cumsum_all_ranks[rank],
|
||||
send_ids_all_ranks[rank], output_tensors_all_ranks[rank],
|
||||
recv_cumsum_all_ranks[rank], recv_ids_all_ranks[rank],
|
||||
all_workspaces, rank, world_size)
|
||||
output_tensors_this_rank = torch.ops.trtllm.moe_comm(
|
||||
input_tensors_this_rank, send_cumsum_all_ranks[rank],
|
||||
send_ids_all_ranks[rank], recv_cumsum_all_ranks[rank],
|
||||
recv_ids_all_ranks[rank], all_workspaces,
|
||||
input_entry_per_rank * world_size, rank, world_size)
|
||||
output_tensors_all_ranks.append(output_tensors_this_rank)
|
||||
|
||||
for rank in range(world_size):
|
||||
cuda_streams_all_ranks[rank].synchronize()
|
||||
|
||||
torch.testing.assert_close(output_tensor,
|
||||
ref_output_tensor,
|
||||
atol=1e-5,
|
||||
rtol=1e-5)
|
||||
# Reconstruct the full output tensors by concatenating results from all ranks
|
||||
for i in range(tensor_count):
|
||||
# Collect the actual received data from each rank (trim to actual recv count)
|
||||
actual_output_parts = []
|
||||
for rank in range(world_size):
|
||||
total_recv_count = total_recv_all_ranks_cpu[rank].item()
|
||||
# Each rank returns tensor with size [input_entry_per_rank * world_size, vector_dim]
|
||||
# but only the first total_recv_count entries are valid
|
||||
actual_output_parts.append(
|
||||
output_tensors_all_ranks[rank][i][:total_recv_count])
|
||||
|
||||
@parameterized.expand([
|
||||
(0, 8, 256, 4, 3, False),
|
||||
(0, 8, 256, 4, 3, True),
|
||||
(1, 8, 256, 4, 3, False),
|
||||
(1, 8, 256, 4, 3, True),
|
||||
(1, 4, 256, 8, 3, False),
|
||||
(1, 4, 256, 8, 3, True),
|
||||
(7, 8, 256, 8, 1025, False),
|
||||
(7, 8, 256, 8, 1025, True),
|
||||
(7, 64, 1024, 32, 1029, False),
|
||||
(7, 64, 1024, 32, 1029, True),
|
||||
])
|
||||
def test_moe_alltoall_prepare_indices(
|
||||
self, ep_rank: int, ep_size: int, expert_count: int, top_k: int,
|
||||
max_token_count_per_rank: int,
|
||||
use_real_rank_token_count_cumsum: bool):
|
||||
# Concatenate all ranks' outputs to form the complete result
|
||||
actual_output = torch.cat(actual_output_parts, dim=0)
|
||||
torch.testing.assert_close(actual_output,
|
||||
ref_output_tensors[i],
|
||||
atol=1e-5,
|
||||
rtol=1e-5)
|
||||
|
||||
|
||||
class TestMoeAlltoAllFP8SingleGPU(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
torch.manual_seed(0x1234)
|
||||
tllm.logger.set_level('error')
|
||||
|
||||
def test_moe_alltoall_fp8_with_indices(self):
|
||||
"""Test fp8 alltoall with properly constructed indices"""
|
||||
torch.cuda.set_device(0)
|
||||
gathered_target_rank_ids = torch.randint(
|
||||
0,
|
||||
ep_size, (ep_size * max_token_count_per_rank, top_k),
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
real_rank_token_count_cumsum = None
|
||||
if use_real_rank_token_count_cumsum:
|
||||
real_rank_token_count_cumsum = torch.randint(
|
||||
0,
|
||||
max_token_count_per_rank + 1, (ep_size, ),
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
real_rank_token_count_cumsum = torch.cumsum(
|
||||
real_rank_token_count_cumsum, dim=0).to(torch.int32)
|
||||
|
||||
def generate_references():
|
||||
gathered_target_rank_ids_cpu_lists = gathered_target_rank_ids.cpu(
|
||||
).tolist()
|
||||
if use_real_rank_token_count_cumsum:
|
||||
real_rank_token_count_cumsum_cpu_lists = real_rank_token_count_cumsum.cpu(
|
||||
).tolist()
|
||||
else:
|
||||
real_rank_token_count_cumsum_cpu_lists = [
|
||||
(i + 1) * max_token_count_per_rank for i in range(ep_size)
|
||||
]
|
||||
rank_token_start = 0
|
||||
ref_local_gather_indices_cpu_lists = []
|
||||
ref_recv_rank_count_cumsum_cpu_lists = [0] * ep_size
|
||||
ref_recv_rank_local_indices_cpu_lists = []
|
||||
ref_send_rank_count_cumsum_cpu_lists = [0] * ep_size
|
||||
ref_send_rank_local_indices_cpu_lists = []
|
||||
ref_backward_recv_rank_local_indices_cpu_lists = []
|
||||
total_recv_count = 0
|
||||
for rank in range(ep_size):
|
||||
rank_token_end = real_rank_token_count_cumsum_cpu_lists[rank]
|
||||
for token_id in range(rank_token_start, rank_token_end):
|
||||
if ep_rank in gathered_target_rank_ids_cpu_lists[token_id]:
|
||||
ref_local_gather_indices_cpu_lists.append(token_id)
|
||||
ref_recv_rank_local_indices_cpu_lists.append(
|
||||
total_recv_count)
|
||||
total_recv_count += 1
|
||||
ref_recv_rank_count_cumsum_cpu_lists[rank] = total_recv_count
|
||||
if rank == ep_rank:
|
||||
total_send_count = 0
|
||||
for target_rank in range(ep_size):
|
||||
for token_id in range(rank_token_start, rank_token_end):
|
||||
local_token_id = token_id - rank_token_start
|
||||
if target_rank in gathered_target_rank_ids_cpu_lists[
|
||||
token_id]:
|
||||
pos = gathered_target_rank_ids_cpu_lists[
|
||||
token_id].index(target_rank)
|
||||
ref_send_rank_local_indices_cpu_lists.append(
|
||||
local_token_id)
|
||||
ref_backward_recv_rank_local_indices_cpu_lists.append(
|
||||
local_token_id * top_k + pos)
|
||||
total_send_count += 1
|
||||
ref_send_rank_count_cumsum_cpu_lists[
|
||||
target_rank] = total_send_count
|
||||
rank_token_start = rank_token_end
|
||||
ref_local_gather_indices = torch.IntTensor(
|
||||
ref_local_gather_indices_cpu_lists).cuda()
|
||||
ref_send_rank_count_cumsum = torch.IntTensor(
|
||||
ref_send_rank_count_cumsum_cpu_lists).cuda()
|
||||
ref_send_rank_local_indices = torch.IntTensor(
|
||||
ref_send_rank_local_indices_cpu_lists).cuda()
|
||||
ref_recv_rank_count_cumsum = torch.IntTensor(
|
||||
ref_recv_rank_count_cumsum_cpu_lists).cuda()
|
||||
ref_recv_rank_local_indices = torch.IntTensor(
|
||||
ref_recv_rank_local_indices_cpu_lists).cuda()
|
||||
ref_backward_recv_rank_local_indices = torch.IntTensor(
|
||||
ref_backward_recv_rank_local_indices_cpu_lists).cuda()
|
||||
return ref_local_gather_indices, ref_send_rank_count_cumsum, ref_send_rank_local_indices, ref_recv_rank_count_cumsum, ref_recv_rank_local_indices, ref_backward_recv_rank_local_indices
|
||||
# Match dimensions from the error
|
||||
input_entry_count = 16384
|
||||
output_entry_count = 16384
|
||||
vector_dim = 2944
|
||||
sf_vector_dim = 92 # Scaling factor dimension from error
|
||||
send_recv_count = 1000 # Number of entries to send/receive
|
||||
|
||||
ref_local_gather_indices, ref_send_rank_count_cumsum, ref_send_rank_local_indices, ref_recv_rank_count_cumsum, ref_recv_rank_local_indices, ref_backward_recv_rank_local_indices = generate_references(
|
||||
)
|
||||
# Create input tensors - first as float16, then convert
|
||||
input_tensor_fp16 = torch.randn(input_entry_count,
|
||||
vector_dim,
|
||||
dtype=torch.float16,
|
||||
device='cuda')
|
||||
input_tensor_fp8 = input_tensor_fp16.to(torch.float8_e4m3fn)
|
||||
|
||||
local_gather_indices, send_rank_count_cumsum, send_rank_local_indices, recv_rank_count_cumsum, recv_rank_local_indices, backward_recv_rank_local_indices = \
|
||||
torch.ops.trtllm.moe_comm_prepare_indices(gathered_target_rank_ids, real_rank_token_count_cumsum, max_token_count_per_rank, expert_count, top_k, ep_rank, ep_size)
|
||||
# Scaling factor tensor
|
||||
input_sf_tensor = torch.randint(1,
|
||||
255, (input_entry_count, sf_vector_dim),
|
||||
dtype=torch.uint8,
|
||||
device='cuda')
|
||||
|
||||
assert torch.equal(
|
||||
local_gather_indices[:torch.numel(ref_local_gather_indices)],
|
||||
ref_local_gather_indices)
|
||||
assert torch.equal(
|
||||
send_rank_count_cumsum[:torch.numel(ref_send_rank_count_cumsum)],
|
||||
ref_send_rank_count_cumsum)
|
||||
assert torch.equal(
|
||||
send_rank_local_indices[:torch.numel(ref_send_rank_local_indices)],
|
||||
ref_send_rank_local_indices)
|
||||
assert torch.equal(
|
||||
recv_rank_count_cumsum[:torch.numel(ref_recv_rank_count_cumsum)],
|
||||
ref_recv_rank_count_cumsum)
|
||||
assert torch.equal(
|
||||
recv_rank_local_indices[:torch.numel(ref_recv_rank_local_indices)],
|
||||
ref_recv_rank_local_indices)
|
||||
assert torch.equal(
|
||||
backward_recv_rank_local_indices[:torch.numel(
|
||||
ref_backward_recv_rank_local_indices)],
|
||||
ref_backward_recv_rank_local_indices)
|
||||
# Expert selection tensors
|
||||
input_experts = torch.randint(0,
|
||||
64, (input_entry_count, 4),
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
input_scales = torch.rand(input_entry_count,
|
||||
4,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
|
||||
@parameterized.expand([
|
||||
(0, 8, 256, 4, 3),
|
||||
(1, 8, 256, 4, 3),
|
||||
(7, 8, 256, 4, 3),
|
||||
(7, 8, 256, 8, 32),
|
||||
(7, 8, 256, 32, 10),
|
||||
(7, 8, 1024, 32, 127),
|
||||
(7, 64, 1024, 32, 1029),
|
||||
(9, 64, 1024, 3, 1029),
|
||||
])
|
||||
def test_moe_local_gather(self, ep_rank: int, ep_size: int,
|
||||
expert_count: int, top_k: int,
|
||||
max_token_count_per_rank: int):
|
||||
torch.cuda.set_device(0)
|
||||
rank_token_count_cumsum = torch.randint(0,
|
||||
max_token_count_per_rank + 1,
|
||||
(ep_size, ),
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
rank_token_count_cumsum = torch.cumsum(rank_token_count_cumsum,
|
||||
dim=0).to(torch.int32)
|
||||
local_token_count = rank_token_count_cumsum[ep_size - 1].cpu().item()
|
||||
local_max_token_count = max_token_count_per_rank * ep_size
|
||||
local_gather_indices = torch.randint(0,
|
||||
max_token_count_per_rank * ep_size,
|
||||
(local_max_token_count, ),
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
# Construct send/recv indices
|
||||
send_cumsum = torch.tensor([send_recv_count],
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
recv_cumsum = torch.tensor([send_recv_count],
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
|
||||
gathered_expert_ids = torch.randint(
|
||||
0,
|
||||
expert_count, (max_token_count_per_rank * ep_size, top_k),
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
gathered_scales = torch.rand(
|
||||
(max_token_count_per_rank * ep_size, top_k),
|
||||
dtype=torch.float32,
|
||||
device=torch.device('cuda'))
|
||||
# Random indices for sending
|
||||
send_indices = torch.randperm(input_entry_count,
|
||||
dtype=torch.int32,
|
||||
device='cuda')[:send_recv_count]
|
||||
recv_indices = torch.randperm(output_entry_count,
|
||||
dtype=torch.int32,
|
||||
device='cuda')[:send_recv_count]
|
||||
|
||||
ref_local_expert_ids = torch.zeros(local_max_token_count,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
ref_local_scales = torch.zeros(local_max_token_count,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=torch.device('cuda'))
|
||||
# Create workspace
|
||||
workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(1)
|
||||
all_workspaces = torch.zeros(1,
|
||||
workspace_size // 8,
|
||||
dtype=torch.uint64,
|
||||
device='cuda')
|
||||
torch.ops.trtllm.moe_initialize_workspace(all_workspaces, 0, 1)
|
||||
|
||||
# compute reference
|
||||
ref_local_expert_ids += expert_count
|
||||
valid_local_gather_indices = local_gather_indices[:local_token_count]
|
||||
ref_local_expert_ids[:local_token_count] = gathered_expert_ids[
|
||||
valid_local_gather_indices]
|
||||
ref_local_scales[:local_token_count] = gathered_scales[
|
||||
valid_local_gather_indices]
|
||||
print(f"Test configuration:")
|
||||
print(f" Input entries: {input_entry_count}")
|
||||
print(f" Vector dim: {vector_dim}")
|
||||
print(f" SF vector dim: {sf_vector_dim}")
|
||||
print(f" Send/recv count: {send_recv_count}")
|
||||
print(f" FP8 tensor shape: {input_tensor_fp8.shape}")
|
||||
print(f" SF tensor shape: {input_sf_tensor.shape}")
|
||||
|
||||
local_expert_ids = torch.empty(local_max_token_count,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
local_scales = torch.empty(local_max_token_count,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=torch.device('cuda'))
|
||||
try:
|
||||
# Test with all 4 tensors
|
||||
output_tensor_fp8, output_sf_tensor, output_experts, output_scales = \
|
||||
torch.ops.trtllm.moe_comm([
|
||||
input_tensor_fp8, input_sf_tensor, input_experts, input_scales
|
||||
], send_cumsum, send_indices, recv_cumsum, recv_indices, all_workspaces, output_entry_count, 0, 1)
|
||||
|
||||
torch.ops.trtllm.moe_local_gather(rank_token_count_cumsum,
|
||||
local_gather_indices,
|
||||
gathered_expert_ids, gathered_scales,
|
||||
local_expert_ids, local_scales,
|
||||
max_token_count_per_rank,
|
||||
expert_count, top_k, ep_rank, ep_size)
|
||||
torch.cuda.synchronize()
|
||||
print("FP8 alltoall test PASSED!")
|
||||
|
||||
assert torch.equal(local_expert_ids, ref_local_expert_ids)
|
||||
assert torch.equal(local_scales, ref_local_scales)
|
||||
# Verify outputs
|
||||
print(f"\nOutput verification:")
|
||||
print(f" Output FP8 shape: {output_tensor_fp8.shape}")
|
||||
print(f" Output SF shape: {output_sf_tensor.shape}")
|
||||
print(
|
||||
f" Non-zero FP8 elements: {(output_tensor_fp8 != 0).sum().item()}"
|
||||
)
|
||||
print(
|
||||
f" Non-zero SF elements: {(output_sf_tensor != 0).sum().item()}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"FP8 alltoall test FAILED: {e}")
|
||||
print(f"Error type: {type(e)}")
|
||||
raise
|
||||
|
||||
@parameterized.expand([
|
||||
(0, 2, 16, 20, 8, 512),
|
||||
@ -489,7 +424,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
|
||||
cpu_expert_ids_all_ranks_lists = []
|
||||
cpu_token_count_lists = []
|
||||
cpu_scales_all_ranks_lists = []
|
||||
for _ in range(ep_size):
|
||||
token_count = torch.randint(max_token_count_per_rank // 2,
|
||||
max_token_count_per_rank + 1, (1, ),
|
||||
@ -505,12 +439,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cpu')))
|
||||
|
||||
cpu_scales_all_ranks_lists.append(
|
||||
torch.zeros(token_count,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=torch.device('cpu')) + 0.5)
|
||||
|
||||
cpu_token_count_lists.append(token_count)
|
||||
|
||||
def compute_target_rank(expert_id):
|
||||
@ -519,7 +447,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
|
||||
def generate_references():
|
||||
ref_prepared_local_expert_ids = []
|
||||
ref_prepared_local_scales = []
|
||||
ref_local_send_rank_count_cumsum = [0] * ep_size
|
||||
ref_local_recv_rank_count_cumsum = [0] * ep_size
|
||||
ref_local_recv_rank_indices = []
|
||||
@ -580,16 +507,13 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
for pos in range(top_k):
|
||||
expert_id = int(
|
||||
cpu_expert_ids_all_ranks_lists[rank][token_id][pos])
|
||||
sf = cpu_scales_all_ranks_lists[rank][token_id][pos]
|
||||
target_rank_id = compute_target_rank(expert_id)
|
||||
if target_rank_id == ep_rank:
|
||||
if not token_is_received:
|
||||
token_is_received = True
|
||||
ref_prepared_local_expert_ids.append(
|
||||
[slot_count] * top_k)
|
||||
ref_prepared_local_scales.append([0.0] * top_k)
|
||||
ref_prepared_local_expert_ids[-1][pos] = expert_id
|
||||
ref_prepared_local_scales[-1][pos] = sf
|
||||
if token_is_received:
|
||||
ref_local_recv_rank_indices.append(
|
||||
total_recv_token_count)
|
||||
@ -599,9 +523,9 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
rank] = current_recv_token_count if rank == 0 else ref_local_recv_rank_count_cumsum[
|
||||
rank - 1] + current_recv_token_count
|
||||
|
||||
return ref_prepared_local_expert_ids, ref_prepared_local_scales, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count
|
||||
return ref_prepared_local_expert_ids, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count
|
||||
|
||||
ref_prepared_local_expert_ids, ref_prepared_local_scales, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count = generate_references(
|
||||
ref_prepared_local_expert_ids, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count = generate_references(
|
||||
)
|
||||
|
||||
cpu_experter_count_lists = []
|
||||
@ -615,10 +539,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
expert_ids_all_ranks = [
|
||||
cpu_expert_ids_all_ranks_lists[i].cuda() for i in range(ep_size)
|
||||
]
|
||||
#scales_all_ranks = torch.FloatTensor(cpu_scales_all_ranks_lists).cuda()
|
||||
scales_all_ranks = [
|
||||
cpu_scales_all_ranks_lists[i].cuda() for i in range(ep_size)
|
||||
]
|
||||
|
||||
experter_count_lists = [
|
||||
cpu_experter_count_lists[i].cuda() for i in range(ep_size)
|
||||
@ -637,30 +557,18 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
expert_ids_all_ranks[0], scales_all_ranks[0],
|
||||
experter_count_lists[0], all_workspaces,
|
||||
max_token_count_per_rank, 0, 1, expert_count, slot_count, top_k)
|
||||
expert_ids_all_ranks[0], experter_count_lists[0],
|
||||
all_workspaces, max_token_count_per_rank, 0, 1, expert_count,
|
||||
slot_count, top_k)
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
# Make torch alloc tensor to avoid cuda sync
|
||||
prepared_local_experts = []
|
||||
prepared_local_scales = []
|
||||
local_send_rank_count_cumsum = []
|
||||
local_send_rank_indices = []
|
||||
local_recv_rank_count_cumsum = []
|
||||
local_recv_rank_indices = []
|
||||
backward_local_recv_rank_indices = []
|
||||
for _ in range(ep_size):
|
||||
prepared_local_experts.append(
|
||||
torch.empty(max_token_count_per_rank * ep_size,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda')))
|
||||
prepared_local_scales.append(
|
||||
torch.empty(max_token_count_per_rank * ep_size,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=torch.device('cuda')))
|
||||
local_send_rank_count_cumsum.append(
|
||||
torch.empty(ep_size,
|
||||
dtype=torch.int32,
|
||||
@ -676,8 +584,6 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
backward_local_recv_rank_indices.append(
|
||||
torch.empty(0, dtype=torch.int32, device=torch.device('cuda')))
|
||||
|
||||
prepared_local_experts = []
|
||||
prepared_local_scales = []
|
||||
local_send_rank_count_cumsum = []
|
||||
local_send_rank_indices = []
|
||||
local_recv_rank_count_cumsum = []
|
||||
@ -694,35 +600,19 @@ class TestMoeAlltoAllSingleGPU(unittest.TestCase):
|
||||
for rank in range(ep_size):
|
||||
with torch.cuda.stream(cuda_streams_all_ranks[rank]):
|
||||
if rank == ep_rank:
|
||||
prepared_local_experts, prepared_local_scales, local_send_rank_count_cumsum, \
|
||||
local_send_rank_count_cumsum, \
|
||||
local_send_rank_indices, local_recv_rank_count_cumsum, local_recv_rank_indices, \
|
||||
backward_local_recv_rank_indices, gathered_expert_statics\
|
||||
= torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(expert_ids_all_ranks[rank], scales_all_ranks[rank], experter_count_lists[rank], all_workspaces, max_token_count_per_rank,
|
||||
= torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(expert_ids_all_ranks[rank], experter_count_lists[rank], all_workspaces, max_token_count_per_rank,
|
||||
rank, ep_size, expert_count, slot_count, top_k)
|
||||
else:
|
||||
torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
expert_ids_all_ranks[rank], scales_all_ranks[rank],
|
||||
experter_count_lists[rank], all_workspaces,
|
||||
max_token_count_per_rank, rank, ep_size, expert_count,
|
||||
slot_count, top_k)
|
||||
expert_ids_all_ranks[rank], experter_count_lists[rank],
|
||||
all_workspaces, max_token_count_per_rank, rank, ep_size,
|
||||
expert_count, slot_count, top_k)
|
||||
for rank in range(ep_size):
|
||||
cuda_streams_all_ranks[rank].synchronize()
|
||||
|
||||
prepared_local_experts_cpu = prepared_local_experts[:
|
||||
total_recv_token_count].cpu(
|
||||
)
|
||||
prepared_local_scales_cpu = prepared_local_scales[:
|
||||
total_recv_token_count].cpu(
|
||||
)
|
||||
for i in range(total_recv_token_count):
|
||||
for j in range(top_k):
|
||||
expert_id = int(prepared_local_experts_cpu[i][j])
|
||||
assert 0 <= expert_id and expert_id <= slot_count
|
||||
if expert_id < slot_count:
|
||||
assert compute_target_rank(expert_id) == ep_rank
|
||||
scale = float(prepared_local_scales_cpu[i][j])
|
||||
assert scale > 1e-6
|
||||
|
||||
gathered_expert_statics_cpu = gathered_expert_statics.cpu()
|
||||
for rank in range(ep_size):
|
||||
for i in range(expert_count):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user