TensorRT-LLMs/tensorrt_llm/models/generation_mixin.py
2023-10-15 21:26:20 +08:00

389 lines
16 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import OrderedDict
import tensorrt as trt
from ..functional import Tensor
from ..mapping import Mapping
class GenerationMixin:
def get_transformer_layers(self, mapping, num_layers):
layers_per_pipeline_stage = num_layers // mapping.pp_size
layers_range = list(
range(mapping.pp_rank * layers_per_pipeline_stage,
(mapping.pp_rank + 1) * layers_per_pipeline_stage, 1))
return layers_range
def prepare_basic_inputs(self,
max_batch_size,
max_beam_width,
max_input_len,
max_new_tokens,
num_kv_heads,
head_size,
num_layers,
kv_dtype,
remove_input_padding=False,
use_gpt_attention_plugin=False,
use_gemm_plugin=False,
use_custom_all_reduce=False,
paged_kv_cache=False,
tokens_per_block=64,
gather_all_token_logits=False,
dtype=None,
num_heads=None,
mapping=Mapping(),
max_num_tokens=None):
max_len = max_input_len + max_new_tokens
bb_range_cxt = [1, (max_batch_size + 1) // 2, max_batch_size]
bb_range_gen = [
1, (max_batch_size * max_beam_width + 1) // 2,
max_batch_size * max_beam_width
]
_bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
_beam_width_range = [1, (max_beam_width + 1) // 2, max_beam_width]
inlen_range_cxt = [1, (max_input_len + 1) // 2, max_input_len]
inlen_range_gen = [1, 1, 1]
_mask_len_ctx = [1, (max_input_len + 1) // 2, max_input_len]
_mask_len_gen = [2, (max_len + 1) // 2 + 1, max_len + 1]
_kv_cache_range_ctx = [0, 0, 0]
_kv_cache_range_gen = [1, (max_len + 1) // 2, max_len]
_max_len_range = [0, (max_len + 1) // 2, max_len]
if max_num_tokens is None:
num_tokens_range_ctx = [
1, (max_input_len * max_batch_size + 1) // 2,
max_input_len * max_batch_size
]
num_tokens_range_gen = [
1, max_batch_size * max_beam_width,
max_beam_width * max_batch_size
]
else:
num_tokens_range_ctx = [[
1, (max_num_tokens + 1) // 2, max_num_tokens
]]
num_tokens_range_gen = [[
1, (max_num_tokens + 1) // 2, max_num_tokens
]]
enable_two_optimization_profiles = False
if use_gpt_attention_plugin == False or use_gemm_plugin == False:
use_in_flight_batching = use_gpt_attention_plugin and remove_input_padding and paged_kv_cache
enable_two_optimization_profiles = not use_in_flight_batching
if enable_two_optimization_profiles:
bb_range = [bb_range_cxt, bb_range_gen]
bs_range = [_bs_range, _bs_range]
beam_width_range = [_beam_width_range, _beam_width_range]
inlen_range = [inlen_range_cxt, inlen_range_gen]
mask_len_range = [_mask_len_ctx, _mask_len_gen]
if use_gpt_attention_plugin:
kv_cache_range = [_kv_cache_range_gen, _kv_cache_range_gen]
else:
kv_cache_range = [_kv_cache_range_ctx, _kv_cache_range_gen]
max_len_range = [_max_len_range, _max_len_range]
num_tokens_range = [num_tokens_range_ctx, num_tokens_range_gen]
else:
bb_range = [bb_range_gen]
bs_range = [_bs_range]
beam_width_range = [_beam_width_range]
inlen_range = [[1, 1, max_input_len]]
mask_len_range = [[1, (max_len + 1) // 2 + 1, max_len + 1]]
kv_cache_range = [[0, (max_len + 1) // 2, max_len]]
max_len_range = [_max_len_range]
if max_num_tokens is None:
num_tokens_range = [[
1, max_batch_size * max_beam_width,
max(max_input_len * max_batch_size,
max_beam_width * max_batch_size)
]]
else:
num_tokens_range = num_tokens_range_ctx
input_ids = None
position_ids = None
hidden_states = None
if remove_input_padding:
if mapping.is_first_pp_rank():
input_ids = Tensor(
name='input_ids',
dtype=trt.int32,
shape=[1, -1],
dim_range=OrderedDict([
('batch_size_fake',
[1, 1] if enable_two_optimization_profiles else [1]),
('num_tokens', num_tokens_range),
]))
position_ids = Tensor(
name='position_ids',
dtype=trt.int32,
shape=[1, -1],
dim_range=OrderedDict([
('batch_size_fake',
[1, 1] if enable_two_optimization_profiles else [1]),
('num_tokens', num_tokens_range),
]))
else:
assert dtype is not None
assert num_heads is not None
hidden_states = Tensor(
name='hidden_states_input',
dtype=dtype,
shape=[1, -1, head_size * num_heads],
dim_range=OrderedDict([
('batch_size_fake',
[1, 1] if enable_two_optimization_profiles else [1]),
('num_tokens', num_tokens_range),
('hidden_size',
[head_size * num_heads, head_size *
num_heads] if enable_two_optimization_profiles else
[head_size * num_heads]),
]))
else:
if mapping.is_first_pp_rank():
input_ids = Tensor(name='input_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size_beam_width', bb_range),
('input_len', inlen_range),
]))
position_ids = Tensor(name='position_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size_beam_width', bb_range),
('input_len', inlen_range),
]))
else:
assert dtype is not None
assert num_heads is not None
hidden_states = Tensor(
name='hidden_states_input',
dtype=dtype,
shape=[-1, -1, head_size * num_heads],
dim_range=OrderedDict([
('batch_size_beam_width', bb_range),
('input_len', inlen_range),
('hidden_size',
[head_size * num_heads, head_size *
num_heads] if enable_two_optimization_profiles else
[head_size * num_heads]),
]))
num_kv_heads = (num_kv_heads + mapping.tp_size - 1) // mapping.tp_size
layers_range = self.get_transformer_layers(mapping, num_layers)
past_key_value = []
kv_cache_block_pointers_list = []
if not paged_kv_cache:
for i in layers_range:
kv_dim_range = OrderedDict([
('batch_size_beam_width', bb_range),
('kv', [2, 2] if enable_two_optimization_profiles else [2]),
('num_heads', [num_kv_heads, num_kv_heads]
if enable_two_optimization_profiles else [num_kv_heads]),
('past_key_len', kv_cache_range),
('head_size', [head_size, head_size]
if enable_two_optimization_profiles else [head_size]),
])
kv = Tensor(name=f'past_key_value_{i}',
dtype=kv_dtype,
shape=[-1, 2, num_kv_heads, -1, head_size],
dim_range=kv_dim_range)
past_key_value.append(kv)
kv_cache_block_pointers_list.append(None)
else:
if enable_two_optimization_profiles:
max_blocks_per_seq_range = [
[
math.ceil(kv_cache_range[0][0] / tokens_per_block),
math.ceil(kv_cache_range[0][1] / tokens_per_block),
math.ceil(kv_cache_range[0][2] / tokens_per_block)
],
[
math.ceil(kv_cache_range[1][0] / tokens_per_block),
math.ceil(kv_cache_range[1][1] / tokens_per_block),
math.ceil(kv_cache_range[1][2] / tokens_per_block)
]
]
blocks_range = [
[
bb_range[0][0] * max_blocks_per_seq_range[0][0],
bb_range[0][1] * max_blocks_per_seq_range[0][1],
bb_range[0][2] * max_blocks_per_seq_range[0][2]
],
[
bb_range[1][0] * max_blocks_per_seq_range[1][0],
bb_range[1][1] * max_blocks_per_seq_range[1][1],
bb_range[1][2] * max_blocks_per_seq_range[1][2]
],
]
max_blocks_per_seq_range = [[
x for x in max_blocks_per_seq_range[0]
], [x for x in max_blocks_per_seq_range[1]]]
else:
max_blocks_per_seq_range = [[
math.ceil(kv_cache_range[0][0] / tokens_per_block),
math.ceil(kv_cache_range[0][1] / tokens_per_block),
math.ceil(kv_cache_range[0][2] / tokens_per_block)
]]
blocks_range = [[
bb_range[0][0] * max_blocks_per_seq_range[0][0],
bb_range[0][1] * max_blocks_per_seq_range[0][1],
bb_range[0][2] * max_blocks_per_seq_range[0][2]
]]
max_blocks_per_seq_range = [[
x for x in max_blocks_per_seq_range[0]
]]
kv_dim_range = OrderedDict([
('blocks', blocks_range),
('kv', [2, 2] if enable_two_optimization_profiles else [2]),
('num_heads', [num_kv_heads, num_kv_heads]
if enable_two_optimization_profiles else [num_kv_heads]),
('tokens_per_block', [tokens_per_block, tokens_per_block]
if enable_two_optimization_profiles else [tokens_per_block]),
('head_size', [head_size, head_size]
if enable_two_optimization_profiles else [head_size]),
])
for i in layers_range:
kv_cache_block_pointers = Tensor(
name=f'kv_cache_block_pointers_{i}',
dtype=trt.int64,
shape=[-1, 2, -1],
dim_range=OrderedDict([
('batch_size_beam_width', bb_range),
('kv',
[2, 2] if enable_two_optimization_profiles else [2]),
('max_blocks_per_seq', max_blocks_per_seq_range),
]))
kv_cache_block_pointers_list.append(kv_cache_block_pointers)
past_key_value.append(None)
sequence_length = None
context_lengths = None
host_context_lengths = None
host_past_key_value_lengths = None
attention_mask = None
cache_indirection = None
host_request_types = None
if use_gpt_attention_plugin:
sequence_length = Tensor(
name='sequence_length',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', bb_range)]),
)
host_request_types = Tensor(
name='host_request_types',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', bb_range)]),
)
host_past_key_value_lengths = Tensor(
name='host_past_key_value_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', bb_range)]),
)
context_lengths = Tensor(
name='context_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', bb_range)]),
)
else:
attention_mask = Tensor(
name='attention_mask',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size_beam_width', bb_range),
('mask_len', mask_len_range),
]),
)
if use_gpt_attention_plugin and remove_input_padding:
host_context_lengths = Tensor(
name='host_context_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', bb_range)]),
)
last_token_ids = None
if mapping.is_last_pp_rank() and not gather_all_token_logits:
last_token_ids = Tensor(
name='last_token_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size_last_token_ids', bb_range),
]),
)
cache_indirection = Tensor(
name='cache_indirection',
dtype=trt.int32,
shape=[-1, -1, -1],
dim_range=OrderedDict([
('batch_size_cache', bs_range),
('beam_width', beam_width_range),
('max_seq_len', max_len_range),
]),
)
all_reduce_workspace = None
if use_custom_all_reduce and mapping.tp_size > 1:
# 3 (= buffer + signals_in + signals_out)
workspace_size = 3 * mapping.tp_size
all_reduce_workspace = Tensor(
name='all_reduce_workspace',
dtype=trt.int64,
shape=[workspace_size],
dim_range=OrderedDict([
('all_reduce_size', [workspace_size, workspace_size]
if enable_two_optimization_profiles else [workspace_size])
]))
return {
'input_ids': input_ids,
'hidden_states_input': hidden_states,
'position_ids': position_ids,
'attention_mask': attention_mask,
'sequence_length': sequence_length,
'host_past_key_value_lengths': host_past_key_value_lengths,
'past_key_value': past_key_value,
'last_token_ids': last_token_ids,
'cache_indirection': cache_indirection,
'kv_cache_block_pointers_list': kv_cache_block_pointers_list,
'context_lengths': context_lengths,
'host_context_lengths': host_context_lengths,
'host_request_types': host_request_types,
'all_reduce_workspace': all_reduce_workspace,
}