[TRTLLM-9676][fix] Fix mamba_cache_manager when enabling cuda_graph_padding and let test cover this case (#9873)

Signed-off-by: JadoTu <107457950+JadoTu@users.noreply.github.com>
This commit is contained in:
JadoTu 2026-01-09 14:50:16 +08:00 committed by GitHub
parent c5331e6dbb
commit 4c498bfe58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 10 deletions

View File

@ -18,6 +18,7 @@ from ..speculative.eagle3 import Eagle3ResourceManager
from ..speculative.mtp import SampleStateTensorsMTP
from ..utils import make_weak_ref, piecewise_cuda_graph
from .llm_request import get_draft_token_length
from .mamba_cache_manager import MambaCacheManager
from .resource_manager import (BaseResourceManager, ResourceManager,
ResourceManagerType)
from .sampler import SampleStateTensors
@ -450,6 +451,11 @@ class CUDAGraphRunner:
if spec_res_mgr:
spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID])
# handle special cases of padding requests + MambaCacheManager or MambaHybridCacheManager
if isinstance(kv_cache_manager, MambaCacheManager):
kv_cache_manager.reorder_state_indices_when_padding_requests(
batch_size, padding_size)
self.padding_dummy_request.py_draft_tokens = [0] * runtime_draft_len
batch.generation_requests.extend([self.padding_dummy_request] *
padding_size)

View File

@ -109,23 +109,59 @@ class MambaCacheManager(BaseResourceManager):
self.state_indices: torch.Tensor = torch.arange(max_batch_size,
device=device,
dtype=torch.int32)
# save mamba state indices for requests
self.state_indices_list: List[int] = []
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
state_indices = []
self.state_indices_list.clear()
for r in request_ids:
# cache hit
if r in self.mamba_cache_index:
state_indices.append(self.mamba_cache_index[r])
self.state_indices_list.append(self.mamba_cache_index[r])
# cache miss
else:
if len(self.mamba_cache_free_blocks) == 0:
raise Exception("run out of mamba cache blocks")
block = self.mamba_cache_free_blocks.pop()
self.mamba_cache_index[r] = block
state_indices.append(block)
self.state_indices[:len(state_indices)].copy_(torch.tensor(
state_indices, dtype=torch.int32, pin_memory=True),
non_blocking=True)
self.state_indices_list.append(block)
self.state_indices[:len(self.state_indices_list)].copy_(
torch.tensor(self.state_indices_list,
dtype=torch.int32,
pin_memory=True),
non_blocking=True)
# When there exists padded requests, the state indices should not be repeated.
def reorder_state_indices_when_padding_requests(self, request_size,
padding_size):
if padding_size == 0:
return
assert request_size + padding_size <= self.state_indices.numel(
), "Padding requests run out of available mamba cache blocks"
# we can use mamba_cache_free_blocks for padding_requests
if padding_size <= len(self.mamba_cache_free_blocks):
self.state_indices[request_size:request_size +
padding_size] = torch.tensor(
self.mamba_cache_free_blocks[:padding_size],
dtype=self.state_indices.dtype,
pin_memory=True).to(
self.state_indices.device,
non_blocking=True)
# But just finished requests won't free their used resources immediately
# In explicit, the running order is self.scheduler.schedule_request, self._forward_step() and self._process_previous_batch() in the PyExecutor.
# In this way, the current forward step will remove finished requests but will not remove mamba_cache immediately.
else:
all_mamba_cache_indices = set(range(self.state_indices.numel()))
allocated_indices = set(self.state_indices_list)
free_indices = list(all_mamba_cache_indices - allocated_indices)
self.state_indices[request_size:request_size +
padding_size] = torch.tensor(
free_indices[:padding_size],
dtype=self.state_indices.dtype,
pin_memory=True).to(
self.state_indices.device,
non_blocking=True)
def prepare_resources(self, scheduled_batch: ScheduledRequests):
context_ids = [

View File

@ -4846,10 +4846,12 @@ class TestQwen3NextInstruct(LlmapiAccuracyTestHarness):
model_path = f"{self.MODEL_PATH}/Qwen3-Next-80B-A3B-Instruct"
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
enable_block_reuse=False)
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig(
max_batch_size=512, enable_padding=True)
if cuda_graph else None)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig(
enable_padding=True,
batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048])
if cuda_graph else None)
with LLM(
model_path,
@ -4864,6 +4866,7 @@ class TestQwen3NextInstruct(LlmapiAccuracyTestHarness):
task.evaluate(llm)
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN",
self.GSM8K_MAX_OUTPUT_LEN)
mocker.patch.object(GSM8K, "NUM_SAMPLES", 1319)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)