TensorRT-LLMs/cpp/tensorrt_llm/runtime/promptTuningParams.cpp
wili eba3623a54
Feat: Variable-Beam-Width-Search (VBWS) part4 (#3979)
* feat/vbws-part4-v1.8: rebase

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* feat/vbws-part4-v1.9: fix incorrect output when using short output length

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v1.9.1: remove useless variables

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v1.9.2:fix incorrect output when using short output length

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v1.9.3: rebase

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v1.9.4: rebase

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v1.9.5: remove API change

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

---------

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>
Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
2025-05-12 22:32:29 +02:00

93 lines
3.5 KiB
C++

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/runtime/promptTuningParams.h"
namespace tensorrt_llm::runtime
{
void PromptTuningParams::fillTasksTensor(TensorPtr tasksHost, SizeType32 const batchSize,
SizeType32 const numContextRequests, std::vector<SizeType32> const& reqBeamWidths,
std::vector<SizeType32> const& reqPromptLengths, BufferManager const& manager, bool packedInput)
{
auto const& tasksHostShape = tasksHost->getShape();
TLLM_CHECK_WITH_INFO(tasksHostShape.nbDims == 1, "tasksHost expected to have dimension [batchSize]");
TLLM_CHECK_WITH_INFO(tasksHostShape.d[0] == batchSize, "tasksHost expected to have dimension [batchSize]");
auto const tasksHostPtr = bufferCast<SizeType32 const>(*tasksHost);
bool validInput = packedInput || numContextRequests == batchSize || numContextRequests == 0;
TLLM_CHECK_WITH_INFO(validInput,
"fillTasksTensor function with packed inputs must be called with only context requests or only generation "
"requests.");
bool validShapes = (static_cast<SizeType32>(reqBeamWidths.size()) == batchSize
&& static_cast<SizeType32>(reqPromptLengths.size()) == numContextRequests
&& static_cast<SizeType32>(promptTuningEnabled.size()) == batchSize);
TLLM_CHECK_WITH_INFO(validShapes,
"Invalid inputs to fillTasksTensor function. reqBeamWidths and reqPtuningEnabled size must be batchSize and "
"propmtLenghts size must be numContextRequests");
SizeType32 totalInputSize = 0;
std::vector<SizeType32> promptTasksHost;
for (SizeType32 bid = 0; bid < batchSize; bid++)
{
SizeType32 taskId = promptTuningEnabled[bid] ? tasksHostPtr[bid] : 0;
if (packedInput)
{
if (bid < numContextRequests)
{
totalInputSize += reqPromptLengths[bid];
promptTasksHost.insert(promptTasksHost.end(), reqPromptLengths[bid], taskId);
}
else
{
for (SizeType32 beam = 0; beam < reqBeamWidths[bid]; ++beam)
{
promptTasksHost.insert(promptTasksHost.end(), 1, taskId);
totalInputSize++;
}
}
}
else
{
if (bid < numContextRequests)
{
promptTasksHost.push_back(taskId);
++totalInputSize;
}
else
{
promptTasksHost.insert(promptTasksHost.end(), reqBeamWidths[bid], taskId);
totalInputSize += reqBeamWidths[bid];
}
}
}
if (packedInput)
{
tasks = manager.copyFrom(
promptTasksHost, runtime::ITensor::makeShape({totalInputSize}), runtime::MemoryType::kGPU);
}
else
{
tasks = manager.copyFrom(
promptTasksHost, runtime::ITensor::makeShape({totalInputSize, 1}), runtime::MemoryType::kGPU);
}
}
} // namespace tensorrt_llm::runtime