TensorRT-LLMs/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp
Iman Tabrizian 7d992972b2
[TRTLLM-10273][feat] Move MambaCacheManager from Python to C++ (#10540)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
2026-02-10 07:20:56 -08:00

269 lines
11 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/utils/runtimeUtils.h"
#include <unordered_set>
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm::batch_manager::rnn_state_manager
{
RnnStateManager::RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, tensorrt_llm::runtime::BufferManager const& bufferManager)
: mMaxNumSequences(maxNumSequences)
, mMaxBeamWidth{modelConfig.getMaxBeamWidth()}
{
TLLM_CHECK_WITH_INFO(modelConfig.usePagedState(), "RnnStateManager should be used with Paged State enabled.");
TLLM_CHECK_WITH_INFO(modelConfig.useMambaConv1dPlugin(), "RnnStateManager should be used with MambaConv1dPlugin.");
TLLM_CHECK_WITH_INFO(mMaxBeamWidth == 1, "Beam search is not supported for Mamba now.");
mBeamSlotsPerSequence = mMaxBeamWidth == 1 ? mMaxBeamWidth : mMaxBeamWidth + 1;
// If we need support beam search, we may need mMaxBeamWidth + 1 slots and use separate input / output states.
auto const& rnnConfig = modelConfig.getRnnConfig();
TLLM_CHECK_WITH_INFO(rnnConfig.has_value(), "RnnStateManager should be used with rnnConfig");
auto const convKernel = rnnConfig->convKernel;
auto const stateSize = rnnConfig->stateSize;
auto const rnnHiddenSize = rnnConfig->rnnHiddenSize;
auto const rnnHeadSize = rnnConfig->rnnHeadSize;
auto const rnnConvDimSize = rnnConfig->rnnConvDimSize;
auto const localNbLayers
= modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
auto const dataType = modelConfig.getDataType();
auto const rnnStateShape = [&]()
{
if (rnnHeadSize > 0)
{
return tensorrt_llm::runtime::ITensor::makeShape({localNbLayers, mMaxNumSequences * mBeamSlotsPerSequence,
rnnHiddenSize / rnnHeadSize, stateSize, rnnHeadSize});
}
else
{
return tensorrt_llm::runtime::ITensor::makeShape(
{localNbLayers, mMaxNumSequences * mBeamSlotsPerSequence, stateSize, rnnHiddenSize});
}
}();
auto const convStateShape = tensorrt_llm::runtime::ITensor::makeShape(
{localNbLayers, mMaxNumSequences * mBeamSlotsPerSequence, convKernel - 1, rnnConvDimSize});
pagedRnnStates = bufferManager.gpu(rnnStateShape, nvinfer1::DataType::kFLOAT);
pagedConvStates = bufferManager.gpu(convStateShape, dataType);
auto const statePtrsShape = tensorrt_llm::runtime::ITensor::makeShape({localNbLayers});
rnnStatePtrs = tensorrt_llm::runtime::BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
convStatePtrs = tensorrt_llm::runtime::BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
auto* rnnStatePtrArray = bufferCast<void*>(*rnnStatePtrs);
auto* convStatePtrArray = bufferCast<void*>(*convStatePtrs);
rnnStatePtr.resize(localNbLayers);
convStatePtr.resize(localNbLayers);
for (int i = 0; i < localNbLayers; i++)
{
auto layerRnnStates = tensorrt_llm::runtime::ITensor::slice(pagedRnnStates, i, 1);
auto layerConvStates = tensorrt_llm::runtime::ITensor::slice(pagedConvStates, i, 1);
rnnStatePtrArray[i] = layerRnnStates->data();
convStatePtrArray[i] = layerConvStates->data();
rnnStatePtr[i] = tensorrt_llm::runtime::ITensor::slice(rnnStatePtrs, i, 1);
convStatePtr[i] = tensorrt_llm::runtime::ITensor::slice(convStatePtrs, i, 1);
}
}
RnnStateManager::RnnStateManager(SizeType32 dState, SizeType32 dConv, SizeType32 numHeads, SizeType32 nGroups,
SizeType32 headDim, SizeType32 maxBatchSize, WorldConfig const& worldConfig, int64_t stream,
nvinfer1::DataType dtype, nvinfer1::DataType ssmCacheDtype, std::vector<SizeType32> const& ppLayers)
: mMaxNumSequences(maxBatchSize)
, mMaxBeamWidth{1}
, mBeamSlotsPerSequence{1}
, mBufferManager{std::make_shared<CudaStream>(reinterpret_cast<cudaStream_t>(stream))}
{
auto const tpSize = worldConfig.getTensorParallelism();
auto const dInner = headDim * numHeads;
auto convDim = dInner + 2 * nGroups * dState;
auto nheads = numHeads;
TLLM_CHECK_WITH_INFO(nheads % tpSize == 0, "nheads must be divisible by tp_size");
TLLM_CHECK_WITH_INFO(convDim % tpSize == 0, "conv_dim must be divisible by tp_size");
convDim = convDim / tpSize;
nheads = nheads / tpSize;
auto const numLocalLayers = static_cast<SizeType32>(ppLayers.size());
for (SizeType32 offset = 0; offset < numLocalLayers; ++offset)
{
mLayerOffsets[ppLayers[offset]] = offset;
}
auto const convStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, convDim, dConv - 1});
pagedConvStates = mBufferManager->gpu(convStateShape, dtype);
auto const rnnStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, nheads, headDim, dState});
pagedRnnStates = mBufferManager->gpu(rnnStateShape, ssmCacheDtype);
mFreeBlocks.reserve(maxBatchSize);
for (SizeType32 i = 0; i < maxBatchSize; ++i)
{
mFreeBlocks.push_back(i);
}
auto const statePtrsShape = ITensor::makeShape({numLocalLayers});
rnnStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
convStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
auto* rnnStatePtrArray = bufferCast<void*>(*rnnStatePtrs);
auto* convStatePtrArray = bufferCast<void*>(*convStatePtrs);
rnnStatePtr.resize(numLocalLayers);
convStatePtr.resize(numLocalLayers);
for (SizeType32 i = 0; i < numLocalLayers; i++)
{
auto layerRnnStates = ITensor::slice(pagedRnnStates, i, 1);
auto layerConvStates = ITensor::slice(pagedConvStates, i, 1);
rnnStatePtrArray[i] = layerRnnStates->data();
convStatePtrArray[i] = layerConvStates->data();
rnnStatePtr[i] = ITensor::slice(rnnStatePtrs, i, 1);
convStatePtr[i] = ITensor::slice(convStatePtrs, i, 1);
}
}
void RnnStateManager::getPtrBuffers(
TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const
{
auto const firstLayerId
= modelConfig.getFirstLocalLayer(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
auto const& layerTypes = modelConfig.getLayerTypes();
utils::insertTensorVector(
inputBuffers, "conv_state_ptr_", convStatePtr, firstLayerId, layerTypes, ModelConfig::LayerType::kRECURRENT);
utils::insertTensorVector(
inputBuffers, "rnn_state_ptr_", rnnStatePtr, firstLayerId, layerTypes, ModelConfig::LayerType::kRECURRENT);
}
void RnnStateManager::fillSlotMapping(
runtime::ITensor& dstPointers, SizeType32 dstSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const
{
TLLM_CHECK(seqSlotIdx < mMaxNumSequences);
TLLM_CHECK(beamWidth <= mMaxBeamWidth);
auto* dstPtr = bufferCast<SizeType32>(dstPointers);
if (beamWidth == 1)
{
dstPtr[dstSlotOffset] = seqSlotIdx * mBeamSlotsPerSequence;
}
else
{
// leave first for context.
std::iota(dstPtr + dstSlotOffset, dstPtr + dstSlotOffset + beamWidth, seqSlotIdx * mBeamSlotsPerSequence + 1);
}
}
void RnnStateManager::allocateCacheBlocks(std::vector<RequestIdType> const& requestIds)
{
for (auto const& requestId : requestIds)
{
auto it = mCacheIndex.find(requestId);
if (it == mCacheIndex.end())
{
TLLM_CHECK_WITH_INFO(!mFreeBlocks.empty(), "Run out of RNN state cache blocks");
SizeType32 const block = mFreeBlocks.back();
mFreeBlocks.pop_back();
mCacheIndex[requestId] = block;
}
}
}
void RnnStateManager::freeCacheBlock(RequestIdType requestId)
{
auto it = mCacheIndex.find(requestId);
if (it != mCacheIndex.end())
{
mFreeBlocks.push_back(it->second);
mCacheIndex.erase(it);
}
}
RnnStateManager::SizeType32 RnnStateManager::getCacheIndex(RequestIdType requestId) const
{
auto it = mCacheIndex.find(requestId);
TLLM_CHECK_WITH_INFO(it != mCacheIndex.end(), "Request ID not found in cache index");
return it->second;
}
std::vector<RnnStateManager::SizeType32> RnnStateManager::getStateIndices(
std::vector<RequestIdType> const& requestIds, std::vector<bool> const& isPadding)
{
TLLM_CHECK_WITH_INFO(requestIds.size() == isPadding.size(), "requestIds and isPadding must have the same size");
std::unordered_set<SizeType32> availableSlots;
availableSlots.reserve(mMaxNumSequences);
for (SizeType32 i = 0; i < mMaxNumSequences; ++i)
{
availableSlots.insert(i);
}
for (size_t i = 0; i < requestIds.size(); ++i)
{
if (!isPadding[i])
{
availableSlots.erase(getCacheIndex(requestIds[i]));
}
}
std::vector<SizeType32> result;
result.reserve(requestIds.size());
auto availableIt = availableSlots.begin();
for (size_t i = 0; i < requestIds.size(); ++i)
{
if (isPadding[i])
{
TLLM_CHECK_WITH_INFO(availableIt != availableSlots.end(), "Run out of available slots for padding");
result.push_back(*availableIt);
++availableIt;
}
else
{
result.push_back(getCacheIndex(requestIds[i]));
}
}
return result;
}
RnnStateManager::TensorPtr RnnStateManager::getConvStates(SizeType32 layerIdx) const
{
auto it = mLayerOffsets.find(layerIdx);
TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets");
auto result = ITensor::slice(pagedConvStates, it->second, 1);
result->squeeze(0);
return result;
}
RnnStateManager::TensorPtr RnnStateManager::getSsmStates(SizeType32 layerIdx) const
{
auto it = mLayerOffsets.find(layerIdx);
TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets");
auto result = ITensor::slice(pagedRnnStates, it->second, 1);
result->squeeze(0);
return result;
}
} // namespace tensorrt_llm::batch_manager::rnn_state_manager