/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/batch_manager/rnnStateManager.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/utils/runtimeUtils.h" #include using namespace tensorrt_llm::runtime; namespace tensorrt_llm::batch_manager::rnn_state_manager { RnnStateManager::RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, tensorrt_llm::runtime::BufferManager const& bufferManager) : mMaxNumSequences(maxNumSequences) , mMaxBeamWidth{modelConfig.getMaxBeamWidth()} { TLLM_CHECK_WITH_INFO(modelConfig.usePagedState(), "RnnStateManager should be used with Paged State enabled."); TLLM_CHECK_WITH_INFO(modelConfig.useMambaConv1dPlugin(), "RnnStateManager should be used with MambaConv1dPlugin."); TLLM_CHECK_WITH_INFO(mMaxBeamWidth == 1, "Beam search is not supported for Mamba now."); mBeamSlotsPerSequence = mMaxBeamWidth == 1 ? mMaxBeamWidth : mMaxBeamWidth + 1; // If we need support beam search, we may need mMaxBeamWidth + 1 slots and use separate input / output states. auto const& rnnConfig = modelConfig.getRnnConfig(); TLLM_CHECK_WITH_INFO(rnnConfig.has_value(), "RnnStateManager should be used with rnnConfig"); auto const convKernel = rnnConfig->convKernel; auto const stateSize = rnnConfig->stateSize; auto const rnnHiddenSize = rnnConfig->rnnHiddenSize; auto const rnnHeadSize = rnnConfig->rnnHeadSize; auto const rnnConvDimSize = rnnConfig->rnnConvDimSize; auto const localNbLayers = modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()); auto const dataType = modelConfig.getDataType(); auto const rnnStateShape = [&]() { if (rnnHeadSize > 0) { return tensorrt_llm::runtime::ITensor::makeShape({localNbLayers, mMaxNumSequences * mBeamSlotsPerSequence, rnnHiddenSize / rnnHeadSize, stateSize, rnnHeadSize}); } else { return tensorrt_llm::runtime::ITensor::makeShape( {localNbLayers, mMaxNumSequences * mBeamSlotsPerSequence, stateSize, rnnHiddenSize}); } }(); auto const convStateShape = tensorrt_llm::runtime::ITensor::makeShape( {localNbLayers, mMaxNumSequences * mBeamSlotsPerSequence, convKernel - 1, rnnConvDimSize}); pagedRnnStates = bufferManager.gpu(rnnStateShape, nvinfer1::DataType::kFLOAT); pagedConvStates = bufferManager.gpu(convStateShape, dataType); auto const statePtrsShape = tensorrt_llm::runtime::ITensor::makeShape({localNbLayers}); rnnStatePtrs = tensorrt_llm::runtime::BufferManager::cpu(statePtrsShape, TRTDataType::value); convStatePtrs = tensorrt_llm::runtime::BufferManager::cpu(statePtrsShape, TRTDataType::value); auto* rnnStatePtrArray = bufferCast(*rnnStatePtrs); auto* convStatePtrArray = bufferCast(*convStatePtrs); rnnStatePtr.resize(localNbLayers); convStatePtr.resize(localNbLayers); for (int i = 0; i < localNbLayers; i++) { auto layerRnnStates = tensorrt_llm::runtime::ITensor::slice(pagedRnnStates, i, 1); auto layerConvStates = tensorrt_llm::runtime::ITensor::slice(pagedConvStates, i, 1); rnnStatePtrArray[i] = layerRnnStates->data(); convStatePtrArray[i] = layerConvStates->data(); rnnStatePtr[i] = tensorrt_llm::runtime::ITensor::slice(rnnStatePtrs, i, 1); convStatePtr[i] = tensorrt_llm::runtime::ITensor::slice(convStatePtrs, i, 1); } } RnnStateManager::RnnStateManager(SizeType32 dState, SizeType32 dConv, SizeType32 numHeads, SizeType32 nGroups, SizeType32 headDim, SizeType32 maxBatchSize, WorldConfig const& worldConfig, int64_t stream, nvinfer1::DataType dtype, nvinfer1::DataType ssmCacheDtype, std::vector const& ppLayers) : mMaxNumSequences(maxBatchSize) , mMaxBeamWidth{1} , mBeamSlotsPerSequence{1} , mBufferManager{std::make_shared(reinterpret_cast(stream))} { auto const tpSize = worldConfig.getTensorParallelism(); auto const dInner = headDim * numHeads; auto convDim = dInner + 2 * nGroups * dState; auto nheads = numHeads; TLLM_CHECK_WITH_INFO(nheads % tpSize == 0, "nheads must be divisible by tp_size"); TLLM_CHECK_WITH_INFO(convDim % tpSize == 0, "conv_dim must be divisible by tp_size"); convDim = convDim / tpSize; nheads = nheads / tpSize; auto const numLocalLayers = static_cast(ppLayers.size()); for (SizeType32 offset = 0; offset < numLocalLayers; ++offset) { mLayerOffsets[ppLayers[offset]] = offset; } auto const convStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, convDim, dConv - 1}); pagedConvStates = mBufferManager->gpu(convStateShape, dtype); auto const rnnStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, nheads, headDim, dState}); pagedRnnStates = mBufferManager->gpu(rnnStateShape, ssmCacheDtype); mFreeBlocks.reserve(maxBatchSize); for (SizeType32 i = 0; i < maxBatchSize; ++i) { mFreeBlocks.push_back(i); } auto const statePtrsShape = ITensor::makeShape({numLocalLayers}); rnnStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType::value); convStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType::value); auto* rnnStatePtrArray = bufferCast(*rnnStatePtrs); auto* convStatePtrArray = bufferCast(*convStatePtrs); rnnStatePtr.resize(numLocalLayers); convStatePtr.resize(numLocalLayers); for (SizeType32 i = 0; i < numLocalLayers; i++) { auto layerRnnStates = ITensor::slice(pagedRnnStates, i, 1); auto layerConvStates = ITensor::slice(pagedConvStates, i, 1); rnnStatePtrArray[i] = layerRnnStates->data(); convStatePtrArray[i] = layerConvStates->data(); rnnStatePtr[i] = ITensor::slice(rnnStatePtrs, i, 1); convStatePtr[i] = ITensor::slice(convStatePtrs, i, 1); } } void RnnStateManager::getPtrBuffers( TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const { auto const firstLayerId = modelConfig.getFirstLocalLayer(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()); auto const& layerTypes = modelConfig.getLayerTypes(); utils::insertTensorVector( inputBuffers, "conv_state_ptr_", convStatePtr, firstLayerId, layerTypes, ModelConfig::LayerType::kRECURRENT); utils::insertTensorVector( inputBuffers, "rnn_state_ptr_", rnnStatePtr, firstLayerId, layerTypes, ModelConfig::LayerType::kRECURRENT); } void RnnStateManager::fillSlotMapping( runtime::ITensor& dstPointers, SizeType32 dstSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const { TLLM_CHECK(seqSlotIdx < mMaxNumSequences); TLLM_CHECK(beamWidth <= mMaxBeamWidth); auto* dstPtr = bufferCast(dstPointers); if (beamWidth == 1) { dstPtr[dstSlotOffset] = seqSlotIdx * mBeamSlotsPerSequence; } else { // leave first for context. std::iota(dstPtr + dstSlotOffset, dstPtr + dstSlotOffset + beamWidth, seqSlotIdx * mBeamSlotsPerSequence + 1); } } void RnnStateManager::allocateCacheBlocks(std::vector const& requestIds) { for (auto const& requestId : requestIds) { auto it = mCacheIndex.find(requestId); if (it == mCacheIndex.end()) { TLLM_CHECK_WITH_INFO(!mFreeBlocks.empty(), "Run out of RNN state cache blocks"); SizeType32 const block = mFreeBlocks.back(); mFreeBlocks.pop_back(); mCacheIndex[requestId] = block; } } } void RnnStateManager::freeCacheBlock(RequestIdType requestId) { auto it = mCacheIndex.find(requestId); if (it != mCacheIndex.end()) { mFreeBlocks.push_back(it->second); mCacheIndex.erase(it); } } RnnStateManager::SizeType32 RnnStateManager::getCacheIndex(RequestIdType requestId) const { auto it = mCacheIndex.find(requestId); TLLM_CHECK_WITH_INFO(it != mCacheIndex.end(), "Request ID not found in cache index"); return it->second; } std::vector RnnStateManager::getStateIndices( std::vector const& requestIds, std::vector const& isPadding) { TLLM_CHECK_WITH_INFO(requestIds.size() == isPadding.size(), "requestIds and isPadding must have the same size"); std::unordered_set availableSlots; availableSlots.reserve(mMaxNumSequences); for (SizeType32 i = 0; i < mMaxNumSequences; ++i) { availableSlots.insert(i); } for (size_t i = 0; i < requestIds.size(); ++i) { if (!isPadding[i]) { availableSlots.erase(getCacheIndex(requestIds[i])); } } std::vector result; result.reserve(requestIds.size()); auto availableIt = availableSlots.begin(); for (size_t i = 0; i < requestIds.size(); ++i) { if (isPadding[i]) { TLLM_CHECK_WITH_INFO(availableIt != availableSlots.end(), "Run out of available slots for padding"); result.push_back(*availableIt); ++availableIt; } else { result.push_back(getCacheIndex(requestIds[i])); } } return result; } RnnStateManager::TensorPtr RnnStateManager::getConvStates(SizeType32 layerIdx) const { auto it = mLayerOffsets.find(layerIdx); TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets"); auto result = ITensor::slice(pagedConvStates, it->second, 1); result->squeeze(0); return result; } RnnStateManager::TensorPtr RnnStateManager::getSsmStates(SizeType32 layerIdx) const { auto it = mLayerOffsets.find(layerIdx); TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets"); auto result = ITensor::slice(pagedRnnStates, it->second, 1); result->squeeze(0); return result; } } // namespace tensorrt_llm::batch_manager::rnn_state_manager