/* * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tests/layers/explicitDraftTokensLayerTest.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/speculativeDecoding/explicitDraftTokensKernels.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/speculativeDecodingModule.h" #include #include namespace tensorrt_llm::tests::layers { // TODO(nkorobov) verify context + gen mix using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::layers; using namespace tensorrt_llm::common; namespace tk = tensorrt_llm::kernels; namespace tksd = tensorrt_llm::kernels::speculative_decoding; namespace tcc = tensorrt_llm::common::conversion; namespace trk = tensorrt_llm::runtime::kernels; TokensVec ExplicitDraftTokensDummyNetwork::tokenize(std::string const& letters) const { TokensVec tokens; for (char c : letters) { tokens.push_back(static_cast(c)); } return tokens; } std::string ExplicitDraftTokensDummyNetwork::detokenize(TokensVec const& tokens) const { std::string letters; for (int token : tokens) { letters += static_cast(token); } return letters; } DraftTokensVec ExplicitDraftTokensDummyNetwork::draftLettersToTokens(DraftLettersVec const& draftLetters) const { DraftTokensVec draftTokens(draftLetters.size()); for (SizeType32 bi = 0; bi < draftLetters.size(); ++bi) { draftTokens[bi].resize(draftLetters[bi].size()); for (SizeType32 pi = 0; pi < draftLetters[bi].size(); ++pi) { draftTokens[bi][pi] = tokenize(draftLetters[bi][pi]); } } return draftTokens; } SizeType32 ExplicitDraftTokensDummyNetwork::longestCommonPrefixLength(TokensVec const& a, TokensVec const& b) const { SizeType32 minLength = std::min(a.size(), b.size()); SizeType32 idx = 0; while (idx < minLength && a[idx] == b[idx]) { ++idx; } return idx; } SizeType32 ExplicitDraftTokensDummyNetwork::computeCompressedVectorAndIndices(TokensVec& compressedVector, std::vector& packedPosIds, DraftTokensIndices& indices, std::vector const& vectors, SizeType32 basePosId) { TokensVec localCompressedVector; std::vector localPackedPosIds; std::vector> localIndices; // FIXME always take the 1st beam as the reference. Is that correct? // Add whole first vector to compressed vector localCompressedVector = vectors[0]; // All indices of first vector. localIndices.push_back(std::vector(localCompressedVector.size())); for (SizeType32 ti = 0; ti < localCompressedVector.size(); ++ti) { localIndices[0][ti] = ti; // Set local to batch packed pos ids. localPackedPosIds.push_back(basePosId + ti); } // Starting from the 1st path. for (SizeType32 pi = 1; pi < vectors.size(); ++pi) { // Match path to compressed vector (aka path 0). auto const prefixLength = longestCommonPrefixLength(localCompressedVector, vectors[pi]); localIndices.push_back(std::vector(vectors[pi].size())); // Set indices of the matched prefix. for (SizeType32 ti = 0; ti < prefixLength; ++ti) { localIndices[pi][ti] = ti; } // For non-matched part. for (SizeType32 ti = prefixLength; ti < vectors[pi].size(); ++ti) { // Add new tokens to compressed vector. localCompressedVector.push_back(vectors[pi][ti]); // Set new pos ids. localPackedPosIds.push_back(basePosId + ti); // Set their indices. localIndices[pi][ti] = localCompressedVector.size() - 1; } } compressedVector.insert(compressedVector.end(), localCompressedVector.begin(), localCompressedVector.end()); packedPosIds.insert(packedPosIds.end(), localPackedPosIds.begin(), localPackedPosIds.end()); indices.push_back(localIndices); return static_cast(localCompressedVector.size()); } void ExplicitDraftTokensDummyNetwork::createNextMasks( DraftTokensIndices const& indices, DraftTokensVec const& draftTokens, SizeType32 maxGenLength) { for (SizeType32 bi = 0; bi < indices.size(); ++bi) { std::vector> localMask(maxGenLength, std::vector(maxGenLength)); // Create fill diagonal. for (SizeType32 ti = 0; ti < maxGenLength; ++ti) { localMask[ti][ti] = true; } SizeType32 rowIdx = 0; for (SizeType32 pi = 0; pi < draftTokens[bi].size(); ++pi) { auto const prefixLength = pi == 0 ? 0 : longestCommonPrefixLength(draftTokens[bi][0], draftTokens[bi][pi]); for (SizeType32 ti = 0; ti < draftTokens[bi][pi].size(); ++ti) { auto const index = indices[bi][pi][ti]; // If we are in the "prefix" part of the sequence skip it as it does not represent real mask row. if (ti < prefixLength) { continue; } // Fill lower triangular part according to the prefix. for (SizeType32 tti = 0; tti < ti; ++tti) { localMask[rowIdx][indices[bi][pi][tti]] = true; } rowIdx++; } } mMasks.push_back(localMask); } } void ExplicitDraftTokensDummyNetwork::compressTokens(TokensVec& compressedVector, std::vector& packedPosIds, DraftTokensIndices& indices, std::vector& generationLengths, DraftTokensVec const& draftTokens, std::vector const& basePosIds) { generationLengths.resize(draftTokens.size()); for (SizeType32 bi = 0; bi < draftTokens.size(); ++bi) { auto numGeneratedTokens = computeCompressedVectorAndIndices( compressedVector, packedPosIds, indices, draftTokens[bi], basePosIds[bi]); generationLengths[bi] = numGeneratedTokens; } // Pad vectors to the maximum size auto const padSize = mSamplingParams.getMaxDecodingTokens() * mSamplingParams.getBatchSize() - compressedVector.size(); compressedVector.insert(compressedVector.end(), padSize, mSamplingParams.getPadId()); packedPosIds.insert(packedPosIds.end(), padSize, 0); } void ExplicitDraftTokensDummyNetwork::acceptTokens(std::vector const& predictionTokens, DraftTokensVec const& lastDraftTokens, DraftTokensVec const& nextDraftTokens) { TLLM_CHECK_WITH_INFO(predictionTokens.size() == lastDraftTokens.size(), "Batch size of predictions (%d) does not match the batch size of last draft tokens (%d)", static_cast(predictionTokens.size()), static_cast(lastDraftTokens.size())); TLLM_CHECK_WITH_INFO(predictionTokens.size() == nextDraftTokens.size(), "Batch size of predictions (%d) does not match the batch size of next draft tokens (%d)", static_cast(predictionTokens.size()), static_cast(nextDraftTokens.size())); mBestPathLengths.resize(predictionTokens.size()); mBestPathIndices.resize(predictionTokens.size()); // Needed for unit test of ExplicitDraftTokensDummyNetwork only. if (mOutputIds.size() == 0) { mOutputIds.resize(lastDraftTokens.size()); } for (SizeType32 bi = 0; bi < predictionTokens.size(); ++bi) { SizeType32 maxMatchLen = -1; SizeType32 maxMatchIdx = -1; // Find path with largest prefix shared with the predicted tokens. for (SizeType32 pi = 0; pi < lastDraftTokens[bi].size(); ++pi) { TLLM_CHECK_WITH_INFO(predictionTokens[bi][0] == lastDraftTokens[bi][pi][0], "First token of prediction and draft token must match"); auto const matchLen = longestCommonPrefixLength(lastDraftTokens[bi][pi], predictionTokens[bi]); if (matchLen > maxMatchLen) { maxMatchLen = matchLen; maxMatchIdx = pi; } } mBestPathLengths[bi] = maxMatchLen; mBestPathIndices[bi] = maxMatchIdx; // Update output ids. First draft token is already counted in outputs mOutputIds[bi].insert(mOutputIds[bi].end(), lastDraftTokens[bi][maxMatchIdx].begin() + 1, lastDraftTokens[bi][maxMatchIdx].begin() + maxMatchLen); mOutputIds[bi].push_back(nextDraftTokens[bi][0][0]); } } void ExplicitDraftTokensDummyNetwork::forward(SamplingParams const& params, std::vector const& promptsLetters, std::vector const& predictionLetters, DraftLettersVec const& nextDraftLetters, DraftLettersVec const& lastDraftLetters) { mSamplingParams = params; TLLM_CHECK(params.getBatchSize() == promptsLetters.size()); TLLM_CHECK(params.getBatchSize() == predictionLetters.size()); TLLM_CHECK(params.getBatchSize() == nextDraftLetters.size()); TLLM_CHECK(params.getBatchSize() == lastDraftLetters.size()); // Tokenize mNextDraftTokens = draftLettersToTokens(nextDraftLetters); mLastDraftTokens = draftLettersToTokens(lastDraftLetters); std::vector predictionTokens; for (SizeType32 bi = 0; bi < predictionLetters.size(); ++bi) { predictionTokens.push_back(tokenize(predictionLetters[bi])); mPrompts.push_back(tokenize(promptsLetters[bi])); } std::vector basePosIds; for (auto const& prompt : mPrompts) { basePosIds.push_back(prompt.size()); } mOutputIds = mPrompts; // Make compressed tensors and pos ids for the current and next tokens compressTokens(mNextCompressedVector, mNextPackedPosIds, mNextDraftTokenIndices, mNextGenerationLengths, mNextDraftTokens, basePosIds); compressTokens(mLastCompressedVector, mLastPackedPosIds, mLastDraftTokenIndices, mLastGenerationLengths, mLastDraftTokens, basePosIds); mMaxNextGenLength = *std::max_element(mNextGenerationLengths.begin(), mNextGenerationLengths.end()); acceptTokens(predictionTokens, mLastDraftTokens, mNextDraftTokens); createNextMasks(mNextDraftTokenIndices, mNextDraftTokens, mMaxNextGenLength); } TEST(ExplicitDraftTokensDummyNetworkTest, tokenizeTest) { ExplicitDraftTokensDummyNetwork network; { auto tokens = network.tokenize("hello world"); EXPECT_EQ(tokens, std::vector({104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100})); } { DraftLettersVec lettersVec = {{"hello world", "hello"}, {"world"}}; auto draftTokens = network.draftLettersToTokens(lettersVec); ASSERT_EQ(draftTokens.size(), 2); ASSERT_EQ(draftTokens[0].size(), 2); ASSERT_EQ(draftTokens[1].size(), 1); EXPECT_EQ(draftTokens[0][0], std::vector({104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100})); EXPECT_EQ(draftTokens[0][1], std::vector({104, 101, 108, 108, 111})); EXPECT_EQ(draftTokens[1][0], std::vector({119, 111, 114, 108, 100})); } } TEST(ExplicitDraftTokensDummyNetworkTest, detokenizeTest) { ExplicitDraftTokensDummyNetwork network; { auto letters = network.detokenize(std::vector({104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100})); EXPECT_EQ(letters, "hello world"); } } TEST(ExplicitDraftTokensDummyNetworkTest, longestCommonPrefixLengthTest) { ExplicitDraftTokensDummyNetwork network; EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {1, 2}), 2); EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {1, 2, 3}), 3); EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {1, 5, 6}), 1); EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {2, 5, 6}), 0); EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {}), 0); } TEST(ExplicitDraftTokensDummyNetworkTest, computeCompressedVectorAndIndicesTest) { ExplicitDraftTokensDummyNetwork network; { std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; SizeType32 basePosId{0}; std::vector> tokens = {{0, 1, 2, 3}}; auto const totalGen = network.computeCompressedVectorAndIndices(compressedVector, packedPosIds, indices, tokens, basePosId); EXPECT_EQ(totalGen, 4); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3})); EXPECT_EQ(packedPosIds, std::vector({0, 1, 2, 3})); ASSERT_EQ(indices.size(), 1); ASSERT_EQ(indices[0].size(), 1); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); } { std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; SizeType32 basePosId{0}; std::vector> tokens = {{0, 1, 2, 3}, {0, 2, 3, 4}}; auto const totalGen = network.computeCompressedVectorAndIndices(compressedVector, packedPosIds, indices, tokens, basePosId); EXPECT_EQ(totalGen, 7); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3, 2, 3, 4})); EXPECT_EQ(packedPosIds, std::vector({0, 1, 2, 3, 1, 2, 3})); ASSERT_EQ(indices.size(), 1); ASSERT_EQ(indices[0].size(), 2); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[0][1], std::vector({0, 4, 5, 6})); } { std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; SizeType32 basePosId{0}; std::vector> tokens = {{0, 1, 2, 3}, {0, 1, 6, 2}, {0, 5, 6, 2}}; auto const totalGen = network.computeCompressedVectorAndIndices(compressedVector, packedPosIds, indices, tokens, basePosId); EXPECT_EQ(totalGen, 9); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3, 6, 2, 5, 6, 2})); EXPECT_EQ(packedPosIds, std::vector({0, 1, 2, 3, 2, 3, 1, 2, 3})); ASSERT_EQ(indices.size(), 1); ASSERT_EQ(indices[0].size(), 3); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[0][1], std::vector({0, 1, 4, 5})); EXPECT_EQ(indices[0][2], std::vector({0, 6, 7, 8})); } { std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; SizeType32 basePosId{10}; std::vector> tokens = {{0, 1, 2, 3}, {0, 1, 6, 2}, {0, 5, 6, 2}}; auto const totalGen = network.computeCompressedVectorAndIndices(compressedVector, packedPosIds, indices, tokens, basePosId); EXPECT_EQ(totalGen, 9); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3, 6, 2, 5, 6, 2})); EXPECT_EQ(packedPosIds, std::vector({10, 11, 12, 13, 12, 13, 11, 12, 13})); ASSERT_EQ(indices.size(), 1); ASSERT_EQ(indices[0].size(), 3); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[0][1], std::vector({0, 1, 4, 5})); EXPECT_EQ(indices[0][2], std::vector({0, 6, 7, 8})); } } TEST(ExplicitDraftTokensDummyNetworkTest, compressTokensTest) { { ExplicitDraftTokensDummyNetwork network; std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; std::vector genLengths; SamplingParams params; params.setBatchSize(1); params.setMaxNumPaths(1); params.setMaxDraftPathLen(6); network.setSamplingParams(params); DraftTokensVec tokens = {{{0, 1, 2, 3}}}; std::vector basePosIds = {0}; network.compressTokens(compressedVector, packedPosIds, indices, genLengths, tokens, basePosIds); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3, -1, -1, -1})); EXPECT_EQ(packedPosIds, std::vector({0, 1, 2, 3, 0, 0, 0})); ASSERT_EQ(indices.size(), 1); ASSERT_EQ(indices[0].size(), 1); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); ASSERT_EQ(genLengths.size(), 1); EXPECT_EQ(genLengths[0], 4); network.createNextMasks(indices, tokens, 4); auto masks = network.getNextMasks(); ASSERT_EQ(masks.size(), 1); ASSERT_EQ(masks[0].size(), 4); ASSERT_EQ(masks[0][0].size(), 4); EXPECT_EQ(masks[0][0], std::vector({true, false, false, false})); EXPECT_EQ(masks[0][1], std::vector({true, true, false, false})); EXPECT_EQ(masks[0][2], std::vector({true, true, true, false})); EXPECT_EQ(masks[0][3], std::vector({true, true, true, true})); } { ExplicitDraftTokensDummyNetwork network; std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; std::vector genLengths; SamplingParams params; params.setBatchSize(2); params.setMaxNumPaths(1); params.setMaxDraftPathLen(6); network.setSamplingParams(params); std::vector basePosIds = {10, 10}; DraftTokensVec tokens = {{{0, 1, 2, 3}}, {{0, 1, 2, 3}}}; network.compressTokens(compressedVector, packedPosIds, indices, genLengths, tokens, basePosIds); EXPECT_EQ(compressedVector, std::vector({0, 1, 2, 3, 0, 1, 2, 3, -1, -1, -1, -1, -1, -1})); EXPECT_EQ(packedPosIds, std::vector({10, 11, 12, 13, 10, 11, 12, 13, 0, 0, 0, 0, 0, 0})); ASSERT_EQ(indices.size(), 2); ASSERT_EQ(indices[0].size(), 1); ASSERT_EQ(indices[1].size(), 1); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[1][0], std::vector({0, 1, 2, 3})); ASSERT_EQ(genLengths.size(), 2); EXPECT_EQ(genLengths[0], 4); EXPECT_EQ(genLengths[1], 4); network.createNextMasks(indices, tokens, 4); auto masks = network.getNextMasks(); ASSERT_EQ(masks.size(), 2); ASSERT_EQ(masks[0].size(), 4); ASSERT_EQ(masks[1].size(), 4); ASSERT_EQ(masks[0][0].size(), 4); ASSERT_EQ(masks[1][0].size(), 4); EXPECT_EQ(masks[0][0], std::vector({true, false, false, false})); EXPECT_EQ(masks[0][1], std::vector({true, true, false, false})); EXPECT_EQ(masks[0][2], std::vector({true, true, true, false})); EXPECT_EQ(masks[0][3], std::vector({true, true, true, true})); EXPECT_EQ(masks[1][0], std::vector({true, false, false, false})); EXPECT_EQ(masks[1][1], std::vector({true, true, false, false})); EXPECT_EQ(masks[1][2], std::vector({true, true, true, false})); EXPECT_EQ(masks[1][3], std::vector({true, true, true, true})); } { ExplicitDraftTokensDummyNetwork network; std::vector compressedVector; std::vector packedPosIds; DraftTokensIndices indices; std::vector genLengths; SamplingParams params; params.setBatchSize(2); params.setMaxNumPaths(3); params.setMaxDraftPathLen(4); network.setSamplingParams(params); std::vector basePosIds = {10, 0}; DraftTokensVec tokens = {{{0, 1, 2, 3}, {0, 1, 6, 2}, {0, 5, 6, 2}}, {{0, 1, 2, 3}, {0, 1, 2, 4}}}; network.compressTokens(compressedVector, packedPosIds, indices, genLengths, tokens, basePosIds); EXPECT_EQ(compressedVector, std::vector( {0, 1, 2, 3, 6, 2, 5, 6, 2, 0, 1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1})); EXPECT_EQ(packedPosIds, std::vector( {10, 11, 12, 13, 12, 13, 11, 12, 13, 0, 1, 2, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); ASSERT_EQ(indices.size(), 2); ASSERT_EQ(indices[0].size(), 3); ASSERT_EQ(indices[1].size(), 2); EXPECT_EQ(indices[0][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[0][1], std::vector({0, 1, 4, 5})); EXPECT_EQ(indices[0][2], std::vector({0, 6, 7, 8})); EXPECT_EQ(indices[1][0], std::vector({0, 1, 2, 3})); EXPECT_EQ(indices[1][1], std::vector({0, 1, 2, 4})); ASSERT_EQ(genLengths.size(), 2); EXPECT_EQ(genLengths[0], 9); EXPECT_EQ(genLengths[1], 5); network.createNextMasks(indices, tokens, 9); auto masks = network.getNextMasks(); ASSERT_EQ(masks.size(), 2); ASSERT_EQ(masks[0].size(), 9); ASSERT_EQ(masks[1].size(), 9); ASSERT_EQ(masks[0][0].size(), 9); ASSERT_EQ(masks[1][0].size(), 9); EXPECT_EQ(masks[0][0], std::vector({true, false, false, false, false, false, false, false, false})); EXPECT_EQ(masks[0][1], std::vector({true, true, false, false, false, false, false, false, false})); EXPECT_EQ(masks[0][2], std::vector({true, true, true, false, false, false, false, false, false})); EXPECT_EQ(masks[0][3], std::vector({true, true, true, true, false, false, false, false, false})); EXPECT_EQ(masks[0][4], std::vector({true, true, false, false, true, false, false, false, false})); EXPECT_EQ(masks[0][5], std::vector({true, true, false, false, true, true, false, false, false})); EXPECT_EQ(masks[0][6], std::vector({true, false, false, false, false, false, true, false, false})); EXPECT_EQ(masks[0][7], std::vector({true, false, false, false, false, false, true, true, false})); EXPECT_EQ(masks[0][8], std::vector({true, false, false, false, false, false, true, true, true})); EXPECT_EQ(masks[1][0], std::vector({true, false, false, false, false, false, false, false, false})); EXPECT_EQ(masks[1][1], std::vector({true, true, false, false, false, false, false, false, false})); EXPECT_EQ(masks[1][2], std::vector({true, true, true, false, false, false, false, false, false})); EXPECT_EQ(masks[1][3], std::vector({true, true, true, true, false, false, false, false, false})); EXPECT_EQ(masks[1][4], std::vector({true, true, true, false, true, false, false, false, false})); EXPECT_EQ(masks[1][5], std::vector({false, false, false, false, false, true, false, false, false})); EXPECT_EQ(masks[1][6], std::vector({false, false, false, false, false, false, true, false, false})); EXPECT_EQ(masks[1][7], std::vector({false, false, false, false, false, false, false, true, false})); EXPECT_EQ(masks[1][8], std::vector({false, false, false, false, false, false, false, false, true})); } } TEST(ExplicitDraftTokensDummyNetworkTest, acceptTokensTest) { { ExplicitDraftTokensDummyNetwork network; std::vector predictionTokens = {network.tokenize("how things")}; DraftLettersVec lastDraftLetters = {{"how do ", "how are", "however", "hello w"}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}}; auto lastDraftTokens = network.draftLettersToTokens(lastDraftLetters); auto nextDraftTokens = network.draftLettersToTokens(nextDraftLetters); network.acceptTokens(predictionTokens, lastDraftTokens, nextDraftTokens); auto bestPathLengths = network.getBestPathLengths(); auto bestPathIndices = network.getBestPathIndices(); auto outputIds = network.getOutputIds(); ASSERT_EQ(bestPathLengths.size(), 1); ASSERT_EQ(bestPathIndices.size(), 1); ASSERT_EQ(outputIds.size(), 1); EXPECT_EQ(bestPathLengths[0], 4); EXPECT_EQ(bestPathIndices[0], 0); EXPECT_EQ(network.detokenize(outputIds[0]), "ow t"); } { ExplicitDraftTokensDummyNetwork network; std::vector predictionTokens = {network.tokenize("however you")}; DraftLettersVec lastDraftLetters = {{"how do ", "how tho", "however", "hello w"}}; DraftLettersVec nextDraftLetters = {{" increme", " introdu", " i = 0; ", " importa"}}; auto lastDraftTokens = network.draftLettersToTokens(lastDraftLetters); auto nextDraftTokens = network.draftLettersToTokens(nextDraftLetters); network.acceptTokens(predictionTokens, lastDraftTokens, nextDraftTokens); auto bestPathLengths = network.getBestPathLengths(); auto bestPathIndices = network.getBestPathIndices(); auto outputIds = network.getOutputIds(); ASSERT_EQ(bestPathLengths.size(), 1); ASSERT_EQ(bestPathIndices.size(), 1); ASSERT_EQ(outputIds.size(), 1); EXPECT_EQ(bestPathLengths[0], 7); EXPECT_EQ(bestPathIndices[0], 2); EXPECT_EQ(network.detokenize(outputIds[0]), "owever "); } { ExplicitDraftTokensDummyNetwork network; std::vector predictionTokens = {network.tokenize("how things")}; DraftLettersVec lastDraftLetters = {{"heruist", "habit i", "handove", "hammer "}}; DraftLettersVec nextDraftLetters = {{"oatmeal", "ocean b", "occupat", "oblivio"}}; auto lastDraftTokens = network.draftLettersToTokens(lastDraftLetters); auto nextDraftTokens = network.draftLettersToTokens(nextDraftLetters); network.acceptTokens(predictionTokens, lastDraftTokens, nextDraftTokens); auto bestPathLengths = network.getBestPathLengths(); auto bestPathIndices = network.getBestPathIndices(); auto outputIds = network.getOutputIds(); ASSERT_EQ(bestPathLengths.size(), 1); ASSERT_EQ(bestPathIndices.size(), 1); ASSERT_EQ(outputIds.size(), 1); EXPECT_EQ(bestPathLengths[0], 1); EXPECT_EQ(bestPathIndices[0], 0); EXPECT_EQ(network.detokenize(outputIds[0]), "o"); } } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// template void ExplicitDraftTokensLayerTest::SetUp() { mStream = std::make_shared(); mBufferManager = std::make_shared(mStream); mAllocator = std::make_shared(*mBufferManager); } template void ExplicitDraftTokensLayerTest::allocateBuffers() { auto const dataType = TRTDataType::value; auto speculativeDecodingModule = std::make_shared(mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getMaxDecodingDraftTokens(), mSamplingParams.getMaxNumPaths()); auto const decodingDomain = tensorrt_llm::layers::DecoderDomain(mSamplingParams.getMaxBatchSize(), 1, mSamplingParams.getVocabSize(), mSamplingParams.getVocabSize(), speculativeDecodingModule); mExplicitDraftTokensLayer = std::make_shared>( decodingDomain, mStream->get(), mAllocator); // outputs mOutputIds = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxSeqLen()}), nvinfer1::DataType::kINT32); mSeqLengths = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mAcceptedLengths = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mNextDraftLengths = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mPrevDraftLengths = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mAcceptedLengthCumSum = BufferManager::pinned( ITensor::makeShape({mSamplingParams.getMaxBatchSize() + 1}), nvinfer1::DataType::kINT32); mOutputNextDraftTokens = BufferManager::pinned( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingDraftTokens()}), nvinfer1::DataType::kINT32); mOutputPositionIdsBase = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mRandomDataSample = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), dataType); mRandomDataValidation = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen()}), dataType); mPackedMasks = BufferManager::pinned( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingTokens(), static_cast(divUp(mSamplingParams.getMaxDecodingTokens(), 32))}), nvinfer1::DataType::kINT32); mNextPosIds = BufferManager::pinned( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kINT32); mOutputUnpackedNextDraftTokens = BufferManager::pinned( ITensor::makeShape( {mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mOutputUnpackedNextDraftIndices = BufferManager::pinned( ITensor::makeShape( {mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mOutputDraftProbs = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getVocabSize()}), dataType); mOutputTemperatures = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), dataType); mOutputGenerationLengths = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mMaxGenLengthHost = BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32); // inputs mBatchSlots = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32); mTokensPerStep = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mPathsOffsets = BufferManager::pinned( ITensor::makeShape({mSamplingParams.getMaxBatchSize() * mSamplingParams.getMaxDraftPathLen()}), nvinfer1::DataType::kINT32); mMasks = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingTokens(), mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kBOOL); mInputNextDraftTokens = BufferManager::pinned( ITensor::makeShape( {mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mLastDraftTokens = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mPackedPosIds = BufferManager::pinned( ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kINT32); mBestPathLengths = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mBestPathIndices = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mSpecDecodingGenerationLengths = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mNextFlatTokens = BufferManager::pinned( ITensor::makeShape({mSamplingParams.getMaxBatchSize() * mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kINT32); mInputPositionIdsBase = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mNextDraftIndices = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mLastDraftIndices = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mNextDraftProbs = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getVocabSize()}), dataType); mEndIds = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32); mMaxGenLengthDevice = BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32); // Packed inputs mMaxGenerationLength = BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32); mCumSumGenerationLengths = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32); // Packed outputs mPackedPositionIdsBase = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32); mPackedGenerationLengths = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32); mPackedRandomDataSample = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getBatchSize()}), dataType); mPackedRandomDataVerification = BufferManager::pinned( ITensor::makeShape( {mSamplingParams.getBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen()}), dataType); mPackedNextDraftTokens = BufferManager::pinned( ITensor::makeShape( {mSamplingParams.getBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mPackedNextDraftIndices = BufferManager::pinned( ITensor::makeShape( {mSamplingParams.getBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()}), nvinfer1::DataType::kINT32); mPackedPackedMasks = BufferManager::pinned( ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxDecodingTokens(), static_cast(divUp(mSamplingParams.getMaxDecodingTokens(), 32))}), nvinfer1::DataType::kINT32); mPackedPositionOffsets = BufferManager::pinned( ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kINT32); mPackedPackedPosIds = BufferManager::pinned( ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxDecodingTokens()}), nvinfer1::DataType::kINT32); mPackedDraftProbs = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getVocabSize()}), dataType); mPackedTemperatures = BufferManager::pinned(ITensor::makeShape({mSamplingParams.getBatchSize()}), dataType); } template void ExplicitDraftTokensLayerTest::setup() { // outputs trk::invokeFill(*mOutputIds, TokenIdType{-1}, *mStream); trk::invokeFill(*mSeqLengths, SizeType32{0}, *mStream); trk::invokeFill(*mAcceptedLengths, SizeType32{0}, *mStream); trk::invokeFill(*mAcceptedLengthCumSum, SizeType32{-1}, *mStream); trk::invokeFill(*mOutputNextDraftTokens, TokenIdType{-1}, *mStream); trk::invokeFill(*mOutputPositionIdsBase, SizeType32{0}, *mStream); trk::invokeFill(*mRandomDataSample, T{0}, *mStream); trk::invokeFill(*mRandomDataValidation, T{0}, *mStream); trk::invokeFill(*mPackedMasks, SizeType32{0}, *mStream); trk::invokeFill(*mNextPosIds, SizeType32{0}, *mStream); trk::invokeFill(*mOutputUnpackedNextDraftTokens, TokenIdType{-1}, *mStream); trk::invokeFill(*mOutputUnpackedNextDraftIndices, SizeType32{0}, *mStream); trk::invokeFill(*mEndIds, TokenIdType{-1}, *mStream); auto inDraftProbs = BufferRange(*mNextDraftProbs); std::mt19937 gen(42); std::uniform_real_distribution distr(0.0, 1.0); std::generate(inDraftProbs.begin(), inDraftProbs.end(), [&gen, &distr]() { return static_cast(distr(gen)); }); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { batchSlotsPtr[bi] = 2 * bi; } auto setupParams = std::make_shared(); mRandomSeeds = std::vector(mSamplingParams.getBatchSize()); mTemperatures = std::vector(mSamplingParams.getBatchSize()); std::mt19937 generator(42); std::uniform_int_distribution seedDistr(1, 1000); std::uniform_real_distribution temperatureDistr(0.001f, 1.f); std::generate( mRandomSeeds.begin(), mRandomSeeds.end(), [&generator, &seedDistr]() { return seedDistr(generator); }); std::generate(mTemperatures.begin(), mTemperatures.end(), [&generator, &temperatureDistr]() { return temperatureDistr(generator); }); setupParams->randomSeed = mRandomSeeds; setupParams->temperature = mTemperatures; setupParams->randomDataSample = tcc::toTllmTensor(*mRandomDataSample); setupParams->temperatures = tcc::toTllmTensor(*mOutputTemperatures); mExplicitDraftTokensLayer->setup(mSamplingParams.getBatchSize(), 1, batchSlotsPtr, setupParams); mStream->synchronize(); mBestPathLengths = mBufferManager->copyFrom(mNetwork.getBestPathLengths(), ITensor::makeShape({mSamplingParams.getBatchSize()}), runtime::MemoryType::kPINNED); mBestPathIndices = mBufferManager->copyFrom(mNetwork.getBestPathIndices(), ITensor::makeShape({mSamplingParams.getBatchSize()}), runtime::MemoryType::kPINNED); mPackedPosIds = mBufferManager->copyFrom(mNetwork.getNextPackedPosId(), ITensor::makeShape({mSamplingParams.getMaxDecodingTokens() * mSamplingParams.getBatchSize()}), runtime::MemoryType::kPINNED); auto const nextDraftTokens = mNetwork.getNextDraftTokens(); auto const lastDraftTokens = mNetwork.getLastDraftTokens(); auto const nextDraftIndices = mNetwork.getNextDraftIndices(); auto const lastDraftIndices = mNetwork.getLastDraftIndices(); auto sequenceLength = BufferRange(*mSeqLengths); auto nextDraftTokensRange = BufferRange(*mInputNextDraftTokens); auto lastDraftTokensRange = BufferRange(*mLastDraftTokens); auto nextDraftIndicesRange = BufferRange(*mNextDraftIndices); auto lastDraftIndicesRange = BufferRange(*mLastDraftIndices); auto inputPositionIdsBase = BufferRange(*mInputPositionIdsBase); auto outputIds = BufferRange(*mOutputIds); auto generationLengths = mNetwork.getNextGenerationLengths(); auto prompts = mNetwork.getPrompts(); for (SizeType32 bi = 0; bi < nextDraftTokens.size(); ++bi) { for (SizeType32 pi = 0; pi < nextDraftTokens[bi].size(); ++pi) { for (SizeType32 ti = 0; ti < nextDraftTokens[bi][pi].size(); ++ti) { auto idx = tc::flat_index3(bi, pi, ti, mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()); nextDraftTokensRange[idx] = nextDraftTokens[bi][pi][ti]; lastDraftTokensRange[idx] = lastDraftTokens[bi][pi][ti]; nextDraftIndicesRange[idx] = nextDraftIndices[bi][pi][ti]; lastDraftIndicesRange[idx] = lastDraftIndices[bi][pi][ti]; } } bufferCast(*mSpecDecodingGenerationLengths)[bi] = generationLengths[bi]; sequenceLength[batchSlotsPtr[bi]] = prompts[bi].size(); std::copy(prompts[bi].begin(), prompts[bi].end(), outputIds.begin() + batchSlotsPtr[bi] * mSamplingParams.getMaxSeqLen()); inputPositionIdsBase[bi] = prompts[bi].size(); } auto nextFlatTokens = mNetwork.getNextFlatTokens(); TLLM_LOG_DEBUG("Next flat tokens are \"%s\"", mNetwork.detokenize(nextFlatTokens).c_str()); auto nextFlatTokensRange = BufferRange(*mNextFlatTokens); std::copy(nextFlatTokens.begin(), nextFlatTokens.end(), nextFlatTokensRange.begin()); auto const masks = mNetwork.getNextMasks(); auto masksRange = BufferRange(*mMasks); auto const maxGenLength = mNetwork.getMaxNextGenerationLength(); bufferCast(*mMaxGenerationLength)[0] = maxGenLength; for (SizeType32 bi = 0; bi < masks.size(); ++bi) { TLLM_CHECK(maxGenLength == masks[bi].size()); for (SizeType32 ri = 0; ri < masks[bi].size(); ++ri) { TLLM_CHECK(maxGenLength == masks[bi][ri].size()); for (SizeType32 ci = 0; ci < masks[bi][ri].size(); ++ci) { masksRange[bi * maxGenLength * maxGenLength + ri * maxGenLength + ci] = masks[bi][ri][ci]; } } } } template std::shared_ptr ExplicitDraftTokensLayerTest::createInputTensors() { auto forwardParams = std::make_shared(tcc::toTllmTensor(*mEndIds), mSamplingParams.getBatchSize()); forwardParams->batchSlots = tcc::toTllmTensor(*mBatchSlots); forwardParams->seqSlots = tcc::toTllmTensor(*mBatchSlots); forwardParams->masks = tcc::toTllmTensor(*mMasks); forwardParams->nextDraftTokens = tcc::toTllmTensor(*mInputNextDraftTokens); forwardParams->nextDraftIndices = tcc::toTllmTensor(*mNextDraftIndices); forwardParams->lastDraftTokens = tcc::toTllmTensor(*mLastDraftTokens); forwardParams->lastDraftIndices = tcc::toTllmTensor(*mLastDraftIndices); forwardParams->packedPosIds = tcc::toTllmTensor(*mPackedPosIds); forwardParams->bestPathLengths = tcc::toTllmTensor(*mBestPathLengths); forwardParams->bestPathIndices = tcc::toTllmTensor(*mBestPathIndices); forwardParams->generationLengths = tcc::toTllmTensor(*mSpecDecodingGenerationLengths); forwardParams->nextFlatTokens = tcc::toTllmTensor(*mNextFlatTokens); forwardParams->positionIdsBase = tcc::toTllmTensor(*mInputPositionIdsBase); forwardParams->nextDraftProbs = tcc::toTllmTensor(*mNextDraftProbs); forwardParams->maxGenLengthDevice = tcc::toTllmTensor(*mMaxGenLengthDevice); return forwardParams; } template std::shared_ptr ExplicitDraftTokensLayerTest::createOutputTensors() { auto outputParams = std::make_shared(tcc::toTllmTensor(*mOutputIds)); outputParams->sequenceLength = tcc::toTllmTensor(*mSeqLengths); outputParams->nextDraftTokens = tcc::toTllmTensor(*mOutputNextDraftTokens); outputParams->numNewTokens = tcc::toTllmTensor(*mAcceptedLengths); outputParams->nextDraftLengths = tcc::toTllmTensor(*mNextDraftLengths); outputParams->prevDraftLengths = tcc::toTllmTensor(*mPrevDraftLengths); outputParams->numNewTokensCumSum = tcc::toTllmTensor(*mAcceptedLengthCumSum); outputParams->pathsOffsets = tcc::toTllmTensor(*mPathsOffsets); outputParams->nextDraftPosIds = tcc::toTllmTensor(*mNextPosIds); outputParams->positionIdsBase = tcc::toTllmTensor(*mOutputPositionIdsBase); outputParams->randomDataSample = tcc::toTllmTensor(*mRandomDataSample); outputParams->randomDataValidation = tcc::toTllmTensor(*mRandomDataValidation); outputParams->packedMasks = tcc::toTllmTensor(*mPackedMasks); outputParams->packedMasks = tcc::toTllmTensor(*mPackedMasks); outputParams->unpackedNextDraftTokens = tcc::toTllmTensor(*mOutputUnpackedNextDraftTokens); outputParams->unpackedNextDraftIndices = tcc::toTllmTensor(*mOutputUnpackedNextDraftIndices); outputParams->nextDraftProbs = tcc::toTllmTensor(*mOutputDraftProbs); outputParams->temperatures = tcc::toTllmTensor(*mOutputTemperatures); outputParams->generationLengths = tcc::toTllmTensor(*mOutputGenerationLengths); outputParams->maxGenLengthHost = tcc::toTllmTensor(*mMaxGenLengthHost); return outputParams; } std::vector boolArrayToBitmask(BufferRange::iterator boolIterator, size_t pathLen) { std::vector bitmask(divUp(pathLen, 32)); for (size_t bi = 0; bi < pathLen; ++bi) { auto slice = bi / 32; if (boolIterator[bi]) { bitmask[slice] |= (1 << (bi % 32)); } } return bitmask; } template void ExplicitDraftTokensLayerTest::checkLayerResult() { auto const batchSlots = BufferRange(*mBatchSlots); // Check generated random data { auto const randomDataSample = BufferRange(*mRandomDataSample); auto const randomDataValidation = BufferRange(*mRandomDataValidation); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; // Check that all fields are filled with non zero data EXPECT_NE(randomDataSample[batchSlot], T{0}) << " bi: " << bi; auto const stride = mSamplingParams.getMaxNumPaths() * mSamplingParams.getMaxDraftPathLen(); EXPECT_FALSE(std::any_of(randomDataValidation.begin() + batchSlot * stride, randomDataValidation.begin() + (batchSlot + 1) * stride, [](T val) { return val == T{0}; })) << " bi: " << bi; } } // Check masks { auto const packedMasks = BufferRange(*mPackedMasks); auto masks = BufferRange(*mMasks); auto generationLengths = mNetwork.getNextGenerationLengths(); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { for (SizeType32 ti = 0; ti < generationLengths[bi]; ++ti) { auto const batchSlot = batchSlots[bi]; auto const maskIdx = tc::flat_index3( bi, ti, 0, mNetwork.getMaxNextGenerationLength(), mNetwork.getMaxNextGenerationLength()); auto const bitmask = boolArrayToBitmask(masks.begin() + maskIdx, mNetwork.getMaxNextGenerationLength()); for (SizeType32 mi = 0; mi < bitmask.size(); ++mi) { auto const packedMaskIdx = tc::flat_index3(batchSlot, ti, mi, mSamplingParams.getMaxDecodingTokens(), static_cast(divUp(mSamplingParams.getMaxDecodingTokens(), 32))); EXPECT_EQ(bitmask[mi], packedMasks[packedMaskIdx]) << " bi: " << bi << " ti: " << ti; } } } } // Check accepted tokens auto const outputIds = BufferRange(*mOutputIds); auto const refOutputIds = mNetwork.getOutputIds(); auto const promptIds = mNetwork.getPrompts(); auto const seqLenghts = BufferRange(*mSeqLengths); auto const lastDraftTokens = BufferRange(*mLastDraftTokens); auto const bestPathLengths = BufferRange(*mBestPathLengths); auto const bestPathIndices = BufferRange(*mBestPathIndices); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; // Updated seq length is prompt length and newly accepted tokens. EXPECT_EQ(seqLenghts[batchSlot], promptIds[bi].size() + bestPathLengths[bi]) << " bi: " << bi; // Check that output ids contains accepted tokens. for (SizeType32 ti = 0; ti < promptIds[bi].size() + bestPathLengths[bi]; ++ti) { EXPECT_EQ(outputIds[batchSlot * mSamplingParams.getMaxSeqLen() + ti], refOutputIds[bi][ti]) << " bi: " << bi << " ti: " << ti; } auto outputIter = outputIds.begin() + batchSlot * mSamplingParams.getMaxSeqLen(); std::vector outputVec(outputIter, outputIter + seqLenghts[batchSlot]); TLLM_LOG_DEBUG("Output ids at %d request is \"%s\"", bi, mNetwork.detokenize(outputVec).c_str()); TLLM_LOG_DEBUG("Ref output ids at %d request is \"%s\"", bi, mNetwork.detokenize(refOutputIds[bi]).c_str()); } // Check new draft tokens { auto const outputNextDraftTokens = BufferRange(*mOutputNextDraftTokens); auto const generationLengths = BufferRange(*mSpecDecodingGenerationLengths); auto const compressedDraftTokens = mNetwork.getNextFlatTokens(); TLLM_LOG_DEBUG("Next compressed draft tokens are \"%s\"", mNetwork.detokenize(compressedDraftTokens).c_str()); SizeType32 compressedIdx = 0; for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; auto const generatedLength = generationLengths[bi]; // Check draft tokens for the next iteration. for (SizeType32 ti = 0; ti < generatedLength - 1; ++ti) { auto const idx = tc::flat_index2(batchSlot, ti, mSamplingParams.getMaxDecodingDraftTokens()); EXPECT_EQ(outputNextDraftTokens[idx], compressedDraftTokens[compressedIdx + ti + 1]) << " bi: " << bi << " ti: " << ti; } // Check length of the draft tokens. EXPECT_EQ(BufferRange(*mNextDraftLengths)[batchSlot], generatedLength - 1) << " bi: " << bi; // Check accepted length. EXPECT_EQ(BufferRange(*mAcceptedLengths)[batchSlot], bestPathLengths[bi]) << " bi: " << bi; compressedIdx += generatedLength; } } // Check position ids { auto const outputPositionIdsBase = BufferRange(*mOutputPositionIdsBase); auto const nextPosIds = BufferRange(*mNextPosIds); auto const generationLengths = BufferRange(*mSpecDecodingGenerationLengths); auto const packedPosIds = mNetwork.getNextPackedPosId(); SizeType32 compressedIdx = 0; for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; EXPECT_EQ(outputPositionIdsBase[batchSlot], seqLenghts[batchSlot]); auto const generatedLength = generationLengths[bi]; // Check pos ids for the next iteration. for (SizeType32 ti = 0; ti < generatedLength; ++ti) { auto const idx = tc::flat_index2(batchSlot, ti, mSamplingParams.getMaxDecodingTokens()); // Minus -1 to account for context phase correction of pos ids EXPECT_EQ(nextPosIds[idx], packedPosIds[compressedIdx + ti] - 1) << " bi: " << bi << " ti: " << ti; } compressedIdx += generatedLength; } } // Check unpacked indices and tokens { auto const nextDraftTokens = mNetwork.getNextDraftTokens(); auto const nextDraftIndices = mNetwork.getNextDraftIndices(); auto const nextDraftTokensRange = BufferRange(*mOutputUnpackedNextDraftTokens); auto const nextDraftIndicesRange = BufferRange(*mOutputUnpackedNextDraftIndices); for (SizeType32 bi = 0; bi < nextDraftTokens.size(); ++bi) { auto const batchSlot = batchSlots[bi]; for (SizeType32 pi = 0; pi < nextDraftTokens[bi].size(); ++pi) { for (SizeType32 ti = 0; ti < nextDraftTokens[bi][pi].size(); ++ti) { auto idx = tc::flat_index3( batchSlot, pi, ti, mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxPathLen()); EXPECT_EQ(nextDraftTokensRange[idx], nextDraftTokens[bi][pi][ti]) << "bi: " << bi << " pi: " << pi << " ti: " << ti; EXPECT_EQ(nextDraftIndicesRange[idx], nextDraftIndices[bi][pi][ti]) << "bi: " << bi << " pi: " << pi << " ti: " << ti; } } } } // Check accumulated cum sum and paths offsets { auto const accumulatedCumSum = BufferRange(*mAcceptedLengthCumSum); auto const pathsOffsets = BufferRange(*mPathsOffsets); auto const acceptedLengths = BufferRange(*mAcceptedLengths); auto const bestPathIndices = BufferRange(*mBestPathIndices); auto const lastDraftIndices = mNetwork.getLastDraftIndices(); SizeType32 sum = 0; for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; EXPECT_EQ(sum, accumulatedCumSum[bi]) << "bi: " << bi; auto const acceptedLength = acceptedLengths[batchSlot] - 1; for (SizeType32 ti = 0; ti < acceptedLength; ++ti) { EXPECT_EQ(pathsOffsets[sum + ti], lastDraftIndices[bi][bestPathIndices[bi]][ti + 1] - 1) << "bi: " << bi << " ti: " << ti; } sum += acceptedLength; } EXPECT_EQ(sum, accumulatedCumSum[mSamplingParams.getBatchSize()]); } // Check draft probs { auto const outDraftProbs = BufferRange(*mOutputDraftProbs); auto const inDraftProbs = BufferRange(*mNextDraftProbs); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; for (SizeType32 pi = 0; pi < mSamplingParams.getMaxNumPaths(); ++pi) { for (SizeType32 ti = 0; ti < mSamplingParams.getMaxDraftPathLen(); ++ti) { for (SizeType32 vi = 0; vi < mSamplingParams.getVocabSize(); ++vi) { auto const outProbIdx = tc::flat_index4(batchSlot, pi, ti, vi, mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getVocabSize()); auto const inProbIdx = tc::flat_index4(bi, pi, ti, vi, mSamplingParams.getMaxNumPaths(), mSamplingParams.getMaxDraftPathLen(), mSamplingParams.getVocabSize()); EXPECT_EQ(outDraftProbs[outProbIdx], inDraftProbs[inProbIdx]) << "bi: " << bi << " pi: " << pi << " ti: " << ti << " vi: " << vi; } } } } } // Check temperature { for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; EXPECT_EQ(BufferRange(*mOutputTemperatures)[batchSlot], static_cast(1.f / mTemperatures[bi])) << " bi: " << bi; } } } template void ExplicitDraftTokensLayerTest::packData() { tksd::PackExplicitDraftTokensParams params; params.batchSlots = bufferCast(*mBatchSlots); params.cumSumGenerationLengths = bufferCast(*mCumSumGenerationLengths); params.maxGenerationLength = bufferCast(*mMaxGenerationLength); params.outputPositionIdsBase = bufferCast(*mPackedPositionIdsBase); params.inputPositionIdsBase = bufferCast(*mOutputPositionIdsBase); params.outputGenerationLengths = bufferCast(*mPackedGenerationLengths); params.inputGenerationLengths = bufferCast(*mSpecDecodingGenerationLengths); params.outputRandomDataSample = bufferCast(*mPackedRandomDataSample); params.inputRandomDataSample = bufferCast(*mRandomDataSample); params.outputRandomDataValidation = bufferCast(*mPackedRandomDataVerification); params.inputRandomDataValidation = bufferCast(*mRandomDataValidation); params.outputNextDraftTokens = bufferCast(*mPackedNextDraftTokens); params.inputNextDraftTokens = bufferCast(*mOutputUnpackedNextDraftTokens); params.outputNextDraftIndices = bufferCast(*mPackedNextDraftIndices); params.inputNextDraftIndices = bufferCast(*mOutputUnpackedNextDraftIndices); params.outputPackedMask = bufferCast(*mPackedPackedMasks); params.inputPackedMask = bufferCast(*mPackedMasks); params.inputPositionIds = bufferCast(*mNextPosIds); params.outputPositionOffsets = bufferCast(*mPackedPositionOffsets); params.outputPositionIds = bufferCast(*mPackedPackedPosIds); params.outputDraftProbs = bufferCast(*mPackedDraftProbs); params.inputDraftProbs = bufferCast(*mOutputDraftProbs); params.outputTemperatures = bufferCast(*mPackedTemperatures); params.inputTemperatures = bufferCast(*mOutputTemperatures); params.batchSize = mSamplingParams.getBatchSize(); params.numPaths = mSamplingParams.getMaxNumPaths(); params.maxPathLength = mSamplingParams.getMaxPathLen(); params.vocabSize = mSamplingParams.getVocabSize(); params.numGenerationRequests = mSamplingParams.getBatchSize(); params.numContextTokens = 0; params.checkParams(); tksd::invokePackGenerationLengths(params, mStream->get()); // Compute inclusive sum auto reduceTempStorageBytes = tksd::invokeScanGenerationLengths( nullptr, 0, nullptr, nullptr, mSamplingParams.getBatchSize(), mStream->get()); auto reduceMaxTempStorage = mBufferManager->gpu(reduceTempStorageBytes); tksd::invokeScanGenerationLengths(bufferCast(*reduceMaxTempStorage), reduceTempStorageBytes, bufferCast(*mSpecDecodingGenerationLengths), bufferCast(*mCumSumGenerationLengths), mSamplingParams.getBatchSize(), mStream->get()); // Pack tensors from batch slot position to continuous array tksd::invokePackExplicitDraftTokens(params, mStream->get()); // Copy draft probs tksd::invokeCopyProbs(params, mStream->get()); } template void ExplicitDraftTokensLayerTest::checkPackResult() { auto const batchSlots = BufferRange(*mBatchSlots); auto const maxGenLength = mNetwork.getMaxNextGenerationLength(); auto const numPackedMasks = static_cast(divUp(mSamplingParams.getMaxDecodingTokens(), 32)); for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi) { auto const batchSlot = batchSlots[bi]; EXPECT_EQ(BufferRange(*mPackedPositionIdsBase)[bi], BufferRange(*mOutputPositionIdsBase)[batchSlot]) << "bi: " << bi; EXPECT_EQ(BufferRange(*mPackedGenerationLengths)[bi], BufferRange(*mSpecDecodingGenerationLengths)[batchSlot]) << "bi: " << bi; EXPECT_EQ(BufferRange(*mPackedRandomDataSample)[bi], BufferRange(*mRandomDataSample)[batchSlot]) << "bi: " << bi; EXPECT_EQ(BufferRange(*mPackedTemperatures)[bi], BufferRange(*mOutputTemperatures)[batchSlot]) << "bi: " << bi; for (SizeType32 pi = 0; pi < mSamplingParams.getMaxNumPaths(); ++pi) { for (SizeType32 ti = 0; ti < mSamplingParams.getMaxDraftPathLen(); ++ti) { EXPECT_EQ(bufferCast(*ITensor::at(mPackedRandomDataVerification, {bi, pi, ti}))[0], bufferCast(*ITensor::at(mRandomDataValidation, {batchSlot, pi, ti}))[0]) << "bi: " << bi << " pi: " << pi << " ti: " << ti; for (SizeType32 vi = 0; vi < mSamplingParams.getVocabSize(); ++vi) { EXPECT_EQ(bufferCast(*ITensor::at(mPackedDraftProbs, {bi, pi, ti, vi}))[0], bufferCast(*ITensor::at(mOutputDraftProbs, {batchSlot, pi, ti, vi}))[0]) << "bi: " << bi << " pi: " << pi << " ti: " << ti << " vi: " << vi; } } for (SizeType32 ti = 0; ti < mSamplingParams.getMaxPathLen(); ++ti) { EXPECT_EQ(bufferCast(*ITensor::at(mPackedNextDraftTokens, {bi, pi, ti}))[0], bufferCast(*ITensor::at(mOutputUnpackedNextDraftTokens, {batchSlot, pi, ti}))[0]) << "bi: " << bi << " pi: " << pi << " ti: " << ti; EXPECT_EQ(bufferCast(*ITensor::at(mPackedNextDraftIndices, {bi, pi, ti}))[0], bufferCast(*ITensor::at(mOutputUnpackedNextDraftIndices, {batchSlot, pi, ti}))[0]) << "bi: " << bi << " pi: " << pi << " ti: " << ti; } } auto const basePosId = BufferRange(*mPackedPositionIdsBase)[bi]; for (SizeType32 ti = 0; ti < maxGenLength; ++ti) { auto const outPosOffsetIdx = tc::flat_index2(bi, ti, maxGenLength); auto const inPosOffsetIdx = tc::flat_index2(batchSlot, ti, mSamplingParams.getMaxDecodingTokens()); EXPECT_EQ(BufferRange(*mPackedPositionOffsets)[outPosOffsetIdx], BufferRange(*mNextPosIds)[inPosOffsetIdx] - basePosId + 1) << "bi: " << bi << " ti: " << ti; } auto const outputMaskStartId = (bi == 0) ? 0 : BufferRange(*mCumSumGenerationLengths)[bi - 1]; auto const numTokens = (bi == 0) ? BufferRange(*mCumSumGenerationLengths)[0] : BufferRange(*mCumSumGenerationLengths)[bi] - BufferRange(*mCumSumGenerationLengths)[bi - 1]; for (SizeType32 mi = 0; mi < numTokens * numPackedMasks; ++mi) { auto const outMaskIdx = outputMaskStartId * numPackedMasks + mi; auto const inMaskIdx = tc::flat_index2(batchSlot, mi, mSamplingParams.getMaxDecodingTokens() * numPackedMasks); EXPECT_EQ( BufferRange(*mPackedPackedMasks)[outMaskIdx], BufferRange(*mPackedMasks)[inMaskIdx]) << "bi: " << bi << " mi: " << mi; } } } template void ExplicitDraftTokensLayerTest::runTest(std::vector const& prompts, std::vector const& predictions, DraftLettersVec const& nextDraftLetters, DraftLettersVec const& lastDraftLetters, SamplingParams& params) { mSamplingParams = params; mNetwork.forward(params, prompts, predictions, nextDraftLetters, lastDraftLetters); allocateBuffers(); setup(); auto inputTensors = createInputTensors(); auto outputTensors = createOutputTensors(); mExplicitDraftTokensLayer->forwardAsync(outputTensors, inputTensors); mStream->synchronize(); checkLayerResult(); packData(); mStream->synchronize(); checkPackResult(); } template class ExplicitDraftTokensLayerTest; template class ExplicitDraftTokensLayerTest; TYPED_TEST_SUITE(ExplicitDraftTokensLayerTest, FloatAndHalfTypes); TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestBS1) { SamplingParams params; std::vector prompt = {"Hi mate, h"}; std::vector predictions = {"how things"}; DraftLettersVec lastDraftLetters = {{"how do ", "how are", "however", "hello w"}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}}; params.setBatchSize(1); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestBS1OnePaths) { SamplingParams params; std::vector prompt = {"Hi mate, h"}; std::vector predictions = {"how things"}; DraftLettersVec lastDraftLetters = {{"how do "}}; DraftLettersVec nextDraftLetters = {{"things "}}; params.setBatchSize(1); params.setMaxNumPaths(1); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestSecondPathAcceptedBS1) { SamplingParams params; std::vector prompt = {"Hi mate, h"}; std::vector predictions = {"how things"}; DraftLettersVec lastDraftLetters = {{"howdy f", "how are", "however", "hello w"}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}}; params.setBatchSize(1); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestNoDraftAcceptedBS1) { SamplingParams params; std::vector prompt = {"Hi mate, h"}; std::vector predictions = {"how things"}; DraftLettersVec lastDraftLetters = {{"handove", "human f", "heavy l", "hello h"}}; DraftLettersVec nextDraftLetters = {{"oatmeal", "ocean b", "occupat", "oblivio"}}; params.setBatchSize(1); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestBS2SameSequence) { SamplingParams params; std::vector prompt = {"Hi mate, h", "Hi mate, h"}; std::vector predictions = {"how things", "how things"}; DraftLettersVec lastDraftLetters = {{"how do ", "how are", "however", "hello w"}, {"how do ", "how are", "however", "hello w"}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}, {"things ", "that is", "to crea", "touchab"}}; params.setBatchSize(2); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestBS2Long) { SamplingParams params; std::vector prompt = {"Hi mate, h", "London is t"}; std::vector predictions = {"how things are going", "the capital of Great Britain"}; DraftLettersVec lastDraftLetters = {{"how do you ", "how are you", "however you", "hello world"}, {"the bar and", "the best ci", "the capital", "thoughest p"}}; DraftLettersVec nextDraftLetters = {{"things are ", "that is sad", "to create a", "touchable y"}, {" of Great B", " and the ma", " of country", " also known"}}; params.setBatchSize(2); // ceil(4 * 10 / 32) = 2 masks per request params.setMaxNumPaths(4); params.setMaxDraftPathLen(10); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestBS2DifferentSequences) { SamplingParams params; std::vector prompt = {"Hi mate, h", "London is t"}; std::vector predictions = {"how things", "the cap"}; DraftLettersVec lastDraftLetters = {{"how do ", "how are", "however", "hello w"}, {"the bar", "the bes", "the cap", "thoughe"}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}, {"itan of", "iteract", "ital of", "importa"}}; params.setBatchSize(2); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestB4DifferentSequences) { SamplingParams params; std::vector prompt = {"Hi mate, h", "London is t", "Short", "Very long prompt but should not m"}; std::vector predictions = {"how things", "the cap", "twave o", "matter "}; DraftLettersVec lastDraftLetters = {{"how do ", "how are", "however", "hello w"}, {"the bar", "the bes", "the cap", "thoughe"}, {"t promp", "ts on Y", "ter out", "twave o"}, {"matter ", "mean an", "make th", "modify "}}; DraftLettersVec nextDraftLetters = {{"things ", "that is", "to crea", "touchab"}, {"itan of", "iteract", "ital of", "importa"}, {" chips ", " oil an", " semico", " exampl"}, {"at all ", "anythin", "above a", "albeit "}}; params.setBatchSize(4); this->runTest(prompt, predictions, nextDraftLetters, lastDraftLetters, params); } } // namespace tensorrt_llm::tests::layers