mirror of
https://github.com/NVIDIA/nccl-tests.git
synced 2026-01-14 02:47:21 +08:00
Update NCCL tests
This commit is contained in:
parent
749573f2d6
commit
d313d20a26
@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# See LICENSE.txt for license information
|
||||
#
|
||||
@ -92,7 +92,12 @@ ${DST_DIR}/%.o: %.cu common.h $(TEST_VERIFIABLE_HDRS)
|
||||
@mkdir -p ${DST_DIR}
|
||||
$(NVCC) -o $@ $(NVCUFLAGS) -c $<
|
||||
|
||||
${DST_DIR}/%_perf:${DST_DIR}/%.o ${DST_DIR}/common.o $(TEST_VERIFIABLE_OBJS)
|
||||
${DST_DIR}/timer.o: timer.cc timer.h
|
||||
@printf "Compiling %-35s > %s\n" $< $@
|
||||
@mkdir -p ${DST_DIR}
|
||||
$(CXX) $(CXXFLAGS) -o $@ -c timer.cc
|
||||
|
||||
${DST_DIR}/%_perf:${DST_DIR}/%.o ${DST_DIR}/common.o ${DST_DIR}/timer.o $(TEST_VERIFIABLE_OBJS)
|
||||
@printf "Linking %-35s > %s\n" $< $@
|
||||
@mkdir -p ${DST_DIR}
|
||||
$(NVCC) -o $@ $(NVCUFLAGS) $^ ${NVLDFLAGS}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -7,12 +7,15 @@
|
||||
#include "cuda_runtime.h"
|
||||
#include "common.h"
|
||||
|
||||
#define ALIGN 4
|
||||
|
||||
void AllGatherGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) {
|
||||
*sendcount = count/nranks;
|
||||
*recvcount = (count/nranks)*nranks;
|
||||
*sendInplaceOffset = count/nranks;
|
||||
size_t base = (count/(ALIGN*nranks))*ALIGN;
|
||||
*sendcount = base;
|
||||
*recvcount = base*nranks;
|
||||
*sendInplaceOffset = base;
|
||||
*recvInplaceOffset = 0;
|
||||
*paramcount = *sendcount;
|
||||
*paramcount = base;
|
||||
}
|
||||
|
||||
testResult_t AllGatherInitData(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t op, int root, int rep, int in_place) {
|
||||
@ -21,8 +24,7 @@ testResult_t AllGatherInitData(struct threadArgs* args, ncclDataType_t type, ncc
|
||||
int nranks = args->nProcs*args->nThreads*args->nGpus;
|
||||
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
int gpuid = args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
|
||||
void* data = in_place ? ((char*)args->recvbuffs[i])+rank*args->sendBytes : args->sendbuffs[i];
|
||||
@ -78,7 +80,7 @@ testResult_t AllGatherRunTest(struct threadArgs* args, int root, ncclDataType_t
|
||||
}
|
||||
|
||||
for (int i=0; i<type_count; i++) {
|
||||
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "", -1));
|
||||
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "none", -1));
|
||||
}
|
||||
return testSuccess;
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -21,8 +21,7 @@ testResult_t AllReduceInitData(struct threadArgs* args, ncclDataType_t type, ncc
|
||||
int nranks = args->nProcs*args->nThreads*args->nGpus;
|
||||
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
int gpuid = args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
|
||||
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -21,9 +21,7 @@ testResult_t AlltoAllInitData(struct threadArgs* args, ncclDataType_t type, nccl
|
||||
int nranks = args->nProcs*args->nThreads*args->nGpus;
|
||||
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
char* str = getenv("NCCL_TESTS_DEVICE");
|
||||
int gpuid = str ? atoi(str) : args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
|
||||
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
|
||||
@ -51,7 +49,6 @@ testResult_t AlltoAllRunColl(void* sendbuff, void* recvbuff, size_t count, ncclD
|
||||
int nRanks;
|
||||
NCCLCHECK(ncclCommCount(comm, &nRanks));
|
||||
size_t rankOffset = count * wordSize(type);
|
||||
if (count == 0) return testSuccess;
|
||||
|
||||
#if NCCL_MAJOR < 2 || NCCL_MINOR < 7
|
||||
printf("NCCL 2.7 or later is needed for alltoall. This test was compiled with %d.%d.\n", NCCL_MAJOR, NCCL_MINOR);
|
||||
@ -97,7 +94,7 @@ testResult_t AlltoAllRunTest(struct threadArgs* args, int root, ncclDataType_t t
|
||||
}
|
||||
|
||||
for (int i=0; i<type_count; i++) {
|
||||
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "", -1));
|
||||
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "none", -1));
|
||||
}
|
||||
return testSuccess;
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -20,8 +20,7 @@ testResult_t BroadcastInitData(struct threadArgs* args, ncclDataType_t type, ncc
|
||||
size_t recvcount = args->expectedBytes / wordSize(type);
|
||||
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
int gpuid = args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
|
||||
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
|
||||
@ -94,7 +93,7 @@ testResult_t BroadcastRunTest(struct threadArgs* args, int root, ncclDataType_t
|
||||
|
||||
for (int i=0; i<type_count; i++) {
|
||||
for (int j=begin_root; j<=end_root; j++) {
|
||||
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "", j));
|
||||
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "none", j));
|
||||
}
|
||||
}
|
||||
return testSuccess;
|
||||
|
||||
151
src/common.cu
151
src/common.cu
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -50,6 +50,12 @@ int test_ncclVersion = 0; // init'd with ncclGetVersion()
|
||||
int test_opnum = 4;
|
||||
#endif
|
||||
|
||||
// For libnccl's < 2.13
|
||||
extern "C" __attribute__((weak)) char const* ncclGetLastError(ncclComm_t comm) {
|
||||
return "";
|
||||
}
|
||||
|
||||
int is_main_proc = 0;
|
||||
thread_local int is_main_thread = 0;
|
||||
|
||||
// Command line parameter defaults
|
||||
@ -68,7 +74,10 @@ static int nccltype = ncclFloat;
|
||||
static int ncclroot = 0;
|
||||
static int parallel_init = 0;
|
||||
static int blocking_coll = 0;
|
||||
static int streamnull = 0;
|
||||
static int timeout = 0;
|
||||
static int cudaGraphLaunches = 0;
|
||||
static int report_cputime = 0;
|
||||
// Report average iteration time: (0=RANK0,1=AVG,2=MIN,3=MAX)
|
||||
static int average = 1;
|
||||
|
||||
@ -198,7 +207,7 @@ void Allreduce(struct threadArgs* args, T* value, int average) {
|
||||
}
|
||||
#endif
|
||||
|
||||
if(average == 1) accumulator[epoch] /= args->nProcs*args->nThreads;
|
||||
if(average == 1) accumulator[epoch] /= args->totalProcs*args->nThreads;
|
||||
counter[epoch] = 0;
|
||||
pthread_cond_broadcast(&cond[epoch]);
|
||||
}
|
||||
@ -220,10 +229,8 @@ testResult_t CheckData(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
CUDACHECK(cudaHostAlloc((void**)&wrongPerGpu, args->nGpus*sizeof(int64_t), cudaHostAllocMapped));
|
||||
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
int device;
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
NCCLCHECK(ncclCommCuDevice(args->comms[i], &device));
|
||||
CUDACHECK(cudaSetDevice(device));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
void *data = in_place ? ((void *)((uintptr_t)args->recvbuffs[i] + args->recvInplaceOffset*rank)) : args->recvbuffs[i];
|
||||
|
||||
TESTCHECK(CheckDelta(data, args->expected[i], count, 0, type, op, 0, nranks, wrongPerGpu+i));
|
||||
@ -266,6 +273,8 @@ testResult_t testStreamSynchronize(int ngpus, cudaStream_t* streams, ncclComm_t*
|
||||
int remaining = ngpus;
|
||||
int* done = (int*)malloc(sizeof(int)*ngpus);
|
||||
memset(done, 0, sizeof(int)*ngpus);
|
||||
timer tim;
|
||||
|
||||
while (remaining) {
|
||||
int idle = 1;
|
||||
for (int i=0; i<ngpus; i++) {
|
||||
@ -294,6 +303,19 @@ testResult_t testStreamSynchronize(int ngpus, cudaStream_t* streams, ncclComm_t*
|
||||
NCCLCHECK(ncclAsyncErr);
|
||||
}
|
||||
}
|
||||
double delta = tim.elapsed();
|
||||
if (delta > timeout && timeout > 0) {
|
||||
for (int i=0; i<ngpus; i++)
|
||||
NCCLCHECK(ncclCommAbort(comms[i]));
|
||||
char hostname[1024];
|
||||
getHostName(hostname, 1024);
|
||||
printf("%s: Test timeout (%ds) %s:%d\n",
|
||||
hostname,
|
||||
timeout,
|
||||
__FILE__,__LINE__);
|
||||
free(done);
|
||||
return testTimeout;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -315,9 +337,7 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
if (args->nGpus > 1) NCCLCHECK(ncclGroupStart());
|
||||
for (int i = 0; i < args->nGpus; i++) {
|
||||
#ifndef NCCL_MAJOR
|
||||
int cudaDev;
|
||||
NCCLCHECK(ncclCommCuDevice(args->comms[i], &cudaDev));
|
||||
CUDACHECK(cudaSetDevice(cudaDev));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
#endif
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
char* recvBuff = ((char*)args->recvbuffs[i]) + shift;
|
||||
@ -411,7 +431,7 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
#endif
|
||||
|
||||
// Performance Benchmark
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
timer tim;
|
||||
for (int iter = 0; iter < iters; iter++) {
|
||||
if (agg_iters>1) NCCLCHECK(ncclGroupStart());
|
||||
for (int aiter = 0; aiter < agg_iters; aiter++) {
|
||||
@ -432,7 +452,7 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
}
|
||||
// Resync CPU, restart timing, launch cuda graph
|
||||
Barrier(args);
|
||||
start = std::chrono::high_resolution_clock::now();
|
||||
tim.reset();
|
||||
for (int l=0; l<cudaGraphLaunches; l++) {
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
CUDACHECK(cudaGraphLaunch(graphExec[i], args->streams[i]));
|
||||
@ -441,10 +461,10 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
}
|
||||
#endif
|
||||
|
||||
double cputimeSec = tim.elapsed()/(iters*agg_iters);
|
||||
TESTCHECK(completeColl(args));
|
||||
|
||||
auto delta = std::chrono::high_resolution_clock::now() - start;
|
||||
double deltaSec = std::chrono::duration_cast<std::chrono::duration<double>>(delta).count();
|
||||
double deltaSec = tim.elapsed();
|
||||
deltaSec = deltaSec/(iters*agg_iters);
|
||||
if (cudaGraphLaunches >= 1) deltaSec = deltaSec/cudaGraphLaunches;
|
||||
Allreduce(args, &deltaSec, average);
|
||||
@ -520,7 +540,7 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
wrongElts = wrongElts1;
|
||||
}
|
||||
|
||||
double timeUsec = deltaSec*1.0E6;
|
||||
double timeUsec = (report_cputime ? cputimeSec : deltaSec)*1.0E6;
|
||||
char timeStr[100];
|
||||
if (timeUsec >= 10000.0) {
|
||||
sprintf(timeStr, "%7.0f", timeUsec);
|
||||
@ -555,6 +575,9 @@ void setupArgs(size_t size, ncclDataType_t type, struct threadArgs* args) {
|
||||
}
|
||||
|
||||
testResult_t TimeTest(struct threadArgs* args, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName, int root) {
|
||||
// Sync to avoid first-call timeout
|
||||
Barrier(args);
|
||||
|
||||
// Warm-up for large size
|
||||
setupArgs(args->maxbytes, type, args);
|
||||
for (int iter = 0; iter < warmup_iters; iter++) {
|
||||
@ -586,8 +609,7 @@ testResult_t threadRunTests(struct threadArgs* args) {
|
||||
// Set device to the first of our GPUs. If we don't do that, some operations
|
||||
// will be done on the current GPU (by default : 0) and if the GPUs are in
|
||||
// exclusive mode those operations will fail.
|
||||
int gpuid = args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[0]));
|
||||
TESTCHECK(ncclTestEngine.runTest(args, ncclroot, (ncclDataType_t)nccltype, test_typenames[nccltype], (ncclRedOp_t)ncclop, test_opnames[ncclop]));
|
||||
return testSuccess;
|
||||
}
|
||||
@ -598,13 +620,12 @@ testResult_t threadInit(struct threadArgs* args) {
|
||||
int nranks = args->nProcs*args->nThreads*args->nGpus;
|
||||
|
||||
//set main thread again
|
||||
is_main_thread = (args->proc == 0 && args->thread == 0) ? 1 : 0;
|
||||
is_main_thread = (is_main_proc && args->thread == 0) ? 1 : 0;
|
||||
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
int rank = args->proc*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
int gpuid = args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
NCCLCHECK(ncclCommInitRank(args->comms+i, nranks, args->ncclId, rank));
|
||||
}
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
@ -679,7 +700,10 @@ int main(int argc, char* argv[]) {
|
||||
{"datatype", required_argument, 0, 'd'},
|
||||
{"root", required_argument, 0, 'r'},
|
||||
{"blocking", required_argument, 0, 'z'},
|
||||
{"stream_null", required_argument, 0, 'y'},
|
||||
{"timeout", required_argument, 0, 'T'},
|
||||
{"cudagraph", required_argument, 0, 'G'},
|
||||
{"report_cputime", required_argument, 0, 'C'},
|
||||
{"average", required_argument, 0, 'a'},
|
||||
{"help", no_argument, 0, 'h'},
|
||||
{}
|
||||
@ -687,7 +711,7 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
while(1) {
|
||||
int c;
|
||||
c = getopt_long(argc, argv, "t:g:b:e:i:f:n:m:w:p:c:o:d:r:z:hG:a:", longopts, &longindex);
|
||||
c = getopt_long(argc, argv, "t:g:b:e:i:f:n:m:w:p:c:o:d:r:z:y:T:hG:C:a:", longopts, &longindex);
|
||||
|
||||
if (c == -1)
|
||||
break;
|
||||
@ -752,6 +776,12 @@ int main(int argc, char* argv[]) {
|
||||
case 'z':
|
||||
blocking_coll = strtol(optarg, NULL, 0);
|
||||
break;
|
||||
case 'y':
|
||||
streamnull = strtol(optarg, NULL, 0);
|
||||
break;
|
||||
case 'T':
|
||||
timeout = strtol(optarg, NULL, 0);
|
||||
break;
|
||||
case 'G':
|
||||
#if (NCCL_MAJOR > 2 || (NCCL_MAJOR >= 2 && NCCL_MINOR >= 9)) && CUDART_VERSION >= 11030
|
||||
cudaGraphLaunches = strtol(optarg, NULL, 0);
|
||||
@ -759,6 +789,9 @@ int main(int argc, char* argv[]) {
|
||||
printf("Option -G (CUDA graph) not supported before NCCL 2.9 + CUDA 11.3. Ignoring\n");
|
||||
#endif
|
||||
break;
|
||||
case 'C':
|
||||
report_cputime = strtol(optarg, NULL, 0);
|
||||
break;
|
||||
case 'a':
|
||||
average = (int)strtol(optarg, NULL, 0);
|
||||
break;
|
||||
@ -787,11 +820,14 @@ int main(int argc, char* argv[]) {
|
||||
"[-d,--datatype <nccltype/all>] \n\t"
|
||||
"[-r,--root <root>] \n\t"
|
||||
"[-z,--blocking <0/1>] \n\t"
|
||||
"[-y,--stream_null <0/1>] \n\t"
|
||||
"[-T,--timeout <time in seconds>] \n\t"
|
||||
"[-G,--cudagraph <num graph launches>] \n\t"
|
||||
"[-C,--report_cputime <0/1>] \n\t"
|
||||
"[-a,--average <0/1/2/3> report average iteration time <0=RANK0/1=AVG/2=MIN/3=MAX>] \n\t"
|
||||
"[-h,--help]\n",
|
||||
basename(argv[0]));
|
||||
return 0;
|
||||
basename(argv[0]));
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
if (minBytes > maxBytes) {
|
||||
@ -808,23 +844,31 @@ int main(int argc, char* argv[]) {
|
||||
}
|
||||
|
||||
testResult_t run() {
|
||||
int nProcs = 1, proc = 0;
|
||||
int totalProcs = 1, proc = 0, ncclProcs = 1, ncclProc = 0, color = 0;
|
||||
int localRank = 0;
|
||||
char hostname[1024];
|
||||
getHostName(hostname, 1024);
|
||||
|
||||
#ifdef MPI_SUPPORT
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &nProcs);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &totalProcs);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &proc);
|
||||
uint64_t hostHashs[nProcs];
|
||||
uint64_t hostHashs[totalProcs];
|
||||
hostHashs[proc] = getHostHash(hostname);
|
||||
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD);
|
||||
for (int p=0; p<nProcs; p++) {
|
||||
for (int p=0; p<totalProcs; p++) {
|
||||
if (p == proc) break;
|
||||
if (hostHashs[p] == hostHashs[proc]) localRank++;
|
||||
}
|
||||
|
||||
char* str = getenv("NCCL_TESTS_SPLIT_MASK");
|
||||
uint64_t mask = str ? strtoul(str, NULL, 16) : 0;
|
||||
MPI_Comm mpi_comm;
|
||||
color = proc & mask;
|
||||
MPI_Comm_split(MPI_COMM_WORLD, color, proc, &mpi_comm);
|
||||
MPI_Comm_size(mpi_comm, &ncclProcs);
|
||||
MPI_Comm_rank(mpi_comm, &ncclProc);
|
||||
#endif
|
||||
is_main_thread = (proc == 0) ? 1 : 0;
|
||||
is_main_thread = is_main_proc = (proc == 0) ? 1 : 0;
|
||||
|
||||
PRINT("# nThread %d nGpus %d minBytes %ld maxBytes %ld step: %ld(%s) warmup iters: %d iters: %d agg iters: %d validation: %d graph: %d\n",
|
||||
nThreads, nGpus, minBytes, maxBytes,
|
||||
@ -839,22 +883,24 @@ testResult_t run() {
|
||||
char line[MAX_LINE];
|
||||
int len = 0;
|
||||
size_t maxMem = ~0;
|
||||
char* envstr = getenv("NCCL_TESTS_DEVICE");
|
||||
int gpu0 = envstr ? atoi(envstr) : -1;
|
||||
for (int i=0; i<nThreads*nGpus; i++) {
|
||||
int cudaDev = localRank*nThreads*nGpus+i;
|
||||
int cudaDev = (gpu0 != -1 ? gpu0 : localRank*nThreads*nGpus) + i;
|
||||
int rank = proc*nThreads*nGpus+i;
|
||||
cudaDeviceProp prop;
|
||||
CUDACHECK(cudaGetDeviceProperties(&prop, cudaDev));
|
||||
len += snprintf(line+len, MAX_LINE-len, "# Rank %2d Pid %6d on %10s device %2d [0x%02x] %s\n",
|
||||
rank, getpid(), hostname, cudaDev, prop.pciBusID, prop.name);
|
||||
len += snprintf(line+len, MAX_LINE-len, "# Rank %2d Group %2d Pid %6d on %10s device %2d [0x%02x] %s\n",
|
||||
rank, color, getpid(), hostname, cudaDev, prop.pciBusID, prop.name);
|
||||
maxMem = std::min(maxMem, prop.totalGlobalMem);
|
||||
}
|
||||
|
||||
#if MPI_SUPPORT
|
||||
char *lines = (proc == 0) ? (char *)malloc(nProcs*MAX_LINE) : NULL;
|
||||
char *lines = (proc == 0) ? (char *)malloc(totalProcs*MAX_LINE) : NULL;
|
||||
// Gather all output in rank order to root (0)
|
||||
MPI_Gather(line, MAX_LINE, MPI_BYTE, lines, MAX_LINE, MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
if (proc == 0) {
|
||||
for (int p = 0; p < nProcs; p++)
|
||||
for (int p = 0; p < totalProcs; p++)
|
||||
PRINT("%s", lines+MAX_LINE*p);
|
||||
free(lines);
|
||||
}
|
||||
@ -871,39 +917,43 @@ testResult_t run() {
|
||||
}
|
||||
|
||||
ncclUniqueId ncclId;
|
||||
if (proc == 0) {
|
||||
if (ncclProc == 0) {
|
||||
NCCLCHECK(ncclGetUniqueId(&ncclId));
|
||||
}
|
||||
#ifdef MPI_SUPPORT
|
||||
MPI_Bcast(&ncclId, sizeof(ncclId), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
MPI_Bcast(&ncclId, sizeof(ncclId), MPI_BYTE, 0, mpi_comm);
|
||||
#endif
|
||||
int gpus[nGpus*nThreads];
|
||||
cudaStream_t streams[nGpus*nThreads];
|
||||
void* sendbuffs[nGpus*nThreads];
|
||||
void* recvbuffs[nGpus*nThreads];
|
||||
void* expected[nGpus*nThreads];
|
||||
size_t sendBytes, recvBytes;
|
||||
|
||||
ncclTestEngine.getBuffSize(&sendBytes, &recvBytes, (size_t)maxBytes, (size_t)nProcs*nGpus*nThreads);
|
||||
ncclTestEngine.getBuffSize(&sendBytes, &recvBytes, (size_t)maxBytes, (size_t)ncclProcs*nGpus*nThreads);
|
||||
|
||||
envstr = getenv("NCCL_TESTS_DEVICE");
|
||||
gpu0 = envstr ? atoi(envstr) : -1;
|
||||
for (int i=0; i<nGpus*nThreads; i++) {
|
||||
CUDACHECK(cudaSetDevice(localRank*nThreads*nGpus+i));
|
||||
gpus[i] = (gpu0 != -1 ? gpu0 : localRank*nThreads*nGpus) + i;
|
||||
CUDACHECK(cudaSetDevice(gpus[i]));
|
||||
TESTCHECK(AllocateBuffs(sendbuffs+i, sendBytes, recvbuffs+i, recvBytes, expected+i, (size_t)maxBytes));
|
||||
CUDACHECK(cudaStreamCreateWithFlags(streams+i, cudaStreamNonBlocking));
|
||||
if (streamnull)
|
||||
streams[i] = NULL;
|
||||
else
|
||||
CUDACHECK(cudaStreamCreateWithFlags(streams+i, cudaStreamNonBlocking));
|
||||
}
|
||||
|
||||
//if parallel init is not selected, use main thread to initialize NCCL
|
||||
ncclComm_t* comms = (ncclComm_t*)malloc(sizeof(ncclComm_t)*nThreads*nGpus);
|
||||
if (!parallel_init) {
|
||||
if (nProcs == 1) {
|
||||
int gpuArray[nGpus*nThreads];
|
||||
for (int i=0; i<nGpus*nThreads; i++) gpuArray[i] = i;
|
||||
NCCLCHECK(ncclCommInitAll(comms, nGpus*nThreads, gpuArray));
|
||||
if (ncclProcs == 1) {
|
||||
NCCLCHECK(ncclCommInitAll(comms, nGpus*nThreads, gpus));
|
||||
} else {
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
for (int i=0; i<nGpus*nThreads; i++) {
|
||||
CUDACHECK(cudaSetDevice(localRank*nThreads*nGpus+i));
|
||||
NCCLCHECK(ncclCommInitRank(comms+i, nProcs*nThreads*nGpus, ncclId, proc*nThreads*nGpus+i));
|
||||
CUDACHECK(cudaSetDevice(gpus[i]));
|
||||
NCCLCHECK(ncclCommInitRank(comms+i, ncclProcs*nThreads*nGpus, ncclId, ncclProc*nThreads*nGpus+i));
|
||||
}
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
}
|
||||
@ -919,10 +969,11 @@ testResult_t run() {
|
||||
errors[t] = bw_count[t] = 0;
|
||||
}
|
||||
|
||||
const char* timeStr = report_cputime ? "cputime" : "time";
|
||||
PRINT("#\n");
|
||||
PRINT("# %10s %12s %8s %6s %6s out-of-place in-place \n", "", "", "", "", "");
|
||||
PRINT("# %10s %12s %8s %6s %6s %7s %6s %6s %6s %7s %6s %6s %6s\n", "size", "count", "type", "redop", "root",
|
||||
"time", "algbw", "busbw", "#wrong", "time", "algbw", "busbw", "#wrong");
|
||||
timeStr, "algbw", "busbw", "#wrong", timeStr, "algbw", "busbw", "#wrong");
|
||||
PRINT("# %10s %12s %8s %6s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "", "", "",
|
||||
"(us)", "(GB/s)", "(GB/s)", "", "(us)", "(GB/s)", "(GB/s)", "");
|
||||
|
||||
@ -936,11 +987,13 @@ testResult_t run() {
|
||||
threads[t].args.stepfactor=stepFactor;
|
||||
threads[t].args.localRank = localRank;
|
||||
|
||||
threads[t].args.nProcs=nProcs;
|
||||
threads[t].args.proc=proc;
|
||||
threads[t].args.totalProcs=totalProcs;
|
||||
threads[t].args.nProcs=ncclProcs;
|
||||
threads[t].args.proc=ncclProc;
|
||||
threads[t].args.nThreads=nThreads;
|
||||
threads[t].args.thread=t;
|
||||
threads[t].args.nGpus=nGpus;
|
||||
threads[t].args.gpus=gpus+t*nGpus;
|
||||
threads[t].args.sendbuffs = sendbuffs+t*nGpus;
|
||||
threads[t].args.recvbuffs = recvbuffs+t*nGpus;
|
||||
threads[t].args.expected = expected+t*nGpus;
|
||||
@ -990,8 +1043,8 @@ testResult_t run() {
|
||||
}
|
||||
CUDACHECK(cudaFreeHost(delta));
|
||||
|
||||
char* str = getenv("NCCL_TESTS_MIN_BW");
|
||||
double check_avg_bw = str ? atof(str) : -1;
|
||||
envstr = getenv("NCCL_TESTS_MIN_BW");
|
||||
double check_avg_bw = envstr ? atof(envstr) : -1;
|
||||
bw[0] /= bw_count[0];
|
||||
|
||||
PRINT("# Out of bounds values : %d %s\n", errors[0], errors[0] ? "FAILED" : "OK");
|
||||
@ -1001,6 +1054,8 @@ testResult_t run() {
|
||||
MPI_Finalize();
|
||||
#endif
|
||||
|
||||
PRINT("%s\n", ncclGetLastError(NULL));
|
||||
|
||||
// 'cuda-memcheck --leak-check full' requires this
|
||||
cudaDeviceReset();
|
||||
|
||||
|
||||
56
src/common.h
56
src/common.h
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -15,6 +15,10 @@
|
||||
#endif
|
||||
#include <pthread.h>
|
||||
#include "nccl1_compat.h"
|
||||
#include "timer.h"
|
||||
|
||||
// For nccl.h < 2.13 since we define a weak fallback
|
||||
extern "C" char const* ncclGetLastError(ncclComm_t comm);
|
||||
|
||||
#define CUDACHECK(cmd) do { \
|
||||
cudaError_t err = cmd; \
|
||||
@ -61,6 +65,8 @@ typedef enum {
|
||||
testInternalError = 1,
|
||||
testCudaError = 2,
|
||||
testNcclError = 3,
|
||||
testTimeout = 4,
|
||||
testNumResults = 5
|
||||
} testResult_t;
|
||||
|
||||
// Relay errors up and trace
|
||||
@ -110,11 +116,13 @@ struct threadArgs {
|
||||
size_t stepbytes;
|
||||
size_t stepfactor;
|
||||
|
||||
int totalProcs;
|
||||
int nProcs;
|
||||
int proc;
|
||||
int nThreads;
|
||||
int thread;
|
||||
int nGpus;
|
||||
int* gpus;
|
||||
int localRank;
|
||||
void** sendbuffs;
|
||||
size_t sendBytes;
|
||||
@ -144,8 +152,6 @@ struct testThread {
|
||||
testResult_t ret;
|
||||
};
|
||||
|
||||
#include <chrono>
|
||||
|
||||
// Provided by common.cu
|
||||
extern void Barrier(struct threadArgs* args);
|
||||
extern testResult_t TimeTest(struct threadArgs* args, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName, int root);
|
||||
@ -153,10 +159,6 @@ extern testResult_t InitDataReduce(void* data, const size_t count, const size_t
|
||||
extern testResult_t InitData(void* data, const size_t count, size_t offset, ncclDataType_t type, ncclRedOp_t op, const uint64_t seed, const int nranks, const int rank);
|
||||
extern void AllocateBuffs(void **sendbuff, void **recvbuff, void **expected, void **expectedHost, size_t nbytes, int nranks);
|
||||
|
||||
// Provided by each coll
|
||||
extern void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root);
|
||||
extern void print_header();
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
static void getHostName(char* hostname, int maxlen) {
|
||||
@ -171,46 +173,15 @@ static void getHostName(char* hostname, int maxlen) {
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
static uint64_t getHash(const char* string, size_t n) {
|
||||
// Based on DJB2a, result = result * 33 ^ char
|
||||
static uint64_t getHostHash(const char* string) {
|
||||
// Based on DJB2, result = result * 33 + char
|
||||
uint64_t result = 5381;
|
||||
for (size_t c = 0; c < n; c++) {
|
||||
result = ((result << 5) + result) ^ string[c];
|
||||
for (int c = 0; string[c] != '\0'; c++){
|
||||
result = ((result << 5) + result) + string[c];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* Generate a hash of the unique identifying string for this host
|
||||
* that will be unique for both bare-metal and container instances
|
||||
* Equivalent of a hash of;
|
||||
*
|
||||
* $(hostname)$(cat /proc/sys/kernel/random/boot_id)
|
||||
*
|
||||
*/
|
||||
#define HOSTID_FILE "/proc/sys/kernel/random/boot_id"
|
||||
static uint64_t getHostHash(const char* hostname) {
|
||||
char hostHash[1024];
|
||||
|
||||
// Fall back is the hostname if something fails
|
||||
(void) strncpy(hostHash, hostname, sizeof(hostHash));
|
||||
int offset = strlen(hostHash);
|
||||
|
||||
FILE *file = fopen(HOSTID_FILE, "r");
|
||||
if (file != NULL) {
|
||||
char *p;
|
||||
if (fscanf(file, "%ms", &p) == 1) {
|
||||
strncpy(hostHash+offset, p, sizeof(hostHash)-offset-1);
|
||||
free(p);
|
||||
}
|
||||
}
|
||||
fclose(file);
|
||||
|
||||
// Make sure the string is terminated
|
||||
hostHash[sizeof(hostHash)-1]='\0';
|
||||
|
||||
return getHash(hostHash, strlen(hostHash));
|
||||
}
|
||||
|
||||
static size_t wordSize(ncclDataType_t type) {
|
||||
switch(type) {
|
||||
case ncclChar:
|
||||
@ -277,6 +248,7 @@ static int ncclstringtoop (char *str) {
|
||||
return ncclSum;
|
||||
}
|
||||
|
||||
extern int is_main_proc;
|
||||
extern thread_local int is_main_thread;
|
||||
#define PRINT if (is_main_thread) printf
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -21,8 +21,7 @@ testResult_t GatherInitData(struct threadArgs* args, ncclDataType_t type, ncclRe
|
||||
int nranks = args->nProcs*args->nThreads*args->nGpus;
|
||||
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
int gpuid = args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
|
||||
void* data = in_place ? ((char*)args->recvbuffs[i])+rank*args->sendBytes : args->sendbuffs[i];
|
||||
@ -103,7 +102,7 @@ testResult_t GatherRunTest(struct threadArgs* args, int root, ncclDataType_t typ
|
||||
|
||||
for (int i=0; i<type_count; i++) {
|
||||
for (int j=begin_root; j<=end_root; j++) {
|
||||
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "", j));
|
||||
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "none", j));
|
||||
}
|
||||
}
|
||||
return testSuccess;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -24,8 +24,7 @@ testResult_t HyperCubeInitData(struct threadArgs* args, ncclDataType_t type, ncc
|
||||
int nranks = args->nProcs*args->nThreads*args->nGpus;
|
||||
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
int gpuid = args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
|
||||
void* data = in_place ? ((char*)args->recvbuffs[i])+rank*args->sendBytes : args->sendbuffs[i];
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -21,8 +21,7 @@ testResult_t ReduceInitData(struct threadArgs* args, ncclDataType_t type, ncclRe
|
||||
int nranks = args->nProcs*args->nThreads*args->nGpus;
|
||||
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
int gpuid = args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
|
||||
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -7,12 +7,15 @@
|
||||
#include "cuda_runtime.h"
|
||||
#include "common.h"
|
||||
|
||||
#define ALIGN 4
|
||||
|
||||
void ReduceScatterGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) {
|
||||
*sendcount = (count/nranks)*nranks;
|
||||
*recvcount = count/nranks;
|
||||
size_t base = (count/(ALIGN*nranks))*ALIGN;
|
||||
*sendcount = base*nranks;
|
||||
*recvcount = base;
|
||||
*sendInplaceOffset = 0;
|
||||
*recvInplaceOffset = count/nranks;
|
||||
*paramcount = *recvcount;
|
||||
*recvInplaceOffset = base;
|
||||
*paramcount = base;
|
||||
}
|
||||
|
||||
testResult_t ReduceScatterInitData(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t op, int root, int rep, int in_place) {
|
||||
@ -21,8 +24,7 @@ testResult_t ReduceScatterInitData(struct threadArgs* args, ncclDataType_t type,
|
||||
int nranks = args->nProcs*args->nThreads*args->nGpus;
|
||||
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
int gpuid = args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
|
||||
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -20,8 +20,7 @@ testResult_t ScatterInitData(struct threadArgs* args, ncclDataType_t type, ncclR
|
||||
size_t recvcount = args->expectedBytes / wordSize(type);
|
||||
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
int gpuid = args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
|
||||
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
|
||||
@ -99,7 +98,7 @@ testResult_t ScatterRunTest(struct threadArgs* args, int root, ncclDataType_t ty
|
||||
|
||||
for (int i=0; i<type_count; i++) {
|
||||
for (int j=begin_root; j<=end_root; j++) {
|
||||
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "", j));
|
||||
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, "none", j));
|
||||
}
|
||||
}
|
||||
return testSuccess;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@ -21,8 +21,7 @@ testResult_t SendRecvInitData(struct threadArgs* args, ncclDataType_t type, nccl
|
||||
int nranks = args->nProcs*args->nThreads*args->nGpus;
|
||||
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
int gpuid = args->localRank*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
|
||||
CUDACHECK(cudaSetDevice(gpuid));
|
||||
CUDACHECK(cudaSetDevice(args->gpus[i]));
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
|
||||
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
|
||||
|
||||
28
src/timer.cc
Normal file
28
src/timer.cc
Normal file
@ -0,0 +1,28 @@
|
||||
#include "timer.h"
|
||||
|
||||
// Make sure to compile this translation unit with the host compiler and not
|
||||
// nvcc, lest you hit an internal compiler error (ICE) with GCC 10.3.0
|
||||
#include <chrono>
|
||||
|
||||
namespace {
|
||||
std::uint64_t now() {
|
||||
using clock = std::chrono::steady_clock;
|
||||
return std::chrono::duration_cast<std::chrono::nanoseconds>(clock::now().time_since_epoch()).count();
|
||||
}
|
||||
}
|
||||
|
||||
timer::timer() {
|
||||
t0 = now();
|
||||
}
|
||||
|
||||
double timer::elapsed() const {
|
||||
std::uint64_t t1 = now();
|
||||
return 1.e-9*(t1 - t0);
|
||||
}
|
||||
|
||||
double timer::reset() {
|
||||
std::uint64_t t1 = now();
|
||||
double ans = 1.e-9*(t1 - t0);
|
||||
t0 = t1;
|
||||
return ans;
|
||||
}
|
||||
15
src/timer.h
Normal file
15
src/timer.h
Normal file
@ -0,0 +1,15 @@
|
||||
#ifndef _408319ecdd5b47b28bf8f511c4fdf816
|
||||
#define _408319ecdd5b47b28bf8f511c4fdf816
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
// Can't include <chrono> because of bug with gcc 10.3.0
|
||||
class timer {
|
||||
std::uint64_t t0;
|
||||
public:
|
||||
timer();
|
||||
double elapsed() const;
|
||||
double reset();
|
||||
};
|
||||
|
||||
#endif
|
||||
Loading…
Reference in New Issue
Block a user