TensorRT-LLMs/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp
Kanghwan 41e5870a70
[#8476][chore] Update license (#8807)
Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
2025-11-19 15:05:25 -08:00

769 lines
30 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/types.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <numeric>
#include <string>
#include <vector>
namespace tr = tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
namespace texec = tensorrt_llm::executor;
namespace tb = tensorrt_llm::batch_manager;
using VecTokens = tb::LlmRequest::VecTokens;
using SizeType32 = tb::LlmRequest::SizeType32;
using VecTokenExtraIds = tb::LlmRequest::VecTokenExtraIds;
using VecUniqueTokens = tb::LlmRequest::VecUniqueTokens;
class LlmRequestTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
{
protected:
void SetUp() override {}
void TearDown() override {}
};
TEST_F(LlmRequestTest, fromExecutorRequest)
{
VecTokens inputTokens{1, 2, 3, 4, 5};
SizeType32 maxNewTokens(66);
texec::IdType requestId{77};
{
texec::Request execReq(inputTokens, maxNewTokens);
tb::LlmRequest llmReq(requestId, execReq);
EXPECT_EQ(llmReq.getTokens().size(), 1);
EXPECT_EQ(llmReq.getTokens().at(0), inputTokens);
EXPECT_EQ(llmReq.mMaxNewTokens, maxNewTokens);
EXPECT_EQ(llmReq.mSamplingConfig.numReturnSequences, execReq.getSamplingConfig().getNumReturnSequences());
EXPECT_EQ(llmReq.getOrigPromptLen(), inputTokens.size());
EXPECT_EQ(llmReq.getMaxSentTokenLen(), inputTokens.size());
EXPECT_EQ(llmReq.getState(), tb::LlmRequestState::kCONTEXT_INIT);
EXPECT_FALSE(llmReq.mSeqSlot);
// No speculative decoding config, draft tokens should be empty
EXPECT_EQ(llmReq.getNumDraftTokens(), 0);
EXPECT_FALSE(llmReq.getEmbeddingBias().has_value());
EXPECT_FALSE(llmReq.getBadWordsList().has_value());
EXPECT_FALSE(llmReq.getStopWordsList().has_value());
EXPECT_FALSE(llmReq.getPromptEmbeddingTable().has_value());
EXPECT_FALSE(llmReq.getPromptVocabSize().has_value());
}
// Embedding bias
{
texec::Request execReq(inputTokens, maxNewTokens);
SizeType32 vocabSize = 100;
// Try adding embedding bias
auto embeddingBias = texec::Tensor::cpu(texec::DataType::kFP32, {vocabSize});
execReq.setEmbeddingBias(embeddingBias);
tb::LlmRequest llmReq(requestId, execReq);
EXPECT_TRUE(llmReq.getEmbeddingBias().has_value());
EXPECT_EQ(llmReq.getEmbeddingBias().value()->getShape().nbDims, 2);
EXPECT_EQ(llmReq.getEmbeddingBias().value()->getShape().d[0], 1);
EXPECT_EQ(llmReq.getEmbeddingBias().value()->getShape().d[1], vocabSize);
}
// bad/stop words
{
texec::Request execReq(inputTokens, maxNewTokens);
SizeType32 vocabSize = 100;
// Try adding embedding bias
std::list<VecTokens> badWords{{1, 2, 3}, {4, 5}, {9}};
std::list<VecTokens> stopWords{{1, 3}, {4}};
execReq.setBadWords(badWords);
execReq.setStopWords(stopWords);
tb::LlmRequest llmReq(requestId, execReq);
EXPECT_TRUE(llmReq.getBadWordsList().has_value());
EXPECT_TRUE(llmReq.getStopWordsList().has_value());
{
auto badWordsTensor = llmReq.getBadWordsList().value();
EXPECT_EQ(badWordsTensor->getDataType(), nvinfer1::DataType::kINT32);
EXPECT_EQ(badWordsTensor->getShape().nbDims, 3);
EXPECT_EQ(badWordsTensor->getShape().d[0], 1);
EXPECT_EQ(badWordsTensor->getShape().d[1], 2);
EXPECT_EQ(badWordsTensor->getShape().d[2], 6);
auto data = tr::bufferCast<int32_t>(*badWordsTensor);
EXPECT_EQ(data[0], 1);
EXPECT_EQ(data[1], 2);
EXPECT_EQ(data[2], 3);
EXPECT_EQ(data[3], 4);
EXPECT_EQ(data[4], 5);
EXPECT_EQ(data[5], 9);
EXPECT_EQ(data[6 + 0], 3);
EXPECT_EQ(data[6 + 1], 5);
EXPECT_EQ(data[6 + 2], 6);
EXPECT_EQ(data[6 + 3], -1);
EXPECT_EQ(data[6 + 4], -1);
EXPECT_EQ(data[6 + 5], -1);
}
{
auto stopWordsTensor = llmReq.getStopWordsList().value();
EXPECT_EQ(stopWordsTensor->getDataType(), nvinfer1::DataType::kINT32);
EXPECT_EQ(stopWordsTensor->getShape().nbDims, 3);
EXPECT_EQ(stopWordsTensor->getShape().d[0], 1);
EXPECT_EQ(stopWordsTensor->getShape().d[1], 2);
EXPECT_EQ(stopWordsTensor->getShape().d[2], 3);
auto data = tr::bufferCast<int32_t>(*stopWordsTensor);
EXPECT_EQ(data[0], 1);
EXPECT_EQ(data[1], 3);
EXPECT_EQ(data[2], 4);
EXPECT_EQ(data[3 + 0], 2);
EXPECT_EQ(data[3 + 1], 3);
EXPECT_EQ(data[3 + 2], -1);
}
}
// Prompt tuning
{
texec::Request execReq(inputTokens, maxNewTokens);
SizeType32 vocabSize = 100;
SizeType32 hiddenSize = 64;
auto embeddingTable = texec::Tensor::cpu(texec::DataType::kFP32, {vocabSize, hiddenSize});
VecTokenExtraIds extraIds{1, 1, 1, 0, 0};
texec::PromptTuningConfig config(embeddingTable, extraIds);
execReq.setPromptTuningConfig(config);
tb::LlmRequest llmReq(requestId, execReq);
EXPECT_TRUE(llmReq.getPromptEmbeddingTable().has_value());
EXPECT_TRUE(llmReq.getPromptVocabSize().has_value());
EXPECT_EQ(llmReq.getPromptEmbeddingTable().value()->getShape().nbDims, 3);
EXPECT_EQ(llmReq.getPromptEmbeddingTable().value()->getShape().d[0], 1);
EXPECT_EQ(llmReq.getPromptEmbeddingTable().value()->getShape().d[1], vocabSize);
EXPECT_EQ(llmReq.getPromptEmbeddingTable().value()->getShape().d[2], hiddenSize);
EXPECT_EQ(llmReq.getPromptEmbeddingTable().value()->getDataType(), nvinfer1::DataType::kFLOAT);
EXPECT_EQ(llmReq.getPromptVocabSize().value(), vocabSize);
VecUniqueTokens uniqueTokens;
for (size_t i = 0; i < inputTokens.size(); ++i)
{
uniqueTokens.push_back({inputTokens[i], extraIds[i]});
}
EXPECT_EQ(llmReq.getUniqueTokens(0), uniqueTokens);
}
}
TEST_F(LlmRequestTest, invalidExecRequest)
{
VecTokens inputTokens{1, 2, 3, 4, 5};
SizeType32 maxNewTokens(66);
texec::IdType requestId{77};
// Input is too long
std::list<std::pair<std::function<void()>, std::string>> lambdaErrMsgs;
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(2, 1000, 0, 32000);
};
lambdaErrMsgs.emplace_back(lambda, "exceeds maximum input");
}
// Invalid beam width
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
execReq.setSamplingConfig(texec::SamplingConfig(-1));
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(500, 1000, 0, 32000);
};
lambdaErrMsgs.emplace_back(lambda, "beamWidth > 0");
}
// Invalid input draft len
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
execReq.setExternalDraftTokensConfig(texec::ExternalDraftTokensConfig({1, 2}));
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(500, 1000, 1, 32000);
};
lambdaErrMsgs.emplace_back(lambda, "exceeds maximum draft");
}
// Invalid ptable shape
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
auto embeddingTable = texec::Tensor::cpu(texec::DataType::kFP32, {17, 32, 69});
texec::PromptTuningConfig config(embeddingTable);
execReq.setPromptTuningConfig(config);
tb::LlmRequest llmReq(requestId, execReq);
};
lambdaErrMsgs.emplace_back(lambda, "Expected prompt embedding table to have shape");
}
// Invalid extra id vector's size
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
auto embeddingTable = texec::Tensor::cpu(texec::DataType::kFP32, {4, 8});
VecTokenExtraIds extraIds(inputTokens.size() - 1, 0);
texec::PromptTuningConfig config(embeddingTable, extraIds);
execReq.setPromptTuningConfig(config);
tb::LlmRequest llmReq(requestId, execReq);
};
lambdaErrMsgs.emplace_back(lambda, "must be the same as input token vector size");
}
// Extra ids not provided when enabling kv cache reuse with prompt table
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(inputTokens, maxNewTokens);
auto embeddingTable = texec::Tensor::cpu(texec::DataType::kFP32, {4, 8});
texec::PromptTuningConfig config(embeddingTable);
execReq.setPromptTuningConfig(config);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(500, 1000, 1, 32000, std::nullopt, true);
};
lambdaErrMsgs.emplace_back(lambda, "Input token extra ids must be provided");
}
// Invalid endId
{
auto lambda = [&inputTokens, maxNewTokens, requestId]()
{
texec::Request execReq(
inputTokens, maxNewTokens, false, texec::SamplingConfig(), texec::OutputConfig(), -2);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(500, 1000, 1, 32000);
};
lambdaErrMsgs.emplace_back(lambda, "EndId (-2) is not within acceptable range [-1, 32000)");
}
for (auto& lambdaErrMsg : lambdaErrMsgs)
{
auto& lambda = lambdaErrMsg.first;
auto& errMsg = lambdaErrMsg.second;
try
{
lambda();
FAIL() << "Expected failure with " << errMsg;
}
catch (tc::TllmException const& e)
{
EXPECT_THAT(e.what(), testing::HasSubstr(errMsg));
}
catch (std::exception const& e)
{
FAIL() << "Expected TllmException with " << errMsg << " got " << e.what();
}
}
{
// Validate output len truncation w/o draft tokens
texec::Request execReq(inputTokens, maxNewTokens);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(10, 60, 0, 32000);
EXPECT_EQ(llmReq.mMaxNewTokens, 60 - inputTokens.size());
}
{
// Validate output len truncation w draft tokens
texec::Request execReq(inputTokens, maxNewTokens);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(10, 60, 2, 32000);
EXPECT_EQ(llmReq.mMaxNewTokens, 60 - inputTokens.size() - 2);
}
{
// Validate extra ids when enabling kv cache reuse with prompt table
texec::Request execReq(inputTokens, maxNewTokens);
auto embeddingTable = texec::Tensor::cpu(texec::DataType::kFP32, {6, 42});
VecTokenExtraIds extraIds(inputTokens.size(), 1);
texec::PromptTuningConfig config(embeddingTable, extraIds);
execReq.setPromptTuningConfig(config);
tb::LlmRequest llmReq(requestId, execReq);
EXPECT_EQ(static_cast<size_t>(llmReq.getOrigPromptLen()), inputTokens.size());
llmReq.validate(500, 1000, 1, 32000, std::nullopt, true);
}
{
using AdditionalModelOutput = texec::AdditionalModelOutput;
// Validate additional context and gen outputs
texec::Request execReq(inputTokens, maxNewTokens);
std::vector<AdditionalModelOutput> additionalModelOutputs{
AdditionalModelOutput{"context_gen_output", true}, AdditionalModelOutput{"gen_output", false}};
texec::OutputConfig outputConfig;
outputConfig.additionalModelOutputs = additionalModelOutputs;
execReq.setOutputConfig(outputConfig);
tb::LlmRequest llmReq(requestId, execReq);
llmReq.validate(10, 60, 2, 32000, std::nullopt, false);
auto const& additionalContextOutputs = llmReq.getAdditionalContextOutputs();
EXPECT_EQ(additionalContextOutputs.count("context_gen_output"), 1);
EXPECT_EQ(additionalContextOutputs.count("gen_output"), 0);
auto const& additionalGenerationOutputs = llmReq.getAdditionalGenerationOutputs();
EXPECT_EQ(additionalGenerationOutputs.count("context_gen_output"), 1);
EXPECT_EQ(additionalGenerationOutputs.count("gen_output"), 1);
}
}
TEST_F(LlmRequestTest, pause)
{
auto inputTokens = std::make_shared<VecTokens>(VecTokens{1, 2, 3, 4, 5});
SizeType32 maxNewTokens(66);
tb::LlmRequest::RequestIdType requestId{77};
tb::LlmRequest llmReq(requestId, maxNewTokens, inputTokens, tr::SamplingConfig(1), false);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
EXPECT_EQ(llmReq.getMaxNumGeneratedTokens(), 5);
// maxInput is larger then num tokens
llmReq.pause(12);
EXPECT_EQ(llmReq.mPromptLen, 10);
EXPECT_EQ(llmReq.mMaxNewTokens, 61);
EXPECT_EQ(llmReq.getState(), tb::LlmRequestState::kCONTEXT_INIT);
EXPECT_EQ(llmReq.getMaxNumGeneratedTokens(), 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
llmReq.addNewToken(1, 0);
EXPECT_EQ(llmReq.getMaxNumGeneratedTokens(), 4);
llmReq.pause(12);
// max Input is now smaller than num tokens
EXPECT_EQ(llmReq.mPromptLen, 12);
EXPECT_EQ(llmReq.mMaxNewTokens, 59);
EXPECT_EQ(llmReq.getState(), tb::LlmRequestState::kCONTEXT_INIT);
EXPECT_EQ(llmReq.getMaxNumGeneratedTokens(), 0);
}
TEST_F(LlmRequestTest, testAllocateLogitsBuffer)
{
auto inputTokens = std::make_shared<VecTokens>(VecTokens{1, 2, 3, 4, 5});
SizeType32 maxNewTokens(60);
tb::LlmRequest::RequestIdType requestId{77};
tb::LlmRequest llmReq(requestId, maxNewTokens, inputTokens, tr::SamplingConfig(1), false);
EXPECT_EQ(llmReq.mPromptLen, 5);
SizeType32 vocabSizePadded = 32000;
nvinfer1::DataType logitsDataType = nvinfer1::DataType::kFLOAT;
// Test the allocation of context logits
EXPECT_EQ(llmReq.getContextLogitsHost(), nullptr);
llmReq.allocContextLogitsHost(vocabSizePadded, logitsDataType);
auto contextLogitsHostShape = llmReq.getContextLogitsHost()->getShape();
EXPECT_EQ(contextLogitsHostShape.nbDims, 2);
EXPECT_EQ(contextLogitsHostShape.d[0], 5);
EXPECT_EQ(contextLogitsHostShape.d[1], vocabSizePadded);
// Test the allocation of generation logits
EXPECT_EQ(llmReq.getGenerationLogitsHost(), nullptr);
llmReq.allocGenerationLogitsHost(vocabSizePadded, logitsDataType);
auto generationLogitsHostShape = llmReq.getGenerationLogitsHost()->getShape();
EXPECT_EQ(generationLogitsHostShape.nbDims, 3);
EXPECT_EQ(generationLogitsHostShape.d[0], 1);
EXPECT_EQ(generationLogitsHostShape.d[1], maxNewTokens);
EXPECT_EQ(generationLogitsHostShape.d[2], vocabSizePadded);
// Test the allocation of target model's accepted token logits
// Set draft token
EXPECT_EQ(llmReq.getNumDraftTokens(), 0);
auto draftTokens = std::make_shared<VecTokens>(VecTokens{7, 8, 9});
llmReq.setDraftTokens(draftTokens);
EXPECT_EQ(llmReq.getNumDraftTokens(), 3);
// Clean the generation logits
llmReq.setGenerationLogitsHost(nullptr);
EXPECT_EQ(llmReq.getGenerationLogitsHost(), nullptr);
llmReq.allocTargetModelAcceptedTokenLogitsHost(vocabSizePadded, logitsDataType);
auto targetModelAcceptedTokenLogitShape = llmReq.getGenerationLogitsHost()->getShape();
EXPECT_EQ(targetModelAcceptedTokenLogitShape.nbDims, 3);
EXPECT_EQ(targetModelAcceptedTokenLogitShape.d[0], 1);
EXPECT_EQ(targetModelAcceptedTokenLogitShape.d[1], 4);
EXPECT_EQ(targetModelAcceptedTokenLogitShape.d[2], vocabSizePadded);
}
TEST_F(LlmRequestTest, testLastTokensSetIndependence)
{
tb::LlmRequest::RequestIdType requestId{77};
SizeType32 maxNewTokens(66);
auto inputTokens = std::make_shared<VecTokens>(VecTokens{1, 2, 3, 4, 5});
SizeType32 beamWidth = 3;
bool streaming = false;
tb::LlmRequest::BeamTokens expectedInitialOutput
= {{1, 2, 3, 4, 5, 10, 20}, {1, 2, 3, 4, 5, 11, 21}, {1, 2, 3, 4, 5, 12, 22}};
tb::LlmRequest::BeamTokens expectedOverwrittenOutput
= {{1, 2, 3, 4, 5, 100, 200}, {1, 2, 3, 4, 5, 101, 201}, {1, 2, 3, 4, 5, 102, 202}};
tb::LlmRequest llmReq(requestId, maxNewTokens, inputTokens, tr::SamplingConfig(beamWidth), streaming);
// check individually set tokens
llmReq.addNewToken(10, 0);
llmReq.addNewToken(11, 1);
llmReq.addNewToken(12, 2);
auto lastTokens = llmReq.getLastTokens();
EXPECT_EQ(lastTokens.size(), beamWidth);
EXPECT_THAT(lastTokens, testing::ElementsAreArray({10, 11, 12}));
// check tokens set all-at-once
VecTokens expectedLastTokens = VecTokens({20, 21, 22});
llmReq.addNewTokens(expectedLastTokens);
for (SizeType32 beam = 0; beam < beamWidth; beam++)
{
EXPECT_EQ(llmReq.getLastTokens(beam), expectedLastTokens[beam]);
}
// check mTokens when written by addNewToken
for (SizeType32 beam = 0; beam < beamWidth; beam++)
{
EXPECT_THAT(llmReq.getTokens(beam), testing::ElementsAreArray(expectedInitialOutput[beam]));
}
// check that setGeneratedTokens sets mTokens, but doesn't change lastTokens
tb::LlmRequest::BeamTokens overwriteTokens = {{100, 200}, {101, 201}, {102, 202}};
llmReq.setGeneratedTokens(overwriteTokens);
for (SizeType32 beam = 0; beam < beamWidth; beam++)
{
EXPECT_THAT(llmReq.getTokens(beam), testing::ElementsAreArray(expectedOverwrittenOutput[beam]));
}
EXPECT_THAT(llmReq.getLastTokens(), testing::ElementsAreArray({20, 21, 22}));
}
TEST_F(LlmRequestTest, testCreateRequests)
{
auto inputTokens = std::make_shared<VecTokens>(VecTokens{1, 2, 3, 4, 5});
SizeType32 maxNewTokens{60};
tb::LlmRequest::RequestIdType requestId{77};
SizeType32 vocabSize{32};
nvinfer1::DataType dtype{nvinfer1::DataType::kHALF};
tr::SamplingConfig samplingConfig(1);
samplingConfig.randomSeed = std::vector<texec::RandomSeedType>{7};
tb::LlmRequest llmReq(requestId, maxNewTokens, inputTokens, samplingConfig, false);
try
{
auto childReq = llmReq.createChildRequest(1837);
FAIL() << "Expected an exception.";
}
catch (tc::TllmException const& e)
{
EXPECT_THAT(e.what(), testing::HasSubstr("Cannot create child requests more than"));
}
samplingConfig.numReturnSequences = 3;
tb::LlmRequest llmReq2(requestId, maxNewTokens, inputTokens, samplingConfig, false);
auto childReq1 = llmReq2.createChildRequest(78);
{
EXPECT_EQ(llmReq2.getChildRequests().size(), 1);
EXPECT_EQ(childReq1->mRequestId, 78);
EXPECT_EQ(childReq1->getTokens().at(0), *inputTokens);
EXPECT_EQ(childReq1->getNumTokens(0), llmReq.getNumTokens(0));
EXPECT_EQ(childReq1->getOrigPromptLen(), llmReq.getOrigPromptLen());
EXPECT_EQ(childReq1->mMaxNewTokens, llmReq.mMaxNewTokens);
EXPECT_EQ(childReq1->getState(), llmReq.getState());
EXPECT_EQ(childReq1->mSamplingConfig.randomSeed.value(), std::vector<texec::RandomSeedType>{8});
EXPECT_EQ(llmReq2.mSamplingConfig.randomSeed.value(), std::vector<texec::RandomSeedType>{7});
EXPECT_FALSE(childReq1->mSeqSlot);
}
{
auto childReq2 = llmReq2.createChildRequest(79);
auto childRequests = llmReq2.getChildRequests();
EXPECT_EQ(childRequests.size(), 2);
EXPECT_EQ(childRequests.at(0), childReq1);
EXPECT_EQ(childRequests.at(1), childReq2);
EXPECT_EQ(childReq2->mSamplingConfig.randomSeed.value(), std::vector<texec::RandomSeedType>{9});
EXPECT_EQ(childReq1->mSamplingConfig.randomSeed.value(), std::vector<texec::RandomSeedType>{8});
EXPECT_EQ(llmReq2.mSamplingConfig.randomSeed.value(), std::vector<texec::RandomSeedType>{7});
}
}
using ParamType = std::tuple<bool, bool, bool, SizeType32, SizeType32, SizeType32>;
std::string generateTestName(testing::TestParamInfo<ParamType> const& info)
{
auto const streaming = std::get<0>(info.param);
auto const excludeInputFromOutput = std::get<1>(info.param);
auto const returnAllGeneratedTokens = std::get<2>(info.param);
auto const beamWdith = std::get<3>(info.param);
auto const tokensPerIteration = std::get<4>(info.param);
auto const numReturnSequences = std::get<5>(info.param);
std::string name = "llmRequestTest";
if (streaming)
{
name += "Streaming";
}
if (excludeInputFromOutput)
{
name += "ExclInput";
}
if (returnAllGeneratedTokens)
{
name += "RetAllTokens";
}
name += "Bw" + std::to_string(beamWdith);
name += "TokensPerIt" + std::to_string(tokensPerIteration);
name += "N" + std::to_string(numReturnSequences);
return name;
}
class ParamTest : public LlmRequestTest, public ::testing::WithParamInterface<ParamType>
{
};
TEST_P(ParamTest, createResponse)
{
bool const streaming{std::get<0>(GetParam())};
bool const excludeInputFromOutput{std::get<1>(GetParam())};
bool const returnAllGeneratedTokens{std::get<2>(GetParam())};
SizeType32 const beamWidth{std::get<3>(GetParam())};
SizeType32 const tokensPerIteration{std::get<4>(GetParam())};
SizeType32 const numReturnSequences{std::get<5>(GetParam())};
auto inputTokens = std::make_shared<VecTokens>(VecTokens{1, 2, 3, 4, 5});
SizeType32 maxNewTokens(66);
tb::LlmRequest::RequestIdType requestId{77};
tr::SamplingConfig samplingConfig(beamWidth);
// numReturnSequences = nullopt, otherwise.
if (beamWidth == 1 || numReturnSequences < beamWidth)
{
samplingConfig.numReturnSequences = numReturnSequences;
}
auto numReturnBeams = samplingConfig.getNumReturnBeams();
// Expect one sequence per request in beam search.
auto numSequences = beamWidth > 1 ? 1 : numReturnSequences;
std::vector<std::shared_ptr<tb::LlmRequest>> llmRequests;
llmRequests.emplace_back(
std::make_shared<tb::LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, streaming));
{
auto llmReq = llmRequests.at(0);
llmReq->setExcludeInputFromOutput(excludeInputFromOutput);
if (streaming && beamWidth > 1 && !returnAllGeneratedTokens)
{
EXPECT_THROW(
llmReq->setReturnAllGeneratedTokens(returnAllGeneratedTokens), tensorrt_llm::common::TllmException);
return;
}
llmReq->setReturnAllGeneratedTokens(returnAllGeneratedTokens);
}
if (beamWidth == 1)
{
auto llmReq = llmRequests.at(0);
for (auto seqIdx = 1; seqIdx < numReturnSequences; seqIdx++)
{
tb::LlmRequest::RequestIdType childReqId{77 + static_cast<tb::LlmRequest::RequestIdType>(seqIdx)};
auto childReq = llmReq->createChildRequest(childReqId);
EXPECT_EQ(childReq->getReturnAllGeneratedTokens(), llmReq->getReturnAllGeneratedTokens());
EXPECT_TRUE(childReq->isChild());
llmRequests.emplace_back(std::move(childReq));
}
}
for (auto& llmReq : llmRequests)
{
auto response = llmReq->createResponse();
EXPECT_FALSE(response);
}
SizeType32 constexpr numIterations{5};
std::vector<texec::TokenIdType> newTokens(numSequences);
std::iota(newTokens.begin(), newTokens.end(), 1);
for (auto seqIdx = 0; seqIdx < numSequences; seqIdx++)
{
auto llmReq = llmRequests.at(seqIdx);
for (int i = 0; i < numIterations - 1; ++i)
{
for (int j = 0; j < tokensPerIteration; ++j)
{
llmReq->addNewTokens(VecTokens(numReturnBeams, newTokens.at(seqIdx)));
}
llmReq->setState(tb::LlmRequestState::kGENERATION_IN_PROGRESS);
auto response = llmReq->createResponse();
EXPECT_TRUE(streaming == response.has_value());
for (int beamIdx = 0; beamIdx < numReturnBeams; ++beamIdx)
{
if (streaming)
{
EXPECT_EQ(response.value().getRequestId(), requestId);
auto result = response.value().getResult();
EXPECT_EQ(result.outputTokenIds.size(), numReturnBeams);
auto const& beamTokens = result.outputTokenIds.at(beamIdx);
if (returnAllGeneratedTokens)
{
auto const expectedSize = (i + 1) * tokensPerIteration;
EXPECT_EQ(beamTokens.size(), expectedSize);
VecTokens expectedTokens(expectedSize, newTokens.at(seqIdx));
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
else
{
auto const expectedSize = tokensPerIteration;
EXPECT_EQ(beamTokens.size(), expectedSize);
VecTokens expectedTokens(expectedSize, newTokens.at(seqIdx));
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
}
}
response = llmReq->createResponse();
EXPECT_FALSE(response);
}
}
for (auto seqIdx = 0; seqIdx < numSequences; seqIdx++)
{
for (int j = 0; j < tokensPerIteration; ++j)
{
llmRequests.at(seqIdx)->addNewTokens(VecTokens(numReturnBeams, newTokens.at(seqIdx)));
}
}
llmRequests.at(0)->setState(tb::LlmRequestState::kGENERATION_COMPLETE);
auto const numNewTokens = numIterations * tokensPerIteration;
for (auto seqIdx = 0; seqIdx < numSequences; seqIdx++)
{
auto llmReq = llmRequests.at(seqIdx);
auto response = llmReq->createResponse();
if (!streaming && llmRequests.at(seqIdx)->getState() != tb::LlmRequestState::kGENERATION_COMPLETE)
{
EXPECT_FALSE(response);
continue;
}
EXPECT_TRUE(response) << "seqIdx " << seqIdx;
EXPECT_FALSE(response.value().hasError()) << "seqIdx " << seqIdx;
// All response should have the same request id of the original request.
EXPECT_EQ(response.value().getRequestId(), requestId);
auto result = response.value().getResult();
EXPECT_EQ(result.outputTokenIds.size(), numReturnBeams);
// Only the first sequence has finished.
EXPECT_EQ(result.isSequenceFinal, seqIdx == 0) << "seqIdx " << seqIdx;
EXPECT_EQ(result.isFinal, numSequences == 1) << "seqIdx " << seqIdx;
auto newToken = newTokens.at(seqIdx);
for (int beamIdx = 0; beamIdx < numReturnBeams; ++beamIdx)
{
auto const& beamTokens = result.outputTokenIds.at(beamIdx);
if (!streaming)
{
if (excludeInputFromOutput)
{
EXPECT_EQ(beamTokens.size(), numNewTokens);
VecTokens expectedTokens(numNewTokens, newToken);
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
else
{
auto const expectedSize = inputTokens->size() + numNewTokens;
EXPECT_EQ(beamTokens.size(), expectedSize);
VecTokens expectedTokens(*inputTokens);
expectedTokens.resize(expectedSize, newToken);
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
}
else
{
if (returnAllGeneratedTokens)
{
EXPECT_EQ(beamTokens.size(), numNewTokens);
VecTokens expectedTokens(numNewTokens, newToken);
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
else
{
auto const expectedSize = tokensPerIteration;
EXPECT_EQ(beamTokens.size(), expectedSize);
VecTokens expectedTokens(expectedSize, newToken);
EXPECT_THAT(beamTokens, testing::ElementsAreArray(expectedTokens));
}
}
}
}
if (numSequences > 1)
{
for (auto seqIdx = 1; seqIdx < numSequences; seqIdx++)
{
auto llmReq = llmRequests.at(seqIdx);
for (int j = 0; j < tokensPerIteration; ++j)
{
llmReq->addNewTokens(VecTokens(beamWidth, newTokens.at(seqIdx)));
}
llmReq->setState(tb::LlmRequestState::kGENERATION_COMPLETE);
}
for (auto seqIdx = 1; seqIdx < numSequences; seqIdx++)
{
auto response = llmRequests.at(seqIdx)->createResponse();
EXPECT_TRUE(response) << "seqIdx " << seqIdx;
EXPECT_FALSE(response.value().hasError()) << "seqIdx " << seqIdx;
auto result = response.value().getResult();
// All sequences have finished.
EXPECT_TRUE(result.isSequenceFinal) << "seqIdx " << seqIdx;
EXPECT_TRUE(result.isFinal) << "seqIdx " << seqIdx;
}
}
}
INSTANTIATE_TEST_SUITE_P(LlmRequestTest, ParamTest,
testing::Combine(
// TODO: Support and add coverage for streamLLM
testing::Values(false),
// excludeInputFromOutput
testing::Values(false, true),
// returnAllGeneratedTokens
testing::Values(false, true),
// beamWidth
testing::Values(1, 2),
// tokensPerIteration
testing::Values(1, 3),
// numReturnSequences
testing::Values(1, 2)),
generateTestName);