add runtime guards for ncclAlltoAll()

This commit is contained in:
Shane Snyder 2025-10-27 13:38:10 -07:00 committed by David Addison
parent 3744121a2d
commit f66d20e360

View File

@ -11,6 +11,8 @@
#include "vector_types.h"
#endif
#pragma weak ncclAlltoAll
void AlltoAllGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, size_t eltSize, int nranks) {
*paramcount = (count/nranks) & -(16/eltSize);
*sendcount = nranks*(*paramcount);
@ -264,8 +266,13 @@ testResult_t AlltoAllRunColl(void* sendbuff, size_t sendoffset, void* recvbuff,
char* sptr = (char*)sendbuff + sendoffset;
char* rptr = (char*)recvbuff + recvoffset;
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
NCCLCHECK(ncclAlltoAll(sptr, rptr, count, type, comm, stream));
#elif NCCL_VERSION_CODE >= NCCL_VERSION(2,7,0)
if (test_ncclVersion >= NCCL_VERSION(2,28,0)) {
NCCLCHECK(ncclAlltoAll(sptr, rptr, count, type, comm, stream));
return testSuccess;
}
// fall-through to send/recv implementation if ncclAlltoAll is not available
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,7,0)
int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
size_t rankOffset = count * wordSize(type);