add GIN-based device API kernels to alltoall

- add GIN-only A2A kernel implementation
- add hybrid LSA+GIN A2A kernel implementation
- update perf test cases to expose a function for setting
  devCommRequirements for each device implementation and
  simplify devCommCreate code path to use this directly instead
  of complex fallback logic
- add missing call to devCommDestroy
This commit is contained in:
Shane Snyder 2025-10-06 08:49:00 -05:00 committed by David Addison
parent 00f52811b8
commit 9829ea42b5
12 changed files with 173 additions and 117 deletions

View File

@ -90,8 +90,8 @@ testResult_t AllGatherRunTest(struct threadArgs* args, int root, ncclDataType_t
}
struct testEngine allGatherEngine = {
AllGatherGetBuffSize,
AllGatherRunTest
.getBuffSize = AllGatherGetBuffSize,
.runTest = AllGatherRunTest
};
#pragma weak ncclTestEngine=allGatherEngine

View File

@ -65,6 +65,26 @@ void AllReduceGetBw(size_t count, int typesize, double sec, double* algBw, doubl
}
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
// set devComm reqs for allreduce device kernels
bool AllReduceGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs) {
if (!reqs) return false;
memset(reqs, 0, sizeof(*reqs));
switch(deviceImpl) {
case 1: // allReduceLsaKernel
case 2: // allReduceLsaVectorizedKernel
reqs->lsaBarrierCount = deviceCtaCount;
return true;
case 3: // allReduceMultimemKernel
case 4: // allReduceMultimemVectorizedKernel
reqs->lsaMultimem = true;
reqs->lsaBarrierCount = deviceCtaCount;
return true;
default:
return false;
}
}
/*
* Kernel 1: allReduceLsaKernel - Basic LSA-based AllReduce
*
@ -453,19 +473,19 @@ testResult_t AllReduceRunColl(void* sendbuff, size_t sendoffset, void* recvbuff,
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
case 1:
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(allReduceLsaKernel, type, op),
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 0));
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
return testSuccess;
case 2:
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(allReduceLsaVectorizedKernel, type, op),
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 0));
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
return testSuccess;
case 3:
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(allReduceMultimemKernel, type, op),
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 1));
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
return testSuccess;
case 4:
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(allReduceMultimemVectorizedKernel, type, op),
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 1));
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
return testSuccess;
#endif
}
@ -522,8 +542,11 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t
}
struct testEngine allReduceEngine = {
AllReduceGetBuffSize,
AllReduceRunTest
.getBuffSize = AllReduceGetBuffSize,
.runTest = AllReduceRunTest,
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
.getDevCommRequirements = AllReduceGetDevCommRequirements
#endif
};
#pragma weak ncclTestEngine=allReduceEngine

View File

@ -50,6 +50,26 @@ void AlltoAllGetBw(size_t count, int typesize, double sec, double* algBw, double
}
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
// set devComm reqs for alltoall device kernels
bool AlltoAllGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs) {
if (!reqs) return false;
memset(reqs, 0, sizeof(*reqs));
switch(deviceImpl) {
case 1: // NvlAlltoAllKernel
case 2: // NvlAlltoAllKernelOptimized
reqs->lsaBarrierCount = deviceCtaCount;
return true;
case 3: // GinAlltoAllKernel
case 4: // HybridAlltoAllKernel (LSA+GIN)
reqs->barrierCount = deviceCtaCount;
reqs->ginSignalCount = deviceCtaCount;
return true;
default:
return false;
}
}
// shared scalar AlltoAll implementation used by both kernels
template <typename T>
__device__ void AlltoAllScalarImpl(ncclWindow_t sendwin, size_t sendoffset, ncclWindow_t recvwin, size_t recvoffset, size_t count, int rank, int nRanks, int tid, int nthreads) {
@ -159,6 +179,84 @@ __global__ void NvlAlltoAllKernelOptimized(ncclWindow_t sendwin, size_t sendoffs
bar.sync(ncclCoopCta(), cuda::memory_order_release);
}
template <typename T>
__global__ void GinAlltoAllKernel(ncclWindow_t sendwin, size_t sendoffset, ncclWindow_t recvwin, size_t recvoffset, size_t count, int root, struct ncclDevComm devComm) {
int ginContext = 0;
unsigned int signalIndex = 0;
ncclGin gin { devComm, ginContext };
uint64_t signalValue = gin.readSignal(signalIndex);
ncclBarrierSession<ncclCoopCta> bar { ncclCoopCta(), ncclTeamTagWorld(), gin, blockIdx.x };
bar.sync(ncclCoopCta(), cuda::memory_order_relaxed, ncclGinFenceLevel::Relaxed);
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int nthreads = blockDim.x * gridDim.x;
/* send to all peers via GIN */
const size_t size = count * sizeof(T);
for (int r=tid; r<devComm.nRanks; r+=nthreads) {
gin.put(ncclTeamWorld(devComm), r,
recvwin, recvoffset + devComm.rank * size,
sendwin, sendoffset + r * size,
size, ncclGin_SignalInc{signalIndex});
}
gin.waitSignal(ncclCoopCta(), signalIndex, signalValue + devComm.nRanks);
gin.flush(ncclCoopCta());
bar.sync(ncclCoopCta(), cuda::memory_order_release, ncclGinFenceLevel::Relaxed);
}
template <typename T>
__global__ void HybridAlltoAllKernel(ncclWindow_t sendwin, size_t sendoffset, ncclWindow_t recvwin, size_t recvoffset, size_t count, int root, struct ncclDevComm devComm) {
int ginContext = 0;
unsigned int signalIndex = 0;
ncclGin gin { devComm, ginContext };
uint64_t signalValue = gin.readSignal(signalIndex);
ncclBarrierSession<ncclCoopCta> bar { ncclCoopCta(), ncclTeamTagWorld(), gin, blockIdx.x };
bar.sync(ncclCoopCta(), cuda::memory_order_relaxed, ncclGinFenceLevel::Relaxed);
int tid = threadIdx.x + blockIdx.x*blockDim.x;
int nthreads = blockDim.x * gridDim.x;
ncclTeam world = ncclTeamWorld(devComm);
ncclTeam lsa = ncclTeamLsa(devComm);
const int startLsa = world.rank - lsa.rank;
const int lsaSize = lsa.nRanks;
/* handle remote peers (i.e., non-LSA) using GIN */
const size_t size = count * sizeof(T);
for (int r = tid; r < startLsa; r += nthreads) {
gin.put(world, r,
recvwin, recvoffset + world.rank * size,
sendwin, sendoffset + r * size,
size, ncclGin_SignalInc{signalIndex});
}
for (int r = startLsa + lsaSize + tid; r < world.nRanks; r += nthreads) {
gin.put(world, r,
recvwin, recvoffset + world.rank * size,
sendwin, sendoffset + r * size,
size, ncclGin_SignalInc{signalIndex});
}
/* handle local peers with LSA */
T* sendLocal = (T*)ncclGetLocalPointer(sendwin, sendoffset);
for (size_t offset = tid; offset < count; offset += nthreads) {
for (int lp = 0; lp < lsa.nRanks; lp++) {
int wr = startLsa + lp;
T* recvPtr = (T*)ncclGetLsaPointer(recvwin, recvoffset, lp);
recvPtr[world.rank * count + offset] = sendLocal[wr * count + offset];
}
}
int numRemotePeers = world.nRanks - lsa.nRanks;
gin.waitSignal(ncclCoopCta(), signalIndex, signalValue + numRemotePeers);
gin.flush(ncclCoopCta());
bar.sync(ncclCoopCta(), cuda::memory_order_release, ncclGinFenceLevel::Relaxed);
}
#endif
testResult_t AlltoAllRunColl(void* sendbuff, size_t sendoffset, void* recvbuff, size_t recvoffset, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, int deviceImpl) {
@ -184,10 +282,16 @@ testResult_t AlltoAllRunColl(void* sendbuff, size_t sendoffset, void* recvbuff,
} else {
switch(deviceImpl) {
case 1:
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(NvlAlltoAllKernel, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 0));
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(NvlAlltoAllKernel, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
return testSuccess;
case 2:
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(NvlAlltoAllKernelOptimized, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 0));
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(NvlAlltoAllKernelOptimized, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
return testSuccess;
case 3:
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(GinAlltoAllKernel, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
return testSuccess;
case 4:
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(HybridAlltoAllKernel, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
return testSuccess;
default:
return testNotImplemented;
@ -232,8 +336,11 @@ testResult_t AlltoAllRunTest(struct threadArgs* args, int root, ncclDataType_t t
}
struct testEngine alltoAllEngine = {
AlltoAllGetBuffSize,
AlltoAllRunTest
.getBuffSize = AlltoAllGetBuffSize,
.runTest = AlltoAllRunTest,
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
.getDevCommRequirements = AlltoAllGetDevCommRequirements
#endif
};
#pragma weak ncclTestEngine=alltoAllEngine

View File

@ -107,8 +107,8 @@ testResult_t BroadcastRunTest(struct threadArgs* args, int root, ncclDataType_t
}
struct testEngine broadcastEngine = {
BroadcastGetBuffSize,
BroadcastRunTest
.getBuffSize = BroadcastGetBuffSize,
.runTest = BroadcastRunTest
};
#pragma weak ncclTestEngine=broadcastEngine

View File

@ -100,7 +100,6 @@ static int report_cputime = 0;
static int deviceImpl = 0;
int deviceCtaCount = 16; // Default number of CTAs for device implementation
bool deviceMultimemEnabled = false; // Track whether multimem was successfully enabled
// Report average iteration time: (0=RANK0,1=AVG,2=MIN,3=MAX)
static int average = 1;
@ -768,49 +767,18 @@ testResult_t threadInit(struct threadArgs* args) {
NCCLCHECK(ncclGroupEnd());
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
/* Create device communicators with multimem fallback */
/* Create device communicators based on test-specific requirements */
if (deviceImpl) {
// Duplicate comms so our checks here do not affect the originals
ncclComm_t tmpComms[args->nGpus];
memset(tmpComms, 0, sizeof(tmpComms));
NCCLCHECK(ncclGroupStart());
for (int i = 0; i < args->nGpus; i++) {
int rank;
NCCLCHECK(ncclCommUserRank(args->comms[i], &rank));
NCCLCHECK(ncclCommSplit(args->comms[i], 0, rank, &tmpComms[i], NULL));
ncclDevCommRequirements reqs;
if (!ncclTestEngine.getDevCommRequirements ||
!ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) {
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
return testNotImplemented;
}
NCCLCHECK(ncclGroupEnd());
// Check multimem support on the duplicated comms
bool checkMultimemFailed = false;
ncclResult_t result;
ncclDevComm tmpDevComms[args->nGpus];
memset(tmpDevComms, 0, sizeof(tmpDevComms));
NCCLCHECK(ncclGroupStart());
for (int i = 0; i < args->nGpus; i++) {
ncclDevCommRequirements reqs;
memset(&reqs, 0, sizeof(reqs));
reqs.lsaBarrierCount = deviceCtaCount;
reqs.lsaMultimem = true;
result = ncclDevCommCreate(tmpComms[i], &reqs, &tmpDevComms[i]);
if (result != ncclInProgress && result != ncclSuccess) {
checkMultimemFailed = true;
}
}
result = ncclGroupEnd();
if (result != ncclSuccess) checkMultimemFailed = true;
deviceMultimemEnabled = !checkMultimemFailed;
// Create final dev comms with correct multimem setting and cleanup temps
NCCLCHECK(ncclGroupStart());
for (int i = 0; i < args->nGpus; i++) {
ncclDevCommRequirements reqs;
memset(&reqs, 0, sizeof(reqs));
reqs.lsaBarrierCount = deviceCtaCount;
reqs.lsaMultimem = deviceMultimemEnabled;
NCCLCHECK(ncclDevCommCreate(args->comms[i], &reqs, args->devComms+i));
NCCLCHECK(ncclDevCommDestroy(tmpComms[i], &tmpDevComms[i]));
NCCLCHECK(ncclCommDestroy(tmpComms[i]));
}
NCCLCHECK(ncclGroupEnd());
}
@ -1320,49 +1288,18 @@ testResult_t run() {
NCCLCHECK(ncclGroupEnd());
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
/* Create device communicators with multimem fallback */
/* Create device communicators based on test-specific requirements */
if (deviceImpl) {
// Duplicate comms so our checks here do not affect the originals
ncclComm_t tmpComms[nGpus * nThreads];
memset(tmpComms, 0, sizeof(tmpComms));
NCCLCHECK(ncclGroupStart());
for (int i = 0; i < nGpus * nThreads; i++) {
int rank;
NCCLCHECK(ncclCommUserRank(comms[i], &rank));
NCCLCHECK(ncclCommSplit(comms[i], 0, rank, &tmpComms[i], NULL));
ncclDevCommRequirements reqs;
if (!ncclTestEngine.getDevCommRequirements ||
!ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) {
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
return testNotImplemented;
}
NCCLCHECK(ncclGroupEnd());
// Check multimem support on the duplicated comms
bool checkMultimemFailed = false;
ncclResult_t result;
ncclDevComm tmpDevComms[nGpus * nThreads];
memset(tmpDevComms, 0, sizeof(tmpDevComms));
NCCLCHECK(ncclGroupStart());
for (int i = 0; i < nGpus * nThreads; i++) {
ncclDevCommRequirements reqs;
memset(&reqs, 0, sizeof(reqs));
reqs.lsaBarrierCount = deviceCtaCount;
reqs.lsaMultimem = true;
result = ncclDevCommCreate(tmpComms[i], &reqs, &tmpDevComms[i]);
if (result != ncclInProgress && result != ncclSuccess) {
checkMultimemFailed = true;
}
}
result = ncclGroupEnd();
if (result != ncclSuccess) checkMultimemFailed = true;
deviceMultimemEnabled = !checkMultimemFailed;
// Create final dev comms with correct multimem setting and cleanup temps
NCCLCHECK(ncclGroupStart());
for (int i = 0; i < nGpus * nThreads; i++) {
ncclDevCommRequirements reqs;
memset(&reqs, 0, sizeof(reqs));
reqs.lsaBarrierCount = deviceCtaCount;
reqs.lsaMultimem = deviceMultimemEnabled;
NCCLCHECK(ncclDevCommCreate(comms[i], &reqs, devComms+i));
NCCLCHECK(ncclDevCommDestroy(tmpComms[i], &tmpDevComms[i]));
NCCLCHECK(ncclCommDestroy(tmpComms[i]));
}
NCCLCHECK(ncclGroupEnd());
}

View File

@ -109,6 +109,9 @@ struct testEngine {
void (*getBuffSize)(size_t *sendcount, size_t *recvcount, size_t count, int nranks);
testResult_t (*runTest)(struct threadArgs* args, int root, ncclDataType_t type,
const char* typeName, ncclRedOp_t op, const char* opName);
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
bool (*getDevCommRequirements)(int deviceImpl, ncclDevCommRequirements* reqs);
#endif
};
extern struct testEngine ncclTestEngine;
@ -276,7 +279,6 @@ static size_t wordSize(ncclDataType_t type) {
extern int test_ncclVersion; // init'd with ncclGetVersion()
extern int deviceCtaCount; // number of CTAs for device implementation
extern bool deviceMultimemEnabled; // whether multimem was successfully enabled
constexpr int test_opNumMax = (int)ncclNumOps + (NCCL_VERSION_CODE >= NCCL_VERSION(2,11,0) ? 1 : 0);
extern int test_opnum;
extern int test_typenum;
@ -317,23 +319,10 @@ extern thread_local int is_main_thread;
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
template <typename F>
testResult_t testLaunchDeviceKernel(F kernel, void* sendbuff, size_t sendoffset, void* recvbuff, size_t recvoffset, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, int useMultimem) {
testResult_t testLaunchDeviceKernel(F kernel, void* sendbuff, size_t sendoffset, void* recvbuff, size_t recvoffset, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
if (kernel == nullptr) return testNotImplemented;
ncclDevComm* devComm = (ncclDevComm*)comm;
// Check if multimem is enabled for this kernel
if (useMultimem && !deviceMultimemEnabled) {
printf("[KERNEL_LAUNCH_ERROR] Device kernel requires multimem but it was not available during "
"DevComm creation. Multimem support may not be available on this hardware.\n");
return testInternalError;
}
// Only check mcBasePtr if multimem is active for this kernel
if (useMultimem && devComm->lsaMultimem.mcBasePtr == nullptr) {
printf("[KERNEL_LAUNCH_ERROR] Device kernel requires multimem, which may not be available.\n");
return testInternalError;
}
ncclWindow_t sendwin = (ncclWindow_t)sendbuff;
ncclWindow_t recvwin = (ncclWindow_t)recvbuff;
kernel<<<deviceCtaCount, 512, 0, stream>>>(sendwin, sendoffset, recvwin, recvoffset, count, root, *devComm);
@ -355,7 +344,7 @@ testResult_t testLaunchDeviceKernel(F kernel, void* sendbuff, size_t sendoffset,
)
#else
template <typename F>
testResult_t testLaunchDeviceKernel(F kernel, void* sendbuff, size_t sendoffset, void* recvbuff, size_t recvoffset, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, int useMultimem) {
testResult_t testLaunchDeviceKernel(F kernel, void* sendbuff, size_t sendoffset, void* recvbuff, size_t recvoffset, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
return testNotImplemented;
}
#define SPECIALIZE_KERNEL(kernel, type, op) nullptr

View File

@ -121,8 +121,8 @@ testResult_t GatherRunTest(struct threadArgs* args, int root, ncclDataType_t typ
}
struct testEngine gatherEngine = {
GatherGetBuffSize,
GatherRunTest
.getBuffSize = GatherGetBuffSize,
.runTest = GatherRunTest
};
#pragma weak ncclTestEngine=gatherEngine

View File

@ -115,8 +115,8 @@ testResult_t HyperCubeRunTest(struct threadArgs* args, int root, ncclDataType_t
}
struct testEngine hyperCubeEngine = {
HyperCubeGetBuffSize,
HyperCubeRunTest
.getBuffSize = HyperCubeGetBuffSize,
.runTest = HyperCubeRunTest
};
#pragma weak ncclTestEngine=hyperCubeEngine

View File

@ -109,8 +109,8 @@ testResult_t ReduceRunTest(struct threadArgs* args, int root, ncclDataType_t typ
}
struct testEngine reduceEngine = {
ReduceGetBuffSize,
ReduceRunTest
.getBuffSize = ReduceGetBuffSize,
.runTest = ReduceRunTest
};
#pragma weak ncclTestEngine=reduceEngine

View File

@ -102,8 +102,8 @@ testResult_t ReduceScatterRunTest(struct threadArgs* args, int root, ncclDataTyp
}
struct testEngine reduceScatterEngine = {
ReduceScatterGetBuffSize,
ReduceScatterRunTest
.getBuffSize = ReduceScatterGetBuffSize,
.runTest = ReduceScatterRunTest
};
#pragma weak ncclTestEngine=reduceScatterEngine

View File

@ -117,8 +117,8 @@ testResult_t ScatterRunTest(struct threadArgs* args, int root, ncclDataType_t ty
}
struct testEngine scatterEngine = {
ScatterGetBuffSize,
ScatterRunTest
.getBuffSize = ScatterGetBuffSize,
.runTest = ScatterRunTest
};
#pragma weak ncclTestEngine=scatterEngine

View File

@ -113,8 +113,8 @@ testResult_t SendRecvRunTest(struct threadArgs* args, int root, ncclDataType_t t
}
struct testEngine sendRecvEngine = {
SendRecvGetBuffSize,
SendRecvRunTest
.getBuffSize = SendRecvGetBuffSize,
.runTest = SendRecvRunTest
};
#pragma weak ncclTestEngine=sendRecvEngine