mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[fix] Performance Optimization for MNNVL TwoShot Kernel (#5934)
Signed-off-by: Shiyu Li <shili@nvidia.com> Co-authored-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
This commit is contained in:
parent
fe070a0168
commit
6e1aee6fd6
@ -61,6 +61,31 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val)
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
|
||||
__device__ float4 loadfloat4(void const* ptr)
|
||||
{
|
||||
|
||||
float return_value[4];
|
||||
|
||||
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), "=f"(return_value[3])
|
||||
: "l"(ptr));
|
||||
|
||||
return *(float4*) return_value;
|
||||
}
|
||||
|
||||
__device__ __inline__ float2 loadfloat2(void const* ptr)
|
||||
{
|
||||
|
||||
float return_value[2];
|
||||
|
||||
asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n"
|
||||
: "=f"(return_value[0]), "=f"(return_value[1])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
|
||||
return *(float2*) return_value;
|
||||
}
|
||||
|
||||
template <int WORLD_SIZE, typename T>
|
||||
__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
|
||||
int buffer_M, int token_dim, int rank, uint32_t* buffer_flags, bool wait_for_results)
|
||||
@ -74,20 +99,13 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
// [input_ptr, clear_ptr, buffer_size, access_counter]
|
||||
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
|
||||
// Each buffer is M * N and we have 2 buffers in each group, one for reduce-scatter and one for allgather
|
||||
uint32_t buffer_group_size = flag.z << 1;
|
||||
uint32_t input_offset = flag.x * buffer_group_size;
|
||||
uint32_t clear_offset = flag.y * buffer_group_size;
|
||||
uint32_t* offset_access_ptr = &buffer_flags[3];
|
||||
// Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
|
||||
uint32_t buffer_size = (buffer_flags[2] << 1);
|
||||
uint32_t input_offset = buffer_flags[0] * buffer_size;
|
||||
uint32_t clear_offset = buffer_flags[1] * buffer_size;
|
||||
|
||||
if (wait_for_results)
|
||||
{
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(offset_access_ptr, 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (elt < token_dim)
|
||||
{
|
||||
@ -101,17 +119,16 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
|
||||
// Reduce and broadcast
|
||||
|
||||
int global_token = token * WORLD_SIZE + rank;
|
||||
if (global_token < num_tokens)
|
||||
if ((token % WORLD_SIZE) == rank)
|
||||
{
|
||||
|
||||
int local_token = token / WORLD_SIZE;
|
||||
float accum = 0.f;
|
||||
|
||||
T values[WORLD_SIZE];
|
||||
|
||||
for (int r = 0; r < WORLD_SIZE; r++)
|
||||
{
|
||||
input_ptrs[rank][clear_offset + token * token_dim * WORLD_SIZE + r * token_dim + elt]
|
||||
input_ptrs[rank][clear_offset + local_token * token_dim * WORLD_SIZE + r * token_dim + elt]
|
||||
= fromFloat<T>(-0.f);
|
||||
}
|
||||
|
||||
@ -121,7 +138,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
for (int r = 0; r < WORLD_SIZE; r++)
|
||||
{
|
||||
T volatile* lamport_ptr = (T volatile*) &input_ptrs[rank][input_offset
|
||||
+ token * token_dim * WORLD_SIZE + r * token_dim + elt];
|
||||
+ local_token * token_dim * WORLD_SIZE + r * token_dim + elt];
|
||||
values[r] = *lamport_ptr;
|
||||
valid &= !isNegZero(values[r]);
|
||||
}
|
||||
@ -132,7 +149,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
{
|
||||
accum += toFloat<T>(values[r]);
|
||||
}
|
||||
mcast_ptr[input_offset + buffer_M * token_dim + global_token * token_dim + elt] = fromFloat<T>(accum);
|
||||
mcast_ptr[input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
|
||||
}
|
||||
}
|
||||
|
||||
@ -145,23 +162,50 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
// Optionally wait for results if the next layer isn't doing the Lamport check
|
||||
if (wait_for_results)
|
||||
{
|
||||
T volatile* lamport_ptr
|
||||
= (T volatile*) &input_ptrs[rank][input_offset + buffer_M * token_dim + token * token_dim + elt];
|
||||
T val = *lamport_ptr;
|
||||
while (isNegZero(val))
|
||||
val = *lamport_ptr;
|
||||
// Update the atomic counter to indicate the block has read the offsets
|
||||
__syncthreads();
|
||||
|
||||
// Copy if requested
|
||||
if (output_ptr)
|
||||
output_ptr[token * token_dim + elt] = val;
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
|
||||
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
|
||||
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
|
||||
#else
|
||||
atomicAdd(offset_access_ptr, 1);
|
||||
#endif
|
||||
}
|
||||
// Only use a set of CTAs for lamport sync, reargange the grid
|
||||
constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T);
|
||||
// blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
|
||||
if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD))
|
||||
{
|
||||
uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;
|
||||
|
||||
void* lamport_ptr = (void*) &input_ptrs[rank][input_offset + buffer_M * token_dim + current_pos];
|
||||
// We have 2 assumptions here:
|
||||
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
|
||||
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
|
||||
float2 val = loadfloat2(lamport_ptr);
|
||||
while (isNegZero(*(T*) &val))
|
||||
{
|
||||
val = loadfloat2(lamport_ptr);
|
||||
}
|
||||
if (output_ptr)
|
||||
{
|
||||
*((float2*) &output_ptr[current_pos]) = val;
|
||||
}
|
||||
}
|
||||
|
||||
// Update the buffer flags
|
||||
if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0)
|
||||
{
|
||||
// Make sure all blocks have finished reading the offsets, 2-D grid
|
||||
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
|
||||
{
|
||||
}
|
||||
buffer_flags[0] = (buffer_flags[0] + 1) % 3;
|
||||
buffer_flags[1] = (buffer_flags[1] + 1) % 3;
|
||||
buffer_flags[0] = (flag.x + 1) % 3;
|
||||
buffer_flags[1] = (flag.y + 1) % 3;
|
||||
*(offset_access_ptr) = 0;
|
||||
}
|
||||
}
|
||||
@ -251,18 +295,6 @@ __device__ void copy_f4_ldg(T_IN* dst, T_IN const* src)
|
||||
*dst4 = *src4;
|
||||
}
|
||||
|
||||
__device__ float4 loadfloat4(void const* ptr)
|
||||
{
|
||||
|
||||
float return_value[4];
|
||||
|
||||
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), "=f"(return_value[3])
|
||||
: "l"(ptr));
|
||||
|
||||
return *(float4*) return_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T add(T a, T b)
|
||||
{
|
||||
@ -322,19 +354,14 @@ __global__ void __launch_bounds__(128, 1)
|
||||
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
|
||||
|
||||
uint32_t* offset_access_ptr = &buffer_flags[3];
|
||||
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
|
||||
// Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
|
||||
uint32_t buffer_size = buffer_flags[2];
|
||||
uint32_t buffer_offset = buffer_flags[0] * (buffer_size << 1);
|
||||
uint32_t buffer_size = flag.z;
|
||||
uint32_t buffer_offset = flag.x * (buffer_size << 1);
|
||||
T_IN const* input = &buffer_input[buffer_offset + buffer_size];
|
||||
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(offset_access_ptr, 1);
|
||||
}
|
||||
|
||||
for (int i = 0; i < NUM_INPUTS; i++)
|
||||
{
|
||||
for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++)
|
||||
@ -361,7 +388,17 @@ __global__ void __launch_bounds__(128, 1)
|
||||
}
|
||||
|
||||
__pipeline_commit();
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
|
||||
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
|
||||
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
|
||||
#else
|
||||
atomicAdd(offset_access_ptr, 1);
|
||||
#endif
|
||||
}
|
||||
// Load all inputs
|
||||
bool valid = false;
|
||||
|
||||
@ -494,14 +531,13 @@ __global__ void __launch_bounds__(128, 1)
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
|
||||
{
|
||||
// Make sure all blocks have finished accessing the buffer
|
||||
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) != gridDim.x * gridDim.y)
|
||||
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
|
||||
{
|
||||
}
|
||||
buffer_flags[0] = (buffer_flags[0] + 1) % 3;
|
||||
buffer_flags[1] = (buffer_flags[1] + 1) % 3;
|
||||
buffer_flags[0] = (flag.x + 1) % 3;
|
||||
buffer_flags[1] = (flag.y + 1) % 3;
|
||||
*(offset_access_ptr) = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ McastDeviceMemory::McastDeviceMemory(
|
||||
, mMcHandle(0)
|
||||
{
|
||||
|
||||
cudaSetDevice(mDeviceIdx);
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx));
|
||||
// Check if the device support multicasting
|
||||
int multicast_supported{0};
|
||||
TLLM_CU_CHECK(cuDeviceGetAttribute(&multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, mDeviceIdx));
|
||||
@ -82,34 +82,41 @@ McastDeviceMemory::McastDeviceMemory(
|
||||
{
|
||||
allocNvlsMcastMem(mSignalPadOffset + kSIGNAL_PAD_SIZE);
|
||||
}
|
||||
mSignalPadsDev.resize(mGroupSize);
|
||||
// Initialize signal pads
|
||||
mSignalPads.resize(mGroupSize);
|
||||
for (size_t i = 0; i < mGroupSize; i++)
|
||||
{
|
||||
mSignalPadsDev[i] = mUcPtrs[i] + mSignalPadOffset;
|
||||
mSignalPads[i] = mUcPtrs[i] + mSignalPadOffset;
|
||||
if (i == mGroupRank)
|
||||
{
|
||||
cuMemsetD8(mSignalPadsDev[i], 0, kSIGNAL_PAD_SIZE);
|
||||
cuMemsetD8(mSignalPads[i], 0, kSIGNAL_PAD_SIZE);
|
||||
}
|
||||
}
|
||||
// Copy host array of pointers to device array
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&mSignalPadsDev, mGroupSize * sizeof(CUdeviceptr)));
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&mUcPtrsDev, mGroupSize * sizeof(CUdeviceptr)));
|
||||
TLLM_CUDA_CHECK(
|
||||
cudaMemcpy(mSignalPadsDev, mSignalPads.data(), mGroupSize * sizeof(CUdeviceptr), cudaMemcpyHostToDevice));
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(mUcPtrsDev, mUcPtrs.data(), mGroupSize * sizeof(CUdeviceptr), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
McastDeviceMemory::~McastDeviceMemory()
|
||||
{
|
||||
tensorrt_llm::common::unregisterMcastDevMemBuffer(this);
|
||||
TLLM_CUDA_CHECK(cudaFree(mSignalPadsDev));
|
||||
TLLM_CUDA_CHECK(cudaFree(mUcPtrsDev));
|
||||
|
||||
if (mIsMNNvlink)
|
||||
{
|
||||
for (uint32_t rank = 0; rank < mGroupSize; rank++)
|
||||
{
|
||||
if (rank == mGroupRank)
|
||||
{
|
||||
cuMemRelease(mUcHandles[rank]);
|
||||
}
|
||||
else
|
||||
{
|
||||
mUcHandles[rank] = 0;
|
||||
}
|
||||
TLLM_CU_CHECK(cuMemUnmap(mUcPtrs[rank], mAllocationSize));
|
||||
// We need to release the handle on each rank
|
||||
TLLM_CU_CHECK(cuMemRelease(mUcHandles[rank]));
|
||||
}
|
||||
cuMemRelease(mMcHandle);
|
||||
TLLM_CU_CHECK(cuMemUnmap(mMcPtr, mAllocationSize));
|
||||
TLLM_CU_CHECK(cuMemAddressFree(mMcPtr, mAllocationSize));
|
||||
TLLM_CU_CHECK(cuMemRelease(mMcHandle));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@ -44,16 +44,18 @@ public:
|
||||
|
||||
McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink);
|
||||
|
||||
// We don't register the pointer in these two functions since we don't expect any python-level code would call
|
||||
// to obtain the raw pointers.
|
||||
//! Get the raw array of signal pad pointers to all ranks (including self)
|
||||
void** getSignalPadPtrsDev()
|
||||
{
|
||||
return reinterpret_cast<void**>(mSignalPadsDev.data());
|
||||
return mSignalPadsDev;
|
||||
}
|
||||
|
||||
//! Get the raw array of unicast pointers to all ranks (including self)
|
||||
void** getBufferPtrsDev()
|
||||
{
|
||||
return reinterpret_cast<void**>(mUcPtrs.data());
|
||||
return mUcPtrsDev;
|
||||
}
|
||||
|
||||
//! Get the raw unicast pointer to a given rank
|
||||
@ -93,11 +95,17 @@ private:
|
||||
size_t mAllocationSize;
|
||||
|
||||
CUdeviceptr mMcPtr;
|
||||
std::vector<CUdeviceptr> mUcPtrs;
|
||||
std::vector<CUdeviceptr> mSignalPadsDev;
|
||||
CUmemGenericAllocationHandle mMcHandle;
|
||||
std::vector<CUmemGenericAllocationHandle> mUcHandles;
|
||||
|
||||
// Host array of pointers
|
||||
std::vector<CUdeviceptr> mUcPtrs;
|
||||
std::vector<CUdeviceptr> mSignalPads;
|
||||
|
||||
// Device array of pointers
|
||||
void** mUcPtrsDev;
|
||||
void** mSignalPadsDev;
|
||||
|
||||
// For intra-node mcast
|
||||
tensorrt_llm::runtime::IpcNvlsHandle* mNvlsHandle;
|
||||
|
||||
|
||||
@ -798,12 +798,12 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
|
||||
do_finalize = not (hidden_states.shape[0]
|
||||
<= self.moe_allreduce.max_token
|
||||
and self.fusion_config.POST_MOE_FUSION
|
||||
and self.model_config.moe_backend == 'TRTLLM'
|
||||
and self.mlp.experts.has_nvfp4)
|
||||
# Note: this fusion pattern is only supported for single-node TRTLLM-nvfp4 backend now
|
||||
do_finalize = self.mapping.is_multi_node() or (
|
||||
not (hidden_states.shape[0] <= self.moe_allreduce.max_token
|
||||
and self.fusion_config.POST_MOE_FUSION
|
||||
and self.model_config.moe_backend == "TRTLLM"
|
||||
and self.mlp.experts.has_nvfp4))
|
||||
|
||||
hidden_states = _run_MoE(hidden_states,
|
||||
hidden_states_fp4=None,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user