/* * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/runtime/gptJsonConfig.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/stringUtils.h" #include #include #include using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; namespace { using Json = typename nlohmann::json::basic_json; template FieldType parseJsonFieldOr(Json const& json, std::string_view name, FieldType defaultValue) { auto value = defaultValue; try { value = json.at(name).template get(); } catch (nlohmann::json::out_of_range&) { // std::cerr << e.what() << '\n'; } return value; } template GptJsonConfig parseJson(InputType&& i) { auto constexpr allowExceptions = true; auto constexpr ingoreComments = true; auto json = nlohmann::json::parse(i, nullptr, allowExceptions, ingoreComments); auto const& builderConfig = json.at("builder_config"); auto const name = builderConfig.at("name").template get(); auto const precision = builderConfig.at("precision").template get(); auto const worldSize = builderConfig.at("tensor_parallel").template get(); auto const numHeads = builderConfig.at("num_heads").template get() / worldSize; auto const hiddenSize = builderConfig.at("hidden_size").template get() / worldSize; auto const vocabSize = builderConfig.at("vocab_size").template get(); auto const numLayers = builderConfig.at("num_layers").template get(); auto dataType = nvinfer1::DataType::kFLOAT; if (!precision.compare("float32")) dataType = nvinfer1::DataType::kFLOAT; else if (!precision.compare("float16")) dataType = nvinfer1::DataType::kHALF; else if (!precision.compare("bfloat16")) dataType = nvinfer1::DataType::kBF16; else TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", precision.c_str())); auto const pagedKvCache = parseJsonFieldOr(builderConfig, "paged_kv_cache", false); auto const tokensPerBlock = parseJsonFieldOr(builderConfig, "tokens_per_block", 0); auto const quantMode = tc::QuantMode(parseJsonFieldOr(builderConfig, "quant_mode", tc::QuantMode::none().value())); auto const numKvHeads = parseJsonFieldOr(builderConfig, "num_kv_heads", numHeads * worldSize) / worldSize; auto const& pluginConfig = json.at("plugin_config"); auto const& gptAttentionPlugin = pluginConfig.at("gpt_attention_plugin"); auto const useGptAttentionPlugin = !gptAttentionPlugin.is_boolean() || gptAttentionPlugin.template get(); auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get(); auto const inflightBatching = pluginConfig.at("in_flight_batching").template get(); auto modelConfig = GptModelConfig{vocabSize, numLayers, numHeads, hiddenSize, dataType}; modelConfig.useGptAttentionPlugin(useGptAttentionPlugin); modelConfig.usePackedInput(removeInputPadding); modelConfig.usePagedKvCache(pagedKvCache); modelConfig.useInflightBatching(inflightBatching); modelConfig.setTokensPerBlock(tokensPerBlock); modelConfig.setQuantMode(quantMode); modelConfig.setNbKvHeads(numKvHeads); return GptJsonConfig{name, precision, worldSize, modelConfig}; } } // namespace std::string GptJsonConfig::engineFilename(WorldConfig const& worldConfig, std::string const& model) const { TLLM_CHECK_WITH_INFO(getWorldSize() == worldConfig.getSize(), "world size mismatch"); return model + "_" + getPrecision() + "_tp" + std::to_string(worldConfig.getSize()) + "_rank" + std::to_string(worldConfig.getRank()) + ".engine"; } GptJsonConfig GptJsonConfig::parse(std::string const& json) { return parseJson(json); } GptJsonConfig GptJsonConfig::parse(std::istream& json) { return parseJson(json); } GptJsonConfig GptJsonConfig::parse(std::filesystem::path const& path) { TLLM_CHECK_WITH_INFO(std::filesystem::exists(path), std::string("File does not exist: ") + path.string()); std::ifstream json(path); return parse(json); }