mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9493][feat] Custom AllToAll for helix parallelism (#9986)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
92d90fa29a
commit
8c1cfc872b
683
cpp/tensorrt_llm/kernels/helixAllToAll.cu
Normal file
683
cpp/tensorrt_llm/kernels/helixAllToAll.cu
Normal file
@ -0,0 +1,683 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/kernels/cudaAsyncOps.cuh"
|
||||
#include "tensorrt_llm/kernels/helixAllToAll.h"
|
||||
#include "tensorrt_llm/kernels/ll128Proto.cuh"
|
||||
#include "tensorrt_llm/kernels/moeCommKernelsCommon.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
// ============================================================================
|
||||
// Structure declarations and definitions
|
||||
// ============================================================================
|
||||
|
||||
// ALIGN_256 is defined in moeCommKernelsCommon.h
|
||||
|
||||
struct ALIGN_256 HelixFifoInfo
|
||||
{
|
||||
volatile int64_t head;
|
||||
volatile int64_t tail;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Helix-specific FIFO constants
|
||||
// Note: Helix uses 128KB FIFO entries vs 256KB in FusedMoe
|
||||
// ============================================================================
|
||||
|
||||
constexpr int HELIX_FIFO_DEPTH = 4;
|
||||
constexpr int HELIX_FIFO_ENTRY_BYTES = 128 * 1024;
|
||||
constexpr int HELIX_FIFO_TOTAL_BYTES = HELIX_FIFO_ENTRY_BYTES * HELIX_FIFO_DEPTH;
|
||||
constexpr int HELIX_FIFO_ENTRY_128B_COUNT = HELIX_FIFO_ENTRY_BYTES / BYTES_PER_128B_BLOCK;
|
||||
constexpr int HELIX_FIFO_TOTAL_U64 = HELIX_FIFO_TOTAL_BYTES / sizeof(uint64_t);
|
||||
|
||||
// ============================================================================
|
||||
// Implementation-only structures
|
||||
// ============================================================================
|
||||
|
||||
struct HelixPairInfo
|
||||
{
|
||||
int senderRank;
|
||||
int receiverRank;
|
||||
int channel;
|
||||
int runChannelCount;
|
||||
};
|
||||
|
||||
// WARP_SIZE, WARP_MASK, and other constants are defined in moeCommKernelsCommon.h
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
__host__ __device__ inline int getFieldSize(HelixFieldInfo const& fieldInfo)
|
||||
{
|
||||
return fieldInfo.elementCount * fieldInfo.elementSize;
|
||||
}
|
||||
|
||||
__host__ __device__ inline uint8_t* getPtr(HelixFieldInfo const& fieldInfo, int blockIdx)
|
||||
{
|
||||
return fieldInfo.dataPtr + blockIdx * fieldInfo.stride;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void waitG2sAllFields(uint64_t* smemBar, uint32_t* phaseParity)
|
||||
{
|
||||
cp_async_wait_group<0>();
|
||||
smemBarWait(smemBar, phaseParity);
|
||||
}
|
||||
|
||||
// Align size to 128 bytes
|
||||
__host__ __device__ __forceinline__ int align128(int size)
|
||||
{
|
||||
return align_up(size, BYTES_PER_128B_BLOCK);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// G2S (Global to Shared) Operations
|
||||
// ============================================================================
|
||||
|
||||
__device__ __forceinline__ void g2sField(
|
||||
HelixFieldInfo const& fieldInfo, int dataIndex, uint8_t* shmemBase, int shmemOffset, uint64_t* smemBar, int laneId)
|
||||
{
|
||||
int copySize = getFieldSize(fieldInfo);
|
||||
if (copySize > 0 && laneId == 0)
|
||||
{
|
||||
uint8_t* srcPtr = getPtr(fieldInfo, dataIndex);
|
||||
uint8_t* dstPtr = shmemBase + shmemOffset;
|
||||
cp_async_bulk_g2s(dstPtr, srcPtr, copySize, smemBar);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool ALLOW_VARIABLE_FIELD1>
|
||||
__device__ __forceinline__ int g2sAllFields(
|
||||
HelixFieldInfo const* fieldInfo, int dataIndex, uint8_t* shmemBase, uint64_t* smemBar, int laneId)
|
||||
{
|
||||
int totalSize = 0;
|
||||
|
||||
// Load field 0 (variable size half)
|
||||
g2sField(fieldInfo[0], dataIndex, shmemBase, 0, smemBar, laneId);
|
||||
int field0Size = getFieldSize(fieldInfo[0]);
|
||||
totalSize += field0Size;
|
||||
|
||||
// Load field 1 (single float2)
|
||||
if constexpr (ALLOW_VARIABLE_FIELD1)
|
||||
{
|
||||
g2sField(fieldInfo[1], dataIndex, shmemBase, totalSize, smemBar, laneId);
|
||||
totalSize += getFieldSize(fieldInfo[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
ldgsts<8>(reinterpret_cast<int*>(shmemBase + totalSize),
|
||||
reinterpret_cast<int const*>(getPtr(fieldInfo[1], dataIndex)), laneId == 0);
|
||||
cp_async_commit_group();
|
||||
}
|
||||
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// S2G (Shared to Global) Operations
|
||||
// ============================================================================
|
||||
|
||||
__device__ __forceinline__ void s2gField(
|
||||
HelixFieldInfo const& fieldInfo, int dataIndex, uint8_t* shmemBase, int shmemOffset, int laneId)
|
||||
{
|
||||
int copySize = getFieldSize(fieldInfo);
|
||||
if (copySize > 0 && laneId == 0)
|
||||
{
|
||||
uint8_t* srcPtr = shmemBase + shmemOffset;
|
||||
uint8_t* dstPtr = getPtr(fieldInfo, dataIndex);
|
||||
cp_async_bulk_s2g(dstPtr, srcPtr, copySize);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool ALLOW_VARIABLE_FIELD1>
|
||||
__device__ __forceinline__ void s2gAllFields(
|
||||
HelixFieldInfo const* fieldInfo, int dataIndex, uint8_t* shmemBase, int laneId)
|
||||
{
|
||||
int offset = 0;
|
||||
|
||||
// Store field 0 (variable size half)
|
||||
s2gField(fieldInfo[0], dataIndex, shmemBase, offset, laneId);
|
||||
int field0Size = getFieldSize(fieldInfo[0]);
|
||||
offset += field0Size;
|
||||
|
||||
// Store field 1 (single float2)
|
||||
if constexpr (ALLOW_VARIABLE_FIELD1)
|
||||
{
|
||||
s2gField(fieldInfo[1], dataIndex, shmemBase, offset, laneId);
|
||||
offset += getFieldSize(fieldInfo[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (laneId == 0)
|
||||
{
|
||||
auto* srcPtr = reinterpret_cast<float2*>(reinterpret_cast<uint8_t*>(shmemBase) + offset);
|
||||
auto* dstPtr = reinterpret_cast<float2*>(getPtr(fieldInfo[1], dataIndex));
|
||||
dstPtr[0] = srcPtr[0];
|
||||
}
|
||||
}
|
||||
cp_async_bulk_commit_group();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Workspace FIFO Operations
|
||||
// ============================================================================
|
||||
|
||||
__device__ __forceinline__ uint64_t* getFifoBasePtr(HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
|
||||
{
|
||||
// FIFO is physically located at receiver rank
|
||||
int mappedMemoryRank = pairInfo.receiverRank;
|
||||
int rankInsideMappedMemory = pairInfo.senderRank;
|
||||
|
||||
auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64;
|
||||
// Navigate to the right FIFO: [peer_rank][channel]
|
||||
size_t fifoOffset = rankInsideMappedMemory * params.maxChannelCount * HELIX_FIFO_TOTAL_U64;
|
||||
fifoOffset += pairInfo.channel * HELIX_FIFO_TOTAL_U64;
|
||||
|
||||
return mappedMemory + fifoOffset;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ HelixFifoInfo* getSenderHelixFifoInfo(
|
||||
HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
|
||||
{
|
||||
// SenderSideHelixFifoInfo is physically located at sender rank
|
||||
int mappedMemoryRank = pairInfo.senderRank;
|
||||
int rankInsideMappedMemory = pairInfo.receiverRank;
|
||||
|
||||
auto* mappedMemory = reinterpret_cast<uint8_t*>(params.workspace + mappedMemoryRank * params.workspaceStrideInU64);
|
||||
size_t fieldOffset = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount;
|
||||
mappedMemory += fieldOffset;
|
||||
mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(HelixFifoInfo);
|
||||
mappedMemory += pairInfo.channel * sizeof(HelixFifoInfo);
|
||||
|
||||
return reinterpret_cast<HelixFifoInfo*>(mappedMemory);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ HelixFifoInfo* getReceiverHelixFifoInfo(
|
||||
HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
|
||||
{
|
||||
// ReceiverSideHelixFifoInfo is physically located at receiver rank
|
||||
int mappedMemoryRank = pairInfo.receiverRank;
|
||||
int rankInsideMappedMemory = pairInfo.senderRank;
|
||||
|
||||
auto* mappedMemory = reinterpret_cast<uint8_t*>(params.workspace + mappedMemoryRank * params.workspaceStrideInU64);
|
||||
size_t fieldOffset = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount;
|
||||
fieldOffset += sizeof(HelixFifoInfo) * params.cpSize * params.maxChannelCount;
|
||||
mappedMemory += fieldOffset;
|
||||
mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(HelixFifoInfo);
|
||||
mappedMemory += pairInfo.channel * sizeof(HelixFifoInfo);
|
||||
|
||||
return reinterpret_cast<HelixFifoInfo*>(mappedMemory);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void startWorkspaceS2G(
|
||||
uint64_t* fifoEntry, uint8_t* shmemBase, int send128ByteCount, int fifo128ByteOffset, int laneId)
|
||||
{
|
||||
int copyByteCount = send128ByteCount * BYTES_PER_128B_BLOCK;
|
||||
if (laneId == 0)
|
||||
{
|
||||
cp_async_bulk_s2g(
|
||||
fifoEntry + fifo128ByteOffset * BYTES_PER_128B_BLOCK / sizeof(uint64_t), shmemBase, copyByteCount);
|
||||
}
|
||||
cp_async_bulk_commit_group();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void startWorkspaceS2GReg(
|
||||
uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int laneId)
|
||||
{
|
||||
int copyInt4Count = send128ByteCount * BYTES_PER_128B_BLOCK / sizeof(int4);
|
||||
int4* sharedMemoryInt4 = reinterpret_cast<int4*>(sharedMemoryBase);
|
||||
uint64_t* fifoPtr = fifoEntry + fifo128ByteOffset * UINT64_PER_128B_BLOCK;
|
||||
int4* fifoPtrInt4 = reinterpret_cast<int4*>(fifoPtr);
|
||||
#pragma unroll 4
|
||||
for (int i = laneId; i < copyInt4Count; i += WARP_SIZE)
|
||||
{
|
||||
fifoPtrInt4[i] = sharedMemoryInt4[i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* shmemBase, uint64_t* fifoEntry, int allLoad128ByteCount,
|
||||
int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int laneId)
|
||||
{
|
||||
int copyByteCount = (allLoad128ByteCount - loaded128ByteCount) * BYTES_PER_128B_BLOCK;
|
||||
if (laneId == 0)
|
||||
{
|
||||
cp_async_bulk_g2s(shmemBase + loaded128ByteCount * BYTES_PER_128B_BLOCK,
|
||||
fifoEntry + (fifo128ByteOffset + loaded128ByteCount) * UINT64_PER_128B_BLOCK, copyByteCount, smemBar);
|
||||
}
|
||||
return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? copyByteCount : 0);
|
||||
}
|
||||
|
||||
// LL128Proto is now defined in ll128Proto.cuh
|
||||
|
||||
// ============================================================================
|
||||
// Size helpers
|
||||
// ============================================================================
|
||||
|
||||
// Compute total size needed for both fields
|
||||
__host__ __device__ __forceinline__ int computeTotalUnpackedSize(HelixFieldInfo const* fields)
|
||||
{
|
||||
int size = 0;
|
||||
// Field 0: note it must be aligned to 16 bytes
|
||||
size += align_up(getFieldSize(fields[0]), 16);
|
||||
// Field 1: single float2
|
||||
size += align_up(getFieldSize(fields[1]), 16);
|
||||
return align128(size);
|
||||
}
|
||||
|
||||
__host__ __device__ __forceinline__ int computeTotalPackedSize(HelixFieldInfo const* fields)
|
||||
{
|
||||
// because field 0 must be aligned to 16 bytes, this is the same as unpacked
|
||||
return computeTotalUnpackedSize(fields);
|
||||
}
|
||||
|
||||
__host__ __device__ __forceinline__ int computeProtoTransferSize(HelixFieldInfo const* fields)
|
||||
{
|
||||
return LL128Proto::computeProtoTransfer128ByteAlignedSize(computeTotalPackedSize(fields));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main All-to-All Kernel
|
||||
// ============================================================================
|
||||
|
||||
template <bool ALLOW_VARIABLE_FIELD1>
|
||||
__global__ void helixAllToAllKernel(HelixAllToAllParams params)
|
||||
{
|
||||
extern __shared__ uint8_t allWarpShmem[];
|
||||
__shared__ uint64_t allWarpSmemBar[MAX_GROUP_COUNT_PER_BLOCK];
|
||||
|
||||
bool isSender = (blockIdx.z == 0);
|
||||
// Each warp is a group handling a different peer rank
|
||||
int group = __shfl_sync(WARP_MASK, threadIdx.y, 0);
|
||||
int laneId = threadIdx.x % WARP_SIZE;
|
||||
int runChannelCount = gridDim.y;
|
||||
|
||||
// Compute peer rank: blockIdx.x determines which set of peers, group
|
||||
// determines which peer in that set
|
||||
int peerRank = blockIdx.x * blockDim.y + group;
|
||||
|
||||
if (peerRank >= params.cpSize)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Setup pair info for this communication
|
||||
HelixPairInfo pairInfo;
|
||||
pairInfo.channel = blockIdx.y;
|
||||
pairInfo.runChannelCount = runChannelCount;
|
||||
pairInfo.senderRank = isSender ? params.cpRank : peerRank;
|
||||
pairInfo.receiverRank = isSender ? peerRank : params.cpRank;
|
||||
|
||||
// Initialize barrier for this group
|
||||
initSmemBar(&allWarpSmemBar[group], laneId);
|
||||
uint32_t phaseParity = 0;
|
||||
|
||||
// Get shared memory for this group
|
||||
int singlePackedSize = computeTotalPackedSize(params.sendFields);
|
||||
int singlePacked128ByteCount = singlePackedSize / BYTES_PER_128B_BLOCK;
|
||||
int singleUnpackedSize = computeTotalUnpackedSize(params.sendFields);
|
||||
int singleProtoTransferSize = computeProtoTransferSize(params.sendFields);
|
||||
int singleProtoTransfer128ByteCount = singleProtoTransferSize / BYTES_PER_128B_BLOCK;
|
||||
int singleShmSize = std::max(singleUnpackedSize, singleProtoTransferSize);
|
||||
uint8_t* shmem = allWarpShmem + group * singleShmSize;
|
||||
|
||||
// Get FIFO pointers
|
||||
uint64_t* fifoBase = getFifoBasePtr(params, pairInfo);
|
||||
HelixFifoInfo* senderFifo = getSenderHelixFifoInfo(params, pairInfo);
|
||||
HelixFifoInfo* receiverFifo = getReceiverHelixFifoInfo(params, pairInfo);
|
||||
|
||||
int fifoEntry128ByteIndexBase = HELIX_FIFO_ENTRY_128B_COUNT;
|
||||
int fifoEntryIndex = -1;
|
||||
|
||||
// regardless of sender or receiver, we wait for the previous kernel here
|
||||
// receiver blocks do not need to wait at all, but they should not start
|
||||
// to stress the memory system regardless
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
if (isSender)
|
||||
{
|
||||
// sender blocks should trigger next kernel immediately, s.t. they
|
||||
// do not block the next kernel from starting
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
|
||||
// Sender logic: send data from cpRank's slice to peerRank
|
||||
int64_t head = senderFifo->head;
|
||||
int64_t tail = senderFifo->tail;
|
||||
|
||||
// Each channel processes entries with stride
|
||||
// Start at channel index, increment by total channel count
|
||||
for (int entryIdx = pairInfo.channel; entryIdx < params.entryCount; entryIdx += runChannelCount)
|
||||
{
|
||||
|
||||
// dataIndex points to the data for peerRank in this entry
|
||||
int dataIndex = entryIdx * params.cpSize + peerRank;
|
||||
|
||||
// Load data from global to shared, then arrive on barrier
|
||||
int loadedSize = g2sAllFields<ALLOW_VARIABLE_FIELD1>(
|
||||
params.sendFields, dataIndex, shmem, &allWarpSmemBar[group], laneId);
|
||||
uint64_t arriveState = mbarrier_arrive_expect_tx(&allWarpSmemBar[group], laneId == 0 ? loadedSize : 0);
|
||||
|
||||
// update FIFO entry index and head if needed
|
||||
if (fifoEntry128ByteIndexBase + singleProtoTransfer128ByteCount > HELIX_FIFO_ENTRY_128B_COUNT)
|
||||
{
|
||||
if (fifoEntryIndex >= 0)
|
||||
{
|
||||
head++;
|
||||
__syncwarp();
|
||||
senderFifo->head = head;
|
||||
}
|
||||
fifoEntryIndex = head % HELIX_FIFO_DEPTH;
|
||||
fifoEntry128ByteIndexBase = 0;
|
||||
while (tail + HELIX_FIFO_DEPTH <= head)
|
||||
{
|
||||
tail = senderFifo->tail;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// wait for data to be loaded into shared memory
|
||||
waitG2sAllFields(&allWarpSmemBar[group], &phaseParity);
|
||||
// note: we don't need to pack anything, fields are already packed in
|
||||
// shared memory
|
||||
|
||||
LL128Proto::protoPack(shmem, head, singlePacked128ByteCount, fifoEntry128ByteIndexBase, laneId);
|
||||
|
||||
uint64_t* fifoEntry = fifoBase + fifoEntryIndex * (HELIX_FIFO_ENTRY_BYTES / sizeof(uint64_t));
|
||||
|
||||
// Copy from shared to workspace FIFO
|
||||
startWorkspaceS2GReg(fifoEntry, shmem, singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase, laneId);
|
||||
|
||||
fifoEntry128ByteIndexBase += singleProtoTransfer128ByteCount;
|
||||
|
||||
// ensure that we can over-write shmem in next iteration
|
||||
// (it must be fully read by all threads when doing S2G above)
|
||||
__syncwarp();
|
||||
}
|
||||
if (fifoEntry128ByteIndexBase > 0)
|
||||
{
|
||||
head++;
|
||||
senderFifo->head = head;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Receiver logic: receive data from peerRank to cpRank's slice
|
||||
int64_t tail = receiverFifo->tail;
|
||||
bool needRelease = false;
|
||||
|
||||
// Each channel processes entries with stride
|
||||
// Start at channel index, increment by total channel count
|
||||
for (int entryIdx = pairInfo.channel; entryIdx < params.entryCount; entryIdx += runChannelCount)
|
||||
{
|
||||
// receiver blocks should trigger next kernel at last iteration
|
||||
// note: some blocks might not even go into this for-loop, but they
|
||||
// would exit which is equivalent to the pre-exit trigger
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
if (entryIdx + runChannelCount >= params.entryCount)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif
|
||||
// dataIndex points to where we receive data from peerRank in this entry
|
||||
int dataIndex = entryIdx * params.cpSize + peerRank;
|
||||
int loaded128ByteCount = 0;
|
||||
|
||||
if (fifoEntry128ByteIndexBase + singleProtoTransfer128ByteCount > HELIX_FIFO_ENTRY_128B_COUNT)
|
||||
{
|
||||
if (fifoEntryIndex >= 0)
|
||||
{
|
||||
tail++;
|
||||
needRelease = true;
|
||||
}
|
||||
fifoEntryIndex = tail % HELIX_FIFO_DEPTH;
|
||||
fifoEntry128ByteIndexBase = 0;
|
||||
// receiver doesn't need to wait on FIFO entry being readable: it's
|
||||
// always readable
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
uint64_t* fifoEntry = fifoBase + fifoEntryIndex * (HELIX_FIFO_ENTRY_BYTES / sizeof(uint64_t));
|
||||
while (loaded128ByteCount < singleProtoTransfer128ByteCount)
|
||||
{
|
||||
startWorkspaceG2S(shmem, fifoEntry, singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase,
|
||||
loaded128ByteCount, &allWarpSmemBar[group], laneId);
|
||||
if (needRelease)
|
||||
{
|
||||
receiverFifo->tail = tail;
|
||||
senderFifo->tail = tail;
|
||||
needRelease = false;
|
||||
}
|
||||
smemBarWait(&allWarpSmemBar[group], &phaseParity);
|
||||
loaded128ByteCount += LL128Proto::template checkDataReceivedInShm<false>(shmem, tail,
|
||||
singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId);
|
||||
}
|
||||
|
||||
LL128Proto::protoUnpack(
|
||||
shmem, tail, singlePacked128ByteCount, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId);
|
||||
|
||||
// note: fields are already unpacked in shared memory
|
||||
s2gAllFields<ALLOW_VARIABLE_FIELD1>(params.recvFields, dataIndex, shmem, laneId);
|
||||
// wait for data to be read from shared memory
|
||||
cp_async_bulk_wait_group_read<0>();
|
||||
|
||||
// note: LL128Proto doesn't need rearm
|
||||
// rearmFifoBuffer();
|
||||
fifoEntry128ByteIndexBase += singleProtoTransfer128ByteCount;
|
||||
}
|
||||
if (fifoEntry128ByteIndexBase > 0)
|
||||
{
|
||||
tail++;
|
||||
receiverFifo->tail = tail;
|
||||
senderFifo->tail = tail;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Compute actual channel count
|
||||
// ============================================================================
|
||||
|
||||
struct hash_cache_key
|
||||
{
|
||||
size_t operator()(std::tuple<int, int, int> const& x) const
|
||||
{
|
||||
return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ALLOW_VARIABLE_FIELD1>
|
||||
std::tuple<int, int, int> computeChannelAndGroupCount(int cpSize, HelixFieldInfo const* fields)
|
||||
{
|
||||
static std::unordered_map<std::tuple<int, int, int>, std::tuple<int, int, int>, hash_cache_key> cache;
|
||||
int deviceId = 0;
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&deviceId));
|
||||
int singleShmSize = std::max(computeTotalUnpackedSize(fields), computeProtoTransferSize(fields));
|
||||
auto key = std::make_tuple(deviceId, cpSize, singleShmSize);
|
||||
auto it = cache.find(key);
|
||||
if (it != cache.end())
|
||||
{
|
||||
return it->second;
|
||||
}
|
||||
|
||||
int maxGroupCountPerCta = std::min(cpSize, MAX_GROUP_COUNT_PER_BLOCK);
|
||||
int groupCountPerCta = maxGroupCountPerCta; // Start with max
|
||||
int totalDynamicShmemSize = singleShmSize * groupCountPerCta;
|
||||
int maxDynamicShmSize = 0;
|
||||
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxDynamicShmSize, cudaDevAttrMaxSharedMemoryPerBlockOptin, deviceId));
|
||||
|
||||
while (totalDynamicShmemSize > maxDynamicShmSize)
|
||||
{
|
||||
groupCountPerCta--;
|
||||
totalDynamicShmemSize = singleShmSize * groupCountPerCta;
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(totalDynamicShmemSize <= maxDynamicShmSize, "Single packed size %d exceeds limit %d",
|
||||
singleShmSize, maxDynamicShmSize);
|
||||
|
||||
// Set shared memory attribute if needed
|
||||
if (totalDynamicShmemSize > 48 * 1024)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaFuncSetAttribute(helixAllToAllKernel<ALLOW_VARIABLE_FIELD1>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, totalDynamicShmemSize));
|
||||
}
|
||||
|
||||
int blockCountPerChannel = ceil_div(cpSize, groupCountPerCta);
|
||||
blockCountPerChannel *= 2; // for send and recv
|
||||
|
||||
int smCount = 0;
|
||||
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId));
|
||||
// TODO: we might only want to use half the SMs to overlap with other kernels.
|
||||
// note that overlap with FMHA is almost impossible because it must use
|
||||
// all SMs and probably uses >50% shmem per SM.
|
||||
// overlap with the subsequent BMM / out proj GEMMs might be possible,
|
||||
// so we need experiments to see whether it makes sense.
|
||||
int channelCount = std::max(smCount / blockCountPerChannel, 1);
|
||||
auto value = std::make_tuple(channelCount, groupCountPerCta, totalDynamicShmemSize);
|
||||
cache[key] = value;
|
||||
return value;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Host Launch Function
|
||||
// ============================================================================
|
||||
|
||||
template <bool ALLOW_VARIABLE_FIELD1>
|
||||
void launchHelixAllToAllImpl(HelixAllToAllParams const& params, cudaStream_t stream)
|
||||
{
|
||||
int maxChannelCount = computeHelixMaxChannelCount(params.cpSize);
|
||||
TLLM_CHECK_WITH_INFO(params.maxChannelCount == maxChannelCount,
|
||||
"maxChannelCount %d does not match computed maxChannelCount %d", params.maxChannelCount, maxChannelCount);
|
||||
auto [channelCount, groupCountPerCta, totalDynamicShmemSize]
|
||||
= computeChannelAndGroupCount<ALLOW_VARIABLE_FIELD1>(params.cpSize, params.sendFields);
|
||||
if (params.channelCount > 0)
|
||||
{
|
||||
channelCount = params.channelCount;
|
||||
TLLM_CHECK_WITH_INFO(channelCount <= maxChannelCount, "channelCount %d exceeds maxChannelCount %d",
|
||||
channelCount, maxChannelCount);
|
||||
}
|
||||
|
||||
// Compute grid dimensions
|
||||
// grid.x = blocks per channel (how many blocks needed to cover all peer
|
||||
// ranks) grid.y = number of channels (parallel channels) grid.z = 2 (sender
|
||||
// and receiver)
|
||||
int ctaPerChannel = ceil_div(params.cpSize, groupCountPerCta);
|
||||
|
||||
auto* kernel_instance = &helixAllToAllKernel<ALLOW_VARIABLE_FIELD1>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = dim3(ctaPerChannel, channelCount, 2);
|
||||
config.blockDim = dim3(WARP_SIZE, groupCountPerCta);
|
||||
config.dynamicSmemBytes = totalDynamicShmemSize;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_instance, params));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// ============================================================================
|
||||
// Public API Functions
|
||||
// ============================================================================
|
||||
|
||||
int computeHelixMaxChannelCount(int cpSize, int smCount)
|
||||
{
|
||||
if (smCount == 0)
|
||||
{
|
||||
int deviceId = 0;
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&deviceId));
|
||||
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId));
|
||||
}
|
||||
|
||||
int blockCountPerChannel = ceil_div(cpSize, MAX_GROUP_COUNT_PER_BLOCK);
|
||||
blockCountPerChannel *= 2; // for send and recv
|
||||
|
||||
int preferredChannel = smCount / blockCountPerChannel;
|
||||
return std::max(preferredChannel, 1); // at least one channel
|
||||
}
|
||||
|
||||
size_t computeHelixWorkspaceSizePerRank(int cpSize)
|
||||
{
|
||||
static int maxChannelCount = 0;
|
||||
if (maxChannelCount == 0)
|
||||
{
|
||||
maxChannelCount = computeHelixMaxChannelCount(cpSize);
|
||||
}
|
||||
|
||||
// FIFO buffers: cpSize * channelCount pairs
|
||||
size_t fifoSize = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount;
|
||||
|
||||
// Sender and receiver FIFO info structures
|
||||
size_t senderInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
|
||||
size_t receiverInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
|
||||
|
||||
return fifoSize + senderInfoSize + receiverInfoSize;
|
||||
}
|
||||
|
||||
void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream)
|
||||
{
|
||||
if (allowVariableField1)
|
||||
{
|
||||
launchHelixAllToAllImpl<true>(params, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
launchHelixAllToAllImpl<false>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Workspace Initialization
|
||||
// ============================================================================
|
||||
|
||||
void initializeHelixWorkspace(uint64_t* local_workspace_ptr, int cpSize, cudaStream_t stream)
|
||||
{
|
||||
int maxChannelCount = computeHelixMaxChannelCount(cpSize);
|
||||
// Calculate sizes with channel dimension
|
||||
size_t fifoSize = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount;
|
||||
size_t senderInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
|
||||
size_t receiverInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
|
||||
|
||||
// Initialize FIFO buffers to 0xFFFFFFFF (-1 for signed integer types)
|
||||
TLLM_CUDA_CHECK(cudaMemsetAsync(local_workspace_ptr, 0xFF, fifoSize, stream));
|
||||
|
||||
// Initialize sender and receiver info to zero (single call for both)
|
||||
uint8_t* infoPtr = reinterpret_cast<uint8_t*>(local_workspace_ptr) + fifoSize;
|
||||
TLLM_CUDA_CHECK(cudaMemsetAsync(infoPtr, 0, senderInfoSize + receiverInfoSize, stream));
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
94
cpp/tensorrt_llm/kernels/helixAllToAll.h
Normal file
94
cpp/tensorrt_llm/kernels/helixAllToAll.h
Normal file
@ -0,0 +1,94 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/config.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
struct HelixFieldInfo
|
||||
{
|
||||
uint8_t* dataPtr;
|
||||
int elementCount; // Number of elements (e.g., kv_lora_rank for field 0, 1 for
|
||||
// field 1)
|
||||
int elementSize; // Size of each element in bytes (2 for half, 8 for float2)
|
||||
int stride; // Stride between rows in bytes
|
||||
};
|
||||
|
||||
struct HelixAllToAllParams
|
||||
{
|
||||
HelixFieldInfo sendFields[2];
|
||||
HelixFieldInfo recvFields[2];
|
||||
int entryCount; // Number of entries per peer rank to process
|
||||
uint64_t* workspace;
|
||||
int workspaceStrideInU64;
|
||||
int cpRank;
|
||||
int cpSize;
|
||||
int channelCount; // use 0 to auto-compute
|
||||
int maxChannelCount;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Workspace Management Functions
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Compute number of channels for communication based on cpSize.
|
||||
*
|
||||
* @param cpSize Number of context parallel ranks
|
||||
* @param smCount Number of SMs available (0 = auto-detect)
|
||||
* @return Number of channels to use
|
||||
*/
|
||||
int computeHelixMaxChannelCount(int cpSize, int smCount = 0);
|
||||
|
||||
/**
|
||||
* Compute the workspace size required per rank for the all-to-all operation.
|
||||
*
|
||||
* @param cpSize Number of context parallel ranks
|
||||
* @return Size in bytes
|
||||
*/
|
||||
size_t computeHelixWorkspaceSizePerRank(int cpSize);
|
||||
|
||||
/**
|
||||
* Initialize workspace memory for a given rank.
|
||||
* Should be called once during setup.
|
||||
*
|
||||
* @param workspace Pointer to workspace memory (per-rank view)
|
||||
* @param cpSize Number of context parallel ranks
|
||||
* @param stream CUDA stream for asynchronous operations
|
||||
*/
|
||||
void initializeHelixWorkspace(uint64_t* workspace, int cpSize, cudaStream_t stream);
|
||||
|
||||
/**
|
||||
* Launch the helix all-to-all kernel.
|
||||
*
|
||||
* @param params Kernel parameters including field info and workspace
|
||||
* @param allowVariableField1 Whether to allow variable field 1
|
||||
* @param stream CUDA stream for kernel launch
|
||||
*/
|
||||
void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
@ -18,6 +18,7 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
#include <tensorrt_llm/kernels/helixAllToAll.h>
|
||||
#include <tensorrt_llm/thop/attentionOp.h>
|
||||
#include <tensorrt_llm/thop/moeAlltoAllMeta.h>
|
||||
#include <torch/extension.h>
|
||||
@ -73,5 +74,10 @@ void initBindings(nb::module_& m)
|
||||
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
|
||||
nb::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
m.def(
|
||||
"get_helix_workspace_size_per_rank",
|
||||
[](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); },
|
||||
nb::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes");
|
||||
}
|
||||
} // namespace tensorrt_llm::nanobind::thop
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <tensorrt_llm/kernels/helixAllToAll.h>
|
||||
#include <tensorrt_llm/thop/attentionOp.h>
|
||||
#include <tensorrt_llm/thop/moeAlltoAllMeta.h>
|
||||
#include <torch/extension.h>
|
||||
@ -73,5 +74,10 @@ void initBindings(pybind11::module_& m)
|
||||
py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt,
|
||||
py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def(
|
||||
"get_helix_workspace_size_per_rank",
|
||||
[](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); },
|
||||
py::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes");
|
||||
}
|
||||
} // namespace tensorrt_llm::pybind::thop
|
||||
|
||||
@ -16,19 +16,12 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include "tensorrt_llm/kernels/helixAllToAll.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cassert>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include <nccl.h>
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
@ -119,6 +112,145 @@ std::vector<torch::Tensor> alltoall_helix(
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
/**
|
||||
* Helix All-to-All operation with two fields.
|
||||
*
|
||||
* Input tensors have shape [..., cp_size, kv_lora_rank] for partial_o and [...,
|
||||
* cp_size, 2] for softmax_stats. The operation exchanges data along the cp_size
|
||||
* dimension across all ranks.
|
||||
*
|
||||
* @param partial_o Field 0 tensor (half precision, shape [..., cp_size,
|
||||
* kv_lora_rank])
|
||||
* @param softmax_stats Field 1 tensor (float32, shape [..., cp_size, 2])
|
||||
* @param workspace Workspace tensor (uint64, strided across ranks)
|
||||
* @param cp_rank Current context parallel rank
|
||||
* @param cp_size Total number of context parallel ranks
|
||||
* @return tuple of (partial_o_out, softmax_stats_out) with same shapes as inputs
|
||||
*/
|
||||
std::tuple<torch::Tensor, torch::Tensor> alltoall_helix_native(
|
||||
torch::Tensor partial_o, torch::Tensor softmax_stats, torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
|
||||
{
|
||||
|
||||
// Input validation
|
||||
CHECK_TH_CUDA(partial_o);
|
||||
CHECK_TH_CUDA(softmax_stats);
|
||||
CHECK_TH_CUDA(workspace);
|
||||
CHECK_CONTIGUOUS(partial_o);
|
||||
CHECK_CONTIGUOUS(softmax_stats);
|
||||
|
||||
// Type checks
|
||||
TORCH_CHECK(partial_o.scalar_type() == at::ScalarType::Half || partial_o.scalar_type() == at::ScalarType::BFloat16,
|
||||
"partial_o must be half or bfloat16");
|
||||
CHECK_TYPE(softmax_stats, at::ScalarType::Float);
|
||||
CHECK_TYPE(workspace, at::ScalarType::UInt64);
|
||||
|
||||
// Shape validation
|
||||
TORCH_CHECK(partial_o.dim() >= 2, "partial_o must have at least 2 dimensions");
|
||||
TORCH_CHECK(softmax_stats.dim() >= 2, "softmax_stats must have at least 2 dimensions");
|
||||
TORCH_CHECK(
|
||||
partial_o.dim() == softmax_stats.dim(), "partial_o and softmax_stats must have same number of dimensions");
|
||||
|
||||
// Get dimensions
|
||||
int kv_lora_rank = partial_o.size(-1);
|
||||
TORCH_CHECK(partial_o.size(-2) == cp_size && softmax_stats.size(-2) == cp_size,
|
||||
"partial_o/softmax_stats second-to-last dimension must equal cp_size");
|
||||
TORCH_CHECK(softmax_stats.size(-1) % 2 == 0 && softmax_stats.size(-1) >= 2,
|
||||
"softmax_stats last dimension must be divisible by 2 (float2)");
|
||||
bool allowVariableField1 = softmax_stats.size(-1) > 2;
|
||||
|
||||
// Check that leading dimensions match
|
||||
for (int i = 0; i < partial_o.dim() - 2; i++)
|
||||
{
|
||||
TORCH_CHECK(partial_o.size(i) == softmax_stats.size(i),
|
||||
"partial_o and softmax_stats must have matching dimensions except last two");
|
||||
}
|
||||
TORCH_CHECK(partial_o.size(-1) * partial_o.element_size() % 16 == 0, "partial_o must be aligned to 16 bytes");
|
||||
|
||||
TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D (strided across ranks)");
|
||||
TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows");
|
||||
|
||||
// Calculate entry count (product of all dimensions before cp_size)
|
||||
// This is the number of entries to process per peer rank
|
||||
int entry_count = 1;
|
||||
for (int i = 0; i < partial_o.dim() - 2; i++)
|
||||
{
|
||||
entry_count *= partial_o.size(i);
|
||||
}
|
||||
|
||||
// Reshape to 3D: [entry_count, cp_size, feature_dim]
|
||||
torch::Tensor partial_o_3d = partial_o.reshape({entry_count, cp_size, kv_lora_rank});
|
||||
torch::Tensor softmax_stats_3d = softmax_stats.reshape({entry_count, cp_size, softmax_stats.size(-1)});
|
||||
|
||||
// Allocate output tensors (same shape as input)
|
||||
torch::Tensor partial_o_out = torch::empty_like(partial_o);
|
||||
torch::Tensor softmax_stats_out = torch::empty_like(softmax_stats);
|
||||
|
||||
torch::Tensor partial_o_out_3d = partial_o_out.reshape({entry_count, cp_size, kv_lora_rank});
|
||||
torch::Tensor softmax_stats_out_3d = softmax_stats_out.reshape({entry_count, cp_size, softmax_stats.size(-1)});
|
||||
|
||||
// Setup parameters
|
||||
tensorrt_llm::kernels::HelixAllToAllParams params;
|
||||
|
||||
// Field 0 (variable size half)
|
||||
params.sendFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_3d.data_ptr());
|
||||
params.sendFields[0].elementCount = kv_lora_rank;
|
||||
params.sendFields[0].elementSize = partial_o.element_size();
|
||||
params.sendFields[0].stride = partial_o_3d.stride(1) * partial_o.element_size();
|
||||
|
||||
params.recvFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_out_3d.data_ptr());
|
||||
params.recvFields[0].elementCount = kv_lora_rank;
|
||||
params.recvFields[0].elementSize = partial_o.element_size();
|
||||
params.recvFields[0].stride = partial_o_out_3d.stride(1) * partial_o.element_size();
|
||||
|
||||
// Field 1 (single float2)
|
||||
params.sendFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_3d.data_ptr<float>());
|
||||
params.sendFields[1].elementCount = softmax_stats.size(-1);
|
||||
params.sendFields[1].elementSize = softmax_stats.element_size();
|
||||
params.sendFields[1].stride = softmax_stats_3d.stride(1) * softmax_stats.element_size();
|
||||
|
||||
params.recvFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_out_3d.data_ptr<float>());
|
||||
params.recvFields[1].elementCount = softmax_stats.size(-1);
|
||||
params.recvFields[1].elementSize = softmax_stats.element_size();
|
||||
params.recvFields[1].stride = softmax_stats_out_3d.stride(1) * softmax_stats.element_size();
|
||||
|
||||
// Entry count and workspace
|
||||
params.entryCount = entry_count;
|
||||
params.workspace = workspace.data_ptr<uint64_t>();
|
||||
params.workspaceStrideInU64 = workspace.stride(0);
|
||||
|
||||
// CP info
|
||||
params.cpRank = cp_rank;
|
||||
params.cpSize = cp_size;
|
||||
params.channelCount = 0; // auto-compute
|
||||
params.maxChannelCount = tensorrt_llm::kernels::computeHelixMaxChannelCount(cp_size);
|
||||
|
||||
// Launch kernel
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
tensorrt_llm::kernels::launchHelixAllToAll(params, allowVariableField1, stream);
|
||||
|
||||
return std::make_tuple(partial_o_out, softmax_stats_out);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize workspace for helix all-to-all
|
||||
*/
|
||||
void initialize_helix_workspace(torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
|
||||
{
|
||||
CHECK_TH_CUDA(workspace);
|
||||
CHECK_TYPE(workspace, at::ScalarType::UInt64);
|
||||
TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D");
|
||||
TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows");
|
||||
TORCH_CHECK(cp_rank >= 0 && cp_rank < cp_size, "cp_rank must be in [0, cp_size)");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
uint64_t* global_workspace_ptr = workspace.data_ptr<uint64_t>();
|
||||
uint64_t* local_workspace_ptr = workspace[cp_rank].data_ptr<uint64_t>();
|
||||
TORCH_CHECK(local_workspace_ptr == global_workspace_ptr + cp_rank * workspace.stride(0),
|
||||
"local_workspace_ptr must be at the correct offset in the global "
|
||||
"workspace");
|
||||
tensorrt_llm::kernels::initializeHelixWorkspace(local_workspace_ptr, cp_size, stream);
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
@ -126,9 +258,17 @@ TRTLLM_NAMESPACE_END
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("alltoall_helix(Tensor[] input_list, int[] group, int? num_lists) -> Tensor[]");
|
||||
m.def(
|
||||
"alltoall_helix_native(Tensor partial_o, Tensor softmax_stats, Tensor(a!) workspace, int "
|
||||
"cp_rank, int cp_size) -> (Tensor, Tensor)");
|
||||
m.def(
|
||||
"initialize_helix_workspace(Tensor(a!) workspace, int cp_rank, int cp_size) "
|
||||
"-> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("alltoall_helix", &tensorrt_llm::torch_ext::alltoall_helix);
|
||||
m.impl("alltoall_helix_native", &tensorrt_llm::torch_ext::alltoall_helix_native);
|
||||
m.impl("initialize_helix_workspace", &tensorrt_llm::torch_ext::initialize_helix_workspace);
|
||||
}
|
||||
|
||||
@ -51,38 +51,54 @@ def _check_cu_result(cu_func_ret):
|
||||
|
||||
|
||||
class MnnvlMemory:
|
||||
"""MNNVL memory management for tensor parallel (TP) operations."""
|
||||
|
||||
# Shared across all subclasses (global/device state).
|
||||
initialized: bool = False
|
||||
|
||||
current_mem_offset: int = 0
|
||||
current_rank_stride: int = 0 # stride for ranks and also address space size.
|
||||
current_start_address: int = 0
|
||||
|
||||
# allocation granularity
|
||||
allocation_granularity: int = 0
|
||||
|
||||
# fabric address page size (512 MB)
|
||||
fabric_page_size: int = 1 << 29
|
||||
|
||||
# MPI communicator
|
||||
comm = None
|
||||
|
||||
fabric_page_size: int = 1 << 29 # 512 MB.
|
||||
dev_id: int = None
|
||||
|
||||
# Per-class state attributes. These will be auto-initialized for each subclass
|
||||
# to avoid polluting the parent class's state. Use callable (e.g., dict) for mutable defaults.
|
||||
_per_class_attrs = {
|
||||
"current_mem_offset": 0,
|
||||
"current_rank_stride": 0, # stride for ranks and also address space size.
|
||||
"current_start_address": 0,
|
||||
"comm": None, # MPI communicator.
|
||||
"allocated_map": dict, # callable for fresh dict.
|
||||
"address_refcnt": dict, # callable for fresh dict.
|
||||
}
|
||||
|
||||
# Initialize per-class state for the base class.
|
||||
current_mem_offset: int = 0
|
||||
current_rank_stride: int = 0
|
||||
current_start_address: int = 0
|
||||
comm = None
|
||||
allocated_map = {}
|
||||
address_refcnt = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""Auto-initialize per-class attributes for each subclass to avoid sharing state with parent."""
|
||||
super().__init_subclass__(**kwargs)
|
||||
for attr, default in cls._per_class_attrs.items():
|
||||
if callable(default):
|
||||
setattr(cls, attr, default()) # e.g., dict() creates a fresh dict.
|
||||
else:
|
||||
setattr(cls, attr, default)
|
||||
|
||||
def __init__(self, mapping: Mapping, size: int):
|
||||
self.mapping = mapping
|
||||
self.segment_size = size
|
||||
self.ptr, self.rank_stride = MnnvlMemory.open_mnnvl_memory(self.mapping, size)
|
||||
self.ptr, self.rank_stride = type(self).open_mnnvl_memory(self.mapping, size)
|
||||
|
||||
def __del__(self):
|
||||
if not sys.is_finalizing():
|
||||
if hasattr(self, "ptr"):
|
||||
MnnvlMemory.close_mnnvl_memory(self.ptr)
|
||||
type(self).close_mnnvl_memory(self.ptr)
|
||||
|
||||
def as_torch_strided_tensor(self, dtype):
|
||||
num_segments = MnnvlMemory.comm.Get_size()
|
||||
num_segments = type(self).comm.Get_size()
|
||||
return pack_strided_memory(
|
||||
self.ptr, self.segment_size, self.rank_stride, num_segments, dtype, MnnvlMemory.dev_id
|
||||
)
|
||||
@ -99,16 +115,17 @@ class MnnvlMemory:
|
||||
pynvml.nvmlInit()
|
||||
MnnvlMemory.initialized = True
|
||||
|
||||
@staticmethod
|
||||
def get_comm(mapping: Mapping):
|
||||
if MnnvlMemory.comm is not None:
|
||||
return MnnvlMemory.comm
|
||||
@classmethod
|
||||
def get_comm(cls, mapping: Mapping):
|
||||
"""Get TP-based communicator (ranks grouped by PP+CP+MOE_TP, ordered by TP rank)."""
|
||||
if cls.comm is not None:
|
||||
return cls.comm
|
||||
comm = mpi_comm().Split(
|
||||
(mapping.pp_rank * mapping.cp_size + mapping.cp_rank) * mapping.moe_tp_size
|
||||
+ mapping.moe_tp_rank,
|
||||
mapping.tp_rank,
|
||||
)
|
||||
MnnvlMemory.comm = comm
|
||||
cls.comm = comm
|
||||
return comm
|
||||
|
||||
@staticmethod
|
||||
@ -148,23 +165,26 @@ class MnnvlMemory:
|
||||
MnnvlMemory.allocation_granularity = granularity
|
||||
return MnnvlMemory.allocation_granularity
|
||||
|
||||
@staticmethod
|
||||
def new_mnnvl_memory_address(mapping: Mapping, size: int):
|
||||
@classmethod
|
||||
def new_mnnvl_memory_address(cls, mapping: Mapping, size: int):
|
||||
page_count = (size + MnnvlMemory.fabric_page_size - 1) // MnnvlMemory.fabric_page_size
|
||||
current_rank_stride = page_count * MnnvlMemory.fabric_page_size
|
||||
logger.info(f"[MnnvlMemory] creating address with stride={current_rank_stride}")
|
||||
comm = MnnvlMemory.get_comm(mapping)
|
||||
logger.info(f"[{cls.__name__}] creating address with stride={current_rank_stride}")
|
||||
comm = cls.get_comm(mapping)
|
||||
comm_size = comm.Get_size()
|
||||
address_size = current_rank_stride * comm_size
|
||||
ptr = _check_cu_result(
|
||||
cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0)
|
||||
)
|
||||
MnnvlMemory.current_start_address = int(ptr)
|
||||
MnnvlMemory.current_rank_stride = current_rank_stride
|
||||
MnnvlMemory.current_mem_offset = 0
|
||||
cls.current_start_address = int(ptr)
|
||||
cls.current_rank_stride = current_rank_stride
|
||||
cls.current_mem_offset = 0
|
||||
|
||||
@classmethod
|
||||
def open_mnnvl_memory(cls, mapping: Mapping, size: int):
|
||||
# Ensure MnnvlMemory is initialized (for dev_id and allocation_granularity)
|
||||
MnnvlMemory.initialize()
|
||||
|
||||
@staticmethod
|
||||
def open_mnnvl_memory(mapping: Mapping, size: int):
|
||||
dev = _check_cu_result(cuda.cuCtxGetDevice())
|
||||
dev_id = int(dev)
|
||||
if MnnvlMemory.dev_id is None:
|
||||
@ -172,7 +192,7 @@ class MnnvlMemory:
|
||||
assert dev_id == MnnvlMemory.dev_id, (
|
||||
f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}"
|
||||
)
|
||||
comm = MnnvlMemory.get_comm(mapping)
|
||||
comm = cls.get_comm(mapping)
|
||||
comm_rank = comm.Get_rank()
|
||||
comm_size = comm.Get_size()
|
||||
all_rank_allocate_sizes = comm.allgather(size)
|
||||
@ -181,10 +201,10 @@ class MnnvlMemory:
|
||||
granularity = MnnvlMemory.get_allocation_granularity(dev_id)
|
||||
aligned_size = (size + granularity - 1) // granularity * granularity
|
||||
|
||||
if MnnvlMemory.current_mem_offset + aligned_size > MnnvlMemory.current_rank_stride:
|
||||
MnnvlMemory.new_mnnvl_memory_address(mapping, aligned_size)
|
||||
if cls.current_mem_offset + aligned_size > cls.current_rank_stride:
|
||||
cls.new_mnnvl_memory_address(mapping, aligned_size)
|
||||
|
||||
assert MnnvlMemory.current_mem_offset + aligned_size <= MnnvlMemory.current_rank_stride
|
||||
assert cls.current_mem_offset + aligned_size <= cls.current_rank_stride
|
||||
|
||||
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
|
||||
allocated_mem_handle = _check_cu_result(
|
||||
@ -245,9 +265,7 @@ class MnnvlMemory:
|
||||
|
||||
for i, remote_handle_data in enumerate(all_handles_data):
|
||||
rank_ptr = (
|
||||
MnnvlMemory.current_start_address
|
||||
+ MnnvlMemory.current_rank_stride * i
|
||||
+ MnnvlMemory.current_mem_offset
|
||||
cls.current_start_address + cls.current_rank_stride * i + cls.current_mem_offset
|
||||
)
|
||||
if i == comm_rank:
|
||||
# Local memory mapping
|
||||
@ -265,44 +283,44 @@ class MnnvlMemory:
|
||||
|
||||
_check_cu_result(cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1))
|
||||
|
||||
ptr = MnnvlMemory.current_start_address + MnnvlMemory.current_mem_offset
|
||||
stride = MnnvlMemory.current_rank_stride
|
||||
MnnvlMemory.allocated_map[ptr] = (
|
||||
ptr = cls.current_start_address + cls.current_mem_offset
|
||||
stride = cls.current_rank_stride
|
||||
cls.allocated_map[ptr] = (
|
||||
mapping,
|
||||
aligned_size,
|
||||
mem_handles,
|
||||
MnnvlMemory.current_start_address,
|
||||
MnnvlMemory.current_rank_stride,
|
||||
MnnvlMemory.current_mem_offset,
|
||||
cls.current_start_address,
|
||||
cls.current_rank_stride,
|
||||
cls.current_mem_offset,
|
||||
)
|
||||
MnnvlMemory.address_refcnt[MnnvlMemory.current_start_address] = (
|
||||
MnnvlMemory.address_refcnt.get(MnnvlMemory.current_start_address, 0) + 1
|
||||
cls.address_refcnt[cls.current_start_address] = (
|
||||
cls.address_refcnt.get(cls.current_start_address, 0) + 1
|
||||
)
|
||||
|
||||
MnnvlMemory.current_mem_offset += aligned_size
|
||||
cls.current_mem_offset += aligned_size
|
||||
return ptr, stride
|
||||
|
||||
@staticmethod
|
||||
def close_mnnvl_memory(ptr: int):
|
||||
@classmethod
|
||||
def close_mnnvl_memory(cls, ptr: int):
|
||||
mapping, aligned_size, mem_handles, start_address, rank_stride, address_offset = (
|
||||
MnnvlMemory.allocated_map.pop(ptr)
|
||||
cls.allocated_map.pop(ptr)
|
||||
)
|
||||
comm = MnnvlMemory.get_comm(mapping)
|
||||
comm = cls.get_comm(mapping)
|
||||
comm_size = comm.Get_size()
|
||||
for i in range(comm_size):
|
||||
rank_ptr = start_address + i * rank_stride + address_offset
|
||||
_check_cu_result(cuda.cuMemUnmap(rank_ptr, aligned_size))
|
||||
_check_cu_result(cuda.cuMemRelease(mem_handles[i]))
|
||||
MnnvlMemory.address_refcnt[start_address] -= 1
|
||||
cls.address_refcnt[start_address] -= 1
|
||||
|
||||
if MnnvlMemory.address_refcnt[start_address] == 0:
|
||||
MnnvlMemory.address_refcnt.pop(start_address)
|
||||
if cls.address_refcnt[start_address] == 0:
|
||||
cls.address_refcnt.pop(start_address)
|
||||
device_ptr = cuda.CUdeviceptr(start_address)
|
||||
_check_cu_result(cuda.cuMemAddressFree(device_ptr, comm_size * rank_stride))
|
||||
if start_address == MnnvlMemory.current_start_address:
|
||||
MnnvlMemory.current_start_address = 0
|
||||
MnnvlMemory.current_rank_stride = 0
|
||||
MnnvlMemory.current_mem_offset = 0
|
||||
if start_address == cls.current_start_address:
|
||||
cls.current_start_address = 0
|
||||
cls.current_rank_stride = 0
|
||||
cls.current_mem_offset = 0
|
||||
|
||||
@staticmethod
|
||||
def support_nvlink(need_all_up: bool = True):
|
||||
@ -338,6 +356,29 @@ class MnnvlMemory:
|
||||
return support_nvlink_and_all_up
|
||||
|
||||
|
||||
class HelixCpMnnvlMemory(MnnvlMemory):
|
||||
"""MNNVL memory management for Helix context parallel (CP) operations.
|
||||
|
||||
Per-class state (current_mem_offset, comm, allocated_map, etc.) is automatically
|
||||
initialized via __init_subclass__ in the parent class, ensuring this class has
|
||||
its own isolated state separate from MnnvlMemory.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_comm(cls, mapping: Mapping):
|
||||
"""Get CP-based communicator (ranks grouped by PP+TP+MOE_TP, ordered by CP rank)."""
|
||||
if cls.comm is not None:
|
||||
return cls.comm
|
||||
comm = mpi_comm().Split(
|
||||
mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size
|
||||
+ mapping.tp_rank * mapping.moe_tp_size
|
||||
+ mapping.moe_tp_rank,
|
||||
mapping.cp_rank,
|
||||
)
|
||||
cls.comm = comm
|
||||
return comm
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoEAlltoallInfo:
|
||||
local_gather_indices: torch.Tensor
|
||||
|
||||
@ -752,6 +752,17 @@ def _register_fake():
|
||||
for i in range(0, len(input_list), num_ranks)
|
||||
]
|
||||
|
||||
@torch.library.register_fake("trtllm::alltoall_helix_native")
|
||||
def _(partial_o, softmax_stats, workspace, cp_rank, cp_size):
|
||||
# Returns outputs with same shapes as inputs
|
||||
return partial_o.new_empty(partial_o.shape), softmax_stats.new_empty(
|
||||
softmax_stats.shape)
|
||||
|
||||
@torch.library.register_fake("trtllm::initialize_helix_workspace")
|
||||
def _(workspace, cp_rank, cp_size):
|
||||
# This op initializes workspace in-place and returns nothing
|
||||
return None
|
||||
|
||||
@torch.library.register_fake("trtllm::helix_post_process")
|
||||
def _(gathered_o, gathered_stats, scale):
|
||||
return gathered_o.new_empty(*gathered_o.shape[1:])
|
||||
|
||||
@ -2,9 +2,10 @@ from tensorrt_llm.functional import AllReduceFusionOp
|
||||
|
||||
from .communicator import Distributed, MPIDist, TorchDist
|
||||
from .moe_alltoall import MoeAlltoAll
|
||||
from .ops import (AllReduce, AllReduceParams, AllReduceStrategy, MoEAllReduce,
|
||||
MoEAllReduceParams, allgather, alltoall_helix, cp_allgather,
|
||||
reducescatter, userbuffers_allreduce_finalize)
|
||||
from .ops import (AllReduce, AllReduceParams, AllReduceStrategy,
|
||||
HelixAllToAllNative, MoEAllReduce, MoEAllReduceParams,
|
||||
allgather, alltoall_helix, cp_allgather, reducescatter,
|
||||
userbuffers_allreduce_finalize)
|
||||
|
||||
__all__ = [
|
||||
"allgather",
|
||||
@ -16,6 +17,7 @@ __all__ = [
|
||||
"AllReduceParams",
|
||||
"AllReduceFusionOp",
|
||||
"AllReduceStrategy",
|
||||
"HelixAllToAllNative",
|
||||
"MoEAllReduce",
|
||||
"MoEAllReduceParams",
|
||||
"MoeAlltoAll",
|
||||
|
||||
@ -2,14 +2,16 @@ import math
|
||||
import os
|
||||
import platform
|
||||
import threading
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._mnnvl_utils import HelixCpMnnvlMemory, MnnvlMemory
|
||||
from tensorrt_llm._torch.distributed.symm_mem_allreduce import \
|
||||
SymmetricMemoryAllReduce
|
||||
from tensorrt_llm._utils import mpi_comm, mpi_disabled
|
||||
from tensorrt_llm.bindings import internal as _tllm_internal
|
||||
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
|
||||
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
|
||||
AllReduceStrategy, MoEAllReduceParams)
|
||||
@ -363,6 +365,84 @@ def alltoall_helix(
|
||||
return torch.ops.trtllm.alltoall_helix(inputs, group, num_lists)
|
||||
|
||||
|
||||
class HelixAllToAllNative:
|
||||
"""
|
||||
Manager for Helix All-to-All operations with MNNVL workspace management.
|
||||
|
||||
Exchanges data along the cp_size dimension:
|
||||
- partial_o: [..., cp_size, kv_lora_rank] half-precision
|
||||
- softmax_stats: [..., cp_size, 2] float32
|
||||
"""
|
||||
|
||||
# Global cache: mapping -> instance
|
||||
_cache: Dict[Mapping, "HelixAllToAllNative"] = {}
|
||||
|
||||
def __init__(self, mapping: Mapping, workspace: HelixCpMnnvlMemory,
|
||||
workspace_tensor: torch.Tensor):
|
||||
"""Private constructor - use get() instead."""
|
||||
self.mapping = mapping
|
||||
self.workspace = workspace
|
||||
self.workspace_tensor = workspace_tensor
|
||||
|
||||
@staticmethod
|
||||
def get(mapping: Mapping) -> "HelixAllToAllNative":
|
||||
"""
|
||||
Get or create a HelixAllToAllNative instance for the given configuration.
|
||||
|
||||
Args:
|
||||
mapping: TensorRT-LLM mapping object containing cp_size and cp_rank
|
||||
|
||||
Returns:
|
||||
Cached or newly-created HelixAllToAllNative instance
|
||||
"""
|
||||
if mapping not in HelixAllToAllNative._cache:
|
||||
logger.info(
|
||||
f"Rank {mapping.cp_rank} initializing HelixCpMnnvlMemory for Helix"
|
||||
)
|
||||
MnnvlMemory.initialize()
|
||||
|
||||
# Get workspace size (in bytes)
|
||||
workspace_size_per_rank = _tllm_internal.thop.get_helix_workspace_size_per_rank(
|
||||
mapping.cp_size)
|
||||
|
||||
# Allocate MNNVL memory using CP communicator for Helix
|
||||
workspace = HelixCpMnnvlMemory(mapping, workspace_size_per_rank)
|
||||
workspace_tensor = workspace.as_torch_strided_tensor(torch.uint64)
|
||||
|
||||
torch.ops.trtllm.initialize_helix_workspace(workspace_tensor,
|
||||
mapping.cp_rank,
|
||||
mapping.cp_size)
|
||||
torch.cuda.synchronize()
|
||||
HelixCpMnnvlMemory.get_comm(mapping).barrier()
|
||||
|
||||
HelixAllToAllNative._cache[mapping] = HelixAllToAllNative(
|
||||
mapping, workspace, workspace_tensor)
|
||||
|
||||
return HelixAllToAllNative._cache[mapping]
|
||||
|
||||
def alltoall_native(self, partial_o: torch.Tensor,
|
||||
softmax_stats: torch.Tensor):
|
||||
"""
|
||||
Perform all-to-all data exchange.
|
||||
|
||||
Args:
|
||||
partial_o: Tensor with shape [..., cp_size, kv_lora_rank], dtype half.
|
||||
softmax_stats: Tensor with shape [..., cp_size, 2], dtype float32.
|
||||
|
||||
Returns:
|
||||
Tuple of (partial_o_out, softmax_stats_out) with same shapes as inputs.
|
||||
"""
|
||||
partial_o_out, softmax_stats_out = torch.ops.trtllm.alltoall_helix_native(
|
||||
partial_o,
|
||||
softmax_stats,
|
||||
self.workspace_tensor,
|
||||
self.mapping.cp_rank,
|
||||
self.mapping.cp_size,
|
||||
)
|
||||
|
||||
return partial_o_out, softmax_stats_out
|
||||
|
||||
|
||||
def reducescatter(
|
||||
input: Union[torch.Tensor, List[torch.Tensor]],
|
||||
mapping: Mapping,
|
||||
|
||||
@ -20,7 +20,7 @@ from ..attention_backend.interface import (AttentionBackend, AttentionMask,
|
||||
from ..attention_backend.sparse.dsa import (
|
||||
DSAtrtllmAttentionMetadata, transform_local_topk_and_prepare_pool_view)
|
||||
from ..attention_backend.utils import create_attention, get_attention_backend
|
||||
from ..distributed import AllReduceParams, alltoall_helix
|
||||
from ..distributed import AllReduceParams, HelixAllToAllNative, alltoall_helix
|
||||
from ..model_config import ModelConfig
|
||||
from ..peft.lora.layer import LoraLayer, LoraModuleType
|
||||
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
|
||||
@ -1120,24 +1120,60 @@ class MLA(nn.Module):
|
||||
softmax_stats_tensor=softmax_stats,
|
||||
**kwargs,
|
||||
)
|
||||
# this is the post-processing of helix parallel attention,
|
||||
# similar to the post-processing of ring attention
|
||||
kv_lora_rank = partial_o.shape[-1] // self.num_heads_tp
|
||||
assert self.kv_lora_rank == kv_lora_rank
|
||||
# transpose the tensors to make the split across cp_size contiguous
|
||||
# for both tensors, we need to split across the second dimension
|
||||
chunks = []
|
||||
for t in [partial_o, softmax_stats]:
|
||||
t = t.transpose(1, 0).contiguous()
|
||||
chunks.extend(torch.split(t,
|
||||
t.shape[0] // self.mapping.cp_size))
|
||||
gathered = alltoall_helix(chunks, self.mapping.cp_group)
|
||||
# transpose the tensors back to ensure dimensions are ordered correctly
|
||||
# note: an additional dimension was added at the first index for all-to-all,
|
||||
# so the transpose dimensions are shifted by 1
|
||||
gathered = [t.transpose(1, 2).contiguous() for t in gathered]
|
||||
return torch.ops.trtllm.helix_post_process(gathered[0], gathered[1],
|
||||
1.0)
|
||||
|
||||
# Switch between NCCL-based and FIFO-based (MNNVL) all-to-all based on cp_config.
|
||||
if self.mapping.cp_config.get("use_nccl_for_alltoall", True):
|
||||
# NCCL-based implementation using alltoall_helix.
|
||||
# This is the post-processing of helix parallel attention,
|
||||
# similar to the post-processing of ring attention.
|
||||
# Transpose the tensors to make the split across cp_size contiguous
|
||||
# For both tensors, we need to split across the second dimension.
|
||||
chunks = []
|
||||
for t in [partial_o, softmax_stats]:
|
||||
t = t.transpose(1, 0).contiguous()
|
||||
chunks.extend(
|
||||
torch.split(t, t.shape[0] // self.mapping.cp_size))
|
||||
gathered = alltoall_helix(chunks, self.mapping.cp_group)
|
||||
# Transpose the tensors back to ensure dimensions are ordered correctly.
|
||||
# Note: an additional dimension was added at the first index for all-to-all,
|
||||
# so the transpose dimensions are shifted by 1.
|
||||
gathered = [t.transpose(1, 2).contiguous() for t in gathered]
|
||||
return torch.ops.trtllm.helix_post_process(
|
||||
gathered[0], gathered[1], 1.0)
|
||||
else:
|
||||
# FIFO-based implementation using MNNVL workspace and LL128 Proto.
|
||||
# Get or create Helix All-to-All instance.
|
||||
helix = HelixAllToAllNative.get(self.mapping)
|
||||
|
||||
# Get dimensions.
|
||||
num_tokens = partial_o.shape[0]
|
||||
cp_size = self.mapping.cp_size
|
||||
|
||||
# Reshape for FIFO-based all-to-all.
|
||||
# partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank]
|
||||
# softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2]
|
||||
|
||||
partial_o = partial_o.view(
|
||||
num_tokens, cp_size, self.num_heads_tp_cp,
|
||||
kv_lora_rank).transpose(1, 2).contiguous()
|
||||
softmax_stats = softmax_stats.view(num_tokens, cp_size,
|
||||
self.num_heads_tp_cp,
|
||||
2).transpose(1,
|
||||
2).contiguous()
|
||||
|
||||
# Call FIFO-based helixAllToAll.
|
||||
partial_o_out, softmax_stats_out = helix.alltoall_native(
|
||||
partial_o, softmax_stats)
|
||||
|
||||
# partial_o_out: [num_tokens, num_heads_tp_cp, cp_size, kv_lora_rank]
|
||||
# softmax_stats_out: [num_tokens, num_heads_tp_cp, cp_size, 2]
|
||||
# cp_dim = 2 (the dimension where cp_size is located)
|
||||
|
||||
# Call helix_post_process_native with cp_dim=2.
|
||||
return torch.ops.trtllm.helix_post_process_native(
|
||||
partial_o_out, softmax_stats_out, 1.0, 2)
|
||||
else:
|
||||
attn_output = attn_backend.forward(q, k, v, attn_metadata, **kwargs)
|
||||
return attn_output
|
||||
|
||||
@ -887,7 +887,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
"cudagraph:none", "cudagraph:without_padding",
|
||||
"cudagraph:with_padding"
|
||||
])
|
||||
def test_auto_dtype_with_helix(self, cuda_graph_config):
|
||||
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
|
||||
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
|
||||
use_nccl_for_alltoall = comms_medium == "nccl"
|
||||
kv_cache_config = {
|
||||
"free_gpu_memory_fraction": 0.5,
|
||||
"enable_block_reuse": False,
|
||||
@ -912,7 +914,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
"context_parallel_size": 2,
|
||||
"cp_config": {
|
||||
"cp_type": "HELIX",
|
||||
"tokens_per_block": 32
|
||||
"tokens_per_block": 32,
|
||||
"use_nccl_for_alltoall": use_nccl_for_alltoall
|
||||
},
|
||||
"disable_overlap_scheduler": True,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
|
||||
@ -535,9 +535,12 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:none]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:without_padding]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:with_padding]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
|
||||
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
|
||||
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
|
||||
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
|
||||
|
||||
@ -69,7 +69,8 @@ l0_gb200_multi_gpus:
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:without_padding]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
@ -84,8 +85,10 @@ l0_gb200_multi_gpus:
|
||||
stage: post_merge
|
||||
backend: pytorch
|
||||
tests:
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:with_padding]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:none]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
|
||||
@ -644,7 +644,13 @@ def _run_mla_distributed(
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def _full_test_multi_gpu(rank: int, world_size: int, scenario: Scenario, gen_steps: int):
|
||||
def _full_test_multi_gpu(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scenario: Scenario,
|
||||
gen_steps: int,
|
||||
comms_medium: str = False,
|
||||
):
|
||||
if scenario.rope_scaling:
|
||||
rope_scaling = {
|
||||
"beta_fast": scenario.rope_beta_fast,
|
||||
@ -814,7 +820,13 @@ def _full_test_multi_gpu(rank: int, world_size: int, scenario: Scenario, gen_ste
|
||||
|
||||
# Distributed mapping for helix
|
||||
mapping = Mapping(
|
||||
world_size=world_size, rank=rank, cp_size=world_size, cp_config={"cp_type": CpType.HELIX}
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
cp_size=world_size,
|
||||
cp_config={
|
||||
"cp_type": CpType.HELIX,
|
||||
"use_nccl_for_alltoall": comms_medium == "nccl",
|
||||
},
|
||||
)
|
||||
# we use cp_allgather here because there is no broadcast op across CP group
|
||||
ref_output_all = cp_allgather(ref_output, mapping=mapping, dim=0)
|
||||
@ -849,18 +861,24 @@ def _run_single_rank(func, *args, **kwargs):
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test")
|
||||
@pytest.mark.parametrize("scenario", test_scenarios, ids=lambda x: f"scenario: {x}")
|
||||
@pytest.mark.parametrize("comms_medium", ["nccl", "fifo"])
|
||||
def test_mla_helix_distributed(
|
||||
scenario: Scenario,
|
||||
comms_medium: str,
|
||||
gen_steps: Optional[int] = None,
|
||||
max_mismatch_ratio: float = 0.02,
|
||||
mismatch_ratios: Optional[List[float]] = None,
|
||||
):
|
||||
world_size = 2
|
||||
print(f"Testing with comms_medium={comms_medium}.")
|
||||
gen_steps = scenario.ref_steps if gen_steps is None else gen_steps
|
||||
with MPIPoolExecutor(max_workers=world_size) as executor:
|
||||
results = executor.map(
|
||||
_run_single_rank,
|
||||
*zip(*[(_full_test_multi_gpu, world_size, scenario, gen_steps)] * world_size),
|
||||
*zip(
|
||||
*[(_full_test_multi_gpu, world_size, scenario, gen_steps, comms_medium == "nccl")]
|
||||
* world_size
|
||||
),
|
||||
)
|
||||
if mismatch_ratios is None:
|
||||
for ratio_mismatch in results:
|
||||
@ -870,13 +888,22 @@ def test_mla_helix_distributed(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for scenario in all_scenarios[:11]:
|
||||
timing_steps = 256
|
||||
gen_steps = scenario.ref_steps + timing_steps
|
||||
print(f"Running scenario: {scenario} and timing {timing_steps} steps")
|
||||
mismatch_ratios = []
|
||||
test_mla_helix_distributed(scenario, gen_steps=gen_steps, mismatch_ratios=mismatch_ratios)
|
||||
if any(mismatch > 0 for mismatch in mismatch_ratios):
|
||||
print(f"Numerical test failed with mismatch ratios: {mismatch_ratios}")
|
||||
else:
|
||||
print("Numerical test passed")
|
||||
for comms_medium in ["fifo", "nccl"]:
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Testing with comms_medium={comms_medium}")
|
||||
print(f"{'=' * 60}\n")
|
||||
for scenario in all_scenarios[:11]:
|
||||
timing_steps = 256
|
||||
gen_steps = scenario.ref_steps + timing_steps
|
||||
print(f"Running scenario: {scenario} and timing {timing_steps} steps")
|
||||
mismatch_ratios = []
|
||||
test_mla_helix_distributed(
|
||||
scenario,
|
||||
comms_medium=comms_medium,
|
||||
gen_steps=gen_steps,
|
||||
mismatch_ratios=mismatch_ratios,
|
||||
)
|
||||
if any(mismatch > 0 for mismatch in mismatch_ratios):
|
||||
print(f"Numerical test failed with mismatch ratios: {mismatch_ratios}")
|
||||
else:
|
||||
print("Numerical test passed")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user