mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
1641 lines
58 KiB
Plaintext
1641 lines
58 KiB
Plaintext
/*
|
|
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/common/config.h"
|
|
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
|
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/common/customAllReduceUtils.h"
|
|
#include "tensorrt_llm/common/dataType.h"
|
|
#include "tensorrt_llm/common/envUtils.h"
|
|
#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h"
|
|
#include <cooperative_groups.h>
|
|
#include <tuple>
|
|
#include <type_traits>
|
|
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
|
|
namespace kernels
|
|
{
|
|
|
|
using tensorrt_llm::common::divUp;
|
|
using tensorrt_llm::common::roundUp;
|
|
using tensorrt_llm::common::cuda_max;
|
|
using tensorrt_llm::common::cuda_abs;
|
|
|
|
static StaticLowPrecisionBuffers static_tp2_buffers;
|
|
static StaticLowPrecisionBuffers static_tp4_buffers;
|
|
static StaticLowPrecisionBuffers static_tp8_buffers;
|
|
|
|
StaticLowPrecisionBuffers* getBufferForTpSize(size_t tpSize)
|
|
{
|
|
if (tpSize == 2)
|
|
{
|
|
return &static_tp2_buffers;
|
|
}
|
|
else if (tpSize == 4)
|
|
{
|
|
return &static_tp4_buffers;
|
|
}
|
|
else if (tpSize == 8)
|
|
{
|
|
return &static_tp8_buffers;
|
|
}
|
|
else
|
|
{
|
|
TLLM_THROW("Unsupported tpSize for LowPrecisionCustomAllReduce");
|
|
}
|
|
}
|
|
|
|
void initialize_static_lowprecision_buffers(int64_t* buffer, size_t tpSize)
|
|
{
|
|
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
|
|
|
|
StaticLowPrecisionBuffers* static_buffers = getBufferForTpSize(tpSize);
|
|
|
|
// Store pointers in static structure
|
|
for (int i = 0; i < tpSize; ++i)
|
|
{
|
|
static_buffers->peer_comm_buffer_ptrs[i] = buffer_ptrs[i];
|
|
static_buffers->peer_comm_buffer_ptrs[tpSize + i] = buffer_ptrs[tpSize + i];
|
|
static_buffers->peer_barrier_ptrs_in[i] = reinterpret_cast<uint64_t*>(buffer_ptrs[2 * tpSize + i]);
|
|
static_buffers->peer_barrier_ptrs_out[i] = reinterpret_cast<uint64_t*>(buffer_ptrs[3 * tpSize + i]);
|
|
}
|
|
|
|
constexpr int LOW_PRECISION_NUM_POINTERS_PER_RANK = 4;
|
|
// Store the flag pointer
|
|
int flag_offset = 1;
|
|
static_buffers->flag_ptr = &buffer[LOW_PRECISION_NUM_POINTERS_PER_RANK * tpSize + flag_offset];
|
|
|
|
static_buffers->initialized = true;
|
|
static_buffers->tpSize = tpSize;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ void lp_allreduce_st_flag_release(uint64_t const& flag, uint64_t* flag_addr)
|
|
{
|
|
#if __CUDA_ARCH__ >= 700
|
|
asm volatile("st.global.release.sys.b64 [%1], %0;" ::"l"(flag), "l"(flag_addr));
|
|
#else
|
|
__threadfence_system();
|
|
asm volatile("st.global.volatile.b64 [%1], %0;" ::"l"(flag), "l"(flag_addr));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ void lp_allreduce_ld_flag_acquire(uint64_t& flag, uint64_t* flag_addr)
|
|
{
|
|
#if __CUDA_ARCH__ >= 700
|
|
asm volatile("ld.global.acquire.sys.b64 %0, [%1];" : "=l"(flag) : "l"(flag_addr));
|
|
#else
|
|
asm volatile("ld.global.volatile.b64 %0, [%1];" : "=l"(flag) : "l"(flag_addr));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Type Converter that packs data format to 128 bits data type
|
|
//
|
|
using PackedFloat = union
|
|
{
|
|
int4 packed;
|
|
float unpacked[4];
|
|
};
|
|
|
|
using PackedHalf = union
|
|
{
|
|
int4 packed;
|
|
// half2 unpacked[4];
|
|
__half unpacked[8];
|
|
};
|
|
|
|
template <typename T>
|
|
struct PackedOn16Bytes
|
|
{
|
|
};
|
|
|
|
template <typename T, int Num>
|
|
struct PackedOnNum
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct PackedOn16Bytes<float>
|
|
{
|
|
using Type = PackedFloat;
|
|
};
|
|
|
|
template <>
|
|
struct PackedOn16Bytes<half>
|
|
{
|
|
using Type = PackedHalf;
|
|
};
|
|
|
|
using PackedINT8 = union
|
|
{
|
|
int4 packed;
|
|
int8_t unpacked[16];
|
|
};
|
|
|
|
using PackedINT8_8Bytes = union
|
|
{
|
|
int2 packed;
|
|
int8_t unpacked[8];
|
|
};
|
|
|
|
using PackedINT8_4Bytes = union
|
|
{
|
|
int packed;
|
|
int8_t unpacked[4];
|
|
};
|
|
|
|
template <>
|
|
struct PackedOn16Bytes<int8_t>
|
|
{
|
|
using Type = PackedINT8;
|
|
};
|
|
|
|
template <>
|
|
struct PackedOnNum<int8_t, 8>
|
|
{
|
|
using Type = PackedINT8_8Bytes;
|
|
};
|
|
|
|
template <>
|
|
struct PackedOnNum<int8_t, 4>
|
|
{
|
|
using Type = PackedINT8_4Bytes;
|
|
};
|
|
|
|
#ifdef ENABLE_BF16
|
|
using PackedBFloat16 = union
|
|
{
|
|
int4 packed;
|
|
//__nv_bfloat162 unpacked[4];
|
|
__nv_bfloat16 unpacked[8];
|
|
};
|
|
|
|
template <>
|
|
struct PackedOn16Bytes<__nv_bfloat16>
|
|
{
|
|
using Type = PackedBFloat16;
|
|
};
|
|
#endif
|
|
|
|
#ifdef ENABLE_FP8
|
|
using PackedFloat8E4m3 = union
|
|
{
|
|
int4 packed;
|
|
__nv_fp8_e4m3 unpacked[16];
|
|
};
|
|
|
|
using PackedFloat8E4m3_8Bytes = union
|
|
{
|
|
int2 packed;
|
|
__nv_fp8_e4m3 unpacked[8];
|
|
};
|
|
|
|
using PackedFloat8E4m3_4Bytes = union
|
|
{
|
|
int packed;
|
|
__nv_fp8_e4m3 unpacked[4];
|
|
};
|
|
|
|
template <>
|
|
struct PackedOn16Bytes<__nv_fp8_e4m3>
|
|
{
|
|
using Type = PackedFloat8E4m3;
|
|
};
|
|
|
|
template <>
|
|
struct PackedOnNum<__nv_fp8_e4m3, 8>
|
|
{
|
|
using Type = PackedFloat8E4m3_8Bytes;
|
|
};
|
|
|
|
template <>
|
|
struct PackedOnNum<__nv_fp8_e4m3, 4>
|
|
{
|
|
using Type = PackedFloat8E4m3_4Bytes;
|
|
};
|
|
#endif
|
|
|
|
template <int num>
|
|
struct LowPrecisionIntPack
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct LowPrecisionIntPack<4>
|
|
{
|
|
using Type = int;
|
|
};
|
|
|
|
template <>
|
|
struct LowPrecisionIntPack<8>
|
|
{
|
|
using Type = int2;
|
|
};
|
|
|
|
template <>
|
|
struct LowPrecisionIntPack<16>
|
|
{
|
|
using Type = int4;
|
|
};
|
|
|
|
__inline__ __device__ void multi_gpu_barrier(
|
|
uint64_t** signals, const uint64_t flag, const size_t rank, const size_t world_size, int const tidx, int const bidx)
|
|
{
|
|
// At the end of the function, we now that has least block 0 from all others GPUs have reached that point.
|
|
uint64_t volatile* my_signals = signals[rank];
|
|
if (tidx < world_size)
|
|
{
|
|
// The 1st block notifies the other ranks.
|
|
if (bidx == 0)
|
|
{
|
|
signals[tidx][rank] = flag;
|
|
}
|
|
|
|
// Busy-wait until all ranks are ready.
|
|
while (my_signals[tidx] != flag)
|
|
{
|
|
}
|
|
}
|
|
|
|
// Make sure we can move on...
|
|
__syncthreads();
|
|
}
|
|
|
|
__device__ __forceinline__ void st_global_release(int4 const& val, int4* addr)
|
|
{
|
|
asm volatile("st.release.global.sys.v4.b32 [%4], {%0, %1, %2, %3};" ::"r"(val.x), "r"(val.y), "r"(val.z),
|
|
"r"(val.w), "l"(addr));
|
|
}
|
|
|
|
__device__ __forceinline__ int4 ld_global_acquire(int4* addr)
|
|
{
|
|
int4 val;
|
|
asm volatile("ld.acquire.global.sys.v4.b32 {%0, %1, %2, %3}, [%4];"
|
|
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
|
|
: "l"(addr));
|
|
return val;
|
|
}
|
|
|
|
__device__ __forceinline__ void st_global_volatile(int4 const& val, int4* addr)
|
|
{
|
|
asm volatile("st.volatile.global.v4.b32 [%4], {%0, %1, %2, %3};" ::"r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w),
|
|
"l"(addr));
|
|
}
|
|
|
|
__device__ __forceinline__ int4 ld_global_volatile(int4* addr)
|
|
{
|
|
int4 val;
|
|
asm volatile("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
|
|
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
|
|
: "l"(addr));
|
|
return val;
|
|
}
|
|
|
|
__device__ __forceinline__ void fence_acq_rel_sys()
|
|
{
|
|
asm volatile("fence.acq_rel.sys;" ::: "memory");
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ __forceinline__ uintptr_t cvta_to_global(T* ptr)
|
|
{
|
|
return (uintptr_t) __cvta_generic_to_global(ptr);
|
|
}
|
|
|
|
__device__ __forceinline__ uint64_t ld_volatile_global(uint64_t* ptr)
|
|
{
|
|
uint64_t ans;
|
|
asm("ld.volatile.global.u64 %0, [%1];" : "=l"(ans) : "l"(cvta_to_global(ptr)));
|
|
return ans;
|
|
}
|
|
|
|
__device__ __forceinline__ void wait_send_peer(uint64_t local_flag, uint64_t* peer_flag_ptr)
|
|
{
|
|
uint64_t peer_flag = ld_volatile_global(peer_flag_ptr);
|
|
while (local_flag - peer_flag >= LP_ALLREDUCE_BUFFER_CHUNKS)
|
|
{
|
|
peer_flag = ld_volatile_global(peer_flag_ptr);
|
|
}
|
|
return;
|
|
}
|
|
|
|
__device__ __forceinline__ void wait_recv_peer(uint64_t local_flag, uint64_t* peer_flag_ptr)
|
|
{
|
|
uint64_t peer_flag = ld_volatile_global(peer_flag_ptr);
|
|
|
|
while (local_flag >= peer_flag)
|
|
{
|
|
peer_flag = ld_volatile_global(peer_flag_ptr);
|
|
}
|
|
return;
|
|
}
|
|
|
|
__device__ __forceinline__ void notify_peer(uint64_t* peer_flag_ptr)
|
|
{
|
|
asm volatile("st.relaxed.sys.global.u64 [%0], %1;" ::"l"(cvta_to_global(peer_flag_ptr)), "l"(uint64_t(1))
|
|
: "memory");
|
|
return;
|
|
}
|
|
|
|
__device__ __forceinline__ void notify_peer_with_value_relax(uint64_t* peer_flag_ptr, uint64_t value)
|
|
{
|
|
asm volatile("st.relaxed.sys.global.u64 [%0], %1;" ::"l"(cvta_to_global(peer_flag_ptr)), "l"(value) : "memory");
|
|
return;
|
|
}
|
|
|
|
__device__ __forceinline__ void notify_peer_with_value(uint64_t* peer_flag_ptr, uint64_t value)
|
|
{
|
|
*peer_flag_ptr = value;
|
|
return;
|
|
}
|
|
|
|
__device__ float warp_reduce_max(float val)
|
|
{
|
|
val = cuda_max(__shfl_xor_sync(~0, val, 16), val);
|
|
val = cuda_max(__shfl_xor_sync(~0, val, 8), val);
|
|
val = cuda_max(__shfl_xor_sync(~0, val, 4), val);
|
|
val = cuda_max(__shfl_xor_sync(~0, val, 2), val);
|
|
val = cuda_max(__shfl_xor_sync(~0, val, 1), val);
|
|
return val;
|
|
}
|
|
|
|
template <typename QUANTIZE_T>
|
|
struct QuantMaxValue;
|
|
|
|
template <>
|
|
struct QuantMaxValue<int8_t>
|
|
{
|
|
static constexpr float value = 127.0f;
|
|
};
|
|
|
|
template <>
|
|
struct QuantMaxValue<__nv_fp8_e4m3>
|
|
{
|
|
static constexpr float value = 448.0f;
|
|
};
|
|
|
|
template <int32_t RANKS_PER_NODE, typename T_IN, typename T_OUT>
|
|
__global__ void lowPrecisionPreprocessKernel(
|
|
const T_IN* __restrict__ input, size_t elts_per_rank_in, size_t elts_per_rank_out, T_OUT* __restrict__ output)
|
|
{
|
|
constexpr float QUANT_MAX = QuantMaxValue<T_OUT>::value;
|
|
constexpr int32_t output_rounds = sizeof(T_IN) / sizeof(T_OUT);
|
|
constexpr int32_t elts_per_thread = sizeof(int4) / sizeof(T_OUT);
|
|
constexpr int32_t elts_per_round = sizeof(int4) / sizeof(T_IN);
|
|
constexpr int32_t elts_per_warp_per_round = elts_per_round * WARP_SIZE;
|
|
constexpr int32_t NUM_ELTS_PER_WARP_IN = (WARP_SIZE - 1) * elts_per_thread;
|
|
constexpr int32_t NUM_ELTS_PER_WARP_OUT = WARP_SIZE * elts_per_thread;
|
|
using PackedInputType = typename PackedOn16Bytes<T_IN>::Type;
|
|
using PackedOutputType = typename PackedOnNum<T_OUT, elts_per_round>::Type;
|
|
|
|
using PackedInputIntType = typename LowPrecisionIntPack<sizeof(int4)>::Type;
|
|
using PackedOutputIntType = typename LowPrecisionIntPack<elts_per_round>::Type;
|
|
|
|
const int32_t target_rank = blockIdx.x / (gridDim.x / RANKS_PER_NODE);
|
|
const int32_t local_bid = blockIdx.x % (gridDim.x / RANKS_PER_NODE);
|
|
|
|
input += elts_per_rank_in * target_rank;
|
|
output += elts_per_rank_out * target_rank;
|
|
|
|
const int32_t lane_id = threadIdx.x % WARP_SIZE;
|
|
const int32_t wid = threadIdx.x / WARP_SIZE;
|
|
|
|
PackedInputType vals[output_rounds];
|
|
size_t start_in = NUM_ELTS_PER_WARP_IN * LP_ALLREDUCE_WARP_NUM_PER_BLOCK * local_bid + wid * NUM_ELTS_PER_WARP_IN;
|
|
size_t start_out
|
|
= NUM_ELTS_PER_WARP_OUT * LP_ALLREDUCE_WARP_NUM_PER_BLOCK * local_bid + wid * NUM_ELTS_PER_WARP_OUT;
|
|
|
|
#pragma unroll
|
|
for (int32_t i = 0; i < output_rounds; ++i)
|
|
{
|
|
int32_t local_offset = lane_id * elts_per_round + elts_per_warp_per_round * i;
|
|
int32_t global_offset = start_in + local_offset;
|
|
if (local_offset < NUM_ELTS_PER_WARP_IN && global_offset < elts_per_rank_in)
|
|
{
|
|
vals[i].packed = *reinterpret_cast<PackedInputIntType const*>(input + start_in + local_offset);
|
|
}
|
|
else
|
|
{
|
|
#pragma unroll
|
|
for (int j = 0; j < elts_per_round; j++)
|
|
{
|
|
vals[i].unpacked[j] = 0.0f;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Calculate scaling factor
|
|
float scalar = 0;
|
|
for (int32_t i = 0; i < output_rounds; ++i)
|
|
{
|
|
#pragma unroll
|
|
for (int32_t j = 0; j < elts_per_round; ++j)
|
|
{
|
|
scalar = cuda_max(cuda_abs((float) (vals[i].unpacked[j])), scalar);
|
|
}
|
|
}
|
|
|
|
scalar = warp_reduce_max(scalar);
|
|
if (scalar != 0.0f)
|
|
{
|
|
scalar = QUANT_MAX / scalar;
|
|
}
|
|
|
|
// Quantize and write output
|
|
PackedOutputType output_vals[output_rounds];
|
|
for (int32_t i = 0; i < output_rounds; ++i)
|
|
{
|
|
int32_t local_write_offset = lane_id * elts_per_round + elts_per_warp_per_round * i;
|
|
if (local_write_offset < NUM_ELTS_PER_WARP_IN)
|
|
{
|
|
#pragma unroll
|
|
for (int32_t j = 0; j < elts_per_round; ++j)
|
|
{
|
|
float out_val = vals[i].unpacked[j];
|
|
if (scalar != 0.0f)
|
|
{
|
|
out_val *= scalar;
|
|
}
|
|
output_vals[i].unpacked[j] = static_cast<T_OUT>(out_val);
|
|
}
|
|
}
|
|
else if (local_write_offset == NUM_ELTS_PER_WARP_IN)
|
|
{
|
|
*(reinterpret_cast<float*>(&output_vals[i])) = scalar;
|
|
}
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int32_t i = 0; i < output_rounds; ++i)
|
|
{
|
|
|
|
int32_t local_write_offset = lane_id * elts_per_round + elts_per_warp_per_round * i;
|
|
*reinterpret_cast<PackedOutputIntType*>(output + start_out + local_write_offset) = output_vals[i].packed;
|
|
}
|
|
}
|
|
|
|
template <int32_t RANKS_PER_NODE, typename T_IN>
|
|
__device__ void lowPrecisionTwoShotFirstStageKernel(int32_t myrank, size_t elts_per_rank, T_IN** input, float* smem)
|
|
{
|
|
constexpr float QUANT_MAX = QuantMaxValue<T_IN>::value;
|
|
constexpr int32_t elts_per_thread = sizeof(int4) / sizeof(T_IN);
|
|
constexpr int32_t NUM_ELTS_PER_WARP_IN = WARP_SIZE * elts_per_thread;
|
|
|
|
const int32_t lane_id = threadIdx.x % WARP_SIZE;
|
|
const int32_t bid = blockIdx.x;
|
|
const int32_t wid = threadIdx.x / WARP_SIZE;
|
|
const size_t in_start
|
|
= (bid * LP_ALLREDUCE_WARP_NUM_PER_BLOCK + wid) * NUM_ELTS_PER_WARP_IN + lane_id * elts_per_thread;
|
|
|
|
// Packed data type for comms
|
|
using PackedType = typename PackedOn16Bytes<T_IN>::Type;
|
|
float* smem_scalar_ptr = &smem[RANKS_PER_NODE * wid];
|
|
const size_t rank_offset = elts_per_rank * myrank;
|
|
|
|
for (size_t local_offset = in_start; local_offset < elts_per_rank;
|
|
local_offset += gridDim.x * blockDim.x * elts_per_thread)
|
|
{
|
|
float sums[elts_per_thread];
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < elts_per_thread; ++ii)
|
|
{
|
|
sums[ii] = 0;
|
|
}
|
|
|
|
// Read, dequantize and reduce sum
|
|
{
|
|
PackedType vals[RANKS_PER_NODE];
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
vals[ii].packed = *reinterpret_cast<int4 const*>(&input[ii][local_offset + rank_offset]);
|
|
}
|
|
|
|
if (lane_id == (WARP_SIZE - 1))
|
|
{
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
float* tmp_scalar = (float*) (&(vals[ii]));
|
|
smem_scalar_ptr[ii] = tmp_scalar[0];
|
|
}
|
|
}
|
|
__syncwarp();
|
|
|
|
if (lane_id < (WARP_SIZE - 1))
|
|
{
|
|
// Sum the values from the different ranks
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
#pragma unroll
|
|
for (int32_t jj = 0; jj < elts_per_thread; ++jj)
|
|
{
|
|
if (smem_scalar_ptr[ii] != 0)
|
|
{
|
|
sums[jj] += (float) (vals[ii].unpacked[jj]) / smem_scalar_ptr[ii];
|
|
}
|
|
else
|
|
{
|
|
sums[jj] += (float) (vals[ii].unpacked[jj]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Quantize and write back results
|
|
{
|
|
float scalar = 0;
|
|
|
|
if (lane_id < (WARP_SIZE - 1))
|
|
{
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < elts_per_thread; ++ii)
|
|
{
|
|
scalar = cuda_max(cuda_abs(sums[ii]), scalar);
|
|
}
|
|
}
|
|
|
|
scalar = warp_reduce_max(scalar);
|
|
|
|
if (scalar != 0.0f)
|
|
{
|
|
scalar = (QUANT_MAX) / scalar;
|
|
}
|
|
|
|
PackedType tmp_val;
|
|
if (lane_id < (WARP_SIZE - 1))
|
|
{
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < elts_per_thread; ++ii)
|
|
{
|
|
float tmp = sums[ii];
|
|
if (scalar != 0.0f)
|
|
{
|
|
tmp *= scalar;
|
|
}
|
|
tmp_val.unpacked[ii] = static_cast<T_IN>(tmp);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
((float*) (&tmp_val))[0] = scalar;
|
|
}
|
|
|
|
*reinterpret_cast<int4*>(input[0] + local_offset + rank_offset) = tmp_val.packed;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <int32_t RANKS_PER_NODE, typename T_IN, typename T_OUT>
|
|
__device__ void lowPrecisionTwoShotSecondStageKernel(size_t input_elts_per_rank, size_t output_elts_per_rank,
|
|
T_IN** input, T_OUT* output, float* smem, int32_t* dst_rank)
|
|
{
|
|
constexpr int32_t elts_per_thread = sizeof(int4) / sizeof(T_IN);
|
|
constexpr int32_t output_rounds = sizeof(T_OUT) / sizeof(T_IN);
|
|
constexpr int32_t depack_num = elts_per_thread / output_rounds;
|
|
|
|
constexpr int32_t NUM_ELTS_PER_WARP_IN = WARP_SIZE * elts_per_thread;
|
|
constexpr int32_t NUM_ELTS_PER_WARP_OUT = (WARP_SIZE - 1) * elts_per_thread;
|
|
|
|
const int32_t lane_id = threadIdx.x % WARP_SIZE;
|
|
const int32_t bid = blockIdx.x;
|
|
const int32_t wid = threadIdx.x / WARP_SIZE;
|
|
|
|
const size_t in_start
|
|
= (bid * LP_ALLREDUCE_WARP_NUM_PER_BLOCK + wid) * NUM_ELTS_PER_WARP_IN + lane_id * elts_per_thread;
|
|
const size_t out_start
|
|
= (bid * LP_ALLREDUCE_WARP_NUM_PER_BLOCK + wid) * NUM_ELTS_PER_WARP_OUT + lane_id * elts_per_thread;
|
|
|
|
float* smem_scalar_ptr = &smem[RANKS_PER_NODE * wid];
|
|
|
|
using PackedInType = typename PackedOn16Bytes<T_IN>::Type;
|
|
using PackedOutType = typename PackedOn16Bytes<T_OUT>::Type;
|
|
|
|
PackedInType vals[RANKS_PER_NODE];
|
|
|
|
for (size_t input_offset = in_start, output_offset = out_start; input_offset < input_elts_per_rank;
|
|
input_offset += gridDim.x * LP_ALLREDUCE_WARP_NUM_PER_BLOCK * NUM_ELTS_PER_WARP_IN,
|
|
output_offset += gridDim.x * LP_ALLREDUCE_WARP_NUM_PER_BLOCK * NUM_ELTS_PER_WARP_OUT)
|
|
{
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
size_t tmp_offset = dst_rank[ii] * input_elts_per_rank + input_offset;
|
|
if (input_offset < input_elts_per_rank)
|
|
{
|
|
vals[ii].packed = *reinterpret_cast<int4 const*>(&input[ii][tmp_offset]);
|
|
}
|
|
}
|
|
|
|
if (lane_id == (WARP_SIZE - 1))
|
|
{
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
float* tmp_scalar = (float*) (&(vals[ii]));
|
|
smem_scalar_ptr[ii] = tmp_scalar[0];
|
|
}
|
|
}
|
|
__syncwarp();
|
|
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
float scale = smem_scalar_ptr[ii];
|
|
size_t tmp_output_offset = dst_rank[ii] * output_elts_per_rank + output_offset;
|
|
|
|
if (output_offset < output_elts_per_rank)
|
|
{
|
|
if (lane_id < (WARP_SIZE - 1))
|
|
{
|
|
for (int32_t jj = 0; jj < output_rounds; ++jj)
|
|
{
|
|
PackedOutType tmp_output;
|
|
|
|
#pragma unroll
|
|
for (int32_t kk = 0; kk < depack_num; kk++)
|
|
{
|
|
float tmp = (float) (vals[ii].unpacked[kk + jj * depack_num]);
|
|
if (scale != 0.0f)
|
|
{
|
|
tmp /= scale;
|
|
}
|
|
tmp_output.unpacked[kk] = static_cast<T_OUT>(tmp);
|
|
}
|
|
|
|
*reinterpret_cast<PackedOutType*>(output + tmp_output_offset + jj * depack_num) = tmp_output;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, typename QUANT_T, int32_t RANKS_PER_NODE>
|
|
static __global__ void lowPrecisionTwoShotAllReduceKernel(LowPrecisionAllReduceParams params)
|
|
{
|
|
const int32_t bidx = blockIdx.x;
|
|
const int32_t tidx = threadIdx.x;
|
|
|
|
extern __shared__ float smem[];
|
|
|
|
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
|
|
|
// The source pointers. Distributed round-robin for the different warps.
|
|
QUANT_T* src_d[RANKS_PER_NODE];
|
|
// The destination ranks for round-robin gathering
|
|
int32_t dst_rank[RANKS_PER_NODE];
|
|
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
int32_t rank = (params.local_rank + ii) % RANKS_PER_NODE;
|
|
src_d[ii] = reinterpret_cast<QUANT_T*>(params.peer_comm_buffer_ptrs[rank]);
|
|
dst_rank[ii] = rank;
|
|
}
|
|
|
|
lowPrecisionTwoShotFirstStageKernel<RANKS_PER_NODE, QUANT_T>(
|
|
params.local_rank, params.buffer_elts_per_rank, src_d, smem);
|
|
|
|
// Sync threads to make sure all block threads have the sums
|
|
__syncthreads();
|
|
|
|
// Barriers among the blocks with the same idx (release-acquire semantics)
|
|
if (tidx < RANKS_PER_NODE)
|
|
{
|
|
// The all blocks notifies the other ranks.
|
|
uint32_t flag_block_offset = RANKS_PER_NODE + bidx * RANKS_PER_NODE;
|
|
lp_allreduce_st_flag_release(
|
|
params.barrier_flag, params.peer_barrier_ptrs_in[tidx] + flag_block_offset + params.local_rank);
|
|
|
|
// Busy-wait until all ranks are ready.
|
|
uint64_t rank_barrier = 0;
|
|
uint64_t* peer_barrier_d = params.peer_barrier_ptrs_in[params.local_rank] + flag_block_offset + tidx;
|
|
do
|
|
{
|
|
lp_allreduce_ld_flag_acquire(rank_barrier, peer_barrier_d);
|
|
} while (rank_barrier != params.barrier_flag);
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// Do allgather and dequantize
|
|
float* smem_allgather = smem + (RANKS_PER_NODE * LP_ALLREDUCE_WARP_NUM_PER_BLOCK);
|
|
lowPrecisionTwoShotSecondStageKernel<RANKS_PER_NODE, QUANT_T, T>(params.buffer_elts_per_rank, params.elts_per_rank,
|
|
src_d, reinterpret_cast<T*>(params.local_output_buffer_ptr), smem_allgather, dst_rank);
|
|
}
|
|
|
|
template <typename T_IN, typename T_OUT>
|
|
__global__ void lowPrecisionHierPreprocessKernel(
|
|
const T_IN* __restrict__ input, size_t n_in, T_OUT* __restrict__ output)
|
|
{
|
|
constexpr float QUANT_MAX = QuantMaxValue<T_OUT>::value;
|
|
constexpr int32_t output_rounds = sizeof(T_IN) / sizeof(T_OUT);
|
|
constexpr int32_t elts_per_thread = sizeof(int4) / sizeof(T_OUT);
|
|
constexpr int32_t elts_per_round = sizeof(int4) / sizeof(T_IN);
|
|
constexpr int32_t elts_per_warp_per_round = elts_per_round * WARP_SIZE;
|
|
constexpr int32_t NUM_ELTS_PER_WARP_IN = (WARP_SIZE - 1) * elts_per_thread;
|
|
constexpr int32_t NUM_ELTS_PER_WARP_OUT = WARP_SIZE * elts_per_thread;
|
|
|
|
using PackedInputType = typename PackedOn16Bytes<T_IN>::Type;
|
|
using PackedOutputType = typename PackedOnNum<T_OUT, elts_per_round>::Type;
|
|
using PackedInputIntType = typename LowPrecisionIntPack<16>::Type;
|
|
using PackedOutputIntType = typename LowPrecisionIntPack<elts_per_round>::Type;
|
|
|
|
const int32_t lane_id = threadIdx.x % WARP_SIZE;
|
|
const int32_t wid = threadIdx.x / WARP_SIZE;
|
|
PackedInputType vals[output_rounds];
|
|
|
|
for (size_t start = blockIdx.x * LP_ALLREDUCE_WARP_NUM_PER_BLOCK + wid; start * NUM_ELTS_PER_WARP_IN < n_in;
|
|
start += LP_ALLREDUCE_WARP_NUM_PER_BLOCK * gridDim.x)
|
|
{
|
|
int32_t read_rounds = 0;
|
|
int32_t local_n_in = (n_in - start * NUM_ELTS_PER_WARP_IN) > NUM_ELTS_PER_WARP_IN
|
|
? NUM_ELTS_PER_WARP_IN
|
|
: (n_in - start * NUM_ELTS_PER_WARP_IN);
|
|
if (local_n_in <= 0)
|
|
{
|
|
return;
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int32_t i = 0; i < output_rounds; ++i)
|
|
{
|
|
int32_t local_offset = lane_id * elts_per_round + elts_per_warp_per_round * i;
|
|
if (local_offset < local_n_in)
|
|
{
|
|
vals[i].packed
|
|
= *reinterpret_cast<PackedInputIntType const*>(input + start * NUM_ELTS_PER_WARP_IN + local_offset);
|
|
read_rounds++;
|
|
}
|
|
else
|
|
{
|
|
#pragma unroll
|
|
for (int j = 0; j < elts_per_round; j++)
|
|
{
|
|
vals[i].unpacked[j] = 0.0f;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Calculate scaling factor
|
|
float scalar = 0;
|
|
for (int32_t i = 0; i < read_rounds; ++i)
|
|
{
|
|
#pragma unroll
|
|
for (int32_t j = 0; j < elts_per_round; ++j)
|
|
{
|
|
scalar = cuda_max(cuda_abs((float) (vals[i].unpacked[j])), scalar);
|
|
}
|
|
}
|
|
|
|
scalar = warp_reduce_max(scalar);
|
|
if (scalar != 0.0f)
|
|
{
|
|
scalar = QUANT_MAX / scalar;
|
|
}
|
|
|
|
// Quantize and write output
|
|
PackedOutputType output_vals[output_rounds];
|
|
for (int32_t i = 0; i < output_rounds; ++i)
|
|
{
|
|
int32_t local_write_offset = lane_id * elts_per_round + elts_per_warp_per_round * i;
|
|
if (local_write_offset < NUM_ELTS_PER_WARP_IN)
|
|
{
|
|
#pragma unroll
|
|
for (int32_t j = 0; j < elts_per_round; ++j)
|
|
{
|
|
float out_val = vals[i].unpacked[j];
|
|
if (scalar != 0.0f)
|
|
{
|
|
out_val *= scalar;
|
|
}
|
|
output_vals[i].unpacked[j] = static_cast<T_OUT>(out_val);
|
|
}
|
|
}
|
|
else if (local_write_offset == NUM_ELTS_PER_WARP_IN)
|
|
{
|
|
*(reinterpret_cast<float*>(&output_vals[i])) = scalar;
|
|
}
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int32_t i = 0; i < output_rounds; ++i)
|
|
{
|
|
int32_t local_write_offset = lane_id * elts_per_round + elts_per_warp_per_round * i;
|
|
*reinterpret_cast<PackedOutputIntType*>(output + start * NUM_ELTS_PER_WARP_OUT + local_write_offset)
|
|
= output_vals[i].packed;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <int32_t RANKS_PER_NODE, typename T>
|
|
__device__ void hierReduceWithQdq(
|
|
LowPrecisionAllReduceParams params, T** input, T* output, int64_t start_offset, int64_t length, float* smem)
|
|
{
|
|
// Constants
|
|
constexpr float QUANT_MAX = QuantMaxValue<T>::value;
|
|
constexpr int32_t elts_per_thread = sizeof(int4) / sizeof(T);
|
|
|
|
// Thread indices
|
|
const int32_t lane_id = threadIdx.x % WARP_SIZE;
|
|
const int32_t wid = threadIdx.x / WARP_SIZE;
|
|
const size_t start = threadIdx.x * elts_per_thread;
|
|
|
|
// Packed data type for comms
|
|
using PackedType = typename PackedOn16Bytes<T>::Type;
|
|
float* smem_scalar_ptr = &smem[RANKS_PER_NODE * wid];
|
|
|
|
for (size_t index = start; index < length; index += LP_ALLREDUCE_DEFAULT_BLOCK_SIZE * elts_per_thread)
|
|
{
|
|
// Initialize sum array
|
|
float sums[elts_per_thread];
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < elts_per_thread; ++ii)
|
|
{
|
|
sums[ii] = 0;
|
|
}
|
|
|
|
// Load values from different ranks and dequantize
|
|
{
|
|
PackedType vals[RANKS_PER_NODE];
|
|
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
vals[ii].packed = *reinterpret_cast<int4 const*>(&input[ii][start_offset + index]);
|
|
}
|
|
|
|
if (lane_id == (WARP_SIZE - 1))
|
|
{
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
float* tmp_scalar = (float*) (&(vals[ii]));
|
|
smem_scalar_ptr[ii] = tmp_scalar[0];
|
|
}
|
|
}
|
|
__syncwarp();
|
|
|
|
if (lane_id < (WARP_SIZE - 1))
|
|
{
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
#pragma unroll
|
|
for (int32_t jj = 0; jj < elts_per_thread; ++jj)
|
|
{
|
|
if (smem_scalar_ptr[ii] != 0)
|
|
{
|
|
sums[jj] += (float) (vals[ii].unpacked[jj]) / smem_scalar_ptr[ii];
|
|
}
|
|
else
|
|
{
|
|
sums[jj] += (float) (vals[ii].unpacked[jj]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Quantize results and write output
|
|
{
|
|
float scalar = 0;
|
|
|
|
if (lane_id < (WARP_SIZE - 1))
|
|
{
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < elts_per_thread; ++ii)
|
|
{
|
|
scalar = cuda_max(cuda_abs(sums[ii]), scalar);
|
|
}
|
|
}
|
|
|
|
scalar = warp_reduce_max(scalar);
|
|
|
|
if (scalar != 0.0f)
|
|
{
|
|
scalar = QUANT_MAX / scalar;
|
|
}
|
|
|
|
PackedType tmp_val;
|
|
|
|
if (lane_id < (WARP_SIZE - 1))
|
|
{
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < elts_per_thread; ++ii)
|
|
{
|
|
float tmp = sums[ii];
|
|
if (scalar != 0.0f)
|
|
{
|
|
tmp *= scalar;
|
|
}
|
|
tmp_val.unpacked[ii] = (T) tmp;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
((float*) (&tmp_val))[0] = scalar;
|
|
}
|
|
|
|
*reinterpret_cast<int4*>(&output[threadIdx.x * elts_per_thread]) = tmp_val.packed;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <int32_t RANKS_PER_NODE, typename T_IN, typename T_OUT>
|
|
__device__ void hierAllgatherWithDq(LowPrecisionAllReduceParams params, T_IN** input, T_OUT* output,
|
|
size_t input_offset, int32_t global_iter, int32_t length, int32_t blocks_per_stage, float* smem)
|
|
{
|
|
// Constants and thread indices
|
|
constexpr int32_t elts_per_thread = sizeof(int4) / sizeof(T_IN);
|
|
constexpr int32_t output_rounds = sizeof(T_OUT) / sizeof(T_IN);
|
|
constexpr int32_t depack_num = elts_per_thread / output_rounds;
|
|
|
|
const int32_t bidx = blockIdx.x;
|
|
const int32_t tidx = threadIdx.x;
|
|
const int32_t lane_id = tidx % WARP_SIZE;
|
|
const int32_t wid = tidx / WARP_SIZE;
|
|
const int32_t start = tidx * elts_per_thread;
|
|
|
|
const int32_t OUTPUT_ELEMENT_PER_WARP = (WARP_SIZE - 1) * elts_per_thread;
|
|
const int32_t OUTPUT_ELEMENT_PER_BLOCK = OUTPUT_ELEMENT_PER_WARP * LP_ALLREDUCE_WARP_NUM_PER_BLOCK;
|
|
|
|
using PackedType = typename PackedOn16Bytes<T_IN>::Type;
|
|
using PackedOutputType = typename PackedOn16Bytes<T_OUT>::Type;
|
|
const int32_t numa_rank = params.numa_rank;
|
|
|
|
PackedType vals[RANKS_PER_NODE];
|
|
float* smem_scalar_ptr = &smem[RANKS_PER_NODE * wid];
|
|
|
|
for (size_t index = start; index < length; index += LP_ALLREDUCE_DEFAULT_BLOCK_SIZE * elts_per_thread)
|
|
{
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
vals[ii].packed = *reinterpret_cast<int4 const*>(&input[ii][input_offset + index]);
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
if (lane_id == WARP_SIZE - 1)
|
|
{
|
|
float* tmp_scalar = (float*) (&(vals[ii]));
|
|
smem_scalar_ptr[ii] = tmp_scalar[0];
|
|
}
|
|
}
|
|
__syncwarp();
|
|
|
|
const size_t elts_total = params.elts_total;
|
|
|
|
for (int32_t ii = 0; ii < RANKS_PER_NODE; ++ii)
|
|
{
|
|
float scale = smem_scalar_ptr[ii];
|
|
size_t offset_global = global_iter * blocks_per_stage * RANKS_PER_NODE * OUTPUT_ELEMENT_PER_BLOCK;
|
|
|
|
int32_t tmp_rank = (numa_rank + ii) % RANKS_PER_NODE;
|
|
size_t offset_local = offset_global + (bidx % blocks_per_stage) * RANKS_PER_NODE * OUTPUT_ELEMENT_PER_BLOCK
|
|
+ tmp_rank * OUTPUT_ELEMENT_PER_BLOCK + wid * OUTPUT_ELEMENT_PER_WARP + lane_id * elts_per_thread;
|
|
bool need_write = elts_total > offset_local;
|
|
|
|
if (lane_id < WARP_SIZE - 1 && need_write)
|
|
{
|
|
for (int32_t jj = 0; jj < output_rounds; ++jj)
|
|
{
|
|
PackedOutputType tmp_output;
|
|
|
|
#pragma unroll
|
|
for (int32_t kk = 0; kk < depack_num; kk++)
|
|
{
|
|
float tmp = (float) (vals[ii].unpacked[kk + jj * depack_num]);
|
|
if (scale != 0)
|
|
{
|
|
tmp /= scale;
|
|
}
|
|
((T_OUT*) (&tmp_output))[kk] = (T_OUT) tmp;
|
|
}
|
|
|
|
*reinterpret_cast<int4*>(&reinterpret_cast<T_OUT*>(output)[offset_local + jj * depack_num])
|
|
= *reinterpret_cast<int4*>(&tmp_output);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, typename QUANT_T, int RANKS_PER_NODE>
|
|
static __global__ __launch_bounds__(512, 1) void lowPrecisionTwoShotHierAllReduceKernel(
|
|
LowPrecisionAllReduceParams params)
|
|
{
|
|
|
|
// The block index.
|
|
int const bidx = blockIdx.x;
|
|
// The thread index with the block.
|
|
int const tidx = threadIdx.x;
|
|
// The block num
|
|
int const block_num = gridDim.x;
|
|
int const duplicate = LP_ALLREDUCE_BUFFER_DUPLICATE;
|
|
// this algorithm have 3 stages , so for one stage, have 1/3's block num
|
|
int const block_num_per_stage = block_num / LP_ALLREDUCE_HIER_STAGE_NUM;
|
|
|
|
// The number of elements packed into one for comms
|
|
constexpr int elts_per_thread = sizeof(int4) / sizeof(QUANT_T);
|
|
constexpr int ELTS_PER_BLOCK = elts_per_thread * LP_ALLREDUCE_DEFAULT_BLOCK_SIZE;
|
|
|
|
extern __shared__ float smem[];
|
|
|
|
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
|
// Packed data type for comms
|
|
using PackedType = typename PackedOn16Bytes<QUANT_T>::Type;
|
|
|
|
if (bidx < block_num_per_stage)
|
|
{
|
|
// reduce-scatter inside NUMA
|
|
int local_bid = bidx % block_num_per_stage;
|
|
uint64_t send_flag = *params.rs_send_flags[local_bid];
|
|
QUANT_T* src_d[LP_ALLREDUCE_RANKS_PER_NUMA];
|
|
QUANT_T* dst = reinterpret_cast<QUANT_T*>(params.rs_buffers[local_bid]);
|
|
|
|
// The destination ranks for round-robin gathering
|
|
#pragma unroll
|
|
for (int ii = 0; ii < LP_ALLREDUCE_RANKS_PER_NUMA; ++ii)
|
|
{
|
|
int numa_rank = (params.numa_rank + ii) % LP_ALLREDUCE_RANKS_PER_NUMA;
|
|
src_d[ii] = reinterpret_cast<QUANT_T*>(params.inputs_inside_numa[numa_rank]);
|
|
}
|
|
|
|
int32_t index = 0;
|
|
while (index < params.num_rounds_fence)
|
|
{
|
|
if (tidx < LP_ALLREDUCE_NUMA_NUM)
|
|
{
|
|
wait_send_peer(send_flag, params.rs_ack_flags[local_bid] + tidx);
|
|
}
|
|
__syncthreads();
|
|
int const processed = index * duplicate;
|
|
int const remaining = params.num_rounds - processed;
|
|
int const transfer_times = min(duplicate, remaining);
|
|
|
|
for (int i = 0; i < transfer_times; ++i)
|
|
{
|
|
int const global_iter = index * duplicate + i;
|
|
|
|
int const chunk_idx = send_flag % LP_ALLREDUCE_BUFFER_CHUNKS;
|
|
int const dst_offset = chunk_idx * ELTS_PER_BLOCK * duplicate + ELTS_PER_BLOCK * i;
|
|
int const global_per_tier = block_num_per_stage * LP_ALLREDUCE_RANKS_PER_NUMA * ELTS_PER_BLOCK;
|
|
int const rank_offset = LP_ALLREDUCE_RANKS_PER_NUMA * ELTS_PER_BLOCK;
|
|
const size_t global_offset
|
|
= global_iter * global_per_tier + local_bid * rank_offset + params.numa_rank * ELTS_PER_BLOCK;
|
|
hierReduceWithQdq<LP_ALLREDUCE_RANKS_PER_NUMA, QUANT_T>(
|
|
params, src_d, dst + dst_offset, global_offset, ELTS_PER_BLOCK, smem);
|
|
}
|
|
|
|
__syncthreads();
|
|
send_flag++;
|
|
if (tidx == 0)
|
|
{
|
|
__threadfence_system();
|
|
notify_peer_with_value(params.rs_notify_remote_flags[local_bid], send_flag);
|
|
notify_peer_with_value(params.rs_notify_local_flags[local_bid], send_flag);
|
|
}
|
|
index++;
|
|
}
|
|
if (tidx == 0)
|
|
{
|
|
*params.rs_send_flags[local_bid] = send_flag;
|
|
}
|
|
return;
|
|
}
|
|
|
|
else if (bidx >= block_num_per_stage && bidx < block_num_per_stage * 2)
|
|
{
|
|
// partial allreduce cross NUMA
|
|
int local_bid = bidx % block_num_per_stage;
|
|
uint64_t send_flag = *params.ar_send_flags[local_bid];
|
|
// 2 is all
|
|
QUANT_T* src_d[LP_ALLREDUCE_NUMA_NUM];
|
|
QUANT_T* dst = reinterpret_cast<QUANT_T*>(params.ar_buffers[local_bid]);
|
|
src_d[0] = reinterpret_cast<QUANT_T*>(params.rs_buffers[local_bid]);
|
|
src_d[1] = reinterpret_cast<QUANT_T*>(params.ar_peer_buffers_cross_numa[local_bid]);
|
|
|
|
int32_t index = 0;
|
|
while (index < params.num_rounds_fence)
|
|
{
|
|
if (tidx == 0)
|
|
{
|
|
wait_recv_peer(send_flag, params.rs_notify_local_flags[local_bid]);
|
|
wait_recv_peer(send_flag, params.ar_ack_peer_rs_flags[local_bid]);
|
|
wait_send_peer(send_flag, params.ar_ack_flags[local_bid]);
|
|
}
|
|
__syncthreads();
|
|
|
|
int const processed = index * duplicate;
|
|
int const remaining = params.num_rounds - processed;
|
|
int const transfer_times = min(duplicate, remaining);
|
|
|
|
int const chunk_idx = send_flag % LP_ALLREDUCE_BUFFER_CHUNKS;
|
|
int const base_offset = chunk_idx * ELTS_PER_BLOCK * duplicate;
|
|
|
|
for (int i = 0; i < transfer_times; ++i)
|
|
{
|
|
int const offset = base_offset + i * ELTS_PER_BLOCK;
|
|
hierReduceWithQdq<LP_ALLREDUCE_NUMA_NUM, QUANT_T>(
|
|
params, src_d, dst + offset, offset, ELTS_PER_BLOCK, smem);
|
|
}
|
|
__syncthreads();
|
|
|
|
send_flag++;
|
|
if (tidx == 0)
|
|
{
|
|
__threadfence_system();
|
|
notify_peer_with_value(params.ar_notify_rs_remote_flags[local_bid], send_flag);
|
|
notify_peer_with_value(params.ar_notify_rs_local_flags[local_bid], send_flag);
|
|
notify_peer_with_value(params.ar_notify_ag_flags[local_bid], send_flag);
|
|
}
|
|
index++;
|
|
}
|
|
if (tidx == 0)
|
|
{
|
|
*params.ar_send_flags[local_bid] = send_flag;
|
|
}
|
|
return;
|
|
}
|
|
else if (bidx >= block_num_per_stage * 2 && bidx < block_num_per_stage * 3)
|
|
{
|
|
// allgather inside NUMA
|
|
int local_bid = bidx % block_num_per_stage;
|
|
uint64_t send_flag = *params.ag_send_flags[local_bid];
|
|
QUANT_T* src_d[LP_ALLREDUCE_RANKS_PER_NUMA];
|
|
T* dst = reinterpret_cast<T*>(params.local_output_buffer_ptr);
|
|
#pragma unroll
|
|
for (int ii = 0; ii < LP_ALLREDUCE_RANKS_PER_NUMA; ++ii)
|
|
{
|
|
int numa_rank = (params.numa_rank + ii) % LP_ALLREDUCE_RANKS_PER_NUMA;
|
|
|
|
src_d[ii] = reinterpret_cast<QUANT_T*>(params.ag_peer_buffers_inside_numa[local_bid * 4 + numa_rank]);
|
|
}
|
|
|
|
int32_t index = 0;
|
|
while (index < params.num_rounds_fence)
|
|
{
|
|
if (tidx == 0)
|
|
{
|
|
wait_recv_peer(send_flag, params.ar_notify_ag_flags[local_bid]);
|
|
}
|
|
|
|
__syncthreads();
|
|
if (tidx < LP_ALLREDUCE_RANKS_PER_NUMA)
|
|
{
|
|
|
|
notify_peer_with_value_relax(
|
|
params.ag_notify_peer_inside_numa_flags[local_bid * LP_ALLREDUCE_RANKS_PER_NUMA + tidx],
|
|
send_flag + 1);
|
|
wait_recv_peer(send_flag, params.ag_ack_peer_inside_numa_flags[local_bid] + tidx);
|
|
}
|
|
__syncthreads();
|
|
|
|
int const processed = index * duplicate;
|
|
int const remaining = params.num_rounds - processed;
|
|
int const transfer_times = min(duplicate, remaining);
|
|
|
|
int const chunk_idx = send_flag % LP_ALLREDUCE_BUFFER_CHUNKS;
|
|
int const base_offset = chunk_idx * ELTS_PER_BLOCK * duplicate;
|
|
|
|
for (int i = 0; i < transfer_times; ++i)
|
|
{
|
|
|
|
int const global_iter = processed + i;
|
|
const size_t curr_offset = base_offset + i * ELTS_PER_BLOCK;
|
|
|
|
hierAllgatherWithDq<LP_ALLREDUCE_RANKS_PER_NUMA, QUANT_T, T>(
|
|
params, src_d, dst, curr_offset, global_iter, ELTS_PER_BLOCK, block_num_per_stage, smem);
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
send_flag++;
|
|
if (tidx == 0)
|
|
{
|
|
notify_peer_with_value_relax(params.ar_ack_flags[local_bid], send_flag);
|
|
}
|
|
index++;
|
|
}
|
|
if (tidx == 0)
|
|
{
|
|
*params.ag_send_flags[local_bid] = send_flag;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
return;
|
|
}
|
|
}
|
|
|
|
template <typename T, typename QUANT_T, int RANKS_PER_NODE>
|
|
void lowPrecisionAllReduceDispatchRanksPerNode(kernels::LowPrecisionAllReduceParams& params, cudaStream_t stream)
|
|
{
|
|
constexpr int qtype_elts_per_load = LP_ALLREDUCE_BYTES_PER_LOAD / sizeof(QUANT_T);
|
|
constexpr int elts_per_block = qtype_elts_per_load * (LP_ALLREDUCE_WARPSIZE - 1) * LP_ALLREDUCE_WARP_NUM_PER_BLOCK;
|
|
constexpr int elts_per_block_with_scale = qtype_elts_per_load * LP_ALLREDUCE_DEFAULT_BLOCK_SIZE;
|
|
if (RANKS_PER_NODE <= 4)
|
|
{
|
|
|
|
int blocks_per_grid = LP_ALLREDUCE_MAX_BLOCKS * 2, threads_per_block = LP_ALLREDUCE_DEFAULT_BLOCK_SIZE;
|
|
|
|
params.elts_per_rank = params.elts_total / RANKS_PER_NODE;
|
|
params.rank_offset = params.rank * params.elts_per_rank;
|
|
params.elts_per_block = elts_per_block;
|
|
|
|
size_t num_rounds_per_rank = (params.elts_per_rank - 1) / elts_per_block + 1;
|
|
size_t my_rank = params.local_rank;
|
|
|
|
params.buffer_offset = my_rank * elts_per_block_with_scale * num_rounds_per_rank;
|
|
params.buffer_elts_per_rank = elts_per_block_with_scale * num_rounds_per_rank;
|
|
lowPrecisionPreprocessKernel<RANKS_PER_NODE, T, QUANT_T>
|
|
<<<num_rounds_per_rank * RANKS_PER_NODE, threads_per_block, 0, stream>>>(
|
|
(T const*) params.local_input_buffer_ptr, params.elts_per_rank, params.buffer_elts_per_rank,
|
|
(QUANT_T*) params.peer_comm_buffer_ptrs[my_rank]);
|
|
lowPrecisionTwoShotAllReduceKernel<T, QUANT_T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block,
|
|
(LP_ALLREDUCE_WARP_NUM_PER_BLOCK * RANKS_PER_NODE) * sizeof(float) * 2, stream>>>(params);
|
|
}
|
|
else
|
|
{
|
|
int blocks_per_grid = LP_ALLREDUCE_MAX_BLOCKS, threads_per_block = LP_ALLREDUCE_DEFAULT_BLOCK_SIZE;
|
|
params.num_rounds = (((params.elts_total - 1) / elts_per_block + 1) - 1) / LP_ALLREDUCE_MAX_RANKS_PER_NUMA
|
|
/ LP_ALLREDUCE_MAX_BLOCKS
|
|
+ 1;
|
|
params.num_rounds_fence = (params.num_rounds - 1) / LP_ALLREDUCE_BUFFER_DUPLICATE + 1;
|
|
blocks_per_grid = params.num_rounds < LP_ALLREDUCE_MAX_BLOCKS ? params.num_rounds : blocks_per_grid;
|
|
|
|
size_t preprocess_blocks_per_grid = params.num_rounds * LP_ALLREDUCE_MAX_RANKS_PER_NUMA * blocks_per_grid;
|
|
size_t my_rank = params.local_rank;
|
|
blocks_per_grid *= LP_ALLREDUCE_HIER_STAGE_NUM; // 3 stages need more block
|
|
|
|
lowPrecisionHierPreprocessKernel<T, QUANT_T><<<preprocess_blocks_per_grid, LP_ALLREDUCE_DEFAULT_BLOCK_SIZE,
|
|
(LP_ALLREDUCE_WARP_NUM_PER_BLOCK) * sizeof(float), stream>>>((T const*) params.local_input_buffer_ptr,
|
|
params.elts_total, (QUANT_T*) params.peer_comm_buffer_ptrs[my_rank]);
|
|
lowPrecisionTwoShotHierAllReduceKernel<T, QUANT_T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block,
|
|
(LP_ALLREDUCE_WARP_NUM_PER_BLOCK * RANKS_PER_NODE) * sizeof(float), stream>>>(params);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void lowPrecisionAllReduceDispatchType(kernels::LowPrecisionAllReduceParams& param, cudaStream_t stream)
|
|
{
|
|
#ifdef ENABLE_FP8
|
|
switch (param.ranks_per_node)
|
|
{
|
|
case 2: lowPrecisionAllReduceDispatchRanksPerNode<T, __nv_fp8_e4m3, 2>(param, stream); break;
|
|
case 4: lowPrecisionAllReduceDispatchRanksPerNode<T, __nv_fp8_e4m3, 4>(param, stream); break;
|
|
case 8: lowPrecisionAllReduceDispatchRanksPerNode<T, __nv_fp8_e4m3, 8>(param, stream); break;
|
|
default: TLLM_THROW("Custom LowPrecision all reduce only supported on {2, 4, 8} GPUs per node.");
|
|
}
|
|
#else
|
|
TLLM_THROW("Can't Use Low Precision Allreduce When Compile Without ENABLE_FP8");
|
|
#endif
|
|
}
|
|
|
|
std::vector<size_t> splitNumber(size_t number)
|
|
{
|
|
std::vector<size_t> parts;
|
|
size_t parts_num = number / LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE;
|
|
size_t remain = number % LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE;
|
|
if (parts_num == 0)
|
|
{
|
|
parts.push_back(remain);
|
|
}
|
|
else
|
|
{
|
|
if (remain == 0)
|
|
{
|
|
for (size_t i = 0; i < parts_num; ++i)
|
|
{
|
|
parts.push_back(LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
for (size_t i = 0; i < parts_num - 1; ++i)
|
|
{
|
|
parts.push_back(LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE);
|
|
}
|
|
// if last remain part is small, will split a normal part, and fuse remain part to half normal
|
|
// part
|
|
if (remain < LP_ALLREDUCE_MIN_ELTS_THRESHOLD)
|
|
{
|
|
parts.push_back(LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE / 2 + remain);
|
|
parts.push_back(LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE / 2);
|
|
}
|
|
else
|
|
{
|
|
parts.push_back(LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE);
|
|
parts.push_back(remain);
|
|
}
|
|
}
|
|
}
|
|
return parts;
|
|
}
|
|
|
|
LowPrecisionAllReduceParams LowPrecisionAllReduceParams::deserialize(
|
|
size_t tpSize, size_t tpRank, nvinfer1::DataType dataType, int token_num, int hidden_size)
|
|
{
|
|
|
|
// Get appropriate static buffer
|
|
StaticLowPrecisionBuffers* static_buffers = getBufferForTpSize(tpSize);
|
|
|
|
// Check initialization
|
|
if (!static_buffers->initialized || static_buffers->tpSize != tpSize)
|
|
{
|
|
TLLM_THROW("Static buffers for TP size %zu not initialized", tpSize);
|
|
}
|
|
|
|
// Use the stored flag pointer
|
|
*(static_buffers->flag_ptr) += 1;
|
|
|
|
TLLM_LOG_TRACE("AllReduceParams's flag value is %d", *(static_buffers->flag_ptr));
|
|
uint64_t flag_value = *(static_buffers->flag_ptr);
|
|
LowPrecisionAllReduceParams params;
|
|
// Even plugins use ping buffers, odd plugins use pong.
|
|
// That way, we don't need to wait for other GPUs to be done
|
|
// before copying input tensor to workspace.
|
|
auto const buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize;
|
|
|
|
for (int i = 0; i < tpSize; ++i)
|
|
{
|
|
params.peer_comm_buffer_ptrs[i] = static_buffers->peer_comm_buffer_ptrs[buffer_offset + i];
|
|
}
|
|
for (int i = 0; i < tpSize; ++i)
|
|
{
|
|
params.peer_barrier_ptrs_in[i] = static_buffers->peer_barrier_ptrs_in[i];
|
|
}
|
|
for (int i = 0; i < tpSize; ++i)
|
|
{
|
|
params.peer_barrier_ptrs_out[i] = static_buffers->peer_barrier_ptrs_out[i];
|
|
}
|
|
// Assume that a single allreduce will not be divided into more than 64 allreduces of 64MB each,it is not very safe
|
|
params.barrier_flag = flag_value;
|
|
params.ranks_per_node = tpSize;
|
|
params.local_rank = tpRank;
|
|
|
|
return params;
|
|
}
|
|
|
|
LowPrecisionAllReduceParams LowPrecisionAllReduceParams::deserialize_hier(
|
|
size_t tpSize, size_t tpRank, nvinfer1::DataType dataType, int token_num, int hidden_size)
|
|
{
|
|
|
|
// Get appropriate static buffer
|
|
StaticLowPrecisionBuffers* static_buffers = getBufferForTpSize(tpSize);
|
|
|
|
// Check initialization
|
|
if (!static_buffers->initialized || static_buffers->tpSize != tpSize)
|
|
{
|
|
TLLM_THROW("Static buffers for TP size %zu not initialized", tpSize);
|
|
}
|
|
|
|
// Use the stored flag pointer
|
|
*(static_buffers->flag_ptr) += 1;
|
|
|
|
TLLM_LOG_TRACE("AllReduceParams's flag value is %d", *(static_buffers->flag_ptr));
|
|
uint64_t flag_value = *(static_buffers->flag_ptr);
|
|
LowPrecisionAllReduceParams params;
|
|
// Even plugins use ping buffers, odd plugins use pong.
|
|
// That way, we don't need to wait for other GPUs to be done
|
|
// before copying input tensor to workspace.
|
|
auto const buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize;
|
|
|
|
for (int i = 0; i < tpSize; ++i)
|
|
{
|
|
params.peer_comm_buffer_ptrs[i] = static_buffers->peer_comm_buffer_ptrs[buffer_offset + i];
|
|
}
|
|
for (int i = 0; i < tpSize; ++i)
|
|
{
|
|
params.peer_barrier_ptrs_in[i] = static_buffers->peer_barrier_ptrs_in[i];
|
|
}
|
|
for (int i = 0; i < tpSize; ++i)
|
|
{
|
|
params.peer_barrier_ptrs_out[i] = static_buffers->peer_barrier_ptrs_out[i];
|
|
}
|
|
// Assume that a single allreduce will not be divided into more than 64 allreduces of 64MB each,it is not very safe
|
|
params.barrier_flag = flag_value;
|
|
params.ranks_per_node = tpSize;
|
|
params.local_rank = tpRank;
|
|
|
|
params.numa_rank = tpRank % LP_ALLREDUCE_MAX_RANKS_PER_NUMA;
|
|
|
|
// assume quant_type is 1 bytes , so we can transfer LP_ALLREDUCE_BYTES_PER_LOAD elts once
|
|
int REAL_ELTS_PER_BLOCK
|
|
= (LP_ALLREDUCE_WARPSIZE - 1) * LP_ALLREDUCE_BYTES_PER_LOAD * LP_ALLREDUCE_WARP_NUM_PER_BLOCK;
|
|
int QUANT_ELTS_PER_BLOCK = LP_ALLREDUCE_DEFAULT_BLOCK_SIZE * LP_ALLREDUCE_BYTES_PER_LOAD;
|
|
|
|
int max_rounds = (((LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE - 1) / REAL_ELTS_PER_BLOCK + 1) - 1)
|
|
/ LP_ALLREDUCE_MAX_RANKS_PER_NUMA / LP_ALLREDUCE_MAX_BLOCKS
|
|
+ 1;
|
|
int max_fence_rounds = (max_rounds - 1) / LP_ALLREDUCE_BUFFER_DUPLICATE + 1;
|
|
|
|
uint64_t quantize_offset = max_fence_rounds * LP_ALLREDUCE_MAX_RANKS_PER_NUMA * LP_ALLREDUCE_MAX_BLOCKS
|
|
* LP_ALLREDUCE_BUFFER_DUPLICATE * QUANT_ELTS_PER_BLOCK;
|
|
for (int i = 0; i < LP_ALLREDUCE_MAX_RANKS_PER_NUMA; ++i)
|
|
{
|
|
params.inputs_inside_numa[i]
|
|
= params.peer_comm_buffer_ptrs[(tpRank / LP_ALLREDUCE_MAX_RANKS_PER_NUMA) * LP_ALLREDUCE_MAX_RANKS_PER_NUMA
|
|
+ i];
|
|
}
|
|
|
|
for (int i = 0; i < LP_ALLREDUCE_MAX_BLOCKS; ++i)
|
|
{
|
|
|
|
const size_t block_buffer_size
|
|
= QUANT_ELTS_PER_BLOCK * LP_ALLREDUCE_BUFFER_CHUNKS * LP_ALLREDUCE_BUFFER_DUPLICATE;
|
|
char* base_ptr = reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[tpRank]);
|
|
|
|
params.rs_buffers[i] = base_ptr + quantize_offset + block_buffer_size * i;
|
|
|
|
const size_t ar_buffer_offset = quantize_offset + block_buffer_size * LP_ALLREDUCE_MAX_BLOCKS;
|
|
|
|
params.ar_buffers[i] = base_ptr + ar_buffer_offset + block_buffer_size * i;
|
|
|
|
int const cross_numa_rank = (tpRank + LP_ALLREDUCE_MAX_RANKS_PER_NUMA) % tpSize;
|
|
params.ar_peer_buffers_cross_numa[i] = reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[cross_numa_rank])
|
|
+ quantize_offset + block_buffer_size * i;
|
|
int const numa_group_base = (tpRank / LP_ALLREDUCE_MAX_RANKS_PER_NUMA) * LP_ALLREDUCE_MAX_RANKS_PER_NUMA;
|
|
for (int j = 0; j < LP_ALLREDUCE_MAX_RANKS_PER_NUMA; ++j)
|
|
{
|
|
int const rank_in_numa = numa_group_base + j;
|
|
params.ag_peer_buffers_inside_numa[i * LP_ALLREDUCE_MAX_RANKS_PER_NUMA + j]
|
|
= reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[rank_in_numa])
|
|
+ ar_buffer_offset + block_buffer_size * i);
|
|
}
|
|
|
|
const size_t rs_send_flags_offset = ar_buffer_offset + block_buffer_size * LP_ALLREDUCE_MAX_BLOCKS;
|
|
params.rs_send_flags[i] = reinterpret_cast<uint64_t*>(base_ptr + rs_send_flags_offset + i * sizeof(uint64_t));
|
|
|
|
uint64_t rs_ack_flags_offset = rs_send_flags_offset + LP_ALLREDUCE_MAX_BLOCKS * sizeof(uint64_t);
|
|
params.rs_ack_flags[i]
|
|
= reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[tpRank])
|
|
+ rs_ack_flags_offset + i * sizeof(uint64_t) * 2);
|
|
|
|
uint64_t rs_notify_local_flags_offset = rs_ack_flags_offset + LP_ALLREDUCE_MAX_BLOCKS * sizeof(uint64_t) * 2;
|
|
params.rs_notify_local_flags[i]
|
|
= reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[tpRank])
|
|
+ rs_notify_local_flags_offset + i * sizeof(uint64_t));
|
|
|
|
uint64_t rs_notify_remote_flags_offset
|
|
= rs_notify_local_flags_offset + LP_ALLREDUCE_MAX_BLOCKS * sizeof(uint64_t);
|
|
|
|
// now only 8gpus can use hier , so %8 is a magic num
|
|
params.rs_notify_remote_flags[i] = reinterpret_cast<uint64_t*>(
|
|
reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[(tpRank + LP_ALLREDUCE_MAX_RANKS_PER_NUMA) % tpSize])
|
|
+ rs_notify_remote_flags_offset + i * sizeof(uint64_t));
|
|
|
|
// special flag for ar stage
|
|
params.ar_ack_peer_rs_flags[i]
|
|
= reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[tpRank])
|
|
+ rs_notify_remote_flags_offset + i * sizeof(uint64_t));
|
|
|
|
// rs stage handshake done
|
|
|
|
// for partial ar stage handshake
|
|
uint64_t ar_send_flags_offset = rs_notify_remote_flags_offset + LP_ALLREDUCE_MAX_BLOCKS * sizeof(uint64_t);
|
|
params.ar_send_flags[i]
|
|
= reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[tpRank])
|
|
+ ar_send_flags_offset + i * sizeof(uint64_t));
|
|
|
|
// 2 flag in numa,so use fix *2
|
|
// for ar notify , it is rs_ack_flags
|
|
params.ar_notify_rs_local_flags[i]
|
|
= reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[tpRank])
|
|
+ rs_ack_flags_offset + i * sizeof(uint64_t) * 2);
|
|
// now only 8gpus can use hier , so %8 is a magic num
|
|
params.ar_notify_rs_remote_flags[i] = reinterpret_cast<uint64_t*>(
|
|
reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[(tpRank + LP_ALLREDUCE_MAX_RANKS_PER_NUMA) % tpSize])
|
|
+ rs_ack_flags_offset + i * sizeof(uint64_t) * 2 + sizeof(uint64_t));
|
|
|
|
uint64_t ar_ack_flags_offset = ar_send_flags_offset + LP_ALLREDUCE_MAX_BLOCKS * sizeof(uint64_t);
|
|
|
|
params.ar_ack_flags[i] = reinterpret_cast<uint64_t*>(
|
|
reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[tpRank]) + ar_ack_flags_offset + i * sizeof(uint64_t));
|
|
|
|
uint64_t ar_notify_ag_flags_offset = ar_ack_flags_offset + LP_ALLREDUCE_MAX_BLOCKS * sizeof(uint64_t);
|
|
params.ar_notify_ag_flags[i]
|
|
= reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[tpRank])
|
|
+ ar_notify_ag_flags_offset + i * sizeof(uint64_t));
|
|
|
|
// partial ar stage done
|
|
|
|
// for ag stage
|
|
uint64_t ag_send_flags_offset = ar_notify_ag_flags_offset + LP_ALLREDUCE_MAX_BLOCKS * sizeof(uint64_t);
|
|
params.ag_send_flags[i]
|
|
= reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[tpRank])
|
|
+ ag_send_flags_offset + i * sizeof(uint64_t));
|
|
|
|
// 4 flag in numa,so use fix *4
|
|
uint64_t ag_ack_peer_inside_numa_flags_offset
|
|
= ag_send_flags_offset + LP_ALLREDUCE_MAX_BLOCKS * sizeof(uint64_t);
|
|
params.ag_ack_peer_inside_numa_flags[i]
|
|
= reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[tpRank])
|
|
+ ag_ack_peer_inside_numa_flags_offset + i * sizeof(uint64_t) * 4);
|
|
|
|
for (int j = 0; j < LP_ALLREDUCE_MAX_RANKS_PER_NUMA; ++j)
|
|
{
|
|
params.ag_notify_peer_inside_numa_flags[i * LP_ALLREDUCE_MAX_RANKS_PER_NUMA + j]
|
|
= reinterpret_cast<uint64_t*>(
|
|
reinterpret_cast<char*>(params.peer_comm_buffer_ptrs[(tpRank / LP_ALLREDUCE_MAX_RANKS_PER_NUMA)
|
|
* LP_ALLREDUCE_MAX_RANKS_PER_NUMA
|
|
+ j])
|
|
+ ag_ack_peer_inside_numa_flags_offset + i * sizeof(uint64_t) * 4
|
|
+ (tpRank % LP_ALLREDUCE_MAX_RANKS_PER_NUMA) * sizeof(uint64_t));
|
|
}
|
|
// ag stage done
|
|
}
|
|
|
|
return params;
|
|
}
|
|
|
|
bool lowPrecisionConfigurationSupported(size_t n_ranks, size_t msg_size)
|
|
{
|
|
size_t elts_per_thread = LP_ALLREDUCE_BYTES_PER_LOAD; // assume quant_type size is 1 bytes
|
|
int msg_align = elts_per_thread;
|
|
if (n_ranks <= 4)
|
|
{
|
|
msg_align *= n_ranks;
|
|
}
|
|
return msg_size % msg_align == 0;
|
|
}
|
|
|
|
int32_t max_workspace_size_lowprecision(int32_t tp_size)
|
|
{
|
|
// assume quant_type is 1 byte , so we can transfer LP_ALLREDUCE_BYTES_PER_LOAD elts once
|
|
constexpr int32_t REAL_ELTS_PER_BLOCK
|
|
= (LP_ALLREDUCE_WARPSIZE - 1) * LP_ALLREDUCE_BYTES_PER_LOAD * LP_ALLREDUCE_WARP_NUM_PER_BLOCK;
|
|
constexpr int32_t QUANT_ELTS_PER_BLOCK = LP_ALLREDUCE_DEFAULT_BLOCK_SIZE * LP_ALLREDUCE_BYTES_PER_LOAD;
|
|
|
|
int32_t buffer_bytes;
|
|
if (tp_size == 8)
|
|
{
|
|
int32_t max_rounds = ((((LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE - 1) / REAL_ELTS_PER_BLOCK + 1) - 1)
|
|
/ LP_ALLREDUCE_MAX_RANKS_PER_NUMA / LP_ALLREDUCE_MAX_BLOCKS)
|
|
+ 1;
|
|
int32_t max_fence_rounds = ((max_rounds - 1) / LP_ALLREDUCE_BUFFER_DUPLICATE) + 1;
|
|
int32_t quantize_buffer_bytes = max_fence_rounds * LP_ALLREDUCE_MAX_RANKS_PER_NUMA * LP_ALLREDUCE_MAX_BLOCKS
|
|
* LP_ALLREDUCE_BUFFER_DUPLICATE * QUANT_ELTS_PER_BLOCK;
|
|
int32_t comm_buffer_bytes = LP_ALLREDUCE_BUFFER_CHUNKS * LP_ALLREDUCE_BUFFER_DUPLICATE * LP_ALLREDUCE_MAX_BLOCKS
|
|
* LP_ALLREDUCE_HIER_STAGE_NUM * QUANT_ELTS_PER_BLOCK;
|
|
buffer_bytes = quantize_buffer_bytes + comm_buffer_bytes;
|
|
}
|
|
else
|
|
{
|
|
buffer_bytes = (((LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE / tp_size - 1) / REAL_ELTS_PER_BLOCK) + 1)
|
|
* QUANT_ELTS_PER_BLOCK * tp_size;
|
|
}
|
|
|
|
constexpr int32_t HANDSHAKE_FLAG_NUM = 32;
|
|
int32_t flag_bytes = LP_ALLREDUCE_MAX_BLOCKS * HANDSHAKE_FLAG_NUM * sizeof(uint64_t);
|
|
|
|
return buffer_bytes + flag_bytes;
|
|
}
|
|
|
|
void customLowPrecisionAllReduce(
|
|
kernels::LowPrecisionAllReduceParams& params, nvinfer1::DataType dataType, cudaStream_t stream)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(lowPrecisionConfigurationSupported(params.ranks_per_node, params.elts_total),
|
|
"Low Precision Custom all-reduce configuration unsupported");
|
|
|
|
sync_check_cuda_error(stream);
|
|
|
|
switch (dataType)
|
|
{
|
|
case nvinfer1::DataType::kFLOAT: lowPrecisionAllReduceDispatchType<float>(params, stream); break;
|
|
case nvinfer1::DataType::kHALF: lowPrecisionAllReduceDispatchType<half>(params, stream); break;
|
|
#ifdef ENABLE_BF16
|
|
case nvinfer1::DataType::kBF16: lowPrecisionAllReduceDispatchType<__nv_bfloat16>(params, stream); break;
|
|
#endif
|
|
default: TLLM_THROW("Unsupported dataType for customAllReduce");
|
|
}
|
|
sync_check_cuda_error(stream);
|
|
}
|
|
|
|
} // namespace kernels
|
|
|
|
TRTLLM_NAMESPACE_END
|