mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-24 12:42:54 +08:00
626 lines
27 KiB
Python
626 lines
27 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import math
|
|
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
import tensorrt as trt
|
|
|
|
from .._common import default_net, precision
|
|
from ..functional import (AttentionMaskType, PositionEmbeddingType,
|
|
RotaryScalingType, Tensor, bert_attention, cast, clip,
|
|
concat, constant, expand_mask, generate_alibi_biases,
|
|
generate_alibi_slopes, gpt_attention, matmul, round,
|
|
shape, slice, softmax, split)
|
|
from ..module import Module
|
|
from ..parameter import Parameter
|
|
from ..quantization import QuantMode
|
|
from ..quantization.layers import FP8Linear, FP8RowLinear
|
|
from .linear import ColumnLinear, RowLinear
|
|
|
|
|
|
class AttentionParams:
|
|
|
|
def __init__(self,
|
|
sequence_length: Tensor = None,
|
|
context_lengths: Tensor = None,
|
|
host_context_lengths: Tensor = None,
|
|
max_context_length: int = None,
|
|
host_request_types: Tensor = None,
|
|
encoder_input_lengths: Tensor = None,
|
|
encoder_max_input_length: Tensor = None):
|
|
self.sequence_length = sequence_length
|
|
self.context_lengths = context_lengths
|
|
self.host_context_lengths = host_context_lengths
|
|
# max allowed context length. Required to
|
|
# compute scratch memory size.
|
|
self.max_context_length = max_context_length
|
|
self.host_request_types = host_request_types
|
|
|
|
self.encoder_input_lengths = encoder_input_lengths
|
|
self.encoder_max_input_length = encoder_max_input_length
|
|
|
|
def is_valid_cross_attn(self, do_cross_attention):
|
|
if do_cross_attention:
|
|
if self.encoder_input_lengths is None:
|
|
return False
|
|
if self.encoder_max_input_length is None:
|
|
return False
|
|
|
|
def is_valid(self, gpt_attention_plugin, remove_input_padding):
|
|
if gpt_attention_plugin:
|
|
if self.sequence_length is None:
|
|
return False
|
|
if self.context_lengths is None:
|
|
return False
|
|
if self.host_request_types is None:
|
|
return False
|
|
if self.max_context_length is None:
|
|
return False
|
|
|
|
if remove_input_padding:
|
|
if self.host_context_lengths is None:
|
|
return False
|
|
if not gpt_attention_plugin:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class KeyValueCacheParams:
|
|
|
|
def __init__(self,
|
|
past_key_value: List[Tensor] = None,
|
|
host_past_key_value_lengths: Tensor = None,
|
|
kv_cache_block_pointers: List[Tensor] = None,
|
|
cache_indirection: Tensor = None,
|
|
past_key_value_length: Tensor = None):
|
|
self.past_key_value = past_key_value
|
|
self.host_past_key_value_lengths = host_past_key_value_lengths
|
|
self.kv_cache_block_pointers = kv_cache_block_pointers
|
|
self.cache_indirection = cache_indirection
|
|
# self.past_key_value_length = past_key_value_length
|
|
|
|
def get_first_past_key_value(self):
|
|
if self.past_key_value is None:
|
|
return None
|
|
return self.past_key_value[0]
|
|
|
|
def get_first_kv_cache_block_pointers(self):
|
|
if self.kv_cache_block_pointers is None:
|
|
return None
|
|
return self.kv_cache_block_pointers[0]
|
|
|
|
def is_valid(self, gpt_attention_plugin):
|
|
if gpt_attention_plugin:
|
|
if self.host_past_key_value_lengths is None:
|
|
return False
|
|
if self.cache_indirection is None:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class Attention(Module):
|
|
|
|
def __init__(self,
|
|
hidden_size,
|
|
num_attention_heads,
|
|
num_kv_heads=None,
|
|
max_position_embeddings=1024,
|
|
num_layers=1,
|
|
apply_query_key_layer_scaling=False,
|
|
attention_mask_type=AttentionMaskType.padding,
|
|
bias=True,
|
|
dtype=None,
|
|
position_embedding_type=PositionEmbeddingType.learned_absolute,
|
|
rotary_embedding_base=10000.0,
|
|
rotary_embedding_scaling=None,
|
|
use_int8_kv_cache=False,
|
|
rotary_embedding_percentage=1.0,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
multi_block_mode=False,
|
|
quant_mode: QuantMode = QuantMode(0),
|
|
q_scaling=1.0,
|
|
cross_attention=False,
|
|
relative_attention=False,
|
|
max_distance=0,
|
|
num_buckets=0,
|
|
instance_id: int = 0):
|
|
super().__init__()
|
|
|
|
self.cross_attention = cross_attention
|
|
self.attention_mask_type = attention_mask_type
|
|
self.attention_head_size = hidden_size // num_attention_heads
|
|
assert num_attention_heads % tp_size == 0, \
|
|
"num_attention_heads must be divisible by tp_size"
|
|
self.num_attention_heads = num_attention_heads // tp_size
|
|
self.num_attention_kv_heads = (
|
|
num_kv_heads + tp_size - 1
|
|
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
|
self.hidden_size = hidden_size // tp_size
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.tp_size = tp_size
|
|
self.tp_rank = tp_rank
|
|
|
|
self.num_layers = num_layers
|
|
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
|
self.norm_factor = math.sqrt(self.attention_head_size)
|
|
self.q_scaling = q_scaling
|
|
if self.apply_query_key_layer_scaling:
|
|
self.norm_factor *= self.num_layers
|
|
self.q_scaling *= self.num_layers
|
|
# Whether to scale ALiBi bias. Mathematically, it's equivalent to
|
|
# normalizing QK after adding bias.
|
|
# - False, inv_sqrt_Dh * Q*K^T + alibi_bias
|
|
# - True, inv_sqrt_Dh * Q*K^T + inv_sqrt_Dh * alibi_bias
|
|
self.scale_alibi_bias = position_embedding_type == PositionEmbeddingType.alibi_with_scale
|
|
|
|
self.position_embedding_type = position_embedding_type
|
|
self.multi_block_mode = multi_block_mode
|
|
|
|
self.relative_attention = relative_attention
|
|
self.max_distance = max_distance
|
|
|
|
self.rotary_embedding_base = rotary_embedding_base
|
|
self.rotary_embedding_scale_type = RotaryScalingType.none
|
|
self.rotary_embedding_scale = 1.0
|
|
if rotary_embedding_scaling is not None:
|
|
assert rotary_embedding_scaling["type"] in ["linear", "dynamic"]
|
|
self.rotary_embedding_scale_type = RotaryScalingType.linear if rotary_embedding_scaling[
|
|
"type"] == "linear" else RotaryScalingType.dynamic
|
|
self.rotary_embedding_scale = rotary_embedding_scaling["factor"]
|
|
assert self.rotary_embedding_scale > 1.0
|
|
self.rotary_embedding_dim = 0
|
|
if self.position_embedding_type.is_rope():
|
|
self.rotary_embedding_dim = int(self.attention_head_size *
|
|
rotary_embedding_percentage)
|
|
# TODO: Once we add RotaryEmbedding outside GPTAttention plugin,
|
|
# we need to set it up here
|
|
|
|
self.dtype = dtype
|
|
self.quant_mode = quant_mode
|
|
if use_int8_kv_cache:
|
|
# TODO: remove use_int8_kv_cache as can be replaced by quant_mode.has_kv_cache_quant()
|
|
# Merge int8 setting into quant_mode
|
|
self.quant_mode = self.quant_mode.set_int8_kv_cache()
|
|
self.use_int8_kv_cache = use_int8_kv_cache
|
|
if self.quant_mode.has_kv_cache_quant():
|
|
self.kv_orig_quant_scale = Parameter(shape=(1, ), dtype='float32')
|
|
self.kv_quant_orig_scale = Parameter(shape=(1, ), dtype='float32')
|
|
else:
|
|
self.register_parameter('kv_orig_quant_scale', None)
|
|
self.register_parameter('kv_quant_orig_scale', None)
|
|
|
|
# The output feature size is therefore (h/tp + 2*kvh/tp) * d, where h is num_heads,
|
|
# d is head_size, kvh is the num_kv_heads and tp is tensor_parallel_size.
|
|
# In ColumnLinear op, the output dim is calculated by (h + 2*kvh) * d / tp,
|
|
# which matches the desired output size (h/tp + 2*kvh/tp) * d after splitting
|
|
self.use_fp8_qdq = self.quant_mode.has_fp8_qdq()
|
|
if self.use_fp8_qdq:
|
|
self.qkv = FP8Linear(hidden_size,
|
|
hidden_size +
|
|
(2 * tp_size * self.num_attention_kv_heads *
|
|
self.attention_head_size),
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False)
|
|
self.dense = FP8RowLinear(hidden_size,
|
|
hidden_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
instance_id=instance_id)
|
|
else:
|
|
self.qkv = ColumnLinear(hidden_size,
|
|
hidden_size +
|
|
(2 * tp_size * self.num_attention_kv_heads *
|
|
self.attention_head_size),
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False)
|
|
self.dense = RowLinear(hidden_size,
|
|
hidden_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
instance_id=instance_id)
|
|
|
|
# per-layer relative attention table
|
|
if relative_attention:
|
|
self.rel_attn_table = Parameter(shape=(num_attention_heads //
|
|
tp_size, num_buckets),
|
|
dtype=dtype)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: Tensor,
|
|
attention_mask=None,
|
|
use_cache=False,
|
|
kv_cache_params=None,
|
|
attention_params=None,
|
|
encoder_output: Optional[Tensor] = None,
|
|
workspace=None,
|
|
):
|
|
|
|
assert isinstance(hidden_states, Tensor)
|
|
|
|
alibi_slopes = None
|
|
if self.position_embedding_type.is_rope():
|
|
if not default_net().plugin_config.gpt_attention_plugin:
|
|
raise ValueError(
|
|
'RoPE is only supported with GPTAttention plugin')
|
|
elif self.position_embedding_type.is_alibi():
|
|
dtype = trt.float32
|
|
if default_net().plugin_config.gpt_attention_plugin:
|
|
dtype = hidden_states.dtype
|
|
alibi_scale = 1. / self.norm_factor if self.scale_alibi_bias else 1.
|
|
alibi_slopes = generate_alibi_slopes(self.num_attention_heads *
|
|
self.tp_size,
|
|
dtype=dtype,
|
|
tp_size=self.tp_size,
|
|
tp_rank=self.tp_rank,
|
|
alibi_scale=alibi_scale)
|
|
|
|
qkv = self.qkv(hidden_states)
|
|
|
|
paged_kv_cache = default_net().plugin_config.paged_kv_cache
|
|
|
|
assert attention_params is None or attention_params.is_valid(
|
|
default_net().plugin_config.gpt_attention_plugin,
|
|
default_net().plugin_config.remove_input_padding)
|
|
assert kv_cache_params is None or kv_cache_params.is_valid(
|
|
default_net().plugin_config.gpt_attention_plugin)
|
|
|
|
past_key_value = None if kv_cache_params is None else kv_cache_params.get_first_past_key_value(
|
|
)
|
|
if self.cross_attention and (past_key_value is not None):
|
|
past_key_value = kv_cache_params.past_key_value[1]
|
|
|
|
# if cross attention, cross QKV only needs to be calculated once in the
|
|
# 1st decoding step --> write to cross KV cache --> remains constant
|
|
# during the entire decoding. 1st and >1 steps are distinguished by
|
|
# whether past_key_value exists or not
|
|
# also, cross KV cache max length is set from encoder output seqlen,
|
|
# this maps to the max context length concept in decoder-only models
|
|
cross_qkv = None
|
|
# get length data in every run
|
|
if encoder_output:
|
|
assert isinstance(encoder_output, Tensor)
|
|
# but only do projection once at 1st decoding step
|
|
if self.cross_attention and encoder_output:
|
|
cross_qkv = self.qkv(encoder_output)
|
|
|
|
if default_net().plugin_config.gpt_attention_plugin:
|
|
assert self.attention_mask_type in [
|
|
AttentionMaskType.causal, AttentionMaskType.bidirectional
|
|
], 'Plugin only support masked MHA.'
|
|
kv_orig_quant_scale = self.kv_orig_quant_scale.value if self.quant_mode.has_kv_cache_quant(
|
|
) else None
|
|
kv_quant_orig_scale = self.kv_quant_orig_scale.value if self.quant_mode.has_kv_cache_quant(
|
|
) else None
|
|
context, past_key_value = gpt_attention(
|
|
tensor=qkv,
|
|
past_key_value=past_key_value,
|
|
sequence_length=attention_params.sequence_length,
|
|
host_past_key_value_lengths=kv_cache_params.
|
|
host_past_key_value_lengths,
|
|
context_lengths=attention_params.context_lengths,
|
|
cache_indirection=kv_cache_params.cache_indirection,
|
|
host_request_types=attention_params.host_request_types,
|
|
num_heads=self.num_attention_heads,
|
|
num_kv_heads=self.num_attention_kv_heads,
|
|
hidden_size_per_head=self.attention_head_size,
|
|
q_scaling=self.q_scaling,
|
|
rotary_embedding_dim=self.rotary_embedding_dim,
|
|
rotary_embedding_base=self.rotary_embedding_base,
|
|
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
|
|
rotary_embedding_scale=self.rotary_embedding_scale,
|
|
rotary_embedding_max_positions=self.max_position_embeddings,
|
|
position_embedding_type=self.position_embedding_type,
|
|
multi_block_mode=self.multi_block_mode,
|
|
kv_orig_quant_scale=kv_orig_quant_scale,
|
|
kv_quant_orig_scale=kv_quant_orig_scale,
|
|
kv_cache_quant_mode=self.quant_mode,
|
|
max_context_length=attention_params.max_context_length,
|
|
mask_type=self.attention_mask_type,
|
|
alibi_slopes=alibi_slopes,
|
|
tp_size=self.tp_size,
|
|
tp_rank=self.tp_rank,
|
|
kv_cache_block_pointers=kv_cache_params.
|
|
get_first_kv_cache_block_pointers(),
|
|
do_cross_attention=self.cross_attention,
|
|
cross_qkv=cross_qkv,
|
|
cross_qkv_length=attention_params.encoder_max_input_length,
|
|
encoder_input_lengths=attention_params.encoder_input_lengths,
|
|
relative_attention_bias=self.rel_attn_table.value
|
|
if self.relative_attention else None,
|
|
max_distance=self.max_distance,
|
|
host_context_lengths=attention_params.host_context_lengths,
|
|
)
|
|
|
|
else:
|
|
# plain TensorRT mode
|
|
assert paged_kv_cache == False
|
|
|
|
def transpose_for_scores(x, is_kv: bool = False):
|
|
_num_attention_heads = self.num_attention_kv_heads if is_kv else self.num_attention_heads
|
|
new_x_shape = concat([
|
|
shape(x, 0),
|
|
shape(x, 1), _num_attention_heads, self.attention_head_size
|
|
])
|
|
return x.view(new_x_shape).permute([0, 2, 1, 3])
|
|
|
|
# qkv after projection is of shape
|
|
# [bs, seqlen, (num_attention_heads + 2 * num_attention_kv_heads), attention_head_size].
|
|
# The projected and split qkv after transpose_for_scores():
|
|
# Q[bs, num_attention_heads, seqlen, attention_head_size]
|
|
# K[bs, num_attention_kv_heads, seqlen, attention_head_size]
|
|
# V[bs, num_attention_kv_heads, seqlen, attention_head_size]
|
|
kv_size = self.attention_head_size * self.num_attention_kv_heads
|
|
query, key, value = split(qkv, [self.hidden_size, kv_size, kv_size],
|
|
dim=2)
|
|
|
|
# in cross attention mode, replace kv by encoder_output
|
|
if self.cross_attention and encoder_output is not None:
|
|
encoder_qkv = self.qkv(encoder_output)
|
|
_, key, value = split(encoder_qkv,
|
|
[self.hidden_size, kv_size, kv_size],
|
|
dim=2)
|
|
|
|
query = transpose_for_scores(query)
|
|
key = transpose_for_scores(key, is_kv=True)
|
|
value = transpose_for_scores(value, is_kv=True)
|
|
|
|
if past_key_value is not None:
|
|
|
|
def dequantize_tensor(x, scale):
|
|
# Cast from int8 to dtype
|
|
casted_x = cast(x, self.dtype)
|
|
return casted_x * scale
|
|
|
|
if self.use_int8_kv_cache:
|
|
past_key_value = dequantize_tensor(
|
|
past_key_value, self.kv_quant_orig_scale.value)
|
|
|
|
# past_key_value [bs, 2, num_heads, max_seq_len, head_dim]
|
|
past_key, past_value = split(past_key_value, 1, dim=1)
|
|
|
|
key_shape = concat([
|
|
shape(past_key, 0),
|
|
shape(past_key, 2),
|
|
shape(past_key, 3),
|
|
shape(past_key, 4)
|
|
])
|
|
past_key = past_key.view(key_shape, zero_is_placeholder=False)
|
|
past_value = past_value.view(key_shape,
|
|
zero_is_placeholder=False)
|
|
|
|
key = concat([past_key, key], dim=2).cast(self.dtype)
|
|
value = concat([past_value, value], dim=2).cast(self.dtype)
|
|
|
|
if use_cache:
|
|
key_inflated_shape = concat([
|
|
shape(key, 0), 1,
|
|
shape(key, 1),
|
|
shape(key, 2),
|
|
shape(key, 3)
|
|
])
|
|
inflated_key = key.view(key_inflated_shape,
|
|
zero_is_placeholder=False)
|
|
inflated_value = value.view(key_inflated_shape,
|
|
zero_is_placeholder=False)
|
|
past_key_value = concat([inflated_key, inflated_value], dim=1)
|
|
|
|
if self.use_int8_kv_cache:
|
|
|
|
def quantize_tensor(x, scale):
|
|
scaled = x * scale
|
|
rounded = round(scaled)
|
|
clipped = clip(rounded, -128, 127)
|
|
quantized = cast(clipped, 'int8')
|
|
return quantized
|
|
|
|
past_key_value = quantize_tensor(
|
|
past_key_value, self.kv_orig_quant_scale.value)
|
|
|
|
key_length = shape(key, 2)
|
|
|
|
# The following code creates a 2D tensor with 0s in the lower triangular (including the diagonal) and
|
|
# +INF in the upper triangular parts. This bias tensor will be added to the output of the Q*K^T matrix
|
|
# multiplication (BMM1). The +INF elements will be transformed to 0s by the Softmax operator that
|
|
# follows. The elements that corresponds to 0s in the bias are unaffected by the bias tensor.
|
|
#
|
|
# Note that when we added to another bias tensor B (for example, with AliBi), the values in the lower-
|
|
# triangular part of the B tensor are not affected and the upper-triangular ones are set to +INF.
|
|
if self.attention_mask_type == AttentionMaskType.causal:
|
|
query_length = shape(query, 2)
|
|
starts = concat([0, 0, key_length - query_length, 0])
|
|
sizes = concat([1, 1, query_length, key_length])
|
|
select_buf = np.expand_dims(
|
|
np.tril(
|
|
np.ones((self.max_position_embeddings,
|
|
self.max_position_embeddings))).astype(bool),
|
|
(0, 1))
|
|
|
|
select_buf = np.logical_not(select_buf)
|
|
mask_buf = np.zeros_like(select_buf, np.float32)
|
|
mask_buf[select_buf] = float('-inf')
|
|
buffer = constant(mask_buf)
|
|
causal_mask = slice(buffer, starts, sizes)
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = expand_mask(attention_mask, shape(query, 2))
|
|
bias = attention_mask
|
|
if self.position_embedding_type.is_alibi():
|
|
alibi_biases = generate_alibi_biases(alibi_slopes, key_length)
|
|
bias = alibi_biases if bias is None else bias + alibi_biases
|
|
|
|
key = key.permute([0, 1, 3, 2])
|
|
with precision('float32'):
|
|
attention_scores = matmul(cast(query, 'float32'),
|
|
cast(key, 'float32'))
|
|
|
|
attention_scores = attention_scores / self.norm_factor
|
|
|
|
if self.attention_mask_type == AttentionMaskType.causal:
|
|
bias = causal_mask if bias is None else bias + causal_mask
|
|
|
|
if bias is not None and not self.cross_attention:
|
|
attention_scores = attention_scores + bias
|
|
|
|
attention_probs = softmax(attention_scores, dim=-1)
|
|
|
|
context = matmul(attention_probs, value).permute([0, 2, 1, 3])
|
|
context = context.view(
|
|
concat([shape(context, 0),
|
|
shape(context, 1), self.hidden_size]))
|
|
|
|
context = self.dense(context, workspace)
|
|
|
|
if use_cache:
|
|
return (context, past_key_value)
|
|
else:
|
|
return context
|
|
|
|
|
|
class BertAttention(Module):
|
|
|
|
def __init__(self,
|
|
hidden_size,
|
|
num_attention_heads,
|
|
num_kv_heads=None,
|
|
max_position_embeddings=1024,
|
|
num_layers=1,
|
|
q_scaling=1.0,
|
|
apply_query_key_layer_scaling=False,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
relative_attention=False,
|
|
max_distance=0,
|
|
num_buckets=0):
|
|
super().__init__()
|
|
|
|
self.attention_head_size = hidden_size // num_attention_heads
|
|
self.num_attention_heads = num_attention_heads // tp_size
|
|
self.num_attention_kv_heads = (
|
|
num_kv_heads + tp_size - 1
|
|
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
|
self.hidden_size = hidden_size // tp_size
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.norm_factor = math.sqrt(self.attention_head_size)
|
|
self.tp_size = tp_size
|
|
self.tp_rank = tp_rank
|
|
|
|
self.num_layers = num_layers
|
|
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
|
self.norm_factor = math.sqrt(self.attention_head_size)
|
|
self.q_scaling = q_scaling
|
|
if self.apply_query_key_layer_scaling:
|
|
self.norm_factor *= self.num_layers
|
|
self.q_scaling *= self.num_layers
|
|
|
|
self.dtype = dtype
|
|
|
|
self.relative_attention = relative_attention
|
|
self.max_distance = max_distance
|
|
|
|
self.qkv = ColumnLinear(hidden_size,
|
|
hidden_size +
|
|
(2 * tp_size * self.num_attention_kv_heads *
|
|
self.attention_head_size),
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False)
|
|
self.dense = RowLinear(hidden_size,
|
|
hidden_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size)
|
|
|
|
# per-layer relative attention table
|
|
if relative_attention:
|
|
self.rel_attn_table = Parameter(shape=(num_attention_heads //
|
|
tp_size, num_buckets),
|
|
dtype=dtype)
|
|
|
|
def forward(self,
|
|
hidden_states: Tensor,
|
|
attention_mask=None,
|
|
input_lengths=None):
|
|
assert isinstance(hidden_states, Tensor)
|
|
|
|
qkv = self.qkv(hidden_states)
|
|
|
|
if default_net().plugin_config.bert_attention_plugin:
|
|
# TRT plugin mode
|
|
assert input_lengths is not None
|
|
context = bert_attention(
|
|
qkv,
|
|
input_lengths,
|
|
self.num_attention_heads,
|
|
self.attention_head_size,
|
|
q_scaling=self.q_scaling,
|
|
relative_attention=self.relative_attention,
|
|
max_distance=self.max_distance,
|
|
relative_attention_bias=self.rel_attn_table.value
|
|
if self.relative_attention else None)
|
|
else:
|
|
# plain TRT mode
|
|
def transpose_for_scores(x):
|
|
new_x_shape = concat([
|
|
shape(x, 0),
|
|
shape(x, 1), self.num_attention_heads,
|
|
self.attention_head_size
|
|
])
|
|
return x.view(new_x_shape).permute([0, 2, 1, 3])
|
|
|
|
query, key, value = split(qkv, self.hidden_size, dim=2)
|
|
query = transpose_for_scores(query)
|
|
key = transpose_for_scores(key)
|
|
value = transpose_for_scores(value)
|
|
|
|
key = key.permute([0, 1, 3, 2])
|
|
attention_scores = matmul(query, key)
|
|
attention_scores = attention_scores / self.norm_factor
|
|
|
|
if attention_mask is not None:
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
attention_probs = softmax(attention_scores, dim=-1)
|
|
|
|
context = matmul(attention_probs, value).permute([0, 2, 1, 3])
|
|
context = context.view(
|
|
concat([shape(context, 0),
|
|
shape(context, 1), self.hidden_size]))
|
|
|
|
context = self.dense(context)
|
|
|
|
return context
|