TensorRT-LLMs/cpp/tests/kernels/decodingKernelTest.cpp
Kaiyu Xie 75057cd036
Update TensorRT-LLM (#2333)
* Update TensorRT-LLM

---------

Co-authored-by: Puneesh Khanna <puneesh.khanna@tii.ae>
Co-authored-by: Ethan Zhang <26497102+ethnzhng@users.noreply.github.com>
2024-10-15 15:28:40 +08:00

1688 lines
73 KiB
C++

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TOP_LEVEL_DIR
#error "Define TOP_LEVEL_DIR"
#endif
#include <gtest/gtest.h>
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/kernels/speculativeDecoding/externalDraftTokensKernels.h"
#include "tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include <curand_kernel.h>
#include <random>
#include <unordered_set>
namespace tk = tensorrt_llm::kernels;
namespace tksp = tensorrt_llm::kernels::speculative_decoding;
namespace tc = tensorrt_llm::common;
namespace trk = tensorrt_llm::runtime::kernels;
using namespace tensorrt_llm::runtime;
namespace
{
inline bool almostEqual(float a, float b, float atol = 1e-2, float rtol = 1e-3)
{
// Params: a = value to compare and b = reference
// This function follows implementation of numpy.isclose(), which checks
// abs(a - b) <= (atol + rtol * abs(b)).
// Note that the inequality above is asymmetric where b is considered as
// a reference value. To account into both absolute/relative errors, it
// uses absolute tolerance and relative tolerance at the same time. The
// default values of atol and rtol borrowed from numpy.isclose(). For the
// case of nan value, the result will be true.
if (isnan(a) && isnan(b))
{
return true;
}
return fabs(a - b) <= (atol + rtol * fabs(b));
}
std::vector<float> calculateGaussianKernel(float sigma, int size)
{
std::vector<float> kernel(size);
float sum = 0.f;
for (int i = 0; i < size; ++i)
{
int x = i - size / 2;
kernel[i] = std::exp(-0.5f * (x * x) / (sigma * sigma));
sum += kernel[i];
}
// Normalize the kernel
for (int i = 0; i < size; ++i)
{
kernel[i] /= sum;
}
return kernel;
}
template <typename T>
void applyGaussianFilter(T* result, float const* input, int n, float sigma)
{
int size = static_cast<int>(std::ceil(6.f * sigma));
size = (size % 2 == 0) ? size + 1 : size;
std::vector<float> kernel = calculateGaussianKernel(sigma, size);
int halfSize = size / 2;
for (int i = 0; i < n; ++i)
{
result[i] = T{0};
}
// Convolution operation
for (int i = 0; i < n; ++i)
{
for (int j = 0; j < size; ++j)
{
int k = i - halfSize + j;
if (k >= 0 && k < n)
{
result[i] += input[k] * kernel[j];
}
}
}
}
template void applyGaussianFilter(float* result, float const* input, int n, float sigma);
template void applyGaussianFilter(__half* result, float const* input, int n, float sigma);
template <typename T>
void probsToLogits(T const* probs, T* logits, SizeType32 n)
{
constexpr float eps = 1e-6f;
for (SizeType32 ni = 0; ni < n; ++ni)
{
auto const prob = std::max(eps, static_cast<float>(probs[ni]));
logits[ni] = std::log(prob / (1.f - prob));
}
}
template <typename T>
void softmax(T const* logits, T* probs, int n)
{
float epsilon = 1e-6f;
// Find the maximum logit value
float maxLogits = -std::numeric_limits<float>::max();
for (int ii = 0; ii < n; ++ii)
{
maxLogits = std::max(maxLogits, static_cast<float>(logits[ii]));
}
// Calculate the numerator of the softmax formula
float expSum = 0.0;
for (int ii = 0; ii < n; ++ii)
{
expSum += std::exp(static_cast<float>(logits[ii]) - maxLogits);
}
// Calculate softmax probabilities
for (int ii = 0; ii < n; ++ii)
{
float prob = std::exp(static_cast<float>(logits[ii]) - maxLogits) / (expSum + epsilon);
probs[ii] = prob;
}
}
template void probsToLogits(float const* probs, float* logits, SizeType32 n);
template void probsToLogits(__half const* probs, __half* logits, SizeType32 n);
template <typename T>
void checkEquality(DecodingOutput::TensorPtr src, DecodingOutput::TensorPtr dst, char const* bufferName,
tensorrt_llm::runtime::BufferManager& bufferManager)
{
auto srcHost = bufferManager.copyFrom(*src, MemoryType::kPINNEDPOOL);
auto dstHost = bufferManager.copyFrom(*dst, MemoryType::kPINNEDPOOL);
bufferManager.getStream().synchronize();
auto srcPtr = bufferCast<T>(*srcHost);
auto dstPtr = bufferCast<T>(*dstHost);
for (SizeType32 ii = 0; ii < src->getSize(); ++ii)
{
// since it's a simple copy, floats support the simple equality
EXPECT_EQ(srcPtr[ii], dstPtr[ii]) << "Unequal values in buffer " << bufferName << " at ii: " << ii
<< " with values: src " << srcPtr[ii] << " dst " << dstPtr[ii] << std::endl;
}
}
template <typename T>
void fillBufferWithRandom(ITensor& buffer, tensorrt_llm::runtime::BufferManager& bufferManager, std::mt19937& randGen)
{
auto cpuBuffer = bufferManager.cpu(buffer.getShape(), TRTDataType<T>::value);
auto const size = cpuBuffer->getSize();
auto rawPtr = bufferCast<T>(*cpuBuffer);
std::uniform_int_distribution<> dis(0, 255);
for (SizeType32 i = 0; i < size; ++i)
{
rawPtr[i] = static_cast<T>(dis(randGen));
}
bufferManager.copy(*cpuBuffer, buffer);
}
class TestBeamHypothesesCopy : public ::testing::Test
{
public:
DecodingOutput::BeamHypotheses srcBeams;
DecodingOutput::BeamHypotheses dstBeams;
DecodingOutput::TensorPtr mSrcCumLogProbs;
DecodingOutput::TensorPtr mDstCumLogProbs;
SizeType32 mNumSMs;
std::shared_ptr<tensorrt_llm::runtime::CudaStream> mStream;
std::shared_ptr<tensorrt_llm::runtime::BufferManager> mBufferManager;
std::mt19937 gen;
void SetUp() override
{
mStream = std::make_shared<tensorrt_llm::runtime::CudaStream>();
mBufferManager = std::make_shared<tensorrt_llm::runtime::BufferManager>(mStream);
int device;
cudaGetDevice(&device);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, device);
mNumSMs = deviceProp.multiProcessorCount;
gen.seed(42U);
}
void initializeBuffers(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 maxSeqLen)
{
srcBeams.empty(*mBufferManager);
srcBeams.reshape(batchSize, beamWidth, maxSeqLen);
mSrcCumLogProbs = mBufferManager->gpu(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kFLOAT);
setBuffers(srcBeams, mSrcCumLogProbs, 2);
dstBeams.empty(*mBufferManager);
dstBeams.reshape(batchSize, beamWidth, maxSeqLen);
mDstCumLogProbs = mBufferManager->gpu(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kFLOAT);
setBuffers(dstBeams, mDstCumLogProbs, 1);
}
void setBuffers(DecodingOutput::BeamHypotheses currBeams, DecodingOutput::TensorPtr cumLogProbs, int value)
{
fillBufferWithRandom<SizeType32>(*currBeams.outputIdsCBA, *mBufferManager, gen);
fillBufferWithRandom<float>(*currBeams.logProbsCBA, *mBufferManager, gen);
fillBufferWithRandom<SizeType32>(*currBeams.sequenceLengthsCBA, *mBufferManager, gen);
fillBufferWithRandom<float>(*currBeams.cumLogProbsCBA, *mBufferManager, gen);
fillBufferWithRandom<float>(*currBeams.normedScoresCBA, *mBufferManager, gen);
fillBufferWithRandom<SizeType32>(*currBeams.numBeamsCBA, *mBufferManager, gen);
fillBufferWithRandom<float>(*currBeams.minNormedScoresCBA, *mBufferManager, gen);
fillBufferWithRandom<bool>(*currBeams.batchDones, *mBufferManager, gen);
fillBufferWithRandom<float>(*cumLogProbs, *mBufferManager, gen);
}
void checkAllEqual()
{
checkEquality<SizeType32>(srcBeams.outputIdsCBA, dstBeams.outputIdsCBA, "outputIdsCBA", *mBufferManager);
checkEquality<float>(srcBeams.logProbsCBA, dstBeams.logProbsCBA, "logProbsCBA", *mBufferManager);
checkEquality<SizeType32>(
srcBeams.sequenceLengthsCBA, dstBeams.sequenceLengthsCBA, "sequenceLengthsCBA", *mBufferManager);
checkEquality<float>(srcBeams.cumLogProbsCBA, dstBeams.cumLogProbsCBA, "cumLogProbsCBA", *mBufferManager);
checkEquality<float>(srcBeams.normedScoresCBA, dstBeams.normedScoresCBA, "normedScoresCBA", *mBufferManager);
checkEquality<SizeType32>(srcBeams.numBeamsCBA, dstBeams.numBeamsCBA, "numBeamsCBA", *mBufferManager);
checkEquality<float>(
srcBeams.minNormedScoresCBA, dstBeams.minNormedScoresCBA, "minNormedScoresCBA", *mBufferManager);
checkEquality<bool>(srcBeams.batchDones, dstBeams.batchDones, "batchDones", *mBufferManager);
checkEquality<float>(mSrcCumLogProbs, mDstCumLogProbs, "cumLogProbs", *mBufferManager);
}
};
// Test for invokeCopyBeamHypotheses
TEST_F(TestBeamHypothesesCopy, FullBatchTest)
{
SizeType32 const batchSize{1024};
SizeType32 const beamWidth{64};
SizeType32 const maxSeqLen{2048};
initializeBuffers(batchSize, beamWidth, maxSeqLen);
mStream->synchronize();
tk::invokeCopyBeamHypotheses(srcBeams, dstBeams, *mSrcCumLogProbs, *mDstCumLogProbs, *mStream, mNumSMs);
mStream->synchronize();
checkAllEqual();
}
TEST_F(TestBeamHypothesesCopy, SingleBatchTest)
{
SizeType32 const batchSize{1};
SizeType32 const beamWidth{64};
SizeType32 const maxSeqLen{16384};
initializeBuffers(batchSize, beamWidth, maxSeqLen);
mStream->synchronize();
tk::invokeCopyBeamHypotheses(srcBeams, dstBeams, *mSrcCumLogProbs, *mDstCumLogProbs, *mStream, mNumSMs);
mStream->synchronize();
checkAllEqual();
}
/**
* @brief Fills a slice of a tensor with data from a source array.
*
* This function writes to `tensor` from source array `src` at index `idx.
* It optionally flattens the tensor before performing the insertion.
* For example tensor if we wanted to write 5 values in the 3rd row of [1,10,100]
* We will use (tensor, 2, 5, src, true, mBufferManager) where src is a buffer with at least 5 elems.
*
* @tparam T The type of elements in the source array.
* @param tensor A shared pointer to the tensor to be modified. Also need to be of type T.
* @param idx The index at which to start inserting data into the tensor.
* @param insertLen The number of elements to insert from the source array into the tensor.
* @param src An array containing the data to be inserted into the tensor.
* @param flattenFirst A boolean flag indicating whether to flatten the first dimension of the tensor before insertion.
* @param bufferManager A shared pointer to a BufferManager responsible for managing memory operations.
*/
template <typename T>
void fillTensorAtIndex(ITensor::SharedPtr tensor, SizeType32 idx, std::vector<T> src, bool flattenFirst,
std::shared_ptr<tensorrt_llm::runtime::BufferManager> bufferManager)
{
SizeType32 insertLen = src.size();
ITensor::SharedPtr target = ITensor::view(tensor);
if (flattenFirst)
{
target->squeeze(0);
}
target = ITensor::slice(target, idx, 1);
target->squeeze(0);
target = ITensor::slice(target, 0, insertLen);
bufferManager->copy(src.data(), *target);
}
class TestGatherTree : public ::testing::Test
{
public:
SizeType32 batchSize{1};
SizeType32 beamWidth{5};
SizeType32 maxSeqLen{20};
using TensorPtr = ITensor::SharedPtr;
using DecodingOutputPtr = std::unique_ptr<DecodingOutput>;
DecodingOutputPtr decodingOutput{nullptr};
SamplingConfig samplingConfig = SamplingConfig();
std::shared_ptr<tensorrt_llm::runtime::CudaStream> mStream{nullptr};
std::shared_ptr<tensorrt_llm::runtime::BufferManager> mBufferManager{nullptr};
SamplingConfig mSamplingConfig;
using DecodingInputPtr = std::unique_ptr<DecodingInput>;
DecodingInputPtr decodingInput{nullptr};
TensorPtr targetOut{nullptr};
void SetUp() override
{
mStream = std::make_shared<tensorrt_llm::runtime::CudaStream>();
mBufferManager = std::make_shared<tensorrt_llm::runtime::BufferManager>(mStream);
}
// create the empty buffers with the correct shapes and zero them
void createBuffers()
{
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
auto constexpr nvSizeType = TRTDataType<SizeType32>::value;
auto constexpr nvFloatType = TRTDataType<float>::value;
auto const maxBatchSizeShape = ITensor::makeShape({batchSize});
auto const maxBatchSizeXmaxBeamWidth = ITensor::makeShape({batchSize, beamWidth});
auto const jointOutputIdsShape = ITensor::makeShape({batchSize, beamWidth, maxSeqLen});
{ // prevent reusing these vars after std::move
auto dummyLogits = mBufferManager->emptyTensor(MemoryType::kGPU, nvFloatType);
auto endIds = mBufferManager->emptyTensor(MemoryType::kGPU, nvTokenIdType);
auto batchSlots = mBufferManager->emptyTensor(MemoryType::kPINNED, nvSizeType);
decodingInput = std::make_unique<DecodingInput>(
0, 0, 0, 0, std::move(dummyLogits), std::move(endIds), std::move(batchSlots));
}
auto& dInput = *decodingInput;
dInput.maxLength = maxSeqLen;
const_cast<ITensor&>(*dInput.endIds).reshape(maxBatchSizeShape);
const_cast<ITensor&>(*dInput.batchSlots).reshape(maxBatchSizeShape);
const_cast<ITensor&>(*dInput.endIds).reshape(maxBatchSizeShape);
const_cast<ITensor&>(*dInput.batchSlots).reshape(maxBatchSizeShape);
auto& inputLengths = const_cast<ITensor&>(*dInput.lengths);
dInput.lengths = mBufferManager->gpu(maxBatchSizeXmaxBeamWidth, nvSizeType);
mBufferManager->setZero(const_cast<ITensor&>(*dInput.lengths));
{ // prevent reusing these vars after std::move
auto ids = mBufferManager->gpu(jointOutputIdsShape, nvTokenIdType);
mBufferManager->setZero(*ids);
auto gatheredIds = mBufferManager->gpu(jointOutputIdsShape, nvTokenIdType);
mBufferManager->setZero(*gatheredIds);
decodingOutput = std::make_unique<DecodingOutput>(std::move(ids), std::move(gatheredIds));
}
auto& dOutput = *decodingOutput;
dOutput.logProbs = mBufferManager->gpu(jointOutputIdsShape, nvFloatType);
mBufferManager->setZero(*dOutput.logProbs);
dOutput.logProbsTiled = mBufferManager->gpu(ITensor::makeShape({maxSeqLen, batchSize, beamWidth}), nvFloatType);
mBufferManager->setZero(*dOutput.logProbsTiled);
dOutput.lengths = mBufferManager->gpu(ITensor::makeShape({batchSize, beamWidth}), nvSizeType);
mBufferManager->setZero(*dOutput.lengths);
dOutput.cumLogProbs = mBufferManager->gpu(maxBatchSizeXmaxBeamWidth, nvFloatType);
mBufferManager->setZero(*dOutput.cumLogProbs);
dOutput.beamHypotheses.empty(*mBufferManager);
dOutput.beamHypotheses.reshape(batchSize, beamWidth, maxSeqLen);
dOutput.finishReasons
= mBufferManager->gpu(maxBatchSizeXmaxBeamWidth, TRTDataType<tk::FinishedState::UnderlyingType>::value);
mBufferManager->setZero(*dOutput.finishReasons);
dOutput.parentIds = mBufferManager->gpu(jointOutputIdsShape, nvTokenIdType);
mBufferManager->setZero(*dOutput.parentIds);
targetOut = mBufferManager->gpu(jointOutputIdsShape, nvTokenIdType);
mBufferManager->setZero(*targetOut);
}
// clang-format off
// hardcode the input data for the output_len = 10 case
// this should not cause any beam swapping from the CBAs, just reorder the beams
void hardcodeBuffersLen10()
{
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
auto constexpr nvSizeType = TRTDataType<SizeType32>::value;
auto constexpr nvFloatType = TRTDataType<float>::value;
std::vector<SizeType32> len = {3, 3, 3, 3, 3};
TensorPtr inputLengths{ITensor::slice(constPointerCast(decodingInput->lengths), 0, 1)};
mBufferManager->copy(len.data(),*inputLengths);
std::vector<std::vector<float>> logProbs =
{
{-2.96689, -1.63675, -2.31329, -0.0377979, -2.2442, -1.57552, -0.310524, -0.696636, -2.41985},
{-2.96689, -1.63675, -2.31329, -0.0377979, -1.31451, -2.63339, -0.534199, -0.493615, -2.61479},
{-2.96689, -1.63675, -2.31329, -0.0377979, -2.2442, -1.57552, -0.310524, -3.11851, -1.01671},
{-2.96689, -1.63675, -2.31329, -0.0377979, -1.31451, -2.63339, -0.534199, 0, 0},
{-2.96689, -1.63675, -2.31329, -0.0377979, -2.2442, -1.57552, -0.310524, -0.696636, -3.62298}
};
for (SizeType32 it = 0; it < logProbs.size(); it++){
fillTensorAtIndex(decodingOutput->logProbs, it, logProbs[it], true, mBufferManager);
}
std::vector<std::vector<float>> logProbsTiled =
{
{-2.70907, -2.96689, -3.27157, -3.37314, -3.50595},
{-1.84733, -1.8942, -1.63675, -1.9567, -1.47513},
{-0.305059, -0.765237, -2.31329, -2.37162, -2.48475},
{-1.97517, -0.0377979, -2.0169, -2.42439, -2.27471},
{-1.31451, -2.2442, -1.5831, -2.44732, -2.02409},
{-1.57552, -2.63339, -2.11286, -2.57304, -3.85214},
{-0.310524, -0.534199, -0.74379, -2.86232, -1.72914},
{-0.696636, -0.493615, -0.237725, -3.07164, -3.11851},
{-2.41985, -2.61479, -1.01671, -3.62298, -1.26586},
{-0.844337, -0.922832, -0.427682, -0.419985, -1.85996}
};
TensorPtr logProbsTiledView = ITensor::view(decodingOutput->logProbsTiled,ITensor::makeShape({maxSeqLen*batchSize, beamWidth}));
for (SizeType32 it = 0; it < logProbsTiled.size(); it++){
auto logProbsSlice = ITensor::slice(logProbsTiledView, it+3,1);
mBufferManager->copy(logProbsTiled[it].data(),*logProbsSlice);
}
std::vector<SizeType32> outputLenghts = {13, 13, 13, 13, 13};
mBufferManager->copy(outputLenghts.data(),*decodingOutput->lengths);
std::vector<float> cumLogProbs = {-15.0458, -15.4681, -15.8323, -15.8424, -16.0614};
mBufferManager->copy(cumLogProbs.data(),*decodingOutput->cumLogProbs);
std::vector<std::vector<TokenIdType>> outputIdsCBA =
{
{1, 864, 304, 367, 263, 760, 310, 278, 3815, 29973},
{1, 864, 304, 367, 263, 760, 310, 1749, 3815, 29973}
};
for(SizeType32 it = 0; it < outputIdsCBA.size(); it++)
{
fillTensorAtIndex(decodingOutput->beamHypotheses.outputIdsCBA, it, outputIdsCBA[it], true, mBufferManager);
}
std::vector<std::vector<float>> logProbsCBA =
{
{0, 0, 0, -2.96689, -1.63675, -2.31329, -0.0377979, -1.31451, -2.63339, -0.534199, -2.19674},
{0, 0, 0, -2.96689, -1.63675, -2.31329, -0.0377979, -2.2442, -1.57552, -0.310524, -2.81382,}
};
for(SizeType32 it = 0; it < logProbsCBA.size(); it++)
{
fillTensorAtIndex(decodingOutput->beamHypotheses.logProbsCBA, it, logProbsCBA[it], true, mBufferManager);
}
std::vector<SizeType32> sequenceLengthsCBA = {10, 10, 0, 0, 0, 0, 0, 0, 0, 0};
mBufferManager->copy(sequenceLengthsCBA.data(), *decodingOutput->beamHypotheses.sequenceLengthsCBA);
std::vector<float> cumLogProbsCBA = {-13.6336, -13.8988, 0, 0, 0, 0, 0, 0, 0, 0};
mBufferManager->copy(cumLogProbsCBA.data(), *decodingOutput->beamHypotheses.cumLogProbsCBA);
std::vector<float> normedScoresCBA = {-1.7042, -1.73735, 0, 0, 0, 0, 0, 0, 0, 0};
mBufferManager->copy(normedScoresCBA.data(), *decodingOutput->beamHypotheses.normedScoresCBA);
std::vector<SizeType32> numBeamsCBA = {2};
mBufferManager->copy(numBeamsCBA.data(), *decodingOutput->beamHypotheses.numBeamsCBA);
std::vector<float> minNormedScoresCBA = {-1.73735};
mBufferManager->copy(minNormedScoresCBA.data(), *decodingOutput->beamHypotheses.minNormedScoresCBA);
std::vector<SizeType32> batchDones = {0};
mBufferManager->copy(batchDones.data(), *decodingOutput->beamHypotheses.batchDones);
std::vector<uint8_t> finishReasons = {4, 4, 4, 4, 4};
mBufferManager->copy(finishReasons.data(), *decodingOutput->finishReasons);
std::vector<std::vector<TokenIdType>> ids =
{
{1, 864, 304, 1073, 825, 1048, 278, 278, 3815, 29973, 13, 4806, 526},
{1, 864, 304, 367, 920, 304, 310, 1749, 3815, 29973, 13, 4806, 526},
{1, 864, 304, 679, 263, 760, 679, 263, 29973, 13, 310, 526, 502},
{1, 864, 304, 1207, 901, 278, 1749, 445, 3889, 393, 591, 13443, 276},
{1, 864, 304, 1074, 263, 29973, 1207, 263, 2446, 12623, 1334, 29915, 30010}
};
for(SizeType32 it = 0; it < ids.size(); it++)
{
fillTensorAtIndex(decodingOutput->ids, it, ids[it], true, mBufferManager);
}
std::vector<std::vector<SizeType32>> parentIds =
{
{0, 0, 0, 0, 0, 3, 0, 1, 1, 0, 0, 0, 0},
{0, 0, 0, 0, 0, 1, 2, 1, 0, 1, 1, 1, 1},
{0, 0, 0, 0, 1, 2, 1, 4, 3, 2, 4, 4, 3},
{0, 0, 0, 0, 0, 0, 0, 1, 4, 1, 0, 0, 4},
{0, 0, 0, 0, 3, 3, 1, 2, 0, 4, 0, 3, 0}
};
for(SizeType32 it = 0; it < parentIds.size(); it++)
{
fillTensorAtIndex(decodingOutput->parentIds, it, parentIds[it], true, mBufferManager);
}
std::vector<std::vector<TokenIdType>> targetOutput =
{
{1, 864, 304, 367, 263, 760, 310, 1749, 3815, 29973, 13, 4806, 526},
{1, 864, 304, 367, 263, 760, 310, 278, 3815, 29973, 13, 4806, 526},
{1, 864, 304, 367, 263, 760, 310, 1749, 3815, 29973, 13, 13443, 502},
{1, 864, 304, 367, 263, 760, 310, 1749, 3815, 29973, 591, 29915, 276},
{1, 864, 304, 367, 263, 760, 310, 1749, 3815, 29973, 13, 4806, 30010}
};
for(SizeType32 it = 0; it < targetOutput.size(); it++)
{
fillTensorAtIndex(targetOut, it, targetOutput[it], true, mBufferManager);
}
}
// this case has the output_len = 8, and tests that the beams from the CBAs are correctly swapped.
void hardcodeBuffersLen8()
{
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
auto constexpr nvSizeType = TRTDataType<SizeType32>::value;
auto constexpr nvFloatType = TRTDataType<float>::value;
std::vector<SizeType32> len = {3, 3, 3, 3, 3};
TensorPtr inputLengths{ITensor::slice(constPointerCast(decodingInput->lengths), 0, 1)};
mBufferManager->copy(len.data(),*inputLengths);
std::vector<std::vector<float> >logProbs =
{
{-2.96689, -1.63675, -2.31329, -0.0377979, -2.2442, -1.57552, -0.310524},
{-2.96689, -1.63675, -2.31329, -0.0377979, -1.31451, -2.63339, -0.534199},
{-2.96689, -1.63675, -2.31329, -0.0377979, -2.44732, -2.11286, -0.74379},
{-2.96689, -1.63675, -2.31329, -0.0377979, -1.31451, -2.63339, -2.86232},
{-2.96689, -1.63675, -2.31329, -0.0377979, -1.31451, -3.85214, -1.72914}
};
for (SizeType32 it = 0; it < logProbs.size(); it++){
fillTensorAtIndex(decodingOutput->logProbs, it, logProbs[it], true, mBufferManager);
}
std::vector<std::vector<float>> logProbsTiled =
{
{-2.70907, -2.96689, -3.27157, -3.37314, -3.50595},
{-1.84733, -1.8942, -1.63675, -1.9567, -1.47513},
{-0.305059, -0.765237, -2.31329, -2.37162, -2.48475},
{-1.97517, -0.0377979, -2.0169, -2.42439, -2.27471},
{-1.31451, -2.2442, -1.5831, -2.44732, -2.02409},
{-1.57552, -2.63339, -2.11286, -2.57304, -3.85214},
{-0.310524, -0.534199, -0.74379, -2.86232, -1.72914},
{-0.696636, -0.493615, -0.237725, -3.07164, -3.11851}
};
TensorPtr logProbsTiledView = ITensor::view(decodingOutput->logProbsTiled,ITensor::makeShape({maxSeqLen*batchSize, beamWidth}));
for (SizeType32 it = 0; it < logProbsTiled.size(); it++){
auto logProbsSlice = ITensor::slice(logProbsTiledView, it+3,1);
mBufferManager->copy(logProbsTiled[it].data(),*logProbsSlice);
}
std::vector<SizeType32> outputLenghts = {11, 11, 11, 11, 11};
mBufferManager->copy(outputLenghts.data(),*decodingOutput->lengths);
std::vector<float> cumLogProbs = {-11.7816, -11.9304, -14.0883, -14.1566, -14.2035};
mBufferManager->copy(cumLogProbs.data(),*decodingOutput->cumLogProbs);
std::vector<std::vector<TokenIdType>> outputIdsCBA =
{
{1, 864, 304, 367, 263, 760, 310, 278, 3815, 29973},
{1, 864, 304, 367, 263, 760, 310, 1749, 3815, 29973}
};
for(SizeType32 it = 0; it < outputIdsCBA.size(); it++)
{
fillTensorAtIndex(decodingOutput->beamHypotheses.outputIdsCBA, it, outputIdsCBA[it], true, mBufferManager);
}
std::vector<std::vector<float>> logProbsCBA =
{
{0, 0, 0, -2.96689, -1.63675, -2.31329, -0.0377979, -1.31451, -2.63339, -0.534199, -2.19674},
{0, 0, 0, -2.96689, -1.63675, -2.31329, -0.0377979, -2.2442, -1.57552, -0.310524, -2.81382,}
};
for(SizeType32 it = 0; it < logProbsCBA.size(); it++)
{
fillTensorAtIndex(decodingOutput->beamHypotheses.logProbsCBA, it, logProbsCBA[it], true, mBufferManager);
}
std::vector<SizeType32> sequenceLengthsCBA = {10, 10, 0, 0, 0, 0, 0, 0, 0, 0};
mBufferManager->copy(sequenceLengthsCBA.data(), *decodingOutput->beamHypotheses.sequenceLengthsCBA);
std::vector<float> cumLogProbsCBA = {-13.6336, -13.8988, 0, 0, 0, 0, 0, 0, 0, 0};
mBufferManager->copy(cumLogProbsCBA.data(), *decodingOutput->beamHypotheses.cumLogProbsCBA);
std::vector<float> normedScoresCBA = {-1.7042, -1.73735, 0, 0, 0, 0, 0, 0, 0, 0};
mBufferManager->copy(normedScoresCBA.data(), *decodingOutput->beamHypotheses.normedScoresCBA);
std::vector<SizeType32> numBeamsCBA = {2};
mBufferManager->copy(numBeamsCBA.data(), *decodingOutput->beamHypotheses.numBeamsCBA);
std::vector<float> minNormedScoresCBA = {-1.73735};
mBufferManager->copy(minNormedScoresCBA.data(), *decodingOutput->beamHypotheses.minNormedScoresCBA);
std::vector<SizeType32> batchDones = {0};
mBufferManager->copy(batchDones.data(), *decodingOutput->beamHypotheses.batchDones);
std::vector<uint8_t> finishReasons = {4, 4, 4, 4, 4};
mBufferManager->copy(finishReasons.data(), *decodingOutput->finishReasons);
std::vector<std::vector<TokenIdType>> ids =
{
{1, 864, 304, 1073, 825, 1048, 278, 278, 3815, 29973, 13},
{1, 864, 304, 367, 920, 304, 310, 1749, 3815, 29973, 13},
{1, 864, 304, 679, 263, 760, 679, 263, 29973, 13, 310},
{1, 864, 304, 1207, 901, 278, 1749, 445, 3889, 393, 591},
{1, 864, 304, 1074, 263, 29973, 1207, 263, 2446, 12623, 1334}
};
for(SizeType32 it = 0; it < ids.size(); it++)
{
fillTensorAtIndex(decodingOutput->ids, it, ids[it], true, mBufferManager);
}
std::vector<std::vector<SizeType32>> parentIds =
{
{0, 0, 0, 0, 0, 3, 0, 1, 1, 0, 0},
{0, 0, 0, 0, 0, 1, 2, 1, 0, 1, 1},
{0, 0, 0, 0, 1, 2, 1, 4, 3, 2, 4},
{0, 0, 0, 0, 0, 0, 0, 1, 4, 1, 0},
{0, 0, 0, 0, 3, 3, 1, 2, 0, 4, 0}
};
for(SizeType32 it = 0; it < parentIds.size(); it++)
{
fillTensorAtIndex(decodingOutput->parentIds, it, parentIds[it], true, mBufferManager);
}
std::vector<std::vector<TokenIdType>> targetOutput =
{
{1, 864, 304, 367, 263, 760, 310, 1749, 3815, 29973, 13},
{1, 864, 304, 367, 263, 760, 310, 278, 3815, 29973, 13},
{1, 864, 304, 367, 263, 760, 310, 278, 3815, 29973, 0},
{1, 864, 304, 367, 263, 760, 310, 1749, 3815, 29973, 0},
{1, 864, 304, 367, 263, 760, 310, 278, 2446, 12623, 310}
};
for(SizeType32 it = 0; it < targetOutput.size(); it++)
{
fillTensorAtIndex(targetOut, it, targetOutput[it], true, mBufferManager);
}
}
// clang-format on
bool checkResult()
{
TensorPtr reference = this->mBufferManager->copyFrom((*targetOut), tensorrt_llm::runtime::MemoryType::kCPU);
auto referencePtr = bufferCast<TokenIdType>(*reference);
TensorPtr real
= this->mBufferManager->copyFrom((*decodingOutput->gatheredIds), tensorrt_llm::runtime::MemoryType::kCPU);
auto realPtr = bufferCast<TokenIdType>(*real);
bool allEqual = true;
for (SizeType32 iAssert = 0; iAssert < batchSize * beamWidth * maxSeqLen; iAssert++)
{
if (referencePtr[iAssert] != realPtr[iAssert])
{
TLLM_LOG_ERROR("Mismatch input value. Position of inputs: %d, expected value: %d, output value: %d",
iAssert, referencePtr[iAssert], realPtr[iAssert]);
allEqual = false;
}
}
return allEqual;
}
};
TEST_F(TestGatherTree, GatherTreeNoSwap)
{
createBuffers();
hardcodeBuffersLen10();
cudaDeviceSynchronize();
kernels::gatherTree(*decodingOutput, *decodingInput, *mBufferManager, mSamplingConfig);
cudaDeviceSynchronize();
EXPECT_TRUE(checkResult());
}
TEST_F(TestGatherTree, GatherTreeWithSwap)
{
createBuffers();
hardcodeBuffersLen8();
cudaDeviceSynchronize();
kernels::gatherTree(*decodingOutput, *decodingInput, *mBufferManager, mSamplingConfig);
cudaDeviceSynchronize();
EXPECT_TRUE(checkResult());
}
enum AcceptKernelMode
{
BY_IDS,
BY_LOGITS,
BY_IDS_WITH_PATH
};
struct DecodingKernelTestParam
{
SizeType32 mBatchSize{128};
SizeType32 mMaxBatchSize{2 * mBatchSize};
SizeType32 mBeamWidth{1};
SizeType32 mMaxSeqLen{16};
SizeType32 mVocabSize{32};
SizeType32 mMaxDraftTokens{8};
SizeType32 mMaxNumHeads{0};
SizeType32 mMaxDraftSeqPerStep{1};
AcceptKernelMode mAcceptMode{AcceptKernelMode::BY_IDS};
DecodingKernelTestParam& setBatchSize(SizeType32 bs)
{
mBatchSize = bs;
mMaxBatchSize = 2 * mBatchSize;
return *this;
}
DecodingKernelTestParam& setVocabSize(SizeType32 vs)
{
mVocabSize = vs;
return *this;
}
DecodingKernelTestParam& setMaxSeqLen(SizeType32 msl)
{
mMaxSeqLen = msl;
return *this;
}
DecodingKernelTestParam& setMaxDraftTokens(SizeType32 dt)
{
mMaxDraftTokens = dt;
return *this;
}
DecodingKernelTestParam& setMaxNumHeads(SizeType32 mnh)
{
mMaxNumHeads = mnh;
return *this;
}
DecodingKernelTestParam& setMaxDraftSeqPerStep(SizeType32 tps)
{
mMaxDraftSeqPerStep = tps;
return *this;
}
DecodingKernelTestParam& setAcceptMode(AcceptKernelMode const& mode)
{
mAcceptMode = mode;
return *this;
}
};
template <typename T>
class DecodingKernelsTest : public testing::Test
{
public:
using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr;
void SetUp() override
{
mStream = std::make_shared<tensorrt_llm::runtime::CudaStream>();
mBufferManager = std::make_shared<tensorrt_llm::runtime::BufferManager>(mStream);
}
void TearDown() override {}
void createBuffers()
{
auto const dataType = TRTDataType<T>::value;
auto const ptrType = TRTDataType<T*>::value;
mDraftTokens = mBufferManager->pinnedPool(
ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqlen}), nvinfer1::DataType::kINT32);
mTargetTokens = mBufferManager->pinnedPool(
ITensor::makeShape({mMaxBatchSize, mMaxTargetSeqlen}), nvinfer1::DataType::kINT32);
mOutputTokens
= mBufferManager->pinnedPool(ITensor::makeShape({mMaxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kINT32);
mNumsDraftTokens = mBufferManager->pinnedPool(
ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqPerStep}), nvinfer1::DataType::kINT32);
mSequenceLengths = mBufferManager->pinnedPool(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
mAcceptedLengths = mBufferManager->pinnedPool(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
mContextLengths = mBufferManager->pinnedPool(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
mFinishedSteps = mBufferManager->pinnedPool(ITensor::makeShape({mMaxDraftTokens + 1, mMaxBatchSize}),
TRTDataType<tk::FinishedState::UnderlyingType>::value);
mFinishedFinal = mBufferManager->pinnedPool(
ITensor::makeShape({mMaxBatchSize}), TRTDataType<tk::FinishedState::UnderlyingType>::value);
mFinishedSum = mBufferManager->pinnedPool(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
mPaths = mBufferManager->pinnedPool(
ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqPerStep, mMaxDraftTokens}), nvinfer1::DataType::kINT32);
mEndIds = mBufferManager->pinnedPool(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
mBatchSlots = mBufferManager->pinnedPool(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
auto batchSlotsRange = BufferRange<SizeType32>(*mBatchSlots);
std::iota(batchSlotsRange.begin(), batchSlotsRange.end(), 0);
mCurandStates = mBufferManager->gpu(
ITensor::makeShape({mMaxBatchSize, sizeof(curandState_t)}), nvinfer1::DataType::kINT8);
mAcceptedLen.resize(mMaxBatchSize);
mOutputLen.resize(mMaxBatchSize);
mAcceptedFinished.resize(mMaxBatchSize, tk::FinishedState::empty());
// Buffers only for Logits comparison
if (mAcceptMode == AcceptKernelMode::BY_LOGITS)
{
mDraftLogits = mBufferManager->pinnedPool(
ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType);
mTargetLogits = mBufferManager->pinnedPool(
ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType);
mTargetLogitsPtrs = mBufferManager->pinnedPool(ITensor::makeShape({mMaxBatchSize}), ptrType);
mRefTargetLogits = mBufferManager->pinnedPool(
ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType);
mDraftProbs = mBufferManager->pinnedPool(
ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType);
mTargetProbs = mBufferManager->pinnedPool(
ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType);
}
if (mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH)
{
mMedusaLogitsPtrs = mBufferManager->pinnedPool(
ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqPerStep, mMaxNumHeads}), ptrType);
mMedusaInputLogitsPtrs
= mBufferManager->pinnedPool(ITensor::makeShape({mMaxBatchSize, mMaxNumHeads}), ptrType);
mTokensPerStep
= mBufferManager->pinnedPool(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
mBestPaths = mBufferManager->pinnedPool(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
}
}
void initData(SizeType32 seed)
{
std::mt19937 generator(seed);
std::uniform_int_distribution<SizeType32> contextLenDistr(0, std::max(mMaxSeqLen - mMaxTotalDraftTokens, 0));
std::uniform_int_distribution<SizeType32> numTotalDraftTokensDistr(1, mMaxTotalDraftTokens);
std::uniform_int_distribution<SizeType32> numDraftTokensDistr(0, mMaxDraftTokens);
std::uniform_int_distribution<SizeType32> vocabDistr(1, mVocabSize - 1);
std::uniform_real_distribution<float> acceptTokenDistr(0.f, 1.f);
trk::invokeFill(*mPaths, int32_t{-1}, *mStream);
trk::invokeFill(*mFinishedFinal, tk::FinishedState::UnderlyingType{0}, *mStream);
auto sequenceLengthsPtr = BufferRange<SizeType32>(*mSequenceLengths);
auto contextLengthsPtr = BufferRange<SizeType32>(*mContextLengths);
auto numsDraftTokensPtr = BufferRange<SizeType32>(*mNumsDraftTokens);
auto draftTokensPtr = BufferRange<SizeType32>(*mDraftTokens);
auto targetTokensPtr = BufferRange<SizeType32>(*mTargetTokens);
auto finishedStepsPtr
= reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedSteps));
auto pathsPtr = BufferRange<SizeType32>(*mPaths);
auto endIdsPtr = BufferRange<SizeType32>(*mEndIds);
auto batchSlotsPtr = bufferCast<SizeType32>(*mBatchSlots);
tk::invokeCurandInitialize(reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStates)), batchSlotsPtr,
mMaxBatchSize, seed, this->mStream->get());
auto generateAvoidingValues = [&vocabDistr, &generator](std::uniform_int_distribution<SizeType32>& distr,
std::unordered_set<SizeType32> const& tokensToAvoid, SizeType32 maxTries = -1,
SizeType32 defaultValue = -1)
{
// Avoid generating endId.
auto token = distr(generator);
SizeType32 tries = 0;
while (tokensToAvoid.count(token) != 0 && ((maxTries >= 0 && tries < maxTries) || maxTries < 0))
{
token = distr(generator);
tries++;
}
if (tries == maxTries)
{
token = defaultValue;
}
return token;
};
// Init batch slots
for (SizeType32 bi = 0; bi < mBatchSize; ++bi)
{
batchSlotsPtr[bi] = 2 * bi;
}
// Init end ids
for (SizeType32 bi = 0; bi < mMaxBatchSize; ++bi)
{
endIdsPtr[bi] = generateAvoidingValues(vocabDistr, {mPadId});
TLLM_LOG_DEBUG("bi %d endIdsPtr[bi] %d", bi, endIdsPtr[bi]);
// Randomly init context len for target and draft
contextLengthsPtr[bi] = contextLenDistr(generator);
}
std::fill(draftTokensPtr.begin(), draftTokensPtr.begin() + mMaxBatchSize * mMaxDraftSeqlen, mPadId);
std::fill(targetTokensPtr.begin(), targetTokensPtr.begin() + mMaxBatchSize * mMaxTargetSeqlen, mPadId);
std::fill(pathsPtr.begin(), pathsPtr.begin() + mMaxBatchSize * mMaxDraftSeqPerStep * mMaxDraftTokens, -1);
// Generate paths
for (SizeType32 bi = 0; bi < mMaxBatchSize; ++bi)
{
auto const numTotalDraftTokens = std::min(mMaxDraftTokens, numTotalDraftTokensDistr(generator));
std::uniform_int_distribution<SizeType32> pathIdDistr(0, numTotalDraftTokens);
for (SizeType32 pi = 0; pi < mMaxDraftSeqPerStep; ++pi)
{
std::unordered_set<SizeType32> pathIds;
auto const numDraftTokensAtStep = numDraftTokensDistr(generator);
numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + pi] = numDraftTokensAtStep;
for (SizeType32 ti = 0; ti < numDraftTokensAtStep; ++ti)
{
auto const pathIdx = tc::flat_index3(bi, pi, ti, mMaxDraftSeqPerStep, mMaxDraftTokens);
// Single linear path for BY_IDS and BY_LOGITS modes
auto const pathId = mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH && ti != 0
? generateAvoidingValues(pathIdDistr, pathIds, mMaxDraftTokens * 5, -1)
: ti;
pathsPtr[pathIdx] = pathId;
pathIds.insert(pathId);
}
if (bi == 2)
{
TLLM_LOG_DEBUG("bi %d pi %d numsDraftTokensPtr[bi] %d", bi, pi, numDraftTokensAtStep);
}
}
}
for (SizeType32 ti = 0; ti < mMaxDraftSeqPerStep; ++ti)
{
std::vector<SizeType32> targetPredictedLen(mMaxBatchSize);
std::vector<SizeType32> targetAcceptedLen(mMaxBatchSize);
// Init number of draft tokens
for (SizeType32 bi = 0; bi < mMaxBatchSize; ++bi)
{
// It can be shorter than num of draft tokens due to the EOS generation
std::uniform_int_distribution<SizeType32> realDraftTokensDistr(
0, numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]);
targetPredictedLen[bi] = realDraftTokensDistr(generator);
// Accept ~ half of the tokens on avergae
std::poisson_distribution<SizeType32> targetAcceptedDistr(targetPredictedLen[bi] / 2);
targetAcceptedLen[bi] = std::min(targetAcceptedDistr(generator), targetPredictedLen[bi]);
if (bi == 2)
{
TLLM_LOG_DEBUG("bi %d ti %d targetPredictedLen[bi] %d targetAcceptedLen[bi] %d", bi, ti,
targetPredictedLen[bi], targetAcceptedLen[bi]);
}
}
// Fill draft tokens
for (SizeType32 bi = 0; bi < mMaxBatchSize; ++bi)
{
for (SizeType32 si = 0; si < numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]; ++si)
{
auto const pathIdx = tc::flat_index3(bi, ti, si, mMaxDraftSeqPerStep, mMaxDraftTokens);
if (pathsPtr[pathIdx] == -1)
{
continue;
}
auto const draftTokenIdx = bi * mMaxDraftSeqlen + pathsPtr[pathIdx];
// Avoid generating endId. We'll insert in manually later if needed.
draftTokensPtr[draftTokenIdx] = generateAvoidingValues(vocabDistr, {mPadId, endIdsPtr[bi]});
if (bi == 2)
{
TLLM_LOG_DEBUG("bi %d ti %d si %d pathId %d draftToken %d", bi, ti, si, pathsPtr[pathIdx],
draftTokensPtr[draftTokenIdx]);
}
}
}
for (SizeType32 bi = 0; bi < mMaxBatchSize; ++bi)
{
sequenceLengthsPtr[bi] = contextLengthsPtr[bi] + targetPredictedLen[bi];
// Initialize finished states
for (int di = 0; di < numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]; ++di)
{
finishedStepsPtr[di * mMaxBatchSize + bi]
= (di < targetPredictedLen[bi]) ? tk::FinishedState::empty() : tk::FinishedState::finished();
}
// Init helper vectors
mAcceptedLen[bi] = contextLengthsPtr[bi] + std::max(targetAcceptedLen[bi], 0);
mOutputLen[bi] = std::min(sequenceLengthsPtr[bi], std::min(mAcceptedLen[bi] + 1, mMaxSeqLen));
mAcceptedFinished[bi] = finishedStepsPtr[std::max(targetAcceptedLen[bi], 0) * mMaxBatchSize + bi];
if (bi == 2)
{
TLLM_LOG_DEBUG(
"bi %d ti %d contextLengthsPtr[bi] %d sequenceLengthsPtr[bi] %d mAcceptedLen[bi] %d "
"mOutputLen[bi] "
"%d",
bi, ti, contextLengthsPtr[bi], sequenceLengthsPtr[bi], mAcceptedLen[bi], mOutputLen[bi]);
}
}
// Fill token arrays
for (SizeType32 bi = 0; bi < mMaxBatchSize; ++bi)
{
// Draft: [d0, d1, d2, ... for numsDraftTokensPtr[bi] ... , dK,
// padId, padId, .. to mMaxDraftSeqlen]
// Target: [padId, padId, ... for contextLengthsPtr[bi] ... padId,
// d0, d1, d2, ... for targetAcceptedLen[bi],
// ti (!= di), ti+1 (!= di+1), ... for (targetPredictedLen[bi] - targetAcceptedLen[bi]),
// EOS, EOS, EOS, ... for (numsDraftTokensPtr[bi] - targetPredictedLen[bi])
// padId, padId, .. to mMaxSeqLen]
auto numDraftTokens = numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti];
for (SizeType32 si = 0; si < numDraftTokens; ++si)
{
auto const curPathIdx = tc::flat_index3(bi, ti, si, mMaxDraftSeqPerStep, mMaxDraftTokens);
auto const nextPathIdx = si + 1 < numDraftTokens
? tc::flat_index3(bi, ti, si + 1, mMaxDraftSeqPerStep, mMaxDraftTokens)
: -1;
auto const curPathId = pathsPtr[curPathIdx];
auto nextPathId = curPathId;
if (mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH)
{
nextPathId = nextPathIdx > -1 ? pathsPtr[nextPathIdx] : -1;
}
if (curPathId == -1)
{
continue;
}
auto const contextLen
= mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH ? 0 : contextLengthsPtr[bi];
auto const draftTokenIdx = bi * mMaxDraftSeqlen + nextPathId;
auto const targetTokenIdx = bi * mMaxTargetSeqlen + contextLen + curPathId;
auto targetToken = mPadId;
if (0 <= si && si < targetAcceptedLen[bi] && nextPathId != -1)
{
// Use draft token up to the accepted len
targetToken = draftTokensPtr[draftTokenIdx];
}
else if (0 <= si && si < targetPredictedLen[bi])
{
// Do not use draft token token up to the generated len
std::unordered_set<SizeType32> avoidValues = {mPadId, endIdsPtr[bi]};
if (nextPathId != -1)
{
avoidValues.insert(draftTokensPtr[draftTokenIdx]);
}
targetToken = generateAvoidingValues(vocabDistr, avoidValues);
}
else if (targetPredictedLen[bi] <= si && si < numsDraftTokensPtr[bi])
{
// Fill with EOS from generated len to the draft len
targetToken = endIdsPtr[bi];
}
targetTokensPtr[targetTokenIdx] = targetToken;
if (bi == 2)
{
TLLM_LOG_DEBUG(
"bi %d ti %d si %d pathId %d targetToken %d", bi, ti, si, curPathId, targetToken);
}
}
}
}
if (mAcceptMode == AcceptKernelMode::BY_LOGITS)
{
initDataAndReferenceAcceptByLogits();
}
if (mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH)
{
initDataAndReferenceAcceptByIdsWithPaths();
}
mSequenceLengthsCopy = mBufferManager->copyFrom(*mSequenceLengths, MemoryType::kCPU);
}
void initDataAndReferenceAcceptByIdsWithPaths()
{
auto const dataType = TRTDataType<T>::value;
auto const ptrType = TRTDataType<T*>::value;
auto pathsPtr = BufferRange<SizeType32>(*mPaths);
auto endIdsPtr = BufferRange<SizeType32>(*mEndIds);
auto contextLengthsPtr = BufferRange<SizeType32>(*mContextLengths);
auto draftTokensPtr = BufferRange<SizeType32>(*mDraftTokens);
auto targetTokensPtr = BufferRange<SizeType32>(*mTargetTokens);
auto medusaInputLogitsPtr = BufferRange<T*>(*mMedusaInputLogitsPtrs);
trk::invokeFill(*mMedusaLogitsPtrs, int64_t{0}, *mStream);
trk::invokeFill(*mTokensPerStep, int32_t{mMaxTotalDraftTokens}, *mStream);
trk::invokeFill(*mBestPaths, int32_t{-1}, *mStream);
mAcceptedLen.resize(mMaxBatchSize);
mAcceptedPathIdx.resize(mMaxBatchSize);
mRefAcceptedTokens.resize(mMaxBatchSize);
mFinishedByIdsPaths.resize(mMaxBatchSize);
mLastTargetIdx.resize(mMaxBatchSize);
for (SizeType32 bi = 0; bi < mMaxBatchSize; ++bi)
{
SizeType32 maxAcceptedLen = -1;
SizeType32 maxAcceptedPath = -1;
SizeType32 maxNextTargetTokenIdx = -1;
bool maxFinished = false;
std::vector<SizeType32> maxAcceptedTokens;
for (SizeType32 ti = 0; ti < mMaxDraftSeqPerStep; ++ti)
{
std::vector<SizeType32> acceptedTokens;
SizeType32 curAcceptedLen = mMaxDraftTokens;
SizeType32 curAcceptedPath = ti;
bool curFinished = false;
auto const pathIdx = tc::flat_index3(bi, ti, 0, mMaxDraftSeqPerStep, mMaxDraftTokens);
auto const pathId = pathsPtr[pathIdx];
if (pathId == -1)
{
continue;
}
auto targetTokenIdx = bi * mMaxTargetSeqlen + pathId;
auto targetToken = targetTokensPtr[targetTokenIdx];
auto curNextTargetTokenIdx = pathId;
for (SizeType32 di = 1; di < mMaxDraftTokens; ++di)
{
auto const pathIdx = tc::flat_index3(bi, ti, di, mMaxDraftSeqPerStep, mMaxDraftTokens);
auto const pathId = pathsPtr[pathIdx];
if (pathId == -1)
{
curAcceptedLen = di;
curAcceptedPath = ti;
curFinished = false;
acceptedTokens.push_back(targetToken);
break;
}
auto const draftTokenIdx = bi * mMaxDraftSeqlen + pathId - 1;
auto const targetTokenIdx = bi * mMaxTargetSeqlen + pathId;
auto const draftToken = draftTokensPtr[draftTokenIdx];
bool const hasEnd = targetToken == endIdsPtr[bi];
if (!hasEnd)
{
acceptedTokens.push_back(targetToken);
}
if (draftToken != targetToken || hasEnd)
{
auto const curLen = hasEnd ? di - 1 : di;
curAcceptedLen = curLen;
curAcceptedPath = ti;
curFinished = hasEnd;
break;
}
targetToken = targetTokensPtr[targetTokenIdx];
curNextTargetTokenIdx = pathId;
}
if (curAcceptedLen == mMaxDraftTokens)
{
acceptedTokens.push_back(targetToken);
}
if (curAcceptedLen > maxAcceptedLen)
{
maxAcceptedLen = curAcceptedLen;
maxAcceptedPath = curAcceptedPath;
maxAcceptedTokens = acceptedTokens;
maxFinished = curFinished;
maxNextTargetTokenIdx = curNextTargetTokenIdx;
}
}
mAcceptedLen[bi] = maxAcceptedLen;
mAcceptedPathIdx[bi] = maxAcceptedPath;
mRefAcceptedTokens[bi] = maxAcceptedTokens;
mFinishedByIdsPaths[bi] = maxFinished;
mLastTargetIdx[bi] = maxNextTargetTokenIdx;
for (SizeType32 hi = 0; hi < mMaxNumHeads; ++hi)
{
medusaInputLogitsPtr[bi * mMaxNumHeads + hi] = static_cast<T*>(nullptr)
+ tc::flat_index4(hi, bi, 0, 0, mMaxBatchSize, mMaxDraftSeqPerStep, mVocabSize);
}
if (bi == 2)
{
TLLM_LOG_DEBUG("bi %d maxAcceptedLen %d maxAcceptedPath %d maxNextTargetTokenIdx %d", bi,
maxAcceptedLen, maxAcceptedPath, maxNextTargetTokenIdx);
std::ostringstream ss;
for (auto& tk : maxAcceptedTokens)
{
ss << tk << " ";
}
TLLM_LOG_DEBUG(ss.str().c_str());
}
}
}
void initDataAndReferenceAcceptByLogits()
{
auto contextLengthsPtr = BufferRange<SizeType32>(*mContextLengths);
auto numsDraftTokensPtr = BufferRange<SizeType32>(*mNumsDraftTokens);
auto draftTokensPtr = BufferRange<SizeType32>(*mDraftTokens);
auto targetTokensPtr = BufferRange<SizeType32>(*mTargetTokens);
auto draftProbsPtr = BufferRange<T>(*mDraftProbs);
auto targetProbsPtr = BufferRange<T>(*mTargetProbs);
auto draftLogitsPtr = BufferRange<T>(*mDraftLogits);
auto targetLogitsPtr = BufferRange<T>(*mTargetLogits);
auto targetLogitsPtrsPtr = BufferRange<T*>(*mTargetLogitsPtrs);
auto refTargetLogitsPtr = BufferRange<T>(*mRefTargetLogits);
auto batchSlotsPtr = BufferRange<SizeType32>(*mBatchSlots);
for (SizeType32 bi = 0; bi < mMaxBatchSize; ++bi)
{
// Init draft and target logits and probabilities
for (SizeType32 si = 0; si < numsDraftTokensPtr[bi]; ++si)
{
std::vector<float> peakDraftProb(mVocabSize, 0.f);
std::vector<float> peakTargetProb(mVocabSize, 0.f);
auto const targetToken = targetTokensPtr[bi * mMaxSeqLen + contextLengthsPtr[bi] + si] % mVocabSize;
auto const draftToken = draftTokensPtr[bi * mMaxDraftTokens + si] % mVocabSize;
peakDraftProb[draftToken] = 1.f;
peakTargetProb[targetToken] = 1.f;
auto const logitsOffset = bi * mMaxDraftTokens * mVocabSize + si * mVocabSize;
// Emulate some distribution around target token
applyGaussianFilter(
draftProbsPtr.begin() + logitsOffset, peakDraftProb.data(), peakDraftProb.size(), 1.0f);
applyGaussianFilter(
targetProbsPtr.begin() + logitsOffset, peakTargetProb.data(), peakTargetProb.size(), 1.0f);
// Probabilities to logits
probsToLogits(draftProbsPtr.begin() + logitsOffset, draftLogitsPtr.begin() + logitsOffset, mVocabSize);
probsToLogits(
targetProbsPtr.begin() + logitsOffset, targetLogitsPtr.begin() + logitsOffset, mVocabSize);
// Do softmax conversion back to emulate kernels accuracy
softmax(draftLogitsPtr.begin() + logitsOffset, draftProbsPtr.begin() + logitsOffset, mVocabSize);
softmax(targetLogitsPtr.begin() + logitsOffset, targetProbsPtr.begin() + logitsOffset, mVocabSize);
}
}
for (SizeType32 bi = 0; bi < mMaxBatchSize; ++bi)
{
for (SizeType32 si = 0; si < mMaxDraftTokens; ++si)
{
auto const logitsOffset = bi * mMaxDraftTokens * mVocabSize + si * mVocabSize;
auto const outputLen = mOutputLen[bi] - contextLengthsPtr[bi];
auto const acceptedLen = mAcceptedLen[bi] - contextLengthsPtr[bi];
if (si < acceptedLen)
{
auto logitsStart = targetLogitsPtr.begin() + logitsOffset;
std::copy(logitsStart, logitsStart + mVocabSize, refTargetLogitsPtr.begin() + logitsOffset);
}
else if (si == acceptedLen)
{
// When token is not accepted, correct probabilities and compute updated logits
float sumProb = 1e-6f;
for (SizeType32 vi = 0; vi < mVocabSize; ++vi)
{
auto const correctedProb = std::max(
static_cast<float>(targetProbsPtr[logitsOffset + vi] - draftProbsPtr[logitsOffset + vi]),
0.f);
sumProb += correctedProb;
}
for (SizeType32 vi = 0; vi < mVocabSize; ++vi)
{
auto prob = std::max(static_cast<float>(
targetProbsPtr[logitsOffset + vi] - draftProbsPtr[logitsOffset + vi]),
0.f)
/ sumProb;
if (prob < 1e-8)
{
prob = 0.f;
}
refTargetLogitsPtr[logitsOffset + vi] = std::log(prob / (1.f - prob));
}
}
}
}
for (SizeType32 bi = 0; bi < mBatchSize; ++bi)
{
targetLogitsPtrsPtr[bi] = targetLogitsPtr.begin() + batchSlotsPtr[bi] * mMaxDraftTokens * mVocabSize;
}
}
void callAcceptByIds()
{
// tksp::invokeAcceptDraftTokensByIds(bufferCast<SizeType32>(*mDraftTokens),
// bufferCast<SizeType32>(*mTargetTokens), bufferCast<SizeType32>(*mContextLengths),
// bufferCast<SizeType32>(*mNumsDraftTokens), bufferCast<SizeType32>(*mSequenceLengths),
// reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedSteps)),
// reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal)),
// bufferCast<SizeType32>(*mFinishedSum), bufferCast<SizeType32>(*mBatchSlots), mBatchSize, mMaxBatchSize,
// mBeamWidth, mMaxSeqLen, mMaxDraftTokens, mStream->get());
}
void callAcceptByLogits()
{
// tksp::acceptDraftTokensByLogits(bufferCast<T>(*mDraftLogits),
// reinterpret_cast<T**>(bufferCast<int64_t>(*mTargetLogitsPtrs)), bufferCast<T>(*mDraftProbs),
// bufferCast<T>(*mTargetProbs), bufferCast<SizeType32>(*mNumsDraftTokens),
// reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedSteps)),
// reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStates)),
// bufferCast<SizeType32>(*mBatchSlots), mBatchSize, mMaxBatchSize, mBeamWidth, mVocabSize, mVocabSize,
// mMaxDraftTokens, false, 0.9f, mStream->get());
}
void callAcceptByIdsWithPaths()
{
tksp::AcceptDraftTokensByIdsWithPathsParams<T> params;
params.outputIds = bufferCast<SizeType32>(*mOutputTokens);
params.draftIds = bufferCast<SizeType32>(*mDraftTokens);
params.targetIds = bufferCast<SizeType32>(*mTargetTokens);
params.sequenceLengths = bufferCast<SizeType32>(*mSequenceLengths);
params.acceptedLengths = bufferCast<SizeType32>(*mAcceptedLengths);
params.finishedFinal
= reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal));
params.batchSlots = bufferCast<SizeType32>(*mBatchSlots);
params.paths = bufferCast<SizeType32>(*mPaths);
params.endIds = bufferCast<SizeType32>(*mEndIds);
params.medusaLogits = reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs));
params.logitsPtrs = reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaLogitsPtrs));
params.curTokensPerStep = bufferCast<SizeType32>(*mTokensPerStep);
params.targetTokensPerStep = bufferCast<SizeType32>(*mTokensPerStep);
params.bestPathIds = bufferCast<SizeType32>(*mBestPaths);
params.batchSize = mBatchSize;
params.maxBatchSize = mMaxBatchSize;
params.vocabSize = mVocabSize;
params.maxSeqLen = mMaxSeqLen;
params.maxDraftPathLen = mMaxNumHeads;
params.maxDecodingTokens = mMaxDraftSeqPerStep;
params.stream = mStream->get();
params.checkParams();
tksp::acceptDraftTokensByIdsWithPaths(params);
}
void callTestedKernel()
{
switch (mAcceptMode)
{
case AcceptKernelMode::BY_IDS: callAcceptByIds(); break;
case AcceptKernelMode::BY_LOGITS: callAcceptByLogits(); break;
case AcceptKernelMode::BY_IDS_WITH_PATH: callAcceptByIdsWithPaths(); break;
default: TLLM_CHECK(false); // Should never be here
}
}
void verifyAcceptByIdsResults(SizeType32 seed)
{
auto finishedFinalPtr
= reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal));
auto sequenceLengthsPtr = BufferRange<SizeType32>(*mSequenceLengths);
auto finishedSumPtr = BufferRange<SizeType32>(*mFinishedSum);
auto batchSlotsPtr = BufferRange<SizeType32>(*mBatchSlots);
// Verify seqLen for accepted tokens
for (SizeType32 bi = 0; bi < mBatchSize; ++bi)
{
auto const batchSlot = batchSlotsPtr[bi];
EXPECT_EQ(mOutputLen[batchSlot], sequenceLengthsPtr[batchSlot]) << " bi " << bi << " seed " << seed;
EXPECT_EQ(mAcceptedFinished[batchSlot].isFinished(), finishedFinalPtr[batchSlot].isFinished())
<< " bi " << bi << " seed " << seed;
EXPECT_EQ(mAcceptedFinished[batchSlot].isSkipDecoding(), finishedFinalPtr[batchSlot].isSkipDecoding())
<< " bi " << bi << " seed " << seed;
EXPECT_EQ(static_cast<SizeType32>(mAcceptedFinished[batchSlot].isFinished()), finishedSumPtr[batchSlot]);
}
}
void verifyAcceptByLogitsResults(SizeType32 seed)
{
auto finishedStepsPtr
= reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedSteps));
auto contextLengthsPtr = BufferRange<SizeType32>(*mContextLengths);
auto outLogitsPtr = BufferRange<T>(*mTargetLogits);
auto refLogitsPtr = BufferRange<T>(*mRefTargetLogits);
auto numsDraftTokensPtr = BufferRange<SizeType32>(*mNumsDraftTokens);
auto batchSlotsPtr = BufferRange<SizeType32>(*mBatchSlots);
for (SizeType32 bi = 0; bi < mBatchSize; ++bi)
{
auto const batchSlot = batchSlotsPtr[bi];
for (SizeType32 si = 0; si < numsDraftTokensPtr[batchSlot]; ++si)
{
auto const outFinishedState = finishedStepsPtr[si * mMaxBatchSize + batchSlot];
auto const logitsOffset = batchSlot * mMaxDraftTokens * mVocabSize + si * mVocabSize;
if (si <= mAcceptedLen[batchSlot] - contextLengthsPtr[batchSlot])
{
EXPECT_FALSE(outFinishedState.isSkipDecoding())
<< " bi: " << bi << " si: " << si << " seed: " << seed;
for (SizeType32 vi = 0; vi < mVocabSize; ++vi)
{
auto const outLogit = static_cast<float>(outLogitsPtr[logitsOffset + vi]);
auto const refLogit = static_cast<float>(refLogitsPtr[logitsOffset + vi]);
EXPECT_FALSE((refLogit > -10) ^ (outLogit > -10))
<< " bi: " << bi << " si: " << si << " vi: " << vi << " seed: " << seed;
if (refLogit > -10 && outLogit > -10)
{
if (!almostEqual(outLogit, refLogit, 1e-1, 1e-2))
{
std::cout << refLogit << " " << outLogit << std::endl;
}
ASSERT_TRUE(almostEqual(outLogit, refLogit, 1e-1, 1e-2))
<< " bi: " << bi << " si: " << si << " vi: " << vi << " seed: " << seed;
}
}
}
else
{
EXPECT_TRUE(outFinishedState.isSkipDecoding())
<< " bi: " << bi << " si: " << si << " seed: " << seed;
}
}
}
}
void verifyAcceptByIdsWithPathsResults(SizeType32 seed)
{
auto medusaLogitsPtrsPtr = BufferRange<T*>(*mMedusaLogitsPtrs);
auto batchSlotsPtr = BufferRange<SizeType32>(*mBatchSlots);
auto draftContextLengths = BufferRange<SizeType32>(*mSequenceLengths);
auto draftContextLengthsInit = BufferRange<SizeType32>(*mSequenceLengthsCopy);
auto acceptedLengths = BufferRange<SizeType32>(*mAcceptedLengths);
auto outputIdsPtr = BufferRange<TokenIdType>(*mOutputTokens);
auto bestPathIds = BufferRange<SizeType32>(*mBestPaths);
auto finishedFinalPtr
= reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal));
for (SizeType32 bi = 0; bi < mBatchSize; ++bi)
{
auto const batchSlot = batchSlotsPtr[bi];
auto const bestPathIdx = mAcceptedPathIdx[batchSlot];
auto const lastTargetIdx = mLastTargetIdx[batchSlot];
if (lastTargetIdx < 0)
{
continue;
}
auto const acceptedLen = mAcceptedLen[batchSlot];
auto acceptedTokens = mRefAcceptedTokens[batchSlot];
EXPECT_EQ(bestPathIds[batchSlot], bestPathIdx) << "bi: " << bi << " seed: " << seed;
for (int32_t hi = 0; hi < mMaxNumHeads; ++hi)
{
auto refOffset
= tc::flat_index4(hi, batchSlot, lastTargetIdx, 0, mMaxBatchSize, mMaxDraftSeqPerStep, mVocabSize);
auto outOffset
= static_cast<SizeType32>(medusaLogitsPtrsPtr[bi * mMaxNumHeads + hi] - static_cast<T*>(nullptr));
EXPECT_EQ(outOffset, refOffset) << " bi: " << bi << " hi: " << hi << " seed: " << seed;
}
EXPECT_EQ(acceptedLengths[batchSlot], acceptedLen) << " bi: " << bi << " seed: " << seed;
EXPECT_EQ(draftContextLengths[batchSlot], draftContextLengthsInit[batchSlot] + acceptedLen)
<< " bi: " << bi << " seed: " << seed << " out: " << draftContextLengths[batchSlot]
<< " ref: " << draftContextLengthsInit[batchSlot] + acceptedLen;
for (SizeType32 ti = 0; ti < acceptedLen; ++ti)
{
ASSERT_EQ(mRefAcceptedTokens[batchSlot].size(), acceptedLen)
<< " bi: " << bi << " ti: " << ti << " seed: " << seed;
EXPECT_EQ(outputIdsPtr[batchSlot * mMaxSeqLen + draftContextLengthsInit[batchSlot] + ti],
mRefAcceptedTokens[batchSlot][ti])
<< " bi: " << bi << " ti: " << ti << " seed: " << seed;
}
EXPECT_EQ(finishedFinalPtr[batchSlot].isFinished(), mFinishedByIdsPaths[batchSlot])
<< " bi: " << bi << " seed: " << seed;
}
}
void verifyResult(SizeType32 seed)
{
switch (mAcceptMode)
{
case AcceptKernelMode::BY_IDS: verifyAcceptByIdsResults(seed); break;
case AcceptKernelMode::BY_LOGITS: verifyAcceptByLogitsResults(seed); break;
case AcceptKernelMode::BY_IDS_WITH_PATH: verifyAcceptByIdsWithPathsResults(seed); break;
default: TLLM_CHECK(false); // Should never be here
}
}
void runTest(DecodingKernelTestParam const& params)
{
mAcceptMode = params.mAcceptMode;
mBatchSize = params.mBatchSize;
mMaxBatchSize = params.mMaxBatchSize;
mBeamWidth = params.mBeamWidth;
mVocabSize = params.mVocabSize;
mMaxDraftTokens = params.mMaxDraftTokens;
mMaxSeqLen = params.mMaxSeqLen;
mMaxNumHeads = params.mMaxNumHeads;
if (mMaxNumHeads > 1 && mAcceptMode != AcceptKernelMode::BY_IDS_WITH_PATH)
{
GTEST_SKIP() << "MaxNumHeads > 1 is only supported for AcceptKernelMode::BY_IDS_WITH_PATH";
}
mMaxDraftSeqPerStep = params.mMaxDraftSeqPerStep;
if (mMaxDraftSeqPerStep > 1 && mAcceptMode != AcceptKernelMode::BY_IDS_WITH_PATH)
{
GTEST_SKIP() << "MaxDraftSeqPerStep > 1 is only supported for AcceptKernelMode::BY_IDS_WITH_PATH";
}
mMaxTotalDraftTokens = mMaxDraftSeqPerStep * mMaxDraftTokens;
mPadId = mVocabSize - 1;
mMaxDraftSeqlen = mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH ? mMaxDraftTokens - 1 : mMaxDraftTokens;
mMaxTargetSeqlen = mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH ? mMaxDraftTokens : mMaxSeqLen;
createBuffers();
for (SizeType32 seed = 0; seed < mSeeds; ++seed)
{
// if (seed != 145)
// {
// continue;
// }
TLLM_LOG_DEBUG("Seed %d", seed);
initData(seed);
mStream->synchronize();
callTestedKernel();
mStream->synchronize();
verifyResult(seed);
}
}
protected:
std::shared_ptr<tensorrt_llm::runtime::BufferManager> mBufferManager;
std::shared_ptr<tensorrt_llm::runtime::CudaStream> mStream;
TensorPtr mDraftTokens;
TensorPtr mTargetTokens;
TensorPtr mOutputTokens;
TensorPtr mDraftLogits;
TensorPtr mTargetLogits;
TensorPtr mTargetLogitsPtrs;
TensorPtr mRefTargetLogits;
TensorPtr mDraftProbs;
TensorPtr mTargetProbs;
TensorPtr mNumsDraftTokens;
TensorPtr mSequenceLengths;
TensorPtr mSequenceLengthsCopy;
TensorPtr mAcceptedLengths;
TensorPtr mContextLengths;
TensorPtr mFinishedSteps;
TensorPtr mFinishedFinal;
TensorPtr mFinishedSum;
TensorPtr mBatchSlots;
TensorPtr mPaths;
TensorPtr mEndIds;
TensorPtr mMedusaLogitsPtrs;
TensorPtr mMedusaInputLogitsPtrs;
TensorPtr mTokensPerStep;
TensorPtr mBestPaths;
TensorPtr mCurandStates;
std::vector<SizeType32> mAcceptedLen;
std::vector<SizeType32> mOutputLen;
std::vector<tk::FinishedState> mAcceptedFinished;
std::vector<SizeType32> mAcceptedPathIdx;
std::vector<SizeType32> mLastTargetIdx;
std::vector<std::vector<SizeType32>> mRefAcceptedTokens;
std::vector<bool> mFinishedByIdsPaths;
SizeType32 mBatchSize;
SizeType32 mMaxBatchSize;
SizeType32 mBeamWidth;
SizeType32 mMaxSeqLen;
SizeType32 mVocabSize;
SizeType32 mMaxDraftTokens;
SizeType32 mMaxTotalDraftTokens;
SizeType32 mMaxDraftSeqlen;
SizeType32 mMaxTargetSeqlen;
SizeType32 mMaxNumHeads;
SizeType32 mMaxDraftSeqPerStep;
AcceptKernelMode mAcceptMode;
SizeType32 mPadId;
static constexpr SizeType32 mSeeds = 64;
};
template class DecodingKernelsTest<float>;
template class DecodingKernelsTest<half>;
typedef testing::Types<float, half> FloatAndHalfTypes;
TYPED_TEST_SUITE(DecodingKernelsTest, FloatAndHalfTypes);
TYPED_TEST(DecodingKernelsTest, DISABLED_acceptDraftTokensByIdsKernelSmall)
{
this->runTest(DecodingKernelTestParam()
.setBatchSize(1)
.setMaxSeqLen(16)
.setVocabSize(32)
.setMaxDraftTokens(8)
.setMaxDraftSeqPerStep(1)
.setAcceptMode(AcceptKernelMode::BY_IDS));
}
TYPED_TEST(DecodingKernelsTest, DISABLED_acceptDraftTokensByIdsKernelLarge)
{
this->runTest(DecodingKernelTestParam()
.setBatchSize(128)
.setMaxSeqLen(128)
.setVocabSize(52000)
.setMaxDraftTokens(8)
.setMaxDraftSeqPerStep(1)
.setAcceptMode(AcceptKernelMode::BY_IDS));
}
TYPED_TEST(DecodingKernelsTest, DISABLED_acceptDraftTokensByLogitsKernelSmall)
{
this->runTest(DecodingKernelTestParam()
.setBatchSize(1)
.setMaxSeqLen(16)
.setVocabSize(32)
.setMaxDraftTokens(8)
.setMaxDraftSeqPerStep(1)
.setAcceptMode(AcceptKernelMode::BY_LOGITS));
}
TYPED_TEST(DecodingKernelsTest, DISABLED_acceptDraftTokensByLogitsKernelLarge)
{
this->runTest(DecodingKernelTestParam()
.setBatchSize(64)
.setMaxSeqLen(64)
.setVocabSize(4000)
.setMaxDraftTokens(8)
.setMaxDraftSeqPerStep(1)
.setAcceptMode(AcceptKernelMode::BY_LOGITS));
}
// FIXME(nkorobov): test is incorrect and too complicated.
TYPED_TEST(DecodingKernelsTest, DISABLED_acceptDraftTokensByIdsWithPathsKernelSmall)
{
this->runTest(DecodingKernelTestParam()
.setBatchSize(1)
.setMaxSeqLen(128)
.setVocabSize(32)
.setMaxDraftTokens(3)
.setMaxDraftSeqPerStep(4)
.setMaxNumHeads(2)
.setAcceptMode(AcceptKernelMode::BY_IDS_WITH_PATH));
}
// FIXME(nkorobov): test is incorrect and too complicated.
TYPED_TEST(DecodingKernelsTest, DISABLED_acceptDraftTokensByIdsWithPathsKernelLarge)
{
this->runTest(DecodingKernelTestParam()
.setBatchSize(128)
.setMaxSeqLen(1024)
.setVocabSize(4000)
.setMaxDraftTokens(8)
.setMaxDraftSeqPerStep(64)
.setMaxNumHeads(7)
.setAcceptMode(AcceptKernelMode::BY_IDS_WITH_PATH));
}
} // end of namespace