mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* chore: Remove GptSession/V1 from TRT workflow Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove stateful decoders Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove GptSession buffers Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove GptSession utils Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove GptSession kernels Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove V1 GPT models from tests Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove gptSessionBenchmark from scripts and docs Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove gptSession IO classes Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove GptSession from test lists Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove GptSession from docs Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove useless encoder test Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove mActualBatchSize from DecoderState Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Remove static batching from ExecutorTest - Updated `validateContextLogits` and `validateGenerationLogits` functions to remove the `batchingType` parameter. - Adjusted related test functions to reflect the changes in parameter lists. - Cleaned up the instantiation of test cases to eliminate unnecessary batchingType references. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --------- Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
116 lines
5.4 KiB
C++
116 lines
5.4 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
|
|
#include "tensorrt_llm/common/assert.h"
|
|
#include "tensorrt_llm/runtime/utils/runtimeUtils.h"
|
|
|
|
using namespace tensorrt_llm::runtime;
|
|
|
|
namespace tensorrt_llm::batch_manager::rnn_state_manager
|
|
{
|
|
|
|
RnnStateManager::RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runtime::ModelConfig const& modelConfig,
|
|
runtime::WorldConfig const& worldConfig, tensorrt_llm::runtime::BufferManager const& bufferManager)
|
|
: mMaxNumSequences(maxNumSequences)
|
|
, mMaxBeamWidth{modelConfig.getMaxBeamWidth()}
|
|
{
|
|
TLLM_CHECK_WITH_INFO(modelConfig.usePagedState(), "RnnStateManager should be used with Paged State enabled.");
|
|
TLLM_CHECK_WITH_INFO(modelConfig.useMambaConv1dPlugin(), "RnnStateManager should be used with MambaConv1dPlugin.");
|
|
TLLM_CHECK_WITH_INFO(mMaxBeamWidth == 1, "Beam search is not supported for Mamba now.");
|
|
mBeamSlotsPerSequence = mMaxBeamWidth == 1 ? mMaxBeamWidth : mMaxBeamWidth + 1;
|
|
// If we need support beam search, we may need mMaxBeamWidth + 1 slots and use separate input / output states.
|
|
auto const& rnnConfig = modelConfig.getRnnConfig();
|
|
TLLM_CHECK_WITH_INFO(rnnConfig.has_value(), "RnnStateManager should be used with rnnConfig");
|
|
auto const convKernel = rnnConfig->convKernel;
|
|
auto const stateSize = rnnConfig->stateSize;
|
|
auto const rnnHiddenSize = rnnConfig->rnnHiddenSize;
|
|
auto const rnnHeadSize = rnnConfig->rnnHeadSize;
|
|
auto const rnnConvDimSize = rnnConfig->rnnConvDimSize;
|
|
auto const localNbLayers = modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism());
|
|
auto const dataType = modelConfig.getDataType();
|
|
|
|
auto const rnnStateShape = [&]()
|
|
{
|
|
if (rnnHeadSize > 0)
|
|
{
|
|
return tensorrt_llm::runtime::ITensor::makeShape({localNbLayers, mMaxNumSequences * mBeamSlotsPerSequence,
|
|
rnnHiddenSize / rnnHeadSize, stateSize, rnnHeadSize});
|
|
}
|
|
else
|
|
{
|
|
return tensorrt_llm::runtime::ITensor::makeShape(
|
|
{localNbLayers, mMaxNumSequences * mBeamSlotsPerSequence, stateSize, rnnHiddenSize});
|
|
}
|
|
}();
|
|
auto const convStateShape = tensorrt_llm::runtime::ITensor::makeShape(
|
|
{localNbLayers, mMaxNumSequences * mBeamSlotsPerSequence, convKernel - 1, rnnConvDimSize});
|
|
pagedRnnStates = bufferManager.gpu(rnnStateShape, nvinfer1::DataType::kFLOAT);
|
|
pagedConvStates = bufferManager.gpu(convStateShape, dataType);
|
|
|
|
auto const statePtrsShape = tensorrt_llm::runtime::ITensor::makeShape({localNbLayers});
|
|
rnnStatePtrs = tensorrt_llm::runtime::BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
|
|
convStatePtrs = tensorrt_llm::runtime::BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
|
|
auto* rnnStatePtrArray = bufferCast<void*>(*rnnStatePtrs);
|
|
auto* convStatePtrArray = bufferCast<void*>(*convStatePtrs);
|
|
|
|
rnnStatePtr.resize(localNbLayers);
|
|
convStatePtr.resize(localNbLayers);
|
|
for (int i = 0; i < localNbLayers; i++)
|
|
{
|
|
auto layerRnnStates = tensorrt_llm::runtime::ITensor::slice(pagedRnnStates, i, 1);
|
|
auto layerConvStates = tensorrt_llm::runtime::ITensor::slice(pagedConvStates, i, 1);
|
|
rnnStatePtrArray[i] = layerRnnStates->data();
|
|
convStatePtrArray[i] = layerConvStates->data();
|
|
rnnStatePtr[i] = tensorrt_llm::runtime::ITensor::slice(rnnStatePtrs, i, 1);
|
|
convStatePtr[i] = tensorrt_llm::runtime::ITensor::slice(convStatePtrs, i, 1);
|
|
}
|
|
}
|
|
|
|
void RnnStateManager::getPtrBuffers(
|
|
TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const
|
|
{
|
|
auto const localNbLayers = modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism());
|
|
auto const firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
|
|
auto const& layerTypes = modelConfig.getLayerTypes();
|
|
|
|
utils::insertTensorVector(
|
|
inputBuffers, "conv_state_ptr_", convStatePtr, firstLayerId, layerTypes, ModelConfig::LayerType::kRECURRENT);
|
|
utils::insertTensorVector(
|
|
inputBuffers, "rnn_state_ptr_", rnnStatePtr, firstLayerId, layerTypes, ModelConfig::LayerType::kRECURRENT);
|
|
}
|
|
|
|
void RnnStateManager::fillSlotMapping(
|
|
runtime::ITensor& dstPointers, SizeType32 dstSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const
|
|
{
|
|
TLLM_CHECK(seqSlotIdx < mMaxNumSequences);
|
|
TLLM_CHECK(beamWidth <= mMaxBeamWidth);
|
|
|
|
auto* dstPtr = bufferCast<SizeType32>(dstPointers);
|
|
if (beamWidth == 1)
|
|
{
|
|
dstPtr[dstSlotOffset] = seqSlotIdx * mBeamSlotsPerSequence;
|
|
}
|
|
else
|
|
{
|
|
// leave first for context.
|
|
std::iota(dstPtr + dstSlotOffset, dstPtr + dstSlotOffset + beamWidth, seqSlotIdx * mBeamSlotsPerSequence + 1);
|
|
}
|
|
}
|
|
|
|
} // namespace tensorrt_llm::batch_manager::rnn_state_manager
|