Add support for FP8 datatypes

Added new datatypes: f8e4m3, f8e5m2

Only supported on H100+ architectures and NCCL versions >= 2.24.0
This commit is contained in:
David Addison 2025-04-18 19:20:59 -07:00
parent b4300cc79d
commit 501a149d57
4 changed files with 415 additions and 156 deletions

View File

@ -21,15 +21,21 @@ int test_ncclVersion = 0; // init'd with ncclGetVersion()
#if NCCL_MAJOR >= 2
ncclDataType_t test_types[ncclNumTypes] = {
ncclInt8, ncclUint8, ncclInt32, ncclUint32, ncclInt64, ncclUint64, ncclHalf, ncclFloat, ncclDouble
#if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
#if HAVE_BF16
, ncclBfloat16
#endif
#if HAVE_FP8
, ncclFloat8e4m3, ncclFloat8e5m2
#endif
};
const char *test_typenames[ncclNumTypes] = {
"int8", "uint8", "int32", "uint32", "int64", "uint64", "half", "float", "double"
#if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
#if HAVE_BF16
, "bfloat16"
#endif
#if HAVE_FP8
, "f8e4m3", "f8e5m2"
#endif
};
int test_typenum = -1;
@ -86,6 +92,7 @@ static int average = 1;
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
static int local_register = 0;
#endif
static int minCudaArch = 1<<30;
#define NUM_BLOCKS 32
@ -126,18 +133,18 @@ static double parsesize(const char *value) {
}
testResult_t CheckDelta(void* results, void* expected, size_t count, size_t offset, ncclDataType_t type, ncclRedOp_t op, uint64_t seed, int nranks, int64_t *wrongEltN) {
ncclVerifiableVerify(results, expected, count, (int)type, (int)op, nranks, seed, offset, wrongEltN, cudaStreamDefault);
CUDACHECK(ncclVerifiableVerify(results, expected, count, (int)type, (int)op, nranks, seed, offset, wrongEltN, cudaStreamDefault));
CUDACHECK(cudaDeviceSynchronize());
return testSuccess;
}
testResult_t InitDataReduce(void* data, const size_t count, const size_t offset, ncclDataType_t type, ncclRedOp_t op, uint64_t seed, int nranks) {
ncclVerifiablePrepareExpected(data, count, (int)type, (int)op, nranks, seed, offset, cudaStreamDefault);
CUDACHECK(ncclVerifiablePrepareExpected(data, count, (int)type, (int)op, nranks, seed, offset, cudaStreamDefault));
return testSuccess;
}
testResult_t InitData(void* data, const size_t count, size_t offset, ncclDataType_t type, ncclRedOp_t op, uint64_t seed, int nranks, int rank) {
ncclVerifiablePrepareInput(data, count, (int)type, (int)op, nranks, rank, seed, offset, cudaStreamDefault);
CUDACHECK(ncclVerifiablePrepareInput(data, count, (int)type, (int)op, nranks, rank, seed, offset, cudaStreamDefault));
return testSuccess;
}
@ -358,9 +365,12 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
union {
int8_t i8; uint8_t u8; int32_t i32; uint32_t u32; int64_t i64; uint64_t u64;
half f16; float f32; double f64;
#if defined(__CUDA_BF16_TYPES_EXIST__)
#if HAVE_BF16
__nv_bfloat16 bf16;
#endif
#if HAVE_FP8
__nv_fp8_e4m3 f8e4m3; __nv_fp8_e5m2 f8e5m2;
#endif
};
switch(type) {
case ncclInt8: i8 = ncclVerifiablePremulScalar<int8_t>(rank); break;
@ -372,9 +382,14 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
case ncclFloat16: f16 = ncclVerifiablePremulScalar<half>(rank); break;
case ncclFloat32: f32 = ncclVerifiablePremulScalar<float>(rank); break;
case ncclFloat64: f64 = ncclVerifiablePremulScalar<double>(rank); break;
#if defined(__CUDA_BF16_TYPES_EXIST__)
#if HAVE_BF16
case ncclBfloat16: bf16 = ncclVerifiablePremulScalar<__nv_bfloat16>(rank); break;
#endif
#if HAVE_FP8
case ncclFloat8e4m3: f8e4m3 = ncclVerifiablePremulScalar<__nv_fp8_e4m3>(rank); break;
case ncclFloat8e5m2: f8e5m2 = ncclVerifiablePremulScalar<__nv_fp8_e5m2>(rank); break;
#endif
default: break; // Just to silence clang
}
NCCLCHECK(ncclRedOpCreatePreMulSum(&op, &u64, type, ncclScalarHostImmediate, args->comms[i]));
}
@ -702,13 +717,20 @@ int main(int argc, char* argv[]) {
test_typenum = 9;
if (NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) && test_ncclVersion >= NCCL_VERSION(2,10,0)) {
test_opnum++; // ncclAvg
#if defined(__CUDA_BF16_TYPES_EXIST__)
test_typenum++; // bfloat16
#endif
}
if (NCCL_VERSION_CODE >= NCCL_VERSION(2,11,0) && test_ncclVersion >= NCCL_VERSION(2,11,0)) {
test_opnum++; // PreMulSum
}
#if defined(__CUDA_BF16_TYPES_EXIST__)
if (NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) && test_ncclVersion >= NCCL_VERSION(2,10,0)) {
test_typenum++; // bfloat16
}
#endif
#if defined(__CUDA_FP8_TYPES_EXIST__)
if (NCCL_VERSION_CODE >= NCCL_VERSION(2,24,0) && test_ncclVersion >= NCCL_VERSION(2,24,0)) {
test_typenum += 2; // fp8 e4m3,e5m2
}
#endif
#endif
// Parse args
@ -1033,12 +1055,37 @@ testResult_t run() {
gpus[i] = (gpu0 != -1 ? gpu0 : localRank*nThreads*nGpus) + i;
CUDACHECK(cudaSetDevice(gpus[i]));
TESTCHECK(AllocateBuffs(sendbuffs+i, sendBytes, recvbuffs+i, recvBytes, expected+i, (size_t)maxBytes));
if (streamnull)
if (streamnull) {
streams[i] = NULL;
else
}
else {
CUDACHECK(cudaStreamCreateWithFlags(streams+i, cudaStreamNonBlocking));
}
int archMajor, archMinor;
CUDACHECK(cudaDeviceGetAttribute(&archMajor, cudaDevAttrComputeCapabilityMajor, gpus[i]));
CUDACHECK(cudaDeviceGetAttribute(&archMinor, cudaDevAttrComputeCapabilityMinor, gpus[i]));
minCudaArch = std::min(minCudaArch, 100*archMajor + 10*archMinor);
}
#ifdef MPI_SUPPORT
MPI_Allreduce(MPI_IN_PLACE, &minCudaArch, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD);
#endif
#if defined(__CUDA_FP8_TYPES_EXIST__)
if (NCCL_VERSION_CODE >= NCCL_VERSION(2,24,0) && test_ncclVersion >= NCCL_VERSION(2,24,0)) {
if (minCudaArch < 900) { // Filter out fp8 on pre-Hopper hardware
int n = 0;
for (int i=0; i < test_typenum; i++) {
if (!(test_types[i] == ncclFloat8e4m3 || test_types[i] == ncclFloat8e5m2)) {
test_types[n] = test_types[i];
test_typenames[n] = test_typenames[i];
n += 1;
}
}
test_typenum = n;
}
}
#endif
//if parallel init is not selected, use main thread to initialize NCCL
ncclComm_t* comms = (ncclComm_t*)malloc(sizeof(ncclComm_t)*nThreads*nGpus);
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)

View File

@ -213,16 +213,34 @@ static uint64_t getHostHash(const char* hostname) {
return getHash(hostHash, strlen(hostHash));
}
#define HAVE_BF16 0
#define HAVE_FP8 0
#if NCCL_MAJOR >= 2
#if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
#undef HAVE_BF16
#define HAVE_BF16 1
#if defined(__CUDA_FP8_TYPES_EXIST__) && NCCL_VERSION_CODE >= NCCL_VERSION(2,24,0)
#undef HAVE_FP8
#define HAVE_FP8 1
#endif
#endif
#endif
static size_t wordSize(ncclDataType_t type) {
switch(type) {
case ncclChar:
#if NCCL_MAJOR >= 2
//case ncclInt8:
case ncclUint8:
#endif
#if HAVE_FP8
case ncclFloat8e4m3:
case ncclFloat8e5m2:
#endif
return 1;
case ncclHalf:
#if defined(__CUDA_BF16_TYPES_EXIST__)
#if HAVE_BF16
case ncclBfloat16:
#endif
//case ncclFloat16:

View File

@ -8,6 +8,15 @@
#if CUDART_VERSION >= 11000
#include <cuda_bf16.h>
#endif
#if CUDART_VERSION >= 11080
#include <cuda_fp8.h>
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,24,0) && defined(__CUDA_FP8_TYPES_EXIST__)
#define HAVE_ncclFloat8 1
#else
#define HAVE_ncclFloat8 0
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) && defined(__CUDA_BF16_TYPES_EXIST__)
#define HAVE_ncclBfloat16 1
@ -84,10 +93,16 @@ template<typename T>
struct IsIntegral: std::is_integral<T> {};
template<>
struct IsIntegral<half>: std::false_type {};
#ifdef __CUDA_BF16_TYPES_EXIST__
#if HAVE_ncclBfloat16
template<>
struct IsIntegral<__nv_bfloat16>: std::false_type {};
#endif
#if HAVE_ncclFloat8
template<>
struct IsIntegral<__nv_fp8_e4m3>: std::false_type {};
template<>
struct IsIntegral<__nv_fp8_e5m2>: std::false_type {};
#endif
}
////////////////////////////////////////////////////////////////////////////////
@ -107,23 +122,72 @@ __host__ __device__ T inhibit(T x) {
////////////////////////////////////////////////////////////////////////////////
namespace {
template<typename Y, typename X>
__host__ __device__ Y castTo(X x) {
template<typename Y>
__host__ __device__ Y castTo(uint64_t x) {
return Y(x);
}
template<typename Y>
__host__ __device__ Y castTo(float x) {
return Y(x);
}
template<typename Y>
__host__ __device__ Y castTo(double x) {
return Y(x);
}
template<>
__host__ __device__ half castTo<half>(float x) {
return __float2half(x);
}
#ifdef __CUDA_BF16_TYPES_EXIST__
template<>
__host__ __device__ half castTo<half>(double x) {
return __double2half(x);
}
template<>
__host__ __device__ half castTo<half>(uint64_t x) {
return __ull2half_rn(x);
}
#if HAVE_ncclBfloat16
template<>
__host__ __device__ __nv_bfloat16 castTo<__nv_bfloat16>(float x) {
return __float2bfloat16(x);
}
template<>
__host__ __device__ __nv_bfloat16 castTo<__nv_bfloat16>(double x) {
return __double2bfloat16(x);
}
template<>
__host__ __device__ __nv_bfloat16 castTo<__nv_bfloat16>(uint64_t x) {
return __double2bfloat16((double)x);
}
#endif
#if HAVE_ncclFloat8
template<>
__host__ __device__ __nv_fp8_e4m3 castTo<__nv_fp8_e4m3>(float x) {
return __nv_fp8_e4m3(x);
}
template<>
__host__ __device__ __nv_fp8_e4m3 castTo<__nv_fp8_e4m3>(double x) {
return __nv_fp8_e4m3(x);
}
template<>
__host__ __device__ __nv_fp8_e4m3 castTo<__nv_fp8_e4m3>(uint64_t x) {
return __nv_fp8_e4m3((double)x);
}
template<>
__host__ __device__ __nv_fp8_e5m2 castTo<__nv_fp8_e5m2>(float x) {
return __nv_fp8_e5m2(x);
}
template<>
__host__ __device__ __nv_fp8_e5m2 castTo<__nv_fp8_e5m2>(double x) {
return __nv_fp8_e5m2(x);
}
template<>
__host__ __device__ __nv_fp8_e5m2 castTo<__nv_fp8_e5m2>(uint64_t x) {
return __nv_fp8_e5m2((double)x);
}
#endif
}
@ -151,7 +215,7 @@ struct ReduceSum {
return __float2half(__half2float(a) + __half2float(b));
#endif
}
#ifdef __CUDA_BF16_TYPES_EXIST__
#if HAVE_ncclBfloat16
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
#if __CUDA_ARCH__ >= 800
return __hadd(a, b);
@ -160,6 +224,22 @@ struct ReduceSum {
#endif
}
#endif
#if HAVE_ncclFloat8
__host__ __device__ __nv_fp8_e4m3 operator()(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) const {
#if __CUDA_ARCH__ >= 800
return __nv_fp8_e4m3(__hadd(__half(a), __half(b)));
#else
return __nv_fp8_e4m3(float(a) + float(b));
#endif
}
__host__ __device__ __nv_fp8_e5m2 operator()(__nv_fp8_e5m2 a, __nv_fp8_e5m2 b) const {
#if __CUDA_ARCH__ >= 800
return __nv_fp8_e5m2(__hadd(__half(a), __half(b)));
#else
return __nv_fp8_e5m2(float(a) + float(b));
#endif
}
#endif
template<typename T>
__host__ __device__ T postOp(T x) const { return x; }
};
@ -175,7 +255,7 @@ struct ReduceProd {
return __float2half(__half2float(a) * __half2float(b));
#endif
}
#ifdef __CUDA_BF16_TYPES_EXIST__
#if HAVE_ncclBfloat16
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
#if __CUDA_ARCH__ >= 800
return __hmul(a, b);
@ -184,6 +264,22 @@ struct ReduceProd {
#endif
}
#endif
#if HAVE_ncclFloat8
__host__ __device__ __nv_fp8_e4m3 operator()(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) const {
#if __CUDA_ARCH__ >= 800
return __nv_fp8_e4m3(__hmul(__half(a), __half(b)));
#else
return __nv_fp8_e4m3(float(a) * float(b));
#endif
}
__host__ __device__ __nv_fp8_e5m2 operator()(__nv_fp8_e5m2 a, __nv_fp8_e5m2 b) const {
#if __CUDA_ARCH__ >= 800
return __nv_fp8_e5m2(__hmul(__half(a), __half(b)));
#else
return __nv_fp8_e5m2(float(a) * float(b));
#endif
}
#endif
template<typename T>
__host__ __device__ T postOp(T x) const { return x; }
};
@ -201,7 +297,7 @@ struct ReduceMin {
return __half2float(a) < __half2float(b) ? a : b;
#endif
}
#ifdef __CUDA_BF16_TYPES_EXIST__
#if HAVE_ncclBfloat16
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
#if __CUDA_ARCH__ >= 800
return __hmin(a, b);
@ -212,6 +308,22 @@ struct ReduceMin {
#endif
}
#endif
#if HAVE_ncclFloat8
__host__ __device__ __nv_fp8_e4m3 operator()(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) const {
#if __CUDA_ARCH__ >= 800
return __nv_fp8_e4m3(__hmin(__half(a), __half(b)));
#else
return __nv_fp8_e4m3(float(a) < float(b) ? a : b);
#endif
}
__host__ __device__ __nv_fp8_e5m2 operator()(__nv_fp8_e5m2 a, __nv_fp8_e5m2 b) const {
#if __CUDA_ARCH__ >= 800
return __nv_fp8_e5m2(__hmin(__half(a), __half(b)));
#else
return __nv_fp8_e5m2(float(a) < float(b) ? a : b);
#endif
}
#endif
template<typename T>
__host__ __device__ T postOp(T x) const { return x; }
};
@ -229,7 +341,7 @@ struct ReduceMax {
return __half2float(a) > __half2float(b) ? a : b;
#endif
}
#ifdef __CUDA_BF16_TYPES_EXIST__
#if HAVE_ncclBfloat16
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
#if __CUDA_ARCH__ >= 800
return __hmax(a, b);
@ -240,6 +352,22 @@ struct ReduceMax {
#endif
}
#endif
#if HAVE_ncclFloat8
__host__ __device__ __nv_fp8_e4m3 operator()(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) const {
#if __CUDA_ARCH__ >= 800
return __nv_fp8_e4m3(__hmax(__half(a), __half(b)));
#else
return __nv_fp8_e4m3(float(a) > float(b) ? a : b);
#endif
}
__host__ __device__ __nv_fp8_e5m2 operator()(__nv_fp8_e5m2 a, __nv_fp8_e5m2 b) const {
#if __CUDA_ARCH__ >= 800
return __nv_fp8_e5m2(__hmax(__half(a), __half(b)));
#else
return __nv_fp8_e5m2(float(a) > float(b) ? a : b);
#endif
}
#endif
template<typename T>
__host__ __device__ T postOp(T x) const { return x; }
};
@ -297,29 +425,47 @@ struct ReduceAvg {
namespace {
template<typename T>
struct FloatLayout;
struct FloatLayout { static constexpr bool is_floating_point = false; };
template<>
struct FloatLayout<float> {
static constexpr bool is_floating_point = true;
static constexpr int exponent_bits = 8, mantissa_bits = 23;
static constexpr int exponent_bias = (1<<(exponent_bits-1))-1;
};
template<>
struct FloatLayout<double> {
static constexpr bool is_floating_point = true;
static constexpr int exponent_bits = 11, mantissa_bits = 52;
static constexpr int exponent_bias = (1<<(exponent_bits-1))-1;
};
template<>
struct FloatLayout<half> {
static constexpr bool is_floating_point = true;
static constexpr int exponent_bits = 5, mantissa_bits = 10;
static constexpr int exponent_bias = (1<<(exponent_bits-1))-1;
};
#ifdef __CUDA_BF16_TYPES_EXIST__
#if HAVE_ncclBfloat16
template<>
struct FloatLayout<__nv_bfloat16> {
static constexpr bool is_floating_point = true;
static constexpr int exponent_bits = 8, mantissa_bits = 7;
static constexpr int exponent_bias = (1<<(exponent_bits-1))-1;
};
#endif
#if HAVE_ncclFloat8
template<>
struct FloatLayout<__nv_fp8_e4m3> {
static constexpr bool is_floating_point = true;
static constexpr int exponent_bits = 4, mantissa_bits = 3;
static constexpr int exponent_bias = (1<<(exponent_bits-1))-1;
};
template<>
struct FloatLayout<__nv_fp8_e5m2> {
static constexpr bool is_floating_point = true;
static constexpr int exponent_bits = 5, mantissa_bits = 2;
static constexpr int exponent_bias = (1<<(exponent_bits-1))-1;
};
#endif
template<typename T>
__host__ __device__ T makeFloat(int sign, int exp, uint64_t mant) {
@ -632,11 +778,12 @@ __host__ __device__ void genOutput(
////////////////////////////////////////////////////////////////////////////////
// Nil reduction (byte copy functions). Optimized to assume rank_n=1
// genInput specialization for integer ReduceNil.
namespace {
template<typename T, bool IsIntegral>
template<typename T>
__host__ __device__ void genInput(
T &ans, ReduceNil, int rank_n, int rank_me, uint64_t seed, intptr_t index,
std::integral_constant<bool, IsIntegral>
std::true_type /*integral*/
) {
(void)rank_n, (void)rank_me; // silence unused warnings
union { uint64_t bits; T tmp; };
@ -646,6 +793,24 @@ __host__ __device__ void genInput(
ans = tmp;
}
// genInput specialization for floating point ReduceNil.
template<typename T>
__host__ __device__ void genInput(
T &ans, ReduceNil, int rank_n, int rank_me, uint64_t seed, intptr_t index,
std::false_type /*integral*/
) {
(void)rank_n; // silence unused warnings
constexpr uint64_t mant_mask = (uint64_t(1) << FloatLayout<T>::mantissa_bits)-1;
uint64_t rng = hashOf(index ^ index<<16 ^ rank_me, seed);
int sign = rng & 1;
rng ^= rng>>1;
int exp = rng & ((1<<(FloatLayout<T>::exponent_bits-1))-1);
exp += 1<<(FloatLayout<T>::exponent_bits-2);
rng ^= rng >> FloatLayout<T>::exponent_bits;
uint64_t mant = rng & mant_mask;
ans = makeFloat<T>(sign, exp, mant);
}
template<typename T, typename ReduceFn, bool IsIntegral>
__host__ __device__ void genOutput(
T &ans, ReduceNil op, int rank_n, uint64_t seed, intptr_t index,
@ -734,20 +899,34 @@ __host__ __device__ void genOutput(
namespace {
template<typename T>
__host__ __device__ void genInput(
T &ans, ReduceAvg, int rank_n, int rank_me, uint64_t seed, intptr_t index,
T &ans, ReduceAvg, int rank_n, int rank_me, uint64_t rng, intptr_t index,
std::false_type /*integral*/
) {
ans = genInOutFloatSum<T>(/*input_not_output=*/true, rank_n, rank_me, seed, index, /*same_sign=*/true);
// We can't control the nranks divisor in avareages so to control error we
// limit to two ranks contributing non-zero values. This way there is no ambiguity
// of summation.
int r = shuffleRank(rank_n, rank_me, rng);
uint64_t m = (rng*(r ? 0xbeef : 1)) & ((1ul<<FloatLayout<T>::mantissa_bits)-1);
ans = r < 2 ? castTo<T>(1+m) : castTo<T>((uint64_t)0);
}
template<typename T>
__host__ __device__ void genOutput(
T &ans, ReduceAvg, int rank_n, uint64_t seed, intptr_t index,
T &ans, ReduceAvg, int rank_n, uint64_t rng, intptr_t index,
std::false_type /*integral*/
) {
ans = genInOutFloatSum<T>(/*input_not_output=*/false, rank_n, 0, seed, index, /*same_sign=*/true);
using T1 = typename std::conditional<(sizeof(T)<sizeof(double)), float, double>::type;
ans = ReduceProd()(ans, T1(1)/T1(rank_n));
shuffleRank(rank_n, -1, rng);
uint64_t m0 = (rng*(0 ? 0xbeef : 1)) & ((1ul<<FloatLayout<T>::mantissa_bits)-1);
uint64_t m1 = (rng*(1 ? 0xbeef : 1)) & ((1ul<<FloatLayout<T>::mantissa_bits)-1);
if (rank_n == 1) {
ans = castTo<T>(1+m0);
} else {
// NCCL varies which datatype it does the muls with depending on __CUDA_ARCH__.
// We account for this by using a tolerance of 2 ulps during the verification.
using TMul = typename std::conditional<(sizeof(T) < sizeof(double)), float, double>::type;
ans = ReduceSum()((T)(TMul(1+m0)*TMul(1.0/rank_n)),
(T)(TMul(1+m1)*TMul(1.0/rank_n)));
}
}
}
@ -809,10 +988,9 @@ __host__ __device__ T genOutput(
////////////////////////////////////////////////////////////////////////////////
#if !SELF_TEST
namespace {
template<typename T, typename ReduceFn>
__global__ void prepareInput2(
__global__ void __launch_bounds__(512, 1) prepareInput2(
T *elts, intptr_t elt_n, ReduceFn op, int rank_n, int rank_me,
uint64_t seed, intptr_t elt_ix0
) {
@ -833,40 +1011,49 @@ __global__ void prepareInput2(
}
template<typename ReduceOp>
void prepareInput1(
cudaError_t prepareInput1(
void *elts, intptr_t elt_n, int elt_ty, ReduceOp op, int rank_n, int rank_me,
uint64_t seed, intptr_t elt_ix0, cudaStream_t stream
) {
int block_n = std::min<intptr_t>(32, (elt_n + 4*512-1)/(4*512));
#define CASE_TY(T) prepareInput2<<<block_n, 512, 0, stream>>>((T*)elts, elt_n, op, rank_n, rank_me, seed, elt_ix0); break;
void const *fn = nullptr;
switch(elt_ty) {
case ncclInt8: CASE_TY(int8_t)
case ncclUint8: CASE_TY(uint8_t)
case ncclInt32: CASE_TY(int32_t)
case ncclUint32: CASE_TY(uint32_t)
case ncclInt64: CASE_TY(int64_t)
case ncclUint64: CASE_TY(uint64_t)
case ncclFloat16: CASE_TY(half)
case ncclInt8: fn = (void const*)&prepareInput2<int8_t, ReduceOp>; break;
case ncclUint8: fn = (void const*)&prepareInput2<uint8_t, ReduceOp>; break;
case ncclInt32: fn = (void const*)&prepareInput2<int32_t, ReduceOp>; break;
case ncclUint32: fn = (void const*)&prepareInput2<uint32_t, ReduceOp>; break;
case ncclInt64: fn = (void const*)&prepareInput2<int64_t, ReduceOp>; break;
case ncclUint64: fn = (void const*)&prepareInput2<uint64_t, ReduceOp>; break;
case ncclFloat16: fn = (void const*)&prepareInput2<half, ReduceOp>; break;
#if HAVE_ncclBfloat16
case ncclBfloat16: CASE_TY(__nv_bfloat16)
case ncclBfloat16: fn = (void const*)&prepareInput2<__nv_bfloat16, ReduceOp>; break;
#endif
case ncclFloat32: CASE_TY(float)
case ncclFloat64: CASE_TY(double)
default: assert(0);
#if HAVE_ncclFloat8
case ncclFloat8e4m3: fn = (void const*)&prepareInput2<__nv_fp8_e4m3, ReduceOp>; break;
case ncclFloat8e5m2: fn = (void const*)&prepareInput2<__nv_fp8_e5m2, ReduceOp>; break;
#endif
case ncclFloat32: fn = (void const*)&prepareInput2<float, ReduceOp>; break;
case ncclFloat64: fn = (void const*)&prepareInput2<double, ReduceOp>; break;
default: assert(0); return cudaErrorInvalidValue;
}
#undef CASE_TY
dim3 grid = {1, 1, 1};
grid.x = (unsigned int)std::min<intptr_t>(32, (elt_n + 4*512-1)/(4*512));
dim3 block = {512, 1, 1};
void *args[7] = {&elts, &elt_n, &op, &rank_n, &rank_me, &seed, &elt_ix0};
if (grid.x == 0) return cudaSuccess;
return cudaLaunchKernel(fn, grid, block, args, 0, stream);
}
}
void ncclVerifiablePrepareInput(
cudaError_t ncclVerifiablePrepareInput(
void *elts, intptr_t elt_n, int elt_ty, int red_op, int rank_n, int rank_me,
uint64_t seed, intptr_t elt_ix0, cudaStream_t stream
) {
#define CASE_OP(op) \
if(rank_n == 1) \
prepareInput1(elts, elt_n, elt_ty, ReduceNil(), rank_n, rank_me, seed, elt_ix0, stream); \
return prepareInput1(elts, elt_n, elt_ty, ReduceNil(), rank_n, rank_me, seed, elt_ix0, stream); \
else \
prepareInput1(elts, elt_n, elt_ty, op, rank_n, rank_me, seed, elt_ix0, stream); \
return prepareInput1(elts, elt_n, elt_ty, op, rank_n, rank_me, seed, elt_ix0, stream); \
break;
switch(red_op) {
case ncclSum: CASE_OP(ReduceSum())
@ -882,14 +1069,12 @@ void ncclVerifiablePrepareInput(
}
#undef CASE_OP
}
#endif
////////////////////////////////////////////////////////////////////////////////
#if !SELF_TEST
namespace {
template<typename T, typename ReduceFn>
__global__ void prepareExpected2(
__global__ void __launch_bounds__(512, 1) prepareExpected2(
T *elts, intptr_t elt_n, ReduceFn op, int rank_n,
uint64_t seed, intptr_t elt_ix0
) {
@ -909,40 +1094,49 @@ __global__ void prepareExpected2(
}
template<typename ReduceOp>
void prepareExpected1(
cudaError_t prepareExpected1(
void *elts, intptr_t elt_n, int elt_ty, ReduceOp op, int rank_n,
uint64_t seed, intptr_t elt_ix0, cudaStream_t stream
) {
int block_n = std::min<intptr_t>(32, (elt_n + 4*512-1)/(4*512));
#define CASE_TY(T) prepareExpected2<<<block_n, 512, 0, stream>>>((T*)elts, elt_n, op, rank_n, seed, elt_ix0); break;
void const *fn = nullptr;
switch(elt_ty) {
case ncclInt8: CASE_TY(int8_t)
case ncclUint8: CASE_TY(uint8_t)
case ncclInt32: CASE_TY(int32_t)
case ncclUint32: CASE_TY(uint32_t)
case ncclInt64: CASE_TY(int64_t)
case ncclUint64: CASE_TY(uint64_t)
case ncclFloat16: CASE_TY(half)
case ncclInt8: fn = (void const*)&prepareExpected2<int8_t, ReduceOp>; break;
case ncclUint8: fn = (void const*)&prepareExpected2<uint8_t, ReduceOp>; break;
case ncclInt32: fn = (void const*)&prepareExpected2<int32_t, ReduceOp>; break;
case ncclUint32: fn = (void const*)&prepareExpected2<uint32_t, ReduceOp>; break;
case ncclInt64: fn = (void const*)&prepareExpected2<int64_t, ReduceOp>; break;
case ncclUint64: fn = (void const*)&prepareExpected2<uint64_t, ReduceOp>; break;
case ncclFloat16: fn = (void const*)&prepareExpected2<half, ReduceOp>; break;
#if HAVE_ncclBfloat16
case ncclBfloat16: CASE_TY(__nv_bfloat16)
case ncclBfloat16: fn = (void const*)&prepareExpected2<__nv_bfloat16, ReduceOp>; break;
#endif
case ncclFloat32: CASE_TY(float)
case ncclFloat64: CASE_TY(double)
default: assert(0);
#if HAVE_ncclFloat8
case ncclFloat8e4m3: fn = (void const*)&prepareExpected2<__nv_fp8_e4m3, ReduceOp>; break;
case ncclFloat8e5m2: fn = (void const*)&prepareExpected2<__nv_fp8_e5m2, ReduceOp>; break;
#endif
case ncclFloat32: fn = (void const*)&prepareExpected2<float, ReduceOp>; break;
case ncclFloat64: fn = (void const*)&prepareExpected2<double, ReduceOp>; break;
default: assert(0); return cudaErrorInvalidValue;
}
#undef CASE_TY
dim3 grid = {1, 1, 1};
grid.x = (unsigned int)std::min<intptr_t>(32, (elt_n + 4*512-1)/(4*512));
dim3 block = {512, 1, 1};
void *args[6] = {&elts, &elt_n, &op, &rank_n, &seed, &elt_ix0};
if (grid.x == 0) return cudaSuccess;
return cudaLaunchKernel(fn, grid, block, args, 0, stream);
}
}
void ncclVerifiablePrepareExpected(
cudaError_t ncclVerifiablePrepareExpected(
void *elts, intptr_t elt_n, int elt_ty, int red_op, int rank_n,
uint64_t seed, intptr_t elt_ix0, cudaStream_t stream
) {
#define CASE_OP(op) \
if(rank_n == 1) \
prepareExpected1(elts, elt_n, elt_ty, ReduceNil(), rank_n, seed, elt_ix0, stream); \
return prepareExpected1(elts, elt_n, elt_ty, ReduceNil(), rank_n, seed, elt_ix0, stream); \
else \
prepareExpected1(elts, elt_n, elt_ty, op, rank_n, seed, elt_ix0, stream); \
return prepareExpected1(elts, elt_n, elt_ty, op, rank_n, seed, elt_ix0, stream); \
break;
switch(red_op) {
case ncclSum: CASE_OP(ReduceSum())
@ -958,52 +1152,10 @@ void ncclVerifiablePrepareExpected(
}
#undef CASE_OP
}
#endif
////////////////////////////////////////////////////////////////////////////////
namespace {
/* How we compare floating point values when exactness is impossible is interesting.
* First, we take note that simply reinterpreting integer bits as floating point
* gives us a monotonic mapping which exponentially spaces out floats. Thus
* consecutive integers encode consecutive floats. In general, using integer
* subraction on the bitpatterns of two floats gives us an integer which is the
* logarithm of their relative difference. But, if the floats always have similar
* exponents, than the integer difference is actually proportional to the
* relative error (this is because we are counting hops in the mantissa bits only,
* not the exponent bits). So a cheap way to compare if two floats are relatively
* close is: abs(intBits(a), intBits(b)) < tolerance. The following formula
* calculates such a tolerance for a summation of n floats. This formula
* was derived by inspecting the maximum observed integer difference over many
* random runs of summation. The parameter values were computed by the
* companion program "inexact_regress.cu".
*/
__host__ __device__ unsigned calcSumFloatTolerance(int rank_n, int elt_ty) {
float power, coef;
switch(elt_ty) {
case ncclFloat32:
case ncclFloat64:
power = .51f;
coef = 1.25f;
break;
case ncclFloat16:
power = .91f;
coef = .75f;
break;
#if HAVE_ncclBfloat16
case ncclBfloat16:
power = .91f;
coef = .66f;
break;
#endif
}
#if __CUDA_ARCH__
return 1 + unsigned(coef*powf(float(rank_n), power));
#else
return 1 + unsigned(coef*std::pow(float(rank_n), power));
#endif
}
template<typename T>
__host__ __device__ uint64_t calcDelta(T a, T b) {
union { T t; uint8_t i1; uint16_t i2; uint32_t i4; uint64_t i8; } x, y;
@ -1020,10 +1172,9 @@ __host__ __device__ uint64_t calcDelta(T a, T b) {
////////////////////////////////////////////////////////////////////////////////
#if !SELF_TEST
namespace {
template<typename T>
__global__ void verifyPrepared(
__global__ void __launch_bounds__(512, 1) verifyPrepared(
T const *results, T const *expected, intptr_t elt_n, unsigned tolerance, int64_t *bad_elt_n
) {
intptr_t i0 = blockIdx.x*(elt_n/gridDim.x);
@ -1039,16 +1190,34 @@ __global__ void verifyPrepared(
bad += tolerance < delta ? 1 : 0;
#if 0
if(tolerance < delta) {
printf("verifyPrepared ix=%lld got=%g exp=%g\n", (long long)i, (float)results[i], (float)expected[i]);
printf("verifyPrepared ix=%lld got=%g exp=%g tol=%d\n", (long long)i, (float)results[i], (float)expected[i], tolerance);
}
#endif
i += blockDim.x;
}
asm volatile("red.global.add.u64 [%0],%1;" :: "l"(bad_elt_n), "l"(bad));
asm volatile("red.global.add.u64 [%0],%1;" :: "l"(bad_elt_n), "l"(bad) : "memory");
}
cudaError_t verifyPrepared1(int bytePerElt,
void const *results, void const *expected, intptr_t elt_n, unsigned tolerance, int64_t *bad_elt_n, cudaStream_t stream, int block_n
) {
void const *fn = nullptr;
switch(bytePerElt) {
case 1: fn = (void const*)&verifyPrepared<uint8_t>; break;
case 2: fn = (void const*)&verifyPrepared<uint16_t>; break;
case 4: fn = (void const*)&verifyPrepared<uint32_t>; break;
case 8: fn = (void const*)&verifyPrepared<uint64_t>; break;
default: assert(0); return cudaErrorInvalidValue;
}
dim3 grid = {(unsigned int)block_n, 1, 1};
dim3 block = {512, 1, 1};
void *args[5] = {&results, &expected, &elt_n, &tolerance, &bad_elt_n};
if (grid.x == 0) return cudaSuccess;
return cudaLaunchKernel(fn, grid, block, args, 0, stream);
}
template<typename T, typename Uint, typename ReduceFn>
__global__ void verifyInline2(
__global__ void __launch_bounds__(512, 1) verifyInline2(
T const *results, intptr_t elt_n, ReduceFn op, int rank_n, uint64_t seed,
intptr_t elt_ix0, unsigned tolerance, int64_t *bad_elt_n
) {
@ -1077,39 +1246,52 @@ __global__ void verifyInline2(
#endif
i += blockDim.x;
}
asm volatile("red.global.add.u64 [%0],%1;" :: "l"(bad_elt_n), "l"(bad));
asm volatile("red.global.add.u64 [%0],%1;" :: "l"(bad_elt_n), "l"(bad) : "memory");
}
template<typename T, typename Uint>
void verifyInline1(
cudaError_t verifyInline1(
T const *results, intptr_t elt_n, int red_op, int rank_n, uint64_t seed, intptr_t elt_ix0,
unsigned tolerance, int64_t *bad_elt_n, cudaStream_t stream, int block_n
) {
void const *fn = nullptr;
ReduceNil opnil;
ReduceSum opsum;
ReduceMin opmin;
ReduceMax opmax;
ReduceProd opprod;
ReduceAvg opavg{rank_n};
ReducePreMulSum oppremulsum;
void *args[8] = {&results, &elt_n, nullptr, &rank_n, &seed, &elt_ix0, &tolerance, &bad_elt_n};
#define CASE_OP(op) \
if(rank_n == 1) \
verifyInline2<T, Uint><<<block_n, 512, 0, stream>>> \
((T const*)results, elt_n, ReduceNil(), rank_n, seed, elt_ix0, tolerance, bad_elt_n); \
else \
verifyInline2<T, Uint><<<block_n, 512, 0, stream>>> \
((T const*)results, elt_n, op, rank_n, seed, elt_ix0, tolerance, bad_elt_n); \
break;
if(rank_n == 1) { \
fn = (void const*)&verifyInline2<T, Uint, ReduceNil>; \
args[2] = &opnil; \
} else { \
fn = (void const*)&verifyInline2<T, Uint, decltype(op)>; \
args[2] = &op; \
} break;
switch(red_op) {
case ncclSum: CASE_OP(ReduceSum())
case ncclMin: CASE_OP(ReduceMin())
case ncclMax: CASE_OP(ReduceMax())
case ncclProd: CASE_OP(ReduceProd())
case ncclSum: CASE_OP(opsum)
case ncclMin: CASE_OP(opmin)
case ncclMax: CASE_OP(opmax)
case ncclProd: CASE_OP(opprod)
#if HAVE_ncclAvg
case ncclAvg: CASE_OP(ReduceAvg{rank_n})
case ncclAvg: CASE_OP(opavg)
#endif
#if HAVE_ncclPreMulSum
default: CASE_OP(ReducePreMulSum())
default: CASE_OP(oppremulsum)
#endif
}
#undef CASE_OP
dim3 grid = {(unsigned int)block_n, 1, 1};
dim3 block = {512, 1, 1};
if (grid.x == 0) return cudaSuccess;
return cudaLaunchKernel(fn, grid, block, args, 0, stream);
}
}
void ncclVerifiableVerify(
cudaError_t ncclVerifiableVerify(
void const *results, void const *expected, intptr_t elt_n, int elt_ty,
int red_op, int rank_n, uint64_t seed, intptr_t elt_ix0,
int64_t *bad_elt_n, cudaStream_t stream
@ -1118,11 +1300,21 @@ void ncclVerifiableVerify(
#if HAVE_ncclBfloat16
floating |= elt_ty == ncclBfloat16;
#endif
#if HAVE_ncclFloat8
floating |= elt_ty == ncclFloat8e4m3;
floating |= elt_ty == ncclFloat8e5m2;
#endif
unsigned tolerance = 0;
#if HAVE_ncclAvg
if (floating && red_op == ncclAvg)
tolerance = calcSumFloatTolerance(rank_n, elt_ty);
if (floating && red_op == ncclAvg) {
// Average does it's pre-multiplies in an unspecified floating point format
// (could be the actual type T or float or half). That means the premultiply
// verify does could generate a discrepancy in the least mantissa digit. After
// adding those two (since avg only has two non-zero contributions) we could
// be off by a distance of 2 units.
tolerance = 2;
}
#endif
int block_n = std::min<intptr_t>(32, (elt_n + 4*512-1)/(4*512));
@ -1130,9 +1322,9 @@ void ncclVerifiableVerify(
*bad_elt_n = 0;
#define CASE_TY(T, Uint) { \
if(expected != nullptr) { \
verifyPrepared<<<block_n, 512, 0, stream>>>((Uint const*)results, (Uint const*)expected, elt_n, tolerance, bad_elt_n); \
return verifyPrepared1(sizeof(T), results, expected, elt_n, tolerance, bad_elt_n, stream, block_n); \
} else { \
verifyInline1<T, Uint>((T const*)results, elt_n, red_op, rank_n, seed, elt_ix0, tolerance, bad_elt_n, stream, block_n); \
return verifyInline1<T, Uint>((T const*)results, elt_n, red_op, rank_n, seed, elt_ix0, tolerance, bad_elt_n, stream, block_n); \
} \
} break;
switch(elt_ty) {
@ -1143,29 +1335,30 @@ void ncclVerifiableVerify(
case ncclInt64: CASE_TY(int64_t, uint64_t)
case ncclUint64: CASE_TY(uint64_t, uint64_t)
case ncclFloat16: CASE_TY(half, uint16_t)
#if HAVE_ncclFloat8
case ncclFloat8e4m3: CASE_TY(__nv_fp8_e4m3, uint8_t)
case ncclFloat8e5m2: CASE_TY(__nv_fp8_e5m2, uint8_t)
#endif
#if HAVE_ncclBfloat16
case ncclBfloat16: CASE_TY(__nv_bfloat16, uint16_t)
#endif
case ncclFloat32: CASE_TY(float, uint32_t)
case ncclFloat64: CASE_TY(double, uint64_t)
default: assert(0);
default: assert(0); return cudaErrorInvalidValue;
}
#undef CASE_TY
}
#endif
////////////////////////////////////////////////////////////////////////////////
#if SELF_TEST
#include <iostream>
namespace {
template<typename T, typename Op>
__device__ void sweep2(int ty, char const *tyname, Op op, char const *opname, int rank_n) {
//if(!std::is_same<T,half>::value) return;
//if(!std::is_same<Op,ReduceProd>::value) return;
//if(rank_n!=3) return;
unsigned tolerance = !IsIntegral<T>::value && std::is_same<Op,ReduceAvg>::value ? calcSumFloatTolerance(rank_n, ty) : 0;
unsigned tolerance = !IsIntegral<T>::value && std::is_same<Op,ReduceAvg>::value ? 2 : 0;
uint64_t seed = 0xc8e2bed69766d533;
for(int ix=threadIdx.x; ix < 10000; ix+=blockDim.x) {
@ -1202,7 +1395,7 @@ __device__ void sweep1(int ty, char const *tyname) {
}
}
__global__ void sweep() {
__global__ void __launch_bounds__(512, 1) sweep() {
sweep1<int8_t>(ncclInt8, "int8");
sweep1<uint8_t>(ncclUint8, "uint8");
sweep1<int32_t>(ncclInt32, "int32");
@ -1210,18 +1403,18 @@ __global__ void sweep() {
sweep1<int64_t>(ncclInt64, "int64");
sweep1<uint64_t>(ncclUint64, "uint64");
sweep1<half>(ncclFloat16, "half");
#if HAVE_ncclFloat8
sweep1<__nv_fp8_e4m3>(ncclBfloat16, "float8e4m3");
sweep1<__nv_fp8_e5m2>(ncclBfloat16, "float8e5m2");
#endif
#if HAVE_ncclBfloat16
sweep1<__nv_bfloat16>(ncclBfloat16, "bfloat16");
#endif
sweep1<float>(ncclFloat32, "float");
sweep1<double>(ncclFloat64, "double");
}
int main(int arg_n, char **args) {
std::cerr<<"You are hoping to see no output beyond this line."<<std::endl;
cudaSetDevice(0);
sweep<<<1,512>>>();
cudaDeviceSynchronize();
return 0;
}
#endif
void ncclVerifiableLaunchSelfTest() {
sweep<<<1,512>>>();
}

View File

@ -34,13 +34,13 @@ __host__ __device__ T ncclVerifiablePremulScalar(int rank_me) {
}
// Enqueue kernel to generate data which is to be reduced.
void ncclVerifiablePrepareInput(
cudaError_t ncclVerifiablePrepareInput(
void *elts, intptr_t elt_n, int elt_ty, int red_op, int rank_n, int rank_me,
uint64_t seed, intptr_t elt_ix0, cudaStream_t stream
);
// Enqueue kernel to generate expected results of reduction.
void ncclVerifiablePrepareExpected(
cudaError_t ncclVerifiablePrepareExpected(
void *elts, intptr_t elt_n, int elt_ty, int red_op, int rank_n,
uint64_t seed, intptr_t elt_ix0, cudaStream_t stream
);
@ -51,9 +51,10 @@ void ncclVerifiablePrepareExpected(
// which can be costly. Thus if you plan to run the same reduction multiple
// times it is advantageous to precompute the expected values with
// ncclVerifiablePrepareExpected and pass them as `expected` here.
void ncclVerifiableVerify(
cudaError_t ncclVerifiableVerify(
void const *results, void const *expected, intptr_t elt_n, int elt_ty,
int red_op, int rank_n, uint64_t seed, intptr_t elt_ix0,
int64_t *bad_elt_n, cudaStream_t stream
);
#endif