mirror of
https://github.com/NVIDIA/nccl-tests.git
synced 2026-04-23 16:08:20 +08:00
add GIN-based device API kernels to alltoall
- add GIN-only A2A kernel implementation - add hybrid LSA+GIN A2A kernel implementation - update perf test cases to expose a function for setting devCommRequirements for each device implementation and simplify devCommCreate code path to use this directly instead of complex fallback logic - add missing call to devCommDestroy
This commit is contained in:
parent
00f52811b8
commit
9829ea42b5
@ -90,8 +90,8 @@ testResult_t AllGatherRunTest(struct threadArgs* args, int root, ncclDataType_t
|
||||
}
|
||||
|
||||
struct testEngine allGatherEngine = {
|
||||
AllGatherGetBuffSize,
|
||||
AllGatherRunTest
|
||||
.getBuffSize = AllGatherGetBuffSize,
|
||||
.runTest = AllGatherRunTest
|
||||
};
|
||||
|
||||
#pragma weak ncclTestEngine=allGatherEngine
|
||||
|
||||
@ -65,6 +65,26 @@ void AllReduceGetBw(size_t count, int typesize, double sec, double* algBw, doubl
|
||||
}
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
// set devComm reqs for allreduce device kernels
|
||||
bool AllReduceGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs) {
|
||||
if (!reqs) return false;
|
||||
memset(reqs, 0, sizeof(*reqs));
|
||||
|
||||
switch(deviceImpl) {
|
||||
case 1: // allReduceLsaKernel
|
||||
case 2: // allReduceLsaVectorizedKernel
|
||||
reqs->lsaBarrierCount = deviceCtaCount;
|
||||
return true;
|
||||
case 3: // allReduceMultimemKernel
|
||||
case 4: // allReduceMultimemVectorizedKernel
|
||||
reqs->lsaMultimem = true;
|
||||
reqs->lsaBarrierCount = deviceCtaCount;
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Kernel 1: allReduceLsaKernel - Basic LSA-based AllReduce
|
||||
*
|
||||
@ -453,19 +473,19 @@ testResult_t AllReduceRunColl(void* sendbuff, size_t sendoffset, void* recvbuff,
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
case 1:
|
||||
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(allReduceLsaKernel, type, op),
|
||||
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 0));
|
||||
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
|
||||
return testSuccess;
|
||||
case 2:
|
||||
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(allReduceLsaVectorizedKernel, type, op),
|
||||
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 0));
|
||||
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
|
||||
return testSuccess;
|
||||
case 3:
|
||||
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(allReduceMultimemKernel, type, op),
|
||||
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 1));
|
||||
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
|
||||
return testSuccess;
|
||||
case 4:
|
||||
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(allReduceMultimemVectorizedKernel, type, op),
|
||||
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 1));
|
||||
sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
|
||||
return testSuccess;
|
||||
#endif
|
||||
}
|
||||
@ -522,8 +542,11 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t
|
||||
}
|
||||
|
||||
struct testEngine allReduceEngine = {
|
||||
AllReduceGetBuffSize,
|
||||
AllReduceRunTest
|
||||
.getBuffSize = AllReduceGetBuffSize,
|
||||
.runTest = AllReduceRunTest,
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
.getDevCommRequirements = AllReduceGetDevCommRequirements
|
||||
#endif
|
||||
};
|
||||
|
||||
#pragma weak ncclTestEngine=allReduceEngine
|
||||
|
||||
115
src/alltoall.cu
115
src/alltoall.cu
@ -50,6 +50,26 @@ void AlltoAllGetBw(size_t count, int typesize, double sec, double* algBw, double
|
||||
}
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
// set devComm reqs for alltoall device kernels
|
||||
bool AlltoAllGetDevCommRequirements(int deviceImpl, ncclDevCommRequirements* reqs) {
|
||||
if (!reqs) return false;
|
||||
memset(reqs, 0, sizeof(*reqs));
|
||||
|
||||
switch(deviceImpl) {
|
||||
case 1: // NvlAlltoAllKernel
|
||||
case 2: // NvlAlltoAllKernelOptimized
|
||||
reqs->lsaBarrierCount = deviceCtaCount;
|
||||
return true;
|
||||
case 3: // GinAlltoAllKernel
|
||||
case 4: // HybridAlltoAllKernel (LSA+GIN)
|
||||
reqs->barrierCount = deviceCtaCount;
|
||||
reqs->ginSignalCount = deviceCtaCount;
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// shared scalar AlltoAll implementation used by both kernels
|
||||
template <typename T>
|
||||
__device__ void AlltoAllScalarImpl(ncclWindow_t sendwin, size_t sendoffset, ncclWindow_t recvwin, size_t recvoffset, size_t count, int rank, int nRanks, int tid, int nthreads) {
|
||||
@ -159,6 +179,84 @@ __global__ void NvlAlltoAllKernelOptimized(ncclWindow_t sendwin, size_t sendoffs
|
||||
|
||||
bar.sync(ncclCoopCta(), cuda::memory_order_release);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void GinAlltoAllKernel(ncclWindow_t sendwin, size_t sendoffset, ncclWindow_t recvwin, size_t recvoffset, size_t count, int root, struct ncclDevComm devComm) {
|
||||
int ginContext = 0;
|
||||
unsigned int signalIndex = 0;
|
||||
ncclGin gin { devComm, ginContext };
|
||||
uint64_t signalValue = gin.readSignal(signalIndex);
|
||||
|
||||
ncclBarrierSession<ncclCoopCta> bar { ncclCoopCta(), ncclTeamTagWorld(), gin, blockIdx.x };
|
||||
bar.sync(ncclCoopCta(), cuda::memory_order_relaxed, ncclGinFenceLevel::Relaxed);
|
||||
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int nthreads = blockDim.x * gridDim.x;
|
||||
|
||||
/* send to all peers via GIN */
|
||||
const size_t size = count * sizeof(T);
|
||||
for (int r=tid; r<devComm.nRanks; r+=nthreads) {
|
||||
gin.put(ncclTeamWorld(devComm), r,
|
||||
recvwin, recvoffset + devComm.rank * size,
|
||||
sendwin, sendoffset + r * size,
|
||||
size, ncclGin_SignalInc{signalIndex});
|
||||
}
|
||||
|
||||
gin.waitSignal(ncclCoopCta(), signalIndex, signalValue + devComm.nRanks);
|
||||
gin.flush(ncclCoopCta());
|
||||
|
||||
bar.sync(ncclCoopCta(), cuda::memory_order_release, ncclGinFenceLevel::Relaxed);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void HybridAlltoAllKernel(ncclWindow_t sendwin, size_t sendoffset, ncclWindow_t recvwin, size_t recvoffset, size_t count, int root, struct ncclDevComm devComm) {
|
||||
int ginContext = 0;
|
||||
unsigned int signalIndex = 0;
|
||||
ncclGin gin { devComm, ginContext };
|
||||
uint64_t signalValue = gin.readSignal(signalIndex);
|
||||
|
||||
ncclBarrierSession<ncclCoopCta> bar { ncclCoopCta(), ncclTeamTagWorld(), gin, blockIdx.x };
|
||||
bar.sync(ncclCoopCta(), cuda::memory_order_relaxed, ncclGinFenceLevel::Relaxed);
|
||||
|
||||
int tid = threadIdx.x + blockIdx.x*blockDim.x;
|
||||
int nthreads = blockDim.x * gridDim.x;
|
||||
|
||||
ncclTeam world = ncclTeamWorld(devComm);
|
||||
ncclTeam lsa = ncclTeamLsa(devComm);
|
||||
const int startLsa = world.rank - lsa.rank;
|
||||
const int lsaSize = lsa.nRanks;
|
||||
|
||||
/* handle remote peers (i.e., non-LSA) using GIN */
|
||||
const size_t size = count * sizeof(T);
|
||||
for (int r = tid; r < startLsa; r += nthreads) {
|
||||
gin.put(world, r,
|
||||
recvwin, recvoffset + world.rank * size,
|
||||
sendwin, sendoffset + r * size,
|
||||
size, ncclGin_SignalInc{signalIndex});
|
||||
}
|
||||
for (int r = startLsa + lsaSize + tid; r < world.nRanks; r += nthreads) {
|
||||
gin.put(world, r,
|
||||
recvwin, recvoffset + world.rank * size,
|
||||
sendwin, sendoffset + r * size,
|
||||
size, ncclGin_SignalInc{signalIndex});
|
||||
}
|
||||
|
||||
/* handle local peers with LSA */
|
||||
T* sendLocal = (T*)ncclGetLocalPointer(sendwin, sendoffset);
|
||||
for (size_t offset = tid; offset < count; offset += nthreads) {
|
||||
for (int lp = 0; lp < lsa.nRanks; lp++) {
|
||||
int wr = startLsa + lp;
|
||||
T* recvPtr = (T*)ncclGetLsaPointer(recvwin, recvoffset, lp);
|
||||
recvPtr[world.rank * count + offset] = sendLocal[wr * count + offset];
|
||||
}
|
||||
}
|
||||
|
||||
int numRemotePeers = world.nRanks - lsa.nRanks;
|
||||
gin.waitSignal(ncclCoopCta(), signalIndex, signalValue + numRemotePeers);
|
||||
gin.flush(ncclCoopCta());
|
||||
|
||||
bar.sync(ncclCoopCta(), cuda::memory_order_release, ncclGinFenceLevel::Relaxed);
|
||||
}
|
||||
#endif
|
||||
|
||||
testResult_t AlltoAllRunColl(void* sendbuff, size_t sendoffset, void* recvbuff, size_t recvoffset, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, int deviceImpl) {
|
||||
@ -184,10 +282,16 @@ testResult_t AlltoAllRunColl(void* sendbuff, size_t sendoffset, void* recvbuff,
|
||||
} else {
|
||||
switch(deviceImpl) {
|
||||
case 1:
|
||||
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(NvlAlltoAllKernel, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 0));
|
||||
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(NvlAlltoAllKernel, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
|
||||
return testSuccess;
|
||||
case 2:
|
||||
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(NvlAlltoAllKernelOptimized, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream, 0));
|
||||
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(NvlAlltoAllKernelOptimized, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
|
||||
return testSuccess;
|
||||
case 3:
|
||||
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(GinAlltoAllKernel, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
|
||||
return testSuccess;
|
||||
case 4:
|
||||
TESTCHECK(testLaunchDeviceKernel(SPECIALIZE_KERNEL(HybridAlltoAllKernel, type, op), sendbuff, sendoffset, recvbuff, recvoffset, count, type, op, root, comm, stream));
|
||||
return testSuccess;
|
||||
default:
|
||||
return testNotImplemented;
|
||||
@ -232,8 +336,11 @@ testResult_t AlltoAllRunTest(struct threadArgs* args, int root, ncclDataType_t t
|
||||
}
|
||||
|
||||
struct testEngine alltoAllEngine = {
|
||||
AlltoAllGetBuffSize,
|
||||
AlltoAllRunTest
|
||||
.getBuffSize = AlltoAllGetBuffSize,
|
||||
.runTest = AlltoAllRunTest,
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
.getDevCommRequirements = AlltoAllGetDevCommRequirements
|
||||
#endif
|
||||
};
|
||||
|
||||
#pragma weak ncclTestEngine=alltoAllEngine
|
||||
|
||||
@ -107,8 +107,8 @@ testResult_t BroadcastRunTest(struct threadArgs* args, int root, ncclDataType_t
|
||||
}
|
||||
|
||||
struct testEngine broadcastEngine = {
|
||||
BroadcastGetBuffSize,
|
||||
BroadcastRunTest
|
||||
.getBuffSize = BroadcastGetBuffSize,
|
||||
.runTest = BroadcastRunTest
|
||||
};
|
||||
|
||||
#pragma weak ncclTestEngine=broadcastEngine
|
||||
|
||||
@ -100,7 +100,6 @@ static int report_cputime = 0;
|
||||
static int deviceImpl = 0;
|
||||
|
||||
int deviceCtaCount = 16; // Default number of CTAs for device implementation
|
||||
bool deviceMultimemEnabled = false; // Track whether multimem was successfully enabled
|
||||
|
||||
// Report average iteration time: (0=RANK0,1=AVG,2=MIN,3=MAX)
|
||||
static int average = 1;
|
||||
@ -768,49 +767,18 @@ testResult_t threadInit(struct threadArgs* args) {
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
#endif
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
/* Create device communicators with multimem fallback */
|
||||
/* Create device communicators based on test-specific requirements */
|
||||
if (deviceImpl) {
|
||||
// Duplicate comms so our checks here do not affect the originals
|
||||
ncclComm_t tmpComms[args->nGpus];
|
||||
memset(tmpComms, 0, sizeof(tmpComms));
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
for (int i = 0; i < args->nGpus; i++) {
|
||||
int rank;
|
||||
NCCLCHECK(ncclCommUserRank(args->comms[i], &rank));
|
||||
NCCLCHECK(ncclCommSplit(args->comms[i], 0, rank, &tmpComms[i], NULL));
|
||||
ncclDevCommRequirements reqs;
|
||||
if (!ncclTestEngine.getDevCommRequirements ||
|
||||
!ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) {
|
||||
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
|
||||
return testNotImplemented;
|
||||
}
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
|
||||
// Check multimem support on the duplicated comms
|
||||
bool checkMultimemFailed = false;
|
||||
ncclResult_t result;
|
||||
ncclDevComm tmpDevComms[args->nGpus];
|
||||
memset(tmpDevComms, 0, sizeof(tmpDevComms));
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
for (int i = 0; i < args->nGpus; i++) {
|
||||
ncclDevCommRequirements reqs;
|
||||
memset(&reqs, 0, sizeof(reqs));
|
||||
reqs.lsaBarrierCount = deviceCtaCount;
|
||||
reqs.lsaMultimem = true;
|
||||
result = ncclDevCommCreate(tmpComms[i], &reqs, &tmpDevComms[i]);
|
||||
if (result != ncclInProgress && result != ncclSuccess) {
|
||||
checkMultimemFailed = true;
|
||||
}
|
||||
}
|
||||
result = ncclGroupEnd();
|
||||
if (result != ncclSuccess) checkMultimemFailed = true;
|
||||
deviceMultimemEnabled = !checkMultimemFailed;
|
||||
|
||||
// Create final dev comms with correct multimem setting and cleanup temps
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
for (int i = 0; i < args->nGpus; i++) {
|
||||
ncclDevCommRequirements reqs;
|
||||
memset(&reqs, 0, sizeof(reqs));
|
||||
reqs.lsaBarrierCount = deviceCtaCount;
|
||||
reqs.lsaMultimem = deviceMultimemEnabled;
|
||||
NCCLCHECK(ncclDevCommCreate(args->comms[i], &reqs, args->devComms+i));
|
||||
NCCLCHECK(ncclDevCommDestroy(tmpComms[i], &tmpDevComms[i]));
|
||||
NCCLCHECK(ncclCommDestroy(tmpComms[i]));
|
||||
}
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
}
|
||||
@ -1320,49 +1288,18 @@ testResult_t run() {
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
#endif
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
/* Create device communicators with multimem fallback */
|
||||
/* Create device communicators based on test-specific requirements */
|
||||
if (deviceImpl) {
|
||||
// Duplicate comms so our checks here do not affect the originals
|
||||
ncclComm_t tmpComms[nGpus * nThreads];
|
||||
memset(tmpComms, 0, sizeof(tmpComms));
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
for (int i = 0; i < nGpus * nThreads; i++) {
|
||||
int rank;
|
||||
NCCLCHECK(ncclCommUserRank(comms[i], &rank));
|
||||
NCCLCHECK(ncclCommSplit(comms[i], 0, rank, &tmpComms[i], NULL));
|
||||
ncclDevCommRequirements reqs;
|
||||
if (!ncclTestEngine.getDevCommRequirements ||
|
||||
!ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) {
|
||||
fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl);
|
||||
return testNotImplemented;
|
||||
}
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
|
||||
// Check multimem support on the duplicated comms
|
||||
bool checkMultimemFailed = false;
|
||||
ncclResult_t result;
|
||||
ncclDevComm tmpDevComms[nGpus * nThreads];
|
||||
memset(tmpDevComms, 0, sizeof(tmpDevComms));
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
for (int i = 0; i < nGpus * nThreads; i++) {
|
||||
ncclDevCommRequirements reqs;
|
||||
memset(&reqs, 0, sizeof(reqs));
|
||||
reqs.lsaBarrierCount = deviceCtaCount;
|
||||
reqs.lsaMultimem = true;
|
||||
result = ncclDevCommCreate(tmpComms[i], &reqs, &tmpDevComms[i]);
|
||||
if (result != ncclInProgress && result != ncclSuccess) {
|
||||
checkMultimemFailed = true;
|
||||
}
|
||||
}
|
||||
result = ncclGroupEnd();
|
||||
if (result != ncclSuccess) checkMultimemFailed = true;
|
||||
deviceMultimemEnabled = !checkMultimemFailed;
|
||||
|
||||
// Create final dev comms with correct multimem setting and cleanup temps
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
for (int i = 0; i < nGpus * nThreads; i++) {
|
||||
ncclDevCommRequirements reqs;
|
||||
memset(&reqs, 0, sizeof(reqs));
|
||||
reqs.lsaBarrierCount = deviceCtaCount;
|
||||
reqs.lsaMultimem = deviceMultimemEnabled;
|
||||
NCCLCHECK(ncclDevCommCreate(comms[i], &reqs, devComms+i));
|
||||
NCCLCHECK(ncclDevCommDestroy(tmpComms[i], &tmpDevComms[i]));
|
||||
NCCLCHECK(ncclCommDestroy(tmpComms[i]));
|
||||
}
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
}
|
||||
|
||||
21
src/common.h
21
src/common.h
@ -109,6 +109,9 @@ struct testEngine {
|
||||
void (*getBuffSize)(size_t *sendcount, size_t *recvcount, size_t count, int nranks);
|
||||
testResult_t (*runTest)(struct threadArgs* args, int root, ncclDataType_t type,
|
||||
const char* typeName, ncclRedOp_t op, const char* opName);
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
bool (*getDevCommRequirements)(int deviceImpl, ncclDevCommRequirements* reqs);
|
||||
#endif
|
||||
};
|
||||
|
||||
extern struct testEngine ncclTestEngine;
|
||||
@ -276,7 +279,6 @@ static size_t wordSize(ncclDataType_t type) {
|
||||
|
||||
extern int test_ncclVersion; // init'd with ncclGetVersion()
|
||||
extern int deviceCtaCount; // number of CTAs for device implementation
|
||||
extern bool deviceMultimemEnabled; // whether multimem was successfully enabled
|
||||
constexpr int test_opNumMax = (int)ncclNumOps + (NCCL_VERSION_CODE >= NCCL_VERSION(2,11,0) ? 1 : 0);
|
||||
extern int test_opnum;
|
||||
extern int test_typenum;
|
||||
@ -317,23 +319,10 @@ extern thread_local int is_main_thread;
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0)
|
||||
template <typename F>
|
||||
testResult_t testLaunchDeviceKernel(F kernel, void* sendbuff, size_t sendoffset, void* recvbuff, size_t recvoffset, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, int useMultimem) {
|
||||
testResult_t testLaunchDeviceKernel(F kernel, void* sendbuff, size_t sendoffset, void* recvbuff, size_t recvoffset, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
||||
if (kernel == nullptr) return testNotImplemented;
|
||||
ncclDevComm* devComm = (ncclDevComm*)comm;
|
||||
|
||||
// Check if multimem is enabled for this kernel
|
||||
if (useMultimem && !deviceMultimemEnabled) {
|
||||
printf("[KERNEL_LAUNCH_ERROR] Device kernel requires multimem but it was not available during "
|
||||
"DevComm creation. Multimem support may not be available on this hardware.\n");
|
||||
return testInternalError;
|
||||
}
|
||||
|
||||
// Only check mcBasePtr if multimem is active for this kernel
|
||||
if (useMultimem && devComm->lsaMultimem.mcBasePtr == nullptr) {
|
||||
printf("[KERNEL_LAUNCH_ERROR] Device kernel requires multimem, which may not be available.\n");
|
||||
return testInternalError;
|
||||
}
|
||||
|
||||
ncclWindow_t sendwin = (ncclWindow_t)sendbuff;
|
||||
ncclWindow_t recvwin = (ncclWindow_t)recvbuff;
|
||||
kernel<<<deviceCtaCount, 512, 0, stream>>>(sendwin, sendoffset, recvwin, recvoffset, count, root, *devComm);
|
||||
@ -355,7 +344,7 @@ testResult_t testLaunchDeviceKernel(F kernel, void* sendbuff, size_t sendoffset,
|
||||
)
|
||||
#else
|
||||
template <typename F>
|
||||
testResult_t testLaunchDeviceKernel(F kernel, void* sendbuff, size_t sendoffset, void* recvbuff, size_t recvoffset, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream, int useMultimem) {
|
||||
testResult_t testLaunchDeviceKernel(F kernel, void* sendbuff, size_t sendoffset, void* recvbuff, size_t recvoffset, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
||||
return testNotImplemented;
|
||||
}
|
||||
#define SPECIALIZE_KERNEL(kernel, type, op) nullptr
|
||||
|
||||
@ -121,8 +121,8 @@ testResult_t GatherRunTest(struct threadArgs* args, int root, ncclDataType_t typ
|
||||
}
|
||||
|
||||
struct testEngine gatherEngine = {
|
||||
GatherGetBuffSize,
|
||||
GatherRunTest
|
||||
.getBuffSize = GatherGetBuffSize,
|
||||
.runTest = GatherRunTest
|
||||
};
|
||||
|
||||
#pragma weak ncclTestEngine=gatherEngine
|
||||
|
||||
@ -115,8 +115,8 @@ testResult_t HyperCubeRunTest(struct threadArgs* args, int root, ncclDataType_t
|
||||
}
|
||||
|
||||
struct testEngine hyperCubeEngine = {
|
||||
HyperCubeGetBuffSize,
|
||||
HyperCubeRunTest
|
||||
.getBuffSize = HyperCubeGetBuffSize,
|
||||
.runTest = HyperCubeRunTest
|
||||
};
|
||||
|
||||
#pragma weak ncclTestEngine=hyperCubeEngine
|
||||
|
||||
@ -109,8 +109,8 @@ testResult_t ReduceRunTest(struct threadArgs* args, int root, ncclDataType_t typ
|
||||
}
|
||||
|
||||
struct testEngine reduceEngine = {
|
||||
ReduceGetBuffSize,
|
||||
ReduceRunTest
|
||||
.getBuffSize = ReduceGetBuffSize,
|
||||
.runTest = ReduceRunTest
|
||||
};
|
||||
|
||||
#pragma weak ncclTestEngine=reduceEngine
|
||||
|
||||
@ -102,8 +102,8 @@ testResult_t ReduceScatterRunTest(struct threadArgs* args, int root, ncclDataTyp
|
||||
}
|
||||
|
||||
struct testEngine reduceScatterEngine = {
|
||||
ReduceScatterGetBuffSize,
|
||||
ReduceScatterRunTest
|
||||
.getBuffSize = ReduceScatterGetBuffSize,
|
||||
.runTest = ReduceScatterRunTest
|
||||
};
|
||||
|
||||
#pragma weak ncclTestEngine=reduceScatterEngine
|
||||
|
||||
@ -117,8 +117,8 @@ testResult_t ScatterRunTest(struct threadArgs* args, int root, ncclDataType_t ty
|
||||
}
|
||||
|
||||
struct testEngine scatterEngine = {
|
||||
ScatterGetBuffSize,
|
||||
ScatterRunTest
|
||||
.getBuffSize = ScatterGetBuffSize,
|
||||
.runTest = ScatterRunTest
|
||||
};
|
||||
|
||||
#pragma weak ncclTestEngine=scatterEngine
|
||||
|
||||
@ -113,8 +113,8 @@ testResult_t SendRecvRunTest(struct threadArgs* args, int root, ncclDataType_t t
|
||||
}
|
||||
|
||||
struct testEngine sendRecvEngine = {
|
||||
SendRecvGetBuffSize,
|
||||
SendRecvRunTest
|
||||
.getBuffSize = SendRecvGetBuffSize,
|
||||
.runTest = SendRecvRunTest
|
||||
};
|
||||
|
||||
#pragma weak ncclTestEngine=sendRecvEngine
|
||||
|
||||
Loading…
Reference in New Issue
Block a user