/* * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/layers/lookaheadAlgorithm.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/layers/decodingParams.h" #include "tensorrt_llm/layers/lookaheadDecodingUtils.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/lookaheadModule.h" #include #include namespace tensorrt_llm::layers { using namespace tensorrt_llm::runtime; LookaheadAlgorithm::LookaheadAlgorithm( runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id) : mMaxW(maxW) , mMaxN(maxN) , mMaxG(maxG) , mFilling(0) , mPoolManager(maxG) , mId(id) , mGoldenTokensMax( runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxN * 2 - 1}), nvinfer1::DataType::kINT32)) , mPrefillsMax(runtime::BufferManager::cpu( runtime::ITensor::makeShape({(maxN <= 1 ? 0 : maxN - 2)}), nvinfer1::DataType::kINT32)) , mKeyTokensMax(runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW}), nvinfer1::DataType::kINT32)) , mPastTokensMax( runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW * (maxN - 1)}), nvinfer1::DataType::kINT32)) , mGuessTokensMax( runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxG * (maxN - 1)}), nvinfer1::DataType::kINT32)) { runtime::SizeType32 maxGeneratedLen, maxDraftLen; std::tie(maxGeneratedLen, std::ignore, maxDraftLen, std::ignore) = executor::LookaheadDecodingConfig(maxW, maxN, maxG).calculateSpeculativeResource(); mAttentionMask = runtime::BufferManager::cpu( runtime::ITensor::makeShape({maxDraftLen, maxDraftLen}), nvinfer1::DataType::kBOOL); mDraftTokensMax = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxDraftLen}), nvinfer1::DataType::kINT32); mSampledTokensMax = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxGeneratedLen}), nvinfer1::DataType::kINT32); mEncodeMapMax = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxDraftLen}), nvinfer1::DataType::kINT32); } void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeType32 n, SizeType32 g, uint64_t seed) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK_WITH_INFO(w <= mMaxW, "lookahead requires setup w (%d) <= max_w (%d)", w, mMaxW); TLLM_CHECK_WITH_INFO(n <= mMaxN, "lookahead requires setup n (%d) <= max_n (%d)", n, mMaxN); TLLM_CHECK_WITH_INFO(g <= mMaxG, "lookahead requires setup g (%d) <= max_g (%d)", g, mMaxG); mW = w; mN = n; mG = g; std::tie(std::ignore, std::ignore, mRuntimeMaxDraftLen, mRuntimeMaxDraftPathLen) = executor::LookaheadDecodingConfig(mW, mN, mG).calculateSpeculativeResource(); mPoolManager.setup(mG); mPoolManager.accept(prompt, mN); mGoldenTokens = ITensor::slice(mGoldenTokensMax, 0, mN * 2 - 1); mPrefills = ITensor::slice(mPrefillsMax, 0, mN <= 1 ? 0 : mN - 2); mKeyTokens = ITensor::slice(mKeyTokensMax, 0, mW); mPastTokens = ITensor::slice(mPastTokensMax, 0, mW * (mN - 1)); mPastTokens->reshape(ITensor::makeShape({mW, mN - 1})); BufferRange promptRange(*prompt); BufferRange prefillRange(*mPrefills); BufferRange pastRange(*mPastTokens); BufferRange goldRange(*mGoldenTokens); srand(seed); auto randToken = [&promptRange](auto& item) { item = promptRange[rand() % promptRange.size()]; }; std::for_each(prefillRange.begin(), prefillRange.end(), randToken); std::for_each(pastRange.begin(), pastRange.end(), [](auto& a) { a = -1; }); for (SizeType32 i = 0; i < mW; i++) { if (mN - 1 > 0) { randToken(pastRange[i * (mN - 1)]); } } std::copy(std::prev(promptRange.end(), mN - 1), promptRange.end(), goldRange.begin()); mGuessTokens = ITensor::slice(mGuessTokensMax, 0, 0); mFilling = (mN - 1) > 0 ? 1 : 0; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void LookaheadAlgorithm::accept(TensorConstPtr const& generatedTokens) { TLLM_CHECK(ITensor::volume(generatedTokens->getShape()) <= mN); BufferRange generatedRange(*generatedTokens); BufferRange goldRange(*mGoldenTokens); auto genLen = generatedTokens->getShape().d[0]; TLLM_CHECK(genLen <= mN); std::copy(generatedRange.begin(), generatedRange.end(), goldRange.begin() + mN - 1); TensorPtr newGold = ITensor::slice(mGoldenTokens, 0, mN - 1 + genLen); mPoolManager.accept(newGold, mN); std::copy(goldRange.begin() + genLen, goldRange.begin() + genLen + mN - 1, goldRange.begin()); } //! lookahead has two phase, prefill the past tokens matrix and maintain past tokens matrix. runtime::SizeType32 LookaheadAlgorithm::lookahead( TensorPtr const& draftTokens, TensorPtr const& positionIds, runtime::SizeType32 startPosId) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); SizeType32 prefill = mN - 2 - mFilling; SizeType32 len = prefill + mFilling * mW; TLLM_CHECK(len <= ITensor::volume(draftTokens->getShape())); TLLM_CHECK(len <= ITensor::volume(positionIds->getShape())); BufferRange prefillRange(*mPrefills); BufferRange pastRange(*mPastTokens); BufferRange draftRange(*draftTokens); PRINT_TOKENS(mPrefills); if (mFilling < mN - 1) { // prefilling std::copy(prefillRange.begin() + mFilling, prefillRange.end(), draftRange.begin()); for (SizeType32 i = 0; i < mW; i++) { auto start = pastRange.begin() + i * (mN - 1); auto end = pastRange.begin() + i * (mN - 1) + mFilling; std::copy(start, end, draftRange.begin() + prefill + i * mFilling); } } else { // shift up std::copy(pastRange.begin() + 1, pastRange.begin() + mFilling * mW, draftRange.begin()); } BufferRange positionIdsRange(*positionIds); SizeType32 idx = 0, wj = 0; auto fillPosition = [&positionIdsRange, &idx](SizeType32 start, SizeType32 len) { for (SizeType32 i = start; i < start + len; i++) { positionIdsRange[idx++] = i; } }; if (prefill >= 0) { fillPosition(startPosId, prefill); for (wj = 0; wj < mW; wj++) { fillPosition(startPosId + prefill + wj, mFilling); } } else { fillPosition(startPosId, mFilling - 1); for (wj = 1; wj < mW; wj++) { fillPosition(startPosId - 1 + wj, mFilling); } } PRINT_VALUES(positionIds); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return len; } runtime::SizeType32 LookaheadAlgorithm::guess(TensorPtr const& guessTokens, TensorPtr const& guessIds, runtime::SizeType32 startPosId, runtime::TokenIdType lastToken) { auto guesses = mPoolManager.guess(lastToken, mW); SizeType32 len = 0; std::for_each(guesses.begin(), guesses.end(), [&len](auto& a) { len += ITensor::volume(a->getShape()); }); TLLM_CHECK(len <= ITensor::volume(guessTokens->getShape())); TLLM_CHECK(len <= ITensor::volume(guessIds->getShape())); BufferRange guessTokensRange(*guessTokens); BufferRange guessIdsRange(*guessIds); SizeType32 cur = 0; for (auto guess : guesses) { BufferRange guessRange(*guess); std::copy(guessRange.begin(), guessRange.end(), guessTokensRange.begin() + cur); SizeType32 tmp = startPosId; std::for_each( guessIdsRange.begin() + cur, guessIdsRange.begin() + cur + mN - 1, [&tmp](auto& v) { v = tmp++; }); cur += ITensor::volume(guess->getShape()); } return len; } void LookaheadAlgorithm::posIdsToMask(TensorPtr const& mask, TensorConstPtr const& posIds) { auto len = ITensor::volume(posIds->getShape()); TLLM_CHECK(mask->getDimension<0>() >= len); TLLM_CHECK(mask->getDimension<1>() >= len); auto posIdsRange = BufferRange(*posIds); auto maskLocation = BufferLocation(*mask); for (auto& item : maskLocation) { item = false; } if (len > 0) { std::vector> stack; for (auto i = 0; i < len; i++) { auto cur = posIdsRange[i]; while (stack.size() > 0 && cur <= stack.back().second) { stack.pop_back(); } TLLM_CHECK(stack.size() > 0 ? cur == stack.back().second + 1 : true); stack.push_back(std::make_pair(i, cur)); for (auto prev : stack) { maskLocation.at(i, prev.first) = true; } } } } struct TreeValue; using TreeMap = std::unordered_map; struct TreeValue { TreeValue() : nexts(std::make_shared()) { } using Nexts = std::shared_ptr; Nexts nexts{nullptr}; std::list sources; }; using TreeNode = TreeMap::value_type; template void treeDFS(TreeNode& node, BF const& visitBefore, AF const& visitAfter) { visitBefore(node); for (auto& next : *(node.second.nexts)) { treeDFS(next, visitBefore, visitAfter); } visitAfter(node); } SizeType32 LookaheadAlgorithm::treeEncode( TensorPtr const& tokens, TensorPtr const& posIds, TensorPtr const& mask, TensorPtr const& encodeMap) { TLLM_CHECK(ITensor::volume(tokens->getShape()) == ITensor::volume(posIds->getShape())); auto len = ITensor::volume(tokens->getShape()); BufferRange tokensRange(*tokens); BufferRange posIdsRange(*posIds); BufferLocation maskLocation(*mask); BufferRange mapRange(*encodeMap); auto branches = std::make_shared(); for (auto i = 0; i < len; i++) { auto nexts = branches; for (auto j = 0; j <= i; j++) { if (maskLocation.at(i, j)) { auto pos = posIdsRange[j]; auto tok = tokensRange[j]; auto found = nexts->find(tok); if (found != nexts->end()) { found->second.sources.push_back(j); nexts = found->second.nexts; } else { auto [inserted, ok] = nexts->insert({tok, TreeValue()}); inserted->second.sources.push_back(j); nexts = inserted->second.nexts; } } } } for (auto& item : maskLocation) { item = 0; } std::vector> stack; SizeType32 offset = 0; SizeType32 posId = posIdsRange.size() ? posIdsRange[0] : 0; auto visitBefore = [&stack, &maskLocation, &tokensRange, &posIdsRange, &posId, &offset, &mapRange](TreeNode const& node) { stack.push_back(std::make_pair(offset, node.first)); for (auto const& source : node.second.sources) { mapRange[source] = offset; } for (auto const& prev : stack) { maskLocation.at(offset, prev.first) = true; } tokensRange[offset] = node.first; posIdsRange[offset] = posId; offset++; posId++; }; auto visitAfter = [&stack, &posId](TreeNode const& node) { stack.pop_back(); posId--; }; for (auto& next : *branches) { treeDFS(next, visitBefore, visitAfter); } for (SizeType32 i = offset; i < len; i++) { tokensRange[i] = 0; posIdsRange[i] = 0; } for (SizeType32 i = 0; i < len; i++) { for (SizeType32 j = i < offset ? offset : 0; j < len; j++) { maskLocation.at(i, j) = false; } } return offset; } void LookaheadAlgorithm::prepare(TensorPtr const& draftTokens, TensorPtr const& positionIds, TensorPtr const& draftLengthPtr, TensorPtr const& attentionMask, SizeType32 attentionMaskOffset, TensorConstPtr const& lastPositionIdPtr, TensorConstPtr const& lastTokenPtr) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); if (mRuntimeMaxDraftLen == 0) { mDraftTokens = ITensor::slice(mDraftTokensMax, 0, 0); mEncodeMap = ITensor::slice(mEncodeMapMax, 0, 0); (BufferRange(*draftLengthPtr))[0] = 0; return; } auto lastToken = BufferRange(*lastTokenPtr)[0]; auto offset = BufferRange(*lastPositionIdPtr)[0]; SizeType32 inputLen = ITensor::volume(draftTokens->getShape()); TLLM_CHECK(inputLen >= mRuntimeMaxDraftLen); BufferRange draftRange(*draftTokens); BufferRange positionRange(*positionIds); SizeType32 filledLen = 0; filledLen += lookahead(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen), ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen), offset); auto guessStart = filledLen; filledLen += guess(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen), ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen), offset, lastToken); auto guessEnd = filledLen; std::copy(draftRange.begin() + guessStart, draftRange.begin() + guessEnd, BufferRange(*mGuessTokensMax).begin()); mGuessTokens = ITensor::slice(mGuessTokensMax, 0, guessEnd - guessStart); posIdsToMask(mAttentionMask, ITensor::slice(positionIds, 0, filledLen)); auto draftLen = treeEncode(ITensor::slice(draftTokens, 0, filledLen), ITensor::slice(positionIds, 0, filledLen), mAttentionMask, mEncodeMapMax); for (SizeType32 i = 0; i < draftLen; i++) { BufferRange srcRange(*ITensor::at(mAttentionMask, {i})); BufferRange dstRange(*ITensor::slice(attentionMask, {i + attentionMaskOffset, attentionMaskOffset})); std::copy(srcRange.begin(), srcRange.end(), dstRange.begin()); } std::copy(draftRange.begin(), draftRange.begin() + draftLen, BufferRange(*mDraftTokensMax).begin()); mDraftTokens = ITensor::slice(mDraftTokensMax, 0, draftLen); (BufferRange(*draftLengthPtr))[0] = draftLen; mEncodeMap = ITensor::slice(mEncodeMapMax, 0, filledLen); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acceptedOffsets, TensorPtr const& acceptedLength, TokenIdType newLastToken, TensorConstPtr const& goldenTokens, TensorConstPtr const& endToken) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(ITensor::volume(goldenTokens->getShape()) == ITensor::volume(mDraftTokens->getShape())); BufferRange goldRange(*goldenTokens); BufferRange draftRange(*mDraftTokens); BufferLocation maskLocation(*mAttentionMask); auto draftSize = ITensor::volume(mDraftTokens->getShape()); auto end = *BufferRange(*endToken).begin(); SizeType32 maxHit = 0, hitIdx = 0; for (SizeType32 i = 0; i < draftSize; i++) { SizeType32 hit = 0; TokenIdType cur = newLastToken; for (SizeType32 j = 0; j < draftSize; j++) { if (maskLocation.at(i, j)) { if (draftRange[j] == cur && draftRange[j] != end) { hit++; cur = goldRange[j]; } else { break; } } } if (hit > maxHit) { maxHit = hit; hitIdx = i; } } maxHit = maxHit > mRuntimeMaxDraftPathLen ? mRuntimeMaxDraftPathLen : maxHit; SizeType32 acceptedIdx = 0; BufferRange acceptedRange(*accepted); BufferRange acceptedOffsetsRange(*acceptedOffsets); acceptedRange[acceptedIdx] = newLastToken; for (SizeType32 j = 0; j < draftSize; j++) { if (maskLocation.at(hitIdx, j) && acceptedIdx < maxHit) { acceptedOffsetsRange[acceptedIdx++] = j; acceptedRange[acceptedIdx] = goldRange[j]; } } *BufferRange(*acceptedLength).begin() = maxHit + 1; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } //! lookahead Jacobi matrix has prefilling phase and maintenance phase. //! W=5, N=5. //! *prefilling phase* //! mFilling = 1->2, Tokens initialized from prompt. To fill the second line. //! 0>1 2 3 * //! 4 * //! 5 * //! 6 * //! 7 * //! mFilling = 2->3. //! 0 1>2 3 4 * //! 4 5 * //! 5 6 * //! 6 7 * //! 7 8 * //! mFilling = 3->4. //! 0 1 2>3 4 5 * //! 4 5 6 * //! 5 6 7 * //! 6 7 9 * //! 7 8 a * //! *maintenance phase* //! mFilling = 4->4. shift up and generate five n-grams. //! 0 1 2 3>4 5 6 * //! 4 5 6 7 * //! 5 6 7 8 * //! 6 7 8 9 * //! 7 8 9 a * //! mFilling = 4. //! 0 1 2 3 4>5 6 7 * //! 5 6 7 8 * //! 6 7 8 9 * //! 7 8 9 a * //! 8 9 a b * //! mFilling = 4. //! 0 1 2 3 4 5>6 7 8 * //! 6 7 8 9 * //! 7 8 9 a * //! 8 9 a b * //! 9 a b c * void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const& acceptedOffsets, TensorPtr const& acceptedLength, TensorConstPtr const& sampledTokens, TensorConstPtr const& endToken) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(ITensor::volume(acceptedTokens->getShape()) >= mN); BufferRange zippedTokensRange(*sampledTokens); BufferRange sampledRange(*mSampledTokensMax); BufferRange mapRange(*mEncodeMap); BufferRange unzipRange(*mSampledTokensMax); mSampledTokens = ITensor::slice(mSampledTokensMax, 0, mEncodeMap->getShape().d[0] + 1); unzipRange[0] = zippedTokensRange[0]; for (SizeType32 i = 0; i < mapRange.size(); i++) { unzipRange[i + 1] = zippedTokensRange[mapRange[i] + 1]; } BufferRange keyRange(*mKeyTokens); BufferRange pastRange(*mPastTokens); auto newLastToken = sampledRange[0]; SizeType32 prefill = mN - 2 - mFilling; for (SizeType32 i = 0; i < mW; i++) { keyRange[i] = sampledRange[prefill + i * mFilling + mFilling]; } if (mFilling < mN - 1) { for (SizeType32 i = 0; i < mW; i++) { pastRange[i * (mN - 1) + mFilling] = keyRange[i]; } } else if (mN > 1) { for (SizeType32 i = 0; i < mW; i++) { auto begin = pastRange.begin() + i * (mN - 1); auto end = pastRange.begin() + i * (mN - 1) + mN - 1; auto key = *begin; std::copy(begin + 1, end, begin); *(std::prev(end, 1)) = keyRange[i]; keyRange[i] = key; } keyRange[0] = newLastToken; mPoolManager.update(mKeyTokens, mPastTokens); } auto guessSize = ITensor::volume(mGuessTokens->getShape()); auto outputSize = ITensor::volume(mSampledTokens->getShape()); auto lookSize = 1 + (mN > 1 ? mN - 2 : 0) - mFilling + mFilling * mW; TLLM_CHECK(guessSize + lookSize == outputSize); TensorConstPtr goldenTokens = ITensor::slice(mSampledTokens, lookSize, guessSize); auto& acptLen = *BufferRange(*acceptedLength).begin(); verify(acceptedTokens, acceptedOffsets, acceptedLength, newLastToken, ITensor::slice(sampledTokens, 1), endToken); accept(ITensor::slice(acceptedTokens, 0, *BufferRange(*acceptedLength).begin())); if (mFilling < mN - 1) { mFilling++; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } } // namespace tensorrt_llm::layers