/* * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include namespace tle = tensorrt_llm::executor; namespace fs = std::filesystem; struct RuntimeOptions { std::string trtContextEnginePath; std::string trtGenerationEnginePath; std::string inputTokensCsvFile; std::string outputTokensCsvFile; bool streaming; bool excludeInputFromOutput; int contextRankSize; int generationRankSize; tle::SizeType32 maxNewTokens; tle::SizeType32 beamWidth; std::optional numReturnSequences; tle::SizeType32 timeoutMs; }; RuntimeOptions parseArgs(int argc, char* argv[]); // Function that enqueues requests into context executor and generation executor std::unordered_map enqueueRequests( RuntimeOptions const& runtimeOpts, tle::Executor& contextExecutor, tle::Executor& generationExecutor); // Function that waits for gen responses and stores output tokens std::unordered_map waitForGenResponses(RuntimeOptions const& runtimeOpts, std::unordered_map const& genRequestIdToContextRequestId, tle::Executor& generationExecutor); // Utility function to read input tokens from csv file std::vector readInputTokens(std::string const& path); // Utility function to write output tokens from csv file void writeOutputTokens(std::string const& path, std::unordered_map& genRequestIdToContextRequestId, std::unordered_map const& outputTokens, tle::SizeType32 beamWidth); int main(int argc, char* argv[]) { // Register the TRT-LLM plugins initTrtLlmPlugins(); auto runtimeOpts = parseArgs(argc, argv); TLLM_CHECK_WITH_INFO(runtimeOpts.beamWidth == 1, "Only support beamWidth =1"); TLLM_CHECK_WITH_INFO( runtimeOpts.numReturnSequences.has_value() == false || runtimeOpts.numReturnSequences.value() == 1, "Only support numReturnSequences =1"); // Create the executor for this engine auto contextExecutorConfig = tle::ExecutorConfig(runtimeOpts.beamWidth); auto generationExecutorConfig = tle::ExecutorConfig(runtimeOpts.beamWidth); bool isOrchestrator = (tensorrt_llm::mpi::MpiComm::world().getRank() == 0); auto orchestratorConfig = tle::OrchestratorConfig(isOrchestrator, "", nullptr, false); int contextRankSize = runtimeOpts.contextRankSize; int generationRankSize = runtimeOpts.generationRankSize; TLLM_CHECK_WITH_INFO(tensorrt_llm::mpi::MpiComm::world().getSize() >= contextRankSize + generationRankSize + 1, " MPI should launch at least [contextRankSize+generationRankSize+1]: %d processes", contextRankSize + generationRankSize + 1); int deviceCount = -1; TLLM_CHECK(cudaGetDeviceCount(&deviceCount) == cudaSuccess); std::vector contextRankIds(contextRankSize); std::vector contextDeviceIds(contextRankSize); std::vector generationRankIds(generationRankSize); std::vector generationDeviceIds(generationRankSize); for (int i = 0; i < contextRankSize; i++) { contextRankIds[i] = i + 1; contextDeviceIds[i] = i % deviceCount; TLLM_LOG_INFO("context Rank %d on device %d", contextRankIds[i], contextDeviceIds[i]); } tle::ParallelConfig contextParallelConfig{tensorrt_llm::executor::CommunicationType::kMPI, tensorrt_llm::executor::CommunicationMode::kORCHESTRATOR, contextDeviceIds, contextRankIds, orchestratorConfig}; for (int i = 0; i < generationRankSize; i++) { generationRankIds[i] = i + 1 + contextRankSize; generationDeviceIds[i] = (i + contextRankSize) % deviceCount; TLLM_LOG_INFO("generation Rank %d on device %d", generationRankIds[i], generationDeviceIds[i]); } tle::ParallelConfig generationParallelConfig{tensorrt_llm::executor::CommunicationType::kMPI, tensorrt_llm::executor::CommunicationMode::kORCHESTRATOR, generationDeviceIds, generationRankIds, orchestratorConfig}; contextExecutorConfig.setParallelConfig(contextParallelConfig); generationExecutorConfig.setParallelConfig(generationParallelConfig); auto contextExecutor = tle::Executor(runtimeOpts.trtContextEnginePath, tle::ModelType::kDECODER_ONLY, contextExecutorConfig); auto generationExecutor = tle::Executor(runtimeOpts.trtGenerationEnginePath, tle::ModelType::kDECODER_ONLY, generationExecutorConfig); tensorrt_llm::mpi::MpiComm::world().barrier(); if (tensorrt_llm::mpi::MpiComm::world().getRank() == 0) { TLLM_CHECK_WITH_INFO(contextExecutor.canEnqueueRequests(), "contextExecutor can't enqueue requests"); TLLM_CHECK_WITH_INFO(generationExecutor.canEnqueueRequests(), "generationExecutor can't enqueue requests"); auto genRequestIdsToContextRequestIds = enqueueRequests(runtimeOpts, contextExecutor, generationExecutor); auto outputTokens = waitForGenResponses(runtimeOpts, genRequestIdsToContextRequestIds, generationExecutor); TLLM_LOG_INFO("Writing output tokens to %s", runtimeOpts.outputTokensCsvFile.c_str()); writeOutputTokens( runtimeOpts.outputTokensCsvFile, genRequestIdsToContextRequestIds, outputTokens, runtimeOpts.beamWidth); } tensorrt_llm::mpi::MpiComm::world().barrier(); TLLM_LOG_INFO("Exiting."); return 0; } RuntimeOptions parseArgs(int argc, char* argv[]) { RuntimeOptions runtimeOpts; cxxopts::Options options(argv[0], "Example that demonstrates how to use the Executor Disaggregated API"); options.add_options()("h,help", "Print usage"); options.add_options()( "context_engine_dir", "Directory that store the context engine.", cxxopts::value()); options.add_options()( "generation_engine_dir", "Directory that store the generation engine.", cxxopts::value()); options.add_options()( "context_rank_size", "The number of ranks for the context engine", cxxopts::value()->default_value("1")); options.add_options()("generation_rank_size", "The number of ranks for the generation engine", cxxopts::value()->default_value("1")); options.add_options()("beam_width", "The beam width", cxxopts::value()->default_value("1")); options.add_options()( "num_return_sequences", "The number of return sequences per request.", cxxopts::value>()); options.add_options()("streaming", "Operate in streaming mode", cxxopts::value()->default_value("false")); options.add_options()("exclude_input_from_output", "Exclude input tokens when writing output tokens. Only has effect for streaming = false. For streaming = true, " "output tokens are not included.", cxxopts::value()->default_value("false")); options.add_options()( "max_new_tokens", "The maximum number of tokens to generate", cxxopts::value()->default_value("10")); options.add_options()( "input_tokens_csv_file", "Path to a csv file that contains input tokens", cxxopts::value()); options.add_options()("output_tokens_csv_file", "Path to a csv file that will contain the output tokens", cxxopts::value()->default_value("outputTokens.csv")); options.add_options()("timeout_ms", "The maximum time to wait for all responses, in milliseconds.", cxxopts::value()->default_value("10000")); auto parsedOptions = options.parse(argc, argv); // Argument: help if (parsedOptions.count("help")) { TLLM_LOG_ERROR(options.help()); exit(0); } runtimeOpts.trtContextEnginePath = parsedOptions["context_engine_dir"].as(); if (!fs::exists(runtimeOpts.trtContextEnginePath) || !fs::is_directory(runtimeOpts.trtContextEnginePath)) { TLLM_LOG_ERROR("Context engine directory doesn't exist."); exit(1); } runtimeOpts.trtGenerationEnginePath = parsedOptions["generation_engine_dir"].as(); if (!fs::exists(runtimeOpts.trtGenerationEnginePath) || !fs::is_directory(runtimeOpts.trtGenerationEnginePath)) { TLLM_LOG_ERROR("Generation engine directory doesn't exist."); exit(1); } // Argument: Input tokens csv file if (!parsedOptions.count("input_tokens_csv_file")) { TLLM_LOG_ERROR(options.help()); TLLM_LOG_ERROR("Please specify input_tokens_csv_file"); exit(1); } runtimeOpts.inputTokensCsvFile = parsedOptions["input_tokens_csv_file"].as(); runtimeOpts.streaming = parsedOptions["streaming"].as(); runtimeOpts.excludeInputFromOutput = parsedOptions["exclude_input_from_output"].as(); runtimeOpts.maxNewTokens = parsedOptions["max_new_tokens"].as(); runtimeOpts.beamWidth = parsedOptions["beam_width"].as(); runtimeOpts.contextRankSize = parsedOptions["context_rank_size"].as(); runtimeOpts.generationRankSize = parsedOptions["generation_rank_size"].as(); if (parsedOptions.count("num_return_sequences") > 0) { runtimeOpts.numReturnSequences = parsedOptions["num_return_sequences"].as>(); } runtimeOpts.timeoutMs = parsedOptions["timeout_ms"].as(); runtimeOpts.outputTokensCsvFile = parsedOptions["output_tokens_csv_file"].as(); return runtimeOpts; } std::unordered_map enqueueRequests( RuntimeOptions const& runtimeOpts, tle::Executor& contextExecutor, tle::Executor& generationExecutor) { tle::OutputConfig outputConfig; outputConfig.excludeInputFromOutput = runtimeOpts.excludeInputFromOutput; tle::SamplingConfig samplingConfig(runtimeOpts.beamWidth); std::unordered_map genRequestIdToContextRequestId; if (runtimeOpts.numReturnSequences && runtimeOpts.beamWidth == 1) { samplingConfig.setTopP(0.9); } samplingConfig.setNumReturnSequences(runtimeOpts.numReturnSequences); TLLM_LOG_INFO("Reading input tokens from %s", runtimeOpts.inputTokensCsvFile.c_str()); auto inputTokens = readInputTokens(runtimeOpts.inputTokensCsvFile); TLLM_LOG_INFO("Number of requests: %d", inputTokens.size()); std::vector requests; for (auto& tokens : inputTokens) { TLLM_LOG_INFO("Creating request with %d input tokens", tokens.size()); requests.emplace_back( std::move(tokens), runtimeOpts.maxNewTokens, runtimeOpts.streaming, samplingConfig, outputConfig); requests.back().setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY); } auto contextRequestIds = contextExecutor.enqueueRequests(requests); for (size_t i = 0; i < requests.size(); i++) { TLLM_LOG_INFO("waiting response for Context request id: %lu,", contextRequestIds[i]); auto response = contextExecutor.awaitResponses(contextRequestIds[i]); TLLM_LOG_INFO("response received for Context request id: %lu", contextRequestIds[i]); TLLM_CHECK(response.size() == 1); TLLM_CHECK(response.back().getResult().contextPhaseParams.has_value()); requests.at(i).setContextPhaseParams(response.back().getResult().contextPhaseParams.value()); requests.at(i).setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_GENERATION_ONLY); auto genRequestId = generationExecutor.enqueueRequest(requests.at(i)); genRequestIdToContextRequestId[genRequestId] = contextRequestIds[i]; TLLM_LOG_INFO("enqueuing generation request for Context request id: %lu, generation request id: %lu", contextRequestIds[i], genRequestId); } return genRequestIdToContextRequestId; } std::unordered_map waitForGenResponses(RuntimeOptions const& runtimeOpts, std::unordered_map const& genRequestIdToContextRequestId, tle::Executor& generationExecutor) { // Map that will be used to store output tokens for requests std::unordered_map outputTokens; std::vector contextRequestIds{}; std::vector genRequestIds{}; for (auto const& [key, value] : genRequestIdToContextRequestId) { genRequestIds.push_back(key); contextRequestIds.push_back(value); } for (auto contextRequestId : contextRequestIds) { outputTokens[contextRequestId] = tle::BeamTokens(runtimeOpts.beamWidth); } tle::SizeType32 numFinished{0}; tle::SizeType32 iter{0}; // Get the new tokens for each request while (numFinished < static_cast(genRequestIds.size()) && iter < runtimeOpts.timeoutMs) { std::chrono::milliseconds waitTime(1); // Wait for any response auto responses = generationExecutor.awaitResponses(waitTime); auto insertResponseTokens = [&outputTokens, &genRequestIdToContextRequestId](tle::IdType genRequestId, tle::SizeType32 seqIdx, tle::VecTokens const& respTokens) { TLLM_LOG_INFO("Got %d tokens for seqIdx %d for genRequestId %d,contextRequestId %d", respTokens.size(), seqIdx, genRequestId, genRequestIdToContextRequestId.at(genRequestId)); // Store the output tokens for that request id auto& outTokens = outputTokens.at(genRequestIdToContextRequestId.at(genRequestId)).at(seqIdx); outTokens.insert(outTokens.end(), std::make_move_iterator(respTokens.begin()), std::make_move_iterator(respTokens.end())); }; // Loop over the responses for (auto const& response : responses) { auto genRequestId = response.getRequestId(); if (!response.hasError()) { auto result = response.getResult(); numFinished += result.isFinal; if (runtimeOpts.beamWidth > 1) { for (tle::SizeType32 beam = 0; beam < runtimeOpts.beamWidth; ++beam) { insertResponseTokens(genRequestId, beam, result.outputTokenIds.at(beam)); } } else { insertResponseTokens(genRequestId, result.sequenceIndex, result.outputTokenIds.at(0)); } if (result.isFinal) { TLLM_LOG_INFO("genRequest id %lu ,contextRequestId %lu is completed.", genRequestId, genRequestIdToContextRequestId.at(genRequestId)); } } else { // Allow response with error only if awaitResponse processed a terminated request id std::string err = "genReqId " + std::to_string(response.getRequestId()) + " has already been processed and was terminated."; if (response.getErrorMsg() != err) { TLLM_THROW("GenRequest id %lu encountered error: %s", genRequestId, response.getErrorMsg().c_str()); } } } ++iter; } if (iter == runtimeOpts.timeoutMs) { TLLM_THROW("Timeout exceeded."); } return outputTokens; } std::vector readInputTokens(std::string const& path) { std::vector data; std::ifstream file(path); if (!file.is_open()) { auto const err = std::string{"Failed to open file: "} + path; TLLM_LOG_ERROR(err); TLLM_THROW(err); } std::string line; while (std::getline(file, line)) { std::vector row; std::stringstream ss(line); std::string token; while (std::getline(ss, token, ',')) { try { row.push_back(std::stoi(token)); } catch (std::invalid_argument const& e) { TLLM_LOG_ERROR("Invalid argument: %s", e.what()); } catch (std::out_of_range const& e) { TLLM_LOG_ERROR("Out of range: %s", e.what()); } } data.push_back(row); } file.close(); return data; } void writeOutputTokens(std::string const& path, std::unordered_map& genRequestIdToContextRequestId, std::unordered_map const& outputTokens, tle::SizeType32 beamWidth) { std::ofstream file(path); if (!file.is_open()) { TLLM_LOG_ERROR("Failed to open file %s", path.c_str()); return; } std::vector requestIds; for (auto const& [key, value] : genRequestIdToContextRequestId) { requestIds.push_back(value); } std::sort(requestIds.begin(), requestIds.end()); for (auto requestId : requestIds) { auto const& outTokens = outputTokens.at(requestId); for (tle::SizeType32 beam = 0; beam < beamWidth; ++beam) { auto const& beamTokens = outTokens.at(beam); for (size_t i = 0; i < beamTokens.size(); ++i) { file << beamTokens[i]; if (i < beamTokens.size() - 1) { file << ", "; } } file << "\n"; } } file.close(); }