mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Nave Assaf <nassaf@nvidia.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com> Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Signed-off-by: qqiao <qqiao@nvidia.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: Bo Deng <deemod@nvidia.com> Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com> Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com> Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com> Signed-off-by: Hui Gao <huig@nvidia.com> Signed-off-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com> Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Signed-off-by: Michal Guzek <mguzek@nvidia.com> Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com> Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> Signed-off-by: ruodil <200874449+ruodil@users.noreply.github.com> Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com> Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com> Co-authored-by: Nave Assaf <55059536+Naveassaf@users.noreply.github.com> Co-authored-by: Yechan Kim <161688079+yechank-nvidia@users.noreply.github.com> Co-authored-by: brb-nv <169953907+brb-nv@users.noreply.github.com> Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com> Co-authored-by: Emma Qiao <qqiao@nvidia.com> Co-authored-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Co-authored-by: Bo Deng <deemod@nvidia.com> Co-authored-by: Jin Li <59594262+liji-nv@users.noreply.github.com> Co-authored-by: yifeizhang-c <219273404+yifeizhang-c@users.noreply.github.com> Co-authored-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Co-authored-by: Erin <14718778+hchings@users.noreply.github.com> Co-authored-by: chenfeiz0326 <chenfeiz@nvidia.com> Co-authored-by: ChristinaZ <83400082+ChristinaZ@users.noreply.github.com> Co-authored-by: Venky <23023424+venkywonka@users.noreply.github.com> Co-authored-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com> Co-authored-by: HuiGao-NV <huig@nvidia.com> Co-authored-by: milesial <milesial@users.noreply.github.com> Co-authored-by: Shi Xiaowei <39303645+Shixiaowei02@users.noreply.github.com> Co-authored-by: Michal Guzek <moraxu@users.noreply.github.com> Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com> Co-authored-by: Guoming Zhang <137257613+nv-guomingz@users.noreply.github.com> Co-authored-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> Co-authored-by: pcastonguay <55748270+pcastonguay@users.noreply.github.com> Co-authored-by: ruodil <200874449+ruodil@users.noreply.github.com> Co-authored-by: Linda <57756729+Linda-Stadter@users.noreply.github.com> Co-authored-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com> Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Co-authored-by: Jiagan Cheng <jiaganc@nvidia.com> Co-authored-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> Co-authored-by: Larry <197874197+LarryXFly@users.noreply.github.com> Co-authored-by: Sharan Chetlur <116769508+schetlur-nv@users.noreply.github.com> Co-authored-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
89 lines
3.3 KiB
C++
89 lines
3.3 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/batch_manager/logitsPostProcessor.h"
|
|
|
|
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
|
|
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
|
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
|
|
#include "tensorrt_llm/common/nvtxUtils.h"
|
|
#include "tensorrt_llm/runtime/iTensor.h"
|
|
|
|
namespace tr = tensorrt_llm::runtime;
|
|
|
|
namespace tensorrt_llm::batch_manager
|
|
{
|
|
|
|
using TensorPtr = runtime::ITensor::SharedPtr;
|
|
using ITensor = runtime::ITensor;
|
|
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
|
|
|
bool LogitsPostProcessor::operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor,
|
|
tr::WorldConfig const& worldConfig, CudaStreamPtr const& stream,
|
|
std::optional<LogitsPostProcessorBatched> const& logitsPostProcessorBatched) const
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
NVTX3_SCOPED_RANGE(LogitsPostProcessor);
|
|
|
|
// Arguments for batched processor
|
|
std::vector<LlmRequest::RequestIdType> reqIdsVec;
|
|
std::vector<LlmRequest::TensorPtr> logitsVec;
|
|
std::vector<std::reference_wrapper<LlmRequest::BeamTokens const>> beamTokensVec;
|
|
std::vector<std::optional<LlmRequest::RequestIdType>> clientIdsVec;
|
|
|
|
bool logitsPostProcessorIsApplied = false;
|
|
for (size_t batchIdx = 0; batchIdx < inputBuffers.decoderRequests.size(); ++batchIdx)
|
|
{
|
|
auto const& llmReq = inputBuffers.decoderRequests.at(batchIdx);
|
|
auto& logits = inputBuffers.logits.at(batchIdx);
|
|
|
|
// Invoke non-batched processor or collect arguments for batched processor
|
|
if (llmReq->mLogitsPostProcessor)
|
|
{
|
|
logitsPostProcessorIsApplied = true;
|
|
if (replicateLogitsPostProcessor || worldConfig.isFirstTensorParallelRank())
|
|
{
|
|
(*llmReq->mLogitsPostProcessor)(
|
|
llmReq->mRequestId, logits, llmReq->getTokens(), stream, llmReq->mClientId);
|
|
}
|
|
}
|
|
else if (llmReq->mApplyLogitsPostProcessorBatched)
|
|
{
|
|
reqIdsVec.push_back(llmReq->mRequestId);
|
|
logitsVec.push_back(logits);
|
|
beamTokensVec.emplace_back(llmReq->getTokens());
|
|
clientIdsVec.push_back(llmReq->mClientId);
|
|
}
|
|
}
|
|
|
|
// Invoke batched processor
|
|
if (!reqIdsVec.empty())
|
|
{
|
|
logitsPostProcessorIsApplied = true;
|
|
if (replicateLogitsPostProcessor || worldConfig.isFirstTensorParallelRank())
|
|
{
|
|
(*logitsPostProcessorBatched)(reqIdsVec, logitsVec, beamTokensVec, stream, clientIdsVec);
|
|
}
|
|
}
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
|
|
return logitsPostProcessorIsApplied;
|
|
}
|
|
|
|
} // namespace tensorrt_llm::batch_manager
|