/* * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h" #include namespace tlc = tensorrt_llm::common; namespace tle = tensorrt_llm::executor; namespace fs = std::filesystem; struct RuntimeOptions { std::string trtEnginePath; tle::SizeType32 numSysPrompts; tle::SizeType32 sysPromptTokens; tle::SizeType32 contextTokens; tle::SizeType32 maxTokensMean; tle::SizeType32 maxTokensStddev; tle::SizeType32 numRequests; size_t hostCacheSize; size_t maxTokensInPagedKvCache; }; struct KVCacheBlock { KVCacheBlock(size_t hash, int cacheLevel, int priority, std::optional loraId = std::nullopt, std::shared_ptr prevBlock = nullptr); size_t hash; int cacheLevel; int priority; std::optional loraId; std::shared_ptr prevBlock; std::unordered_map> nextBlocks; }; class RadixTree { public: explicit RadixTree(tle::Executor& executor); // Check the executor for new events. void pollEvents(); private: std::shared_ptr mCacheEventManager; // The root block of the radix tree std::shared_ptr root; // A table mapping block hashes to their pointers std::unordered_map> blockTable; // Event counter size_t eventCounter; }; // Utility function to parse input arguments RuntimeOptions parseArgs(int argc, char* argv[]); // Create a tle::Request tle::Request makeRequest(int sysPromptTokens, int contextTokens, std::uniform_int_distribution sysPromptSelector, std::normal_distribution maxNumTokensSelector); std::default_random_engine gen; int main(int argc, char* argv[]) { // Register the TRT-LLM plugins initTrtLlmPlugins(); auto runtimeOpts = parseArgs(argc, argv); // Create the executor for this engine auto executorConfig = tle::ExecutorConfig(1); // Beam width 1 is required for cache block reuse auto kvCacheConfig = tle::KvCacheConfig(true, runtimeOpts.maxTokensInPagedKvCache ? std::optional(runtimeOpts.maxTokensInPagedKvCache) : std::nullopt); // Enable cache block reuse kvCacheConfig.setHostCacheSize(runtimeOpts.hostCacheSize); kvCacheConfig.setEventBufferMaxSize(32768); executorConfig.setKvCacheConfig(kvCacheConfig); auto executor = tle::Executor(runtimeOpts.trtEnginePath, tle::ModelType::kDECODER_ONLY, executorConfig); auto radixTree = RadixTree(executor); auto activeRequests = runtimeOpts.numRequests; std::uniform_int_distribution sysPromptSelector( 1, runtimeOpts.numSysPrompts); // Select a system prompt between 1 and `runtimeOpts.numSysPrompts` std::normal_distribution maxNumTokensSelector(runtimeOpts.maxTokensMean, runtimeOpts.maxTokensStddev); // Create and enqueue the requests for (int i = 0; i < runtimeOpts.numRequests; i++) { std::ignore = executor.enqueueRequest(makeRequest( runtimeOpts.sysPromptTokens, runtimeOpts.contextTokens, sysPromptSelector, maxNumTokensSelector)); } while (activeRequests > 0) { auto responses = executor.awaitResponses(std::chrono::milliseconds(20)); for (auto const& response : responses) { if (response.getResult().isFinal) activeRequests--; } // Only call pollEvents once every 20ms. Events are only added to the queue once per iteration, so no need to // poll faster than this. radixTree.pollEvents(); } return 0; } RuntimeOptions parseArgs(int argc, char* argv[]) { RuntimeOptions runtimeOpts; cxxopts::Options options(argv[0], "Example that demonstrates how to use the ExecutorKVCacheManager API"); options.add_options()("h,help", "Print usage"); options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value()); options.add_options()("num_sys_prompts", "Amount of unique simulated system prompts to use", cxxopts::value()->default_value("10")); options.add_options()( "sys_prompt_tokens", "Size of the simulated system prompts", cxxopts::value()->default_value("256")); options.add_options()("context_tokens", "Amount of varying context tokens coming after the system prompts", cxxopts::value()->default_value("128")); options.add_options()( "max_tokens_mean", "Mean number of max output tokens", cxxopts::value()->default_value("128")); options.add_options()( "max_tokens_stddev", "Standard deviation of max output tokens", cxxopts::value()->default_value("32")); options.add_options()( "num_requests", "Amount of requests to send to the engine", cxxopts::value()->default_value("100")); options.add_options()("host_cache_size", "Size of the KV Cache in host memory in bytes", cxxopts::value()->default_value("0")); options.add_options()("max_tokens_in_paged_kv_cache", "Amount of tokens in the kv cache", cxxopts::value()->default_value("0")); auto parsedOptions = options.parse(argc, argv); // Argument: help if (parsedOptions.count("help")) { TLLM_LOG_ERROR(options.help()); exit(0); } // Argument: Engine directory if (!parsedOptions.count("engine_dir")) { TLLM_LOG_ERROR(options.help()); TLLM_LOG_ERROR("Please specify engine directory."); exit(1); } runtimeOpts.trtEnginePath = parsedOptions["engine_dir"].as(); if (!fs::exists(runtimeOpts.trtEnginePath) || !fs::is_directory(runtimeOpts.trtEnginePath)) { TLLM_LOG_ERROR("Engine directory doesn't exist."); exit(1); } runtimeOpts.numSysPrompts = parsedOptions["num_sys_prompts"].as(); runtimeOpts.sysPromptTokens = parsedOptions["sys_prompt_tokens"].as(); runtimeOpts.contextTokens = parsedOptions["context_tokens"].as(); runtimeOpts.maxTokensMean = parsedOptions["max_tokens_mean"].as(); runtimeOpts.maxTokensStddev = parsedOptions["max_tokens_stddev"].as(); runtimeOpts.numRequests = parsedOptions["num_requests"].as(); runtimeOpts.hostCacheSize = parsedOptions["host_cache_size"].as(); runtimeOpts.maxTokensInPagedKvCache = parsedOptions["max_tokens_in_paged_kv_cache"].as(); return runtimeOpts; } KVCacheBlock::KVCacheBlock( size_t hash, int cacheLevel, int priority, std::optional loraId, std::shared_ptr prevBlock) : hash{hash} , cacheLevel{cacheLevel} , priority{priority} , loraId{loraId} , prevBlock{prevBlock} , nextBlocks{} { } RadixTree::RadixTree(tle::Executor& executor) : mCacheEventManager(*executor.getKVCacheEventManager()) , eventCounter{1} { // Use id=-1 for the root block. Doesn't matter what exact id is used, just that it is unique. root = std::make_shared(-1, -1, -1); blockTable[-1] = root; // Wait for the `CREATED` event to be emitted. while (true) { auto events = mCacheEventManager->getLatestEvents(); if (events.size() == 1) { auto const& eventData = std::get(events.front().data); TLLM_LOG_INFO("Event ID %d: KV Cache Manager initialized with blocks per level of: %s", events.front().eventId, tlc::vec2str(eventData.numBlocksPerCacheLevel).c_str()); break; } } }; void RadixTree::pollEvents() { auto events = mCacheEventManager->getLatestEvents(std::chrono::milliseconds(20)); for (tle::KVCacheEvent const& event : events) { TLLM_CHECK(event.eventId == eventCounter++); if (std::holds_alternative(event.data)) { // Blocks have been stored into the radix tree auto const& eventData = std::get(event.data); auto prevBlock = blockTable[eventData.parentHash.value_or(-1)]; // This block should be in the tree TLLM_CHECK(blockTable.find(prevBlock->hash) != blockTable.end()); for (auto& block : eventData.blocks) { TLLM_LOG_INFO("Event ID %d: Block %04x was inserted into the radix tree with parent %04x.", event.eventId, block.blockHash, prevBlock->hash); // This block shouldn't already exist in the tree, and should have tokens associated with it TLLM_CHECK(blockTable.find(block.blockHash) == blockTable.end()); TLLM_CHECK(block.tokens.size() > 0); auto thisBlock = std::make_shared( block.blockHash, block.cacheLevel, block.priority, block.loraId, prevBlock); blockTable[block.blockHash] = thisBlock; // Link the parent to the new block prevBlock->nextBlocks[block.blockHash] = thisBlock; prevBlock = thisBlock; } } else if (std::holds_alternative(event.data)) { auto const& eventData = std::get(event.data); for (auto const& hash : eventData.blockHashes) { TLLM_LOG_INFO("Event ID %d: Block %04x was removed from the radix tree.", event.eventId, hash); // This block should exist in the tree TLLM_CHECK(blockTable.find(hash) != blockTable.end()); auto& block = blockTable[hash]; // Check that the block has no children, and that the parent has the block listed as a child TLLM_CHECK(block->nextBlocks.size() == 0); TLLM_CHECK(block->prevBlock->nextBlocks.find(block->hash) != block->prevBlock->nextBlocks.end()); // Remove the block from it's parent, and remove the entry in the block table block->prevBlock->nextBlocks.erase(block->hash); blockTable.erase(hash); } } else if (std::holds_alternative(event.data)) { auto const& eventData = std::get(event.data); if (eventData.priority.has_value()) { // The block priority was updated TLLM_LOG_INFO("Event ID %d: Block %04x priority was changed from %d to %d", event.eventId, eventData.blockHash, eventData.priority->oldValue, eventData.priority->newValue); TLLM_CHECK(blockTable[eventData.blockHash]->priority == eventData.priority->oldValue); blockTable[eventData.blockHash]->priority = eventData.priority->newValue; } if (eventData.cacheLevel.has_value()) { // The block cache level was updated TLLM_LOG_INFO("Event ID %d: Block %04x cache level was changed from %d to %d", event.eventId, eventData.blockHash, eventData.cacheLevel->oldValue, eventData.cacheLevel->newValue); TLLM_CHECK(blockTable[eventData.blockHash]->cacheLevel == eventData.cacheLevel->oldValue); blockTable[eventData.blockHash]->cacheLevel = eventData.cacheLevel->newValue; } } else { TLLM_LOG_ERROR("Unsupported event type. This shouldn't happen!"); } } } tle::Request makeRequest(int sysPromptTokens, int contextTokens, std::uniform_int_distribution sysPromptSelector, std::normal_distribution maxNumTokensSelector) { int sysPromptVersion = sysPromptSelector(gen); tle::VecTokens inputTokens; // Add `sysPromptTokens` tokens. Add the version to the token ids to create a unique system prompt for (int i = 0; i < sysPromptTokens; i++) { inputTokens.emplace_back(sysPromptVersion + i); } // Add random context tokens for (int i = 0; i < contextTokens; i++) { inputTokens.emplace_back(rand() % 1000); } return tle::Request(inputTokens, maxNumTokensSelector(gen)); }