/* * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tests/unit_tests/layers/explicitDraftTokensLayerTest.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/speculativeDecoding/explicitDraftTokensKernels.h" #include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/speculativeDecodingModule.h" #include "tensorrt_llm/runtime/tllmLogger.h" #include #include #include #include namespace tensorrt_llm::tests::layers { // TODO(nkorobov) verify context + gen mix using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::layers; using namespace tensorrt_llm::common; namespace tk = tensorrt_llm::kernels; namespace tksd = tensorrt_llm::kernels::speculative_decoding; namespace trk = tensorrt_llm::runtime::kernels; TokensVec ExplicitDraftTokensDummyNetwork::tokenize(std::string const& letters) const { TokensVec tokens; for (char c : letters) { tokens.push_back(static_cast(c)); } return tokens; } std::string ExplicitDraftTokensDummyNetwork::detokenize(TokensVec const& tokens) const { std::string letters; for (int token : tokens) { letters += static_cast(token); } return letters; } DraftTokensVec ExplicitDraftTokensDummyNetwork::draftLettersToTokens(DraftLettersVec const& draftLetters) const { DraftTokensVec draftTokens(draftLetters.size()); for (SizeType32 bi = 0; bi < draftLetters.size(); ++bi) { draftTokens[bi].resize(draftLetters[bi].size()); for (SizeType32 pi = 0; pi < draftLetters[bi].size(); ++pi) { draftTokens[bi][pi] = tokenize(draftLetters[bi][pi]); } } return draftTokens; } SizeType32 ExplicitDraftTokensDummyNetwork::longestCommonPrefixLength(TokensVec const& a, TokensVec const& b) const { SizeType32 minLength = std::min(a.size(), b.size()); SizeType32 idx = 0; while (idx < minLength && a[idx] == b[idx]) { ++idx; } return idx; } SizeType32 ExplicitDraftTokensDummyNetwork::computeCompressedVectorAndIndices(TokensVec& compressedVector, std::vector& packedPosIds, DraftTokensIndices& indices, std::vector const& vectors, SizeType32 basePosId) { TokensVec localCompressedVector; std::vector localPackedPosIds; std::vector> localIndices; // FIXME always take the 1st beam as the reference. Is that correct? // Add whole first vector to compressed vector localCompressedVector = vectors[0]; // All indices of first vector. localIndices.push_back(std::vector(localCompressedVector.size())); for (SizeType32 ti = 0; ti < localCompressedVector.size(); ++ti) { localIndices[0][ti] = ti; // Set local to batch packed pos ids. localPackedPosIds.push_back(basePosId + ti); } // Starting from the 1st path. for (SizeType32 pi = 1; pi < vectors.size(); ++pi) { // Match path to compressed vector (aka path 0). auto const prefixLength = longestCommonPrefixLength(localCompressedVector, vectors[pi]); localIndices.push_back(std::vector(vectors[pi].size())); // Set indices of the matched prefix. for (SizeType32 ti = 0; ti < prefixLength; ++ti) { localIndices[pi][ti] = ti; } // For non-matched part. for (SizeType32 ti = prefixLength; ti < vectors[pi].size(); ++ti) { // Add new tokens to compressed vector. localCompressedVector.push_back(vectors[pi][ti]); // Set new pos ids. localPackedPosIds.push_back(basePosId + ti); // Set their indices. localIndices[pi][ti] = localCompressedVector.size() - 1; } } compressedVector.insert(compressedVector.end(), localCompressedVector.begin(), localCompressedVector.end()); packedPosIds.insert(packedPosIds.end(), localPackedPosIds.begin(), localPackedPosIds.end()); indices.push_back(localIndices); return static_cast(localCompressedVector.size()); } void ExplicitDraftTokensDummyNetwork::createNextMasks( DraftTokensIndices const& indices, DraftTokensVec const& draftTokens, SizeType32 maxGenLength) { for (SizeType32 bi = 0; bi < indices.size(); ++bi) { std::vector> localMask(maxGenLength, std::vector(maxGenLength)); // Create fill diagonal. for (SizeType32 ti = 0; ti < maxGenLength; ++ti) { localMask[ti][ti] = true; } SizeType32 rowIdx = 0; for (SizeType32 pi = 0; pi < draftTokens[bi].size(); ++pi) { auto const prefixLength = pi == 0 ? 0 : longestCommonPrefixLength(draftTokens[bi][0], draftTokens[bi][pi]); for (SizeType32 ti = 0; ti < draftTokens[bi][pi].size(); ++ti) { auto const index = indices[bi][pi][ti]; // If we are in the "prefix" part of the sequence skip it as it does not represent real mask row. if (ti < prefixLength) { continue; } // Fill lower triangular part according to the prefix. for (SizeType32 tti = 0; tti < ti; ++tti) { localMask[rowIdx][indices[bi][pi][tti]] = true; } rowIdx++; } } mMasks.push_back(localMask); } } void ExplicitDraftTokensDummyNetwork::compressTokens(TokensVec& compressedVector, std::vector& packedPosIds, DraftTokensIndices& indices, std::vector& generationLengths, DraftTokensVec const& draftTokens, std::vector const& basePosIds) { generationLengths.resize(draftTokens.size()); for (SizeType32 bi = 0; bi < draftTokens.size(); ++bi) { auto numGeneratedTokens = computeCompressedVectorAndIndices( compressedVector, packedPosIds, indices, draftTokens[bi], basePosIds[bi]); generationLengths[bi] = numGeneratedTokens; } // Pad vectors to the maximum size auto const padSize = mSamplingParams.getMaxDecodingTokens() * mSamplingParams.getBatchSize() - compressedVector.size(); compressedVector.insert(compressedVector.end(), padSize, mSamplingParams.getPadId()); packedPosIds.insert(packedPosIds.end(), padSize, 0); } void ExplicitDraftTokensDummyNetwork::acceptTokens(std::vector const& predictionTokens, DraftTokensVec const& lastDraftTokens, DraftTokensVec const& nextDraftTokens) { TLLM_CHECK_WITH_INFO(predictionTokens.size() == lastDraftTokens.size(), "Batch size of predictions (%d) does not match the batch size of last draft tokens (%d)", static_cast(predictionTokens.size()), static_cast(lastDraftTokens.size())); TLLM_CHECK_WITH_INFO(predictionTokens.size() == nextDraftTokens.size(), "Batch size of predictions (%d) does not match the batch size of next draft tokens (%d)", static_cast(predictionTokens.size()), static_cast(nextDraftTokens.size())); mBestPathLengths.resize(predictionTokens.size()); mBestPathIndices.resize(predictionTokens.size()); // Needed for unit test of ExplicitDraftTokensDummyNetwork only. if (mOutputIds.size() == 0) { mOutputIds.resize(lastDraftTokens.size()); } for (SizeType32 bi = 0; bi < predictionTokens.size(); ++bi) { SizeType32 maxMatchLen = -1; SizeType32 maxMatchIdx = -1; // Find path with largest prefix shared with the predicted tokens. for (SizeType32 pi = 0; pi < lastDraftTokens[bi].size(); ++pi) { TLLM_CHECK_WITH_INFO(predictionTokens[bi][0] == lastDraftTokens[bi][pi][0], "First token of prediction and draft token must match"); auto const matchLen = longestCommonPrefixLength(lastDraftTokens[bi][pi], predictionTokens[bi]); if (matchLen > maxMatchLen) { maxMatchLen = matchLen; maxMatchIdx = pi; } } mBestPathLengths[bi] = maxMatchLen; mBestPathIndices[bi] = maxMatchIdx; // Update output ids. First draft token is already counted in outputs mOutputIds[bi].insert(mOutputIds[bi].end(), lastDraftTokens[bi][maxMatchIdx].begin() + 1, lastDraftTokens[bi][maxMatchIdx].begin() + maxMatchLen); mOutputIds[bi].push_back(nextDraftTokens[bi][0][0]); } } void ExplicitDraftTokensDummyNetwork::forward(SamplingParams const& params, std::vector const& promptsLetters, std::vector const& predictionLetters, DraftLettersVec const& nextDraftLetters, DraftLettersVec const& lastDraftLetters) { mSamplingParams = params; TLLM_CHECK(params.getBatchSize() == promptsLetters.size()); TLLM_CHECK(params.getBatchSize() == predictionLetters.size()); TLLM_CHECK(params.getBatchSize() == nextDraftLetters.size()); TLLM_CHECK(params.getBatchSize() == lastDraftLetters.size()); // Tokenize mNextDraftTokens = draftLettersToTokens(nextDraftLetters); mLastDraftTokens = draftLettersToTokens(lastDraftLetters); std::vector predictionTokens; for (SizeType32 bi = 0; bi < predictionLetters.size(); ++bi) { predictionTokens.push_back(tokenize(predictionLetters[bi])); mPrompts.push_back(tokenize(promptsLetters[bi])); } std::vector basePosIds; for (auto const& prompt : mPrompts) { basePosIds.push_back(prompt.size()); } mOutputIds = mPrompts; // Make compressed tensors and pos ids for the current and next tokens compressTokens(mNextCompressedVector, mNextPackedPosIds, mNextDraftTokenIndices, mNextGenerationLengths, mNextDraftTokens, basePosIds); compressTokens(mLastCompressedVector, mLastPackedPosIds, mLastDraftTokenIndices, mLastGenerationLengths, mLastDraftTokens, basePosIds); mMaxNextGenLength = *std::max_element(mNextGenerationLengths.begin(), mNextGenerationLengths.end()); acceptTokens(predictionTokens, mLastDraftTokens, mNextDraftTokens); createNextMasks(mNextDraftTokenIndices, mNextDraftTokens, mMaxNextGenLength); } TEST(ExplicitDraftTokensDummyNetworkTest, tokenizeTest) { ExplicitDraftTokensDummyNetwork network; { auto tokens = network.tokenize("hello world"); EXPECT_EQ(tokens, std::vector({104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100})); } { DraftLettersVec lettersVec = {{"hello world", "hello"}, {"world"}}; auto draftTokens = network.draftLettersToTokens(lettersVec); ASSERT_EQ(draftTokens.size(), 2); ASSERT_EQ(draftTokens[0].size(), 2); ASSERT_EQ(draftTokens[1].size(), 1); EXPECT_EQ(draftTokens[0][0], std::vector({104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100})); EXPECT_EQ(draftTokens[0][1], std::vector({104, 101, 108, 108, 111})); EXPECT_EQ(draftTokens[1][0], std::vector({119, 111, 114, 108, 100})); } } TEST(ExplicitDraftTokensDummyNetworkTest, detokenizeTest) { ExplicitDraftTokensDummyNetwork network; { auto letters = network.detokenize(std::vector({104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100})); EXPECT_EQ(letters, "hello world"); } } TEST(ExplicitDraftTokensDummyNetworkTest, longestCommonPrefixLengthTest) { ExplicitDraftTokensDummyNetwork network; EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {1, 2}), 2); EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {1, 2, 3}), 3); EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {1, 5, 6}), 1); EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {2, 5, 6}), 0); EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {}), 0); } TEST(ExplicitDraftTokensDummyNetworkTest, computeCompressedVectorAndIndicesTest) { ExplicitDraftTokensDummyNetwork network; { std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; SizeType32 basePosId{0}; std::vector> tokens = {{0, 1, 2, 3}}; auto const totalGen = network.computeCompressedVectorAndIndices(compressedVector, packedPosIds, indices, tokens, basePosId); EXPECT_EQ(totalGen, 4); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3})); EXPECT_EQ(packedPosIds, std::vector({0, 1, 2, 3})); ASSERT_EQ(indices.size(), 1); ASSERT_EQ(indices[0].size(), 1); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); } { std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; SizeType32 basePosId{0}; std::vector> tokens = {{0, 1, 2, 3}, {0, 2, 3, 4}}; auto const totalGen = network.computeCompressedVectorAndIndices(compressedVector, packedPosIds, indices, tokens, basePosId); EXPECT_EQ(totalGen, 7); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3, 2, 3, 4})); EXPECT_EQ(packedPosIds, std::vector({0, 1, 2, 3, 1, 2, 3})); ASSERT_EQ(indices.size(), 1); ASSERT_EQ(indices[0].size(), 2); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[0][1], std::vector({0, 4, 5, 6})); } { std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; SizeType32 basePosId{0}; std::vector> tokens = {{0, 1, 2, 3}, {0, 1, 6, 2}, {0, 5, 6, 2}}; auto const totalGen = network.computeCompressedVectorAndIndices(compressedVector, packedPosIds, indices, tokens, basePosId); EXPECT_EQ(totalGen, 9); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3, 6, 2, 5, 6, 2})); EXPECT_EQ(packedPosIds, std::vector({0, 1, 2, 3, 2, 3, 1, 2, 3})); ASSERT_EQ(indices.size(), 1); ASSERT_EQ(indices[0].size(), 3); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[0][1], std::vector({0, 1, 4, 5})); EXPECT_EQ(indices[0][2], std::vector({0, 6, 7, 8})); } { std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; SizeType32 basePosId{10}; std::vector> tokens = {{0, 1, 2, 3}, {0, 1, 6, 2}, {0, 5, 6, 2}}; auto const totalGen = network.computeCompressedVectorAndIndices(compressedVector, packedPosIds, indices, tokens, basePosId); EXPECT_EQ(totalGen, 9); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3, 6, 2, 5, 6, 2})); EXPECT_EQ(packedPosIds, std::vector({10, 11, 12, 13, 12, 13, 11, 12, 13})); ASSERT_EQ(indices.size(), 1); ASSERT_EQ(indices[0].size(), 3); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[0][1], std::vector({0, 1, 4, 5})); EXPECT_EQ(indices[0][2], std::vector({0, 6, 7, 8})); } } TEST(ExplicitDraftTokensDummyNetworkTest, compressTokensTest) { { ExplicitDraftTokensDummyNetwork network; std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; std::vector genLengths; SamplingParams params; params.setBatchSize(1); params.setMaxNumPaths(1); params.setMaxDraftPathLen(6); network.setSamplingParams(params); DraftTokensVec tokens = {{{0, 1, 2, 3}}}; std::vector basePosIds = {0}; network.compressTokens(compressedVector, packedPosIds, indices, genLengths, tokens, basePosIds); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3, -1, -1, -1})); EXPECT_EQ(packedPosIds, std::vector({0, 1, 2, 3, 0, 0, 0})); ASSERT_EQ(indices.size(), 1); ASSERT_EQ(indices[0].size(), 1); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); ASSERT_EQ(genLengths.size(), 1); EXPECT_EQ(genLengths[0], 4); network.createNextMasks(indices, tokens, 4); auto masks = network.getNextMasks(); ASSERT_EQ(masks.size(), 1); ASSERT_EQ(masks[0].size(), 4); ASSERT_EQ(masks[0][0].size(), 4); EXPECT_EQ(masks[0][0], std::vector({true, false, false, false})); EXPECT_EQ(masks[0][1], std::vector({true, true, false, false})); EXPECT_EQ(masks[0][2], std::vector({true, true, true, false})); EXPECT_EQ(masks[0][3], std::vector({true, true, true, true})); } { ExplicitDraftTokensDummyNetwork network; std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; std::vector genLengths; SamplingParams params; params.setBatchSize(2); params.setMaxNumPaths(1); params.setMaxDraftPathLen(6); network.setSamplingParams(params); std::vector basePosIds = {10, 10}; DraftTokensVec tokens = {{{0, 1, 2, 3}}, {{0, 1, 2, 3}}}; network.compressTokens(compressedVector, packedPosIds, indices, genLengths, tokens, basePosIds); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3, 0, 1, 2, 3, -1, -1, -1, -1, -1, -1})); EXPECT_EQ(packedPosIds, std::vector({10, 11, 12, 13, 10, 11, 12, 13, 0, 0, 0, 0, 0, 0})); ASSERT_EQ(indices.size(), 2); ASSERT_EQ(indices[0].size(), 1); ASSERT_EQ(indices[1].size(), 1); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[1][0], std::vector({0, 1, 2, 3})); ASSERT_EQ(genLengths.size(), 2); EXPECT_EQ(genLengths[0], 4); EXPECT_EQ(genLengths[1], 4); network.createNextMasks(indices, tokens, 4); auto masks = network.getNextMasks(); ASSERT_EQ(masks.size(), 2); ASSERT_EQ(masks[0].size(), 4); ASSERT_EQ(masks[1].size(), 4); ASSERT_EQ(masks[0][0].size(), 4); ASSERT_EQ(masks[1][0].size(), 4); EXPECT_EQ(masks[0][0], std::vector({true, false, false, false})); EXPECT_EQ(masks[0][1], std::vector({true, true, false, false})); EXPECT_EQ(masks[0][2], std::vector({true, true, true, false})); EXPECT_EQ(masks[0][3], std::vector({true, true, true, true})); EXPECT_EQ(masks[1][0], std::vector({true, false, false, false})); EXPECT_EQ(masks[1][1], std::vector({true, true, false, false})); EXPECT_EQ(masks[1][2], std::vector({true, true, true, false})); EXPECT_EQ(masks[1][3], std::vector({true, true, true, true})); } { ExplicitDraftTokensDummyNetwork network; std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; std::vector genLengths; SamplingParams params; params.setBatchSize(2); params.setMaxNumPaths(3); params.setMaxDraftPathLen(4); network.setSamplingParams(params); std::vector basePosIds = {10, 0}; DraftTokensVec tokens = {{{0, 1, 2, 3}, {0, 1, 6, 2}, {0, 5, 6, 2}}, {{0, 1, 2, 3}, {0, 1, 2, 4}}}; network.compressTokens(compressedVector, packedPosIds, indices, genLengths, tokens, basePosIds); EXPECT_EQ(compressedVector, std::vector( {0, 1, 2, 3, 6, 2, 5, 6, 2, 0, 1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1})); EXPECT_EQ(packedPosIds, std::vector( {10, 11, 12, 13, 12, 13, 11, 12, 13, 0, 1, 2, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); ASSERT_EQ(indices.size(), 2); ASSERT_EQ(indices[0].size(), 3); ASSERT_EQ(indices[1].size(), 2); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[0][1], std::vector({0, 1, 4, 5})); EXPECT_EQ(indices[0][2], std::vector({0, 6, 7, 8})); EXPECT_EQ(indices[1][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[1][1], std::vector({0, 1, 2, 4})); ASSERT_EQ(genLengths.size(), 2); EXPECT_EQ(genLengths[0], 9); EXPECT_EQ(genLengths[1], 5); network.createNextMasks(indices, tokens, 9); auto masks = network.getNextMasks(); ASSERT_EQ(masks.size(), 2); ASSERT_EQ(masks[0].size(), 9); ASSERT_EQ(masks[1].size(), 9); ASSERT_EQ(masks[0][0].size(), 9); ASSERT_EQ(masks[1][0].size(), 9); EXPECT_EQ(masks[0][0], std::vector({true, false, false, false, false, false, false, false, false})); EXPECT_EQ(masks[0][1], std::vector({true, true, false, false, false, false, false, false, false})); EXPECT_EQ(masks[0][2], std::vector({true, true, true, false, false, false, false, false, false})); EXPECT_EQ(masks[0][3], std::vector({true, true, true, true, false, false, false, false, false})); EXPECT_EQ(masks[0][4], std::vector({true, true, false, false, true, false, false, false, false})); EXPECT_EQ(masks[0][5], std::vector({true, true, false, false, true, true, false, false, false})); EXPECT_EQ(masks[0][6], std::vector({true, false, false, false, false, false, true, false, false})); EXPECT_EQ(masks[0][7], std::vector({true, false, false, false, false, false, true, true, false})); EXPECT_EQ(masks[0][8], std::vector({true, false, false, false, false, false, true, true, true})); EXPECT_EQ(masks[1][0], std::vector({true, false, false, false, false, false, false, false, false})); EXPECT_EQ(masks[1][1], std::vector({true, true, false, false, false, false, false, false, false})); EXPECT_EQ(masks[1][2], std::vector({true, true, true, false, false, false, false, false, false})); EXPECT_EQ(masks[1][3], std::vector({true, true, true, true, false, false, false, false, false})); EXPECT_EQ(masks[1][4], std::vector({true, true, true, false, true, false, false, false, false})); EXPECT_EQ(masks[1][5], std::vector({false, false, false, false, false, true, false, false, false})); EXPECT_EQ(masks[1][6], std::vector({false, false, false, false, false, false, true, false, false})); EXPECT_EQ(masks[1][7], std::vector({false, false, false, false, false, false, false, true, false})); EXPECT_EQ(masks[1][8], std::vector({false, false, false, false, false, false, false, false, true})); } } TEST(ExplicitDraftTokensDummyNetworkTest, acceptTokensTest) { { ExplicitDraftTokensDummyNetwork network; std::vector predictionTokens = {network.tokenize("how things")}; DraftLettersVec lastDraftLetters = {{"how do ", "how are", "however", "hello w"}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}}; auto lastDraftTokens = network.draftLettersToTokens(lastDraftLetters); auto nextDraftTokens = network.draftLettersToTokens(nextDraftLetters); network.acceptTokens(predictionTokens, lastDraftTokens, nextDraftTokens); auto bestPathLengths = network.getBestPathLengths(); auto bestPathIndices = network.getBestPathIndices(); auto outputIds = network.getOutputIds(); ASSERT_EQ(bestPathLengths.size(), 1); ASSERT_EQ(bestPathIndices.size(), 1); ASSERT_EQ(outputIds.size(), 1); EXPECT_EQ(bestPathLengths[0], 4); EXPECT_EQ(bestPathIndices[0], 0); EXPECT_EQ(network.detokenize(outputIds[0]), "ow t"); } { ExplicitDraftTokensDummyNetwork network; std::vector predictionTokens = {network.tokenize("however you")}; DraftLettersVec lastDraftLetters = {{"how do ", "how tho", "however", "hello w"}}; DraftLettersVec nextDraftLetters = {{" increme", " introdu", " i = 0; ", " importa"}}; auto lastDraftTokens = network.draftLettersToTokens(lastDraftLetters); auto nextDraftTokens = network.draftLettersToTokens(nextDraftLetters); network.acceptTokens(predictionTokens, lastDraftTokens, nextDraftTokens); auto bestPathLengths = network.getBestPathLengths(); auto bestPathIndices = network.getBestPathIndices(); auto outputIds = network.getOutputIds(); ASSERT_EQ(bestPathLengths.size(), 1); ASSERT_EQ(bestPathIndices.size(), 1); ASSERT_EQ(outputIds.size(), 1); EXPECT_EQ(bestPathLengths[0], 7); EXPECT_EQ(bestPathIndices[0], 2); EXPECT_EQ(network.detokenize(outputIds[0]), "owever "); } { ExplicitDraftTokensDummyNetwork network; std::vector predictionTokens = {network.tokenize("how things")}; DraftLettersVec lastDraftLetters = {{"heruist", "habit i", "handove", "hammer "}}; DraftLettersVec nextDraftLetters = {{"oatmeal", "ocean b", "occupat", "oblivio"}}; auto lastDraftTokens = network.draftLettersToTokens(lastDraftLetters); auto nextDraftTokens = network.draftLettersToTokens(nextDraftLetters); network.acceptTokens(predictionTokens, lastDraftTokens, nextDraftTokens); auto bestPathLengths = network.getBestPathLengths(); auto bestPathIndices = network.getBestPathIndices(); auto outputIds = network.getOutputIds(); ASSERT_EQ(bestPathLengths.size(), 1); ASSERT_EQ(bestPathIndices.size(), 1); ASSERT_EQ(outputIds.size(), 1); EXPECT_EQ(bestPathLengths[0], 1); EXPECT_EQ(bestPathIndices[0], 0); EXPECT_EQ(network.detokenize(outputIds[0]), "o"); } } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// template void ExplicitDraftTokensLayerTest::SetUp() { mStream = std::make_shared(); mBufferManager = std::make_shared(mStream); } template void ExplicitDraftTokensLayerTest::allocateBuffers() { using DataType = typename T::DataType; auto const dataType = TRTDataType::value; auto speculativeDecodingModule = std::make_shared(mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getMaxDecodingDraftTokens(), mSamplingParams.getMaxNumPaths()); auto const decodingDomain = tensorrt_llm::layers::DecoderDomain(mSamplingParams.getMaxBatchSize(), 1, mSamplingParams.getVocabSize(), mSamplingParams.getVocabSize(), speculativeDecodingModule); mExplicitDraftTokensLayer = std::make_shared>( decodingDomain, mBufferManager); // outputs mOutputIds = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxSeqLen()}), nvinfer1::DataType::kINT32); mSeqLengths = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mAcceptedLengths = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mNextDraftLengths = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mPrevDraftLengths = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mAcceptedLengthCumSum = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize() + 1}), nvinfer1::DataType::kINT32); mOutputNextDraftTokens = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingDraftTokens()}), nvinfer1::DataType::kINT32); mOutputPositionIdsBase = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mRandomDataSample = BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), dataType); mRandomDataValidation = BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen()}), dataType); mPackedMasks = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingTokens(), static_cast(divUp(mSamplingParams.getMaxDecodingTokens(), 32))}), nvinfer1::DataType::kINT32); mNextPosIds = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kINT32); mOutputUnpackedNextDraftTokens = BufferManager::pinnedPool( ITensor::makeShape( {mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mOutputUnpackedNextDraftIndices = BufferManager::pinnedPool( ITensor::makeShape( {mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mOutputDraftProbs = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getVocabSize()}), dataType); mOutputTemperatures = BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), dataType); mOutputGenerationLengths = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mOutputGenerationLengthsHost = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mMaxGenLengthHost = BufferManager::pinnedPool(ITensor::makeShape({1}), nvinfer1::DataType::kINT32); // inputs mBatchSlots = BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32); mTokensPerStep = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mPathsOffsets = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize() * mSamplingParams.getMaxDraftPathLen()}), nvinfer1::DataType::kINT32); mMasks = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingTokens(), mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kBOOL); mInputNextDraftTokens = BufferManager::pinnedPool( ITensor::makeShape( {mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mLastDraftTokens = BufferManager::pinnedPool( ITensor::makeShape( {mSamplingParams.getBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mPackedPosIds = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kINT32); mBestPathLengths = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mBestPathIndices = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mSpecDecodingGenerationLengths = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mNextFlatTokens = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize() * mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kINT32); mInputPositionIdsBase = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mNextDraftIndices = BufferManager::pinnedPool( ITensor::makeShape( {mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mLastDraftIndices = BufferManager::pinnedPool( ITensor::makeShape( {mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mNextDraftProbs = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getVocabSize()}), dataType); mEndIds = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mMaxGenLengthDevice = BufferManager::pinnedPool(ITensor::makeShape({1}), nvinfer1::DataType::kINT32); // Packed inputs mMaxGenerationLength = BufferManager::pinnedPool(ITensor::makeShape({1}), nvinfer1::DataType::kINT32); mCumSumGenerationLengths = BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32); // Packed outputs mPackedPositionIdsBase = BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32); mPackedGenerationLengths = BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32); mPackedRandomDataSample = BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), dataType); mPackedRandomDataVerification = BufferManager::pinnedPool( ITensor::makeShape( {mSamplingParams.getBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen()}), dataType); mPackedNextDraftTokens = BufferManager::pinnedPool( ITensor::makeShape( {mSamplingParams.getBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mPackedNextDraftIndices = BufferManager::pinnedPool( ITensor::makeShape( {mSamplingParams.getBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mPackedPackedMasks = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxDecodingTokens(), static_cast(divUp(mSamplingParams.getMaxDecodingTokens(), 32))}), nvinfer1::DataType::kINT32); mPackedPositionOffsets = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kINT32); mPackedPackedPosIds = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kINT32); mPackedDraftProbs = BufferManager::pinnedPool( ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getVocabSize()}), dataType); mPackedTemperatures = BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), dataType); mDecodingWorkspace = std::make_shared(mBufferManager, decodingDomain, TRTDataType::value, mExplicitDraftTokensLayer->getWorkspaceSize()); } template void ExplicitDraftTokensLayerTest::setup() { using DataType = typename T::DataType; // outputs trk::invokeFill(*mOutputIds, TokenIdType{-1}, *mStream); trk::invokeFill(*mSeqLengths, SizeType32{0}, *mStream); trk::invokeFill(*mAcceptedLengths, SizeType32{0}, *mStream); trk::invokeFill(*mAcceptedLengthCumSum, SizeType32{-1}, *mStream); trk::invokeFill(*mOutputNextDraftTokens, TokenIdType{-1}, *mStream); trk::invokeFill(*mOutputPositionIdsBase, SizeType32{0}, *mStream); trk::invokeFill(*mRandomDataSample, DataType{0}, *mStream); trk::invokeFill(*mRandomDataValidation, DataType{0}, *mStream); trk::invokeFill(*mPackedMasks, SizeType32{0}, *mStream); trk::invokeFill(*mNextPosIds, SizeType32{0}, *mStream); trk::invokeFill(*mOutputUnpackedNextDraftTokens, TokenIdType{-1}, *mStream); trk::invokeFill(*mOutputUnpackedNextDraftIndices, SizeType32{0}, *mStream); trk::invokeFill(*mEndIds, TokenIdType{-1}, *mStream); auto inDraftProbs = BufferRange(*mNextDraftProbs); std::mt19937 gen(42); std::uniform_real_distribution distr(0.0, 1.0); std::generate( inDraftProbs.begin(), inDraftProbs.end(), [&gen, &distr]() { return static_cast(distr(gen)); }); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { batchSlotsPtr[bi] = 2 * bi; } auto setupParams = std::make_shared(); mRandomSeeds = std::vector(mSamplingParams.getBatchSize()); mTemperatures = std::vector(mSamplingParams.getBatchSize()); std::mt19937 generator(42); std::uniform_int_distribution seedDistr(1, 1000); std::uniform_real_distribution temperatureDistr(0.001f, 1.f); std::generate( mRandomSeeds.begin(), mRandomSeeds.end(), [&generator, &seedDistr]() { return seedDistr(generator); }); std::generate(mTemperatures.begin(), mTemperatures.end(), [&generator, &temperatureDistr]() { return temperatureDistr(generator); }); setupParams->randomSeed = mRandomSeeds; setupParams->temperature = mTemperatures; setupParams->randomDataSample = mRandomDataSample; setupParams->temperatures = mOutputTemperatures; setupParams->dtype = TRTDataType::value; mDecodingWorkspace->setDeviceBatchSlots(mBatchSlots); mExplicitDraftTokensLayer->setup(mSamplingParams.getBatchSize(), 1, mBatchSlots, setupParams, mDecodingWorkspace); mStream->synchronize(); mBestPathLengths = mBufferManager->copyFrom(mNetwork.getBestPathLengths(), ITensor::makeShape({mSamplingParams.getBatchSize()}), runtime::MemoryType::kPINNEDPOOL); mBestPathIndices = mBufferManager->copyFrom(mNetwork.getBestPathIndices(), ITensor::makeShape({mSamplingParams.getBatchSize()}), runtime::MemoryType::kPINNEDPOOL); mPackedPosIds = mBufferManager->copyFrom(mNetwork.getNextPackedPosId(), ITensor::makeShape({mSamplingParams.getMaxDecodingTokens() * mSamplingParams.getBatchSize()}), runtime::MemoryType::kPINNEDPOOL); auto const nextDraftTokens = mNetwork.getNextDraftTokens(); auto const lastDraftTokens = mNetwork.getLastDraftTokens(); auto const nextDraftIndices = mNetwork.getNextDraftIndices(); auto const lastDraftIndices = mNetwork.getLastDraftIndices(); auto sequenceLength = BufferRange(*mSeqLengths); auto nextDraftTokensRange = BufferRange(*mInputNextDraftTokens); auto lastDraftTokensRange = BufferRange(*mLastDraftTokens); auto nextDraftIndicesRange = BufferRange(*mNextDraftIndices); auto lastDraftIndicesRange = BufferRange(*mLastDraftIndices); auto inputPositionIdsBase = BufferRange(*mInputPositionIdsBase); auto outputIds = BufferRange(*mOutputIds); auto generationLengths = mNetwork.getNextGenerationLengths(); auto prompts = mNetwork.getPrompts(); for (SizeType32 bi = 0; bi < nextDraftTokens.size(); ++bi) { for (SizeType32 pi = 0; pi < nextDraftTokens[bi].size(); ++pi) { for (SizeType32 ti = 0; ti < nextDraftTokens[bi][pi].size(); ++ti) { auto idx = flat_index3(bi, pi, ti, mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()); nextDraftTokensRange[idx] = nextDraftTokens[bi][pi][ti]; lastDraftTokensRange[idx] = lastDraftTokens[bi][pi][ti]; nextDraftIndicesRange[idx] = nextDraftIndices[bi][pi][ti]; lastDraftIndicesRange[idx] = lastDraftIndices[bi][pi][ti]; } } bufferCast(*mSpecDecodingGenerationLengths)[bi] = generationLengths[bi]; sequenceLength[batchSlotsPtr[bi]] = prompts[bi].size(); std::copy(prompts[bi].begin(), prompts[bi].end(), outputIds.begin() + batchSlotsPtr[bi] * mSamplingParams.getMaxSeqLen()); inputPositionIdsBase[bi] = prompts[bi].size(); } auto nextFlatTokens = mNetwork.getNextFlatTokens(); TLLM_LOG_DEBUG("Next flat tokens are \"%s\"", mNetwork.detokenize(nextFlatTokens).c_str()); auto nextFlatTokensRange = BufferRange(*mNextFlatTokens); std::copy(nextFlatTokens.begin(), nextFlatTokens.end(), nextFlatTokensRange.begin()); auto const masks = mNetwork.getNextMasks(); auto masksRange = BufferRange(*mMasks); auto const maxGenLength = mNetwork.getMaxNextGenerationLength(); bufferCast(*mMaxGenerationLength)[0] = maxGenLength; for (SizeType32 bi = 0; bi < masks.size(); ++bi) { TLLM_CHECK(maxGenLength == masks[bi].size()); for (SizeType32 ri = 0; ri < masks[bi].size(); ++ri) { TLLM_CHECK(maxGenLength == masks[bi][ri].size()); for (SizeType32 ci = 0; ci < masks[bi][ri].size(); ++ci) { masksRange[bi * maxGenLength * maxGenLength + ri * maxGenLength + ci] = masks[bi][ri][ci]; } } } } template std::shared_ptr ExplicitDraftTokensLayerTest::createInputTensors() { auto forwardParams = std::make_shared(mEndIds, mBatchSlots, mSamplingParams.getBatchSize()); forwardParams->seqSlots = mBatchSlots; forwardParams->masks = mMasks; forwardParams->nextDraftTokens = mInputNextDraftTokens; forwardParams->nextDraftIndices = mNextDraftIndices; forwardParams->lastDraftTokens = mLastDraftTokens; forwardParams->lastDraftIndices = mLastDraftIndices; forwardParams->packedPosIds = mPackedPosIds; forwardParams->bestPathLengths = mBestPathLengths; forwardParams->bestPathIndices = mBestPathIndices; forwardParams->generationLengths = mSpecDecodingGenerationLengths; forwardParams->nextFlatTokens = mNextFlatTokens; forwardParams->positionIdsBase = mInputPositionIdsBase; forwardParams->nextDraftProbs = mNextDraftProbs; forwardParams->maxGenLengthDevice = mMaxGenLengthDevice; return forwardParams; } template std::shared_ptr ExplicitDraftTokensLayerTest::createOutputTensors() { auto outputParams = std::make_shared(mOutputIds); outputParams->sequenceLength = mSeqLengths; outputParams->nextDraftTokens = mOutputNextDraftTokens; outputParams->numNewTokens = mAcceptedLengths; outputParams->nextDraftLengths = mNextDraftLengths; outputParams->prevDraftLengths = mPrevDraftLengths; outputParams->numNewTokensCumSum = mAcceptedLengthCumSum; outputParams->pathsOffsets = mPathsOffsets; outputParams->nextDraftPosIds = mNextPosIds; outputParams->positionIdsBase = mOutputPositionIdsBase; outputParams->randomDataSample = mRandomDataSample; outputParams->randomDataValidation = mRandomDataValidation; outputParams->packedMasks = mPackedMasks; outputParams->packedMasks = mPackedMasks; outputParams->unpackedNextDraftTokens = mOutputUnpackedNextDraftTokens; outputParams->unpackedNextDraftIndices = mOutputUnpackedNextDraftIndices; outputParams->nextDraftProbs = mOutputDraftProbs; outputParams->temperatures = mOutputTemperatures; outputParams->generationLengths = mOutputGenerationLengths; outputParams->generationLengthsHost = mOutputGenerationLengthsHost; outputParams->maxGenLengthHost = mMaxGenLengthHost; return outputParams; } std::vector boolArrayToBitmask(BufferRange::iterator boolIterator, size_t pathLen) { std::vector bitmask(divUp(pathLen, 32)); for (size_t bi = 0; bi < pathLen; ++bi) { auto slice = bi / 32; if (boolIterator[bi]) { bitmask[slice] |= (1 << (bi % 32)); } } return bitmask; } template void ExplicitDraftTokensLayerTest::checkLayerResult() { using DataType = typename T::DataType; auto const batchSlots = BufferRange(*mBatchSlots); // Check generated random data { auto const randomDataSample = BufferRange(*mRandomDataSample); auto const randomDataValidation = BufferRange(*mRandomDataValidation); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; // Check that all fields are filled with non zero data EXPECT_NE(randomDataSample[batchSlot], DataType{0}) << " bi: " << bi; auto const stride = mSamplingParams.getMaxNumPaths() * mSamplingParams.getMaxDraftPathLen(); EXPECT_FALSE(std::any_of(randomDataValidation.begin() + batchSlot * stride, randomDataValidation.begin() + (batchSlot + 1) * stride, [](DataType val) { return val == DataType{0}; })) << " bi: " << bi; } } // Check masks { auto const packedMasks = BufferRange(*mPackedMasks); auto masks = BufferRange(*mMasks); auto generationLengths = mNetwork.getNextGenerationLengths(); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { for (SizeType32 ti = 0; ti < generationLengths[bi]; ++ti) { auto const batchSlot = batchSlots[bi]; auto const maskIdx = flat_index3( bi, ti, 0, mNetwork.getMaxNextGenerationLength(), mNetwork.getMaxNextGenerationLength()); auto const bitmask = boolArrayToBitmask(masks.begin() + maskIdx, mNetwork.getMaxNextGenerationLength()); for (SizeType32 mi = 0; mi < bitmask.size(); ++mi) { auto const packedMaskIdx = flat_index3(batchSlot, ti, mi, mSamplingParams.getMaxDecodingTokens(), static_cast(divUp(mSamplingParams.getMaxDecodingTokens(), 32))); EXPECT_EQ(bitmask[mi], packedMasks[packedMaskIdx]) << " bi: " << bi << " ti: " << ti; } } } } // Check accepted tokens auto const outputIds = BufferRange(*mOutputIds); auto const refOutputIds = mNetwork.getOutputIds(); auto const promptIds = mNetwork.getPrompts(); auto const seqLenghts = BufferRange(*mSeqLengths); auto const lastDraftTokens = BufferRange(*mLastDraftTokens); auto const bestPathLengths = BufferRange(*mBestPathLengths); auto const bestPathIndices = BufferRange(*mBestPathIndices); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; // Updated seq length is prompt length and newly accepted tokens. EXPECT_EQ(seqLenghts[batchSlot], promptIds[bi].size() + bestPathLengths[bi]) << " bi: " << bi; // Check that output ids contains accepted tokens. for (SizeType32 ti = 0; ti < promptIds[bi].size() + bestPathLengths[bi]; ++ti) { EXPECT_EQ(outputIds[batchSlot * mSamplingParams.getMaxSeqLen() + ti], refOutputIds[bi][ti]) << " bi: " << bi << " ti: " << ti; } auto outputIter = outputIds.begin() + batchSlot * mSamplingParams.getMaxSeqLen(); std::vector outputVec(outputIter, outputIter + seqLenghts[batchSlot]); TLLM_LOG_DEBUG("Output ids at %d request is \"%s\"", bi, mNetwork.detokenize(outputVec).c_str()); TLLM_LOG_DEBUG("Ref output ids at %d request is \"%s\"", bi, mNetwork.detokenize(refOutputIds[bi]).c_str()); } // Check new draft tokens { auto const outputNextDraftTokens = BufferRange(*mOutputNextDraftTokens); auto const generationLengths = BufferRange(*mSpecDecodingGenerationLengths); auto const compressedDraftTokens = mNetwork.getNextFlatTokens(); TLLM_LOG_DEBUG("Next compressed draft tokens are \"%s\"", mNetwork.detokenize(compressedDraftTokens).c_str()); SizeType32 compressedIdx = 0; for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; auto const generatedLength = generationLengths[bi]; // Check draft tokens for the next iteration. for (SizeType32 ti = 0; ti < generatedLength - 1; ++ti) { auto const idx = flat_index2(batchSlot, ti, mSamplingParams.getMaxDecodingDraftTokens()); EXPECT_EQ(outputNextDraftTokens[idx], compressedDraftTokens[compressedIdx + ti + 1]) << " bi: " << bi << " ti: " << ti; } // Check length of the draft tokens. EXPECT_EQ(BufferRange(*mNextDraftLengths)[batchSlot], generatedLength - 1) << " bi: " << bi; // Check accepted length. EXPECT_EQ(BufferRange(*mAcceptedLengths)[batchSlot], bestPathLengths[bi]) << " bi: " << bi; compressedIdx += generatedLength; } } // Check position ids { auto const outputPositionIdsBase = BufferRange(*mOutputPositionIdsBase); auto const nextPosIds = BufferRange(*mNextPosIds); auto const generationLengths = BufferRange(*mSpecDecodingGenerationLengths); auto const packedPosIds = mNetwork.getNextPackedPosId(); SizeType32 compressedIdx = 0; for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; EXPECT_EQ(outputPositionIdsBase[batchSlot], seqLenghts[batchSlot]); auto const generatedLength = generationLengths[bi]; // Check pos ids for the next iteration. for (SizeType32 ti = 0; ti < generatedLength; ++ti) { auto const idx = flat_index2(batchSlot, ti, mSamplingParams.getMaxDecodingTokens()); // Minus -1 to account for context phase correction of pos ids EXPECT_EQ(nextPosIds[idx], packedPosIds[compressedIdx + ti] - 1) << " bi: " << bi << " ti: " << ti; } compressedIdx += generatedLength; } } // Check unpacked indices and tokens { auto const nextDraftTokens = mNetwork.getNextDraftTokens(); auto const nextDraftIndices = mNetwork.getNextDraftIndices(); auto const nextDraftTokensRange = BufferRange(*mOutputUnpackedNextDraftTokens); auto const nextDraftIndicesRange = BufferRange(*mOutputUnpackedNextDraftIndices); for (SizeType32 bi = 0; bi < nextDraftTokens.size(); ++bi) { auto const batchSlot = batchSlots[bi]; for (SizeType32 pi = 0; pi < nextDraftTokens[bi].size(); ++pi) { for (SizeType32 ti = 0; ti < nextDraftTokens[bi][pi].size(); ++ti) { auto idx = flat_index3( batchSlot, pi, ti, mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()); EXPECT_EQ(nextDraftTokensRange[idx], nextDraftTokens[bi][pi][ti]) << "bi: " << bi << " pi: " << pi << " ti: " << ti; EXPECT_EQ(nextDraftIndicesRange[idx], nextDraftIndices[bi][pi][ti]) << "bi: " << bi << " pi: " << pi << " ti: " << ti; } } } } // Check accumulated cum sum and paths offsets { auto const accumulatedCumSum = BufferRange(*mAcceptedLengthCumSum); auto const pathsOffsets = BufferRange(*mPathsOffsets); auto const acceptedLengths = BufferRange(*mAcceptedLengths); auto const bestPathIndices = BufferRange(*mBestPathIndices); auto const lastDraftIndices = mNetwork.getLastDraftIndices(); SizeType32 sum = 0; for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; EXPECT_EQ(sum, accumulatedCumSum[bi]) << "bi: " << bi; auto const acceptedLength = acceptedLengths[batchSlot] - 1; for (SizeType32 ti = 0; ti < acceptedLength; ++ti) { EXPECT_EQ(pathsOffsets[sum + ti], lastDraftIndices[bi][bestPathIndices[bi]][ti + 1] - 1) << "bi: " << bi << " ti: " << ti; } sum += acceptedLength; } EXPECT_EQ(sum, accumulatedCumSum[mSamplingParams.getBatchSize()]); } // Check draft probs { auto const outDraftProbs = BufferRange(*mOutputDraftProbs); auto const inDraftProbs = BufferRange(*mNextDraftProbs); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; for (SizeType32 pi = 0; pi < mSamplingParams.getMaxNumPaths(); ++pi) { for (SizeType32 ti = 0; ti < mSamplingParams.getMaxDraftPathLen(); ++ti) { for (SizeType32 vi = 0; vi < mSamplingParams.getVocabSize(); ++vi) { auto const outProbIdx = flat_index4(batchSlot, pi, ti, vi, mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getVocabSize()); auto const inProbIdx = flat_index4(bi, pi, ti, vi, mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getVocabSize()); EXPECT_EQ(outDraftProbs[outProbIdx], inDraftProbs[inProbIdx]) << "bi: " << bi << " pi: " << pi << " ti: " << ti << " vi: " << vi; } } } } } // Check temperature { for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; EXPECT_EQ( BufferRange(*mOutputTemperatures)[batchSlot], static_cast(1.f / mTemperatures[bi])) << " bi: " << bi; } } } template void ExplicitDraftTokensLayerTest::packData() { using DataType = typename T::DataType; tksd::PackExplicitDraftTokensParams params; params.batchSlots = bufferCast(*mBatchSlots); params.cumSumGenerationLengths = bufferCast(*mCumSumGenerationLengths); params.maxGenerationLength = bufferCast(*mMaxGenerationLength); params.outputPositionIdsBase = bufferCast(*mPackedPositionIdsBase); params.inputPositionIdsBase = bufferCast(*mOutputPositionIdsBase); params.outputGenerationLengths = bufferCast(*mPackedGenerationLengths); params.inputGenerationLengths = bufferCast(*mSpecDecodingGenerationLengths); params.outputRandomDataSample = bufferCast(*mPackedRandomDataSample); params.inputRandomDataSample = bufferCast(*mRandomDataSample); params.outputRandomDataValidation = bufferCast(*mPackedRandomDataVerification); params.inputRandomDataValidation = bufferCast(*mRandomDataValidation); params.outputNextDraftTokens = bufferCast(*mPackedNextDraftTokens); params.inputNextDraftTokens = bufferCast(*mOutputUnpackedNextDraftTokens); params.outputNextDraftIndices = bufferCast(*mPackedNextDraftIndices); params.inputNextDraftIndices = bufferCast(*mOutputUnpackedNextDraftIndices); params.outputPackedMask = bufferCast(*mPackedPackedMasks); params.inputPackedMask = bufferCast(*mPackedMasks); params.inputPositionIds = bufferCast(*mNextPosIds); params.outputPositionOffsets = bufferCast(*mPackedPositionOffsets); params.outputPositionIds = bufferCast(*mPackedPackedPosIds); params.outputDraftProbs = bufferCast(*mPackedDraftProbs); params.inputDraftProbs = bufferCast(*mOutputDraftProbs); params.outputTemperatures = bufferCast(*mPackedTemperatures); params.inputTemperatures = bufferCast(*mOutputTemperatures); params.batchSize = mSamplingParams.getBatchSize(); params.numPaths = mSamplingParams.getMaxNumPaths(); params.maxPathLength = mSamplingParams.getMaxPathLen(); params.vocabSize = mSamplingParams.getVocabSize(); params.numGenerationRequests = mSamplingParams.getBatchSize(); params.numContextTokens = 0; params.checkParams(); tksd::invokePackGenerationLengths(params, mStream->get()); // Compute inclusive sum auto reduceTempStorageBytes = tksd::invokeScanGenerationLengths( nullptr, 0, nullptr, nullptr, mSamplingParams.getBatchSize(), mStream->get()); auto reduceMaxTempStorage = mBufferManager->gpu(reduceTempStorageBytes); tksd::invokeScanGenerationLengths(bufferCast(*reduceMaxTempStorage), reduceTempStorageBytes, bufferCast(*mSpecDecodingGenerationLengths), bufferCast(*mCumSumGenerationLengths), mSamplingParams.getBatchSize(), mStream->get()); // Pack tensors from batch slot position to continuous array tksd::invokePackExplicitDraftTokens(params, mStream->get()); // Copy draft probs tksd::invokeCopyProbs(params, mStream->get()); } template void ExplicitDraftTokensLayerTest::checkPackResult() { using DataType = typename T::DataType; auto const batchSlots = BufferRange(*mBatchSlots); auto const maxGenLength = mNetwork.getMaxNextGenerationLength(); auto const numPackedMasks = static_cast(divUp(mSamplingParams.getMaxDecodingTokens(), 32)); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; EXPECT_EQ(BufferRange(*mPackedPositionIdsBase)[bi], BufferRange(*mOutputPositionIdsBase)[batchSlot]) << "bi: " << bi; EXPECT_EQ(BufferRange(*mPackedGenerationLengths)[bi], BufferRange(*mSpecDecodingGenerationLengths)[batchSlot]) << "bi: " << bi; EXPECT_EQ( BufferRange(*mPackedRandomDataSample)[bi], BufferRange(*mRandomDataSample)[batchSlot]) << "bi: " << bi; EXPECT_EQ( BufferRange(*mPackedTemperatures)[bi], BufferRange(*mOutputTemperatures)[batchSlot]) << "bi: " << bi; for (SizeType32 pi = 0; pi < mSamplingParams.getMaxNumPaths(); ++pi) { for (SizeType32 ti = 0; ti < mSamplingParams.getMaxDraftPathLen(); ++ti) { EXPECT_EQ(bufferCast(*ITensor::at(mPackedRandomDataVerification, {bi, pi, ti}))[0], bufferCast(*ITensor::at(mRandomDataValidation, {batchSlot, pi, ti}))[0]) << "bi: " << bi << " pi: " << pi << " ti: " << ti; for (SizeType32 vi = 0; vi < mSamplingParams.getVocabSize(); ++vi) { EXPECT_EQ(bufferCast(*ITensor::at(mPackedDraftProbs, {bi, pi, ti, vi}))[0], bufferCast(*ITensor::at(mOutputDraftProbs, {batchSlot, pi, ti, vi}))[0]) << "bi: " << bi << " pi: " << pi << " ti: " << ti << " vi: " << vi; } } for (SizeType32 ti = 0; ti < mSamplingParams.getMaxPathLen(); ++ti) { EXPECT_EQ(bufferCast(*ITensor::at(mPackedNextDraftTokens, {bi, pi, ti}))[0], bufferCast(*ITensor::at(mOutputUnpackedNextDraftTokens, {batchSlot, pi, ti}))[0]) << "bi: " << bi << " pi: " << pi << " ti: " << ti; EXPECT_EQ(bufferCast(*ITensor::at(mPackedNextDraftIndices, {bi, pi, ti}))[0], bufferCast(*ITensor::at(mOutputUnpackedNextDraftIndices, {batchSlot, pi, ti}))[0]) << "bi: " << bi << " pi: " << pi << " ti: " << ti; } } auto const basePosId = BufferRange(*mPackedPositionIdsBase)[bi]; for (SizeType32 ti = 0; ti < maxGenLength; ++ti) { auto const outPosOffsetIdx = flat_index2(bi, ti, maxGenLength); auto const inPosOffsetIdx = flat_index2(batchSlot, ti, mSamplingParams.getMaxDecodingTokens()); EXPECT_EQ(BufferRange(*mPackedPositionOffsets)[outPosOffsetIdx], BufferRange(*mNextPosIds)[inPosOffsetIdx] - basePosId + 1) << "bi: " << bi << " ti: " << ti; } auto const outputMaskStartId = (bi == 0) ? 0 : BufferRange(*mCumSumGenerationLengths)[bi - 1]; auto const numTokens = (bi == 0) ? BufferRange(*mCumSumGenerationLengths)[0] : BufferRange(*mCumSumGenerationLengths)[bi] - BufferRange(*mCumSumGenerationLengths)[bi - 1]; for (SizeType32 mi = 0; mi < numTokens * numPackedMasks; ++mi) { auto const outMaskIdx = outputMaskStartId * numPackedMasks + mi; auto const inMaskIdx = flat_index2(batchSlot, mi, mSamplingParams.getMaxDecodingTokens() * numPackedMasks); EXPECT_EQ( BufferRange(*mPackedPackedMasks)[outMaskIdx], BufferRange(*mPackedMasks)[inMaskIdx]) << "bi: " << bi << " mi: " << mi; } } } template void ExplicitDraftTokensLayerTest::runTest(std::vector const& prompts, std::vector const& predictions, DraftLettersVec const& nextDraftLetters, DraftLettersVec const& lastDraftLetters, SamplingParams& params) { mSamplingParams = params; mNetwork.forward(params, prompts, predictions, nextDraftLetters, lastDraftLetters); allocateBuffers(); setup(); auto inputTensors = createInputTensors(); auto outputTensors = createOutputTensors(); mDecodingWorkspace->setDeviceBatchSlots(mBatchSlots); mExplicitDraftTokensLayer->forwardAsync(outputTensors, inputTensors, mDecodingWorkspace); mStream->synchronize(); checkLayerResult(); packData(); mStream->synchronize(); checkPackResult(); } template class ExplicitDraftTokensLayerTest>; template class ExplicitDraftTokensLayerTest>; #ifdef ENABLE_BF16 template class ExplicitDraftTokensLayerTest>; #endif // ENABLE_BF16 TYPED_TEST_SUITE(ExplicitDraftTokensLayerTest, TestTypes); TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestBS1) { SamplingParams params; std::vector prompt = {"Hi mate, h"}; std::vector predictions = {"how things"}; DraftLettersVec lastDraftLetters = {{"how do ", "how are", "however", "hello w"}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}}; params.setBatchSize(1); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestBS1OnePaths) { SamplingParams params; std::vector prompt = {"Hi mate, h"}; std::vector predictions = {"how things"}; DraftLettersVec lastDraftLetters = {{"how do "}}; DraftLettersVec nextDraftLetters = {{"things "}}; params.setBatchSize(1); params.setMaxNumPaths(1); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestSecondPathAcceptedBS1) { SamplingParams params; std::vector prompt = {"Hi mate, h"}; std::vector predictions = {"how things"}; DraftLettersVec lastDraftLetters = {{"howdy f", "how are", "however", "hello w"}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}}; params.setBatchSize(1); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestNoDraftAcceptedBS1) { SamplingParams params; std::vector prompt = {"Hi mate, h"}; std::vector predictions = {"how things"}; DraftLettersVec lastDraftLetters = {{"handove", "human f", "heavy l", "hello h"}}; DraftLettersVec nextDraftLetters = {{"oatmeal", "ocean b", "occupat", "oblivio"}}; params.setBatchSize(1); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestBS2SameSequence) { SamplingParams params; std::vector prompt = {"Hi mate, h", "Hi mate, h"}; std::vector predictions = {"how things", "how things"}; DraftLettersVec lastDraftLetters = {{"how do ", "how are", "however", "hello w"}, {"how do ", "how are", "however", "hello w"}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}, {"things ", "that is", "to crea", "touchab"}}; params.setBatchSize(2); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestBS2Long) { SamplingParams params; std::vector prompt = {"Hi mate, h", "London is t"}; std::vector predictions = {"how things are going", "the capital of Great Britain"}; DraftLettersVec lastDraftLetters = {{"how do you ", "how are you", "however you", "hello world"}, {"the bar and", "the best ci", "the capital", "thoughest p"}}; DraftLettersVec nextDraftLetters = {{"things are ", "that is sad", "to create a", "touchable y"}, {" of Great B", " and the ma", " of country", " also known"}}; params.setBatchSize(2); // ceil(4 * 10 / 32) = 2 masks per request params.setMaxNumPaths(4); params.setMaxDraftPathLen(10); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestBS2DifferentSequences) { SamplingParams params; std::vector prompt = {"Hi mate, h", "London is t"}; std::vector predictions = {"how things", "the cap"}; DraftLettersVec lastDraftLetters = {{"how do ", "how are", "however", "hello w"}, {"the bar", "the bes", "the cap", "thoughe"}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}, {"itan of", "iteract", "ital of", "importa"}}; params.setBatchSize(2); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestB4DifferentSequences) { SamplingParams params; std::vector prompt = {"Hi mate, h", "London is t", "Short", "Very long prompt but should not m"}; std::vector predictions = {"how things", "the cap", "twave o", "matter "}; DraftLettersVec lastDraftLetters = {{"how do ", "how are", "however", "hello w"}, {"the bar", "the bes", "the cap", "thoughe"}, {"t promp", "ts on Y", "ter out", "twave o"}, {"matter ", "mean an", "make th", "modify "}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}, {"itan of", "iteract", "ital of", "importa"}, {" chips ", " oil an", " semico", " exampl"}, {"at all ", "anythin", "above a", "albeit "}}; params.setBatchSize(4); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } template class FillRandDataTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init) { protected: static auto constexpr mDataType{TRTDataType::value}; FillRandDataTest() {} void SetUp() override { mLogger = std::make_shared(); mStream = std::make_shared(); mBufferManager = std::make_shared(mStream); } void TearDown() override {} void runTest(SizeType32 batchSize, SizeType32 numPaths, SizeType32 draftLength, bool skipVerification, uint64_t randomSeed, bool batchInit) { SizeType32* batchSlotsPtr{nullptr}; auto curandState = mBufferManager->gpu(ITensor::makeShape({batchSize, 48}), nvinfer1::DataType::kUINT8); auto* curandStatePtr = reinterpret_cast(bufferCast(*curandState)); if (batchInit) { auto randomSeeds = mBufferManager->gpu(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT64); trk::invokeFill(*randomSeeds, static_cast(randomSeed), *mStream); auto* randomSeedsPtr = bufferCast(*randomSeeds); tk::invokeCurandBatchInitialize(curandStatePtr, batchSlotsPtr, batchSize, randomSeedsPtr, mStream->get()); } else { tk::invokeCurandInitialize(curandStatePtr, batchSlotsPtr, batchSize, randomSeed, mStream->get()); } mStream->synchronize(); tksd::FillRandDataExplicitDraftTokensParams params; params.batchSize = batchSize; params.numPaths = numPaths; params.draftLength = draftLength; params.skipVerification = skipVerification; auto randDataSample = mBufferManager->gpu(ITensor::makeShape({batchSize}), mDataType); auto randDataValidation = mBufferManager->gpu(ITensor::makeShape({batchSize, numPaths, draftLength}), mDataType); params.randDataSample = bufferCast(*randDataSample); params.randDataVerification = bufferCast(*randDataValidation); params.curandState = curandStatePtr; params.batchSlots = batchSlotsPtr; tksd::invokeFillRandData(params, mStream->get()); mStream->synchronize(); auto randDataSampleHost = mBufferManager->copyFrom(*randDataSample, MemoryType::kCPU); auto randDataSampleHostPtr = bufferCast(*randDataSampleHost); EXPECT_GE(randDataSampleHostPtr[0], T(0)); EXPECT_LE(randDataSampleHostPtr[0], T(1)); auto randDataValidationHost = mBufferManager->copyFrom(*randDataValidation, MemoryType::kCPU); auto randDataValidationHostRange = BufferRange(*randDataValidationHost); for (auto i = 0; i < randDataValidationHostRange.size(); ++i) { EXPECT_GE(randDataValidationHostRange[i], T(0)) << "index " << i; EXPECT_LE(randDataValidationHostRange[i], T(1)) << "index " << i; } } private: std::shared_ptr mLogger; std::shared_ptr mStream; std::shared_ptr mBufferManager; }; #ifdef ENABLE_BF16 using FloatHalfBfloatTypes = testing::Types; TYPED_TEST_SUITE(FillRandDataTest, FloatHalfBfloatTypes); #else TYPED_TEST_SUITE(FillRandDataTest, FloatAndHalfTypes); #endif TYPED_TEST(FillRandDataTest, SimpleTest) { SizeType32 constexpr batchSize{2}; SizeType32 constexpr numPaths{3}; SizeType32 constexpr draftLength{4}; bool constexpr skipVerification{false}; uint64_t randomSeed{0}; this->runTest(batchSize, numPaths, draftLength, skipVerification, randomSeed, false); } TYPED_TEST(FillRandDataTest, BatchInit) { SizeType32 constexpr batchSize{3}; SizeType32 constexpr numPaths{2}; SizeType32 constexpr draftLength{5}; bool constexpr skipVerification{false}; uint64_t randomSeed{42}; this->runTest(batchSize, numPaths, draftLength, skipVerification, randomSeed, true); } } // namespace tensorrt_llm::tests::layers