mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com> Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Eddie-Wang1120 <wangjinheng1120@163.com> Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
1092 lines
46 KiB
C++
1092 lines
46 KiB
C++
/*
|
|
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#ifndef TOP_LEVEL_DIR
|
|
#error "Define TOP_LEVEL_DIR"
|
|
#endif
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include "tensorrt_llm/common/memoryUtils.h"
|
|
#include "tensorrt_llm/kernels/decodingCommon.h"
|
|
#include "tensorrt_llm/kernels/decodingKernels.h"
|
|
#include "tensorrt_llm/runtime/bufferManager.h"
|
|
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
|
#include <curand_kernel.h>
|
|
#include <random>
|
|
#include <unordered_set>
|
|
|
|
namespace tk = tensorrt_llm::kernels;
|
|
namespace tc = tensorrt_llm::common;
|
|
namespace trk = tensorrt_llm::runtime::kernels;
|
|
|
|
using namespace tensorrt_llm::runtime;
|
|
|
|
namespace
|
|
{
|
|
|
|
inline bool almostEqual(float a, float b, float atol = 1e-2, float rtol = 1e-3)
|
|
{
|
|
// Params: a = value to compare and b = reference
|
|
// This function follows implementation of numpy.isclose(), which checks
|
|
// abs(a - b) <= (atol + rtol * abs(b)).
|
|
// Note that the inequality above is asymmetric where b is considered as
|
|
// a reference value. To account into both absolute/relative errors, it
|
|
// uses absolute tolerance and relative tolerance at the same time. The
|
|
// default values of atol and rtol borrowed from numpy.isclose(). For the
|
|
// case of nan value, the result will be true.
|
|
if (isnan(a) && isnan(b))
|
|
{
|
|
return true;
|
|
}
|
|
return fabs(a - b) <= (atol + rtol * fabs(b));
|
|
}
|
|
|
|
std::vector<float> calculateGaussianKernel(float sigma, int size)
|
|
{
|
|
std::vector<float> kernel(size);
|
|
float sum = 0.f;
|
|
|
|
for (int i = 0; i < size; ++i)
|
|
{
|
|
int x = i - size / 2;
|
|
kernel[i] = std::exp(-0.5f * (x * x) / (sigma * sigma));
|
|
sum += kernel[i];
|
|
}
|
|
|
|
// Normalize the kernel
|
|
for (int i = 0; i < size; ++i)
|
|
{
|
|
kernel[i] /= sum;
|
|
}
|
|
|
|
return kernel;
|
|
}
|
|
|
|
template <typename T>
|
|
void applyGaussianFilter(T* result, float const* input, int n, float sigma)
|
|
{
|
|
int size = static_cast<int>(std::ceil(6.f * sigma));
|
|
size = (size % 2 == 0) ? size + 1 : size;
|
|
|
|
std::vector<float> kernel = calculateGaussianKernel(sigma, size);
|
|
int halfSize = size / 2;
|
|
|
|
for (int i = 0; i < n; ++i)
|
|
{
|
|
result[i] = T{0};
|
|
}
|
|
|
|
// Convolution operation
|
|
for (int i = 0; i < n; ++i)
|
|
{
|
|
for (int j = 0; j < size; ++j)
|
|
{
|
|
int k = i - halfSize + j;
|
|
if (k >= 0 && k < n)
|
|
{
|
|
result[i] += input[k] * kernel[j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template void applyGaussianFilter(float* result, float const* input, int n, float sigma);
|
|
template void applyGaussianFilter(__half* result, float const* input, int n, float sigma);
|
|
|
|
template <typename T>
|
|
void probsToLogits(T const* probs, T* logits, SizeType n)
|
|
{
|
|
constexpr float eps = 1e-6f;
|
|
for (SizeType ni = 0; ni < n; ++ni)
|
|
{
|
|
auto const prob = std::max(eps, static_cast<float>(probs[ni]));
|
|
logits[ni] = std::log(prob / (1.f - prob));
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void softmax(T const* logits, T* probs, int n)
|
|
{
|
|
float epsilon = 1e-6f;
|
|
|
|
// Find the maximum logit value
|
|
float maxLogits = -std::numeric_limits<float>::max();
|
|
for (int ii = 0; ii < n; ++ii)
|
|
{
|
|
maxLogits = std::max(maxLogits, static_cast<float>(logits[ii]));
|
|
}
|
|
|
|
// Calculate the numerator of the softmax formula
|
|
float expSum = 0.0;
|
|
for (int ii = 0; ii < n; ++ii)
|
|
{
|
|
expSum += std::exp(static_cast<float>(logits[ii]) - maxLogits);
|
|
}
|
|
|
|
// Calculate softmax probabilities
|
|
for (int ii = 0; ii < n; ++ii)
|
|
{
|
|
float prob = std::exp(static_cast<float>(logits[ii]) - maxLogits) / (expSum + epsilon);
|
|
probs[ii] = prob;
|
|
}
|
|
}
|
|
|
|
template void probsToLogits(float const* probs, float* logits, SizeType n);
|
|
template void probsToLogits(__half const* probs, __half* logits, SizeType n);
|
|
|
|
enum AcceptKernelMode
|
|
{
|
|
BY_IDS,
|
|
BY_LOGITS,
|
|
BY_IDS_WITH_PATH
|
|
};
|
|
|
|
struct DecodingKernelTestParam
|
|
{
|
|
SizeType mBatchSize{128};
|
|
SizeType mMaxBatchSize{2 * mBatchSize};
|
|
SizeType mBeamWidth{1};
|
|
SizeType mMaxSeqLen{16};
|
|
SizeType mVocabSize{32};
|
|
SizeType mMaxDraftTokens{8};
|
|
SizeType mMaxNumHeads{0};
|
|
SizeType mMaxDraftSeqPerStep{1};
|
|
AcceptKernelMode mAcceptMode{AcceptKernelMode::BY_IDS};
|
|
|
|
DecodingKernelTestParam& setBatchSize(SizeType bs)
|
|
{
|
|
mBatchSize = bs;
|
|
mMaxBatchSize = 2 * mBatchSize;
|
|
return *this;
|
|
}
|
|
|
|
DecodingKernelTestParam& setVocabSize(SizeType vs)
|
|
{
|
|
mVocabSize = vs;
|
|
return *this;
|
|
}
|
|
|
|
DecodingKernelTestParam& setMaxSeqLen(SizeType msl)
|
|
{
|
|
mMaxSeqLen = msl;
|
|
return *this;
|
|
}
|
|
|
|
DecodingKernelTestParam& setMaxDraftTokens(SizeType dt)
|
|
{
|
|
mMaxDraftTokens = dt;
|
|
return *this;
|
|
}
|
|
|
|
DecodingKernelTestParam& setMaxNumHeads(SizeType mnh)
|
|
{
|
|
mMaxNumHeads = mnh;
|
|
return *this;
|
|
}
|
|
|
|
DecodingKernelTestParam& setMaxDraftSeqPerStep(SizeType tps)
|
|
{
|
|
mMaxDraftSeqPerStep = tps;
|
|
return *this;
|
|
}
|
|
|
|
DecodingKernelTestParam& setAcceptMode(AcceptKernelMode const& mode)
|
|
{
|
|
mAcceptMode = mode;
|
|
return *this;
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
class DecodingKernelsTest : public testing::Test
|
|
{
|
|
public:
|
|
using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr;
|
|
|
|
void SetUp() override
|
|
{
|
|
mStream = std::make_shared<tensorrt_llm::runtime::CudaStream>();
|
|
mBufferManager = std::make_shared<tensorrt_llm::runtime::BufferManager>(mStream);
|
|
}
|
|
|
|
void TearDown() override {}
|
|
|
|
void createBuffers()
|
|
{
|
|
auto const dataType = TRTDataType<T>::value;
|
|
auto const ptrType = TRTDataType<T*>::value;
|
|
|
|
mDraftTokens
|
|
= BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqlen}), nvinfer1::DataType::kINT32);
|
|
mTargetTokens
|
|
= BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, mMaxTargetSeqlen}), nvinfer1::DataType::kINT32);
|
|
mOutputTokens
|
|
= BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kINT32);
|
|
mNumsDraftTokens = BufferManager::pinned(
|
|
ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqPerStep}), nvinfer1::DataType::kINT32);
|
|
mSequenceLengths = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
|
|
mAcceptedLengths = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
|
|
mContextLengths = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
|
|
mFinishedSteps = BufferManager::pinned(ITensor::makeShape({mMaxDraftTokens + 1, mMaxBatchSize}),
|
|
TRTDataType<tk::FinishedState::UnderlyingType>::value);
|
|
mFinishedFinal = BufferManager::pinned(
|
|
ITensor::makeShape({mMaxBatchSize}), TRTDataType<tk::FinishedState::UnderlyingType>::value);
|
|
mFinishedSum = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
|
|
|
|
mPaths = BufferManager::pinned(
|
|
ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqPerStep, mMaxDraftTokens}), nvinfer1::DataType::kINT32);
|
|
mEndIds = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
|
|
|
|
mBatchSlots = BufferManager::pinned(ITensor::makeShape({mBatchSize}), nvinfer1::DataType::kINT32);
|
|
|
|
mCurandStates = mBufferManager->gpu(
|
|
ITensor::makeShape({mMaxBatchSize, sizeof(curandState_t)}), nvinfer1::DataType::kINT8);
|
|
|
|
mAcceptedLen.resize(mMaxBatchSize);
|
|
mOutputLen.resize(mMaxBatchSize);
|
|
mAcceptedFinished.resize(mMaxBatchSize, tk::FinishedState::empty());
|
|
|
|
// Buffers only for Logits comparison
|
|
if (mAcceptMode == AcceptKernelMode::BY_LOGITS)
|
|
{
|
|
mDraftLogits = BufferManager::pinned(
|
|
ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType);
|
|
mTargetLogits = BufferManager::pinned(
|
|
ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType);
|
|
mTargetLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), ptrType);
|
|
mRefTargetLogits = BufferManager::pinned(
|
|
ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType);
|
|
|
|
mDraftProbs = BufferManager::pinned(
|
|
ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType);
|
|
mTargetProbs = BufferManager::pinned(
|
|
ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType);
|
|
}
|
|
|
|
if (mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH)
|
|
{
|
|
mMedusaLogitsPtrs = BufferManager::pinned(
|
|
ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqPerStep, mMaxNumHeads}), ptrType);
|
|
mMedusaInputLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, mMaxNumHeads}), ptrType);
|
|
mTokensPerStep = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
|
|
mBestPaths = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
|
|
}
|
|
}
|
|
|
|
void initData(SizeType seed)
|
|
{
|
|
std::mt19937 generator(seed);
|
|
std::uniform_int_distribution<SizeType> contextLenDistr(0, std::max(mMaxSeqLen - mMaxTotalDraftTokens, 0));
|
|
std::uniform_int_distribution<SizeType> numTotalDraftTokensDistr(1, mMaxTotalDraftTokens);
|
|
std::uniform_int_distribution<SizeType> numDraftTokensDistr(0, mMaxDraftTokens);
|
|
std::uniform_int_distribution<SizeType> vocabDistr(1, mVocabSize - 1);
|
|
std::uniform_real_distribution<float> acceptTokenDistr(0.f, 1.f);
|
|
|
|
trk::invokeFill(*mPaths, int32_t{-1}, *mStream);
|
|
trk::invokeFill(*mFinishedFinal, tk::FinishedState::UnderlyingType{0}, *mStream);
|
|
|
|
auto sequenceLengthsPtr = BufferRange<SizeType>(*mSequenceLengths);
|
|
auto contextLengthsPtr = BufferRange<SizeType>(*mContextLengths);
|
|
auto numsDraftTokensPtr = BufferRange<SizeType>(*mNumsDraftTokens);
|
|
auto draftTokensPtr = BufferRange<SizeType>(*mDraftTokens);
|
|
auto targetTokensPtr = BufferRange<SizeType>(*mTargetTokens);
|
|
auto finishedStepsPtr
|
|
= reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedSteps));
|
|
auto pathsPtr = BufferRange<SizeType>(*mPaths);
|
|
auto endIdsPtr = BufferRange<SizeType>(*mEndIds);
|
|
|
|
auto batchSlotsPtr = BufferRange<SizeType>(*mBatchSlots);
|
|
|
|
tk::invokeCurandInitialize(reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStates)), nullptr,
|
|
mMaxBatchSize, seed, this->mStream->get());
|
|
|
|
auto generateAvoidingValues
|
|
= [&vocabDistr, &generator](std::uniform_int_distribution<SizeType>& distr,
|
|
std::unordered_set<SizeType> const& tokensToAvoid, SizeType maxTries = -1, SizeType defaultValue = -1)
|
|
{
|
|
// Avoid generating endId.
|
|
auto token = distr(generator);
|
|
SizeType tries = 0;
|
|
while (tokensToAvoid.count(token) != 0 && ((maxTries >= 0 && tries < maxTries) || maxTries < 0))
|
|
{
|
|
token = distr(generator);
|
|
tries++;
|
|
}
|
|
if (tries == maxTries)
|
|
{
|
|
token = defaultValue;
|
|
}
|
|
return token;
|
|
};
|
|
|
|
// Init batch slots
|
|
for (SizeType bi = 0; bi < mBatchSize; ++bi)
|
|
{
|
|
batchSlotsPtr[bi] = 2 * bi;
|
|
}
|
|
|
|
// Init end ids
|
|
for (SizeType bi = 0; bi < mMaxBatchSize; ++bi)
|
|
{
|
|
endIdsPtr[bi] = generateAvoidingValues(vocabDistr, {mPadId});
|
|
TLLM_LOG_DEBUG("bi %d endIdsPtr[bi] %d", bi, endIdsPtr[bi]);
|
|
|
|
// Randomly init context len for target and draft
|
|
contextLengthsPtr[bi] = contextLenDistr(generator);
|
|
}
|
|
|
|
std::fill(draftTokensPtr.begin(), draftTokensPtr.begin() + mMaxBatchSize * mMaxDraftSeqlen, mPadId);
|
|
std::fill(targetTokensPtr.begin(), targetTokensPtr.begin() + mMaxBatchSize * mMaxTargetSeqlen, mPadId);
|
|
std::fill(pathsPtr.begin(), pathsPtr.begin() + mMaxBatchSize * mMaxDraftSeqPerStep * mMaxDraftTokens, -1);
|
|
|
|
// Generate paths
|
|
for (SizeType bi = 0; bi < mMaxBatchSize; ++bi)
|
|
{
|
|
auto const numTotalDraftTokens = std::min(mMaxDraftTokens, numTotalDraftTokensDistr(generator));
|
|
std::uniform_int_distribution<SizeType> pathIdDistr(0, numTotalDraftTokens);
|
|
for (SizeType pi = 0; pi < mMaxDraftSeqPerStep; ++pi)
|
|
{
|
|
std::unordered_set<SizeType> pathIds;
|
|
auto const numDraftTokensAtStep = numDraftTokensDistr(generator);
|
|
numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + pi] = numDraftTokensAtStep;
|
|
|
|
for (SizeType ti = 0; ti < numDraftTokensAtStep; ++ti)
|
|
{
|
|
auto const pathIdx = tc::flat_index3(bi, pi, ti, mMaxDraftSeqPerStep, mMaxDraftTokens);
|
|
// Single linear path for BY_IDS and BY_LOGITS modes
|
|
auto const pathId = mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH && ti != 0
|
|
? generateAvoidingValues(pathIdDistr, pathIds, mMaxDraftTokens * 5, -1)
|
|
: ti;
|
|
pathsPtr[pathIdx] = pathId;
|
|
pathIds.insert(pathId);
|
|
}
|
|
if (bi == 2)
|
|
{
|
|
TLLM_LOG_DEBUG("bi %d pi %d numsDraftTokensPtr[bi] %d", bi, pi, numDraftTokensAtStep);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (SizeType ti = 0; ti < mMaxDraftSeqPerStep; ++ti)
|
|
{
|
|
std::vector<SizeType> targetPredictedLen(mMaxBatchSize);
|
|
std::vector<SizeType> targetAcceptedLen(mMaxBatchSize);
|
|
|
|
// Init number of draft tokens
|
|
for (SizeType bi = 0; bi < mMaxBatchSize; ++bi)
|
|
{
|
|
// It can be shorter than num of draft tokens due to the EOS generation
|
|
std::uniform_int_distribution<SizeType> realDraftTokensDistr(
|
|
0, numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]);
|
|
targetPredictedLen[bi] = realDraftTokensDistr(generator);
|
|
// Accept ~ half of the tokens on avergae
|
|
std::poisson_distribution<SizeType> targetAcceptedDistr(targetPredictedLen[bi] / 2);
|
|
targetAcceptedLen[bi] = std::min(targetAcceptedDistr(generator), targetPredictedLen[bi]);
|
|
if (bi == 2)
|
|
{
|
|
TLLM_LOG_DEBUG("bi %d ti %d targetPredictedLen[bi] %d targetAcceptedLen[bi] %d", bi, ti,
|
|
targetPredictedLen[bi], targetAcceptedLen[bi]);
|
|
}
|
|
}
|
|
|
|
// Fill draft tokens
|
|
for (SizeType bi = 0; bi < mMaxBatchSize; ++bi)
|
|
{
|
|
for (SizeType si = 0; si < numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]; ++si)
|
|
{
|
|
auto const pathIdx = tc::flat_index3(bi, ti, si, mMaxDraftSeqPerStep, mMaxDraftTokens);
|
|
if (pathsPtr[pathIdx] == -1)
|
|
{
|
|
continue;
|
|
}
|
|
auto const draftTokenIdx = bi * mMaxDraftSeqlen + pathsPtr[pathIdx];
|
|
// Avoid generating endId. We'll insert in manually later if needed.
|
|
draftTokensPtr[draftTokenIdx] = generateAvoidingValues(vocabDistr, {mPadId, endIdsPtr[bi]});
|
|
if (bi == 2)
|
|
{
|
|
TLLM_LOG_DEBUG("bi %d ti %d si %d pathId %d draftToken %d", bi, ti, si, pathsPtr[pathIdx],
|
|
draftTokensPtr[draftTokenIdx]);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (SizeType bi = 0; bi < mMaxBatchSize; ++bi)
|
|
{
|
|
sequenceLengthsPtr[bi] = contextLengthsPtr[bi] + targetPredictedLen[bi];
|
|
|
|
// Initialize finished states
|
|
for (int di = 0; di < numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]; ++di)
|
|
{
|
|
finishedStepsPtr[di * mMaxBatchSize + bi]
|
|
= (di < targetPredictedLen[bi]) ? tk::FinishedState::empty() : tk::FinishedState::finished();
|
|
}
|
|
|
|
// Init helper vectors
|
|
mAcceptedLen[bi] = contextLengthsPtr[bi] + std::max(targetAcceptedLen[bi], 0);
|
|
mOutputLen[bi] = std::min(sequenceLengthsPtr[bi], std::min(mAcceptedLen[bi] + 1, mMaxSeqLen));
|
|
mAcceptedFinished[bi] = finishedStepsPtr[std::max(targetAcceptedLen[bi], 0) * mMaxBatchSize + bi];
|
|
if (bi == 2)
|
|
{
|
|
TLLM_LOG_DEBUG(
|
|
"bi %d ti %d contextLengthsPtr[bi] %d sequenceLengthsPtr[bi] %d mAcceptedLen[bi] %d "
|
|
"mOutputLen[bi] "
|
|
"%d",
|
|
bi, ti, contextLengthsPtr[bi], sequenceLengthsPtr[bi], mAcceptedLen[bi], mOutputLen[bi]);
|
|
}
|
|
}
|
|
|
|
// Fill token arrays
|
|
for (SizeType bi = 0; bi < mMaxBatchSize; ++bi)
|
|
{
|
|
// Draft: [d0, d1, d2, ... for numsDraftTokensPtr[bi] ... , dK,
|
|
// padId, padId, .. to mMaxDraftSeqlen]
|
|
// Target: [padId, padId, ... for contextLengthsPtr[bi] ... padId,
|
|
// d0, d1, d2, ... for targetAcceptedLen[bi],
|
|
// ti (!= di), ti+1 (!= di+1), ... for (targetPredictedLen[bi] - targetAcceptedLen[bi]),
|
|
// EOS, EOS, EOS, ... for (numsDraftTokensPtr[bi] - targetPredictedLen[bi])
|
|
// padId, padId, .. to mMaxSeqLen]
|
|
auto numDraftTokens = numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti];
|
|
for (SizeType si = 0; si < numDraftTokens; ++si)
|
|
{
|
|
auto const curPathIdx = tc::flat_index3(bi, ti, si, mMaxDraftSeqPerStep, mMaxDraftTokens);
|
|
auto const nextPathIdx = si + 1 < numDraftTokens
|
|
? tc::flat_index3(bi, ti, si + 1, mMaxDraftSeqPerStep, mMaxDraftTokens)
|
|
: -1;
|
|
auto const curPathId = pathsPtr[curPathIdx];
|
|
auto nextPathId = curPathId;
|
|
if (mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH)
|
|
{
|
|
nextPathId = nextPathIdx > -1 ? pathsPtr[nextPathIdx] : -1;
|
|
}
|
|
|
|
if (curPathId == -1)
|
|
{
|
|
continue;
|
|
}
|
|
auto const contextLen
|
|
= mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH ? 0 : contextLengthsPtr[bi];
|
|
auto const draftTokenIdx = bi * mMaxDraftSeqlen + nextPathId;
|
|
auto const targetTokenIdx = bi * mMaxTargetSeqlen + contextLen + curPathId;
|
|
auto targetToken = mPadId;
|
|
if (0 <= si && si < targetAcceptedLen[bi] && nextPathId != -1)
|
|
{
|
|
// Use draft token up to the accepted len
|
|
targetToken = draftTokensPtr[draftTokenIdx];
|
|
}
|
|
else if (0 <= si && si < targetPredictedLen[bi])
|
|
{
|
|
// Do not use draft token token up to the generated len
|
|
std::unordered_set<SizeType> avoidValues = {mPadId, endIdsPtr[bi]};
|
|
if (nextPathId != -1)
|
|
{
|
|
avoidValues.insert(draftTokensPtr[draftTokenIdx]);
|
|
}
|
|
targetToken = generateAvoidingValues(vocabDistr, avoidValues);
|
|
}
|
|
else if (targetPredictedLen[bi] <= si && si < numsDraftTokensPtr[bi])
|
|
{
|
|
// Fill with EOS from generated len to the draft len
|
|
targetToken = endIdsPtr[bi];
|
|
}
|
|
targetTokensPtr[targetTokenIdx] = targetToken;
|
|
if (bi == 2)
|
|
{
|
|
TLLM_LOG_DEBUG(
|
|
"bi %d ti %d si %d pathId %d targetToken %d", bi, ti, si, curPathId, targetToken);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (mAcceptMode == AcceptKernelMode::BY_LOGITS)
|
|
{
|
|
initDataAndReferenceAcceptByLogits();
|
|
}
|
|
|
|
if (mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH)
|
|
{
|
|
initDataAndReferenceAcceptByIdsWithPaths();
|
|
}
|
|
mSequenceLengthsCopy = mBufferManager->copyFrom(*mSequenceLengths, MemoryType::kCPU);
|
|
}
|
|
|
|
void initDataAndReferenceAcceptByIdsWithPaths()
|
|
{
|
|
auto const dataType = TRTDataType<T>::value;
|
|
auto const ptrType = TRTDataType<T*>::value;
|
|
|
|
auto pathsPtr = BufferRange<SizeType>(*mPaths);
|
|
auto endIdsPtr = BufferRange<SizeType>(*mEndIds);
|
|
auto contextLengthsPtr = BufferRange<SizeType>(*mContextLengths);
|
|
auto draftTokensPtr = BufferRange<SizeType>(*mDraftTokens);
|
|
auto targetTokensPtr = BufferRange<SizeType>(*mTargetTokens);
|
|
auto medusaInputLogitsPtr = BufferRange<T*>(*mMedusaInputLogitsPtrs);
|
|
|
|
trk::invokeFill(*mMedusaLogitsPtrs, int64_t{0}, *mStream);
|
|
trk::invokeFill(*mTokensPerStep, int32_t{mMaxTotalDraftTokens}, *mStream);
|
|
trk::invokeFill(*mBestPaths, int32_t{-1}, *mStream);
|
|
|
|
mAcceptedLen.resize(mMaxBatchSize);
|
|
mAcceptedPathIdx.resize(mMaxBatchSize);
|
|
mRefAcceptedTokens.resize(mMaxBatchSize);
|
|
mFinishedByIdsPaths.resize(mMaxBatchSize);
|
|
mLastTargetIdx.resize(mMaxBatchSize);
|
|
for (SizeType bi = 0; bi < mMaxBatchSize; ++bi)
|
|
{
|
|
SizeType maxAcceptedLen = -1;
|
|
SizeType maxAcceptedPath = -1;
|
|
SizeType maxNextTargetTokenIdx = -1;
|
|
bool maxFinished = false;
|
|
std::vector<SizeType> maxAcceptedTokens;
|
|
for (SizeType ti = 0; ti < mMaxDraftSeqPerStep; ++ti)
|
|
{
|
|
std::vector<SizeType> acceptedTokens;
|
|
SizeType curAcceptedLen = mMaxDraftTokens;
|
|
SizeType curAcceptedPath = ti;
|
|
bool curFinished = false;
|
|
|
|
auto const pathIdx = tc::flat_index3(bi, ti, 0, mMaxDraftSeqPerStep, mMaxDraftTokens);
|
|
auto const pathId = pathsPtr[pathIdx];
|
|
if (pathId == -1)
|
|
{
|
|
continue;
|
|
}
|
|
auto targetTokenIdx = bi * mMaxTargetSeqlen + pathId;
|
|
auto targetToken = targetTokensPtr[targetTokenIdx];
|
|
auto curNextTargetTokenIdx = pathId;
|
|
for (SizeType di = 1; di < mMaxDraftTokens; ++di)
|
|
{
|
|
auto const pathIdx = tc::flat_index3(bi, ti, di, mMaxDraftSeqPerStep, mMaxDraftTokens);
|
|
auto const pathId = pathsPtr[pathIdx];
|
|
if (pathId == -1)
|
|
{
|
|
curAcceptedLen = di;
|
|
curAcceptedPath = ti;
|
|
curFinished = false;
|
|
acceptedTokens.push_back(targetToken);
|
|
break;
|
|
}
|
|
auto const draftTokenIdx = bi * mMaxDraftSeqlen + pathId - 1;
|
|
auto const targetTokenIdx = bi * mMaxTargetSeqlen + pathId;
|
|
auto const draftToken = draftTokensPtr[draftTokenIdx];
|
|
bool const hasEnd = targetToken == endIdsPtr[bi];
|
|
if (!hasEnd)
|
|
{
|
|
acceptedTokens.push_back(targetToken);
|
|
}
|
|
if (draftToken != targetToken || hasEnd)
|
|
{
|
|
auto const curLen = hasEnd ? di - 1 : di;
|
|
curAcceptedLen = curLen;
|
|
curAcceptedPath = ti;
|
|
curFinished = hasEnd;
|
|
break;
|
|
}
|
|
targetToken = targetTokensPtr[targetTokenIdx];
|
|
curNextTargetTokenIdx = pathId;
|
|
}
|
|
if (curAcceptedLen == mMaxDraftTokens)
|
|
{
|
|
acceptedTokens.push_back(targetToken);
|
|
}
|
|
if (curAcceptedLen > maxAcceptedLen)
|
|
{
|
|
maxAcceptedLen = curAcceptedLen;
|
|
maxAcceptedPath = curAcceptedPath;
|
|
maxAcceptedTokens = acceptedTokens;
|
|
maxFinished = curFinished;
|
|
maxNextTargetTokenIdx = curNextTargetTokenIdx;
|
|
}
|
|
}
|
|
mAcceptedLen[bi] = maxAcceptedLen;
|
|
mAcceptedPathIdx[bi] = maxAcceptedPath;
|
|
mRefAcceptedTokens[bi] = maxAcceptedTokens;
|
|
mFinishedByIdsPaths[bi] = maxFinished;
|
|
mLastTargetIdx[bi] = maxNextTargetTokenIdx;
|
|
for (SizeType hi = 0; hi < mMaxNumHeads; ++hi)
|
|
{
|
|
medusaInputLogitsPtr[bi * mMaxNumHeads + hi] = static_cast<T*>(nullptr)
|
|
+ tc::flat_index4(hi, bi, 0, 0, mMaxBatchSize, mMaxDraftSeqPerStep, mVocabSize);
|
|
}
|
|
if (bi == 2)
|
|
{
|
|
TLLM_LOG_DEBUG("bi %d maxAcceptedLen %d maxAcceptedPath %d maxNextTargetTokenIdx %d", bi,
|
|
maxAcceptedLen, maxAcceptedPath, maxNextTargetTokenIdx);
|
|
std::ostringstream ss;
|
|
for (auto& tk : maxAcceptedTokens)
|
|
{
|
|
ss << tk << " ";
|
|
}
|
|
TLLM_LOG_DEBUG(ss.str().c_str());
|
|
}
|
|
}
|
|
}
|
|
|
|
void initDataAndReferenceAcceptByLogits()
|
|
{
|
|
auto contextLengthsPtr = BufferRange<SizeType>(*mContextLengths);
|
|
auto numsDraftTokensPtr = BufferRange<SizeType>(*mNumsDraftTokens);
|
|
auto draftTokensPtr = BufferRange<SizeType>(*mDraftTokens);
|
|
auto targetTokensPtr = BufferRange<SizeType>(*mTargetTokens);
|
|
|
|
auto draftProbsPtr = BufferRange<T>(*mDraftProbs);
|
|
auto targetProbsPtr = BufferRange<T>(*mTargetProbs);
|
|
|
|
auto draftLogitsPtr = BufferRange<T>(*mDraftLogits);
|
|
auto targetLogitsPtr = BufferRange<T>(*mTargetLogits);
|
|
auto targetLogitsPtrsPtr = BufferRange<T*>(*mTargetLogitsPtrs);
|
|
auto refTargetLogitsPtr = BufferRange<T>(*mRefTargetLogits);
|
|
auto batchSlotsPtr = BufferRange<SizeType>(*mBatchSlots);
|
|
|
|
for (SizeType bi = 0; bi < mMaxBatchSize; ++bi)
|
|
{
|
|
// Init draft and target logits and probabilities
|
|
for (SizeType si = 0; si < numsDraftTokensPtr[bi]; ++si)
|
|
{
|
|
std::vector<float> peakDraftProb(mVocabSize, 0.f);
|
|
std::vector<float> peakTargetProb(mVocabSize, 0.f);
|
|
|
|
auto const targetToken = targetTokensPtr[bi * mMaxSeqLen + contextLengthsPtr[bi] + si] % mVocabSize;
|
|
auto const draftToken = draftTokensPtr[bi * mMaxDraftTokens + si] % mVocabSize;
|
|
|
|
peakDraftProb[draftToken] = 1.f;
|
|
peakTargetProb[targetToken] = 1.f;
|
|
|
|
auto const logitsOffset = bi * mMaxDraftTokens * mVocabSize + si * mVocabSize;
|
|
// Emulate some distribution around target token
|
|
applyGaussianFilter(
|
|
draftProbsPtr.begin() + logitsOffset, peakDraftProb.data(), peakDraftProb.size(), 1.0f);
|
|
applyGaussianFilter(
|
|
targetProbsPtr.begin() + logitsOffset, peakTargetProb.data(), peakTargetProb.size(), 1.0f);
|
|
|
|
// Probabilities to logits
|
|
probsToLogits(draftProbsPtr.begin() + logitsOffset, draftLogitsPtr.begin() + logitsOffset, mVocabSize);
|
|
probsToLogits(
|
|
targetProbsPtr.begin() + logitsOffset, targetLogitsPtr.begin() + logitsOffset, mVocabSize);
|
|
|
|
// Do softmax conversion back to emulate kernels accuracy
|
|
softmax(draftLogitsPtr.begin() + logitsOffset, draftProbsPtr.begin() + logitsOffset, mVocabSize);
|
|
softmax(targetLogitsPtr.begin() + logitsOffset, targetProbsPtr.begin() + logitsOffset, mVocabSize);
|
|
}
|
|
}
|
|
|
|
for (SizeType bi = 0; bi < mMaxBatchSize; ++bi)
|
|
{
|
|
for (SizeType si = 0; si < mMaxDraftTokens; ++si)
|
|
{
|
|
auto const logitsOffset = bi * mMaxDraftTokens * mVocabSize + si * mVocabSize;
|
|
auto const outputLen = mOutputLen[bi] - contextLengthsPtr[bi];
|
|
auto const acceptedLen = mAcceptedLen[bi] - contextLengthsPtr[bi];
|
|
if (si < acceptedLen)
|
|
{
|
|
auto logitsStart = targetLogitsPtr.begin() + logitsOffset;
|
|
std::copy(logitsStart, logitsStart + mVocabSize, refTargetLogitsPtr.begin() + logitsOffset);
|
|
}
|
|
else if (si == acceptedLen)
|
|
{
|
|
// When token is not accepted, correct probabilities and compute updated logits
|
|
float sumProb = 1e-6f;
|
|
for (SizeType vi = 0; vi < mVocabSize; ++vi)
|
|
{
|
|
auto const correctedProb = std::max(
|
|
static_cast<float>(targetProbsPtr[logitsOffset + vi] - draftProbsPtr[logitsOffset + vi]),
|
|
0.f);
|
|
sumProb += correctedProb;
|
|
}
|
|
for (SizeType vi = 0; vi < mVocabSize; ++vi)
|
|
{
|
|
auto prob = std::max(static_cast<float>(
|
|
targetProbsPtr[logitsOffset + vi] - draftProbsPtr[logitsOffset + vi]),
|
|
0.f)
|
|
/ sumProb;
|
|
if (prob < 1e-8)
|
|
{
|
|
prob = 0.f;
|
|
}
|
|
refTargetLogitsPtr[logitsOffset + vi] = std::log(prob / (1.f - prob));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
for (SizeType bi = 0; bi < mBatchSize; ++bi)
|
|
{
|
|
targetLogitsPtrsPtr[bi] = targetLogitsPtr.begin() + batchSlotsPtr[bi] * mMaxDraftTokens * mVocabSize;
|
|
}
|
|
}
|
|
|
|
void callAcceptByIds()
|
|
{
|
|
tk::invokeAcceptDraftTokensByIds(bufferCast<SizeType>(*mDraftTokens), bufferCast<SizeType>(*mTargetTokens),
|
|
bufferCast<SizeType>(*mContextLengths), bufferCast<SizeType>(*mNumsDraftTokens),
|
|
bufferCast<SizeType>(*mSequenceLengths),
|
|
reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedSteps)),
|
|
reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal)),
|
|
bufferCast<SizeType>(*mFinishedSum), bufferCast<SizeType>(*mBatchSlots), mBatchSize, mMaxBatchSize,
|
|
mBeamWidth, mMaxSeqLen, mMaxDraftTokens, mStream->get());
|
|
}
|
|
|
|
void callAcceptByLogits()
|
|
{
|
|
tk::acceptDraftTokensByLogits(bufferCast<T>(*mDraftLogits),
|
|
reinterpret_cast<T**>(bufferCast<int64_t>(*mTargetLogitsPtrs)), bufferCast<T>(*mDraftProbs),
|
|
bufferCast<T>(*mTargetProbs), bufferCast<SizeType>(*mNumsDraftTokens),
|
|
reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedSteps)),
|
|
reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStates)), bufferCast<SizeType>(*mBatchSlots),
|
|
mBatchSize, mMaxBatchSize, mBeamWidth, mVocabSize, mVocabSize, mMaxDraftTokens, false, 0.9f,
|
|
mStream->get());
|
|
}
|
|
|
|
void callAcceptByIdsWithPaths()
|
|
{
|
|
tk::acceptDraftTokensByIdsWithPaths(bufferCast<SizeType>(*mOutputTokens), bufferCast<SizeType>(*mDraftTokens),
|
|
bufferCast<SizeType>(*mTargetTokens), bufferCast<SizeType>(*mSequenceLengths),
|
|
bufferCast<SizeType>(*mAcceptedLengths),
|
|
reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal)),
|
|
bufferCast<SizeType>(*mBatchSlots), bufferCast<SizeType>(*mPaths), bufferCast<SizeType>(*mEndIds),
|
|
reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs)),
|
|
reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaLogitsPtrs)), bufferCast<SizeType>(*mTokensPerStep),
|
|
bufferCast<SizeType>(*mTokensPerStep), bufferCast<SizeType>(*mBestPaths), mBatchSize, mVocabSize,
|
|
mMaxBatchSize, mMaxTargetSeqlen, mMaxSeqLen, mMaxNumHeads, mMaxDraftSeqPerStep, mStream->get());
|
|
}
|
|
|
|
void callTestedKernel()
|
|
{
|
|
switch (mAcceptMode)
|
|
{
|
|
case AcceptKernelMode::BY_IDS: callAcceptByIds(); break;
|
|
case AcceptKernelMode::BY_LOGITS: callAcceptByLogits(); break;
|
|
case AcceptKernelMode::BY_IDS_WITH_PATH: callAcceptByIdsWithPaths(); break;
|
|
default: TLLM_CHECK(false); // Should never be here
|
|
}
|
|
}
|
|
|
|
void verifyAcceptByIdsResults(SizeType seed)
|
|
{
|
|
auto finishedFinalPtr
|
|
= reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal));
|
|
auto sequenceLengthsPtr = BufferRange<SizeType>(*mSequenceLengths);
|
|
auto finishedSumPtr = BufferRange<SizeType>(*mFinishedSum);
|
|
auto batchSlotsPtr = BufferRange<SizeType>(*mBatchSlots);
|
|
// Verify seqLen for accepted tokens
|
|
for (SizeType bi = 0; bi < mBatchSize; ++bi)
|
|
{
|
|
auto const batchSlot = batchSlotsPtr[bi];
|
|
EXPECT_EQ(mOutputLen[batchSlot], sequenceLengthsPtr[batchSlot]) << " bi " << bi << " seed " << seed;
|
|
EXPECT_EQ(mAcceptedFinished[batchSlot].isFinished(), finishedFinalPtr[batchSlot].isFinished())
|
|
<< " bi " << bi << " seed " << seed;
|
|
EXPECT_EQ(mAcceptedFinished[batchSlot].isSkipDecoding(), finishedFinalPtr[batchSlot].isSkipDecoding())
|
|
<< " bi " << bi << " seed " << seed;
|
|
EXPECT_EQ(static_cast<SizeType>(mAcceptedFinished[batchSlot].isFinished()), finishedSumPtr[batchSlot]);
|
|
}
|
|
}
|
|
|
|
void verifyAcceptByLogitsResults(SizeType seed)
|
|
{
|
|
auto finishedStepsPtr
|
|
= reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedSteps));
|
|
auto contextLengthsPtr = BufferRange<SizeType>(*mContextLengths);
|
|
auto outLogitsPtr = BufferRange<T>(*mTargetLogits);
|
|
auto refLogitsPtr = BufferRange<T>(*mRefTargetLogits);
|
|
auto numsDraftTokensPtr = BufferRange<SizeType>(*mNumsDraftTokens);
|
|
auto batchSlotsPtr = BufferRange<SizeType>(*mBatchSlots);
|
|
|
|
for (SizeType bi = 0; bi < mBatchSize; ++bi)
|
|
{
|
|
auto const batchSlot = batchSlotsPtr[bi];
|
|
for (SizeType si = 0; si < numsDraftTokensPtr[batchSlot]; ++si)
|
|
{
|
|
auto const outFinishedState = finishedStepsPtr[si * mMaxBatchSize + batchSlot];
|
|
auto const logitsOffset = batchSlot * mMaxDraftTokens * mVocabSize + si * mVocabSize;
|
|
if (si <= mAcceptedLen[batchSlot] - contextLengthsPtr[batchSlot])
|
|
{
|
|
EXPECT_FALSE(outFinishedState.isSkipDecoding())
|
|
<< " bi: " << bi << " si: " << si << " seed: " << seed;
|
|
for (SizeType vi = 0; vi < mVocabSize; ++vi)
|
|
{
|
|
auto const outLogit = static_cast<float>(outLogitsPtr[logitsOffset + vi]);
|
|
auto const refLogit = static_cast<float>(refLogitsPtr[logitsOffset + vi]);
|
|
EXPECT_FALSE((refLogit > -10) ^ (outLogit > -10))
|
|
<< " bi: " << bi << " si: " << si << " vi: " << vi << " seed: " << seed;
|
|
if (refLogit > -10 && outLogit > -10)
|
|
{
|
|
if (!almostEqual(outLogit, refLogit, 1e-1, 1e-2))
|
|
{
|
|
std::cout << refLogit << " " << outLogit << std::endl;
|
|
}
|
|
ASSERT_TRUE(almostEqual(outLogit, refLogit, 1e-1, 1e-2))
|
|
<< " bi: " << bi << " si: " << si << " vi: " << vi << " seed: " << seed;
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
EXPECT_TRUE(outFinishedState.isSkipDecoding())
|
|
<< " bi: " << bi << " si: " << si << " seed: " << seed;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void verifyAcceptByIdsWithPathsResults(SizeType seed)
|
|
{
|
|
auto medusaLogitsPtrsPtr = BufferRange<T*>(*mMedusaLogitsPtrs);
|
|
auto batchSlotsPtr = BufferRange<SizeType>(*mBatchSlots);
|
|
auto draftContextLengths = BufferRange<SizeType>(*mSequenceLengths);
|
|
auto draftContextLengthsInit = BufferRange<SizeType>(*mSequenceLengthsCopy);
|
|
auto acceptedLengths = BufferRange<SizeType>(*mAcceptedLengths);
|
|
auto outputIdsPtr = BufferRange<TokenIdType>(*mOutputTokens);
|
|
auto bestPathIds = BufferRange<SizeType>(*mBestPaths);
|
|
auto finishedFinalPtr
|
|
= reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal));
|
|
|
|
for (SizeType bi = 0; bi < mBatchSize; ++bi)
|
|
{
|
|
auto const batchSlot = batchSlotsPtr[bi];
|
|
auto const bestPathIdx = mAcceptedPathIdx[batchSlot];
|
|
auto const lastTargetIdx = mLastTargetIdx[batchSlot];
|
|
if (lastTargetIdx < 0)
|
|
{
|
|
continue;
|
|
}
|
|
|
|
auto const acceptedLen = mAcceptedLen[batchSlot];
|
|
auto acceptedTokens = mRefAcceptedTokens[batchSlot];
|
|
|
|
EXPECT_EQ(bestPathIds[batchSlot], bestPathIdx) << "bi: " << bi << " seed: " << seed;
|
|
|
|
for (int32_t hi = 0; hi < mMaxNumHeads; ++hi)
|
|
{
|
|
auto refOffset
|
|
= tc::flat_index4(hi, batchSlot, lastTargetIdx, 0, mMaxBatchSize, mMaxDraftSeqPerStep, mVocabSize);
|
|
auto outOffset
|
|
= static_cast<SizeType>(medusaLogitsPtrsPtr[bi * mMaxNumHeads + hi] - static_cast<T*>(nullptr));
|
|
EXPECT_EQ(outOffset, refOffset) << " bi: " << bi << " hi: " << hi << " seed: " << seed;
|
|
}
|
|
EXPECT_EQ(acceptedLengths[batchSlot], acceptedLen) << " bi: " << bi << " seed: " << seed;
|
|
EXPECT_EQ(draftContextLengths[batchSlot], draftContextLengthsInit[batchSlot] + acceptedLen)
|
|
<< " bi: " << bi << " seed: " << seed << " out: " << draftContextLengths[batchSlot]
|
|
<< " ref: " << draftContextLengthsInit[batchSlot] + acceptedLen;
|
|
|
|
for (SizeType ti = 0; ti < acceptedLen; ++ti)
|
|
{
|
|
ASSERT_EQ(mRefAcceptedTokens[batchSlot].size(), acceptedLen)
|
|
<< " bi: " << bi << " ti: " << ti << " seed: " << seed;
|
|
EXPECT_EQ(outputIdsPtr[batchSlot * mMaxSeqLen + draftContextLengthsInit[batchSlot] + ti],
|
|
mRefAcceptedTokens[batchSlot][ti])
|
|
<< " bi: " << bi << " ti: " << ti << " seed: " << seed;
|
|
}
|
|
EXPECT_EQ(finishedFinalPtr[batchSlot].isFinished(), mFinishedByIdsPaths[batchSlot])
|
|
<< " bi: " << bi << " seed: " << seed;
|
|
}
|
|
}
|
|
|
|
void verifyResult(SizeType seed)
|
|
{
|
|
switch (mAcceptMode)
|
|
{
|
|
case AcceptKernelMode::BY_IDS: verifyAcceptByIdsResults(seed); break;
|
|
case AcceptKernelMode::BY_LOGITS: verifyAcceptByLogitsResults(seed); break;
|
|
case AcceptKernelMode::BY_IDS_WITH_PATH: verifyAcceptByIdsWithPathsResults(seed); break;
|
|
default: TLLM_CHECK(false); // Should never be here
|
|
}
|
|
}
|
|
|
|
void runTest(DecodingKernelTestParam const& params)
|
|
{
|
|
mAcceptMode = params.mAcceptMode;
|
|
|
|
mBatchSize = params.mBatchSize;
|
|
mMaxBatchSize = params.mMaxBatchSize;
|
|
mBeamWidth = params.mBeamWidth;
|
|
mVocabSize = params.mVocabSize;
|
|
mMaxDraftTokens = params.mMaxDraftTokens;
|
|
mMaxSeqLen = params.mMaxSeqLen;
|
|
|
|
mMaxNumHeads = params.mMaxNumHeads;
|
|
if (mMaxNumHeads > 1 && mAcceptMode != AcceptKernelMode::BY_IDS_WITH_PATH)
|
|
{
|
|
GTEST_SKIP() << "MaxNumHeads > 1 is only supported for AcceptKernelMode::BY_IDS_WITH_PATH";
|
|
}
|
|
|
|
mMaxDraftSeqPerStep = params.mMaxDraftSeqPerStep;
|
|
if (mMaxDraftSeqPerStep > 1 && mAcceptMode != AcceptKernelMode::BY_IDS_WITH_PATH)
|
|
{
|
|
GTEST_SKIP() << "MaxDraftSeqPerStep > 1 is only supported for AcceptKernelMode::BY_IDS_WITH_PATH";
|
|
}
|
|
|
|
mMaxTotalDraftTokens = mMaxDraftSeqPerStep * mMaxDraftTokens;
|
|
mPadId = mVocabSize - 1;
|
|
|
|
mMaxDraftSeqlen = mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH ? mMaxDraftTokens - 1 : mMaxDraftTokens;
|
|
mMaxTargetSeqlen = mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH ? mMaxDraftTokens : mMaxSeqLen;
|
|
|
|
createBuffers();
|
|
|
|
for (SizeType seed = 0; seed < mSeeds; ++seed)
|
|
{
|
|
// if (seed != 145)
|
|
// {
|
|
// continue;
|
|
// }
|
|
TLLM_LOG_DEBUG("Seed %d", seed);
|
|
|
|
initData(seed);
|
|
|
|
mStream->synchronize();
|
|
|
|
callTestedKernel();
|
|
|
|
mStream->synchronize();
|
|
|
|
verifyResult(seed);
|
|
}
|
|
}
|
|
|
|
protected:
|
|
std::shared_ptr<tensorrt_llm::runtime::BufferManager> mBufferManager;
|
|
std::shared_ptr<tensorrt_llm::runtime::CudaStream> mStream;
|
|
|
|
TensorPtr mDraftTokens;
|
|
TensorPtr mTargetTokens;
|
|
TensorPtr mOutputTokens;
|
|
|
|
TensorPtr mDraftLogits;
|
|
TensorPtr mTargetLogits;
|
|
TensorPtr mTargetLogitsPtrs;
|
|
TensorPtr mRefTargetLogits;
|
|
|
|
TensorPtr mDraftProbs;
|
|
TensorPtr mTargetProbs;
|
|
|
|
TensorPtr mNumsDraftTokens;
|
|
TensorPtr mSequenceLengths;
|
|
TensorPtr mSequenceLengthsCopy;
|
|
TensorPtr mAcceptedLengths;
|
|
TensorPtr mContextLengths;
|
|
TensorPtr mFinishedSteps;
|
|
TensorPtr mFinishedFinal;
|
|
TensorPtr mFinishedSum;
|
|
TensorPtr mBatchSlots;
|
|
|
|
TensorPtr mPaths;
|
|
TensorPtr mEndIds;
|
|
TensorPtr mMedusaLogitsPtrs;
|
|
TensorPtr mMedusaInputLogitsPtrs;
|
|
TensorPtr mTokensPerStep;
|
|
TensorPtr mBestPaths;
|
|
|
|
TensorPtr mCurandStates;
|
|
|
|
std::vector<SizeType> mAcceptedLen;
|
|
std::vector<SizeType> mOutputLen;
|
|
std::vector<tk::FinishedState> mAcceptedFinished;
|
|
std::vector<SizeType> mAcceptedPathIdx;
|
|
std::vector<SizeType> mLastTargetIdx;
|
|
std::vector<std::vector<SizeType>> mRefAcceptedTokens;
|
|
std::vector<bool> mFinishedByIdsPaths;
|
|
|
|
SizeType mBatchSize;
|
|
SizeType mMaxBatchSize;
|
|
SizeType mBeamWidth;
|
|
SizeType mMaxSeqLen;
|
|
SizeType mVocabSize;
|
|
SizeType mMaxDraftTokens;
|
|
SizeType mMaxTotalDraftTokens;
|
|
SizeType mMaxDraftSeqlen;
|
|
SizeType mMaxTargetSeqlen;
|
|
SizeType mMaxNumHeads;
|
|
SizeType mMaxDraftSeqPerStep;
|
|
AcceptKernelMode mAcceptMode;
|
|
SizeType mPadId;
|
|
static constexpr SizeType mSeeds = 64;
|
|
};
|
|
|
|
template class DecodingKernelsTest<float>;
|
|
template class DecodingKernelsTest<half>;
|
|
|
|
typedef testing::Types<float, half> FloatAndHalfTypes;
|
|
|
|
TYPED_TEST_SUITE(DecodingKernelsTest, FloatAndHalfTypes);
|
|
|
|
TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByIdsKernelSmall)
|
|
{
|
|
this->runTest(DecodingKernelTestParam()
|
|
.setBatchSize(1)
|
|
.setMaxSeqLen(16)
|
|
.setVocabSize(32)
|
|
.setMaxDraftTokens(8)
|
|
.setMaxDraftSeqPerStep(1)
|
|
.setAcceptMode(AcceptKernelMode::BY_IDS));
|
|
}
|
|
|
|
TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByIdsKernelLarge)
|
|
{
|
|
this->runTest(DecodingKernelTestParam()
|
|
.setBatchSize(128)
|
|
.setMaxSeqLen(128)
|
|
.setVocabSize(52000)
|
|
.setMaxDraftTokens(8)
|
|
.setMaxDraftSeqPerStep(1)
|
|
.setAcceptMode(AcceptKernelMode::BY_IDS));
|
|
}
|
|
|
|
TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByLogitsKernelSmall)
|
|
{
|
|
this->runTest(DecodingKernelTestParam()
|
|
.setBatchSize(1)
|
|
.setMaxSeqLen(16)
|
|
.setVocabSize(32)
|
|
.setMaxDraftTokens(8)
|
|
.setMaxDraftSeqPerStep(1)
|
|
.setAcceptMode(AcceptKernelMode::BY_LOGITS));
|
|
}
|
|
|
|
TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByLogitsKernelLarge)
|
|
{
|
|
this->runTest(DecodingKernelTestParam()
|
|
.setBatchSize(64)
|
|
.setMaxSeqLen(64)
|
|
.setVocabSize(4000)
|
|
.setMaxDraftTokens(8)
|
|
.setMaxDraftSeqPerStep(1)
|
|
.setAcceptMode(AcceptKernelMode::BY_LOGITS));
|
|
}
|
|
|
|
TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByIdsWithPathsKernelSmall)
|
|
{
|
|
this->runTest(DecodingKernelTestParam()
|
|
.setBatchSize(1)
|
|
.setMaxSeqLen(128)
|
|
.setVocabSize(32)
|
|
.setMaxDraftTokens(3)
|
|
.setMaxDraftSeqPerStep(4)
|
|
.setMaxNumHeads(2)
|
|
.setAcceptMode(AcceptKernelMode::BY_IDS_WITH_PATH));
|
|
}
|
|
|
|
TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByIdsWithPathsKernelLarge)
|
|
{
|
|
this->runTest(DecodingKernelTestParam()
|
|
.setBatchSize(128)
|
|
.setMaxSeqLen(1024)
|
|
.setVocabSize(4000)
|
|
.setMaxDraftTokens(8)
|
|
.setMaxDraftSeqPerStep(64)
|
|
.setMaxNumHeads(7)
|
|
.setAcceptMode(AcceptKernelMode::BY_IDS_WITH_PATH));
|
|
}
|
|
} // end of namespace
|