TensorRT-LLMs/tests/functional/test_selective_scan.py
Kaiyu Xie be9cd719f7
Update TensorRT-LLM (#2094)
* Update TensorRT-LLM

---------

Co-authored-by: akhoroshev <arthoroshev@gmail.com>
Co-authored-by: Fabian Joswig <fjosw@users.noreply.github.com>
Co-authored-by: Tayef Shah <tayefshah@gmail.com>
Co-authored-by: lfz941 <linfanzai941@gmail.com>
2024-08-07 16:44:43 +08:00

609 lines
27 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
from itertools import product
import numpy as np
import pytest
import torch
from einops import rearrange, repeat
from parameterized import parameterized
from torch_ref import (selective_scan_ref, selective_state_update_ref,
ssd_chunk_scan_combined_ref)
import tensorrt_llm
from tensorrt_llm import Tensor
from tensorrt_llm._utils import str_dtype_to_torch
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.util import getSMVersion, skip_bf16_pre_ampere, unittest_name_func
class TestFunctional(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level('error')
@parameterized.expand(list(
product([2048], [16], ['context', 'generation'],
['float16', 'float32', 'bfloat16'], [3], [16], [False, True])),
name_func=unittest_name_func)
def test_selective_scan(self, dim, dstate, req_type, dtype, batch_size,
max_seq_len, remove_padding):
# Skip tests that are not supported in pre-ampere architecture
skip_bf16_pre_ampere(dtype)
# configs
device = "cuda"
seq_len = max_seq_len if req_type == 'context' else 1
dt_rank = 160
delta_softplus = True
# test data
torch.random.manual_seed(0)
if remove_padding:
last_token_ids = torch.randint(1,
seq_len + 1, (batch_size, ),
dtype=torch.int32)
host_context_lengths = last_token_ids.detach().clone().cpu()
last_token_ids = torch.cumsum(last_token_ids,
dim=0,
dtype=torch.int32).to(device)
total_num_tokens = last_token_ids[batch_size - 1]
else:
last_token_ids = torch.ones(
(batch_size, ), dtype=torch.int32, device=device) * seq_len
host_context_lengths = last_token_ids.detach().clone().cpu()
total_num_tokens = batch_size * seq_len
state = torch.randn(batch_size,
dstate,
dim,
device=device,
dtype=str_dtype_to_torch(dtype))
x = torch.randn(total_num_tokens,
dim,
device=device,
dtype=str_dtype_to_torch(dtype))
dt = torch.randn(total_num_tokens,
dim,
device=device,
dtype=str_dtype_to_torch(dtype))
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dstate, dim, device=device) - 1.0
BC = torch.randn(total_num_tokens,
dt_rank + dstate * 2,
device=device,
dtype=str_dtype_to_torch(dtype))
D = torch.randn(dim, device=device)
z = torch.randn_like(x)
host_request_types = torch.tensor([0 if req_type == 'context' else 1] *
batch_size,
dtype=torch.int32)
if not remove_padding or req_type == 'generation':
x = x.view(-1, seq_len, dim)
dt = dt.view(-1, seq_len, dim)
BC = BC.view(-1, seq_len, dt_rank + dstate * 2)
z = z.view(-1, seq_len, dim)
output = torch.zeros(x.shape,
device=device,
dtype=str_dtype_to_torch(dtype))
state_ref = state.detach().clone()
x_ref = x.detach().clone()
dt_ref = dt.detach().clone()
dt_bias_ref = dt_bias.detach().clone()
A_ref = A.detach().clone()
B_ref = BC[..., dt_rank:dt_rank + dstate].detach().clone()
C_ref = BC[..., dt_rank + dstate:].detach().clone()
D_ref = D.detach().clone()
z_ref = z.detach().clone()
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
if remove_padding:
net.plugin_config.remove_input_padding = True
else:
net.plugin_config.remove_input_padding = False
net.plugin_config.paged_state = False
with tensorrt_llm.net_guard(net):
x_tensor = Tensor(name='input',
shape=x.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
state_tensor = Tensor(name='state',
shape=state.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
dt_tensor = Tensor(name='delta',
shape=dt.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
dt_bias_tensor = Tensor(
name='delta_bias',
shape=dt_bias.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
A_tensor = Tensor(name='A',
shape=A.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
BC_tensor = Tensor(name='BC',
shape=BC.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
D_tensor = Tensor(name='D',
shape=D.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
z_tensor = Tensor(name='z',
shape=z.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
host_request_types_tensor = Tensor(
name='host_request_types',
shape=host_request_types.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
last_token_ids_tensor = Tensor(
name='last_token_ids',
shape=last_token_ids.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
host_context_lengths_tensor = None
if remove_padding:
host_context_lengths_tensor = Tensor(
name='host_context_lengths',
shape=host_context_lengths.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
outputs = tensorrt_llm.functional.selective_scan(
x_tensor,
state_tensor,
dt_tensor,
dt_bias_tensor,
A_tensor,
BC_tensor,
D_tensor,
host_request_types_tensor,
last_token_ids_tensor,
dim,
dstate,
dt_rank,
delta_softplus,
dtype,
host_context_lengths=host_context_lengths_tensor,
z=z_tensor)
net._mark_output(outputs[0],
'output',
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
net._mark_output(outputs[1],
'present_state',
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
# trt run
inputs = {
'input': x,
'state': state,
'delta': dt,
'delta_bias': dt_bias,
'A': A,
'BC': BC,
'D': D,
'z': z,
'host_request_types': host_request_types,
'last_token_ids': last_token_ids
}
if remove_padding:
inputs['host_context_lengths'] = host_context_lengths
outputs = {'output': output, 'present_state': state}
stream = torch.cuda.current_stream()
builder_config = builder.create_builder_config(precision=dtype, )
engine = builder.build_engine(net, builder_config)
session = tensorrt_llm.runtime.Session.from_serialized_engine(engine)
session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
out_ref = torch.zeros(output.shape,
device=device,
dtype=str_dtype_to_torch(dtype))
if req_type == 'context':
# pytorch run
if remove_padding:
for i in range(batch_size):
start_id = 0 if i == 0 else last_token_ids[i - 1]
end_id = last_token_ids[i]
part_out_ref, part_state_ref = selective_scan_ref(
x_ref[start_id:end_id].unsqueeze(0),
dt_ref[start_id:end_id].unsqueeze(0),
A_ref,
B_ref[start_id:end_id].unsqueeze(0),
C_ref[start_id:end_id].unsqueeze(0),
D=D_ref,
z=z_ref[start_id:end_id].unsqueeze(0),
delta_bias=dt_bias_ref,
delta_softplus=True)
out_ref[start_id:end_id][:] = part_out_ref.squeeze(0)
state_ref[i][:][:] = part_state_ref.squeeze(0)
else:
out_ref, state_ref = selective_scan_ref(x_ref,
dt_ref,
A_ref,
B_ref,
C_ref,
D=D_ref,
z=z_ref,
delta_bias=dt_bias_ref,
delta_softplus=True)
elif req_type == 'generation':
# pytorch run
out_ref = selective_state_update_ref(state_ref,
x_ref.squeeze(1),
dt_ref.squeeze(1),
A_ref,
B_ref.squeeze(1),
C_ref.squeeze(1),
D=D_ref,
z=z_ref.squeeze(1),
dt_bias=dt_bias_ref,
dt_softplus=True)
out_ref = out_ref.unsqueeze(1)
dtype_atol = {"float16": 5e-3, "float32": 2e-3, "bfloat16": 5e-2}
output_cpu = outputs['output'].to(torch.float32).cpu()
present_state_cpu = outputs['present_state'].to(torch.float32).cpu()
np.testing.assert_allclose(out_ref.to(torch.float32).cpu().numpy(),
output_cpu.numpy(),
atol=dtype_atol[dtype])
np.testing.assert_allclose(state_ref.to(torch.float32).cpu().numpy(),
present_state_cpu.numpy(),
atol=dtype_atol[dtype])
@parameterized.expand(
list(
product([2048], [64], [1, 4], ['context', 'generation'],
['float32', 'float16', 'bfloat16'], [3], [16],
[True, False], [True, False])) +
# long sequence tests to cover the int overflow issue
list(
product([5120], [64], [1], ['context'], ['float16'], [2], [131072],
[True, False], [True, False])),
name_func=unittest_name_func)
def test_selective_scan_v2(self, dim, headdim, ngroups, req_type, dtype,
batch_size, max_seq_len, has_z, remove_padding):
# Skip tests that are not supported
skip_bf16_pre_ampere(dtype)
if dtype == 'float32' and req_type == 'context':
pytest.skip(
"Mamba2 chunk scan kernel only support float16 and bfloat16")
if getSMVersion() < 80:
pytest.skip("Mamba2 is not supported in pre-Ampere architecture")
if max_seq_len >= 128 * 1024:
total_gpu_mem = torch.cuda.get_device_properties(0).total_memory
if total_gpu_mem <= 68 * 1024**3:
pytest.skip(
"The long sequence test needs at least 68GB memory, skipping"
)
# configs
device = "cuda"
seq_len = max_seq_len if req_type == 'context' else 1
long_context = max_seq_len >= 128 * 1024
dstate = 128
chunk_size = 256
nheads = dim // headdim
delta_softplus = True
mean = 0.0
if long_context:
std_dev = 0.05
elif dtype == "float32":
std_dev = 0.5
else:
std_dev = 0.1
# test data
torch.random.manual_seed(0)
if remove_padding:
last_token_ids = torch.randint(1,
seq_len + 1, (batch_size, ),
dtype=torch.int32)
last_token_ids[0] = seq_len
host_context_lengths = last_token_ids.detach().clone().cpu()
last_token_ids = torch.cumsum(last_token_ids,
dim=0,
dtype=torch.int32).to(device)
total_num_tokens = last_token_ids[batch_size - 1]
else:
last_token_ids = torch.ones(
(batch_size, ), dtype=torch.int32, device=device) * seq_len
host_context_lengths = last_token_ids.detach().clone().cpu()
total_num_tokens = batch_size * seq_len
state = torch.empty(batch_size,
nheads,
dstate,
headdim,
device=device,
dtype=str_dtype_to_torch(dtype))
x = torch.empty(total_num_tokens,
dim,
device=device,
dtype=str_dtype_to_torch(dtype))
x.normal_(mean, std_dev)
state.normal_(mean, std_dev)
dt = torch.randn(total_num_tokens,
nheads,
device=device,
dtype=str_dtype_to_torch(dtype))
dt_bias = torch.rand(nheads, device=device) - 4.0
A = -torch.rand(nheads, device=device) - 1.0
BC = torch.randn(total_num_tokens,
ngroups * dstate * 2,
device=device,
dtype=str_dtype_to_torch(dtype))
D = torch.randn(nheads, device=device)
if has_z:
z = torch.randn_like(x)
host_request_types = torch.tensor([0 if req_type == 'context' else 1] *
batch_size,
dtype=torch.int32)
if not remove_padding or req_type == 'generation':
x = x.view(-1, seq_len, dim)
dt = dt.view(-1, seq_len, nheads)
BC = BC.view(-1, seq_len, ngroups * dstate * 2)
if has_z:
z = z.view(-1, seq_len, dim)
xBC = torch.concat([x, BC], dim=-1).contiguous()
if has_z:
zxBCdt = torch.concat([z, torch.randn_like(xBC), dt],
dim=-1).contiguous()
else:
zxBCdt = torch.concat([torch.randn_like(xBC), dt],
dim=-1).contiguous()
output = torch.zeros(x.shape,
device=device,
dtype=str_dtype_to_torch(dtype))
state_ref = state.detach().clone()
x_ref = x.detach().clone()
dt_ref = dt.detach().clone()
dt_bias_ref = dt_bias.detach().clone()
A_ref = A.detach().clone()
B_ref = BC[..., 0:ngroups * dstate].detach().clone()
C_ref = BC[..., ngroups * dstate:].detach().clone()
D_ref = D.detach().clone()
z_ref = z.detach().clone() if has_z else None
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
if remove_padding:
net.plugin_config.remove_input_padding = True
else:
net.plugin_config.remove_input_padding = False
net.plugin_config.paged_state = False
with tensorrt_llm.net_guard(net):
x_tensor = Tensor(name='input',
shape=xBC.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
state_tensor = Tensor(name='state',
shape=state.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
dt_tensor = Tensor(name='delta',
shape=zxBCdt.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
dt_bias_tensor = Tensor(
name='delta_bias',
shape=dt_bias.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
A_tensor = Tensor(name='A',
shape=A.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
BC_tensor = Tensor(name='BC',
shape=xBC.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
D_tensor = Tensor(name='D',
shape=D.shape,
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
if has_z:
z_tensor = Tensor(name='z',
shape=zxBCdt.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
host_request_types_tensor = Tensor(
name='host_request_types',
shape=host_request_types.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
last_token_ids_tensor = Tensor(
name='last_token_ids',
shape=last_token_ids.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
host_context_lengths_tensor = None
if remove_padding:
host_context_lengths_tensor = Tensor(
name='host_context_lengths',
shape=host_context_lengths.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
outputs = tensorrt_llm.functional.selective_scan(
x_tensor,
state_tensor,
dt_tensor,
dt_bias_tensor,
A_tensor,
BC_tensor,
D_tensor,
host_request_types_tensor,
last_token_ids_tensor,
dim,
dstate,
0,
delta_softplus,
dtype,
z=z_tensor if has_z else None,
host_context_lengths=host_context_lengths_tensor,
nheads=nheads,
ngroups=ngroups,
chunk_size=chunk_size,
mamba_version='Mamba2')
net._mark_output(outputs[0],
'output',
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
net._mark_output(outputs[1],
'present_state',
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
# trt run
inputs = {
'input': xBC,
'state': state,
'delta': zxBCdt,
'delta_bias': dt_bias,
'A': A,
'BC': xBC,
'D': D,
'host_request_types': host_request_types,
'last_token_ids': last_token_ids
}
if remove_padding:
inputs['host_context_lengths'] = host_context_lengths
if has_z:
inputs['z'] = zxBCdt
inputs
outputs = {'output': output, 'present_state': state}
stream = torch.cuda.current_stream()
builder_config = builder.create_builder_config(precision=dtype, )
engine = builder.build_engine(net, builder_config)
session = tensorrt_llm.runtime.Session.from_serialized_engine(engine)
session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
out_ref = torch.zeros(output.shape,
device=device,
dtype=str_dtype_to_torch(dtype))
# pytorch run
if req_type == 'context':
if remove_padding:
for i in range(batch_size):
start = 0 if i == 0 else last_token_ids[i - 1]
end = last_token_ids[i]
x_reshaped = rearrange(x_ref[start:end].unsqueeze(0),
"b l (h p) -> b l h p",
p=headdim)
B_ref_reshaped = rearrange(B_ref[start:end].unsqueeze(0),
"b l (g n) -> b l g n",
g=ngroups)
C_ref_reshaped = rearrange(C_ref[start:end].unsqueeze(0),
"b l (g n) -> b l g n",
g=ngroups)
z_ref_reshaped = rearrange(z_ref[start:end].unsqueeze(0),
"b l (h p) -> b l h p",
p=headdim) if has_z else None
part_out_ref, part_state_ref = ssd_chunk_scan_combined_ref(
x_reshaped,
dt_ref[start:end].unsqueeze(0),
A_ref,
B_ref_reshaped,
C_ref_reshaped,
chunk_size,
D=D_ref,
z=z_ref_reshaped,
dt_bias=dt_bias_ref,
dt_softplus=delta_softplus)
part_out_ref = rearrange(part_out_ref,
"b l h p -> b l (h p)")
out_ref[start:end, ] = part_out_ref.squeeze(0)
state_ref[i, ] = part_state_ref.squeeze(0)
elif long_context:
# to save memory
for i in range(batch_size):
x_reshaped = rearrange(x_ref[i:i + 1, ],
"b l (h p) -> b l h p",
p=headdim)
B_ref_reshaped = rearrange(B_ref[i:i + 1, ],
"b l (g n) -> b l g n",
g=ngroups)
C_ref_reshaped = rearrange(C_ref[i:i + 1, ],
"b l (g n) -> b l g n",
g=ngroups)
z_ref_reshaped = rearrange(z_ref[i:i + 1, ],
"b l (h p) -> b l h p",
p=headdim) if has_z else None
part_out_ref, part_state_ref = ssd_chunk_scan_combined_ref(
x_reshaped,
dt_ref[i:i + 1, ],
A_ref,
B_ref_reshaped,
C_ref_reshaped,
chunk_size,
D=D_ref,
z=z_ref_reshaped,
dt_bias=dt_bias_ref,
dt_softplus=delta_softplus)
part_out_ref = rearrange(part_out_ref,
"b l h p -> b l (h p)")
out_ref[i, ] = part_out_ref.squeeze(0)
state_ref[i, ] = part_state_ref.squeeze(0)
else:
x_reshaped = rearrange(x_ref, "b l (h p) -> b l h p", p=headdim)
B_ref_reshaped = rearrange(B_ref,
"b l (g n) -> b l g n",
g=ngroups)
C_ref_reshaped = rearrange(C_ref,
"b l (g n) -> b l g n",
g=ngroups)
z_ref_reshaped = rearrange(
z_ref, "b l (h p) -> b l h p", p=headdim) if has_z else None
out_ref, state_ref = ssd_chunk_scan_combined_ref(
x_reshaped,
dt_ref,
A_ref,
B_ref_reshaped,
C_ref_reshaped,
chunk_size,
D=D_ref,
z=z_ref_reshaped,
dt_bias=dt_bias_ref,
dt_softplus=delta_softplus)
out_ref = rearrange(out_ref, "b l h p -> b l (h p)")
elif req_type == 'generation':
A_ref = repeat(A_ref, "h -> h n p", p=headdim,
n=dstate).to(dtype=torch.float32)
dt_ref = repeat(dt_ref.squeeze(1), "b h -> b h p", p=headdim)
dt_bias_ref = repeat(dt_bias_ref, "h -> h p", p=headdim)
D_ref = repeat(D_ref, "h -> h p", p=headdim)
B_ref = rearrange(B_ref.squeeze(1), "b (g n) -> b g n", g=ngroups)
C_ref = rearrange(C_ref.squeeze(1), "b (g n) -> b g n", g=ngroups)
x_reshaped = rearrange(x_ref.squeeze(1),
"b (h p) -> b h p",
p=headdim)
if has_z:
z_ref = rearrange(z_ref.squeeze(1),
"b (h p) -> b h p",
p=headdim)
out_ref = selective_state_update_ref(state_ref,
x_reshaped,
dt_ref,
A_ref,
B_ref,
C_ref,
D=D_ref,
z=z_ref,
dt_bias=dt_bias_ref,
dt_softplus=delta_softplus)
out_ref = rearrange(out_ref, "b h p -> b (h p)").unsqueeze(1)
if long_context:
dtype_atol = {"float16": 2e-2, "bfloat16": 1e-1}
else:
dtype_atol = {"float16": 5e-3, "float32": 2e-3, "bfloat16": 5e-2}
output_cpu = outputs['output'].to(torch.float32).cpu()
present_state_cpu = outputs['present_state'].to(torch.float32).cpu()
np.testing.assert_allclose(out_ref.to(torch.float32).cpu().numpy(),
output_cpu.numpy(),
atol=dtype_atol[dtype])
np.testing.assert_allclose(state_ref.to(torch.float32).cpu().numpy(),
present_state_cpu.numpy(),
atol=dtype_atol[dtype])