[fix] Performance Optimization for MNNVL TwoShot Kernel (#5934)

Signed-off-by: Shiyu Li <shili@nvidia.com>
Co-authored-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
This commit is contained in:
Shiyu Li 2025-07-16 19:49:51 -07:00 committed by GitHub
parent fe070a0168
commit 6e1aee6fd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 129 additions and 78 deletions

View File

@ -61,6 +61,31 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val)
return __float2bfloat16(val);
}
__device__ float4 loadfloat4(void const* ptr)
{
float return_value[4];
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
: "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), "=f"(return_value[3])
: "l"(ptr));
return *(float4*) return_value;
}
__device__ __inline__ float2 loadfloat2(void const* ptr)
{
float return_value[2];
asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n"
: "=f"(return_value[0]), "=f"(return_value[1])
: "l"(ptr)
: "memory");
return *(float2*) return_value;
}
template <int WORLD_SIZE, typename T>
__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
int buffer_M, int token_dim, int rank, uint32_t* buffer_flags, bool wait_for_results)
@ -74,20 +99,13 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
cudaGridDependencySynchronize();
#endif
// [input_ptr, clear_ptr, buffer_size, access_counter]
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
// Each buffer is M * N and we have 2 buffers in each group, one for reduce-scatter and one for allgather
uint32_t buffer_group_size = flag.z << 1;
uint32_t input_offset = flag.x * buffer_group_size;
uint32_t clear_offset = flag.y * buffer_group_size;
uint32_t* offset_access_ptr = &buffer_flags[3];
// Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
uint32_t buffer_size = (buffer_flags[2] << 1);
uint32_t input_offset = buffer_flags[0] * buffer_size;
uint32_t clear_offset = buffer_flags[1] * buffer_size;
if (wait_for_results)
{
__syncthreads();
if (threadIdx.x == 0)
{
atomicAdd(offset_access_ptr, 1);
}
}
if (elt < token_dim)
{
@ -101,17 +119,16 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
// Reduce and broadcast
int global_token = token * WORLD_SIZE + rank;
if (global_token < num_tokens)
if ((token % WORLD_SIZE) == rank)
{
int local_token = token / WORLD_SIZE;
float accum = 0.f;
T values[WORLD_SIZE];
for (int r = 0; r < WORLD_SIZE; r++)
{
input_ptrs[rank][clear_offset + token * token_dim * WORLD_SIZE + r * token_dim + elt]
input_ptrs[rank][clear_offset + local_token * token_dim * WORLD_SIZE + r * token_dim + elt]
= fromFloat<T>(-0.f);
}
@ -121,7 +138,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
for (int r = 0; r < WORLD_SIZE; r++)
{
T volatile* lamport_ptr = (T volatile*) &input_ptrs[rank][input_offset
+ token * token_dim * WORLD_SIZE + r * token_dim + elt];
+ local_token * token_dim * WORLD_SIZE + r * token_dim + elt];
values[r] = *lamport_ptr;
valid &= !isNegZero(values[r]);
}
@ -132,7 +149,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
{
accum += toFloat<T>(values[r]);
}
mcast_ptr[input_offset + buffer_M * token_dim + global_token * token_dim + elt] = fromFloat<T>(accum);
mcast_ptr[input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
}
}
@ -145,23 +162,50 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
// Optionally wait for results if the next layer isn't doing the Lamport check
if (wait_for_results)
{
T volatile* lamport_ptr
= (T volatile*) &input_ptrs[rank][input_offset + buffer_M * token_dim + token * token_dim + elt];
T val = *lamport_ptr;
while (isNegZero(val))
val = *lamport_ptr;
// Update the atomic counter to indicate the block has read the offsets
__syncthreads();
// Copy if requested
if (output_ptr)
output_ptr[token * token_dim + elt] = val;
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
if (threadIdx.x == 0)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
#else
atomicAdd(offset_access_ptr, 1);
#endif
}
// Only use a set of CTAs for lamport sync, reargange the grid
constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T);
// blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD))
{
uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;
void* lamport_ptr = (void*) &input_ptrs[rank][input_offset + buffer_M * token_dim + current_pos];
// We have 2 assumptions here:
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
float2 val = loadfloat2(lamport_ptr);
while (isNegZero(*(T*) &val))
{
val = loadfloat2(lamport_ptr);
}
if (output_ptr)
{
*((float2*) &output_ptr[current_pos]) = val;
}
}
// Update the buffer flags
if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0)
{
// Make sure all blocks have finished reading the offsets, 2-D grid
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
{
}
buffer_flags[0] = (buffer_flags[0] + 1) % 3;
buffer_flags[1] = (buffer_flags[1] + 1) % 3;
buffer_flags[0] = (flag.x + 1) % 3;
buffer_flags[1] = (flag.y + 1) % 3;
*(offset_access_ptr) = 0;
}
}
@ -251,18 +295,6 @@ __device__ void copy_f4_ldg(T_IN* dst, T_IN const* src)
*dst4 = *src4;
}
__device__ float4 loadfloat4(void const* ptr)
{
float return_value[4];
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
: "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), "=f"(return_value[3])
: "l"(ptr));
return *(float4*) return_value;
}
template <typename T>
inline __device__ T add(T a, T b)
{
@ -322,19 +354,14 @@ __global__ void __launch_bounds__(128, 1)
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
uint32_t* offset_access_ptr = &buffer_flags[3];
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
// Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
uint32_t buffer_size = buffer_flags[2];
uint32_t buffer_offset = buffer_flags[0] * (buffer_size << 1);
uint32_t buffer_size = flag.z;
uint32_t buffer_offset = flag.x * (buffer_size << 1);
T_IN const* input = &buffer_input[buffer_offset + buffer_size];
cudaTriggerProgrammaticLaunchCompletion();
__syncthreads();
if (threadIdx.x == 0)
{
atomicAdd(offset_access_ptr, 1);
}
for (int i = 0; i < NUM_INPUTS; i++)
{
for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++)
@ -361,7 +388,17 @@ __global__ void __launch_bounds__(128, 1)
}
__pipeline_commit();
__syncthreads();
if (threadIdx.x == 0)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
#else
atomicAdd(offset_access_ptr, 1);
#endif
}
// Load all inputs
bool valid = false;
@ -494,14 +531,13 @@ __global__ void __launch_bounds__(128, 1)
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
{
// Make sure all blocks have finished accessing the buffer
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) != gridDim.x * gridDim.y)
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
{
}
buffer_flags[0] = (buffer_flags[0] + 1) % 3;
buffer_flags[1] = (buffer_flags[1] + 1) % 3;
buffer_flags[0] = (flag.x + 1) % 3;
buffer_flags[1] = (flag.y + 1) % 3;
*(offset_access_ptr) = 0;
}
__syncthreads();
#endif
}

View File

@ -50,7 +50,7 @@ McastDeviceMemory::McastDeviceMemory(
, mMcHandle(0)
{
cudaSetDevice(mDeviceIdx);
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx));
// Check if the device support multicasting
int multicast_supported{0};
TLLM_CU_CHECK(cuDeviceGetAttribute(&multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, mDeviceIdx));
@ -82,34 +82,41 @@ McastDeviceMemory::McastDeviceMemory(
{
allocNvlsMcastMem(mSignalPadOffset + kSIGNAL_PAD_SIZE);
}
mSignalPadsDev.resize(mGroupSize);
// Initialize signal pads
mSignalPads.resize(mGroupSize);
for (size_t i = 0; i < mGroupSize; i++)
{
mSignalPadsDev[i] = mUcPtrs[i] + mSignalPadOffset;
mSignalPads[i] = mUcPtrs[i] + mSignalPadOffset;
if (i == mGroupRank)
{
cuMemsetD8(mSignalPadsDev[i], 0, kSIGNAL_PAD_SIZE);
cuMemsetD8(mSignalPads[i], 0, kSIGNAL_PAD_SIZE);
}
}
// Copy host array of pointers to device array
TLLM_CUDA_CHECK(cudaMalloc(&mSignalPadsDev, mGroupSize * sizeof(CUdeviceptr)));
TLLM_CUDA_CHECK(cudaMalloc(&mUcPtrsDev, mGroupSize * sizeof(CUdeviceptr)));
TLLM_CUDA_CHECK(
cudaMemcpy(mSignalPadsDev, mSignalPads.data(), mGroupSize * sizeof(CUdeviceptr), cudaMemcpyHostToDevice));
TLLM_CUDA_CHECK(cudaMemcpy(mUcPtrsDev, mUcPtrs.data(), mGroupSize * sizeof(CUdeviceptr), cudaMemcpyHostToDevice));
}
McastDeviceMemory::~McastDeviceMemory()
{
tensorrt_llm::common::unregisterMcastDevMemBuffer(this);
TLLM_CUDA_CHECK(cudaFree(mSignalPadsDev));
TLLM_CUDA_CHECK(cudaFree(mUcPtrsDev));
if (mIsMNNvlink)
{
for (uint32_t rank = 0; rank < mGroupSize; rank++)
{
if (rank == mGroupRank)
{
cuMemRelease(mUcHandles[rank]);
}
else
{
mUcHandles[rank] = 0;
}
TLLM_CU_CHECK(cuMemUnmap(mUcPtrs[rank], mAllocationSize));
// We need to release the handle on each rank
TLLM_CU_CHECK(cuMemRelease(mUcHandles[rank]));
}
cuMemRelease(mMcHandle);
TLLM_CU_CHECK(cuMemUnmap(mMcPtr, mAllocationSize));
TLLM_CU_CHECK(cuMemAddressFree(mMcPtr, mAllocationSize));
TLLM_CU_CHECK(cuMemRelease(mMcHandle));
}
else
{

View File

@ -44,16 +44,18 @@ public:
McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink);
// We don't register the pointer in these two functions since we don't expect any python-level code would call
// to obtain the raw pointers.
//! Get the raw array of signal pad pointers to all ranks (including self)
void** getSignalPadPtrsDev()
{
return reinterpret_cast<void**>(mSignalPadsDev.data());
return mSignalPadsDev;
}
//! Get the raw array of unicast pointers to all ranks (including self)
void** getBufferPtrsDev()
{
return reinterpret_cast<void**>(mUcPtrs.data());
return mUcPtrsDev;
}
//! Get the raw unicast pointer to a given rank
@ -93,11 +95,17 @@ private:
size_t mAllocationSize;
CUdeviceptr mMcPtr;
std::vector<CUdeviceptr> mUcPtrs;
std::vector<CUdeviceptr> mSignalPadsDev;
CUmemGenericAllocationHandle mMcHandle;
std::vector<CUmemGenericAllocationHandle> mUcHandles;
// Host array of pointers
std::vector<CUdeviceptr> mUcPtrs;
std::vector<CUdeviceptr> mSignalPads;
// Device array of pointers
void** mUcPtrsDev;
void** mSignalPadsDev;
// For intra-node mcast
tensorrt_llm::runtime::IpcNvlsHandle* mNvlsHandle;

View File

@ -798,12 +798,12 @@ class DeepseekV3DecoderLayer(DecoderLayer):
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
do_finalize = not (hidden_states.shape[0]
<= self.moe_allreduce.max_token
and self.fusion_config.POST_MOE_FUSION
and self.model_config.moe_backend == 'TRTLLM'
and self.mlp.experts.has_nvfp4)
# Note: this fusion pattern is only supported for single-node TRTLLM-nvfp4 backend now
do_finalize = self.mapping.is_multi_node() or (
not (hidden_states.shape[0] <= self.moe_allreduce.max_token
and self.fusion_config.POST_MOE_FUSION
and self.model_config.moe_backend == "TRTLLM"
and self.mlp.experts.has_nvfp4))
hidden_states = _run_MoE(hidden_states,
hidden_states_fp4=None,