mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
492 lines
20 KiB
Python
492 lines
20 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import tensorrt as trt
|
|
import torch
|
|
from transformers import GPT2Config
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm as tllm
|
|
from tensorrt_llm import Tensor, net_guard
|
|
from tensorrt_llm._utils import torch_to_numpy
|
|
from tensorrt_llm.functional import PositionEmbeddingType, gpt_attention
|
|
from tensorrt_llm.graph_rewriting import (AnalysisPatternManager,
|
|
FLayerInfoMemo,
|
|
FuseAttentionWithBiasPass, Layer,
|
|
PatternAnalyzer, PatternRewriter,
|
|
RewritePatternManager)
|
|
from tensorrt_llm.layers import Linear
|
|
from tensorrt_llm.quantization import QuantMode
|
|
|
|
|
|
# Borrowed from test_gpt_attention.py, and pruned unnecessary logics.
|
|
def create_gpt_attention_network(attention_type='gpt2_attention',
|
|
dtype='float32') -> tensorrt_llm.Network:
|
|
|
|
def _construct_execution(
|
|
shape_dict,
|
|
weight,
|
|
bias,
|
|
num_heads,
|
|
hidden_size,
|
|
dtype,
|
|
enable_multi_query_attention,
|
|
in_len,
|
|
):
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
net.plugin_config.set_gpt_attention_plugin(dtype)
|
|
net.plugin_config.remove_input_padding = True
|
|
|
|
head_size = hidden_size // num_heads
|
|
with tensorrt_llm.net_guard(net):
|
|
x_tensor = Tensor(name='input',
|
|
shape=shape_dict['input'],
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
past_key_value_tensor = Tensor(
|
|
name='past_key_value',
|
|
shape=tuple(shape_dict['past_key_value']),
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
sequence_length_tensor = Tensor(
|
|
name='sequence_length',
|
|
shape=shape_dict['sequence_length'],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_past_key_value_lengths_tensor = Tensor(
|
|
name='host_past_key_value_lengths',
|
|
shape=shape_dict['host_past_key_value_lengths'],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_max_kv_cache_lengths_tensor = Tensor(
|
|
name='host_max_kv_cache_lengths',
|
|
shape=shape_dict['host_max_kv_cache_lengths'],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
context_lengths_tensor = Tensor(
|
|
name='context_lengths',
|
|
shape=shape_dict['context_lengths'],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_request_types_tensor = Tensor(
|
|
name='host_request_types',
|
|
shape=shape_dict['host_request_types'],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_context_lengths_tensor = None
|
|
if net.plugin_config.remove_input_padding:
|
|
host_context_lengths_tensor = Tensor(
|
|
name='host_context_lengths',
|
|
shape=shape_dict['host_context_lengths'],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
cache_indirection_tensor = Tensor(
|
|
name='cache_indirection',
|
|
shape=shape_dict['cache_indirection'],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
linear = Linear(hidden_size,
|
|
weight.size()[-1],
|
|
bias=attention_type in [
|
|
'gpt2_attention', 'llama_attention',
|
|
'gpt_bigcode_attention'
|
|
])
|
|
linear.weight.value = np.ascontiguousarray(
|
|
torch_to_numpy(weight.T.cpu()))
|
|
if attention_type in [
|
|
'gpt2_attention', 'llama_attention', 'gpt_bigcode_attention'
|
|
]:
|
|
linear.bias.value = torch_to_numpy(bias.cpu())
|
|
qkv = linear(x_tensor)
|
|
|
|
rotary_embedding_dim = head_size if attention_type in [
|
|
'llama_attention', 'gptj_attention'
|
|
] else 0
|
|
|
|
if attention_type == 'llama_attention':
|
|
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
|
|
elif attention_type == 'gptj_attention':
|
|
position_embedding_type = PositionEmbeddingType.rope_gptj
|
|
else:
|
|
position_embedding_type = PositionEmbeddingType.learned_absolute
|
|
outputs = tensorrt_llm.functional.gpt_attention(
|
|
tensor=qkv,
|
|
past_key_value=past_key_value_tensor,
|
|
sequence_length=sequence_length_tensor,
|
|
host_past_key_value_lengths=host_past_key_value_lengths_tensor,
|
|
host_max_kv_cache_lengths=host_max_kv_cache_lengths_tensor,
|
|
context_lengths=context_lengths_tensor,
|
|
cache_indirection=cache_indirection_tensor,
|
|
host_request_types=host_request_types_tensor,
|
|
num_heads=num_heads,
|
|
num_kv_heads=1 if enable_multi_query_attention else num_heads,
|
|
hidden_size_per_head=head_size,
|
|
q_scaling=1.0,
|
|
max_context_length=in_len,
|
|
rotary_embedding_dim=rotary_embedding_dim,
|
|
position_embedding_type=position_embedding_type,
|
|
kv_orig_quant_scale=None,
|
|
kv_quant_orig_scale=None,
|
|
kv_cache_quant_mode=QuantMode.from_description(
|
|
use_int8_kv_cache=False),
|
|
kv_cache_block_pointers=None,
|
|
host_context_lengths=host_context_lengths_tensor)
|
|
|
|
net._mark_output(outputs[0],
|
|
'output',
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
net._mark_output(outputs[1],
|
|
'present_key_value',
|
|
dtype=tensorrt_llm.str_dtype_to_trt('float16'))
|
|
|
|
return net
|
|
|
|
enable_multi_query_attention = False
|
|
batch_size = 4
|
|
in_len = 128
|
|
max_seq_len = 148
|
|
hidden_size = 1024
|
|
num_heads = 16
|
|
head_size = hidden_size // num_heads
|
|
kv_num_heads = 1 if enable_multi_query_attention and attention_type == 'gpt_bigcode_attention' else num_heads
|
|
qkv_hidden_size = hidden_size + 2 * kv_num_heads * head_size
|
|
shape_dict = {
|
|
'weight': (hidden_size, qkv_hidden_size),
|
|
'bias': (qkv_hidden_size, ),
|
|
'past_key_value': (batch_size, 2, kv_num_heads, max_seq_len, head_size),
|
|
'present_key_value':
|
|
(batch_size, 2, kv_num_heads, max_seq_len, head_size),
|
|
'sequence_length': (batch_size, ),
|
|
'context_lengths': (batch_size, ),
|
|
'host_context_lengths': (batch_size, ),
|
|
'host_request_types': (batch_size, ),
|
|
'max_input_sequence_length': (in_len, ),
|
|
'cache_indirection': (batch_size, 1, max_seq_len),
|
|
'input': (batch_size, in_len, hidden_size),
|
|
'output': (batch_size, in_len, hidden_size),
|
|
'host_past_key_value_lengths': (batch_size, ),
|
|
'host_max_kv_cache_lengths': (1, )
|
|
}
|
|
|
|
weight = torch.randn(shape_dict['weight'],
|
|
dtype=tllm._utils.str_dtype_to_torch(dtype),
|
|
device='cuda') * 1e-3
|
|
bias = torch.randn(shape_dict['bias'],
|
|
dtype=tllm._utils.str_dtype_to_torch(dtype),
|
|
device='cuda') * 1e-2
|
|
|
|
configuration = GPT2Config(
|
|
hidden_size=hidden_size,
|
|
n_layer=1,
|
|
n_head=num_heads,
|
|
vocab_size=51200,
|
|
use_cache=True,
|
|
resid_pdrop=0,
|
|
embd_pdrop=0,
|
|
attn_pdrop=0,
|
|
hidden_act='gelu',
|
|
torch_dtype=dtype,
|
|
)
|
|
attention = GPT2Attention(configuration).cuda().eval()
|
|
|
|
attention.c_attn.weight = torch.nn.parameter.Parameter(
|
|
data=weight.clone().detach(), requires_grad=False)
|
|
attention.c_attn.bias = torch.nn.parameter.Parameter(
|
|
data=bias.clone().detach(), requires_grad=False)
|
|
attention.c_proj.weight = torch.nn.parameter.Parameter(
|
|
data=torch.eye(hidden_size,
|
|
dtype=tllm._utils.str_dtype_to_torch(dtype),
|
|
device='cuda'),
|
|
requires_grad=False)
|
|
attention.c_proj.bias = torch.nn.parameter.Parameter(data=torch.zeros(
|
|
(hidden_size, ),
|
|
dtype=tllm._utils.str_dtype_to_torch(dtype),
|
|
device='cuda'),
|
|
requires_grad=False)
|
|
|
|
return _construct_execution(
|
|
shape_dict=shape_dict,
|
|
weight=weight,
|
|
bias=bias,
|
|
num_heads=num_heads,
|
|
hidden_size=hidden_size,
|
|
dtype='float16',
|
|
enable_multi_query_attention=enable_multi_query_attention,
|
|
in_len=in_len)
|
|
|
|
|
|
class TestNetworkForGraphRewrite(unittest.TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
self.network = create_gpt_attention_network('gpt2_attention', 'float32')
|
|
|
|
def test_num_inputs(self):
|
|
num_inputs = self.network.trt_network.num_inputs
|
|
self.assertEqual(num_inputs, len(self.network.get_inputs()))
|
|
|
|
def test_is_input(self):
|
|
for x in self.network.get_inputs():
|
|
self.assertTrue(self.network.is_input(x))
|
|
|
|
def test_is_output(self):
|
|
for x in self.network.get_outputs():
|
|
self.assertTrue(self.network.is_output(x))
|
|
|
|
def test_get_layer_by_name(self):
|
|
for layer in self.network.get_layers():
|
|
self.assertEqual(layer, self.network.get_layer_by_name(layer.name))
|
|
|
|
def test_get_tensor_parent(self):
|
|
for layer in self.network.get_layers():
|
|
for tensor in layer.get_outputs():
|
|
self.assertEqual(layer, tensor.get_parent())
|
|
|
|
def test_get_tensor_users(self):
|
|
for layer in self.network.get_layers():
|
|
for tensor in layer.get_inputs():
|
|
found_this_user = False
|
|
for user in tensor.get_users():
|
|
if user == layer:
|
|
found_this_user = True
|
|
break
|
|
self.assertTrue(found_this_user)
|
|
|
|
def test_to_dot(self):
|
|
dot_code = self.network.to_dot()
|
|
self.assertTrue('digraph' in dot_code)
|
|
|
|
|
|
class TestLayer(unittest.TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
self.network = create_gpt_attention_network('gpt2_attention', 'float32')
|
|
|
|
def test_as_layer(self):
|
|
for layer in self.network.get_layers():
|
|
self.assertTrue(layer.as_layer())
|
|
|
|
|
|
class NaivePatternRewriter_ReplaceAddWithSub(PatternRewriter):
|
|
|
|
def __init__(self):
|
|
super().__init__('replace_add_with_sub',
|
|
root_layer={trt.LayerType.ELEMENTWISE},
|
|
seperate_match_rewrite=True)
|
|
self.rewrite_count = 0
|
|
|
|
def match(self, layer: Layer):
|
|
# the rewriter will stop at the first matched layer, and the Rewriter will enter rewrite() to do the rewriting.
|
|
return layer.as_layer().op == trt.ElementWiseOperation.SUM
|
|
|
|
def rewrite(self, layer: Layer) -> None:
|
|
# here just a statis for testing, free to add arbitrary logic here.
|
|
self.rewrite_count += 1
|
|
|
|
# the layer here should be an elementwise_SUM layer
|
|
with net_guard(layer.network):
|
|
# There are several stages to replace some subgraph with another subgraph:
|
|
|
|
# Step 1: Get the input tensors and output tensors of the subgraph to replace.
|
|
# - For elementwise_SUM, there are two inputs and one output.
|
|
a, b = layer.get_inputs(0, 1)
|
|
o = layer.get_outputs(0)[0]
|
|
|
|
# Step 2: Create a new subgraph that takes the old one's inputs
|
|
# - here we insert an elementwise_SUB layer, and c is the output
|
|
c = a - b
|
|
|
|
# Step 3: Redirect all the layers depending on the outputs of the old subgraph to the new subgraph's.
|
|
# - After this, the SUM become dangling, and will be pruned by TensorRT when building the engine.
|
|
# - Note that, there is no API in TensorRT python to remove a layer explicitly, the `replace_all_uses_with` is the only way to "remove" a layer.
|
|
o.replace_all_uses_with(c)
|
|
|
|
# Step 4: Mark the all the layers in the old subgraph as removed.
|
|
# - This helps the PatternRewriter to skip the removed layers
|
|
layer.mark_as_removed()
|
|
|
|
|
|
class NaivePatternAnalyzer_CountAdd(PatternAnalyzer):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__('count_add', root_layer={trt.LayerType.ELEMENTWISE})
|
|
self.count = 0
|
|
|
|
def match(self, layer: Layer) -> bool:
|
|
return layer.as_layer().op == trt.ElementWiseOperation.SUM
|
|
|
|
def analyze(self, layer: Layer) -> None:
|
|
self.count += 1
|
|
|
|
|
|
class TestPatternRewriter(unittest.TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
self.network = create_gpt_attention_network('gpt2_attention', 'float32')
|
|
|
|
def test_pattern_rewriter(self):
|
|
patterns = RewritePatternManager()
|
|
patterns.add(
|
|
label='replace_add_to_sub',
|
|
pattern=NaivePatternRewriter_ReplaceAddWithSub(),
|
|
)
|
|
|
|
def get_input_key(layer) -> str:
|
|
inputs = layer.get_inputs(0, 1)
|
|
return '-'.join([x.name for x in inputs])
|
|
|
|
add_input_to_users = {}
|
|
for layer in self.network.get_layers():
|
|
if layer.as_layer(
|
|
).type == trt.LayerType.ELEMENTWISE and layer.as_layer(
|
|
).op == trt.ElementWiseOperation.SUM:
|
|
output_uses = list(layer.get_outputs(0)[0].get_users())
|
|
add_input_to_users[get_input_key(layer)] = output_uses
|
|
|
|
patterns.rewrite(self.network)
|
|
|
|
# Only one layer hit this pattern: Linear/ELEMENTWISE_SUM_0
|
|
self.assertEqual(patterns.get('replace_add_to_sub').rewrite_count, 1)
|
|
|
|
# check all the users of the output tensors of the removed layers are removed, and the danling layer will be finally removed by TensorRT.
|
|
for layer in self.network.get_layers():
|
|
if layer.as_layer(
|
|
).type == trt.LayerType.ELEMENTWISE and layer.as_layer(
|
|
).op == trt.ElementWiseOperation.SUM:
|
|
for out in layer.get_outputs():
|
|
self.assertEqual(len(list(out.get_users())), 0)
|
|
|
|
# check the add layers are replaced by sub layers with same input and output consumers
|
|
for layer in self.network.get_layers():
|
|
if layer.as_layer(
|
|
).type == trt.LayerType.ELEMENTWISE and layer.as_layer(
|
|
).op == trt.ElementWiseOperation.SUB:
|
|
output_uses = layer.get_outputs(0)[0].get_users()
|
|
key = get_input_key(layer)
|
|
if key in add_input_to_users:
|
|
src = [x.name for x in add_input_to_users[key]]
|
|
dst = [x.name for x in output_uses]
|
|
self.assertEqual(src, dst)
|
|
|
|
|
|
class TestPatternAnalyzer(unittest.TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
self.network = create_gpt_attention_network('gpt2_attention', 'float32')
|
|
|
|
def test_pattern_analyzer(self):
|
|
patterns = AnalysisPatternManager()
|
|
patterns.add(
|
|
label='count_add',
|
|
pattern=NaivePatternAnalyzer_CountAdd(),
|
|
)
|
|
|
|
patterns.analyze(self.network)
|
|
# There is only one layer hit this pattern: ELEMENTWISE_SUM_0
|
|
self.assertEqual(patterns.get('count_add').count, 1)
|
|
|
|
|
|
class GPTAttentionPluginRemovePaddingRewritePass(PatternRewriter):
|
|
'''
|
|
This is just a demo for altering the gpt_attention plugin to enable remove_padding by replacing the inputs:
|
|
- tensor, shape from [batch_size, in_len, hidden_size] to [1, batch_size * in_len, hidden_size]
|
|
'''
|
|
|
|
def __init__(self):
|
|
super().__init__('gpt_attention_plugin_remove_padding',
|
|
root_layer={trt.LayerType.PLUGIN_V2})
|
|
|
|
self.count = 0
|
|
|
|
def match_and_rewrite(self, layer: Layer) -> bool:
|
|
if layer.as_layer().type != trt.LayerType.PLUGIN_V2 or \
|
|
layer.as_layer().plugin.plugin_namespace != 'tensorrt_llm' or\
|
|
layer.as_layer().plugin.plugin_type != 'GPTAttention':
|
|
return False
|
|
|
|
flayer = FLayerInfoMemo.instance().get(layer.name)
|
|
assert flayer
|
|
tensor_input: Tensor = flayer.get_input('tensor')
|
|
if tensor_input.shape[0] == 1: # already on remove-padding mode
|
|
return False
|
|
|
|
self.log_info(f'hit gpt_attention plugin: {layer.name}')
|
|
|
|
assert self.args is not None, "args should be passed in from RewritePatternManager.rewrite()"
|
|
batch_size, in_len, hidden_size = self.args['batch_size'], self.args[
|
|
'in_len'], self.args['hidden_size']
|
|
|
|
# record the times of rewriting
|
|
self.count += 1
|
|
|
|
new_inputs = flayer.clone_inputs()
|
|
with net_guard(layer.network):
|
|
# Step 1: create new inputs and repalce the original arglist
|
|
input = Tensor(
|
|
name='tensor',
|
|
dtype=trt.float16,
|
|
shape=(1, batch_size * in_len, hidden_size),
|
|
)
|
|
new_inputs['tensor'] = input
|
|
|
|
# Step 2: create a new plugin instance
|
|
new_outs = gpt_attention(**new_inputs)
|
|
|
|
# Step 3: deprive all the users of the old plugin instance
|
|
flayer.replace_outputs_uses_with(layer.network, new_outs)
|
|
|
|
# Step 4: remove the old plugin instance
|
|
layer.mark_as_removed()
|
|
|
|
return True
|
|
|
|
|
|
class TestGPTAttentionPluginRemovePaddingRewritePass(unittest.TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
self.network = create_gpt_attention_network('gpt2_attention', 'float32')
|
|
|
|
def test_gpt_attention_plugin_remove_padding_rewrite_pass(self):
|
|
with net_guard(self.network):
|
|
for layer in FLayerInfoMemo.instance().data.values():
|
|
print(layer.raw_inputs)
|
|
|
|
RewritePatternManager().instance().add(
|
|
"gpt_attention remove padding",
|
|
GPTAttentionPluginRemovePaddingRewritePass())
|
|
|
|
RewritePatternManager().instance().rewrite(self.network,
|
|
args=dict(
|
|
batch_size=4,
|
|
in_len=128,
|
|
hidden_size=1024))
|
|
|
|
self.assertEqual(
|
|
RewritePatternManager().instance().get(
|
|
"gpt_attention remove padding").count, 1)
|
|
|
|
|
|
class TestFuseAttentionWithBiasPass(unittest.TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
self.network = create_gpt_attention_network('gpt2_attention', 'float32')
|
|
|
|
def test_as_layer(self):
|
|
res = False
|
|
rewriter = FuseAttentionWithBiasPass()
|
|
for layer in self.network.get_layers():
|
|
res = res or rewriter.match_and_rewrite(layer)
|
|
self.assertEqual(res, True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|