refactor comm init

This commit is contained in:
Katie Gioioso 2025-11-21 02:56:37 +00:00
parent 332e61896f
commit 070d17528c

View File

@ -184,6 +184,32 @@ static void outputFileFinalize(output_file_type_t output_file_type) {
}
}
testResult_t initComms(ncclComm_t* comms, int nComms, int firstRank, int nRanks, int* cudaDevs, ncclUniqueId& ncclId) {
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,14,0)
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0)
if (ctaPolicy >= 0)
config.CTAPolicy = ctaPolicy;
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
config.nvlinkCentricSched = 1;
#endif
#endif
#endif
NCCLCHECK(ncclGroupStart());
for (int i=0; i<nComms; i++) {
int rank = firstRank + i;
CUDACHECK(cudaSetDevice(cudaDevs[i]));
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,14,0)
NCCLCHECK(ncclCommInitRankConfig(comms+i, nRanks, ncclId, rank, &config));
#else
NCCLCHECK(ncclCommInitRank(comms+i, nRanks, ncclId, rank));
#endif
}
NCCLCHECK(ncclGroupEnd());
return testSuccess;
}
static double parsesize(const char *value) {
long long int units;
double size;
@ -759,28 +785,8 @@ testResult_t threadInit(struct threadArgs* args) {
getGPUMemoryInfo(nullptr, &initFreeGpuMem[g]);
}
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,14,0)
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0)
if (ctaPolicy >= 0)
config.CTAPolicy = ctaPolicy;
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
config.nvlinkCentricSched = 1;
#endif
#endif
#endif
NCCLCHECK(ncclGroupStart());
for (int i=0; i<args->nGpus; i++) {
int rank = args->proc*args->nThreads*args->nGpus + args->thread*args->nGpus + i;
CUDACHECK(cudaSetDevice(args->gpus[i]));
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,14,0)
NCCLCHECK(ncclCommInitRankConfig(args->comms+i, nranks, args->ncclId, rank, &config));
#else
NCCLCHECK(ncclCommInitRank(args->comms+i, nranks, args->ncclId, rank));
#endif
}
NCCLCHECK(ncclGroupEnd());
int firstRank = args->proc*args->nThreads*args->nGpus + args->thread*args->nGpus;
TESTCHECK(initComms(args->comms, args->nGpus, firstRank, nranks, args->gpus, args->ncclId));
// Capture the memory used by the GPUs after initializing the NCCL communicators
for (int g = 0; g < args->nGpus; ++g) {
@ -1347,26 +1353,7 @@ testResult_t run() {
getGPUMemoryInfo(nullptr, &initFreeGpuMem[g]);
}
//if parallel init is not selected, use main thread to initialize NCCL
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,14,0)
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0)
if (ctaPolicy >= 0)
config.CTAPolicy = ctaPolicy;
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
config.nvlinkCentricSched = 1;
#endif
#endif
#endif
NCCLCHECK(ncclGroupStart());
for (int i=0; i<nGpus*nThreads; i++) {
CUDACHECK(cudaSetDevice(gpus[i]));
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,14,0)
NCCLCHECK(ncclCommInitRankConfig(comms+i, ncclProcs*nThreads*nGpus, ncclId, ncclProc*nThreads*nGpus+i, &config));
#else
NCCLCHECK(ncclCommInitRank(comms+i, ncclProcs*nThreads*nGpus, ncclId, ncclProc*nThreads*nGpus+i));
#endif
}
NCCLCHECK(ncclGroupEnd());
TESTCHECK(initComms(comms, nGpus*nThreads, ncclProc*nThreads*nGpus, ncclProcs*nThreads*nGpus, gpus, ncclId));
// Capture the memory used by the GPUs after initializing the NCCL communicators
for (int g = 0; g < nGpus; ++g) {