From f66d20e360cf94cad955b97c1e908f75b332ca43 Mon Sep 17 00:00:00 2001 From: Shane Snyder Date: Mon, 27 Oct 2025 13:38:10 -0700 Subject: [PATCH] add runtime guards for ncclAlltoAll() --- src/alltoall.cu | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/alltoall.cu b/src/alltoall.cu index 6021171..54e6836 100644 --- a/src/alltoall.cu +++ b/src/alltoall.cu @@ -11,6 +11,8 @@ #include "vector_types.h" #endif +#pragma weak ncclAlltoAll + void AlltoAllGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, size_t eltSize, int nranks) { *paramcount = (count/nranks) & -(16/eltSize); *sendcount = nranks*(*paramcount); @@ -264,8 +266,13 @@ testResult_t AlltoAllRunColl(void* sendbuff, size_t sendoffset, void* recvbuff, char* sptr = (char*)sendbuff + sendoffset; char* rptr = (char*)recvbuff + recvoffset; #if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0) - NCCLCHECK(ncclAlltoAll(sptr, rptr, count, type, comm, stream)); -#elif NCCL_VERSION_CODE >= NCCL_VERSION(2,7,0) + if (test_ncclVersion >= NCCL_VERSION(2,28,0)) { + NCCLCHECK(ncclAlltoAll(sptr, rptr, count, type, comm, stream)); + return testSuccess; + } + // fall-through to send/recv implementation if ncclAlltoAll is not available +#endif +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,7,0) int nRanks; NCCLCHECK(ncclCommCount(comm, &nRanks)); size_t rankOffset = count * wordSize(type);