diff --git a/src/common.cu b/src/common.cu index ff4e1fd..25fc7da 100644 --- a/src/common.cu +++ b/src/common.cu @@ -40,22 +40,40 @@ static int ncclroot = 0; static int parallel_init = 0; static int blocking_coll = 0; -double parsesize(char *value) { +static double parsesize(const char *value) { long long int units; double size; + char size_lit; - if (strchr(value, 'G') != NULL) { - units=1024*1024*1024; - } else if (strchr(value, 'M') != NULL) { - units=1024*1024; - } else if (strchr(value, 'K') != NULL) { - units=1024; - } else { - units=1; + int count = sscanf(value, "%lf %1s", &size, &size_lit); + + switch (count) { + case 2: + switch (size_lit) { + case 'G': + case 'g': + units = 1024*1024*1024; + break; + case 'M': + case 'm': + units = 1024*1024; + break; + case 'K': + case 'k': + units = 1024; + break; + default: + return -1.0; + }; + break; + case 1: + units = 1; + break; + default: + return -1.0; } - size = atof(value)*units; - return size; + return size * units; } double DeltaMaxValue(ncclDataType_t type) { @@ -570,6 +588,7 @@ int main(int argc, char* argv[]) { setlinebuf(stdout); // Parse args + double parsed; int longindex; static struct option longopts[] = { {"nthreads", required_argument, 0, 't'}, @@ -605,10 +624,20 @@ int main(int argc, char* argv[]) { nGpus = strtol(optarg, NULL, 0); break; case 'b': - minBytes = (size_t)parsesize(optarg); + parsed = parsesize(optarg); + if (parsed < 0) { + fprintf(stderr, "invalid size specified for 'minbytes'\n"); + return -1; + } + minBytes = (size_t)parsed; break; case 'e': - maxBytes = (size_t)parsesize(optarg); + parsed = parsesize(optarg); + if (parsed < 0) { + fprintf(stderr, "invalid size specified for 'maxbytes'\n"); + return -1; + } + maxBytes = (size_t)parsed; break; case 'i': stepBytes = strtol(optarg, NULL, 0); @@ -623,7 +652,7 @@ int main(int argc, char* argv[]) { #if NCCL_MAJOR >= 2 && NCCL_MINOR >= 2 agg_iters = (int)strtol(optarg, NULL, 0); #else - printf("Option -m not supported before NCCL 2.2. Ignoring\n"); + fprintf(stderr, "Option -m not supported before NCCL 2.2. Ignoring\n"); #endif break; case 'w': @@ -648,7 +677,7 @@ int main(int argc, char* argv[]) { blocking_coll = strtol(optarg, NULL, 0); break; case 'h': - printf("USAGE: %s \n\t" + fprintf(stderr, "USAGE: %s \n\t" "[-t,--nthreads ] \n\t" "[-g,--ngpus ] \n\t" "[-b,--minbytes ] \n\t" @@ -668,8 +697,8 @@ int main(int argc, char* argv[]) { basename(argv[0])); return 0; default: - printf("invalid option \n"); - printf("USAGE: %s \n\t" + fprintf(stderr, "invalid option \n"); + fprintf(stderr, "USAGE: %s \n\t" "[-t,--nthreads ] \n\t" "[-g,--ngpus ] \n\t" "[-b,--minbytes ] \n\t" @@ -690,6 +719,12 @@ int main(int argc, char* argv[]) { return 0; } } + if (minBytes > maxBytes) { + fprintf(stderr, "invalid sizes for 'minbytes' and 'maxbytes': %llu > %llu\n", + (unsigned long long)minBytes, + (unsigned long long)maxBytes); + return -1; + } #ifdef MPI_SUPPORT MPI_Init(&argc, &argv); #endif