mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Timur Abishev <abishev.timur@gmail.com> Co-authored-by: MahmoudAshraf97 <hassouna97.ma@gmail.com> Co-authored-by: Saeyoon Oh <saeyoon.oh@furiosa.ai> Co-authored-by: hattizai <hattizai@gmail.com>
245 lines
11 KiB
Python
245 lines
11 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import sys
|
|
import unittest
|
|
from itertools import product
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from parameterized import parameterized
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm import Tensor
|
|
from tensorrt_llm._utils import str_dtype_to_torch
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
from torch_ref import mamba_conv1d_ref
|
|
from utils.util import skip_bf16_pre_ampere, unittest_name_func
|
|
|
|
|
|
class TestFunctional(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
tensorrt_llm.logger.set_level('error')
|
|
|
|
@parameterized.expand(
|
|
list(
|
|
product([2048], [4], ['context', 'generation'],
|
|
['float16', 'float32', 'bfloat16'], [5], [16], [0, 64],
|
|
[False, True], [False, True])) +
|
|
# long sequence tests to cover the int overflow issue
|
|
list(
|
|
product([5376], [4], ['context'], ['float16', 'bfloat16'], [2],
|
|
[131072], [10240], [False, True], [False, True])),
|
|
name_func=unittest_name_func)
|
|
def test_mamba_conv1d(self, dim, dconv, req_type, dtype, batch_size,
|
|
max_seq_len, stride_size, remove_padding, apply_silu):
|
|
# Skip tests that are not supported in pre-ampere architecture
|
|
skip_bf16_pre_ampere(dtype)
|
|
if max_seq_len == 131072:
|
|
total_gpu_mem = torch.cuda.get_device_properties(0).total_memory
|
|
if total_gpu_mem <= 33 * 1024**3:
|
|
pytest.skip(
|
|
"The long sequence test needs at least 33GB memory, skipping"
|
|
)
|
|
|
|
device = "cuda"
|
|
seq_len = max_seq_len if req_type == 'context' else 1
|
|
with_stride = stride_size > 0
|
|
pre_stride = stride_size
|
|
post_stride = 64 if with_stride else 0
|
|
mean = 0.0
|
|
std_dev = 1.0 if dtype == "float32" else 0.5
|
|
|
|
# test data
|
|
last_token_ids_trt = None
|
|
torch.random.manual_seed(0)
|
|
if remove_padding and req_type == 'context':
|
|
last_token_ids = torch.randint(1,
|
|
seq_len + 1, (batch_size, ),
|
|
dtype=torch.int32)
|
|
last_token_ids[0] = seq_len
|
|
host_context_length = last_token_ids.detach().clone().cpu()
|
|
last_token_ids_trt = torch.cumsum(last_token_ids,
|
|
dim=0,
|
|
dtype=torch.int32).to(device)
|
|
else:
|
|
last_token_ids = torch.ones(
|
|
(batch_size, ), dtype=torch.int32, device=device) * seq_len
|
|
host_context_length = last_token_ids.detach().clone().cpu()
|
|
last_token_ids_trt = last_token_ids
|
|
if req_type == 'context':
|
|
past_conv_state = torch.zeros([batch_size, dim, dconv - 1],
|
|
dtype=str_dtype_to_torch(dtype),
|
|
device=device)
|
|
else:
|
|
past_conv_state = torch.randn(batch_size,
|
|
dim,
|
|
dconv - 1,
|
|
dtype=str_dtype_to_torch(dtype),
|
|
device=device)
|
|
past_conv_state.normal_(mean, std_dev)
|
|
|
|
conv_weight = torch.randn([dim, 1, dconv],
|
|
dtype=str_dtype_to_torch(dtype),
|
|
device=device)
|
|
conv_bias = torch.randn([dim],
|
|
dtype=str_dtype_to_torch(dtype),
|
|
device=device)
|
|
|
|
host_request_types = torch.tensor([0 if req_type == 'context' else 1] *
|
|
batch_size,
|
|
dtype=torch.int32)
|
|
x = torch.empty(batch_size,
|
|
dim,
|
|
seq_len,
|
|
device=device,
|
|
dtype=str_dtype_to_torch(dtype))
|
|
x.normal_(mean, std_dev)
|
|
|
|
x_trt = x.detach().permute(0, 2, 1).contiguous()
|
|
if remove_padding and req_type == 'context':
|
|
x_batches = []
|
|
for b in range(batch_size):
|
|
x_batches.append(x_trt[b, :last_token_ids[b], :])
|
|
x_trt = torch.cat(x_batches, dim=0)
|
|
past_conv_state_trt = past_conv_state.permute(0, 2, 1).contiguous()
|
|
conv_weight_trt = conv_weight.permute(1, 2, 0).contiguous()
|
|
|
|
output_trt = torch.zeros_like(x_trt)
|
|
present_conv_state_trt = torch.zeros_like(past_conv_state_trt)
|
|
if with_stride:
|
|
base_shape = [x_trt.shape[i] for i in range(len(x_trt.shape) - 1)]
|
|
pad_pre_shape = base_shape + [pre_stride]
|
|
pad_post_shape = base_shape + [post_stride]
|
|
pad_pre = torch.randn(pad_pre_shape,
|
|
device=device,
|
|
dtype=str_dtype_to_torch(dtype))
|
|
pad_post = torch.randn(pad_post_shape,
|
|
device=device,
|
|
dtype=str_dtype_to_torch(dtype))
|
|
x_trt = torch.cat([pad_pre, x_trt, pad_post], dim=-1).contiguous()
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
if remove_padding:
|
|
net.plugin_config.remove_input_padding = True
|
|
else:
|
|
net.plugin_config.remove_input_padding = False
|
|
net.plugin_config.paged_state = False
|
|
with tensorrt_llm.net_guard(net):
|
|
x_tensor = Tensor(name='input',
|
|
shape=x_trt.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
conv_weight_tensor = Tensor(
|
|
name='conv_weight',
|
|
shape=conv_weight_trt.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
conv_bias_tensor = Tensor(
|
|
name='conv_bias',
|
|
shape=conv_bias.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
past_conv_state_tensor = Tensor(
|
|
name='past_conv_state',
|
|
shape=past_conv_state_trt.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
host_request_types_tensor = Tensor(
|
|
name='host_request_types',
|
|
shape=host_request_types.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
last_token_ids_tensor = Tensor(
|
|
name='last_token_ids',
|
|
shape=last_token_ids.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_context_length_tensor = None
|
|
if remove_padding:
|
|
host_context_length_tensor = Tensor(
|
|
name='host_context_lengths',
|
|
shape=host_context_length.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
outputs = tensorrt_llm.functional.mamba_conv1d(
|
|
x_tensor,
|
|
past_conv_state_tensor,
|
|
conv_weight_tensor,
|
|
conv_bias_tensor,
|
|
host_request_types_tensor,
|
|
last_token_ids_tensor,
|
|
dim,
|
|
dconv,
|
|
dtype,
|
|
pre_stride=pre_stride,
|
|
post_stride=post_stride,
|
|
host_context_lengths=host_context_length_tensor,
|
|
apply_silu=apply_silu)
|
|
net._mark_output(outputs[0],
|
|
'output',
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
net._mark_output(outputs[1],
|
|
'present_conv_state',
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
inputs = {
|
|
'input': x_trt,
|
|
'conv_weight': conv_weight_trt,
|
|
'conv_bias': conv_bias,
|
|
'past_conv_state': past_conv_state_trt,
|
|
'host_request_types': host_request_types,
|
|
'last_token_ids': last_token_ids_trt,
|
|
}
|
|
if remove_padding:
|
|
inputs['host_context_lengths'] = host_context_length
|
|
outputs = {
|
|
'output': output_trt,
|
|
'present_conv_state': present_conv_state_trt
|
|
}
|
|
stream = torch.cuda.current_stream()
|
|
builder_config = builder.create_builder_config(precision=dtype, )
|
|
engine = builder.build_engine(net, builder_config)
|
|
session = tensorrt_llm.runtime.Session.from_serialized_engine(engine)
|
|
session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
|
|
torch.cuda.synchronize()
|
|
|
|
out_ref = torch.zeros_like(x)
|
|
present_conv_state_ref = torch.zeros_like(past_conv_state)
|
|
|
|
for b in range(batch_size):
|
|
out_ref[b:b + 1, :, :host_context_length[b].item(
|
|
)], present_conv_state_ref[b:b + 1, :, :] = mamba_conv1d_ref(
|
|
x[b:b + 1, :, :host_context_length[b].item()],
|
|
past_conv_state[b:b + 1, :, :], conv_weight, conv_bias,
|
|
apply_silu)
|
|
present_conv_state_ref = present_conv_state_ref.permute(0, 2,
|
|
1).contiguous()
|
|
out_ref = out_ref.permute(0, 2, 1).contiguous()
|
|
|
|
if remove_padding and req_type == 'context':
|
|
out_ref_batches = []
|
|
for b in range(batch_size):
|
|
out_ref_batches.append(out_ref[b, :host_context_length[b], :])
|
|
out_ref = torch.cat(out_ref_batches, dim=0)
|
|
|
|
dtype_atol = {"float16": 1e-2, "float32": 2e-3, "bfloat16": 1e-1}
|
|
|
|
np.testing.assert_allclose(out_ref.to(torch.float32).cpu().numpy(),
|
|
output_trt.to(torch.float32).cpu().numpy(),
|
|
atol=dtype_atol[dtype])
|
|
np.testing.assert_allclose(
|
|
present_conv_state_ref.to(torch.float32).cpu().numpy(),
|
|
present_conv_state_trt.to(torch.float32).cpu().numpy(),
|
|
atol=dtype_atol[dtype])
|