From 7d992972b2e0b3b7f680cdb7173bf30eb5d6f285 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com> Date: Tue, 10 Feb 2026 07:20:56 -0800 Subject: [PATCH] [TRTLLM-10273][feat] Move MambaCacheManager from Python to C++ (#10540) Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/rnnStateManager.h | 27 ++ .../batch_manager/rnnStateManager.cpp | 152 +++++++++ .../nanobind/batch_manager/bindings.cpp | 31 +- .../_torch/attention_backend/flashinfer.py | 1 + .../_torch/attention_backend/interface.py | 23 +- .../_torch/attention_backend/trtllm.py | 1 + .../_torch/attention_backend/vanilla.py | 1 + .../_torch/models/modeling_nemotron_h.py | 11 +- .../_torch/models/modeling_qwen3_next.py | 25 +- .../_torch/modules/mamba/mamba2_metadata.py | 63 +++- .../_torch/modules/mamba/mamba2_mixer.py | 26 +- .../_torch/pyexecutor/cuda_graph_runner.py | 6 +- .../_torch/pyexecutor/mamba_cache_manager.py | 289 +++++++++++++++++- .../_torch/pyexecutor/model_engine.py | 2 + .../tools/layer_wise_benchmarks/runner.py | 2 +- .../defs/accuracy/test_llm_api_pytorch.py | 16 +- .../test_lists/qa/llm_function_core.txt | 6 +- .../test_lists/test-db/l0_dgx_h100.yml | 6 +- 18 files changed, 634 insertions(+), 54 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/rnnStateManager.h b/cpp/include/tensorrt_llm/batch_manager/rnnStateManager.h index 1b2f97fa1f..555cbc4a98 100644 --- a/cpp/include/tensorrt_llm/batch_manager/rnnStateManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/rnnStateManager.h @@ -16,11 +16,16 @@ #pragma once +#include "tensorrt_llm/batch_manager/common.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/modelConfig.h" #include "tensorrt_llm/runtime/worldConfig.h" +#include +#include +#include + namespace tensorrt_llm::batch_manager::rnn_state_manager { @@ -30,16 +35,34 @@ public: using TensorPtr = runtime::ITensor::SharedPtr; using SizeType32 = tensorrt_llm::runtime::SizeType32; using TensorMap = runtime::StringPtrMap; + using RequestIdType = tensorrt_llm::batch_manager::RequestIdType; RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, tensorrt_llm::runtime::BufferManager const& bufferManager); + RnnStateManager(SizeType32 dState, SizeType32 dConv, SizeType32 numHeads, SizeType32 nGroups, SizeType32 headDim, + SizeType32 maxBatchSize, runtime::WorldConfig const& worldConfig, int64_t stream, nvinfer1::DataType dtype, + nvinfer1::DataType ssmCacheDtype, std::vector const& ppLayers); + void getPtrBuffers(TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const; void fillSlotMapping( runtime::ITensor& dstPointers, SizeType32 dstSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const; + void allocateCacheBlocks(std::vector const& requestIds); + + void freeCacheBlock(RequestIdType requestId); + + [[nodiscard]] SizeType32 getCacheIndex(RequestIdType requestId) const; + + [[nodiscard]] std::vector getStateIndices( + std::vector const& requestIds, std::vector const& isPadding); + + [[nodiscard]] TensorPtr getConvStates(SizeType32 layerIdx) const; + + [[nodiscard]] TensorPtr getSsmStates(SizeType32 layerIdx) const; + private: // If we need support beam search, we may need mMaxBeamWidth + 1 slots and use separate input / output states. TensorPtr pagedRnnStates; // [local_nb_layers, max_seq_num * max_beam_width, state_size, rnn_hidden_size] or @@ -55,6 +78,10 @@ private: SizeType32 mMaxNumSequences = 0; SizeType32 mMaxBeamWidth = 0; SizeType32 mBeamSlotsPerSequence = 0; + std::unordered_map mLayerOffsets; + std::vector mFreeBlocks; + std::unordered_map mCacheIndex; + std::optional mBufferManager; }; } // namespace tensorrt_llm::batch_manager::rnn_state_manager diff --git a/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp b/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp index 736458d98d..60019ca422 100644 --- a/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp @@ -17,8 +17,11 @@ #include "tensorrt_llm/batch_manager/rnnStateManager.h" #include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/utils/runtimeUtils.h" +#include + using namespace tensorrt_llm::runtime; namespace tensorrt_llm::batch_manager::rnn_state_manager @@ -82,6 +85,64 @@ RnnStateManager::RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runti } } +RnnStateManager::RnnStateManager(SizeType32 dState, SizeType32 dConv, SizeType32 numHeads, SizeType32 nGroups, + SizeType32 headDim, SizeType32 maxBatchSize, WorldConfig const& worldConfig, int64_t stream, + nvinfer1::DataType dtype, nvinfer1::DataType ssmCacheDtype, std::vector const& ppLayers) + : mMaxNumSequences(maxBatchSize) + , mMaxBeamWidth{1} + , mBeamSlotsPerSequence{1} + , mBufferManager{std::make_shared(reinterpret_cast(stream))} +{ + auto const tpSize = worldConfig.getTensorParallelism(); + + auto const dInner = headDim * numHeads; + auto convDim = dInner + 2 * nGroups * dState; + auto nheads = numHeads; + + TLLM_CHECK_WITH_INFO(nheads % tpSize == 0, "nheads must be divisible by tp_size"); + TLLM_CHECK_WITH_INFO(convDim % tpSize == 0, "conv_dim must be divisible by tp_size"); + + convDim = convDim / tpSize; + nheads = nheads / tpSize; + + auto const numLocalLayers = static_cast(ppLayers.size()); + + for (SizeType32 offset = 0; offset < numLocalLayers; ++offset) + { + mLayerOffsets[ppLayers[offset]] = offset; + } + + auto const convStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, convDim, dConv - 1}); + pagedConvStates = mBufferManager->gpu(convStateShape, dtype); + + auto const rnnStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, nheads, headDim, dState}); + pagedRnnStates = mBufferManager->gpu(rnnStateShape, ssmCacheDtype); + + mFreeBlocks.reserve(maxBatchSize); + for (SizeType32 i = 0; i < maxBatchSize; ++i) + { + mFreeBlocks.push_back(i); + } + + auto const statePtrsShape = ITensor::makeShape({numLocalLayers}); + rnnStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType::value); + convStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType::value); + auto* rnnStatePtrArray = bufferCast(*rnnStatePtrs); + auto* convStatePtrArray = bufferCast(*convStatePtrs); + + rnnStatePtr.resize(numLocalLayers); + convStatePtr.resize(numLocalLayers); + for (SizeType32 i = 0; i < numLocalLayers; i++) + { + auto layerRnnStates = ITensor::slice(pagedRnnStates, i, 1); + auto layerConvStates = ITensor::slice(pagedConvStates, i, 1); + rnnStatePtrArray[i] = layerRnnStates->data(); + convStatePtrArray[i] = layerConvStates->data(); + rnnStatePtr[i] = ITensor::slice(rnnStatePtrs, i, 1); + convStatePtr[i] = ITensor::slice(convStatePtrs, i, 1); + } +} + void RnnStateManager::getPtrBuffers( TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const { @@ -113,4 +174,95 @@ void RnnStateManager::fillSlotMapping( } } +void RnnStateManager::allocateCacheBlocks(std::vector const& requestIds) +{ + for (auto const& requestId : requestIds) + { + auto it = mCacheIndex.find(requestId); + if (it == mCacheIndex.end()) + { + TLLM_CHECK_WITH_INFO(!mFreeBlocks.empty(), "Run out of RNN state cache blocks"); + SizeType32 const block = mFreeBlocks.back(); + mFreeBlocks.pop_back(); + mCacheIndex[requestId] = block; + } + } +} + +void RnnStateManager::freeCacheBlock(RequestIdType requestId) +{ + auto it = mCacheIndex.find(requestId); + if (it != mCacheIndex.end()) + { + mFreeBlocks.push_back(it->second); + mCacheIndex.erase(it); + } +} + +RnnStateManager::SizeType32 RnnStateManager::getCacheIndex(RequestIdType requestId) const +{ + auto it = mCacheIndex.find(requestId); + TLLM_CHECK_WITH_INFO(it != mCacheIndex.end(), "Request ID not found in cache index"); + return it->second; +} + +std::vector RnnStateManager::getStateIndices( + std::vector const& requestIds, std::vector const& isPadding) +{ + TLLM_CHECK_WITH_INFO(requestIds.size() == isPadding.size(), "requestIds and isPadding must have the same size"); + + std::unordered_set availableSlots; + availableSlots.reserve(mMaxNumSequences); + for (SizeType32 i = 0; i < mMaxNumSequences; ++i) + { + availableSlots.insert(i); + } + + for (size_t i = 0; i < requestIds.size(); ++i) + { + if (!isPadding[i]) + { + availableSlots.erase(getCacheIndex(requestIds[i])); + } + } + + std::vector result; + result.reserve(requestIds.size()); + auto availableIt = availableSlots.begin(); + + for (size_t i = 0; i < requestIds.size(); ++i) + { + if (isPadding[i]) + { + TLLM_CHECK_WITH_INFO(availableIt != availableSlots.end(), "Run out of available slots for padding"); + result.push_back(*availableIt); + ++availableIt; + } + else + { + result.push_back(getCacheIndex(requestIds[i])); + } + } + + return result; +} + +RnnStateManager::TensorPtr RnnStateManager::getConvStates(SizeType32 layerIdx) const +{ + auto it = mLayerOffsets.find(layerIdx); + TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets"); + auto result = ITensor::slice(pagedConvStates, it->second, 1); + result->squeeze(0); + return result; +} + +RnnStateManager::TensorPtr RnnStateManager::getSsmStates(SizeType32 layerIdx) const +{ + auto it = mLayerOffsets.find(layerIdx); + TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets"); + auto result = ITensor::slice(pagedRnnStates, it->second, 1); + result->squeeze(0); + return result; +} + } // namespace tensorrt_llm::batch_manager::rnn_state_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index efd73e5caf..a9c4b3e03a 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -416,7 +416,36 @@ void initBindings(nb::module_& m) nb::class_(m, "RnnStateManager") .def(nb::init(), - nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")); + nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"), + nb::call_guard()) + .def(nb::init const&>(), + nb::arg("d_state"), nb::arg("d_conv"), nb::arg("num_heads"), nb::arg("n_groups"), nb::arg("head_dim"), + nb::arg("max_batch_size"), nb::arg("world_config"), nb::arg("stream"), nb::arg("dtype"), + nb::arg("ssm_cache_dtype"), nb::arg("pp_layers"), nb::call_guard()) + .def( + "get_conv_states", + [](tb::rnn_state_manager::RnnStateManager& self, tr::SizeType32 layerIdx) -> at::Tensor + { + auto tensor = self.getConvStates(layerIdx); + return tr::Torch::tensor(tensor); + }, + nb::arg("layer_idx"), nb::call_guard()) + .def( + "get_ssm_states", + [](tb::rnn_state_manager::RnnStateManager& self, tr::SizeType32 layerIdx) -> at::Tensor + { + auto tensor = self.getSsmStates(layerIdx); + return tr::Torch::tensor(tensor); + }, + nb::arg("layer_idx"), nb::call_guard()) + .def("allocate_cache_blocks", &tb::rnn_state_manager::RnnStateManager::allocateCacheBlocks, + nb::arg("request_ids"), nb::call_guard()) + .def("free_cache_block", &tb::rnn_state_manager::RnnStateManager::freeCacheBlock, nb::arg("request_id"), + nb::call_guard()) + .def("get_state_indices", &tb::rnn_state_manager::RnnStateManager::getStateIndices, nb::arg("request_ids"), + nb::arg("is_padding"), nb::call_guard()); m.def( "add_new_tokens_to_requests", diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index 4766b49a6c..5edc65933b 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -204,6 +204,7 @@ class FlashInferAttentionMetadata(AttentionMetadata): return self.kv_cache_manager.tokens_per_block def prepare(self) -> None: + super().prepare() extra_attrs = get_model_extra_attrs() if extra_attrs is None: get_global_attrs().attention_metadata = weakref.ref(self) diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 0f4c5bac42..8ef247d069 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -3,7 +3,7 @@ import weakref from collections import namedtuple from dataclasses import dataclass, field from enum import Enum, IntEnum -from typing import (TYPE_CHECKING, Dict, Generic, List, Optional, Protocol, +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Protocol, Tuple, Type, TypeVar, Union) import torch @@ -21,6 +21,7 @@ from tensorrt_llm.models.modeling_utils import QuantConfig from ..memory_buffer_utils import Buffers from ..metadata import KVCacheParams +from ..pyexecutor.mamba_cache_manager import MambaCacheManager from ..pyexecutor.resource_manager import KVCacheManager, KVCacheManagerV2 from ..utils import get_model_extra_attrs @@ -147,6 +148,9 @@ class AttentionMetadata: _num_ctx_tokens: int = field(init=False, default=0, repr=False) _num_tokens: int = field(init=False, default=0, repr=False) + mamba_metadata: Optional[Any] = None + mamba_chunk_size: int = 128 + # The number of tokens in the padded sequence. padded_num_tokens: Optional[int] = None @@ -290,6 +294,23 @@ class AttentionMetadata: """ Hook to be called before the forward step of the model. """ + self._prepare_mamba_metadata() + + def _prepare_mamba_metadata(self): + if self.mamba_metadata is False: + return + + if self.mamba_metadata is None: + if (self.kv_cache_manager is not None + and isinstance(self.kv_cache_manager, MambaCacheManager)): + from ..modules.mamba.mamba2_metadata import Mamba2Metadata + self.mamba_metadata = Mamba2Metadata(self.max_num_requests, + self.mamba_chunk_size) + else: + self.mamba_metadata = False + return + + self.mamba_metadata.prepare(self) def create_cuda_graph_metadata(self, max_batch_size: int, diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 9620746137..90e608fb19 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -933,6 +933,7 @@ class TrtllmAttentionMetadata(AttentionMetadata): self.helix_is_inactive_rank_cpu[:batch_size], non_blocking=True) def prepare(self) -> None: + super().prepare() extra_attrs = get_model_extra_attrs() # If model extra attrs is set, attention_metadata is setup in executor. if extra_attrs is None: diff --git a/tensorrt_llm/_torch/attention_backend/vanilla.py b/tensorrt_llm/_torch/attention_backend/vanilla.py index 46c765e23c..c2b2c7f1a1 100644 --- a/tensorrt_llm/_torch/attention_backend/vanilla.py +++ b/tensorrt_llm/_torch/attention_backend/vanilla.py @@ -59,6 +59,7 @@ def generate_sliding_window_mask(batch_size: int, target_length: int, class VanillaAttentionMetadata(AttentionMetadata): def prepare(self) -> None: + super().prepare() # indices of used cache blocks for each sequence assert self.request_ids is not None self.block_ids_per_seq = self.kv_cache_manager.get_batch_cache_indices( diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index d1a4974e74..ee77603ba5 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -22,7 +22,6 @@ from transformers import AutoConfig, PretrainedConfig from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ BaseWeightMapper -from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata from tensorrt_llm._torch.utils import ActivationType, relu2 from ..attention_backend import AttentionMetadata @@ -424,8 +423,6 @@ class NemotronHModel(DecoderModel): dtype=config.torch_dtype, ) - self.mamba_metadata: Optional[Mamba2Metadata] = None - def forward( self, attn_metadata: AttentionMetadata, @@ -440,11 +437,7 @@ class NemotronHModel(DecoderModel): "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests: - self.mamba_metadata = Mamba2Metadata( - attn_metadata.max_num_requests, - chunk_size=self.model_config.pretrained_config.chunk_size) - self.mamba_metadata.prepare(attn_metadata) + mamba_metadata = attn_metadata.mamba_metadata if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -456,7 +449,7 @@ class NemotronHModel(DecoderModel): hidden_states, attn_metadata, spec_metadata=spec_metadata, - mamba_metadata=self.mamba_metadata) + mamba_metadata=mamba_metadata) hidden_states = self.norm_f(hidden_states) diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_next.py b/tensorrt_llm/_torch/models/modeling_qwen3_next.py index 799d4076ab..373e64f924 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_next.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_next.py @@ -31,6 +31,8 @@ from tensorrt_llm._torch.modules.fla.chunk import chunk_gated_delta_rule from tensorrt_llm._torch.modules.fla.fused_sigmoid_gating_recurrent import \ fused_sigmoid_gating_delta_rule_update from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata +from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \ + use_cpp_mamba_cache_manager from tensorrt_llm.mapping import Mapping from ..attention_backend import AttentionMetadata @@ -754,8 +756,12 @@ class Qwen3NextGatedDeltaNet(nn.Module): batch_split_size = [num_prefills, num_decodes] has_initial_states = mamba_metadata.has_initial_states - state_indices = attn_metadata.kv_cache_manager.get_state_indices( - )[:num_prefills + num_decodes] + batch_size = num_prefills + num_decodes + if use_cpp_mamba_cache_manager(): + state_indices = mamba_metadata.state_indices[:batch_size] + else: + state_indices = attn_metadata.kv_cache_manager.get_state_indices( + )[:batch_size] state_indices_p, state_indices_d = torch.split(state_indices, batch_split_size) @@ -1197,8 +1203,6 @@ class Qwen3NextModel(DecoderModel): use_gemma=True, ) - self.mamba_metadata: Optional[Mamba2Metadata] = None - def forward( self, attn_metadata: AttentionMetadata, @@ -1213,13 +1217,10 @@ class Qwen3NextModel(DecoderModel): "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests: - self.mamba_metadata = Mamba2Metadata( - attn_metadata.max_num_requests, - # chunk_size=self.model_config.pretrained_config.mamba2_chunk_size) - # TODO check how to get the correct chunk_size - chunk_size=128) - self.mamba_metadata.prepare(attn_metadata) + mamba_metadata = attn_metadata.mamba_metadata + if mamba_metadata.max_batch_size != attn_metadata.max_num_requests: + attn_metadata.mamba_metadata = Mamba2Metadata( + attn_metadata.max_num_requests, chunk_size=128) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1233,7 +1234,7 @@ class Qwen3NextModel(DecoderModel): attn_metadata=attn_metadata, residual=residual, spec_metadata=spec_metadata, - mamba_metadata=self.mamba_metadata) + mamba_metadata=mamba_metadata) return hidden_states diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py b/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py index 81b461fa1c..6888dbfaf4 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py @@ -19,6 +19,10 @@ from typing import Tuple import torch from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ + CUDA_GRAPH_DUMMY_REQUEST_ID +from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \ + use_cpp_mamba_cache_manager def cu_seqlens_to_chunk_indices_offsets( @@ -107,11 +111,47 @@ class Mamba2Metadata: self.chunk_indices: torch.Tensor = None self.chunk_offsets: torch.Tensor = None + self.state_indices_cpu = torch.zeros(max_batch_size, + dtype=torch.int32, + pin_memory=True) + self.state_indices = torch.zeros(max_batch_size, + dtype=torch.int32, + device="cuda") + self._query_start_loc_long_buf = torch.arange(0, + max_batch_size + 1, + dtype=torch.long, + device="cuda") + self._query_start_loc_buf = torch.zeros(max_batch_size + 1, + dtype=torch.int, + device="cuda") + self.query_start_loc_long = self._query_start_loc_long_buf + self.query_start_loc = self._query_start_loc_buf + def prepare(self, attn_metadata: AttentionMetadata): batch_size = attn_metadata.seq_lens.shape[0] num_contexts = attn_metadata.num_contexts context_lens = attn_metadata.seq_lens_cuda[:num_contexts] num_ctx_tokens = attn_metadata.num_ctx_tokens + + kv_cache_manager = attn_metadata.kv_cache_manager + request_ids = attn_metadata.request_ids + + if (kv_cache_manager is not None + and hasattr(kv_cache_manager, 'get_state_indices') + and request_ids is not None): + if use_cpp_mamba_cache_manager(): + batch_request_ids = request_ids[:batch_size] + is_padding = [ + req_id == CUDA_GRAPH_DUMMY_REQUEST_ID + for req_id in batch_request_ids + ] + indices = kv_cache_manager.get_state_indices( + batch_request_ids, is_padding) + for i, idx in enumerate(indices): + self.state_indices_cpu[i] = idx + self.state_indices[:batch_size].copy_( + self.state_indices_cpu[:batch_size], non_blocking=True) + if num_contexts > 0: torch.cumsum(context_lens, dim=0, @@ -125,8 +165,14 @@ class Mamba2Metadata: out=self.cu_seqlens[num_contexts + 1:batch_size + 1]) # Need both `query_start_loc` and `query_start_loc_long` because `causal_conv1d_fn` # accepts only `int32` while `chunk_gated_delta_rule` accepts only `long`. - self.query_start_loc = self.cu_seqlens[:batch_size + 1] - self.query_start_loc_long = self.query_start_loc.to(torch.long) + self._query_start_loc_buf[:batch_size + + 1] = self.cu_seqlens[:batch_size + 1] + self.query_start_loc = self._query_start_loc_buf[:batch_size + 1] + self._query_start_loc_long_buf[:batch_size + 1].copy_( + self.query_start_loc.to(torch.long), non_blocking=True) + self.query_start_loc_long = self._query_start_loc_long_buf[: + batch_size + + 1] self.seq_idx = torch.repeat_interleave( torch.arange(num_contexts, dtype=torch.int, @@ -148,8 +194,11 @@ class Mamba2Metadata: self.chunk_offsets = None else: self.query_start_loc = None - self.query_start_loc_long = torch.arange( - 0, - batch_size + 1, - dtype=torch.long, - device=self.cu_seqlens.device) + torch.arange(0, + batch_size + 1, + dtype=torch.long, + device=self.cu_seqlens.device, + out=self._query_start_loc_long_buf[:batch_size + 1]) + self.query_start_loc_long = self._query_start_loc_long_buf[: + batch_size + + 1] diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 192e304419..1874ea65d4 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -21,6 +21,8 @@ from flashinfer.mamba import selective_state_update as selective_state_update_fi from torch import nn from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata +from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \ + use_cpp_mamba_cache_manager from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -204,15 +206,24 @@ class Mamba2Mixer(nn.Module): seqlen_split_size = [num_prefill_tokens, num_decode_tokens] batch_split_size = [num_prefills, num_decodes] - state_indices = attn_metadata.kv_cache_manager.get_state_indices( - )[:num_prefills + num_decodes] + if use_cpp_mamba_cache_manager(): + state_indices = mamba_metadata.state_indices[:num_prefills + + num_decodes] + conv_states = attn_metadata.kv_cache_manager.get_conv_states( + self.layer_idx) + ssm_states = attn_metadata.kv_cache_manager.get_ssm_states( + self.layer_idx) + layer_cache = None # Not used in C++ path + else: + state_indices = attn_metadata.kv_cache_manager.get_state_indices( + )[:num_prefills + num_decodes] + layer_cache = attn_metadata.kv_cache_manager.mamba_layer_cache( + self.layer_idx) + conv_states = layer_cache.conv + ssm_states = layer_cache.temporal state_indices_p, state_indices_d = torch.split(state_indices, batch_split_size) - layer_cache = attn_metadata.kv_cache_manager.mamba_layer_cache( - self.layer_idx) - conv_states = layer_cache.conv - ssm_states = layer_cache.temporal # in_proj zxbcdt = self.in_proj(hidden_states) @@ -310,6 +321,9 @@ class Mamba2Mixer(nn.Module): is_target_verify = attn_metadata.kv_cache_manager.is_speculative( ) and spec_metadata is not None if is_target_verify: + # Speculative decoding only supported with Python path + assert layer_cache is not None, \ + "Speculative decoding requires Python MambaCacheManager" # TODO: support dynamic speculation, will add current_draft_len later [TRTLLM-10319] draft_token_num = spec_metadata.max_draft_len + 1 intermediate_conv_states = layer_cache.intermediate_conv_window diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 7878215855..0ae1443138 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -19,7 +19,7 @@ from ..speculative.mtp import SampleStateTensorsMTP from ..speculative.utils import get_draft_kv_cache_manager from ..utils import make_weak_ref, piecewise_cuda_graph from .llm_request import get_draft_token_length -from .mamba_cache_manager import MambaCacheManager +from .mamba_cache_manager import MambaCacheManager, use_cpp_mamba_cache_manager from .resource_manager import (BaseResourceManager, ResourceManager, ResourceManagerType) from .sampler import SampleStateTensors @@ -460,8 +460,8 @@ class CUDAGraphRunner: if spec_res_mgr: spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID]) - # handle special cases of padding requests + MambaCacheManager or MambaHybridCacheManager - if isinstance(kv_cache_manager, MambaCacheManager): + if (isinstance(kv_cache_manager, MambaCacheManager) + and not use_cpp_mamba_cache_manager()): kv_cache_manager.reorder_state_indices_when_padding_requests( batch_size, padding_size) diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index 56d3cbdc1c..b8e4a04575 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -13,22 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Union import torch -from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest -from tensorrt_llm._torch.pyexecutor.resource_manager import ( - BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers) -from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests -from tensorrt_llm.llmapi.llm_args import KvCacheConfig -from tensorrt_llm.logger import logger -from tensorrt_llm.mapping import Mapping +import tensorrt_llm.bindings if TYPE_CHECKING: from tensorrt_llm._torch.attention_backend.interface import \ AttentionMetadata +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest +from tensorrt_llm._torch.pyexecutor.resource_manager import ( + BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers) +from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests +from tensorrt_llm._utils import torch_dtype_to_binding +from tensorrt_llm.llmapi.llm_args import KvCacheConfig +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping + +RnnStateManagerCpp = tensorrt_llm.bindings.internal.batch_manager.RnnStateManager +WorldConfig = tensorrt_llm.bindings.WorldConfig GB = 1 << 30 @@ -42,7 +48,117 @@ def get_tensor_size_bytes(tensor): return 0 -class MambaCacheManager(BaseResourceManager): +def use_cpp_mamba_cache_manager() -> bool: + """Check if C++ MambaCacheManager should be used. + + Returns True if TRTLLM_USE_CPP_MAMBA='1' is set, False otherwise. + By default, PythonMambaCacheManager is used. + """ + return os.environ.get('TRTLLM_USE_CPP_MAMBA', '0') == '1' + + +class CppMambaCacheManager(BaseResourceManager): + """C++ backed Mamba cache manager using RnnStateManager bindings.""" + + def __init__( + self, + d_state: int, + d_conv: int, + num_heads: int, + n_groups: int, + head_dim: int, + num_layers: int, + max_num_sequences: int, + mapping: Mapping, + dtype: torch.dtype, + ssm_cache_dtype: torch.dtype, + layer_mask: Optional[List[bool]] = None, + stream: Optional[torch.cuda.Stream] = None, + ) -> None: + self.mamba_ssm_cache_dtype = ssm_cache_dtype + + # get tp size + tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1 + world_config = WorldConfig( + tensor_parallelism=tp_size, + pipeline_parallelism=mapping.pp_size, + rank=mapping.rank, + gpus_per_node=mapping.gpus_per_node, + ) + + dtype_binding = torch_dtype_to_binding(dtype) + ssm_cache_dtype_binding = torch_dtype_to_binding( + ssm_cache_dtype if ssm_cache_dtype is not None else dtype) + + self._stream = stream if stream is not None else torch.cuda.current_stream( + ) + + pp_layers, _ = get_pp_layers(num_layers, mapping, layer_mask=layer_mask) + + self.mamba_impl = RnnStateManagerCpp( + d_state=d_state, + d_conv=d_conv, + num_heads=num_heads, + n_groups=n_groups, + head_dim=head_dim, + max_batch_size=max_num_sequences, + world_config=world_config, + stream=self._stream.cuda_stream, + dtype=dtype_binding, + ssm_cache_dtype=ssm_cache_dtype_binding, + pp_layers=pp_layers, + ) + self._max_num_sequences = max_num_sequences + + def get_max_resource_count(self) -> int: + # Return the maximum number of sequences that can be cached. + return self._max_num_sequences + + def get_needed_resource_to_completion(self, request: LlmRequest) -> int: + # For Mamba cache manager, we always need one slot per request. + return 1 + + def prepare_resources(self, scheduled_batch: ScheduledRequests): + context_ids = [ + i.py_request_id for i in scheduled_batch.context_requests + ] + generation_ids = [ + i.py_request_id for i in scheduled_batch.generation_requests + ] + request_ids = context_ids + generation_ids + self.mamba_impl.allocate_cache_blocks(request_ids) + + def free_resources(self, request: LlmRequest): + self.mamba_impl.free_cache_block(request.py_request_id) + + def add_dummy_requests(self, request_ids: List[int], **kwargs): + # For CUDA graph dummy requests, the blocks will be allocated + # when get_state_indices is called. + from .cuda_graph_runner import CUDA_GRAPH_DUMMY_REQUEST_ID + request_ids = [ + rid for rid in request_ids if rid != CUDA_GRAPH_DUMMY_REQUEST_ID + ] + if request_ids: + self.mamba_impl.allocate_cache_blocks(request_ids) + + def get_state_indices(self, request_ids: List[int], + is_padding: List[bool]) -> List[int]: + return self.mamba_impl.get_state_indices(request_ids, is_padding) + + def get_conv_states(self, layer_idx: int) -> torch.Tensor: + return self.mamba_impl.get_conv_states(layer_idx) + + def get_ssm_states(self, layer_idx: int) -> torch.Tensor: + return self.mamba_impl.get_ssm_states(layer_idx) + + def get_mamba_ssm_cache_dtype(self) -> torch.dtype: + return self.mamba_ssm_cache_dtype + + def shutdown(self): + torch.cuda.empty_cache() + + +class PythonMambaCacheManager(BaseResourceManager): @dataclass(frozen=True, kw_only=True) class State: @@ -193,6 +309,17 @@ class MambaCacheManager(BaseResourceManager): dtype=torch.int32, device=device) + # Store max_batch_size for resource management + self._max_batch_size = max_batch_size + + def get_max_resource_count(self) -> int: + """Return the maximum number of sequences that can be cached.""" + return self._max_batch_size + + def get_needed_resource_to_completion(self, request: LlmRequest) -> int: + """For Mamba cache manager, we always need one slot per request.""" + return 1 + @torch.inference_mode() def _prepare_mamba_cache_blocks(self, request_ids: List[int]): self.state_indices_list.clear() @@ -261,7 +388,9 @@ class MambaCacheManager(BaseResourceManager): block = self.mamba_cache_index.pop(request_id) self.mamba_cache_free_blocks.append(block) - def get_state_indices(self) -> torch.Tensor: + def get_state_indices(self, + request_ids: List[int] = None, + is_padding: List[bool] = None) -> torch.Tensor: return self.state_indices def get_conv_states(self, layer_idx: int) -> torch.Tensor: @@ -296,6 +425,9 @@ class MambaCacheManager(BaseResourceManager): layer_offset = self.mamba_layer_offsets[layer_idx] return self.mamba_cache.at_layer_idx(layer_offset) + def get_mamba_ssm_cache_dtype(self) -> torch.dtype: + return self.mamba_ssm_cache_dtype + def shutdown(self): """Release tensor memory.""" # Clear state indices @@ -346,6 +478,140 @@ class MambaCacheManager(BaseResourceManager): conv_states[:, state_indices_d, :] = accepted_conv_state +class MambaCacheManager(BaseResourceManager): + + def __init__( + self, + d_state: int, + d_conv: int, + num_heads: int, + n_groups: int, + head_dim: int, + num_layers: int, + max_batch_size: int, + spec_state_size: int, + mapping: Mapping, + dtype: torch.dtype, + ssm_cache_dtype: torch.dtype, + layer_mask: Optional[List[bool]] = None, + stream: Optional[torch.cuda.Stream] = None, + speculative_num_draft_tokens: Optional[int] = None, + ) -> None: + max_num_sequences = max_batch_size * mapping.pp_size + self._use_cpp = use_cpp_mamba_cache_manager() + + if self._use_cpp: + assert speculative_num_draft_tokens is None, \ + "speculative_num_draft_tokens is not supported in CppMambaCacheManager" + self._impl = CppMambaCacheManager( + d_state=d_state, + d_conv=d_conv, + num_heads=num_heads, + n_groups=n_groups, + head_dim=head_dim, + num_layers=num_layers, + max_num_sequences=max_num_sequences, + mapping=mapping, + dtype=dtype, + ssm_cache_dtype=ssm_cache_dtype, + layer_mask=layer_mask, + stream=stream, + ) + else: + self._impl = PythonMambaCacheManager( + d_state=d_state, + d_conv=d_conv, + num_heads=num_heads, + n_groups=n_groups, + head_dim=head_dim, + num_layers=num_layers, + max_batch_size=max_batch_size, + spec_state_size=spec_state_size, + mapping=mapping, + dtype=dtype, + ssm_cache_dtype=ssm_cache_dtype, + layer_mask=layer_mask, + speculative_num_draft_tokens=speculative_num_draft_tokens, + ) + + def get_max_resource_count(self) -> int: + return self._impl.get_max_resource_count() + + def get_needed_resource_to_completion(self, request: LlmRequest) -> int: + return self._impl.get_needed_resource_to_completion(request) + + def prepare_resources(self, scheduled_batch: ScheduledRequests): + self._impl.prepare_resources(scheduled_batch) + + def free_resources(self, request: LlmRequest): + self._impl.free_resources(request) + + def add_dummy_requests(self, request_ids: List[int], **kwargs): + if self._use_cpp: + self._impl.add_dummy_requests(request_ids, **kwargs) + + def get_state_indices( + self, + request_ids: Optional[List[int]] = None, + is_padding: Optional[List[bool]] = None + ) -> Union[torch.Tensor, List[int]]: + return self._impl.get_state_indices(request_ids, is_padding) + + def reorder_state_indices_when_padding_requests(self, request_size: int, + padding_size: int): + assert not self._use_cpp, "reorder_state_indices_when_padding_requests is not supported in CppMambaCacheManager" + self._impl.reorder_state_indices_when_padding_requests( + request_size, padding_size) + + @property + def mamba_cache_free_blocks(self) -> List[int]: + assert not self._use_cpp, "mamba_cache_free_blocks is not supported in CppMambaCacheManager" + return self._impl.mamba_cache_free_blocks + + @property + def mamba_cache_index(self) -> Dict[int, int]: + assert not self._use_cpp, "mamba_cache_index is not supported in CppMambaCacheManager" + return self._impl.mamba_cache_index + + def get_conv_states(self, layer_idx: int) -> torch.Tensor: + return self._impl.get_conv_states(layer_idx) + + def get_ssm_states(self, layer_idx: int) -> torch.Tensor: + return self._impl.get_ssm_states(layer_idx) + + def get_mamba_ssm_cache_dtype(self) -> torch.dtype: + return self._impl.get_mamba_ssm_cache_dtype() + + def get_intermediate_ssm_states(self, + layer_idx: int) -> Optional[torch.Tensor]: + assert not self._use_cpp, "get_intermediate_ssm_states is not supported in CppMambaCacheManager" + return self._impl.get_intermediate_ssm_states(layer_idx) + + def get_intermediate_conv_states(self, + layer_idx: int) -> Optional[torch.Tensor]: + assert not self._use_cpp, "get_intermediate_conv_states is not supported in CppMambaCacheManager" + return self._impl.get_intermediate_conv_states(layer_idx) + + def is_speculative(self) -> bool: + assert not self._use_cpp, "is_speculative is not supported in CppMambaCacheManager" + return self._impl.is_speculative() + + def mamba_layer_cache( + self, layer_idx: int + ) -> Union[PythonMambaCacheManager.State, + PythonMambaCacheManager.SpeculativeState, None]: + assert not self._use_cpp, "mamba_layer_cache is not supported in CppMambaCacheManager" + return self._impl.mamba_layer_cache(layer_idx) + + def shutdown(self): + self._impl.shutdown() + + def update_mamba_states(self, attn_metadata: "AttentionMetadata", + num_accepted_tokens: torch.Tensor): + assert self._use_cpp, "update_mamba_states is not supported in PythonMambaCacheManager" + self._impl.update_mamba_states(attn_metadata, num_accepted_tokens) + + class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): def __init__( @@ -399,6 +665,7 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): mamba_cache_dtype, mamba_ssm_cache_dtype, mamba_layer_mask, + execution_stream, speculative_num_draft_tokens=spec_config.max_draft_len if spec_config is not None else None, ) @@ -430,6 +697,10 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): MambaCacheManager.free_resources(self, request) KVCacheManager.free_resources(self, request) + def add_dummy_requests(self, request_ids: List[int], **kwargs): + MambaCacheManager.add_dummy_requests(self, request_ids) + return KVCacheManager.add_dummy_requests(self, request_ids, **kwargs) + def shutdown(self): MambaCacheManager.shutdown(self) KVCacheManager.shutdown(self) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index de14006bfd..6b2a77dcf5 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2777,6 +2777,8 @@ class PyTorchModelEngine(ModelEngine): num_extra_kv_tokens=get_num_extra_kv_tokens(spec_config)) attn_metadata.kv_cache_manager = kv_cache_manager + if hasattr(self.model.model_config.pretrained_config, 'chunk_size'): + attn_metadata.mamba_chunk_size = self.model.model_config.pretrained_config.chunk_size attn_metadata.prepare() peft_cache_manager = resource_manager and resource_manager.get_resource_manager( diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py index ce6b20c37f..5b5355a28d 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/runner.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py @@ -884,7 +884,7 @@ class Runner: else: raise NotImplementedError("Unsupported config") kv_cache_manager.add_dummy_requests( - list(range(max_batch_size)), [max_seq_len] * max_batch_size + list(range(max_batch_size)), token_nums=[max_seq_len] * max_batch_size ) return kv_cache_manager diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 93b5367140..55e5a6147c 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -5690,6 +5690,17 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness): @skip_pre_hopper @pytest.mark.skip_less_mpi_world_size(4) @pytest.mark.skip_less_device_memory(40000) + @pytest.mark.parametrize( + "use_cpp_mamba", + [ + False, + True, + ], + ids=[ + "python_mamba_cache", + "cpp_mamba_cache", + ], + ) @pytest.mark.parametrize( "attention_dp", [ @@ -5701,7 +5712,10 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness): "attention_dp_on", ], ) - def test_fp8_4gpus(self, attention_dp): + def test_fp8_4gpus(self, attention_dp, use_cpp_mamba, monkeypatch): + monkeypatch.setenv("TRTLLM_USE_CPP_MAMBA", + "1" if use_cpp_mamba else "0") + with LLM( f"{llm_models_root()}/Nemotron-SuperV3-phase1-mtp-fp8-fp8kv", kv_cache_config=KvCacheConfig( diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 088ba536c7..de47aab6ec 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -264,8 +264,10 @@ accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1 accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True] accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True] accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-True-True] -accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on] -accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off] +accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-python_mamba_cache] +accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-cpp_mamba_cache] +accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-python_mamba_cache] +accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-cpp_mamba_cache] accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on-trtllm] # multimodal accuracy tests diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 398342c0a3..aa9d560d77 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -104,8 +104,10 @@ l0_dgx_h100: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] - - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off] - - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on] + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-python_mamba_cache] + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-cpp_mamba_cache] + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-python_mamba_cache] + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-cpp_mamba_cache] - test_e2e.py::test_ptp_quickstart_advanced_bs1 - test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8] - test_e2e.py::test_trtllm_bench_llmapi_launch[pytorch_backend-llama-v3-llama3-8b]