mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
852 lines
34 KiB
C++
852 lines
34 KiB
C++
//
|
|
// Created by martinma on 5/24/23.
|
|
//
|
|
/*
|
|
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/runtime/gptSession.h"
|
|
|
|
#include "iBuffer.h"
|
|
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
|
#include "tensorrt_llm/kernels/decodingKernels.h"
|
|
#include "tensorrt_llm/runtime/gptDecoderBatch.h"
|
|
#include "tensorrt_llm/runtime/ipcUtils.h"
|
|
#include "tensorrt_llm/runtime/ncclCommunicator.h"
|
|
#include "tensorrt_llm/runtime/runtimeBuffers.h"
|
|
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
|
#include "tensorrt_llm/runtime/statefulGptDecoder.h"
|
|
#include "tensorrt_llm/runtime/tllmLogger.h"
|
|
#include "tensorrt_llm/runtime/tllmRuntime.h"
|
|
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
|
|
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
#include <fstream>
|
|
#include <memory>
|
|
|
|
using namespace tensorrt_llm::runtime;
|
|
|
|
namespace tc = tensorrt_llm::common;
|
|
namespace bmkv = tensorrt_llm::batch_manager::kv_cache_manager;
|
|
|
|
GptSession::GptSession(GptModelConfig const& modelConfig, WorldConfig const& worldConfig, void const* engineBuffer,
|
|
std::size_t engineSize, LoggerPtr logger)
|
|
: mModelConfig{modelConfig}
|
|
, mWorldConfig{worldConfig}
|
|
, mDevice{utils::initDevice(worldConfig)}
|
|
, mLogger{logger ? std::move(logger) : std::make_shared<TllmLogger>()}
|
|
, mRuntime{std::make_shared<TllmRuntime>(engineBuffer, engineSize, *mLogger)}
|
|
, mNumMicroBatches{worldConfig.getPipelineParallelism()}
|
|
, mDecoders{}
|
|
, mBuffers{}
|
|
, mCudaGraphInstances{}
|
|
{
|
|
if (mWorldConfig.isPipelineParallel())
|
|
{
|
|
mPipelineComm = NcclCommunicator::createPipelineComm(mWorldConfig, *mLogger);
|
|
mCommStream = std::make_shared<CudaStream>();
|
|
}
|
|
|
|
// TODO compare expected and runtime tensor names?
|
|
}
|
|
|
|
nvinfer1::ILogger& GptSession::getLogger() const
|
|
{
|
|
return *mLogger;
|
|
}
|
|
|
|
BufferManager& GptSession::getBufferManager() const
|
|
{
|
|
return mRuntime->getBufferManager();
|
|
}
|
|
|
|
void GptSession::createContexts(SizeType numMicroBatches)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
mRuntime->clearContexts();
|
|
// Instantiate multiple contexts for flip-flopping
|
|
auto const numContextsPerPhase = std::max(2, numMicroBatches);
|
|
auto const numProfiles = mRuntime->getNbProfiles();
|
|
TLLM_CHECK_WITH_INFO(
|
|
numProfiles == 1 || numProfiles == 2, "GPT only expects one optimization profile or two optimization profiles");
|
|
auto constexpr ctxContextId = 0;
|
|
auto constexpr genContextId = 1;
|
|
if (numProfiles == 2)
|
|
{
|
|
for (auto i = 0; i < numContextsPerPhase; ++i)
|
|
mRuntime->addContext(genContextId);
|
|
}
|
|
for (auto i = 0; i < numContextsPerPhase; ++i)
|
|
mRuntime->addContext(ctxContextId);
|
|
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void GptSession::createBuffers(SizeType numMicroBatches)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
mBuffers.clear();
|
|
|
|
for (SizeType i = 0; i < numMicroBatches; ++i)
|
|
{
|
|
mBuffers.emplace_back(std::make_shared<RuntimeBuffers>());
|
|
mBuffers.back()->create(*mRuntime, mModelConfig, mWorldConfig);
|
|
}
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
|
|
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
auto const vocabSize = mModelConfig.getVocabSize();
|
|
auto const vocabSizePadded = mModelConfig.getVocabSizePadded(mWorldConfig.getSize());
|
|
auto const& stream = mRuntime->getStreamPtr();
|
|
|
|
mDecoders.clear();
|
|
|
|
for (SizeType i = 0; i < numMicroBatches; ++i)
|
|
{
|
|
if (decoderPerRequest)
|
|
mDecoders.emplace_back(std::make_shared<GptDecoderBatch>(vocabSize, vocabSizePadded, stream));
|
|
else
|
|
mDecoders.emplace_back(std::make_shared<StatefulGptDecoder>(vocabSize, vocabSizePadded, stream));
|
|
mDecoders.back()->setup(batchSize, beamWidth, maxSequenceLength, logitsType);
|
|
}
|
|
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void GptSession::createKvCacheManagers(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
|
|
SizeType numMicroBatches, std::optional<SizeType> maxTokensInPagedKvCache)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
auto const localNbLayers = mModelConfig.getNbLayers(mWorldConfig.getPipelineParallelism());
|
|
auto const nbHeads = mModelConfig.getNbHeads();
|
|
auto const nbKvHeads = mModelConfig.getNbKvHeads();
|
|
auto const hiddenSize = mModelConfig.getHiddenSize();
|
|
auto const tokensPerBlock = mModelConfig.getTokensPerBlock();
|
|
|
|
auto const maxBlocksPerSeq = tc::divUp(maxSequenceLength, tokensPerBlock);
|
|
auto const maxNumTokens
|
|
= maxTokensInPagedKvCache.value_or(batchSize * beamWidth * maxBlocksPerSeq * tokensPerBlock);
|
|
auto const maxNumBlocks = tc::divUp(maxNumTokens, tokensPerBlock);
|
|
|
|
nvinfer1::DataType kvDtype;
|
|
if (mModelConfig.getQuantMode().hasFp8KvCache())
|
|
{
|
|
kvDtype = nvinfer1::DataType::kFP8;
|
|
}
|
|
else if (mModelConfig.getQuantMode().hasInt8KvCache())
|
|
{
|
|
kvDtype = nvinfer1::DataType::kINT8;
|
|
}
|
|
else
|
|
{
|
|
kvDtype = mModelConfig.getDataType();
|
|
}
|
|
|
|
mKvCacheManagers.clear();
|
|
|
|
for (SizeType i = 0; i < numMicroBatches; ++i)
|
|
{
|
|
mKvCacheManagers.emplace_back(
|
|
std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbHeads, nbKvHeads, hiddenSize, tokensPerBlock,
|
|
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, kvDtype, mRuntime->getStreamPtr()));
|
|
}
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void GptSession::createCustomAllReduceWorkspace(
|
|
SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength)
|
|
{
|
|
setPeerAccess(mWorldConfig, true);
|
|
|
|
auto& manager = mRuntime->getBufferManager();
|
|
for (const auto& buffer : mBuffers)
|
|
{
|
|
buffer->createCustomAllReduceWorkspace(
|
|
maxBatchSize, maxBeamWidth, maxSequenceLength, mModelConfig.getHiddenSize(), mWorldConfig, manager);
|
|
}
|
|
}
|
|
|
|
void GptSession::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, bool decoderPerRequest,
|
|
std::optional<SizeType> maxTokensInPagedKvCache, std::optional<SizeType> numMicroBatches)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
|
|
if (numMicroBatches)
|
|
mNumMicroBatches = numMicroBatches.value();
|
|
createContexts(mNumMicroBatches);
|
|
createBuffers(mNumMicroBatches);
|
|
|
|
auto const microBatchSize = tc::ceilDiv(maxBatchSize, mNumMicroBatches);
|
|
// Store this param related to deocder buffer size and kv cache manager to check against
|
|
// the input shape with the params given in generate().
|
|
// gptDecoderBatch does not resize buffers, but allows smaller batchSize and beamWidth.
|
|
// TODO refactor batch manager to remove dependency on maxSequenceLength.
|
|
mDecoderMaxSequenceLength = maxSequenceLength;
|
|
|
|
if (mModelConfig.usePagedKvCache())
|
|
{
|
|
createKvCacheManagers(
|
|
microBatchSize, maxBeamWidth, maxSequenceLength, mNumMicroBatches, maxTokensInPagedKvCache);
|
|
}
|
|
|
|
if (mWorldConfig.isLastPipelineParallelRank())
|
|
{
|
|
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
|
|
createDecoders(
|
|
microBatchSize, maxBeamWidth, maxSequenceLength, logitsType, decoderPerRequest, mNumMicroBatches);
|
|
}
|
|
|
|
if (mWorldConfig.isPipelineParallel())
|
|
{
|
|
mReceivedEvents.clear();
|
|
for (SizeType i = 0; i < mNumMicroBatches; ++i)
|
|
mReceivedEvents.emplace_back();
|
|
}
|
|
|
|
if (mWorldConfig.isTensorParallel() && mModelConfig.useCustomAllReduce())
|
|
{
|
|
createCustomAllReduceWorkspace(microBatchSize, maxBeamWidth, maxSequenceLength);
|
|
}
|
|
|
|
// we don't know maxInputLength and maxNewTokens yet and ignore those for pre-allocation
|
|
auto const generationConfig
|
|
= RuntimeBuffers::GenerationConfig{microBatchSize, maxBeamWidth, 0, 0, maxSequenceLength};
|
|
|
|
for (auto& buffers : mBuffers)
|
|
buffers->reshape(generationConfig, mModelConfig, mWorldConfig);
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void GptSession::generateSingleBatch(
|
|
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
|
|
TLLM_CHECK_WITH_INFO(inputs.packed == mModelConfig.usePackedInput(),
|
|
"The chosen model requires a packed input tensor (did you set packed?).");
|
|
auto const& inputLengths = inputs.lengths;
|
|
TLLM_CHECK_WITH_INFO(inputLengths->getShape().nbDims == 1, "Input lengths tensor must be one-dimensional.");
|
|
|
|
auto constexpr microBatchId = 0;
|
|
auto& manager = mRuntime->getBufferManager();
|
|
|
|
// Initialize and reshape buffers
|
|
auto& buffers = *mBuffers.at(microBatchId);
|
|
TLLM_CHECK_WITH_INFO(buffers.allocated, "Buffers not allocated, please call setup first!");
|
|
buffers.initContextLengths(inputLengths, manager);
|
|
auto const generationConfig = RuntimeBuffers::GenerationConfig::fromInput(*inputs.ids, *buffers.contextLengthsHost,
|
|
inputs.packed, samplingConfig.beamWidth, mDecoderMaxSequenceLength, inputs.maxNewTokens);
|
|
|
|
auto const batchSize = generationConfig.batchSize;
|
|
auto const beamWidth = generationConfig.beamWidth;
|
|
auto const maxInputLength = generationConfig.maxInputLength;
|
|
auto const maxNewTokens = generationConfig.maxNewTokens;
|
|
|
|
if (mWorldConfig.isLastPipelineParallelRank() && mModelConfig.computeContextLogits())
|
|
{
|
|
auto const vocabSizePadded = mModelConfig.getVocabSizePadded(mWorldConfig.getSize());
|
|
|
|
TLLM_CHECK_WITH_INFO(outputs.contextLogits,
|
|
"outputs.contextLogits is nullptr. It must be allocated when computeContextLogits() is enabled.");
|
|
outputs.contextLogits->reshape(ITensor::makeShape({batchSize, maxInputLength, vocabSizePadded}));
|
|
auto const contextLogitsShape = outputs.contextLogits->getShape();
|
|
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[0] == batchSize, "Invalid dim[0]");
|
|
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[1] == maxInputLength, "Invalid dim[1]");
|
|
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[2] == vocabSizePadded, "Invalid dim[2]");
|
|
|
|
buffers.logits = outputs.contextLogits;
|
|
}
|
|
|
|
buffers.reshape(generationConfig, mModelConfig, mWorldConfig);
|
|
kvCacheAddSequences(beamWidth, microBatchId);
|
|
ITensor::SharedPtr newTokens{initNewTokens(inputs, samplingConfig, microBatchId)};
|
|
|
|
auto& onTokenGenerated = outputs.onTokenGenerated;
|
|
outputs.ids->reshape(ITensor::makeShape({batchSize, beamWidth, mDecoderMaxSequenceLength}));
|
|
|
|
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
|
|
RuntimeBuffers::TensorMap inputBuffers[2];
|
|
RuntimeBuffers::TensorMap outputBuffers[2];
|
|
for (SizeType step = 0; step < maxNewTokens; ++step)
|
|
{
|
|
auto const contextId = step % 2;
|
|
if (step == 0)
|
|
{
|
|
SizeType const contextIdForContextPhase
|
|
= mRuntime->getNbProfiles() == 2 ? mRuntime->getNbContexts() / 2 : 0;
|
|
buffers.prepareContextStep(
|
|
inputs.ids, inputs.padId, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
|
|
buffers.getRuntimeBuffers(
|
|
inputBuffers[contextId], outputBuffers[contextId], step, inputs.ids, mModelConfig, mWorldConfig);
|
|
mRuntime->setInputTensors(contextIdForContextPhase, inputBuffers[contextId]);
|
|
mRuntime->setOutputTensors(contextIdForContextPhase, outputBuffers[contextId]);
|
|
|
|
if (isCudaGraphMode())
|
|
{
|
|
for (auto& instance : mCudaGraphInstances)
|
|
{
|
|
instance.clear();
|
|
}
|
|
}
|
|
TLLM_CHECK_WITH_INFO(
|
|
mRuntime->executeContext(contextIdForContextPhase), "Executing TRT engine in context phase failed!");
|
|
}
|
|
else
|
|
{
|
|
if (isCudaGraphMode() && mCudaGraphInstances[contextId].hasInstance())
|
|
{
|
|
mCudaGraphInstances[contextId].launch(mRuntime->getStream());
|
|
}
|
|
else
|
|
{
|
|
TLLM_CHECK_WITH_INFO(
|
|
mRuntime->executeContext(contextId), "Executing TRT engine in generation phase failed!");
|
|
}
|
|
}
|
|
|
|
sync_check_cuda_error();
|
|
|
|
if (step == 0)
|
|
{
|
|
buffers.postContextStep(manager, generationConfig, mModelConfig, mWorldConfig);
|
|
}
|
|
|
|
std::swap(buffers.cacheIndirectionDecoderInput, buffers.cacheIndirectionDecoderOutput);
|
|
|
|
if (step < maxNewTokens - 1) // this is not the last step
|
|
{ // preparing the next step
|
|
auto const nextStep = step + 1;
|
|
auto const nextContextId = nextStep % 2;
|
|
auto nextInputIds = buffers.prepareNextStep(
|
|
step, newTokens, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
|
|
buffers.getRuntimeBuffers(inputBuffers[nextContextId], outputBuffers[nextContextId], nextStep, nextInputIds,
|
|
mModelConfig, mWorldConfig);
|
|
mRuntime->setInputTensors(nextContextId, inputBuffers[nextContextId]);
|
|
mRuntime->setOutputTensors(nextContextId, outputBuffers[nextContextId]);
|
|
|
|
if (isCudaGraphMode())
|
|
{
|
|
mCudaGraphInstances[nextContextId].prepareNextGraph(*mRuntime, nextContextId);
|
|
}
|
|
}
|
|
|
|
sync_check_cuda_error();
|
|
|
|
// FIXME: this synchronize is important to get logits right
|
|
// manager.getStream().synchronize();
|
|
|
|
decoderStepAsync(outputs.ids, newTokens, maxInputLength + step, microBatchId);
|
|
auto const shouldStop = shouldStopSync(batchSize, beamWidth, microBatchId);
|
|
|
|
if (mWorldConfig.isFirstPipelineParallelRank())
|
|
{
|
|
if (onTokenGenerated)
|
|
{
|
|
// TODO use getNewTokens(), remove step from Callback?
|
|
ITensor::SharedPtr outputIds
|
|
= mWorldConfig.isPipelineParallel() ? outputs.ids : mDecoders.at(microBatchId)->getOutputIds();
|
|
onTokenGenerated(outputIds, step, shouldStop || step == maxNewTokens - 1);
|
|
}
|
|
}
|
|
|
|
if (shouldStop)
|
|
{
|
|
mLogger->log(nvinfer1::ILogger::Severity::kVERBOSE, "GPT decoding finished early");
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (mModelConfig.usePagedKvCache())
|
|
{
|
|
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
|
{
|
|
kvCacheManager->removeSequence(batchIdx);
|
|
}
|
|
}
|
|
|
|
finalizeOutputIds(*outputs.ids, microBatchId);
|
|
manager.getStream().synchronize();
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void GptSession::kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId)
|
|
{
|
|
if (mModelConfig.usePagedKvCache())
|
|
{
|
|
auto& kvCacheManager = mKvCacheManagers.at(microBatchId);
|
|
TLLM_CHECK(kvCacheManager);
|
|
auto contextLengthsHost = mBuffers.at(microBatchId)->contextLengthsHost;
|
|
TLLM_CHECK(contextLengthsHost);
|
|
auto const contextLengthsPtr = bufferCast<SizeType const>(*contextLengthsHost);
|
|
auto const contextLengthsSize = static_cast<SizeType>(contextLengthsHost->getSize());
|
|
for (SizeType batchIdx = 0; batchIdx < contextLengthsSize; ++batchIdx)
|
|
{
|
|
kvCacheManager->addSequence(batchIdx, contextLengthsPtr[batchIdx], beamWidth);
|
|
}
|
|
}
|
|
}
|
|
|
|
ITensor::SharedPtr GptSession::initNewTokens(
|
|
GenerationInput const& inputs, SamplingConfig const& samplingConfig, SizeType microBatchId)
|
|
{
|
|
if (mWorldConfig.isLastPipelineParallelRank())
|
|
{
|
|
auto& decoder = mDecoders.at(microBatchId);
|
|
decoder->newBatch(inputs, samplingConfig);
|
|
return decoder->getNewTokens();
|
|
}
|
|
else if (mWorldConfig.isFirstPipelineParallelRank())
|
|
{
|
|
auto const beamWidth = samplingConfig.beamWidth;
|
|
auto const batchSize = static_cast<SizeType>(inputs.lengths->getSize());
|
|
return mRuntime->getBufferManager().gpu(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kINT32);
|
|
}
|
|
else
|
|
{
|
|
return ITensor::SharedPtr{};
|
|
}
|
|
}
|
|
|
|
namespace
|
|
{
|
|
std::vector<GenerationInput> splitInputs(
|
|
GenerationInput const& inputs, SizeType numMicroBatches, BufferManager& manager)
|
|
{
|
|
std::vector<GenerationInput> inputBatches;
|
|
auto const numRequests = inputs.lengths->getShape().d[0];
|
|
auto const microBatchSize = tc::ceilDiv(numRequests, numMicroBatches);
|
|
|
|
if (inputs.packed)
|
|
{
|
|
auto contextLengthsHost = manager.copyFrom(*inputs.lengths, MemoryType::kCPU);
|
|
ITensor::SharedPtr inputIdsView = ITensor::view(inputs.ids);
|
|
inputIdsView->squeeze(0);
|
|
auto contextLengthsRange = BufferRange<SizeType>(*contextLengthsHost);
|
|
|
|
auto tokensBegin = 0;
|
|
for (auto offset = 0; offset < numRequests; offset += microBatchSize)
|
|
{
|
|
auto batchSize = std::min(microBatchSize, numRequests - offset);
|
|
auto numTokens = std::accumulate(
|
|
contextLengthsRange.begin() + offset, contextLengthsRange.begin() + offset + batchSize, 0);
|
|
|
|
ITensor::SharedPtr batchInputs = ITensor::slice(inputIdsView, tokensBegin, numTokens);
|
|
batchInputs->reshape(ITensor::makeShape({1, numTokens}));
|
|
|
|
inputBatches.emplace_back(inputs.endId, inputs.padId, batchInputs,
|
|
ITensor::slice(inputs.lengths, offset, batchSize), inputs.packed);
|
|
|
|
tokensBegin += numTokens;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
for (auto offset = 0; offset < numRequests; offset += microBatchSize)
|
|
{
|
|
auto batchSize = std::min(microBatchSize, numRequests - offset);
|
|
inputBatches.emplace_back(inputs.endId, inputs.padId, ITensor::slice(inputs.ids, offset, batchSize),
|
|
ITensor::slice(inputs.lengths, offset, batchSize), inputs.packed);
|
|
}
|
|
}
|
|
|
|
for (auto& batch : inputBatches)
|
|
{
|
|
if (inputs.embeddingBiasOpt)
|
|
batch.embeddingBiasOpt = inputs.embeddingBiasOpt;
|
|
if (inputs.badWordsList)
|
|
batch.badWordsList = inputs.badWordsList;
|
|
if (inputs.stopWordsList)
|
|
batch.stopWordsList = inputs.stopWordsList;
|
|
if (inputs.maxNewTokens)
|
|
batch.maxNewTokens = inputs.maxNewTokens;
|
|
}
|
|
|
|
return inputBatches;
|
|
}
|
|
|
|
void updateOutputIds(
|
|
ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep, CudaStream const& stream)
|
|
{ // assemble outputIds of all micro batches
|
|
auto const& newTokensShape = newTokens->getShape();
|
|
auto newTokensView = ITensor::view(newTokens, ITensor::makeShape({1, newTokensShape.d[0] * newTokensShape.d[1]}));
|
|
auto const& outputIdsShape = outputIds->getShape();
|
|
auto outputIdsView = ITensor::view(
|
|
outputIds, ITensor::makeShape({outputIdsShape.d[0] * outputIdsShape.d[1], outputIdsShape.d[2]}));
|
|
kernels::invokeTransposeWithOutputOffset(*outputIdsView, *newTokensView, decoderStep, stream);
|
|
}
|
|
} // namespace
|
|
|
|
void GptSession::generateMultiBatch(
|
|
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
|
|
TLLM_CHECK_WITH_INFO(inputs.packed == mModelConfig.usePackedInput(),
|
|
"The chosen model requires a packed input tensor (did you set packed?).");
|
|
auto const& inputLengths = inputs.lengths;
|
|
TLLM_CHECK_WITH_INFO(inputLengths->getShape().nbDims == 1, "Input lengths tensor must be one-dimensional.");
|
|
|
|
auto& manager = mRuntime->getBufferManager();
|
|
|
|
auto const batchSize = static_cast<SizeType>(inputLengths->getSize());
|
|
auto const beamWidth = samplingConfig.beamWidth;
|
|
outputs.ids->reshape(ITensor::makeShape({batchSize, beamWidth, mDecoderMaxSequenceLength}));
|
|
auto& onTokenGenerated = outputs.onTokenGenerated;
|
|
|
|
auto const numMicroBatches = std::min(batchSize, mNumMicroBatches);
|
|
auto microBatches = splitInputs(inputs, numMicroBatches, manager);
|
|
|
|
std::vector<RuntimeBuffers::GenerationConfig> generationConfigs;
|
|
std::vector<ITensor::SharedPtr> newTokensPerBatch;
|
|
std::vector<ITensor::SharedPtr> outputIdsPerBatch;
|
|
|
|
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
|
{
|
|
auto const& microBatchInputs = microBatches.at(microBatchId);
|
|
|
|
// Initialize and reshape buffers
|
|
auto& buffers = *mBuffers.at(microBatchId);
|
|
TLLM_CHECK_WITH_INFO(buffers.allocated, "Buffers not allocated, please call setup first!");
|
|
buffers.initContextLengths(microBatchInputs.lengths, manager);
|
|
generationConfigs.emplace_back(RuntimeBuffers::GenerationConfig::fromInput(*microBatchInputs.ids,
|
|
*buffers.contextLengthsHost, microBatchInputs.packed, samplingConfig.beamWidth, mDecoderMaxSequenceLength,
|
|
microBatchInputs.maxNewTokens));
|
|
auto const& generationConfig = generationConfigs.back();
|
|
|
|
auto const beamWidth = generationConfig.beamWidth;
|
|
|
|
buffers.reshape(generationConfig, mModelConfig, mWorldConfig);
|
|
kvCacheAddSequences(beamWidth, microBatchId);
|
|
newTokensPerBatch.emplace_back(initNewTokens(microBatchInputs, samplingConfig, microBatchId));
|
|
}
|
|
|
|
auto maxNewTokens = generationConfigs.front().maxNewTokens;
|
|
auto microBatchSize = generationConfigs.front().batchSize;
|
|
auto offset = 0;
|
|
outputIdsPerBatch.emplace_back(ITensor::slice(outputs.ids, offset, microBatchSize));
|
|
offset += microBatchSize;
|
|
for (auto microBatchId = 1; microBatchId < numMicroBatches; ++microBatchId)
|
|
{
|
|
maxNewTokens = std::min(maxNewTokens, generationConfigs.at(microBatchId).maxNewTokens);
|
|
auto microBatchSize = generationConfigs.at(microBatchId).batchSize;
|
|
outputIdsPerBatch.emplace_back(ITensor::slice(outputs.ids, offset, microBatchSize));
|
|
offset += microBatchSize;
|
|
}
|
|
|
|
// TODO(micro batching) do we need 1 or 2 per micro batch?
|
|
std::vector<RuntimeBuffers::TensorMap> inputBuffers(numMicroBatches * 2);
|
|
std::vector<RuntimeBuffers::TensorMap> outputBuffers(numMicroBatches * 2);
|
|
std::vector<bool> microBatchesFinished(numMicroBatches, false);
|
|
for (SizeType step = 0; step < maxNewTokens; ++step)
|
|
{
|
|
if (std::all_of(microBatchesFinished.begin(), microBatchesFinished.end(), [](bool x) { return x; }))
|
|
break;
|
|
|
|
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
|
{
|
|
auto& buffers = *mBuffers.at(microBatchId);
|
|
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
|
|
auto& newTokens = newTokensPerBatch.at(microBatchId);
|
|
auto& generationConfig = generationConfigs.at(microBatchId);
|
|
auto& outputIds = outputIdsPerBatch.at(microBatchId);
|
|
|
|
if (microBatchesFinished.at(microBatchId))
|
|
continue;
|
|
|
|
if (step > 0)
|
|
{
|
|
auto const microBatchSize = generationConfig.batchSize;
|
|
auto const beamWidth = generationConfig.beamWidth;
|
|
auto const shouldStop = shouldStopSync(microBatchSize, beamWidth, microBatchId);
|
|
|
|
if (mWorldConfig.isFirstPipelineParallelRank() && onTokenGenerated
|
|
&& microBatchId == numMicroBatches - 1)
|
|
{
|
|
onTokenGenerated(outputs.ids, step - 1, shouldStop);
|
|
}
|
|
|
|
if (shouldStop)
|
|
{
|
|
mLogger->log(nvinfer1::ILogger::Severity::kVERBOSE, "GPT decoding finished early");
|
|
microBatchesFinished.at(microBatchId) = true;
|
|
continue;
|
|
}
|
|
}
|
|
|
|
auto const contextId = microBatchId;
|
|
if (step == 0)
|
|
{
|
|
SizeType const contextIdForContextPhase
|
|
= contextId + (mRuntime->getNbProfiles() == 2 ? mNumMicroBatches : 0);
|
|
auto const& inputs = microBatches.at(microBatchId);
|
|
buffers.prepareContextStep(
|
|
inputs.ids, inputs.padId, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
|
|
buffers.getRuntimeBuffers(
|
|
inputBuffers[contextId], outputBuffers[contextId], step, inputs.ids, mModelConfig, mWorldConfig);
|
|
mRuntime->setInputTensors(contextIdForContextPhase, inputBuffers[contextId]);
|
|
mRuntime->setOutputTensors(contextIdForContextPhase, outputBuffers[contextId]);
|
|
|
|
TLLM_CHECK_WITH_INFO(
|
|
mRuntime->executeContext(contextIdForContextPhase), "Executing TRT engine failed!");
|
|
|
|
buffers.postContextStep(manager, generationConfig, mModelConfig, mWorldConfig);
|
|
}
|
|
else
|
|
{
|
|
auto nextInputIds = buffers.prepareNextStep(
|
|
step - 1, newTokens, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
|
|
buffers.getRuntimeBuffers(
|
|
inputBuffers[contextId], outputBuffers[contextId], step, nextInputIds, mModelConfig, mWorldConfig);
|
|
mRuntime->setInputTensors(contextId, inputBuffers[contextId]);
|
|
mRuntime->setOutputTensors(contextId, outputBuffers[contextId]);
|
|
|
|
TLLM_CHECK_WITH_INFO(mRuntime->executeContext(contextId), "Executing TRT engine failed!");
|
|
}
|
|
sync_check_cuda_error();
|
|
|
|
std::swap(buffers.cacheIndirectionDecoderInput, buffers.cacheIndirectionDecoderOutput);
|
|
|
|
auto const maxInputLength = generationConfigs.at(microBatchId).maxInputLength;
|
|
auto const decoderStep = maxInputLength + step;
|
|
decoderStepAsync(outputIds, newTokens, decoderStep, microBatchId);
|
|
if (!mWorldConfig.isPipelineParallel() && mNumMicroBatches > 1)
|
|
{
|
|
updateOutputIds(outputIds, newTokens, decoderStep, mRuntime->getStream());
|
|
}
|
|
}
|
|
}
|
|
|
|
// TODO(micro batching) move into loop above?
|
|
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
|
{
|
|
auto const& generationConfig = generationConfigs.at(microBatchId);
|
|
auto const microBatchSize = generationConfig.batchSize;
|
|
|
|
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
|
|
auto& outputIds = outputIdsPerBatch.at(microBatchId);
|
|
|
|
// TODO(micro batching) sync receive event
|
|
if (mWorldConfig.isFirstPipelineParallelRank() && onTokenGenerated && microBatchId == numMicroBatches - 1)
|
|
{
|
|
onTokenGenerated(outputs.ids, maxNewTokens - 1, true);
|
|
}
|
|
|
|
if (mModelConfig.usePagedKvCache())
|
|
{
|
|
for (auto batchIdx = 0; batchIdx < microBatchSize; ++batchIdx)
|
|
{
|
|
kvCacheManager->removeSequence(batchIdx);
|
|
}
|
|
}
|
|
|
|
// TODO(micro batching) use mCommStream?
|
|
finalizeOutputIds(*outputIds, microBatchId);
|
|
}
|
|
manager.getStream().synchronize();
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void GptSession::decoderStepAsync(
|
|
ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep, SizeType microBatchId)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
auto& stream = mRuntime->getStream();
|
|
auto& buffers = *mBuffers.at(microBatchId);
|
|
|
|
if (mWorldConfig.isLastPipelineParallelRank())
|
|
{
|
|
auto& decoder = *mDecoders.at(microBatchId);
|
|
|
|
decoder::Input decodingInput{buffers.logits};
|
|
decoder::Output decodingOutput{};
|
|
decodingInput.cacheIndirection = buffers.cacheIndirectionDecoderInput;
|
|
decodingOutput.cacheIndirection = buffers.cacheIndirectionDecoderOutput;
|
|
decodingOutput.sequenceLengths = buffers.sequenceLengths;
|
|
|
|
decoder.forwardAsync(decodingOutput, decodingInput);
|
|
if (mWorldConfig.isPipelineParallel())
|
|
{ // send shouldStop to all previous ranks and newTokens to the first rank
|
|
stream.record(mCommEvent.get());
|
|
mCommStream->wait(mCommEvent.get());
|
|
auto const pipelineGroup = mWorldConfig.getPipelineParallelGroup();
|
|
|
|
auto& cacheIndirection = *buffers.cacheIndirectionDecoderOutput;
|
|
auto& sequenceLengths = *buffers.sequenceLengths;
|
|
auto const beamWidth = cacheIndirection.getShape().d[1];
|
|
for (auto peerIdx = 0; peerIdx < mWorldConfig.getPipelineParallelism() - 1; ++peerIdx)
|
|
{
|
|
mPipelineComm->send<SizeType>(*decoder.getNbFinished(), pipelineGroup[peerIdx], *mCommStream, *mLogger);
|
|
if (beamWidth > 1)
|
|
{
|
|
mPipelineComm->send<SizeType>(cacheIndirection, pipelineGroup[peerIdx], *mCommStream, *mLogger);
|
|
}
|
|
mPipelineComm->send<SizeType>(sequenceLengths, pipelineGroup[peerIdx], *mCommStream, *mLogger);
|
|
}
|
|
mPipelineComm->send<TokenIdType>(*decoder.getNewTokens(), pipelineGroup.front(), *mCommStream, *mLogger);
|
|
}
|
|
}
|
|
else // pipeline parallel mode
|
|
{ // receive shouldStop from the last rank on a separate stream
|
|
stream.record(mCommEvent.get());
|
|
mCommStream->wait(mCommEvent.get());
|
|
auto const pipelineGroup = mWorldConfig.getPipelineParallelGroup();
|
|
auto const peer = pipelineGroup.back();
|
|
mPipelineComm->receive<SizeType>(*buffers.nbFinished, peer, *mCommStream, *mLogger);
|
|
|
|
auto& cacheIndirection = *buffers.cacheIndirectionDecoderOutput;
|
|
auto& sequenceLengths = *buffers.sequenceLengths;
|
|
auto const beamWidth = cacheIndirection.getShape().d[1];
|
|
if (beamWidth > 1)
|
|
{
|
|
mPipelineComm->receive<SizeType>(cacheIndirection, peer, *mCommStream, *mLogger);
|
|
}
|
|
mPipelineComm->receive<SizeType>(sequenceLengths, peer, *mCommStream, *mLogger);
|
|
if (mWorldConfig.isFirstPipelineParallelRank())
|
|
{ // receive newTokens from last rank on a separate stream
|
|
mPipelineComm->receive<TokenIdType>(*newTokens, peer, *mCommStream, *mLogger);
|
|
updateOutputIds(outputIds, newTokens, decoderStep, *mCommStream);
|
|
}
|
|
mCommStream->record(mReceivedEvents.at(microBatchId).get());
|
|
}
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
bool GptSession::shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType microBatchId)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
|
|
SizeType nbFinished = 0;
|
|
|
|
if (mWorldConfig.isLastPipelineParallelRank())
|
|
{ // read the Finished flag from the decoder
|
|
auto& decoder = *mDecoders.at(microBatchId);
|
|
decoder.isFinishedSync();
|
|
nbFinished = *bufferCast<SizeType>(*decoder.getNbFinished());
|
|
}
|
|
else
|
|
{ // ensure all information has been received
|
|
TLLM_CUDA_CHECK(cudaEventSynchronize(mReceivedEvents.at(microBatchId).get()));
|
|
nbFinished = *bufferCast<SizeType>(*mBuffers.at(microBatchId)->nbFinished);
|
|
}
|
|
sync_check_cuda_error();
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
return nbFinished == batchSize * beamWidth;
|
|
}
|
|
|
|
void GptSession::finalizeOutputIds(ITensor& outputIds, SizeType microBatchId)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
auto& manager = mRuntime->getBufferManager();
|
|
|
|
if (mWorldConfig.isPipelineParallel())
|
|
{
|
|
auto& stream = mRuntime->getStream();
|
|
auto const pipelineGroup = mWorldConfig.getPipelineParallelGroup();
|
|
|
|
if (mWorldConfig.isLastPipelineParallelRank())
|
|
{ // send ids from last to first
|
|
auto const peer = pipelineGroup.front();
|
|
auto const finalOutputIds = mDecoders.at(microBatchId)->getFinalOutputIds();
|
|
mPipelineComm->send(
|
|
bufferCast<std::int32_t>(*finalOutputIds), finalOutputIds->getSize(), peer, stream, *mLogger);
|
|
}
|
|
else if (mWorldConfig.isFirstPipelineParallelRank())
|
|
{ // receive ids from last on first
|
|
auto const peer = pipelineGroup.back();
|
|
mPipelineComm->receive(bufferCast<std::int32_t>(outputIds), outputIds.getSize(), peer, stream, *mLogger);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
manager.copy(*mDecoders.at(microBatchId)->getFinalOutputIds(), outputIds);
|
|
}
|
|
|
|
sync_check_cuda_error();
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void GptSession::CudaGraphExecutor::create(cudaGraph_t const& graph)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
assert(mInstance == nullptr);
|
|
TLLM_CUDA_CHECK(cudaGraphInstantiate(&mInstance, graph, nullptr, nullptr, 0));
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void GptSession::CudaGraphExecutor::uploadToStream(CudaStream const& stream)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
assert(hasInstance());
|
|
TLLM_CUDA_CHECK(cudaGraphUpload(mInstance, stream.get()));
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void GptSession::CudaGraphExecutor::launch(CudaStream const& stream)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
TLLM_CUDA_CHECK(cudaGraphLaunch(mInstance, stream.get()));
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
bool GptSession::CudaGraphExecutor::update(cudaGraph_t const& graph)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
return cudaGraphExecUpdate(mInstance, graph, nullptr) != cudaSuccess;
|
|
}
|
|
|
|
void GptSession::CudaGraphExecutor::clear()
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
if (mInstance != nullptr)
|
|
{
|
|
TLLM_CUDA_CHECK(cudaGraphExecDestroy(mInstance));
|
|
mInstance = nullptr;
|
|
}
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|
|
|
|
void GptSession::CudaGraphExecutor::prepareNextGraph(TllmRuntime const& runtime, SizeType nextContextId)
|
|
{
|
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
|
auto& stream = runtime.getStream();
|
|
|
|
cudaGraph_t nextGraph;
|
|
TLLM_CUDA_CHECK(cudaStreamBeginCapture(stream.get(), cudaStreamCaptureModeThreadLocal));
|
|
runtime.executeContext(nextContextId);
|
|
TLLM_CUDA_CHECK(cudaStreamEndCapture(stream.get(), &nextGraph));
|
|
|
|
if (hasInstance())
|
|
{
|
|
if (update(nextGraph))
|
|
{
|
|
clear();
|
|
create(nextGraph);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
create(nextGraph);
|
|
}
|
|
|
|
TLLM_CUDA_CHECK(cudaGraphDestroy(nextGraph));
|
|
uploadToStream(stream);
|
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
|
}
|