Compatibility with 2.29 device API: use NCCL_DEV_COMM_REQUIREMENTS_INTIIALIZER, query properties to check for device api support

This commit is contained in:
Katie Gioioso 2025-11-13 01:16:07 +00:00
parent 7106245178
commit 24874bdaa8
4 changed files with 89 additions and 16 deletions

View File

@ -64,27 +64,50 @@ void AllReduceGetBw(size_t count, int typesize, double sec, double* algBw, doubl
*busBw = baseBw * factor;
}
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
// set devComm reqs for allreduce device kernels
bool AllReduceGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs) {
if (!reqs) return false;
memset(reqs, 0, sizeof(*reqs));
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0)
// set devComm reqs for allreduce device kernels
testResult_t AllReduceGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs, ncclCommProperties_t* commProperties) {
if (!reqs || !commProperties) return testInternalError;
switch(deviceImpl) {
case 1: // allReduceLsaKernel
case 2: // allReduceLsaVectorizedKernel
reqs->lsaBarrierCount = deviceCtaCount;
return true;
return testSuccess;
case 3: // allReduceMultimemKernel
case 4: // allReduceMultimemVectorizedKernel
if (!commProperties->multimemSupport) {
fprintf(stderr, "This test requires multimem support, but multimem support is not enabled for this communicator.\n");
return testInternalError;
}
reqs->lsaMultimem = true;
reqs->lsaBarrierCount = deviceCtaCount;
return testSuccess;
default:
return testNotImplemented;
}
}
#elif NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
bool AllReduceGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs) {
if (!reqs) return false;
memset(reqs, 0, sizeof(*reqs));
switch(deviceImpl) {
case 1: // allReduceLsaKernel
case 2: // allReduceLsaVectorizedKernel
reqs->lsaBarrierCount = deviceCtaCount;
return true;
case 3: // allReduceMultimemKernel
case 4: // allReduceMultimemVectorizedKernelMultimem = true;
reqs->lsaMultimem = true;
reqs->lsaBarrierCount = deviceCtaCount;
return true;
default:
return false;
}
}
}
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
/*
* Kernel 1: allReduceLsaKernel - Basic LSA-based AllReduce
*

View File

@ -51,7 +51,30 @@ void AlltoAllGetBw(size_t count, int typesize, double sec, double* algBw, double
*busBw = baseBw * factor;
}
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0)
// set devComm reqs for alltoall device kernels
testResult_t AlltoAllGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs, ncclCommProperties_t* commProperties) {
if (!reqs || !commProperties) return testInternalError;
switch(deviceImpl) {
case 1: // NvlAlltoAllKernel
case 2: // NvlAlltoAllKernelOptimized
reqs->lsaBarrierCount = deviceCtaCount;
return testSuccess;
case 3: // GinAlltoAllKernel
case 4: // HybridAlltoAllKernel (LSA+GIN)
if (commProperties->ginType == NCCL_GIN_TYPE_NONE) {
fprintf(stderr, "This test requires GIN support, but GIN support is not enabled for this communicator.\n");
return testInternalError;
}
reqs->barrierCount = deviceCtaCount;
reqs->ginSignalCount = deviceCtaCount;
return testSuccess;
default:
return testNotImplemented;
}
}
#elif NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
// set devComm reqs for alltoall device kernels
bool AlltoAllGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs) {
if (!reqs) return false;
@ -73,7 +96,9 @@ bool AlltoAllGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* req
return false;
}
}
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
// shared scalar AlltoAll implementation used by both kernels
template <typename T>
__device__ void AlltoAllScalarImpl(ncclWindow_t sendwin, size_t sendoffset, ncclWindow_t recvwin, size_t recvoffset, size_t count, int rank, int nRanks, int tid, int nthreads) {

View File

@ -813,12 +813,23 @@ testResult_t threadInit(struct threadArgs* args) {
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
/* Create device communicators based on test-specific requirements */
if (deviceImpl) {
ncclDevCommRequirements reqs;
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0)
ncclDevCommRequirements reqs = NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER;
if (!ncclTestEngine.getDevCommRequirements) {
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
return testNotImplemented;
}
ncclCommProperties commProperties = NCCL_COMM_PROPERTIES_INITIALIZER;
NCCLCHECK(ncclCommQueryProperties(args->comms[0], &commProperties));
TESTCHECK(ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs, &commProperties));
#else
ncclDevCommRequirements reqs = {};
if (!ncclTestEngine.getDevCommRequirements ||
!ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) {
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
return testNotImplemented;
}
#endif
NCCLCHECK(ncclGroupStart());
for (int i = 0; i < args->nGpus; i++) {
@ -1384,12 +1395,23 @@ testResult_t run() {
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
/* Create device communicators based on test-specific requirements */
if (deviceImpl) {
ncclDevCommRequirements reqs;
if (!ncclTestEngine.getDevCommRequirements ||
!ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) {
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
return testNotImplemented;
}
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0)
ncclDevCommRequirements reqs = NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER;
if (!ncclTestEngine.getDevCommRequirements) {
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
return testNotImplemented;
}
ncclCommProperties commProperties = NCCL_COMM_PROPERTIES_INITIALIZER;
NCCLCHECK(ncclCommQueryProperties(comms[0], &commProperties));
TESTCHECK(ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs, &commProperties));
#else
ncclDevCommRequirements reqs = {};
if (!ncclTestEngine.getDevCommRequirements ||
!ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) {
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
return testNotImplemented;
}
#endif
NCCLCHECK(ncclGroupStart());
for (int i = 0; i < nGpus * nThreads; i++) {

View File

@ -111,7 +111,10 @@ struct testEngine {
void (*getBuffSize)(size_t *sendcount, size_t *recvcount, size_t count, int nranks);
testResult_t (*runTest)(struct threadArgs* args, int root, ncclDataType_t type,
const char* typeName, ncclRedOp_t op, const char* opName);
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0)
testResult_t (*getDevCommRequirements)(int deviceImpl, ncclDevCommRequirements* reqs, ncclCommProperties_t* commProperties);
#elif NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
bool (*getDevCommRequirements)(int deviceImpl, ncclDevCommRequirements* reqs);
#endif
};