mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
1526 lines
60 KiB
Plaintext
1526 lines
60 KiB
Plaintext
/*
|
|
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/common/config.h"
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
|
|
#include <type_traits>
|
|
|
|
#include "tensorrt_llm/common/logger.h"
|
|
#include "tensorrt_llm/kernels/cudaAsyncOps.cuh"
|
|
#include "tensorrt_llm/kernels/fusedMoeCommKernels.h"
|
|
#include "tensorrt_llm/kernels/ll128Proto.cuh"
|
|
#include "tensorrt_llm/kernels/quantization.cuh"
|
|
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
|
|
namespace kernels
|
|
{
|
|
|
|
using tensorrt_llm::common::launchWithPdlWhenEnabled;
|
|
|
|
// Quantize a contiguous shared-memory buffer containing elements of DType into NVFP4 with per-16-element FP8 scales.
|
|
// Output layout (repeated per 16-element group per lane), followed by one global scale float:
|
|
// [WARP_SIZE * 8 bytes packed e2m1 values] [WARP_SIZE * 1 byte E4M3 per-group scales] ... [global_scale (4 bytes)]
|
|
// Each lane writes one 64-bit packed e2m1 for its 16 values and one 1-byte E4M3 scale per group.
|
|
// Global scale is computed as (448*6)/absmax and written once at the end of the buffer.
|
|
template <typename DType>
|
|
__device__ __forceinline__ void quantize_nvfp4_sharedmem(uint8_t* compact_ptr, int sizeInBytes, int laneId)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
|
int const numElems = sizeInBytes / sizeof(DType);
|
|
assert(numElems % 2 == 0);
|
|
if (numElems <= 0)
|
|
{
|
|
return;
|
|
}
|
|
|
|
DType const* in = reinterpret_cast<DType const*>(compact_ptr);
|
|
|
|
// 1) Global absmax across the field (warp reduce) in original dtype precision when possible
|
|
float threadMaxFloat = 0.f;
|
|
if constexpr (std::is_same_v<DType, half> || std::is_same_v<DType, __nv_bfloat16>)
|
|
{
|
|
using DType2 = typename tensorrt_llm::common::packed_as<DType, 2>::type;
|
|
DType2 const* in2 = reinterpret_cast<DType2 const*>(in);
|
|
int const numPairs = numElems / 2;
|
|
|
|
// Initialize to zero to avoid a concentrated shared-memory read from index 0 across all lanes
|
|
DType2 localMax2;
|
|
localMax2.x = DType(0.);
|
|
localMax2.y = DType(0.);
|
|
// stride over pairs
|
|
for (int i = laneId; i < numPairs; i += WARP_SIZE)
|
|
{
|
|
DType2 v2 = in2[i];
|
|
localMax2 = tensorrt_llm::common::cuda_max(localMax2, tensorrt_llm::common::cuda_abs(v2));
|
|
}
|
|
// Reduce vector to scalar float in-thread
|
|
DType localMax = tensorrt_llm::common::cuda_max<DType, DType2>(localMax2);
|
|
threadMaxFloat = tensorrt_llm::common::cuda_cast<float>(localMax);
|
|
}
|
|
else
|
|
{
|
|
float localMax = 0.f;
|
|
for (int i = laneId; i < numElems; i += WARP_SIZE)
|
|
{
|
|
float v = fabsf(tensorrt_llm::common::cuda_cast<float>(in[i]));
|
|
localMax = fmaxf(localMax, v);
|
|
}
|
|
threadMaxFloat = localMax;
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
{
|
|
threadMaxFloat = fmaxf(threadMaxFloat, __shfl_xor_sync(0xffffffff, threadMaxFloat, offset));
|
|
}
|
|
float const eps = 1e-12f;
|
|
float const globalAbsMax = fmaxf(threadMaxFloat, eps);
|
|
|
|
// 2) Global scale
|
|
float const SFScaleVal = (448.0f * 6.0f) * (1.0f / globalAbsMax);
|
|
|
|
// 3) Output layout
|
|
int const numGroups = (numElems + WARP_SIZE * 16 - 1) / (WARP_SIZE * 16);
|
|
|
|
// 8 bytes for e2m1, 1 byte for scale
|
|
int const outputBlockSizeInBytes = 8 * WARP_SIZE + WARP_SIZE;
|
|
uint8_t* const globalScaleOutBytes = compact_ptr + numGroups * outputBlockSizeInBytes;
|
|
|
|
// 4) Per-16 group quantization
|
|
int const swizzle_idy = laneId / 4;
|
|
int const swizzle_idx = (laneId % 4) * 8;
|
|
|
|
for (int groupId = 0; groupId < numGroups; groupId++)
|
|
{
|
|
int groupStart = groupId * (WARP_SIZE * 16);
|
|
float vecMax = 0.f;
|
|
float2 raw[8];
|
|
|
|
if constexpr (std::is_same_v<DType, half> || std::is_same_v<DType, __nv_bfloat16>)
|
|
{
|
|
using DType2 = typename tensorrt_llm::common::packed_as<DType, 2>::type;
|
|
int const numPairs = numElems / 2;
|
|
DType2 const* in2Ptr = reinterpret_cast<DType2 const*>(in);
|
|
int const pairBase = groupStart >> 1;
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < 8; ++i)
|
|
{
|
|
int const pi = pairBase + swizzle_idy * 32 + swizzle_idx + (i + swizzle_idy) % 8;
|
|
if (pi < numPairs)
|
|
{
|
|
DType2 v2 = in2Ptr[pi];
|
|
float x = tensorrt_llm::common::cuda_cast<float>(v2.x);
|
|
float y = tensorrt_llm::common::cuda_cast<float>(v2.y);
|
|
raw[i] = make_float2(x, y);
|
|
vecMax = fmaxf(vecMax, fmaxf(fabsf(x), fabsf(y)));
|
|
}
|
|
else
|
|
{
|
|
raw[i] = make_float2(0.0f, 0.0f);
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
groupStart += laneId * 16;
|
|
#pragma unroll
|
|
for (int i = 0; i < 8; ++i)
|
|
{
|
|
int idx = groupStart + (i << 1);
|
|
if (idx < numElems)
|
|
{
|
|
float x = tensorrt_llm::common::cuda_cast<float>(in[idx]);
|
|
float y = (idx + 1 < numElems) ? tensorrt_llm::common::cuda_cast<float>(in[idx + 1]) : 0.0f;
|
|
raw[i] = make_float2(x, y);
|
|
vecMax = fmaxf(vecMax, fmaxf(fabsf(x), fabsf(y)));
|
|
}
|
|
else
|
|
{
|
|
raw[i] = make_float2(0.0f, 0.0f);
|
|
}
|
|
}
|
|
}
|
|
|
|
// SF from vecMax and global scale; write as E4M3
|
|
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
|
__nv_fp8_e4m3 sf8 = __nv_fp8_e4m3(SFValue);
|
|
float SFValueNarrow = static_cast<float>(sf8);
|
|
float const outputScale = (vecMax != 0.f)
|
|
? reciprocal_approximate_ftz(SFValueNarrow * reciprocal_approximate_ftz(SFScaleVal))
|
|
: 0.0f;
|
|
|
|
// Pack 16 values -> 8 bytes e2m1 (use raw[] read above to avoid a second shared-memory read)
|
|
float2 fp2Vals[8];
|
|
#pragma unroll
|
|
for (int i = 0; i < 8; ++i)
|
|
{
|
|
fp2Vals[i] = make_float2(raw[i].x * outputScale, raw[i].y * outputScale);
|
|
}
|
|
uint64_t const e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
|
|
|
uint8_t* const outValPtr = compact_ptr + groupId * outputBlockSizeInBytes + laneId * sizeof(uint64_t);
|
|
uint8_t* const outScalePtr
|
|
= compact_ptr + groupId * outputBlockSizeInBytes + WARP_SIZE * sizeof(uint64_t) + laneId * sizeof(uint8_t);
|
|
|
|
if (laneId < 16)
|
|
{
|
|
reinterpret_cast<uint64_t*>(outValPtr)[0] = e2m1Vec;
|
|
}
|
|
__syncwarp();
|
|
if (laneId >= 16)
|
|
{
|
|
reinterpret_cast<uint64_t*>(outValPtr)[0] = e2m1Vec;
|
|
}
|
|
outScalePtr[0] = sf8.__x;
|
|
}
|
|
|
|
// Store global scale (fp32) once with a single 32-bit store. Use lane 0 to avoid races.
|
|
if (laneId == 0)
|
|
{
|
|
*reinterpret_cast<float*>(globalScaleOutBytes) = SFScaleVal;
|
|
}
|
|
#endif
|
|
}
|
|
|
|
// Convert one lane's packed 16 e2m1 values (in a 64-bit word) into eight float2 values (16 floats).
|
|
// Uses 8 cvt.rn.f16x2.e2m1x2 instructions, one per input byte, to produce eight half2 which are cast to float2.
|
|
inline __device__ void e2m1_to_fp32_vec(uint64_t e2m1Vec, float2 (&array)[8])
|
|
{
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
|
uint32_t out_fp16[8];
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .b8 b0;\n"
|
|
".reg .b8 b1;\n"
|
|
".reg .b8 b2;\n"
|
|
".reg .b8 b3;\n"
|
|
".reg .b8 b4;\n"
|
|
".reg .b8 b5;\n"
|
|
".reg .b8 b6;\n"
|
|
".reg .b8 b7;\n"
|
|
".reg .b32 lo;\n"
|
|
".reg .b32 hi;\n"
|
|
"mov.b64 {lo, hi}, %8;\n"
|
|
"mov.b32 {b0, b1, b2, b3}, lo;\n"
|
|
"mov.b32 {b4, b5, b6, b7}, hi;\n"
|
|
"cvt.rn.f16x2.e2m1x2 %0, b0;\n"
|
|
"cvt.rn.f16x2.e2m1x2 %1, b1;\n"
|
|
"cvt.rn.f16x2.e2m1x2 %2, b2;\n"
|
|
"cvt.rn.f16x2.e2m1x2 %3, b3;\n"
|
|
"cvt.rn.f16x2.e2m1x2 %4, b4;\n"
|
|
"cvt.rn.f16x2.e2m1x2 %5, b5;\n"
|
|
"cvt.rn.f16x2.e2m1x2 %6, b6;\n"
|
|
"cvt.rn.f16x2.e2m1x2 %7, b7;\n"
|
|
"}"
|
|
: "=r"(out_fp16[0]), "=r"(out_fp16[1]), "=r"(out_fp16[2]), "=r"(out_fp16[3]), "=r"(out_fp16[4]),
|
|
"=r"(out_fp16[5]), "=r"(out_fp16[6]), "=r"(out_fp16[7])
|
|
: "l"(e2m1Vec));
|
|
|
|
array[0] = __half22float2(reinterpret_cast<__half2&>(out_fp16[0]));
|
|
array[1] = __half22float2(reinterpret_cast<__half2&>(out_fp16[1]));
|
|
array[2] = __half22float2(reinterpret_cast<__half2&>(out_fp16[2]));
|
|
array[3] = __half22float2(reinterpret_cast<__half2&>(out_fp16[3]));
|
|
array[4] = __half22float2(reinterpret_cast<__half2&>(out_fp16[4]));
|
|
array[5] = __half22float2(reinterpret_cast<__half2&>(out_fp16[5]));
|
|
array[6] = __half22float2(reinterpret_cast<__half2&>(out_fp16[6]));
|
|
array[7] = __half22float2(reinterpret_cast<__half2&>(out_fp16[7]));
|
|
#endif
|
|
}
|
|
|
|
template <typename DType>
|
|
__device__ __forceinline__ void dequantize_nvfp4_sharedmem(uint8_t* compact_ptr, int sizeInBytes, int laneId)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
|
int const numElems = sizeInBytes / sizeof(DType);
|
|
if (numElems <= 0)
|
|
{
|
|
return;
|
|
}
|
|
|
|
int const numGroups = (numElems + WARP_SIZE * 16 - 1) / (WARP_SIZE * 16);
|
|
|
|
// New layout matches quantize: per-group blocks of [8*WARP_SIZE bytes values][WARP_SIZE bytes scales],
|
|
// followed by a single 4-byte global scale at the end.
|
|
int const inputBlockSizeInBytes = 8 * WARP_SIZE + WARP_SIZE;
|
|
uint8_t* const globalScaleOutBytes = compact_ptr + numGroups * inputBlockSizeInBytes;
|
|
float const SFScaleVal = reciprocal_approximate_ftz(*reinterpret_cast<float const*>(globalScaleOutBytes));
|
|
__syncwarp();
|
|
|
|
DType* out = reinterpret_cast<DType*>(compact_ptr);
|
|
|
|
// Process groups in reverse order to avoid overwriting packed input before it is read
|
|
for (int groupId = numGroups - 1; groupId >= 0; --groupId)
|
|
{
|
|
int const groupStart = laneId * 16 + groupId * (WARP_SIZE * 16);
|
|
// Conflict-free read of packed 64-bit e2m1 values from shared memory:
|
|
// serialize half-warps to avoid lane i and i+16 hitting the same bank in the same cycle.
|
|
uint8_t const* const valBase = compact_ptr + groupId * inputBlockSizeInBytes;
|
|
uint64_t packed = 0ull;
|
|
if (laneId < 16)
|
|
{
|
|
packed = reinterpret_cast<uint64_t const*>(valBase)[laneId];
|
|
}
|
|
__syncwarp();
|
|
if (laneId >= 16)
|
|
{
|
|
packed = reinterpret_cast<uint64_t const*>(valBase)[laneId];
|
|
}
|
|
|
|
// Read per-lane 1-byte scales to match quantize access pattern
|
|
uint8_t const* const scalesBase = compact_ptr + groupId * inputBlockSizeInBytes + WARP_SIZE * sizeof(uint64_t);
|
|
uint8_t sfByte = scalesBase[laneId];
|
|
__nv_fp8_e4m3 sf8;
|
|
sf8.__x = sfByte;
|
|
float const SFValueNarrow = static_cast<float>(sf8);
|
|
float const dequantScale = SFScaleVal * SFValueNarrow;
|
|
__syncwarp();
|
|
|
|
float2 tmp[8];
|
|
e2m1_to_fp32_vec(packed, tmp);
|
|
|
|
// Vectorized stores with swizzle to avoid bank conflicts, matching quantize path
|
|
if constexpr (std::is_same_v<DType, half> || std::is_same_v<DType, __nv_bfloat16>)
|
|
{
|
|
using DType2 = typename tensorrt_llm::common::packed_as<DType, 2>::type;
|
|
DType2* out2 = reinterpret_cast<DType2*>(out);
|
|
int const numPairs = numElems / 2;
|
|
int const pairBase = (groupId * (WARP_SIZE * 16)) >> 1;
|
|
int const swizzle_idy = laneId / 4;
|
|
int const swizzle_idx = (laneId % 4) * 8;
|
|
|
|
#pragma unroll
|
|
for (int t = 0; t < 8; ++t)
|
|
{
|
|
int const pi = pairBase + swizzle_idy * 32 + swizzle_idx + (t + swizzle_idy) % 8;
|
|
if (pi < numPairs)
|
|
{
|
|
DType2 v2;
|
|
v2.x = tensorrt_llm::common::cuda_cast<DType>(tmp[t].x * dequantScale);
|
|
v2.y = tensorrt_llm::common::cuda_cast<DType>(tmp[t].y * dequantScale);
|
|
out2[pi] = v2;
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// Fallback linear layout for non-16-bit types
|
|
#pragma unroll
|
|
for (int t = 0; t < 8; ++t)
|
|
{
|
|
int idx0 = groupStart + (t << 1);
|
|
if (idx0 < numElems)
|
|
{
|
|
using DType2 = typename tensorrt_llm::common::packed_as<DType, 2>::type;
|
|
DType2 v2;
|
|
v2.x = tensorrt_llm::common::cuda_cast<DType>(tmp[t].x * dequantScale);
|
|
v2.y = tensorrt_llm::common::cuda_cast<DType>(tmp[t].y * dequantScale);
|
|
reinterpret_cast<DType2*>(out + idx0)[0] = v2;
|
|
}
|
|
}
|
|
}
|
|
__syncwarp();
|
|
}
|
|
#endif
|
|
}
|
|
|
|
__host__ void MoeCommFieldInfo::fillFieldInfo(
|
|
uint8_t* dataPtr, size_t elementSize, int vectorSize, int stride, cudaDataType_t dataType)
|
|
{
|
|
TLLM_CHECK(elementSize == 1 || elementSize == 2 || elementSize == 4 || elementSize == 8 || elementSize == 16);
|
|
|
|
dataPtrBase = dataPtr;
|
|
|
|
uint64_t dataPtrU64 = reinterpret_cast<uint64_t>(dataPtr);
|
|
|
|
while (elementSize < 16 && dataPtrU64 % (elementSize * 2) == 0 && vectorSize % 2 == 0 && stride % 2 == 0)
|
|
{
|
|
elementSize *= 2;
|
|
vectorSize /= 2;
|
|
stride /= 2;
|
|
}
|
|
|
|
if (elementSize == 16)
|
|
{
|
|
alignedUnitBit = 4;
|
|
}
|
|
else if (elementSize == 8)
|
|
{
|
|
alignedUnitBit = 3;
|
|
}
|
|
else if (elementSize == 4)
|
|
{
|
|
alignedUnitBit = 2;
|
|
}
|
|
else if (elementSize == 2)
|
|
{
|
|
alignedUnitBit = 1;
|
|
}
|
|
else
|
|
{
|
|
alignedUnitBit = 0;
|
|
}
|
|
|
|
alignedUnitCount = vectorSize;
|
|
alignedUnitStride = stride;
|
|
originalDataType = dataType;
|
|
}
|
|
|
|
// Wrapper class that delegates to LL128Proto but accepts extra warpId parameter for backward compatibility
|
|
class Ll128ProtoWrapper
|
|
{
|
|
public:
|
|
static constexpr uint32_t INITIALIZED_VALUE = LL128Proto::INITIALIZED_VALUE;
|
|
|
|
template <bool USE_FINISH>
|
|
static __device__ __forceinline__ int checkDataReceivedInShm(uint8_t* sharedMemoryBase, uint64_t step,
|
|
int countIn128Bytes, int fifoEntry128ByteIndexBase, int loaded128ByteCount, int /*warpId*/, int laneId)
|
|
{
|
|
return LL128Proto::checkDataReceivedInShm<USE_FINISH>(
|
|
sharedMemoryBase, step, countIn128Bytes, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId);
|
|
}
|
|
|
|
static __device__ __forceinline__ void protoPack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes,
|
|
int fifoEntry128ByteIndexBase, int /*warpId*/, int laneId)
|
|
{
|
|
LL128Proto::protoPack(sharedMemoryBase, step, countIn128Bytes, fifoEntry128ByteIndexBase, laneId);
|
|
}
|
|
|
|
static __device__ __forceinline__ void protoUnpack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes,
|
|
int fifoEntry128ByteIndexBase, int loaded128ByteCount, int /*warpId*/, int laneId)
|
|
{
|
|
LL128Proto::protoUnpack(
|
|
sharedMemoryBase, step, countIn128Bytes, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId);
|
|
}
|
|
|
|
static __device__ __forceinline__ void rearm(uint32_t* u32FifoPtr, uint64_t step, int countIn128Bytes,
|
|
int fifoEntry128ByteIndexBase, int /*warpId*/, int laneId)
|
|
{
|
|
LL128Proto::rearm(u32FifoPtr, step, countIn128Bytes, fifoEntry128ByteIndexBase, laneId);
|
|
}
|
|
|
|
static __device__ __host__ __forceinline__ int computeProtoTransfer128ByteAlignedSize(
|
|
int compact128ByteSizeBeforeProto)
|
|
{
|
|
return LL128Proto::computeProtoTransfer128ByteAlignedSize(compact128ByteSizeBeforeProto);
|
|
}
|
|
};
|
|
|
|
using FusedMoeProto = Ll128ProtoWrapper;
|
|
|
|
// using FusedMoeProto = LamportProto;
|
|
|
|
namespace fused_moe_impl
|
|
{
|
|
|
|
// returns copy size for txCount
|
|
__device__ __forceinline__ int startFieldG2S(MoeCommFieldInfo const& fieldInfo, int dataIndex,
|
|
uint8_t* sharedMemoryBase, int warpId, int laneId, uint64_t* smemBar)
|
|
{
|
|
// we can copy more data than needed, just align to 16 bytes.
|
|
int alignedShmLoadOffset = fieldInfo.getUncompactShmOffset();
|
|
uint8_t* sharedMemoryLoadPtr = sharedMemoryBase + alignedShmLoadOffset;
|
|
int copyByteCount = 0;
|
|
uint8_t* loadPtr = fieldInfo.get16BAlignedLoadCopyRange(dataIndex, ©ByteCount);
|
|
if (laneId == 0 && copyByteCount > 0)
|
|
{
|
|
cp_async_bulk_g2s(sharedMemoryLoadPtr, loadPtr, copyByteCount, smemBar);
|
|
}
|
|
return copyByteCount;
|
|
}
|
|
|
|
__device__ __forceinline__ void startFieldS2G(
|
|
MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId)
|
|
{
|
|
int alignedShmStoreOffset = fieldInfo.getUncompactShmOffset();
|
|
uint8_t* sharedMemoryStorePtr = sharedMemoryBase + alignedShmStoreOffset;
|
|
int copyByteCount = 0;
|
|
int headTailShmIdx;
|
|
int headTailGlobalIdx;
|
|
uint8_t* storePtr
|
|
= fieldInfo.get16BAlignedStoreCopyRange(dataIndex, ©ByteCount, laneId, &headTailShmIdx, &headTailGlobalIdx);
|
|
if (copyByteCount > 0 && laneId == 0)
|
|
{
|
|
cp_async_bulk_s2g(storePtr, sharedMemoryStorePtr + MoeCommFieldInfo::BYTES_PER_16B_BLOCK, copyByteCount);
|
|
}
|
|
if (headTailGlobalIdx >= 0)
|
|
{
|
|
// copy head and tail
|
|
fieldInfo.getRawPtr(dataIndex, nullptr)[headTailGlobalIdx] = sharedMemoryStorePtr[headTailShmIdx];
|
|
}
|
|
__syncwarp();
|
|
}
|
|
|
|
// SRC_AFTER_DST is true, if src > dst, pack will use this,
|
|
// SRC_AFTER_DST is false, if src < dst, unpack will use this
|
|
template <typename T, bool SRC_AFTER_DST = true>
|
|
__device__ __forceinline__ void memmoveSharedMemory(uint8_t* dst, uint8_t const* src, int copySize, int laneId)
|
|
{
|
|
int count = (copySize + sizeof(T) - 1) / sizeof(T);
|
|
int warpLoopStart = SRC_AFTER_DST ? 0 : (count + WARP_SIZE - 1) / WARP_SIZE - 1;
|
|
int warpLoopEnd = SRC_AFTER_DST ? (count + WARP_SIZE - 1) / WARP_SIZE : -1;
|
|
int warpLoopUpdate = SRC_AFTER_DST ? 1 : -1;
|
|
for (int i = warpLoopStart; i != warpLoopEnd; i += warpLoopUpdate)
|
|
{
|
|
int idx = laneId + i * WARP_SIZE;
|
|
T data = T{};
|
|
if (idx < count)
|
|
{
|
|
data = reinterpret_cast<T const*>(src)[idx];
|
|
}
|
|
__syncwarp();
|
|
if (idx < count)
|
|
{
|
|
reinterpret_cast<T*>(dst)[idx] = data;
|
|
}
|
|
__syncwarp();
|
|
}
|
|
}
|
|
|
|
template <bool IS_PACK = true>
|
|
__device__ __forceinline__ void memmoveFieldOnSharedMemory(
|
|
MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
|
|
{
|
|
int movOffset = fieldInfo.getMemmoveOffsets(dataIndex);
|
|
if (movOffset == 0)
|
|
{
|
|
// if movOffset is 0, src and dst are the same, don't need memmove.
|
|
return;
|
|
}
|
|
int alignedBytes = 1 << fieldInfo.alignedUnitBit;
|
|
int copySize = fieldInfo.alignedUnitCount * alignedBytes;
|
|
uint8_t* sharedMemoryCompact = sharedMemoryBase + fieldInfo.getCompactShmOffset();
|
|
uint8_t* sharedMemoryUncompact = sharedMemoryCompact + movOffset;
|
|
uint8_t* sharedMemoryDst = IS_PACK ? sharedMemoryCompact : sharedMemoryUncompact;
|
|
uint8_t* sharedMemorySrc = IS_PACK ? sharedMemoryUncompact : sharedMemoryCompact;
|
|
|
|
if (movOffset % 16 == 0)
|
|
{
|
|
memmoveSharedMemory<int4, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
|
|
}
|
|
else if (movOffset % 8 == 0)
|
|
{
|
|
memmoveSharedMemory<int64_t, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
|
|
}
|
|
else if (movOffset % 4 == 0)
|
|
{
|
|
memmoveSharedMemory<int, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
|
|
}
|
|
else if (movOffset % 2 == 0)
|
|
{
|
|
memmoveSharedMemory<int16_t, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
|
|
}
|
|
else
|
|
{
|
|
memmoveSharedMemory<int8_t, IS_PACK>(sharedMemoryDst, sharedMemorySrc, copySize, laneId);
|
|
}
|
|
}
|
|
|
|
template <int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
|
|
__device__ __forceinline__ void packAllFields(
|
|
FusedMoeFieldInfo const& sendFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
|
|
{
|
|
#pragma unroll
|
|
for (int i = 0; i < FIELD_COUNT; i++)
|
|
{
|
|
memmoveFieldOnSharedMemory<true>(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId);
|
|
}
|
|
__syncwarp();
|
|
}
|
|
|
|
template <int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
|
|
__device__ __forceinline__ void unpackAllFields(
|
|
FusedMoeFieldInfo const& recvFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
|
|
{
|
|
#pragma unroll
|
|
for (int i = FIELD_COUNT - 1; i >= 0; i--)
|
|
{
|
|
memmoveFieldOnSharedMemory<false>(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId);
|
|
}
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ void startWorkspaceS2G(
|
|
uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int warpId, int laneId)
|
|
{
|
|
int copyByteCount = send128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
|
|
if (laneId == 0)
|
|
{
|
|
cp_async_bulk_s2g(fifoEntry + fifo128ByteOffset * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t),
|
|
sharedMemoryBase, copyByteCount);
|
|
}
|
|
__syncwarp();
|
|
cp_async_bulk_commit_group();
|
|
}
|
|
|
|
__device__ __forceinline__ void startWorkspaceS2GReg(
|
|
uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int warpId, int laneId)
|
|
{
|
|
int copyInt4Count = send128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int4);
|
|
int4* sharedMemoryInt4 = reinterpret_cast<int4*>(sharedMemoryBase);
|
|
uint64_t* fifoPtr = fifoEntry + fifo128ByteOffset * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t);
|
|
int4* fifoPtrInt4 = reinterpret_cast<int4*>(fifoPtr);
|
|
#pragma unroll 4
|
|
for (int i = laneId; i < copyInt4Count; i += WARP_SIZE)
|
|
{
|
|
fifoPtrInt4[i] = sharedMemoryInt4[i];
|
|
}
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* sharedMemoryBase, uint64_t* fifoEntry,
|
|
int allLoad128ByteCount, int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int warpId, int laneId)
|
|
{
|
|
int copyByteCount = (allLoad128ByteCount - loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
|
|
if (laneId == 0)
|
|
{
|
|
cp_async_bulk_g2s(sharedMemoryBase + loaded128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK,
|
|
fifoEntry
|
|
+ (fifo128ByteOffset + loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t),
|
|
copyByteCount, smemBar);
|
|
}
|
|
return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? copyByteCount : 0);
|
|
}
|
|
|
|
__device__ __forceinline__ void g2sBasicFields(FusedMoeFieldInfo const& sendFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId)
|
|
{
|
|
int topK = expertParallelInfo.topK;
|
|
int* tokenSelectedSlotsPtr = sendFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK);
|
|
float* scalePtr = sendFieldInfo.getScalePtr(dataIndex, laneId, topK);
|
|
ldgsts<4>(reinterpret_cast<int*>(sharedMemoryBase) + laneId, tokenSelectedSlotsPtr, laneId < topK);
|
|
ldgsts<4>(reinterpret_cast<int*>(sharedMemoryBase) + laneId + topK, reinterpret_cast<int*>(scalePtr),
|
|
laneId < topK && sendFieldInfo.expertScales != nullptr);
|
|
}
|
|
|
|
// May commit 1 group for basic fields(tokenSelectedSlots and scales) if HAS_BASIC_FIELDS is true
|
|
// For other fields, use smemBar.
|
|
template <bool HAS_BASIC_FIELDS = true, int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
|
|
__device__ __forceinline__ uint64_t g2sAllFields(FusedMoeFieldInfo const& sendFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId,
|
|
uint64_t* smemBar)
|
|
{
|
|
if (HAS_BASIC_FIELDS)
|
|
{
|
|
g2sBasicFields(sendFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, laneId);
|
|
cp_async_commit_group();
|
|
}
|
|
int asyncLoadSize = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < FIELD_COUNT; i++)
|
|
{
|
|
asyncLoadSize
|
|
+= startFieldG2S(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId, smemBar);
|
|
}
|
|
return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? asyncLoadSize : 0);
|
|
}
|
|
|
|
template <bool HAS_BASIC_FIELDS = true>
|
|
__device__ __forceinline__ void waitG2SBasicFields()
|
|
{
|
|
if (HAS_BASIC_FIELDS)
|
|
{
|
|
cp_async_wait_group<0>();
|
|
__syncwarp();
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void waitG2SOtherFields(uint64_t* memBar, uint32_t* phaseParity)
|
|
{
|
|
smemBarWait(memBar, phaseParity);
|
|
}
|
|
|
|
template <bool HAS_BASIC_FIELDS = true>
|
|
__device__ __forceinline__ void waitG2SAllFields(uint64_t* memBar, uint32_t* phaseParity)
|
|
{
|
|
waitG2SBasicFields<HAS_BASIC_FIELDS>();
|
|
waitG2SOtherFields(memBar, phaseParity);
|
|
}
|
|
|
|
__device__ __forceinline__ void waitS2GBulkRead()
|
|
{
|
|
cp_async_bulk_wait_group_read<0>();
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ void s2gBasicFields(FusedMoeFieldInfo const& recvFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId)
|
|
{
|
|
int topK = expertParallelInfo.topK;
|
|
int* tokenSelectedSlotsPtr = recvFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK);
|
|
float* scalePtr = recvFieldInfo.getScalePtr(dataIndex, laneId, topK);
|
|
if (laneId < topK)
|
|
{
|
|
int selectedSlot = reinterpret_cast<int*>(sharedMemoryBase)[laneId];
|
|
*tokenSelectedSlotsPtr = selectedSlot;
|
|
if (recvFieldInfo.expertScales != nullptr)
|
|
{
|
|
float scale = reinterpret_cast<float*>(sharedMemoryBase)[laneId + topK];
|
|
*scalePtr = scale;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Will commit 1 group, for all non-basic fields
|
|
template <bool HAS_BASIC_FIELDS = true, int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT>
|
|
__device__ __forceinline__ void s2gAllFields(FusedMoeFieldInfo const& recvFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId)
|
|
{
|
|
if (HAS_BASIC_FIELDS)
|
|
{
|
|
s2gBasicFields(recvFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, warpId, laneId);
|
|
__syncwarp();
|
|
}
|
|
#pragma unroll
|
|
for (int i = 0; i < FIELD_COUNT; i++)
|
|
{
|
|
startFieldS2G(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId);
|
|
}
|
|
cp_async_bulk_commit_group();
|
|
}
|
|
|
|
template <int FIELD_COUNT, bool HAS_BASIC_FIELD = true, bool LOW_PRECISION = false>
|
|
class SingleChannelCommunicator
|
|
{
|
|
public:
|
|
__device__ __forceinline__ SingleChannelCommunicator(FusedMoeFieldInfo const& fieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, MoeSingleCommMeta const& commMeta,
|
|
FusedMoeWorkspace const& workspace, FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo,
|
|
uint64_t* smemBar, uint8_t* shmemBase)
|
|
: mFieldInfo(fieldInfo)
|
|
, mExpertParallelInfo(expertParallelInfo)
|
|
, mCommMeta(commMeta)
|
|
, mWorkspace(workspace)
|
|
, mWorldInfo(worldInfo)
|
|
, mPairInfo(pairInfo)
|
|
, mSmemBar(smemBar)
|
|
, mShmemBase(shmemBase)
|
|
{
|
|
if constexpr (LOW_PRECISION)
|
|
{
|
|
static_assert(FIELD_COUNT == 1, "Low precision alltoall only support 1 field");
|
|
}
|
|
|
|
mWarpId = threadIdx.x / WARP_SIZE;
|
|
mLaneId = threadIdx.x % WARP_SIZE;
|
|
|
|
mFifoBasePtr = mWorkspace.getFifoBasePtr(mWorldInfo, mPairInfo);
|
|
mSenderSideFifoInfo = mWorkspace.getSenderSideFifoInfo(mWorldInfo, mPairInfo);
|
|
mReceiverSideFifoInfo = mWorkspace.getReceiverSideFifoInfo(mWorldInfo, mPairInfo);
|
|
|
|
mSingleTransfer128ByteCount = mCommMeta.getTransfer128ByteCount();
|
|
mSingleCompactData128ByteCount = mCommMeta.getCompactData128ByteCount();
|
|
// initialize as need new Entry first
|
|
mFifoEntry128ByteIndexBase = kFifoEntry128ByteCount;
|
|
mFifoEntryIndex = -1;
|
|
|
|
initSmemBar(mSmemBar, mLaneId);
|
|
}
|
|
|
|
__device__ __forceinline__ uint64_t* getFifoEntryPtr() const
|
|
{
|
|
return mFifoBasePtr + mFifoEntryIndex * kFifoEntrySizeInU64;
|
|
}
|
|
|
|
__device__ __forceinline__ bool needNewEntry() const
|
|
{
|
|
return mFifoEntry128ByteIndexBase + mSingleTransfer128ByteCount > kFifoEntry128ByteCount;
|
|
}
|
|
|
|
__device__ __forceinline__ void nextToken()
|
|
{
|
|
mFifoEntry128ByteIndexBase += mSingleTransfer128ByteCount;
|
|
}
|
|
|
|
__device__ __forceinline__ void senderInitFifo()
|
|
{
|
|
mHead = mSenderSideFifoInfo->head;
|
|
mTail = mSenderSideFifoInfo->tail;
|
|
}
|
|
|
|
__device__ __forceinline__ void receiverInitFifo()
|
|
{
|
|
mHead = mReceiverSideFifoInfo->head;
|
|
mTail = mReceiverSideFifoInfo->tail;
|
|
}
|
|
|
|
/*
|
|
* Head | 0 | 1 | 2 | 3 | 4 | 4 | 4 | 4 | 4 | 5 |
|
|
* Tail | 0 | 0 | 0 | 0 | 0 | 1 | 2 | 3 | 4 | 4 |
|
|
* Writable | Y | Y | Y | Y | N | Y | Y | Y | Y | Y |
|
|
* Readable | N | Y | Y | Y | Y | Y | Y | Y | N | Y |
|
|
*/
|
|
|
|
__device__ __forceinline__ void waitEntryWritable()
|
|
{
|
|
while (mTail + kFifoDepth <= mHead)
|
|
{
|
|
mTail = mSenderSideFifoInfo->tail;
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void updateWriteEntry()
|
|
{
|
|
__syncwarp();
|
|
mSenderSideFifoInfo->head = mHead;
|
|
}
|
|
|
|
__device__ __forceinline__ void waitEntryReadable()
|
|
{
|
|
// always readable as long as flag matches.
|
|
}
|
|
|
|
__device__ __forceinline__ void updateReadEntry()
|
|
{
|
|
mReceiverSideFifoInfo->tail = mTail;
|
|
mSenderSideFifoInfo->tail = mTail;
|
|
}
|
|
|
|
__device__ __forceinline__ void newSendEntry()
|
|
{
|
|
mFifoEntryIndex = mHead % kFifoDepth;
|
|
mFifoEntry128ByteIndexBase = 0;
|
|
waitEntryWritable();
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ void newReceiveEntry()
|
|
{
|
|
mFifoEntryIndex = mTail % kFifoDepth;
|
|
mFifoEntry128ByteIndexBase = 0;
|
|
waitEntryReadable();
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ void doSend(int tokenCount, int* sendIndexMapping)
|
|
{
|
|
senderInitFifo();
|
|
|
|
int sendIndex = mPairInfo.channel;
|
|
uint32_t phaseParity = 0;
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cudaGridDependencySynchronize();
|
|
cudaTriggerProgrammaticLaunchCompletion();
|
|
#endif
|
|
for (; sendIndex < tokenCount; sendIndex += mPairInfo.runChannelCount)
|
|
{
|
|
int tokenIndex = sendIndexMapping == nullptr ? sendIndex : sendIndexMapping[sendIndex];
|
|
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<HAS_BASIC_FIELD, FIELD_COUNT>(
|
|
mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId, mSmemBar);
|
|
if (needNewEntry())
|
|
{
|
|
if (mFifoEntryIndex >= 0)
|
|
{
|
|
// not first entry, update FIFO info from last entry.
|
|
mHead++;
|
|
updateWriteEntry();
|
|
}
|
|
newSendEntry();
|
|
}
|
|
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<HAS_BASIC_FIELD>(mSmemBar, &phaseParity);
|
|
tensorrt_llm::kernels::fused_moe_impl::packAllFields<FIELD_COUNT>(
|
|
mFieldInfo, tokenIndex, mShmemBase, mLaneId);
|
|
|
|
if constexpr (LOW_PRECISION)
|
|
{
|
|
// quantize here.
|
|
int alignedUnitBit = mFieldInfo.fieldsInfo[0].alignedUnitBit;
|
|
int alignedUnitCount = mFieldInfo.fieldsInfo[0].alignedUnitCount;
|
|
int sizeInBytes = alignedUnitCount * (1 << alignedUnitBit);
|
|
uint8_t* sharedMemoryCompact = mShmemBase + mFieldInfo.fieldsInfo[0].getCompactShmOffset();
|
|
cudaDataType_t originalDataType = mFieldInfo.fieldsInfo[0].originalDataType;
|
|
|
|
switch (originalDataType)
|
|
{
|
|
case CUDA_R_16BF:
|
|
quantize_nvfp4_sharedmem<__nv_bfloat16>(sharedMemoryCompact, sizeInBytes, mLaneId);
|
|
break;
|
|
case CUDA_R_16F: quantize_nvfp4_sharedmem<half>(sharedMemoryCompact, sizeInBytes, mLaneId); break;
|
|
default: break;
|
|
}
|
|
}
|
|
|
|
FusedMoeProto::protoPack(
|
|
mShmemBase, mHead, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId);
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::startWorkspaceS2GReg(getFifoEntryPtr(), mShmemBase,
|
|
mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId);
|
|
|
|
// tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead();
|
|
|
|
nextToken();
|
|
}
|
|
if (mFifoEntry128ByteIndexBase > 0)
|
|
{
|
|
mHead++;
|
|
updateWriteEntry();
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void rearmFifoBuffer()
|
|
{
|
|
constexpr int kUint32CountPer128Byte = 128 / sizeof(uint32_t);
|
|
uint32_t* fifoPtr = reinterpret_cast<uint32_t*>(getFifoEntryPtr());
|
|
fifoPtr += mFifoEntry128ByteIndexBase * kUint32CountPer128Byte;
|
|
|
|
FusedMoeProto::rearm(fifoPtr, mTail, mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId);
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ __forceinline__ void doReceive(int tokenCount, int* recvIndexMapping)
|
|
{
|
|
receiverInitFifo();
|
|
int recvIndex = mPairInfo.channel;
|
|
uint32_t phaseParity = 0;
|
|
bool needRelease = false;
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cudaGridDependencySynchronize();
|
|
cudaTriggerProgrammaticLaunchCompletion();
|
|
#endif
|
|
for (; recvIndex < tokenCount; recvIndex += mPairInfo.runChannelCount)
|
|
{
|
|
int tokenIndex = recvIndexMapping == nullptr ? recvIndex : recvIndexMapping[recvIndex];
|
|
int loaded128ByteCount = 0;
|
|
if (needNewEntry())
|
|
{
|
|
if (mFifoEntryIndex >= 0)
|
|
{
|
|
// not first entry, update FIFO info from last entry.
|
|
mTail++;
|
|
needRelease = true;
|
|
}
|
|
newReceiveEntry();
|
|
}
|
|
while (loaded128ByteCount < mSingleTransfer128ByteCount)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::startWorkspaceG2S(mShmemBase, getFifoEntryPtr(),
|
|
mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mSmemBar, mWarpId,
|
|
mLaneId);
|
|
if (needRelease)
|
|
{
|
|
updateReadEntry();
|
|
needRelease = false;
|
|
}
|
|
smemBarWait(mSmemBar, &phaseParity);
|
|
loaded128ByteCount += FusedMoeProto::template checkDataReceivedInShm<false>(mShmemBase, mTail,
|
|
mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mWarpId, mLaneId);
|
|
}
|
|
|
|
FusedMoeProto::protoUnpack(mShmemBase, mTail, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase,
|
|
loaded128ByteCount, mWarpId, mLaneId);
|
|
|
|
if constexpr (LOW_PRECISION)
|
|
{
|
|
int alignedUnitBit = mFieldInfo.fieldsInfo[0].alignedUnitBit;
|
|
int alignedUnitCount = mFieldInfo.fieldsInfo[0].alignedUnitCount;
|
|
int sizeInBytes = alignedUnitCount * (1 << alignedUnitBit);
|
|
uint8_t* sharedMemoryCompact = mShmemBase + mFieldInfo.fieldsInfo[0].getCompactShmOffset();
|
|
cudaDataType_t originalDataType = mFieldInfo.fieldsInfo[0].originalDataType;
|
|
|
|
switch (originalDataType)
|
|
{
|
|
case CUDA_R_16BF:
|
|
dequantize_nvfp4_sharedmem<__nv_bfloat16>(sharedMemoryCompact, sizeInBytes, mLaneId);
|
|
break;
|
|
case CUDA_R_16F: dequantize_nvfp4_sharedmem<half>(sharedMemoryCompact, sizeInBytes, mLaneId); break;
|
|
default: break;
|
|
}
|
|
}
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::unpackAllFields<FIELD_COUNT>(
|
|
mFieldInfo, tokenIndex, mShmemBase, mLaneId);
|
|
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<HAS_BASIC_FIELD, FIELD_COUNT>(
|
|
mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId);
|
|
tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead();
|
|
|
|
rearmFifoBuffer();
|
|
nextToken();
|
|
}
|
|
if (mFifoEntry128ByteIndexBase > 0)
|
|
{
|
|
mTail++;
|
|
updateReadEntry();
|
|
}
|
|
}
|
|
|
|
private:
|
|
static constexpr int kFifoEntrySizeInU64 = FusedMoeCommunicator::FIFO_ENTRY_BYTES / sizeof(uint64_t);
|
|
static constexpr int kFifoEntry128ByteCount = FusedMoeCommunicator::FIFO_ENTRY_128_BYTE_COUNT;
|
|
static constexpr int kFifoDepth = FusedMoeCommunicator::FIFO_DEPTH;
|
|
|
|
FusedMoeFieldInfo mFieldInfo;
|
|
MoeExpertParallelInfo mExpertParallelInfo;
|
|
MoeSingleCommMeta mCommMeta;
|
|
FusedMoeWorkspace mWorkspace;
|
|
FusedMoeWorldInfo mWorldInfo;
|
|
FusedMoePairInfo mPairInfo;
|
|
uint64_t* mSmemBar;
|
|
uint8_t* mShmemBase;
|
|
|
|
int mLaneId;
|
|
int mWarpId;
|
|
|
|
uint64_t* mFifoBasePtr;
|
|
SenderSideFifoInfo* mSenderSideFifoInfo;
|
|
ReceiverSideFifoInfo* mReceiverSideFifoInfo;
|
|
|
|
int64_t mHead;
|
|
int64_t mTail;
|
|
|
|
int mSingleTransfer128ByteCount;
|
|
int mSingleCompactData128ByteCount;
|
|
int mFifoEntry128ByteIndexBase;
|
|
int mFifoEntryIndex;
|
|
};
|
|
|
|
template <int FIELD_COUNT = MOE_COMM_FIELD_MAX_COUNT, bool LOW_PRECISION = false>
|
|
__global__ void moeAllToAllKernel(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, bool hasBasicFields)
|
|
{
|
|
__shared__ uint64_t allWarpSmemBar[32];
|
|
extern __shared__ int4 allWarpShm[];
|
|
|
|
bool isSender = blockIdx.z == 0;
|
|
int runChannelCount = gridDim.y;
|
|
int group = threadIdx.y;
|
|
SendRecvIndices dataIndices = isSender ? params.sendIndices : params.recvIndices;
|
|
|
|
FusedMoePairInfo pairInfo;
|
|
int peerRank = blockIdx.x * blockDim.y + group;
|
|
if (peerRank >= params.worldInfo.epInfo.epSize)
|
|
{
|
|
return;
|
|
}
|
|
int tokenCount;
|
|
int* groupStartPtr = dataIndices.getGroupStart(peerRank, tokenCount);
|
|
if (tokenCount == 0)
|
|
{
|
|
return;
|
|
}
|
|
|
|
pairInfo.channel = blockIdx.y;
|
|
pairInfo.runChannelCount = runChannelCount;
|
|
pairInfo.senderRank = isSender ? params.worldInfo.epInfo.epRank : peerRank;
|
|
pairInfo.receiverRank = isSender ? peerRank : params.worldInfo.epInfo.epRank;
|
|
|
|
if (isSender)
|
|
{
|
|
int singleShmSize = params.sendCommMeta.getSingleShmSize();
|
|
if (hasBasicFields)
|
|
{
|
|
SingleChannelCommunicator<FIELD_COUNT, true, LOW_PRECISION> comm(params.sendFieldInfo,
|
|
params.expertParallelInfo, params.sendCommMeta, workspace, params.worldInfo, pairInfo,
|
|
allWarpSmemBar + group, reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
|
|
comm.doSend(tokenCount, groupStartPtr);
|
|
}
|
|
else
|
|
{
|
|
SingleChannelCommunicator<FIELD_COUNT, false, LOW_PRECISION> comm(params.sendFieldInfo,
|
|
params.expertParallelInfo, params.sendCommMeta, workspace, params.worldInfo, pairInfo,
|
|
allWarpSmemBar + group, reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
|
|
comm.doSend(tokenCount, groupStartPtr);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
int singleShmSize = params.recvCommMeta.getSingleShmSize();
|
|
if (hasBasicFields)
|
|
{
|
|
SingleChannelCommunicator<FIELD_COUNT, true, LOW_PRECISION> comm(params.recvFieldInfo,
|
|
params.expertParallelInfo, params.recvCommMeta, workspace, params.worldInfo, pairInfo,
|
|
allWarpSmemBar + group, reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
|
|
comm.doReceive(tokenCount, groupStartPtr);
|
|
}
|
|
else
|
|
{
|
|
SingleChannelCommunicator<FIELD_COUNT, false, LOW_PRECISION> comm(params.recvFieldInfo,
|
|
params.expertParallelInfo, params.recvCommMeta, workspace, params.worldInfo, pairInfo,
|
|
allWarpSmemBar + group, reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * group);
|
|
comm.doReceive(tokenCount, groupStartPtr);
|
|
}
|
|
}
|
|
}
|
|
|
|
int computeMoeAlltoallMaxDynamicSharedMemorySize()
|
|
{
|
|
int devId = -1;
|
|
TLLM_CUDA_CHECK(cudaGetDevice(&devId));
|
|
cudaFuncAttributes attr{};
|
|
TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, (void const*) moeAllToAllKernel<1>));
|
|
int staticSmem = static_cast<int>(attr.sharedSizeBytes);
|
|
int maxPerBlockShmOptin = 0;
|
|
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxPerBlockShmOptin, cudaDevAttrMaxSharedMemoryPerBlockOptin, devId));
|
|
return maxPerBlockShmOptin - staticSmem;
|
|
}
|
|
|
|
} // namespace fused_moe_impl
|
|
|
|
void FusedMoeFieldInfo::fillMetaInfo(
|
|
MoeSingleCommMeta* singleCommMeta, int topK, bool hasScales, bool hasBasicFields, bool isLowPrecision) const
|
|
{
|
|
singleCommMeta->singleUncompactAlignedSize = computeSingleUncompactSize(topK, hasScales, hasBasicFields);
|
|
|
|
if (isLowPrecision)
|
|
{
|
|
assert(fieldCount == 1);
|
|
assert(fieldsInfo[0].originalDataType == CUDA_R_16F || fieldsInfo[0].originalDataType == CUDA_R_16BF);
|
|
|
|
auto alignment128 = MoeCommFieldInfo::BYTES_PER_128B_BLOCK;
|
|
|
|
auto alignedUnitBit = fieldsInfo[0].alignedUnitBit;
|
|
auto alignedUnitCount = fieldsInfo[0].alignedUnitCount;
|
|
auto originalFieldSize = alignedUnitCount * (1 << alignedUnitBit);
|
|
|
|
int numElements = originalFieldSize / 2;
|
|
int numGroups = (numElements + WARP_SIZE * 16 - 1) / (WARP_SIZE * 16);
|
|
int sizePerGroupInBytes = (WARP_SIZE * 16 / 2 + WARP_SIZE * 1);
|
|
|
|
int totalSize = numGroups * sizePerGroupInBytes + 4;
|
|
int compactSize = (totalSize + alignment128 - 1) / alignment128 * alignment128;
|
|
|
|
singleCommMeta->singleCompactAlignedSize = compactSize;
|
|
singleCommMeta->singleTransferAlignedSize
|
|
= FusedMoeProto::computeProtoTransfer128ByteAlignedSize(singleCommMeta->singleCompactAlignedSize);
|
|
return;
|
|
}
|
|
|
|
singleCommMeta->singleCompactAlignedSize = computeSingleCompactSize(topK, hasScales, hasBasicFields);
|
|
singleCommMeta->singleTransferAlignedSize
|
|
= FusedMoeProto::computeProtoTransfer128ByteAlignedSize(singleCommMeta->singleCompactAlignedSize);
|
|
}
|
|
|
|
void FusedMoeFieldInfo::fillFieldPlacementInfo(int topK, bool hasBasicFields)
|
|
{
|
|
int basicFieldSize = 0;
|
|
if (hasBasicFields)
|
|
{
|
|
basicFieldSize = topK * sizeof(int) + (expertScales != nullptr ? topK * sizeof(float) : 0);
|
|
// align to 16 bytes
|
|
basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1)
|
|
/ MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
|
|
}
|
|
int offset = basicFieldSize;
|
|
int unalignedFieldIndex = 0;
|
|
for (int i = 0; i < fieldCount; i++)
|
|
{
|
|
fieldsInfo[i].compact16BOffset = offset / MoeCommFieldInfo::BYTES_PER_16B_BLOCK;
|
|
offset += fieldsInfo[i].getFieldCompactSize();
|
|
fieldsInfo[i].unalignedFieldIndex = unalignedFieldIndex;
|
|
if (fieldsInfo[i].alignedUnitBit < 4)
|
|
{
|
|
unalignedFieldIndex++;
|
|
}
|
|
}
|
|
for (int i = fieldCount; i < MOE_COMM_FIELD_MAX_COUNT; i++)
|
|
{
|
|
fieldsInfo[i].setUnused();
|
|
}
|
|
}
|
|
|
|
void FusedMoeWorkspace::initializeLocalWorkspace(FusedMoeWorldInfo const& worldInfo)
|
|
{
|
|
int epSize = worldInfo.epInfo.epSize;
|
|
int epRank = worldInfo.epInfo.epRank;
|
|
size_t fifoSize = static_cast<size_t>(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * epSize * channelCount;
|
|
size_t senderSideInfoSize = sizeof(SenderSideFifoInfo) * epSize * channelCount;
|
|
size_t receiverSideInfoSize = sizeof(ReceiverSideFifoInfo) * epSize * channelCount;
|
|
uint64_t* localWorkspacePtr = workspacePtr + epRank * rankStrideInU64;
|
|
TLLM_CU_CHECK(cuMemsetD32(reinterpret_cast<CUdeviceptr>(localWorkspacePtr), FusedMoeProto::INITIALIZED_VALUE,
|
|
fifoSize / sizeof(uint32_t)));
|
|
TLLM_CUDA_CHECK(cudaMemset(
|
|
reinterpret_cast<uint8_t*>(localWorkspacePtr) + fifoSize, 0, senderSideInfoSize + receiverSideInfoSize));
|
|
}
|
|
|
|
void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cudaStream_t stream)
|
|
{
|
|
bool hasBasicFields = params.sendFieldInfo.tokenSelectedSlots != nullptr;
|
|
int warpSendShmSize = params.sendCommMeta.getSingleShmSize();
|
|
int warpRecvShmSize = params.recvCommMeta.getSingleShmSize();
|
|
int warpShmSize = warpSendShmSize;
|
|
int epSize = params.worldInfo.epInfo.epSize;
|
|
TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)",
|
|
warpSendShmSize, warpRecvShmSize);
|
|
int maxGroupCountPerCta = std::min(params.worldInfo.epInfo.epSize, FusedMoeCommunicator::MAX_GROUP_COUNT_PER_BLOCK);
|
|
static int maxDynamicShmSize = fused_moe_impl::computeMoeAlltoallMaxDynamicSharedMemorySize();
|
|
int groupCountPerCta = std::min(maxGroupCountPerCta, maxDynamicShmSize / warpShmSize);
|
|
|
|
int maxFieldCount = std::max(params.sendFieldInfo.fieldCount, params.recvFieldInfo.fieldCount);
|
|
TLLM_CHECK_WITH_INFO(params.isLowPrecision == false || maxFieldCount == 1, "low precision only support 1 field");
|
|
|
|
auto getFunc = [](int fieldCount, bool lowPrecision)
|
|
{
|
|
switch (fieldCount)
|
|
{
|
|
case 1:
|
|
if (lowPrecision)
|
|
return fused_moe_impl::moeAllToAllKernel<1, true>;
|
|
else
|
|
return fused_moe_impl::moeAllToAllKernel<1>;
|
|
case 2: return fused_moe_impl::moeAllToAllKernel<2>;
|
|
case 3: return fused_moe_impl::moeAllToAllKernel<3>;
|
|
case 4: return fused_moe_impl::moeAllToAllKernel<4>;
|
|
case 5: return fused_moe_impl::moeAllToAllKernel<5>;
|
|
case 6: return fused_moe_impl::moeAllToAllKernel<6>;
|
|
case 7: return fused_moe_impl::moeAllToAllKernel<7>;
|
|
case 8: return fused_moe_impl::moeAllToAllKernel<8>;
|
|
default: return fused_moe_impl::moeAllToAllKernel<8>;
|
|
}
|
|
return fused_moe_impl::moeAllToAllKernel<8>;
|
|
};
|
|
auto* kernelFn = getFunc(maxFieldCount, params.isLowPrecision);
|
|
|
|
if (groupCountPerCta * warpShmSize > 48 * 1024)
|
|
{
|
|
TLLM_CUDA_CHECK(cudaFuncSetAttribute(
|
|
kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, groupCountPerCta * warpShmSize));
|
|
}
|
|
for (; groupCountPerCta > 0; groupCountPerCta--)
|
|
{
|
|
int dynamicShmSize = groupCountPerCta * warpShmSize;
|
|
int numBlocks = 0;
|
|
if (cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
|
&numBlocks, kernelFn, WARP_SIZE * groupCountPerCta, dynamicShmSize)
|
|
!= cudaSuccess)
|
|
{
|
|
continue;
|
|
}
|
|
if (numBlocks >= 1)
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
TLLM_CHECK_WITH_INFO(
|
|
groupCountPerCta >= 1, "computed groupCount=%d, warpShmSize=%d", groupCountPerCta, warpShmSize);
|
|
int ctaPerChannel = (epSize + groupCountPerCta - 1) / groupCountPerCta;
|
|
groupCountPerCta = (epSize + ctaPerChannel - 1) / ctaPerChannel;
|
|
int totalDynamicShmSize = warpShmSize * groupCountPerCta;
|
|
|
|
dim3 block = FusedMoeCommunicator::getLaunchBlockDim(groupCountPerCta);
|
|
dim3 grid = FusedMoeCommunicator::getLaunchGridDim(params.worldInfo.epInfo.epSize, groupCountPerCta);
|
|
launchWithPdlWhenEnabled(
|
|
"moeAllToAll", kernelFn, grid, block, totalDynamicShmSize, stream, params, workspace, hasBasicFields);
|
|
TLLM_CUDA_CHECK(cudaGetLastError());
|
|
}
|
|
|
|
int FusedMoeCommunicator::maxSmCount = -1;
|
|
bool FusedMoeCommunicator::maxSmCountUsed = false;
|
|
|
|
void setMaxUsableSmCount(int smCount)
|
|
{
|
|
FusedMoeCommunicator::setMaxUsableSmCount(smCount);
|
|
}
|
|
|
|
size_t getFusedMoeCommWorkspaceSize(int epSize)
|
|
{
|
|
int channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize);
|
|
size_t workspaceSize = FusedMoeWorkspace::computeWorkspaceSizePreRank(epSize, channelCount);
|
|
return workspaceSize;
|
|
}
|
|
|
|
void constructWorkspace(FusedMoeWorkspace* workspace, uint64_t* workspacePtr, size_t rankStrideInU64, int epSize)
|
|
{
|
|
workspace->workspacePtr = workspacePtr;
|
|
workspace->rankStrideInU64 = rankStrideInU64;
|
|
workspace->channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize);
|
|
}
|
|
|
|
void initializeFusedMoeLocalWorkspace(FusedMoeWorkspace* workspace, FusedMoeWorldInfo const& worldInfo)
|
|
{
|
|
workspace->initializeLocalWorkspace(worldInfo);
|
|
}
|
|
|
|
namespace fused_moe_comm_tests
|
|
{
|
|
|
|
__global__ void g2sKernel(FusedMoeFieldInfo allFieldInfo, MoeExpertParallelInfo expertParallelInfo,
|
|
MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmDump, bool hasBasicFields)
|
|
{
|
|
__shared__ uint64_t allWarpSmemBar[32];
|
|
extern __shared__ int4 allWarpShm[];
|
|
int laneId = threadIdx.x % WARP_SIZE;
|
|
int warpId = threadIdx.x / WARP_SIZE;
|
|
int warpCount = blockDim.x / WARP_SIZE;
|
|
int tokenIndex = warpId + blockIdx.x * warpCount;
|
|
if (tokenIndex >= tokenCount)
|
|
{
|
|
return;
|
|
}
|
|
|
|
int singleShmSize = singleCommMeta.singleUncompactAlignedSize;
|
|
|
|
initSmemBar(&allWarpSmemBar[warpId], laneId);
|
|
uint32_t phaseParity = 0;
|
|
|
|
uint8_t* sharedMemoryBase = reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * warpId;
|
|
|
|
if (hasBasicFields)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<true>(
|
|
allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
|
|
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<true>(&allWarpSmemBar[warpId], &phaseParity);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<false>(
|
|
allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
|
|
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<false>(&allWarpSmemBar[warpId], &phaseParity);
|
|
}
|
|
|
|
for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE)
|
|
{
|
|
shmDump[tokenIndex * singleShmSize / sizeof(int) + offset] = reinterpret_cast<int*>(sharedMemoryBase)[offset];
|
|
}
|
|
}
|
|
|
|
void launchSingleG2S(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
|
|
int tokenCount, int* shmDump, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream)
|
|
{
|
|
int warpShmSize = sendFieldInfo.computeSingleUncompactSize(
|
|
expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields);
|
|
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
|
|
dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1);
|
|
MoeSingleCommMeta singleCommMeta;
|
|
sendFieldInfo.fillMetaInfo(
|
|
&singleCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields, false);
|
|
TLLM_CUDA_CHECK(
|
|
cudaFuncSetAttribute(g2sKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
|
|
g2sKernel<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(
|
|
sendFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmDump, hasBasicFields);
|
|
TLLM_CUDA_CHECK(cudaGetLastError());
|
|
}
|
|
|
|
__global__ void s2gKernel(FusedMoeFieldInfo recvFieldInfo, MoeExpertParallelInfo expertParallelInfo,
|
|
MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmPreload, bool hasBasicFields)
|
|
{
|
|
extern __shared__ int4 allWarpShm[];
|
|
int laneId = threadIdx.x % WARP_SIZE;
|
|
int warpId = threadIdx.x / WARP_SIZE;
|
|
int warpCount = blockDim.x / WARP_SIZE;
|
|
int tokenIndex = warpId + blockIdx.x * warpCount;
|
|
if (tokenIndex >= tokenCount)
|
|
{
|
|
return;
|
|
}
|
|
int singleShmSize = singleCommMeta.singleUncompactAlignedSize;
|
|
uint8_t* sharedMemoryBase = reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * warpId;
|
|
|
|
for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE)
|
|
{
|
|
reinterpret_cast<int*>(sharedMemoryBase)[offset]
|
|
= shmPreload[tokenIndex * singleShmSize / sizeof(int) + offset];
|
|
}
|
|
__syncwarp();
|
|
|
|
if (hasBasicFields)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<true>(
|
|
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<false>(
|
|
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
|
|
}
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead();
|
|
}
|
|
|
|
void launchSingleS2G(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo,
|
|
int tokenCount, int* shmPreload, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream)
|
|
{
|
|
int warpShmSize = recvFieldInfo.computeSingleUncompactSize(
|
|
expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields);
|
|
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
|
|
dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1);
|
|
MoeSingleCommMeta singleCommMeta;
|
|
recvFieldInfo.fillMetaInfo(
|
|
&singleCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields, false);
|
|
TLLM_CUDA_CHECK(
|
|
cudaFuncSetAttribute(s2gKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
|
|
s2gKernel<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(
|
|
recvFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmPreload, hasBasicFields);
|
|
TLLM_CUDA_CHECK(cudaGetLastError());
|
|
}
|
|
|
|
__global__ void loopbackKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo,
|
|
MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta,
|
|
int* recvIndexMapping, int tokenCount, bool hasBasicFields)
|
|
{
|
|
__shared__ uint64_t allWarpSmemBar[32];
|
|
extern __shared__ int4 allWarpShm[];
|
|
int laneId = threadIdx.x % WARP_SIZE;
|
|
int warpId = threadIdx.x / WARP_SIZE;
|
|
int warpCount = blockDim.x / WARP_SIZE;
|
|
int tokenIndex = warpId + blockIdx.x * warpCount;
|
|
if (tokenIndex >= tokenCount)
|
|
{
|
|
return;
|
|
}
|
|
|
|
int recvTokenIndex = recvIndexMapping[tokenIndex];
|
|
|
|
initSmemBar(&allWarpSmemBar[warpId], laneId);
|
|
uint32_t phaseParity = 0;
|
|
|
|
int singleShmSize = sendCommMeta.getSingleShmSize();
|
|
|
|
uint8_t* sharedMemoryBase = reinterpret_cast<uint8_t*>(allWarpShm) + singleShmSize * warpId;
|
|
|
|
if (hasBasicFields)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<true>(
|
|
sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::g2sAllFields<false>(
|
|
sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]);
|
|
}
|
|
|
|
if (hasBasicFields)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<true>(&allWarpSmemBar[warpId], &phaseParity);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields<false>(&allWarpSmemBar[warpId], &phaseParity);
|
|
}
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::packAllFields(sendFieldInfo, tokenIndex, sharedMemoryBase, laneId);
|
|
|
|
tokenIndex = recvTokenIndex; // switch to recvTokenIndex;
|
|
|
|
tensorrt_llm::kernels::fused_moe_impl::unpackAllFields(recvFieldInfo, tokenIndex, sharedMemoryBase, laneId);
|
|
|
|
if (hasBasicFields)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<true>(
|
|
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::s2gAllFields<false>(
|
|
recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId);
|
|
}
|
|
|
|
cp_async_bulk_wait_group_read<0>();
|
|
__syncwarp();
|
|
}
|
|
|
|
// G2S -> Pack -> Unpack -> S2G
|
|
void launchLoopback(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int* recvIndexMapping, int tokenCount, int warpsPerBlock,
|
|
bool hasBasicFields, cudaStream_t stream)
|
|
{
|
|
MoeSingleCommMeta sendCommMeta, recvCommMeta;
|
|
sendFieldInfo.fillMetaInfo(
|
|
&sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields, false);
|
|
recvFieldInfo.fillMetaInfo(
|
|
&recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields, false);
|
|
int warpSendShmSize = sendCommMeta.getSingleShmSize();
|
|
int warpRecvShmSize = recvCommMeta.getSingleShmSize();
|
|
int warpShmSize = warpSendShmSize;
|
|
TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)",
|
|
warpSendShmSize, warpRecvShmSize);
|
|
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
|
|
dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1);
|
|
TLLM_CUDA_CHECK(
|
|
cudaFuncSetAttribute(loopbackKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
|
|
loopbackKernel<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(sendFieldInfo, recvFieldInfo,
|
|
expertParallelInfo, sendCommMeta, recvCommMeta, recvIndexMapping, tokenCount, hasBasicFields);
|
|
TLLM_CUDA_CHECK(cudaGetLastError());
|
|
}
|
|
|
|
template <bool HAS_BASIC_FIELD = true>
|
|
__global__ void localFifoSendRecvKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo,
|
|
MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta,
|
|
FusedMoeWorkspace fusedMoeWorkspace, int* sendIndexMapping, int* recvIndexMapping, int tokenCount)
|
|
{
|
|
__shared__ uint64_t allWarpSmemBar[32];
|
|
extern __shared__ int4 allWarpShm[];
|
|
|
|
FusedMoeWorldInfo worldInfo;
|
|
worldInfo.epInfo.epRank = 0;
|
|
worldInfo.epInfo.epSize = 1;
|
|
|
|
int warpId = threadIdx.x / WARP_SIZE;
|
|
int warpCount = blockDim.x / WARP_SIZE;
|
|
|
|
FusedMoePairInfo pairInfo;
|
|
pairInfo.senderRank = 0;
|
|
pairInfo.receiverRank = 0;
|
|
pairInfo.channel = blockIdx.z * warpCount + warpId;
|
|
pairInfo.runChannelCount = gridDim.z * warpCount;
|
|
|
|
if (blockIdx.y == 0)
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator<MOE_COMM_FIELD_MAX_COUNT, HAS_BASIC_FIELD>
|
|
senderComm(sendFieldInfo, expertParallelInfo, sendCommMeta, fusedMoeWorkspace, worldInfo, pairInfo,
|
|
&allWarpSmemBar[warpId],
|
|
reinterpret_cast<uint8_t*>(&allWarpShm[0]) + warpId * sendCommMeta.getSingleShmSize());
|
|
senderComm.doSend(tokenCount, sendIndexMapping);
|
|
}
|
|
else
|
|
{
|
|
tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator<MOE_COMM_FIELD_MAX_COUNT, HAS_BASIC_FIELD>
|
|
recverComm(recvFieldInfo, expertParallelInfo, recvCommMeta, fusedMoeWorkspace, worldInfo, pairInfo,
|
|
&allWarpSmemBar[warpId],
|
|
reinterpret_cast<uint8_t*>(&allWarpShm[0]) + warpId * recvCommMeta.getSingleShmSize());
|
|
recverComm.doReceive(tokenCount, recvIndexMapping);
|
|
}
|
|
}
|
|
|
|
void launchLocalFifoSendRecv(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo,
|
|
MoeExpertParallelInfo const& expertParallelInfo, int* sendIndexMapping, int* recvIndexMapping,
|
|
FusedMoeWorkspace fusedMoeWorkspace, int tokenCount, int warpsPerBlock, int blockChannelCount, bool hasBasicFields,
|
|
cudaStream_t stream)
|
|
{
|
|
MoeSingleCommMeta sendCommMeta, recvCommMeta;
|
|
sendFieldInfo.fillMetaInfo(
|
|
&sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields, false);
|
|
recvFieldInfo.fillMetaInfo(
|
|
&recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields, false);
|
|
int warpSendShmSize = sendCommMeta.getSingleShmSize();
|
|
int warpRecvShmSize = recvCommMeta.getSingleShmSize();
|
|
int warpShmSize = warpSendShmSize;
|
|
TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)",
|
|
warpSendShmSize, warpRecvShmSize);
|
|
dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1);
|
|
dim3 gridDim(1, 2, blockChannelCount);
|
|
auto* kernelFn = localFifoSendRecvKernel<>;
|
|
if (hasBasicFields)
|
|
{
|
|
kernelFn = localFifoSendRecvKernel<true>;
|
|
}
|
|
else
|
|
{
|
|
kernelFn = localFifoSendRecvKernel<false>;
|
|
}
|
|
TLLM_CUDA_CHECK(
|
|
cudaFuncSetAttribute(kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock));
|
|
kernelFn<<<gridDim, blockDim, warpShmSize * warpsPerBlock, stream>>>(sendFieldInfo, recvFieldInfo,
|
|
expertParallelInfo, sendCommMeta, recvCommMeta, fusedMoeWorkspace, sendIndexMapping, recvIndexMapping,
|
|
tokenCount);
|
|
TLLM_CUDA_CHECK(cudaGetLastError());
|
|
}
|
|
|
|
} // namespace fused_moe_comm_tests
|
|
|
|
} // namespace kernels
|
|
|
|
TRTLLM_NAMESPACE_END
|