mirror of
https://github.com/NVIDIA/nccl-tests.git
synced 2026-01-14 02:47:21 +08:00
Compatibility with 2.29 device API: use NCCL_DEV_COMM_REQUIREMENTS_INTIIALIZER, query properties to check for device api support
This commit is contained in:
parent
7106245178
commit
24874bdaa8
@ -64,27 +64,50 @@ void AllReduceGetBw(size_t count, int typesize, double sec, double* algBw, doubl
|
||||
*busBw = baseBw * factor;
|
||||
}
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
// set devComm reqs for allreduce device kernels
|
||||
bool AllReduceGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs) {
|
||||
if (!reqs) return false;
|
||||
memset(reqs, 0, sizeof(*reqs));
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0)
|
||||
// set devComm reqs for allreduce device kernels
|
||||
testResult_t AllReduceGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs, ncclCommProperties_t* commProperties) {
|
||||
if (!reqs || !commProperties) return testInternalError;
|
||||
|
||||
switch(deviceImpl) {
|
||||
case 1: // allReduceLsaKernel
|
||||
case 2: // allReduceLsaVectorizedKernel
|
||||
reqs->lsaBarrierCount = deviceCtaCount;
|
||||
return true;
|
||||
return testSuccess;
|
||||
case 3: // allReduceMultimemKernel
|
||||
case 4: // allReduceMultimemVectorizedKernel
|
||||
if (!commProperties->multimemSupport) {
|
||||
fprintf(stderr, "This test requires multimem support, but multimem support is not enabled for this communicator.\n");
|
||||
return testInternalError;
|
||||
}
|
||||
reqs->lsaMultimem = true;
|
||||
reqs->lsaBarrierCount = deviceCtaCount;
|
||||
return testSuccess;
|
||||
default:
|
||||
return testNotImplemented;
|
||||
}
|
||||
}
|
||||
#elif NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
bool AllReduceGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs) {
|
||||
if (!reqs) return false;
|
||||
memset(reqs, 0, sizeof(*reqs));
|
||||
switch(deviceImpl) {
|
||||
case 1: // allReduceLsaKernel
|
||||
case 2: // allReduceLsaVectorizedKernel
|
||||
reqs->lsaBarrierCount = deviceCtaCount;
|
||||
return true;
|
||||
case 3: // allReduceMultimemKernel
|
||||
case 4: // allReduceMultimemVectorizedKernelMultimem = true;
|
||||
reqs->lsaMultimem = true;
|
||||
reqs->lsaBarrierCount = deviceCtaCount;
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
/*
|
||||
* Kernel 1: allReduceLsaKernel - Basic LSA-based AllReduce
|
||||
*
|
||||
|
||||
@ -51,7 +51,30 @@ void AlltoAllGetBw(size_t count, int typesize, double sec, double* algBw, double
|
||||
*busBw = baseBw * factor;
|
||||
}
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0)
|
||||
// set devComm reqs for alltoall device kernels
|
||||
testResult_t AlltoAllGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs, ncclCommProperties_t* commProperties) {
|
||||
if (!reqs || !commProperties) return testInternalError;
|
||||
|
||||
switch(deviceImpl) {
|
||||
case 1: // NvlAlltoAllKernel
|
||||
case 2: // NvlAlltoAllKernelOptimized
|
||||
reqs->lsaBarrierCount = deviceCtaCount;
|
||||
return testSuccess;
|
||||
case 3: // GinAlltoAllKernel
|
||||
case 4: // HybridAlltoAllKernel (LSA+GIN)
|
||||
if (commProperties->ginType == NCCL_GIN_TYPE_NONE) {
|
||||
fprintf(stderr, "This test requires GIN support, but GIN support is not enabled for this communicator.\n");
|
||||
return testInternalError;
|
||||
}
|
||||
reqs->barrierCount = deviceCtaCount;
|
||||
reqs->ginSignalCount = deviceCtaCount;
|
||||
return testSuccess;
|
||||
default:
|
||||
return testNotImplemented;
|
||||
}
|
||||
}
|
||||
#elif NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
// set devComm reqs for alltoall device kernels
|
||||
bool AlltoAllGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs) {
|
||||
if (!reqs) return false;
|
||||
@ -73,7 +96,9 @@ bool AlltoAllGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* req
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
// shared scalar AlltoAll implementation used by both kernels
|
||||
template <typename T>
|
||||
__device__ void AlltoAllScalarImpl(ncclWindow_t sendwin, size_t sendoffset, ncclWindow_t recvwin, size_t recvoffset, size_t count, int rank, int nRanks, int tid, int nthreads) {
|
||||
|
||||
@ -813,12 +813,23 @@ testResult_t threadInit(struct threadArgs* args) {
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
/* Create device communicators based on test-specific requirements */
|
||||
if (deviceImpl) {
|
||||
ncclDevCommRequirements reqs;
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0)
|
||||
ncclDevCommRequirements reqs = NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER;
|
||||
if (!ncclTestEngine.getDevCommRequirements) {
|
||||
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
|
||||
return testNotImplemented;
|
||||
}
|
||||
ncclCommProperties commProperties = NCCL_COMM_PROPERTIES_INITIALIZER;
|
||||
NCCLCHECK(ncclCommQueryProperties(args->comms[0], &commProperties));
|
||||
TESTCHECK(ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs, &commProperties));
|
||||
#else
|
||||
ncclDevCommRequirements reqs = {};
|
||||
if (!ncclTestEngine.getDevCommRequirements ||
|
||||
!ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) {
|
||||
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
|
||||
return testNotImplemented;
|
||||
}
|
||||
#endif
|
||||
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
for (int i = 0; i < args->nGpus; i++) {
|
||||
@ -1384,12 +1395,23 @@ testResult_t run() {
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
/* Create device communicators based on test-specific requirements */
|
||||
if (deviceImpl) {
|
||||
ncclDevCommRequirements reqs;
|
||||
if (!ncclTestEngine.getDevCommRequirements ||
|
||||
!ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) {
|
||||
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
|
||||
return testNotImplemented;
|
||||
}
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0)
|
||||
ncclDevCommRequirements reqs = NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER;
|
||||
if (!ncclTestEngine.getDevCommRequirements) {
|
||||
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
|
||||
return testNotImplemented;
|
||||
}
|
||||
ncclCommProperties commProperties = NCCL_COMM_PROPERTIES_INITIALIZER;
|
||||
NCCLCHECK(ncclCommQueryProperties(comms[0], &commProperties));
|
||||
TESTCHECK(ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs, &commProperties));
|
||||
#else
|
||||
ncclDevCommRequirements reqs = {};
|
||||
if (!ncclTestEngine.getDevCommRequirements ||
|
||||
!ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) {
|
||||
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
|
||||
return testNotImplemented;
|
||||
}
|
||||
#endif
|
||||
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
for (int i = 0; i < nGpus * nThreads; i++) {
|
||||
|
||||
@ -111,7 +111,10 @@ struct testEngine {
|
||||
void (*getBuffSize)(size_t *sendcount, size_t *recvcount, size_t count, int nranks);
|
||||
testResult_t (*runTest)(struct threadArgs* args, int root, ncclDataType_t type,
|
||||
const char* typeName, ncclRedOp_t op, const char* opName);
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0)
|
||||
testResult_t (*getDevCommRequirements)(int deviceImpl, ncclDevCommRequirements* reqs, ncclCommProperties_t* commProperties);
|
||||
#elif NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
bool (*getDevCommRequirements)(int deviceImpl, ncclDevCommRequirements* reqs);
|
||||
#endif
|
||||
};
|
||||
|
||||
Loading…
Reference in New Issue
Block a user