/* * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/cudaUtils.h" #include #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/kernels/cudaAsyncOps.cuh" #include "tensorrt_llm/kernels/fusedMoeCommKernels.h" #include "tensorrt_llm/kernels/ll128Proto.cuh" #include "tensorrt_llm/kernels/quantization.cuh" TRTLLM_NAMESPACE_BEGIN namespace kernels { using tensorrt_llm::common::launchWithPdlWhenEnabled; // Quantize a contiguous shared-memory buffer containing elements of DType into NVFP4 with per-16-element FP8 scales. // Output layout (repeated per 16-element group per lane), followed by one global scale float: // [WARP_SIZE * 8 bytes packed e2m1 values] [WARP_SIZE * 1 byte E4M3 per-group scales] ... [global_scale (4 bytes)] // Each lane writes one 64-bit packed e2m1 for its 16 values and one 1-byte E4M3 scale per group. // Global scale is computed as (448*6)/absmax and written once at the end of the buffer. template __device__ __forceinline__ void quantize_nvfp4_sharedmem(uint8_t* compact_ptr, int sizeInBytes, int laneId) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) int const numElems = sizeInBytes / sizeof(DType); assert(numElems % 2 == 0); if (numElems <= 0) { return; } DType const* in = reinterpret_cast(compact_ptr); // 1) Global absmax across the field (warp reduce) in original dtype precision when possible float threadMaxFloat = 0.f; if constexpr (std::is_same_v || std::is_same_v) { using DType2 = typename tensorrt_llm::common::packed_as::type; DType2 const* in2 = reinterpret_cast(in); int const numPairs = numElems / 2; // Initialize to zero to avoid a concentrated shared-memory read from index 0 across all lanes DType2 localMax2; localMax2.x = DType(0.); localMax2.y = DType(0.); // stride over pairs for (int i = laneId; i < numPairs; i += WARP_SIZE) { DType2 v2 = in2[i]; localMax2 = tensorrt_llm::common::cuda_max(localMax2, tensorrt_llm::common::cuda_abs(v2)); } // Reduce vector to scalar float in-thread DType localMax = tensorrt_llm::common::cuda_max(localMax2); threadMaxFloat = tensorrt_llm::common::cuda_cast(localMax); } else { float localMax = 0.f; for (int i = laneId; i < numElems; i += WARP_SIZE) { float v = fabsf(tensorrt_llm::common::cuda_cast(in[i])); localMax = fmaxf(localMax, v); } threadMaxFloat = localMax; } #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { threadMaxFloat = fmaxf(threadMaxFloat, __shfl_xor_sync(0xffffffff, threadMaxFloat, offset)); } float const eps = 1e-12f; float const globalAbsMax = fmaxf(threadMaxFloat, eps); // 2) Global scale float const SFScaleVal = (448.0f * 6.0f) * (1.0f / globalAbsMax); // 3) Output layout int const numGroups = (numElems + WARP_SIZE * 16 - 1) / (WARP_SIZE * 16); // 8 bytes for e2m1, 1 byte for scale int const outputBlockSizeInBytes = 8 * WARP_SIZE + WARP_SIZE; uint8_t* const globalScaleOutBytes = compact_ptr + numGroups * outputBlockSizeInBytes; // 4) Per-16 group quantization int const swizzle_idy = laneId / 4; int const swizzle_idx = (laneId % 4) * 8; for (int groupId = 0; groupId < numGroups; groupId++) { int groupStart = groupId * (WARP_SIZE * 16); float vecMax = 0.f; float2 raw[8]; if constexpr (std::is_same_v || std::is_same_v) { using DType2 = typename tensorrt_llm::common::packed_as::type; int const numPairs = numElems / 2; DType2 const* in2Ptr = reinterpret_cast(in); int const pairBase = groupStart >> 1; #pragma unroll for (int i = 0; i < 8; ++i) { int const pi = pairBase + swizzle_idy * 32 + swizzle_idx + (i + swizzle_idy) % 8; if (pi < numPairs) { DType2 v2 = in2Ptr[pi]; float x = tensorrt_llm::common::cuda_cast(v2.x); float y = tensorrt_llm::common::cuda_cast(v2.y); raw[i] = make_float2(x, y); vecMax = fmaxf(vecMax, fmaxf(fabsf(x), fabsf(y))); } else { raw[i] = make_float2(0.0f, 0.0f); } } } else { groupStart += laneId * 16; #pragma unroll for (int i = 0; i < 8; ++i) { int idx = groupStart + (i << 1); if (idx < numElems) { float x = tensorrt_llm::common::cuda_cast(in[idx]); float y = (idx + 1 < numElems) ? tensorrt_llm::common::cuda_cast(in[idx + 1]) : 0.0f; raw[i] = make_float2(x, y); vecMax = fmaxf(vecMax, fmaxf(fabsf(x), fabsf(y))); } else { raw[i] = make_float2(0.0f, 0.0f); } } } // SF from vecMax and global scale; write as E4M3 float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); __nv_fp8_e4m3 sf8 = __nv_fp8_e4m3(SFValue); float SFValueNarrow = static_cast(sf8); float const outputScale = (vecMax != 0.f) ? reciprocal_approximate_ftz(SFValueNarrow * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; // Pack 16 values -> 8 bytes e2m1 (use raw[] read above to avoid a second shared-memory read) float2 fp2Vals[8]; #pragma unroll for (int i = 0; i < 8; ++i) { fp2Vals[i] = make_float2(raw[i].x * outputScale, raw[i].y * outputScale); } uint64_t const e2m1Vec = fp32_vec_to_e2m1(fp2Vals); uint8_t* const outValPtr = compact_ptr + groupId * outputBlockSizeInBytes + laneId * sizeof(uint64_t); uint8_t* const outScalePtr = compact_ptr + groupId * outputBlockSizeInBytes + WARP_SIZE * sizeof(uint64_t) + laneId * sizeof(uint8_t); if (laneId < 16) { reinterpret_cast(outValPtr)[0] = e2m1Vec; } __syncwarp(); if (laneId >= 16) { reinterpret_cast(outValPtr)[0] = e2m1Vec; } outScalePtr[0] = sf8.__x; } // Store global scale (fp32) once with a single 32-bit store. Use lane 0 to avoid races. if (laneId == 0) { *reinterpret_cast(globalScaleOutBytes) = SFScaleVal; } #endif } // Convert one lane's packed 16 e2m1 values (in a 64-bit word) into eight float2 values (16 floats). // Uses 8 cvt.rn.f16x2.e2m1x2 instructions, one per input byte, to produce eight half2 which are cast to float2. inline __device__ void e2m1_to_fp32_vec(uint64_t e2m1Vec, float2 (&array)[8]) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t out_fp16[8]; asm volatile( "{\n" ".reg .b8 b0;\n" ".reg .b8 b1;\n" ".reg .b8 b2;\n" ".reg .b8 b3;\n" ".reg .b8 b4;\n" ".reg .b8 b5;\n" ".reg .b8 b6;\n" ".reg .b8 b7;\n" ".reg .b32 lo;\n" ".reg .b32 hi;\n" "mov.b64 {lo, hi}, %8;\n" "mov.b32 {b0, b1, b2, b3}, lo;\n" "mov.b32 {b4, b5, b6, b7}, hi;\n" "cvt.rn.f16x2.e2m1x2 %0, b0;\n" "cvt.rn.f16x2.e2m1x2 %1, b1;\n" "cvt.rn.f16x2.e2m1x2 %2, b2;\n" "cvt.rn.f16x2.e2m1x2 %3, b3;\n" "cvt.rn.f16x2.e2m1x2 %4, b4;\n" "cvt.rn.f16x2.e2m1x2 %5, b5;\n" "cvt.rn.f16x2.e2m1x2 %6, b6;\n" "cvt.rn.f16x2.e2m1x2 %7, b7;\n" "}" : "=r"(out_fp16[0]), "=r"(out_fp16[1]), "=r"(out_fp16[2]), "=r"(out_fp16[3]), "=r"(out_fp16[4]), "=r"(out_fp16[5]), "=r"(out_fp16[6]), "=r"(out_fp16[7]) : "l"(e2m1Vec)); array[0] = __half22float2(reinterpret_cast<__half2&>(out_fp16[0])); array[1] = __half22float2(reinterpret_cast<__half2&>(out_fp16[1])); array[2] = __half22float2(reinterpret_cast<__half2&>(out_fp16[2])); array[3] = __half22float2(reinterpret_cast<__half2&>(out_fp16[3])); array[4] = __half22float2(reinterpret_cast<__half2&>(out_fp16[4])); array[5] = __half22float2(reinterpret_cast<__half2&>(out_fp16[5])); array[6] = __half22float2(reinterpret_cast<__half2&>(out_fp16[6])); array[7] = __half22float2(reinterpret_cast<__half2&>(out_fp16[7])); #endif } template __device__ __forceinline__ void dequantize_nvfp4_sharedmem(uint8_t* compact_ptr, int sizeInBytes, int laneId) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) int const numElems = sizeInBytes / sizeof(DType); if (numElems <= 0) { return; } int const numGroups = (numElems + WARP_SIZE * 16 - 1) / (WARP_SIZE * 16); // New layout matches quantize: per-group blocks of [8*WARP_SIZE bytes values][WARP_SIZE bytes scales], // followed by a single 4-byte global scale at the end. int const inputBlockSizeInBytes = 8 * WARP_SIZE + WARP_SIZE; uint8_t* const globalScaleOutBytes = compact_ptr + numGroups * inputBlockSizeInBytes; float const SFScaleVal = reciprocal_approximate_ftz(*reinterpret_cast(globalScaleOutBytes)); __syncwarp(); DType* out = reinterpret_cast(compact_ptr); // Process groups in reverse order to avoid overwriting packed input before it is read for (int groupId = numGroups - 1; groupId >= 0; --groupId) { int const groupStart = laneId * 16 + groupId * (WARP_SIZE * 16); // Conflict-free read of packed 64-bit e2m1 values from shared memory: // serialize half-warps to avoid lane i and i+16 hitting the same bank in the same cycle. uint8_t const* const valBase = compact_ptr + groupId * inputBlockSizeInBytes; uint64_t packed = 0ull; if (laneId < 16) { packed = reinterpret_cast(valBase)[laneId]; } __syncwarp(); if (laneId >= 16) { packed = reinterpret_cast(valBase)[laneId]; } // Read per-lane 1-byte scales to match quantize access pattern uint8_t const* const scalesBase = compact_ptr + groupId * inputBlockSizeInBytes + WARP_SIZE * sizeof(uint64_t); uint8_t sfByte = scalesBase[laneId]; __nv_fp8_e4m3 sf8; sf8.__x = sfByte; float const SFValueNarrow = static_cast(sf8); float const dequantScale = SFScaleVal * SFValueNarrow; __syncwarp(); float2 tmp[8]; e2m1_to_fp32_vec(packed, tmp); // Vectorized stores with swizzle to avoid bank conflicts, matching quantize path if constexpr (std::is_same_v || std::is_same_v) { using DType2 = typename tensorrt_llm::common::packed_as::type; DType2* out2 = reinterpret_cast(out); int const numPairs = numElems / 2; int const pairBase = (groupId * (WARP_SIZE * 16)) >> 1; int const swizzle_idy = laneId / 4; int const swizzle_idx = (laneId % 4) * 8; #pragma unroll for (int t = 0; t < 8; ++t) { int const pi = pairBase + swizzle_idy * 32 + swizzle_idx + (t + swizzle_idy) % 8; if (pi < numPairs) { DType2 v2; v2.x = tensorrt_llm::common::cuda_cast(tmp[t].x * dequantScale); v2.y = tensorrt_llm::common::cuda_cast(tmp[t].y * dequantScale); out2[pi] = v2; } } } else { // Fallback linear layout for non-16-bit types #pragma unroll for (int t = 0; t < 8; ++t) { int idx0 = groupStart + (t << 1); if (idx0 < numElems) { using DType2 = typename tensorrt_llm::common::packed_as::type; DType2 v2; v2.x = tensorrt_llm::common::cuda_cast(tmp[t].x * dequantScale); v2.y = tensorrt_llm::common::cuda_cast(tmp[t].y * dequantScale); reinterpret_cast(out + idx0)[0] = v2; } } } __syncwarp(); } #endif } __host__ void MoeCommFieldInfo::fillFieldInfo( uint8_t* dataPtr, size_t elementSize, int vectorSize, int stride, cudaDataType_t dataType) { TLLM_CHECK(elementSize == 1 || elementSize == 2 || elementSize == 4 || elementSize == 8 || elementSize == 16); dataPtrBase = dataPtr; uint64_t dataPtrU64 = reinterpret_cast(dataPtr); while (elementSize < 16 && dataPtrU64 % (elementSize * 2) == 0 && vectorSize % 2 == 0 && stride % 2 == 0) { elementSize *= 2; vectorSize /= 2; stride /= 2; } if (elementSize == 16) { alignedUnitBit = 4; } else if (elementSize == 8) { alignedUnitBit = 3; } else if (elementSize == 4) { alignedUnitBit = 2; } else if (elementSize == 2) { alignedUnitBit = 1; } else { alignedUnitBit = 0; } alignedUnitCount = vectorSize; alignedUnitStride = stride; originalDataType = dataType; } // Wrapper class that delegates to LL128Proto but accepts extra warpId parameter for backward compatibility class Ll128ProtoWrapper { public: static constexpr uint32_t INITIALIZED_VALUE = LL128Proto::INITIALIZED_VALUE; template static __device__ __forceinline__ int checkDataReceivedInShm(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes, int fifoEntry128ByteIndexBase, int loaded128ByteCount, int /*warpId*/, int laneId) { return LL128Proto::checkDataReceivedInShm( sharedMemoryBase, step, countIn128Bytes, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId); } static __device__ __forceinline__ void protoPack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes, int fifoEntry128ByteIndexBase, int /*warpId*/, int laneId) { LL128Proto::protoPack(sharedMemoryBase, step, countIn128Bytes, fifoEntry128ByteIndexBase, laneId); } static __device__ __forceinline__ void protoUnpack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes, int fifoEntry128ByteIndexBase, int loaded128ByteCount, int /*warpId*/, int laneId) { LL128Proto::protoUnpack( sharedMemoryBase, step, countIn128Bytes, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId); } static __device__ __forceinline__ void rearm(uint32_t* u32FifoPtr, uint64_t step, int countIn128Bytes, int fifoEntry128ByteIndexBase, int /*warpId*/, int laneId) { LL128Proto::rearm(u32FifoPtr, step, countIn128Bytes, fifoEntry128ByteIndexBase, laneId); } static __device__ __host__ __forceinline__ int computeProtoTransfer128ByteAlignedSize( int compact128ByteSizeBeforeProto) { return LL128Proto::computeProtoTransfer128ByteAlignedSize(compact128ByteSizeBeforeProto); } }; using FusedMoeProto = Ll128ProtoWrapper; // using FusedMoeProto = LamportProto; namespace fused_moe_impl { // returns copy size for txCount __device__ __forceinline__ int startFieldG2S(MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId, uint64_t* smemBar) { // we can copy more data than needed, just align to 16 bytes. int alignedShmLoadOffset = fieldInfo.getUncompactShmOffset(); uint8_t* sharedMemoryLoadPtr = sharedMemoryBase + alignedShmLoadOffset; int copyByteCount = 0; uint8_t* loadPtr = fieldInfo.get16BAlignedLoadCopyRange(dataIndex, ©ByteCount); if (laneId == 0 && copyByteCount > 0) { cp_async_bulk_g2s(sharedMemoryLoadPtr, loadPtr, copyByteCount, smemBar); } return copyByteCount; } __device__ __forceinline__ void startFieldS2G( MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId) { int alignedShmStoreOffset = fieldInfo.getUncompactShmOffset(); uint8_t* sharedMemoryStorePtr = sharedMemoryBase + alignedShmStoreOffset; int copyByteCount = 0; int headTailShmIdx; int headTailGlobalIdx; uint8_t* storePtr = fieldInfo.get16BAlignedStoreCopyRange(dataIndex, ©ByteCount, laneId, &headTailShmIdx, &headTailGlobalIdx); if (copyByteCount > 0 && laneId == 0) { cp_async_bulk_s2g(storePtr, sharedMemoryStorePtr + MoeCommFieldInfo::BYTES_PER_16B_BLOCK, copyByteCount); } if (headTailGlobalIdx >= 0) { // copy head and tail fieldInfo.getRawPtr(dataIndex, nullptr)[headTailGlobalIdx] = sharedMemoryStorePtr[headTailShmIdx]; } __syncwarp(); } // SRC_AFTER_DST is true, if src > dst, pack will use this, // SRC_AFTER_DST is false, if src < dst, unpack will use this template __device__ __forceinline__ void memmoveSharedMemory(uint8_t* dst, uint8_t const* src, int copySize, int laneId) { int count = (copySize + sizeof(T) - 1) / sizeof(T); int warpLoopStart = SRC_AFTER_DST ? 0 : (count + WARP_SIZE - 1) / WARP_SIZE - 1; int warpLoopEnd = SRC_AFTER_DST ? (count + WARP_SIZE - 1) / WARP_SIZE : -1; int warpLoopUpdate = SRC_AFTER_DST ? 1 : -1; for (int i = warpLoopStart; i != warpLoopEnd; i += warpLoopUpdate) { int idx = laneId + i * WARP_SIZE; T data = T{}; if (idx < count) { data = reinterpret_cast(src)[idx]; } __syncwarp(); if (idx < count) { reinterpret_cast(dst)[idx] = data; } __syncwarp(); } } template __device__ __forceinline__ void memmoveFieldOnSharedMemory( MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) { int movOffset = fieldInfo.getMemmoveOffsets(dataIndex); if (movOffset == 0) { // if movOffset is 0, src and dst are the same, don't need memmove. return; } int alignedBytes = 1 << fieldInfo.alignedUnitBit; int copySize = fieldInfo.alignedUnitCount * alignedBytes; uint8_t* sharedMemoryCompact = sharedMemoryBase + fieldInfo.getCompactShmOffset(); uint8_t* sharedMemoryUncompact = sharedMemoryCompact + movOffset; uint8_t* sharedMemoryDst = IS_PACK ? sharedMemoryCompact : sharedMemoryUncompact; uint8_t* sharedMemorySrc = IS_PACK ? sharedMemoryUncompact : sharedMemoryCompact; if (movOffset % 16 == 0) { memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); } else if (movOffset % 8 == 0) { memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); } else if (movOffset % 4 == 0) { memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); } else if (movOffset % 2 == 0) { memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); } else { memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); } } template __device__ __forceinline__ void packAllFields( FusedMoeFieldInfo const& sendFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) { #pragma unroll for (int i = 0; i < FIELD_COUNT; i++) { memmoveFieldOnSharedMemory(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId); } __syncwarp(); } template __device__ __forceinline__ void unpackAllFields( FusedMoeFieldInfo const& recvFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) { #pragma unroll for (int i = FIELD_COUNT - 1; i >= 0; i--) { memmoveFieldOnSharedMemory(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId); } __syncwarp(); } __device__ __forceinline__ void startWorkspaceS2G( uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int warpId, int laneId) { int copyByteCount = send128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK; if (laneId == 0) { cp_async_bulk_s2g(fifoEntry + fifo128ByteOffset * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t), sharedMemoryBase, copyByteCount); } __syncwarp(); cp_async_bulk_commit_group(); } __device__ __forceinline__ void startWorkspaceS2GReg( uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int warpId, int laneId) { int copyInt4Count = send128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int4); int4* sharedMemoryInt4 = reinterpret_cast(sharedMemoryBase); uint64_t* fifoPtr = fifoEntry + fifo128ByteOffset * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t); int4* fifoPtrInt4 = reinterpret_cast(fifoPtr); #pragma unroll 4 for (int i = laneId; i < copyInt4Count; i += WARP_SIZE) { fifoPtrInt4[i] = sharedMemoryInt4[i]; } __syncwarp(); } __device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* sharedMemoryBase, uint64_t* fifoEntry, int allLoad128ByteCount, int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int warpId, int laneId) { int copyByteCount = (allLoad128ByteCount - loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK; if (laneId == 0) { cp_async_bulk_g2s(sharedMemoryBase + loaded128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK, fifoEntry + (fifo128ByteOffset + loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t), copyByteCount, smemBar); } return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? copyByteCount : 0); } __device__ __forceinline__ void g2sBasicFields(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) { int topK = expertParallelInfo.topK; int* tokenSelectedSlotsPtr = sendFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK); float* scalePtr = sendFieldInfo.getScalePtr(dataIndex, laneId, topK); ldgsts<4>(reinterpret_cast(sharedMemoryBase) + laneId, tokenSelectedSlotsPtr, laneId < topK); ldgsts<4>(reinterpret_cast(sharedMemoryBase) + laneId + topK, reinterpret_cast(scalePtr), laneId < topK && sendFieldInfo.expertScales != nullptr); } // May commit 1 group for basic fields(tokenSelectedSlots and scales) if HAS_BASIC_FIELDS is true // For other fields, use smemBar. template __device__ __forceinline__ uint64_t g2sAllFields(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId, uint64_t* smemBar) { if (HAS_BASIC_FIELDS) { g2sBasicFields(sendFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, laneId); cp_async_commit_group(); } int asyncLoadSize = 0; #pragma unroll for (int i = 0; i < FIELD_COUNT; i++) { asyncLoadSize += startFieldG2S(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId, smemBar); } return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? asyncLoadSize : 0); } template __device__ __forceinline__ void waitG2SBasicFields() { if (HAS_BASIC_FIELDS) { cp_async_wait_group<0>(); __syncwarp(); } } __device__ __forceinline__ void waitG2SOtherFields(uint64_t* memBar, uint32_t* phaseParity) { smemBarWait(memBar, phaseParity); } template __device__ __forceinline__ void waitG2SAllFields(uint64_t* memBar, uint32_t* phaseParity) { waitG2SBasicFields(); waitG2SOtherFields(memBar, phaseParity); } __device__ __forceinline__ void waitS2GBulkRead() { cp_async_bulk_wait_group_read<0>(); __syncwarp(); } __device__ __forceinline__ void s2gBasicFields(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId) { int topK = expertParallelInfo.topK; int* tokenSelectedSlotsPtr = recvFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK); float* scalePtr = recvFieldInfo.getScalePtr(dataIndex, laneId, topK); if (laneId < topK) { int selectedSlot = reinterpret_cast(sharedMemoryBase)[laneId]; *tokenSelectedSlotsPtr = selectedSlot; if (recvFieldInfo.expertScales != nullptr) { float scale = reinterpret_cast(sharedMemoryBase)[laneId + topK]; *scalePtr = scale; } } } // Will commit 1 group, for all non-basic fields template __device__ __forceinline__ void s2gAllFields(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId) { if (HAS_BASIC_FIELDS) { s2gBasicFields(recvFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, warpId, laneId); __syncwarp(); } #pragma unroll for (int i = 0; i < FIELD_COUNT; i++) { startFieldS2G(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId); } cp_async_bulk_commit_group(); } template class SingleChannelCommunicator { public: __device__ __forceinline__ SingleChannelCommunicator(FusedMoeFieldInfo const& fieldInfo, MoeExpertParallelInfo const& expertParallelInfo, MoeSingleCommMeta const& commMeta, FusedMoeWorkspace const& workspace, FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo, uint64_t* smemBar, uint8_t* shmemBase) : mFieldInfo(fieldInfo) , mExpertParallelInfo(expertParallelInfo) , mCommMeta(commMeta) , mWorkspace(workspace) , mWorldInfo(worldInfo) , mPairInfo(pairInfo) , mSmemBar(smemBar) , mShmemBase(shmemBase) { if constexpr (LOW_PRECISION) { static_assert(FIELD_COUNT == 1, "Low precision alltoall only support 1 field"); } mWarpId = threadIdx.x / WARP_SIZE; mLaneId = threadIdx.x % WARP_SIZE; mFifoBasePtr = mWorkspace.getFifoBasePtr(mWorldInfo, mPairInfo); mSenderSideFifoInfo = mWorkspace.getSenderSideFifoInfo(mWorldInfo, mPairInfo); mReceiverSideFifoInfo = mWorkspace.getReceiverSideFifoInfo(mWorldInfo, mPairInfo); mSingleTransfer128ByteCount = mCommMeta.getTransfer128ByteCount(); mSingleCompactData128ByteCount = mCommMeta.getCompactData128ByteCount(); // initialize as need new Entry first mFifoEntry128ByteIndexBase = kFifoEntry128ByteCount; mFifoEntryIndex = -1; initSmemBar(mSmemBar, mLaneId); } __device__ __forceinline__ uint64_t* getFifoEntryPtr() const { return mFifoBasePtr + mFifoEntryIndex * kFifoEntrySizeInU64; } __device__ __forceinline__ bool needNewEntry() const { return mFifoEntry128ByteIndexBase + mSingleTransfer128ByteCount > kFifoEntry128ByteCount; } __device__ __forceinline__ void nextToken() { mFifoEntry128ByteIndexBase += mSingleTransfer128ByteCount; } __device__ __forceinline__ void senderInitFifo() { mHead = mSenderSideFifoInfo->head; mTail = mSenderSideFifoInfo->tail; } __device__ __forceinline__ void receiverInitFifo() { mHead = mReceiverSideFifoInfo->head; mTail = mReceiverSideFifoInfo->tail; } /* * Head | 0 | 1 | 2 | 3 | 4 | 4 | 4 | 4 | 4 | 5 | * Tail | 0 | 0 | 0 | 0 | 0 | 1 | 2 | 3 | 4 | 4 | * Writable | Y | Y | Y | Y | N | Y | Y | Y | Y | Y | * Readable | N | Y | Y | Y | Y | Y | Y | Y | N | Y | */ __device__ __forceinline__ void waitEntryWritable() { while (mTail + kFifoDepth <= mHead) { mTail = mSenderSideFifoInfo->tail; } } __device__ __forceinline__ void updateWriteEntry() { __syncwarp(); mSenderSideFifoInfo->head = mHead; } __device__ __forceinline__ void waitEntryReadable() { // always readable as long as flag matches. } __device__ __forceinline__ void updateReadEntry() { mReceiverSideFifoInfo->tail = mTail; mSenderSideFifoInfo->tail = mTail; } __device__ __forceinline__ void newSendEntry() { mFifoEntryIndex = mHead % kFifoDepth; mFifoEntry128ByteIndexBase = 0; waitEntryWritable(); __syncwarp(); } __device__ __forceinline__ void newReceiveEntry() { mFifoEntryIndex = mTail % kFifoDepth; mFifoEntry128ByteIndexBase = 0; waitEntryReadable(); __syncwarp(); } __device__ __forceinline__ void doSend(int tokenCount, int* sendIndexMapping) { senderInitFifo(); int sendIndex = mPairInfo.channel; uint32_t phaseParity = 0; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); cudaTriggerProgrammaticLaunchCompletion(); #endif for (; sendIndex < tokenCount; sendIndex += mPairInfo.runChannelCount) { int tokenIndex = sendIndexMapping == nullptr ? sendIndex : sendIndexMapping[sendIndex]; tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId, mSmemBar); if (needNewEntry()) { if (mFifoEntryIndex >= 0) { // not first entry, update FIFO info from last entry. mHead++; updateWriteEntry(); } newSendEntry(); } tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(mSmemBar, &phaseParity); tensorrt_llm::kernels::fused_moe_impl::packAllFields( mFieldInfo, tokenIndex, mShmemBase, mLaneId); if constexpr (LOW_PRECISION) { // quantize here. int alignedUnitBit = mFieldInfo.fieldsInfo[0].alignedUnitBit; int alignedUnitCount = mFieldInfo.fieldsInfo[0].alignedUnitCount; int sizeInBytes = alignedUnitCount * (1 << alignedUnitBit); uint8_t* sharedMemoryCompact = mShmemBase + mFieldInfo.fieldsInfo[0].getCompactShmOffset(); cudaDataType_t originalDataType = mFieldInfo.fieldsInfo[0].originalDataType; switch (originalDataType) { case CUDA_R_16BF: quantize_nvfp4_sharedmem<__nv_bfloat16>(sharedMemoryCompact, sizeInBytes, mLaneId); break; case CUDA_R_16F: quantize_nvfp4_sharedmem(sharedMemoryCompact, sizeInBytes, mLaneId); break; default: break; } } FusedMoeProto::protoPack( mShmemBase, mHead, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId); tensorrt_llm::kernels::fused_moe_impl::startWorkspaceS2GReg(getFifoEntryPtr(), mShmemBase, mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId); // tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead(); nextToken(); } if (mFifoEntry128ByteIndexBase > 0) { mHead++; updateWriteEntry(); } } __device__ __forceinline__ void rearmFifoBuffer() { constexpr int kUint32CountPer128Byte = 128 / sizeof(uint32_t); uint32_t* fifoPtr = reinterpret_cast(getFifoEntryPtr()); fifoPtr += mFifoEntry128ByteIndexBase * kUint32CountPer128Byte; FusedMoeProto::rearm(fifoPtr, mTail, mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId); __syncwarp(); } __device__ __forceinline__ void doReceive(int tokenCount, int* recvIndexMapping) { receiverInitFifo(); int recvIndex = mPairInfo.channel; uint32_t phaseParity = 0; bool needRelease = false; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); cudaTriggerProgrammaticLaunchCompletion(); #endif for (; recvIndex < tokenCount; recvIndex += mPairInfo.runChannelCount) { int tokenIndex = recvIndexMapping == nullptr ? recvIndex : recvIndexMapping[recvIndex]; int loaded128ByteCount = 0; if (needNewEntry()) { if (mFifoEntryIndex >= 0) { // not first entry, update FIFO info from last entry. mTail++; needRelease = true; } newReceiveEntry(); } while (loaded128ByteCount < mSingleTransfer128ByteCount) { tensorrt_llm::kernels::fused_moe_impl::startWorkspaceG2S(mShmemBase, getFifoEntryPtr(), mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mSmemBar, mWarpId, mLaneId); if (needRelease) { updateReadEntry(); needRelease = false; } smemBarWait(mSmemBar, &phaseParity); loaded128ByteCount += FusedMoeProto::template checkDataReceivedInShm(mShmemBase, mTail, mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mWarpId, mLaneId); } FusedMoeProto::protoUnpack(mShmemBase, mTail, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mWarpId, mLaneId); if constexpr (LOW_PRECISION) { int alignedUnitBit = mFieldInfo.fieldsInfo[0].alignedUnitBit; int alignedUnitCount = mFieldInfo.fieldsInfo[0].alignedUnitCount; int sizeInBytes = alignedUnitCount * (1 << alignedUnitBit); uint8_t* sharedMemoryCompact = mShmemBase + mFieldInfo.fieldsInfo[0].getCompactShmOffset(); cudaDataType_t originalDataType = mFieldInfo.fieldsInfo[0].originalDataType; switch (originalDataType) { case CUDA_R_16BF: dequantize_nvfp4_sharedmem<__nv_bfloat16>(sharedMemoryCompact, sizeInBytes, mLaneId); break; case CUDA_R_16F: dequantize_nvfp4_sharedmem(sharedMemoryCompact, sizeInBytes, mLaneId); break; default: break; } } tensorrt_llm::kernels::fused_moe_impl::unpackAllFields( mFieldInfo, tokenIndex, mShmemBase, mLaneId); tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId); tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead(); rearmFifoBuffer(); nextToken(); } if (mFifoEntry128ByteIndexBase > 0) { mTail++; updateReadEntry(); } } private: static constexpr int kFifoEntrySizeInU64 = FusedMoeCommunicator::FIFO_ENTRY_BYTES / sizeof(uint64_t); static constexpr int kFifoEntry128ByteCount = FusedMoeCommunicator::FIFO_ENTRY_128_BYTE_COUNT; static constexpr int kFifoDepth = FusedMoeCommunicator::FIFO_DEPTH; FusedMoeFieldInfo mFieldInfo; MoeExpertParallelInfo mExpertParallelInfo; MoeSingleCommMeta mCommMeta; FusedMoeWorkspace mWorkspace; FusedMoeWorldInfo mWorldInfo; FusedMoePairInfo mPairInfo; uint64_t* mSmemBar; uint8_t* mShmemBase; int mLaneId; int mWarpId; uint64_t* mFifoBasePtr; SenderSideFifoInfo* mSenderSideFifoInfo; ReceiverSideFifoInfo* mReceiverSideFifoInfo; int64_t mHead; int64_t mTail; int mSingleTransfer128ByteCount; int mSingleCompactData128ByteCount; int mFifoEntry128ByteIndexBase; int mFifoEntryIndex; }; template __global__ void moeAllToAllKernel(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, bool hasBasicFields) { __shared__ uint64_t allWarpSmemBar[32]; extern __shared__ int4 allWarpShm[]; bool isSender = blockIdx.z == 0; int runChannelCount = gridDim.y; int group = threadIdx.y; SendRecvIndices dataIndices = isSender ? params.sendIndices : params.recvIndices; FusedMoePairInfo pairInfo; int peerRank = blockIdx.x * blockDim.y + group; if (peerRank >= params.worldInfo.epInfo.epSize) { return; } int tokenCount; int* groupStartPtr = dataIndices.getGroupStart(peerRank, tokenCount); if (tokenCount == 0) { return; } pairInfo.channel = blockIdx.y; pairInfo.runChannelCount = runChannelCount; pairInfo.senderRank = isSender ? params.worldInfo.epInfo.epRank : peerRank; pairInfo.receiverRank = isSender ? peerRank : params.worldInfo.epInfo.epRank; if (isSender) { int singleShmSize = params.sendCommMeta.getSingleShmSize(); if (hasBasicFields) { SingleChannelCommunicator comm(params.sendFieldInfo, params.expertParallelInfo, params.sendCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, reinterpret_cast(allWarpShm) + singleShmSize * group); comm.doSend(tokenCount, groupStartPtr); } else { SingleChannelCommunicator comm(params.sendFieldInfo, params.expertParallelInfo, params.sendCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, reinterpret_cast(allWarpShm) + singleShmSize * group); comm.doSend(tokenCount, groupStartPtr); } } else { int singleShmSize = params.recvCommMeta.getSingleShmSize(); if (hasBasicFields) { SingleChannelCommunicator comm(params.recvFieldInfo, params.expertParallelInfo, params.recvCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, reinterpret_cast(allWarpShm) + singleShmSize * group); comm.doReceive(tokenCount, groupStartPtr); } else { SingleChannelCommunicator comm(params.recvFieldInfo, params.expertParallelInfo, params.recvCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, reinterpret_cast(allWarpShm) + singleShmSize * group); comm.doReceive(tokenCount, groupStartPtr); } } } int computeMoeAlltoallMaxDynamicSharedMemorySize() { int devId = -1; TLLM_CUDA_CHECK(cudaGetDevice(&devId)); cudaFuncAttributes attr{}; TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, (void const*) moeAllToAllKernel<1>)); int staticSmem = static_cast(attr.sharedSizeBytes); int maxPerBlockShmOptin = 0; TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxPerBlockShmOptin, cudaDevAttrMaxSharedMemoryPerBlockOptin, devId)); return maxPerBlockShmOptin - staticSmem; } } // namespace fused_moe_impl void FusedMoeFieldInfo::fillMetaInfo( MoeSingleCommMeta* singleCommMeta, int topK, bool hasScales, bool hasBasicFields, bool isLowPrecision) const { singleCommMeta->singleUncompactAlignedSize = computeSingleUncompactSize(topK, hasScales, hasBasicFields); if (isLowPrecision) { assert(fieldCount == 1); assert(fieldsInfo[0].originalDataType == CUDA_R_16F || fieldsInfo[0].originalDataType == CUDA_R_16BF); auto alignment128 = MoeCommFieldInfo::BYTES_PER_128B_BLOCK; auto alignedUnitBit = fieldsInfo[0].alignedUnitBit; auto alignedUnitCount = fieldsInfo[0].alignedUnitCount; auto originalFieldSize = alignedUnitCount * (1 << alignedUnitBit); int numElements = originalFieldSize / 2; int numGroups = (numElements + WARP_SIZE * 16 - 1) / (WARP_SIZE * 16); int sizePerGroupInBytes = (WARP_SIZE * 16 / 2 + WARP_SIZE * 1); int totalSize = numGroups * sizePerGroupInBytes + 4; int compactSize = (totalSize + alignment128 - 1) / alignment128 * alignment128; singleCommMeta->singleCompactAlignedSize = compactSize; singleCommMeta->singleTransferAlignedSize = FusedMoeProto::computeProtoTransfer128ByteAlignedSize(singleCommMeta->singleCompactAlignedSize); return; } singleCommMeta->singleCompactAlignedSize = computeSingleCompactSize(topK, hasScales, hasBasicFields); singleCommMeta->singleTransferAlignedSize = FusedMoeProto::computeProtoTransfer128ByteAlignedSize(singleCommMeta->singleCompactAlignedSize); } void FusedMoeFieldInfo::fillFieldPlacementInfo(int topK, bool hasBasicFields) { int basicFieldSize = 0; if (hasBasicFields) { basicFieldSize = topK * sizeof(int) + (expertScales != nullptr ? topK * sizeof(float) : 0); // align to 16 bytes basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1) / MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK; } int offset = basicFieldSize; int unalignedFieldIndex = 0; for (int i = 0; i < fieldCount; i++) { fieldsInfo[i].compact16BOffset = offset / MoeCommFieldInfo::BYTES_PER_16B_BLOCK; offset += fieldsInfo[i].getFieldCompactSize(); fieldsInfo[i].unalignedFieldIndex = unalignedFieldIndex; if (fieldsInfo[i].alignedUnitBit < 4) { unalignedFieldIndex++; } } for (int i = fieldCount; i < MOE_COMM_FIELD_MAX_COUNT; i++) { fieldsInfo[i].setUnused(); } } void FusedMoeWorkspace::initializeLocalWorkspace(FusedMoeWorldInfo const& worldInfo) { int epSize = worldInfo.epInfo.epSize; int epRank = worldInfo.epInfo.epRank; size_t fifoSize = static_cast(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * epSize * channelCount; size_t senderSideInfoSize = sizeof(SenderSideFifoInfo) * epSize * channelCount; size_t receiverSideInfoSize = sizeof(ReceiverSideFifoInfo) * epSize * channelCount; uint64_t* localWorkspacePtr = workspacePtr + epRank * rankStrideInU64; TLLM_CU_CHECK(cuMemsetD32(reinterpret_cast(localWorkspacePtr), FusedMoeProto::INITIALIZED_VALUE, fifoSize / sizeof(uint32_t))); TLLM_CUDA_CHECK(cudaMemset( reinterpret_cast(localWorkspacePtr) + fifoSize, 0, senderSideInfoSize + receiverSideInfoSize)); } void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cudaStream_t stream) { bool hasBasicFields = params.sendFieldInfo.tokenSelectedSlots != nullptr; int warpSendShmSize = params.sendCommMeta.getSingleShmSize(); int warpRecvShmSize = params.recvCommMeta.getSingleShmSize(); int warpShmSize = warpSendShmSize; int epSize = params.worldInfo.epInfo.epSize; TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)", warpSendShmSize, warpRecvShmSize); int maxGroupCountPerCta = std::min(params.worldInfo.epInfo.epSize, FusedMoeCommunicator::MAX_GROUP_COUNT_PER_BLOCK); static int maxDynamicShmSize = fused_moe_impl::computeMoeAlltoallMaxDynamicSharedMemorySize(); int groupCountPerCta = std::min(maxGroupCountPerCta, maxDynamicShmSize / warpShmSize); int maxFieldCount = std::max(params.sendFieldInfo.fieldCount, params.recvFieldInfo.fieldCount); TLLM_CHECK_WITH_INFO(params.isLowPrecision == false || maxFieldCount == 1, "low precision only support 1 field"); auto getFunc = [](int fieldCount, bool lowPrecision) { switch (fieldCount) { case 1: if (lowPrecision) return fused_moe_impl::moeAllToAllKernel<1, true>; else return fused_moe_impl::moeAllToAllKernel<1>; case 2: return fused_moe_impl::moeAllToAllKernel<2>; case 3: return fused_moe_impl::moeAllToAllKernel<3>; case 4: return fused_moe_impl::moeAllToAllKernel<4>; case 5: return fused_moe_impl::moeAllToAllKernel<5>; case 6: return fused_moe_impl::moeAllToAllKernel<6>; case 7: return fused_moe_impl::moeAllToAllKernel<7>; case 8: return fused_moe_impl::moeAllToAllKernel<8>; default: return fused_moe_impl::moeAllToAllKernel<8>; } return fused_moe_impl::moeAllToAllKernel<8>; }; auto* kernelFn = getFunc(maxFieldCount, params.isLowPrecision); if (groupCountPerCta * warpShmSize > 48 * 1024) { TLLM_CUDA_CHECK(cudaFuncSetAttribute( kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, groupCountPerCta * warpShmSize)); } for (; groupCountPerCta > 0; groupCountPerCta--) { int dynamicShmSize = groupCountPerCta * warpShmSize; int numBlocks = 0; if (cudaOccupancyMaxActiveBlocksPerMultiprocessor( &numBlocks, kernelFn, WARP_SIZE * groupCountPerCta, dynamicShmSize) != cudaSuccess) { continue; } if (numBlocks >= 1) { break; } } TLLM_CHECK_WITH_INFO( groupCountPerCta >= 1, "computed groupCount=%d, warpShmSize=%d", groupCountPerCta, warpShmSize); int ctaPerChannel = (epSize + groupCountPerCta - 1) / groupCountPerCta; groupCountPerCta = (epSize + ctaPerChannel - 1) / ctaPerChannel; int totalDynamicShmSize = warpShmSize * groupCountPerCta; dim3 block = FusedMoeCommunicator::getLaunchBlockDim(groupCountPerCta); dim3 grid = FusedMoeCommunicator::getLaunchGridDim(params.worldInfo.epInfo.epSize, groupCountPerCta); launchWithPdlWhenEnabled( "moeAllToAll", kernelFn, grid, block, totalDynamicShmSize, stream, params, workspace, hasBasicFields); TLLM_CUDA_CHECK(cudaGetLastError()); } int FusedMoeCommunicator::maxSmCount = -1; bool FusedMoeCommunicator::maxSmCountUsed = false; void setMaxUsableSmCount(int smCount) { FusedMoeCommunicator::setMaxUsableSmCount(smCount); } size_t getFusedMoeCommWorkspaceSize(int epSize) { int channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize); size_t workspaceSize = FusedMoeWorkspace::computeWorkspaceSizePreRank(epSize, channelCount); return workspaceSize; } void constructWorkspace(FusedMoeWorkspace* workspace, uint64_t* workspacePtr, size_t rankStrideInU64, int epSize) { workspace->workspacePtr = workspacePtr; workspace->rankStrideInU64 = rankStrideInU64; workspace->channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize); } void initializeFusedMoeLocalWorkspace(FusedMoeWorkspace* workspace, FusedMoeWorldInfo const& worldInfo) { workspace->initializeLocalWorkspace(worldInfo); } namespace fused_moe_comm_tests { __global__ void g2sKernel(FusedMoeFieldInfo allFieldInfo, MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmDump, bool hasBasicFields) { __shared__ uint64_t allWarpSmemBar[32]; extern __shared__ int4 allWarpShm[]; int laneId = threadIdx.x % WARP_SIZE; int warpId = threadIdx.x / WARP_SIZE; int warpCount = blockDim.x / WARP_SIZE; int tokenIndex = warpId + blockIdx.x * warpCount; if (tokenIndex >= tokenCount) { return; } int singleShmSize = singleCommMeta.singleUncompactAlignedSize; initSmemBar(&allWarpSmemBar[warpId], laneId); uint32_t phaseParity = 0; uint8_t* sharedMemoryBase = reinterpret_cast(allWarpShm) + singleShmSize * warpId; if (hasBasicFields) { tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); } else { tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); } for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE) { shmDump[tokenIndex * singleShmSize / sizeof(int) + offset] = reinterpret_cast(sharedMemoryBase)[offset]; } } void launchSingleG2S(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int tokenCount, int* shmDump, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream) { int warpShmSize = sendFieldInfo.computeSingleUncompactSize( expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields); dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1); MoeSingleCommMeta singleCommMeta; sendFieldInfo.fillMetaInfo( &singleCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields, false); TLLM_CUDA_CHECK( cudaFuncSetAttribute(g2sKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); g2sKernel<<>>( sendFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmDump, hasBasicFields); TLLM_CUDA_CHECK(cudaGetLastError()); } __global__ void s2gKernel(FusedMoeFieldInfo recvFieldInfo, MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmPreload, bool hasBasicFields) { extern __shared__ int4 allWarpShm[]; int laneId = threadIdx.x % WARP_SIZE; int warpId = threadIdx.x / WARP_SIZE; int warpCount = blockDim.x / WARP_SIZE; int tokenIndex = warpId + blockIdx.x * warpCount; if (tokenIndex >= tokenCount) { return; } int singleShmSize = singleCommMeta.singleUncompactAlignedSize; uint8_t* sharedMemoryBase = reinterpret_cast(allWarpShm) + singleShmSize * warpId; for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE) { reinterpret_cast(sharedMemoryBase)[offset] = shmPreload[tokenIndex * singleShmSize / sizeof(int) + offset]; } __syncwarp(); if (hasBasicFields) { tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); } else { tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); } tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead(); } void launchSingleS2G(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int tokenCount, int* shmPreload, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream) { int warpShmSize = recvFieldInfo.computeSingleUncompactSize( expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields); dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1); MoeSingleCommMeta singleCommMeta; recvFieldInfo.fillMetaInfo( &singleCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields, false); TLLM_CUDA_CHECK( cudaFuncSetAttribute(s2gKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); s2gKernel<<>>( recvFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmPreload, hasBasicFields); TLLM_CUDA_CHECK(cudaGetLastError()); } __global__ void loopbackKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo, MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta, int* recvIndexMapping, int tokenCount, bool hasBasicFields) { __shared__ uint64_t allWarpSmemBar[32]; extern __shared__ int4 allWarpShm[]; int laneId = threadIdx.x % WARP_SIZE; int warpId = threadIdx.x / WARP_SIZE; int warpCount = blockDim.x / WARP_SIZE; int tokenIndex = warpId + blockIdx.x * warpCount; if (tokenIndex >= tokenCount) { return; } int recvTokenIndex = recvIndexMapping[tokenIndex]; initSmemBar(&allWarpSmemBar[warpId], laneId); uint32_t phaseParity = 0; int singleShmSize = sendCommMeta.getSingleShmSize(); uint8_t* sharedMemoryBase = reinterpret_cast(allWarpShm) + singleShmSize * warpId; if (hasBasicFields) { tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); } else { tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); } if (hasBasicFields) { tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); } else { tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); } tensorrt_llm::kernels::fused_moe_impl::packAllFields(sendFieldInfo, tokenIndex, sharedMemoryBase, laneId); tokenIndex = recvTokenIndex; // switch to recvTokenIndex; tensorrt_llm::kernels::fused_moe_impl::unpackAllFields(recvFieldInfo, tokenIndex, sharedMemoryBase, laneId); if (hasBasicFields) { tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); } else { tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); } cp_async_bulk_wait_group_read<0>(); __syncwarp(); } // G2S -> Pack -> Unpack -> S2G void launchLoopback(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int* recvIndexMapping, int tokenCount, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream) { MoeSingleCommMeta sendCommMeta, recvCommMeta; sendFieldInfo.fillMetaInfo( &sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields, false); recvFieldInfo.fillMetaInfo( &recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields, false); int warpSendShmSize = sendCommMeta.getSingleShmSize(); int warpRecvShmSize = recvCommMeta.getSingleShmSize(); int warpShmSize = warpSendShmSize; TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)", warpSendShmSize, warpRecvShmSize); dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1); TLLM_CUDA_CHECK( cudaFuncSetAttribute(loopbackKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); loopbackKernel<<>>(sendFieldInfo, recvFieldInfo, expertParallelInfo, sendCommMeta, recvCommMeta, recvIndexMapping, tokenCount, hasBasicFields); TLLM_CUDA_CHECK(cudaGetLastError()); } template __global__ void localFifoSendRecvKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo, MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta, FusedMoeWorkspace fusedMoeWorkspace, int* sendIndexMapping, int* recvIndexMapping, int tokenCount) { __shared__ uint64_t allWarpSmemBar[32]; extern __shared__ int4 allWarpShm[]; FusedMoeWorldInfo worldInfo; worldInfo.epInfo.epRank = 0; worldInfo.epInfo.epSize = 1; int warpId = threadIdx.x / WARP_SIZE; int warpCount = blockDim.x / WARP_SIZE; FusedMoePairInfo pairInfo; pairInfo.senderRank = 0; pairInfo.receiverRank = 0; pairInfo.channel = blockIdx.z * warpCount + warpId; pairInfo.runChannelCount = gridDim.z * warpCount; if (blockIdx.y == 0) { tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator senderComm(sendFieldInfo, expertParallelInfo, sendCommMeta, fusedMoeWorkspace, worldInfo, pairInfo, &allWarpSmemBar[warpId], reinterpret_cast(&allWarpShm[0]) + warpId * sendCommMeta.getSingleShmSize()); senderComm.doSend(tokenCount, sendIndexMapping); } else { tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator recverComm(recvFieldInfo, expertParallelInfo, recvCommMeta, fusedMoeWorkspace, worldInfo, pairInfo, &allWarpSmemBar[warpId], reinterpret_cast(&allWarpShm[0]) + warpId * recvCommMeta.getSingleShmSize()); recverComm.doReceive(tokenCount, recvIndexMapping); } } void launchLocalFifoSendRecv(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, int* sendIndexMapping, int* recvIndexMapping, FusedMoeWorkspace fusedMoeWorkspace, int tokenCount, int warpsPerBlock, int blockChannelCount, bool hasBasicFields, cudaStream_t stream) { MoeSingleCommMeta sendCommMeta, recvCommMeta; sendFieldInfo.fillMetaInfo( &sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields, false); recvFieldInfo.fillMetaInfo( &recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields, false); int warpSendShmSize = sendCommMeta.getSingleShmSize(); int warpRecvShmSize = recvCommMeta.getSingleShmSize(); int warpShmSize = warpSendShmSize; TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)", warpSendShmSize, warpRecvShmSize); dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); dim3 gridDim(1, 2, blockChannelCount); auto* kernelFn = localFifoSendRecvKernel<>; if (hasBasicFields) { kernelFn = localFifoSendRecvKernel; } else { kernelFn = localFifoSendRecvKernel; } TLLM_CUDA_CHECK( cudaFuncSetAttribute(kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); kernelFn<<>>(sendFieldInfo, recvFieldInfo, expertParallelInfo, sendCommMeta, recvCommMeta, fusedMoeWorkspace, sendIndexMapping, recvIndexMapping, tokenCount); TLLM_CUDA_CHECK(cudaGetLastError()); } } // namespace fused_moe_comm_tests } // namespace kernels TRTLLM_NAMESPACE_END