mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[TRTLLM-10273][feat] Move MambaCacheManager from Python to C++ (#10540)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
parent
d6e49542bd
commit
7d992972b2
@ -16,11 +16,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/common.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::batch_manager::rnn_state_manager
|
||||
{
|
||||
|
||||
@ -30,16 +35,34 @@ public:
|
||||
using TensorPtr = runtime::ITensor::SharedPtr;
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
using TensorMap = runtime::StringPtrMap<runtime::ITensor>;
|
||||
using RequestIdType = tensorrt_llm::batch_manager::RequestIdType;
|
||||
|
||||
RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runtime::ModelConfig const& modelConfig,
|
||||
runtime::WorldConfig const& worldConfig, tensorrt_llm::runtime::BufferManager const& bufferManager);
|
||||
|
||||
RnnStateManager(SizeType32 dState, SizeType32 dConv, SizeType32 numHeads, SizeType32 nGroups, SizeType32 headDim,
|
||||
SizeType32 maxBatchSize, runtime::WorldConfig const& worldConfig, int64_t stream, nvinfer1::DataType dtype,
|
||||
nvinfer1::DataType ssmCacheDtype, std::vector<SizeType32> const& ppLayers);
|
||||
|
||||
void getPtrBuffers(TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig,
|
||||
runtime::WorldConfig const& worldConfig) const;
|
||||
|
||||
void fillSlotMapping(
|
||||
runtime::ITensor& dstPointers, SizeType32 dstSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const;
|
||||
|
||||
void allocateCacheBlocks(std::vector<RequestIdType> const& requestIds);
|
||||
|
||||
void freeCacheBlock(RequestIdType requestId);
|
||||
|
||||
[[nodiscard]] SizeType32 getCacheIndex(RequestIdType requestId) const;
|
||||
|
||||
[[nodiscard]] std::vector<SizeType32> getStateIndices(
|
||||
std::vector<RequestIdType> const& requestIds, std::vector<bool> const& isPadding);
|
||||
|
||||
[[nodiscard]] TensorPtr getConvStates(SizeType32 layerIdx) const;
|
||||
|
||||
[[nodiscard]] TensorPtr getSsmStates(SizeType32 layerIdx) const;
|
||||
|
||||
private:
|
||||
// If we need support beam search, we may need mMaxBeamWidth + 1 slots and use separate input / output states.
|
||||
TensorPtr pagedRnnStates; // [local_nb_layers, max_seq_num * max_beam_width, state_size, rnn_hidden_size] or
|
||||
@ -55,6 +78,10 @@ private:
|
||||
SizeType32 mMaxNumSequences = 0;
|
||||
SizeType32 mMaxBeamWidth = 0;
|
||||
SizeType32 mBeamSlotsPerSequence = 0;
|
||||
std::unordered_map<SizeType32, SizeType32> mLayerOffsets;
|
||||
std::vector<SizeType32> mFreeBlocks;
|
||||
std::unordered_map<RequestIdType, SizeType32> mCacheIndex;
|
||||
std::optional<runtime::BufferManager> mBufferManager;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::rnn_state_manager
|
||||
|
||||
@ -17,8 +17,11 @@
|
||||
|
||||
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/utils/runtimeUtils.h"
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm::batch_manager::rnn_state_manager
|
||||
@ -82,6 +85,64 @@ RnnStateManager::RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runti
|
||||
}
|
||||
}
|
||||
|
||||
RnnStateManager::RnnStateManager(SizeType32 dState, SizeType32 dConv, SizeType32 numHeads, SizeType32 nGroups,
|
||||
SizeType32 headDim, SizeType32 maxBatchSize, WorldConfig const& worldConfig, int64_t stream,
|
||||
nvinfer1::DataType dtype, nvinfer1::DataType ssmCacheDtype, std::vector<SizeType32> const& ppLayers)
|
||||
: mMaxNumSequences(maxBatchSize)
|
||||
, mMaxBeamWidth{1}
|
||||
, mBeamSlotsPerSequence{1}
|
||||
, mBufferManager{std::make_shared<CudaStream>(reinterpret_cast<cudaStream_t>(stream))}
|
||||
{
|
||||
auto const tpSize = worldConfig.getTensorParallelism();
|
||||
|
||||
auto const dInner = headDim * numHeads;
|
||||
auto convDim = dInner + 2 * nGroups * dState;
|
||||
auto nheads = numHeads;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(nheads % tpSize == 0, "nheads must be divisible by tp_size");
|
||||
TLLM_CHECK_WITH_INFO(convDim % tpSize == 0, "conv_dim must be divisible by tp_size");
|
||||
|
||||
convDim = convDim / tpSize;
|
||||
nheads = nheads / tpSize;
|
||||
|
||||
auto const numLocalLayers = static_cast<SizeType32>(ppLayers.size());
|
||||
|
||||
for (SizeType32 offset = 0; offset < numLocalLayers; ++offset)
|
||||
{
|
||||
mLayerOffsets[ppLayers[offset]] = offset;
|
||||
}
|
||||
|
||||
auto const convStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, convDim, dConv - 1});
|
||||
pagedConvStates = mBufferManager->gpu(convStateShape, dtype);
|
||||
|
||||
auto const rnnStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, nheads, headDim, dState});
|
||||
pagedRnnStates = mBufferManager->gpu(rnnStateShape, ssmCacheDtype);
|
||||
|
||||
mFreeBlocks.reserve(maxBatchSize);
|
||||
for (SizeType32 i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
mFreeBlocks.push_back(i);
|
||||
}
|
||||
|
||||
auto const statePtrsShape = ITensor::makeShape({numLocalLayers});
|
||||
rnnStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
|
||||
convStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
|
||||
auto* rnnStatePtrArray = bufferCast<void*>(*rnnStatePtrs);
|
||||
auto* convStatePtrArray = bufferCast<void*>(*convStatePtrs);
|
||||
|
||||
rnnStatePtr.resize(numLocalLayers);
|
||||
convStatePtr.resize(numLocalLayers);
|
||||
for (SizeType32 i = 0; i < numLocalLayers; i++)
|
||||
{
|
||||
auto layerRnnStates = ITensor::slice(pagedRnnStates, i, 1);
|
||||
auto layerConvStates = ITensor::slice(pagedConvStates, i, 1);
|
||||
rnnStatePtrArray[i] = layerRnnStates->data();
|
||||
convStatePtrArray[i] = layerConvStates->data();
|
||||
rnnStatePtr[i] = ITensor::slice(rnnStatePtrs, i, 1);
|
||||
convStatePtr[i] = ITensor::slice(convStatePtrs, i, 1);
|
||||
}
|
||||
}
|
||||
|
||||
void RnnStateManager::getPtrBuffers(
|
||||
TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const
|
||||
{
|
||||
@ -113,4 +174,95 @@ void RnnStateManager::fillSlotMapping(
|
||||
}
|
||||
}
|
||||
|
||||
void RnnStateManager::allocateCacheBlocks(std::vector<RequestIdType> const& requestIds)
|
||||
{
|
||||
for (auto const& requestId : requestIds)
|
||||
{
|
||||
auto it = mCacheIndex.find(requestId);
|
||||
if (it == mCacheIndex.end())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!mFreeBlocks.empty(), "Run out of RNN state cache blocks");
|
||||
SizeType32 const block = mFreeBlocks.back();
|
||||
mFreeBlocks.pop_back();
|
||||
mCacheIndex[requestId] = block;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RnnStateManager::freeCacheBlock(RequestIdType requestId)
|
||||
{
|
||||
auto it = mCacheIndex.find(requestId);
|
||||
if (it != mCacheIndex.end())
|
||||
{
|
||||
mFreeBlocks.push_back(it->second);
|
||||
mCacheIndex.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
RnnStateManager::SizeType32 RnnStateManager::getCacheIndex(RequestIdType requestId) const
|
||||
{
|
||||
auto it = mCacheIndex.find(requestId);
|
||||
TLLM_CHECK_WITH_INFO(it != mCacheIndex.end(), "Request ID not found in cache index");
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<RnnStateManager::SizeType32> RnnStateManager::getStateIndices(
|
||||
std::vector<RequestIdType> const& requestIds, std::vector<bool> const& isPadding)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(requestIds.size() == isPadding.size(), "requestIds and isPadding must have the same size");
|
||||
|
||||
std::unordered_set<SizeType32> availableSlots;
|
||||
availableSlots.reserve(mMaxNumSequences);
|
||||
for (SizeType32 i = 0; i < mMaxNumSequences; ++i)
|
||||
{
|
||||
availableSlots.insert(i);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < requestIds.size(); ++i)
|
||||
{
|
||||
if (!isPadding[i])
|
||||
{
|
||||
availableSlots.erase(getCacheIndex(requestIds[i]));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<SizeType32> result;
|
||||
result.reserve(requestIds.size());
|
||||
auto availableIt = availableSlots.begin();
|
||||
|
||||
for (size_t i = 0; i < requestIds.size(); ++i)
|
||||
{
|
||||
if (isPadding[i])
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(availableIt != availableSlots.end(), "Run out of available slots for padding");
|
||||
result.push_back(*availableIt);
|
||||
++availableIt;
|
||||
}
|
||||
else
|
||||
{
|
||||
result.push_back(getCacheIndex(requestIds[i]));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
RnnStateManager::TensorPtr RnnStateManager::getConvStates(SizeType32 layerIdx) const
|
||||
{
|
||||
auto it = mLayerOffsets.find(layerIdx);
|
||||
TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets");
|
||||
auto result = ITensor::slice(pagedConvStates, it->second, 1);
|
||||
result->squeeze(0);
|
||||
return result;
|
||||
}
|
||||
|
||||
RnnStateManager::TensorPtr RnnStateManager::getSsmStates(SizeType32 layerIdx) const
|
||||
{
|
||||
auto it = mLayerOffsets.find(layerIdx);
|
||||
TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets");
|
||||
auto result = ITensor::slice(pagedRnnStates, it->second, 1);
|
||||
result->squeeze(0);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::rnn_state_manager
|
||||
|
||||
@ -416,7 +416,36 @@ void initBindings(nb::module_& m)
|
||||
|
||||
nb::class_<tb::rnn_state_manager::RnnStateManager>(m, "RnnStateManager")
|
||||
.def(nb::init<tr::SizeType32, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
|
||||
nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"));
|
||||
nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"),
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(nb::init<tr::SizeType32, tr::SizeType32, tr::SizeType32, tr::SizeType32, tr::SizeType32, tr::SizeType32,
|
||||
tr::WorldConfig const&, int64_t, nvinfer1::DataType, nvinfer1::DataType,
|
||||
std::vector<tr::SizeType32> const&>(),
|
||||
nb::arg("d_state"), nb::arg("d_conv"), nb::arg("num_heads"), nb::arg("n_groups"), nb::arg("head_dim"),
|
||||
nb::arg("max_batch_size"), nb::arg("world_config"), nb::arg("stream"), nb::arg("dtype"),
|
||||
nb::arg("ssm_cache_dtype"), nb::arg("pp_layers"), nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"get_conv_states",
|
||||
[](tb::rnn_state_manager::RnnStateManager& self, tr::SizeType32 layerIdx) -> at::Tensor
|
||||
{
|
||||
auto tensor = self.getConvStates(layerIdx);
|
||||
return tr::Torch::tensor(tensor);
|
||||
},
|
||||
nb::arg("layer_idx"), nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"get_ssm_states",
|
||||
[](tb::rnn_state_manager::RnnStateManager& self, tr::SizeType32 layerIdx) -> at::Tensor
|
||||
{
|
||||
auto tensor = self.getSsmStates(layerIdx);
|
||||
return tr::Torch::tensor(tensor);
|
||||
},
|
||||
nb::arg("layer_idx"), nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("allocate_cache_blocks", &tb::rnn_state_manager::RnnStateManager::allocateCacheBlocks,
|
||||
nb::arg("request_ids"), nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("free_cache_block", &tb::rnn_state_manager::RnnStateManager::freeCacheBlock, nb::arg("request_id"),
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("get_state_indices", &tb::rnn_state_manager::RnnStateManager::getStateIndices, nb::arg("request_ids"),
|
||||
nb::arg("is_padding"), nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
m.def(
|
||||
"add_new_tokens_to_requests",
|
||||
|
||||
@ -204,6 +204,7 @@ class FlashInferAttentionMetadata(AttentionMetadata):
|
||||
return self.kv_cache_manager.tokens_per_block
|
||||
|
||||
def prepare(self) -> None:
|
||||
super().prepare()
|
||||
extra_attrs = get_model_extra_attrs()
|
||||
if extra_attrs is None:
|
||||
get_global_attrs().attention_metadata = weakref.ref(self)
|
||||
|
||||
@ -3,7 +3,7 @@ import weakref
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, IntEnum
|
||||
from typing import (TYPE_CHECKING, Dict, Generic, List, Optional, Protocol,
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Protocol,
|
||||
Tuple, Type, TypeVar, Union)
|
||||
|
||||
import torch
|
||||
@ -21,6 +21,7 @@ from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from ..memory_buffer_utils import Buffers
|
||||
from ..metadata import KVCacheParams
|
||||
from ..pyexecutor.mamba_cache_manager import MambaCacheManager
|
||||
from ..pyexecutor.resource_manager import KVCacheManager, KVCacheManagerV2
|
||||
from ..utils import get_model_extra_attrs
|
||||
|
||||
@ -147,6 +148,9 @@ class AttentionMetadata:
|
||||
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
|
||||
_num_tokens: int = field(init=False, default=0, repr=False)
|
||||
|
||||
mamba_metadata: Optional[Any] = None
|
||||
mamba_chunk_size: int = 128
|
||||
|
||||
# The number of tokens in the padded sequence.
|
||||
padded_num_tokens: Optional[int] = None
|
||||
|
||||
@ -290,6 +294,23 @@ class AttentionMetadata:
|
||||
"""
|
||||
Hook to be called before the forward step of the model.
|
||||
"""
|
||||
self._prepare_mamba_metadata()
|
||||
|
||||
def _prepare_mamba_metadata(self):
|
||||
if self.mamba_metadata is False:
|
||||
return
|
||||
|
||||
if self.mamba_metadata is None:
|
||||
if (self.kv_cache_manager is not None
|
||||
and isinstance(self.kv_cache_manager, MambaCacheManager)):
|
||||
from ..modules.mamba.mamba2_metadata import Mamba2Metadata
|
||||
self.mamba_metadata = Mamba2Metadata(self.max_num_requests,
|
||||
self.mamba_chunk_size)
|
||||
else:
|
||||
self.mamba_metadata = False
|
||||
return
|
||||
|
||||
self.mamba_metadata.prepare(self)
|
||||
|
||||
def create_cuda_graph_metadata(self,
|
||||
max_batch_size: int,
|
||||
|
||||
@ -933,6 +933,7 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
self.helix_is_inactive_rank_cpu[:batch_size], non_blocking=True)
|
||||
|
||||
def prepare(self) -> None:
|
||||
super().prepare()
|
||||
extra_attrs = get_model_extra_attrs()
|
||||
# If model extra attrs is set, attention_metadata is setup in executor.
|
||||
if extra_attrs is None:
|
||||
|
||||
@ -59,6 +59,7 @@ def generate_sliding_window_mask(batch_size: int, target_length: int,
|
||||
class VanillaAttentionMetadata(AttentionMetadata):
|
||||
|
||||
def prepare(self) -> None:
|
||||
super().prepare()
|
||||
# indices of used cache blocks for each sequence
|
||||
assert self.request_ids is not None
|
||||
self.block_ids_per_seq = self.kv_cache_manager.get_batch_cache_indices(
|
||||
|
||||
@ -22,7 +22,6 @@ from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
|
||||
BaseWeightMapper
|
||||
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from tensorrt_llm._torch.utils import ActivationType, relu2
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -424,8 +423,6 @@ class NemotronHModel(DecoderModel):
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
self.mamba_metadata: Optional[Mamba2Metadata] = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
@ -440,11 +437,7 @@ class NemotronHModel(DecoderModel):
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests:
|
||||
self.mamba_metadata = Mamba2Metadata(
|
||||
attn_metadata.max_num_requests,
|
||||
chunk_size=self.model_config.pretrained_config.chunk_size)
|
||||
self.mamba_metadata.prepare(attn_metadata)
|
||||
mamba_metadata = attn_metadata.mamba_metadata
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
@ -456,7 +449,7 @@ class NemotronHModel(DecoderModel):
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
spec_metadata=spec_metadata,
|
||||
mamba_metadata=self.mamba_metadata)
|
||||
mamba_metadata=mamba_metadata)
|
||||
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
|
||||
|
||||
@ -31,6 +31,8 @@ from tensorrt_llm._torch.modules.fla.chunk import chunk_gated_delta_rule
|
||||
from tensorrt_llm._torch.modules.fla.fused_sigmoid_gating_recurrent import \
|
||||
fused_sigmoid_gating_delta_rule_update
|
||||
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \
|
||||
use_cpp_mamba_cache_manager
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -754,8 +756,12 @@ class Qwen3NextGatedDeltaNet(nn.Module):
|
||||
batch_split_size = [num_prefills, num_decodes]
|
||||
has_initial_states = mamba_metadata.has_initial_states
|
||||
|
||||
state_indices = attn_metadata.kv_cache_manager.get_state_indices(
|
||||
)[:num_prefills + num_decodes]
|
||||
batch_size = num_prefills + num_decodes
|
||||
if use_cpp_mamba_cache_manager():
|
||||
state_indices = mamba_metadata.state_indices[:batch_size]
|
||||
else:
|
||||
state_indices = attn_metadata.kv_cache_manager.get_state_indices(
|
||||
)[:batch_size]
|
||||
|
||||
state_indices_p, state_indices_d = torch.split(state_indices,
|
||||
batch_split_size)
|
||||
@ -1197,8 +1203,6 @@ class Qwen3NextModel(DecoderModel):
|
||||
use_gemma=True,
|
||||
)
|
||||
|
||||
self.mamba_metadata: Optional[Mamba2Metadata] = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
@ -1213,13 +1217,10 @@ class Qwen3NextModel(DecoderModel):
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests:
|
||||
self.mamba_metadata = Mamba2Metadata(
|
||||
attn_metadata.max_num_requests,
|
||||
# chunk_size=self.model_config.pretrained_config.mamba2_chunk_size)
|
||||
# TODO check how to get the correct chunk_size
|
||||
chunk_size=128)
|
||||
self.mamba_metadata.prepare(attn_metadata)
|
||||
mamba_metadata = attn_metadata.mamba_metadata
|
||||
if mamba_metadata.max_batch_size != attn_metadata.max_num_requests:
|
||||
attn_metadata.mamba_metadata = Mamba2Metadata(
|
||||
attn_metadata.max_num_requests, chunk_size=128)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
@ -1233,7 +1234,7 @@ class Qwen3NextModel(DecoderModel):
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
spec_metadata=spec_metadata,
|
||||
mamba_metadata=self.mamba_metadata)
|
||||
mamba_metadata=mamba_metadata)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
@ -19,6 +19,10 @@ from typing import Tuple
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
CUDA_GRAPH_DUMMY_REQUEST_ID
|
||||
from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \
|
||||
use_cpp_mamba_cache_manager
|
||||
|
||||
|
||||
def cu_seqlens_to_chunk_indices_offsets(
|
||||
@ -107,11 +111,47 @@ class Mamba2Metadata:
|
||||
self.chunk_indices: torch.Tensor = None
|
||||
self.chunk_offsets: torch.Tensor = None
|
||||
|
||||
self.state_indices_cpu = torch.zeros(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
self.state_indices = torch.zeros(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
self._query_start_loc_long_buf = torch.arange(0,
|
||||
max_batch_size + 1,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self._query_start_loc_buf = torch.zeros(max_batch_size + 1,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
self.query_start_loc_long = self._query_start_loc_long_buf
|
||||
self.query_start_loc = self._query_start_loc_buf
|
||||
|
||||
def prepare(self, attn_metadata: AttentionMetadata):
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
num_contexts = attn_metadata.num_contexts
|
||||
context_lens = attn_metadata.seq_lens_cuda[:num_contexts]
|
||||
num_ctx_tokens = attn_metadata.num_ctx_tokens
|
||||
|
||||
kv_cache_manager = attn_metadata.kv_cache_manager
|
||||
request_ids = attn_metadata.request_ids
|
||||
|
||||
if (kv_cache_manager is not None
|
||||
and hasattr(kv_cache_manager, 'get_state_indices')
|
||||
and request_ids is not None):
|
||||
if use_cpp_mamba_cache_manager():
|
||||
batch_request_ids = request_ids[:batch_size]
|
||||
is_padding = [
|
||||
req_id == CUDA_GRAPH_DUMMY_REQUEST_ID
|
||||
for req_id in batch_request_ids
|
||||
]
|
||||
indices = kv_cache_manager.get_state_indices(
|
||||
batch_request_ids, is_padding)
|
||||
for i, idx in enumerate(indices):
|
||||
self.state_indices_cpu[i] = idx
|
||||
self.state_indices[:batch_size].copy_(
|
||||
self.state_indices_cpu[:batch_size], non_blocking=True)
|
||||
|
||||
if num_contexts > 0:
|
||||
torch.cumsum(context_lens,
|
||||
dim=0,
|
||||
@ -125,8 +165,14 @@ class Mamba2Metadata:
|
||||
out=self.cu_seqlens[num_contexts + 1:batch_size + 1])
|
||||
# Need both `query_start_loc` and `query_start_loc_long` because `causal_conv1d_fn`
|
||||
# accepts only `int32` while `chunk_gated_delta_rule` accepts only `long`.
|
||||
self.query_start_loc = self.cu_seqlens[:batch_size + 1]
|
||||
self.query_start_loc_long = self.query_start_loc.to(torch.long)
|
||||
self._query_start_loc_buf[:batch_size +
|
||||
1] = self.cu_seqlens[:batch_size + 1]
|
||||
self.query_start_loc = self._query_start_loc_buf[:batch_size + 1]
|
||||
self._query_start_loc_long_buf[:batch_size + 1].copy_(
|
||||
self.query_start_loc.to(torch.long), non_blocking=True)
|
||||
self.query_start_loc_long = self._query_start_loc_long_buf[:
|
||||
batch_size
|
||||
+ 1]
|
||||
self.seq_idx = torch.repeat_interleave(
|
||||
torch.arange(num_contexts,
|
||||
dtype=torch.int,
|
||||
@ -148,8 +194,11 @@ class Mamba2Metadata:
|
||||
self.chunk_offsets = None
|
||||
else:
|
||||
self.query_start_loc = None
|
||||
self.query_start_loc_long = torch.arange(
|
||||
0,
|
||||
batch_size + 1,
|
||||
dtype=torch.long,
|
||||
device=self.cu_seqlens.device)
|
||||
torch.arange(0,
|
||||
batch_size + 1,
|
||||
dtype=torch.long,
|
||||
device=self.cu_seqlens.device,
|
||||
out=self._query_start_loc_long_buf[:batch_size + 1])
|
||||
self.query_start_loc_long = self._query_start_loc_long_buf[:
|
||||
batch_size
|
||||
+ 1]
|
||||
|
||||
@ -21,6 +21,8 @@ from flashinfer.mamba import selective_state_update as selective_state_update_fi
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \
|
||||
use_cpp_mamba_cache_manager
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
@ -204,15 +206,24 @@ class Mamba2Mixer(nn.Module):
|
||||
seqlen_split_size = [num_prefill_tokens, num_decode_tokens]
|
||||
batch_split_size = [num_prefills, num_decodes]
|
||||
|
||||
state_indices = attn_metadata.kv_cache_manager.get_state_indices(
|
||||
)[:num_prefills + num_decodes]
|
||||
if use_cpp_mamba_cache_manager():
|
||||
state_indices = mamba_metadata.state_indices[:num_prefills +
|
||||
num_decodes]
|
||||
conv_states = attn_metadata.kv_cache_manager.get_conv_states(
|
||||
self.layer_idx)
|
||||
ssm_states = attn_metadata.kv_cache_manager.get_ssm_states(
|
||||
self.layer_idx)
|
||||
layer_cache = None # Not used in C++ path
|
||||
else:
|
||||
state_indices = attn_metadata.kv_cache_manager.get_state_indices(
|
||||
)[:num_prefills + num_decodes]
|
||||
layer_cache = attn_metadata.kv_cache_manager.mamba_layer_cache(
|
||||
self.layer_idx)
|
||||
conv_states = layer_cache.conv
|
||||
ssm_states = layer_cache.temporal
|
||||
|
||||
state_indices_p, state_indices_d = torch.split(state_indices,
|
||||
batch_split_size)
|
||||
layer_cache = attn_metadata.kv_cache_manager.mamba_layer_cache(
|
||||
self.layer_idx)
|
||||
conv_states = layer_cache.conv
|
||||
ssm_states = layer_cache.temporal
|
||||
|
||||
# in_proj
|
||||
zxbcdt = self.in_proj(hidden_states)
|
||||
@ -310,6 +321,9 @@ class Mamba2Mixer(nn.Module):
|
||||
is_target_verify = attn_metadata.kv_cache_manager.is_speculative(
|
||||
) and spec_metadata is not None
|
||||
if is_target_verify:
|
||||
# Speculative decoding only supported with Python path
|
||||
assert layer_cache is not None, \
|
||||
"Speculative decoding requires Python MambaCacheManager"
|
||||
# TODO: support dynamic speculation, will add current_draft_len later [TRTLLM-10319]
|
||||
draft_token_num = spec_metadata.max_draft_len + 1
|
||||
intermediate_conv_states = layer_cache.intermediate_conv_window
|
||||
|
||||
@ -19,7 +19,7 @@ from ..speculative.mtp import SampleStateTensorsMTP
|
||||
from ..speculative.utils import get_draft_kv_cache_manager
|
||||
from ..utils import make_weak_ref, piecewise_cuda_graph
|
||||
from .llm_request import get_draft_token_length
|
||||
from .mamba_cache_manager import MambaCacheManager
|
||||
from .mamba_cache_manager import MambaCacheManager, use_cpp_mamba_cache_manager
|
||||
from .resource_manager import (BaseResourceManager, ResourceManager,
|
||||
ResourceManagerType)
|
||||
from .sampler import SampleStateTensors
|
||||
@ -460,8 +460,8 @@ class CUDAGraphRunner:
|
||||
if spec_res_mgr:
|
||||
spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID])
|
||||
|
||||
# handle special cases of padding requests + MambaCacheManager or MambaHybridCacheManager
|
||||
if isinstance(kv_cache_manager, MambaCacheManager):
|
||||
if (isinstance(kv_cache_manager, MambaCacheManager)
|
||||
and not use_cpp_mamba_cache_manager()):
|
||||
kv_cache_manager.reorder_state_indices_when_padding_requests(
|
||||
batch_size, padding_size)
|
||||
|
||||
|
||||
@ -13,22 +13,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import (
|
||||
BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers)
|
||||
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
import tensorrt_llm.bindings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensorrt_llm._torch.attention_backend.interface import \
|
||||
AttentionMetadata
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import (
|
||||
BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers)
|
||||
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
|
||||
from tensorrt_llm._utils import torch_dtype_to_binding
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
RnnStateManagerCpp = tensorrt_llm.bindings.internal.batch_manager.RnnStateManager
|
||||
WorldConfig = tensorrt_llm.bindings.WorldConfig
|
||||
|
||||
GB = 1 << 30
|
||||
|
||||
@ -42,7 +48,117 @@ def get_tensor_size_bytes(tensor):
|
||||
return 0
|
||||
|
||||
|
||||
class MambaCacheManager(BaseResourceManager):
|
||||
def use_cpp_mamba_cache_manager() -> bool:
|
||||
"""Check if C++ MambaCacheManager should be used.
|
||||
|
||||
Returns True if TRTLLM_USE_CPP_MAMBA='1' is set, False otherwise.
|
||||
By default, PythonMambaCacheManager is used.
|
||||
"""
|
||||
return os.environ.get('TRTLLM_USE_CPP_MAMBA', '0') == '1'
|
||||
|
||||
|
||||
class CppMambaCacheManager(BaseResourceManager):
|
||||
"""C++ backed Mamba cache manager using RnnStateManager bindings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_state: int,
|
||||
d_conv: int,
|
||||
num_heads: int,
|
||||
n_groups: int,
|
||||
head_dim: int,
|
||||
num_layers: int,
|
||||
max_num_sequences: int,
|
||||
mapping: Mapping,
|
||||
dtype: torch.dtype,
|
||||
ssm_cache_dtype: torch.dtype,
|
||||
layer_mask: Optional[List[bool]] = None,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
self.mamba_ssm_cache_dtype = ssm_cache_dtype
|
||||
|
||||
# get tp size
|
||||
tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1
|
||||
world_config = WorldConfig(
|
||||
tensor_parallelism=tp_size,
|
||||
pipeline_parallelism=mapping.pp_size,
|
||||
rank=mapping.rank,
|
||||
gpus_per_node=mapping.gpus_per_node,
|
||||
)
|
||||
|
||||
dtype_binding = torch_dtype_to_binding(dtype)
|
||||
ssm_cache_dtype_binding = torch_dtype_to_binding(
|
||||
ssm_cache_dtype if ssm_cache_dtype is not None else dtype)
|
||||
|
||||
self._stream = stream if stream is not None else torch.cuda.current_stream(
|
||||
)
|
||||
|
||||
pp_layers, _ = get_pp_layers(num_layers, mapping, layer_mask=layer_mask)
|
||||
|
||||
self.mamba_impl = RnnStateManagerCpp(
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
num_heads=num_heads,
|
||||
n_groups=n_groups,
|
||||
head_dim=head_dim,
|
||||
max_batch_size=max_num_sequences,
|
||||
world_config=world_config,
|
||||
stream=self._stream.cuda_stream,
|
||||
dtype=dtype_binding,
|
||||
ssm_cache_dtype=ssm_cache_dtype_binding,
|
||||
pp_layers=pp_layers,
|
||||
)
|
||||
self._max_num_sequences = max_num_sequences
|
||||
|
||||
def get_max_resource_count(self) -> int:
|
||||
# Return the maximum number of sequences that can be cached.
|
||||
return self._max_num_sequences
|
||||
|
||||
def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
|
||||
# For Mamba cache manager, we always need one slot per request.
|
||||
return 1
|
||||
|
||||
def prepare_resources(self, scheduled_batch: ScheduledRequests):
|
||||
context_ids = [
|
||||
i.py_request_id for i in scheduled_batch.context_requests
|
||||
]
|
||||
generation_ids = [
|
||||
i.py_request_id for i in scheduled_batch.generation_requests
|
||||
]
|
||||
request_ids = context_ids + generation_ids
|
||||
self.mamba_impl.allocate_cache_blocks(request_ids)
|
||||
|
||||
def free_resources(self, request: LlmRequest):
|
||||
self.mamba_impl.free_cache_block(request.py_request_id)
|
||||
|
||||
def add_dummy_requests(self, request_ids: List[int], **kwargs):
|
||||
# For CUDA graph dummy requests, the blocks will be allocated
|
||||
# when get_state_indices is called.
|
||||
from .cuda_graph_runner import CUDA_GRAPH_DUMMY_REQUEST_ID
|
||||
request_ids = [
|
||||
rid for rid in request_ids if rid != CUDA_GRAPH_DUMMY_REQUEST_ID
|
||||
]
|
||||
if request_ids:
|
||||
self.mamba_impl.allocate_cache_blocks(request_ids)
|
||||
|
||||
def get_state_indices(self, request_ids: List[int],
|
||||
is_padding: List[bool]) -> List[int]:
|
||||
return self.mamba_impl.get_state_indices(request_ids, is_padding)
|
||||
|
||||
def get_conv_states(self, layer_idx: int) -> torch.Tensor:
|
||||
return self.mamba_impl.get_conv_states(layer_idx)
|
||||
|
||||
def get_ssm_states(self, layer_idx: int) -> torch.Tensor:
|
||||
return self.mamba_impl.get_ssm_states(layer_idx)
|
||||
|
||||
def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
|
||||
return self.mamba_ssm_cache_dtype
|
||||
|
||||
def shutdown(self):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
class PythonMambaCacheManager(BaseResourceManager):
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class State:
|
||||
@ -193,6 +309,17 @@ class MambaCacheManager(BaseResourceManager):
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# Store max_batch_size for resource management
|
||||
self._max_batch_size = max_batch_size
|
||||
|
||||
def get_max_resource_count(self) -> int:
|
||||
"""Return the maximum number of sequences that can be cached."""
|
||||
return self._max_batch_size
|
||||
|
||||
def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
|
||||
"""For Mamba cache manager, we always need one slot per request."""
|
||||
return 1
|
||||
|
||||
@torch.inference_mode()
|
||||
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
|
||||
self.state_indices_list.clear()
|
||||
@ -261,7 +388,9 @@ class MambaCacheManager(BaseResourceManager):
|
||||
block = self.mamba_cache_index.pop(request_id)
|
||||
self.mamba_cache_free_blocks.append(block)
|
||||
|
||||
def get_state_indices(self) -> torch.Tensor:
|
||||
def get_state_indices(self,
|
||||
request_ids: List[int] = None,
|
||||
is_padding: List[bool] = None) -> torch.Tensor:
|
||||
return self.state_indices
|
||||
|
||||
def get_conv_states(self, layer_idx: int) -> torch.Tensor:
|
||||
@ -296,6 +425,9 @@ class MambaCacheManager(BaseResourceManager):
|
||||
layer_offset = self.mamba_layer_offsets[layer_idx]
|
||||
return self.mamba_cache.at_layer_idx(layer_offset)
|
||||
|
||||
def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
|
||||
return self.mamba_ssm_cache_dtype
|
||||
|
||||
def shutdown(self):
|
||||
"""Release tensor memory."""
|
||||
# Clear state indices
|
||||
@ -346,6 +478,140 @@ class MambaCacheManager(BaseResourceManager):
|
||||
conv_states[:, state_indices_d, :] = accepted_conv_state
|
||||
|
||||
|
||||
class MambaCacheManager(BaseResourceManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_state: int,
|
||||
d_conv: int,
|
||||
num_heads: int,
|
||||
n_groups: int,
|
||||
head_dim: int,
|
||||
num_layers: int,
|
||||
max_batch_size: int,
|
||||
spec_state_size: int,
|
||||
mapping: Mapping,
|
||||
dtype: torch.dtype,
|
||||
ssm_cache_dtype: torch.dtype,
|
||||
layer_mask: Optional[List[bool]] = None,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
speculative_num_draft_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
max_num_sequences = max_batch_size * mapping.pp_size
|
||||
self._use_cpp = use_cpp_mamba_cache_manager()
|
||||
|
||||
if self._use_cpp:
|
||||
assert speculative_num_draft_tokens is None, \
|
||||
"speculative_num_draft_tokens is not supported in CppMambaCacheManager"
|
||||
self._impl = CppMambaCacheManager(
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
num_heads=num_heads,
|
||||
n_groups=n_groups,
|
||||
head_dim=head_dim,
|
||||
num_layers=num_layers,
|
||||
max_num_sequences=max_num_sequences,
|
||||
mapping=mapping,
|
||||
dtype=dtype,
|
||||
ssm_cache_dtype=ssm_cache_dtype,
|
||||
layer_mask=layer_mask,
|
||||
stream=stream,
|
||||
)
|
||||
else:
|
||||
self._impl = PythonMambaCacheManager(
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
num_heads=num_heads,
|
||||
n_groups=n_groups,
|
||||
head_dim=head_dim,
|
||||
num_layers=num_layers,
|
||||
max_batch_size=max_batch_size,
|
||||
spec_state_size=spec_state_size,
|
||||
mapping=mapping,
|
||||
dtype=dtype,
|
||||
ssm_cache_dtype=ssm_cache_dtype,
|
||||
layer_mask=layer_mask,
|
||||
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
||||
)
|
||||
|
||||
def get_max_resource_count(self) -> int:
|
||||
return self._impl.get_max_resource_count()
|
||||
|
||||
def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
|
||||
return self._impl.get_needed_resource_to_completion(request)
|
||||
|
||||
def prepare_resources(self, scheduled_batch: ScheduledRequests):
|
||||
self._impl.prepare_resources(scheduled_batch)
|
||||
|
||||
def free_resources(self, request: LlmRequest):
|
||||
self._impl.free_resources(request)
|
||||
|
||||
def add_dummy_requests(self, request_ids: List[int], **kwargs):
|
||||
if self._use_cpp:
|
||||
self._impl.add_dummy_requests(request_ids, **kwargs)
|
||||
|
||||
def get_state_indices(
|
||||
self,
|
||||
request_ids: Optional[List[int]] = None,
|
||||
is_padding: Optional[List[bool]] = None
|
||||
) -> Union[torch.Tensor, List[int]]:
|
||||
return self._impl.get_state_indices(request_ids, is_padding)
|
||||
|
||||
def reorder_state_indices_when_padding_requests(self, request_size: int,
|
||||
padding_size: int):
|
||||
assert not self._use_cpp, "reorder_state_indices_when_padding_requests is not supported in CppMambaCacheManager"
|
||||
self._impl.reorder_state_indices_when_padding_requests(
|
||||
request_size, padding_size)
|
||||
|
||||
@property
|
||||
def mamba_cache_free_blocks(self) -> List[int]:
|
||||
assert not self._use_cpp, "mamba_cache_free_blocks is not supported in CppMambaCacheManager"
|
||||
return self._impl.mamba_cache_free_blocks
|
||||
|
||||
@property
|
||||
def mamba_cache_index(self) -> Dict[int, int]:
|
||||
assert not self._use_cpp, "mamba_cache_index is not supported in CppMambaCacheManager"
|
||||
return self._impl.mamba_cache_index
|
||||
|
||||
def get_conv_states(self, layer_idx: int) -> torch.Tensor:
|
||||
return self._impl.get_conv_states(layer_idx)
|
||||
|
||||
def get_ssm_states(self, layer_idx: int) -> torch.Tensor:
|
||||
return self._impl.get_ssm_states(layer_idx)
|
||||
|
||||
def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
|
||||
return self._impl.get_mamba_ssm_cache_dtype()
|
||||
|
||||
def get_intermediate_ssm_states(self,
|
||||
layer_idx: int) -> Optional[torch.Tensor]:
|
||||
assert not self._use_cpp, "get_intermediate_ssm_states is not supported in CppMambaCacheManager"
|
||||
return self._impl.get_intermediate_ssm_states(layer_idx)
|
||||
|
||||
def get_intermediate_conv_states(self,
|
||||
layer_idx: int) -> Optional[torch.Tensor]:
|
||||
assert not self._use_cpp, "get_intermediate_conv_states is not supported in CppMambaCacheManager"
|
||||
return self._impl.get_intermediate_conv_states(layer_idx)
|
||||
|
||||
def is_speculative(self) -> bool:
|
||||
assert not self._use_cpp, "is_speculative is not supported in CppMambaCacheManager"
|
||||
return self._impl.is_speculative()
|
||||
|
||||
def mamba_layer_cache(
|
||||
self, layer_idx: int
|
||||
) -> Union[PythonMambaCacheManager.State,
|
||||
PythonMambaCacheManager.SpeculativeState, None]:
|
||||
assert not self._use_cpp, "mamba_layer_cache is not supported in CppMambaCacheManager"
|
||||
return self._impl.mamba_layer_cache(layer_idx)
|
||||
|
||||
def shutdown(self):
|
||||
self._impl.shutdown()
|
||||
|
||||
def update_mamba_states(self, attn_metadata: "AttentionMetadata",
|
||||
num_accepted_tokens: torch.Tensor):
|
||||
assert self._use_cpp, "update_mamba_states is not supported in PythonMambaCacheManager"
|
||||
self._impl.update_mamba_states(attn_metadata, num_accepted_tokens)
|
||||
|
||||
|
||||
class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
|
||||
|
||||
def __init__(
|
||||
@ -399,6 +665,7 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
|
||||
mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype,
|
||||
mamba_layer_mask,
|
||||
execution_stream,
|
||||
speculative_num_draft_tokens=spec_config.max_draft_len
|
||||
if spec_config is not None else None,
|
||||
)
|
||||
@ -430,6 +697,10 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
|
||||
MambaCacheManager.free_resources(self, request)
|
||||
KVCacheManager.free_resources(self, request)
|
||||
|
||||
def add_dummy_requests(self, request_ids: List[int], **kwargs):
|
||||
MambaCacheManager.add_dummy_requests(self, request_ids)
|
||||
return KVCacheManager.add_dummy_requests(self, request_ids, **kwargs)
|
||||
|
||||
def shutdown(self):
|
||||
MambaCacheManager.shutdown(self)
|
||||
KVCacheManager.shutdown(self)
|
||||
|
||||
@ -2777,6 +2777,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
num_extra_kv_tokens=get_num_extra_kv_tokens(spec_config))
|
||||
attn_metadata.kv_cache_manager = kv_cache_manager
|
||||
|
||||
if hasattr(self.model.model_config.pretrained_config, 'chunk_size'):
|
||||
attn_metadata.mamba_chunk_size = self.model.model_config.pretrained_config.chunk_size
|
||||
attn_metadata.prepare()
|
||||
|
||||
peft_cache_manager = resource_manager and resource_manager.get_resource_manager(
|
||||
|
||||
@ -884,7 +884,7 @@ class Runner:
|
||||
else:
|
||||
raise NotImplementedError("Unsupported config")
|
||||
kv_cache_manager.add_dummy_requests(
|
||||
list(range(max_batch_size)), [max_seq_len] * max_batch_size
|
||||
list(range(max_batch_size)), token_nums=[max_seq_len] * max_batch_size
|
||||
)
|
||||
return kv_cache_manager
|
||||
|
||||
|
||||
@ -5690,6 +5690,17 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_mpi_world_size(4)
|
||||
@pytest.mark.skip_less_device_memory(40000)
|
||||
@pytest.mark.parametrize(
|
||||
"use_cpp_mamba",
|
||||
[
|
||||
False,
|
||||
True,
|
||||
],
|
||||
ids=[
|
||||
"python_mamba_cache",
|
||||
"cpp_mamba_cache",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"attention_dp",
|
||||
[
|
||||
@ -5701,7 +5712,10 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
|
||||
"attention_dp_on",
|
||||
],
|
||||
)
|
||||
def test_fp8_4gpus(self, attention_dp):
|
||||
def test_fp8_4gpus(self, attention_dp, use_cpp_mamba, monkeypatch):
|
||||
monkeypatch.setenv("TRTLLM_USE_CPP_MAMBA",
|
||||
"1" if use_cpp_mamba else "0")
|
||||
|
||||
with LLM(
|
||||
f"{llm_models_root()}/Nemotron-SuperV3-phase1-mtp-fp8-fp8kv",
|
||||
kv_cache_config=KvCacheConfig(
|
||||
|
||||
@ -264,8 +264,10 @@ accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-True-True]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-python_mamba_cache]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-cpp_mamba_cache]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-python_mamba_cache]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-cpp_mamba_cache]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on-trtllm]
|
||||
|
||||
# multimodal accuracy tests
|
||||
|
||||
@ -104,8 +104,10 @@ l0_dgx_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-python_mamba_cache]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_off-cpp_mamba_cache]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-python_mamba_cache]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-cpp_mamba_cache]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_bs1
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8]
|
||||
- test_e2e.py::test_trtllm_bench_llmapi_launch[pytorch_backend-llama-v3-llama3-8b]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user