diff --git a/src/common.cu b/src/common.cu index 6d103d7..9277ea2 100644 --- a/src/common.cu +++ b/src/common.cu @@ -10,6 +10,8 @@ #include #include #include +#include +#include #include "cuda.h" #include "../verifiable/verifiable.h" @@ -892,6 +894,26 @@ int main(int argc, char* argv[]) { return 0; } +#ifdef MPI_SUPPORT +// parse int for base 2/10/16, will ignore first whitespaces +static bool parseInt(char *s, int *num) { + char *p = NULL; + if (!s || !num) + return false; + while (*s && isspace(*s)) ++s; + if (!*s) return false; + + if (strncasecmp(s, "0b", 2) == 0) + *num = (int)strtoul(s + 2, &p, 2); + else + *num = (int)strtoul(s, &p, 0); + + if (p == s) + return false; + return true; +} +#endif + testResult_t run() { int totalProcs = 1, proc = 0, ncclProcs = 1, ncclProc = 0, color = 0; int localRank = 0; @@ -909,10 +931,33 @@ testResult_t run() { if (hostHashs[p] == hostHashs[proc]) localRank++; } - char* str = getenv("NCCL_TESTS_SPLIT_MASK"); - uint64_t mask = str ? strtoul(str, NULL, 16) : 0; + char *splitMaskEnv = NULL; + if (splitMaskEnv = getenv("NCCL_TESTS_SPLIT_MASK")) { + color = proc & strtoul(splitMaskEnv, NULL, 16); + } else if (splitMaskEnv = getenv("NCCL_TESTS_SPLIT")) { + if ( + (strncasecmp(splitMaskEnv, "AND", strlen("AND")) == 0 && parseInt(splitMaskEnv + strlen("AND"), &color)) || + (strncasecmp(splitMaskEnv, "&", strlen("&")) == 0 && parseInt(splitMaskEnv + strlen("&"), &color)) + ) + color = proc & color; + if ( + (strncasecmp(splitMaskEnv, "OR", strlen("OR")) == 0 && parseInt(splitMaskEnv + strlen("OR"), &color)) || + (strncasecmp(splitMaskEnv, "|", strlen("|")) == 0 && parseInt(splitMaskEnv + strlen("|"), &color)) + ) + color = proc | color; + if ( + (strncasecmp(splitMaskEnv, "MOD", strlen("MOD")) == 0 && parseInt(splitMaskEnv + strlen("MOD"), &color)) || + (strncasecmp(splitMaskEnv, "%", strlen("%")) == 0 && parseInt(splitMaskEnv + strlen("%"), &color)) + ) + color = proc % color; + if ( + (strncasecmp(splitMaskEnv, "DIV", strlen("DIV")) == 0 && parseInt(splitMaskEnv + strlen("DIV"), &color)) || + (strncasecmp(splitMaskEnv, "/", strlen("/")) == 0 && parseInt(splitMaskEnv + strlen("/"), &color)) + ) + color = proc / color; + } + MPI_Comm mpi_comm; - color = proc & mask; MPI_Comm_split(MPI_COMM_WORLD, color, proc, &mpi_comm); MPI_Comm_size(mpi_comm, &ncclProcs); MPI_Comm_rank(mpi_comm, &ncclProc);