/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/kernels/stopCriteriaKernels.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/runtime/bufferManager.h" #include #include #include #include namespace tk = tensorrt_llm::kernels; namespace tc = tensorrt_llm::common; using namespace tensorrt_llm::runtime; namespace { // TODO(nkorobov): add tests for numNewTokens for EOS and seqLenLimit class StopCriteriaKernelsTest : public testing::Test { public: using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr; void SetUp() override { mStream = std::make_shared(); mBufferManager = std::make_shared(mStream); } void TearDown() override {} void initData(SizeType32 seed, std::vector>> const& stopWords, SizeType32 maxStopWordsLen, SizeType32 batchSize, SizeType32 beamWidth, std::vector tokensPerStepVec = {}) { auto const maxBatchSize = 2 * batchSize; std::mt19937 generator(seed); std::uniform_int_distribution seqLenDistr(1, mMaxSeqLen); std::uniform_int_distribution endIdPosDistr(0, mMaxSeqLen); std::uniform_int_distribution tokensPerStepDistr(1, mMaxTokensPerStep); mSequenceLengths = BufferManager::pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT32); mSequenceLengthLimits = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); mFinished = BufferManager::pinned( ITensor::makeShape({maxBatchSize, beamWidth}), TRTDataType::value); mFinishedSum = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); mOutputIds = BufferManager::pinned( ITensor::makeShape({maxBatchSize, beamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); mOutputIdsPtr = BufferManager::pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT64); mParentIds = BufferManager::pinned( ITensor::makeShape({maxBatchSize, beamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); mParentIdsPtr = BufferManager::pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT64); mRefOutputIds = BufferManager::pinned( ITensor::makeShape({maxBatchSize, beamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); mStopWords = BufferManager::pinned(ITensor::makeShape({maxBatchSize, 2, maxStopWordsLen}), nvinfer1::DataType::kINT32); mStopWordsPtr = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT64); mStopWordsLen = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); mBatchSlots = BufferManager::pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); mEndIds = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); mTokensPerStep = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType32 bi = 0; bi < batchSize; ++bi) { batchSlotsPtr[bi] = 2 * bi; } auto sequenceLengthsPtr = bufferCast(*mSequenceLengths); auto sequenceLengthLimitsPtr = bufferCast(*mSequenceLengthLimits); auto finishedPtr = reinterpret_cast(bufferCast(*mFinished)); auto finishedSumPtr = bufferCast(*mFinishedSum); for (SizeType32 bi = 0; bi < maxBatchSize; ++bi) { for (SizeType32 ri = 0; ri < beamWidth; ri++) { sequenceLengthsPtr[bi * beamWidth + ri] = maxStopWordsLen == 0 ? seqLenDistr(generator) : mMaxSeqLen - (static_cast(bi / 2) + ri) % mMaxSeqLen; finishedPtr[bi * beamWidth + ri] = tk::FinishedState::empty(); } } for (SizeType32 bi = 0; bi < maxBatchSize; ++bi) { sequenceLengthLimitsPtr[bi] = maxStopWordsLen == 0 ? seqLenDistr(generator) : mMaxSeqLen - static_cast(bi / 2) % mMaxSeqLen; finishedSumPtr[bi] = 0; } auto outputIdsPtrsData = reinterpret_cast(bufferCast(*mOutputIdsPtr)); auto parentIdsPtrsData = reinterpret_cast(bufferCast(*mParentIdsPtr)); auto outputIdsData = bufferCast(*mOutputIds); auto refOutputIdsData = bufferCast(*mRefOutputIds); auto parentIdsData = bufferCast(*mParentIds); auto endIds = BufferRange(*mEndIds); auto tokensPerStep = BufferRange(*mTokensPerStep); mInitSequenceLengths = mBufferManager->copyFrom(*mSequenceLengths, MemoryType::kCPU); // Tokens ids are // bi: 0, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] // bi: 0, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] // bi: 1, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] // bi: 1, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] // bi: 2, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] // bi: 2, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29] // bi: 3, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] // bi: 3, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29] // bi: 4, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] // bi: 4, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28] // bi: 5, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] // bi: 5, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28] // bi: 6, ri: 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] // bi: 6, ri: 1: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27] for (SizeType32 bi = 0; bi < maxBatchSize; bi++) { for (SizeType32 ri = 0; ri < beamWidth; ri++) { for (SizeType32 si = 0; si < mMaxSeqLen; si++) { auto const idx = tc::flat_index3(bi, ri, si, beamWidth, mMaxSeqLen); outputIdsData[idx] = ri * mMaxSeqLen + si; parentIdsData[idx] = 0; } } } for (SizeType32 bi = 0; bi < maxBatchSize; bi++) { outputIdsPtrsData[bi] = outputIdsData + bi * beamWidth * mMaxSeqLen; parentIdsPtrsData[bi] = parentIdsData + bi * beamWidth * mMaxSeqLen; } for (SizeType32 bi = 0; bi < maxBatchSize; bi++) { auto const endIdPos = endIdPosDistr(generator); auto const idx = tc::flat_index3(bi, /* ri */ 0, endIdPos, beamWidth, mMaxSeqLen); endIds[bi] = outputIdsData[idx]; } for (SizeType32 bi = 0; bi < maxBatchSize; bi++) { tokensPerStep[bi] = tokensPerStepDistr(generator); } if (!tokensPerStepVec.empty()) { TLLM_CHECK(tokensPerStepVec.size() == batchSize); for (SizeType32 bi = 0; bi < batchSize; bi++) { auto const batchSlot = batchSlotsPtr[bi]; tokensPerStep[batchSlot] = tokensPerStepVec[bi]; } } mInitTokensPerStep = mBufferManager->copyFrom(*mTokensPerStep, MemoryType::kCPU); // Init stop words tensor auto stopWordsData = bufferCast(*mStopWords); std::fill(stopWordsData, stopWordsData + maxBatchSize * 2 * maxStopWordsLen, -1); for (SizeType32 bi = 0; bi < stopWords.size(); bi++) { SizeType32 totalLen = 0; for (SizeType32 wi = 0; wi < stopWords[bi].size(); ++wi) { for (SizeType32 si = 0; si < stopWords[bi][wi].size(); ++si) { stopWordsData[bi * 2 * maxStopWordsLen + 0 * maxStopWordsLen + totalLen + si] = stopWords[bi][wi][si]; } totalLen += stopWords[bi][wi].size(); // Do not add value if stop words is empty if (totalLen > 0) { stopWordsData[bi * 2 * maxStopWordsLen + 1 * maxStopWordsLen + wi] = totalLen; } } // Special case when all stop words are of single token length if (stopWords[bi].size() == totalLen) { stopWordsData[bi * 2 * maxStopWordsLen + 1 * maxStopWordsLen + totalLen] = totalLen + 1; } } auto stopWordsPtr = BufferRange(*mStopWordsPtr); auto stopWordsLensPtr = bufferCast(*mStopWordsLen); for (SizeType32 bi = 0; bi < stopWords.size(); bi++) { stopWordsPtr[bi] = stopWordsData + bi * 2 * maxStopWordsLen; SizeType32 stopWordsLen = 0; for (auto const& words : stopWords[bi]) { stopWordsLen += words.size(); } if (stopWordsLen == stopWords[bi].size()) { stopWordsLen += 1; } stopWordsLensPtr[bi] = stopWordsLen; } } void verifyMaxSeqLenStopCriteriaResults(SizeType32 seed, SizeType32 batchSize, SizeType32 beamWidth) { mStream->synchronize(); auto sequenceLengthsPtr = bufferCast(*mSequenceLengths); auto sequenceLengthLimitsPtr = bufferCast(*mSequenceLengthLimits); auto finishedPtr = reinterpret_cast(bufferCast(*mFinished)); auto finishedSumPtr = bufferCast(*mFinishedSum); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType32 batchIdx = 0; batchIdx < batchSize; ++batchIdx) { SizeType32 refSumFinished = 0; auto const batchSlot = batchSlotsPtr[batchIdx]; for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto const batchBeamIdx = batchSlot * beamWidth + beamIdx; auto const limitExceeded = sequenceLengthsPtr[batchBeamIdx] >= sequenceLengthLimitsPtr[batchSlot]; refSumFinished += limitExceeded; if (limitExceeded) { EXPECT_TRUE(finishedPtr[batchBeamIdx].isFinishedMaxLength()) << " batchIdx: " << batchIdx << " beamIdx: " << beamIdx << " seed: " << seed; } } EXPECT_EQ(refSumFinished, finishedSumPtr[batchSlot]); } } std::optional isSubsequence( SizeType32 const* sequence, SizeType32 n, std::vector const& subsequence) { auto it = std::search(sequence, sequence + n, subsequence.begin(), subsequence.end()); return (it != sequence + n) ? std::make_optional((it - sequence)) : std::nullopt; } void verifyStopWordsStopCriteriaResults(SizeType32 seed, std::vector>> const& stopWords, SizeType32 stopWordsLen, SizeType32 batchSize, SizeType32 beamWidth, bool multipleTokensPerStep) { mStream->synchronize(); auto outputIdsData = bufferCast(*mOutputIds); auto finishedPtr = reinterpret_cast(bufferCast(*mFinished)); auto initSequenceLengths = BufferRange(*mInitSequenceLengths); auto sequenceLengths = BufferRange(*mSequenceLengths); auto batchSlots = BufferRange(*mBatchSlots); auto initTokensPerStep = BufferRange(*mInitTokensPerStep); auto tokensPerStep = BufferRange(*mTokensPerStep); auto minStopSeqLen = std::numeric_limits::max(); auto minMatchIdx = std::numeric_limits::max(); for (SizeType32 bi = 0; bi < batchSize; bi++) { auto const batchSlot = batchSlots[bi]; for (SizeType32 bwi = 0; bwi < beamWidth; bwi++) { auto outputIdsBatchBeam = outputIdsData + batchSlot * beamWidth * mMaxSeqLen + bwi * mMaxSeqLen; bool found = false; for (SizeType32 wi = 0; wi < stopWords[batchSlot].size(); ++wi) { auto const wordLen = static_cast(stopWords[batchSlot][wi].size()); auto const numTokens = multipleTokensPerStep ? initTokensPerStep[batchSlot] : 1; auto const seqLen = initSequenceLengths[batchSlot * beamWidth + bwi]; auto const offset = seqLen - wordLen - (numTokens - 1); if (wordLen > 0) { auto matchIdx = isSubsequence( outputIdsBatchBeam + offset, wordLen + (numTokens - 1), stopWords[batchSlot][wi]); found |= matchIdx.has_value(); if (matchIdx.has_value()) { if (matchIdx.value() + offset + wordLen < minStopSeqLen) { minStopSeqLen = matchIdx.value() + offset + wordLen; minMatchIdx = matchIdx.value(); } } } } if (found) { EXPECT_TRUE(finishedPtr[batchSlot * beamWidth + bwi].isFinishedStopWords()); } else { EXPECT_FALSE(finishedPtr[batchSlot * beamWidth + bwi].isFinished()); } if (multipleTokensPerStep && found) { EXPECT_EQ(sequenceLengths[batchSlot * beamWidth + bwi], minStopSeqLen); EXPECT_EQ(tokensPerStep[batchSlot], minMatchIdx + 1); } } } } void verifyExplicitEOSCriteriaResults(SizeType32 seed, SizeType32 batchSize) { mStream->synchronize(); auto const beamWidth = 1; auto outputIdsData = BufferRange(*mOutputIds); auto finishedPtr = reinterpret_cast(bufferCast(*mFinished)); auto sequenceLengths = BufferRange(*mSequenceLengths); auto initSequenceLengths = BufferRange(*mInitSequenceLengths); auto batchSlots = BufferRange(*mBatchSlots); auto endIds = BufferRange(*mEndIds); auto tokensPerStep = BufferRange(*mTokensPerStep); for (SizeType32 bi = 0; bi < batchSize; bi++) { auto const batchSlot = batchSlots[bi]; auto const seqLen = sequenceLengths[batchSlot]; auto const initSeqLen = initSequenceLengths[batchSlot]; auto const endId = endIds[batchSlot]; auto const numTokens = tokensPerStep[batchSlot]; for (SizeType32 ti = 0; ti < numTokens; ++ti) { auto const offset = std::max(0, initSeqLen - numTokens + ti); auto const idx = tc::flat_index3(bi, /* ri */ 0, /* si */ offset, beamWidth, mMaxSeqLen); auto const outputId = outputIdsData[idx]; if (endId == outputId) { EXPECT_EQ(seqLen, std::max(offset, 0)); auto const eosIdx = tc::flat_index3(bi, /* ri */ 0, /* si */ seqLen, beamWidth, mMaxSeqLen); EXPECT_EQ(outputIdsData[eosIdx], endId); EXPECT_TRUE(finishedPtr[batchSlot].isFinishedEOS()); break; } } } } void runStopWordsCriteriaTest(std::vector>> const& stopWords, SizeType32 batchSize, SizeType32 beamWidth, std::vector tokensPerStep = {}) { SizeType32 maxStopWordsLen = 0; for (auto const& batchStopWords : stopWords) { SizeType32 stopWordsLen = 0; for (auto const& words : batchStopWords) { stopWordsLen += words.size(); } if (stopWordsLen == batchStopWords.size()) { stopWordsLen += 1; } maxStopWordsLen = std::max(maxStopWordsLen, stopWordsLen); } initData(0, stopWords, maxStopWordsLen, batchSize, beamWidth, tokensPerStep); auto numNewTokens = tokensPerStep.size() ? bufferCast(*mTokensPerStep) : nullptr; tk::invokeStopWordsCriterion(bufferCast(*mOutputIdsPtr), bufferCast(*mParentIdsPtr), bufferCast(*mStopWordsPtr), reinterpret_cast(bufferCast(*mFinished)), bufferCast(*mSequenceLengths), bufferCast(*mBatchSlots), bufferCast(*mStopWordsLen), numNewTokens, maxStopWordsLen, batchSize, beamWidth, mMaxSeqLen, mStream->get()); verifyStopWordsStopCriteriaResults(0, stopWords, maxStopWordsLen, batchSize, beamWidth, tokensPerStep.size()); } void runMaxLengthCriteriaTest(SizeType32 seed, SizeType32 batchSize, SizeType32 beamWidth) { initData(seed, {}, 0, batchSize, beamWidth); tk::invokeLengthCriterion( reinterpret_cast(bufferCast(*mFinished)), bufferCast(*mFinishedSum), reinterpret_cast(bufferCast(*mSequenceLengthLimits)), bufferCast(*mSequenceLengths), /* numNewTokens */ nullptr, bufferCast(*mBatchSlots), batchSize, beamWidth, mStream->get()); verifyMaxSeqLenStopCriteriaResults(seed, batchSize, beamWidth); } void runExplicitEOSCriteriaTest(SizeType32 seed, SizeType32 batchSize) { initData(seed, {}, 0, batchSize, /* beamWidth */ 1); tk::invokeExplicitEOSCriterion(reinterpret_cast(bufferCast(*mOutputIdsPtr)), bufferCast(*mEndIds), reinterpret_cast(bufferCast(*mFinished)), bufferCast(*mSequenceLengths), bufferCast(*mTokensPerStep), bufferCast(*mBatchSlots), batchSize, /* beamWidth */ 1, mMaxTokensPerStep, mStream->get()); verifyExplicitEOSCriteriaResults(seed, batchSize); } protected: std::shared_ptr mBufferManager; std::shared_ptr mStream; TensorPtr mSequenceLengths; TensorPtr mInitSequenceLengths; TensorPtr mSequenceLengthLimits; TensorPtr mFinished; TensorPtr mFinishedSum; TensorPtr mOutputIds; TensorPtr mRefOutputIds; TensorPtr mOutputIdsPtr; TensorPtr mParentIds; TensorPtr mParentIdsPtr; TensorPtr mStopWords; TensorPtr mStopWordsPtr; TensorPtr mStopWordsLen; TensorPtr mBatchSlots; TensorPtr mEndIds; TensorPtr mTokensPerStep; TensorPtr mInitTokensPerStep; static SizeType32 constexpr mMaxSeqLen{16}; static SizeType32 constexpr mVocabSize{32}; static SizeType32 constexpr mMaxTokensPerStep{4}; }; TEST_F(StopCriteriaKernelsTest, maxLengthCriteriaBS1BW1Test) { SizeType32 constexpr seeds = 64; SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 1; for (SizeType32 seed = 0; seed < seeds; ++seed) { this->runMaxLengthCriteriaTest(seed, batchSize, beamWidth); } } TEST_F(StopCriteriaKernelsTest, maxLengthCriteriaBS1BW2Test) { SizeType32 constexpr seeds = 64; SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 2; for (SizeType32 seed = 0; seed < seeds; ++seed) { this->runMaxLengthCriteriaTest(seed, batchSize, beamWidth); } } TEST_F(StopCriteriaKernelsTest, maxLengthCriteriaBS1024BW1Test) { SizeType32 constexpr seeds = 64; SizeType32 constexpr batchSize = 1024; SizeType32 constexpr beamWidth = 1; for (SizeType32 seed = 0; seed < seeds; ++seed) { this->runMaxLengthCriteriaTest(seed, batchSize, beamWidth); } } TEST_F(StopCriteriaKernelsTest, maxLengthCriteriaBS1024BW2Test) { SizeType32 constexpr seeds = 64; SizeType32 constexpr batchSize = 1024; SizeType32 constexpr beamWidth = 2; for (SizeType32 seed = 0; seed < seeds; ++seed) { this->runMaxLengthCriteriaTest(seed, batchSize, beamWidth); } } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1SingleTokenSingleWordTest) { SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 1; // Expected to not match any word this->runStopWordsCriteriaTest({{{2}}, {{2}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1SingleTokenMultipleWordsTest) { SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 1; // Expected to match 15 this->runStopWordsCriteriaTest({{{145}, {4}, {1}, {15}}, {{}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1MultipleTokensSingleWordTest) { SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 1; // Expected to not match any word this->runStopWordsCriteriaTest({{{2, 3}}, {{}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1MultipleTokensMultipleWordsMatchTest) { SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 1; // Expected to match {13, 14, 15} this->runStopWordsCriteriaTest({{{1, 4}, {2, 3}, {13, 14, 15}}, {{}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1MultipleTokensMultipleWordsNotMatchTest) { SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 1; // Expected to not match any word this->runStopWordsCriteriaTest({{{1, 4}, {2, 3}, {12, 14, 15}}, {{}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS4MultipleTokensMultipleWordsTest) { SizeType32 constexpr batchSize = 4; SizeType32 constexpr beamWidth = 1; // Expected to match {12, 13} for the 5th instance in the batch this->runStopWordsCriteriaTest( {{{2}}, {{}}, {{}}, {{}}, {{15}, {12, 13}}, {{}}, {{1}, {8, 9}}, {{}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS4BW2MultipleTokensMultipleWordsTest) { SizeType32 constexpr batchSize = 4; SizeType32 constexpr beamWidth = 2; // Expected to match {12, 13} to {bi, bw}={{5, 0}} // Expected to match {11, 12} to {bi, bw}={{7, 0}} // Expected to match {27} to {bi, bw}={{5, 1}} this->runStopWordsCriteriaTest( {{{2}}, {{}}, {{}}, {{}}, {{11}, {12, 13}}, {{}}, {{27}, {11, 12}}, {{}}}, batchSize, beamWidth); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1SingleTokenSingleWordNoMatchTestMultipleTokensPerStep) { SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 1; // Expected to not match any word this->runStopWordsCriteriaTest({{{13}}, {{13}}}, batchSize, beamWidth, {2}); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1SingleTokenSingleWordMatchTestMultipleTokensPerStep) { SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 1; // Expected to match any word {13} this->runStopWordsCriteriaTest({{{13}}, {{13}}}, batchSize, beamWidth, {3}); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1SingleTokenMultipleWordsSingleMatchTestMultipleTokensPerStep) { SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 1; // Expected to match 15 this->runStopWordsCriteriaTest({{{145}, {4}, {1}, {15}}, {{}}}, batchSize, beamWidth, {2}); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1SingleTokenMultipleWordsMultipleMatchTestMultipleTokensPerStep) { SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 1; // Expected to match 15 and 14 and stop on 14 this->runStopWordsCriteriaTest({{{145}, {4}, {1}, {15}, {14}}, {{}}}, batchSize, beamWidth, {2}); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS1MultipleTokensMultipleWordsMatchTestMultipleTokensPerStep) { SizeType32 constexpr batchSize = 1; SizeType32 constexpr beamWidth = 1; // Expected to match {13, 14, 15} and {11, 12} and stop on {11, 12} this->runStopWordsCriteriaTest({{{1, 4}, {2, 3}, {13, 14, 15}, {11, 12}}, {{}}}, batchSize, beamWidth, {5}); } TEST_F(StopCriteriaKernelsTest, stopWordsCriteriaBS4MultipleTokensMultipleWordsTestMultipleTokensPerStep) { SizeType32 constexpr batchSize = 4; SizeType32 constexpr beamWidth = 1; // Expected to match {8, 9} and {12, 13} this->runStopWordsCriteriaTest( {{{2}}, {{}}, {{}}, {{}}, {{15}, {12, 13}}, {{}}, {{1}, {8, 9}}, {{}}}, batchSize, beamWidth, {2, 5, 3, 4}); } TEST_F(StopCriteriaKernelsTest, explicitEOSCriteria) { SizeType32 constexpr seeds = 64; SizeType32 constexpr beamWidth = 1; SizeType32 constexpr batchSize = 1024; for (SizeType32 seed = 0; seed < seeds; ++seed) { this->runExplicitEOSCriteriaTest(seed, batchSize); } } } // end of namespace