TensorRT-LLMs/tests/unittest/trt/functional/test_selective_scan.py
xiweny 6979afa6f2
test: reorganize tests folder hierarchy (#2996)
1. move TRT path tests to 'trt' folder
2. optimize some import usage
2025-03-27 12:07:53 +08:00

675 lines
30 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from itertools import product
import numpy as np
import pytest
import torch
from einops import rearrange, repeat
from parameterized import parameterized
from utils.torch_ref import (selective_scan_ref, selective_state_update_ref,
ssd_chunk_scan_combined_ref)
from utils.util import unittest_name_func
import tensorrt_llm
from tensorrt_llm import Tensor
from tensorrt_llm._utils import str_dtype_to_torch
class TestFunctional(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level('error')
@parameterized.expand(list(
product([2048], [16], ['context', 'generation'],
['float16', 'float32', 'bfloat16'], [3], [16], [False, True])),
name_func=unittest_name_func)
def test_selective_scan(self, dim, dstate, req_type, dtype, batch_size,
max_seq_len, remove_padding):
# configs
device = "cuda"
seq_len = max_seq_len if req_type == 'context' else 1
dt_rank = 160
delta_softplus = True
# test data
torch.random.manual_seed(0)
if remove_padding:
last_token_ids = torch.randint(1,
seq_len + 1, (batch_size, ),
dtype=torch.int32)
host_context_lengths = last_token_ids.detach().clone().cpu()
last_token_ids = torch.cumsum(last_token_ids,
dim=0,
dtype=torch.int32).to(device)
total_num_tokens = last_token_ids[batch_size - 1]
else:
last_token_ids = torch.ones(
(batch_size, ), dtype=torch.int32, device=device) * seq_len
host_context_lengths = last_token_ids.detach().clone().cpu()
total_num_tokens = batch_size * seq_len
state = torch.randn(batch_size,
dstate,
dim,
device=device,
dtype=str_dtype_to_torch(dtype))
x = torch.randn(total_num_tokens,
dim,
device=device,
dtype=str_dtype_to_torch(dtype))
dt = torch.randn(total_num_tokens,
dim,
device=device,
dtype=str_dtype_to_torch(dtype))
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dstate, dim, device=device) - 1.0
BC = torch.randn(total_num_tokens,
dt_rank + dstate * 2,
device=device,
dtype=str_dtype_to_torch(dtype))
D = torch.randn(dim, device=device)
z = torch.randn_like(x)
host_request_types = torch.tensor([0 if req_type == 'context' else 1] *
batch_size,
dtype=torch.int32)
if not remove_padding or req_type == 'generation':
x = x.view(-1, seq_len, dim)
dt = dt.view(-1, seq_len, dim)
BC = BC.view(-1, seq_len, dt_rank + dstate * 2)
z = z.view(-1, seq_len, dim)
if remove_padding and req_type == "generation":
output = torch.zeros(x.squeeze(1).shape,
device=device,
dtype=str_dtype_to_torch(dtype))
else:
output = torch.zeros(x.shape,
device=device,
dtype=str_dtype_to_torch(dtype))
state_ref = state.detach().clone()
x_ref = x.detach().clone()
dt_ref = dt.detach().clone()
dt_bias_ref = dt_bias.detach().clone()
A_ref = A.detach().clone()
B_ref = BC[..., dt_rank:dt_rank + dstate].detach().clone()
C_ref = BC[..., dt_rank + dstate:].detach().clone()
D_ref = D.detach().clone()
z_ref = z.detach().clone()
if remove_padding and req_type == "generation":
x = x.squeeze(1)
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
if remove_padding:
net.plugin_config.remove_input_padding = True
else:
net.plugin_config.remove_input_padding = False
net.plugin_config.paged_state = False
with tensorrt_llm.net_guard(net):
x_tensor = Tensor(name='input',
shape=x.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
state_tensor = Tensor(name='state',
shape=state.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
dt_tensor = Tensor(name='delta',
shape=dt.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
dt_bias_tensor = Tensor(
name='delta_bias',
shape=dt_bias.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
A_tensor = Tensor(name='A',
shape=A.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
BC_tensor = Tensor(name='BC',
shape=BC.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
D_tensor = Tensor(name='D',
shape=D.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
z_tensor = Tensor(name='z',
shape=z.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
host_request_types_tensor = Tensor(
name='host_request_types',
shape=host_request_types.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
last_token_ids_tensor = Tensor(
name='last_token_ids',
shape=last_token_ids.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
host_context_lengths_tensor = None
if remove_padding:
host_context_lengths_tensor = Tensor(
name='host_context_lengths',
shape=host_context_lengths.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
outputs = tensorrt_llm.functional.selective_scan(
x_tensor,
state_tensor,
dt_tensor,
dt_bias_tensor,
A_tensor,
BC_tensor,
D_tensor,
host_request_types_tensor,
last_token_ids_tensor,
dim,
dstate,
dt_rank,
delta_softplus,
dtype,
host_context_lengths=host_context_lengths_tensor,
z=z_tensor)
net._mark_output(outputs[0],
'output',
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
net._mark_output(outputs[1],
'present_state',
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
# trt run
inputs = {
'input': x,
'state': state,
'delta': dt,
'delta_bias': dt_bias,
'A': A,
'BC': BC,
'D': D,
'z': z,
'host_request_types': host_request_types,
'last_token_ids': last_token_ids
}
if remove_padding:
inputs['host_context_lengths'] = host_context_lengths
outputs = {'output': output, 'present_state': state}
stream = torch.cuda.current_stream()
builder_config = builder.create_builder_config(precision=dtype, )
engine = builder.build_engine(net, builder_config)
session = tensorrt_llm.runtime.Session.from_serialized_engine(engine)
session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
if remove_padding and req_type == "generation":
out_ref = torch.zeros(output.unsqueeze(1).shape,
device=device,
dtype=str_dtype_to_torch(dtype))
else:
out_ref = torch.zeros(output.shape,
device=device,
dtype=str_dtype_to_torch(dtype))
if req_type == 'context':
# pytorch run
if remove_padding:
for i in range(batch_size):
start_id = 0 if i == 0 else last_token_ids[i - 1]
end_id = last_token_ids[i]
part_out_ref, part_state_ref = selective_scan_ref(
x_ref[start_id:end_id].unsqueeze(0),
dt_ref[start_id:end_id].unsqueeze(0),
A_ref,
B_ref[start_id:end_id].unsqueeze(0),
C_ref[start_id:end_id].unsqueeze(0),
D=D_ref,
z=z_ref[start_id:end_id].unsqueeze(0),
delta_bias=dt_bias_ref,
delta_softplus=True)
out_ref[start_id:end_id][:] = part_out_ref.squeeze(0)
state_ref[i][:][:] = part_state_ref.squeeze(0)
else:
out_ref, state_ref = selective_scan_ref(x_ref,
dt_ref,
A_ref,
B_ref,
C_ref,
D=D_ref,
z=z_ref,
delta_bias=dt_bias_ref,
delta_softplus=True)
elif req_type == 'generation':
# pytorch run
out_ref = selective_state_update_ref(state_ref,
x_ref.squeeze(1),
dt_ref.squeeze(1),
A_ref,
B_ref.squeeze(1),
C_ref.squeeze(1),
D=D_ref,
z=z_ref.squeeze(1),
dt_bias=dt_bias_ref,
dt_softplus=True)
out_ref = out_ref.unsqueeze(1)
dtype_atol = {"float16": 5e-3, "float32": 2e-3, "bfloat16": 5e-2}
if remove_padding and req_type == "generation":
out_ref = out_ref.squeeze(1)
output_cpu = outputs['output'].to(torch.float32).cpu()
present_state_cpu = outputs['present_state'].to(torch.float32).cpu()
np.testing.assert_allclose(out_ref.to(torch.float32).cpu().numpy(),
output_cpu.numpy(),
atol=dtype_atol[dtype])
np.testing.assert_allclose(state_ref.to(torch.float32).cpu().numpy(),
present_state_cpu.numpy(),
atol=dtype_atol[dtype])
@parameterized.expand(
# P=8x and H=2x
list(
product([160, 320, 640], [80], [1], [128], ['context'], ['float16'],
[1, 2, 8, 16], [16, 64, 256], [True], [True])) +
# normal tests
list(
product([2048], [64], [1, 4], [128], ['context', 'generation'],
['float32', 'float16', 'bfloat16'], [3], [16],
[True, False], [True, False])) +
# arbitrary N generation tests
list(
product([2048], [64], [1, 4], [16, 32, 48, 64, 80, 96, 128, 256],
['generation'], ['float32', 'float16'], [3], [16], [True],
[True])) +
# long sequence tests to cover the int overflow issue
list(
product([5120], [64], [1], [128], ['context'], ['float16'], [2],
[131072], [True, False], [True, False])) +
# P=8x and H=2x
list(
product([144], [72], [1], [64, 128, 256], ['context', 'generation'],
['float16'], [16], [16384], [True, False], [True, False])),
name_func=unittest_name_func)
def test_selective_scan_v2(self, dim, headdim, ngroups, dstate, req_type,
dtype, batch_size, max_seq_len, has_z,
remove_padding):
# Skip tests that are not supported
if dtype == 'float32' and req_type == 'context':
pytest.skip(
"Mamba2 chunk scan kernel only support float16 and bfloat16")
if max_seq_len >= 128 * 1024:
total_gpu_mem = torch.cuda.get_device_properties(0).total_memory
if total_gpu_mem <= 68 * 1024**3:
pytest.skip(
"The long sequence test needs at least 68GB memory, skipping"
)
# configs
device = "cuda"
seq_len = max_seq_len if req_type == 'context' else 1
long_context = max_seq_len >= 128 * 1024
chunk_size = 256
nheads = dim // headdim
nheads_pad0 = (nheads + 7) // 8 * 8 - nheads
delta_softplus = True
mean = 0.0
if long_context:
std_dev = 0.05
elif dtype == "float32":
std_dev = 0.5
else:
std_dev = 0.1
# test data
torch.random.manual_seed(0)
if remove_padding:
last_token_ids = torch.randint(1,
seq_len + 1, (batch_size, ),
dtype=torch.int32)
last_token_ids[0] = seq_len
host_context_lengths = last_token_ids.detach().clone().cpu()
last_token_ids = torch.cumsum(last_token_ids,
dim=0,
dtype=torch.int32).to(device)
total_num_tokens = last_token_ids[batch_size - 1]
else:
last_token_ids = torch.ones(
(batch_size, ), dtype=torch.int32, device=device) * seq_len
host_context_lengths = last_token_ids.detach().clone().cpu()
total_num_tokens = batch_size * seq_len
state = torch.empty(batch_size,
nheads,
dstate,
headdim,
device=device,
dtype=str_dtype_to_torch(dtype))
x = torch.empty(total_num_tokens,
dim,
device=device,
dtype=str_dtype_to_torch(dtype))
x.normal_(mean, std_dev)
state.normal_(mean, std_dev)
dt = torch.randn(total_num_tokens,
nheads,
device=device,
dtype=str_dtype_to_torch(dtype))
if nheads_pad0:
dt_pad0 = torch.randn(total_num_tokens,
nheads_pad0,
device=device,
dtype=str_dtype_to_torch(dtype))
dt_bias = torch.rand(nheads, device=device) - 4.0
A = -torch.rand(nheads, device=device) - 1.0
BC = torch.randn(total_num_tokens,
ngroups * dstate * 2,
device=device,
dtype=str_dtype_to_torch(dtype))
D = torch.randn(nheads, device=device)
if has_z:
z = torch.randn_like(x)
host_request_types = torch.tensor([0 if req_type == 'context' else 1] *
batch_size,
dtype=torch.int32)
if not remove_padding or req_type == 'generation':
x = x.view(-1, seq_len, dim)
dt = dt.view(-1, seq_len, nheads)
if nheads_pad0:
dt_pad0 = dt_pad0.view(-1, seq_len, nheads_pad0)
BC = BC.view(-1, seq_len, ngroups * dstate * 2)
if has_z:
z = z.view(-1, seq_len, dim)
xBC = torch.concat([x, BC], dim=-1).contiguous()
if has_z:
zxBCdt = torch.concat([z, torch.randn_like(xBC), dt] + ([
dt_pad0,
] if nheads_pad0 else []),
dim=-1).contiguous()
else:
zxBCdt = torch.concat([torch.randn_like(xBC), dt] + ([
dt_pad0,
] if nheads_pad0 else []),
dim=-1).contiguous()
if remove_padding and req_type == "generation":
xBC = xBC.squeeze(1)
if remove_padding and req_type == "generation":
output = torch.zeros(x.squeeze(1).shape,
device=device,
dtype=str_dtype_to_torch(dtype))
else:
output = torch.zeros(x.shape,
device=device,
dtype=str_dtype_to_torch(dtype))
state_ref = state.detach().clone()
x_ref = x.detach().clone()
dt_ref = dt.detach().clone()
dt_bias_ref = dt_bias.detach().clone()
A_ref = A.detach().clone()
B_ref = BC[..., 0:ngroups * dstate].detach().clone()
C_ref = BC[..., ngroups * dstate:].detach().clone()
D_ref = D.detach().clone()
z_ref = z.detach().clone() if has_z else None
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
if remove_padding:
net.plugin_config.remove_input_padding = True
else:
net.plugin_config.remove_input_padding = False
net.plugin_config.paged_state = False
with tensorrt_llm.net_guard(net):
x_tensor = Tensor(name='input',
shape=xBC.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
state_tensor = Tensor(name='state',
shape=state.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
dt_tensor = Tensor(name='delta',
shape=zxBCdt.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
dt_bias_tensor = Tensor(
name='delta_bias',
shape=dt_bias.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
A_tensor = Tensor(name='A',
shape=A.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
BC_tensor = Tensor(name='BC',
shape=xBC.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
D_tensor = Tensor(name='D',
shape=D.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
if has_z:
z_tensor = Tensor(name='z',
shape=zxBCdt.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
host_request_types_tensor = Tensor(
name='host_request_types',
shape=host_request_types.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
last_token_ids_tensor = Tensor(
name='last_token_ids',
shape=last_token_ids.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
host_context_lengths_tensor = None
if remove_padding:
host_context_lengths_tensor = Tensor(
name='host_context_lengths',
shape=host_context_lengths.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
outputs = tensorrt_llm.functional.selective_scan(
x_tensor,
state_tensor,
dt_tensor,
dt_bias_tensor,
A_tensor,
BC_tensor,
D_tensor,
host_request_types_tensor,
last_token_ids_tensor,
dim,
dstate,
0,
delta_softplus,
dtype,
z=z_tensor if has_z else None,
host_context_lengths=host_context_lengths_tensor,
nheads=nheads,
ngroups=ngroups,
chunk_size=chunk_size,
mamba_version='Mamba2')
net._mark_output(outputs[0],
'output',
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
net._mark_output(outputs[1],
'present_state',
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
# trt run
inputs = {
'input': xBC,
'state': state,
'delta': zxBCdt,
'delta_bias': dt_bias,
'A': A,
'BC': xBC,
'D': D,
'host_request_types': host_request_types,
'last_token_ids': last_token_ids
}
if remove_padding:
inputs['host_context_lengths'] = host_context_lengths
if has_z:
inputs['z'] = zxBCdt
outputs = {'output': output, 'present_state': state}
stream = torch.cuda.current_stream()
builder_config = builder.create_builder_config(precision=dtype, )
engine = builder.build_engine(net, builder_config)
session = tensorrt_llm.runtime.Session.from_serialized_engine(engine)
session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
if remove_padding and req_type == "generation":
out_ref = torch.zeros(output.unsqueeze(1).shape,
device=device,
dtype=str_dtype_to_torch(dtype))
else:
out_ref = torch.zeros(output.shape,
device=device,
dtype=str_dtype_to_torch(dtype))
# pytorch run
if req_type == 'context':
if remove_padding:
for i in range(batch_size):
start = 0 if i == 0 else last_token_ids[i - 1]
end = last_token_ids[i]
x_reshaped = rearrange(x_ref[start:end].unsqueeze(0),
"b l (h p) -> b l h p",
p=headdim)
B_ref_reshaped = rearrange(B_ref[start:end].unsqueeze(0),
"b l (g n) -> b l g n",
g=ngroups)
C_ref_reshaped = rearrange(C_ref[start:end].unsqueeze(0),
"b l (g n) -> b l g n",
g=ngroups)
z_ref_reshaped = rearrange(z_ref[start:end].unsqueeze(0),
"b l (h p) -> b l h p",
p=headdim) if has_z else None
part_out_ref, part_state_ref = ssd_chunk_scan_combined_ref(
x_reshaped,
dt_ref[start:end].unsqueeze(0),
A_ref,
B_ref_reshaped,
C_ref_reshaped,
chunk_size,
D=D_ref,
z=z_ref_reshaped,
dt_bias=dt_bias_ref,
dt_softplus=delta_softplus)
part_out_ref = rearrange(part_out_ref,
"b l h p -> b l (h p)")
out_ref[
start:end,
] = part_out_ref.squeeze(0)
state_ref[
i,
] = part_state_ref.squeeze(0)
elif long_context:
# to save memory
for i in range(batch_size):
x_reshaped = rearrange(x_ref[
i:i + 1,
],
"b l (h p) -> b l h p",
p=headdim)
B_ref_reshaped = rearrange(B_ref[
i:i + 1,
],
"b l (g n) -> b l g n",
g=ngroups)
C_ref_reshaped = rearrange(C_ref[
i:i + 1,
],
"b l (g n) -> b l g n",
g=ngroups)
z_ref_reshaped = rearrange(
z_ref[
i:i + 1,
], "b l (h p) -> b l h p", p=headdim) if has_z else None
part_out_ref, part_state_ref = ssd_chunk_scan_combined_ref(
x_reshaped,
dt_ref[
i:i + 1,
],
A_ref,
B_ref_reshaped,
C_ref_reshaped,
chunk_size,
D=D_ref,
z=z_ref_reshaped,
dt_bias=dt_bias_ref,
dt_softplus=delta_softplus)
part_out_ref = rearrange(part_out_ref,
"b l h p -> b l (h p)")
out_ref[
i,
] = part_out_ref.squeeze(0)
state_ref[
i,
] = part_state_ref.squeeze(0)
else:
x_reshaped = rearrange(x_ref, "b l (h p) -> b l h p", p=headdim)
B_ref_reshaped = rearrange(B_ref,
"b l (g n) -> b l g n",
g=ngroups)
C_ref_reshaped = rearrange(C_ref,
"b l (g n) -> b l g n",
g=ngroups)
z_ref_reshaped = rearrange(
z_ref, "b l (h p) -> b l h p", p=headdim) if has_z else None
out_ref, state_ref = ssd_chunk_scan_combined_ref(
x_reshaped,
dt_ref,
A_ref,
B_ref_reshaped,
C_ref_reshaped,
chunk_size,
D=D_ref,
z=z_ref_reshaped,
dt_bias=dt_bias_ref,
dt_softplus=delta_softplus)
out_ref = rearrange(out_ref, "b l h p -> b l (h p)")
elif req_type == 'generation':
A_ref = repeat(A_ref, "h -> h n p", p=headdim,
n=dstate).to(dtype=torch.float32)
dt_ref = repeat(dt_ref.squeeze(1), "b h -> b h p", p=headdim)
dt_bias_ref = repeat(dt_bias_ref, "h -> h p", p=headdim)
D_ref = repeat(D_ref, "h -> h p", p=headdim)
B_ref = rearrange(B_ref.squeeze(1), "b (g n) -> b g n", g=ngroups)
C_ref = rearrange(C_ref.squeeze(1), "b (g n) -> b g n", g=ngroups)
x_reshaped = rearrange(x_ref.squeeze(1),
"b (h p) -> b h p",
p=headdim)
if has_z:
z_ref = rearrange(z_ref.squeeze(1),
"b (h p) -> b h p",
p=headdim)
out_ref = selective_state_update_ref(state_ref,
x_reshaped,
dt_ref,
A_ref,
B_ref,
C_ref,
D=D_ref,
z=z_ref,
dt_bias=dt_bias_ref,
dt_softplus=delta_softplus)
out_ref = rearrange(out_ref, "b h p -> b (h p)").unsqueeze(1)
dtype_atol = {"float16": 5e-2, "float32": 2e-3, "bfloat16": 1e-1}
if remove_padding and req_type == "generation":
out_ref = out_ref.squeeze(1)
output_cpu = outputs['output'].to(torch.float32).cpu()
present_state_cpu = outputs['present_state'].to(torch.float32).cpu()
np.testing.assert_allclose(out_ref.to(torch.float32).cpu().numpy(),
output_cpu.numpy(),
atol=dtype_atol[dtype])
np.testing.assert_allclose(state_ref.to(torch.float32).cpu().numpy(),
present_state_cpu.numpy(),
atol=dtype_atol[dtype])