TensorRT-LLMs/cpp/tests/unit_tests/layers/eagleLayerTest.cpp
Robin Kobus 31979aefac
[None] [ci] Reorganize CMake and Python integration test infrastructure for C++ tests (#6754)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-08-24 20:53:17 +02:00

1158 lines
48 KiB
C++

/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "eagleLayerTest.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/speculativeDecodingModule.h"
#include "tensorrt_llm/runtime/tllmLogger.h"
#include <NvInferRuntimeBase.h>
#include <algorithm>
#include <cstdint>
#include <random>
namespace tensorrt_llm::tests::layers
{
using namespace tensorrt_llm::runtime;
using namespace tensorrt_llm::layers;
using namespace tensorrt_llm::common;
namespace tk = tensorrt_llm::kernels;
namespace tksd = tensorrt_llm::kernels::speculative_decoding;
namespace trk = tensorrt_llm::runtime::kernels;
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
TokensVec EagleDummyNetwork::tokenize(std::string const& letters) const
{
TokensVec tokens;
for (char c : letters)
{
tokens.push_back(static_cast<TokenIdType>(c));
}
return tokens;
}
std::string EagleDummyNetwork::detokenize(TokensVec const& tokens) const
{
std::string letters;
for (int token : tokens)
{
letters += static_cast<char>(token);
}
return letters;
}
DraftTokensVec EagleDummyNetwork::draftLettersToTokens(DraftLettersVec const& draftLetters) const
{
DraftTokensVec draftTokens(draftLetters.size());
for (SizeType32 bi = 0; bi < draftLetters.size(); ++bi)
{
draftTokens[bi] = tokenize(draftLetters[bi]);
}
return draftTokens;
}
SizeType32 EagleDummyNetwork::longestCommonPrefixLength(TokensVec const& a, TokensVec const& b) const
{
SizeType32 minLength = std::min(a.size(), b.size());
SizeType32 idx = 0;
while (idx < minLength && a[idx] == b[idx])
{
++idx;
}
return idx;
}
DraftPath EagleDummyNetwork::pathFromDraftTokens(
DraftTokensVec const& tokens, SizeType32 maxDecodingTokens, SizeType32 maxPathLen) const
{
DraftPath path(maxDecodingTokens);
for (SizeType32 pi = 0; pi < maxDecodingTokens; ++pi)
{
path[pi].resize(maxPathLen);
for (SizeType32 ti = 0; ti < maxPathLen; ++ti)
{
path[pi][ti] = -1;
}
}
SizeType32 draftPosCounter = 1;
for (SizeType32 ti = 1; ti < maxPathLen; ++ti)
{
std::unordered_map<std::string, SizeType32> tokenPosMap;
for (SizeType32 pi = 0; pi < tokens.size(); ++pi)
{
if (tokens[pi].size() > ti - 1)
{
path[pi][0] = 0;
auto const token = tokens[pi][ti - 1];
auto const draftPrefix = detokenize(tokens[pi]).substr(0, ti);
if (tokenPosMap.count(draftPrefix) == 0)
{
tokenPosMap[draftPrefix] = draftPosCounter++;
}
path[pi][ti] = tokenPosMap[draftPrefix];
}
}
}
return path;
}
TokensVec EagleDummyNetwork::flattenTokens(
DraftTokensVec const& tokens, DraftPath const& path, bool isDraftTokens) const
{
SizeType32 maxPathIdx{-1};
for (SizeType32 pi = 0; pi < path.size(); ++pi)
{
for (SizeType32 ti = 0; ti < path[pi].size(); ++ti)
{
auto const pathIdx = path[pi][ti];
maxPathIdx = std::max(pathIdx, maxPathIdx);
}
}
if (!isDraftTokens)
{
maxPathIdx++;
}
TokensVec flattenedTokens(maxPathIdx);
for (SizeType32 pi = 0; pi < path.size(); ++pi)
{
for (SizeType32 ti = 0; ti < path[pi].size(); ++ti)
{
if (isDraftTokens && ti == 0)
{
continue;
}
auto const pathIdx = path[pi][ti];
if (pathIdx != -1)
{
if (isDraftTokens)
{
flattenedTokens[pathIdx - 1] = tokens[pi][ti - 1];
}
else
{
flattenedTokens[pathIdx] = tokens[pi][ti];
}
}
}
}
return flattenedTokens;
}
std::vector<std::vector<std::vector<bool>>> EagleDummyNetwork::createMasks(DraftPaths const& paths) const
{
std::vector<std::vector<std::vector<bool>>> masks;
for (SizeType32 bi = 0; bi < paths.size(); ++bi)
{
std::vector<std::vector<bool>> localMask(paths[bi].size());
for (SizeType32 ti = 0; ti < paths[bi].size(); ++ti)
{
localMask[ti].resize(paths[bi].size());
}
localMask[0][0] = true;
for (SizeType32 pi = 0; pi < paths[bi].size(); ++pi)
{
for (SizeType32 ti = 1; ti < paths[bi][pi].size(); ++ti)
{
auto const to = paths[bi][pi][ti];
if (to == -1)
{
break;
}
localMask[to][to] = true;
for (SizeType32 fi = 0; fi < ti; ++fi)
{
auto const from = paths[bi][pi][fi];
localMask[to][from] = true;
}
}
}
masks.push_back(localMask);
}
return masks;
}
void EagleDummyNetwork::acceptTokens(std::vector<TokensVec> const& predictionTokens,
DraftTokensVec const& lastDraftTokens, DraftPaths const& lastDraftPaths)
{
TLLM_CHECK_WITH_INFO(predictionTokens.size() == lastDraftTokens.size(),
"Batch size of predictions (%d) does not match the batch size of last draft tokens (%d)",
static_cast<SizeType32>(predictionTokens.size()), static_cast<SizeType32>(lastDraftTokens.size()));
TLLM_CHECK_WITH_INFO(predictionTokens.size() == lastDraftPaths.size(),
"Batch size of predictions (%d) does not match the batch size of last draft paths (%d)",
static_cast<SizeType32>(predictionTokens.size()), static_cast<SizeType32>(lastDraftPaths.size()));
mAcceptedTokens.resize(predictionTokens.size());
mAcceptedLens.resize(predictionTokens.size());
mAcceptedPathIds.resize(predictionTokens.size());
// Needed for unit test of EagleDummyNetwork only.
if (mOutputIds.size() == 0)
{
mOutputIds.resize(lastDraftTokens.size());
}
for (SizeType32 bi = 0; bi < lastDraftPaths.size(); ++bi)
{
SizeType32 maxMatchLen = -1;
SizeType32 maxMatchIdx = -1;
std::vector<TokenIdType> bestDraftPath;
// Find path with largest prefix shared with the predicted tokens.
for (SizeType32 pi = 0; pi < lastDraftPaths[bi].size(); ++pi)
{
TokensVec predictedPath(lastDraftPaths[bi][pi].size());
TokensVec draftPath(lastDraftPaths[bi][pi].size());
for (SizeType32 ti = 0; ti < lastDraftPaths[bi][pi].size(); ++ti)
{
predictedPath[ti] = predictionTokens[bi][lastDraftPaths[bi][pi][ti]];
if (ti > 0)
{
draftPath[ti - 1] = lastDraftTokens[bi][lastDraftPaths[bi][pi][ti] - 1];
}
}
auto const matchLen = longestCommonPrefixLength(draftPath, predictedPath);
if (matchLen > maxMatchLen)
{
maxMatchLen = matchLen;
maxMatchIdx = pi;
bestDraftPath = predictedPath;
}
}
mAcceptedTokens[bi] = bestDraftPath;
mAcceptedLens[bi] = maxMatchLen + 1;
mAcceptedPathIds[bi] = maxMatchIdx;
// Update output ids. First draft token is already counted in outputs
mOutputIds[bi].insert(mOutputIds[bi].end(), bestDraftPath.begin(), bestDraftPath.begin() + maxMatchLen + 1);
}
}
void EagleDummyNetwork::forward(SamplingParams const& params, std::vector<std::string> const& prompts,
std::vector<std::vector<std::string>> const& predictionLetters,
std::vector<DraftLettersVec> const& nextDraftLetters, std::vector<DraftLettersVec> const& lastDraftLetters)
{
mSamplingParams = params;
TLLM_CHECK(params.getBatchSize() == nextDraftLetters.size());
TLLM_CHECK(params.getBatchSize() == lastDraftLetters.size());
DraftPaths lastDraftPaths;
DraftPaths nextDraftPaths;
DraftTokensVec lastDraftTokensFlattened;
DraftTokensVec nextDraftTokensFlattened;
std::vector<TokensVec> predictionTokensFlattened;
for (SizeType32 bi = 0; bi < params.getBatchSize(); ++bi)
{
auto const lastDraftTokens = draftLettersToTokens(lastDraftLetters[bi]);
auto const nextDraftTokens = draftLettersToTokens(nextDraftLetters[bi]);
auto const lastDraftPath
= pathFromDraftTokens(lastDraftTokens, params.getMaxDecodingTokens(), params.getMaxPathLen());
auto const nextDraftPath
= pathFromDraftTokens(nextDraftTokens, params.getMaxDecodingTokens(), params.getMaxPathLen());
auto const predictionTokens = draftLettersToTokens(predictionLetters[bi]);
lastDraftPaths.push_back(lastDraftPath);
nextDraftPaths.push_back(nextDraftPath);
lastDraftTokensFlattened.push_back(flattenTokens(lastDraftTokens, lastDraftPath, /* isDraftTokens */ true));
nextDraftTokensFlattened.push_back(flattenTokens(nextDraftTokens, nextDraftPath, /* isDraftTokens */ true));
predictionTokensFlattened.push_back(flattenTokens(predictionTokens, lastDraftPath, /* isDraftTokens */ false));
}
mNextDraftTokens = nextDraftTokensFlattened;
mLastDraftTokens = lastDraftTokensFlattened;
mNextDraftPaths = nextDraftPaths;
mLastDraftPaths = lastDraftPaths;
mNextDraftLens.resize(params.getBatchSize());
mLastDraftLens.resize(params.getBatchSize());
for (SizeType32 bi = 0; bi < params.getBatchSize(); ++bi)
{
mNextDraftLens[bi] = mNextDraftTokens[bi].size();
mLastDraftLens[bi] = mLastDraftTokens[bi].size();
}
std::vector<TokensVec> predictionTokens;
for (SizeType32 bi = 0; bi < predictionLetters.size(); ++bi)
{
mPrompts.push_back(tokenize(prompts[bi]));
}
mOutputIds = mPrompts;
acceptTokens(predictionTokensFlattened, mLastDraftTokens, lastDraftPaths);
mMasks = createMasks(mNextDraftPaths);
}
TEST(EagleDummyNetworkTest, tokenizeTest)
{
EagleDummyNetwork network;
{
auto tokens = network.tokenize("hello world");
EXPECT_EQ(tokens, std::vector<TokenIdType>({104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100}));
}
{
DraftLettersVec lettersVec = {{"hello world"}, {"world"}};
auto draftTokens = network.draftLettersToTokens(lettersVec);
ASSERT_EQ(draftTokens.size(), 2);
ASSERT_EQ(draftTokens[0].size(), 11);
ASSERT_EQ(draftTokens[1].size(), 5);
EXPECT_EQ(draftTokens[0], std::vector<TokenIdType>({104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100}));
EXPECT_EQ(draftTokens[1], std::vector<TokenIdType>({119, 111, 114, 108, 100}));
}
}
TEST(EagleDummyNetworkTest, detokenizeTest)
{
EagleDummyNetwork network;
{
auto letters
= network.detokenize(std::vector<TokenIdType>({104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100}));
EXPECT_EQ(letters, "hello world");
}
}
TEST(EagleDummyNetworkTest, longestCommonPrefixLengthTest)
{
EagleDummyNetwork network;
EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {1, 2}), 2);
EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {1, 2, 3}), 3);
EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {1, 5, 6}), 1);
EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {2, 5, 6}), 0);
EXPECT_EQ(network.longestCommonPrefixLength({1, 2, 3}, {}), 0);
}
TEST(EagleDummyNetworkTest, pathFromDraftTokensTest)
{
EagleDummyNetwork network;
{
SizeType32 const maxDecodingTokens = 5;
SizeType32 const maxPathLen = 4;
DraftTokensVec draftTokens = {{1, 4, 8}, {1, 5}, {2, 6, 9}, {2, 7}, {3}};
auto const paths = network.pathFromDraftTokens(draftTokens, maxDecodingTokens, maxPathLen);
ASSERT_EQ(paths.size(), maxDecodingTokens);
for (SizeType32 pi = 0; pi < maxDecodingTokens; ++pi)
{
ASSERT_EQ(paths[pi].size(), maxPathLen);
if (pi < draftTokens.size())
{
for (SizeType32 ti = 0; ti < maxPathLen; ++ti)
{
if (ti == 0)
{
EXPECT_EQ(paths[pi][ti], 0);
}
else if (ti - 1 < draftTokens[pi].size())
{
EXPECT_EQ(paths[pi][ti], draftTokens[pi][ti - 1]);
}
else
{
EXPECT_EQ(paths[pi][ti], -1);
}
}
}
else
{
for (SizeType32 ti = 0; ti < maxPathLen; ++ti)
{
EXPECT_EQ(paths[pi][ti], -1);
}
}
}
}
}
TEST(EagleDummyNetworkTest, flattenedTokensTest)
{
{
EagleDummyNetwork network;
DraftTokensVec draftTokens = {{1, 4, 8}, {1, 5}, {2, 6, 9}, {2, 7}, {3}};
DraftPath path = {{0, 1, 4, 8}, {0, 1, 5, -1}, {0, 2, 6, 9}, {0, 2, 7, -1}, {0, 3, -1, -1}, {-1, -1, -1, -1},
{-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}};
auto const flattenTokens = network.flattenTokens(draftTokens, path, /* isDraftTokens*/ true);
EXPECT_EQ(flattenTokens, TokensVec({1, 2, 3, 4, 5, 6, 7, 8, 9}));
}
{
EagleDummyNetwork network;
DraftTokensVec predictionTokens = {{0, 1, 4, 8}, {0, 1, 5}, {0, 2, 6, 9}, {0, 2, 7}, {0, 3}};
DraftPath path = {{0, 1, 4, 8}, {0, 1, 5, -1}, {0, 2, 6, 9}, {0, 2, 7, -1}, {0, 3, -1, -1}, {-1, -1, -1, -1},
{-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}};
auto const flattenTokens = network.flattenTokens(predictionTokens, path, /* isDraftTokens*/ false);
EXPECT_EQ(flattenTokens, TokensVec({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
}
}
TEST(EagleDummyNetworkTest, createMasksTest)
{
{
EagleDummyNetwork network;
DraftPaths paths = {{{0, 1, -1, -1}, {-1, -1, -1, -1}}};
auto const mask = network.createMasks(paths);
std::vector<std::vector<std::vector<bool>>> refMask = {{{true, false}, {true, true}}};
EXPECT_EQ(mask, refMask);
}
{
EagleDummyNetwork network;
DraftPaths paths = {{{0, 1, 4, 8}, {0, 1, 5, -1}, {0, 2, 6, 9}, {0, 2, 7, -1}, {0, 3, -1, -1}, {-1, -1, -1, -1},
{-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}}};
auto const mask = network.createMasks(paths);
std::vector<std::vector<std::vector<bool>>> refMask
= {{{true, false, false, false, false, false, false, false, false, false},
{true, true, false, false, false, false, false, false, false, false},
{true, false, true, false, false, false, false, false, false, false},
{true, false, false, true, false, false, false, false, false, false},
{true, true, false, false, true, false, false, false, false, false},
{true, true, false, false, false, true, false, false, false, false},
{true, false, true, false, false, false, true, false, false, false},
{true, false, true, false, false, false, false, true, false, false},
{true, true, false, false, true, false, false, false, true, false},
{true, false, true, false, false, false, true, false, false, true}}};
EXPECT_EQ(mask, refMask);
}
{
EagleDummyNetwork network;
DraftPaths paths = {{{0, 1, 3}, {0, 2, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}},
{{0, 1, 3}, {0, 2, 4}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}}};
auto const mask = network.createMasks(paths);
std::vector<std::vector<std::vector<bool>>> refMask = {
{{true, false, false, false, false}, {true, true, false, false, false}, {true, false, true, false, false},
{true, true, false, true, false}, {false, false, false, false, false}},
{{true, false, false, false, false}, {true, true, false, false, false}, {true, false, true, false, false},
{true, true, false, true, false}, {true, false, true, false, true}}};
EXPECT_EQ(mask, refMask);
}
}
TEST(EagleDummyNetworkTest, acceptTokensTest)
{
{
EagleDummyNetwork network;
SizeType32 const batchSize{1};
SizeType32 const maxDecodingTokens{10};
SizeType32 const maxPathLen{4};
std::vector<DraftLettersVec> predictionLetters = {{"howe", "hoc", "hecl", "hea", "hu"}};
std::vector<DraftLettersVec> lastDraftLetters = {{"how", "he", "wow", "we", "a"}};
DraftPaths lastDraftPaths;
DraftTokensVec lastDraftTokensFlattened;
std::vector<TokensVec> predictionTokensFlattened;
for (SizeType32 bi = 0; bi < batchSize; ++bi)
{
auto const lastDraftTokens = network.draftLettersToTokens(lastDraftLetters[bi]);
auto const lastDraftPath = network.pathFromDraftTokens(lastDraftTokens, maxDecodingTokens, maxPathLen);
auto const predictionTokens = network.draftLettersToTokens(predictionLetters[bi]);
lastDraftPaths.push_back(lastDraftPath);
lastDraftTokensFlattened.push_back(
network.flattenTokens(lastDraftTokens, lastDraftPath, /* isDraftTokens */ true));
predictionTokensFlattened.push_back(
network.flattenTokens(predictionTokens, lastDraftPath, /* isDraftTokens */ false));
}
network.acceptTokens(predictionTokensFlattened, lastDraftTokensFlattened, lastDraftPaths);
auto acceptedLens = network.getAcceptedLens();
auto acceptedPathIds = network.getAcceptedPathIds();
auto outputIds = network.getOutputIds();
ASSERT_EQ(acceptedLens.size(), 1);
ASSERT_EQ(acceptedPathIds.size(), 1);
ASSERT_EQ(outputIds.size(), 1);
EXPECT_EQ(acceptedLens[0], 4);
EXPECT_EQ(acceptedPathIds[0], 0);
EXPECT_EQ(network.detokenize(outputIds[0]), "howe");
}
{
EagleDummyNetwork network;
SizeType32 const batchSize{2};
SizeType32 const maxDecodingTokens{10};
SizeType32 const maxPathLen{4};
std::vector<DraftLettersVec> predictionLetters
= {{"howe", "hoc", "hecl", "hea", "hu"}, {"bcde", "bcdc", "bca", "bcc", "bo"}};
std::vector<DraftLettersVec> lastDraftLetters
= {{"how", "he", "wow", "we", "a"}, {"inc", "inf", "ir", "im", "b"}};
DraftPaths lastDraftPaths;
DraftTokensVec lastDraftTokensFlattened;
std::vector<TokensVec> predictionTokensFlattened;
for (SizeType32 bi = 0; bi < batchSize; ++bi)
{
auto const lastDraftTokens = network.draftLettersToTokens(lastDraftLetters[bi]);
auto const lastDraftPath = network.pathFromDraftTokens(lastDraftTokens, maxDecodingTokens, maxPathLen);
auto const predictionTokens = network.draftLettersToTokens(predictionLetters[bi]);
lastDraftPaths.push_back(lastDraftPath);
lastDraftTokensFlattened.push_back(
network.flattenTokens(lastDraftTokens, lastDraftPath, /* isDraftTokens */ true));
predictionTokensFlattened.push_back(
network.flattenTokens(predictionTokens, lastDraftPath, /* isDraftTokens */ false));
}
network.acceptTokens(predictionTokensFlattened, lastDraftTokensFlattened, lastDraftPaths);
auto acceptedLens = network.getAcceptedLens();
auto acceptedPathIds = network.getAcceptedPathIds();
auto outputIds = network.getOutputIds();
ASSERT_EQ(acceptedLens.size(), 2);
ASSERT_EQ(acceptedPathIds.size(), 2);
ASSERT_EQ(outputIds.size(), 2);
EXPECT_EQ(acceptedLens[0], 4);
EXPECT_EQ(acceptedLens[1], 2);
EXPECT_EQ(acceptedPathIds[0], 0);
EXPECT_EQ(acceptedPathIds[1], 4);
EXPECT_EQ(network.detokenize(outputIds[0]), "howe");
EXPECT_EQ(network.detokenize(outputIds[1]), "bo");
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
void EagleDecodingLayerTest<T>::SetUp()
{
mStream = std::make_shared<tensorrt_llm::runtime::CudaStream>();
mBufferManager = std::make_shared<tensorrt_llm::runtime::BufferManager>(mStream);
}
template <typename T>
void EagleDecodingLayerTest<T>::allocateBuffers()
{
auto speculativeDecodingModule = std::make_shared<SpeculativeDecodingModule>(mSamplingParams.getMaxDraftPathLen(),
mSamplingParams.getMaxDecodingDraftTokens(), mSamplingParams.getMaxDecodingTokens());
auto const decodingDomain = tensorrt_llm::layers::DecoderDomain(mSamplingParams.getMaxBatchSize(), 1,
mSamplingParams.getVocabSize(), mSamplingParams.getVocabSize(), speculativeDecodingModule);
mEagleLayer = std::make_shared<tensorrt_llm::layers::EagleDecodingLayer<T>>(decodingDomain, mBufferManager);
// outputs
mOutputIds = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxSeqLen()}),
nvinfer1::DataType::kINT32);
mSeqLengths = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mOutputNextDraftTokens = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingDraftTokens()}),
nvinfer1::DataType::kINT32);
mOutputUnpackedNextDraftTokens = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingDraftTokens()}),
nvinfer1::DataType::kINT32);
mAcceptedLengths = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mNextPosIds = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingTokens()}),
nvinfer1::DataType::kINT32);
mPrevDraftLengths = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mNextDraftLengths = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mNextGenerationLengths
= mBufferManager->gpu(ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mNextGenerationLengthsHost = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mAcceptedLengthCumSum = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize() + 1}), nvinfer1::DataType::kINT32);
mPathsOffsets = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize() * mSamplingParams.getMaxDraftPathLen()}),
nvinfer1::DataType::kINT32);
mPackedMasks = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingTokens(),
static_cast<SizeType32>(divUp(mSamplingParams.getMaxDecodingTokens(), 32))}),
nvinfer1::DataType::kINT32);
mRandomDataSample = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kFLOAT);
mRandomDataValidation = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize(), mSamplingParams.getMaxDecodingTokens()}),
nvinfer1::DataType::kFLOAT);
mOutputTemperatures = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kFLOAT);
mOutputNextDraftPaths
= BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getMaxBatchSize(),
mSamplingParams.getMaxDecodingTokens(), mSamplingParams.getMaxPathLen()}),
nvinfer1::DataType::kINT32);
mEagleNetCtxRequestTypesHost = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mEagleNetCtxContextLengthsHost = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mEagleNetCtxPastKeyValueLengthsHost = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mEagleNetGenRequestTypesHost = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mEagleNetGenContextLengthsHost = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
mEagleNetGenPastKeyValueLengthsHost = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getMaxBatchSize()}), nvinfer1::DataType::kINT32);
// inputs
mBatchSlots
= BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32);
mEndIds
= BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32);
mInputNextDraftTokens = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxDecodingDraftTokens()}),
nvinfer1::DataType::kINT32);
mInputNextDraftLens
= BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32);
mInputNextDraftPaths = BufferManager::pinnedPool(
ITensor::makeShape(
{mSamplingParams.getBatchSize(), mSamplingParams.getMaxDecodingTokens(), mSamplingParams.getMaxPathLen()}),
nvinfer1::DataType::kINT32);
mInputLastDraftTokens = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxDecodingDraftTokens()}),
nvinfer1::DataType::kINT32);
mInputLastDraftLens
= BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32);
mInputLastDraftPaths = BufferManager::pinnedPool(
ITensor::makeShape(
{mSamplingParams.getBatchSize(), mSamplingParams.getMaxDecodingTokens(), mSamplingParams.getMaxPathLen()}),
nvinfer1::DataType::kINT32);
mInputAcceptedTokens = BufferManager::pinnedPool(
ITensor::makeShape({mSamplingParams.getBatchSize(), mSamplingParams.getMaxPathLen()}),
nvinfer1::DataType::kINT32);
mInputAcceptedLens
= BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32);
mInputAcceptedPathIds
= BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32);
mChunkedContextNextTokens
= BufferManager::pinnedPool(ITensor::makeShape({mSamplingParams.getBatchSize()}), nvinfer1::DataType::kINT32);
mDecodingWorkspace = std::make_shared<tensorrt_llm::runtime::DecodingLayerWorkspace>(mBufferManager, decodingDomain,
TRTDataType<float>::value, mSamplingParams.getMaxBatchSize() * sizeof(curandState_t));
}
template <typename T>
void EagleDecodingLayerTest<T>::setup()
{
// outputs
trk::invokeFill(*mOutputIds, TokenIdType{-1}, *mStream);
trk::invokeFill(*mSeqLengths, SizeType32{0}, *mStream);
trk::invokeFill(*mOutputNextDraftTokens, TokenIdType{-1}, *mStream);
trk::invokeFill(*mOutputUnpackedNextDraftTokens, TokenIdType{-1}, *mStream);
trk::invokeFill(*mAcceptedLengths, SizeType32{0}, *mStream);
trk::invokeFill(*mNextPosIds, SizeType32{0}, *mStream);
trk::invokeFill(*mPrevDraftLengths, SizeType32{0}, *mStream);
trk::invokeFill(*mNextDraftLengths, SizeType32{0}, *mStream);
trk::invokeFill(*mNextGenerationLengths, SizeType32{0}, *mStream);
trk::invokeFill(*mNextGenerationLengthsHost, SizeType32{0}, *mStream);
trk::invokeFill(*mAcceptedLengthCumSum, SizeType32{-1}, *mStream);
trk::invokeFill(*mPathsOffsets, SizeType32{0}, *mStream);
trk::invokeFill(*mPackedMasks, SizeType32{0}, *mStream);
trk::invokeFill(*mEndIds, TokenIdType{-1}, *mStream);
trk::invokeFill(*mRandomDataSample, float{0}, *mStream);
trk::invokeFill(*mRandomDataValidation, float{0}, *mStream);
trk::invokeFill(*mOutputTemperatures, float{0}, *mStream);
trk::invokeFill(*mOutputNextDraftPaths, SizeType32{0}, *mStream);
trk::invokeFill(*mChunkedContextNextTokens, SizeType32{-1}, *mStream);
std::mt19937 gen(42);
auto batchSlotsPtr = bufferCast<SizeType32>(*mBatchSlots);
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
{
batchSlotsPtr[bi] = 2 * bi;
}
auto setupParams = std::make_shared<EagleSetupParams>();
mRandomSeeds = std::vector<uint64_t>(mSamplingParams.getBatchSize());
mTemperatures = std::vector<float>(mSamplingParams.getBatchSize());
std::mt19937 generator(42);
std::uniform_int_distribution<uint64_t> seedDistr(1, 1000);
std::uniform_real_distribution<float> temperatureDistr(0.001f, 1.f);
std::generate(
mRandomSeeds.begin(), mRandomSeeds.end(), [&generator, &seedDistr]() { return seedDistr(generator); });
std::generate(mTemperatures.begin(), mTemperatures.end(),
[&generator, &temperatureDistr]() { return temperatureDistr(generator); });
setupParams->randomSeed = mRandomSeeds;
setupParams->temperature = mTemperatures;
setupParams->randomDataSample = mRandomDataSample;
setupParams->temperatures = mOutputTemperatures;
mDecodingWorkspace->setDeviceBatchSlots(mBatchSlots);
mEagleLayer->setup(mSamplingParams.getBatchSize(), 1, mBatchSlots, setupParams, mDecodingWorkspace);
mStream->synchronize();
mInputAcceptedLens = mBufferManager->copyFrom(mNetwork.getAcceptedLens(),
ITensor::makeShape({mSamplingParams.getBatchSize()}), runtime::MemoryType::kPINNEDPOOL);
mInputAcceptedPathIds = mBufferManager->copyFrom(mNetwork.getAcceptedPathIds(),
ITensor::makeShape({mSamplingParams.getBatchSize()}), runtime::MemoryType::kPINNEDPOOL);
auto const nextDraftTokens = mNetwork.getNextDraftTokens();
auto const lastDraftTokens = mNetwork.getLastDraftTokens();
auto const nextDraftPaths = mNetwork.getNextDraftPaths();
auto const lastDraftPaths = mNetwork.getLastDraftPaths();
auto const nextDraftLens = mNetwork.getNextDraftLens();
auto const lastDraftLens = mNetwork.getLastDraftLens();
auto const acceptedTokens = mNetwork.getAcceptedTokens();
auto sequenceLength = BufferRange<SizeType32>(*mSeqLengths);
auto inputNextDraftTokensRange = BufferRange<TokenIdType>(*mInputNextDraftTokens);
auto inputLastDraftTokensRange = BufferRange<TokenIdType>(*mInputLastDraftTokens);
auto inputNextDraftPathsRange = BufferRange<SizeType32>(*mInputNextDraftPaths);
auto inputLastDraftPathsRange = BufferRange<SizeType32>(*mInputLastDraftPaths);
auto inputNextDraftLensRange = BufferRange<SizeType32>(*mInputNextDraftLens);
auto inputLastDraftLensRange = BufferRange<SizeType32>(*mInputLastDraftLens);
auto inputAcceptedTokensRange = BufferRange<SizeType32>(*mInputAcceptedTokens);
auto outputIds = BufferRange<TokenIdType>(*mOutputIds);
auto prompts = mNetwork.getPrompts();
for (SizeType32 bi = 0; bi < nextDraftTokens.size(); ++bi)
{
for (SizeType32 ti = 0; ti < nextDraftTokens[bi].size(); ++ti)
{
auto idx = flat_index2(bi, ti, mSamplingParams.getMaxDecodingDraftTokens());
inputNextDraftTokensRange[idx] = nextDraftTokens[bi][ti];
}
for (SizeType32 ti = 0; ti < lastDraftTokens[bi].size(); ++ti)
{
auto idx = flat_index2(bi, ti, mSamplingParams.getMaxDecodingDraftTokens());
inputLastDraftTokensRange[idx] = lastDraftTokens[bi][ti];
}
for (SizeType32 pi = 0; pi < nextDraftPaths[bi].size(); ++pi)
{
for (SizeType32 ti = 0; ti < nextDraftPaths[bi][pi].size(); ++ti)
{
auto idx
= flat_index3(bi, pi, ti, mSamplingParams.getMaxDecodingTokens(), mSamplingParams.getMaxPathLen());
inputNextDraftPathsRange[idx] = nextDraftPaths[bi][pi][ti];
}
}
for (SizeType32 pi = 0; pi < lastDraftPaths[bi].size(); ++pi)
{
for (SizeType32 ti = 0; ti < lastDraftPaths[bi][pi].size(); ++ti)
{
auto idx
= flat_index3(bi, pi, ti, mSamplingParams.getMaxDecodingTokens(), mSamplingParams.getMaxPathLen());
inputLastDraftPathsRange[idx] = lastDraftPaths[bi][pi][ti];
}
}
inputNextDraftLensRange[bi] = nextDraftLens[bi];
inputLastDraftLensRange[bi] = lastDraftLens[bi];
sequenceLength[batchSlotsPtr[bi]] = prompts[bi].size();
}
for (SizeType32 bi = 0; bi < acceptedTokens.size(); ++bi)
{
for (SizeType32 ti = 0; ti < acceptedTokens[bi].size(); ++ti)
{
auto idx = flat_index2(bi, ti, mSamplingParams.getMaxPathLen());
inputAcceptedTokensRange[idx] = acceptedTokens[bi][ti];
}
}
}
template <typename T>
std::shared_ptr<EagleInputs> EagleDecodingLayerTest<T>::createInputTensors()
{
auto forwardParams
= std::make_shared<EagleInputs>(mEndIds, mBatchSlots, mSamplingParams.getBatchSize(), mInputNextDraftTokens,
mInputNextDraftLens, mInputNextDraftPaths, mInputLastDraftTokens, mInputLastDraftLens, mInputLastDraftPaths,
mInputAcceptedTokens, mInputAcceptedLens, mInputAcceptedPathIds, mChunkedContextNextTokens, mBatchSlots);
return forwardParams;
}
template <typename T>
std::shared_ptr<EagleOutputs> EagleDecodingLayerTest<T>::createOutputTensors()
{
auto outputParams = std::make_shared<EagleOutputs>(mOutputIds);
outputParams->sequenceLength = mSeqLengths;
outputParams->unpackedNextDraftTokens = mOutputUnpackedNextDraftTokens;
outputParams->nextDraftTokens = mOutputNextDraftTokens;
outputParams->numNewTokens = mAcceptedLengths;
outputParams->nextDraftPosIds = mNextPosIds;
outputParams->prevDraftLengths = mPrevDraftLengths;
outputParams->nextDraftLengths = mNextDraftLengths;
outputParams->generationLengths = mNextGenerationLengths;
outputParams->generationLengthsHost = mNextGenerationLengthsHost;
outputParams->numNewTokensCumSum = mAcceptedLengthCumSum;
outputParams->pathsOffsets = mPathsOffsets;
outputParams->packedMasks = mPackedMasks;
outputParams->randomDataSample = mRandomDataSample;
outputParams->randomDataValidation = mRandomDataValidation;
outputParams->temperatures = mOutputTemperatures;
outputParams->nextDraftPaths = mOutputNextDraftPaths;
outputParams->eagleNetCtxRequestTypesHost = mEagleNetCtxRequestTypesHost;
outputParams->eagleNetCtxContextLengthsHost = mEagleNetCtxContextLengthsHost;
outputParams->eagleNetCtxPastKeyValueLengthsHost = mEagleNetCtxPastKeyValueLengthsHost;
outputParams->eagleNetGenRequestTypesHost = mEagleNetGenRequestTypesHost;
outputParams->eagleNetGenContextLengthsHost = mEagleNetGenContextLengthsHost;
outputParams->eagleNetGenPastKeyValueLengthsHost = mEagleNetGenPastKeyValueLengthsHost;
return outputParams;
}
std::vector<int32_t> boolArrayToBitmask(std::vector<bool>::iterator boolIterator, size_t pathLen)
{
std::vector<int32_t> bitmask(divUp(pathLen, 32));
for (size_t bi = 0; bi < pathLen; ++bi)
{
auto slice = bi / 32;
if (boolIterator[bi])
{
bitmask[slice] |= (1 << (bi % 32));
}
}
return bitmask;
}
template <typename T>
void EagleDecodingLayerTest<T>::checkLayerResult()
{
auto const batchSlots = BufferRange<SizeType32>(*mBatchSlots);
// Check generated random data
{
auto const randomDataSample = BufferRange<float>(*mRandomDataSample);
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
{
auto const batchSlot = batchSlots[bi];
// Check that all fields are filled with non zero data
EXPECT_NE(randomDataSample[batchSlot], float{0}) << " bi: " << bi;
}
}
// Check masks
{
auto const randomDataValidation = BufferRange<float>(*mRandomDataValidation);
auto const packedMasks = BufferRange<int32_t>(*mPackedMasks);
auto masks = mNetwork.getNextMasks();
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
{
for (SizeType32 ti = 0; ti < mSamplingParams.getMaxDecodingTokens(); ++ti)
{
auto const batchSlot = batchSlots[bi];
auto const bitmask = boolArrayToBitmask(masks[bi][ti].begin(), mSamplingParams.getMaxDecodingTokens());
EXPECT_NE(randomDataValidation[batchSlot * mSamplingParams.getMaxDecodingTokens() + ti], float{0})
<< " bi: " << bi;
for (SizeType32 mi = 0; mi < bitmask.size(); ++mi)
{
auto const packedMaskIdx = flat_index3(batchSlot, ti, mi, mSamplingParams.getMaxDecodingTokens(),
static_cast<SizeType32>(divUp(mSamplingParams.getMaxDecodingTokens(), 32)));
EXPECT_EQ(bitmask[mi], packedMasks[packedMaskIdx]) << " bi: " << bi << " ti: " << ti;
}
}
}
}
// Check accepted tokens
auto const outputIds = BufferRange<TokenIdType>(*mOutputIds);
auto const refOutputIds = mNetwork.getOutputIds();
auto const promptIds = mNetwork.getPrompts();
auto const seqLenghts = BufferRange<SizeType32>(*mSeqLengths);
auto const acceptedLengths = BufferRange<SizeType32>(*mAcceptedLengths);
auto const inputAcceptedLens = BufferRange<SizeType32>(*mInputAcceptedLens);
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
{
auto const batchSlot = batchSlots[bi];
// Check accepted length.
EXPECT_EQ(inputAcceptedLens[bi], acceptedLengths[batchSlot]) << " bi:" << bi;
// Updated seq length is prompt length and newly accepted tokens.
EXPECT_EQ(seqLenghts[batchSlot], promptIds[bi].size() + acceptedLengths[batchSlot]) << " bi: " << bi;
// Check that output ids contains accepted tokens.
for (SizeType32 ti = promptIds[bi].size(); ti < acceptedLengths[batchSlot]; ++ti)
{
EXPECT_EQ(outputIds[batchSlot * mSamplingParams.getMaxSeqLen() + ti], refOutputIds[bi][ti])
<< " bi: " << bi << " ti: " << ti;
}
}
// Check new draft tokens
{
auto const outputNextDraftTokens = BufferRange<TokenIdType>(*mOutputNextDraftTokens);
auto const outputUnpackedNextDraftTokens = BufferRange<TokenIdType>(*mOutputUnpackedNextDraftTokens);
auto const nextDraftLens = mNetwork.getNextDraftLens();
auto const prevDraftLens = mNetwork.getLastDraftLens();
auto const nextDraftTokens = mNetwork.getNextDraftTokens();
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
{
auto const batchSlot = batchSlots[bi];
auto const nextDraftLen = nextDraftLens[bi];
auto const prevDraftLen = prevDraftLens[bi];
// Check draft tokens for the next iteration.
for (SizeType32 ti = 0; ti < nextDraftLen; ++ti)
{
auto const idx = flat_index2(batchSlot, ti, mSamplingParams.getMaxDecodingDraftTokens());
EXPECT_EQ(outputNextDraftTokens[idx], nextDraftTokens[bi][ti]) << " bi: " << bi << " ti: " << ti;
EXPECT_EQ(outputUnpackedNextDraftTokens[idx], nextDraftTokens[bi][ti])
<< " bi: " << bi << " ti: " << ti;
}
// Check length of the draft tokens.
EXPECT_EQ(BufferRange<SizeType32>(*mNextGenerationLengthsHost)[batchSlot], nextDraftLen + 1)
<< " bi: " << bi;
EXPECT_EQ(BufferRange<SizeType32>(*mNextDraftLengths)[batchSlot], nextDraftLen) << " bi: " << bi;
EXPECT_EQ(BufferRange<SizeType32>(*mPrevDraftLengths)[batchSlot], prevDraftLen) << " bi: " << bi;
for (SizeType32 pi = 0; pi < mSamplingParams.getMaxDecodingTokens(); ++pi)
{
for (SizeType32 ti = 0; ti < mSamplingParams.getMaxPathLen(); ++ti)
{
auto const idx = flat_index3(
bi, pi, ti, mSamplingParams.getMaxDecodingTokens(), mSamplingParams.getMaxPathLen());
auto const idxSlot = flat_index3(
batchSlot, pi, ti, mSamplingParams.getMaxDecodingTokens(), mSamplingParams.getMaxPathLen());
EXPECT_EQ(BufferRange<SizeType32>(*mOutputNextDraftPaths)[idxSlot],
BufferRange<SizeType32>(*mInputNextDraftPaths)[idx])
<< " bi: " << bi << " pi:" << pi << " ti: " << ti;
}
}
}
}
// Check position ids
{
auto const nextPosIds = BufferRange<SizeType32>(*mNextPosIds);
auto const nextDraftPaths = mNetwork.getNextDraftPaths();
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
{
auto const batchSlot = batchSlots[bi];
// Check pos ids for the next iteration.
for (SizeType32 pi = 0; pi < mSamplingParams.getMaxDecodingTokens(); ++pi)
{
for (SizeType32 li = 0; li < mSamplingParams.getMaxPathLen(); ++li)
{
auto const pathIdx = nextDraftPaths[bi][pi][li];
auto const idx = flat_index2(batchSlot, pathIdx, mSamplingParams.getMaxDecodingTokens());
if (pathIdx != -1)
{
EXPECT_EQ(nextPosIds[idx], li) << " bi: " << bi << " pi: " << pi << " li: " << li;
}
}
}
}
}
// Check accumulated cum sum and paths offsets
{
auto const accumulatedCumSum = BufferRange<SizeType32>(*mAcceptedLengthCumSum);
auto const pathsOffsets = BufferRange<SizeType32>(*mPathsOffsets);
auto const acceptedLengths = BufferRange<SizeType32>(*mAcceptedLengths);
auto const inputAcceptedPathIds = BufferRange<SizeType32>(*mInputAcceptedPathIds);
auto const lastDraftPaths = mNetwork.getLastDraftPaths();
SizeType32 sum = 0;
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
{
auto const batchSlot = batchSlots[bi];
EXPECT_EQ(sum, accumulatedCumSum[bi]) << "bi: " << bi;
auto const acceptedLength = acceptedLengths[batchSlot] - 1;
for (SizeType32 ti = 0; ti < acceptedLength; ++ti)
{
EXPECT_EQ(pathsOffsets[sum + ti], lastDraftPaths[bi][inputAcceptedPathIds[bi]][ti + 1] - 1)
<< "bi: " << bi << " ti: " << ti;
}
sum += acceptedLength;
}
EXPECT_EQ(sum, accumulatedCumSum[mSamplingParams.getBatchSize()]);
}
// Check temperature
{
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
{
auto const batchSlot = batchSlots[bi];
EXPECT_EQ(BufferRange<float>(*mOutputTemperatures)[batchSlot], static_cast<float>(mTemperatures[bi]))
<< " bi: " << bi;
}
}
// Check EagleNet host buffers
{
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
{
auto const batchSlot = batchSlots[bi];
EXPECT_EQ(BufferRange<SizeType32>(*mEagleNetCtxRequestTypesHost)[batchSlot], 0) << " bi: " << bi;
EXPECT_EQ(BufferRange<SizeType32>(*mEagleNetGenRequestTypesHost)[batchSlot], 1) << " bi: " << bi;
EXPECT_EQ(
BufferRange<SizeType32>(*mEagleNetCtxContextLengthsHost)[batchSlot], mSamplingParams.getMaxPathLen())
<< " bi: " << bi;
EXPECT_EQ(BufferRange<SizeType32>(*mEagleNetGenContextLengthsHost)[batchSlot],
seqLenghts[batchSlot] + mSamplingParams.getMaxPathLen())
<< " bi: " << bi;
EXPECT_EQ(BufferRange<SizeType32>(*mEagleNetCtxPastKeyValueLengthsHost)[batchSlot],
seqLenghts[batchSlot] + mSamplingParams.getMaxPathLen())
<< " bi: " << bi;
EXPECT_EQ(BufferRange<SizeType32>(*mEagleNetGenPastKeyValueLengthsHost)[batchSlot],
seqLenghts[batchSlot] + mSamplingParams.getMaxPathLen() - 1)
<< " bi: " << bi;
}
}
}
template <typename T>
void EagleDecodingLayerTest<T>::runTest(std::vector<std::string> const& prompts,
std::vector<DraftLettersVec> const& predictions, std::vector<DraftLettersVec> const& nextDraftLetters,
std::vector<DraftLettersVec> const& lastDraftLetters, SamplingParams& params)
{
mSamplingParams = params;
mNetwork.forward(params, prompts, predictions, nextDraftLetters, lastDraftLetters);
allocateBuffers();
setup();
auto inputTensors = createInputTensors();
auto outputTensors = createOutputTensors();
mDecodingWorkspace->setDeviceBatchSlots(mBatchSlots);
mEagleLayer->forwardAsync(outputTensors, inputTensors, mDecodingWorkspace);
mStream->synchronize();
checkLayerResult();
}
TYPED_TEST_SUITE(EagleDecodingLayerTest, FloatAndHalfTypes);
TYPED_TEST(EagleDecodingLayerTest, IOSamePathsBs1)
{
SamplingParams params;
params.setBatchSize(1);
params.setMaxPathLen(4);
params.setMaxDecodingTokens(10);
std::vector<std::string> prompts = {"Hi mate, "};
std::vector<DraftLettersVec> predictionLetters = {{"how ", "hoc", "hecl", "hea", "hu"}};
std::vector<DraftLettersVec> lastDraftLetters = {{"how", "he", "wow", "we", "a"}};
std::vector<DraftLettersVec> nextDraftLetters = {{"are", "ap", "cre", "co", "i"}};
this->runTest(prompts, predictionLetters, nextDraftLetters, lastDraftLetters, params);
}
TYPED_TEST(EagleDecodingLayerTest, IODifferentPathsBs1)
{
SamplingParams params;
params.setBatchSize(1);
params.setMaxPathLen(4);
params.setMaxDecodingTokens(10);
std::vector<std::string> prompts = {"Hi mate, "};
std::vector<DraftLettersVec> predictionLetters = {{"how ", "hoc", "hecl", "hea", "hu"}};
std::vector<DraftLettersVec> lastDraftLetters = {{"how", "he", "wow", "we", "a"}};
std::vector<DraftLettersVec> nextDraftLetters = {{"are", "is", "imp", "do"}};
this->runTest(prompts, predictionLetters, nextDraftLetters, lastDraftLetters, params);
}
TYPED_TEST(EagleDecodingLayerTest, IODifferentPathsNoDraftAcceptedBs1)
{
SamplingParams params;
params.setBatchSize(1);
params.setMaxPathLen(4);
params.setMaxDecodingTokens(10);
std::vector<std::string> prompts = {"Hi mate, "};
std::vector<DraftLettersVec> predictionLetters = {{"how ", "hoc", "hecl", "hea", "hu"}};
std::vector<DraftLettersVec> lastDraftLetters = {{"my", "I'd", "wow", "we", "a"}};
std::vector<DraftLettersVec> nextDraftLetters = {{"are", "ap", "cre", "co", "i"}};
this->runTest(prompts, predictionLetters, nextDraftLetters, lastDraftLetters, params);
}
TYPED_TEST(EagleDecodingLayerTest, IODifferentPathsBs2)
{
SamplingParams params;
params.setBatchSize(2);
params.setMaxPathLen(4);
params.setMaxDecodingTokens(10);
std::vector<std::string> prompts = {"Hi mate, ", "Let's go "};
std::vector<DraftLettersVec> predictionLetters
= {{"how ", "hoc", "hecl", "hea", "hu"}, {"bcde", "bcdc", "bca", "bcc", "bo"}};
std::vector<DraftLettersVec> lastDraftLetters = {{"how", "he", "wow", "we", "a"}, {"inc", "inf", "ir", "im", "b"}};
std::vector<DraftLettersVec> nextDraftLetters = {{"are", "is", "imp", "do"}, {"wli", "mbi", "ard"}};
this->runTest(prompts, predictionLetters, nextDraftLetters, lastDraftLetters, params);
}
} // namespace tensorrt_llm::tests::layers