/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef TOP_LEVEL_DIR #error "Define TOP_LEVEL_DIR" #endif #include #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/stopCriteriaKernels.h" #include "tensorrt_llm/runtime/bufferManager.h" #include #include #include namespace tk = tensorrt_llm::kernels; namespace tc = tensorrt_llm::common; using namespace tensorrt_llm::runtime; namespace { class StopCriteriaKernelsTest : public testing::Test { public: using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr; void SetUp() override { mStream = std::make_shared(); mBufferManager = std::make_shared(mStream); } void TearDown() override {} void initData(SizeType seed, const std::vector>>& stopWords, SizeType stopWordsLen, SizeType batchSize, SizeType beamWidth) { auto const maxBatchSize = 2 * batchSize; std::mt19937 generator(seed); std::uniform_int_distribution seqLenDistr(0, mMaxSeqLen); mSequenceLengths = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT32); mSequenceLengthLimits = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); mFinished = mBufferManager->pinned( ITensor::makeShape({maxBatchSize, beamWidth}), TRTDataType::value); mFinishedSum = mBufferManager->pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32); mOutputIds = mBufferManager->pinned( ITensor::makeShape({maxBatchSize, beamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); mOutputIdsPtr = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT64); mParentIds = mBufferManager->pinned( ITensor::makeShape({maxBatchSize, beamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); mParentIdsPtr = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT64); mRefOutputIds = mBufferManager->pinned( ITensor::makeShape({maxBatchSize, beamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); mStopWords = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, 2, stopWordsLen}), nvinfer1::DataType::kINT32); mBatchSlots = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType bi = 0; bi < batchSize; ++bi) { batchSlotsPtr[bi] = 2 * bi; } auto sequenceLengthsPtr = bufferCast(*mSequenceLengths); auto sequenceLengthLimitsPtr = bufferCast(*mSequenceLengthLimits); auto finishedPtr = reinterpret_cast(bufferCast(*mFinished)); auto finishedSumPtr = bufferCast(*mFinishedSum); for (SizeType bi = 0; bi < maxBatchSize; ++bi) { for (SizeType ri = 0; ri < beamWidth; ri++) { sequenceLengthsPtr[bi * beamWidth + ri] = stopWordsLen == 0 ? seqLenDistr(generator) : mMaxSeqLen - (static_cast(bi / 2) + ri) % mMaxSeqLen; finishedPtr[bi * beamWidth + ri] = tk::FinishedState::empty(); } } for (SizeType bi = 0; bi < maxBatchSize; ++bi) { sequenceLengthLimitsPtr[bi] = stopWordsLen == 0 ? seqLenDistr(generator) : mMaxSeqLen - static_cast(bi / 2) % mMaxSeqLen; } finishedSumPtr[0] = 0; auto outputIdsPtrsData = reinterpret_cast(bufferCast(*mOutputIdsPtr)); auto parentIdsPtrsData = reinterpret_cast(bufferCast(*mParentIdsPtr)); auto outputIdsData = bufferCast(*mOutputIds); auto refOutputIdsData = bufferCast(*mRefOutputIds); auto parentIdsData = bufferCast(*mParentIds); // Tokens ids are // bi: 0, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] // bi: 0, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] // bi: 1, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] // bi: 1, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] // bi: 2, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] // bi: 2, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29] // bi: 3, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] // bi: 3, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29] // bi: 4, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] // bi: 4, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28] // bi: 5, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] // bi: 5, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28] // bi: 6, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] // bi: 6, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27] for (SizeType bi = 0; bi < maxBatchSize; bi++) { for (SizeType ri = 0; ri < beamWidth; ri++) { for (SizeType si = 0; si < mMaxSeqLen; si++) { auto const idx = tc::flat_index3(bi, ri, si, beamWidth, mMaxSeqLen); outputIdsData[idx] = ri * mMaxSeqLen + si; parentIdsData[idx] = 0; } } } for (SizeType bi = 0; bi < maxBatchSize; bi++) { outputIdsPtrsData[bi] = outputIdsData + bi * beamWidth * mMaxSeqLen; parentIdsPtrsData[bi] = parentIdsData + bi * beamWidth * mMaxSeqLen; } // Init stop words tensor auto stopWordsData = bufferCast(*mStopWords); std::fill(stopWordsData, stopWordsData + maxBatchSize * 2 * stopWordsLen, -1); for (SizeType bi = 0; bi < stopWords.size(); bi++) { SizeType totalLen = 0; for (SizeType wi = 0; wi < stopWords[bi].size(); ++wi) { for (SizeType si = 0; si < stopWords[bi][wi].size(); ++si) { stopWordsData[bi * 2 * stopWordsLen + 0 * stopWordsLen + totalLen + si] = stopWords[bi][wi][si]; } totalLen += stopWords[bi][wi].size(); // Do not add value if stop words is empty if (totalLen > 0) { stopWordsData[bi * 2 * stopWordsLen + 1 * stopWordsLen + wi] = totalLen; } } // Special case when all stop words are of single token length if (stopWords[bi].size() == totalLen) { stopWordsData[bi * 2 * stopWordsLen + 1 * stopWordsLen + totalLen] = totalLen + 1; } } } void verifyMaxSeqLenStopCriteriaResults(SizeType seed, SizeType batchSize, SizeType beamWidth) { mStream->synchronize(); auto sequenceLengthsPtr = bufferCast(*mSequenceLengths); auto sequenceLengthLimitsPtr = bufferCast(*mSequenceLengthLimits); auto finishedPtr = reinterpret_cast(bufferCast(*mFinished)); auto finishedSumPtr = bufferCast(*mFinishedSum); auto batchSlotsPtr = bufferCast(*mBatchSlots); int32_t refSumFinished = 0; for (SizeType bi = 0; bi < batchSize * beamWidth; ++bi) { auto const batchIdx = bi / beamWidth; auto const beamIdx = bi % beamWidth; auto const batchSlot = batchSlotsPtr[batchIdx]; auto const batchBeamIdx = batchSlot * beamWidth + beamIdx; const auto limitExceeded = sequenceLengthsPtr[batchBeamIdx] >= sequenceLengthLimitsPtr[batchSlot]; refSumFinished += limitExceeded; if (limitExceeded) { EXPECT_TRUE(finishedPtr[batchBeamIdx].isFinishedMaxLength()) << " batchIdx: " << batchIdx << " beamIdx: " << beamIdx << " seed: " << seed; } } EXPECT_EQ(refSumFinished, finishedSumPtr[0]); } bool isSubsequence(const SizeType* sequence, SizeType n, const std::vector& subsequence) { auto it = std::search(sequence, sequence + n, subsequence.begin(), subsequence.end()); return it != sequence + n; } void verifyStopWordsStopCriteriaResults(SizeType seed, const std::vector>>& stopWords, SizeType stopWordsLen, SizeType batchSize, SizeType beamWidth) { mStream->synchronize(); auto outputIdsData = bufferCast(*mOutputIds); auto finishedPtr = reinterpret_cast(bufferCast(*mFinished)); auto sequenceLengthsPtr = bufferCast(*mSequenceLengths); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType bi = 0; bi < batchSize; bi++) { auto const batchSlot = batchSlotsPtr[bi]; for (SizeType bwi = 0; bwi < beamWidth; bwi++) { auto outputIdsBatchBeam = outputIdsData + batchSlot * beamWidth * mMaxSeqLen + bwi * mMaxSeqLen; bool found = false; for (SizeType wi = 0; wi < stopWords[batchSlot].size(); ++wi) { auto const wordLen = stopWords[batchSlot][wi].size(); auto const seqLen = sequenceLengthsPtr[batchSlot * beamWidth + bwi]; auto const offset = seqLen - wordLen; found |= isSubsequence(outputIdsBatchBeam + offset, wordLen, stopWords[batchSlot][wi]); if (found) { EXPECT_TRUE(finishedPtr[batchSlot * beamWidth + bwi].isFinishedStopWords()); break; } } if (!found) { EXPECT_FALSE(finishedPtr[batchSlot * beamWidth + bwi].isFinished()); } } } } void runStopWordsCriteriaTest( const std::vector>>& stopWords, SizeType batchSize, SizeType beamWidth) { SizeType maxStopWordsLen = 0; for (const auto& batchStopWords : stopWords) { SizeType stopWordsLen = 0; for (const auto& words : batchStopWords) { stopWordsLen += words.size(); } if (stopWordsLen == batchStopWords.size()) { stopWordsLen += 1; } maxStopWordsLen = std::max(maxStopWordsLen, stopWordsLen); } initData(0, stopWords, maxStopWordsLen, batchSize, beamWidth); tk::invokeStopWordsCriterion(reinterpret_cast(bufferCast(*mOutputIdsPtr)), reinterpret_cast(bufferCast(*mParentIdsPtr)), bufferCast(*mStopWords), reinterpret_cast(bufferCast(*mFinished)), bufferCast(*mSequenceLengths), bufferCast(*mBatchSlots), maxStopWordsLen, batchSize, beamWidth, mMaxSeqLen, mStream->get()); verifyStopWordsStopCriteriaResults(0, stopWords, maxStopWordsLen, batchSize, beamWidth); } void runMaxLengthCriteriaTest(SizeType seed, SizeType batchSize, SizeType beamWidth) { initData(seed, {}, 0, batchSize, beamWidth); tk::invokeLengthCriterion( reinterpret_cast(bufferCast(*mFinished)), bufferCast(*mFinishedSum), reinterpret_cast(bufferCast(*mSequenceLengthLimits)), bufferCast(*mSequenceLengths), bufferCast(*mBatchSlots), batchSize, beamWidth, mStream->get()); verifyMaxSeqLenStopCriteriaResults(seed, batchSize, beamWidth); } protected: std::shared_ptr mBufferManager; std::shared_ptr mStream; TensorPtr mSequenceLengths; TensorPtr mSequenceLengthLimits; TensorPtr mFinished; TensorPtr mFinishedSum; TensorPtr mOutputIds; TensorPtr mRefOutputIds; TensorPtr mOutputIdsPtr; TensorPtr mParentIds; TensorPtr mParentIdsPtr; TensorPtr mStopWords; TensorPtr mBatchSlots; static constexpr SizeType mMaxSeqLen{16}; static constexpr SizeType mVocabSize{32}; }; TEST_F(StopCriteriaKernelsTest, maxLengthCriteriaBS1BW1Test) { constexpr SizeType seeds = 64; constexpr SizeType batchSize = 1; constexpr SizeType beamWidth = 1; for (SizeType seed = 0; seed < seeds; ++seed) { this->runMaxLengthCriteriaTest(seed, batchSize, beamWidth); } } TEST_F(StopCriteriaKernelsTest, maxLengthCriteriaBS1BW2Test) { constexpr SizeType seeds = 64; constexpr SizeType batchSize = 1; constexpr SizeType beamWidth = 2; for (SizeType seed = 0; seed < seeds; ++seed) { this->runMaxLengthCriteriaTest(seed, batchSize, beamWidth); } } TEST_F(StopCriteriaKernelsTest, maxLengthCriteriaBS1024BW1Test) { constexpr SizeType seeds = 64; constexpr SizeType batchSize = 1024; constexpr SizeType beamWidth = 1; for (SizeType seed = 0; seed < seeds; ++seed) { this->runMaxLengthCriteriaTest(seed, batchSize, beamWidth); } } TEST_F(StopCriteriaKernelsTest, maxLengthCriteriaBS1024BW2Test) { constexpr SizeType seeds = 64; constexpr SizeType batchSize = 1024; constexpr SizeType beamWidth = 2; for (SizeType seed = 0; seed < seeds; ++seed) { this->runMaxLengthCriteriaTest(seed, batchSize, beamWidth); } } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1SingleTokenSingleWordTest) { constexpr SizeType batchSize = 1; constexpr SizeType beamWidth = 1; // Expected to not match any word this->runStopWordsCriteriaTest({{{2}}, {{2}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1SingleTokenMultipleWordsTest) { constexpr SizeType batchSize = 1; constexpr SizeType beamWidth = 1; // Expected to match 15 this->runStopWordsCriteriaTest({{{145}, {4}, {1}, {15}}, {{}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1MultipleTokensSingleWordTest) { constexpr SizeType batchSize = 1; constexpr SizeType beamWidth = 1; // Expected to not match any word this->runStopWordsCriteriaTest({{{2, 3}}, {{}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1MultipleTokensMultipleWordsMatchTest) { constexpr SizeType batchSize = 1; constexpr SizeType beamWidth = 1; // Expected to match {13, 14, 15} this->runStopWordsCriteriaTest({{{1, 4}, {2, 3}, {13, 14, 15}}, {{}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1MultipleTokensMultipleWordsNotMatchTest) { constexpr SizeType batchSize = 1; constexpr SizeType beamWidth = 1; // Expected to not match any word this->runStopWordsCriteriaTest({{{1, 4}, {2, 3}, {12, 14, 15}}, {{}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS4MultipleTokensMultipleWordsTest) { constexpr SizeType batchSize = 4; constexpr SizeType beamWidth = 1; // Expected to match {12, 13} for the 5th instance in the batch this->runStopWordsCriteriaTest( {{{2}}, {{}}, {{}}, {{}}, {{15}, {12, 13}}, {{}}, {{1}, {8, 9}}, {{}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS4BW2MultipleTokensMultipleWordsTest) { constexpr SizeType batchSize = 4; constexpr SizeType beamWidth = 2; // Expected to match {12, 13} to {bi, bw}={{5, 0}} // Expected to match {11, 12} to {bi, bw}={{7, 0}} // Expected to match {27} to {bi, bw}={{5, 1}} this->runStopWordsCriteriaTest( {{{2}}, {{}}, {{}}, {{}}, {{11}, {12, 13}}, {{}}, {{27}, {11, 12}}, {{}}}, batchSize, beamWidth); } } // end of namespace