mirror of
https://github.com/NVIDIA/nccl-tests.git
synced 2026-04-25 08:58:18 +08:00
Add support for ncclAvg operation
This commit is contained in:
parent
e37545e491
commit
cde7e769c1
@ -84,7 +84,7 @@ testResult_t AllGatherRunTest(struct threadArgs* args, int root, ncclDataType_t
|
||||
run_types = &type;
|
||||
run_typenames = &typeName;
|
||||
} else {
|
||||
type_count = ncclNumTypes;
|
||||
type_count = test_typenum;
|
||||
run_types = test_types;
|
||||
run_typenames = test_typenames;
|
||||
}
|
||||
|
||||
@ -83,7 +83,7 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t
|
||||
run_types = &type;
|
||||
run_typenames = &typeName;
|
||||
} else {
|
||||
type_count = ncclNumTypes;
|
||||
type_count = test_typenum;
|
||||
run_types = test_types;
|
||||
run_typenames = test_typenames;
|
||||
}
|
||||
@ -93,7 +93,7 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t
|
||||
run_ops = &op;
|
||||
run_opnames = &opName;
|
||||
} else {
|
||||
op_count = ncclNumOps;
|
||||
op_count = test_opnum;
|
||||
run_ops = test_ops;
|
||||
run_opnames = test_opnames;
|
||||
}
|
||||
|
||||
@ -102,7 +102,7 @@ testResult_t AlltoAllRunTest(struct threadArgs* args, int root, ncclDataType_t t
|
||||
run_types = &type;
|
||||
run_typenames = &typeName;
|
||||
} else {
|
||||
type_count = ncclNumTypes;
|
||||
type_count = test_typenum;
|
||||
run_types = test_types;
|
||||
run_typenames = test_typenames;
|
||||
}
|
||||
|
||||
@ -92,7 +92,7 @@ testResult_t BroadcastRunTest(struct threadArgs* args, int root, ncclDataType_t
|
||||
run_types = &type;
|
||||
run_typenames = &typeName;
|
||||
} else {
|
||||
type_count = ncclNumTypes;
|
||||
type_count = test_typenum;
|
||||
run_types = test_types;
|
||||
run_typenames = test_typenames;
|
||||
}
|
||||
|
||||
@ -11,23 +11,41 @@
|
||||
#include <libgen.h>
|
||||
#include "cuda.h"
|
||||
|
||||
int test_ncclVersion = 0; // init'd with ncclGetVersion()
|
||||
|
||||
#if NCCL_MAJOR >= 2
|
||||
ncclDataType_t test_types[ncclNumTypes] = {ncclInt8, ncclUint8, ncclInt32, ncclUint32, ncclInt64, ncclUint64, ncclHalf, ncclFloat, ncclDouble,
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
|
||||
ncclBfloat16
|
||||
#endif
|
||||
};
|
||||
const char *test_typenames[ncclNumTypes] = {"int8", "uint8", "int32", "uint32", "int64", "uint64", "half", "float", "double",
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
|
||||
"bfloat16"
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
|
||||
int test_typenum = 10;
|
||||
#else
|
||||
int test_typenum = 9;
|
||||
#endif
|
||||
|
||||
#else
|
||||
ncclDataType_t test_types[ncclNumTypes] = {ncclChar, ncclInt, ncclHalf, ncclFloat, ncclDouble, ncclInt64, ncclUint64};
|
||||
const char *test_typenames[ncclNumTypes] = {"char", "int", "half", "float", "double", "int64", "uint64"};
|
||||
int test_typenum = 7;
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
|
||||
ncclRedOp_t test_ops[ncclNumOps] = {ncclSum, ncclProd, ncclMax, ncclMin, ncclAvg};
|
||||
const char *test_opnames[ncclNumOps] = {"sum", "prod", "max", "min", "avg"};
|
||||
int test_opnum = 5;
|
||||
#else
|
||||
ncclRedOp_t test_ops[ncclNumOps] = {ncclSum, ncclProd, ncclMax, ncclMin};
|
||||
const char *test_opnames[ncclNumOps] = {"sum", "prod", "max", "min"};
|
||||
int test_opnum = 4;
|
||||
#endif
|
||||
|
||||
thread_local int is_main_thread = 0;
|
||||
|
||||
@ -126,7 +144,7 @@ void deltaKern(void* A_, void* B_, size_t count, double* max) {
|
||||
if( delta > locmax ) {
|
||||
locmax = delta;
|
||||
#ifdef DEBUG_PRINT
|
||||
if (delta > .1) printf("Error at %d/%ld : %f != %f\n", i, count, toFloat(A[i]), toFloat(B[i]));
|
||||
if (delta > .1) printf("Error at %ld/%ld(%p) : %f != %f\n", i, count, B+i, toFloat(A[i]), toFloat(B[i]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@ -222,23 +240,48 @@ __device__ half ncclOpMax(half a, half b) { return __half2float(a)>__half2float(
|
||||
template<>
|
||||
__device__ half ncclOpMin(half a, half b) { return __half2float(a)<__half2float(b) ? a : b; }
|
||||
|
||||
template<typename T, T (*Op)(T, T)>
|
||||
template<typename T>
|
||||
__device__ T ncclPostOpIdent(T x, int n) { return x; }
|
||||
|
||||
template<typename T>
|
||||
__device__ T ncclPostOpDiv(T x, int n) { return x/n; }
|
||||
template<>
|
||||
__device__ half ncclPostOpDiv<half>(half x, int n) { return __float2half(__half2float(x)/n); }
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
template<>
|
||||
__device__ __nv_bfloat16 ncclPostOpDiv<__nv_bfloat16>(__nv_bfloat16 x, int n) { return __float2bfloat16(__bfloat162float(x)/n); }
|
||||
#endif
|
||||
|
||||
template<typename T, T (*Op)(T, T), T(*PostOp)(T,int)>
|
||||
__global__ void InitDataReduceKernel(T* data, const size_t N, const size_t offset, const int rep, const int nranks) {
|
||||
for (size_t o=blockIdx.x*blockDim.x+threadIdx.x; o<N; o+=gridDim.x*blockDim.x) {
|
||||
T val = testValue<T>(o+offset, rep, 0);
|
||||
for (int i=1; i<nranks; i++) {
|
||||
val = Op(val, testValue<T>(o+offset, rep, i));
|
||||
}
|
||||
data[o] = val;
|
||||
data[o] = PostOp(val, nranks);
|
||||
}
|
||||
}
|
||||
|
||||
#define KERN(type, op) (void*)InitDataReduceKernel<type, op<type>>
|
||||
#define OPS(type) KERN(type, ncclOpSum), KERN(type, ncclOpProd), KERN(type, ncclOpMax), KERN(type, ncclOpMin)
|
||||
#define KERN(type, op, postop) (void*)InitDataReduceKernel<type, op<type>, postop<type> >
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
|
||||
#define OPS(type) \
|
||||
KERN(type, ncclOpSum, ncclPostOpIdent), \
|
||||
KERN(type, ncclOpProd, ncclPostOpIdent), \
|
||||
KERN(type, ncclOpMax, ncclPostOpIdent), \
|
||||
KERN(type, ncclOpMin, ncclPostOpIdent), \
|
||||
KERN(type, ncclOpSum/*Avg*/, ncclPostOpDiv)
|
||||
#else
|
||||
#define OPS(type) \
|
||||
KERN(type, ncclOpSum, ncclPostOpIdent), \
|
||||
KERN(type, ncclOpProd, ncclPostOpIdent), \
|
||||
KERN(type, ncclOpMax, ncclPostOpIdent), \
|
||||
KERN(type, ncclOpMin, ncclPostOpIdent)
|
||||
#endif
|
||||
|
||||
static void* const redInitDataKerns[ncclNumOps*ncclNumTypes] = {
|
||||
OPS(int8_t), OPS(uint8_t), OPS(int32_t), OPS(uint32_t), OPS(int64_t), OPS(uint64_t), OPS(half), OPS(float), OPS(double),
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
|
||||
OPS(__nv_bfloat16)
|
||||
#endif
|
||||
};
|
||||
@ -267,7 +310,7 @@ static void* const initDataKerns[ncclNumTypes] = {
|
||||
(void*)InitDataKernel< half>,
|
||||
(void*)InitDataKernel< float>,
|
||||
(void*)InitDataKernel< double>,
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
|
||||
(void*)InitDataKernel<__nv_bfloat16>,
|
||||
#endif
|
||||
};
|
||||
@ -367,7 +410,7 @@ testResult_t testStreamSynchronize(int ngpus, cudaStream_t* streams, ncclComm_t*
|
||||
if (cudaErr != cudaErrorNotReady) CUDACHECK(cudaErr);
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,4,0)
|
||||
if (comms) {
|
||||
if (test_ncclVersion >= NCCL_VERSION(2,4,0) && comms) {
|
||||
ncclResult_t ncclAsyncErr;
|
||||
NCCLCHECK(ncclCommGetAsyncError(comms[i], &ncclAsyncErr));
|
||||
if (ncclAsyncErr != ncclSuccess) {
|
||||
@ -602,6 +645,17 @@ int main(int argc, char* argv[]) {
|
||||
// Make sure everyline is flushed so that we see the progress of the test
|
||||
setlinebuf(stdout);
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,4,0)
|
||||
ncclGetVersion(&test_ncclVersion);
|
||||
#else
|
||||
test_ncclVersion = NCCL_VERSION_CODE;
|
||||
#endif
|
||||
//printf("# NCCL_VERSION_CODE=%d ncclGetVersion=%d\n", NCCL_VERSION_CODE, test_ncclVersion);
|
||||
if (NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) && test_ncclVersion < NCCL_VERSION(2,10,0)) {
|
||||
test_opnum -= 1; // exclude ncclAvg
|
||||
test_typenum -= 1; // exclude bfloat16
|
||||
}
|
||||
|
||||
// Parse args
|
||||
int longindex;
|
||||
static struct option longopts[] = {
|
||||
@ -653,7 +707,7 @@ int main(int argc, char* argv[]) {
|
||||
iters = (int)strtol(optarg, NULL, 0);
|
||||
break;
|
||||
case 'm':
|
||||
#if NCCL_MAJOR >= 2 && NCCL_MINOR >= 2
|
||||
#if NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 2)
|
||||
agg_iters = (int)strtol(optarg, NULL, 0);
|
||||
#else
|
||||
printf("Option -m not supported before NCCL 2.2. Ignoring\n");
|
||||
@ -693,7 +747,11 @@ int main(int argc, char* argv[]) {
|
||||
"[-w,--warmup_iters <warmup iteration count>] \n\t"
|
||||
"[-p,--parallel_init <0/1>] \n\t"
|
||||
"[-c,--check <0/1>] \n\t"
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
|
||||
"[-o,--op <sum/prod/min/max/avg/all>] \n\t"
|
||||
#else
|
||||
"[-o,--op <sum/prod/min/max/all>] \n\t"
|
||||
#endif
|
||||
"[-d,--datatype <nccltype/all>] \n\t"
|
||||
"[-r,--root <root>] \n\t"
|
||||
"[-z,--blocking <0/1>] \n\t"
|
||||
@ -701,8 +759,8 @@ int main(int argc, char* argv[]) {
|
||||
basename(argv[0]));
|
||||
return 0;
|
||||
default:
|
||||
printf("invalid option \n");
|
||||
printf("USAGE: %s \n\t"
|
||||
if (c != 'h') printf("invalid option '%c'\n", c);
|
||||
printf("USAGE: %s \n\t"
|
||||
"[-t,--nthreads <num threads>] \n\t"
|
||||
"[-g,--ngpus <gpus per thread>] \n\t"
|
||||
"[-b,--minbytes <min size in bytes>] \n\t"
|
||||
@ -714,7 +772,11 @@ int main(int argc, char* argv[]) {
|
||||
"[-w,--warmup_iters <warmup iteration count>] \n\t"
|
||||
"[-p,--parallel_init <0/1>] \n\t"
|
||||
"[-c,--check <0/1>] \n\t"
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
|
||||
"[-o,--op <sum/prod/min/max/avg/all>] \n\t"
|
||||
#else
|
||||
"[-o,--op <sum/prod/min/max/all>] \n\t"
|
||||
#endif
|
||||
"[-d,--datatype <nccltype/all>] \n\t"
|
||||
"[-r,--root <root>] \n\t"
|
||||
"[-z,--blocking <0/1>] \n\t"
|
||||
@ -899,8 +961,8 @@ testResult_t run() {
|
||||
|
||||
// Free off CUDA allocated memory
|
||||
for (int i=0; i<nGpus*nThreads; i++) {
|
||||
CUDACHECK(cudaFree(sendbuffs[i]));
|
||||
CUDACHECK(cudaFree(recvbuffs[i]));
|
||||
if (sendbuffs[i]) CUDACHECK(cudaFree((char*)sendbuffs[i]));
|
||||
if (recvbuffs[i]) CUDACHECK(cudaFree((char*)recvbuffs[i]));
|
||||
if (datacheck) CUDACHECK(cudaFree(expected[i]));
|
||||
}
|
||||
CUDACHECK(cudaFreeHost(delta));
|
||||
|
||||
@ -235,10 +235,13 @@ static size_t wordSize(ncclDataType_t type) {
|
||||
}
|
||||
}
|
||||
|
||||
extern int test_ncclVersion; // init'd with ncclGetVersion()
|
||||
extern ncclDataType_t test_types[ncclNumTypes];
|
||||
extern const char *test_typenames[ncclNumTypes];
|
||||
extern ncclRedOp_t test_ops[ncclNumOps];
|
||||
extern const char *test_opnames[ncclNumOps];
|
||||
extern int test_opnum;
|
||||
extern int test_typenum;
|
||||
|
||||
static int ncclstringtotype(char *str) {
|
||||
for (int t=0; t<ncclNumTypes; t++) {
|
||||
@ -254,7 +257,7 @@ static int ncclstringtotype(char *str) {
|
||||
}
|
||||
|
||||
static int ncclstringtoop (char *str) {
|
||||
for (int o=0; o<ncclNumOps; o++) {
|
||||
for (int o=0; o<test_opnum; o++) {
|
||||
if (strcmp(str, test_opnames[o]) == 0) {
|
||||
return o;
|
||||
}
|
||||
|
||||
@ -83,7 +83,7 @@ testResult_t ReduceRunTest(struct threadArgs* args, int root, ncclDataType_t typ
|
||||
run_types = &type;
|
||||
run_typenames = &typeName;
|
||||
} else {
|
||||
type_count = ncclNumTypes;
|
||||
type_count = test_typenum;
|
||||
run_types = test_types;
|
||||
run_typenames = test_typenames;
|
||||
}
|
||||
@ -93,7 +93,7 @@ testResult_t ReduceRunTest(struct threadArgs* args, int root, ncclDataType_t typ
|
||||
run_ops = &op;
|
||||
run_opnames = &opName;
|
||||
} else {
|
||||
op_count = ncclNumOps;
|
||||
op_count = test_opnum;
|
||||
run_ops = test_ops;
|
||||
run_opnames = test_opnames;
|
||||
}
|
||||
|
||||
@ -84,7 +84,7 @@ testResult_t ReduceScatterRunTest(struct threadArgs* args, int root, ncclDataTyp
|
||||
run_types = &type;
|
||||
run_typenames = &typeName;
|
||||
} else {
|
||||
type_count = ncclNumTypes;
|
||||
type_count = test_typenum;
|
||||
run_types = test_types;
|
||||
run_typenames = test_typenames;
|
||||
}
|
||||
@ -94,7 +94,7 @@ testResult_t ReduceScatterRunTest(struct threadArgs* args, int root, ncclDataTyp
|
||||
run_opnames = &opName;
|
||||
op_count = 1;
|
||||
} else {
|
||||
op_count = sizeof(test_ops)/sizeof(test_ops[0]);
|
||||
op_count = test_opnum;
|
||||
run_ops = test_ops;
|
||||
run_opnames = test_opnames;
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user