mirror of
https://github.com/NVIDIA/nccl-tests.git
synced 2026-04-25 08:58:18 +08:00
Add support for new datatype: bfloat16
This commit is contained in:
parent
0b30de583f
commit
e37545e491
@ -12,8 +12,16 @@
|
||||
#include "cuda.h"
|
||||
|
||||
#if NCCL_MAJOR >= 2
|
||||
ncclDataType_t test_types[ncclNumTypes] = {ncclInt8, ncclUint8, ncclInt32, ncclUint32, ncclInt64, ncclUint64, ncclHalf, ncclFloat, ncclDouble};
|
||||
const char *test_typenames[ncclNumTypes] = {"int8", "uint8", "int32", "uint32", "int64", "uint64", "half", "float", "double"};
|
||||
ncclDataType_t test_types[ncclNumTypes] = {ncclInt8, ncclUint8, ncclInt32, ncclUint32, ncclInt64, ncclUint64, ncclHalf, ncclFloat, ncclDouble,
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
ncclBfloat16
|
||||
#endif
|
||||
};
|
||||
const char *test_typenames[ncclNumTypes] = {"int8", "uint8", "int32", "uint32", "int64", "uint64", "half", "float", "double",
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
"bfloat16"
|
||||
#endif
|
||||
};
|
||||
#else
|
||||
ncclDataType_t test_types[ncclNumTypes] = {ncclChar, ncclInt, ncclHalf, ncclFloat, ncclDouble, ncclInt64, ncclUint64};
|
||||
const char *test_typenames[ncclNumTypes] = {"char", "int", "half", "float", "double", "int64", "uint64"};
|
||||
@ -61,6 +69,9 @@ double parsesize(char *value) {
|
||||
double DeltaMaxValue(ncclDataType_t type) {
|
||||
switch(type) {
|
||||
case ncclHalf: return 1e-2;
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
case ncclBfloat16: return 1e-2;
|
||||
#endif
|
||||
case ncclFloat: return 1e-5;
|
||||
case ncclDouble: return 1e-12;
|
||||
case ncclInt:
|
||||
@ -95,6 +106,12 @@ template<> __device__
|
||||
float toFloat(half a) {
|
||||
return __half2float(a);
|
||||
}
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
template<> __device__
|
||||
float toFloat(__nv_bfloat16 a) {
|
||||
return __bfloat162float(a);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename T, int BSIZE> __global__
|
||||
void deltaKern(void* A_, void* B_, size_t count, double* max) {
|
||||
@ -128,6 +145,10 @@ void deltaKern(void* A_, void* B_, size_t count, double* max) {
|
||||
|
||||
testResult_t CheckDelta(void* expected, void* results, size_t count, ncclDataType_t type, double* devmax) {
|
||||
switch (type) {
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
case ncclBfloat16:
|
||||
deltaKern<__nv_bfloat16, 512><<<1, 512>>>(results, expected, count, devmax); break;
|
||||
#endif
|
||||
case ncclHalf:
|
||||
deltaKern<half, 512><<<1, 512>>>(results, expected, count, devmax); break;
|
||||
case ncclFloat:
|
||||
@ -174,6 +195,12 @@ template<>
|
||||
__device__ half testValue<half>(const size_t offset, const int rep, const int rank) {
|
||||
return __float2half(testValue<float>(offset, rep, rank));
|
||||
}
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
template<>
|
||||
__device__ __nv_bfloat16 testValue<__nv_bfloat16>(const size_t offset, const int rep, const int rank) {
|
||||
return __float2bfloat16(testValue<float>(offset, rep, rank));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Operations
|
||||
template<typename T>
|
||||
@ -210,7 +237,10 @@ __global__ void InitDataReduceKernel(T* data, const size_t N, const size_t offse
|
||||
#define OPS(type) KERN(type, ncclOpSum), KERN(type, ncclOpProd), KERN(type, ncclOpMax), KERN(type, ncclOpMin)
|
||||
|
||||
static void* const redInitDataKerns[ncclNumOps*ncclNumTypes] = {
|
||||
OPS(int8_t), OPS(uint8_t), OPS(int32_t), OPS(uint32_t), OPS(int64_t), OPS(uint64_t), OPS(half), OPS(float), OPS(double)
|
||||
OPS(int8_t), OPS(uint8_t), OPS(int32_t), OPS(uint32_t), OPS(int64_t), OPS(uint64_t), OPS(half), OPS(float), OPS(double),
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
OPS(__nv_bfloat16)
|
||||
#endif
|
||||
};
|
||||
|
||||
testResult_t InitDataReduce(void* data, const size_t count, const size_t offset, ncclDataType_t type, ncclRedOp_t op, const int rep, const int nranks) {
|
||||
@ -236,7 +266,10 @@ static void* const initDataKerns[ncclNumTypes] = {
|
||||
(void*)InitDataKernel<uint64_t>,
|
||||
(void*)InitDataKernel< half>,
|
||||
(void*)InitDataKernel< float>,
|
||||
(void*)InitDataKernel< double>
|
||||
(void*)InitDataKernel< double>,
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
(void*)InitDataKernel<__nv_bfloat16>,
|
||||
#endif
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
|
||||
@ -213,6 +213,9 @@ static size_t wordSize(ncclDataType_t type) {
|
||||
#endif
|
||||
return 1;
|
||||
case ncclHalf:
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
case ncclBfloat16:
|
||||
#endif
|
||||
//case ncclFloat16:
|
||||
return 2;
|
||||
case ncclInt:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user