/* * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Include the fixture with the actual benchmark code #include "mixtureOfExpertsBackendBenchmarkFixture.h" /* * Below is all the setup for parameterising the benchmarks */ template struct WeightParams { using DataType = DataType_; using WeightType = WeightType_; using OutputType = OutputType_; }; #define BENCHMARK_BASIC(atype, wtype, otype) \ BENCHMARK_TEMPLATE_DEFINE_F(MixtureOfExpertsBenchmark, Basic_##atype##_##wtype, WeightParams) \ (benchmark::State & state) \ { \ runBenchmark(state); \ } #define BENCHMARK_BASIC_DO_REGISTER(atype, wtype, otype) \ BENCHMARK_REGISTER_F(MixtureOfExpertsBenchmark, Basic_##atype##_##wtype) \ ->Apply(argGen>>) template auto listAllTactics() { int const sm = getSMVersion(); using RunnerType = decltype(BenchClass::mMoERunner); return RunnerType::getTactics(sm); } template int parseTacticToId(nlohmann::json tactic_config) { bool is_tma_warp_specialized = tactic_config.at("is_tma_warp_specialized").get(); int tile_shape_id = -1; std::array tile_shape; if (tactic_config.at("tile_shape").is_array()) tactic_config.at("tile_shape").get_to(tile_shape); else tile_shape_id = tactic_config.at("tile_shape").get(); std::vector confs = listAllTactics(); try { for (int i = 0; i < confs.size(); i++) { auto const& c = confs[i]; if (c.is_tma_warp_specialized != is_tma_warp_specialized) continue; if (!is_tma_warp_specialized) { int stages = tactic_config.at("stages").get(); if (c.stages != stages) continue; } if (tile_shape_id != -1) { int comp = c.getTileConfigAsInt(); if (tile_shape_id != comp) continue; if (is_tma_warp_specialized && (int) c.cluster_shape != tactic_config.at("cluster_shape").get()) continue; // Found matching config return i; } // Handle if the user provided a shape instead of the enum value if (is_tma_warp_specialized) { // TODO Add cases for blackwell shapes using Kv = uint64_t; constexpr static auto K = [](int m, int n) { return (uint64_t(m) << 32) | uint64_t(n); }; static std::unordered_map const tile_map{ {K(64, 16), CutlassTileConfigSM90::CtaShape64x16x128B}, {K(64, 32), CutlassTileConfigSM90::CtaShape64x32x128B}, {K(64, 64), CutlassTileConfigSM90::CtaShape64x64x128B}, {K(64, 128), CutlassTileConfigSM90::CtaShape64x128x128B}, {K(64, 256), CutlassTileConfigSM90::CtaShape64x256x128B}, {K(128, 16), CutlassTileConfigSM90::CtaShape128x16x128B}, {K(128, 32), CutlassTileConfigSM90::CtaShape128x32x128B}, {K(128, 64), CutlassTileConfigSM90::CtaShape128x64x128B}, {K(128, 128), CutlassTileConfigSM90::CtaShape128x128x128B}, {K(128, 256), CutlassTileConfigSM90::CtaShape128x256x128B}, {K(256, 128), CutlassTileConfigSM90::CtaShape256x128x128B}, }; if (c.getTileConfigAsInt() != (int) tile_map.at(K(tile_shape[0], tile_shape[1]))) continue; static std::unordered_map const cluster_map{ // CTA configs for M=64 {K(1, 1), ClusterShape::ClusterShape_1x1x1}, {K(2, 1), ClusterShape::ClusterShape_2x1x1}, {K(1, 2), ClusterShape::ClusterShape_1x2x1}, {K(2, 2), ClusterShape::ClusterShape_2x2x1}, }; std::array cluster_shape; tactic_config.at("cluster_shape").get_to(cluster_shape); if (c.cluster_shape != cluster_map.at(K(cluster_shape[0], cluster_shape[1]))) continue; // Found matching config return i; } else { std::array warp_shape; tactic_config.at("warp_shape").get_to(warp_shape); using Kv = uint64_t; constexpr static auto K = [](std::array a, std::array b) { uint64_t sum = 0; for (auto v : a) sum = sum * 512 + v; for (auto v : b) sum = sum * 256 + v; return sum; }; static std::unordered_map tile_map{ {K({128, 128, 8}, {64, 64, 8}), CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}, {K({16, 128, 64}, {16, 32, 64}), CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64}, {K({32, 128, 64}, {32, 32, 64}), CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64}, {K({64, 128, 64}, {32, 64, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}, {K({64, 64, 128}, {32, 64, 64}), CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64}, {K({64, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}, {K({128, 64, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64}, {K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}, {K({128, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64}, {K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}, {K({128, 256, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64}, {K({256, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}, {K({16, 256, 64}, {16, 64, 64}), CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64} }; if (c.tile_config_sm80 != tile_map.at(K(tile_shape, warp_shape))) continue; // Found matching config return i; } } } catch (std::out_of_range const& e) { std::cerr << "Warning: error parsing tactic " << tactic_config.dump(2) << std::endl; } return -1; } template void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids) { if (tactic.is_number_integer()) { tactic_ids.push_back(tactic.get()); } else if (tactic.is_array()) { for (auto c : tactic) { parseTacticToVectorID(c, tactic_ids); } } else if (tactic.is_object()) { tactic_ids.push_back(parseTacticToId(tactic)); } else if (tactic.is_string()) { assert(tactic.is_string()); auto tactic_name = tactic.get(); if (tactic_name == "all") { auto all_tactics = listAllTactics(); tactic_ids.resize(all_tactics.size()); std::iota(tactic_ids.begin(), tactic_ids.end(), 0); } else { assert(tactic.get() == "auto"); tactic_ids.push_back(-1); } } else { throw std::invalid_argument("Invalid tactic format"); } } // This interdependence of globals could be better, but it works ok for this limited case. std::unordered_map> name_info_map{ {routingConfigCache[LOAD_BALANCED_ROUTING_CONFIG]->getName(), {-1, LOAD_BALANCED_ROUTING_CONFIG}}, {routingConfigCache[UNIFORM_ROUTING_CONFIG]->getName(), {-1, UNIFORM_ROUTING_CONFIG}}, }; int getNameCacheIdx(std::string const& name) { if (name_info_map.find(name) == name_info_map.end()) { return -1; } return name_info_map.at(name).second; } void setNameCacheIdx(std::string const& name, int id) { name_info_map.at(name).second = id; } template std::optional loadRoutingValues(nlohmann::json entry, int64_t num_experts, int64_t k, std::string config_name) { std::optional routing_config; if (entry.is_string()) { routing_config = getNameCacheIdx(entry.get()); if (routing_config < 0) { throw std::invalid_argument("Invalid routing value, could not find valid config"); } } else { if (config_name.empty()) { throw std::invalid_argument("Explicit routing configurations must specify a name"); } std::vector values; entry.get_to(values); routingConfigCache.push_back(std::make_shared(std::move(values), num_experts, k, config_name)); routing_config = routingConfigCache.size() - 1; } auto conf = routingConfigCache[*routing_config]; bool const is_supported = conf->supportsConfig(num_experts, k, {}); auto conf_derived = std::dynamic_pointer_cast(conf); auto conf_default = std::dynamic_pointer_cast, LoadBalancedRoutingConfig, UniformRoutingConfig>>(conf); bool const is_valid_type = conf_derived || conf_default; if (!is_supported || !is_valid_type) { throw std::invalid_argument("Incompatible config selected. " + ((conf_derived) ? "Expected experts: " + std::to_string(num_experts) + " and k: " + std::to_string(k) + " in routing configuration. Found: " + std::to_string(conf_derived->num_experts) + " and " + std::to_string(conf_derived->k) : "Found incompatible routing config type")); } return routing_config; } // This is suboptimal for large benchmark files as we reread it for every data type template void argGenLoadFile(benchmark::internal::Benchmark* benchmark) { /* * See help text for schema description */ std::ifstream file{workloadFile}; std::stringstream buffer; buffer << file.rdbuf(); auto file_contents = buffer.str(); if (LOG_LEVEL >= INFO) std::cout << "Loaded benchmark file: " << file_contents << std::endl; auto source_data = nlohmann::json::parse(file_contents); int i = 0; for (auto run_config : source_data) { if (LOG_LEVEL >= VERBOSE) std::cout << "Parsing run config: " << run_config.dump(2) << std::endl; std::string config_name = "config_" + std::to_string(i); // WARNING: Process the routing configuration immediately, so we can guarantee all configs get processed for all // data types. We should not skip any test cases as a later test config may depend on this config if (run_config.contains("routing_name")) { run_config["routing_name"].get_to(config_name); if (!run_config.contains("selected_experts") && !run_config.contains("expert_distribution")) { throw std::invalid_argument("Setting routing value configuration name but missing routing values"); } } if (run_config.contains("routing_values") || run_config.contains("routing_distribution")) { throw std::invalid_argument( "Using deprecated routing_values or routing_distribution. Please use selected_experts or " "expert_distribution instead."); } std::optional routing_config; auto res = name_info_map.emplace(config_name, std::pair{i, -1}); // We must check i is not equal since this function gets called for each data type if (!res.second && res.first->second.first != i) { throw std::invalid_argument("Redefinition of routing_name " + config_name + " at config " + std::to_string(i) + ". First declared at " + std::to_string(res.first->second.first)); } else if (!res.second) { // Reuse the existing config from a previous parse routing_config = getNameCacheIdx(config_name); } i++; int num_experts = run_config.at("num_experts").get(); int k = run_config.at("k").get(); if (!routing_config) { if (run_config.contains("selected_experts")) { routing_config = loadRoutingValues( run_config["selected_experts"], num_experts, k, config_name); } else if (run_config.contains("expert_distribution")) { routing_config = loadRoutingValues( run_config["expert_distribution"], num_experts, k, config_name); } } // Use the selected config or fall back to balanced routing_config = routing_config.value_or(LOAD_BALANCED_ROUTING_CONFIG); setNameCacheIdx(config_name, *routing_config); // Filter out the types we don't care about testing if (run_config.contains("dtypes")) { std::vector dtypes; run_config["dtypes"].get_to(dtypes); auto hasDtype = [&](char const* d) { return std::any_of(dtypes.begin(), dtypes.end(), [&](auto const& n) { return n == d; }); }; if (BenchClass::FP8 && !hasDtype("fp8")) { continue; } else if (BenchClass::NVFP4 && !hasDtype("fp4")) { continue; } else if (BenchClass::INT4 && !hasDtype("int4")) { continue; } else if (!BenchClass::INT4 && BenchClass::INT_QUANT && !hasDtype("int8")) { continue; } // else if (std::is_same_v && !hasDtype("float") // && !hasDtype("float32")) // { // continue; // } else if (std::is_same_v && !hasDtype("float16") && !hasDtype("half")) { continue; } else if (std::is_same_v && !hasDtype("bfloat16") && !hasDtype("bf16")) { continue; } else if (BenchClass::WFP4AFP8 && !hasDtype("wfp4afp8")) { continue; } } // Do this after filtering datatypes as tactics only make sense if we know the data type bool has_tactic_ids2 = false; std::vector tactic_ids1{}; std::vector tactic_ids2{}; if (run_config.contains("tactic_id1") || run_config.contains("tactic_id2")) { if (run_config.contains("tactic_id")) { throw std::invalid_argument("Cannot use tactic_id and tactic_idX"); } has_tactic_ids2 = true; parseTacticToVectorID(run_config["tactic_id1"], tactic_ids1); parseTacticToVectorID(run_config["tactic_id2"], tactic_ids2); } else { parseTacticToVectorID(run_config["tactic_id"], tactic_ids1); has_tactic_ids2 = false; tactic_ids2.resize(1); // Dummy value so we loop exactly once below } if (tactic_ids1.empty() || tactic_ids2.empty()) { std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl; static bool printed = false; if (!printed) { printed = true; std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n"; auto confs = listAllTactics(); for (auto c : confs) std::cerr << c.toString(); } continue; } auto get_or = [&](auto name, auto def) { return run_config.contains(name) ? run_config[name].template get() : def; }; int tp_size = get_or("tp_size", 1); int ep_size = get_or("ep_size", 1); int world_rank = get_or("world_rank", 0); int bias = get_or("bias", 0); int do_final_scale = get_or("do_final_scale", 1); // Default to scales on int gemm_to_profile = get_or("gemm_to_profile", (int) GemmToProfile::LAYER); TLLM_CHECK_WITH_INFO(world_rank < tp_size * ep_size, "Rank is out of bounds of tp*ep"); if (gemm_to_profile != (int) GemmToProfile::LAYER && routing_config != UNIFORM_ROUTING_CONFIG) { static bool info_printed = false; if (!info_printed && LOG_LEVEL >= INFO) { std::cerr << "Warning: GEMM profiling is experimental, results may be inaccurate" << std::endl; info_printed = true; } static bool printed = false; if (LOG_LEVEL >= ERROR && !printed) { std::cerr << "Warning: Profiling a specific GEMM will always use uniform random token distribution" << std::endl; printed = true; } routing_config = UNIFORM_ROUTING_CONFIG; if (gemm_to_profile == (int) GemmToProfile::GEMM_1) { tactic_ids2 = {-1}; } else if (gemm_to_profile == (int) GemmToProfile::GEMM_2) { if (!has_tactic_ids2) tactic_ids2 = std::move(tactic_ids1); tactic_ids1 = {-1}; } } auto get_range = [&](std::string name, int min = 1, int max = INT32_MAX) { auto val = run_config.at(name).get(); if (val < min || val > max) { throw std::invalid_argument(name + " must be a positive integer"); } return val; }; for (auto t1 : tactic_ids1) { // tactic_ids2 will have one dummy value if has_tactic_ids2 = false for (auto t2 : tactic_ids2) { if (!has_tactic_ids2) t2 = t1; benchmark->Args({num_experts, // get_range("k"), // get_range("hidden_size"), // get_range("inter_size"), // tp_size, ep_size, world_rank, // get_range("num_tokens"), // bias, do_final_scale, // get_range("act_fn", 0, (int) ActivationType::Identity), // t1, // t2, // *routing_config, gemm_to_profile}); } } } } template void argGenHardcoded(benchmark::internal::Benchmark* benchmark) { auto num_experts = {1, 8, 9, 64, 65, 257}; // {1, 8, 64, 65, 1024}; auto top_k = {1, 2, 3, 16}; // {1, 2, 3, 42}; auto hidden_size = {4096}; auto inter_size_mul = {4.f}; // {7.f/2.f, 4.f}; auto num_tokens = {2048}; // {1, 20, 200, 2048}; auto use_bias = {0}; // {0, 1}; auto activation_type = {ActivationType::Gelu}; // {ActivationType::Relu, ActivationType::Gelu, // ActivationType::Silu, ActivationType::Geglu, // ActivationType::Swiglu}; auto cutlass_tactic = {-1}; // {0,..., listAllTactics().size()}; auto routing_config = {LOAD_BALANCED_ROUTING_CONFIG}; // {0, 1, 2}; for (auto num_expert : num_experts) for (auto k : top_k) if (k <= num_expert) for (auto size : hidden_size) for (auto inter_mul : inter_size_mul) { auto inter_size = static_cast(size * inter_mul); for (auto tokens : num_tokens) for (auto bias : use_bias) for (auto act : activation_type) for (auto tactic1 : cutlass_tactic) for (auto tactic2 : cutlass_tactic) for (auto routing : routing_config) benchmark->Args({num_expert, k, size, inter_size, 1, 1, 0, tokens, bias, 1, (int) act, tactic1, tactic2, routing, (int) GemmToProfile::LAYER}); } } template void argGen(benchmark::internal::Benchmark* benchmark) { if (LOG_LEVEL >= VERBOSE) { std::cout << "List of all tactics for dtype " << (int) BenchClass::toDTypeID() << ":\n"; int i = 0; for (auto& t : listAllTactics()) { std::cout << "Tactic " << i << ":\n"; std::cout << t.toString() << std::endl; i++; } } // Generic setup benchmark->UseManualTime(); benchmark->ArgNames( {"Num Experts", "K", "Hidden Size", "Inter Size", "TP Size", "EP Size", "World Rank", "Num Tokens", "Use Bias", "Use Final Scale", "Activation Function", "Tactic ID 1", "Tactic ID 2", "Routing ID", "Gemm To Profile"}); if (workloadFile) argGenLoadFile(benchmark); else argGenHardcoded(benchmark); } // No one cares about float32 // BENCHMARK_BASIC(float, float, float) BENCHMARK_BASIC(half, half, half) using uint8 = uint8_t; BENCHMARK_BASIC(half, uint8, half) using cutlass::uint4b_t; BENCHMARK_BASIC(half, uint4b_t, half) #ifdef ENABLE_BF16 BENCHMARK_BASIC(nv_bfloat16, nv_bfloat16, nv_bfloat16) #endif #ifdef ENABLE_FP8 BENCHMARK_BASIC(SafeFP8, SafeFP8, half) #endif #ifdef ENABLE_FP4 BENCHMARK_BASIC(SafeFP4, SafeFP4, half) BENCHMARK_BASIC(SafeFP8, SafeFP4, half) #endif void delayedRegisterBenchmark() { BENCHMARK_BASIC_DO_REGISTER(half, half, half); #ifdef ENABLE_FP8 BENCHMARK_BASIC_DO_REGISTER(SafeFP8, SafeFP8, half); #endif if (workloadFile) { // Extra ones we don't want for hardcoded runs // BENCHMARK_BASIC_DO_REGISTER(float, float, float); BENCHMARK_BASIC_DO_REGISTER(half, uint8, half); BENCHMARK_BASIC_DO_REGISTER(half, uint4b_t, half); #ifdef ENABLE_BF16 BENCHMARK_BASIC_DO_REGISTER(nv_bfloat16, nv_bfloat16, nv_bfloat16); #endif #ifdef ENABLE_FP4 BENCHMARK_BASIC_DO_REGISTER(SafeFP4, SafeFP4, half); BENCHMARK_BASIC_DO_REGISTER(SafeFP8, SafeFP4, half); #endif } } void doCleanup() { bufferManager.reset(); streamPtr.reset(); } void help() { std::cout << "**Disclaimer: This benchmark is intended for developers to help evaluating the impact of new " "optimisations. This benchmark does not meet the same quality standards as other parts of TRT-LLM. " "Please use with caution**\n\n"; std::cout << "Usage: mixtureOfExpertsBackendBenchmark [--disable_cuda_graphs] [--input_file ] [benchmark " "options]\n"; std::cout << "--disable_cuda_graphs\t\tPrevent the benchmark from using cuda graphs. Useful for getting the performance " "of specific kernels\n" << "--input_file\t\tA JSON file describing the benchmark configurations\n\n" << "File schema\n" "[\n" " {\n" " \"num_experts\": int,\n" " \"k\": int,\n" " \"hidden_size\": int,\n" " \"inter_size\": int,\n" " \"tp_size\": int, (optional)\n" " \"ep_size\": int, (optional)\n" " \"world_rank\": int, (optional)\n" " \"num_tokens\": int,\n" " \"bias\": int, (optional)\n" " \"do_final_scale\": int, (optional)\n" " \"act_fn\": int,\n" " \"tactic_id\": tactic, (see below)\n" " \"tactic_id1\": tactic, (see below)\n" " \"tactic_id2\": tactic, (see below)\n" " \"dtypes\": [string, ...], (optional)\n" " \"routing_name\": string, (optional)\n" " \"selected_experts\": [int, ...], or string, (optional, length is a multiple of k)\n" " \"expert_distribution\": [float, ...], or string, (optional, length is num_experts)\n" " \"gemm_to_profile\": int, (experimental, optional, 1 = gemm1, 2 = gemm2, 3 = layer)\n" " },\n" " ...\n" "]\n" "Explanation:\n" "- \"num_experts\" - The number of experts\n" "- \"k\" - The top k\n" "- \"hidden_size\" - The hidden size\n" "- \"inter_size\" - The inter size\n" "- \"tp_size\" - The TP size to use\n" "- \"ep_size\" - The EP size to use\n" "- \"world_rank\" - The world rank = tp_rank * ep_size + ep_rank\n" "- \"num_tokens\" - The total number of tokens to benchmark\n" "- \"bias\" - If bias should be used, 0 = no bias, 1 = bias\n" "- \"do_final_scale\" - If final scales should be applied, 0 = no scale, 1 = scale\n" "- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = " "swiglu\n" "- \"tactic_id, tactic_id1, tactic_id2\"\n" "The config for the CUTLASS GEMM. tactic_id sets the same tactic for both to the same tactic (except in " "auto mode)\n" "Use tactic_idX to set the tactic for the corresponding GEMM" "Valid tactics are:\n" " - An object:\n" " {\n" " \"is_tma_warp_specialized\": bool,\n" " \"tile_shape\": [int, int, int] or int,\n" " \"cluster_shape\": [int, int, int] or int, (required for sm90, type must be an int if tile_shape " "is " "an int)\n" " \"warp_shape\": [int, int, int], (required for non-sm90 if tile_shape is an array)\n" " \"stages\": int, (required for non-sm90)\n" " },\n" " - An integer: corresponds to an index in the tactics array. WARNING this is not stable between test " "configurations\n" " - An array: of integers or objects, forms a list of tactics to sweep\n" " - The string \"all\": This will sweep through all possible tactics\n" " - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark " "case. " "Useful for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate " "results" "- dtypes - A list of dtypes to run this config through.\n" "Allowed values are: fp8, fp4, wfp4afp8, int4, int8, half, bfloat16\n" "If this argument is omitted all dtypes will be run. Note, not all tactics are supported for all " "dtypes,\n" "unsupported tactics will be skipped with a warning.\n" "- \"routing_name\" - a name to help identify the routing pattern. This can be used by later " "benchmarks to reuse the config\n" "- \"selected_experts\" - a flat array of selected experts to define a new config,\n or a string " "referencing the name of a previous config. Defaults to pre-defined config \"balanced\",\n" "which is short-hand for a perfectly balanced expert distribution\n" "These define the routing values used as input to the moe backend, and is intended to allow comparing " "different routing behaviours.\n" "When defining an array, it must have `T*k` floating point values. Each set of\n" "`k` values defines the input for a single token. If `num_tokens` is greater than `T` it will " "repeat from the beginning\n" "- \"expert_distribution\" - instead of explicitly setting selected_experts, define a random distribution " "that experts will be randomly sampled from." "There is also pre-defined config \"uniform\", which is short-hand for a random uniform distribution\n" "- \"gemm_to_profile\" - the gemm to profile, 1 = gemm1, 2 = gemm2, 3 = full layer. (default layer). If a " "specific GEMM is profiled, it will always use uniform random token distribution\n" "\n"; std::cout << "benchmark options:\n"; benchmark::PrintDefaultHelp(); } void gbenchCustomHelp() { help(); // google-benchmark calls exit() so we need to cleanup manually doCleanup(); } int parseArgsAndRunBench(int argc, char** argv) { try { int shift = 0; for (int i = 1; i < argc; i++) { argv[i - shift] = argv[i]; if (strcmp("--input_file", argv[i]) == 0) { i += 1; if (i == argc) { std::cerr << "Missing file name for input_file\n"; return -1; } workloadFile = argv[i]; if (workloadFile[0] == '-') { std::cerr << "Workload file " << workloadFile << " not a valid file name\n"; return -2; } shift += 2; } else if (strcmp("--disable_cuda_graphs", argv[i]) == 0) { useCudaGraph = false; shift++; } else if (strcmp("--help", argv[i]) == 0 || strcmp("-h", argv[i]) == 0) { help(); return 0; } } argc -= shift; // Delay after we know if the user passed a config file delayedRegisterBenchmark(); benchmark::Initialize(&argc, argv, &gbenchCustomHelp); if (argc > 1) { help(); std::cout << std::flush; // Force flush // Print the error second, so it's easy to see std::cerr << "\nUnrecognised argument: " << argv[1] << std::endl; return -4; } benchmark::RunSpecifiedBenchmarks(); benchmark::Shutdown(); return 0; } catch (std::exception const& e) { std::cerr << "Exiting benchmarks with exception: " << e.what() << std::endl; return -3; } } int main(int argc, char** argv) { deviceCount = getDeviceCount(); if (deviceCount < 0) return 0; streamPtr = std::make_shared(); bufferManager = std::make_unique(streamPtr); int res = -1; try { res = parseArgsAndRunBench(argc, argv); } catch (std::exception const& e) { std::cout << "Benchmark exited with unhandled exception: " << e.what() << std::endl; } doCleanup(); return res; }