mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9676][fix] Fix mamba_cache_manager when enabling cuda_graph_padding and let test cover this case (#9873)
Signed-off-by: JadoTu <107457950+JadoTu@users.noreply.github.com>
This commit is contained in:
parent
c5331e6dbb
commit
4c498bfe58
@ -18,6 +18,7 @@ from ..speculative.eagle3 import Eagle3ResourceManager
|
||||
from ..speculative.mtp import SampleStateTensorsMTP
|
||||
from ..utils import make_weak_ref, piecewise_cuda_graph
|
||||
from .llm_request import get_draft_token_length
|
||||
from .mamba_cache_manager import MambaCacheManager
|
||||
from .resource_manager import (BaseResourceManager, ResourceManager,
|
||||
ResourceManagerType)
|
||||
from .sampler import SampleStateTensors
|
||||
@ -450,6 +451,11 @@ class CUDAGraphRunner:
|
||||
if spec_res_mgr:
|
||||
spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID])
|
||||
|
||||
# handle special cases of padding requests + MambaCacheManager or MambaHybridCacheManager
|
||||
if isinstance(kv_cache_manager, MambaCacheManager):
|
||||
kv_cache_manager.reorder_state_indices_when_padding_requests(
|
||||
batch_size, padding_size)
|
||||
|
||||
self.padding_dummy_request.py_draft_tokens = [0] * runtime_draft_len
|
||||
batch.generation_requests.extend([self.padding_dummy_request] *
|
||||
padding_size)
|
||||
|
||||
@ -109,22 +109,58 @@ class MambaCacheManager(BaseResourceManager):
|
||||
self.state_indices: torch.Tensor = torch.arange(max_batch_size,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
# save mamba state indices for requests
|
||||
self.state_indices_list: List[int] = []
|
||||
|
||||
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
|
||||
state_indices = []
|
||||
self.state_indices_list.clear()
|
||||
for r in request_ids:
|
||||
# cache hit
|
||||
if r in self.mamba_cache_index:
|
||||
state_indices.append(self.mamba_cache_index[r])
|
||||
self.state_indices_list.append(self.mamba_cache_index[r])
|
||||
# cache miss
|
||||
else:
|
||||
if len(self.mamba_cache_free_blocks) == 0:
|
||||
raise Exception("run out of mamba cache blocks")
|
||||
block = self.mamba_cache_free_blocks.pop()
|
||||
self.mamba_cache_index[r] = block
|
||||
state_indices.append(block)
|
||||
self.state_indices[:len(state_indices)].copy_(torch.tensor(
|
||||
state_indices, dtype=torch.int32, pin_memory=True),
|
||||
self.state_indices_list.append(block)
|
||||
self.state_indices[:len(self.state_indices_list)].copy_(
|
||||
torch.tensor(self.state_indices_list,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True),
|
||||
non_blocking=True)
|
||||
|
||||
# When there exists padded requests, the state indices should not be repeated.
|
||||
def reorder_state_indices_when_padding_requests(self, request_size,
|
||||
padding_size):
|
||||
if padding_size == 0:
|
||||
return
|
||||
|
||||
assert request_size + padding_size <= self.state_indices.numel(
|
||||
), "Padding requests run out of available mamba cache blocks"
|
||||
# we can use mamba_cache_free_blocks for padding_requests
|
||||
if padding_size <= len(self.mamba_cache_free_blocks):
|
||||
self.state_indices[request_size:request_size +
|
||||
padding_size] = torch.tensor(
|
||||
self.mamba_cache_free_blocks[:padding_size],
|
||||
dtype=self.state_indices.dtype,
|
||||
pin_memory=True).to(
|
||||
self.state_indices.device,
|
||||
non_blocking=True)
|
||||
# But just finished requests won't free their used resources immediately
|
||||
# In explicit, the running order is self.scheduler.schedule_request, self._forward_step() and self._process_previous_batch() in the PyExecutor.
|
||||
# In this way, the current forward step will remove finished requests but will not remove mamba_cache immediately.
|
||||
else:
|
||||
all_mamba_cache_indices = set(range(self.state_indices.numel()))
|
||||
allocated_indices = set(self.state_indices_list)
|
||||
free_indices = list(all_mamba_cache_indices - allocated_indices)
|
||||
self.state_indices[request_size:request_size +
|
||||
padding_size] = torch.tensor(
|
||||
free_indices[:padding_size],
|
||||
dtype=self.state_indices.dtype,
|
||||
pin_memory=True).to(
|
||||
self.state_indices.device,
|
||||
non_blocking=True)
|
||||
|
||||
def prepare_resources(self, scheduled_batch: ScheduledRequests):
|
||||
|
||||
@ -4846,9 +4846,11 @@ class TestQwen3NextInstruct(LlmapiAccuracyTestHarness):
|
||||
model_path = f"{self.MODEL_PATH}/Qwen3-Next-80B-A3B-Instruct"
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
|
||||
enable_block_reuse=False)
|
||||
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig(
|
||||
max_batch_size=512, enable_padding=True)
|
||||
enable_padding=True,
|
||||
batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048])
|
||||
if cuda_graph else None)
|
||||
|
||||
with LLM(
|
||||
@ -4864,6 +4866,7 @@ class TestQwen3NextInstruct(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN",
|
||||
self.GSM8K_MAX_OUTPUT_LEN)
|
||||
mocker.patch.object(GSM8K, "NUM_SAMPLES", 1319)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user