[TRTLLM-10273][feat] Move MambaCacheManager from Python to C++ (#10540)

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
Iman Tabrizian 2026-02-10 07:20:56 -08:00 committed by GitHub
parent d6e49542bd
commit 7d992972b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 634 additions and 54 deletions

View File

@ -16,11 +16,16 @@
#pragma once
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <optional>
#include <unordered_map>
#include <vector>
namespace tensorrt_llm::batch_manager::rnn_state_manager
{
@ -30,16 +35,34 @@ public:
using TensorPtr = runtime::ITensor::SharedPtr;
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using TensorMap = runtime::StringPtrMap<runtime::ITensor>;
using RequestIdType = tensorrt_llm::batch_manager::RequestIdType;
RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, tensorrt_llm::runtime::BufferManager const& bufferManager);
RnnStateManager(SizeType32 dState, SizeType32 dConv, SizeType32 numHeads, SizeType32 nGroups, SizeType32 headDim,
SizeType32 maxBatchSize, runtime::WorldConfig const& worldConfig, int64_t stream, nvinfer1::DataType dtype,
nvinfer1::DataType ssmCacheDtype, std::vector<SizeType32> const& ppLayers);
void getPtrBuffers(TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig) const;
void fillSlotMapping(
runtime::ITensor& dstPointers, SizeType32 dstSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const;
void allocateCacheBlocks(std::vector<RequestIdType> const& requestIds);
void freeCacheBlock(RequestIdType requestId);
[[nodiscard]] SizeType32 getCacheIndex(RequestIdType requestId) const;
[[nodiscard]] std::vector<SizeType32> getStateIndices(
std::vector<RequestIdType> const& requestIds, std::vector<bool> const& isPadding);
[[nodiscard]] TensorPtr getConvStates(SizeType32 layerIdx) const;
[[nodiscard]] TensorPtr getSsmStates(SizeType32 layerIdx) const;
private:
// If we need support beam search, we may need mMaxBeamWidth + 1 slots and use separate input / output states.
TensorPtr pagedRnnStates; // [local_nb_layers, max_seq_num * max_beam_width, state_size, rnn_hidden_size] or
@ -55,6 +78,10 @@ private:
SizeType32 mMaxNumSequences = 0;
SizeType32 mMaxBeamWidth = 0;
SizeType32 mBeamSlotsPerSequence = 0;
std::unordered_map<SizeType32, SizeType32> mLayerOffsets;
std::vector<SizeType32> mFreeBlocks;
std::unordered_map<RequestIdType, SizeType32> mCacheIndex;
std::optional<runtime::BufferManager> mBufferManager;
};
} // namespace tensorrt_llm::batch_manager::rnn_state_manager

View File

@ -17,8 +17,11 @@
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/utils/runtimeUtils.h"
#include <unordered_set>
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm::batch_manager::rnn_state_manager
@ -82,6 +85,64 @@ RnnStateManager::RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runti
}
}
RnnStateManager::RnnStateManager(SizeType32 dState, SizeType32 dConv, SizeType32 numHeads, SizeType32 nGroups,
SizeType32 headDim, SizeType32 maxBatchSize, WorldConfig const& worldConfig, int64_t stream,
nvinfer1::DataType dtype, nvinfer1::DataType ssmCacheDtype, std::vector<SizeType32> const& ppLayers)
: mMaxNumSequences(maxBatchSize)
, mMaxBeamWidth{1}
, mBeamSlotsPerSequence{1}
, mBufferManager{std::make_shared<CudaStream>(reinterpret_cast<cudaStream_t>(stream))}
{
auto const tpSize = worldConfig.getTensorParallelism();
auto const dInner = headDim * numHeads;
auto convDim = dInner + 2 * nGroups * dState;
auto nheads = numHeads;
TLLM_CHECK_WITH_INFO(nheads % tpSize == 0, "nheads must be divisible by tp_size");
TLLM_CHECK_WITH_INFO(convDim % tpSize == 0, "conv_dim must be divisible by tp_size");
convDim = convDim / tpSize;
nheads = nheads / tpSize;
auto const numLocalLayers = static_cast<SizeType32>(ppLayers.size());
for (SizeType32 offset = 0; offset < numLocalLayers; ++offset)
{
mLayerOffsets[ppLayers[offset]] = offset;
}
auto const convStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, convDim, dConv - 1});
pagedConvStates = mBufferManager->gpu(convStateShape, dtype);
auto const rnnStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, nheads, headDim, dState});
pagedRnnStates = mBufferManager->gpu(rnnStateShape, ssmCacheDtype);
mFreeBlocks.reserve(maxBatchSize);
for (SizeType32 i = 0; i < maxBatchSize; ++i)
{
mFreeBlocks.push_back(i);
}
auto const statePtrsShape = ITensor::makeShape({numLocalLayers});
rnnStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
convStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
auto* rnnStatePtrArray = bufferCast<void*>(*rnnStatePtrs);
auto* convStatePtrArray = bufferCast<void*>(*convStatePtrs);
rnnStatePtr.resize(numLocalLayers);
convStatePtr.resize(numLocalLayers);
for (SizeType32 i = 0; i < numLocalLayers; i++)
{
auto layerRnnStates = ITensor::slice(pagedRnnStates, i, 1);
auto layerConvStates = ITensor::slice(pagedConvStates, i, 1);
rnnStatePtrArray[i] = layerRnnStates->data();
convStatePtrArray[i] = layerConvStates->data();
rnnStatePtr[i] = ITensor::slice(rnnStatePtrs, i, 1);
convStatePtr[i] = ITensor::slice(convStatePtrs, i, 1);
}
}
void RnnStateManager::getPtrBuffers(
TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const
{
@ -113,4 +174,95 @@ void RnnStateManager::fillSlotMapping(
}
}
void RnnStateManager::allocateCacheBlocks(std::vector<RequestIdType> const& requestIds)
{
for (auto const& requestId : requestIds)
{
auto it = mCacheIndex.find(requestId);
if (it == mCacheIndex.end())
{
TLLM_CHECK_WITH_INFO(!mFreeBlocks.empty(), "Run out of RNN state cache blocks");
SizeType32 const block = mFreeBlocks.back();
mFreeBlocks.pop_back();
mCacheIndex[requestId] = block;
}
}
}
void RnnStateManager::freeCacheBlock(RequestIdType requestId)
{
auto it = mCacheIndex.find(requestId);
if (it != mCacheIndex.end())
{
mFreeBlocks.push_back(it->second);
mCacheIndex.erase(it);
}
}
RnnStateManager::SizeType32 RnnStateManager::getCacheIndex(RequestIdType requestId) const
{
auto it = mCacheIndex.find(requestId);
TLLM_CHECK_WITH_INFO(it != mCacheIndex.end(), "Request ID not found in cache index");
return it->second;
}
std::vector<RnnStateManager::SizeType32> RnnStateManager::getStateIndices(
std::vector<RequestIdType> const& requestIds, std::vector<bool> const& isPadding)
{
TLLM_CHECK_WITH_INFO(requestIds.size() == isPadding.size(), "requestIds and isPadding must have the same size");
std::unordered_set<SizeType32> availableSlots;
availableSlots.reserve(mMaxNumSequences);
for (SizeType32 i = 0; i < mMaxNumSequences; ++i)
{
availableSlots.insert(i);
}
for (size_t i = 0; i < requestIds.size(); ++i)
{
if (!isPadding[i])
{
availableSlots.erase(getCacheIndex(requestIds[i]));
}
}
std::vector<SizeType32> result;
result.reserve(requestIds.size());
auto availableIt = availableSlots.begin();
for (size_t i = 0; i < requestIds.size(); ++i)
{
if (isPadding[i])
{
TLLM_CHECK_WITH_INFO(availableIt != availableSlots.end(), "Run out of available slots for padding");
result.push_back(*availableIt);
++availableIt;
}
else
{
result.push_back(getCacheIndex(requestIds[i]));
}
}
return result;
}
RnnStateManager::TensorPtr RnnStateManager::getConvStates(SizeType32 layerIdx) const
{
auto it = mLayerOffsets.find(layerIdx);
TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets");
auto result = ITensor::slice(pagedConvStates, it->second, 1);
result->squeeze(0);
return result;
}
RnnStateManager::TensorPtr RnnStateManager::getSsmStates(SizeType32 layerIdx) const
{
auto it = mLayerOffsets.find(layerIdx);
TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets");
auto result = ITensor::slice(pagedRnnStates, it->second, 1);
result->squeeze(0);
return result;
}
} // namespace tensorrt_llm::batch_manager::rnn_state_manager

View File

@ -416,7 +416,36 @@ void initBindings(nb::module_& m)
nb::class_<tb::rnn_state_manager::RnnStateManager>(m, "RnnStateManager")
.def(nb::init<tr::SizeType32, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"));
nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"),
nb::call_guard<nb::gil_scoped_release>())
.def(nb::init<tr::SizeType32, tr::SizeType32, tr::SizeType32, tr::SizeType32, tr::SizeType32, tr::SizeType32,
tr::WorldConfig const&, int64_t, nvinfer1::DataType, nvinfer1::DataType,
std::vector<tr::SizeType32> const&>(),
nb::arg("d_state"), nb::arg("d_conv"), nb::arg("num_heads"), nb::arg("n_groups"), nb::arg("head_dim"),
nb::arg("max_batch_size"), nb::arg("world_config"), nb::arg("stream"), nb::arg("dtype"),
nb::arg("ssm_cache_dtype"), nb::arg("pp_layers"), nb::call_guard<nb::gil_scoped_release>())
.def(
"get_conv_states",
[](tb::rnn_state_manager::RnnStateManager& self, tr::SizeType32 layerIdx) -> at::Tensor
{
auto tensor = self.getConvStates(layerIdx);
return tr::Torch::tensor(tensor);
},
nb::arg("layer_idx"), nb::call_guard<nb::gil_scoped_release>())
.def(
"get_ssm_states",
[](tb::rnn_state_manager::RnnStateManager& self, tr::SizeType32 layerIdx) -> at::Tensor
{
auto tensor = self.getSsmStates(layerIdx);
return tr::Torch::tensor(tensor);
},
nb::arg("layer_idx"), nb::call_guard<nb::gil_scoped_release>())
.def("allocate_cache_blocks", &tb::rnn_state_manager::RnnStateManager::allocateCacheBlocks,
nb::arg("request_ids"), nb::call_guard<nb::gil_scoped_release>())
.def("free_cache_block", &tb::rnn_state_manager::RnnStateManager::freeCacheBlock, nb::arg("request_id"),
nb::call_guard<nb::gil_scoped_release>())
.def("get_state_indices", &tb::rnn_state_manager::RnnStateManager::getStateIndices, nb::arg("request_ids"),
nb::arg("is_padding"), nb::call_guard<nb::gil_scoped_release>());
m.def(
"add_new_tokens_to_requests",

View File

@ -204,6 +204,7 @@ class FlashInferAttentionMetadata(AttentionMetadata):
return self.kv_cache_manager.tokens_per_block
def prepare(self) -> None:
super().prepare()
extra_attrs = get_model_extra_attrs()
if extra_attrs is None:
get_global_attrs().attention_metadata = weakref.ref(self)

View File

@ -3,7 +3,7 @@ import weakref
from collections import namedtuple
from dataclasses import dataclass, field
from enum import Enum, IntEnum
from typing import (TYPE_CHECKING, Dict, Generic, List, Optional, Protocol,
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Protocol,
Tuple, Type, TypeVar, Union)
import torch
@ -21,6 +21,7 @@ from tensorrt_llm.models.modeling_utils import QuantConfig
from ..memory_buffer_utils import Buffers
from ..metadata import KVCacheParams
from ..pyexecutor.mamba_cache_manager import MambaCacheManager
from ..pyexecutor.resource_manager import KVCacheManager, KVCacheManagerV2
from ..utils import get_model_extra_attrs
@ -147,6 +148,9 @@ class AttentionMetadata:
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
_num_tokens: int = field(init=False, default=0, repr=False)
mamba_metadata: Optional[Any] = None
mamba_chunk_size: int = 128
# The number of tokens in the padded sequence.
padded_num_tokens: Optional[int] = None
@ -290,6 +294,23 @@ class AttentionMetadata:
"""
Hook to be called before the forward step of the model.
"""
self._prepare_mamba_metadata()
def _prepare_mamba_metadata(self):
if self.mamba_metadata is False:
return
if self.mamba_metadata is None:
if (self.kv_cache_manager is not None
and isinstance(self.kv_cache_manager, MambaCacheManager)):
from ..modules.mamba.mamba2_metadata import Mamba2Metadata
self.mamba_metadata = Mamba2Metadata(self.max_num_requests,
self.mamba_chunk_size)
else:
self.mamba_metadata = False
return
self.mamba_metadata.prepare(self)
def create_cuda_graph_metadata(self,
max_batch_size: int,

View File

@ -933,6 +933,7 @@ class TrtllmAttentionMetadata(AttentionMetadata):
self.helix_is_inactive_rank_cpu[:batch_size], non_blocking=True)
def prepare(self) -> None:
super().prepare()
extra_attrs = get_model_extra_attrs()
# If model extra attrs is set, attention_metadata is setup in executor.
if extra_attrs is None:

View File

@ -59,6 +59,7 @@ def generate_sliding_window_mask(batch_size: int, target_length: int,
class VanillaAttentionMetadata(AttentionMetadata):
def prepare(self) -> None:
super().prepare()
# indices of used cache blocks for each sequence
assert self.request_ids is not None
self.block_ids_per_seq = self.kv_cache_manager.get_batch_cache_indices(

View File

@ -22,7 +22,6 @@ from transformers import AutoConfig, PretrainedConfig
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
from tensorrt_llm._torch.utils import ActivationType, relu2
from ..attention_backend import AttentionMetadata
@ -424,8 +423,6 @@ class NemotronHModel(DecoderModel):
dtype=config.torch_dtype,
)
self.mamba_metadata: Optional[Mamba2Metadata] = None
def forward(
self,
attn_metadata: AttentionMetadata,
@ -440,11 +437,7 @@ class NemotronHModel(DecoderModel):
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests:
self.mamba_metadata = Mamba2Metadata(
attn_metadata.max_num_requests,
chunk_size=self.model_config.pretrained_config.chunk_size)
self.mamba_metadata.prepare(attn_metadata)
mamba_metadata = attn_metadata.mamba_metadata
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@ -456,7 +449,7 @@ class NemotronHModel(DecoderModel):
hidden_states,
attn_metadata,
spec_metadata=spec_metadata,
mamba_metadata=self.mamba_metadata)
mamba_metadata=mamba_metadata)
hidden_states = self.norm_f(hidden_states)

View File

@ -31,6 +31,8 @@ from tensorrt_llm._torch.modules.fla.chunk import chunk_gated_delta_rule
from tensorrt_llm._torch.modules.fla.fused_sigmoid_gating_recurrent import \
fused_sigmoid_gating_delta_rule_update
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \
use_cpp_mamba_cache_manager
from tensorrt_llm.mapping import Mapping
from ..attention_backend import AttentionMetadata
@ -754,8 +756,12 @@ class Qwen3NextGatedDeltaNet(nn.Module):
batch_split_size = [num_prefills, num_decodes]
has_initial_states = mamba_metadata.has_initial_states
state_indices = attn_metadata.kv_cache_manager.get_state_indices(
)[:num_prefills + num_decodes]
batch_size = num_prefills + num_decodes
if use_cpp_mamba_cache_manager():
state_indices = mamba_metadata.state_indices[:batch_size]
else:
state_indices = attn_metadata.kv_cache_manager.get_state_indices(
)[:batch_size]
state_indices_p, state_indices_d = torch.split(state_indices,
batch_split_size)
@ -1197,8 +1203,6 @@ class Qwen3NextModel(DecoderModel):
use_gemma=True,
)
self.mamba_metadata: Optional[Mamba2Metadata] = None
def forward(
self,
attn_metadata: AttentionMetadata,
@ -1213,13 +1217,10 @@ class Qwen3NextModel(DecoderModel):
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests:
self.mamba_metadata = Mamba2Metadata(
attn_metadata.max_num_requests,
# chunk_size=self.model_config.pretrained_config.mamba2_chunk_size)
# TODO check how to get the correct chunk_size
chunk_size=128)
self.mamba_metadata.prepare(attn_metadata)
mamba_metadata = attn_metadata.mamba_metadata
if mamba_metadata.max_batch_size != attn_metadata.max_num_requests:
attn_metadata.mamba_metadata = Mamba2Metadata(
attn_metadata.max_num_requests, chunk_size=128)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@ -1233,7 +1234,7 @@ class Qwen3NextModel(DecoderModel):
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
mamba_metadata=self.mamba_metadata)
mamba_metadata=mamba_metadata)
return hidden_states

View File

@ -19,6 +19,10 @@ from typing import Tuple
import torch
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
CUDA_GRAPH_DUMMY_REQUEST_ID
from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \
use_cpp_mamba_cache_manager
def cu_seqlens_to_chunk_indices_offsets(
@ -107,11 +111,47 @@ class Mamba2Metadata:
self.chunk_indices: torch.Tensor = None
self.chunk_offsets: torch.Tensor = None
self.state_indices_cpu = torch.zeros(max_batch_size,
dtype=torch.int32,
pin_memory=True)
self.state_indices = torch.zeros(max_batch_size,
dtype=torch.int32,
device="cuda")
self._query_start_loc_long_buf = torch.arange(0,
max_batch_size + 1,
dtype=torch.long,
device="cuda")
self._query_start_loc_buf = torch.zeros(max_batch_size + 1,
dtype=torch.int,
device="cuda")
self.query_start_loc_long = self._query_start_loc_long_buf
self.query_start_loc = self._query_start_loc_buf
def prepare(self, attn_metadata: AttentionMetadata):
batch_size = attn_metadata.seq_lens.shape[0]
num_contexts = attn_metadata.num_contexts
context_lens = attn_metadata.seq_lens_cuda[:num_contexts]
num_ctx_tokens = attn_metadata.num_ctx_tokens
kv_cache_manager = attn_metadata.kv_cache_manager
request_ids = attn_metadata.request_ids
if (kv_cache_manager is not None
and hasattr(kv_cache_manager, 'get_state_indices')
and request_ids is not None):
if use_cpp_mamba_cache_manager():
batch_request_ids = request_ids[:batch_size]
is_padding = [
req_id == CUDA_GRAPH_DUMMY_REQUEST_ID
for req_id in batch_request_ids
]
indices = kv_cache_manager.get_state_indices(
batch_request_ids, is_padding)
for i, idx in enumerate(indices):
self.state_indices_cpu[i] = idx
self.state_indices[:batch_size].copy_(
self.state_indices_cpu[:batch_size], non_blocking=True)
if num_contexts > 0:
torch.cumsum(context_lens,
dim=0,
@ -125,8 +165,14 @@ class Mamba2Metadata:
out=self.cu_seqlens[num_contexts + 1:batch_size + 1])
# Need both `query_start_loc` and `query_start_loc_long` because `causal_conv1d_fn`
# accepts only `int32` while `chunk_gated_delta_rule` accepts only `long`.
self.query_start_loc = self.cu_seqlens[:batch_size + 1]
self.query_start_loc_long = self.query_start_loc.to(torch.long)
self._query_start_loc_buf[:batch_size +
1] = self.cu_seqlens[:batch_size + 1]
self.query_start_loc = self._query_start_loc_buf[:batch_size + 1]
self._query_start_loc_long_buf[:batch_size + 1].copy_(
self.query_start_loc.to(torch.long), non_blocking=True)
self.query_start_loc_long = self._query_start_loc_long_buf[:
batch_size
+ 1]
self.seq_idx = torch.repeat_interleave(
torch.arange(num_contexts,
dtype=torch.int,
@ -148,8 +194,11 @@ class Mamba2Metadata:
self.chunk_offsets = None
else:
self.query_start_loc = None
self.query_start_loc_long = torch.arange(
0,
batch_size + 1,
dtype=torch.long,
device=self.cu_seqlens.device)
torch.arange(0,
batch_size + 1,
dtype=torch.long,
device=self.cu_seqlens.device,
out=self._query_start_loc_long_buf[:batch_size + 1])
self.query_start_loc_long = self._query_start_loc_long_buf[:
batch_size
+ 1]

View File

@ -21,6 +21,8 @@ from flashinfer.mamba import selective_state_update as selective_state_update_fi
from torch import nn
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \
use_cpp_mamba_cache_manager
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
@ -204,15 +206,24 @@ class Mamba2Mixer(nn.Module):
seqlen_split_size = [num_prefill_tokens, num_decode_tokens]
batch_split_size = [num_prefills, num_decodes]
state_indices = attn_metadata.kv_cache_manager.get_state_indices(
)[:num_prefills + num_decodes]
if use_cpp_mamba_cache_manager():
state_indices = mamba_metadata.state_indices[:num_prefills +
num_decodes]
conv_states = attn_metadata.kv_cache_manager.get_conv_states(
self.layer_idx)
ssm_states = attn_metadata.kv_cache_manager.get_ssm_states(
self.layer_idx)
layer_cache = None # Not used in C++ path
else:
state_indices = attn_metadata.kv_cache_manager.get_state_indices(
)[:num_prefills + num_decodes]
layer_cache = attn_metadata.kv_cache_manager.mamba_layer_cache(
self.layer_idx)
conv_states = layer_cache.conv
ssm_states = layer_cache.temporal
state_indices_p, state_indices_d = torch.split(state_indices,
batch_split_size)
layer_cache = attn_metadata.kv_cache_manager.mamba_layer_cache(
self.layer_idx)
conv_states = layer_cache.conv
ssm_states = layer_cache.temporal
# in_proj
zxbcdt = self.in_proj(hidden_states)
@ -310,6 +321,9 @@ class Mamba2Mixer(nn.Module):
is_target_verify = attn_metadata.kv_cache_manager.is_speculative(
) and spec_metadata is not None
if is_target_verify:
# Speculative decoding only supported with Python path
assert layer_cache is not None, \
"Speculative decoding requires Python MambaCacheManager"
# TODO: support dynamic speculation, will add current_draft_len later [TRTLLM-10319]
draft_token_num = spec_metadata.max_draft_len + 1
intermediate_conv_states = layer_cache.intermediate_conv_window

View File

@ -19,7 +19,7 @@ from ..speculative.mtp import SampleStateTensorsMTP
from ..speculative.utils import get_draft_kv_cache_manager
from ..utils import make_weak_ref, piecewise_cuda_graph
from .llm_request import get_draft_token_length
from .mamba_cache_manager import MambaCacheManager
from .mamba_cache_manager import MambaCacheManager, use_cpp_mamba_cache_manager
from .resource_manager import (BaseResourceManager, ResourceManager,
ResourceManagerType)
from .sampler import SampleStateTensors
@ -460,8 +460,8 @@ class CUDAGraphRunner:
if spec_res_mgr:
spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID])
# handle special cases of padding requests + MambaCacheManager or MambaHybridCacheManager
if isinstance(kv_cache_manager, MambaCacheManager):
if (isinstance(kv_cache_manager, MambaCacheManager)
and not use_cpp_mamba_cache_manager()):
kv_cache_manager.reorder_state_indices_when_padding_requests(
batch_size, padding_size)

View File

@ -13,22 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import torch
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._torch.pyexecutor.resource_manager import (
BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers)
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
import tensorrt_llm.bindings
if TYPE_CHECKING:
from tensorrt_llm._torch.attention_backend.interface import \
AttentionMetadata
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._torch.pyexecutor.resource_manager import (
BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers)
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
from tensorrt_llm._utils import torch_dtype_to_binding
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
RnnStateManagerCpp = tensorrt_llm.bindings.internal.batch_manager.RnnStateManager
WorldConfig = tensorrt_llm.bindings.WorldConfig
GB = 1 << 30
@ -42,7 +48,117 @@ def get_tensor_size_bytes(tensor):
return 0
class MambaCacheManager(BaseResourceManager):
def use_cpp_mamba_cache_manager() -> bool:
"""Check if C++ MambaCacheManager should be used.
Returns True if TRTLLM_USE_CPP_MAMBA='1' is set, False otherwise.
By default, PythonMambaCacheManager is used.
"""
return os.environ.get('TRTLLM_USE_CPP_MAMBA', '0') == '1'
class CppMambaCacheManager(BaseResourceManager):
"""C++ backed Mamba cache manager using RnnStateManager bindings."""
def __init__(
self,
d_state: int,
d_conv: int,
num_heads: int,
n_groups: int,
head_dim: int,
num_layers: int,
max_num_sequences: int,
mapping: Mapping,
dtype: torch.dtype,
ssm_cache_dtype: torch.dtype,
layer_mask: Optional[List[bool]] = None,
stream: Optional[torch.cuda.Stream] = None,
) -> None:
self.mamba_ssm_cache_dtype = ssm_cache_dtype
# get tp size
tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1
world_config = WorldConfig(
tensor_parallelism=tp_size,
pipeline_parallelism=mapping.pp_size,
rank=mapping.rank,
gpus_per_node=mapping.gpus_per_node,
)
dtype_binding = torch_dtype_to_binding(dtype)
ssm_cache_dtype_binding = torch_dtype_to_binding(
ssm_cache_dtype if ssm_cache_dtype is not None else dtype)
self._stream = stream if stream is not None else torch.cuda.current_stream(
)
pp_layers, _ = get_pp_layers(num_layers, mapping, layer_mask=layer_mask)
self.mamba_impl = RnnStateManagerCpp(
d_state=d_state,
d_conv=d_conv,
num_heads=num_heads,
n_groups=n_groups,
head_dim=head_dim,
max_batch_size=max_num_sequences,
world_config=world_config,
stream=self._stream.cuda_stream,
dtype=dtype_binding,
ssm_cache_dtype=ssm_cache_dtype_binding,
pp_layers=pp_layers,
)
self._max_num_sequences = max_num_sequences
def get_max_resource_count(self) -> int:
# Return the maximum number of sequences that can be cached.
return self._max_num_sequences
def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
# For Mamba cache manager, we always need one slot per request.
return 1
def prepare_resources(self, scheduled_batch: ScheduledRequests):
context_ids = [
i.py_request_id for i in scheduled_batch.context_requests
]
generation_ids = [
i.py_request_id for i in scheduled_batch.generation_requests
]
request_ids = context_ids + generation_ids
self.mamba_impl.allocate_cache_blocks(request_ids)
def free_resources(self, request: LlmRequest):
self.mamba_impl.free_cache_block(request.py_request_id)
def add_dummy_requests(self, request_ids: List[int], **kwargs):
# For CUDA graph dummy requests, the blocks will be allocated
# when get_state_indices is called.
from .cuda_graph_runner import CUDA_GRAPH_DUMMY_REQUEST_ID
request_ids = [
rid for rid in request_ids if rid != CUDA_GRAPH_DUMMY_REQUEST_ID
]
if request_ids:
self.mamba_impl.allocate_cache_blocks(request_ids)
def get_state_indices(self, request_ids: List[int],
is_padding: List[bool]) -> List[int]:
return self.mamba_impl.get_state_indices(request_ids, is_padding)
def get_conv_states(self, layer_idx: int) -> torch.Tensor:
return self.mamba_impl.get_conv_states(layer_idx)
def get_ssm_states(self, layer_idx: int) -> torch.Tensor:
return self.mamba_impl.get_ssm_states(layer_idx)
def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
return self.mamba_ssm_cache_dtype
def shutdown(self):
torch.cuda.empty_cache()
class PythonMambaCacheManager(BaseResourceManager):
@dataclass(frozen=True, kw_only=True)
class State:
@ -193,6 +309,17 @@ class MambaCacheManager(BaseResourceManager):
dtype=torch.int32,
device=device)
# Store max_batch_size for resource management
self._max_batch_size = max_batch_size
def get_max_resource_count(self) -> int:
"""Return the maximum number of sequences that can be cached."""
return self._max_batch_size
def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
"""For Mamba cache manager, we always need one slot per request."""
return 1
@torch.inference_mode()
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
self.state_indices_list.clear()
@ -261,7 +388,9 @@ class MambaCacheManager(BaseResourceManager):
block = self.mamba_cache_index.pop(request_id)
self.mamba_cache_free_blocks.append(block)
def get_state_indices(self) -> torch.Tensor:
def get_state_indices(self,
request_ids: List[int] = None,
is_padding: List[bool] = None) -> torch.Tensor:
return self.state_indices
def get_conv_states(self, layer_idx: int) -> torch.Tensor:
@ -296,6 +425,9 @@ class MambaCacheManager(BaseResourceManager):
layer_offset = self.mamba_layer_offsets[layer_idx]
return self.mamba_cache.at_layer_idx(layer_offset)
def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
return self.mamba_ssm_cache_dtype
def shutdown(self):
"""Release tensor memory."""
# Clear state indices
@ -346,6 +478,140 @@ class MambaCacheManager(BaseResourceManager):
conv_states[:, state_indices_d, :] = accepted_conv_state
class MambaCacheManager(BaseResourceManager):
def __init__(
self,
d_state: int,
d_conv: int,
num_heads: int,
n_groups: int,
head_dim: int,
num_layers: int,
max_batch_size: int,
spec_state_size: int,
mapping: Mapping,
dtype: torch.dtype,
ssm_cache_dtype: torch.dtype,
layer_mask: Optional[List[bool]] = None,
stream: Optional[torch.cuda.Stream] = None,
speculative_num_draft_tokens: Optional[int] = None,
) -> None:
max_num_sequences = max_batch_size * mapping.pp_size
self._use_cpp = use_cpp_mamba_cache_manager()
if self._use_cpp:
assert speculative_num_draft_tokens is None, \
"speculative_num_draft_tokens is not supported in CppMambaCacheManager"
self._impl = CppMambaCacheManager(
d_state=d_state,
d_conv=d_conv,
num_heads=num_heads,
n_groups=n_groups,
head_dim=head_dim,
num_layers=num_layers,
max_num_sequences=max_num_sequences,
mapping=mapping,
dtype=dtype,
ssm_cache_dtype=ssm_cache_dtype,
layer_mask=layer_mask,
stream=stream,
)
else:
self._impl = PythonMambaCacheManager(
d_state=d_state,
d_conv=d_conv,
num_heads=num_heads,
n_groups=n_groups,
head_dim=head_dim,
num_layers=num_layers,
max_batch_size=max_batch_size,
spec_state_size=spec_state_size,
mapping=mapping,
dtype=dtype,
ssm_cache_dtype=ssm_cache_dtype,
layer_mask=layer_mask,
speculative_num_draft_tokens=speculative_num_draft_tokens,
)
def get_max_resource_count(self) -> int:
return self._impl.get_max_resource_count()
def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
return self._impl.get_needed_resource_to_completion(request)
def prepare_resources(self, scheduled_batch: ScheduledRequests):
self._impl.prepare_resources(scheduled_batch)
def free_resources(self, request: LlmRequest):
self._impl.free_resources(request)
def add_dummy_requests(self, request_ids: List[int], **kwargs):
if self._use_cpp:
self._impl.add_dummy_requests(request_ids, **kwargs)
def get_state_indices(
self,
request_ids: Optional[List[int]] = None,
is_padding: Optional[List[bool]] = None
) -> Union[torch.Tensor, List[int]]:
return self._impl.get_state_indices(request_ids, is_padding)
def reorder_state_indices_when_padding_requests(self, request_size: int,
padding_size: int):
assert not self._use_cpp, "reorder_state_indices_when_padding_requests is not supported in CppMambaCacheManager"
self._impl.reorder_state_indices_when_padding_requests(
request_size, padding_size)
@property
def mamba_cache_free_blocks(self) -> List[int]:
assert not self._use_cpp, "mamba_cache_free_blocks is not supported in CppMambaCacheManager"
return self._impl.mamba_cache_free_blocks
@property
def mamba_cache_index(self) -> Dict[int, int]:
assert not self._use_cpp, "mamba_cache_index is not supported in CppMambaCacheManager"
return self._impl.mamba_cache_index
def get_conv_states(self, layer_idx: int) -> torch.Tensor:
return self._impl.get_conv_states(layer_idx)
def get_ssm_states(self, layer_idx: int) -> torch.Tensor:
return self._impl.get_ssm_states(layer_idx)
def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
return self._impl.get_mamba_ssm_cache_dtype()
def get_intermediate_ssm_states(self,
layer_idx: int) -> Optional[torch.Tensor]:
assert not self._use_cpp, "get_intermediate_ssm_states is not supported in CppMambaCacheManager"
return self._impl.get_intermediate_ssm_states(layer_idx)
def get_intermediate_conv_states(self,
layer_idx: int) -> Optional[torch.Tensor]:
assert not self._use_cpp, "get_intermediate_conv_states is not supported in CppMambaCacheManager"
return self._impl.get_intermediate_conv_states(layer_idx)
def is_speculative(self) -> bool:
assert not self._use_cpp, "is_speculative is not supported in CppMambaCacheManager"
return self._impl.is_speculative()
def mamba_layer_cache(
self, layer_idx: int
) -> Union[PythonMambaCacheManager.State,
PythonMambaCacheManager.SpeculativeState, None]:
assert not self._use_cpp, "mamba_layer_cache is not supported in CppMambaCacheManager"
return self._impl.mamba_layer_cache(layer_idx)
def shutdown(self):
self._impl.shutdown()
def update_mamba_states(self, attn_metadata: "AttentionMetadata",
num_accepted_tokens: torch.Tensor):
assert self._use_cpp, "update_mamba_states is not supported in PythonMambaCacheManager"
self._impl.update_mamba_states(attn_metadata, num_accepted_tokens)
class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
def __init__(
@ -399,6 +665,7 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
mamba_cache_dtype,
mamba_ssm_cache_dtype,
mamba_layer_mask,
execution_stream,
speculative_num_draft_tokens=spec_config.max_draft_len
if spec_config is not None else None,
)
@ -430,6 +697,10 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
MambaCacheManager.free_resources(self, request)
KVCacheManager.free_resources(self, request)
def add_dummy_requests(self, request_ids: List[int], **kwargs):
MambaCacheManager.add_dummy_requests(self, request_ids)
return KVCacheManager.add_dummy_requests(self, request_ids, **kwargs)
def shutdown(self):
MambaCacheManager.shutdown(self)
KVCacheManager.shutdown(self)

View File

@ -2777,6 +2777,8 @@ class PyTorchModelEngine(ModelEngine):
num_extra_kv_tokens=get_num_extra_kv_tokens(spec_config))
attn_metadata.kv_cache_manager = kv_cache_manager
if hasattr(self.model.model_config.pretrained_config, 'chunk_size'):
attn_metadata.mamba_chunk_size = self.model.model_config.pretrained_config.chunk_size
attn_metadata.prepare()
peft_cache_manager = resource_manager and resource_manager.get_resource_manager(

View File

@ -884,7 +884,7 @@ class Runner:
else:
raise NotImplementedError("Unsupported config")
kv_cache_manager.add_dummy_requests(
list(range(max_batch_size)), [max_seq_len] * max_batch_size
list(range(max_batch_size)), token_nums=[max_seq_len] * max_batch_size
)
return kv_cache_manager

View File

@ -5690,6 +5690,17 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
@skip_pre_hopper
@pytest.mark.skip_less_mpi_world_size(4)
@pytest.mark.skip_less_device_memory(40000)
@pytest.mark.parametrize(
"use_cpp_mamba",
[
False,
True,
],
ids=[
"python_mamba_cache",
"cpp_mamba_cache",
],
)
@pytest.mark.parametrize(
"attention_dp",
[
@ -5701,7 +5712,10 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
"attention_dp_on",
],
)
def test_fp8_4gpus(self, attention_dp):
def test_fp8_4gpus(self, attention_dp, use_cpp_mamba, monkeypatch):
monkeypatch.setenv("TRTLLM_USE_CPP_MAMBA",
"1" if use_cpp_mamba else "0")
with LLM(
f"{llm_models_root()}/Nemotron-SuperV3-phase1-mtp-fp8-fp8kv",
kv_cache_config=KvCacheConfig(

View File

@ -264,8 +264,10 @@ accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-True-True]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-python_mamba_cache]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-cpp_mamba_cache]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-python_mamba_cache]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-cpp_mamba_cache]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on-trtllm]
# multimodal accuracy tests

View File

@ -104,8 +104,10 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-python_mamba_cache]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-cpp_mamba_cache]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-python_mamba_cache]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-cpp_mamba_cache]
- test_e2e.py::test_ptp_quickstart_advanced_bs1
- test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8]
- test_e2e.py::test_trtllm_bench_llmapi_launch[pytorch_backend-llama-v3-llama3-8b]