/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/executor/disaggServerUtil.h" #include "tensorrt_llm/common/utils.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include namespace tensorrt_llm::executor::disagg_executor { class DisaggExecutorOrchestrator::Impl { public: Impl(std::vector const& ctxEnginePaths, std::vector const& genEnginePaths, std::vector const& ctxExecutorConfigs, std::vector const& genExecutorConfigs, bool hasContextAwaitThreads, bool hasGenAwaitThreads) : mhasContextAwaitThreads(hasContextAwaitThreads) , mhasGenAwaitThreads(hasGenAwaitThreads) { TLLM_CHECK(ctxEnginePaths.size() == ctxExecutorConfigs.size()); TLLM_CHECK(genEnginePaths.size() == genExecutorConfigs.size()); TLLM_CHECK(!(ctxEnginePaths.empty() || genEnginePaths.empty())); int worldRank = tensorrt_llm::mpi::MpiComm::world().getRank(); mIsOrchestrator = (worldRank == 0); auto contextNum = ctxEnginePaths.size(); mContextReqIdToGlobalId = std::vector>(contextNum); mContextMapMutexs = std::vector(contextNum); auto genNum = genEnginePaths.size(); mGenerationReqIdToGlobalId = std::vector>(genNum); mGenerationMapMutexs = std::vector(genNum); for (size_t cN = 0; cN < contextNum; cN++) { mContextExecutors.push_back(std::make_unique( ctxEnginePaths[cN], texec::ModelType::kDECODER_ONLY, ctxExecutorConfigs[cN])); } for (size_t gN = 0; gN < genNum; gN++) { mGenerationExecutors.push_back(std::make_unique( genEnginePaths[gN], texec::ModelType::kDECODER_ONLY, genExecutorConfigs[gN])); } if (mIsOrchestrator) { if (mhasContextAwaitThreads) { for (size_t contextIdx = 0; contextIdx < contextNum; contextIdx++) { mContextThreads.emplace_back( [this, contextIdx]() { this->waitResponseAndAppendThreadFun(true, contextIdx); }); } } if (mhasGenAwaitThreads) { for (size_t genIdx = 0; genIdx < genNum; genIdx++) { mGenerationThreads.emplace_back( [this, genIdx]() { this->waitResponseAndAppendThreadFun(false, genIdx); }); } } } tensorrt_llm::mpi::MpiComm::world().barrier(); } std::vector enqueueContext(std::vector const& requests, std::optional selectContextId = std::nullopt, bool batch = false) { std::vector globalReqIds; for (auto const& request : requests) { globalReqIds.push_back(generatedGlobalId()); TLLM_CHECK(request.getRequestType() == tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY); } if (batch) { size_t contextId = selectContextId.has_value() ? selectContextId.value() : selectContextExecutor(); auto contextReqIds = mContextExecutors[contextId]->enqueueRequests(requests); { std::scoped_lock lock{mContextMapMutexs[contextId]}; for (size_t i = 0; i < requests.size(); ++i) { mContextReqIdToGlobalId[contextId][contextReqIds[i]] = globalReqIds[i]; } } } else { for (size_t i = 0; i < requests.size(); ++i) { size_t contextId = selectContextId.has_value() ? selectContextId.value() : selectContextExecutor(); auto contextReqId = mContextExecutors[contextId]->enqueueRequest(requests[i]); { std::scoped_lock lock{mContextMapMutexs[contextId]}; mContextReqIdToGlobalId[contextId][contextReqId] = globalReqIds[i]; } } } return globalReqIds; } void enqueueGeneration(std::vector const& requests, std::vector const& globalRequestIds, std::optional selectGenIdx = std::nullopt, bool batch = false) { TLLM_CHECK(globalRequestIds.size() == requests.size()); for (auto const& request : requests) { TLLM_CHECK(request.getRequestType() == tensorrt_llm::executor::RequestType::REQUEST_TYPE_GENERATION_ONLY); } if (batch) { size_t genIdx = selectGenIdx.has_value() ? selectGenIdx.value() : selectGenerationExecutor(); auto genReqIds = mGenerationExecutors[genIdx]->enqueueRequests(requests); { std::scoped_lock lock{mGenerationMapMutexs[genIdx]}; for (size_t i = 0; i < requests.size(); ++i) { mGenerationReqIdToGlobalId[genIdx][genReqIds[i]] = globalRequestIds[i]; } } } else { for (size_t i = 0; i < requests.size(); ++i) { size_t genIdx = selectGenIdx.has_value() ? selectGenIdx.value() : selectGenerationExecutor(); auto genReqId = mGenerationExecutors[genIdx]->enqueueRequest(requests[i]); { std::scoped_lock lock{mGenerationMapMutexs[genIdx]}; mGenerationReqIdToGlobalId[genIdx][genReqId] = globalRequestIds[i]; } } } } std::vector awaitContextResponses( std::optional contextIdx, std::optional const& timeout) { std::vector responses; if (mhasContextAwaitThreads) { std::unique_lock lock(mResponsesContextMtx); auto pred = [&mShutdown = mShutdown, &resp = this->mContextResponses]() -> bool { return !resp.empty() || mShutdown; }; auto storeResponses = [&resp = this->mContextResponses, &responses]() { responses = std::move(resp); resp.clear(); }; if (timeout) { if (mContextResponsesCV.wait_for(lock, timeout.value(), pred)) { storeResponses(); } } else { mContextResponsesCV.wait(lock, pred); storeResponses(); } TLLM_CHECK_WITH_INFO( !contextIdx.has_value(), "contextIdx should not be provided when mhasContextAwaitThreads is true"); return responses; } if (contextIdx.has_value()) { TLLM_CHECK(!mhasContextAwaitThreads); auto responseFromExecutor = mContextExecutors[contextIdx.value()]->awaitResponses(timeout); for (auto&& resp : responseFromExecutor) { auto reqId = resp.getRequestId(); IdType globalId{0}; { std::scoped_lock lock{mContextMapMutexs.at(contextIdx.value())}; globalId = mContextReqIdToGlobalId.at(contextIdx.value()).at(reqId); } TLLM_CHECK(globalId != 0); responses.emplace_back(std::move(resp), globalId); } return responses; } TLLM_CHECK(timeout.has_value()); auto timeouP = timeout.value() / mContextExecutors.size(); for (size_t ci = 0; ci < mContextExecutors.size(); ci++) { auto responseFromExecutor = mContextExecutors.at(ci)->awaitResponses(timeouP); for (auto&& resp : responseFromExecutor) { auto reqId = resp.getRequestId(); IdType globalId{0}; { std::scoped_lock lock{mContextMapMutexs.at(ci)}; globalId = mContextReqIdToGlobalId.at(ci).at(reqId); } TLLM_CHECK(globalId != 0); responses.emplace_back(std::move(resp), globalId); } } return responses; }; std::vector awaitGenerationResponses( std::optional genIdx, std::optional const& timeout) { std::vector responses; if (mhasGenAwaitThreads) { std::unique_lock lock(mResponseGenerationMtx); auto pred = [&mShutdown = mShutdown, &resp = this->mGenerationResponses]() -> bool { return !resp.empty() || mShutdown; }; auto storeResponses = [&resp = this->mGenerationResponses, &responses]() { responses = std::move(resp); resp.clear(); }; if (timeout) { if (mGenerationResponsesCv.wait_for(lock, timeout.value(), pred)) { storeResponses(); } } else { mGenerationResponsesCv.wait(lock, pred); storeResponses(); } TLLM_CHECK_WITH_INFO(!genIdx.has_value(), "genIdx should not be provided when mhasGenAwaitThreads is true"); return responses; } if (genIdx.has_value()) { TLLM_CHECK(!mhasGenAwaitThreads); auto responseFromExecutor = mGenerationExecutors[genIdx.value()]->awaitResponses(timeout); for (auto&& resp : responseFromExecutor) { auto reqId = resp.getRequestId(); IdType globalId{0}; { std::scoped_lock lock{mGenerationMapMutexs.at(genIdx.value())}; globalId = mGenerationReqIdToGlobalId.at(genIdx.value()).at(reqId); } TLLM_CHECK(globalId != 0); responses.emplace_back(std::move(resp), globalId); } return responses; } TLLM_CHECK(timeout.has_value()); auto timeouP = timeout.value() / mGenerationExecutors.size(); for (size_t gi = 0; gi < mGenerationExecutors.size(); gi++) { auto responseFromExecutor = mGenerationExecutors.at(gi)->awaitResponses(timeouP); for (auto&& resp : responseFromExecutor) { auto reqId = resp.getRequestId(); IdType globalId{0}; { std::scoped_lock lock{mGenerationMapMutexs.at(gi)}; globalId = mGenerationReqIdToGlobalId.at(gi).at(reqId); } TLLM_CHECK(globalId != 0); responses.emplace_back(std::move(resp), globalId); } } return responses; }; [[nodiscard]] bool canEnqueue() const { return mIsOrchestrator; } [[nodiscard]] std::vector> const& getContextExecutors() const { return mContextExecutors; } [[nodiscard]] std::vector> const& getGenExecutors() const { return mGenerationExecutors; } ~Impl() { mShutdown = true; mContextResponsesCV.notify_all(); mGenerationResponsesCv.notify_all(); for (auto&& executor : mContextExecutors) { executor->shutdown(); } for (auto&& executor : mGenerationExecutors) { executor->shutdown(); } if (mIsOrchestrator) { if (mhasContextAwaitThreads) { for (auto&& contextThread : mContextThreads) { if (contextThread.joinable()) { contextThread.join(); } } } if (mhasGenAwaitThreads) { for (auto&& genThread : mGenerationThreads) { if (genThread.joinable()) { genThread.join(); } } } } } private: IdType generatedGlobalId() { return (++mLastId % UINT64_MAX); }; size_t selectContextExecutor() { static size_t selectContextId = 0; auto contextId = (selectContextId++) % mContextExecutors.size(); if (selectContextId >= mContextExecutors.size()) { selectContextId = 0; } return contextId; } size_t selectGenerationExecutor() { static size_t selectGenerationId = 0; auto generationIdx = (selectGenerationId++) % mGenerationExecutors.size(); if (selectGenerationId >= mGenerationExecutors.size()) { selectGenerationId = 0; } return generationIdx; } void appendNewContextResponse(std::vector&& newResponses) { { std::scoped_lock lock(mResponsesContextMtx); for (auto&& response : newResponses) { mContextResponses.emplace_back(std::move(response)); } } mContextResponsesCV.notify_all(); } void appendNewGenerationResponse(std::vector&& newResponses) { { std::scoped_lock lock(mResponseGenerationMtx); for (auto&& response : newResponses) { mGenerationResponses.emplace_back(std::move(response)); } } mGenerationResponsesCv.notify_all(); } void waitResponseAndAppendThreadFun(bool isContext, int executorIdx) { tensorrt_llm::common::setThreadName("waitResponseAndAppendThreadFun"); auto& executor = isContext ? mContextExecutors[executorIdx] : mGenerationExecutors[executorIdx]; while (!mShutdown) { auto responses = executor->awaitResponses(); if (responses.empty()) { continue; } std::vector responseWithIds; if (isContext) { for (auto&& response : responses) { auto reqId = response.getRequestId(); IdType globalId{0}; { std::scoped_lock lock{mContextMapMutexs.at(executorIdx)}; globalId = mContextReqIdToGlobalId.at(executorIdx).at(reqId); } TLLM_CHECK(globalId != 0); responseWithIds.emplace_back(std::move(response), globalId); } if (responseWithIds.size() > 0) { appendNewContextResponse(std::move(responseWithIds)); } } else { for (auto&& response : responses) { auto reqId = response.getRequestId(); IdType globalId{0}; { std::scoped_lock lock{mGenerationMapMutexs.at(executorIdx)}; globalId = mGenerationReqIdToGlobalId.at(executorIdx).at(reqId); } TLLM_CHECK(globalId != 0); responseWithIds.emplace_back(std::move(response), globalId); } if (responseWithIds.size() > 0) { appendNewGenerationResponse(std::move(responseWithIds)); } } } }; std::vector> mContextExecutors; std::vector> mGenerationExecutors; std::vector mContextThreads; std::vector mGenerationThreads; std::atomic mLastId{0}; std::vector> mContextReqIdToGlobalId; std::vector> mGenerationReqIdToGlobalId; std::vector mContextMapMutexs; std::vector mGenerationMapMutexs; std::vector mContextResponses; std::condition_variable mContextResponsesCV; std::mutex mResponsesContextMtx; std::vector mGenerationResponses; std::condition_variable mGenerationResponsesCv; std::mutex mResponseGenerationMtx; std::atomic mShutdown{false}; std::atomic mhasContextAwaitThreads{false}; std::atomic mhasGenAwaitThreads{false}; bool mIsOrchestrator{false}; }; DisaggExecutorOrchestrator::DisaggExecutorOrchestrator(std::vector const& ctxEnginePaths, std::vector const& genEnginePaths, std::vector const& ctxExecutorConfigs, std::vector const& genExecutorConfigs, bool hasContextAwaitThreads, bool hasGenAwaitThreads) : mImpl(std::make_unique(ctxEnginePaths, genEnginePaths, ctxExecutorConfigs, genExecutorConfigs, hasContextAwaitThreads, hasGenAwaitThreads)) { } std::vector DisaggExecutorOrchestrator::enqueueContext( std::vector const& requests, std::optional selectContextId, bool batch) { return mImpl->enqueueContext(requests, selectContextId, batch); } void DisaggExecutorOrchestrator::enqueueGeneration(std::vector const& requests, std::vector const& globalRequestIds, std::optional selectGenIdx, bool batch) { mImpl->enqueueGeneration(requests, globalRequestIds, selectGenIdx, batch); } std::vector DisaggExecutorOrchestrator::awaitContextResponses( std::optional const& timeout, std::optional contextIdx) { return mImpl->awaitContextResponses(contextIdx, timeout); } std::vector DisaggExecutorOrchestrator::awaitGenerationResponses( std::optional const& timeout, std::optional genIdx) { return mImpl->awaitGenerationResponses(genIdx, timeout); } bool DisaggExecutorOrchestrator::canEnqueue() const { return mImpl->canEnqueue(); }; std::vector> const& DisaggExecutorOrchestrator::getContextExecutors() const { return mImpl->getContextExecutors(); } std::vector> const& DisaggExecutorOrchestrator::getGenExecutors() const { return mImpl->getGenExecutors(); } DisaggExecutorOrchestrator::~DisaggExecutorOrchestrator() = default; } // namespace tensorrt_llm::executor::disagg_executor