mirror of
https://github.com/NVIDIA/nccl-tests.git
synced 2026-01-14 02:47:21 +08:00
add runtime guards for ncclAlltoAll()
This commit is contained in:
parent
3744121a2d
commit
f66d20e360
@ -11,6 +11,8 @@
|
||||
#include "vector_types.h"
|
||||
#endif
|
||||
|
||||
#pragma weak ncclAlltoAll
|
||||
|
||||
void AlltoAllGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, size_t eltSize, int nranks) {
|
||||
*paramcount = (count/nranks) & -(16/eltSize);
|
||||
*sendcount = nranks*(*paramcount);
|
||||
@ -264,8 +266,13 @@ testResult_t AlltoAllRunColl(void* sendbuff, size_t sendoffset, void* recvbuff,
|
||||
char* sptr = (char*)sendbuff + sendoffset;
|
||||
char* rptr = (char*)recvbuff + recvoffset;
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
NCCLCHECK(ncclAlltoAll(sptr, rptr, count, type, comm, stream));
|
||||
#elif NCCL_VERSION_CODE >= NCCL_VERSION(2,7,0)
|
||||
if (test_ncclVersion >= NCCL_VERSION(2,28,0)) {
|
||||
NCCLCHECK(ncclAlltoAll(sptr, rptr, count, type, comm, stream));
|
||||
return testSuccess;
|
||||
}
|
||||
// fall-through to send/recv implementation if ncclAlltoAll is not available
|
||||
#endif
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,7,0)
|
||||
int nRanks;
|
||||
NCCLCHECK(ncclCommCount(comm, &nRanks));
|
||||
size_t rankOffset = count * wordSize(type);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user