From 070d17528ce2eea4bbd89b23faafe16eafff7b9c Mon Sep 17 00:00:00 2001 From: Katie Gioioso Date: Fri, 21 Nov 2025 02:56:37 +0000 Subject: [PATCH] refactor comm init --- src/common.cu | 71 +++++++++++++++++++++------------------------------ 1 file changed, 29 insertions(+), 42 deletions(-) diff --git a/src/common.cu b/src/common.cu index 4da475f..800e0ff 100644 --- a/src/common.cu +++ b/src/common.cu @@ -184,6 +184,32 @@ static void outputFileFinalize(output_file_type_t output_file_type) { } } +testResult_t initComms(ncclComm_t* comms, int nComms, int firstRank, int nRanks, int* cudaDevs, ncclUniqueId& ncclId) { +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,14,0) + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0) + if (ctaPolicy >= 0) + config.CTAPolicy = ctaPolicy; +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0) + config.nvlinkCentricSched = 1; +#endif +#endif +#endif + + NCCLCHECK(ncclGroupStart()); + for (int i=0; i= NCCL_VERSION(2,14,0) + NCCLCHECK(ncclCommInitRankConfig(comms+i, nRanks, ncclId, rank, &config)); +#else + NCCLCHECK(ncclCommInitRank(comms+i, nRanks, ncclId, rank)); +#endif + } + NCCLCHECK(ncclGroupEnd()); + return testSuccess; +} + static double parsesize(const char *value) { long long int units; double size; @@ -759,28 +785,8 @@ testResult_t threadInit(struct threadArgs* args) { getGPUMemoryInfo(nullptr, &initFreeGpuMem[g]); } -#if NCCL_VERSION_CODE >= NCCL_VERSION(2,14,0) - ncclConfig_t config = NCCL_CONFIG_INITIALIZER; -#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0) - if (ctaPolicy >= 0) - config.CTAPolicy = ctaPolicy; -#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0) - config.nvlinkCentricSched = 1; -#endif -#endif -#endif - - NCCLCHECK(ncclGroupStart()); - for (int i=0; inGpus; i++) { - int rank = args->proc*args->nThreads*args->nGpus + args->thread*args->nGpus + i; - CUDACHECK(cudaSetDevice(args->gpus[i])); -#if NCCL_VERSION_CODE >= NCCL_VERSION(2,14,0) - NCCLCHECK(ncclCommInitRankConfig(args->comms+i, nranks, args->ncclId, rank, &config)); -#else - NCCLCHECK(ncclCommInitRank(args->comms+i, nranks, args->ncclId, rank)); -#endif - } - NCCLCHECK(ncclGroupEnd()); + int firstRank = args->proc*args->nThreads*args->nGpus + args->thread*args->nGpus; + TESTCHECK(initComms(args->comms, args->nGpus, firstRank, nranks, args->gpus, args->ncclId)); // Capture the memory used by the GPUs after initializing the NCCL communicators for (int g = 0; g < args->nGpus; ++g) { @@ -1347,26 +1353,7 @@ testResult_t run() { getGPUMemoryInfo(nullptr, &initFreeGpuMem[g]); } //if parallel init is not selected, use main thread to initialize NCCL -#if NCCL_VERSION_CODE >= NCCL_VERSION(2,14,0) - ncclConfig_t config = NCCL_CONFIG_INITIALIZER; -#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0) - if (ctaPolicy >= 0) - config.CTAPolicy = ctaPolicy; -#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0) - config.nvlinkCentricSched = 1; -#endif -#endif -#endif - NCCLCHECK(ncclGroupStart()); - for (int i=0; i= NCCL_VERSION(2,14,0) - NCCLCHECK(ncclCommInitRankConfig(comms+i, ncclProcs*nThreads*nGpus, ncclId, ncclProc*nThreads*nGpus+i, &config)); -#else - NCCLCHECK(ncclCommInitRank(comms+i, ncclProcs*nThreads*nGpus, ncclId, ncclProc*nThreads*nGpus+i)); -#endif - } - NCCLCHECK(ncclGroupEnd()); + TESTCHECK(initComms(comms, nGpus*nThreads, ncclProc*nThreads*nGpus, ncclProcs*nThreads*nGpus, gpus, ncclId)); // Capture the memory used by the GPUs after initializing the NCCL communicators for (int g = 0; g < nGpus; ++g) {