From 8c1cfc872bcb077e7628b918e324976a599ba6b1 Mon Sep 17 00:00:00 2001 From: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com> Date: Tue, 23 Dec 2025 18:14:30 -0800 Subject: [PATCH] [TRTLLM-9493][feat] Custom AllToAll for helix parallelism (#9986) Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com> --- cpp/tensorrt_llm/kernels/helixAllToAll.cu | 683 ++++++++++++++++++ cpp/tensorrt_llm/kernels/helixAllToAll.h | 94 +++ cpp/tensorrt_llm/nanobind/thop/bindings.cpp | 6 + cpp/tensorrt_llm/pybind/thop/bindings.cpp | 6 + cpp/tensorrt_llm/thop/alltoallOp.cpp | 158 +++- tensorrt_llm/_mnnvl_utils.py | 155 ++-- .../_torch/custom_ops/cpp_custom_ops.py | 11 + tensorrt_llm/_torch/distributed/__init__.py | 8 +- tensorrt_llm/_torch/distributed/ops.py | 82 ++- tensorrt_llm/_torch/modules/attention.py | 70 +- .../accuracy/test_disaggregated_serving.py | 7 +- .../test_lists/qa/llm_function_core.txt | 9 +- .../test-db/l0_gb200_multi_gpus.yml | 9 +- .../unittest/_torch/modules/test_mla_helix.py | 53 +- 14 files changed, 1243 insertions(+), 108 deletions(-) create mode 100644 cpp/tensorrt_llm/kernels/helixAllToAll.cu create mode 100644 cpp/tensorrt_llm/kernels/helixAllToAll.h diff --git a/cpp/tensorrt_llm/kernels/helixAllToAll.cu b/cpp/tensorrt_llm/kernels/helixAllToAll.cu new file mode 100644 index 0000000000..09e38f7d48 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/helixAllToAll.cu @@ -0,0 +1,683 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/kernels/cudaAsyncOps.cuh" +#include "tensorrt_llm/kernels/helixAllToAll.h" +#include "tensorrt_llm/kernels/ll128Proto.cuh" +#include "tensorrt_llm/kernels/moeCommKernelsCommon.h" + +#include +#include +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels +{ + +namespace +{ + +// ============================================================================ +// Structure declarations and definitions +// ============================================================================ + +// ALIGN_256 is defined in moeCommKernelsCommon.h + +struct ALIGN_256 HelixFifoInfo +{ + volatile int64_t head; + volatile int64_t tail; +}; + +// ============================================================================ +// Helix-specific FIFO constants +// Note: Helix uses 128KB FIFO entries vs 256KB in FusedMoe +// ============================================================================ + +constexpr int HELIX_FIFO_DEPTH = 4; +constexpr int HELIX_FIFO_ENTRY_BYTES = 128 * 1024; +constexpr int HELIX_FIFO_TOTAL_BYTES = HELIX_FIFO_ENTRY_BYTES * HELIX_FIFO_DEPTH; +constexpr int HELIX_FIFO_ENTRY_128B_COUNT = HELIX_FIFO_ENTRY_BYTES / BYTES_PER_128B_BLOCK; +constexpr int HELIX_FIFO_TOTAL_U64 = HELIX_FIFO_TOTAL_BYTES / sizeof(uint64_t); + +// ============================================================================ +// Implementation-only structures +// ============================================================================ + +struct HelixPairInfo +{ + int senderRank; + int receiverRank; + int channel; + int runChannelCount; +}; + +// WARP_SIZE, WARP_MASK, and other constants are defined in moeCommKernelsCommon.h + +// ============================================================================ +// Helper Functions +// ============================================================================ + +__host__ __device__ inline int getFieldSize(HelixFieldInfo const& fieldInfo) +{ + return fieldInfo.elementCount * fieldInfo.elementSize; +} + +__host__ __device__ inline uint8_t* getPtr(HelixFieldInfo const& fieldInfo, int blockIdx) +{ + return fieldInfo.dataPtr + blockIdx * fieldInfo.stride; +} + +__device__ __forceinline__ void waitG2sAllFields(uint64_t* smemBar, uint32_t* phaseParity) +{ + cp_async_wait_group<0>(); + smemBarWait(smemBar, phaseParity); +} + +// Align size to 128 bytes +__host__ __device__ __forceinline__ int align128(int size) +{ + return align_up(size, BYTES_PER_128B_BLOCK); +} + +// ============================================================================ +// G2S (Global to Shared) Operations +// ============================================================================ + +__device__ __forceinline__ void g2sField( + HelixFieldInfo const& fieldInfo, int dataIndex, uint8_t* shmemBase, int shmemOffset, uint64_t* smemBar, int laneId) +{ + int copySize = getFieldSize(fieldInfo); + if (copySize > 0 && laneId == 0) + { + uint8_t* srcPtr = getPtr(fieldInfo, dataIndex); + uint8_t* dstPtr = shmemBase + shmemOffset; + cp_async_bulk_g2s(dstPtr, srcPtr, copySize, smemBar); + } +} + +template +__device__ __forceinline__ int g2sAllFields( + HelixFieldInfo const* fieldInfo, int dataIndex, uint8_t* shmemBase, uint64_t* smemBar, int laneId) +{ + int totalSize = 0; + + // Load field 0 (variable size half) + g2sField(fieldInfo[0], dataIndex, shmemBase, 0, smemBar, laneId); + int field0Size = getFieldSize(fieldInfo[0]); + totalSize += field0Size; + + // Load field 1 (single float2) + if constexpr (ALLOW_VARIABLE_FIELD1) + { + g2sField(fieldInfo[1], dataIndex, shmemBase, totalSize, smemBar, laneId); + totalSize += getFieldSize(fieldInfo[1]); + } + else + { + ldgsts<8>(reinterpret_cast(shmemBase + totalSize), + reinterpret_cast(getPtr(fieldInfo[1], dataIndex)), laneId == 0); + cp_async_commit_group(); + } + + return totalSize; +} + +// ============================================================================ +// S2G (Shared to Global) Operations +// ============================================================================ + +__device__ __forceinline__ void s2gField( + HelixFieldInfo const& fieldInfo, int dataIndex, uint8_t* shmemBase, int shmemOffset, int laneId) +{ + int copySize = getFieldSize(fieldInfo); + if (copySize > 0 && laneId == 0) + { + uint8_t* srcPtr = shmemBase + shmemOffset; + uint8_t* dstPtr = getPtr(fieldInfo, dataIndex); + cp_async_bulk_s2g(dstPtr, srcPtr, copySize); + } +} + +template +__device__ __forceinline__ void s2gAllFields( + HelixFieldInfo const* fieldInfo, int dataIndex, uint8_t* shmemBase, int laneId) +{ + int offset = 0; + + // Store field 0 (variable size half) + s2gField(fieldInfo[0], dataIndex, shmemBase, offset, laneId); + int field0Size = getFieldSize(fieldInfo[0]); + offset += field0Size; + + // Store field 1 (single float2) + if constexpr (ALLOW_VARIABLE_FIELD1) + { + s2gField(fieldInfo[1], dataIndex, shmemBase, offset, laneId); + offset += getFieldSize(fieldInfo[1]); + } + else + { + if (laneId == 0) + { + auto* srcPtr = reinterpret_cast(reinterpret_cast(shmemBase) + offset); + auto* dstPtr = reinterpret_cast(getPtr(fieldInfo[1], dataIndex)); + dstPtr[0] = srcPtr[0]; + } + } + cp_async_bulk_commit_group(); +} + +// ============================================================================ +// Workspace FIFO Operations +// ============================================================================ + +__device__ __forceinline__ uint64_t* getFifoBasePtr(HelixAllToAllParams const& params, HelixPairInfo const& pairInfo) +{ + // FIFO is physically located at receiver rank + int mappedMemoryRank = pairInfo.receiverRank; + int rankInsideMappedMemory = pairInfo.senderRank; + + auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64; + // Navigate to the right FIFO: [peer_rank][channel] + size_t fifoOffset = rankInsideMappedMemory * params.maxChannelCount * HELIX_FIFO_TOTAL_U64; + fifoOffset += pairInfo.channel * HELIX_FIFO_TOTAL_U64; + + return mappedMemory + fifoOffset; +} + +__device__ __forceinline__ HelixFifoInfo* getSenderHelixFifoInfo( + HelixAllToAllParams const& params, HelixPairInfo const& pairInfo) +{ + // SenderSideHelixFifoInfo is physically located at sender rank + int mappedMemoryRank = pairInfo.senderRank; + int rankInsideMappedMemory = pairInfo.receiverRank; + + auto* mappedMemory = reinterpret_cast(params.workspace + mappedMemoryRank * params.workspaceStrideInU64); + size_t fieldOffset = static_cast(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount; + mappedMemory += fieldOffset; + mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(HelixFifoInfo); + mappedMemory += pairInfo.channel * sizeof(HelixFifoInfo); + + return reinterpret_cast(mappedMemory); +} + +__device__ __forceinline__ HelixFifoInfo* getReceiverHelixFifoInfo( + HelixAllToAllParams const& params, HelixPairInfo const& pairInfo) +{ + // ReceiverSideHelixFifoInfo is physically located at receiver rank + int mappedMemoryRank = pairInfo.receiverRank; + int rankInsideMappedMemory = pairInfo.senderRank; + + auto* mappedMemory = reinterpret_cast(params.workspace + mappedMemoryRank * params.workspaceStrideInU64); + size_t fieldOffset = static_cast(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount; + fieldOffset += sizeof(HelixFifoInfo) * params.cpSize * params.maxChannelCount; + mappedMemory += fieldOffset; + mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(HelixFifoInfo); + mappedMemory += pairInfo.channel * sizeof(HelixFifoInfo); + + return reinterpret_cast(mappedMemory); +} + +__device__ __forceinline__ void startWorkspaceS2G( + uint64_t* fifoEntry, uint8_t* shmemBase, int send128ByteCount, int fifo128ByteOffset, int laneId) +{ + int copyByteCount = send128ByteCount * BYTES_PER_128B_BLOCK; + if (laneId == 0) + { + cp_async_bulk_s2g( + fifoEntry + fifo128ByteOffset * BYTES_PER_128B_BLOCK / sizeof(uint64_t), shmemBase, copyByteCount); + } + cp_async_bulk_commit_group(); +} + +__device__ __forceinline__ void startWorkspaceS2GReg( + uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int laneId) +{ + int copyInt4Count = send128ByteCount * BYTES_PER_128B_BLOCK / sizeof(int4); + int4* sharedMemoryInt4 = reinterpret_cast(sharedMemoryBase); + uint64_t* fifoPtr = fifoEntry + fifo128ByteOffset * UINT64_PER_128B_BLOCK; + int4* fifoPtrInt4 = reinterpret_cast(fifoPtr); +#pragma unroll 4 + for (int i = laneId; i < copyInt4Count; i += WARP_SIZE) + { + fifoPtrInt4[i] = sharedMemoryInt4[i]; + } +} + +__device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* shmemBase, uint64_t* fifoEntry, int allLoad128ByteCount, + int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int laneId) +{ + int copyByteCount = (allLoad128ByteCount - loaded128ByteCount) * BYTES_PER_128B_BLOCK; + if (laneId == 0) + { + cp_async_bulk_g2s(shmemBase + loaded128ByteCount * BYTES_PER_128B_BLOCK, + fifoEntry + (fifo128ByteOffset + loaded128ByteCount) * UINT64_PER_128B_BLOCK, copyByteCount, smemBar); + } + return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? copyByteCount : 0); +} + +// LL128Proto is now defined in ll128Proto.cuh + +// ============================================================================ +// Size helpers +// ============================================================================ + +// Compute total size needed for both fields +__host__ __device__ __forceinline__ int computeTotalUnpackedSize(HelixFieldInfo const* fields) +{ + int size = 0; + // Field 0: note it must be aligned to 16 bytes + size += align_up(getFieldSize(fields[0]), 16); + // Field 1: single float2 + size += align_up(getFieldSize(fields[1]), 16); + return align128(size); +} + +__host__ __device__ __forceinline__ int computeTotalPackedSize(HelixFieldInfo const* fields) +{ + // because field 0 must be aligned to 16 bytes, this is the same as unpacked + return computeTotalUnpackedSize(fields); +} + +__host__ __device__ __forceinline__ int computeProtoTransferSize(HelixFieldInfo const* fields) +{ + return LL128Proto::computeProtoTransfer128ByteAlignedSize(computeTotalPackedSize(fields)); +} + +// ============================================================================ +// Main All-to-All Kernel +// ============================================================================ + +template +__global__ void helixAllToAllKernel(HelixAllToAllParams params) +{ + extern __shared__ uint8_t allWarpShmem[]; + __shared__ uint64_t allWarpSmemBar[MAX_GROUP_COUNT_PER_BLOCK]; + + bool isSender = (blockIdx.z == 0); + // Each warp is a group handling a different peer rank + int group = __shfl_sync(WARP_MASK, threadIdx.y, 0); + int laneId = threadIdx.x % WARP_SIZE; + int runChannelCount = gridDim.y; + + // Compute peer rank: blockIdx.x determines which set of peers, group + // determines which peer in that set + int peerRank = blockIdx.x * blockDim.y + group; + + if (peerRank >= params.cpSize) + { + return; + } + + // Setup pair info for this communication + HelixPairInfo pairInfo; + pairInfo.channel = blockIdx.y; + pairInfo.runChannelCount = runChannelCount; + pairInfo.senderRank = isSender ? params.cpRank : peerRank; + pairInfo.receiverRank = isSender ? peerRank : params.cpRank; + + // Initialize barrier for this group + initSmemBar(&allWarpSmemBar[group], laneId); + uint32_t phaseParity = 0; + + // Get shared memory for this group + int singlePackedSize = computeTotalPackedSize(params.sendFields); + int singlePacked128ByteCount = singlePackedSize / BYTES_PER_128B_BLOCK; + int singleUnpackedSize = computeTotalUnpackedSize(params.sendFields); + int singleProtoTransferSize = computeProtoTransferSize(params.sendFields); + int singleProtoTransfer128ByteCount = singleProtoTransferSize / BYTES_PER_128B_BLOCK; + int singleShmSize = std::max(singleUnpackedSize, singleProtoTransferSize); + uint8_t* shmem = allWarpShmem + group * singleShmSize; + + // Get FIFO pointers + uint64_t* fifoBase = getFifoBasePtr(params, pairInfo); + HelixFifoInfo* senderFifo = getSenderHelixFifoInfo(params, pairInfo); + HelixFifoInfo* receiverFifo = getReceiverHelixFifoInfo(params, pairInfo); + + int fifoEntry128ByteIndexBase = HELIX_FIFO_ENTRY_128B_COUNT; + int fifoEntryIndex = -1; + + // regardless of sender or receiver, we wait for the previous kernel here + // receiver blocks do not need to wait at all, but they should not start + // to stress the memory system regardless +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + if (isSender) + { + // sender blocks should trigger next kernel immediately, s.t. they + // do not block the next kernel from starting +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + + // Sender logic: send data from cpRank's slice to peerRank + int64_t head = senderFifo->head; + int64_t tail = senderFifo->tail; + + // Each channel processes entries with stride + // Start at channel index, increment by total channel count + for (int entryIdx = pairInfo.channel; entryIdx < params.entryCount; entryIdx += runChannelCount) + { + + // dataIndex points to the data for peerRank in this entry + int dataIndex = entryIdx * params.cpSize + peerRank; + + // Load data from global to shared, then arrive on barrier + int loadedSize = g2sAllFields( + params.sendFields, dataIndex, shmem, &allWarpSmemBar[group], laneId); + uint64_t arriveState = mbarrier_arrive_expect_tx(&allWarpSmemBar[group], laneId == 0 ? loadedSize : 0); + + // update FIFO entry index and head if needed + if (fifoEntry128ByteIndexBase + singleProtoTransfer128ByteCount > HELIX_FIFO_ENTRY_128B_COUNT) + { + if (fifoEntryIndex >= 0) + { + head++; + __syncwarp(); + senderFifo->head = head; + } + fifoEntryIndex = head % HELIX_FIFO_DEPTH; + fifoEntry128ByteIndexBase = 0; + while (tail + HELIX_FIFO_DEPTH <= head) + { + tail = senderFifo->tail; + } + __syncwarp(); + } + + // wait for data to be loaded into shared memory + waitG2sAllFields(&allWarpSmemBar[group], &phaseParity); + // note: we don't need to pack anything, fields are already packed in + // shared memory + + LL128Proto::protoPack(shmem, head, singlePacked128ByteCount, fifoEntry128ByteIndexBase, laneId); + + uint64_t* fifoEntry = fifoBase + fifoEntryIndex * (HELIX_FIFO_ENTRY_BYTES / sizeof(uint64_t)); + + // Copy from shared to workspace FIFO + startWorkspaceS2GReg(fifoEntry, shmem, singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase, laneId); + + fifoEntry128ByteIndexBase += singleProtoTransfer128ByteCount; + + // ensure that we can over-write shmem in next iteration + // (it must be fully read by all threads when doing S2G above) + __syncwarp(); + } + if (fifoEntry128ByteIndexBase > 0) + { + head++; + senderFifo->head = head; + } + } + else + { + // Receiver logic: receive data from peerRank to cpRank's slice + int64_t tail = receiverFifo->tail; + bool needRelease = false; + + // Each channel processes entries with stride + // Start at channel index, increment by total channel count + for (int entryIdx = pairInfo.channel; entryIdx < params.entryCount; entryIdx += runChannelCount) + { + // receiver blocks should trigger next kernel at last iteration + // note: some blocks might not even go into this for-loop, but they + // would exit which is equivalent to the pre-exit trigger +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if (entryIdx + runChannelCount >= params.entryCount) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + // dataIndex points to where we receive data from peerRank in this entry + int dataIndex = entryIdx * params.cpSize + peerRank; + int loaded128ByteCount = 0; + + if (fifoEntry128ByteIndexBase + singleProtoTransfer128ByteCount > HELIX_FIFO_ENTRY_128B_COUNT) + { + if (fifoEntryIndex >= 0) + { + tail++; + needRelease = true; + } + fifoEntryIndex = tail % HELIX_FIFO_DEPTH; + fifoEntry128ByteIndexBase = 0; + // receiver doesn't need to wait on FIFO entry being readable: it's + // always readable + __syncwarp(); + } + + uint64_t* fifoEntry = fifoBase + fifoEntryIndex * (HELIX_FIFO_ENTRY_BYTES / sizeof(uint64_t)); + while (loaded128ByteCount < singleProtoTransfer128ByteCount) + { + startWorkspaceG2S(shmem, fifoEntry, singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase, + loaded128ByteCount, &allWarpSmemBar[group], laneId); + if (needRelease) + { + receiverFifo->tail = tail; + senderFifo->tail = tail; + needRelease = false; + } + smemBarWait(&allWarpSmemBar[group], &phaseParity); + loaded128ByteCount += LL128Proto::template checkDataReceivedInShm(shmem, tail, + singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId); + } + + LL128Proto::protoUnpack( + shmem, tail, singlePacked128ByteCount, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId); + + // note: fields are already unpacked in shared memory + s2gAllFields(params.recvFields, dataIndex, shmem, laneId); + // wait for data to be read from shared memory + cp_async_bulk_wait_group_read<0>(); + + // note: LL128Proto doesn't need rearm + // rearmFifoBuffer(); + fifoEntry128ByteIndexBase += singleProtoTransfer128ByteCount; + } + if (fifoEntry128ByteIndexBase > 0) + { + tail++; + receiverFifo->tail = tail; + senderFifo->tail = tail; + } + } +} + +// ============================================================================ +// Compute actual channel count +// ============================================================================ + +struct hash_cache_key +{ + size_t operator()(std::tuple const& x) const + { + return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x); + } +}; + +template +std::tuple computeChannelAndGroupCount(int cpSize, HelixFieldInfo const* fields) +{ + static std::unordered_map, std::tuple, hash_cache_key> cache; + int deviceId = 0; + TLLM_CUDA_CHECK(cudaGetDevice(&deviceId)); + int singleShmSize = std::max(computeTotalUnpackedSize(fields), computeProtoTransferSize(fields)); + auto key = std::make_tuple(deviceId, cpSize, singleShmSize); + auto it = cache.find(key); + if (it != cache.end()) + { + return it->second; + } + + int maxGroupCountPerCta = std::min(cpSize, MAX_GROUP_COUNT_PER_BLOCK); + int groupCountPerCta = maxGroupCountPerCta; // Start with max + int totalDynamicShmemSize = singleShmSize * groupCountPerCta; + int maxDynamicShmSize = 0; + TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxDynamicShmSize, cudaDevAttrMaxSharedMemoryPerBlockOptin, deviceId)); + + while (totalDynamicShmemSize > maxDynamicShmSize) + { + groupCountPerCta--; + totalDynamicShmemSize = singleShmSize * groupCountPerCta; + } + + TLLM_CHECK_WITH_INFO(totalDynamicShmemSize <= maxDynamicShmSize, "Single packed size %d exceeds limit %d", + singleShmSize, maxDynamicShmSize); + + // Set shared memory attribute if needed + if (totalDynamicShmemSize > 48 * 1024) + { + TLLM_CUDA_CHECK(cudaFuncSetAttribute(helixAllToAllKernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, totalDynamicShmemSize)); + } + + int blockCountPerChannel = ceil_div(cpSize, groupCountPerCta); + blockCountPerChannel *= 2; // for send and recv + + int smCount = 0; + TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId)); + // TODO: we might only want to use half the SMs to overlap with other kernels. + // note that overlap with FMHA is almost impossible because it must use + // all SMs and probably uses >50% shmem per SM. + // overlap with the subsequent BMM / out proj GEMMs might be possible, + // so we need experiments to see whether it makes sense. + int channelCount = std::max(smCount / blockCountPerChannel, 1); + auto value = std::make_tuple(channelCount, groupCountPerCta, totalDynamicShmemSize); + cache[key] = value; + return value; +} + +// ============================================================================ +// Host Launch Function +// ============================================================================ + +template +void launchHelixAllToAllImpl(HelixAllToAllParams const& params, cudaStream_t stream) +{ + int maxChannelCount = computeHelixMaxChannelCount(params.cpSize); + TLLM_CHECK_WITH_INFO(params.maxChannelCount == maxChannelCount, + "maxChannelCount %d does not match computed maxChannelCount %d", params.maxChannelCount, maxChannelCount); + auto [channelCount, groupCountPerCta, totalDynamicShmemSize] + = computeChannelAndGroupCount(params.cpSize, params.sendFields); + if (params.channelCount > 0) + { + channelCount = params.channelCount; + TLLM_CHECK_WITH_INFO(channelCount <= maxChannelCount, "channelCount %d exceeds maxChannelCount %d", + channelCount, maxChannelCount); + } + + // Compute grid dimensions + // grid.x = blocks per channel (how many blocks needed to cover all peer + // ranks) grid.y = number of channels (parallel channels) grid.z = 2 (sender + // and receiver) + int ctaPerChannel = ceil_div(params.cpSize, groupCountPerCta); + + auto* kernel_instance = &helixAllToAllKernel; + cudaLaunchConfig_t config; + config.gridDim = dim3(ctaPerChannel, channelCount, 2); + config.blockDim = dim3(WARP_SIZE, groupCountPerCta); + config.dynamicSmemBytes = totalDynamicShmemSize; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = common::getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_instance, params)); +} + +} // anonymous namespace + +// ============================================================================ +// Public API Functions +// ============================================================================ + +int computeHelixMaxChannelCount(int cpSize, int smCount) +{ + if (smCount == 0) + { + int deviceId = 0; + TLLM_CUDA_CHECK(cudaGetDevice(&deviceId)); + TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId)); + } + + int blockCountPerChannel = ceil_div(cpSize, MAX_GROUP_COUNT_PER_BLOCK); + blockCountPerChannel *= 2; // for send and recv + + int preferredChannel = smCount / blockCountPerChannel; + return std::max(preferredChannel, 1); // at least one channel +} + +size_t computeHelixWorkspaceSizePerRank(int cpSize) +{ + static int maxChannelCount = 0; + if (maxChannelCount == 0) + { + maxChannelCount = computeHelixMaxChannelCount(cpSize); + } + + // FIFO buffers: cpSize * channelCount pairs + size_t fifoSize = static_cast(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount; + + // Sender and receiver FIFO info structures + size_t senderInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount; + size_t receiverInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount; + + return fifoSize + senderInfoSize + receiverInfoSize; +} + +void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream) +{ + if (allowVariableField1) + { + launchHelixAllToAllImpl(params, stream); + } + else + { + launchHelixAllToAllImpl(params, stream); + } +} + +// ============================================================================ +// Workspace Initialization +// ============================================================================ + +void initializeHelixWorkspace(uint64_t* local_workspace_ptr, int cpSize, cudaStream_t stream) +{ + int maxChannelCount = computeHelixMaxChannelCount(cpSize); + // Calculate sizes with channel dimension + size_t fifoSize = static_cast(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount; + size_t senderInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount; + size_t receiverInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount; + + // Initialize FIFO buffers to 0xFFFFFFFF (-1 for signed integer types) + TLLM_CUDA_CHECK(cudaMemsetAsync(local_workspace_ptr, 0xFF, fifoSize, stream)); + + // Initialize sender and receiver info to zero (single call for both) + uint8_t* infoPtr = reinterpret_cast(local_workspace_ptr) + fifoSize; + TLLM_CUDA_CHECK(cudaMemsetAsync(infoPtr, 0, senderInfoSize + receiverInfoSize, stream)); +} + +} // namespace kernels + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/helixAllToAll.h b/cpp/tensorrt_llm/kernels/helixAllToAll.h new file mode 100644 index 0000000000..c35634133f --- /dev/null +++ b/cpp/tensorrt_llm/kernels/helixAllToAll.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tensorrt_llm/common/config.h" + +#include + +#include +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels +{ + +struct HelixFieldInfo +{ + uint8_t* dataPtr; + int elementCount; // Number of elements (e.g., kv_lora_rank for field 0, 1 for + // field 1) + int elementSize; // Size of each element in bytes (2 for half, 8 for float2) + int stride; // Stride between rows in bytes +}; + +struct HelixAllToAllParams +{ + HelixFieldInfo sendFields[2]; + HelixFieldInfo recvFields[2]; + int entryCount; // Number of entries per peer rank to process + uint64_t* workspace; + int workspaceStrideInU64; + int cpRank; + int cpSize; + int channelCount; // use 0 to auto-compute + int maxChannelCount; +}; + +// ============================================================================ +// Workspace Management Functions +// ============================================================================ + +/** + * Compute number of channels for communication based on cpSize. + * + * @param cpSize Number of context parallel ranks + * @param smCount Number of SMs available (0 = auto-detect) + * @return Number of channels to use + */ +int computeHelixMaxChannelCount(int cpSize, int smCount = 0); + +/** + * Compute the workspace size required per rank for the all-to-all operation. + * + * @param cpSize Number of context parallel ranks + * @return Size in bytes + */ +size_t computeHelixWorkspaceSizePerRank(int cpSize); + +/** + * Initialize workspace memory for a given rank. + * Should be called once during setup. + * + * @param workspace Pointer to workspace memory (per-rank view) + * @param cpSize Number of context parallel ranks + * @param stream CUDA stream for asynchronous operations + */ +void initializeHelixWorkspace(uint64_t* workspace, int cpSize, cudaStream_t stream); + +/** + * Launch the helix all-to-all kernel. + * + * @param params Kernel parameters including field info and workspace + * @param allowVariableField1 Whether to allow variable field 1 + * @param stream CUDA stream for kernel launch + */ +void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream); + +} // namespace kernels + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp index 60e0d00939..6577d1cf18 100644 --- a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -73,5 +74,10 @@ void initBindings(nb::module_& m) nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt, nb::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation", nb::call_guard()); + + m.def( + "get_helix_workspace_size_per_rank", + [](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); }, + nb::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes"); } } // namespace tensorrt_llm::nanobind::thop diff --git a/cpp/tensorrt_llm/pybind/thop/bindings.cpp b/cpp/tensorrt_llm/pybind/thop/bindings.cpp index 7c017a10a6..8cdc8a9982 100644 --- a/cpp/tensorrt_llm/pybind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/thop/bindings.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -73,5 +74,10 @@ void initBindings(pybind11::module_& m) py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt, py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation", py::call_guard()); + + m.def( + "get_helix_workspace_size_per_rank", + [](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); }, + py::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes"); } } // namespace tensorrt_llm::pybind::thop diff --git a/cpp/tensorrt_llm/thop/alltoallOp.cpp b/cpp/tensorrt_llm/thop/alltoallOp.cpp index 61c09466db..8a775b1cf6 100644 --- a/cpp/tensorrt_llm/thop/alltoallOp.cpp +++ b/cpp/tensorrt_llm/thop/alltoallOp.cpp @@ -16,19 +16,12 @@ */ #include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/kernels/helixAllToAll.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" +#include "tensorrt_llm/thop/thUtils.h" -#include -#include -#include -#include -#include -#include #include -#if ENABLE_MULTI_DEVICE -#include -#endif // ENABLE_MULTI_DEVICE TRTLLM_NAMESPACE_BEGIN @@ -119,6 +112,145 @@ std::vector alltoall_helix( #endif // ENABLE_MULTI_DEVICE } +/** + * Helix All-to-All operation with two fields. + * + * Input tensors have shape [..., cp_size, kv_lora_rank] for partial_o and [..., + * cp_size, 2] for softmax_stats. The operation exchanges data along the cp_size + * dimension across all ranks. + * + * @param partial_o Field 0 tensor (half precision, shape [..., cp_size, + * kv_lora_rank]) + * @param softmax_stats Field 1 tensor (float32, shape [..., cp_size, 2]) + * @param workspace Workspace tensor (uint64, strided across ranks) + * @param cp_rank Current context parallel rank + * @param cp_size Total number of context parallel ranks + * @return tuple of (partial_o_out, softmax_stats_out) with same shapes as inputs + */ +std::tuple alltoall_helix_native( + torch::Tensor partial_o, torch::Tensor softmax_stats, torch::Tensor workspace, int64_t cp_rank, int64_t cp_size) +{ + + // Input validation + CHECK_TH_CUDA(partial_o); + CHECK_TH_CUDA(softmax_stats); + CHECK_TH_CUDA(workspace); + CHECK_CONTIGUOUS(partial_o); + CHECK_CONTIGUOUS(softmax_stats); + + // Type checks + TORCH_CHECK(partial_o.scalar_type() == at::ScalarType::Half || partial_o.scalar_type() == at::ScalarType::BFloat16, + "partial_o must be half or bfloat16"); + CHECK_TYPE(softmax_stats, at::ScalarType::Float); + CHECK_TYPE(workspace, at::ScalarType::UInt64); + + // Shape validation + TORCH_CHECK(partial_o.dim() >= 2, "partial_o must have at least 2 dimensions"); + TORCH_CHECK(softmax_stats.dim() >= 2, "softmax_stats must have at least 2 dimensions"); + TORCH_CHECK( + partial_o.dim() == softmax_stats.dim(), "partial_o and softmax_stats must have same number of dimensions"); + + // Get dimensions + int kv_lora_rank = partial_o.size(-1); + TORCH_CHECK(partial_o.size(-2) == cp_size && softmax_stats.size(-2) == cp_size, + "partial_o/softmax_stats second-to-last dimension must equal cp_size"); + TORCH_CHECK(softmax_stats.size(-1) % 2 == 0 && softmax_stats.size(-1) >= 2, + "softmax_stats last dimension must be divisible by 2 (float2)"); + bool allowVariableField1 = softmax_stats.size(-1) > 2; + + // Check that leading dimensions match + for (int i = 0; i < partial_o.dim() - 2; i++) + { + TORCH_CHECK(partial_o.size(i) == softmax_stats.size(i), + "partial_o and softmax_stats must have matching dimensions except last two"); + } + TORCH_CHECK(partial_o.size(-1) * partial_o.element_size() % 16 == 0, "partial_o must be aligned to 16 bytes"); + + TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D (strided across ranks)"); + TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows"); + + // Calculate entry count (product of all dimensions before cp_size) + // This is the number of entries to process per peer rank + int entry_count = 1; + for (int i = 0; i < partial_o.dim() - 2; i++) + { + entry_count *= partial_o.size(i); + } + + // Reshape to 3D: [entry_count, cp_size, feature_dim] + torch::Tensor partial_o_3d = partial_o.reshape({entry_count, cp_size, kv_lora_rank}); + torch::Tensor softmax_stats_3d = softmax_stats.reshape({entry_count, cp_size, softmax_stats.size(-1)}); + + // Allocate output tensors (same shape as input) + torch::Tensor partial_o_out = torch::empty_like(partial_o); + torch::Tensor softmax_stats_out = torch::empty_like(softmax_stats); + + torch::Tensor partial_o_out_3d = partial_o_out.reshape({entry_count, cp_size, kv_lora_rank}); + torch::Tensor softmax_stats_out_3d = softmax_stats_out.reshape({entry_count, cp_size, softmax_stats.size(-1)}); + + // Setup parameters + tensorrt_llm::kernels::HelixAllToAllParams params; + + // Field 0 (variable size half) + params.sendFields[0].dataPtr = reinterpret_cast(partial_o_3d.data_ptr()); + params.sendFields[0].elementCount = kv_lora_rank; + params.sendFields[0].elementSize = partial_o.element_size(); + params.sendFields[0].stride = partial_o_3d.stride(1) * partial_o.element_size(); + + params.recvFields[0].dataPtr = reinterpret_cast(partial_o_out_3d.data_ptr()); + params.recvFields[0].elementCount = kv_lora_rank; + params.recvFields[0].elementSize = partial_o.element_size(); + params.recvFields[0].stride = partial_o_out_3d.stride(1) * partial_o.element_size(); + + // Field 1 (single float2) + params.sendFields[1].dataPtr = reinterpret_cast(softmax_stats_3d.data_ptr()); + params.sendFields[1].elementCount = softmax_stats.size(-1); + params.sendFields[1].elementSize = softmax_stats.element_size(); + params.sendFields[1].stride = softmax_stats_3d.stride(1) * softmax_stats.element_size(); + + params.recvFields[1].dataPtr = reinterpret_cast(softmax_stats_out_3d.data_ptr()); + params.recvFields[1].elementCount = softmax_stats.size(-1); + params.recvFields[1].elementSize = softmax_stats.element_size(); + params.recvFields[1].stride = softmax_stats_out_3d.stride(1) * softmax_stats.element_size(); + + // Entry count and workspace + params.entryCount = entry_count; + params.workspace = workspace.data_ptr(); + params.workspaceStrideInU64 = workspace.stride(0); + + // CP info + params.cpRank = cp_rank; + params.cpSize = cp_size; + params.channelCount = 0; // auto-compute + params.maxChannelCount = tensorrt_llm::kernels::computeHelixMaxChannelCount(cp_size); + + // Launch kernel + auto stream = at::cuda::getCurrentCUDAStream(); + tensorrt_llm::kernels::launchHelixAllToAll(params, allowVariableField1, stream); + + return std::make_tuple(partial_o_out, softmax_stats_out); +} + +/** + * Initialize workspace for helix all-to-all + */ +void initialize_helix_workspace(torch::Tensor workspace, int64_t cp_rank, int64_t cp_size) +{ + CHECK_TH_CUDA(workspace); + CHECK_TYPE(workspace, at::ScalarType::UInt64); + TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D"); + TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows"); + TORCH_CHECK(cp_rank >= 0 && cp_rank < cp_size, "cp_rank must be in [0, cp_size)"); + + auto stream = at::cuda::getCurrentCUDAStream(); + uint64_t* global_workspace_ptr = workspace.data_ptr(); + uint64_t* local_workspace_ptr = workspace[cp_rank].data_ptr(); + TORCH_CHECK(local_workspace_ptr == global_workspace_ptr + cp_rank * workspace.stride(0), + "local_workspace_ptr must be at the correct offset in the global " + "workspace"); + tensorrt_llm::kernels::initializeHelixWorkspace(local_workspace_ptr, cp_size, stream); +} + } // namespace torch_ext TRTLLM_NAMESPACE_END @@ -126,9 +258,17 @@ TRTLLM_NAMESPACE_END TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def("alltoall_helix(Tensor[] input_list, int[] group, int? num_lists) -> Tensor[]"); + m.def( + "alltoall_helix_native(Tensor partial_o, Tensor softmax_stats, Tensor(a!) workspace, int " + "cp_rank, int cp_size) -> (Tensor, Tensor)"); + m.def( + "initialize_helix_workspace(Tensor(a!) workspace, int cp_rank, int cp_size) " + "-> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("alltoall_helix", &tensorrt_llm::torch_ext::alltoall_helix); + m.impl("alltoall_helix_native", &tensorrt_llm::torch_ext::alltoall_helix_native); + m.impl("initialize_helix_workspace", &tensorrt_llm::torch_ext::initialize_helix_workspace); } diff --git a/tensorrt_llm/_mnnvl_utils.py b/tensorrt_llm/_mnnvl_utils.py index 6d981fc1bd..5d168447f9 100644 --- a/tensorrt_llm/_mnnvl_utils.py +++ b/tensorrt_llm/_mnnvl_utils.py @@ -51,38 +51,54 @@ def _check_cu_result(cu_func_ret): class MnnvlMemory: + """MNNVL memory management for tensor parallel (TP) operations.""" + + # Shared across all subclasses (global/device state). initialized: bool = False - - current_mem_offset: int = 0 - current_rank_stride: int = 0 # stride for ranks and also address space size. - current_start_address: int = 0 - - # allocation granularity allocation_granularity: int = 0 - - # fabric address page size (512 MB) - fabric_page_size: int = 1 << 29 - - # MPI communicator - comm = None - + fabric_page_size: int = 1 << 29 # 512 MB. dev_id: int = None + # Per-class state attributes. These will be auto-initialized for each subclass + # to avoid polluting the parent class's state. Use callable (e.g., dict) for mutable defaults. + _per_class_attrs = { + "current_mem_offset": 0, + "current_rank_stride": 0, # stride for ranks and also address space size. + "current_start_address": 0, + "comm": None, # MPI communicator. + "allocated_map": dict, # callable for fresh dict. + "address_refcnt": dict, # callable for fresh dict. + } + + # Initialize per-class state for the base class. + current_mem_offset: int = 0 + current_rank_stride: int = 0 + current_start_address: int = 0 + comm = None allocated_map = {} address_refcnt = {} + def __init_subclass__(cls, **kwargs): + """Auto-initialize per-class attributes for each subclass to avoid sharing state with parent.""" + super().__init_subclass__(**kwargs) + for attr, default in cls._per_class_attrs.items(): + if callable(default): + setattr(cls, attr, default()) # e.g., dict() creates a fresh dict. + else: + setattr(cls, attr, default) + def __init__(self, mapping: Mapping, size: int): self.mapping = mapping self.segment_size = size - self.ptr, self.rank_stride = MnnvlMemory.open_mnnvl_memory(self.mapping, size) + self.ptr, self.rank_stride = type(self).open_mnnvl_memory(self.mapping, size) def __del__(self): if not sys.is_finalizing(): if hasattr(self, "ptr"): - MnnvlMemory.close_mnnvl_memory(self.ptr) + type(self).close_mnnvl_memory(self.ptr) def as_torch_strided_tensor(self, dtype): - num_segments = MnnvlMemory.comm.Get_size() + num_segments = type(self).comm.Get_size() return pack_strided_memory( self.ptr, self.segment_size, self.rank_stride, num_segments, dtype, MnnvlMemory.dev_id ) @@ -99,16 +115,17 @@ class MnnvlMemory: pynvml.nvmlInit() MnnvlMemory.initialized = True - @staticmethod - def get_comm(mapping: Mapping): - if MnnvlMemory.comm is not None: - return MnnvlMemory.comm + @classmethod + def get_comm(cls, mapping: Mapping): + """Get TP-based communicator (ranks grouped by PP+CP+MOE_TP, ordered by TP rank).""" + if cls.comm is not None: + return cls.comm comm = mpi_comm().Split( (mapping.pp_rank * mapping.cp_size + mapping.cp_rank) * mapping.moe_tp_size + mapping.moe_tp_rank, mapping.tp_rank, ) - MnnvlMemory.comm = comm + cls.comm = comm return comm @staticmethod @@ -148,23 +165,26 @@ class MnnvlMemory: MnnvlMemory.allocation_granularity = granularity return MnnvlMemory.allocation_granularity - @staticmethod - def new_mnnvl_memory_address(mapping: Mapping, size: int): + @classmethod + def new_mnnvl_memory_address(cls, mapping: Mapping, size: int): page_count = (size + MnnvlMemory.fabric_page_size - 1) // MnnvlMemory.fabric_page_size current_rank_stride = page_count * MnnvlMemory.fabric_page_size - logger.info(f"[MnnvlMemory] creating address with stride={current_rank_stride}") - comm = MnnvlMemory.get_comm(mapping) + logger.info(f"[{cls.__name__}] creating address with stride={current_rank_stride}") + comm = cls.get_comm(mapping) comm_size = comm.Get_size() address_size = current_rank_stride * comm_size ptr = _check_cu_result( cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0) ) - MnnvlMemory.current_start_address = int(ptr) - MnnvlMemory.current_rank_stride = current_rank_stride - MnnvlMemory.current_mem_offset = 0 + cls.current_start_address = int(ptr) + cls.current_rank_stride = current_rank_stride + cls.current_mem_offset = 0 + + @classmethod + def open_mnnvl_memory(cls, mapping: Mapping, size: int): + # Ensure MnnvlMemory is initialized (for dev_id and allocation_granularity) + MnnvlMemory.initialize() - @staticmethod - def open_mnnvl_memory(mapping: Mapping, size: int): dev = _check_cu_result(cuda.cuCtxGetDevice()) dev_id = int(dev) if MnnvlMemory.dev_id is None: @@ -172,7 +192,7 @@ class MnnvlMemory: assert dev_id == MnnvlMemory.dev_id, ( f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" ) - comm = MnnvlMemory.get_comm(mapping) + comm = cls.get_comm(mapping) comm_rank = comm.Get_rank() comm_size = comm.Get_size() all_rank_allocate_sizes = comm.allgather(size) @@ -181,10 +201,10 @@ class MnnvlMemory: granularity = MnnvlMemory.get_allocation_granularity(dev_id) aligned_size = (size + granularity - 1) // granularity * granularity - if MnnvlMemory.current_mem_offset + aligned_size > MnnvlMemory.current_rank_stride: - MnnvlMemory.new_mnnvl_memory_address(mapping, aligned_size) + if cls.current_mem_offset + aligned_size > cls.current_rank_stride: + cls.new_mnnvl_memory_address(mapping, aligned_size) - assert MnnvlMemory.current_mem_offset + aligned_size <= MnnvlMemory.current_rank_stride + assert cls.current_mem_offset + aligned_size <= cls.current_rank_stride allocation_prop = MnnvlMemory.get_allocation_prop(dev_id) allocated_mem_handle = _check_cu_result( @@ -245,9 +265,7 @@ class MnnvlMemory: for i, remote_handle_data in enumerate(all_handles_data): rank_ptr = ( - MnnvlMemory.current_start_address - + MnnvlMemory.current_rank_stride * i - + MnnvlMemory.current_mem_offset + cls.current_start_address + cls.current_rank_stride * i + cls.current_mem_offset ) if i == comm_rank: # Local memory mapping @@ -265,44 +283,44 @@ class MnnvlMemory: _check_cu_result(cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1)) - ptr = MnnvlMemory.current_start_address + MnnvlMemory.current_mem_offset - stride = MnnvlMemory.current_rank_stride - MnnvlMemory.allocated_map[ptr] = ( + ptr = cls.current_start_address + cls.current_mem_offset + stride = cls.current_rank_stride + cls.allocated_map[ptr] = ( mapping, aligned_size, mem_handles, - MnnvlMemory.current_start_address, - MnnvlMemory.current_rank_stride, - MnnvlMemory.current_mem_offset, + cls.current_start_address, + cls.current_rank_stride, + cls.current_mem_offset, ) - MnnvlMemory.address_refcnt[MnnvlMemory.current_start_address] = ( - MnnvlMemory.address_refcnt.get(MnnvlMemory.current_start_address, 0) + 1 + cls.address_refcnt[cls.current_start_address] = ( + cls.address_refcnt.get(cls.current_start_address, 0) + 1 ) - MnnvlMemory.current_mem_offset += aligned_size + cls.current_mem_offset += aligned_size return ptr, stride - @staticmethod - def close_mnnvl_memory(ptr: int): + @classmethod + def close_mnnvl_memory(cls, ptr: int): mapping, aligned_size, mem_handles, start_address, rank_stride, address_offset = ( - MnnvlMemory.allocated_map.pop(ptr) + cls.allocated_map.pop(ptr) ) - comm = MnnvlMemory.get_comm(mapping) + comm = cls.get_comm(mapping) comm_size = comm.Get_size() for i in range(comm_size): rank_ptr = start_address + i * rank_stride + address_offset _check_cu_result(cuda.cuMemUnmap(rank_ptr, aligned_size)) _check_cu_result(cuda.cuMemRelease(mem_handles[i])) - MnnvlMemory.address_refcnt[start_address] -= 1 + cls.address_refcnt[start_address] -= 1 - if MnnvlMemory.address_refcnt[start_address] == 0: - MnnvlMemory.address_refcnt.pop(start_address) + if cls.address_refcnt[start_address] == 0: + cls.address_refcnt.pop(start_address) device_ptr = cuda.CUdeviceptr(start_address) _check_cu_result(cuda.cuMemAddressFree(device_ptr, comm_size * rank_stride)) - if start_address == MnnvlMemory.current_start_address: - MnnvlMemory.current_start_address = 0 - MnnvlMemory.current_rank_stride = 0 - MnnvlMemory.current_mem_offset = 0 + if start_address == cls.current_start_address: + cls.current_start_address = 0 + cls.current_rank_stride = 0 + cls.current_mem_offset = 0 @staticmethod def support_nvlink(need_all_up: bool = True): @@ -338,6 +356,29 @@ class MnnvlMemory: return support_nvlink_and_all_up +class HelixCpMnnvlMemory(MnnvlMemory): + """MNNVL memory management for Helix context parallel (CP) operations. + + Per-class state (current_mem_offset, comm, allocated_map, etc.) is automatically + initialized via __init_subclass__ in the parent class, ensuring this class has + its own isolated state separate from MnnvlMemory. + """ + + @classmethod + def get_comm(cls, mapping: Mapping): + """Get CP-based communicator (ranks grouped by PP+TP+MOE_TP, ordered by CP rank).""" + if cls.comm is not None: + return cls.comm + comm = mpi_comm().Split( + mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size + + mapping.tp_rank * mapping.moe_tp_size + + mapping.moe_tp_rank, + mapping.cp_rank, + ) + cls.comm = comm + return comm + + @dataclass class MoEAlltoallInfo: local_gather_indices: torch.Tensor diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 68b114a8d7..6fd705f633 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -752,6 +752,17 @@ def _register_fake(): for i in range(0, len(input_list), num_ranks) ] + @torch.library.register_fake("trtllm::alltoall_helix_native") + def _(partial_o, softmax_stats, workspace, cp_rank, cp_size): + # Returns outputs with same shapes as inputs + return partial_o.new_empty(partial_o.shape), softmax_stats.new_empty( + softmax_stats.shape) + + @torch.library.register_fake("trtllm::initialize_helix_workspace") + def _(workspace, cp_rank, cp_size): + # This op initializes workspace in-place and returns nothing + return None + @torch.library.register_fake("trtllm::helix_post_process") def _(gathered_o, gathered_stats, scale): return gathered_o.new_empty(*gathered_o.shape[1:]) diff --git a/tensorrt_llm/_torch/distributed/__init__.py b/tensorrt_llm/_torch/distributed/__init__.py index 6049ebeb6e..b8bfe4ffdf 100644 --- a/tensorrt_llm/_torch/distributed/__init__.py +++ b/tensorrt_llm/_torch/distributed/__init__.py @@ -2,9 +2,10 @@ from tensorrt_llm.functional import AllReduceFusionOp from .communicator import Distributed, MPIDist, TorchDist from .moe_alltoall import MoeAlltoAll -from .ops import (AllReduce, AllReduceParams, AllReduceStrategy, MoEAllReduce, - MoEAllReduceParams, allgather, alltoall_helix, cp_allgather, - reducescatter, userbuffers_allreduce_finalize) +from .ops import (AllReduce, AllReduceParams, AllReduceStrategy, + HelixAllToAllNative, MoEAllReduce, MoEAllReduceParams, + allgather, alltoall_helix, cp_allgather, reducescatter, + userbuffers_allreduce_finalize) __all__ = [ "allgather", @@ -16,6 +17,7 @@ __all__ = [ "AllReduceParams", "AllReduceFusionOp", "AllReduceStrategy", + "HelixAllToAllNative", "MoEAllReduce", "MoEAllReduceParams", "MoeAlltoAll", diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index fa8e61f322..51de18c7b1 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -2,14 +2,16 @@ import math import os import platform import threading -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from torch import nn +from tensorrt_llm._mnnvl_utils import HelixCpMnnvlMemory, MnnvlMemory from tensorrt_llm._torch.distributed.symm_mem_allreduce import \ SymmetricMemoryAllReduce from tensorrt_llm._utils import mpi_comm, mpi_disabled +from tensorrt_llm.bindings import internal as _tllm_internal from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, AllReduceStrategy, MoEAllReduceParams) @@ -363,6 +365,84 @@ def alltoall_helix( return torch.ops.trtllm.alltoall_helix(inputs, group, num_lists) +class HelixAllToAllNative: + """ + Manager for Helix All-to-All operations with MNNVL workspace management. + + Exchanges data along the cp_size dimension: + - partial_o: [..., cp_size, kv_lora_rank] half-precision + - softmax_stats: [..., cp_size, 2] float32 + """ + + # Global cache: mapping -> instance + _cache: Dict[Mapping, "HelixAllToAllNative"] = {} + + def __init__(self, mapping: Mapping, workspace: HelixCpMnnvlMemory, + workspace_tensor: torch.Tensor): + """Private constructor - use get() instead.""" + self.mapping = mapping + self.workspace = workspace + self.workspace_tensor = workspace_tensor + + @staticmethod + def get(mapping: Mapping) -> "HelixAllToAllNative": + """ + Get or create a HelixAllToAllNative instance for the given configuration. + + Args: + mapping: TensorRT-LLM mapping object containing cp_size and cp_rank + + Returns: + Cached or newly-created HelixAllToAllNative instance + """ + if mapping not in HelixAllToAllNative._cache: + logger.info( + f"Rank {mapping.cp_rank} initializing HelixCpMnnvlMemory for Helix" + ) + MnnvlMemory.initialize() + + # Get workspace size (in bytes) + workspace_size_per_rank = _tllm_internal.thop.get_helix_workspace_size_per_rank( + mapping.cp_size) + + # Allocate MNNVL memory using CP communicator for Helix + workspace = HelixCpMnnvlMemory(mapping, workspace_size_per_rank) + workspace_tensor = workspace.as_torch_strided_tensor(torch.uint64) + + torch.ops.trtllm.initialize_helix_workspace(workspace_tensor, + mapping.cp_rank, + mapping.cp_size) + torch.cuda.synchronize() + HelixCpMnnvlMemory.get_comm(mapping).barrier() + + HelixAllToAllNative._cache[mapping] = HelixAllToAllNative( + mapping, workspace, workspace_tensor) + + return HelixAllToAllNative._cache[mapping] + + def alltoall_native(self, partial_o: torch.Tensor, + softmax_stats: torch.Tensor): + """ + Perform all-to-all data exchange. + + Args: + partial_o: Tensor with shape [..., cp_size, kv_lora_rank], dtype half. + softmax_stats: Tensor with shape [..., cp_size, 2], dtype float32. + + Returns: + Tuple of (partial_o_out, softmax_stats_out) with same shapes as inputs. + """ + partial_o_out, softmax_stats_out = torch.ops.trtllm.alltoall_helix_native( + partial_o, + softmax_stats, + self.workspace_tensor, + self.mapping.cp_rank, + self.mapping.cp_size, + ) + + return partial_o_out, softmax_stats_out + + def reducescatter( input: Union[torch.Tensor, List[torch.Tensor]], mapping: Mapping, diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 8543fcf4a1..10cce12a5d 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -20,7 +20,7 @@ from ..attention_backend.interface import (AttentionBackend, AttentionMask, from ..attention_backend.sparse.dsa import ( DSAtrtllmAttentionMetadata, transform_local_topk_and_prepare_pool_view) from ..attention_backend.utils import create_attention, get_attention_backend -from ..distributed import AllReduceParams, alltoall_helix +from ..distributed import AllReduceParams, HelixAllToAllNative, alltoall_helix from ..model_config import ModelConfig from ..peft.lora.layer import LoraLayer, LoraModuleType from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs, @@ -1120,24 +1120,60 @@ class MLA(nn.Module): softmax_stats_tensor=softmax_stats, **kwargs, ) - # this is the post-processing of helix parallel attention, - # similar to the post-processing of ring attention kv_lora_rank = partial_o.shape[-1] // self.num_heads_tp assert self.kv_lora_rank == kv_lora_rank - # transpose the tensors to make the split across cp_size contiguous - # for both tensors, we need to split across the second dimension - chunks = [] - for t in [partial_o, softmax_stats]: - t = t.transpose(1, 0).contiguous() - chunks.extend(torch.split(t, - t.shape[0] // self.mapping.cp_size)) - gathered = alltoall_helix(chunks, self.mapping.cp_group) - # transpose the tensors back to ensure dimensions are ordered correctly - # note: an additional dimension was added at the first index for all-to-all, - # so the transpose dimensions are shifted by 1 - gathered = [t.transpose(1, 2).contiguous() for t in gathered] - return torch.ops.trtllm.helix_post_process(gathered[0], gathered[1], - 1.0) + + # Switch between NCCL-based and FIFO-based (MNNVL) all-to-all based on cp_config. + if self.mapping.cp_config.get("use_nccl_for_alltoall", True): + # NCCL-based implementation using alltoall_helix. + # This is the post-processing of helix parallel attention, + # similar to the post-processing of ring attention. + # Transpose the tensors to make the split across cp_size contiguous + # For both tensors, we need to split across the second dimension. + chunks = [] + for t in [partial_o, softmax_stats]: + t = t.transpose(1, 0).contiguous() + chunks.extend( + torch.split(t, t.shape[0] // self.mapping.cp_size)) + gathered = alltoall_helix(chunks, self.mapping.cp_group) + # Transpose the tensors back to ensure dimensions are ordered correctly. + # Note: an additional dimension was added at the first index for all-to-all, + # so the transpose dimensions are shifted by 1. + gathered = [t.transpose(1, 2).contiguous() for t in gathered] + return torch.ops.trtllm.helix_post_process( + gathered[0], gathered[1], 1.0) + else: + # FIFO-based implementation using MNNVL workspace and LL128 Proto. + # Get or create Helix All-to-All instance. + helix = HelixAllToAllNative.get(self.mapping) + + # Get dimensions. + num_tokens = partial_o.shape[0] + cp_size = self.mapping.cp_size + + # Reshape for FIFO-based all-to-all. + # partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank] + # softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2] + + partial_o = partial_o.view( + num_tokens, cp_size, self.num_heads_tp_cp, + kv_lora_rank).transpose(1, 2).contiguous() + softmax_stats = softmax_stats.view(num_tokens, cp_size, + self.num_heads_tp_cp, + 2).transpose(1, + 2).contiguous() + + # Call FIFO-based helixAllToAll. + partial_o_out, softmax_stats_out = helix.alltoall_native( + partial_o, softmax_stats) + + # partial_o_out: [num_tokens, num_heads_tp_cp, cp_size, kv_lora_rank] + # softmax_stats_out: [num_tokens, num_heads_tp_cp, cp_size, 2] + # cp_dim = 2 (the dimension where cp_size is located) + + # Call helix_post_process_native with cp_dim=2. + return torch.ops.trtllm.helix_post_process_native( + partial_o_out, softmax_stats_out, 1.0, 2) else: attn_output = attn_backend.forward(q, k, v, attn_metadata, **kwargs) return attn_output diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 809482e719..2d6b02897d 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -887,7 +887,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): "cudagraph:none", "cudagraph:without_padding", "cudagraph:with_padding" ]) - def test_auto_dtype_with_helix(self, cuda_graph_config): + @pytest.mark.parametrize("comms_medium", ["fifo", "nccl"]) + def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config): + use_nccl_for_alltoall = comms_medium == "nccl" kv_cache_config = { "free_gpu_memory_fraction": 0.5, "enable_block_reuse": False, @@ -912,7 +914,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): "context_parallel_size": 2, "cp_config": { "cp_type": "HELIX", - "tokens_per_block": 32 + "tokens_per_block": 32, + "use_nccl_for_alltoall": use_nccl_for_alltoall }, "disable_overlap_scheduler": True, "kv_cache_config": kv_cache_config, diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index e071eed889..1f6f254aee 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -535,9 +535,12 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:none] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:without_padding] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:with_padding] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index 2a7562dd75..d52e960636 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -69,7 +69,8 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:without_padding] + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none] + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none] - condition: ranges: system_gpu_count: @@ -84,8 +85,10 @@ l0_gb200_multi_gpus: stage: post_merge backend: pytorch tests: - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:with_padding] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[cudagraph:none] + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding] + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding] + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding] + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] diff --git a/tests/unittest/_torch/modules/test_mla_helix.py b/tests/unittest/_torch/modules/test_mla_helix.py index 2fc27dc2ca..fc7aedf10e 100644 --- a/tests/unittest/_torch/modules/test_mla_helix.py +++ b/tests/unittest/_torch/modules/test_mla_helix.py @@ -644,7 +644,13 @@ def _run_mla_distributed( @torch.inference_mode -def _full_test_multi_gpu(rank: int, world_size: int, scenario: Scenario, gen_steps: int): +def _full_test_multi_gpu( + rank: int, + world_size: int, + scenario: Scenario, + gen_steps: int, + comms_medium: str = False, +): if scenario.rope_scaling: rope_scaling = { "beta_fast": scenario.rope_beta_fast, @@ -814,7 +820,13 @@ def _full_test_multi_gpu(rank: int, world_size: int, scenario: Scenario, gen_ste # Distributed mapping for helix mapping = Mapping( - world_size=world_size, rank=rank, cp_size=world_size, cp_config={"cp_type": CpType.HELIX} + world_size=world_size, + rank=rank, + cp_size=world_size, + cp_config={ + "cp_type": CpType.HELIX, + "use_nccl_for_alltoall": comms_medium == "nccl", + }, ) # we use cp_allgather here because there is no broadcast op across CP group ref_output_all = cp_allgather(ref_output, mapping=mapping, dim=0) @@ -849,18 +861,24 @@ def _run_single_rank(func, *args, **kwargs): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test") @pytest.mark.parametrize("scenario", test_scenarios, ids=lambda x: f"scenario: {x}") +@pytest.mark.parametrize("comms_medium", ["nccl", "fifo"]) def test_mla_helix_distributed( scenario: Scenario, + comms_medium: str, gen_steps: Optional[int] = None, max_mismatch_ratio: float = 0.02, mismatch_ratios: Optional[List[float]] = None, ): world_size = 2 + print(f"Testing with comms_medium={comms_medium}.") gen_steps = scenario.ref_steps if gen_steps is None else gen_steps with MPIPoolExecutor(max_workers=world_size) as executor: results = executor.map( _run_single_rank, - *zip(*[(_full_test_multi_gpu, world_size, scenario, gen_steps)] * world_size), + *zip( + *[(_full_test_multi_gpu, world_size, scenario, gen_steps, comms_medium == "nccl")] + * world_size + ), ) if mismatch_ratios is None: for ratio_mismatch in results: @@ -870,13 +888,22 @@ def test_mla_helix_distributed( if __name__ == "__main__": - for scenario in all_scenarios[:11]: - timing_steps = 256 - gen_steps = scenario.ref_steps + timing_steps - print(f"Running scenario: {scenario} and timing {timing_steps} steps") - mismatch_ratios = [] - test_mla_helix_distributed(scenario, gen_steps=gen_steps, mismatch_ratios=mismatch_ratios) - if any(mismatch > 0 for mismatch in mismatch_ratios): - print(f"Numerical test failed with mismatch ratios: {mismatch_ratios}") - else: - print("Numerical test passed") + for comms_medium in ["fifo", "nccl"]: + print(f"\n{'=' * 60}") + print(f"Testing with comms_medium={comms_medium}") + print(f"{'=' * 60}\n") + for scenario in all_scenarios[:11]: + timing_steps = 256 + gen_steps = scenario.ref_steps + timing_steps + print(f"Running scenario: {scenario} and timing {timing_steps} steps") + mismatch_ratios = [] + test_mla_helix_distributed( + scenario, + comms_medium=comms_medium, + gen_steps=gen_steps, + mismatch_ratios=mismatch_ratios, + ) + if any(mismatch > 0 for mismatch in mismatch_ratios): + print(f"Numerical test failed with mismatch ratios: {mismatch_ratios}") + else: + print("Numerical test passed")