[TRTLLM-4923][feat] Paged mamba cache (#4822)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91 2025-06-04 09:27:08 +03:00 committed by GitHub
parent e71de2a13e
commit 8d31e16877
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 66 additions and 42 deletions

View File

@ -201,14 +201,15 @@ class Mamba2Mixer(nn.Module):
for req_type in batch:
if not is_warmup:
indices = split_indices[req_type].to(torch.long).to(
torch.device("cuda"))
indices = split_indices[req_type].to(torch.device("cuda"))
conv_states = attn_metadata.kv_cache_manager.get_conv_states(
self.layer_idx)
ssm_states = attn_metadata.kv_cache_manager.get_ssm_states(
self.layer_idx)
else:
indices = None
conv_states = None
ssm_states = None
z, xbc, dt = torch.split(
split_zxbcdt[req_type],
@ -234,15 +235,13 @@ class Mamba2Mixer(nn.Module):
cu_seqlens.diff(),
output_size=cu_seqlens[-1]).unsqueeze(0)
current_conv_states = torch.empty_like(
conv_states[indices, ...]) if not is_warmup else None
xbc = causal_conv1d_fn(xbc.transpose(0, 1),
self.conv1d.weight,
self.conv1d.bias,
activation="silu",
conv_states=current_conv_states,
query_start_loc=cu_seqlens).transpose(
0, 1)
conv_states=conv_states,
query_start_loc=cu_seqlens,
cache_indices=indices).transpose(0, 1)
x, B, C = torch.split(xbc.unsqueeze(0), [
self.tp_d_inner,
@ -278,18 +277,18 @@ class Mamba2Mixer(nn.Module):
)
y = rearrange(y, "b l h p -> (b l) (h p)")
# copy new ssm state
if not is_warmup:
ssm_states[indices] = current_ssm_states
# decode
else:
# get conv and ssm states for decode
if not is_warmup:
current_conv_states = conv_states[indices]
current_ssm_states = ssm_states[indices]
# update conv states
xbc = causal_conv1d_update(xbc, current_conv_states,
self.conv1d.weight, self.conv1d.bias,
"silu")
xbc = causal_conv1d_update(xbc,
conv_states,
self.conv1d.weight,
self.conv1d.bias,
activation="silu",
conv_state_indices=indices)
x, B, C = torch.split(
xbc,
@ -314,7 +313,7 @@ class Mamba2Mixer(nn.Module):
z = rearrange(z, "b (h p) -> b h p", p=self.head_dim)
y = selective_state_update(
current_ssm_states,
ssm_states,
x_reshaped,
dt,
A,
@ -324,6 +323,7 @@ class Mamba2Mixer(nn.Module):
z=z,
dt_bias=dt_bias,
dt_softplus=self.delta_softplus,
state_batch_indices=indices,
)
y = rearrange(y, "b h p -> b (h p)")
@ -331,11 +331,6 @@ class Mamba2Mixer(nn.Module):
# gated norm
y = self.norm(y)
# copy new conv and ssm states
if not is_warmup:
conv_states.index_copy_(0, indices, current_conv_states)
ssm_states.index_copy_(0, indices, current_ssm_states)
# append output
out.append(y)

View File

@ -593,7 +593,7 @@ class MambaCacheManager(BaseResourceManager):
block = self.mamba_cache_free_blocks.pop()
self.mamba_cache_index[r] = block
state_indices.append(block)
self.state_indices = torch.as_tensor(state_indices, dtype=torch.long)
self.state_indices = torch.as_tensor(state_indices, dtype=torch.int32)
def free_mamba_cache_blocks(self, request_id: int):
if request_id in self.mamba_cache_index:

View File

@ -26,11 +26,11 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory
@pytest.mark.parametrize(
"dim, dconv, req_type, dtype, batch_size, max_seq_len, remove_padding, apply_silu",
"dim, dconv, req_type, dtype, batch_size, max_seq_len, remove_padding, apply_silu, paged_cache",
list(
product([2048], [4], ['context', 'generation'],
['float16', 'float32', 'bfloat16'], [5], [16], [False, True],
[False, True])) +
[False, True], [False, True])) +
# long sequence tests to cover the int overflow issue
list(
map(
@ -42,9 +42,9 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory
"The long sequence test needs at least 33GB memory, skipping"
)),
product([5376], [4], ['context'], ['float16', 'bfloat16'], [2],
[131072], [False, True], [False, True]))))
[131072], [False, True], [False, True], [False]))))
def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len,
remove_padding, apply_silu):
remove_padding, apply_silu, paged_cache):
device = "cuda"
seq_len = max_seq_len if req_type == "context" else 1
mean = 0.0
@ -94,11 +94,22 @@ def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len,
else:
x_in_out = x.detach().clone()
conv_state_in_out = conv_state.detach().clone()
if paged_cache:
padded_batch_size = 2 * batch_size
cache_indices = torch.randperm(padded_batch_size,
device=device,
dtype=torch.int32)[:batch_size]
conv_state_in_out = torch.empty([padded_batch_size, dim, dconv - 1],
dtype=torch_dtype,
device=device)
conv_state_in_out[cache_indices] = conv_state.detach().clone()
else:
cache_indices = None
conv_state_in_out = conv_state.detach().clone()
conv_weight_input = conv_weight.squeeze(1).contiguous()
if req_type == "context":
cache_indices = None
has_initial_state = None
torch.ops.trtllm.causal_conv1d_fwd(
@ -112,11 +123,12 @@ def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len,
apply_silu,
PAD_SLOT_ID,
)
outputs = (x_in_out, conv_state_in_out)
outputs = (x_in_out, conv_state_in_out[cache_indices]
if cache_indices is not None else conv_state_in_out)
else:
conv_state_indices = cache_indices
cache_seqlens = None
conv_state_indices = None
torch.ops.trtllm.causal_conv1d_update(
x_in_out,
@ -128,7 +140,8 @@ def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len,
conv_state_indices,
PAD_SLOT_ID,
)
outputs = (x_in_out, conv_state_in_out)
outputs = (x_in_out, conv_state_in_out[cache_indices]
if cache_indices is not None else conv_state_in_out)
out_ref = torch.zeros_like(x)
conv_state_ref = torch.zeros_like(conv_state)

View File

@ -30,21 +30,21 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory
@pytest.mark.parametrize(
"dim, headdim, ngroups, dstate, req_type, dtype, batch_size, max_seq_len, has_z, remove_padding",
"dim, headdim, ngroups, dstate, req_type, dtype, batch_size, max_seq_len, has_z, remove_padding, paged_cache",
# P=8x and H=2x
list(
product([160, 320, 640], [80], [1], [128], ['context'], ['float16'],
[1, 2, 8, 16], [16, 64, 256], [True], [True])) +
[1, 2, 8, 16], [16, 64, 256], [True], [True], [False])) +
# normal tests
list(
product([2048], [64], [1, 4], [128], ['context', 'generation'],
['float32', 'float16', 'bfloat16'], [3], [16], [True, False],
[True, False])) +
['float32', 'float16', 'bfloat16'], [3], [16], [False],
[True, False], [True, False])) +
# arbitrary N generation tests
list(
product([2048], [64], [1, 4], [16, 32, 48, 64, 80, 96, 128, 256],
['generation'], ['float32', 'float16'], [3], [16], [True],
[True])) +
[True], [False])) +
# long sequence tests to cover the int overflow issue
list(
map(
@ -56,16 +56,17 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory
"The long sequence test needs at least 68GB memory, skipping"
)),
product([5120], [64], [1], [128], ['context'], ['float16'], [2],
[131072], [True, False], [True, False]))) +
[131072], [True, False], [True, False], [False]))) +
# P=8x and H=2x
list(
product([144], [72], [1], [64, 128, 256], ['context', 'generation'],
['float16'], [16], [16384], [True, False], [True, False])),
['float16'], [16], [16384], [True, False], [True, False],
[False])),
)
def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
req_type, dtype, batch_size,
max_seq_len, has_z,
remove_padding):
remove_padding, paged_cache):
# configs
device = "cuda"
seq_len = max_seq_len if req_type == 'context' else 1
@ -190,6 +191,19 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
outputs = (out, ssm_state)
else:
if paged_cache:
padded_batch_size = 2 * batch_size
state_batch_indices = torch.randperm(padded_batch_size,
device=device,
dtype=torch.int32)[:batch_size]
orig_state = state.detach().clone()
state = torch.empty([padded_batch_size, nheads, headdim, dstate],
dtype=torch_dtype,
device=device)
state[state_batch_indices] = orig_state
else:
state_batch_indices = None
y = selective_state_update(
state,
x,
@ -201,8 +215,10 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
z=z if has_z else None,
dt_bias=dt_bias,
dt_softplus=delta_softplus,
state_batch_indices=state_batch_indices,
)
outputs = (y, state)
outputs = (y, state[state_batch_indices]
if state_batch_indices is not None else state)
# pytorch run
if req_type == 'context':