Merge pull request #46 from NVIDIA/p2p

Add alltoall perf test
This commit is contained in:
Sylvain Jeaugey 2020-06-17 10:45:29 -07:00 committed by GitHub
commit a7b304dde5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 129 additions and 8 deletions

View File

@ -59,7 +59,7 @@ NVLDFLAGS += $(LIBRARIES:%=-l%)
DST_DIR := $(BUILDDIR)
SRC_FILES := $(wildcard *.cu)
OBJ_FILES := $(SRC_FILES:%.cu=${DST_DIR}/%.o)
BIN_FILES_LIST := all_reduce all_gather broadcast reduce_scatter reduce
BIN_FILES_LIST := all_reduce all_gather broadcast reduce_scatter reduce alltoall
BIN_FILES := $(BIN_FILES_LIST:%=${DST_DIR}/%_perf)
build: ${BIN_FILES}

117
src/alltoall.cu Normal file
View File

@ -0,0 +1,117 @@
/*************************************************************************
* Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "cuda_runtime.h"
#include "common.h"
void print_header() {
PRINT("# %10s %12s %6s %6s out-of-place in-place \n", "", "", "", "");
PRINT("# %10s %12s %6s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "size", "count", "type", "redop",
"time", "algbw", "busbw", "error", "time", "algbw", "busbw", "error");
PRINT("# %10s %12s %6s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "", "",
"(us)", "(GB/s)", "(GB/s)", "", "(us)", "(GB/s)", "(GB/s)", "");
}
void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) {
PRINT("%12li %12li %6s %6s", size, count, typeName, opName);
}
void AlltoAllGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) {
*sendcount = (count/nranks)*nranks;
*recvcount = (count/nranks)*nranks;
*sendInplaceOffset = 0;
*recvInplaceOffset = 0;
*paramcount = count/nranks;
}
testResult_t AlltoAllInitData(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t op, int root, int rep, int in_place) {
size_t sendcount = args->sendBytes / wordSize(type);
size_t recvcount = args->expectedBytes / wordSize(type);
int nranks = args->nProcs*args->nThreads*args->nGpus;
for (int i=0; i<args->nGpus; i++) {
char* str = getenv("NCCL_TESTS_DEVICE");
int gpuid = str ? atoi(str) : args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
CUDACHECK(cudaSetDevice(gpuid));
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
TESTCHECK(InitData(data, sendcount, type, rep, rank));
for (int j=0; j<nranks; j++) {
TESTCHECK(InitData(((char*)args->expected[i])+args->sendBytes/nranks*j, sendcount/nranks, type, rep+rank*sendcount/nranks, j));
}
CUDACHECK(cudaDeviceSynchronize());
}
// We don't support in-place alltoall
args->reportErrors = in_place ? 0 : 1;
return testSuccess;
}
void AlltoAllGetBw(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks) {
double baseBw = (double)(count * nranks * typesize) / 1.0E9 / sec;
*algBw = baseBw;
double factor = ((double)(nranks-1))/((double)(nranks));
*busBw = baseBw * factor;
}
testResult_t AlltoAllRunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
size_t rankOffset = count * wordSize(type);
if (count == 0) return testSuccess;
NCCLCHECK(ncclGroupStart());
for (int r=0; r<nRanks; r++) {
NCCLCHECK(ncclSend(((char*)sendbuff)+r*rankOffset, count, type, r, comm, stream));
NCCLCHECK(ncclRecv(((char*)recvbuff)+r*rankOffset, count, type, r, comm, stream));
}
NCCLCHECK(ncclGroupEnd());
return testSuccess;
}
struct testColl alltoAllTest = {
"AlltoAll",
AlltoAllGetCollByteCount,
AlltoAllInitData,
AlltoAllGetBw,
AlltoAllRunColl
};
void AlltoAllGetBuffSize(size_t *sendcount, size_t *recvcount, size_t count, int nranks) {
size_t paramcount, sendInplaceOffset, recvInplaceOffset;
AlltoAllGetCollByteCount(sendcount, recvcount, &paramcount, &sendInplaceOffset, &recvInplaceOffset, count, nranks);
}
testResult_t AlltoAllRunTest(struct threadArgs* args, int root, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName) {
args->collTest = &alltoAllTest;
ncclDataType_t *run_types;
const char **run_typenames;
int type_count;
if ((int)type != -1) {
type_count = 1;
run_types = &type;
run_typenames = &typeName;
} else {
type_count = ncclNumTypes;
run_types = test_types;
run_typenames = test_typenames;
}
for (int i=0; i<type_count; i++) {
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "", -1));
}
return testSuccess;
}
struct testEngine alltoAllEngine = {
AlltoAllGetBuffSize,
AlltoAllRunTest
};
#pragma weak ncclTestEngine=alltoAllEngine

View File

@ -308,7 +308,7 @@ testResult_t CheckData(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
#endif
}
double nranks = args->nProcs*args->nThreads*args->nGpus;
if (maxDelta > DeltaMaxValue(type)*(nranks - 1)) args->errors[0]++;
if (args->reportErrors && maxDelta > DeltaMaxValue(type)*(nranks - 1)) args->errors[0]++;
*delta = maxDelta;
return testSuccess;
}
@ -834,6 +834,8 @@ testResult_t run() {
threads[t].args.bw=bw+t;
threads[t].args.bw_count=bw_count+t;
threads[t].args.reportErrors = 1;
threads[t].func = parallel_init ? threadInit : threadRunTests;
if (t)
TESTCHECK(threadLaunch(threads+t));

View File

@ -17,25 +17,25 @@
#include "nccl1_compat.h"
#define CUDACHECK(cmd) do { \
cudaError_t e = cmd; \
if( e != cudaSuccess ) { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
char hostname[1024]; \
getHostName(hostname, 1024); \
printf("%s: Test CUDA failure %s:%d '%s'\n", \
hostname, \
__FILE__,__LINE__,cudaGetErrorString(e)); \
__FILE__,__LINE__,cudaGetErrorString(err)); \
return testCudaError; \
} \
} while(0)
#define NCCLCHECK(cmd) do { \
ncclResult_t r = cmd; \
if (r!= ncclSuccess) { \
ncclResult_t res = cmd; \
if (res != ncclSuccess) { \
char hostname[1024]; \
getHostName(hostname, 1024); \
printf("%s: Test NCCL failure %s:%d '%s'\n", \
hostname, \
__FILE__,__LINE__,ncclGetErrorString(r)); \
__FILE__,__LINE__,ncclGetErrorString(res)); \
return testNcclError; \
} \
} while(0)
@ -124,6 +124,8 @@ struct threadArgs {
double* bw;
int* bw_count;
int reportErrors;
struct testColl* collTest;
};