# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional import numpy as np import tensorrt as trt from .._common import default_net, precision from .._utils import (fp32_array, int32_array, is_same_dtype, numpy_fp32_to_bf16, preview_trt_version, trt_dtype_to_np, trt_dtype_to_str) from ..functional import (AttentionMaskType, PositionEmbeddingType, RopeEmbeddingUtils, RotaryScalingType, Tensor, arange, bert_attention, cast, clip, concat, conditional, constant, embedding, expand, expand_dims, expand_mask, generate_alibi_biases, generate_alibi_slopes, gpt_attention, matmul, minimum, repeat_interleave, shape, slice, softmax, split, unsqueeze, where) from ..module import Module from ..parameter import Parameter from ..quantization import QuantMode from ..quantization.functional import dequantize, quantize from .linear import ColumnLinear, QKVColumnLinear, RowLinear from .lora import LoraRuntimeParams def make_causal_mask(bsz, tgt_len, past_key_values_length, dtype): _range = arange(start=constant(int32_array(0)), end=tgt_len, dtype=trt_dtype_to_str(dtype)) mask = repeat_interleave(_range, tgt_len, 0).view(concat([tgt_len, tgt_len])) mask = where(mask < mask.transpose(-1, -2), 1.0, 0.0) zero = constant(fp32_array(0)) zero = expand_dims(zero, [0, 1]) zero = expand(zero, concat([tgt_len, past_key_values_length])) mask = concat([zero, mask], dim=1) mask *= np.finfo(trt_dtype_to_np(dtype)).min.item() mask = mask.view(concat([1, 1, tgt_len, tgt_len + past_key_values_length])) mask = expand(mask, concat([bsz, 1, tgt_len, tgt_len + past_key_values_length])) return mask def compute_relative_bias(query_length, key_length, num_buckets, max_distance, bidirectional, rel_attn_table, tp_size=1, tp_group=None, tp_rank=None): def make_relative_position_bucket(relative_position, bidirectional, num_buckets, max_distance): relative_buckets = 0 if bidirectional: num_buckets //= 2 relative_buckets += where(relative_position > 0, num_buckets, 0) relative_position = relative_position.abs() else: relative_position = 0 - minimum(relative_position, 0) max_exact = num_buckets // 2 is_small = relative_position < max_exact max_exact_fp = constant(fp32_array(max_exact)) tmp = cast(relative_position, "float32") / max_exact_fp tmp = tmp.log() const1 = math.log(max_distance / max_exact) const2 = constant(fp32_array(num_buckets - max_exact)) relative_position_if_large = tmp / const1 * const2 relative_position_if_large = cast(relative_position_if_large, "int32") relative_position_if_large = max_exact + relative_position_if_large relative_position_if_large = minimum(relative_position_if_large, num_buckets - 1) relative_buckets += where(is_small, relative_position, relative_position_if_large) return relative_buckets context_position = arange(start=constant(int32_array(0)), end=query_length, dtype=trt_dtype_to_str(trt.int32)) context_position = unsqueeze(context_position, -1) memory_position = arange(start=constant(int32_array(0)), end=key_length, dtype=trt_dtype_to_str(trt.int32)) memory_position = unsqueeze(memory_position, 0) relative_position = memory_position - context_position relative_position_bucket = make_relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional, num_buckets, max_distance, ) # shape (query_length, key_length, num_heads) values = embedding(relative_position_bucket, rel_attn_table, tp_size=tp_size, tp_group=tp_group, tp_rank=tp_rank) # shape (1, num_heads, query_length, key_length) values = unsqueeze(values.permute([2, 0, 1]), 0) return values class AttentionParams(object): def __init__(self, sequence_length: Tensor = None, context_lengths: Tensor = None, host_context_lengths: Tensor = None, max_context_length: int = None, host_request_types: Tensor = None, encoder_input_lengths: Tensor = None, encoder_max_input_length: Tensor = None): self.sequence_length = sequence_length self.context_lengths = context_lengths self.host_context_lengths = host_context_lengths # max allowed context length. Required to # compute scratch memory size. self.max_context_length = max_context_length self.host_request_types = host_request_types self.encoder_input_lengths = encoder_input_lengths self.encoder_max_input_length = encoder_max_input_length def is_valid_cross_attn(self, do_cross_attention): if do_cross_attention: if self.encoder_input_lengths is None: return False if self.encoder_max_input_length is None: return False return True def is_valid(self, gpt_attention_plugin, remove_input_padding): if gpt_attention_plugin: if self.sequence_length is None: return False if self.context_lengths is None: return False if self.host_request_types is None: return False if self.max_context_length is None: return False if remove_input_padding: if self.host_context_lengths is None: return False if not gpt_attention_plugin: return False return True class KeyValueCacheParams: def __init__(self, past_key_value: List[Tensor] = None, host_past_key_value_lengths: Tensor = None, host_max_attention_window_sizes: Tensor = None, host_sink_token_length: Tensor = None, kv_cache_block_offsets: Tensor = None, host_kv_cache_block_offsets: Tensor = None, host_kv_cache_pool_pointers: Tensor = None, cache_indirection: Tensor = None, past_key_value_length: Tensor = None): self.past_key_value = past_key_value self.host_past_key_value_lengths = host_past_key_value_lengths self.host_max_attention_window_sizes = host_max_attention_window_sizes self.host_sink_token_length = host_sink_token_length self.kv_cache_block_offsets = kv_cache_block_offsets self.host_kv_cache_block_offsets = host_kv_cache_block_offsets self.host_kv_cache_pool_pointers = host_kv_cache_pool_pointers self.cache_indirection = cache_indirection # self.past_key_value_length = past_key_value_length def get_first_past_key_value(self): if self.past_key_value is None: return None return self.past_key_value[0] def fill_none_tensor_list(self, list_size): if self.past_key_value is None: self.past_key_value = tuple([None] * list_size) def is_valid(self, gpt_attention_plugin): if gpt_attention_plugin: if self.host_past_key_value_lengths is None: return False if self.host_max_attention_window_sizes is None: return False if self.host_sink_token_length is None: return False if self.cache_indirection is None: return False return True class Attention(Module): def __init__( self, *, local_layer_idx, hidden_size, num_attention_heads, num_kv_heads=None, max_position_embeddings=1024, num_layers=1, apply_query_key_layer_scaling=False, attention_head_size=None, attention_mask_type=AttentionMaskType.padding, bias=True, dtype=None, position_embedding_type=PositionEmbeddingType.learned_absolute, rotary_embedding_base=10000.0, rotary_embedding_scaling=None, rotary_embedding_percentage=1.0, tp_group=None, tp_size=1, tp_rank=0, quant_mode: QuantMode = QuantMode(0), q_scaling=1.0, cross_attention=False, relative_attention=False, max_distance=0, num_buckets=0, dense_bias=None, clip_qkv=None, alibi_bias_max=8, skip_cross_qkv=False, ): super().__init__() self.layer_idx = local_layer_idx self.cross_attention = cross_attention self.attention_mask_type = attention_mask_type self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size assert num_attention_heads % tp_size == 0, \ "num_attention_heads must be divisible by tp_size" self.num_attention_heads = num_attention_heads // tp_size self.num_attention_kv_heads = ( num_kv_heads + tp_size - 1 ) // tp_size if num_kv_heads is not None else self.num_attention_heads self.hidden_size = hidden_size self.attention_hidden_size = self.attention_head_size * self.num_attention_heads self.max_position_embeddings = max_position_embeddings self.bias = bias self.tp_group = tp_group self.tp_size = tp_size self.tp_rank = tp_rank self.dtype = dtype if dense_bias is None: dense_bias = bias self.unfuse_qkv_gemm = False self.num_layers = num_layers self.apply_query_key_layer_scaling = apply_query_key_layer_scaling self.norm_factor = math.sqrt(self.attention_head_size) self.q_scaling = q_scaling if self.apply_query_key_layer_scaling: self.norm_factor *= self.num_layers self.q_scaling *= self.num_layers # Whether to scale ALiBi bias. Mathematically, it's equivalent to # normalizing QK after adding bias. # - False, inv_sqrt_Dh * Q*K^T + alibi_bias # - True, inv_sqrt_Dh * Q*K^T + inv_sqrt_Dh * alibi_bias self.scale_alibi_bias = position_embedding_type == PositionEmbeddingType.alibi_with_scale self.alibi_bias_max = alibi_bias_max self.position_embedding_type = position_embedding_type self.relative_attention = relative_attention self.max_distance = max_distance self.num_buckets = num_buckets self.rotary_embedding_base = rotary_embedding_base self.rotary_embedding_scaling = rotary_embedding_scaling self.rotary_embedding_scale_type = RotaryScalingType.none self.rotary_embedding_scale = 1.0 self.rotary_embedding_percentage = rotary_embedding_percentage if rotary_embedding_scaling is not None: assert rotary_embedding_scaling["type"] in ["linear", "dynamic"] self.rotary_embedding_scale_type = RotaryScalingType.linear if rotary_embedding_scaling[ "type"] == "linear" else RotaryScalingType.dynamic self.rotary_embedding_scale = rotary_embedding_scaling["factor"] assert self.rotary_embedding_scale > 1.0 self.embed_positions = None self.rotary_enabled = False self.rotary_embedding_dim = 0 if self.position_embedding_type.is_rope(): self.rotary_embedding_dim = int(self.attention_head_size * rotary_embedding_percentage) self.rotary_enabled = True self.embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions( self.max_position_embeddings, self.rotary_embedding_dim, ) self.embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( self.max_position_embeddings, self.rotary_embedding_dim, self.rotary_embedding_base, self.rotary_embedding_scale, self.rotary_embedding_scale_type) self.quant_mode = quant_mode self.register_parameter('kv_cache_scaling_factor', None) # The output feature size is therefore (h/tp + 2*kvh/tp) * d, where h is num_heads, # d is head_size, kvh is the num_kv_heads and tp is tensor_parallel_size. # In ColumnLinear op, the output dim is calculated by (h + 2*kvh) * d / tp, # which matches the desired output size (h/tp + 2*kvh/tp) * d after splitting # out dim is not necessarily hidden_size + kv specific size (in MQA/GQA), but num_heads * heads_size # example: d_model != num_heads * head_size in Flan-T5 self.qkv = QKVColumnLinear( hidden_size, tp_size * self.num_attention_heads * self.attention_head_size + (2 * tp_size * self.num_attention_kv_heads * self.attention_head_size), bias=bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size, gather_output=False) self.dense = RowLinear(tp_size * self.num_attention_heads * self.attention_head_size, hidden_size, bias=dense_bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size) # per-layer relative attention table if relative_attention: self.rel_attn_table = Parameter(shape=(num_attention_heads // tp_size, num_buckets), dtype=dtype) if clip_qkv is not None: self.clip_qkv = fp32_array([clip_qkv]) else: self.clip_qkv = None self.skip_cross_qkv = skip_cross_qkv def forward(self, hidden_states: Tensor, attention_mask=None, medusa_packed_mask=None, medusa_position_offsets=None, use_cache=False, kv_cache_params=None, attention_params=None, encoder_output: Optional[Tensor] = None, position_embedding=None, norm_before_bmm1=False, lora_layer_params=None, cross_kv_cache_gen: Optional[Tensor] = None, cross_qkv_reuse: Optional[Tensor] = None): assert isinstance(hidden_states, Tensor) alibi_slopes = None if self.position_embedding_type.is_alibi(): dtype = trt.float32 if default_net().plugin_config.gpt_attention_plugin: dtype = hidden_states.dtype alibi_scale = 1. / self.norm_factor if self.scale_alibi_bias else 1. alibi_slopes = generate_alibi_slopes( self.num_attention_heads * self.tp_size, dtype=dtype, tp_size=self.tp_size, tp_rank=self.tp_rank, alibi_scale=alibi_scale, alibi_bias_max=self.alibi_bias_max) qkv_lora_params = None if lora_layer_params is not None: if not self.cross_attention: qkv_lora_params = lora_layer_params.get_runtime_params( 0, "attn_qkv") else: qkv_lora_params = lora_layer_params.get_runtime_params( 0, "cross_attn_qkv") unfuse_qkv_gemm = self.unfuse_qkv_gemm if unfuse_qkv_gemm: qkv_gemm = [self.q, self.k, self.v] qkv = [gemm(hidden_states) for gemm in qkv_gemm] if default_net( ).plugin_config.lora_plugin and qkv_lora_params is not None: lora = self.qkv.lora(hidden_states, qkv_lora_params) kv_size = self.attention_head_size * self.num_attention_kv_heads qkv_lora = split(lora, [self.attention_hidden_size, kv_size, kv_size], dim=1) qkv = [tensor + lora for tensor, lora in zip(qkv, qkv_lora)] else: qkv = self.qkv(hidden_states, qkv_lora_params) if self.clip_qkv is not None: qkv = clip(qkv, -self.clip_qkv, self.clip_qkv) if default_net().plugin_config.remove_input_padding: if unfuse_qkv_gemm: for tensor in qkv: assert tensor.ndim() == 2 else: assert qkv.ndim() == 2 if default_net( ).plugin_config.lora_plugin and qkv_lora_params is None and lora_layer_params is not None: if not self.cross_attention: q_lora_params = lora_layer_params.get_runtime_params( 0, "attn_q") k_lora_params = lora_layer_params.get_runtime_params( 0, "attn_k") v_lora_params = lora_layer_params.get_runtime_params( 0, "attn_v") else: q_lora_params = lora_layer_params.get_runtime_params( 0, "cross_attn_q") k_lora_params = lora_layer_params.get_runtime_params( 0, "cross_attn_k") v_lora_params = lora_layer_params.get_runtime_params( 0, "cross_attn_v") assert (q_lora_params is not None and k_lora_params is not None and v_lora_params is not None) or \ (q_lora_params is None and k_lora_params is None and v_lora_params is None), "q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time." if q_lora_params is not None and k_lora_params is not None and v_lora_params is not None: qkv_lora_runtime_params = LoraRuntimeParams( lora_ranks=[ q_lora_params.lora_ranks[0], k_lora_params.lora_ranks[0], v_lora_params.lora_ranks[0] ], lora_weights_pointers=[ q_lora_params.lora_weights_pointers[0], k_lora_params.lora_weights_pointers[0], v_lora_params.lora_weights_pointers[0] ], host_request_types=q_lora_params.host_request_types, host_context_lengths=q_lora_params.host_context_lengths, max_context_length=q_lora_params.max_context_length, max_encoder_context_length=q_lora_params. max_encoder_context_length, host_encoder_input_lengths=q_lora_params. host_encoder_input_lengths, ) q_lora, k_lora, v_lora = self.qkv_lora(hidden_states, qkv_lora_runtime_params) qkv_lora = concat([q_lora, k_lora, v_lora], dim=q_lora.rank() - 1) qkv = qkv + qkv_lora if self.position_embedding_type == PositionEmbeddingType.chatglm: qkv = RopeEmbeddingUtils.apply_rotary_pos_emb_chatglm( qkv, position_embedding, self.num_attention_heads, self.attention_head_size, self.max_position_embeddings, self.rotary_embedding_scale, default_net().plugin_config.remove_input_padding, ) self.rotary_embedding_scale_type = RotaryScalingType.none self.rotary_embedding_scale = 1.0 paged_kv_cache = default_net().plugin_config.paged_kv_cache assert attention_params is None or attention_params.is_valid( default_net().plugin_config.gpt_attention_plugin, default_net().plugin_config.remove_input_padding) assert kv_cache_params is None or kv_cache_params.is_valid( default_net().plugin_config.gpt_attention_plugin) past_key_value = None if kv_cache_params is None else kv_cache_params.get_first_past_key_value( ) # if cross attention, cross QKV only needs to be calculated once in the # 1st decoding step --> write to cross KV cache --> remains constant # during the entire decoding steps. # 1st and >1st steps are distinguished by a boolean tensor `cross_kv_cache_gen` passed at runtime # also, cross KV cache max length is set from encoder output seqlen, # this maps to the max context length concept in decoder-only models cross_qkv = None if self.cross_attention and encoder_output: assert isinstance(encoder_output, Tensor) ## True branch: context phase, compute cross qkv cross_qkv_true = self.qkv(encoder_output, qkv_lora_params) if default_net( ).plugin_config.lora_plugin and qkv_lora_params is None and lora_layer_params is not None: cross_q_lora, cross_k_lora, cross_v_lora = self.qkv_lora( encoder_output, qkv_lora_runtime_params, is_cross_attention=True) cross_qkv_lora = concat( [cross_q_lora, cross_k_lora, cross_v_lora], dim=cross_q_lora.rank() - 1) cross_qkv_true = cross_qkv_true + cross_qkv_lora ## End True branch ## False branch: generation phase, no compute but need to obey shape constraints # because TRT's IfConditional requires the output shape of two subgraphs to be identical # our 1st attempt was to stack encoder_output [B, S, H] or [N, H] --> cross qkv [B, S, 3*H] or [N, 3*H], but it still introduces unnecessary concat. A better solution is to create a dummy torch tensor `cross_qkv_resue` with the correct shape and reuse it in every generation step cross_qkv_false = cross_qkv_reuse ## End False branch # IfConditional layer if self.skip_cross_qkv: cross_qkv = conditional(cross_kv_cache_gen, cross_qkv_true, cross_qkv_false) else: cross_qkv = cross_qkv_true if default_net().plugin_config.gpt_attention_plugin: if self.cross_attention and (past_key_value is not None): past_key_value = kv_cache_params.past_key_value[1] assert self.attention_mask_type in [ AttentionMaskType.causal, AttentionMaskType.bidirectional, AttentionMaskType.bidirectionalglm ], 'Plugin only support masked MHA.' # KV cache scales. if self.kv_cache_scaling_factor is not None: kv_orig_quant_scale = constant(fp32_array( [1.0])) / self.kv_cache_scaling_factor.value kv_quant_orig_scale = self.kv_cache_scaling_factor.value else: kv_orig_quant_scale = None kv_quant_orig_scale = None # Attention output scales assert ( not default_net().plugin_config.use_fp8_context_fmha ) or self.quant_mode.has_fp8_qdq( ), "FP8 Context FMHA must be used together with the fp8 quantization workflow." if self.quant_mode.has_fp8_qdq() and default_net( ).plugin_config.use_fp8_context_fmha: # the attention plugin only quantizes the output when fp8 context fmha is enabled. attention_output_orig_quant_scale = constant( fp32_array([1.0] / self.dense.activation_scaling_factor.raw_value)) else: attention_output_orig_quant_scale = None # Rotary cos/sin cache. rotary_cos_sin = constant( self.embed_positions_for_gpt_attention ) if self.position_embedding_type.is_rope() else None context, past_key_value = gpt_attention( qkv=qkv, past_key_value=past_key_value, sequence_length=attention_params.sequence_length, host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, host_max_attention_window_sizes=kv_cache_params. host_max_attention_window_sizes, host_sink_token_length=kv_cache_params.host_sink_token_length, context_lengths=attention_params.context_lengths, cache_indirection=kv_cache_params.cache_indirection, host_request_types=attention_params.host_request_types, layer_idx=self.layer_idx, num_heads=self.num_attention_heads, num_kv_heads=self.num_attention_kv_heads, hidden_size_per_head=self.attention_head_size, q_scaling=self.q_scaling, rotary_embedding_dim=self.rotary_embedding_dim, rotary_embedding_base=self.rotary_embedding_base, rotary_embedding_scale_type=self.rotary_embedding_scale_type, rotary_embedding_scale=self.rotary_embedding_scale, rotary_embedding_max_positions=self.max_position_embeddings, position_embedding_type=self.position_embedding_type, rotary_cos_sin=rotary_cos_sin, kv_orig_quant_scale=kv_orig_quant_scale, kv_quant_orig_scale=kv_quant_orig_scale, attention_output_orig_quant_scale= attention_output_orig_quant_scale, kv_cache_quant_mode=self.quant_mode, max_context_length=attention_params.max_context_length, mask_type=self.attention_mask_type, alibi_slopes=alibi_slopes, tp_size=self.tp_size, tp_rank=self.tp_rank, kv_cache_block_offsets=kv_cache_params.kv_cache_block_offsets, host_kv_cache_block_offsets=kv_cache_params. host_kv_cache_block_offsets, host_kv_cache_pool_pointers=kv_cache_params. host_kv_cache_pool_pointers, do_cross_attention=self.cross_attention, cross_qkv=cross_qkv, cross_qkv_length=attention_params.encoder_max_input_length, encoder_input_lengths=attention_params.encoder_input_lengths, relative_attention_bias=self.rel_attn_table.value if self.relative_attention else None, max_distance=self.max_distance, host_context_lengths=attention_params.host_context_lengths, use_cache=use_cache, medusa_position_offsets=medusa_position_offsets, medusa_packed_mask=medusa_packed_mask, ) else: # plain TensorRT mode assert paged_kv_cache == False def transpose_for_scores(x, rotary: bool = False, is_kv: bool = False): _num_attention_heads = self.num_attention_kv_heads if is_kv else self.num_attention_heads new_x_shape = concat([ shape(x, 0), shape(x, 1), _num_attention_heads, self.attention_head_size ]) if rotary: return x.view(new_x_shape) else: return x.view(new_x_shape).permute([0, 2, 1, 3]) # qkv after projection is of shape # [bs, seqlen, (num_attention_heads + 2 * num_attention_kv_heads), attention_head_size]. # The projected and split qkv after transpose_for_scores(): # Q[bs, num_attention_heads, seqlen, attention_head_size] # K[bs, num_attention_kv_heads, seqlen, attention_head_size] # V[bs, num_attention_kv_heads, seqlen, attention_head_size] kv_size = self.attention_head_size * self.num_attention_kv_heads if unfuse_qkv_gemm: query, key, value = qkv[0], qkv[1], qkv[2] else: query, key, value = split( qkv, [self.attention_hidden_size, kv_size, kv_size], dim=2) # in cross attention mode, replace kv by encoder_output if self.cross_attention and encoder_output is not None: encoder_qkv = self.qkv(encoder_output) _, key, value = split( encoder_qkv, [self.attention_hidden_size, kv_size, kv_size], dim=2) query = transpose_for_scores(query, rotary=self.rotary_enabled) key = transpose_for_scores(key, is_kv=True, rotary=self.rotary_enabled) value = transpose_for_scores(value, is_kv=True) if self.rotary_enabled: if is_same_dtype(self.dtype, trt.bfloat16): embed_positions = numpy_fp32_to_bf16( self.embed_positions.astype(np.float32)) embed_positions = constant(embed_positions) else: embed_positions = constant( self.embed_positions.astype(trt_dtype_to_np( query.dtype))) if self.rotary_embedding_dim is not None: # When shape(hidden_states, 1) > 1(Context phase), the embedding start from 0, # otherwise (Generation phase) move start to position start = where( shape(hidden_states, 1) > 1, 0, shape(past_key_value, 3)) size = where( shape(hidden_states, 1) > 1, shape(hidden_states, 1), 1) sincos = slice(embed_positions, concat([0, start, 0]), concat([1, size, self.rotary_embedding_dim])) sin, cos = split(sincos, self.rotary_embedding_dim // 2, dim=-1) key_rot_size = concat([ shape(key, 0), shape(key, 1), shape(key, 2), self.rotary_embedding_dim ]) query_rot_size = concat([ shape(query, 0), shape(query, 1), shape(query, 2), self.rotary_embedding_dim ]) remaining = shape(key, 3) - self.rotary_embedding_dim key_pass_size = concat([ shape(key, 0), shape(key, 1), shape(key, 2), remaining ]) query_pass_size = concat([ shape(query, 0), shape(query, 1), shape(query, 2), remaining ]) k_rot = slice(key, [0, 0, 0, 0], key_rot_size) k_pass = slice(key, [0, 0, 0, self.rotary_embedding_dim], key_pass_size) q_rot = slice(query, [0, 0, 0, 0], query_rot_size) q_pass = slice(query, [0, 0, 0, self.rotary_embedding_dim], query_pass_size) k_rot = RopeEmbeddingUtils.apply_rotary_pos_emb( k_rot, [cos, sin], self.position_embedding_type) q_rot = RopeEmbeddingUtils.apply_rotary_pos_emb( q_rot, [cos, sin], self.position_embedding_type) key = concat([k_rot, k_pass], dim=3) query = concat([q_rot, q_pass], dim=3) else: key = RopeEmbeddingUtils.apply_rotary_pos_emb( key, [cos, sin], self.position_embedding_type) query = RopeEmbeddingUtils.apply_rotary_pos_emb( query, [cos, sin], self.position_embedding_type) key = key.permute([0, 2, 1, 3]) query = query.permute([0, 2, 1, 3]) if past_key_value is not None and not self.cross_attention: if self.kv_cache_scaling_factor is not None: past_key_value = dequantize( past_key_value, self.kv_cache_scaling_factor.value) # past_key_value [bs, 2, num_heads, max_seq_len, head_dim] past_key, past_value = split(past_key_value, 1, dim=1) key_shape = concat([ shape(past_key, 0), shape(past_key, 2), shape(past_key, 3), shape(past_key, 4) ]) past_key = past_key.view(key_shape, zero_is_placeholder=False) past_value = past_value.view(key_shape, zero_is_placeholder=False) key = concat([past_key, key], dim=2) value = concat([past_value, value], dim=2) if use_cache: key_inflated_shape = concat([ shape(key, 0), 1, shape(key, 1), shape(key, 2), shape(key, 3) ]) inflated_key = key.view(key_inflated_shape, zero_is_placeholder=False) inflated_value = value.view(key_inflated_shape, zero_is_placeholder=False) past_key_value = concat([inflated_key, inflated_value], dim=1) # TRT quantizes the tensor value by doing `cast(clip(fp_value / scale))` while # the plugin quantizes it by doing `cast(clip(fp_value * scale))`. if self.kv_cache_scaling_factor is not None: past_key_value = quantize( past_key_value, self.kv_cache_scaling_factor.value, dtype='fp8' if self.quant_mode.has_fp8_kv_cache() else 'int8') # MQA broadcast if self.num_attention_heads // self.num_attention_kv_heads > 1: key = repeat_interleave( key, self.num_attention_heads // self.num_attention_kv_heads, 1) value = repeat_interleave( value, self.num_attention_heads // self.num_attention_kv_heads, 1) key_length = shape(key, 2) # The following code creates a 2D tensor with 0s in the lower triangular (including the diagonal) and # +INF in the upper triangular parts. This bias tensor will be added to the output of the Q*K^T matrix # multiplication (BMM1). The +INF elements will be transformed to 0s by the Softmax operator that # follows. The elements that corresponds to 0s in the bias are unaffected by the bias tensor. # # Note that when we added to another bias tensor B (for example, with AliBi), the values in the lower- # triangular part of the B tensor are not affected and the upper-triangular ones are set to +INF. if self.attention_mask_type == AttentionMaskType.causal and not self.cross_attention: if self.position_embedding_type.is_alibi(): query_length = shape(query, 2) # bsz, tatget_length, past_key_value_length buffer = make_causal_mask(shape(query, 0), query_length, key_length - query_length, dtype) starts = concat([0, 0, 0, 0]) sizes = concat([1, 1, query_length, key_length]) generated_mask = slice(buffer, starts, sizes) else: query_length = shape(query, 2) starts = concat([0, 0, key_length - query_length, 0]) sizes = concat([1, 1, query_length, key_length]) select_buf = np.expand_dims( np.tril( np.ones( (self.max_position_embeddings, self.max_position_embeddings))).astype(bool), (0, 1)) select_buf = np.logical_not(select_buf) mask_buf = np.zeros_like(select_buf, np.float32) mask_buf[select_buf] = float('-inf') buffer = constant(mask_buf) generated_mask = slice(buffer, starts, sizes) elif self.attention_mask_type == AttentionMaskType.bidirectional and not self.cross_attention: query_length = shape(query, 2) zero_buf = np.expand_dims( np.zeros((self.max_position_embeddings, self.max_position_embeddings), dtype=np.float32), (0, 1)) zero_buf[:, :, :-1, -1] = 1 zero_buf *= -10000 mask = constant(zero_buf) # context phase, query_length mask_size = where(query_length > 1, query_length, 1) mask_start = where(query_length > 1, self.max_position_embeddings - mask_size, 1) start = concat([0, 0, mask_start, mask_start]) size = concat([1, 1, mask_size, mask_size]) generated_mask = slice(mask, start, size) if attention_mask is not None: if self.cross_attention: batch_size = shape(attention_mask, 0) query_len = shape(attention_mask, 1) encoder_input_len = shape(attention_mask, 2) attention_mask = attention_mask.view( concat([batch_size, 1, query_len, encoder_input_len])) attention_mask = where(attention_mask == 0, float('-inf'), 0.0) else: attention_mask = expand_mask(attention_mask, shape(query, 2)) bias = attention_mask if self.position_embedding_type.is_alibi(): alibi_biases = generate_alibi_biases(alibi_slopes, key_length) bias = alibi_biases if bias is None else bias + alibi_biases if self.relative_attention: query_length = shape(query, 2) relative_bias = compute_relative_bias( query_length + key_length - 1, key_length, self.num_buckets, self.max_distance, False, # bidirectional self.rel_attn_table.value.transpose(1, 0), tp_size=self.tp_size, tp_group=self.tp_group, tp_rank=self.tp_rank) start = concat([0, 0, query_length + key_length - 2, 0]) size = concat([ shape(relative_bias, 0), shape(relative_bias, 1), 1, shape(relative_bias, 3) ]) relative_bias = slice(relative_bias, start, size) key = key.permute([0, 1, 3, 2]) with precision('float32'): if norm_before_bmm1: # Apply norm on query earlier to prevent matmul fp16 overflow. query /= (self.q_scaling * self.norm_factor) if preview_trt_version( ) or self.position_embedding_type.is_alibi(): attention_scores = matmul(query, key) else: # For TRT 9.x, OOTB need this WAR to fuse mha. attention_scores = matmul(cast(query, 'float32'), cast(key, 'float32')) if not norm_before_bmm1: attention_scores = attention_scores / (self.q_scaling * self.norm_factor) if self.attention_mask_type in [ AttentionMaskType.causal, AttentionMaskType.bidirectional ] and not self.cross_attention: bias = generated_mask if bias is None else bias + generated_mask if bias is not None: bias = cast(bias, attention_scores.dtype) attention_scores = attention_scores + bias if self.relative_attention: attention_scores = attention_scores + relative_bias attention_probs = softmax(attention_scores, dim=-1) if preview_trt_version() or self.position_embedding_type.is_alibi(): # For trt_version() == 9.x and pos_embed == alibi, TRT has gpu buffer management issues. Need this WAR to avoid peak gpu mem regression. # A dummy reshape WAR for mha fusion for 10.0 attention_probs = attention_probs.view( concat([ shape(attention_probs, 0), shape(attention_probs, 1), shape(attention_probs, 2), shape(value, 2) ])) context = matmul(attention_probs, value, use_fp32_acc=False).permute([0, 2, 1, 3]) else: # For TRT 9.x, need this WAR to fuse mha. context = matmul(attention_probs, cast(value, 'float32')).permute([0, 2, 1, 3]) if context.dtype != value.dtype: context = cast(context, value.dtype) context = context.view( concat([ shape(context, 0), shape(context, 1), self.attention_hidden_size ])) dense_lora_params = None if lora_layer_params is not None: dense_lora_params = lora_layer_params.get_runtime_params( 0, "attn_dense") context = self.dense(context, lora_runtime_params=dense_lora_params) if use_cache: return (context, past_key_value) else: return context class BertAttention(Module): def __init__(self, hidden_size, num_attention_heads, max_position_embeddings=1024, num_layers=1, attention_head_size=None, num_kv_heads=None, q_scaling=1.0, apply_query_key_layer_scaling=False, bias=True, dtype=None, tp_group=None, tp_size=1, tp_rank=0, relative_attention=False, max_distance=0, num_buckets=0): super().__init__() self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size self.num_attention_heads = num_attention_heads // tp_size self.num_attention_kv_heads = ( num_kv_heads + tp_size - 1 ) // tp_size if num_kv_heads is not None else self.num_attention_heads self.hidden_size = hidden_size self.attention_hidden_size = self.attention_head_size * self.num_attention_heads self.max_position_embeddings = max_position_embeddings self.norm_factor = math.sqrt(self.attention_head_size) self.tp_group = tp_group self.tp_size = tp_size self.tp_rank = tp_rank self.num_layers = num_layers self.apply_query_key_layer_scaling = apply_query_key_layer_scaling self.norm_factor = math.sqrt(self.attention_head_size) self.q_scaling = q_scaling if self.apply_query_key_layer_scaling: self.norm_factor *= self.num_layers self.q_scaling *= self.num_layers self.dtype = dtype self.relative_attention = relative_attention self.max_distance = max_distance self.num_buckets = num_buckets # out dim is not necessarily hidden_size + kv specific size (in MQA/GQA), but num_heads * heads_size # example: d_model != num_heads * head_size in Flan-T5 self.qkv = ColumnLinear(hidden_size, tp_size * self.attention_hidden_size + (2 * tp_size * self.num_attention_kv_heads * self.attention_head_size), bias=bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size, gather_output=False) self.dense = RowLinear(tp_size * self.num_attention_heads * self.attention_head_size, hidden_size, bias=bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size) # per-layer relative attention table if relative_attention: self.rel_attn_table = Parameter(shape=(num_attention_heads // tp_size, num_buckets), dtype=dtype) def forward(self, hidden_states: Tensor, attention_mask=None, input_lengths=None, max_input_length=None, lora_layer_params=None): assert isinstance(hidden_states, Tensor) qkv_lora_params = None if lora_layer_params is not None: qkv_lora_params = lora_layer_params.get_runtime_params( 0, "attn_qkv") qkv = self.qkv(hidden_states, qkv_lora_params) if default_net().plugin_config.remove_input_padding: assert qkv.ndim() == 2 if default_net( ).plugin_config.lora_plugin and qkv_lora_params is None and lora_layer_params is not None: q_lora_params = lora_layer_params.get_runtime_params(0, "attn_q") k_lora_params = lora_layer_params.get_runtime_params(0, "attn_k") v_lora_params = lora_layer_params.get_runtime_params(0, "attn_v") assert (q_lora_params is not None and k_lora_params is not None and v_lora_params is not None) or \ (q_lora_params is None and k_lora_params is None and v_lora_params is None), "q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time." if q_lora_params is not None and k_lora_params is not None and v_lora_params is not None: qkv_lora_params = LoraRuntimeParams( lora_ranks=[ q_lora_params.lora_ranks[0], k_lora_params.lora_ranks[0], v_lora_params.lora_ranks[0] ], lora_weights_pointers=[ q_lora_params.lora_weights_pointers[0], k_lora_params.lora_weights_pointers[0], v_lora_params.lora_weights_pointers[0] ], host_request_types=q_lora_params.host_request_types, host_context_lengths=q_lora_params.host_context_lengths, max_context_length=q_lora_params.max_context_length) q_lora, k_lora, v_lora = self.qkv_lora(hidden_states, qkv_lora_params) qkv_lora = concat([q_lora, k_lora, v_lora], dim=q_lora.rank() - 1) qkv = qkv + qkv_lora if default_net().plugin_config.bert_attention_plugin: # TRT plugin mode assert input_lengths is not None context = bert_attention( qkv, input_lengths, self.num_attention_heads, self.attention_head_size, q_scaling=self.q_scaling, relative_attention=self.relative_attention, max_distance=self.max_distance, relative_attention_bias=self.rel_attn_table.value if self.relative_attention else None, max_input_length=max_input_length) else: # plain TRT mode def transpose_for_scores(x): new_x_shape = concat([ shape(x, 0), shape(x, 1), self.num_attention_heads, self.attention_head_size ]) return x.view(new_x_shape).permute([0, 2, 1, 3]) kv_size = self.attention_head_size * self.num_attention_kv_heads query, key, value = split( qkv, [self.attention_hidden_size, kv_size, kv_size], dim=2) query = transpose_for_scores(query) key = transpose_for_scores(key) value = transpose_for_scores(value) key = key.permute([0, 1, 3, 2]) attention_scores = matmul(query, key, use_fp32_acc=False) attention_scores = attention_scores / (self.q_scaling * self.norm_factor) if self.relative_attention: query_len = shape(attention_scores, 2) key_len = shape(attention_scores, 3) bias = compute_relative_bias( query_len, key_len, self.num_buckets, self.max_distance, True, # bidirectional self.rel_attn_table.value.transpose(1, 0), tp_size=self.tp_size, tp_group=self.tp_group, tp_rank=self.tp_rank) attention_scores = attention_scores + bias if attention_mask is not None: attention_mask = expand_mask(attention_mask, shape(query, 2)) attention_mask = cast(attention_mask, attention_scores.dtype) attention_scores = attention_scores + attention_mask attention_probs = softmax(attention_scores, dim=-1) context = matmul(attention_probs, value, use_fp32_acc=False).permute([0, 2, 1, 3]) context = context.view( concat([ shape(context, 0), shape(context, 1), self.attention_hidden_size ])) dense_lora_params = None if lora_layer_params is not None: dense_lora_params = lora_layer_params.get_runtime_params( 0, "attn_dense") context = self.dense(context, lora_runtime_params=dense_lora_params) return context