TensorRT-LLMs/cpp/tensorrt_llm/batch_manager/trtEncoderModel.h
Robin Kobus 3ee4332fb1
refactor: Replace DecoderFinishedEvent with CudaEvent in decoder classes (#3078)
- Updated the `forwardAsync` method in `GptDecoderBatched` and `iGptDecoderBatched` to return `CudaEvent` instead of `DecoderFinishedEventPtr`, simplifying event handling.
- Removed the `DecoderFinishedEvent` class and its associated usage across various files, streamlining the codebase.
- Adjusted related methods and Python bindings to accommodate the new event structure, ensuring compatibility and maintaining functionality.

These changes enhance the clarity and efficiency of the decoding process in the batch manager.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-03-28 14:50:52 +08:00

213 lines
7.0 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/runtime/iGptDecoderBatched.h"
#include "tensorrt_llm/runtime/rawEngine.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "trtGptModel.h"
#include <NvInferRuntime.h>
namespace tensorrt_llm::runtime
{
class TllmRuntime;
class NcclCommunicator;
} // namespace tensorrt_llm::runtime
namespace tensorrt_llm::batch_manager
{
class CapacityScheduler;
class MicroBatchScheduler;
class EncoderBuffers;
class TrtEncoderModel : public TrtGptModel
{
public:
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
using BufferManager = tensorrt_llm::runtime::BufferManager;
using TensorMap = runtime::StringPtrMap<runtime::ITensor>;
using TensorPtr = runtime::ITensor::SharedPtr;
TrtEncoderModel(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
runtime::RawEngine const& rawEngine, std::shared_ptr<nvinfer1::ILogger> logger,
TrtGptModelOptionalParams const& optionalParams);
~TrtEncoderModel() override;
void terminateRequest(std::shared_ptr<LlmRequest> const& llmRequest, bool pause = false) override;
void forward(RequestVector& activeRequests);
void forwardSync() override;
void forwardAsync(RequestList const& activeRequests) override;
[[nodiscard]] runtime::BufferManager const& getBufferManager() const override;
[[nodiscard]] runtime::BufferManager::CudaStreamPtr getRuntimeStreamPtr() const override;
runtime::ModelConfig const& getModelConfig() const override
{
return mModelConfig;
}
[[nodiscard]] bool getGatherGenerationLogits() const override
{
return getModelConfig().computeGenerationLogits();
}
runtime::WorldConfig const& getWorldConfig() const override
{
return mWorldConfig;
}
[[nodiscard]] SizeType32 getHiddenSize() const override
{
return mHiddenSize;
}
[[nodiscard]] SizeType32 getMaxInputLen() const override
{
return mMaxInputLen;
}
[[nodiscard]] SizeType32 getNumMicroBatches() const override
{
return mNumMicroBatches;
}
[[nodiscard]] nvinfer1::DataType getLogitDataType() const override
{
return getModelConfig().getDataType();
}
nvinfer1::DataType getTensorDataType(std::string const& name) const override
{
auto const& engine = mRuntime->getEngine();
return engine.getTensorDataType(name.c_str());
}
nvinfer1::Dims getTensorShape(std::string const& name) const override
{
auto const& engine = mRuntime->getEngine();
return engine.getTensorShape(name.c_str());
}
[[nodiscard]] TrtGptModelType getModelType() const override
{
throw std::runtime_error("TrtEncoderModel does not have model type."); // FIXME:
}
[[nodiscard]] executor::IterationType getIterCounter() const noexcept override
{
return mIterCounter;
}
void updatePeftCache(std::shared_ptr<LlmRequest> const& /*llmRequest*/) override
{
throw std::runtime_error("TrtEncoderModel does not have Peft Cache.");
}
void getCurrentIterationStats(executor::IterationStats& stats) const override;
void getCurrentRequestStats(executor::RequestStatsPerIteration& stats) const override;
[[nodiscard]] executor::DebugTensorsPerIteration getCurrentDebugTensors() const override;
void setLayerProfiler() override;
std::string getLayerProfileInfo() const override;
void setLogitsPostProcessorBatched(std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched) override;
void setReplicateLogitsPostProcessor(bool replicateLogitsPostProcessor) override;
[[nodiscard]] bool getReplicateLogitsPostProcessor() const override;
void resetIterationStats() override {}
[[nodiscard]] SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const override
{
return 0;
};
protected:
std::shared_ptr<kv_cache_manager::BaseKVCacheManager> getKVCacheManager() override
{
throw std::runtime_error("TrtEncoderModel does not have KVCache.");
}
[[nodiscard]] std::shared_ptr<kv_cache_manager::BaseKVCacheManager const> getKVCacheManager() const override
{
throw std::runtime_error("TrtEncoderModel does not have KVCache.");
}
[[nodiscard]] std::shared_ptr<BasePeftCacheManager> getPeftCacheManager() override
{
throw std::runtime_error("TrtEncoderModel does not use PEFT.");
}
[[nodiscard]] std::shared_ptr<BasePeftCacheManager const> getPeftCacheManager() const override
{
throw std::runtime_error("TrtEncoderModel does not use PEFT.");
}
private:
[[nodiscard]] SizeType32 getBufferId() const
{
return mMicroBatchId;
}
void createRuntimeContexts();
void executeContext(SizeType32 runtimeContextId);
void createBuffers();
void executeBatch(RequestVector const& requestList);
void executeBatch(ScheduledRequests const& scheduledRequests);
void rearrangeOutputs(ScheduledRequests const& scheduledRequests);
void createCustomAllReduceWorkspace();
void fillEncoderOutputSync(RequestVector const& requestList, TensorMap outputTensors);
runtime::ModelConfig const mModelConfig;
runtime::WorldConfig const mWorldConfig;
int mDevice{-1};
std::shared_ptr<tensorrt_llm::mpi::MpiComm> mMpiCommPipelinePara;
std::shared_ptr<nvinfer1::ILogger> mLogger;
std::shared_ptr<runtime::TllmRuntime> mRuntime;
SizeType32 mMicroBatchId;
// TODO: Add runtime buffers for async PP
std::vector<std::shared_ptr<EncoderBuffers>> mBuffers;
SizeType32 mNumMicroBatches;
SizeType32 mNumBuffers;
std::vector<ScheduledRequests> mMicroBatchScheduledRequests;
ReqIdsSet mInflightReqIds;
ReqIdsSet mReqIdsToPause;
std::unique_ptr<tensorrt_llm::batch_manager::CapacityScheduler const> mCapacityScheduler;
std::unique_ptr<tensorrt_llm::batch_manager::MicroBatchScheduler const> mMicroBatchScheduler;
SizeType32 mHiddenSize; // already divided by Tensor Parallelism
SizeType32 mMaxInputLen; // WAR for max_input_len == max_seq_len at all circumstances
runtime::BufferManager mCopyBufferManager;
// Iteration counter used to distinguish debug output
executor::IterationType mIterCounter{0};
};
} // namespace tensorrt_llm::batch_manager