TensorRT-LLMs/cpp/tensorrt_llm/executor/model.h
Robin Kobus ccff86068e
fix: request termination in pipeline parallelism (#3892)
* feat: Implement synchronous request termination in batch manager

- Added `terminateRequestSync` method to `TrtEncoderModel` and `TrtGptModelInflightBatching` for handling request termination in the next `forwardSync` call.
- Updated existing request termination logic to utilize the new synchronous method, ensuring generated tokens are cleared appropriately.
- Enhanced logging for clarity in token management during request processing.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fixup! feat: Implement synchronous request termination in batch manager

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fix: MockedModelCancelRequest

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fixup! feat: Implement synchronous request termination in batch manager

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fix: terminate with timeout

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fixup! feat: Implement synchronous request termination in batch manager

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* docs: Update doc string for allottedTimeMs

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-05-05 21:51:41 +08:00

132 lines
5.9 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/logitsPostProcessor.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <nlohmann/json.hpp>
namespace tensorrt_llm::executor
{
class Model
{
using LlmRequestPtr = std::shared_ptr<batch_manager::LlmRequest>;
public:
Model() = default;
virtual ~Model() = default;
/// @brief Function that marks a request Id as complete and cleans up associated state
virtual void terminateRequest(LlmRequestPtr const& llmRequest, bool pause) = 0;
void terminateRequest(LlmRequestPtr const& llmRequest)
{
terminateRequest(llmRequest, false);
}
/// @brief Terminate request in the next forwardSync call that includes the request.
virtual void terminateRequestSync(LlmRequestPtr const& llmRequest, FinishReason finishReason) = 0;
/// @brief Function that synchronizes the decoder
virtual void forwardSync() = 0;
/// @brief Function that tries to advance the active requests
/// Depending on resources available, it's possible that not all requests will get advanced
/// @param activeRequests The list of request to try to advance
virtual void forwardAsync(batch_manager::RequestList const& activeRequests) = 0;
/// @brief Override the runtime batch size for the model
virtual void setRuntimeBatchSize(SizeType32 runtimeBatchSize)
{
// By default, we ignore the runtimeBatchSize unless the model actively supports it
}
/// @brief Get the runtime batch size for the model
[[nodiscard]] virtual SizeType32 getRuntimeBatchSize() const
{
TLLM_CHECK_WITH_INFO(false, "getRuntimeBatchSize is not implemented");
}
/// @brieft Override the runtime max num tokens for the model
virtual void setRuntimeMaxNumTokens(SizeType32 runtimeMaxNumTokens)
{
// By default, we ignore the runtimeMaxNumTokens unless the model actively supports it
}
virtual void updatePeftCache(LlmRequestPtr const& llmRequest) = 0;
/// @brief Reset the iteration stats when there are no inflight requests
virtual void resetIterationStats() = 0;
[[nodiscard]] virtual SizeType32 getMaxNumSequences() const = 0;
[[nodiscard]] virtual SizeType32 getMaxInputLen() const = 0;
[[nodiscard]] virtual SizeType32 getHiddenSize() const = 0;
[[nodiscard]] virtual SizeType32 getMaxSequenceLen() const = 0;
[[nodiscard]] virtual SizeType32 getVocabSizePadded() const = 0;
[[nodiscard]] virtual SizeType32 getMaxDraftLen() const = 0;
[[nodiscard]] virtual SizeType32 getNumMicroBatches() const = 0;
[[nodiscard]] virtual SizeType32 getOperatingBeamWidth() const = 0;
[[nodiscard]] virtual nvinfer1::DataType getLogitDataType() const = 0;
[[nodiscard]] virtual runtime::WorldConfig const& getWorldConfig() const = 0;
[[nodiscard]] virtual runtime::ModelConfig const& getModelConfig() const = 0;
[[nodiscard]] virtual runtime::BufferManager const& getBufferManager() const = 0;
[[nodiscard]] virtual runtime::BufferManager::CudaStreamPtr getRuntimeStreamPtr() const = 0;
[[nodiscard]] virtual IterationType getIterCounter() const noexcept = 0;
[[nodiscard]] virtual bool hasSpeculativeDecodingFastLogits() const noexcept = 0;
[[nodiscard]] virtual bool getGatherGenerationLogits() const = 0;
[[nodiscard]] virtual nvinfer1::DataType getTensorDataType(std::string const& name) const = 0;
[[nodiscard]] virtual nvinfer1::Dims getTensorShape(std::string const& name) const = 0;
/// @brief Function that provides per iteration stats specific to a certain model
/// @param stats The json object to write stats to
virtual void getCurrentIterationStats(IterationStats& stats) const = 0;
/// @brief Function that provides per request stats specific to a certain model
/// @param stats The request stats to be updated
virtual void getCurrentRequestStats(RequestStatsPerIteration& stats) const = 0;
[[nodiscard]] virtual DebugTensorsPerIteration getCurrentDebugTensors() const = 0;
using LogitsPostProcessorBatched = tensorrt_llm::batch_manager::LogitsPostProcessor::LogitsPostProcessorBatched;
virtual void setLogitsPostProcessorBatched(std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched)
= 0;
virtual void setReplicateLogitsPostProcessor(bool replicateLogitsPostProcessor) = 0;
[[nodiscard]] virtual bool getReplicateLogitsPostProcessor() const = 0;
[[nodiscard]] virtual bool hasGuidedDecoder() const noexcept = 0;
[[nodiscard]] virtual std::shared_ptr<tensorrt_llm::batch_manager::kv_cache_manager::BaseKVCacheManager>
getKVCacheManager() = 0;
[[nodiscard]] virtual std::shared_ptr<tensorrt_llm::batch_manager::kv_cache_manager::BaseKVCacheManager const>
getKVCacheManager() const = 0;
//! \brief Get the batch size that can fill the kv cache to the maximum capacity give the sequence length
//! \param seqLen The sequence length
//! \return The batch size that can fill the kv cache to the maximum capacity. If unsuporrted, return 0.
[[nodiscard]] virtual SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const = 0;
};
} // namespace tensorrt_llm::executor