diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 187566f62e..852f2e063d 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -18,6 +18,7 @@ from ..speculative.eagle3 import Eagle3ResourceManager from ..speculative.mtp import SampleStateTensorsMTP from ..utils import make_weak_ref, piecewise_cuda_graph from .llm_request import get_draft_token_length +from .mamba_cache_manager import MambaCacheManager from .resource_manager import (BaseResourceManager, ResourceManager, ResourceManagerType) from .sampler import SampleStateTensors @@ -450,6 +451,11 @@ class CUDAGraphRunner: if spec_res_mgr: spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID]) + # handle special cases of padding requests + MambaCacheManager or MambaHybridCacheManager + if isinstance(kv_cache_manager, MambaCacheManager): + kv_cache_manager.reorder_state_indices_when_padding_requests( + batch_size, padding_size) + self.padding_dummy_request.py_draft_tokens = [0] * runtime_draft_len batch.generation_requests.extend([self.padding_dummy_request] * padding_size) diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index 228b5ef0f8..dfc497f18e 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -109,23 +109,59 @@ class MambaCacheManager(BaseResourceManager): self.state_indices: torch.Tensor = torch.arange(max_batch_size, device=device, dtype=torch.int32) + # save mamba state indices for requests + self.state_indices_list: List[int] = [] def _prepare_mamba_cache_blocks(self, request_ids: List[int]): - state_indices = [] + self.state_indices_list.clear() for r in request_ids: # cache hit if r in self.mamba_cache_index: - state_indices.append(self.mamba_cache_index[r]) + self.state_indices_list.append(self.mamba_cache_index[r]) # cache miss else: if len(self.mamba_cache_free_blocks) == 0: raise Exception("run out of mamba cache blocks") block = self.mamba_cache_free_blocks.pop() self.mamba_cache_index[r] = block - state_indices.append(block) - self.state_indices[:len(state_indices)].copy_(torch.tensor( - state_indices, dtype=torch.int32, pin_memory=True), - non_blocking=True) + self.state_indices_list.append(block) + self.state_indices[:len(self.state_indices_list)].copy_( + torch.tensor(self.state_indices_list, + dtype=torch.int32, + pin_memory=True), + non_blocking=True) + + # When there exists padded requests, the state indices should not be repeated. + def reorder_state_indices_when_padding_requests(self, request_size, + padding_size): + if padding_size == 0: + return + + assert request_size + padding_size <= self.state_indices.numel( + ), "Padding requests run out of available mamba cache blocks" + # we can use mamba_cache_free_blocks for padding_requests + if padding_size <= len(self.mamba_cache_free_blocks): + self.state_indices[request_size:request_size + + padding_size] = torch.tensor( + self.mamba_cache_free_blocks[:padding_size], + dtype=self.state_indices.dtype, + pin_memory=True).to( + self.state_indices.device, + non_blocking=True) + # But just finished requests won't free their used resources immediately + # In explicit, the running order is self.scheduler.schedule_request, self._forward_step() and self._process_previous_batch() in the PyExecutor. + # In this way, the current forward step will remove finished requests but will not remove mamba_cache immediately. + else: + all_mamba_cache_indices = set(range(self.state_indices.numel())) + allocated_indices = set(self.state_indices_list) + free_indices = list(all_mamba_cache_indices - allocated_indices) + self.state_indices[request_size:request_size + + padding_size] = torch.tensor( + free_indices[:padding_size], + dtype=self.state_indices.dtype, + pin_memory=True).to( + self.state_indices.device, + non_blocking=True) def prepare_resources(self, scheduled_batch: ScheduledRequests): context_ids = [ diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 1a32e333b5..01f586accf 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -4846,10 +4846,12 @@ class TestQwen3NextInstruct(LlmapiAccuracyTestHarness): model_path = f"{self.MODEL_PATH}/Qwen3-Next-80B-A3B-Instruct" kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6, enable_block_reuse=False) - pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler, - cuda_graph_config=CudaGraphConfig( - max_batch_size=512, enable_padding=True) - if cuda_graph else None) + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig( + enable_padding=True, + batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]) + if cuda_graph else None) with LLM( model_path, @@ -4864,6 +4866,7 @@ class TestQwen3NextInstruct(LlmapiAccuracyTestHarness): task.evaluate(llm) mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", self.GSM8K_MAX_OUTPUT_LEN) + mocker.patch.object(GSM8K, "NUM_SAMPLES", 1319) task = GSM8K(self.MODEL_NAME) task.evaluate(llm)