/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "trtEncoderModel.h" #include "encoderBuffers.h" #include "tensorrt_llm/batch_manager/capacityScheduler.h" #include "tensorrt_llm/batch_manager/microBatchScheduler.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/nvtxUtils.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/tllmLogger.h" #include "tensorrt_llm/runtime/tllmRuntime.h" #include "tensorrt_llm/runtime/utils/sessionUtils.h" #include #include #include using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::mpi; namespace tensorrt_llm::batch_manager { TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldConfig const& worldConfig, runtime::RawEngine const& rawEngine, std::shared_ptr logger, TrtGptModelOptionalParams const& optionalParams) : TrtGptModel(modelConfig, worldConfig, optionalParams) , mModelConfig{modelConfig} , mWorldConfig{worldConfig} , mDevice{runtime::utils::initDevice(worldConfig)} , mLogger{logger ? std::move(logger) : std::make_shared()} , mRuntime{std::make_shared( rawEngine, mLogger.get(), optionalParams.useGpuDirectStorage, optionalParams.gpuWeightsPercent)} , mMicroBatchId(0) , mCopyBufferManager{std::make_shared()} { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); if (mWorldConfig.isPipelineParallel()) { TLLM_THROW("Pipeline parallelism is currently not supported for encoder models."); mNumMicroBatches = mWorldConfig.getPipelineParallelism(); } else { mNumMicroBatches = isTrtOverlap() ? 2 : 1; } mNumBuffers = mNumMicroBatches; createRuntimeContexts(); createBuffers(); if (mWorldConfig.isPipelineParallel()) { auto const& commSession = COMM_SESSION; mMpiCommPipelinePara = std::make_shared( commSession.split(mWorldConfig.getTensorParallelRank(), mWorldConfig.getPipelineParallelRank())); } mMicroBatchScheduledRequests.resize(mNumMicroBatches); // mEncoderWaitEvents.resize(mNumMicroBatches); // set noScheduleUntilState to LlmRequestState::kENCODER_INIT for encoder model // when null kv cache manager is given, request scheduler will use MaxRequests as capacity scheduler, i.e. no // handling of maximizing utilization or pause/evict // TODO: finer control on encoder requests scheduling mCapacityScheduler = std::make_unique( getMaxBatchSize() * mNumMicroBatches, optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), false, std::nullopt, LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT); mMicroBatchScheduler = std::make_unique( std::nullopt, mModelConfig.getMaxInputLen(), LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT); mHiddenSize = modelConfig.getHiddenSize(); mMaxInputLen = mModelConfig.getMaxInputLen(); TLLM_LOG_INFO("TRTEncoderModel mMaxInputLen: reset to %d from build config.", mMaxInputLen); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } BufferManager const& TrtEncoderModel::getBufferManager() const { return mRuntime->getBufferManager(); } BufferManager::CudaStreamPtr TrtEncoderModel::getRuntimeStreamPtr() const { return mRuntime->getStreamPtr(); } nvinfer1::DataType TrtEncoderModel::getTensorDataType(std::string const& name) const { auto const& engine = mRuntime->getEngine(); return engine.getTensorDataType(name.c_str()); } nvinfer1::Dims TrtEncoderModel::getTensorShape(std::string const& name) const { auto const& engine = mRuntime->getEngine(); return engine.getTensorShape(name.c_str()); } void TrtEncoderModel::getCurrentIterationStats(executor::IterationStats& stats) const { stats.iter = mIterCounter; } void TrtEncoderModel::getCurrentRequestStats(executor::RequestStatsPerIteration& stats) const { stats.iter = mIterCounter; } executor::DebugTensorsPerIteration TrtEncoderModel::getCurrentDebugTensors() const { executor::DebugTensorsPerIteration debugTensors; debugTensors.iter = mIterCounter; TLLM_LOG_WARNING("TrtEncoderModel doesn't support getting debug tensors."); return debugTensors; } void TrtEncoderModel::setLayerProfiler() { TLLM_CHECK(mRuntime); mRuntime->setLayerProfiler(); } std::string TrtEncoderModel::getLayerProfileInfo() const { TLLM_CHECK(mRuntime); return mRuntime->getLayerProfileInfo(); } void TrtEncoderModel::createRuntimeContexts() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mRuntime->clearContexts(); auto const numProfiles = mRuntime->getNbProfiles(); TLLM_CHECK_WITH_INFO(numProfiles == 1, "Encoder only expects one optimization profile"); for (auto i = 0; i < numProfiles; ++i) { mRuntime->addContext(i); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void TrtEncoderModel::executeContext(SizeType32 runtimeContextId) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(executeContext); auto enqueueSuccessful = mRuntime->executeContext(runtimeContextId); if (!enqueueSuccessful) { throw std::runtime_error("Executing TRT engine failed!"); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void TrtEncoderModel::createBuffers() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); for (SizeType32 i = 0; i < mNumBuffers; ++i) { mBuffers.emplace_back( std::make_shared(getMaxBatchSize(), mModelConfig, mWorldConfig, *mRuntime)); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void TrtEncoderModel::executeBatch(ScheduledRequests const& scheduledRequests) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(executeBatch); // encoder model only have one optimization profile for now, so no optimization profile switch SizeType32 optProfileIndex = 0; auto const bufferId = getBufferId(); if (!scheduledRequests.contextRequests.empty()) { // engine I/O auto [inputMap, outputMap] = mBuffers[bufferId]->prepareIO(scheduledRequests.contextRequests, mModelConfig, mWorldConfig, *mRuntime); mRuntime->setInputTensors(optProfileIndex, inputMap); mRuntime->setOutputTensors(optProfileIndex, outputMap); // engine run executeContext(optProfileIndex); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void TrtEncoderModel::rearrangeOutputs(ScheduledRequests const& scheduledRequests) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(rearrangeOutputs); auto const bufferId = getBufferId(); if (!scheduledRequests.contextRequests.empty()) { mBuffers[bufferId]->rearrangeOutputs(scheduledRequests.contextRequests, mModelConfig, mWorldConfig, *mRuntime); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void TrtEncoderModel::forwardSync() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE_WITH_NAME(range, "TrtEncoderModel::forwardSync"); auto const device = mWorldConfig.getDevice(); TLLM_CUDA_CHECK(cudaSetDevice(device)); auto& currRequests = mMicroBatchScheduledRequests.at(mMicroBatchId); // auto& encoderWaitEvent = mEncoderWaitEvents.at(mMicroBatchId); if (!currRequests.empty()) { if (!mWorldConfig.isPipelineParallel() || !mWorldConfig.isLastPipelineParallelRank()) { // TLLM_CHECK_WITH_INFO(mEncStepAsyncSndHdl.get() == nullptr, "encoderSync handle must be nullptr."); // // Wait for encoding for requests in flight for the current micro batch // mEncStepAsyncSndHdl = encoderSync(currRequests, encoderWaitEvent); } else { } NVTX3_SCOPED_RANGE(pauseFlaggedCurrRequests); for (auto const& requests : {currRequests.contextRequests}) { for (auto const& llmReq : requests) { auto const reqId = llmReq->mRequestId; mInflightReqIds.erase(reqId); TLLM_LOG_DEBUG("request ID %u removed from ENCODER inflight set", reqId); // If a request in encoder phase had been flagged to be paused, pause it right away if (mReqIdsToPause.find(reqId) != mReqIdsToPause.end()) { terminateRequest(llmReq, true); mReqIdsToPause.erase(reqId); } } } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void TrtEncoderModel::forwardAsync(RequestList const& activeRequests) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE_WITH_NAME(range, "TrtEncoderModel::ForwardAsync"); auto const device = mWorldConfig.getDevice(); TLLM_CUDA_CHECK(cudaSetDevice(device)); try { auto& currRequests = mMicroBatchScheduledRequests.at(mMicroBatchId); // auto& encoderWaitEvent = mEncoderWaitEvents.at(mMicroBatchId); // Get a new set of requests for encoder // The scheduler will not include any requests that are already in flight for encoder models // TODO: add pause handling logic TLLM_LOG_DEBUG("Running ENCODER request scheduler"); auto [fittingRequests, fittingDisaggeGenInitReuqests, requestsToPause] = (*mCapacityScheduler)(activeRequests); TLLM_CHECK_WITH_INFO( fittingDisaggeGenInitReuqests.empty(), "Disaggregated servering is not support by encoder model."); std::tie(currRequests.contextRequests, std::ignore) = (*mMicroBatchScheduler)( fittingRequests, mInflightReqIds, getMaxBatchSize(), mModelConfig.getMaxNumTokens()); { NVTX3_SCOPED_RANGE(pauseRequestsFlaggedByScheduler); // Loop over requests flagged to be paused, and if not in flight pause it right away for (auto const& llmReq : requestsToPause) { auto const reqId = llmReq->mRequestId; if (mInflightReqIds.find(reqId) == mInflightReqIds.end()) { // Not in flight, can terminate right away terminateRequest(llmReq, true); } else { // In flight, add to set for pausing later mReqIdsToPause.insert(reqId); } } } TLLM_CHECK(currRequests.size() <= static_cast(getMaxBatchSize())); if (!currRequests.empty()) { TLLM_LOG_DEBUG("Running ENCODER model with batch size: %u", currRequests.size()); { NVTX3_SCOPED_RANGE(updateInflightReqIds); // Add to set of requests in flight for (auto const& requests : {currRequests.contextRequests}) { for (auto const& llmReq : requests) { TLLM_LOG_DEBUG("request ID %u added to ENCODER inflight set", llmReq->mRequestId); mInflightReqIds.insert(llmReq->mRequestId); } } } executeBatch(currRequests); sync_check_cuda_error(mRuntime->getStream().get()); rearrangeOutputs(currRequests); sync_check_cuda_error(mRuntime->getStream().get()); // encoderWaitEvent = encoderStepAsync(currRequests); for (auto const& requests : {currRequests.contextRequests}) { for (auto const& llmReq : requests) { if (llmReq->isEncoderInitState()) { llmReq->setState(LlmRequestState::kCONTEXT_INIT); TLLM_LOG_DEBUG("request ID: %u finishes encoder phase", llmReq->mRequestId); } } } } // TODO: PP handling if (!currRequests.empty()) { if (mWorldConfig.isPipelineParallel() && mWorldConfig.isLastPipelineParallelRank()) { // TLLM_CHECK_WITH_INFO(mEncStepAsyncSndHdl.get() == nullptr, "decoderSync handle must be nullptr."); // Wait for encoding for requests in flight for the current micro batch // mEncStepAsyncSndHdl = encoderSync(currRequests, encoderWaitEvent); } } // Update the micro batch ID mMicroBatchId = (mMicroBatchId + 1) % mNumMicroBatches; } // In case of error, we need to free the batch slot associated with those requests catch (std::exception const& e) { for (auto const& llmReq : activeRequests) { terminateRequest(llmReq); } throw; } ++mIterCounter; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void TrtEncoderModel::terminateRequest(std::shared_ptr const& llmReq, bool pause) { // For encoder-only models, just change req state here. might need to do more when using an asynced forward // For enc-dec models, only remove cross kv cache after decoder // genenration has finished TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); if (llmReq->isEncoderInitState()) { llmReq->setState(LlmRequestState::kCONTEXT_INIT); } else { TLLM_LOG_DEBUG("Non-encoder request terminated in encoder model: id %lu", llmReq->mRequestId); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void TrtEncoderModel::terminateRequestSync( std::shared_ptr const& llmReq, executor::FinishReason finishReason) { terminateRequest(llmReq, false); llmReq->finishByReason(finishReason); llmReq->clearGeneratedTokens(); } void TrtEncoderModel::fillEncoderOutputSync(RequestVector const& requestList, TensorMap outputTensors) { auto const totalTokensNb = outputTensors["encoder_output"]->getShape().d[0]; auto const encoderOutputDtype = mRuntime->getEngine().getTensorDataType("encoder_output"); SizeType32 const bytesPerValue = (encoderOutputDtype == nvinfer1::DataType::kFLOAT) ? 4 : 2; std::vector encoderOutputHost( totalTokensNb * mHiddenSize * bytesPerValue * mWorldConfig.getTensorParallelism()); TLLM_CHECK_WITH_INFO(encoderOutputHost.size() > 0, "Encoder output size is 0!"); getBufferManager().copy(*(outputTensors["encoder_output"]), reinterpret_cast(encoderOutputHost.data())); getBufferManager().getStream().synchronize(); // TODO: change engine call to async to improve perf. Also // need to store output buffers, cuda events, etc. auto encoderOutputHostPtr = encoderOutputHost.data(); for (auto const& llmReq : requestList) { SizeType32 const seqLen = llmReq->getEncoderOutputLen(); TensorPtr currentEncoderOutput = mCopyBufferManager.copyFrom(reinterpret_cast(encoderOutputHostPtr), ITensor::makeShape({seqLen, mHiddenSize * mWorldConfig.getTensorParallelism()}), MemoryType::kCPU); llmReq->setEncoderOutputHost(currentEncoderOutput); encoderOutputHostPtr += seqLen * mHiddenSize * bytesPerValue * mWorldConfig.getTensorParallelism(); if (llmReq->isEncoderInitState()) { llmReq->setState(LlmRequestState::kCONTEXT_INIT); } else { TLLM_LOG_DEBUG("Non-encoder request terminated in encoder model: id %lu", llmReq->mRequestId); } } } void TrtEncoderModel::executeBatch(RequestVector const& requestList) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(executeBatch); auto const modelName = mModelConfig.getModelName(); TLLM_CHECK_WITH_INFO(modelName == "EncoderModel" || modelName == "WhisperEncoder", "Model not supported."); TensorMap inputTensors; TensorMap outputTensors; TensorPtr rankOutput; std::vector inputIdsHost; std::vector positionIdsHost; SizeType32 totalOutputLength = 0; SizeType32 totalInputLength = 0; std::vector inputLengthsHost; std::vector inputFeaturesHost; inputLengthsHost.reserve(requestList.size()); SizeType32 maxInputLengthHost = 0; for (auto const& llmReq : requestList) { SizeType32 length = 0; if (mModelConfig.getModelName() == "EncoderModel") { auto const& reqTokens = *(llmReq->getEncoderTokens().value()); length = reqTokens.size(); inputIdsHost.insert(inputIdsHost.end(), reqTokens.begin(), reqTokens.end()); maxInputLengthHost = std::max(maxInputLengthHost, static_cast(length)); } else if (mModelConfig.getModelName() == "WhisperEncoder") { auto const& reqFeatures = llmReq->getEncoderInputFeatures(); // [length, featureDim] length = reqFeatures->getShape().d[0]; auto const curFeatureBytes = reqFeatures->getSizeInBytes(); auto const srcPtr = reinterpret_cast(reqFeatures->data()); inputFeaturesHost.insert(inputFeaturesHost.end(), srcPtr, srcPtr + curFeatureBytes); } positionIdsHost.reserve(positionIdsHost.size() + length); auto const newReqPosBegin = positionIdsHost.end(); positionIdsHost.resize(positionIdsHost.size() + length); std::iota(newReqPosBegin, positionIdsHost.end(), 0); totalOutputLength += llmReq->getEncoderOutputLen(); totalInputLength += length; inputLengthsHost.push_back(length); } TensorPtr hiddenStatesInput; TensorPtr inputLengths = getBufferManager().copyFrom( inputLengthsHost, ITensor::makeShape({static_cast(inputLengthsHost.size())}), MemoryType::kGPU); inputTensors.emplace("input_lengths", inputLengths); if (mModelConfig.getModelName() == "EncoderModel") { // use shape of maxInputLength to indicates max length, content is not important TensorPtr maxInputLength = getBufferManager().gpu(ITensor::makeShape({maxInputLengthHost}), nvinfer1::DataType::kINT32); inputTensors.emplace("max_input_length", maxInputLength); } // engine outputs rankOutput = getBufferManager().gpu( ITensor::makeShape({totalOutputLength, mHiddenSize * mWorldConfig.getTensorParallelism()}), mModelConfig.getDataType()); if (mWorldConfig.isFirstPipelineParallelRank()) { if (mModelConfig.getModelName() == "EncoderModel") { // Engine inputs TensorPtr inputIds = getBufferManager().copyFrom(inputIdsHost, ITensor::makeShape({totalInputLength}), MemoryType::kGPU); TensorPtr positionIds = getBufferManager().copyFrom( positionIdsHost, ITensor::makeShape({totalInputLength}), MemoryType::kGPU); inputTensors.emplace("input_ids", inputIds); inputTensors.emplace("position_ids", positionIds); } else if (mModelConfig.getModelName() == "WhisperEncoder") { auto inputFeaturesHostPtr = inputFeaturesHost.data(); auto const featureDim = requestList.front()->getEncoderInputFeatures()->getShape().d[1]; auto const dtype = requestList.front()->getEncoderInputFeatures()->getDataType(); TensorPtr inputFeatures = getBufferManager().gpu(ITensor::makeShape({totalInputLength, featureDim}), dtype); getBufferManager().copy( reinterpret_cast(inputFeaturesHostPtr), *inputFeatures, runtime::MemoryType::kCPU); TensorPtr positionIds = getBufferManager().copyFrom( positionIdsHost, ITensor::makeShape({totalOutputLength}), MemoryType::kGPU); inputTensors.emplace("input_features", inputFeatures); inputTensors.emplace("position_ids", positionIds); } } else { SizeType32 length = mModelConfig.getModelName() == "WhisperEncoder" ? totalOutputLength : totalInputLength; hiddenStatesInput = getBufferManager().gpu(ITensor::makeShape({length, mHiddenSize * mWorldConfig.getTensorParallelism()}), mModelConfig.getDataType()); inputTensors.emplace("hidden_states_input", hiddenStatesInput); } auto const outputName = mWorldConfig.isLastPipelineParallelRank() ? "encoder_output" : "hidden_states_output"; outputTensors.emplace(outputName, rankOutput); // Set input / output tensors to context, encoder model only have one context mRuntime->setInputTensors(0, inputTensors); mRuntime->setOutputTensors(0, outputTensors); executeContext(0); // copy encoder output to llmRequest, if last PP rank // dispatch result to each llmReq, only needed by the last PP rank // TODO: more dtypes support if (mWorldConfig.isLastPipelineParallelRank()) { fillEncoderOutputSync(requestList, outputTensors); } else { getBufferManager().getStream().synchronize(); } // Update the micro batch ID for next microbatches mMicroBatchId = (mMicroBatchId + 1) % mWorldConfig.getPipelineParallelism(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void TrtEncoderModel::forward(RequestVector& activeRequests) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const device = mWorldConfig.getDevice(); TLLM_CUDA_CHECK(cudaSetDevice(device)); try { if (activeRequests.empty()) { return; } executeBatch(activeRequests); } catch (std::exception const& e) { for (auto& req : activeRequests) { terminateRequest(req); } throw; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void TrtEncoderModel::setLogitsPostProcessorBatched( std::optional logitsPostProcessorBatched) { TLLM_CHECK_WITH_INFO(!logitsPostProcessorBatched.has_value(), "TrtEncoderModel does not use logits processor."); } void TrtEncoderModel::setReplicateLogitsPostProcessor(bool replicateLogitsPostProcessor) { TLLM_THROW("TrtEncoderModel does not use logits processor."); } bool TrtEncoderModel::getReplicateLogitsPostProcessor() const { TLLM_THROW("TrtEncoderModel does not use logits processor."); } TrtEncoderModel::~TrtEncoderModel() = default; } // namespace tensorrt_llm::batch_manager