mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
1030 lines
44 KiB
C++
1030 lines
44 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
|
|
|
|
#include "tensorrt_llm/batch_manager/encoderBuffers.h"
|
|
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
|
#include "tensorrt_llm/batch_manager/loraBuffers.h"
|
|
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
|
|
#include "tensorrt_llm/batch_manager/promptTuningBuffers.h"
|
|
#include "tensorrt_llm/batch_manager/rnnStateBuffers.h"
|
|
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
|
|
#include "tensorrt_llm/batch_manager/transformerBuffers.h"
|
|
#include "tensorrt_llm/common/assert.h"
|
|
#include "tensorrt_llm/common/logger.h"
|
|
#include "tensorrt_llm/common/nvtxUtils.h"
|
|
#include "tensorrt_llm/common/stlUtils.h"
|
|
#include "tensorrt_llm/runtime/bufferManager.h"
|
|
#include "tensorrt_llm/runtime/common.h"
|
|
#include "tensorrt_llm/runtime/decoderState.h"
|
|
#include "tensorrt_llm/runtime/iBuffer.h"
|
|
#include "tensorrt_llm/runtime/iTensor.h"
|
|
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
|
#include "tensorrt_llm/runtime/tllmRuntime.h"
|
|
|
|
#include <algorithm>
|
|
#include <iterator>
|
|
#include <memory>
|
|
#include <numeric>
|
|
#include <vector>
|
|
|
|
using namespace tensorrt_llm::runtime;
|
|
|
|
namespace tensorrt_llm::batch_manager
|
|
{
|
|
|
|
RuntimeBuffers::RuntimeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
|
|
std::vector<SizeType32> const& maxAttentionWindowVec, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLen,
|
|
TllmRuntime const& runtime, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
|
executor::DecodingConfig const& decodingConfig, bool gatherGenerationLogits, std::optional<SizeType32> maxNumTokens,
|
|
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs,
|
|
bool promptTableOffloadingParam)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
|
|
promptTableOffloading = promptTableOffloadingParam;
|
|
|
|
create(maxBatchSize, maxBeamWidth, maxAttentionWindowVec, maxAttentionWindow, sinkTokenLen, runtime, modelConfig,
|
|
worldConfig, decodingConfig, gatherGenerationLogits, additionalModelOutputs);
|
|
|
|
// pre-allocate
|
|
setMaxBufferSizes(maxBatchSize, maxBeamWidth, modelConfig, maxNumTokens);
|
|
reshape(runtime, modelConfig, worldConfig, gatherGenerationLogits);
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
RuntimeBuffers::~RuntimeBuffers() = default;
|
|
|
|
void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
|
|
std::vector<SizeType32> const& maxAttentionWindowVec, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLen,
|
|
TllmRuntime const& runtime, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
|
executor::DecodingConfig const& decodingConfig, bool gatherGenerationLogits,
|
|
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
|
|
auto const& manager = runtime.getBufferManager();
|
|
auto const& engine = runtime.getEngine();
|
|
|
|
if (modelConfig.isTransformerBased())
|
|
{
|
|
transformerBuffers = std::make_unique<TransformerBuffers>(maxBatchSize, maxBeamWidth, maxAttentionWindowVec,
|
|
maxAttentionWindow, sinkTokenLen, runtime, modelConfig, worldConfig);
|
|
}
|
|
if (modelConfig.isRnnBased())
|
|
{
|
|
rnnStateBuffers = std::make_unique<RnnStateBuffers>(maxBatchSize, runtime);
|
|
}
|
|
|
|
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
|
|
inputsIds = manager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
|
|
|
mropeRotaryCosSin = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT);
|
|
mropePositionDeltas = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
|
|
|
|
if (worldConfig.isLastPipelineParallelRank())
|
|
{
|
|
auto const logitsType = engine.getTensorDataType(batch_manager::RuntimeBuffers::kLogitsTensorName);
|
|
logits = manager.emptyTensor(MemoryType::kGPU, logitsType);
|
|
}
|
|
|
|
// TODO: check which tensors can be allocated as pinned for max size
|
|
requestTypes = manager.emptyTensor(MemoryType::kCPU, TRTDataType<runtime::RequestType>::value);
|
|
|
|
contextLengthsHost = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
|
contextLengthsDevice = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
|
|
sequenceLengthsHost = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
|
sequenceLengthsDevice = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
|
|
|
|
lastTokenIdsHost = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
|
lastTokenIdsDevice = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
|
|
logitsIdsHost = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
|
|
|
inputsIds = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
|
|
|
|
if (worldConfig.isPipelineParallel())
|
|
{
|
|
hiddenStates = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType());
|
|
}
|
|
|
|
auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize});
|
|
seqSlots = tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT32);
|
|
seqSlotsDevice = manager.gpu(maxBatchSizeShape, nvinfer1::DataType::kINT32);
|
|
|
|
cacheIndirDecoderIOBatchedCopySrcOffsets
|
|
= tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT64);
|
|
cacheIndirDecoderIOBatchedCopyDstOffsets
|
|
= tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT64);
|
|
cacheIndirDecoderIOBatchedCopySizes
|
|
= tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT64);
|
|
mCacheIndirDecoderIOBatchedCopySrcOffsetsSliceDevice = manager.gpu(maxBatchSizeShape, nvinfer1::DataType::kINT64);
|
|
mCacheIndirDecoderIOBatchedCopyDstOffsetsSliceDevice = manager.gpu(maxBatchSizeShape, nvinfer1::DataType::kINT64);
|
|
mCacheIndirDecoderIOBatchedCopyCopySizesDevice = manager.gpu(maxBatchSizeShape, nvinfer1::DataType::kINT64);
|
|
|
|
// Pre-allocate buffer for saving generation logits for model w/o draft tokens
|
|
if (gatherGenerationLogits
|
|
&& (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal()
|
|
|| modelConfig.getSpeculativeDecodingMode().isNone())
|
|
&& worldConfig.isLastPipelineParallelRank())
|
|
{
|
|
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
|
|
auto const logitsType = engine.getTensorDataType(batch_manager::RuntimeBuffers::kLogitsTensorName);
|
|
|
|
generationLogitsCache.transposedLogits = manager.gpu(
|
|
ITensor::makeShape({maxBeamWidth, GenerationLogitsCache::kCACHE_LENGTH, vocabSizePadded}), logitsType);
|
|
generationLogitsCache.logits = manager.gpu(
|
|
ITensor::makeShape({GenerationLogitsCache::kCACHE_LENGTH, maxBatchSize * maxBeamWidth, vocabSizePadded}),
|
|
logitsType);
|
|
|
|
generationLogitsCache.fragmentPointerDevice
|
|
= manager.gpu(ITensor::makeShape({GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
|
|
generationLogitsCache.fragmentPointerHost = tensorrt_llm::runtime::BufferManager::pinnedPool(
|
|
ITensor::makeShape({maxBatchSize, GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
|
|
}
|
|
|
|
if (modelConfig.useCrossAttention())
|
|
{
|
|
encoderBuffers = std::make_unique<EncoderBuffers>();
|
|
encoderBuffers->create(maxBatchSize, modelConfig, runtime);
|
|
}
|
|
|
|
if (modelConfig.usePromptTuning())
|
|
{
|
|
promptTuningBuffers = std::make_unique<PromptTuningBuffers>(
|
|
maxBatchSize, manager, modelConfig, worldConfig, promptTableOffloading);
|
|
}
|
|
|
|
if (modelConfig.useLoraPlugin())
|
|
{
|
|
loraBuffers = std::make_unique<LoraBuffers>(maxBatchSize, maxBeamWidth, runtime, modelConfig, worldConfig);
|
|
}
|
|
|
|
if (modelConfig.getSpeculativeDecodingMode().isMedusa())
|
|
{
|
|
mMedusaBuffers = std::make_unique<MedusaBuffers>(
|
|
maxBatchSize, maxBeamWidth, manager, modelConfig, worldConfig, decodingConfig, runtime);
|
|
}
|
|
else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
|
|
{
|
|
mLookaheadBuffers = std::make_unique<runtime::LookaheadRuntimeBuffers>(
|
|
maxBatchSize, maxBeamWidth, manager, modelConfig, worldConfig, decodingConfig, runtime);
|
|
}
|
|
else if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens())
|
|
{
|
|
mExplicitDraftTokensBuffers = std::make_unique<runtime::ExplicitDraftTokensBuffers>(
|
|
maxBatchSize, maxBeamWidth, manager, modelConfig, worldConfig);
|
|
}
|
|
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
|
|
{
|
|
mEagleBuffers = std::make_unique<runtime::EagleBuffers>(
|
|
maxBatchSize, maxBeamWidth, manager, modelConfig, worldConfig, decodingConfig);
|
|
}
|
|
|
|
if (modelConfig.useLanguageAdapter())
|
|
{
|
|
languageAdapterRoutings = manager.emptyTensor(MemoryType::kGPU, TRTDataType<SizeType32>::value);
|
|
}
|
|
|
|
for (auto const& output : additionalModelOutputs.value_or(std::vector<executor::AdditionalModelOutput>{}))
|
|
{
|
|
auto const& engine = runtime.getEngine();
|
|
auto const dataType = engine.getTensorDataType(output.name.c_str());
|
|
mAdditionalOutputTensors.emplace(output.name, manager.emptyTensor(runtime::MemoryType::kGPU, dataType));
|
|
}
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void RuntimeBuffers::setMaxBufferSizes(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
|
|
runtime::ModelConfig const& modelConfig, std::optional<SizeType32> maxNumRuntimeTokens)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
|
|
// `maxNumSequences` is reached when all requests are in generation
|
|
numContextRequests = 0;
|
|
numGenRequests = maxBatchSize;
|
|
numGenSequences = maxBatchSize * maxBeamWidth;
|
|
|
|
auto const maxDraftTokens = modelConfig.getMaxDecodingDraftTokens();
|
|
// Draft-Tokens and Beam-Search are mutually exclusive
|
|
numLogits = maxBatchSize * std::max(1 + maxDraftTokens, maxBeamWidth);
|
|
auto const maxNumModelTokens = modelConfig.getMaxNumTokens();
|
|
auto const maxNumContextTokens = maxBatchSize * modelConfig.getMaxInputLen();
|
|
auto const maxNumGenTokens = numLogits;
|
|
// For pre-allocation
|
|
numContextTokens = 0; // Set in `setBufferSizes` rather than here for `computeContextLogits`
|
|
numGenTokens
|
|
= maxNumRuntimeTokens.value_or(maxNumModelTokens.value_or(std::max(maxNumContextTokens, maxNumGenTokens)));
|
|
|
|
if (modelConfig.useCrossAttention())
|
|
{
|
|
encoderBuffers->setMaxBufferSizes(maxBatchSize, modelConfig);
|
|
}
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void RuntimeBuffers::setBufferSizes(RequestVector const& contextRequests, RequestVector const& genRequests)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
NVTX3_SCOPED_RANGE(runtimeBuffersSetBufferSizes);
|
|
|
|
// set context sizes
|
|
numContextRequests = static_cast<SizeType32>(contextRequests.size());
|
|
auto numContextLogits = numContextRequests;
|
|
numContextTokens = 0;
|
|
maxContextLength = 0;
|
|
for (auto const& llmReq : contextRequests)
|
|
{
|
|
auto const draftLength = llmReq->isLastContextChunk() ? llmReq->getNumDraftTokens() : 0;
|
|
numContextLogits += draftLength;
|
|
|
|
auto const contextChunkSize = llmReq->getContextChunkSize();
|
|
numContextTokens += contextChunkSize + draftLength;
|
|
if (maxContextLength < llmReq->mPromptLen)
|
|
{
|
|
maxContextLength = llmReq->mPromptLen;
|
|
}
|
|
}
|
|
|
|
// set generation sizes
|
|
numGenRequests = static_cast<SizeType32>(genRequests.size());
|
|
numGenSequences = 0;
|
|
numGenTokens = 0;
|
|
for (auto const& llmReq : genRequests)
|
|
{
|
|
auto const reqBeamWidth = llmReq->getBeamWidthByIter();
|
|
numGenSequences += reqBeamWidth;
|
|
auto const draftLen = llmReq->getNumDraftTokens();
|
|
numGenTokens += draftLen + reqBeamWidth;
|
|
}
|
|
|
|
numLogits = numContextLogits + numGenTokens;
|
|
|
|
if (encoderBuffers)
|
|
{
|
|
encoderBuffers->setBufferSizes(contextRequests, genRequests);
|
|
}
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void RuntimeBuffers::reshape(TllmRuntime const& runtime, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
|
bool gatherGenerationLogits)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
NVTX3_SCOPED_RANGE(runtimeBuffersReshape);
|
|
|
|
if (worldConfig.isLastPipelineParallelRank())
|
|
{
|
|
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
|
|
|
|
if (modelConfig.computeContextLogits() && (numContextRequests > 0))
|
|
{
|
|
// Only when need to return context logits, and there are new requests will execute context phase,
|
|
// logits buffer need to be re-allocated with size of [numContextTokens + numGenSequences, vocabSizePadded]
|
|
auto const& engine = runtime.getEngine();
|
|
auto const& manager = runtime.getBufferManager();
|
|
auto const logitsType = engine.getTensorDataType(kLogitsTensorName);
|
|
logits = manager.gpu(ITensor::makeShape({numContextTokens + numGenSequences, vocabSizePadded}), logitsType);
|
|
}
|
|
else if (gatherGenerationLogits && modelConfig.getSpeculativeDecodingMode().isNone())
|
|
{
|
|
// If need to return generation logits, re-point the logit buffer to avoid overwrite,
|
|
// so we could write back GenerationLogitsCache::kCACHE_LENGTH steps' logits together
|
|
// logits shape: [1, maxBatchSize * maxBeamWidth, vocabSizePadded]
|
|
// which is large enough to cover both numContextRequests and numGenSequences
|
|
logits = ITensor::slice(generationLogitsCache.logits, generationLogitsCache.offset, 1);
|
|
generationLogitsCache.offset = (generationLogitsCache.offset + 1) % GenerationLogitsCache::kCACHE_LENGTH;
|
|
logits->squeeze(0);
|
|
}
|
|
else
|
|
{
|
|
logits->reshape(ITensor::makeShape({numLogits, vocabSizePadded}));
|
|
}
|
|
}
|
|
|
|
auto const numSequences = getNumSequences();
|
|
auto const numSequencesShape = ITensor::makeShape({numSequences});
|
|
requestTypes->reshape(numSequencesShape);
|
|
contextLengthsHost->reshape(numSequencesShape);
|
|
contextLengthsDevice->reshape(numSequencesShape);
|
|
sequenceLengthsHost->reshape(numSequencesShape);
|
|
sequenceLengthsDevice->reshape(numSequencesShape);
|
|
|
|
auto const numLogitsShape = ITensor::makeShape({numLogits});
|
|
lastTokenIdsHost->reshape(numLogitsShape);
|
|
lastTokenIdsDevice->reshape(numLogitsShape);
|
|
logitsIdsHost->reshape(numLogitsShape);
|
|
|
|
if (transformerBuffers)
|
|
{
|
|
transformerBuffers->reshape(numSequences, numContextTokens + numGenTokens);
|
|
}
|
|
|
|
if (rnnStateBuffers)
|
|
{
|
|
rnnStateBuffers->reshape(numSequences);
|
|
}
|
|
|
|
if (modelConfig.useCrossAttention())
|
|
{
|
|
encoderBuffers->reshape();
|
|
}
|
|
|
|
if (modelConfig.useLoraPlugin())
|
|
{
|
|
loraBuffers->reshape(numSequences);
|
|
}
|
|
|
|
if (mMedusaBuffers)
|
|
{
|
|
mMedusaBuffers->reshape(
|
|
numContextRequests, numGenRequests, modelConfig.getSpeculativeDecodingModulePtr()->getMaxDecodingTokens());
|
|
}
|
|
|
|
if (mLookaheadBuffers && modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
|
|
{
|
|
mLookaheadBuffers->reshape(
|
|
numContextRequests, numGenRequests, modelConfig.getSpeculativeDecodingModulePtr()->getMaxDecodingTokens());
|
|
}
|
|
|
|
if (mExplicitDraftTokensBuffers)
|
|
{
|
|
mExplicitDraftTokensBuffers->reshape(numContextRequests, numGenRequests, modelConfig);
|
|
}
|
|
|
|
if (mEagleBuffers)
|
|
{
|
|
mEagleBuffers->reshape(numContextRequests, numGenRequests, modelConfig);
|
|
}
|
|
|
|
auto const numRequests = getNumRequests();
|
|
auto const numRequestsShape = ITensor::makeShape({numRequests});
|
|
seqSlots->reshape(numRequestsShape);
|
|
seqSlotsDevice->reshape(numRequestsShape);
|
|
|
|
auto const numTokens = getNumTokens();
|
|
inputsIds->reshape(ITensor::makeShape({numTokens}));
|
|
|
|
if (modelConfig.useMrope())
|
|
{
|
|
auto const mropeRotaryCosSinSize = modelConfig.getMaxPositionEmbeddings() * modelConfig.getRotaryEmbeddingDim();
|
|
mropeRotaryCosSin->reshape(ITensor::makeShape({numSequences, mropeRotaryCosSinSize}));
|
|
mropePositionDeltas->reshape(ITensor::makeShape({numSequences, 1}));
|
|
}
|
|
|
|
if (worldConfig.isPipelineParallel())
|
|
{
|
|
auto const hiddenSize = (!modelConfig.getPpReduceScatter() || worldConfig.isFirstPipelineParallelRank())
|
|
? modelConfig.getHiddenSize() * worldConfig.getTensorParallelism()
|
|
: modelConfig.getHiddenSize();
|
|
|
|
auto const hiddenStatesShape = ITensor::makeShape({numTokens, hiddenSize});
|
|
hiddenStates->reshape(hiddenStatesShape);
|
|
}
|
|
|
|
if (modelConfig.useLanguageAdapter())
|
|
{
|
|
languageAdapterRoutings->reshape(ITensor::makeShape({numTokens, 1}));
|
|
}
|
|
|
|
for (auto const& outputTensor : mAdditionalOutputTensors)
|
|
{
|
|
auto const& [name, tensor] = outputTensor;
|
|
auto const& engine = runtime.getEngine();
|
|
auto shape = engine.getTensorShape(name.c_str());
|
|
TLLM_CHECK_WITH_INFO(
|
|
shape.d[0] == -1, "First dimension of additional output tensor '%s' must be dynamic", name.c_str());
|
|
shape.d[0] = numTokens;
|
|
tensor->reshape(shape);
|
|
}
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void RuntimeBuffers::prepareBuffersForCudaGraph(SizeType32 maxSequenceLength)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
NVTX3_SCOPED_RANGE(prepareBuffersForCudaGraph);
|
|
|
|
TLLM_CHECK(numContextRequests == 0);
|
|
|
|
if (transformerBuffers)
|
|
{
|
|
// Set pastKeyValueLength for graph capturing. This way we will capture graph with
|
|
// maxKvCacheLengthRounded rounded to the next kKV_CACHE_LEN_CUDA_GRAPH_ROUND_SIZE.
|
|
// MMHA will launch excessive amount of blocks and some of them will exit early during the actual launch.
|
|
// We can reuse the same graph for the next kKV_CACHE_LEN_CUDA_GRAPH_ROUND_SIZE iterations.
|
|
|
|
// make sure the size does not overflow the max allowed pastKvCacheLength
|
|
auto const pastKvCacheLength = std::min(maxSequenceLength - 1, maxKvCacheLengthRounded);
|
|
|
|
auto* pastKeyValueLengthsPtr = bufferCast<SizeType32>(*transformerBuffers->pastKeyValueLengths);
|
|
std::fill_n(pastKeyValueLengthsPtr, getNumSequences(), pastKvCacheLength);
|
|
}
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, RequestVector const& genRequests,
|
|
SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, runtime::decoder::DecoderState const& decoderState,
|
|
kv_cache_manager::BaseKVCacheManager* kvCacheManagerPtr,
|
|
kv_cache_manager::BaseKVCacheManager* crossKvCacheManagerPtr,
|
|
rnn_state_manager::RnnStateManager* rnnStateManagerPtr, PeftTable const& peftTable,
|
|
runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
|
|
runtime::WorldConfig const& worldConfig, bool trtOverlap, OptionalRef<runtime::ITensor const> newOutputTokens)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
NVTX3_SCOPED_RANGE(runtimeBuffersSetFromInputs);
|
|
|
|
auto const& manager = runtime.getBufferManager();
|
|
auto const& stream = runtime.getStream();
|
|
|
|
// Fill requestTypes
|
|
{
|
|
auto* hostRequestTypes = bufferCast<runtime::RequestType>(*requestTypes);
|
|
std::fill_n(hostRequestTypes, numContextRequests, runtime::RequestType::kCONTEXT);
|
|
std::fill_n(hostRequestTypes + numContextRequests, numGenSequences, runtime::RequestType::kGENERATION);
|
|
}
|
|
|
|
SizeType32 totalInputSize = 0;
|
|
std::vector<TokenIdType> inputHost;
|
|
std::vector<SizeType32> positionIdsHost;
|
|
std::vector<SizeType32> positionIdsHostRow2;
|
|
std::vector<SizeType32> mropePositionDeltasHost;
|
|
std::vector<SizeType32> languageAdapterRoutingsHost;
|
|
|
|
auto* contextLengthsHostPtr = bufferCast<SizeType32>(*contextLengthsHost);
|
|
auto* sequenceLengthsHostPtr = bufferCast<SizeType32>(*sequenceLengthsHost);
|
|
auto* pastKeyValueLengthsPtr
|
|
= transformerBuffers ? bufferCast<SizeType32>(*transformerBuffers->pastKeyValueLengths) : nullptr;
|
|
SizeType32 totalNumLogits{0};
|
|
auto* logitsIdsHostPtr = bufferCast<SizeType32>(*logitsIdsHost);
|
|
bool const isChatGlm = modelConfig.getModelVariant() == ModelConfig::ModelVariant::kChatGlm;
|
|
bool const isGlm = modelConfig.getModelVariant() == ModelConfig::ModelVariant::kGlm;
|
|
auto const mropeRotaryCosSinSize = modelConfig.getMaxPositionEmbeddings() * modelConfig.getRotaryEmbeddingDim();
|
|
|
|
{
|
|
NVTX3_SCOPED_RANGE(seqSlotsLoop);
|
|
auto* seqSlotIndices = bufferCast<SizeType32>(*seqSlots);
|
|
|
|
SizeType32 batchIdx{0};
|
|
for (auto const& requests : {contextRequests, genRequests})
|
|
{
|
|
for (auto const& llmReq : requests)
|
|
{
|
|
// Get position of the current sequence in the decoder
|
|
auto const seqSlot = llmReq->mSeqSlot.value();
|
|
seqSlotIndices[batchIdx] = seqSlot;
|
|
++batchIdx;
|
|
}
|
|
}
|
|
|
|
TLLM_CHECK(seqSlots->getSize() == static_cast<std::size_t>(batchIdx));
|
|
manager.copy(*seqSlots, *seqSlotsDevice);
|
|
}
|
|
|
|
// context preparation loop
|
|
if (!contextRequests.empty())
|
|
{
|
|
NVTX3_SCOPED_RANGE(contextPrepareLoop);
|
|
numContextLogits.resize(contextRequests.size());
|
|
|
|
SizeType32 batchIdx{0};
|
|
for (auto const& llmReq : contextRequests)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(llmReq->isContextInitState() || llmReq->isDisaggGenerationTransmissionComplete(),
|
|
"The request should be in context phase or disaggregated generation tranmissionComplete phase.");
|
|
TLLM_CHECK_WITH_INFO(
|
|
llmReq->getMaxNumGeneratedTokens() == 0, "Context request should not have generated tokens.");
|
|
|
|
auto const& reqTokens = llmReq->getTokens(0);
|
|
auto const& draftTokens = llmReq->getDraftTokens();
|
|
auto const draftLength = llmReq->getNumDraftTokens();
|
|
auto const& positionIds = llmReq->getPositionIds();
|
|
|
|
auto const contextChunkSize = llmReq->getContextChunkSize();
|
|
auto const beginCompute = llmReq->getContextCurrentPosition();
|
|
auto const endCompute = beginCompute + contextChunkSize;
|
|
inputHost.insert(inputHost.end(), reqTokens.begin() + beginCompute, reqTokens.begin() + endCompute);
|
|
|
|
logitsIdsHostPtr[totalNumLogits++] = contextChunkSize;
|
|
numContextLogits.at(batchIdx) = modelConfig.computeContextLogits() ? contextChunkSize : 1;
|
|
|
|
if (llmReq->isLastContextChunk())
|
|
{
|
|
inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end());
|
|
std::fill_n(logitsIdsHostPtr + totalNumLogits, draftLength, 1);
|
|
totalNumLogits += draftLength;
|
|
}
|
|
auto const inputLength = contextChunkSize + (llmReq->isLastContextChunk() ? draftLength : 0);
|
|
contextLengthsHostPtr[batchIdx] = inputLength;
|
|
auto const sequenceLen = inputLength + llmReq->getContextCurrentPosition();
|
|
sequenceLengthsHostPtr[batchIdx] = sequenceLen;
|
|
|
|
if (static_cast<bool>(pastKeyValueLengthsPtr))
|
|
{
|
|
pastKeyValueLengthsPtr[batchIdx] = beginCompute + inputLength;
|
|
}
|
|
|
|
if (positionIds.has_value())
|
|
{
|
|
TLLM_CHECK_WITH_INFO(!(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization");
|
|
positionIdsHost.insert(positionIdsHost.end(), positionIds.value()->begin() + beginCompute,
|
|
positionIds.value()->begin() + endCompute);
|
|
}
|
|
else
|
|
{
|
|
if (isChatGlm)
|
|
{
|
|
// Specialize for ChatGLM-6B with 2D-Position-Embedding
|
|
positionIdsHost.resize(totalInputSize + inputLength);
|
|
std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0);
|
|
positionIdsHost.back() = positionIdsHost.back() - 1;
|
|
|
|
positionIdsHostRow2.resize(totalInputSize + inputLength);
|
|
positionIdsHostRow2.back() = 1;
|
|
}
|
|
else if (isGlm)
|
|
{
|
|
// Specialize for GLM-10B with 2D-Position-Embedding and special value of the mask id position
|
|
auto start = inputHost.begin() + totalInputSize;
|
|
auto end = start + inputLength;
|
|
auto it = std::find_if(
|
|
start, end, [](SizeType32 id) { return id == 50260 || id == 50263 || id == 50264; });
|
|
llmReq->mMaskPosition = (it != end) ? std::distance(start, it) : maxContextLength;
|
|
|
|
positionIdsHost.resize(totalInputSize + inputLength);
|
|
std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0);
|
|
positionIdsHost.back() = llmReq->mMaskPosition;
|
|
|
|
positionIdsHostRow2.resize(totalInputSize + inputLength);
|
|
positionIdsHostRow2.back() = 1;
|
|
}
|
|
else
|
|
{
|
|
// Other models
|
|
positionIdsHost.resize(totalInputSize + inputLength);
|
|
std::iota(std::begin(positionIdsHost) + totalInputSize,
|
|
std::begin(positionIdsHost) + totalInputSize + inputLength, beginCompute);
|
|
}
|
|
}
|
|
if (modelConfig.useMrope())
|
|
{
|
|
auto optMropeRotaryCosSin = llmReq->getMropeRotaryCosSin().value();
|
|
TLLM_CHECK_WITH_INFO(optMropeRotaryCosSin->getShape().d[0] == mropeRotaryCosSinSize,
|
|
"Provided MropeRotarySinCos is %ld and expected is %d.\n", optMropeRotaryCosSin->getShape().d[0],
|
|
int(mropeRotaryCosSinSize));
|
|
|
|
auto const mropeRotaryCosSinCtx = ITensor::slice(mropeRotaryCosSin, batchIdx, 1);
|
|
manager.copy(*optMropeRotaryCosSin, *mropeRotaryCosSinCtx);
|
|
}
|
|
|
|
if (modelConfig.useLanguageAdapter())
|
|
{
|
|
auto const languageAdapterRouting = llmReq->getLanguageAdapterRouting(
|
|
modelConfig.getNumLanguages().value(), endCompute - beginCompute);
|
|
languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(),
|
|
std::begin(languageAdapterRouting), std::end(languageAdapterRouting));
|
|
}
|
|
totalInputSize += inputLength;
|
|
++batchIdx;
|
|
}
|
|
|
|
if (rnnStateBuffers)
|
|
{
|
|
rnnStateBuffers->fillSlotMappings(contextRequests, rnnStateManagerPtr);
|
|
}
|
|
}
|
|
|
|
// generation preparation loop
|
|
if (!genRequests.empty())
|
|
{
|
|
NVTX3_SCOPED_RANGE(genPrepareLoop);
|
|
|
|
auto const numContextRequests = static_cast<SizeType32>(contextRequests.size());
|
|
auto numSequences = numContextRequests;
|
|
for (auto const& llmReq : genRequests)
|
|
{
|
|
auto const reqBeamWidth = llmReq->getBeamWidthByIter();
|
|
auto const draftLength = llmReq->getNumDraftTokens();
|
|
auto const& draftTokens = llmReq->getDraftTokens();
|
|
auto const numLogits = draftLength + reqBeamWidth;
|
|
TLLM_CHECK(draftLength == 0 || reqBeamWidth == 1);
|
|
|
|
auto const promptLen = llmReq->mPromptLen;
|
|
auto const sequenceLen
|
|
= promptLen + llmReq->getMaxNumGeneratedTokens() + static_cast<SizeType32>(trtOverlap);
|
|
auto const& positionIds = llmReq->getPositionIds();
|
|
for (int beam = 0; beam < reqBeamWidth; ++beam)
|
|
{
|
|
auto const numTokens = llmReq->getNumTokens(beam) + static_cast<SizeType32>(trtOverlap);
|
|
// TODO: can this be removed completely?
|
|
if (!trtOverlap)
|
|
{
|
|
auto const lastToken = llmReq->getLastTokens(beam);
|
|
inputHost.push_back(lastToken);
|
|
if (draftLength > 0)
|
|
{
|
|
inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end());
|
|
}
|
|
}
|
|
|
|
// If model updates generation position ids do not append them here.
|
|
if (!modelConfig.getSpeculativeDecodingMode().updatesPositionIds())
|
|
{
|
|
if (positionIds.has_value())
|
|
{
|
|
TLLM_CHECK_WITH_INFO(
|
|
!(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization");
|
|
auto last_context_position_id = positionIds.value()->back();
|
|
positionIdsHost.push_back(
|
|
static_cast<SizeType32>(last_context_position_id + sequenceLen - promptLen));
|
|
}
|
|
else
|
|
{
|
|
if (isChatGlm) // ChatGLM-6B
|
|
{
|
|
positionIdsHost.push_back(static_cast<SizeType32>(promptLen - 2));
|
|
positionIdsHostRow2.push_back(static_cast<SizeType32>(sequenceLen - promptLen + 1));
|
|
}
|
|
else if (isGlm)
|
|
{
|
|
positionIdsHost.push_back(llmReq->mMaskPosition);
|
|
positionIdsHostRow2.push_back(static_cast<SizeType32>(sequenceLen - promptLen + 1));
|
|
}
|
|
else // GPT / ChatGLM2-6B / ChatGLM3-6B / BART
|
|
{
|
|
// positionIds is just the size of tokens -1
|
|
positionIdsHost.push_back(numTokens - 1);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (modelConfig.useMrope())
|
|
{
|
|
auto optMropePositionDeltas = llmReq->getMropePositionDeltas().value();
|
|
mropePositionDeltasHost.push_back(optMropePositionDeltas);
|
|
}
|
|
|
|
if (modelConfig.useLanguageAdapter())
|
|
{
|
|
// Generation requests only have one token per sequence
|
|
auto const languageAdapterRouting
|
|
= llmReq->getLanguageAdapterRouting(modelConfig.getNumLanguages().value(), 1);
|
|
languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(),
|
|
std::begin(languageAdapterRouting), std::end(languageAdapterRouting));
|
|
}
|
|
}
|
|
|
|
if (static_cast<bool>(pastKeyValueLengthsPtr))
|
|
{
|
|
SizeType32 pastKeyValueLength = sequenceLen - 1;
|
|
std::fill_n(pastKeyValueLengthsPtr + numSequences, reqBeamWidth, pastKeyValueLength);
|
|
}
|
|
totalInputSize += numLogits;
|
|
|
|
std::fill_n(logitsIdsHostPtr + totalNumLogits, numLogits, 1);
|
|
|
|
totalNumLogits += numLogits;
|
|
|
|
if (rnnStateBuffers)
|
|
{
|
|
auto const seqSlot = llmReq->mSeqSlot.value();
|
|
auto& rnnStateManager = *rnnStateManagerPtr;
|
|
rnnStateManager.fillSlotMapping(*rnnStateBuffers->slotMappingHost, numSequences, seqSlot, reqBeamWidth);
|
|
}
|
|
numSequences += reqBeamWidth;
|
|
}
|
|
|
|
if (transformerBuffers && maxBeamWidth > 1)
|
|
{
|
|
transformerBuffers->copyCacheIndirection(genRequests, decoderState.getCacheIndirectionOutput(), stream);
|
|
}
|
|
|
|
numSequences = numContextRequests;
|
|
for (auto const& llmReq : genRequests)
|
|
{
|
|
auto const reqBeamWidth = llmReq->getBeamWidthByIter();
|
|
auto const draftLength = llmReq->getNumDraftTokens();
|
|
|
|
auto const contextQLength = llmReq->mPromptLen + draftLength;
|
|
auto const sequenceLen
|
|
= contextQLength + llmReq->getMaxNumGeneratedTokens() + static_cast<SizeType32>(trtOverlap);
|
|
|
|
std::fill_n(contextLengthsHostPtr + numSequences, reqBeamWidth, contextQLength);
|
|
std::fill_n(sequenceLengthsHostPtr + numSequences, reqBeamWidth, sequenceLen);
|
|
numSequences += reqBeamWidth;
|
|
}
|
|
if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
|
|
{
|
|
// copy from lookahead decoding buffer
|
|
mLookaheadBuffers->setFromInputs(numContextRequests, numGenRequests, *requestTypes, *seqSlots,
|
|
decoderState.getLookaheadBuffers(), runtime, modelConfig, worldConfig);
|
|
}
|
|
}
|
|
|
|
// check skipCrossAttnBlocks
|
|
if (transformerBuffers && modelConfig.skipCrossAttnBlocks())
|
|
{
|
|
bool isSkipCrossAttn = true;
|
|
for (auto const& requests : {contextRequests, genRequests})
|
|
{
|
|
for (auto const& llmReq : requests)
|
|
{
|
|
bool tmpValue = false;
|
|
if (llmReq->getSkipCrossAttnBlocks() != nullptr)
|
|
{
|
|
manager.copy(*llmReq->getSkipCrossAttnBlocks(), &tmpValue);
|
|
}
|
|
isSkipCrossAttn &= tmpValue;
|
|
}
|
|
}
|
|
transformerBuffers->copySkipCrossAttnBlocks(isSkipCrossAttn, runtime);
|
|
}
|
|
|
|
if (isChatGlm || isGlm)
|
|
{
|
|
positionIdsHost.reserve(totalInputSize * 2);
|
|
positionIdsHost.insert(positionIdsHost.end(), positionIdsHostRow2.begin(), positionIdsHostRow2.end());
|
|
}
|
|
|
|
if (modelConfig.useCrossAttention())
|
|
{
|
|
encoderBuffers->fill(contextRequests, genRequests, manager);
|
|
}
|
|
if (modelConfig.usePromptTuning())
|
|
{
|
|
promptTuningBuffers->fill(contextRequests, genRequests, manager, modelConfig.usePackedInput());
|
|
}
|
|
if (modelConfig.useLoraPlugin())
|
|
{
|
|
loraBuffers->fill(contextRequests, genRequests, peftTable, manager, modelConfig, worldConfig);
|
|
}
|
|
if (modelConfig.useMrope())
|
|
{
|
|
if (!mropePositionDeltasHost.empty())
|
|
{
|
|
auto mropePositionDeltasGen = ITensor::slice(mropePositionDeltas, 0, numGenSequences);
|
|
manager.copy(mropePositionDeltasHost.data(), *mropePositionDeltasGen);
|
|
}
|
|
}
|
|
|
|
{
|
|
NVTX3_SCOPED_RANGE(bufferCopies);
|
|
if (trtOverlap)
|
|
{
|
|
auto contextInputsIds = ITensor::slice(inputsIds, 0, numContextTokens);
|
|
manager.copy(inputHost.data(), *contextInputsIds);
|
|
|
|
if (!genRequests.empty())
|
|
{
|
|
auto generationInputsIds = ITensor::slice(inputsIds, numContextTokens);
|
|
auto seqSlotsDeviceSlice = ITensor::slice(seqSlotsDevice, numContextRequests);
|
|
runtime::kernels::invokeGatherBatch(
|
|
*generationInputsIds, *newOutputTokens, *seqSlotsDeviceSlice, maxBeamWidth, stream);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
manager.copy(inputHost.data(), *inputsIds);
|
|
}
|
|
// In generation phase, device ptr of context lengths need to be tiled.
|
|
manager.copy(*contextLengthsHost, *contextLengthsDevice);
|
|
manager.copy(*sequenceLengthsHost, *sequenceLengthsDevice);
|
|
auto const logitsIdsHostRange = BufferRange<SizeType32>(*logitsIdsHost);
|
|
auto lastTokenIdsHostRange = BufferRange<SizeType32>(*lastTokenIdsHost);
|
|
common::stl_utils::inclusiveScan(
|
|
logitsIdsHostRange.begin(), logitsIdsHostRange.end(), lastTokenIdsHostRange.begin());
|
|
manager.copy(*lastTokenIdsHost, *lastTokenIdsDevice);
|
|
if (transformerBuffers)
|
|
{
|
|
TensorPtr decoderPositionIds = modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()
|
|
? mLookaheadBuffers->positionIdsDevice
|
|
: nullptr;
|
|
transformerBuffers->copyPositionIds(runtime, positionIdsHost, isChatGlm || isGlm, decoderPositionIds);
|
|
}
|
|
if (rnnStateBuffers)
|
|
{
|
|
rnnStateBuffers->copySlotMappingH2D(runtime);
|
|
}
|
|
if (modelConfig.useLanguageAdapter())
|
|
{
|
|
manager.copy(languageAdapterRoutingsHost.data(), *languageAdapterRoutings);
|
|
}
|
|
}
|
|
|
|
if (transformerBuffers && static_cast<bool>(kvCacheManagerPtr))
|
|
{
|
|
transformerBuffers->copyKvBlockOffsets(
|
|
contextRequests, genRequests, kvCacheManagerPtr, crossKvCacheManagerPtr, manager);
|
|
}
|
|
|
|
if (modelConfig.useCrossAttention())
|
|
{
|
|
transformerBuffers->copyCrossAttentionMasks(contextRequests, genRequests, contextLengthsDevice,
|
|
encoderBuffers->inputLengths, maxContextLength, encoderBuffers->getMaxInputLengthInBatch(), runtime);
|
|
}
|
|
|
|
maxKvCacheLengthRounded = 0;
|
|
if (static_cast<bool>(pastKeyValueLengthsPtr))
|
|
{
|
|
auto const maxKvCacheLength
|
|
= *std::max_element(pastKeyValueLengthsPtr, pastKeyValueLengthsPtr + getNumSequences());
|
|
// Round up kv cache length
|
|
maxKvCacheLengthRounded = common::ceilDiv(maxKvCacheLength, kKV_CACHE_LEN_CUDA_GRAPH_ROUND_SIZE)
|
|
* kKV_CACHE_LEN_CUDA_GRAPH_ROUND_SIZE;
|
|
}
|
|
|
|
if (modelConfig.getSpeculativeDecodingMode().needsDecoderPrologue())
|
|
{
|
|
if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens())
|
|
{
|
|
prepareExplicitDraftTokenBuffers(
|
|
decoderState.getExplicitDraftTokensBuffers(), runtime, modelConfig, worldConfig);
|
|
}
|
|
if (modelConfig.getSpeculativeDecodingMode().isEagle())
|
|
{
|
|
prepareEagleBuffers(
|
|
contextRequests, genRequests, decoderState.getEagleBuffers(), runtime, modelConfig, worldConfig);
|
|
}
|
|
}
|
|
|
|
sync_check_cuda_error(stream.get());
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void RuntimeBuffers::prepareExplicitDraftTokenBuffers(
|
|
runtime::ExplicitDraftTokensBuffers::Inputs const& explicitDraftTokensBuffers, TllmRuntime const& runtime,
|
|
ModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
|
|
TLLM_CHECK(mExplicitDraftTokensBuffers);
|
|
|
|
mExplicitDraftTokensBuffers->setFromInputs(numContextRequests, numGenRequests, *requestTypes, *seqSlots,
|
|
explicitDraftTokensBuffers, *transformerBuffers->positionIds, modelConfig, worldConfig,
|
|
runtime.getBufferManager(), runtime.getStream());
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void RuntimeBuffers::prepareEagleBuffers(RequestVector const& contextRequests, RequestVector const& genRequests,
|
|
runtime::EagleBuffers::Inputs const& eagleBuffers, TllmRuntime const& runtime, ModelConfig const& modelConfig,
|
|
WorldConfig const& worldConfig)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
|
|
TLLM_CHECK(mEagleBuffers);
|
|
|
|
mEagleBuffers->setFromInputs(contextRequests, genRequests, *requestTypes, *seqSlots, eagleBuffers,
|
|
runtime.getBufferManager(), modelConfig, worldConfig);
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
std::tuple<SizeType32, RuntimeBuffers::TensorMap const&, RuntimeBuffers::TensorMap&> RuntimeBuffers::prepareStep(
|
|
RequestVector const& contextRequests, RequestVector const& genRequests, SizeType32 maxBeamWidth,
|
|
SizeType32 maxAttentionWindow, runtime::decoder::DecoderState const& decoderState,
|
|
kv_cache_manager::BaseKVCacheManager* kvCacheManager, kv_cache_manager::BaseKVCacheManager* crossKvCacheManager,
|
|
rnn_state_manager::RnnStateManager* rnnStateManager, PeftTable const& peftTable, TllmRuntime const& runtime,
|
|
ModelConfig const& modelConfig, WorldConfig const& worldConfig, bool gatherGenerationLogits, bool trtOverlap,
|
|
OptionalRef<runtime::ITensor const> newOutputTokens)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
NVTX3_SCOPED_RANGE(runtimeBuffersPrepareStep);
|
|
|
|
setBufferSizes(contextRequests, genRequests);
|
|
reshape(runtime, modelConfig, worldConfig, gatherGenerationLogits);
|
|
|
|
setFromInputs(contextRequests, genRequests, maxBeamWidth, maxAttentionWindow, decoderState, kvCacheManager,
|
|
crossKvCacheManager, rnnStateManager, peftTable, runtime, modelConfig, worldConfig, trtOverlap,
|
|
newOutputTokens);
|
|
|
|
fillIOMaps(modelConfig, worldConfig);
|
|
|
|
auto const numTokens = getNumTokens();
|
|
auto const optProfileId = runtime.getOptProfileId(numTokens, ModelConfig::getOptProfilesSplitPoints());
|
|
setContextIndex(optProfileId);
|
|
TLLM_LOG_DEBUG("numTokens: %d, optProfileId: %d", numTokens, optProfileId);
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
return {optProfileId, inputMap, outputMap};
|
|
}
|
|
|
|
void RuntimeBuffers::fillIOMaps(ModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
NVTX3_SCOPED_RANGE(runtimeBuffersFillIOMaps);
|
|
|
|
inputMap.clear();
|
|
outputMap.clear();
|
|
|
|
if (transformerBuffers)
|
|
{
|
|
transformerBuffers->getBuffers(inputMap, outputMap, modelConfig);
|
|
}
|
|
if (rnnStateBuffers)
|
|
{
|
|
rnnStateBuffers->getBuffers(inputMap);
|
|
}
|
|
|
|
if (worldConfig.isLastPipelineParallelRank())
|
|
{
|
|
// feed a view to TensorRT runtime so reshaping does not change logits buffer
|
|
outputMap.insert_or_assign(kLogitsTensorName, ITensor::view(logits));
|
|
}
|
|
else
|
|
{
|
|
outputMap.insert_or_assign(kHiddenStatesOutputTensorName, hiddenStates);
|
|
}
|
|
|
|
if (worldConfig.isFirstPipelineParallelRank())
|
|
{
|
|
inputMap.insert_or_assign(kInputIdsTensorName, inputsIds);
|
|
}
|
|
else
|
|
{
|
|
inputMap.insert_or_assign(kHiddenStatesInputTensorName, hiddenStates);
|
|
}
|
|
|
|
inputMap.insert_or_assign(kLastTokenIdsTensorName, lastTokenIdsDevice);
|
|
|
|
inputMap.insert_or_assign(kHostRequestTypesTensorName, requestTypes);
|
|
// In the generation phase, we still pass context lengths.
|
|
inputMap.insert_or_assign(kContextLengthsTensorName, contextLengthsDevice);
|
|
inputMap.insert_or_assign(kHostContextLengthsTensorName, contextLengthsHost);
|
|
inputMap.insert_or_assign(kSequenceLengthsTensorName, sequenceLengthsDevice);
|
|
|
|
if (modelConfig.useCrossAttention())
|
|
{
|
|
encoderBuffers->insertInputTensors(inputMap);
|
|
}
|
|
if (modelConfig.usePromptTuning())
|
|
{
|
|
auto const& promptTuningParams = promptTuningBuffers->mPromptTuningParams;
|
|
inputMap.insert_or_assign(kPromptEmbeddingTableTensorName, promptTuningParams.embeddingTable);
|
|
inputMap.insert_or_assign(kTasksTensorName, promptTuningParams.tasks);
|
|
inputMap.insert_or_assign(kPromptVocabSizeTensorName, promptTuningParams.vocabSize);
|
|
}
|
|
if (modelConfig.useMrope())
|
|
{
|
|
|
|
inputMap.insert_or_assign(kMRopeRotaryCosSinTensorName, mropeRotaryCosSin);
|
|
inputMap.insert_or_assign(kMRopePositionDeltasTensorName, mropePositionDeltas);
|
|
}
|
|
if (modelConfig.useLoraPlugin())
|
|
{
|
|
loraBuffers->insertInputTensors(inputMap, loraBuffers->mLoraWeightsPointersHost,
|
|
loraBuffers->mLoraAdapterSizesHost, modelConfig, worldConfig);
|
|
}
|
|
if (modelConfig.useLanguageAdapter())
|
|
{
|
|
inputMap.insert_or_assign("language_adapter_routings", languageAdapterRoutings);
|
|
}
|
|
|
|
if (mMedusaBuffers)
|
|
{
|
|
mMedusaBuffers->insertInputTensors(inputMap, outputMap, worldConfig);
|
|
}
|
|
if (mLookaheadBuffers)
|
|
{
|
|
mLookaheadBuffers->insertInputTensors(inputMap, outputMap, worldConfig);
|
|
}
|
|
if (mExplicitDraftTokensBuffers)
|
|
{
|
|
mExplicitDraftTokensBuffers->insertInputTensors(inputMap, outputMap, worldConfig);
|
|
}
|
|
if (mEagleBuffers)
|
|
{
|
|
mEagleBuffers->insertInputTensors(inputMap, outputMap, worldConfig);
|
|
}
|
|
|
|
for (auto const& outputTensor : mAdditionalOutputTensors)
|
|
{
|
|
outputMap.insert_or_assign(outputTensor.first, outputTensor.second);
|
|
}
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
} // namespace tensorrt_llm::batch_manager
|