TensorRT-LLMs/tests/test_layer.py
2023-10-10 23:22:17 -07:00

837 lines
33 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from itertools import product
import numpy as np
import pytest
import tensorrt as trt
import torch
from functional.torch_ref import attention_qkvpacked_ref
from parameterized import parameterized
from polygraphy.backend.trt import (CreateConfig, EngineFromNetwork, Profile,
TrtRunner)
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
from transformers.models.llama.modeling_llama import (LlamaConfig, LlamaMLP,
LlamaRMSNorm)
from utils.util import getSMVersion
import tensorrt_llm
from tensorrt_llm import Tensor
from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy
from tensorrt_llm.layers import (AttentionParams, KeyValueCacheParams,
PositionEmbeddingType)
class TestLayer(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level('error')
def test_group_norm_float32(self):
# test data
dtype = 'float32'
x_data = torch.randn(2, 6, 3, 3)
m = torch.nn.GroupNorm(3, 6)
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
x = Tensor(name='x',
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
gm = tensorrt_llm.layers.GroupNorm(3, 6)
gm.weight.value = m.weight.detach().cpu().numpy()
gm.bias.value = m.bias.detach().cpu().numpy()
output = gm.forward(x).trt_tensor
output.name = 'output'
network.mark_output(output)
# trt run
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
# pytorch run
with torch.no_grad():
ref = m(x_data)
# compare diff
np.testing.assert_allclose(ref.cpu().numpy(),
outputs['output'],
atol=1e-6)
def test_layer_norm_float32(self):
# test data
dtype = 'float32'
x_data = torch.randn(2, 5, 10, 10)
m = torch.nn.LayerNorm([5, 10, 10])
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
x = Tensor(name='x',
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
gm = tensorrt_llm.layers.LayerNorm([5, 10, 10])
gm.weight.value = m.weight.detach().cpu().numpy()
gm.bias.value = m.bias.detach().cpu().numpy()
output = gm.forward(x).trt_tensor
output.name = 'output'
network.mark_output(output)
# trt run
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
# pytorch run
with torch.no_grad():
ref = m(x_data)
# compare diff
np.testing.assert_allclose(ref.cpu().numpy(),
outputs['output'],
atol=1e-6)
def test_rms_norm_float32(self):
# test data
test_shape = [2, 5, 10, 16]
dtype = 'float32'
x_data = torch.randn(*test_shape)
m = LlamaRMSNorm(test_shape[-1]) # LlamaRMSNorm only supports last dim
with torch.no_grad():
m.weight.copy_(torch.rand([test_shape[-1]]))
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
x = Tensor(name='x',
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
gm = tensorrt_llm.layers.RmsNorm(test_shape[-1])
gm.weight.value = m.weight.detach().cpu().numpy()
output = gm.forward(x).trt_tensor
output.name = 'output'
network.mark_output(output)
# trt run
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
# pytorch run
with torch.no_grad():
ref = m(x_data)
# compare diff
np.testing.assert_allclose(ref.cpu().numpy(),
outputs['output'],
atol=1e-6)
def test_gated_mlp_float32(self):
# test data
d_h = 8
ffn_h = 20
test_shape = [2, 3, 5, d_h]
dtype = 'float32'
torch.random.manual_seed(0)
# need rand for 'normalized' values
x_data = torch.randn(*test_shape)
fc = torch.empty(ffn_h, d_h)
torch.nn.init.xavier_uniform_(fc)
gate = torch.empty(ffn_h, d_h)
torch.nn.init.xavier_uniform_(gate)
proj = torch.empty(d_h, ffn_h)
torch.nn.init.xavier_uniform_(proj)
config = LlamaConfig(hidden_size=d_h,
intermediate_size=ffn_h,
hidden_act='silu')
m = LlamaMLP(config)
# Need torch.no_grad() to update the weights of torch.nn.Linear weights
with torch.no_grad():
m.gate_proj.weight.copy_(fc)
m.up_proj.weight.copy_(gate)
m.down_proj.weight.copy_(proj)
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
x = Tensor(name='x',
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
gm = tensorrt_llm.layers.GatedMLP(d_h,
ffn_h,
hidden_act='silu',
bias=False)
# TensorRT-LLM's Linear uses Parameter class which as a 'value' setter
gm.fc.weight.value = fc.cpu().numpy()
gm.gate.weight.value = gate.cpu().numpy()
gm.proj.weight.value = proj.cpu().numpy()
output = gm.forward(x).trt_tensor
output.name = 'output'
network.mark_output(output)
# trt run
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
# pytorch run
with torch.no_grad():
ref = m(x_data)
# compare diff
np.testing.assert_allclose(ref.cpu().numpy(),
outputs['output'],
atol=1e-5)
@parameterized.expand([["float32", False], ["float32", True],
["bfloat16", False], ["bfloat16", True]])
def test_linear(self, dtype, use_plugin):
# Skip tests that are not supported on V100
if getSMVersion() < 80:
if dtype == 'bfloat16':
pytest.skip(
"bfloat16 is not supported in pre-ampere architecture")
# test data
torch.manual_seed(0)
x_data = torch.randn(128, 20, dtype=str_dtype_to_torch(dtype))
m = torch.nn.Linear(20, 30, dtype=str_dtype_to_torch(dtype))
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
if use_plugin:
net.plugin_config.set_gemm_plugin(dtype)
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
x = Tensor(name='x',
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
gm = tensorrt_llm.layers.Linear(20, 30, dtype=dtype)
gm.weight.value = torch_to_numpy(m.weight.detach().cpu())
gm.bias.value = torch_to_numpy(m.bias.detach().cpu())
output = gm.forward(x).trt_tensor
output.name = 'output'
network.mark_output(output)
# trt run
build_engine = EngineFromNetwork(
(builder.trt_builder, net.trt_network),
CreateConfig(bf16=dtype == "bfloat16",
precision_constraints="obey"))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={'x': x_data})
# pytorch run
with torch.no_grad():
ref = m(x_data)
# The absolute tolerance for bfloat16 is increased marginally because
# a single value (out of 4000) breaks tolerance on a 4090 linux/windows.
atols = {"float32": 1e-6, "bfloat16": 1.03 * 1e-2}
# compare diff
np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(),
outputs['output'].to(torch.float32).numpy(),
atol=atols[dtype])
@parameterized.expand(list(product([True, False], [True, False])))
@pytest.mark.skipif(
getSMVersion() < 80,
reason="bfloat16 is not supported in pre-ampere architecture"
) # Skip tests that are not supported in pre-ampere architecture
def test_prompt_tuning_embedding(self, enable_lookup_plugin,
remove_padding):
torch.random.manual_seed(0)
dtype = "bfloat16"
trt_dtype = tensorrt_llm.str_dtype_to_trt(dtype)
torch_dtype = str_dtype_to_torch(dtype)
embedding_dim = 64
batch_size = 8
seq_len = 12
vocab_size = 100
num_embeddings = 128
num_tasks = 3
task_vocab_size = 30
embeddings = torch.randn((num_embeddings, embedding_dim),
dtype=torch_dtype)
prompt_embedding = torch.randn(
(num_tasks * task_vocab_size, embedding_dim), dtype=torch_dtype)
ids = torch.randint(0,
vocab_size, (batch_size, seq_len),
dtype=torch.int32)
request_tasks = torch.randint(0,
num_tasks, (batch_size, ),
dtype=torch.int32)
request_tasks = request_tasks.unsqueeze(-1).expand(*ids.shape)
v_ids = torch.randint(vocab_size,
vocab_size + task_vocab_size,
(batch_size, seq_len),
dtype=torch.int32)
mask = torch.bernoulli(torch.full((batch_size, seq_len),
0.5)).to(torch.int32)
ids = ids * mask + v_ids * (1 - mask)
if remove_padding:
input_ids = ids.flatten().unsqueeze(0)
request_tasks = request_tasks.flatten().unsqueeze(0)
else:
input_ids = ids
builder = tensorrt_llm.Builder()
net = builder.create_network()
if enable_lookup_plugin:
net.plugin_config.lookup_plugin = dtype
with tensorrt_llm.net_guard(net):
ids_tensor = Tensor(name='ids',
shape=[1, -1] if remove_padding else [-1, -1],
dtype=trt.int32)
prompt_embedding_tensor = Tensor(name='prompt_embedding',
shape=[-1, embedding_dim],
dtype=trt_dtype)
request_tasks_tensor = Tensor(name='request_tasks',
shape=[-1, -1],
dtype=trt.int32)
task_vocab_size_tensor = Tensor(name='task_vocab_size',
shape=(1, ),
dtype=trt.int32)
embedding = tensorrt_llm.layers.PromptTuningEmbedding(
num_embeddings, embedding_dim, vocab_size, trt_dtype)
embedding.weight.value = torch_to_numpy(embeddings.detach().cpu())
output = embedding(ids_tensor, prompt_embedding_tensor,
request_tasks_tensor, task_vocab_size_tensor)
net._mark_output(output, "output", dtype=trt_dtype)
profile = (Profile().add(
"ids", (1, 1), input_ids.shape, input_ids.shape).add(
"prompt_embedding", (1, embedding_dim), prompt_embedding.shape,
prompt_embedding.shape).add("request_tasks", (1, 1),
input_ids.shape, input_ids.shape))
build_engine = EngineFromNetwork(
(builder.trt_builder, net.trt_network),
config=CreateConfig(bf16=(dtype == "bfloat16"),
fp16=(dtype == "float16"),
precision_constraints="obey",
profiles=[profile]))
assert build_engine is not None
with TrtRunner(build_engine) as runner:
output = runner.infer(
feed_dict={
'ids':
input_ids,
'prompt_embedding':
prompt_embedding,
'request_tasks':
request_tasks,
'task_vocab_size':
torch.tensor([task_vocab_size], dtype=torch.int32),
})['output']
output = output.to(torch.float32)
embeddings = embeddings.to(torch.float32)
prompt_embedding = prompt_embedding.view(
(num_tasks, task_vocab_size, embedding_dim)).to(torch.float32)
# use loops for clarity, even if it's non-optimal
for b in range(input_ids.shape[0]):
for s in range(input_ids.shape[1]):
token_id = input_ids[b][s]
if token_id < vocab_size:
np.testing.assert_allclose(output[b][s],
embeddings[token_id])
else:
offset_token_id = token_id - vocab_size
task = request_tasks[b][s]
np.testing.assert_allclose(
output[b][s], prompt_embedding[task][offset_token_id])
def test_conv2d_float32(self):
# test data
dtype = 'float32'
x_data = torch.randn(20, 16, 50, 100)
m = torch.nn.Conv2d(16,
33, (3, 5),
stride=(2, 1),
padding=(4, 2),
dilation=(3, 1))
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
x = Tensor(name='x',
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
gm = tensorrt_llm.layers.Conv2d(16,
33, (3, 5),
stride=(2, 1),
padding=(4, 2),
dilation=(3, 1))
gm.weight.value = m.weight.detach().cpu().numpy()
gm.bias.value = m.bias.detach().cpu().numpy()
output = gm.forward(x).trt_tensor
output.name = 'output'
network.mark_output(output)
# trt run
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
# pytorch run
with torch.no_grad():
ref = m(x_data)
# compare diff
np.testing.assert_allclose(ref.cpu().numpy(),
outputs['output'],
atol=1e-5)
def test_conv_transpose2d_float32(self):
# test data
dtype = 'float32'
x_data = torch.randn(20, 16, 50, 100)
m = torch.nn.ConvTranspose2d(16,
33, (3, 5),
stride=(2, 1),
padding=(4, 2))
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
x = Tensor(name='x',
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
gm = tensorrt_llm.layers.ConvTranspose2d(16,
33, (3, 5),
stride=(2, 1),
padding=(4, 2),
dilation=(3, 1))
gm.weight.value = m.weight.detach().cpu().numpy()
gm.bias.value = m.bias.detach().cpu().numpy()
output = gm.forward(x).trt_tensor
output.name = 'output'
network.mark_output(output)
# trt run
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
# pytorch run
with torch.no_grad():
ref = m(x_data)
# compare diff
np.testing.assert_allclose(ref.cpu().numpy(),
outputs['output'],
atol=1e-05)
def test_avg_pooling_2d_float32(self):
# test data
dtype = 'float32'
x_data = torch.randn(2, 16, 50, 32)
m = torch.nn.AvgPool2d((3, 2), stride=(2, 1))
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
x = Tensor(name='x',
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
ap2d = tensorrt_llm.layers.AvgPool2d((3, 2), stride=(2, 1))
output = ap2d.forward(x).trt_tensor
output.name = 'output'
network.mark_output(output)
# trt run
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
# pytorch run
with torch.no_grad():
ref = m(x_data)
# compare diff
np.testing.assert_allclose(ref.cpu().numpy(),
outputs['output'],
atol=1e-6)
@parameterized.expand([("bfloat16", "float32"), ("float32", "bfloat16")])
def test_cast_bf16(self, from_dtype, to_dtype):
if getSMVersion() < 80:
pytest.skip("bfloat16 is not supported in pre-ampere architecture")
torch_from_dtype = str_dtype_to_torch(from_dtype)
torch_to_dtype = str_dtype_to_torch(to_dtype)
x_data = torch.randn(2, 2, 3, 6, dtype=torch_from_dtype)
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
x = Tensor(name='x',
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt(from_dtype))
cast = tensorrt_llm.layers.Cast(to_dtype)
output = cast.forward(x).trt_tensor
output.name = 'output'
network.mark_output(output)
# trt run
build_engine = EngineFromNetwork(
(builder.trt_builder, net.trt_network),
config=CreateConfig(bf16=True, precision_constraints="obey"))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={'x': x_data})
# pytorch run
ref = x_data.to(torch_to_dtype).to(torch.float32)
# compare diff
np.testing.assert_allclose(ref.cpu().numpy(),
outputs['output'].to(torch.float32),
atol=0)
def test_cast(self):
dtype = 'float16'
x_data = torch.randn(2, 2, 3, 6, dtype=torch.float16)
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
x = Tensor(name='x',
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
cast = tensorrt_llm.layers.Cast('float32')
output = cast.forward(x).trt_tensor
output.name = 'output'
network.mark_output(output)
# trt run
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
# pytorch run
ref = x_data.to(torch.float32)
# compare diff
np.testing.assert_allclose(ref.cpu().numpy(),
outputs['output'],
atol=1e-6)
def test_mish(self):
# test data
dtype = 'float32'
x_data = torch.randn(2, 2, 3, 6)
m = torch.nn.Mish()
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
x = Tensor(name='x',
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
mish = tensorrt_llm.layers.Mish()
output = mish.forward(x).trt_tensor
output.name = 'output'
network.mark_output(output)
# trt run
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
# pytorch run
with torch.no_grad():
ref = m(x_data)
# compare diff
np.testing.assert_allclose(ref.cpu().numpy(),
outputs['output'],
atol=1e-6)
@parameterized.expand([
(12, 512, 16, 64, 'float16', PositionEmbeddingType.alibi, False),
(128, 128, 12, 32, 'float16', PositionEmbeddingType.alibi, True),
(1, 200, 8, 128, 'float32', PositionEmbeddingType.alibi, False),
(48, 30, 24, 80, 'float32', PositionEmbeddingType.alibi, True),
(2, 128, 4, 64, 'float16', PositionEmbeddingType.learned_absolute, True,
True),
(2, 128, 4, 64, 'float32', PositionEmbeddingType.learned_absolute, True,
True),
])
def test_attention(self,
batch_size,
seq_len,
head_num,
head_size,
dtype,
pos_emb_type,
causal_mask,
use_plugin=False):
hidden_size = head_num * head_size
torch_dtype = str_dtype_to_torch(dtype)
mean = 0.0
std_dev = 0.02 if dtype == "float32" else 0.005
hidden_states = torch.empty(size=[batch_size, seq_len, hidden_size],
dtype=torch_dtype,
device='cuda')
hidden_states.normal_(mean, std_dev)
#TODO: can change to random after torch ref support non padding format
context_lengths = torch.full([batch_size],
seq_len,
dtype=torch.int32,
device='cuda')
if use_plugin:
# Only generate 1 step
max_seq_len = seq_len + 1
# zero means "valid" token, one means invalid. Here since torch ref does not support mask, make it all valid.
host_past_key_value_lengths = torch.tensor([0] * batch_size,
dtype=torch.int32)
sequence_length = torch.full([batch_size],
seq_len,
dtype=torch.int32,
device='cuda')
# even in the the context phase, kv cache tensors can not be empty tensor for plugin, the actual shape info
# otherwise, there will be cublas execution error.
# are passed to plugin by the `sequence_length` tensor
kv_shape = (batch_size, 2, head_num, max_seq_len, head_size)
past_key_value = torch.randn(kv_shape,
dtype=torch_dtype,
device='cuda')
cache_indirection = torch.full((
batch_size,
1,
max_seq_len,
),
0,
dtype=torch.int32,
device='cuda')
host_request_types = torch.tensor([0] * batch_size,
dtype=torch.int32,
device='cpu')
q_weight = torch.empty(size=[hidden_size, hidden_size],
dtype=torch_dtype)
torch.nn.init.xavier_uniform_(q_weight)
# The initialization here is chosen to minimize computation after the
# QKV BMMs in order to reduce the amount of differences from FP accumulation.
# We set K and V weights to the identity matrix so that the input is copied
# without doing any accumulation. Additionally, we set the output projection
# to the identity for the same reason.
# The main purpose of these tests is to check the QK^T BMM + Softmax + SV BMM.
eye_weight = torch.eye(hidden_size, dtype=torch_dtype)
qkv_weight = torch.cat([q_weight, eye_weight, eye_weight], dim=-1)
out_weight = eye_weight
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
if use_plugin:
net.plugin_config.gpt_attention_plugin = dtype
with tensorrt_llm.net_guard(net):
trt_hidden_states = Tensor(
name='hidden_states',
shape=hidden_states.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
context_lengths_tensor = Tensor(
name='context_lengths',
shape=context_lengths.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
if use_plugin:
host_request_types_tensor = Tensor(
name='host_request_types',
shape=host_request_types.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
past_key_value_tensor = Tensor(
name='past_key_value',
shape=tuple(past_key_value.shape),
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
sequence_length_tensor = Tensor(
name='sequence_length',
shape=tuple(sequence_length.shape),
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
host_past_key_value_lengths_tensor = Tensor(
name='host_past_key_value_lengths',
shape=tuple(host_past_key_value_lengths.shape),
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
cache_indirection_tensor = Tensor(
name='cache_indirection',
shape=tuple(cache_indirection.shape),
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
mask_type = tensorrt_llm.layers.AttentionMaskType.padding
if causal_mask:
mask_type = tensorrt_llm.layers.AttentionMaskType.causal
attn_layer = tensorrt_llm.layers.Attention(
hidden_size,
head_num,
max_position_embeddings=seq_len,
attention_mask_type=mask_type,
position_embedding_type=pos_emb_type,
bias=False)
attn_layer.qkv.weight.value = np.ascontiguousarray(
qkv_weight.cpu().numpy().transpose([1, 0]))
attn_layer.dense.weight.value = np.ascontiguousarray(
out_weight.cpu().numpy().transpose([1, 0]))
input_tensor = trt_hidden_states
if use_plugin:
output, present_key_value = attn_layer(
input_tensor,
use_cache=True,
kv_cache_params=KeyValueCacheParams(
past_key_value=[past_key_value_tensor],
host_past_key_value_lengths=
host_past_key_value_lengths_tensor,
cache_indirection=cache_indirection_tensor),
attention_params=AttentionParams(
sequence_length=sequence_length_tensor,
context_lengths=context_lengths_tensor,
host_request_types=host_request_types_tensor,
max_context_length=seq_len))
assert isinstance(output, Tensor)
output = output
present_key_value.mark_output(
'present_key_value', tensorrt_llm.str_dtype_to_trt(dtype))
else:
output = attn_layer(input_tensor)
output.mark_output('output', tensorrt_llm.str_dtype_to_trt(dtype))
builder_config = builder.create_builder_config(name='attention',
precision=dtype)
# Build engine
engine_buffer = builder.build_engine(net, builder_config)
session = tensorrt_llm.runtime.Session.from_serialized_engine(
engine_buffer)
stream = torch.cuda.current_stream().cuda_stream
if use_plugin:
inputs = {
'hidden_states': hidden_states,
'past_key_value': past_key_value,
'sequence_length': sequence_length,
'host_past_key_value_lengths': host_past_key_value_lengths,
'context_lengths': context_lengths,
'host_request_types': host_request_types,
'cache_indirection': cache_indirection
}
outputs = {
'output':
torch.empty(hidden_states.shape,
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
device='cuda'),
'present_key_value':
past_key_value,
}
else:
inputs = {
'hidden_states': hidden_states,
'context_lengths': context_lengths,
}
outputs = {
'output':
torch.empty(hidden_states.shape,
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
device='cuda'),
}
session.run(inputs=inputs, outputs=outputs, stream=stream)
torch.cuda.synchronize()
packed_torch_qkv = hidden_states.to("cuda") @ qkv_weight.to("cuda")
packed_torch_qkv = packed_torch_qkv.reshape(
[batch_size, seq_len, 3, head_num, head_size])
alibi_bias = None
if pos_emb_type == PositionEmbeddingType.alibi:
mask = torch.ones(size=[batch_size, seq_len], device="cuda")
alibi_bias = build_alibi_tensor(mask, head_num, torch.float32)
alibi_bias = alibi_bias.reshape([batch_size, head_num, 1, seq_len])
mha_out, _ = attention_qkvpacked_ref(packed_torch_qkv,
causal=causal_mask,
upcast=False,
bias=alibi_bias)
torch_out = mha_out.reshape([batch_size, seq_len, hidden_size])
trt_output = outputs['output']
a_tol = 5e-5 if (dtype == "float32" and not use_plugin) else 2e-3
np.testing.assert_allclose(torch_out.cpu().numpy(),
trt_output.cpu().numpy(),
atol=a_tol,
verbose=True)
if __name__ == '__main__':
unittest.main()