diff --git a/src/common.cu b/src/common.cu index b8afa8d..4da475f 100644 --- a/src/common.cu +++ b/src/common.cu @@ -23,6 +23,7 @@ #pragma weak ncclCommWindowDeregister #pragma weak ncclDevCommCreate #pragma weak ncclDevCommDestroy +#pragma weak ncclCommQueryProperties #define DIVUP(x, y) \ (((x)+(y)-1)/(y)) @@ -814,6 +815,13 @@ testResult_t threadInit(struct threadArgs* args) { /* Create device communicators based on test-specific requirements */ if (deviceImpl) { #if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0) + if (test_ncclVersion < NCCL_VERSION(2,29,0)) { + fprintf(stderr, + "Incompatible NCCL versions. nccl-tests was compiled with NCCL %d, but is running with NCCL %d. " + "The %d Device API is not compatible with versions before 2.29.\n", + NCCL_VERSION_CODE, test_ncclVersion, NCCL_VERSION_CODE); + return testInvalidUsage; + } ncclDevCommRequirements reqs = NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER; if (!ncclTestEngine.getDevCommRequirements) { fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl); @@ -823,6 +831,12 @@ testResult_t threadInit(struct threadArgs* args) { NCCLCHECK(ncclCommQueryProperties(args->comms[0], &commProperties)); TESTCHECK(ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs, &commProperties)); #else + if (test_ncclVersion >= NCCL_VERSION(2,29,0)) { + fprintf(stderr, "Incompatible NCCL versions. nccl-tests was compiled with NCCL 2.28, but is running with NCCL %d. " + "The 2.28 Device API is not compatible with later.\n", + test_ncclVersion); + return testInvalidUsage; + } ncclDevCommRequirements reqs = {}; if (!ncclTestEngine.getDevCommRequirements || !ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) { @@ -1396,21 +1410,33 @@ testResult_t run() { /* Create device communicators based on test-specific requirements */ if (deviceImpl) { #if NCCL_VERSION_CODE >= NCCL_VERSION(2,29,0) - ncclDevCommRequirements reqs = NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER; - if (!ncclTestEngine.getDevCommRequirements) { - fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl); - return testNotImplemented; - } - ncclCommProperties commProperties = NCCL_COMM_PROPERTIES_INITIALIZER; - NCCLCHECK(ncclCommQueryProperties(comms[0], &commProperties)); - TESTCHECK(ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs, &commProperties)); + if (test_ncclVersion < NCCL_VERSION(2,29,0)) { + fprintf(stderr, + "Incompatible NCCL versions. nccl-tests was compiled with NCCL %d, but is running with NCCL %d. " + "The %d Device API is not compatible with versions before 2.29.\n", + NCCL_VERSION_CODE, test_ncclVersion, NCCL_VERSION_CODE); + return testInvalidUsage; + } + ncclDevCommRequirements reqs = NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER; + if (!ncclTestEngine.getDevCommRequirements) { + fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl); + return testNotImplemented; + } + ncclCommProperties commProperties = NCCL_COMM_PROPERTIES_INITIALIZER; + NCCLCHECK(ncclCommQueryProperties(comms[0], &commProperties)); + TESTCHECK(ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs, &commProperties)); #else - ncclDevCommRequirements reqs = {}; - if (!ncclTestEngine.getDevCommRequirements || - !ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) { - fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl); - return testNotImplemented; - } + if (test_ncclVersion >= NCCL_VERSION(2,29,0)) { + fprintf(stderr, "Incompatible NCCL versions. nccl-tests was compiled with NCCL 2.28, but is running with NCCL %d. " + "The 2.28 Device API is not compatible with later versions.\n", test_ncclVersion); + return testInvalidUsage; + } + ncclDevCommRequirements reqs = {}; + if (!ncclTestEngine.getDevCommRequirements || + !ncclTestEngine.getDevCommRequirements(deviceImpl, &reqs)) { + fprintf(stderr, "Device implementation %d is not supported by this test\n", deviceImpl); + return testNotImplemented; + } #endif NCCLCHECK(ncclGroupStart()); diff --git a/src/common.h b/src/common.h index ab9d3cb..8e3b516 100644 --- a/src/common.h +++ b/src/common.h @@ -72,7 +72,8 @@ typedef enum { testNcclError = 3, testTimeout = 4, testNotImplemented = 5, - testNumResults = 6 + testInvalidUsage = 6, + testNumResults = 7, // Must be last } testResult_t; // Relay errors up and trace