/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/executor/executorImpl.h" #include "tensorrt_llm/batch_manager/trtEncoderModel.h" #include "tensorrt_llm/batch_manager/trtGptModelFactory.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaProfilerUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/nvtxUtils.h" #include "tensorrt_llm/common/timestampUtils.h" #include "tensorrt_llm/common/utils.h" #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/orchestratorUtils.h" #include "tensorrt_llm/executor/requestUtils.h" #include "tensorrt_llm/executor/serialization.h" #include "tensorrt_llm/executor/serializeUtils.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/executor/version.h" #include "tensorrt_llm/runtime/loraCache.h" #include "tensorrt_llm/runtime/memoryCounters.h" #include "tensorrt_llm/runtime/utils/mpiTags.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include #include #include #include #include #include #include #include namespace tensorrt_llm::executor { namespace { [[nodiscard]] bool executorConfigIsValid( ::tensorrt_llm::executor::ExecutorConfig const& executorConfig, runtime::ModelConfig const& modelConfig) { // Make sure logic in this function matches fixExecutorConfig if (executorConfig.getEnableChunkedContext()) { if (modelConfig.isRnnBased() || !modelConfig.isKVCacheEnabled() || !modelConfig.getPagedContextFMHA()) { return false; } } return true; } [[nodiscard]] ::tensorrt_llm::executor::ExecutorConfig fixExecutorConfig( ::tensorrt_llm::executor::ExecutorConfig const& executorConfig, runtime::ModelConfig const& modelConfig) { // Make sure logic in this function matches executorConfigIsValid auto fixedExecutorConfig = executorConfig; // Disable chunked context when not supported if (executorConfig.getEnableChunkedContext()) { if (modelConfig.isRnnBased() || !modelConfig.isKVCacheEnabled() || !modelConfig.getPagedContextFMHA()) { fixedExecutorConfig.setEnableChunkedContext(false); TLLM_LOG_WARNING( "Chunked context is not supported for this configuration and will be disabled. " "Related configs: RNNBased: %d, KVCacheEnabled: %d, PagedContextFMHA: %d", modelConfig.isRnnBased(), modelConfig.isKVCacheEnabled(), modelConfig.getPagedContextFMHA()); } } return fixedExecutorConfig; } SizeType32 getNumChildRequests(Request const& request) { auto samplingConfig = request.getSamplingConfig(); return samplingConfig.getBeamWidth() > 1 ? 0 : samplingConfig.getNumReturnSequences().value_or(1) - 1; } } // namespace /// @brief Version of TRT-LLM as defined in tensorrt_llm/version.py char const* version() noexcept { return kTensorRtLlmVersion; } class CancelledRequestsAsyncSend { public: CancelledRequestsAsyncSend(std::shared_ptr const& commSession, std::unordered_set const& cancelledReqIds, int peer) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mNumReq = static_cast(cancelledReqIds.size()); TLLM_LOG_DEBUG("start send %ld cancelled requests to rank %d", mNumReq, peer); mRequest1 = commSession->sendAsync(&mNumReq, 1, mpi::MpiType::kINT64, peer, mpi::MpiTag::kCancelledRequestsNumReq); if (mNumReq > 0) { mIds.assign(cancelledReqIds.begin(), cancelledReqIds.end()); mRequest2 = commSession->sendAsync( mIds.data(), mIds.size(), mpi::MpiType::kUINT64, peer, mpi::MpiTag::kCancelledRequestsIds); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } ~CancelledRequestsAsyncSend() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mRequest1->wait(); if (mRequest2) { mRequest2->wait(); } TLLM_LOG_DEBUG("end send cancelled requests"); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } CancelledRequestsAsyncSend(CancelledRequestsAsyncSend const& executor) = delete; CancelledRequestsAsyncSend& operator=(CancelledRequestsAsyncSend const& executor) = delete; CancelledRequestsAsyncSend(CancelledRequestsAsyncSend&&) = delete; CancelledRequestsAsyncSend& operator=(CancelledRequestsAsyncSend&&) = delete; static std::unordered_set cancelledRequestsRecv( std::shared_ptr const& commSession, int peer) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_LOG_DEBUG("start recv cancelled requests from rank %d", peer); std::unordered_set cancelledReqIds; int64_t numReq{0}; commSession->recv(&numReq, 1, mpi::MpiType::kINT64, peer, mpi::MpiTag::kCancelledRequestsNumReq); TLLM_LOG_DEBUG("recv %ld cancelled requests", numReq); if (numReq > 0) { std::vector buffer(numReq); commSession->recv( buffer.data(), buffer.size(), mpi::MpiType::kUINT64, peer, mpi::MpiTag::kCancelledRequestsIds); cancelledReqIds = std::unordered_set(buffer.begin(), buffer.end()); } TLLM_LOG_DEBUG("end recv cancelled requests from rank %d", peer); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return cancelledReqIds; } private: int64_t mNumReq; std::vector mIds; std::shared_ptr mRequest1; std::shared_ptr mRequest2; }; class RequestWithIdAsyncSend { public: RequestWithIdAsyncSend(std::shared_ptr const& commSession, std::vector const& reqWithIds, int peer) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_LOG_DEBUG("start send requests to rank %d", peer); mNumReq = static_cast(reqWithIds.size()); mRequest1 = commSession->sendAsync(&mNumReq, 1, mpi::MpiType::kINT64, peer, mpi::MpiTag::kRequestWithIdNumReq); if (mNumReq > 0) { mPacked = RequestWithId::serializeReqWithIds(reqWithIds); mVecSize = static_cast(mPacked.size()); mRequest2 = commSession->sendAsync(&mVecSize, 1, mpi::MpiType::kINT64, peer, mpi::MpiTag::kRequestWithIdVecSize); mRequest3 = commSession->sendAsync( mPacked.data(), mPacked.size(), mpi::MpiType::kCHAR, peer, mpi::MpiTag::kRequestWithIdPacked); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } ~RequestWithIdAsyncSend() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mRequest1->wait(); if (mRequest2) { mRequest2->wait(); } if (mRequest3) { mRequest3->wait(); } TLLM_LOG_DEBUG("end send requests"); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } RequestWithIdAsyncSend(RequestWithIdAsyncSend const& executor) = delete; RequestWithIdAsyncSend& operator=(RequestWithIdAsyncSend const& executor) = delete; RequestWithIdAsyncSend(RequestWithIdAsyncSend&&) = delete; RequestWithIdAsyncSend& operator=(RequestWithIdAsyncSend&&) = delete; static std::vector requestWithIdRecv( std::shared_ptr const& commSession, int peer) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_LOG_DEBUG("start recv requests from rank %d", peer); std::vector reqWithIds; int64_t numReq{0}; commSession->recv(&numReq, 1, mpi::MpiType::kINT64, peer, mpi::MpiTag::kRequestWithIdNumReq); if (numReq > 0) { std::vector buffer; int64_t vecSize = 0; commSession->recv(&vecSize, 1, mpi::MpiType::kINT64, peer, mpi::MpiTag::kRequestWithIdVecSize); buffer.resize(vecSize); commSession->recv( buffer.data(), buffer.size(), mpi::MpiType::kCHAR, peer, mpi::MpiTag::kRequestWithIdPacked); reqWithIds = RequestWithId::deserializeReqWithIds(buffer); } TLLM_LOG_DEBUG("end recv requests from rank %d", peer); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return reqWithIds; } private: int64_t mNumReq; int64_t mVecSize; std::vector mPacked; std::shared_ptr mRequest1; std::shared_ptr mRequest2; std::shared_ptr mRequest3; }; void Executor::Impl::loadModel(std::optional const& modelPathOpt, std::optional const& engineBufferOpt, runtime::GptJsonConfig const& jsonConfig, ::tensorrt_llm::executor::ExecutorConfig const& executorConfig, bool isEncoder, std::optional> const& managedWeightsOpt) { auto const gpusPerNode = jsonConfig.getGpusPerNode(); auto const tp = jsonConfig.getTensorParallelism(); auto const pp = jsonConfig.getPipelineParallelism(); auto const cp = jsonConfig.getContextParallelism(); auto parallelConfig = executorConfig.getParallelConfig().value_or(ParallelConfig()); auto worldConfig = runtime::WorldConfig::mpi(gpusPerNode, tp, pp, cp, parallelConfig.getDeviceIds()); TLLM_CHECK_WITH_INFO(modelPathOpt.has_value() || engineBufferOpt.has_value(), "Either engine path or deserialized engine buffer should be given to load the model properly."); auto rawEngine = engineBufferOpt.has_value() ? runtime::RawEngine(engineBufferOpt.value().data(), engineBufferOpt.value().size()) : runtime::RawEngine(modelPathOpt.value() / jsonConfig.engineFilename(worldConfig)); if (rawEngine.getType() != tensorrt_llm::runtime::RawEngine::FilePath) { if (modelPathOpt.has_value()) { rawEngine.setPath(modelPathOpt.value() / jsonConfig.engineFilename(worldConfig)); if (managedWeightsOpt.has_value()) { TLLM_LOG_WARNING( "Executor::Impl::loadModel: managedWeightsOpt argument is ignored when loading engine from file."); } } else if (managedWeightsOpt.has_value()) { rawEngine.setManagedWeightsMap(managedWeightsOpt.value()); } } auto const& modelConfig = jsonConfig.getModelConfig(); if (isEncoder) { mEncoderModel = createEncoderModel(rawEngine, modelConfig, worldConfig, executorConfig); } else { mModel = createModel(rawEngine, modelConfig, worldConfig, executorConfig); } }; Executor::Impl::Impl(std::filesystem::path const& modelPath, std::optional const& encoderModelPath, ModelType const modelType, ::tensorrt_llm::executor::ExecutorConfig const& executorConfig) { auto decoderJsonConfig = runtime::GptJsonConfig::parse(modelPath / "config.json"); // for now, assume encoder & decoder models share the same MPI config auto const tp = decoderJsonConfig.getTensorParallelism(); auto const pp = decoderJsonConfig.getPipelineParallelism(); auto const cp = decoderJsonConfig.getContextParallelism(); initializeCommAndWorkers(tp, pp, cp, executorConfig, modelType, modelPath, std::nullopt, decoderJsonConfig); if (mIsWorker) { if (modelType == ModelType::kENCODER_DECODER) { if (encoderModelPath.has_value()) { auto const encoderJsonConfig = runtime::GptJsonConfig::parse(encoderModelPath.value() / "config.json"); auto const encoderMaxInputLen = encoderJsonConfig.getModelConfig().getMaxInputLen(); auto const encoderHiddenSize = encoderJsonConfig.getModelConfig().getHiddenSize() * encoderJsonConfig.getTensorParallelism(); // recover full hidden size // add encoder info to decoder for encoder-decoder models // note: GptJsonConfig can no longer have modelConfig as const member since it must be mutable here decoderJsonConfig.getModelConfigMutable().setMaxEncoderLen(encoderMaxInputLen); decoderJsonConfig.getModelConfigMutable().setEncoderHiddenSize(encoderHiddenSize); loadModel( encoderModelPath.value(), std::nullopt, encoderJsonConfig, executorConfig, true, std::nullopt); } else { TLLM_LOG_WARNING("Encoder model path not provided. Skipping Encoder Run."); } } loadModel(modelPath, std::nullopt, decoderJsonConfig, executorConfig, false, std::nullopt); } initialize(executorConfig); } Executor::Impl::Impl(BufferView const& engineBufferView, std::string const& jsonConfigStr, std::optional const& encoderEngineBufferView, std::optional const& encoderJsonConfigStr, ModelType const modelType, ::tensorrt_llm::executor::ExecutorConfig const& executorConfig, std::optional> const& managedWeightsOpt) { auto decoderJsonConfig = runtime::GptJsonConfig::parse(jsonConfigStr); // for now, assume encoder & decoder models share the same MPI config auto const tp = decoderJsonConfig.getTensorParallelism(); auto const pp = decoderJsonConfig.getPipelineParallelism(); auto const cp = decoderJsonConfig.getContextParallelism(); initializeCommAndWorkers(tp, pp, cp, executorConfig, modelType, std::nullopt, std::nullopt, decoderJsonConfig); if (mIsWorker) { if (modelType == ModelType::kENCODER_DECODER) { TLLM_CHECK(encoderEngineBufferView.has_value() && encoderJsonConfigStr.has_value()); TLLM_CHECK_WITH_INFO( !managedWeightsOpt.has_value(), "Managed weights are not supported for enc-dec models"); auto const encoderJsonConfig = runtime::GptJsonConfig::parse(encoderJsonConfigStr.value()); auto const encoderMaxInputLen = encoderJsonConfig.getModelConfig().getMaxInputLen(); auto const encoderHiddenSize = encoderJsonConfig.getModelConfig().getHiddenSize() * encoderJsonConfig.getTensorParallelism(); // recover full hidden size // add encoder info to decoder for encoder-decoder models // note: GptJsonConfig can no longer have modelConfig as const member since it must be mutable here decoderJsonConfig.getModelConfigMutable().setMaxEncoderLen(encoderMaxInputLen); decoderJsonConfig.getModelConfigMutable().setEncoderHiddenSize(encoderHiddenSize); loadModel( std::nullopt, encoderEngineBufferView.value(), encoderJsonConfig, executorConfig, true, std::nullopt); } loadModel(std::nullopt, engineBufferView, decoderJsonConfig, executorConfig, false, managedWeightsOpt); } initialize(executorConfig); } Executor::Impl::Impl(std::shared_ptr model, std::optional> encoderModel, ::tensorrt_llm::executor::ExecutorConfig const& executorConfig) { auto const& worldConfig = model->getWorldConfig(); auto const tp = worldConfig.getTensorParallelism(); auto const pp = worldConfig.getPipelineParallelism(); auto const cp = worldConfig.getContextParallelism(); auto const modelType = encoderModel.has_value() ? ModelType::kENCODER_DECODER : ModelType::kDECODER_ONLY; initializeCommAndWorkers(tp, pp, cp, executorConfig, modelType, std::nullopt, worldConfig); if (modelType == ModelType::kENCODER_DECODER) { mEncoderModel = encoderModel.value(); } mModel = std::move(model); initialize(executorConfig); } Executor::Impl::~Impl() { shutdown(); } void Executor::Impl::initialize(::tensorrt_llm::executor::ExecutorConfig const& executorConfig) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mShutdown = false; mShutdownCalled = false; mIterStatsMaxIterations = executorConfig.getIterStatsMaxIterations(); mRequestStatsMaxIterations = executorConfig.getRequestStatsMaxIterations(); mDebugTensorsMaxIterations = executorConfig.getDebugConfig() ? executorConfig.getDebugConfig()->getDebugTensorsMaxIterations() : 0; TLLM_CHECK_WITH_INFO(mDebugTensorsMaxIterations == 0 || mCommMode == CommunicationMode::kLEADER, "debugTensorsMaxIterations > 0 is only allowed in leader mode."); mBatchingType = executorConfig.getBatchingType(); mIsSchedulerMaxUtilization = (executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy() == CapacitySchedulerPolicy::kMAX_UTILIZATION); mIsSchedulerGuaranteedNoEvict = (executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy() == CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT); mIsChunkedContext = executorConfig.getEnableChunkedContext(); mPromptTableOffloading = executorConfig.getPromptTableOffloading(); mMaxQueueSize = executorConfig.getMaxQueueSize(); mLastReqId = 1; auto const& logitsProcConfig = executorConfig.getLogitsPostProcessorConfig(); if (logitsProcConfig.has_value()) { mLogitsPostProcessorMap = logitsProcConfig.value().getProcessorMap().value_or(LogitsPostProcessorMap{}); initializeLogitsPostProcessorBatched(logitsProcConfig.value()); if (!logitsProcConfig.value().getReplicate()) { mModel->setReplicateLogitsPostProcessor(false); } } auto const& commComm = COMM_SESSION; int32_t const commSize = commComm.getSize(); if (mIsWorker) { if (commSize > 1) { auto const& worldConfig = mModel->getWorldConfig(); auto const& commSession = COMM_SESSION; auto const& rank = commSession.getRank(); auto const& tp = worldConfig.getTensorParallelism(); auto const& cp = worldConfig.getContextParallelism(); mCommTensorParallel = std::make_shared( commSession.split(rank / tp, worldConfig.getTensorParallelRank())); mCommContextParallel = std::make_shared( commSession.split(rank / (tp * cp) * tp + rank % tp, worldConfig.getContextParallelRank())); mCommPipelineParallel = std::make_shared( commSession.split(rank % (tp * cp), worldConfig.getPipelineParallelRank())); if (worldConfig.isPipelineParallel()) { mRequestWithIdWaitThread = std::make_unique( "requestWithIdWaitThread", [this]() { mRequestWithIdAsyncSndHdl.reset(nullptr); }); mCancelledRequestsWaitThread = std::make_unique( "cancelledRequestsWaitThread", [this]() { mCancelledRequestsAsyncSndHdl.reset(nullptr); }); if (mIsLeader) { mRequestWithIdLeaderThread = std::make_unique(&Executor::Impl::requestWithIdLeaderThread, this); mCancelledRequestsLeaderThread = std::make_unique(&Executor::Impl::cancelledRequestsLeaderThread, this); } } } // Launch the execution thread mMaxNumActiveRequests = mModel->getMaxNumSequences(); mExecutionThread = std::thread(&Impl::executionLoop, this); } mEnableBlockReuse = executorConfig.getKvCacheConfig().getEnableBlockReuse(); auto const& dynamicBatchConfig = executorConfig.getSchedulerConfig().getDynamicBatchConfig(); if (dynamicBatchConfig) { if (mIsWorker) { if (mModel->getModelConfig().isTransformerBased() && mModel->getModelConfig().isKVCacheEnabled()) { mDynamicBatchTuner = std::make_shared(dynamicBatchConfig.value()); } else { TLLM_LOG_WARNING("Dynamic batch tuner can only support transformer models that use KV cache."); } } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } std::shared_ptr Executor::Impl::createModel(runtime::RawEngine const& rawEngine, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, ::tensorrt_llm::executor::ExecutorConfig const& executorConfig) { auto const gptModelType = [&executorConfig, &modelConfig]() { switch (executorConfig.getBatchingType()) { case BatchingType::kSTATIC: TLLM_THROW( "Static batching type is deprecated. Please use in-flight batching with " "CapacitySchedulerPolicy::kSTATIC_BATCH instead."); case BatchingType::kINFLIGHT: return modelConfig.isRnnBased() ? batch_manager::TrtGptModelType::InflightBatching : batch_manager::TrtGptModelType::InflightFusedBatching; default: TLLM_THROW("Invalid batching strategy"); } }(); bool const isLeaderInOrchMode = (mCommMode == CommunicationMode::kORCHESTRATOR) && mIsLeader; auto const& fixedExecutorConfig = executorConfigIsValid(executorConfig, modelConfig) ? executorConfig : fixExecutorConfig(executorConfig, modelConfig); return batch_manager::TrtGptModelFactory::create( rawEngine, modelConfig, worldConfig, gptModelType, fixedExecutorConfig, isLeaderInOrchMode); } std::shared_ptr Executor::Impl::createEncoderModel(runtime::RawEngine const& rawEngine, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, ::tensorrt_llm::executor::ExecutorConfig const& executorConfig) { auto fixedExecutorConfig = ExecutorConfig{}; fixedExecutorConfig.setSchedulerConfig(executorConfig.getSchedulerConfig()); return std::make_shared( modelConfig, worldConfig, rawEngine, std::make_shared(), fixedExecutorConfig); } void Executor::Impl::setOrchLeaderComm( SizeType32 tp, SizeType32 pp, SizeType32 cp, ParallelConfig const& parallelConfig) { #if ENABLE_MULTI_DEVICE auto optOrchestratorConfig = parallelConfig.getOrchestratorConfig(); if (optOrchestratorConfig.value().getIsOrchestrator()) { TLLM_CHECK_WITH_INFO(mWorldRank == 0, "Rank 0 must be orchestrator"); } TLLM_CHECK_WITH_INFO(parallelConfig.getParticipantIds(), "When not spawning processes in orchestrator mode, participant IDs must be provided"); auto participantIds = parallelConfig.getParticipantIds().value(); TLLM_CHECK_WITH_INFO(static_cast(participantIds.size()) == tp * pp * cp, "When specifying participantIds, participantIds size must be equal to tp*pp*cp"); bool isLeader = (mWorldRank == participantIds.front()); bool isOrchestrator = (mWorldRank == 0); // OrchLeaderComm rank 0 is orchestrator, rank 1 is leader mOrchRank = 0; mLeaderRank = 1; // Create a leaderOrch comm std::vector leaderOrchRanks{0, participantIds.front()}; MPI_Group worldGroup = nullptr; MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); // NOLINT int worldGroupRank = 0; MPI_Group_rank(worldGroup, &worldGroupRank); int worldSize = 0; MPICHECK(MPI_Group_size(worldGroup, &worldSize)); // NOLINT TLLM_CHECK_WITH_INFO(participantIds.front() < worldSize, "Not enough ranks in world"); MPI_Group leaderOrchCommGroup = nullptr; MPICHECK( MPI_Group_incl(worldGroup, leaderOrchRanks.size(), leaderOrchRanks.data(), &leaderOrchCommGroup)); // NOLINT int leaderOrchGroupRank = 0; int leaderOrchGroupSize = 0; MPI_Group_rank(leaderOrchCommGroup, &leaderOrchGroupRank); MPI_Group_size(leaderOrchCommGroup, &leaderOrchGroupSize); if (isOrchestrator || isLeader) { MPI_Comm leaderOrchComm = nullptr; MPICHECK(MPI_Comm_create_group( MPI_COMM_WORLD, leaderOrchCommGroup, participantIds.front(), &leaderOrchComm)); // NOLINT mOrchLeaderComm = std::make_shared(leaderOrchComm, false); } else { mOrchLeaderComm = nullptr; } #endif // ENABLE_MULTI_DEVICE } void Executor::Impl::initializeCommAndWorkers(SizeType32 tp, SizeType32 pp, SizeType32 cp, ::tensorrt_llm::executor::ExecutorConfig const& executorConfig, std::optional modelType, std::optional const& modelPath, std::optional const& worldConfig, std::optional const& decoderGptJsonConfig) { if (modelType.has_value() && modelType.value() == ModelType::kENCODER_DECODER) { TLLM_CHECK_WITH_INFO(pp == 1, "Encoder-Decoder C++ runtime doesn't support Pipeline Parallelism currently. Please switch to Python " "runtime for PP mode, if necessary."); } tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE); mWorldRank = tensorrt_llm::mpi::MpiComm::world().getRank(); mUsePipelineParallel = pp > 1; auto parallelConfig = executorConfig.getParallelConfig().value_or(ParallelConfig()); validateParallelConfig(parallelConfig, modelType, modelPath); mCommMode = parallelConfig.getCommunicationMode(); auto optOrchestratorConfig = parallelConfig.getOrchestratorConfig(); mRecvPollPeriodMs = executorConfig.getRecvPollPeriodMs(); // Need to create communicator between orchestrator and leader if not spawning processes in orchestrator mode if (mCommMode == CommunicationMode::kORCHESTRATOR && !optOrchestratorConfig.value().getSpawnProcesses()) { setOrchLeaderComm(tp, pp, cp, parallelConfig); } if (mCommMode == CommunicationMode::kORCHESTRATOR && optOrchestratorConfig.value().getIsOrchestrator()) { initializeOrchestrator(tp, pp, cp, executorConfig, parallelConfig, modelType.value(), modelPath.value()); } else { initializeWorkers(tp, pp, cp, parallelConfig, worldConfig, decoderGptJsonConfig); } } void Executor::Impl::validateParallelConfig(ParallelConfig const& parallelConfig, std::optional modelType, std::optional const& modelPath) { TLLM_CHECK_WITH_INFO(parallelConfig.getCommunicationType() == CommunicationType::kMPI, "Only CommunicationType kMPI is supported for now."); auto optOrchestratorConfig = parallelConfig.getOrchestratorConfig(); if (parallelConfig.getCommunicationMode() == CommunicationMode::kORCHESTRATOR) { TLLM_CHECK_WITH_INFO( optOrchestratorConfig, "OrchestratorConfig must be set when using ORCHESTRATOR communication mode."); TLLM_CHECK_WITH_INFO(modelPath, "OrchestratorMode only supports reading model weight from disk currently."); TLLM_CHECK_WITH_INFO(modelType, "OrchestratorMode requires modelType to be specified."); } } void Executor::Impl::initializeOrchestrator(SizeType32 tp, SizeType32 pp, SizeType32 cp, ::tensorrt_llm::executor::ExecutorConfig const& executorConfig, ParallelConfig parallelConfig, ModelType modelType, std::filesystem::path const& modelPath) { #if ENABLE_MULTI_DEVICE namespace su = tensorrt_llm::executor::serialize_utils; auto const& worldComm = tensorrt_llm::mpi::MpiComm::world(); int32_t const worldSize = worldComm.getSize(); auto orchestratorConfig = parallelConfig.getOrchestratorConfig().value(); mIsWorker = false; mIsLeader = false; mIsPipelineLeader = false; mIsOrchestrator = true; // Verify that worldSize is 1 if (orchestratorConfig.getSpawnProcesses()) { TLLM_CHECK_WITH_INFO(worldSize == 1, "When using the orchestrator mode and isOrchestrator is true, expect MPI worldSize to be 1."); // Spawn the worker threads auto workerExecPath = orchestratorConfig.getWorkerExecutablePath(); MPI_Comm intercomm = nullptr; MPI_Info mpiInfo = nullptr; MPICHECK(MPI_Info_create(&mpiInfo)); MPICHECK(MPI_Info_set(mpiInfo, "env", "FORCE_NCCL_ALL_REDUCE_STRATEGY")); // Binding policy is not inherited for dynamically spawned jobs, resulting in the worker being bound // to a single core. Override the setting to avoid perf issue - see https://nvbugs/4574329 MPICHECK(MPI_Info_set(mpiInfo, "bind_to", "none")); MPICHECK(MPI_Comm_spawn(workerExecPath.c_str(), MPI_ARGV_NULL, tp * pp * cp, mpiInfo, 0, MPI_COMM_SELF, &intercomm, MPI_ERRCODES_IGNORE)); mOrchLeaderComm = std::make_shared(intercomm, true); // With intercomm, leader is rank 0 in the local group mLeaderRank = 0; mOrchRank = 0; // Copy the executor config, but set the orchestrator flag to false auto newOrchConfig = OrchestratorConfig(false, orchestratorConfig.getWorkerExecutablePath()); parallelConfig.setOrchestratorConfig(newOrchConfig); auto execConfig = executorConfig; execConfig.setParallelConfig(parallelConfig); // Serialize and send the executorConfig, the modelType and the modelPath std::ostringstream oStream; su::serialize(modelPath.string(), oStream); su::serialize(modelType, oStream); su::serialize(execConfig, oStream); auto str = oStream.str(); std::vector buffer(str.begin(), str.end()); auto bufferSize = static_cast(buffer.size()); mOrchLeaderComm->bcast(&bufferSize, 1, mpi::MpiType::kINT64, MPI_ROOT); mOrchLeaderComm->bcast(buffer.data(), buffer.size(), mpi::MpiType::kCHAR, MPI_ROOT); // Wait for workers to have created their executor instance MPICHECK(MPI_Barrier(intercomm)); } // Spawn the thread responsible for sending new requests to the leader of the model mOrchSendReqThread = std::thread(&Impl::orchSendReqThread, this); // Spawn the thread responsible for receiving new responses from the leader of the model mOrchRecvThread = std::thread([&]() { this->orchRecvThread(mpi::MpiTag::kOrchestratorId, mpi::MpiTag::kOrchestratorData); }); #endif // ENABLE_MULTI_DEVICE } void Executor::Impl::initializeWorkers(SizeType32 tp, SizeType32 pp, SizeType32 cp, ParallelConfig& parallelConfig, std::optional const& worldConfig, std::optional const& decoderGptJsonConfig) { auto const& worldComm = tensorrt_llm::mpi::MpiComm::world(); int32_t const worldSize = worldComm.getSize(); auto const& orchestratorConfig = parallelConfig.getOrchestratorConfig(); mIsOrchestrator = mCommMode == CommunicationMode::kORCHESTRATOR && orchestratorConfig.value().getIsOrchestrator(); TLLM_CHECK_WITH_INFO(mCommMode != CommunicationMode::kORCHESTRATOR || orchestratorConfig.has_value(), "When using ORCHESTRATOR mode, orchestrator config must be set"); if (mCommMode == CommunicationMode::kORCHESTRATOR && !orchestratorConfig.value().getSpawnProcesses()) { TLLM_CHECK_WITH_INFO(parallelConfig.getParticipantIds(), "When not spawning processes in orchestrator mode, participant IDs must be provided"); // Check that rank 0 is reserved for the orchestrator auto const participantIds = parallelConfig.getParticipantIds().value(); for (auto const& participantId : participantIds) { TLLM_CHECK_WITH_INFO(participantId != 0, "Rank 0 is reserved for the orchestrator"); } } // Participant ids std::vector participantIds; if (!parallelConfig.getParticipantIds()) { TLLM_CHECK_WITH_INFO(worldSize == tp * pp * cp, "With communicationMode kLEADER, MPI worldSize is expected to be equal to tp*pp*cp when " "participantIds are not specified"); participantIds.resize(tp * pp * cp); std::iota(participantIds.begin(), participantIds.end(), 0); } else { if (mCommMode == CommunicationMode::kORCHESTRATOR && orchestratorConfig.value().getSpawnProcesses()) { TLLM_THROW( "Participant ids should not be set when using CommunicationMode::kORCHESTRATOR with " "spawnProcesses=true"); } participantIds = parallelConfig.getParticipantIds().value(); TLLM_CHECK_WITH_INFO(static_cast(participantIds.size()) == tp * pp * cp, tensorrt_llm::common::fmtstr("When specifying participantIds, participantIds size (%lu) must be equal to " "tp*pp*cp (tp is %u, pp is %u, cp is %u)", participantIds.size(), tp, pp, cp)); } // If deviceIds are specified, check that they match tp*pp*cp if (parallelConfig.getDeviceIds()) { auto deviceIds = parallelConfig.getDeviceIds().value(); auto const hasNumNodes = parallelConfig.getNumNodes().has_value(); if (hasNumNodes || static_cast(deviceIds.size()) != tp * pp * cp) { auto const numNodes = hasNumNodes ? parallelConfig.getNumNodes().value() : tensorrt_llm::mpi::getNumNodes(); TLLM_CHECK_WITH_INFO(static_cast(deviceIds.size() * numNodes) == tp * pp * cp, tensorrt_llm::common::fmtstr("When specifying deviceIds, deviceIds (%lu) * numNodes (%u) must be equal " "to tp*pp*cp (tp is %u, pp is %u, cp is %u)", deviceIds.size(), numNodes, tp, pp, cp)); } } // Bool that indicates if current process is worker for this model or not auto participantIt = std::find(participantIds.begin(), participantIds.end(), mWorldRank); mIsWorker = participantIt != participantIds.end(); // Bool that indicates if current ranks is leader for this model mIsLeader = (mWorldRank == participantIds.front()); mIsPipelineLeader = (mWorldRank == participantIds[tp * (pp - 1)]); #if ENABLE_MULTI_DEVICE if (mIsWorker) { // Create a session, but only assign to COMM_SESSION for ranks participating in this model MPI_Group worldGroup = MPI_GROUP_NULL; MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); // NOLINT MPI_Group sessionGroup = MPI_GROUP_NULL; if (pp > 1) { // reverse participantIds to move leader to last pp rank. retain order in each tp group std::reverse(participantIds.begin(), participantIds.end()); if (tp > 1) { for (SizeType32 ppRank = 0; ppRank < pp; ppRank++) { std::reverse(participantIds.begin() + ppRank * tp, participantIds.begin() + (ppRank + 1) * tp); } } } MPICHECK(MPI_Group_incl(worldGroup, participantIds.size(), participantIds.data(), &sessionGroup)); // NOLINT MPI_Comm sessionComm = MPI_COMM_NULL; MPICHECK( MPI_Comm_create_group(MPI_COMM_WORLD, sessionGroup, 1000 + participantIds.front(), &sessionComm)); // NOLINT tensorrt_llm::mpi::MpiComm::setSession(tensorrt_llm::mpi::MpiComm(sessionComm, false)); } if (mIsLeader && mCommMode == CommunicationMode::kORCHESTRATOR) { auto optOrchestratorConfig = parallelConfig.getOrchestratorConfig(); if (orchestratorConfig.has_value() && orchestratorConfig.value().getSpawnProcesses()) { mOrchLeaderComm = optOrchestratorConfig.value().getOrchLeaderComm(); } else { // mOrchLeaderComm has already been created } TLLM_CHECK(mOrchLeaderComm.get() != nullptr); TLLM_CHECK(worldConfig.has_value() || decoderGptJsonConfig.has_value()); if (worldConfig.has_value()) { mDeviceId = worldConfig->getDevice(); } else { auto gpusPerNode = decoderGptJsonConfig->getGpusPerNode(); auto worldConfig = runtime::WorldConfig::mpi(gpusPerNode, tp, pp, cp, parallelConfig.getDeviceIds()); mDeviceId = worldConfig.getDevice(); } // Spawn the thread responsible for receiving new requests from the orchestrator mLeaderRecvReqThread = std::thread(&Impl::leaderRecvReqThread, this); // Spawn the thread responsible for sending new responses to the orchestrator mLeaderSendThread = std::thread([&]() { this->leaderSendThread(mSendQueue, mpi::MpiTag::kOrchestratorId, mpi::MpiTag::kOrchestratorData); }); } #endif // ENABLE_MULTI_DEVICE } void Executor::Impl::initializeLogitsPostProcessorBatched(LogitsPostProcessorConfig const& logitsProcConfig) { if (logitsProcConfig.getProcessorBatched().has_value()) { mLogitsPostProcessorBatched = [cb = logitsProcConfig.getProcessorBatched().value()]( std::vector const& reqIdsVec, std::vector& logitsVec, std::vector> const& beamTokensVec, CudaStreamPtr const& cudaStreamPtr, std::vector> const& clientIdsVec) { std::vector cbLogitsVec; cbLogitsVec.reserve(logitsVec.size()); for (auto& logits : logitsVec) { cbLogitsVec.emplace_back(executor::detail::ofITensor(logits)); } cb(reqIdsVec, cbLogitsVec, beamTokensVec, cudaStreamPtr, clientIdsVec); }; mModel->setLogitsPostProcessorBatched(mLogitsPostProcessorBatched); } } IdType Executor::Impl::enqueueRequest(Request const& request) { return enqueueRequests({&request, 1}).at(0); } std::vector Executor::Impl::enqueueRequests(std::vector const& requests) { return enqueueRequests({requests.data(), requests.size()}); } std::vector Executor::Impl::enqueueRequests(common::ArrayView const& requests) { TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called, cannot enqueue requests"); checkParallelApiUsage(__func__); TLLM_LOG_DEBUG("Enqueuing %lu requests", requests.size()); std::vector requestWithIds; requestWithIds.reserve(requests.size()); // First check valid of request in enqueue thread, so Exceptions can be thrown to user. for (auto const& req : requests) { auto logitsPostProcessorName = req.getLogitsPostProcessorName(); if (logitsPostProcessorName && logitsPostProcessorName.value() != Request::kBatchedPostProcessorName) { getLogitsPostProcessor(*logitsPostProcessorName); } } std::vector ids; { auto now = std::chrono::steady_clock::now(); for (auto const& req : requests) { ids.emplace_back(generateReqId()); TLLM_LOG_DEBUG("Enqueue new request with id %d", ids.back()); std::vector childReqIds; auto numChildRequests = getNumChildRequests(req); if (numChildRequests > 0) { childReqIds.reserve(numChildRequests); for (int childId = 0; childId < numChildRequests; childId++) { childReqIds.emplace_back(generateReqId()); TLLM_LOG_DEBUG("Add new child request with id %d", childReqIds.back()); } } requestWithIds.emplace_back(RequestWithId{req, ids.back(), std::move(childReqIds), now}); } } if (mCommMode == CommunicationMode::kLEADER) { { std::scoped_lock const lck(mQueuedReqMtx); if (mMaxQueueSize) { auto const maxQueueSize = mMaxQueueSize.value(); auto totalRequestSize = 0; for (auto&& reqWithId : requestWithIds) { totalRequestSize += (getNumChildRequests(reqWithId.req) + 1); } if (maxQueueSize > 0 && mQueuedRequests.size() + totalRequestSize > static_cast(maxQueueSize)) { TLLM_THROW("Maximum queue size of %d has been reached, please try again later", maxQueueSize); } } for (auto&& req : requestWithIds) { insertRequestInOrder(mQueuedRequests, std::move(req)); } } mQueuedReqCv.notify_one(); } else if (mCommMode == CommunicationMode::kORCHESTRATOR) { MpiMessage message(MpiId::PENDING_REQUEST); message.data = PendingRequestData{std::move(requestWithIds)}; mSendQueue.push(std::move(message)); } return ids; } std::vector Executor::Impl::awaitResponses(std::optional const& timeout) { TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called"); checkParallelApiUsage(__func__); std::unique_lock lck(mResponsesMtx); auto pred = [this]() -> bool { return !mResponses.empty() || mShutdown; }; auto storeResponses = [this]() { std::vector responses; for (auto it = mResponses.begin(); it != mResponses.end();) { responses.insert(responses.end(), it->second.begin(), it->second.end()); addTerminatedReqId(it->second, it->first); it = mResponses.erase(it); } return responses; }; std::vector responses; if (timeout) { if (mResponsesCv.wait_for(lck, timeout.value(), pred)) { responses = storeResponses(); } } else { mResponsesCv.wait(lck, pred); responses = storeResponses(); } return responses; } std::vector Executor::Impl::awaitResponses( IdType const& reqId, std::optional const& timeout) { TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called"); checkParallelApiUsage(__func__); std::unique_lock lck(mResponsesMtx); auto pred = [this, reqId]() -> bool { return (mResponses.find(reqId) != mResponses.end() && !mResponses.at(reqId).empty()) || mShutdown; }; auto storeIdResponse = [this, reqId]() { std::vector responses; responses.swap(mResponses.at(reqId)); mResponses.erase(reqId); addTerminatedReqId(responses, reqId); return responses; }; // We don't process a terminated request again. Terminated request is defined as a response // with isFinal = true for a given requestId. if (mTerminatedReqIds.contains(reqId)) { if (mResponses.find(reqId) != mResponses.end()) { TLLM_THROW("ReqId should already be removed from responses!"); } std::string const err = "ReqId " + std::to_string(reqId) + " has already been processed and was terminated."; TLLM_LOG_ERROR("%s", err.c_str()); return {Response(reqId, err)}; } std::vector responses; if (timeout) { if (mResponsesCv.wait_for(lck, timeout.value(), pred)) { responses = storeIdResponse(); } } else { mResponsesCv.wait(lck, pred); responses = storeIdResponse(); } return responses; } std::vector> Executor::Impl::awaitResponses( std::vector const& requestIds, std::optional const& timeout) { TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called"); checkParallelApiUsage(__func__); std::vector> responses; responses.reserve(requestIds.size()); if (timeout) { auto const start_time = std::chrono::high_resolution_clock::now(); for (auto const requestId : requestIds) { auto const elapsed_ms = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - start_time); responses.emplace_back(awaitResponses( requestId, timeout.value() > elapsed_ms ? timeout.value() - elapsed_ms : std::chrono::milliseconds{0})); } } else { for (auto const requestId : requestIds) { responses.emplace_back(awaitResponses(requestId)); } } return responses; } SizeType32 Executor::Impl::getNumResponsesReady(std::optional const& optId) const { TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called"); checkParallelApiUsage(__func__); std::scoped_lock lck(mResponsesMtx); SizeType32 numResponsesReady = 0; if (optId) { auto const reqId = optId.value(); auto const respIt = mResponses.find(reqId); if (respIt != mResponses.end()) { numResponsesReady = static_cast(respIt->second.size()); } } else { for (auto const& [id, responses] : mResponses) { numResponsesReady += static_cast(responses.size()); } } return numResponsesReady; } void Executor::Impl::shutdown() { // Cannot call shutdown multiple times if (mShutdownCalled) { return; } mShutdownCalled = true; if (!mShutdown) { if (mCommMode == CommunicationMode::kLEADER && mIsLeader) { // Enqueue a request to indicate to other ranks to terminate enqueueTerminateRequest(); } else if (mCommMode == CommunicationMode::kORCHESTRATOR) { if (mIsOrchestrator) { // Send to the leader the termination signal mShutdown = true; mResponsesCv.notify_all(); mSendQueue.push(MpiMessage(MpiId::TERMINATION)); // Wait for sender thread to exit if (mOrchSendReqThread.joinable()) { mOrchSendReqThread.join(); } // Wait for recv response thread to exit if (mOrchRecvThread.joinable()) { mOrchRecvThread.join(); } } else if (mIsLeader) { // Wait for sender thread to exit if (mLeaderRecvReqThread.joinable()) { mLeaderRecvReqThread.join(); } // Wait for send response thread to exit if (mLeaderSendThread.joinable()) { mLeaderSendThread.join(); } } } } // Wait for execution thread to terminate if (mExecutionThread.joinable()) { mExecutionThread.join(); } // If we overwrote COMM_SESSION with split, free it now. Otherwise, since // COMM_SESSION is a global static object, it will be destroyed in an // undefined order and can cause crashes on program exit. if (mIsWorker) { tensorrt_llm::mpi::MpiComm::setSession(tensorrt_llm::mpi::MpiComm(MPI_COMM_WORLD, false)); } } void Executor::Impl::cancelRequest(IdType requestId) { TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called"); checkParallelApiUsage(__func__); // Check if the request is terminated already. If so, return { std::scoped_lock lckResp(mResponsesMtx); if (mTerminatedReqIds.contains(requestId)) { TLLM_LOG_INFO("Ignoring already terminated request %lu", requestId); return; } } if (mCommMode == CommunicationMode::kLEADER) { std::scoped_lock lck(mCancelReqMtx); auto& selCancelledReqIds = mUsePipelineParallel ? mPipelineCancelledReqIds : mCancelledReqIds; selCancelledReqIds.insert(requestId); } else if (mCommMode == CommunicationMode::kORCHESTRATOR) { MpiMessage message(MpiId::CANCEL_REQUEST); std::vector cancelledReqIds{requestId}; message.data = RequestIdsData{std::move(cancelledReqIds)}; mSendQueue.push(std::move(message)); } } std::deque Executor::Impl::getLatestIterationStats() { TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called"); checkParallelApiUsage(__func__); std::scoped_lock lck(mIterStatsMtx); return std::exchange(mIterationStats, {}); } std::deque Executor::Impl::getLatestRequestStats() { TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called"); checkParallelApiUsage(__func__); std::scoped_lock lck(mRequestStatsMtx); return std::exchange(mRequestStats, {}); } std::deque Executor::Impl::getLatestDebugTensors() { TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called"); if (mCommMode == CommunicationMode::kORCHESTRATOR) { TLLM_LOG_WARNING("getLatestDebugTensors is not supported in ORCHESTRATOR mode yet"); return {}; } if (mEncoderModel) { TLLM_LOG_WARNING("getLatestDebugTensors is not supported for encoder model yet"); } std::scoped_lock lck(mDebugTensorsMtx); return std::exchange(mDebugTensors, {}); } bool Executor::Impl::canEnqueueRequests() const { return !mShutdownCalled && ((mCommMode == CommunicationMode::kLEADER && mIsLeader) || (mCommMode == CommunicationMode::kORCHESTRATOR && mIsOrchestrator)); } bool Executor::Impl::isParticipant() const { return mIsWorker; } std::optional> Executor::Impl::getKVCacheEventManager() const { if (!mModel) { return std::nullopt; } auto cacheEventManager = mModel->getKVCacheManager(); return cacheEventManager ? std::optional(std::make_shared(cacheEventManager)) : std::nullopt; } void Executor::Impl::requestWithIdLeaderThread() { TLLM_CUDA_CHECK(cudaSetDevice(mModel->getWorldConfig().getDevice())); auto constexpr peer = 0; while (true) { int64_t numActiveRequests; mCommPipelineParallel->recv( &numActiveRequests, 1, mpi::MpiType::kINT64, peer, mpi::MpiTag::kExecutorNumActiveRequests); if (numActiveRequests < 0) { break; } bool lowestPriorityActiveHasValue; std::optional lowestPriorityActive; mCommPipelineParallel->recv(&lowestPriorityActiveHasValue, 1, mpi::MpiType::kBOOL, peer, mpi::MpiTag::kExecutorLowestPriorityActiveHasValue); if (lowestPriorityActiveHasValue) { PriorityType lowestPriorityActiveValue; mCommPipelineParallel->recv( &lowestPriorityActiveValue, 1, mpi::MpiType::kFLOAT, peer, mpi::MpiTag::kExecutorLowestPriorityActive); lowestPriorityActive = lowestPriorityActiveValue; } auto reqWithIds = getLeaderNewReqWithIds(numActiveRequests, lowestPriorityActive); setupDynamicLogitsPostProcessors(reqWithIds); auto requestWithIdAsyncSndHdl = std::make_unique(mCommPipelineParallel, reqWithIds, peer); requestWithIdAsyncSndHdl.reset(nullptr); } } void Executor::Impl::cancelledRequestsLeaderThread() { TLLM_CUDA_CHECK(cudaSetDevice(mModel->getWorldConfig().getDevice())); auto constexpr peer = 0; while (true) { bool shouldExit; mCommPipelineParallel->recv(&shouldExit, 1, mpi::MpiType::kBOOL, peer, mpi::MpiTag::kExecutorShouldExit); if (shouldExit) { break; } std::unique_ptr cancelledRequestsAsyncSndHdl; { std::scoped_lock lck(mCancelReqMtx); cancelledRequestsAsyncSndHdl = std::make_unique(mCommPipelineParallel, mPipelineCancelledReqIds, peer); mPipelineCancelledReqIds.clear(); } cancelledRequestsAsyncSndHdl.reset(nullptr); } } std::vector Executor::Impl::getLeaderNewReqWithIds( SizeType32 numActiveRequests, std::optional lowestPriorityActive) { std::unique_lock lck(mQueuedReqMtx); mQueuedReqCv.wait(lck, [&]() { return (!mQueuedRequests.empty() || numActiveRequests > 0 || mShutdown); }); std::vector reqWithIds; if (mQueuedRequests.empty() || mShutdown) { return reqWithIds; } if (mQueuedRequests.front().id == mTerminateReqId) { reqWithIds.emplace_back(std::move(mQueuedRequests.front())); mQueuedRequests.pop_front(); return reqWithIds; } auto const& firstRequest = mQueuedRequests.front(); auto const firstBeamWidth = firstRequest.req.getSamplingConfig().getBeamWidth(); auto const operatingBeamWidth = numActiveRequests > 0 ? mModel->getOperatingBeamWidth() : firstBeamWidth; auto const tryInsertQueuedRequestIntoReqWithIds = [this, &reqWithIds, operatingBeamWidth]() -> bool { auto& nextRequest = mQueuedRequests.front(); auto const beamWidth = nextRequest.req.getSamplingConfig().getBeamWidth(); if (beamWidth != operatingBeamWidth) { TLLM_LOG_INFO( "Can't dequeue request with ID %ld because beam width %d differs from operating beam width %d.", nextRequest.id, beamWidth, operatingBeamWidth); return false; } TLLM_LOG_DEBUG("Dequeue request with ID %ld", nextRequest.id); reqWithIds.emplace_back(std::move(nextRequest)); mQueuedRequests.pop_front(); return true; }; auto const maxNewRequests = static_cast(std::max(mMaxNumActiveRequests - numActiveRequests, 0)); for (size_t req = 0; !mQueuedRequests.empty() && req < maxNewRequests;) { req += (getNumChildRequests(mQueuedRequests.front().req) + 1); if (req > maxNewRequests) { break; } if (!tryInsertQueuedRequestIntoReqWithIds()) { break; } } if (lowestPriorityActive) { while (!mQueuedRequests.empty() && mQueuedRequests.front().req.getPriority() > (*lowestPriorityActive)) { if (!tryInsertQueuedRequestIntoReqWithIds()) { break; } } } return reqWithIds; } std::vector Executor::Impl::getNewReqWithIds( SizeType32 numActiveRequests, std::optional lowestPriorityActive) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const& worldConfig = mModel->getWorldConfig(); if (worldConfig.isPipelineParallel()) { mRequestWithIdWaitThread->waitStop(); } TLLM_CUDA_CHECK(cudaSetDevice(mModel->getWorldConfig().getDevice())); std::vector reqWithIds; if (mIsPipelineLeader) { if (!worldConfig.isPipelineParallel()) { reqWithIds = getLeaderNewReqWithIds(numActiveRequests, lowestPriorityActive); setupDynamicLogitsPostProcessors(reqWithIds); } else { auto const peer = worldConfig.getPipelineParallelism() - 1; auto numActiveRequestsValue = static_cast(numActiveRequests); auto request1 = mCommPipelineParallel->sendAsync( &numActiveRequestsValue, 1, mpi::MpiType::kINT64, peer, mpi::MpiTag::kExecutorNumActiveRequests); bool lowestPriorityActiveHasValue = lowestPriorityActive.has_value(); auto request2 = mCommPipelineParallel->sendAsync(&lowestPriorityActiveHasValue, 1, mpi::MpiType::kBOOL, peer, mpi::MpiTag::kExecutorLowestPriorityActiveHasValue); auto request3 = lowestPriorityActiveHasValue ? mCommPipelineParallel->sendAsync(&lowestPriorityActive.value(), 1, mpi::MpiType::kFLOAT, peer, mpi::MpiTag::kExecutorLowestPriorityActive) : nullptr; request1->wait(); request2->wait(); if (request3) { request3->wait(); } reqWithIds = RequestWithIdAsyncSend::requestWithIdRecv(mCommPipelineParallel, peer); } if (worldConfig.isTensorParallel() || worldConfig.isContextParallel()) { auto packed = RequestWithId::serializeReqWithIds(reqWithIds); if (worldConfig.isTensorParallel()) { mCommTensorParallel->bcast(packed, 0); } if (worldConfig.isContextParallel()) { mCommContextParallel->bcast(packed, 0); } } } else { if (worldConfig.isFirstPipelineParallelRank()) { std::vector buffer; mCommTensorParallel->bcast(buffer, 0); mCommContextParallel->bcast(buffer, 0); reqWithIds = RequestWithId::deserializeReqWithIds(buffer); } else { auto const peer = worldConfig.getPipelineParallelRank() - 1; reqWithIds = RequestWithIdAsyncSend::requestWithIdRecv(mCommPipelineParallel, peer); } } if (!worldConfig.isLastPipelineParallelRank()) { auto const peer = worldConfig.getPipelineParallelRank() + 1; mRequestWithIdAsyncSndHdl = std::make_unique(mCommPipelineParallel, reqWithIds, peer); mRequestWithIdWaitThread->notifyStart(); } TLLM_CUDA_CHECK(cudaSetDevice(mModel->getWorldConfig().getDevice())); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return reqWithIds; } std::tuple Executor::Impl::fetchNewRequests( SizeType32 numActiveRequests, std::optional lowestPriorityActive) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(fetchNewRequests); // If grab requests from queue, do exchange between ranks auto reqWithIds = getNewReqWithIds(numActiveRequests, lowestPriorityActive); RequestList newRequests; double newActiveRequestsQueueLatencyMS{0.}; for (auto& reqWithId : reqWithIds) { if (reqWithId.id == mTerminateReqId) { mShutdown = true; mResponsesCv.notify_all(); return {}; } try { std::optional llmRequestLogitsPostProcessor; bool applyLogitsPostProcessorBatched{false}; if (mModel->getWorldConfig().isLastPipelineParallelRank()) { auto logitsPostProcessorName = reqWithId.req.getLogitsPostProcessorName(); if (logitsPostProcessorName) { if (logitsPostProcessorName.value() == Request::kBatchedPostProcessorName) { TLLM_CHECK_WITH_INFO( mLogitsPostProcessorBatched, "Batched logits post processor is not defined."); applyLogitsPostProcessorBatched = true; } else { if (logitsPostProcessorName->compare(0, std::char_traits::length(Request::kDynamicPostProcessorNamePrefix), Request::kDynamicPostProcessorNamePrefix) == 0) { TLLM_CHECK_WITH_INFO(!mModel->getReplicateLogitsPostProcessor() || mModel->getWorldConfig().getTensorParallelism() == 1, "Dynamic logits postprocessor must be used with replicate=false or no tensor " "parallelism."); } if (mModel->getWorldConfig().isFirstTensorParallelRank() || mModel->getReplicateLogitsPostProcessor()) { llmRequestLogitsPostProcessor = getLogitsPostProcessor(logitsPostProcessorName.value()); } else { llmRequestLogitsPostProcessor = [](IdType reqId, RtTensorPtr& logits, BeamTokens const& beamTokens, CudaStreamPtr const& cudaStreamPtr, std::optional clientId) {}; } } } } auto newLlmReq = std::make_shared( reqWithId.id, reqWithId.req, llmRequestLogitsPostProcessor, applyLogitsPostProcessorBatched); auto numReturnSequences = newLlmReq->getNumSubRequests(); if (numReturnSequences > 1) { TLLM_CHECK(reqWithId.childReqIds.size() == static_cast(numReturnSequences - 1)); mChildReqIdsMap[reqWithId.id] = reqWithId.childReqIds; } for (auto seqIdx = 0; seqIdx < numReturnSequences; seqIdx++) { auto newReq = seqIdx == 0 ? newLlmReq : newLlmReq->createChildRequest(reqWithId.childReqIds.at(seqIdx - 1)); // If static batching and streaming, disable streaming and exclude input if (mBatchingType == BatchingType::kSTATIC && newReq->isStreaming()) { newReq->setStreaming(false); newReq->setExcludeInputFromOutput(true); } // Validate the request parameters newReq->validate(mModel->getMaxInputLen(), mModel->getMaxSequenceLen(), mModel->getMaxDraftLen(), mModel->getVocabSizePadded(), mEncoderModel ? std::optional(mEncoderModel->getMaxInputLen()) : std::nullopt, mEnableBlockReuse); TLLM_CHECK_WITH_INFO(!mEncoderModel || !mIsSchedulerMaxUtilization, "Encoder or Encoder-Decoder model don't support max utilization scheduler yet. Only max requests " "or guaranteed no evict."); // When streaming is enabled and scheduling policy permits evict/restart, need to guard against the case // where the sequence is truncated on eviction (to respect maxInputLen limits), resulting in loss of // some tokens that have been streamed out. In this case, resuming generation may result in different // completion for locations whose tokens have already been returned. There is no way to protect against // this, so disallowing. if (newReq->isStreaming() && !mIsSchedulerGuaranteedNoEvict && !mIsChunkedContext) { auto const maxReqSeqLen = newReq->mPromptLen + newReq->mMaxNewTokens; auto const maxRestartLen = maxReqSeqLen - 1; TLLM_CHECK_WITH_INFO(maxRestartLen <= mModel->getMaxInputLen(), "Request sequence length is potentially greater than max input length. This cannot be run " "unless streaming is disabled, context chunking is enabled or the GUARANTEED_NO_EVICT " "scheduling policy is used"); } // Create the encoder output tensor if (mEncoderModel) { TLLM_CHECK_WITH_INFO(mModel || (!mModel && newReq->getReturnEncoderOutput()), "Encoder-Decoder models allow optionally returning encoder output. But if it is Encoder-only " "models, please make sure returnEncoderOutput is always true."); // gpu buffers for passing to the next phase newReq->allocEncoderOutput(mEncoderModel->getBufferManager(), mEncoderModel->getLogitDataType()); newReq->allocEncoderHiddenStates( mEncoderModel->getBufferManager(), mEncoderModel->getLogitDataType()); // pinned buffers for returning results to host if (newReq->getReturnEncoderOutput()) { newReq->allocEncoderOutputHost( mEncoderModel->getHiddenSize() * mEncoderModel->getWorldConfig().getTensorParallelism(), mEncoderModel->getLogitDataType()); } } if (!mEncoderModel && newReq->getEncoderInputFeatures()) { TLLM_LOG_INFO("Allocating buffers for encoder output"); // gpu buffers for passing to the next phase newReq->allocEncoderOutput(mModel->getBufferManager(), mModel->getLogitDataType()); newReq->allocEncoderHiddenStates(mModel->getBufferManager(), mModel->getLogitDataType()); } // Create the context logits tensor if (newReq->getReturnContextLogits()) { TLLM_CHECK_WITH_INFO(mModel->getModelConfig().computeContextLogits(), "Return context logit need to build engine with gather_context_logits"); newReq->allocContextLogitsHost(mModel->getVocabSizePadded(), mModel->getLogitDataType()); } // Create the generation logits tensor if (newReq->getReturnGenerationLogits()) { TLLM_CHECK_WITH_INFO(mModel->getGatherGenerationLogits(), "To return generation logits, gather_generation_logits must be enabled in ExecutorConfig"); if (mModel->getModelConfig().getSpeculativeDecodingMode().isDraftTokensExternal() && newReq->hasDraftTokens()) { newReq->allocTargetModelAcceptedTokenLogitsHost( mModel->getVocabSizePadded(), mModel->getLogitDataType()); } else { newReq->allocGenerationLogitsHost(mModel->getVocabSizePadded(), mModel->getLogitDataType()); } } if (mModel->getWorldConfig().isLastPipelineParallelRank() && newReq->getGuidedDecodingParams()) { TLLM_CHECK_WITH_INFO(mModel->hasGuidedDecoder(), "Request is specified with GuidedDecodingParams, but GuidedDecoder is not setup. Please " "provide a valid GuidedDecodingConfig to setup GuidedDecoder."); } if (mModel->getWorldConfig().isLastPipelineParallelRank() && newReq->hasAdditionalOutputs()) { newReq->allocAdditionalOutputs([this](std::string const& name) { return mModel->getTensorDataType(name); }, [this](std::string const& name) { return mModel->getTensorShape(name); }); } mModel->updatePeftCache(newReq); newRequests.emplace_back(std::move(newReq)); } auto queuedEnd = std::chrono::steady_clock::now(); auto reqQueueLatencyMS = std::chrono::duration(queuedEnd - reqWithId.queuedStart).count(); newActiveRequestsQueueLatencyMS += reqQueueLatencyMS; } catch (runtime::LoraExpectedException const& e) { if (mIsLeader) { // In case of an expected LoRA exception (e.g. cache full, cache miss), log a warning and enqueue // response TLLM_LOG_WARNING("%s", e.what()); enqueueNewResponses({{reqWithId.id, e.what(), reqWithId.req.getClientId()}}); } } catch (std::exception const& e) { if (mIsLeader) { // In case of error, create a response with error for this request auto err = std::string("Encountered an error when fetching new request: ") + e.what(); TLLM_LOG_ERROR("%s", err.c_str()); enqueueNewResponses({{reqWithId.id, err, reqWithId.req.getClientId()}}); } } } TLLM_LOG_DEBUG("[RANK %d] num new requests fetched from queue: %d", COMM_SESSION.getRank(), newRequests.size()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return {newRequests, newActiveRequestsQueueLatencyMS}; } void Executor::Impl::terminateActiveRequests(RequestList& activeRequests, std::string const& err) { TLLM_LOG_ERROR("%s", err.c_str()); // Create a response for all requests and add to queue for (auto it = activeRequests.cbegin(); it != activeRequests.cend();) { auto llmReq = (*it); llmReq->setState(batch_manager::LlmRequestState::kGENERATION_COMPLETE); mModel->terminateRequest(llmReq); if (mIsLeader) { enqueueNewResponses({{llmReq->mRequestId, err, llmReq->mClientId}}); } // Remove from the requestList it = activeRequests.erase(it); } } void Executor::Impl::forwardSync(RequestList& activeRequests) { TLLM_LOG_TRACE("[RANK %d] %s start", COMM_SESSION.getRank(), __PRETTY_FUNCTION__); try { if (mEncoderModel) { mEncoderModel->forwardSync(); } mModel->forwardSync(); } catch (std::exception const& e) { std::string const err = std::string("Encountered an error in forwardSync function: ") + e.what(); terminateActiveRequests(activeRequests, err); } TLLM_LOG_TRACE("[RANK %d] %s stop", COMM_SESSION.getRank(), __PRETTY_FUNCTION__); } // The function is used to change the state of a request to context_init from encoder_init for enc-dec model whose // encoder is skipped. The encoder output is populated accordingly with input features given through model executor of // decoder. void Executor::Impl::prepRequestsForEncoderSkip(RequestList& activeRequests) { for (auto& req : activeRequests) { if (req->isEncoderInitState() && req->getEncoderInputFeatures()) { TLLM_LOG_INFO("Changing state of request and setting encoder output to skip encoder run"); req->setState(batch_manager::LlmRequestState::kCONTEXT_INIT); req->setEncoderOutput(req->getEncoderInputFeatures()); } } } void Executor::Impl::finishTimedOutRequests(RequestList const& activeRequests) { if (mIsLeader) { for (auto const& request : activeRequests) { if (request->isTimedOut() && !request->isFinished()) { // workaround to cancelRequest since it throws an error if // mCommMode == CommunicationMode::kORCHESTRATOR && !mIsOrchestrator { std::scoped_lock lck(mCancelReqMtx); auto& selCancelledReqIds = mUsePipelineParallel ? mPipelineCancelledReqIds : mCancelledReqIds; selCancelledReqIds.insert(request->mRequestId); } } } } } void Executor::Impl::forwardAsync(RequestList& activeRequests) { try { TLLM_LOG_DEBUG("num active requests in scope: %d", activeRequests.size()); if (mDynamicBatchTuner) { auto const averageInputLength = static_cast(mDynamicBatchTuner->getAverageInputLength()); auto const averageOutputLength = static_cast(mDynamicBatchTuner->getAverageOutputLength()); auto const maxCapacityBatchSize = mModel->getMaxCapacityBatchSize(averageInputLength, averageOutputLength); if (mDynamicBatchTuner->isBatchSizeTuningEnabled()) { auto runtimeBatchSize = mDynamicBatchTuner->getRuntimeBatchSize(maxCapacityBatchSize); mModel->setRuntimeBatchSize(runtimeBatchSize); } if (mDynamicBatchTuner->isMaxNumTokensTuningEnabled()) { auto runtimeBatchSize = mModel->getRuntimeBatchSize(); auto runtimeMaxNumTokens = mDynamicBatchTuner->getRuntimeMaxNumTokens(runtimeBatchSize); mModel->setRuntimeMaxNumTokens(runtimeMaxNumTokens); } } if (mEncoderModel) { mEncoderModel->forwardAsync(activeRequests); auto const& encoderStream = *(mEncoderModel->getRuntimeStreamPtr()); auto const& decoderStream = *(mModel->getRuntimeStreamPtr()); runtime::CudaEvent encoderFinished; encoderStream.record(encoderFinished); decoderStream.wait(encoderFinished); } if (!mEncoderModel) { prepRequestsForEncoderSkip(activeRequests); } mModel->forwardAsync(activeRequests); } catch (std::exception const& e) { std::string err = std::string("Encountered an error in forwardAsync function: ") + e.what(); terminateActiveRequests(activeRequests, err); } } IterationStats Executor::Impl::getCurrentIterationStats(RequestList const& activeRequests, double iterLatencyMS, SizeType32 numNewActiveRequests, double newActiveRequestsQueueLatencyMS, SizeType32 numCompletedRequests) { IterationStats stats; // Timestamp stats.timestamp = tensorrt_llm::common::getCurrentTimestamp(); stats.numNewActiveRequests = numNewActiveRequests; stats.iterLatencyMS = iterLatencyMS; stats.newActiveRequestsQueueLatencyMS = newActiveRequestsQueueLatencyMS; // Active request count stats.numActiveRequests = static_cast(activeRequests.size()); // Queued request count { std::scoped_lock lck(mQueuedReqMtx); stats.numQueuedRequests = static_cast(mQueuedRequests.size()); } stats.numCompletedRequests = numCompletedRequests; // Max number of requests stats.maxNumActiveRequests = mMaxNumActiveRequests; // Runtime memory allocation statistics auto const& memoryCounters = runtime::MemoryCounters::getInstance(); stats.gpuMemUsage = memoryCounters.getGpu(); stats.cpuMemUsage = memoryCounters.getCpu(); stats.pinnedMemUsage = memoryCounters.getPinned(); // Model specific stats mModel->getCurrentIterationStats(stats); return stats; } RequestStatsPerIteration Executor::Impl::getCurrentRequestStats( RequestList const& activeRequests, RequestList const& finishedRequests) { std::vector requestStatsVec; auto includeDisServingStats = [](LlmRequestPtr const& request, tensorrt_llm::executor::RequestStats& requestStats) { auto requestType = request->getLlmRequestType(); if (requestType == batch_manager::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_ONLY || requestType == batch_manager::LlmRequestType::LLMREQUEST_TYPE_GENERATION_ONLY) { requestStats.disServingStats = executor::DisServingRequestStats{request->getKvCacheTransferTimeMS(), request->getKvCacheSize()}; } }; for (auto const& request : activeRequests) { RequestStats requestStats; requestStats.id = request->mRequestId; requestStats.stage = request->getRequestStage(); requestStats.contextPrefillPosition = request->getContextCurrentPosition(); requestStats.numGeneratedTokens = request->getMaxBeamNumTokens() - request->getOrigPromptLen(); requestStats.avgNumDecodedTokensPerIter = request->getAvgDecodedTokensPerIter(); includeDisServingStats(request, requestStats); requestStats.allocTotalBlocksPerRequest = request->getAllocTotalBlocksPerRequest(); requestStats.allocNewBlocksPerRequest = request->getAllocNewBlocksPerRequest(); requestStats.reusedBlocksPerRequest = request->getReusedBlocksPerRequest(); requestStats.missedBlocksPerRequest = request->getMissedBlocksPerRequest(); requestStats.kvCacheHitRatePerRequest = request->getKVCacheHitRatePerRequest(); requestStatsVec.emplace_back(requestStats); } { std::unique_lock lck(mQueuedReqMtx); for (auto const& request : mQueuedRequests) { // Still waiting for the first scheduling RequestStats requestStats; requestStats.id = static_cast(request.id); requestStats.stage = executor::RequestStage::kQUEUED; requestStats.contextPrefillPosition = 0; requestStats.numGeneratedTokens = 0; requestStats.avgNumDecodedTokensPerIter = 0; requestStats.allocTotalBlocksPerRequest = 0; requestStats.allocNewBlocksPerRequest = 0; requestStats.reusedBlocksPerRequest = 0; requestStats.missedBlocksPerRequest = 0; requestStats.kvCacheHitRatePerRequest = 0; requestStatsVec.emplace_back(requestStats); } } for (auto const& request : finishedRequests) { // Still waiting for the first scheduling RequestStats requestStats; requestStats.id = static_cast(request->mRequestId); requestStats.stage = executor::RequestStage::kGENERATION_COMPLETE; requestStats.contextPrefillPosition = request->getContextCurrentPosition(); requestStats.numGeneratedTokens = request->getMaxBeamNumTokens() - request->getOrigPromptLen(); requestStats.avgNumDecodedTokensPerIter = request->getAvgDecodedTokensPerIter(); includeDisServingStats(request, requestStats); requestStats.allocTotalBlocksPerRequest = request->getAllocTotalBlocksPerRequest(); requestStats.allocNewBlocksPerRequest = request->getAllocNewBlocksPerRequest(); requestStats.reusedBlocksPerRequest = request->getReusedBlocksPerRequest(); requestStats.missedBlocksPerRequest = request->getMissedBlocksPerRequest(); requestStats.kvCacheHitRatePerRequest = request->getKVCacheHitRatePerRequest(); requestStatsVec.emplace_back(requestStats); } RequestStatsPerIteration stats{0, std::move(requestStatsVec)}; // Model specific stats mModel->getCurrentRequestStats(stats); return stats; } void Executor::Impl::appendCurrentIterStats(IterationStats&& currentIterStats) { std::scoped_lock lck(mIterStatsMtx); if (mIterationStats.size() >= mIterStatsMaxIterations) { mIterationStats.pop_front(); } mIterationStats.emplace_back(std::move(currentIterStats)); } void Executor::Impl::appendMultipleIterStats(std::vector&& currentIterStatsVec) { std::scoped_lock lck(mIterStatsMtx); if (mIterationStats.size() + currentIterStatsVec.size() > mIterStatsMaxIterations) { size_t removeCount = mIterationStats.size() + currentIterStatsVec.size() - mIterStatsMaxIterations; for (size_t i = 0; i < removeCount; i++) { mIterationStats.pop_front(); } } mIterationStats.insert(mIterationStats.end(), std::make_move_iterator(currentIterStatsVec.begin()), std::make_move_iterator(currentIterStatsVec.end())); } void Executor::Impl::updateIterationStats(RequestList const& activeRequests, double iterLatencyMS, SizeType32 numNewActiveRequests, double newActiveRequestsQueueLatencyMS, SizeType32 numCompletedRequests, bool flushToOrchestrator) { NVTX3_SCOPED_RANGE(updateIterationStats); if (mIterStatsMaxIterations > 0 && mIsLeader) { auto currentIterStats = getCurrentIterationStats( activeRequests, iterLatencyMS, numNewActiveRequests, newActiveRequestsQueueLatencyMS, numCompletedRequests); // Send the stats to the orchestrator if (mCommMode == CommunicationMode::kORCHESTRATOR) { bool hasSchedThisIter = (currentIterStats.inflightBatchingStats && currentIterStats.inflightBatchingStats->numScheduledRequests > 0) || (currentIterStats.staticBatchingStats && currentIterStats.staticBatchingStats->numScheduledRequests > 0); appendCurrentIterStats(std::move(currentIterStats)); if (hasSchedThisIter || flushToOrchestrator) { std::deque iterStatsQueue; { std::scoped_lock lck(mIterStatsMtx); iterStatsQueue = std::exchange(mIterationStats, {}); } MpiMessage message(MpiId::ITER_STATS); std::vector iterStates( std::make_move_iterator(iterStatsQueue.begin()), std::make_move_iterator(iterStatsQueue.end())); message.data = IterStatsData{std::move(iterStates)}; mSendQueue.push(std::move(message)); } } else { // Add current iteration stats appendCurrentIterStats(std::move(currentIterStats)); } } } void Executor::Impl::appendCurrentRequestStats(RequestStatsPerIteration&& currentRequestStats) { std::scoped_lock lck(mRequestStatsMtx); if (mRequestStats.size() >= mRequestStatsMaxIterations) { mRequestStats.pop_front(); } mRequestStats.emplace_back(std::move(currentRequestStats)); } void Executor::Impl::appendMultipleRequestStats(std::vector&& currentRequestStatsVec) { std::scoped_lock lck(mRequestStatsMtx); if (mRequestStats.size() + currentRequestStatsVec.size() > mRequestStatsMaxIterations) { size_t removeCount = mRequestStats.size() + currentRequestStatsVec.size() - mRequestStatsMaxIterations; for (size_t i = 0; i < removeCount; i++) { mRequestStats.pop_front(); } } mRequestStats.insert(mRequestStats.end(), std::make_move_iterator(currentRequestStatsVec.begin()), std::make_move_iterator(currentRequestStatsVec.end())); } void Executor::Impl::updateRequestStats( RequestList const& activeRequests, RequestList const& finishedRequests, bool flushToOrchestrator) { NVTX3_SCOPED_RANGE(updateRequestStats); if (mRequestStatsMaxIterations > 0 && mIsLeader) { // Add current iteration request stats auto currentRequestStats = getCurrentRequestStats(activeRequests, finishedRequests); // Send the stats to the orchestrator if (mCommMode == CommunicationMode::kORCHESTRATOR) { bool hasScheduledReqs = false; if (!flushToOrchestrator) { size_t activeSize = activeRequests.size(); TLLM_CHECK_WITH_INFO(currentRequestStats.requestStats.size() >= activeSize, "currentRequestStats num is %ld should >= activeRequest num:%zu", currentRequestStats.requestStats.size(), activeSize); hasScheduledReqs = std::any_of(currentRequestStats.requestStats.begin(), currentRequestStats.requestStats.begin() + static_cast(activeSize), [](RequestStats const& requestStat) { return requestStat.scheduled; }); } appendCurrentRequestStats(std::move(currentRequestStats)); if (hasScheduledReqs || flushToOrchestrator) { std::deque requestStatsQueue; { std::scoped_lock lck(mRequestStatsMtx); requestStatsQueue = std::exchange(mRequestStats, {}); } std::vector requestIterStates( std::make_move_iterator(requestStatsQueue.begin()), std::make_move_iterator(requestStatsQueue.end())); MpiMessage message(MpiId::REQUEST_ITER_STATS); message.data = RequestStatsPerIterationData{std::move(requestIterStates)}; mSendQueue.push(std::move(message)); } } else { // Add current iteration stats appendCurrentRequestStats(std::move(currentRequestStats)); } } } void Executor::Impl::appendCurrentDebugTensors() { if (mDebugTensorsMaxIterations > 0) { std::scoped_lock lck(mDebugTensorsMtx); if (mDebugTensors.size() >= mDebugTensorsMaxIterations) { mDebugTensors.pop_front(); } mDebugTensors.emplace_back(mModel->getCurrentDebugTensors()); } } void Executor::Impl::terminateCancelledRequests(RequestList& activeRequests) { NVTX3_SCOPED_RANGE(terminateCancelledRequests); auto const& worldConfig = mModel->getWorldConfig(); auto const broadcastCancelledRequests = [this, &activeRequests, &worldConfig] { auto const& commSession = COMM_SESSION; if (worldConfig.isPipelineParallel()) { mCancelledRequestsWaitThread->waitStop(); } if (commSession.getSize() > 1 && !activeRequests.empty()) { if (mIsPipelineLeader) { if (worldConfig.isPipelineParallel()) { auto const peer = worldConfig.getPipelineParallelism() - 1; bool shouldExit = false; mCommPipelineParallel->send( &shouldExit, 1, mpi::MpiType::kBOOL, peer, mpi::MpiTag::kExecutorShouldExit); auto pipelineCancelledReqIds = CancelledRequestsAsyncSend::cancelledRequestsRecv(mCommPipelineParallel, peer); mCancelledReqIds.insert(pipelineCancelledReqIds.begin(), pipelineCancelledReqIds.end()); } auto numCancelledRequests = static_cast(mCancelledReqIds.size()); if (worldConfig.isTensorParallel()) { mCommTensorParallel->bcastValue(numCancelledRequests, 0); if (numCancelledRequests > 0) { std::vector cancelledReqIdsVec(mCancelledReqIds.begin(), mCancelledReqIds.end()); mCommTensorParallel->bcast( cancelledReqIdsVec.data(), cancelledReqIdsVec.size(), mpi::MpiType::kUINT64, 0); } } if (worldConfig.isContextParallel()) { mCommContextParallel->bcastValue(numCancelledRequests, 0); if (numCancelledRequests > 0) { std::vector cancelledReqIdsVec(mCancelledReqIds.begin(), mCancelledReqIds.end()); mCommContextParallel->bcast( cancelledReqIdsVec.data(), cancelledReqIdsVec.size(), mpi::MpiType::kUINT64, 0); } } } // If not leader else { if (worldConfig.isFirstPipelineParallelRank()) { int64_t numCancelledRequests = 0; mCommTensorParallel->bcastValue(numCancelledRequests, 0); mCommContextParallel->bcastValue(numCancelledRequests, 0); if (numCancelledRequests > 0) { std::vector cancelledReqIdsVec(numCancelledRequests); mCommTensorParallel->bcast( cancelledReqIdsVec.data(), cancelledReqIdsVec.size(), mpi::MpiType::kUINT64, 0); mCommContextParallel->bcast( cancelledReqIdsVec.data(), cancelledReqIdsVec.size(), mpi::MpiType::kUINT64, 0); mCancelledReqIds = std::unordered_set(cancelledReqIdsVec.begin(), cancelledReqIdsVec.end()); } } else { auto const peer = worldConfig.getPipelineParallelRank() - 1; mCancelledReqIds = CancelledRequestsAsyncSend::cancelledRequestsRecv(mCommPipelineParallel, peer); } } if (!worldConfig.isLastPipelineParallelRank()) { auto const peer = worldConfig.getPipelineParallelRank() + 1; mCancelledRequestsAsyncSndHdl = std::make_unique(mCommPipelineParallel, mCancelledReqIds, peer); mCancelledRequestsWaitThread->notifyStart(); } } }; std::unique_lock lck{mCancelReqMtx, std::defer_lock}; if (!worldConfig.isPipelineParallel()) { lck.lock(); } broadcastCancelledRequests(); if (!mCancelledReqIds.empty()) { // Loop over active requests and terminate those that have been cancelled std::unordered_set terminatedReqIds; for (auto& req : activeRequests) { auto reqId = req->isChild() ? req->getParentRequestId() : req->mRequestId; if (mCancelledReqIds.find(reqId) != mCancelledReqIds.end()) { auto finishReason = req->isTimedOut() ? FinishReason::kTIMED_OUT : FinishReason::kCANCELLED; mModel->terminateRequestSync(req, finishReason); // Parent and child requests share the same request id. // Mark it terminated first and remove from the set later. terminatedReqIds.insert(reqId); } } for (auto const& reqId : terminatedReqIds) { mCancelledReqIds.erase(reqId); } } } void Executor::Impl::terminateContextFinishedRequests(InTransList& inTransmissionRequests) { NVTX3_SCOPED_RANGE(terminateContextFinishedRequests); for (auto it = inTransmissionRequests.begin(); it != inTransmissionRequests.end();) { auto& item = *it; auto req = item.request; if (req->isDisaggContextCompleteState()) { // If pinnedBlockIds were tracked, unpin them. Otherwise, just terminate. auto kvMgr = mModel->getKVCacheManager(); if (kvMgr && !item.pinnedBlockIds.empty()) { kvMgr->unpinBlocksById(item.pinnedBlockIds); } else { mModel->terminateRequest(req); } it = inTransmissionRequests.erase(it); } else { ++it; } } } void Executor::Impl::appendNewResponses(std::vector&& newResponses) { { std::scoped_lock lck(mResponsesMtx); for (auto& response : newResponses) { mResponses[response.getRequestId()].emplace_back(std::move(response)); } } mResponsesCv.notify_all(); } Executor::Impl::RequestList Executor::Impl::populateNewResponses( RequestList& activeRequests, InTransList& inTransmissionRequests, std::vector& newResponses) { NVTX3_SCOPED_RANGE(populateNewResponses); RequestList finishedRequests; for (auto it = activeRequests.begin(); it != activeRequests.end();) { auto const& llmReq = (*it); bool const requestDone = llmReq->isFinished(); // Only leader should store responses if (mIsLeader) { auto response = llmReq->createResponse(mModel->hasSpeculativeDecodingFastLogits(), mWorldRank); if (response) { newResponses.emplace_back(std::move(response.value())); } } // Remove from active requests if last response has been generated if (requestDone) { // move the in transmission requests to another tracker if (llmReq->isDisaggContextTransmissionState()) { std::vector pinnedBlockIds{}; auto kvMgr = mModel->getKVCacheManager(); if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow()) { pinnedBlockIds = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true); mModel->terminateRequest(llmReq); } inTransmissionRequests.push_back(InTransmissionItem{*it, pinnedBlockIds}); } finishedRequests.push_back(*it); it = activeRequests.erase(it); } else { ++it; } } return finishedRequests; } void Executor::Impl::executionLoop() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); tensorrt_llm::common::setThreadName("executionLoop"); auto const& worldConfig = mModel->getWorldConfig(); TLLM_CUDA_CHECK(cudaSetDevice(worldConfig.getDevice())); auto const [profileIterIdxs, stopIterIdxs] = tensorrt_llm::common::populateIterationIndexes( kPROFILE_START_STOP_ENV_VAR_NAME, kLEGACY_PROFILE_START_STOP_ENV_VAR_NAME); SizeType32 numNewActiveRequests{0}; std::chrono::time_point iterStart; std::chrono::time_point iterEnd; bool firstIteration{true}; RequestList activeRequests; InTransList inTransmissionRequests; std::vector newResponses; while (!mShutdown || !activeRequests.empty()) { double iterLatencyMS{0.0}; double newActiveRequestsQueueLatencyMS{0.0}; bool reportFinishedRequests = true; RequestList finishedRequests; if (!activeRequests.empty()) { finishTimedOutRequests(activeRequests); terminateCancelledRequests(activeRequests); forwardSync(activeRequests); finishedRequests = populateNewResponses(activeRequests, inTransmissionRequests, newResponses); cleanupDynamicLogitsPostProcessors(finishedRequests); auto const iterCounter = mModel->getIterCounter(); auto const stopIter = !stopIterIdxs.empty() && (stopIterIdxs.count(iterCounter - 1) > 0); if (stopIter) { cudaProfilerStop(); } // When there are no active or inflight requests, we need to update the stats before calling // fetchNewRequests to make sure that the stats are reported accurately. if (activeRequests.empty() && (!firstIteration)) { mModel->resetIterationStats(); updateIterationStats(activeRequests, iterLatencyMS, numNewActiveRequests, newActiveRequestsQueueLatencyMS, static_cast(finishedRequests.size()), true); updateRequestStats(activeRequests, finishedRequests, true); reportFinishedRequests = false; } if (!newResponses.empty()) { enqueueNewResponses(std::move(newResponses)); newResponses.clear(); } iterEnd = std::chrono::steady_clock::now(); iterLatencyMS = std::chrono::duration(iterEnd - iterStart).count(); } if (!inTransmissionRequests.empty()) { terminateContextFinishedRequests(inTransmissionRequests); } if (!mShutdown) { auto const iterCounter = mModel->getIterCounter(); auto const profileIter = !profileIterIdxs.empty() && (profileIterIdxs.count(iterCounter) > 0); if (profileIter) { cudaProfilerStart(); } iterStart = std::chrono::steady_clock::now(); std::optional lowestPriority = std::nullopt; if (!activeRequests.empty()) { lowestPriority = activeRequests.back()->priority(); } auto [newRequests, newActiveRequestsQueueLatency] = fetchNewRequests(static_cast(activeRequests.size()), lowestPriority); newActiveRequestsQueueLatencyMS = newActiveRequestsQueueLatency; numNewActiveRequests = newRequests.size(); if (firstIteration) { firstIteration = false; } for (auto const& newRequest : newRequests) { insertRequestInOrder(activeRequests, newRequest); } // Update dynamic tuning stats if (mDynamicBatchTuner) { for (auto const& req : activeRequests) { auto const inputLength = req->mPromptLen; auto const outputLength = req->mMaxNewTokens; mDynamicBatchTuner->updateStats(inputLength, outputLength); } } } if (!activeRequests.empty()) { forwardAsync(activeRequests); updateIterationStats(activeRequests, iterLatencyMS, numNewActiveRequests, newActiveRequestsQueueLatencyMS, static_cast(finishedRequests.size()), false); // Finished requests were reported once. Avoid reporting it twice. if (reportFinishedRequests) { updateRequestStats(activeRequests, finishedRequests, false); } else { updateRequestStats(activeRequests, {}, false); } appendCurrentDebugTensors(); } } if (mCancelledRequestsWaitThread) { mCancelledRequestsWaitThread.reset(nullptr); } if (mRequestWithIdWaitThread) { mRequestWithIdWaitThread.reset(nullptr); } if (worldConfig.isPipelineParallel() && mIsPipelineLeader) { auto const peer = worldConfig.getPipelineParallelism() - 1; int64_t numActiveRequests = -1; mCommPipelineParallel->send( &numActiveRequests, 1, mpi::MpiType::kINT64, peer, mpi::MpiTag::kExecutorNumActiveRequests); bool shouldExit = true; mCommPipelineParallel->send(&shouldExit, 1, mpi::MpiType::kBOOL, peer, mpi::MpiTag::kExecutorShouldExit); } if (mRequestWithIdLeaderThread) { mRequestWithIdLeaderThread->join(); mRequestWithIdLeaderThread.reset(nullptr); } if (mCancelledRequestsLeaderThread) { mCancelledRequestsLeaderThread->join(); mCancelledRequestsLeaderThread.reset(nullptr); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void Executor::Impl::enqueueTerminateRequest() { { std::scoped_lock lck(mQueuedReqMtx); Request dummyReq({1}, 1); RequestWithId reqWithId{std::move(dummyReq), mTerminateReqId}; mQueuedRequests.emplace_back(reqWithId); } mQueuedReqCv.notify_one(); } void Executor::Impl::enqueueNewResponses(std::vector&& newResponses) { TLLM_CHECK_WITH_INFO(mIsLeader, "Only leader should store responses"); if (mCommMode == CommunicationMode::kLEADER) { appendNewResponses(std::move(newResponses)); } else if (mCommMode == CommunicationMode::kORCHESTRATOR) { MpiMessage message(MpiId::RESPONSE); message.data = ResponseData{std::move(newResponses)}; mSendQueue.push(std::move(message)); } } // Orchestrator thread sending new requests to leader of the model void Executor::Impl::orchSendReqThread() { tensorrt_llm::common::setThreadName("orchSendReq"); while (true) { auto message = mSendQueue.pop(); if (message.id == MpiId::TERMINATION) { mOrchLeaderComm->send(&message.id, 1, mpi::MpiType::kUINT64, mLeaderRank, mpi::MpiTag::kOrchestratorId); TLLM_LOG_INFO("Orchestrator sendReq thread exiting"); break; } if (message.id == MpiId::PENDING_REQUEST) { auto& reqWithIds = std::get(message.data); auto packed = RequestWithId::serializeReqWithIds(reqWithIds.requests); TLLM_LOG_DEBUG("Orchestrator sendReq thread sending %d pending requests", reqWithIds.requests.size()); // Temporary WAR to indicate to client that we cannot send the serialized request // because it exceeds int32_t size limit. // TODO: Should fix as part of https://jirasw.nvidia.com/browse/TRTLLM-708 if (packed.size() > std::numeric_limits::max()) { for (auto const& reqWithId : reqWithIds.requests) { { std::scoped_lock lck(mResponsesMtx); mResponses[reqWithId.id].emplace_back(reqWithId.id, "Request is too large, or you are enqueuing too many requests at once " "to be sent via MPI_Send, please try to enqueue the request(s) again. " "This issue will be resolved in a future version of TRT-LLM."); } mResponsesCv.notify_all(); } } else { mOrchLeaderComm->send(&message.id, 1, mpi::MpiType::kUINT64, mLeaderRank, mpi::MpiTag::kOrchestratorId); mOrchLeaderComm->send( packed.data(), packed.size(), mpi::MpiType::kCHAR, mLeaderRank, mpi::MpiTag::kOrchestratorData); } } else if (message.id == MpiId::CANCEL_REQUEST) { auto& data = std::get(message.data); mOrchLeaderComm->send(&message.id, 1, mpi::MpiType::kUINT64, mLeaderRank, mpi::MpiTag::kOrchestratorId); mOrchLeaderComm->send( data.ids.data(), data.ids.size(), mpi::MpiType::kUINT64, mLeaderRank, mpi::MpiTag::kOrchestratorData); } else { TLLM_THROW("Invalid message id"); } } } // Leader thread receiving new requests from orchestrator void Executor::Impl::leaderRecvReqThread() { tensorrt_llm::common::setThreadName("leaderRecvReq"); TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); #if ENABLE_MULTI_DEVICE auto& selCancelledReqIds = mUsePipelineParallel ? mPipelineCancelledReqIds : mCancelledReqIds; while (true) { if (mRecvPollPeriodMs > 0) { mOrchLeaderComm->recvPoll(mOrchRank, mpi::MpiTag::kOrchestratorId, mRecvPollPeriodMs); } // Blocking is okay: terminate message is expected to arrive here MPI_Message msg = nullptr; MPI_Status status; mOrchLeaderComm->mprobe(mOrchRank, mpi::MpiTag::kOrchestratorId, &msg, &status); int32_t count = 0; MPICHECK(MPI_Get_count(&status, MPI_UINT64_T, &count)); // NOLINT TLLM_CHECK(count == 1); MpiId mpiId{}; MPICHECK(MPI_Mrecv(&mpiId, count, MPI_UINT64_T, &msg, &status)); // NOLINT // EXIT condition from receiving TERMINATE msg if (mpiId == MpiId::TERMINATION) { // Enqueue a request to indicate to other ranks to terminate enqueueTerminateRequest(); // Send message to orchestrator to indicate to terminate orch recv thread mSendQueue.push(MpiMessage(mpiId)); TLLM_LOG_INFO("Leader recvReq thread exiting"); break; } if (mpiId == MpiId::PENDING_REQUEST) { mOrchLeaderComm->mprobe(mOrchRank, mpi::MpiTag::kOrchestratorData, &msg, &status); MPICHECK(MPI_Get_count(&status, MPI_CHAR, &count)); // NOLINT std::vector buffer(count); MPICHECK(MPI_Mrecv(buffer.data(), count, MPI_CHAR, &msg, &status)); // NOLINT auto requestWithIds = RequestWithId::deserializeReqWithIds(buffer); TLLM_LOG_DEBUG("Leader recvReq thread receiving %d pending requests", requestWithIds.size()); { std::scoped_lock lck(mQueuedReqMtx); if (mMaxQueueSize) { auto const maxQueueSize = mMaxQueueSize.value(); if (maxQueueSize > 0 && mQueuedRequests.size() >= static_cast(maxQueueSize)) { auto err = tensorrt_llm::common::fmtstr( "Maximum queue size of %d has been reached, please try again later", maxQueueSize); TLLM_LOG_ERROR("%s", err.c_str()); std::vector responses; responses.reserve(requestWithIds.size()); for (auto const& reqWithId : requestWithIds) { responses.emplace_back(reqWithId.id, err); } enqueueNewResponses(std::move(responses)); continue; } } for (auto&& req : requestWithIds) { req.queuedStart = std::chrono::steady_clock::now(); insertRequestInOrder(mQueuedRequests, std::move(req)); } } mQueuedReqCv.notify_one(); } else if (mpiId == MpiId::CANCEL_REQUEST) { // Prepare receiving data mOrchLeaderComm->mprobe(mOrchRank, mpi::MpiTag::kOrchestratorData, &msg, &status); MPICHECK(MPI_Get_count(&status, MPI_UINT64_T, &count)); // NOLINT std::vector cancelledReqIds(count); MPICHECK(MPI_Mrecv(cancelledReqIds.data(), count, MPI_UINT64_T, &msg, &status)); // NOLINT std::scoped_lock lck(mCancelReqMtx); selCancelledReqIds.insert(cancelledReqIds.begin(), cancelledReqIds.end()); } else { TLLM_THROW("Invalid message id"); } } #endif // ENABLE_MULTI_DEVICE } // Leader thread sending responses to orchestrator void Executor::Impl::leaderSendThread(MpiMessageQueue& sendQueue, mpi::MpiTag idTag, mpi::MpiTag dataTag) { tensorrt_llm::common::setThreadName("leaderSend"); TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); #if ENABLE_MULTI_DEVICE while (true) { auto message = sendQueue.pop(); if (message.id == MpiId::TERMINATION) { mOrchLeaderComm->send(&message.id, 1, mpi::MpiType::kUINT64, mOrchRank, idTag); TLLM_LOG_INFO("Leader sendThread exiting"); break; } if (message.id == MpiId::RESPONSE || message.id == MpiId::ITER_STATS || message.id == MpiId ::REQUEST_ITER_STATS) { std::vector buffer; if (message.id == MpiId::RESPONSE) { auto& responseData = std::get(message.data); TLLM_LOG_DEBUG("Leader sendResp thread sending %d responses", responseData.responses.size()); buffer = Serialization::serialize(responseData.responses); } else if (message.id == MpiId::ITER_STATS) { auto& iterStatsData = std::get(message.data); TLLM_LOG_DEBUG("Leader sendResp thread sending iter stats"); buffer = Serialization::serialize(iterStatsData.iterStatsVec); } else if (message.id == MpiId::REQUEST_ITER_STATS) { auto& requestIterStatsData = std::get(message.data); TLLM_LOG_DEBUG("Leader sendResp thread sending iter request stats"); buffer = Serialization::serialize(requestIterStatsData.requestStatsPerIterationVec); } mOrchLeaderComm->send(&message.id, 1, mpi::MpiType::kUINT64, mOrchRank, idTag); mOrchLeaderComm->send(buffer.data(), buffer.size(), mpi::MpiType::kCHAR, mOrchRank, dataTag); } else { TLLM_THROW("Invalid message id"); } } #endif // ENABLE_MULTI_DEVICE } void Executor::Impl::orchRecvThread(mpi::MpiTag idTag, mpi::MpiTag dataTag) { tensorrt_llm::common::setThreadName("orchRecv"); #if ENABLE_MULTI_DEVICE while (true) { if (mRecvPollPeriodMs > 0) { mOrchLeaderComm->recvPoll(mOrchRank, mpi::MpiTag::kOrchestratorId, mRecvPollPeriodMs); } MPI_Message msg = nullptr; MPI_Status status; mOrchLeaderComm->mprobe(mLeaderRank, idTag, &msg, &status); int32_t count = 0; MPICHECK(MPI_Get_count(&status, MPI_UINT64_T, &count)); // NOLINT TLLM_CHECK(count == 1); MpiId mpiId{}; MPICHECK(MPI_Mrecv(&mpiId, count, MPI_UINT64_T, &msg, &status)); // NOLINT if (mpiId == MpiId::TERMINATION) { TLLM_LOG_INFO("Orchestrator recv thread exiting"); break; } if (mpiId == MpiId::RESPONSE || mpiId == MpiId::ITER_STATS || mpiId == MpiId::REQUEST_ITER_STATS) { mOrchLeaderComm->mprobe(mLeaderRank, dataTag, &msg, &status); MPICHECK(MPI_Get_count(&status, MPI_CHAR, &count)); // NOLINT std::vector buffer(count); MPICHECK(MPI_Mrecv(buffer.data(), count, MPI_CHAR, &msg, &status)); // NOLINT if (mpiId == MpiId::RESPONSE) { auto newResponses = Serialization::deserializeResponses(buffer); TLLM_LOG_DEBUG("Orchestrator recv thread receiving %d responses", newResponses.size()); appendNewResponses(std::move(newResponses)); } else if (mpiId == MpiId::ITER_STATS) { appendMultipleIterStats(Serialization::deserializeIterationStatsVec(buffer)); } else if (mpiId == MpiId::REQUEST_ITER_STATS) { appendMultipleRequestStats(Serialization::deserializeRequestStatsPerIterationVec(buffer)); } } else { TLLM_THROW("Invalid message id"); } } #endif // ENABLE_MULTI_DEVICE } Executor::Impl::LlmRequestLogitsPostProcessor Executor::Impl::getLogitsPostProcessor(std::string const& name) { auto const postProcIt = mLogitsPostProcessorMap.find(name); TLLM_CHECK_WITH_INFO( postProcIt != mLogitsPostProcessorMap.end(), "LogitsPostProcessor %s not found.", name.c_str()); auto executorLogitsPostProcessor = postProcIt->second; return [executorLogitsPostProcessor](IdType reqId, RtTensorPtr& logits, BeamTokens const& beamTokens, CudaStreamPtr const& cudaStreamPtr, std::optional clientId) { auto logitsTensor = executor::detail::ofITensor(logits); executorLogitsPostProcessor(reqId, logitsTensor, beamTokens, cudaStreamPtr, clientId); }; } void Executor::Impl::setupDynamicLogitsPostProcessors(std::vector& newReqWithIds) { for (auto& reqWithId : newReqWithIds) { auto logitsPostProcessor = reqWithId.req.getLogitsPostProcessor(); if (logitsPostProcessor) { std::string const name = Request::kDynamicPostProcessorNamePrefix + std::to_string(reqWithId.id); mLogitsPostProcessorMap[name] = logitsPostProcessor.value(); reqWithId.req.setLogitsPostProcessor(std::nullopt); reqWithId.req.setLogitsPostProcessorName(name); } } } void Executor::Impl::cleanupDynamicLogitsPostProcessors(RequestList const& finishedRequests) { for (auto& req : finishedRequests) { std::string const name = Request::kDynamicPostProcessorNamePrefix + std::to_string(req->mRequestId); auto const postProcIt = mLogitsPostProcessorMap.find(name); if (postProcIt != mLogitsPostProcessorMap.end()) { mLogitsPostProcessorMap.erase(name); } } } void Executor::Impl::addTerminatedReqId(std::vector const& responses, IdType const& reqId) { for (auto const& response : responses) { if (response.hasError() || (!response.hasError() && response.getResult().isFinal)) { mTerminatedReqIds.insert(reqId); if (mChildReqIdsMap.find(reqId) != mChildReqIdsMap.end()) { for (auto childReqId : mChildReqIdsMap.at(reqId)) { mTerminatedReqIds.insert(childReqId); } mChildReqIdsMap.erase(reqId); } } } } void Executor::Impl::checkParallelApiUsage(std::string const& methodName) const { // If leader mode, and not leader, throw error if (mCommMode == CommunicationMode::kLEADER && !mIsLeader) { // Non-leader are not expected to call cancelRequest TLLM_THROW("With LEADER communication mode, only leader rank is expected to call %s", methodName.c_str()); } if (mCommMode == CommunicationMode::kORCHESTRATOR && !mIsOrchestrator) { TLLM_THROW( "With ORCHESTRATOR communication mode, only orchestrator rank is expected to call %s", methodName.c_str()); } } } // namespace tensorrt_llm::executor