mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-4923][feat] Paged mamba cache (#4822)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
parent
e71de2a13e
commit
8d31e16877
@ -201,14 +201,15 @@ class Mamba2Mixer(nn.Module):
|
||||
for req_type in batch:
|
||||
|
||||
if not is_warmup:
|
||||
indices = split_indices[req_type].to(torch.long).to(
|
||||
torch.device("cuda"))
|
||||
indices = split_indices[req_type].to(torch.device("cuda"))
|
||||
conv_states = attn_metadata.kv_cache_manager.get_conv_states(
|
||||
self.layer_idx)
|
||||
ssm_states = attn_metadata.kv_cache_manager.get_ssm_states(
|
||||
self.layer_idx)
|
||||
else:
|
||||
indices = None
|
||||
conv_states = None
|
||||
ssm_states = None
|
||||
|
||||
z, xbc, dt = torch.split(
|
||||
split_zxbcdt[req_type],
|
||||
@ -234,15 +235,13 @@ class Mamba2Mixer(nn.Module):
|
||||
cu_seqlens.diff(),
|
||||
output_size=cu_seqlens[-1]).unsqueeze(0)
|
||||
|
||||
current_conv_states = torch.empty_like(
|
||||
conv_states[indices, ...]) if not is_warmup else None
|
||||
xbc = causal_conv1d_fn(xbc.transpose(0, 1),
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias,
|
||||
activation="silu",
|
||||
conv_states=current_conv_states,
|
||||
query_start_loc=cu_seqlens).transpose(
|
||||
0, 1)
|
||||
conv_states=conv_states,
|
||||
query_start_loc=cu_seqlens,
|
||||
cache_indices=indices).transpose(0, 1)
|
||||
|
||||
x, B, C = torch.split(xbc.unsqueeze(0), [
|
||||
self.tp_d_inner,
|
||||
@ -278,18 +277,18 @@ class Mamba2Mixer(nn.Module):
|
||||
)
|
||||
y = rearrange(y, "b l h p -> (b l) (h p)")
|
||||
|
||||
# copy new ssm state
|
||||
if not is_warmup:
|
||||
ssm_states[indices] = current_ssm_states
|
||||
|
||||
# decode
|
||||
else:
|
||||
|
||||
# get conv and ssm states for decode
|
||||
if not is_warmup:
|
||||
current_conv_states = conv_states[indices]
|
||||
current_ssm_states = ssm_states[indices]
|
||||
|
||||
# update conv states
|
||||
xbc = causal_conv1d_update(xbc, current_conv_states,
|
||||
self.conv1d.weight, self.conv1d.bias,
|
||||
"silu")
|
||||
xbc = causal_conv1d_update(xbc,
|
||||
conv_states,
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias,
|
||||
activation="silu",
|
||||
conv_state_indices=indices)
|
||||
|
||||
x, B, C = torch.split(
|
||||
xbc,
|
||||
@ -314,7 +313,7 @@ class Mamba2Mixer(nn.Module):
|
||||
z = rearrange(z, "b (h p) -> b h p", p=self.head_dim)
|
||||
|
||||
y = selective_state_update(
|
||||
current_ssm_states,
|
||||
ssm_states,
|
||||
x_reshaped,
|
||||
dt,
|
||||
A,
|
||||
@ -324,6 +323,7 @@ class Mamba2Mixer(nn.Module):
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=self.delta_softplus,
|
||||
state_batch_indices=indices,
|
||||
)
|
||||
|
||||
y = rearrange(y, "b h p -> b (h p)")
|
||||
@ -331,11 +331,6 @@ class Mamba2Mixer(nn.Module):
|
||||
# gated norm
|
||||
y = self.norm(y)
|
||||
|
||||
# copy new conv and ssm states
|
||||
if not is_warmup:
|
||||
conv_states.index_copy_(0, indices, current_conv_states)
|
||||
ssm_states.index_copy_(0, indices, current_ssm_states)
|
||||
|
||||
# append output
|
||||
out.append(y)
|
||||
|
||||
|
||||
@ -593,7 +593,7 @@ class MambaCacheManager(BaseResourceManager):
|
||||
block = self.mamba_cache_free_blocks.pop()
|
||||
self.mamba_cache_index[r] = block
|
||||
state_indices.append(block)
|
||||
self.state_indices = torch.as_tensor(state_indices, dtype=torch.long)
|
||||
self.state_indices = torch.as_tensor(state_indices, dtype=torch.int32)
|
||||
|
||||
def free_mamba_cache_blocks(self, request_id: int):
|
||||
if request_id in self.mamba_cache_index:
|
||||
|
||||
@ -26,11 +26,11 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dim, dconv, req_type, dtype, batch_size, max_seq_len, remove_padding, apply_silu",
|
||||
"dim, dconv, req_type, dtype, batch_size, max_seq_len, remove_padding, apply_silu, paged_cache",
|
||||
list(
|
||||
product([2048], [4], ['context', 'generation'],
|
||||
['float16', 'float32', 'bfloat16'], [5], [16], [False, True],
|
||||
[False, True])) +
|
||||
[False, True], [False, True])) +
|
||||
# long sequence tests to cover the int overflow issue
|
||||
list(
|
||||
map(
|
||||
@ -42,9 +42,9 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||
"The long sequence test needs at least 33GB memory, skipping"
|
||||
)),
|
||||
product([5376], [4], ['context'], ['float16', 'bfloat16'], [2],
|
||||
[131072], [False, True], [False, True]))))
|
||||
[131072], [False, True], [False, True], [False]))))
|
||||
def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len,
|
||||
remove_padding, apply_silu):
|
||||
remove_padding, apply_silu, paged_cache):
|
||||
device = "cuda"
|
||||
seq_len = max_seq_len if req_type == "context" else 1
|
||||
mean = 0.0
|
||||
@ -94,11 +94,22 @@ def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len,
|
||||
else:
|
||||
x_in_out = x.detach().clone()
|
||||
|
||||
conv_state_in_out = conv_state.detach().clone()
|
||||
if paged_cache:
|
||||
padded_batch_size = 2 * batch_size
|
||||
cache_indices = torch.randperm(padded_batch_size,
|
||||
device=device,
|
||||
dtype=torch.int32)[:batch_size]
|
||||
conv_state_in_out = torch.empty([padded_batch_size, dim, dconv - 1],
|
||||
dtype=torch_dtype,
|
||||
device=device)
|
||||
conv_state_in_out[cache_indices] = conv_state.detach().clone()
|
||||
else:
|
||||
cache_indices = None
|
||||
conv_state_in_out = conv_state.detach().clone()
|
||||
|
||||
conv_weight_input = conv_weight.squeeze(1).contiguous()
|
||||
|
||||
if req_type == "context":
|
||||
cache_indices = None
|
||||
has_initial_state = None
|
||||
|
||||
torch.ops.trtllm.causal_conv1d_fwd(
|
||||
@ -112,11 +123,12 @@ def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len,
|
||||
apply_silu,
|
||||
PAD_SLOT_ID,
|
||||
)
|
||||
outputs = (x_in_out, conv_state_in_out)
|
||||
outputs = (x_in_out, conv_state_in_out[cache_indices]
|
||||
if cache_indices is not None else conv_state_in_out)
|
||||
|
||||
else:
|
||||
conv_state_indices = cache_indices
|
||||
cache_seqlens = None
|
||||
conv_state_indices = None
|
||||
|
||||
torch.ops.trtllm.causal_conv1d_update(
|
||||
x_in_out,
|
||||
@ -128,7 +140,8 @@ def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len,
|
||||
conv_state_indices,
|
||||
PAD_SLOT_ID,
|
||||
)
|
||||
outputs = (x_in_out, conv_state_in_out)
|
||||
outputs = (x_in_out, conv_state_in_out[cache_indices]
|
||||
if cache_indices is not None else conv_state_in_out)
|
||||
|
||||
out_ref = torch.zeros_like(x)
|
||||
conv_state_ref = torch.zeros_like(conv_state)
|
||||
|
||||
@ -30,21 +30,21 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dim, headdim, ngroups, dstate, req_type, dtype, batch_size, max_seq_len, has_z, remove_padding",
|
||||
"dim, headdim, ngroups, dstate, req_type, dtype, batch_size, max_seq_len, has_z, remove_padding, paged_cache",
|
||||
# P=8x and H=2x
|
||||
list(
|
||||
product([160, 320, 640], [80], [1], [128], ['context'], ['float16'],
|
||||
[1, 2, 8, 16], [16, 64, 256], [True], [True])) +
|
||||
[1, 2, 8, 16], [16, 64, 256], [True], [True], [False])) +
|
||||
# normal tests
|
||||
list(
|
||||
product([2048], [64], [1, 4], [128], ['context', 'generation'],
|
||||
['float32', 'float16', 'bfloat16'], [3], [16], [True, False],
|
||||
[True, False])) +
|
||||
['float32', 'float16', 'bfloat16'], [3], [16], [False],
|
||||
[True, False], [True, False])) +
|
||||
# arbitrary N generation tests
|
||||
list(
|
||||
product([2048], [64], [1, 4], [16, 32, 48, 64, 80, 96, 128, 256],
|
||||
['generation'], ['float32', 'float16'], [3], [16], [True],
|
||||
[True])) +
|
||||
[True], [False])) +
|
||||
# long sequence tests to cover the int overflow issue
|
||||
list(
|
||||
map(
|
||||
@ -56,16 +56,17 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||
"The long sequence test needs at least 68GB memory, skipping"
|
||||
)),
|
||||
product([5120], [64], [1], [128], ['context'], ['float16'], [2],
|
||||
[131072], [True, False], [True, False]))) +
|
||||
[131072], [True, False], [True, False], [False]))) +
|
||||
# P=8x and H=2x
|
||||
list(
|
||||
product([144], [72], [1], [64, 128, 256], ['context', 'generation'],
|
||||
['float16'], [16], [16384], [True, False], [True, False])),
|
||||
['float16'], [16], [16384], [True, False], [True, False],
|
||||
[False])),
|
||||
)
|
||||
def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
|
||||
req_type, dtype, batch_size,
|
||||
max_seq_len, has_z,
|
||||
remove_padding):
|
||||
remove_padding, paged_cache):
|
||||
# configs
|
||||
device = "cuda"
|
||||
seq_len = max_seq_len if req_type == 'context' else 1
|
||||
@ -190,6 +191,19 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
|
||||
outputs = (out, ssm_state)
|
||||
|
||||
else:
|
||||
if paged_cache:
|
||||
padded_batch_size = 2 * batch_size
|
||||
state_batch_indices = torch.randperm(padded_batch_size,
|
||||
device=device,
|
||||
dtype=torch.int32)[:batch_size]
|
||||
orig_state = state.detach().clone()
|
||||
state = torch.empty([padded_batch_size, nheads, headdim, dstate],
|
||||
dtype=torch_dtype,
|
||||
device=device)
|
||||
state[state_batch_indices] = orig_state
|
||||
else:
|
||||
state_batch_indices = None
|
||||
|
||||
y = selective_state_update(
|
||||
state,
|
||||
x,
|
||||
@ -201,8 +215,10 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
|
||||
z=z if has_z else None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=delta_softplus,
|
||||
state_batch_indices=state_batch_indices,
|
||||
)
|
||||
outputs = (y, state)
|
||||
outputs = (y, state[state_batch_indices]
|
||||
if state_batch_indices is not None else state)
|
||||
|
||||
# pytorch run
|
||||
if req_type == 'context':
|
||||
|
||||
Loading…
Reference in New Issue
Block a user