From 9b9de7da81c3fcbee62e5e3a5e621391673f3cd3 Mon Sep 17 00:00:00 2001 From: David Addison Date: Mon, 19 May 2025 18:20:22 -0700 Subject: [PATCH] Add support for Symmetric Memory Registration From NCCL 2.27.x we can now use the Symmetric Memory APIs (-R 2) --- src/common.cu | 58 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/src/common.cu b/src/common.cu index f83cdf0..3987d89 100644 --- a/src/common.cu +++ b/src/common.cu @@ -90,6 +90,8 @@ static int report_cputime = 0; // Report average iteration time: (0=RANK0,1=AVG,2=MIN,3=MAX) static int average = 1; #if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0) +#define LOCAL_REGISTER 1 +#define SYMMETRIC_REGISTER 2 static int local_register = 0; #endif static int minCudaArch = 1<<30; @@ -660,8 +662,16 @@ testResult_t threadInit(struct threadArgs* args) { void **sendRegHandles = (local_register) ? (void **)malloc(sizeof(*sendRegHandles)*args->nGpus) : NULL; void **recvRegHandles = (local_register) ? (void **)malloc(sizeof(*recvRegHandles)*args->nGpus) : NULL; for (int i=0; inGpus; i++) { - if (local_register) NCCLCHECK(ncclCommRegister(args->comms[i], args->sendbuffs[i], args->maxbytes, &sendRegHandles[i])); - if (local_register) NCCLCHECK(ncclCommRegister(args->comms[i], args->recvbuffs[i], args->maxbytes, &recvRegHandles[i])); +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0) + if (test_ncclVersion >= NCCL_VERSION(2,27,0) && (local_register == SYMMETRIC_REGISTER)) { + NCCLCHECK(ncclCommWindowRegister(args->comms[i], args->sendbuffs[i], args->maxbytes, (ncclWindow_t*)&sendRegHandles[i], NCCL_WIN_COLL_SYMMETRIC)); + NCCLCHECK(ncclCommWindowRegister(args->comms[i], args->recvbuffs[i], args->maxbytes, (ncclWindow_t*)&recvRegHandles[i], NCCL_WIN_COLL_SYMMETRIC)); + } else +#endif + { + if (local_register) NCCLCHECK(ncclCommRegister(args->comms[i], args->sendbuffs[i], args->maxbytes, &sendRegHandles[i])); + if (local_register) NCCLCHECK(ncclCommRegister(args->comms[i], args->recvbuffs[i], args->maxbytes, &recvRegHandles[i])); + } } #endif @@ -669,8 +679,16 @@ testResult_t threadInit(struct threadArgs* args) { for (int i=0; inGpus; i++) { #if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0) - if (local_register) NCCLCHECK(ncclCommDeregister(args->comms[i], sendRegHandles[i])); - if (local_register) NCCLCHECK(ncclCommDeregister(args->comms[i], recvRegHandles[i])); +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0) + if (test_ncclVersion >= NCCL_VERSION(2,27,0) && (local_register == SYMMETRIC_REGISTER)) { + NCCLCHECK(ncclCommWindowDeregister(args->comms[i], (ncclWindow_t)sendRegHandles[i])); + NCCLCHECK(ncclCommWindowDeregister(args->comms[i], (ncclWindow_t)recvRegHandles[i])); + } else +#endif + { + if (local_register) NCCLCHECK(ncclCommDeregister(args->comms[i], sendRegHandles[i])); + if (local_register) NCCLCHECK(ncclCommDeregister(args->comms[i], recvRegHandles[i])); + } #endif NCCLCHECK(ncclCommDestroy(args->comms[i])); } @@ -859,8 +877,10 @@ int main(int argc, char* argv[]) { break; case 'R': #if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0) - if ((int)strtol(optarg, NULL, 0)) { - local_register = 1; + local_register = (int)strtol(optarg, NULL, 0); + if (local_register == SYMMETRIC_REGISTER && test_ncclVersion < NCCL_VERSION(2,27,0)) { + printf("Option -R 2 (symmetric) is not supported before NCCL 2.27. Defaulting to local registration\n"); + local_register = LOCAL_REGISTER; } #else printf("Option -R (register) is not supported before NCCL 2.19. Ignoring\n"); @@ -897,7 +917,7 @@ int main(int argc, char* argv[]) { "[-G,--cudagraph ] \n\t" "[-C,--report_cputime <0/1>] \n\t" "[-a,--average <0/1/2/3> report average iteration time <0=RANK0/1=AVG/2=MIN/3=MAX>] \n\t" - "[-R,--local_register <1/0> enable local buffer registration on send/recv buffers (default: disable)] \n\t" + "[-R,--local_register <0/1/2> enable local (1) or symmetric (2) buffer registration on send/recv buffers (default: disable (0))] \n\t" "[-h,--help]\n", basename(argv[0])); return 0; @@ -1107,8 +1127,16 @@ testResult_t run() { sendRegHandles = (local_register) ? (void **)malloc(sizeof(*sendRegHandles)*nThreads*nGpus) : NULL; recvRegHandles = (local_register) ? (void **)malloc(sizeof(*recvRegHandles)*nThreads*nGpus) : NULL; for (int i=0; i= NCCL_VERSION(2,27,0) + if (test_ncclVersion >= NCCL_VERSION(2,27,0) && (local_register == SYMMETRIC_REGISTER)) { + NCCLCHECK(ncclCommWindowRegister(comms[i], sendbuffs[i], maxBytes, (ncclWindow_t*)&sendRegHandles[i], NCCL_WIN_COLL_SYMMETRIC)); + NCCLCHECK(ncclCommWindowRegister(comms[i], recvbuffs[i], maxBytes, (ncclWindow_t*)&recvRegHandles[i], NCCL_WIN_COLL_SYMMETRIC)); + } else +#endif + { + if (local_register) NCCLCHECK(ncclCommRegister(comms[i], sendbuffs[i], maxBytes, &sendRegHandles[i])); + if (local_register) NCCLCHECK(ncclCommRegister(comms[i], recvbuffs[i], maxBytes, &recvRegHandles[i])); + } } #endif } @@ -1188,8 +1216,16 @@ testResult_t run() { if (!parallel_init) { for(int i=0; i= NCCL_VERSION(2,19,0) - if (local_register) NCCLCHECK(ncclCommDeregister(comms[i], sendRegHandles[i])); - if (local_register) NCCLCHECK(ncclCommDeregister(comms[i], recvRegHandles[i])); +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0) + if (test_ncclVersion >= NCCL_VERSION(2,27,0) && (local_register == SYMMETRIC_REGISTER)) { + NCCLCHECK(ncclCommWindowDeregister(comms[i], (ncclWindow_t)sendRegHandles[i])); + NCCLCHECK(ncclCommWindowDeregister(comms[i], (ncclWindow_t)recvRegHandles[i])); + } else +#endif + { + if (local_register) NCCLCHECK(ncclCommDeregister(comms[i], sendRegHandles[i])); + if (local_register) NCCLCHECK(ncclCommDeregister(comms[i], recvRegHandles[i])); + } #endif NCCLCHECK(ncclCommDestroy(comms[i])); }