mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* refactor: Update ExecutorConfig to use AdditionalModelOutput type - Changed function signatures and member variables across multiple files to replace std::optional<std::vector<std::string>> with std::optional<std::vector<executor::AdditionalModelOutput>> to include gatherContext flag for each additional output. - Updated related serialization and deserialization methods to accommodate the new type. - Adjusted tests to reflect the changes in the output handling structure. This refactor enhances the flexibility and maintainability of the output configuration in the executor and batch manager components. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Remove equality operator from TrtGptModelOptionalParams - Deleted the operator== implementation from TrtGptModelOptionalParams to simplify the class. - Updated the pybind11 bindings to remove the exposure of the equality operator to Python. This change streamlines the class definition and reduces unnecessary complexity in the bindings. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Enhance copyAdditionalOutputs to utilize AdditionalModelOutput - Updated the copyAdditionalOutputs function to accept a vector of AdditionalModelOutput, allowing for the inclusion of the gatherContext flag. - Adjusted the logic to handle context and non-context outputs separately, improving the output handling mechanism. - Modified related unit tests to incorporate the new gatherContext parameter, ensuring comprehensive testing of the updated functionality. This refactor improves the flexibility and clarity of output management in the batch processing workflow. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Introduce findOutputTensor utility function for output tensor retrieval - Added a new utility function, findOutputTensor, to encapsulate the logic for finding output tensors and checking their validity. - Refactored copyAdditionalOutputs to utilize findOutputTensor, reducing code duplication and improving clarity. - Enhanced error checking for additional context and generation output tensors. This change streamlines the output tensor retrieval process, enhancing maintainability and readability in the batch processing workflow. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Check final indices of additional output tensors and update tests - Added checks to verify the final indices of additional output tensors for context and generation outputs. - Updated unit tests to verify the changes. - Add lastTokenIds input tensor to test engines. - Logits output depends on gatherContextLogits parameter. - Removed gatherContextOutputs parameter from the validate method in LlmRequest. - Context outputs do not depend on computeContextLogits parameter. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * fixup! refactor: Check final indices of additional output tensors and update tests Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * fixup! refactor: Update ExecutorConfig to use AdditionalModelOutput type Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * fixup! refactor: Remove equality operator from TrtGptModelOptionalParams Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * docs: Update executor.md Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Clean up includes Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --------- Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
348 lines
14 KiB
C++
348 lines
14 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "inflightBatchingUtils.h"
|
|
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
|
|
|
namespace tensorrt_llm::batch_manager::utils
|
|
{
|
|
using ITensor = runtime::ITensor;
|
|
|
|
TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector const& generationRequests)
|
|
{
|
|
auto const numRequests = static_cast<ITensor::DimType64>(contextRequests.size() + generationRequests.size());
|
|
auto requestIds
|
|
= runtime::BufferManager::cpu(ITensor::makeShape({numRequests}), runtime::TRTDataType<RequestIdType>::value);
|
|
auto requestIdsRange = runtime::BufferRange<RequestIdType>(*requestIds);
|
|
auto batchIdx{0};
|
|
for (auto const& requests : {contextRequests, generationRequests})
|
|
{
|
|
for (auto const& request : requests)
|
|
{
|
|
requestIdsRange[batchIdx++] = request->mRequestId;
|
|
}
|
|
}
|
|
return requestIds;
|
|
}
|
|
|
|
void sortByLoraId(ScheduledRequests& scheduledRequests)
|
|
{
|
|
auto sortRequests = [](RequestVector& requests)
|
|
{
|
|
std::sort(requests.begin(), requests.end(),
|
|
[](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); });
|
|
};
|
|
sortRequests(scheduledRequests.contextRequests);
|
|
sortRequests(scheduledRequests.generationRequests);
|
|
}
|
|
|
|
void copyGenerationLogits(RuntimeBuffers::GenerationLogitsCache& generationLogitsCache,
|
|
runtime::BufferManager const& bufferManager, LlmRequest& llmReq, bool beforeDecoder,
|
|
std::vector<SizeType32> const& numDroppedTokens)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
|
|
TLLM_CHECK_WITH_INFO(
|
|
!beforeDecoder || numDroppedTokens.empty(), "numDroppedTokens are only possible after decoder.");
|
|
|
|
auto const reqBeamWidth = llmReq.mSamplingConfig.beamWidth;
|
|
TLLM_CHECK_WITH_INFO(numDroppedTokens.empty() || numDroppedTokens.size() == static_cast<size_t>(reqBeamWidth),
|
|
"Dropped tokens have to be defined for all beams.");
|
|
|
|
auto const fragmentSize = llmReq.getGenerationLogitsFragmentsSize();
|
|
|
|
// Merge logits fragments on device
|
|
auto const& transposeBufferPtr = generationLogitsCache.transposedLogits;
|
|
auto const& cachePointerDevice = generationLogitsCache.fragmentPointerDevice;
|
|
auto const& cachePointerHost = generationLogitsCache.getFragmentPointerHost();
|
|
tensorrt_llm::runtime::kernels::mergeLogitsFragments(bufferManager, *transposeBufferPtr,
|
|
llmReq.getGenerationLogitsFragments(), *cachePointerDevice, *cachePointerHost, 0, 1, reqBeamWidth,
|
|
bufferManager.getStream(), 0);
|
|
llmReq.clearGenerationLogitsFragments();
|
|
|
|
// Copy logits to host
|
|
for (SizeType32 beam = 0; beam < reqBeamWidth; beam++)
|
|
{
|
|
auto const droppedSize = !numDroppedTokens.empty() ? numDroppedTokens.at(beam) : 0;
|
|
// Ignore logits of dropped tokens
|
|
auto const beamFragmentSize = fragmentSize - droppedSize;
|
|
// If this function is called before the decoder, the request does not contain the generated token of the
|
|
// current iteration, so we add 1 to the number of tokens.
|
|
auto const numGenerationToken
|
|
= static_cast<SizeType32>(beforeDecoder) + llmReq.getNumTokens(beam) - llmReq.mPromptLen;
|
|
auto const hostOffset = numGenerationToken - beamFragmentSize;
|
|
|
|
// [beamWidth, GENERATION_LOGITS_BUFFER_LENGTH, vocabSizePadded] -> [beamFragmentSize, vocabSizePadded]
|
|
auto beamDeviceTensorPtr = ITensor::slice(transposeBufferPtr, {beam, 0}, beamFragmentSize);
|
|
// [beamWidth, mMaxNewTokens, vocabSizePadded] -> [beamFragmentSize, vocabSizePadded]
|
|
auto beamHostTensorPtr = ITensor::slice(llmReq.getGenerationLogitsHost(), {beam, hostOffset}, beamFragmentSize);
|
|
bufferManager.copy(*beamDeviceTensorPtr, *beamHostTensorPtr);
|
|
}
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
namespace
|
|
{
|
|
|
|
std::pair<TensorPtr const&, bool> findOutputTensor(std::string const& outputTensorName,
|
|
std::vector<executor::AdditionalModelOutput> const& additionalModelOutputs,
|
|
RuntimeBuffers::TensorMap const& outputMap, bool isContext)
|
|
{
|
|
auto const aoIter = std::find_if(additionalModelOutputs.cbegin(), additionalModelOutputs.cend(),
|
|
[&outputTensorName](auto const& ao) { return ao.name == outputTensorName; });
|
|
TLLM_CHECK_WITH_INFO(aoIter != additionalModelOutputs.cend(), "Additional %s output tensor not found: %s",
|
|
isContext ? "context" : "generation", outputTensorName.c_str());
|
|
|
|
auto const gatherContext = aoIter->gatherContext;
|
|
if (isContext)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(
|
|
gatherContext, "Additional context output tensor not gathered: %s", outputTensorName.c_str());
|
|
}
|
|
|
|
auto const tensorIt = outputMap.find(outputTensorName);
|
|
TLLM_CHECK_WITH_INFO(tensorIt != outputMap.end(), "Additional %s output tensor not found: %s",
|
|
isContext ? "context" : "generation", outputTensorName.c_str());
|
|
|
|
return {tensorIt->second, gatherContext};
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void copyAdditionalOutputs(std::vector<executor::AdditionalModelOutput> const& additionalModelOutputs,
|
|
RequestVector const& contextRequests, RequestVector const& generationRequests,
|
|
RuntimeBuffers::TensorMap const& outputMap, runtime::BufferManager const& manager)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
|
|
// One index shared across all output tensors that have gatherContext
|
|
SizeType32 srcTensorIndexWithContext{0};
|
|
// One index shared across all output tensors that do not have gatherContext
|
|
SizeType32 srcTensorIndexWithoutContext{0};
|
|
|
|
for (auto const& llmReq : contextRequests)
|
|
{
|
|
auto numContextTokens = llmReq->getContextChunkSize();
|
|
for (auto const& outputTensor : llmReq->getAdditionalContextOutputs())
|
|
{
|
|
auto const& [tensor, gatherContext]
|
|
= findOutputTensor(outputTensor.first, additionalModelOutputs, outputMap, true);
|
|
|
|
auto const srcTensorIndex = srcTensorIndexWithContext;
|
|
auto srcView = ITensor::slice(tensor, srcTensorIndex, numContextTokens);
|
|
auto dstView = ITensor::slice(outputTensor.second, llmReq->getContextCurrentPosition(), numContextTokens);
|
|
manager.copy(*srcView, *dstView);
|
|
}
|
|
srcTensorIndexWithContext += numContextTokens;
|
|
srcTensorIndexWithoutContext += 1;
|
|
|
|
// Copy output of last token to generation outputs
|
|
if (llmReq->isLastContextChunk())
|
|
{
|
|
for (auto const& outputTensor : llmReq->getAdditionalGenerationOutputs())
|
|
{
|
|
auto const& [tensor, gatherContext]
|
|
= findOutputTensor(outputTensor.first, additionalModelOutputs, outputMap, false);
|
|
|
|
auto const srcTensorIndex = gatherContext ? srcTensorIndexWithContext : srcTensorIndexWithoutContext;
|
|
auto srcView = ITensor::slice(tensor, srcTensorIndex - 1, 1);
|
|
for (SizeType32 beam = 0; beam < llmReq->mSamplingConfig.beamWidth; beam++)
|
|
{
|
|
auto dstView = ITensor::slice(outputTensor.second, {beam, 0}, 1);
|
|
manager.copy(*srcView, *dstView);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto const& llmReq : generationRequests)
|
|
{
|
|
auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth;
|
|
for (auto const& outputTensor : llmReq->getAdditionalGenerationOutputs())
|
|
{
|
|
auto const& [tensor, gatherContext]
|
|
= findOutputTensor(outputTensor.first, additionalModelOutputs, outputMap, false);
|
|
|
|
auto const srcTensorIndex = gatherContext ? srcTensorIndexWithContext : srcTensorIndexWithoutContext;
|
|
for (SizeType32 beam = 0; beam < reqBeamWidth; beam++)
|
|
{
|
|
auto const generatedLength = llmReq->getNumTokens(beam) - llmReq->getPromptLen();
|
|
TLLM_CHECK(generatedLength >= 1);
|
|
auto srcView = ITensor::slice(tensor, srcTensorIndex + beam, 1);
|
|
auto dstView = ITensor::slice(outputTensor.second, {beam, generatedLength}, 1);
|
|
manager.copy(*srcView, *dstView);
|
|
}
|
|
}
|
|
srcTensorIndexWithContext += reqBeamWidth;
|
|
srcTensorIndexWithoutContext += reqBeamWidth;
|
|
}
|
|
|
|
// Check final indices
|
|
for (auto const& outputTensor : additionalModelOutputs)
|
|
{
|
|
auto const& outputTensorName = outputTensor.name;
|
|
auto const gatherContext = outputTensor.gatherContext;
|
|
|
|
auto const tensorIt = outputMap.find(outputTensorName);
|
|
TLLM_CHECK_WITH_INFO(
|
|
tensorIt != outputMap.end(), "Additional output tensor not found: %s", outputTensorName.c_str());
|
|
|
|
auto const& outputShape = tensorIt->second->getShape();
|
|
auto const outputSize = outputShape.d[0];
|
|
auto const finalIndex = gatherContext ? srcTensorIndexWithContext : srcTensorIndexWithoutContext;
|
|
|
|
TLLM_CHECK_WITH_INFO(finalIndex == outputSize, "Additional %s output tensor final index mismatch %d != %ld: %s",
|
|
gatherContext ? "context" : "generation", finalIndex, outputSize, outputTensorName.c_str());
|
|
}
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void terminateRequest(SequenceSlotManager& seqSlotManager, LlmRequest& llmReq, SizeType32 maxInputLen,
|
|
OptionalRef<kv_cache_manager::BaseKVCacheManager> kvCacheManager,
|
|
OptionalRef<kv_cache_manager::BaseKVCacheManager> crossKvCacheManager,
|
|
OptionalRef<BasePeftCacheManager> peftCacheManager, bool pause)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
// If a sequence slot is associated with this request id, free it
|
|
seqSlotManager.freeSequenceSlot(llmReq.mRequestId);
|
|
// Remove the sequence from kvCacheManager
|
|
auto const requestId = llmReq.mRequestId;
|
|
if (kvCacheManager)
|
|
{
|
|
kvCacheManager->removeSequence(requestId, llmReq);
|
|
}
|
|
if (crossKvCacheManager)
|
|
{
|
|
crossKvCacheManager->removeSequence(requestId, llmReq);
|
|
}
|
|
if (pause && !llmReq.isGenerationCompleteState())
|
|
{
|
|
llmReq.pause(maxInputLen);
|
|
}
|
|
else
|
|
{
|
|
TLLM_LOG_DEBUG("terminated: request ID %lu, paused: %d", requestId, pause);
|
|
}
|
|
|
|
if (peftCacheManager)
|
|
{
|
|
peftCacheManager->markRequestDone(llmReq, pause);
|
|
}
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void CudaGraphExecutor::create(cudaGraph_t const& graph)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
assert(mInstance == nullptr);
|
|
TLLM_CUDA_CHECK(cudaGraphInstantiate(&mInstance, graph, nullptr, nullptr, 0));
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void CudaGraphExecutor::uploadToStream(runtime::CudaStream const& stream)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
assert(hasInstance());
|
|
TLLM_CUDA_CHECK(cudaGraphUpload(mInstance, stream.get()));
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void CudaGraphExecutor::launch(runtime::CudaStream const& stream)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
TLLM_CUDA_CHECK(cudaGraphLaunch(mInstance, stream.get()));
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
bool CudaGraphExecutor::update(cudaGraph_t const& graph)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
return cudaGraphExecUpdate(mInstance, graph, nullptr) != cudaSuccess;
|
|
}
|
|
|
|
void CudaGraphExecutor::clear()
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
if (mInstance != nullptr)
|
|
{
|
|
TLLM_CUDA_CHECK(cudaGraphExecDestroy(mInstance));
|
|
mInstance = nullptr;
|
|
}
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void CudaGraphExecutor::prepareNextGraph(std::shared_ptr<runtime::TllmRuntime>& runtime, SizeType32 nextContextId)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
auto& stream = runtime->getStream();
|
|
|
|
cudaGraph_t nextGraph;
|
|
TLLM_CUDA_CHECK(cudaStreamBeginCapture(stream.get(), cudaStreamCaptureModeThreadLocal));
|
|
runtime->executeContext(nextContextId);
|
|
TLLM_CUDA_CHECK(cudaStreamEndCapture(stream.get(), &nextGraph));
|
|
|
|
if (hasInstance())
|
|
{
|
|
if (update(nextGraph))
|
|
{
|
|
clear();
|
|
create(nextGraph);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
create(nextGraph);
|
|
}
|
|
|
|
TLLM_CUDA_CHECK(cudaGraphDestroy(nextGraph));
|
|
uploadToStream(stream);
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
std::optional<std::shared_ptr<CudaGraphExecutor>> CudaGraphExecutorCache::get(BatchState const& state)
|
|
{
|
|
auto it = mMap.find(state);
|
|
if (it == mMap.end())
|
|
{
|
|
return std::nullopt;
|
|
}
|
|
mCache.splice(mCache.begin(), mCache, it->second);
|
|
return it->second->second;
|
|
}
|
|
|
|
void CudaGraphExecutorCache::put(BatchState const& state, std::shared_ptr<CudaGraphExecutor> const& value)
|
|
{
|
|
auto it = mMap.find(state);
|
|
if (it != mMap.end())
|
|
{
|
|
mCache.erase(it->second);
|
|
}
|
|
mCache.emplace_front(BatchStateGraphExecutorPair{state, value});
|
|
mMap[state] = mCache.begin();
|
|
|
|
if (static_cast<runtime::SizeType32>(mMap.size()) > mCapacity)
|
|
{
|
|
auto lastState = mCache.back().first;
|
|
mCache.pop_back();
|
|
mMap.erase(lastState);
|
|
}
|
|
}
|
|
|
|
} // namespace tensorrt_llm::batch_manager::utils
|