[TRTLLM-9493][feat] Custom AllToAll for helix parallelism (#9986)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
Balaram Buddharaju 2025-12-23 18:14:30 -08:00 committed by GitHub
parent 92d90fa29a
commit 8c1cfc872b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1243 additions and 108 deletions

View File

@ -0,0 +1,683 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/cudaAsyncOps.cuh"
#include "tensorrt_llm/kernels/helixAllToAll.h"
#include "tensorrt_llm/kernels/ll128Proto.cuh"
#include "tensorrt_llm/kernels/moeCommKernelsCommon.h"
#include <algorithm>
#include <tuple>
#include <unordered_map>
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
namespace
{
// ============================================================================
// Structure declarations and definitions
// ============================================================================
// ALIGN_256 is defined in moeCommKernelsCommon.h
struct ALIGN_256 HelixFifoInfo
{
volatile int64_t head;
volatile int64_t tail;
};
// ============================================================================
// Helix-specific FIFO constants
// Note: Helix uses 128KB FIFO entries vs 256KB in FusedMoe
// ============================================================================
constexpr int HELIX_FIFO_DEPTH = 4;
constexpr int HELIX_FIFO_ENTRY_BYTES = 128 * 1024;
constexpr int HELIX_FIFO_TOTAL_BYTES = HELIX_FIFO_ENTRY_BYTES * HELIX_FIFO_DEPTH;
constexpr int HELIX_FIFO_ENTRY_128B_COUNT = HELIX_FIFO_ENTRY_BYTES / BYTES_PER_128B_BLOCK;
constexpr int HELIX_FIFO_TOTAL_U64 = HELIX_FIFO_TOTAL_BYTES / sizeof(uint64_t);
// ============================================================================
// Implementation-only structures
// ============================================================================
struct HelixPairInfo
{
int senderRank;
int receiverRank;
int channel;
int runChannelCount;
};
// WARP_SIZE, WARP_MASK, and other constants are defined in moeCommKernelsCommon.h
// ============================================================================
// Helper Functions
// ============================================================================
__host__ __device__ inline int getFieldSize(HelixFieldInfo const& fieldInfo)
{
return fieldInfo.elementCount * fieldInfo.elementSize;
}
__host__ __device__ inline uint8_t* getPtr(HelixFieldInfo const& fieldInfo, int blockIdx)
{
return fieldInfo.dataPtr + blockIdx * fieldInfo.stride;
}
__device__ __forceinline__ void waitG2sAllFields(uint64_t* smemBar, uint32_t* phaseParity)
{
cp_async_wait_group<0>();
smemBarWait(smemBar, phaseParity);
}
// Align size to 128 bytes
__host__ __device__ __forceinline__ int align128(int size)
{
return align_up(size, BYTES_PER_128B_BLOCK);
}
// ============================================================================
// G2S (Global to Shared) Operations
// ============================================================================
__device__ __forceinline__ void g2sField(
HelixFieldInfo const& fieldInfo, int dataIndex, uint8_t* shmemBase, int shmemOffset, uint64_t* smemBar, int laneId)
{
int copySize = getFieldSize(fieldInfo);
if (copySize > 0 && laneId == 0)
{
uint8_t* srcPtr = getPtr(fieldInfo, dataIndex);
uint8_t* dstPtr = shmemBase + shmemOffset;
cp_async_bulk_g2s(dstPtr, srcPtr, copySize, smemBar);
}
}
template <bool ALLOW_VARIABLE_FIELD1>
__device__ __forceinline__ int g2sAllFields(
HelixFieldInfo const* fieldInfo, int dataIndex, uint8_t* shmemBase, uint64_t* smemBar, int laneId)
{
int totalSize = 0;
// Load field 0 (variable size half)
g2sField(fieldInfo[0], dataIndex, shmemBase, 0, smemBar, laneId);
int field0Size = getFieldSize(fieldInfo[0]);
totalSize += field0Size;
// Load field 1 (single float2)
if constexpr (ALLOW_VARIABLE_FIELD1)
{
g2sField(fieldInfo[1], dataIndex, shmemBase, totalSize, smemBar, laneId);
totalSize += getFieldSize(fieldInfo[1]);
}
else
{
ldgsts<8>(reinterpret_cast<int*>(shmemBase + totalSize),
reinterpret_cast<int const*>(getPtr(fieldInfo[1], dataIndex)), laneId == 0);
cp_async_commit_group();
}
return totalSize;
}
// ============================================================================
// S2G (Shared to Global) Operations
// ============================================================================
__device__ __forceinline__ void s2gField(
HelixFieldInfo const& fieldInfo, int dataIndex, uint8_t* shmemBase, int shmemOffset, int laneId)
{
int copySize = getFieldSize(fieldInfo);
if (copySize > 0 && laneId == 0)
{
uint8_t* srcPtr = shmemBase + shmemOffset;
uint8_t* dstPtr = getPtr(fieldInfo, dataIndex);
cp_async_bulk_s2g(dstPtr, srcPtr, copySize);
}
}
template <bool ALLOW_VARIABLE_FIELD1>
__device__ __forceinline__ void s2gAllFields(
HelixFieldInfo const* fieldInfo, int dataIndex, uint8_t* shmemBase, int laneId)
{
int offset = 0;
// Store field 0 (variable size half)
s2gField(fieldInfo[0], dataIndex, shmemBase, offset, laneId);
int field0Size = getFieldSize(fieldInfo[0]);
offset += field0Size;
// Store field 1 (single float2)
if constexpr (ALLOW_VARIABLE_FIELD1)
{
s2gField(fieldInfo[1], dataIndex, shmemBase, offset, laneId);
offset += getFieldSize(fieldInfo[1]);
}
else
{
if (laneId == 0)
{
auto* srcPtr = reinterpret_cast<float2*>(reinterpret_cast<uint8_t*>(shmemBase) + offset);
auto* dstPtr = reinterpret_cast<float2*>(getPtr(fieldInfo[1], dataIndex));
dstPtr[0] = srcPtr[0];
}
}
cp_async_bulk_commit_group();
}
// ============================================================================
// Workspace FIFO Operations
// ============================================================================
__device__ __forceinline__ uint64_t* getFifoBasePtr(HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
{
// FIFO is physically located at receiver rank
int mappedMemoryRank = pairInfo.receiverRank;
int rankInsideMappedMemory = pairInfo.senderRank;
auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64;
// Navigate to the right FIFO: [peer_rank][channel]
size_t fifoOffset = rankInsideMappedMemory * params.maxChannelCount * HELIX_FIFO_TOTAL_U64;
fifoOffset += pairInfo.channel * HELIX_FIFO_TOTAL_U64;
return mappedMemory + fifoOffset;
}
__device__ __forceinline__ HelixFifoInfo* getSenderHelixFifoInfo(
HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
{
// SenderSideHelixFifoInfo is physically located at sender rank
int mappedMemoryRank = pairInfo.senderRank;
int rankInsideMappedMemory = pairInfo.receiverRank;
auto* mappedMemory = reinterpret_cast<uint8_t*>(params.workspace + mappedMemoryRank * params.workspaceStrideInU64);
size_t fieldOffset = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount;
mappedMemory += fieldOffset;
mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(HelixFifoInfo);
mappedMemory += pairInfo.channel * sizeof(HelixFifoInfo);
return reinterpret_cast<HelixFifoInfo*>(mappedMemory);
}
__device__ __forceinline__ HelixFifoInfo* getReceiverHelixFifoInfo(
HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
{
// ReceiverSideHelixFifoInfo is physically located at receiver rank
int mappedMemoryRank = pairInfo.receiverRank;
int rankInsideMappedMemory = pairInfo.senderRank;
auto* mappedMemory = reinterpret_cast<uint8_t*>(params.workspace + mappedMemoryRank * params.workspaceStrideInU64);
size_t fieldOffset = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount;
fieldOffset += sizeof(HelixFifoInfo) * params.cpSize * params.maxChannelCount;
mappedMemory += fieldOffset;
mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(HelixFifoInfo);
mappedMemory += pairInfo.channel * sizeof(HelixFifoInfo);
return reinterpret_cast<HelixFifoInfo*>(mappedMemory);
}
__device__ __forceinline__ void startWorkspaceS2G(
uint64_t* fifoEntry, uint8_t* shmemBase, int send128ByteCount, int fifo128ByteOffset, int laneId)
{
int copyByteCount = send128ByteCount * BYTES_PER_128B_BLOCK;
if (laneId == 0)
{
cp_async_bulk_s2g(
fifoEntry + fifo128ByteOffset * BYTES_PER_128B_BLOCK / sizeof(uint64_t), shmemBase, copyByteCount);
}
cp_async_bulk_commit_group();
}
__device__ __forceinline__ void startWorkspaceS2GReg(
uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int laneId)
{
int copyInt4Count = send128ByteCount * BYTES_PER_128B_BLOCK / sizeof(int4);
int4* sharedMemoryInt4 = reinterpret_cast<int4*>(sharedMemoryBase);
uint64_t* fifoPtr = fifoEntry + fifo128ByteOffset * UINT64_PER_128B_BLOCK;
int4* fifoPtrInt4 = reinterpret_cast<int4*>(fifoPtr);
#pragma unroll 4
for (int i = laneId; i < copyInt4Count; i += WARP_SIZE)
{
fifoPtrInt4[i] = sharedMemoryInt4[i];
}
}
__device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* shmemBase, uint64_t* fifoEntry, int allLoad128ByteCount,
int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int laneId)
{
int copyByteCount = (allLoad128ByteCount - loaded128ByteCount) * BYTES_PER_128B_BLOCK;
if (laneId == 0)
{
cp_async_bulk_g2s(shmemBase + loaded128ByteCount * BYTES_PER_128B_BLOCK,
fifoEntry + (fifo128ByteOffset + loaded128ByteCount) * UINT64_PER_128B_BLOCK, copyByteCount, smemBar);
}
return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? copyByteCount : 0);
}
// LL128Proto is now defined in ll128Proto.cuh
// ============================================================================
// Size helpers
// ============================================================================
// Compute total size needed for both fields
__host__ __device__ __forceinline__ int computeTotalUnpackedSize(HelixFieldInfo const* fields)
{
int size = 0;
// Field 0: note it must be aligned to 16 bytes
size += align_up(getFieldSize(fields[0]), 16);
// Field 1: single float2
size += align_up(getFieldSize(fields[1]), 16);
return align128(size);
}
__host__ __device__ __forceinline__ int computeTotalPackedSize(HelixFieldInfo const* fields)
{
// because field 0 must be aligned to 16 bytes, this is the same as unpacked
return computeTotalUnpackedSize(fields);
}
__host__ __device__ __forceinline__ int computeProtoTransferSize(HelixFieldInfo const* fields)
{
return LL128Proto::computeProtoTransfer128ByteAlignedSize(computeTotalPackedSize(fields));
}
// ============================================================================
// Main All-to-All Kernel
// ============================================================================
template <bool ALLOW_VARIABLE_FIELD1>
__global__ void helixAllToAllKernel(HelixAllToAllParams params)
{
extern __shared__ uint8_t allWarpShmem[];
__shared__ uint64_t allWarpSmemBar[MAX_GROUP_COUNT_PER_BLOCK];
bool isSender = (blockIdx.z == 0);
// Each warp is a group handling a different peer rank
int group = __shfl_sync(WARP_MASK, threadIdx.y, 0);
int laneId = threadIdx.x % WARP_SIZE;
int runChannelCount = gridDim.y;
// Compute peer rank: blockIdx.x determines which set of peers, group
// determines which peer in that set
int peerRank = blockIdx.x * blockDim.y + group;
if (peerRank >= params.cpSize)
{
return;
}
// Setup pair info for this communication
HelixPairInfo pairInfo;
pairInfo.channel = blockIdx.y;
pairInfo.runChannelCount = runChannelCount;
pairInfo.senderRank = isSender ? params.cpRank : peerRank;
pairInfo.receiverRank = isSender ? peerRank : params.cpRank;
// Initialize barrier for this group
initSmemBar(&allWarpSmemBar[group], laneId);
uint32_t phaseParity = 0;
// Get shared memory for this group
int singlePackedSize = computeTotalPackedSize(params.sendFields);
int singlePacked128ByteCount = singlePackedSize / BYTES_PER_128B_BLOCK;
int singleUnpackedSize = computeTotalUnpackedSize(params.sendFields);
int singleProtoTransferSize = computeProtoTransferSize(params.sendFields);
int singleProtoTransfer128ByteCount = singleProtoTransferSize / BYTES_PER_128B_BLOCK;
int singleShmSize = std::max(singleUnpackedSize, singleProtoTransferSize);
uint8_t* shmem = allWarpShmem + group * singleShmSize;
// Get FIFO pointers
uint64_t* fifoBase = getFifoBasePtr(params, pairInfo);
HelixFifoInfo* senderFifo = getSenderHelixFifoInfo(params, pairInfo);
HelixFifoInfo* receiverFifo = getReceiverHelixFifoInfo(params, pairInfo);
int fifoEntry128ByteIndexBase = HELIX_FIFO_ENTRY_128B_COUNT;
int fifoEntryIndex = -1;
// regardless of sender or receiver, we wait for the previous kernel here
// receiver blocks do not need to wait at all, but they should not start
// to stress the memory system regardless
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (isSender)
{
// sender blocks should trigger next kernel immediately, s.t. they
// do not block the next kernel from starting
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Sender logic: send data from cpRank's slice to peerRank
int64_t head = senderFifo->head;
int64_t tail = senderFifo->tail;
// Each channel processes entries with stride
// Start at channel index, increment by total channel count
for (int entryIdx = pairInfo.channel; entryIdx < params.entryCount; entryIdx += runChannelCount)
{
// dataIndex points to the data for peerRank in this entry
int dataIndex = entryIdx * params.cpSize + peerRank;
// Load data from global to shared, then arrive on barrier
int loadedSize = g2sAllFields<ALLOW_VARIABLE_FIELD1>(
params.sendFields, dataIndex, shmem, &allWarpSmemBar[group], laneId);
uint64_t arriveState = mbarrier_arrive_expect_tx(&allWarpSmemBar[group], laneId == 0 ? loadedSize : 0);
// update FIFO entry index and head if needed
if (fifoEntry128ByteIndexBase + singleProtoTransfer128ByteCount > HELIX_FIFO_ENTRY_128B_COUNT)
{
if (fifoEntryIndex >= 0)
{
head++;
__syncwarp();
senderFifo->head = head;
}
fifoEntryIndex = head % HELIX_FIFO_DEPTH;
fifoEntry128ByteIndexBase = 0;
while (tail + HELIX_FIFO_DEPTH <= head)
{
tail = senderFifo->tail;
}
__syncwarp();
}
// wait for data to be loaded into shared memory
waitG2sAllFields(&allWarpSmemBar[group], &phaseParity);
// note: we don't need to pack anything, fields are already packed in
// shared memory
LL128Proto::protoPack(shmem, head, singlePacked128ByteCount, fifoEntry128ByteIndexBase, laneId);
uint64_t* fifoEntry = fifoBase + fifoEntryIndex * (HELIX_FIFO_ENTRY_BYTES / sizeof(uint64_t));
// Copy from shared to workspace FIFO
startWorkspaceS2GReg(fifoEntry, shmem, singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase, laneId);
fifoEntry128ByteIndexBase += singleProtoTransfer128ByteCount;
// ensure that we can over-write shmem in next iteration
// (it must be fully read by all threads when doing S2G above)
__syncwarp();
}
if (fifoEntry128ByteIndexBase > 0)
{
head++;
senderFifo->head = head;
}
}
else
{
// Receiver logic: receive data from peerRank to cpRank's slice
int64_t tail = receiverFifo->tail;
bool needRelease = false;
// Each channel processes entries with stride
// Start at channel index, increment by total channel count
for (int entryIdx = pairInfo.channel; entryIdx < params.entryCount; entryIdx += runChannelCount)
{
// receiver blocks should trigger next kernel at last iteration
// note: some blocks might not even go into this for-loop, but they
// would exit which is equivalent to the pre-exit trigger
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
if (entryIdx + runChannelCount >= params.entryCount)
{
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// dataIndex points to where we receive data from peerRank in this entry
int dataIndex = entryIdx * params.cpSize + peerRank;
int loaded128ByteCount = 0;
if (fifoEntry128ByteIndexBase + singleProtoTransfer128ByteCount > HELIX_FIFO_ENTRY_128B_COUNT)
{
if (fifoEntryIndex >= 0)
{
tail++;
needRelease = true;
}
fifoEntryIndex = tail % HELIX_FIFO_DEPTH;
fifoEntry128ByteIndexBase = 0;
// receiver doesn't need to wait on FIFO entry being readable: it's
// always readable
__syncwarp();
}
uint64_t* fifoEntry = fifoBase + fifoEntryIndex * (HELIX_FIFO_ENTRY_BYTES / sizeof(uint64_t));
while (loaded128ByteCount < singleProtoTransfer128ByteCount)
{
startWorkspaceG2S(shmem, fifoEntry, singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase,
loaded128ByteCount, &allWarpSmemBar[group], laneId);
if (needRelease)
{
receiverFifo->tail = tail;
senderFifo->tail = tail;
needRelease = false;
}
smemBarWait(&allWarpSmemBar[group], &phaseParity);
loaded128ByteCount += LL128Proto::template checkDataReceivedInShm<false>(shmem, tail,
singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId);
}
LL128Proto::protoUnpack(
shmem, tail, singlePacked128ByteCount, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId);
// note: fields are already unpacked in shared memory
s2gAllFields<ALLOW_VARIABLE_FIELD1>(params.recvFields, dataIndex, shmem, laneId);
// wait for data to be read from shared memory
cp_async_bulk_wait_group_read<0>();
// note: LL128Proto doesn't need rearm
// rearmFifoBuffer();
fifoEntry128ByteIndexBase += singleProtoTransfer128ByteCount;
}
if (fifoEntry128ByteIndexBase > 0)
{
tail++;
receiverFifo->tail = tail;
senderFifo->tail = tail;
}
}
}
// ============================================================================
// Compute actual channel count
// ============================================================================
struct hash_cache_key
{
size_t operator()(std::tuple<int, int, int> const& x) const
{
return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x);
}
};
template <bool ALLOW_VARIABLE_FIELD1>
std::tuple<int, int, int> computeChannelAndGroupCount(int cpSize, HelixFieldInfo const* fields)
{
static std::unordered_map<std::tuple<int, int, int>, std::tuple<int, int, int>, hash_cache_key> cache;
int deviceId = 0;
TLLM_CUDA_CHECK(cudaGetDevice(&deviceId));
int singleShmSize = std::max(computeTotalUnpackedSize(fields), computeProtoTransferSize(fields));
auto key = std::make_tuple(deviceId, cpSize, singleShmSize);
auto it = cache.find(key);
if (it != cache.end())
{
return it->second;
}
int maxGroupCountPerCta = std::min(cpSize, MAX_GROUP_COUNT_PER_BLOCK);
int groupCountPerCta = maxGroupCountPerCta; // Start with max
int totalDynamicShmemSize = singleShmSize * groupCountPerCta;
int maxDynamicShmSize = 0;
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxDynamicShmSize, cudaDevAttrMaxSharedMemoryPerBlockOptin, deviceId));
while (totalDynamicShmemSize > maxDynamicShmSize)
{
groupCountPerCta--;
totalDynamicShmemSize = singleShmSize * groupCountPerCta;
}
TLLM_CHECK_WITH_INFO(totalDynamicShmemSize <= maxDynamicShmSize, "Single packed size %d exceeds limit %d",
singleShmSize, maxDynamicShmSize);
// Set shared memory attribute if needed
if (totalDynamicShmemSize > 48 * 1024)
{
TLLM_CUDA_CHECK(cudaFuncSetAttribute(helixAllToAllKernel<ALLOW_VARIABLE_FIELD1>,
cudaFuncAttributeMaxDynamicSharedMemorySize, totalDynamicShmemSize));
}
int blockCountPerChannel = ceil_div(cpSize, groupCountPerCta);
blockCountPerChannel *= 2; // for send and recv
int smCount = 0;
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId));
// TODO: we might only want to use half the SMs to overlap with other kernels.
// note that overlap with FMHA is almost impossible because it must use
// all SMs and probably uses >50% shmem per SM.
// overlap with the subsequent BMM / out proj GEMMs might be possible,
// so we need experiments to see whether it makes sense.
int channelCount = std::max(smCount / blockCountPerChannel, 1);
auto value = std::make_tuple(channelCount, groupCountPerCta, totalDynamicShmemSize);
cache[key] = value;
return value;
}
// ============================================================================
// Host Launch Function
// ============================================================================
template <bool ALLOW_VARIABLE_FIELD1>
void launchHelixAllToAllImpl(HelixAllToAllParams const& params, cudaStream_t stream)
{
int maxChannelCount = computeHelixMaxChannelCount(params.cpSize);
TLLM_CHECK_WITH_INFO(params.maxChannelCount == maxChannelCount,
"maxChannelCount %d does not match computed maxChannelCount %d", params.maxChannelCount, maxChannelCount);
auto [channelCount, groupCountPerCta, totalDynamicShmemSize]
= computeChannelAndGroupCount<ALLOW_VARIABLE_FIELD1>(params.cpSize, params.sendFields);
if (params.channelCount > 0)
{
channelCount = params.channelCount;
TLLM_CHECK_WITH_INFO(channelCount <= maxChannelCount, "channelCount %d exceeds maxChannelCount %d",
channelCount, maxChannelCount);
}
// Compute grid dimensions
// grid.x = blocks per channel (how many blocks needed to cover all peer
// ranks) grid.y = number of channels (parallel channels) grid.z = 2 (sender
// and receiver)
int ctaPerChannel = ceil_div(params.cpSize, groupCountPerCta);
auto* kernel_instance = &helixAllToAllKernel<ALLOW_VARIABLE_FIELD1>;
cudaLaunchConfig_t config;
config.gridDim = dim3(ctaPerChannel, channelCount, 2);
config.blockDim = dim3(WARP_SIZE, groupCountPerCta);
config.dynamicSmemBytes = totalDynamicShmemSize;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_instance, params));
}
} // anonymous namespace
// ============================================================================
// Public API Functions
// ============================================================================
int computeHelixMaxChannelCount(int cpSize, int smCount)
{
if (smCount == 0)
{
int deviceId = 0;
TLLM_CUDA_CHECK(cudaGetDevice(&deviceId));
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId));
}
int blockCountPerChannel = ceil_div(cpSize, MAX_GROUP_COUNT_PER_BLOCK);
blockCountPerChannel *= 2; // for send and recv
int preferredChannel = smCount / blockCountPerChannel;
return std::max(preferredChannel, 1); // at least one channel
}
size_t computeHelixWorkspaceSizePerRank(int cpSize)
{
static int maxChannelCount = 0;
if (maxChannelCount == 0)
{
maxChannelCount = computeHelixMaxChannelCount(cpSize);
}
// FIFO buffers: cpSize * channelCount pairs
size_t fifoSize = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount;
// Sender and receiver FIFO info structures
size_t senderInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
size_t receiverInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
return fifoSize + senderInfoSize + receiverInfoSize;
}
void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream)
{
if (allowVariableField1)
{
launchHelixAllToAllImpl<true>(params, stream);
}
else
{
launchHelixAllToAllImpl<false>(params, stream);
}
}
// ============================================================================
// Workspace Initialization
// ============================================================================
void initializeHelixWorkspace(uint64_t* local_workspace_ptr, int cpSize, cudaStream_t stream)
{
int maxChannelCount = computeHelixMaxChannelCount(cpSize);
// Calculate sizes with channel dimension
size_t fifoSize = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount;
size_t senderInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
size_t receiverInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
// Initialize FIFO buffers to 0xFFFFFFFF (-1 for signed integer types)
TLLM_CUDA_CHECK(cudaMemsetAsync(local_workspace_ptr, 0xFF, fifoSize, stream));
// Initialize sender and receiver info to zero (single call for both)
uint8_t* infoPtr = reinterpret_cast<uint8_t*>(local_workspace_ptr) + fifoSize;
TLLM_CUDA_CHECK(cudaMemsetAsync(infoPtr, 0, senderInfoSize + receiverInfoSize, stream));
}
} // namespace kernels
TRTLLM_NAMESPACE_END

View File

@ -0,0 +1,94 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/config.h"
#include <cuda_runtime.h>
#include <cstddef>
#include <cstdint>
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
struct HelixFieldInfo
{
uint8_t* dataPtr;
int elementCount; // Number of elements (e.g., kv_lora_rank for field 0, 1 for
// field 1)
int elementSize; // Size of each element in bytes (2 for half, 8 for float2)
int stride; // Stride between rows in bytes
};
struct HelixAllToAllParams
{
HelixFieldInfo sendFields[2];
HelixFieldInfo recvFields[2];
int entryCount; // Number of entries per peer rank to process
uint64_t* workspace;
int workspaceStrideInU64;
int cpRank;
int cpSize;
int channelCount; // use 0 to auto-compute
int maxChannelCount;
};
// ============================================================================
// Workspace Management Functions
// ============================================================================
/**
* Compute number of channels for communication based on cpSize.
*
* @param cpSize Number of context parallel ranks
* @param smCount Number of SMs available (0 = auto-detect)
* @return Number of channels to use
*/
int computeHelixMaxChannelCount(int cpSize, int smCount = 0);
/**
* Compute the workspace size required per rank for the all-to-all operation.
*
* @param cpSize Number of context parallel ranks
* @return Size in bytes
*/
size_t computeHelixWorkspaceSizePerRank(int cpSize);
/**
* Initialize workspace memory for a given rank.
* Should be called once during setup.
*
* @param workspace Pointer to workspace memory (per-rank view)
* @param cpSize Number of context parallel ranks
* @param stream CUDA stream for asynchronous operations
*/
void initializeHelixWorkspace(uint64_t* workspace, int cpSize, cudaStream_t stream);
/**
* Launch the helix all-to-all kernel.
*
* @param params Kernel parameters including field info and workspace
* @param allowVariableField1 Whether to allow variable field 1
* @param stream CUDA stream for kernel launch
*/
void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream);
} // namespace kernels
TRTLLM_NAMESPACE_END

View File

@ -18,6 +18,7 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/vector.h>
#include <tensorrt_llm/kernels/helixAllToAll.h>
#include <tensorrt_llm/thop/attentionOp.h>
#include <tensorrt_llm/thop/moeAlltoAllMeta.h>
#include <torch/extension.h>
@ -73,5 +74,10 @@ void initBindings(nb::module_& m)
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
nb::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
nb::call_guard<nb::gil_scoped_release>());
m.def(
"get_helix_workspace_size_per_rank",
[](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); },
nb::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes");
}
} // namespace tensorrt_llm::nanobind::thop

View File

@ -18,6 +18,7 @@
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <tensorrt_llm/kernels/helixAllToAll.h>
#include <tensorrt_llm/thop/attentionOp.h>
#include <tensorrt_llm/thop/moeAlltoAllMeta.h>
#include <torch/extension.h>
@ -73,5 +74,10 @@ void initBindings(pybind11::module_& m)
py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt,
py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
py::call_guard<py::gil_scoped_release>());
m.def(
"get_helix_workspace_size_per_rank",
[](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); },
py::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes");
}
} // namespace tensorrt_llm::pybind::thop

View File

@ -16,19 +16,12 @@
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/helixAllToAll.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <NvInferRuntime.h>
#include <c10/cuda/CUDAStream.h>
#include <cassert>
#include <set>
#include <string>
#include <torch/extension.h>
#include <vector>
#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#endif // ENABLE_MULTI_DEVICE
TRTLLM_NAMESPACE_BEGIN
@ -119,6 +112,145 @@ std::vector<torch::Tensor> alltoall_helix(
#endif // ENABLE_MULTI_DEVICE
}
/**
* Helix All-to-All operation with two fields.
*
* Input tensors have shape [..., cp_size, kv_lora_rank] for partial_o and [...,
* cp_size, 2] for softmax_stats. The operation exchanges data along the cp_size
* dimension across all ranks.
*
* @param partial_o Field 0 tensor (half precision, shape [..., cp_size,
* kv_lora_rank])
* @param softmax_stats Field 1 tensor (float32, shape [..., cp_size, 2])
* @param workspace Workspace tensor (uint64, strided across ranks)
* @param cp_rank Current context parallel rank
* @param cp_size Total number of context parallel ranks
* @return tuple of (partial_o_out, softmax_stats_out) with same shapes as inputs
*/
std::tuple<torch::Tensor, torch::Tensor> alltoall_helix_native(
torch::Tensor partial_o, torch::Tensor softmax_stats, torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
{
// Input validation
CHECK_TH_CUDA(partial_o);
CHECK_TH_CUDA(softmax_stats);
CHECK_TH_CUDA(workspace);
CHECK_CONTIGUOUS(partial_o);
CHECK_CONTIGUOUS(softmax_stats);
// Type checks
TORCH_CHECK(partial_o.scalar_type() == at::ScalarType::Half || partial_o.scalar_type() == at::ScalarType::BFloat16,
"partial_o must be half or bfloat16");
CHECK_TYPE(softmax_stats, at::ScalarType::Float);
CHECK_TYPE(workspace, at::ScalarType::UInt64);
// Shape validation
TORCH_CHECK(partial_o.dim() >= 2, "partial_o must have at least 2 dimensions");
TORCH_CHECK(softmax_stats.dim() >= 2, "softmax_stats must have at least 2 dimensions");
TORCH_CHECK(
partial_o.dim() == softmax_stats.dim(), "partial_o and softmax_stats must have same number of dimensions");
// Get dimensions
int kv_lora_rank = partial_o.size(-1);
TORCH_CHECK(partial_o.size(-2) == cp_size && softmax_stats.size(-2) == cp_size,
"partial_o/softmax_stats second-to-last dimension must equal cp_size");
TORCH_CHECK(softmax_stats.size(-1) % 2 == 0 && softmax_stats.size(-1) >= 2,
"softmax_stats last dimension must be divisible by 2 (float2)");
bool allowVariableField1 = softmax_stats.size(-1) > 2;
// Check that leading dimensions match
for (int i = 0; i < partial_o.dim() - 2; i++)
{
TORCH_CHECK(partial_o.size(i) == softmax_stats.size(i),
"partial_o and softmax_stats must have matching dimensions except last two");
}
TORCH_CHECK(partial_o.size(-1) * partial_o.element_size() % 16 == 0, "partial_o must be aligned to 16 bytes");
TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D (strided across ranks)");
TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows");
// Calculate entry count (product of all dimensions before cp_size)
// This is the number of entries to process per peer rank
int entry_count = 1;
for (int i = 0; i < partial_o.dim() - 2; i++)
{
entry_count *= partial_o.size(i);
}
// Reshape to 3D: [entry_count, cp_size, feature_dim]
torch::Tensor partial_o_3d = partial_o.reshape({entry_count, cp_size, kv_lora_rank});
torch::Tensor softmax_stats_3d = softmax_stats.reshape({entry_count, cp_size, softmax_stats.size(-1)});
// Allocate output tensors (same shape as input)
torch::Tensor partial_o_out = torch::empty_like(partial_o);
torch::Tensor softmax_stats_out = torch::empty_like(softmax_stats);
torch::Tensor partial_o_out_3d = partial_o_out.reshape({entry_count, cp_size, kv_lora_rank});
torch::Tensor softmax_stats_out_3d = softmax_stats_out.reshape({entry_count, cp_size, softmax_stats.size(-1)});
// Setup parameters
tensorrt_llm::kernels::HelixAllToAllParams params;
// Field 0 (variable size half)
params.sendFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_3d.data_ptr());
params.sendFields[0].elementCount = kv_lora_rank;
params.sendFields[0].elementSize = partial_o.element_size();
params.sendFields[0].stride = partial_o_3d.stride(1) * partial_o.element_size();
params.recvFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_out_3d.data_ptr());
params.recvFields[0].elementCount = kv_lora_rank;
params.recvFields[0].elementSize = partial_o.element_size();
params.recvFields[0].stride = partial_o_out_3d.stride(1) * partial_o.element_size();
// Field 1 (single float2)
params.sendFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_3d.data_ptr<float>());
params.sendFields[1].elementCount = softmax_stats.size(-1);
params.sendFields[1].elementSize = softmax_stats.element_size();
params.sendFields[1].stride = softmax_stats_3d.stride(1) * softmax_stats.element_size();
params.recvFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_out_3d.data_ptr<float>());
params.recvFields[1].elementCount = softmax_stats.size(-1);
params.recvFields[1].elementSize = softmax_stats.element_size();
params.recvFields[1].stride = softmax_stats_out_3d.stride(1) * softmax_stats.element_size();
// Entry count and workspace
params.entryCount = entry_count;
params.workspace = workspace.data_ptr<uint64_t>();
params.workspaceStrideInU64 = workspace.stride(0);
// CP info
params.cpRank = cp_rank;
params.cpSize = cp_size;
params.channelCount = 0; // auto-compute
params.maxChannelCount = tensorrt_llm::kernels::computeHelixMaxChannelCount(cp_size);
// Launch kernel
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::launchHelixAllToAll(params, allowVariableField1, stream);
return std::make_tuple(partial_o_out, softmax_stats_out);
}
/**
* Initialize workspace for helix all-to-all
*/
void initialize_helix_workspace(torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
{
CHECK_TH_CUDA(workspace);
CHECK_TYPE(workspace, at::ScalarType::UInt64);
TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D");
TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows");
TORCH_CHECK(cp_rank >= 0 && cp_rank < cp_size, "cp_rank must be in [0, cp_size)");
auto stream = at::cuda::getCurrentCUDAStream();
uint64_t* global_workspace_ptr = workspace.data_ptr<uint64_t>();
uint64_t* local_workspace_ptr = workspace[cp_rank].data_ptr<uint64_t>();
TORCH_CHECK(local_workspace_ptr == global_workspace_ptr + cp_rank * workspace.stride(0),
"local_workspace_ptr must be at the correct offset in the global "
"workspace");
tensorrt_llm::kernels::initializeHelixWorkspace(local_workspace_ptr, cp_size, stream);
}
} // namespace torch_ext
TRTLLM_NAMESPACE_END
@ -126,9 +258,17 @@ TRTLLM_NAMESPACE_END
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("alltoall_helix(Tensor[] input_list, int[] group, int? num_lists) -> Tensor[]");
m.def(
"alltoall_helix_native(Tensor partial_o, Tensor softmax_stats, Tensor(a!) workspace, int "
"cp_rank, int cp_size) -> (Tensor, Tensor)");
m.def(
"initialize_helix_workspace(Tensor(a!) workspace, int cp_rank, int cp_size) "
"-> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("alltoall_helix", &tensorrt_llm::torch_ext::alltoall_helix);
m.impl("alltoall_helix_native", &tensorrt_llm::torch_ext::alltoall_helix_native);
m.impl("initialize_helix_workspace", &tensorrt_llm::torch_ext::initialize_helix_workspace);
}

View File

@ -51,38 +51,54 @@ def _check_cu_result(cu_func_ret):
class MnnvlMemory:
"""MNNVL memory management for tensor parallel (TP) operations."""
# Shared across all subclasses (global/device state).
initialized: bool = False
current_mem_offset: int = 0
current_rank_stride: int = 0 # stride for ranks and also address space size.
current_start_address: int = 0
# allocation granularity
allocation_granularity: int = 0
# fabric address page size (512 MB)
fabric_page_size: int = 1 << 29
# MPI communicator
comm = None
fabric_page_size: int = 1 << 29 # 512 MB.
dev_id: int = None
# Per-class state attributes. These will be auto-initialized for each subclass
# to avoid polluting the parent class's state. Use callable (e.g., dict) for mutable defaults.
_per_class_attrs = {
"current_mem_offset": 0,
"current_rank_stride": 0, # stride for ranks and also address space size.
"current_start_address": 0,
"comm": None, # MPI communicator.
"allocated_map": dict, # callable for fresh dict.
"address_refcnt": dict, # callable for fresh dict.
}
# Initialize per-class state for the base class.
current_mem_offset: int = 0
current_rank_stride: int = 0
current_start_address: int = 0
comm = None
allocated_map = {}
address_refcnt = {}
def __init_subclass__(cls, **kwargs):
"""Auto-initialize per-class attributes for each subclass to avoid sharing state with parent."""
super().__init_subclass__(**kwargs)
for attr, default in cls._per_class_attrs.items():
if callable(default):
setattr(cls, attr, default()) # e.g., dict() creates a fresh dict.
else:
setattr(cls, attr, default)
def __init__(self, mapping: Mapping, size: int):
self.mapping = mapping
self.segment_size = size
self.ptr, self.rank_stride = MnnvlMemory.open_mnnvl_memory(self.mapping, size)
self.ptr, self.rank_stride = type(self).open_mnnvl_memory(self.mapping, size)
def __del__(self):
if not sys.is_finalizing():
if hasattr(self, "ptr"):
MnnvlMemory.close_mnnvl_memory(self.ptr)
type(self).close_mnnvl_memory(self.ptr)
def as_torch_strided_tensor(self, dtype):
num_segments = MnnvlMemory.comm.Get_size()
num_segments = type(self).comm.Get_size()
return pack_strided_memory(
self.ptr, self.segment_size, self.rank_stride, num_segments, dtype, MnnvlMemory.dev_id
)
@ -99,16 +115,17 @@ class MnnvlMemory:
pynvml.nvmlInit()
MnnvlMemory.initialized = True
@staticmethod
def get_comm(mapping: Mapping):
if MnnvlMemory.comm is not None:
return MnnvlMemory.comm
@classmethod
def get_comm(cls, mapping: Mapping):
"""Get TP-based communicator (ranks grouped by PP+CP+MOE_TP, ordered by TP rank)."""
if cls.comm is not None:
return cls.comm
comm = mpi_comm().Split(
(mapping.pp_rank * mapping.cp_size + mapping.cp_rank) * mapping.moe_tp_size
+ mapping.moe_tp_rank,
mapping.tp_rank,
)
MnnvlMemory.comm = comm
cls.comm = comm
return comm
@staticmethod
@ -148,23 +165,26 @@ class MnnvlMemory:
MnnvlMemory.allocation_granularity = granularity
return MnnvlMemory.allocation_granularity
@staticmethod
def new_mnnvl_memory_address(mapping: Mapping, size: int):
@classmethod
def new_mnnvl_memory_address(cls, mapping: Mapping, size: int):
page_count = (size + MnnvlMemory.fabric_page_size - 1) // MnnvlMemory.fabric_page_size
current_rank_stride = page_count * MnnvlMemory.fabric_page_size
logger.info(f"[MnnvlMemory] creating address with stride={current_rank_stride}")
comm = MnnvlMemory.get_comm(mapping)
logger.info(f"[{cls.__name__}] creating address with stride={current_rank_stride}")
comm = cls.get_comm(mapping)
comm_size = comm.Get_size()
address_size = current_rank_stride * comm_size
ptr = _check_cu_result(
cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0)
)
MnnvlMemory.current_start_address = int(ptr)
MnnvlMemory.current_rank_stride = current_rank_stride
MnnvlMemory.current_mem_offset = 0
cls.current_start_address = int(ptr)
cls.current_rank_stride = current_rank_stride
cls.current_mem_offset = 0
@classmethod
def open_mnnvl_memory(cls, mapping: Mapping, size: int):
# Ensure MnnvlMemory is initialized (for dev_id and allocation_granularity)
MnnvlMemory.initialize()
@staticmethod
def open_mnnvl_memory(mapping: Mapping, size: int):
dev = _check_cu_result(cuda.cuCtxGetDevice())
dev_id = int(dev)
if MnnvlMemory.dev_id is None:
@ -172,7 +192,7 @@ class MnnvlMemory:
assert dev_id == MnnvlMemory.dev_id, (
f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}"
)
comm = MnnvlMemory.get_comm(mapping)
comm = cls.get_comm(mapping)
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()
all_rank_allocate_sizes = comm.allgather(size)
@ -181,10 +201,10 @@ class MnnvlMemory:
granularity = MnnvlMemory.get_allocation_granularity(dev_id)
aligned_size = (size + granularity - 1) // granularity * granularity
if MnnvlMemory.current_mem_offset + aligned_size > MnnvlMemory.current_rank_stride:
MnnvlMemory.new_mnnvl_memory_address(mapping, aligned_size)
if cls.current_mem_offset + aligned_size > cls.current_rank_stride:
cls.new_mnnvl_memory_address(mapping, aligned_size)
assert MnnvlMemory.current_mem_offset + aligned_size <= MnnvlMemory.current_rank_stride
assert cls.current_mem_offset + aligned_size <= cls.current_rank_stride
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
allocated_mem_handle = _check_cu_result(
@ -245,9 +265,7 @@ class MnnvlMemory:
for i, remote_handle_data in enumerate(all_handles_data):
rank_ptr = (
MnnvlMemory.current_start_address
+ MnnvlMemory.current_rank_stride * i
+ MnnvlMemory.current_mem_offset
cls.current_start_address + cls.current_rank_stride * i + cls.current_mem_offset
)
if i == comm_rank:
# Local memory mapping
@ -265,44 +283,44 @@ class MnnvlMemory:
_check_cu_result(cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1))
ptr = MnnvlMemory.current_start_address + MnnvlMemory.current_mem_offset
stride = MnnvlMemory.current_rank_stride
MnnvlMemory.allocated_map[ptr] = (
ptr = cls.current_start_address + cls.current_mem_offset
stride = cls.current_rank_stride
cls.allocated_map[ptr] = (
mapping,
aligned_size,
mem_handles,
MnnvlMemory.current_start_address,
MnnvlMemory.current_rank_stride,
MnnvlMemory.current_mem_offset,
cls.current_start_address,
cls.current_rank_stride,
cls.current_mem_offset,
)
MnnvlMemory.address_refcnt[MnnvlMemory.current_start_address] = (
MnnvlMemory.address_refcnt.get(MnnvlMemory.current_start_address, 0) + 1
cls.address_refcnt[cls.current_start_address] = (
cls.address_refcnt.get(cls.current_start_address, 0) + 1
)
MnnvlMemory.current_mem_offset += aligned_size
cls.current_mem_offset += aligned_size
return ptr, stride
@staticmethod
def close_mnnvl_memory(ptr: int):
@classmethod
def close_mnnvl_memory(cls, ptr: int):
mapping, aligned_size, mem_handles, start_address, rank_stride, address_offset = (
MnnvlMemory.allocated_map.pop(ptr)
cls.allocated_map.pop(ptr)
)
comm = MnnvlMemory.get_comm(mapping)
comm = cls.get_comm(mapping)
comm_size = comm.Get_size()
for i in range(comm_size):
rank_ptr = start_address + i * rank_stride + address_offset
_check_cu_result(cuda.cuMemUnmap(rank_ptr, aligned_size))
_check_cu_result(cuda.cuMemRelease(mem_handles[i]))
MnnvlMemory.address_refcnt[start_address] -= 1
cls.address_refcnt[start_address] -= 1
if MnnvlMemory.address_refcnt[start_address] == 0:
MnnvlMemory.address_refcnt.pop(start_address)
if cls.address_refcnt[start_address] == 0:
cls.address_refcnt.pop(start_address)
device_ptr = cuda.CUdeviceptr(start_address)
_check_cu_result(cuda.cuMemAddressFree(device_ptr, comm_size * rank_stride))
if start_address == MnnvlMemory.current_start_address:
MnnvlMemory.current_start_address = 0
MnnvlMemory.current_rank_stride = 0
MnnvlMemory.current_mem_offset = 0
if start_address == cls.current_start_address:
cls.current_start_address = 0
cls.current_rank_stride = 0
cls.current_mem_offset = 0
@staticmethod
def support_nvlink(need_all_up: bool = True):
@ -338,6 +356,29 @@ class MnnvlMemory:
return support_nvlink_and_all_up
class HelixCpMnnvlMemory(MnnvlMemory):
"""MNNVL memory management for Helix context parallel (CP) operations.
Per-class state (current_mem_offset, comm, allocated_map, etc.) is automatically
initialized via __init_subclass__ in the parent class, ensuring this class has
its own isolated state separate from MnnvlMemory.
"""
@classmethod
def get_comm(cls, mapping: Mapping):
"""Get CP-based communicator (ranks grouped by PP+TP+MOE_TP, ordered by CP rank)."""
if cls.comm is not None:
return cls.comm
comm = mpi_comm().Split(
mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size
+ mapping.tp_rank * mapping.moe_tp_size
+ mapping.moe_tp_rank,
mapping.cp_rank,
)
cls.comm = comm
return comm
@dataclass
class MoEAlltoallInfo:
local_gather_indices: torch.Tensor

View File

@ -752,6 +752,17 @@ def _register_fake():
for i in range(0, len(input_list), num_ranks)
]
@torch.library.register_fake("trtllm::alltoall_helix_native")
def _(partial_o, softmax_stats, workspace, cp_rank, cp_size):
# Returns outputs with same shapes as inputs
return partial_o.new_empty(partial_o.shape), softmax_stats.new_empty(
softmax_stats.shape)
@torch.library.register_fake("trtllm::initialize_helix_workspace")
def _(workspace, cp_rank, cp_size):
# This op initializes workspace in-place and returns nothing
return None
@torch.library.register_fake("trtllm::helix_post_process")
def _(gathered_o, gathered_stats, scale):
return gathered_o.new_empty(*gathered_o.shape[1:])

View File

@ -2,9 +2,10 @@ from tensorrt_llm.functional import AllReduceFusionOp
from .communicator import Distributed, MPIDist, TorchDist
from .moe_alltoall import MoeAlltoAll
from .ops import (AllReduce, AllReduceParams, AllReduceStrategy, MoEAllReduce,
MoEAllReduceParams, allgather, alltoall_helix, cp_allgather,
reducescatter, userbuffers_allreduce_finalize)
from .ops import (AllReduce, AllReduceParams, AllReduceStrategy,
HelixAllToAllNative, MoEAllReduce, MoEAllReduceParams,
allgather, alltoall_helix, cp_allgather, reducescatter,
userbuffers_allreduce_finalize)
__all__ = [
"allgather",
@ -16,6 +17,7 @@ __all__ = [
"AllReduceParams",
"AllReduceFusionOp",
"AllReduceStrategy",
"HelixAllToAllNative",
"MoEAllReduce",
"MoEAllReduceParams",
"MoeAlltoAll",

View File

@ -2,14 +2,16 @@ import math
import os
import platform
import threading
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from tensorrt_llm._mnnvl_utils import HelixCpMnnvlMemory, MnnvlMemory
from tensorrt_llm._torch.distributed.symm_mem_allreduce import \
SymmetricMemoryAllReduce
from tensorrt_llm._utils import mpi_comm, mpi_disabled
from tensorrt_llm.bindings import internal as _tllm_internal
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy, MoEAllReduceParams)
@ -363,6 +365,84 @@ def alltoall_helix(
return torch.ops.trtllm.alltoall_helix(inputs, group, num_lists)
class HelixAllToAllNative:
"""
Manager for Helix All-to-All operations with MNNVL workspace management.
Exchanges data along the cp_size dimension:
- partial_o: [..., cp_size, kv_lora_rank] half-precision
- softmax_stats: [..., cp_size, 2] float32
"""
# Global cache: mapping -> instance
_cache: Dict[Mapping, "HelixAllToAllNative"] = {}
def __init__(self, mapping: Mapping, workspace: HelixCpMnnvlMemory,
workspace_tensor: torch.Tensor):
"""Private constructor - use get() instead."""
self.mapping = mapping
self.workspace = workspace
self.workspace_tensor = workspace_tensor
@staticmethod
def get(mapping: Mapping) -> "HelixAllToAllNative":
"""
Get or create a HelixAllToAllNative instance for the given configuration.
Args:
mapping: TensorRT-LLM mapping object containing cp_size and cp_rank
Returns:
Cached or newly-created HelixAllToAllNative instance
"""
if mapping not in HelixAllToAllNative._cache:
logger.info(
f"Rank {mapping.cp_rank} initializing HelixCpMnnvlMemory for Helix"
)
MnnvlMemory.initialize()
# Get workspace size (in bytes)
workspace_size_per_rank = _tllm_internal.thop.get_helix_workspace_size_per_rank(
mapping.cp_size)
# Allocate MNNVL memory using CP communicator for Helix
workspace = HelixCpMnnvlMemory(mapping, workspace_size_per_rank)
workspace_tensor = workspace.as_torch_strided_tensor(torch.uint64)
torch.ops.trtllm.initialize_helix_workspace(workspace_tensor,
mapping.cp_rank,
mapping.cp_size)
torch.cuda.synchronize()
HelixCpMnnvlMemory.get_comm(mapping).barrier()
HelixAllToAllNative._cache[mapping] = HelixAllToAllNative(
mapping, workspace, workspace_tensor)
return HelixAllToAllNative._cache[mapping]
def alltoall_native(self, partial_o: torch.Tensor,
softmax_stats: torch.Tensor):
"""
Perform all-to-all data exchange.
Args:
partial_o: Tensor with shape [..., cp_size, kv_lora_rank], dtype half.
softmax_stats: Tensor with shape [..., cp_size, 2], dtype float32.
Returns:
Tuple of (partial_o_out, softmax_stats_out) with same shapes as inputs.
"""
partial_o_out, softmax_stats_out = torch.ops.trtllm.alltoall_helix_native(
partial_o,
softmax_stats,
self.workspace_tensor,
self.mapping.cp_rank,
self.mapping.cp_size,
)
return partial_o_out, softmax_stats_out
def reducescatter(
input: Union[torch.Tensor, List[torch.Tensor]],
mapping: Mapping,

View File

@ -20,7 +20,7 @@ from ..attention_backend.interface import (AttentionBackend, AttentionMask,
from ..attention_backend.sparse.dsa import (
DSAtrtllmAttentionMetadata, transform_local_topk_and_prepare_pool_view)
from ..attention_backend.utils import create_attention, get_attention_backend
from ..distributed import AllReduceParams, alltoall_helix
from ..distributed import AllReduceParams, HelixAllToAllNative, alltoall_helix
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
@ -1120,24 +1120,60 @@ class MLA(nn.Module):
softmax_stats_tensor=softmax_stats,
**kwargs,
)
# this is the post-processing of helix parallel attention,
# similar to the post-processing of ring attention
kv_lora_rank = partial_o.shape[-1] // self.num_heads_tp
assert self.kv_lora_rank == kv_lora_rank
# transpose the tensors to make the split across cp_size contiguous
# for both tensors, we need to split across the second dimension
chunks = []
for t in [partial_o, softmax_stats]:
t = t.transpose(1, 0).contiguous()
chunks.extend(torch.split(t,
t.shape[0] // self.mapping.cp_size))
gathered = alltoall_helix(chunks, self.mapping.cp_group)
# transpose the tensors back to ensure dimensions are ordered correctly
# note: an additional dimension was added at the first index for all-to-all,
# so the transpose dimensions are shifted by 1
gathered = [t.transpose(1, 2).contiguous() for t in gathered]
return torch.ops.trtllm.helix_post_process(gathered[0], gathered[1],
1.0)
# Switch between NCCL-based and FIFO-based (MNNVL) all-to-all based on cp_config.
if self.mapping.cp_config.get("use_nccl_for_alltoall", True):
# NCCL-based implementation using alltoall_helix.
# This is the post-processing of helix parallel attention,
# similar to the post-processing of ring attention.
# Transpose the tensors to make the split across cp_size contiguous
# For both tensors, we need to split across the second dimension.
chunks = []
for t in [partial_o, softmax_stats]:
t = t.transpose(1, 0).contiguous()
chunks.extend(
torch.split(t, t.shape[0] // self.mapping.cp_size))
gathered = alltoall_helix(chunks, self.mapping.cp_group)
# Transpose the tensors back to ensure dimensions are ordered correctly.
# Note: an additional dimension was added at the first index for all-to-all,
# so the transpose dimensions are shifted by 1.
gathered = [t.transpose(1, 2).contiguous() for t in gathered]
return torch.ops.trtllm.helix_post_process(
gathered[0], gathered[1], 1.0)
else:
# FIFO-based implementation using MNNVL workspace and LL128 Proto.
# Get or create Helix All-to-All instance.
helix = HelixAllToAllNative.get(self.mapping)
# Get dimensions.
num_tokens = partial_o.shape[0]
cp_size = self.mapping.cp_size
# Reshape for FIFO-based all-to-all.
# partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank]
# softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2]
partial_o = partial_o.view(
num_tokens, cp_size, self.num_heads_tp_cp,
kv_lora_rank).transpose(1, 2).contiguous()
softmax_stats = softmax_stats.view(num_tokens, cp_size,
self.num_heads_tp_cp,
2).transpose(1,
2).contiguous()
# Call FIFO-based helixAllToAll.
partial_o_out, softmax_stats_out = helix.alltoall_native(
partial_o, softmax_stats)
# partial_o_out: [num_tokens, num_heads_tp_cp, cp_size, kv_lora_rank]
# softmax_stats_out: [num_tokens, num_heads_tp_cp, cp_size, 2]
# cp_dim = 2 (the dimension where cp_size is located)
# Call helix_post_process_native with cp_dim=2.
return torch.ops.trtllm.helix_post_process_native(
partial_o_out, softmax_stats_out, 1.0, 2)
else:
attn_output = attn_backend.forward(q, k, v, attn_metadata, **kwargs)
return attn_output

View File

@ -887,7 +887,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
"cudagraph:none", "cudagraph:without_padding",
"cudagraph:with_padding"
])
def test_auto_dtype_with_helix(self, cuda_graph_config):
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
use_nccl_for_alltoall = comms_medium == "nccl"
kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False,
@ -912,7 +914,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
"context_parallel_size": 2,
"cp_config": {
"cp_type": "HELIX",
"tokens_per_block": 32
"tokens_per_block": 32,
"use_nccl_for_alltoall": use_nccl_for_alltoall
},
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,

View File

@ -535,9 +535,12 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:none]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:without_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:with_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]

View File

@ -69,7 +69,8 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:without_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
- condition:
ranges:
system_gpu_count:
@ -84,8 +85,10 @@ l0_gb200_multi_gpus:
stage: post_merge
backend: pytorch
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:with_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:none]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]

View File

@ -644,7 +644,13 @@ def _run_mla_distributed(
@torch.inference_mode
def _full_test_multi_gpu(rank: int, world_size: int, scenario: Scenario, gen_steps: int):
def _full_test_multi_gpu(
rank: int,
world_size: int,
scenario: Scenario,
gen_steps: int,
comms_medium: str = False,
):
if scenario.rope_scaling:
rope_scaling = {
"beta_fast": scenario.rope_beta_fast,
@ -814,7 +820,13 @@ def _full_test_multi_gpu(rank: int, world_size: int, scenario: Scenario, gen_ste
# Distributed mapping for helix
mapping = Mapping(
world_size=world_size, rank=rank, cp_size=world_size, cp_config={"cp_type": CpType.HELIX}
world_size=world_size,
rank=rank,
cp_size=world_size,
cp_config={
"cp_type": CpType.HELIX,
"use_nccl_for_alltoall": comms_medium == "nccl",
},
)
# we use cp_allgather here because there is no broadcast op across CP group
ref_output_all = cp_allgather(ref_output, mapping=mapping, dim=0)
@ -849,18 +861,24 @@ def _run_single_rank(func, *args, **kwargs):
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test")
@pytest.mark.parametrize("scenario", test_scenarios, ids=lambda x: f"scenario: {x}")
@pytest.mark.parametrize("comms_medium", ["nccl", "fifo"])
def test_mla_helix_distributed(
scenario: Scenario,
comms_medium: str,
gen_steps: Optional[int] = None,
max_mismatch_ratio: float = 0.02,
mismatch_ratios: Optional[List[float]] = None,
):
world_size = 2
print(f"Testing with comms_medium={comms_medium}.")
gen_steps = scenario.ref_steps if gen_steps is None else gen_steps
with MPIPoolExecutor(max_workers=world_size) as executor:
results = executor.map(
_run_single_rank,
*zip(*[(_full_test_multi_gpu, world_size, scenario, gen_steps)] * world_size),
*zip(
*[(_full_test_multi_gpu, world_size, scenario, gen_steps, comms_medium == "nccl")]
* world_size
),
)
if mismatch_ratios is None:
for ratio_mismatch in results:
@ -870,13 +888,22 @@ def test_mla_helix_distributed(
if __name__ == "__main__":
for scenario in all_scenarios[:11]:
timing_steps = 256
gen_steps = scenario.ref_steps + timing_steps
print(f"Running scenario: {scenario} and timing {timing_steps} steps")
mismatch_ratios = []
test_mla_helix_distributed(scenario, gen_steps=gen_steps, mismatch_ratios=mismatch_ratios)
if any(mismatch > 0 for mismatch in mismatch_ratios):
print(f"Numerical test failed with mismatch ratios: {mismatch_ratios}")
else:
print("Numerical test passed")
for comms_medium in ["fifo", "nccl"]:
print(f"\n{'=' * 60}")
print(f"Testing with comms_medium={comms_medium}")
print(f"{'=' * 60}\n")
for scenario in all_scenarios[:11]:
timing_steps = 256
gen_steps = scenario.ref_steps + timing_steps
print(f"Running scenario: {scenario} and timing {timing_steps} steps")
mismatch_ratios = []
test_mla_helix_distributed(
scenario,
comms_medium=comms_medium,
gen_steps=gen_steps,
mismatch_ratios=mismatch_ratios,
)
if any(mismatch > 0 for mismatch in mismatch_ratios):
print(f"Numerical test failed with mismatch ratios: {mismatch_ratios}")
else:
print("Numerical test passed")