# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections import OrderedDict from typing import List, Optional import tensorrt as trt from ..functional import Tensor from ..layers import MropeParams, SpecDecodingParams from ..llmapi.kv_cache_type import KVCacheType from ..mapping import Mapping from ..plugin import current_all_reduce_helper class GenerationMixin: @staticmethod def has_ctx_gen_opt_profiles( use_gpt_attention_plugin: bool = False, use_gemm_plugin: bool = False, use_mamba_conv1d_plugin: bool = False, remove_input_padding: bool = False, paged_state: bool = False, kv_cache_type: KVCacheType = KVCacheType.CONTINUOUS) -> bool: res = False if not use_gpt_attention_plugin or not use_gemm_plugin: use_in_flight_batching = False # Refer to modelConfig.h: supportsInflightBatching(), this should be consistent its implementation. # We skip check transformer or rnn arch for simplification. if remove_input_padding and use_gpt_attention_plugin: use_in_flight_batching = kv_cache_type in [ KVCacheType.PAGED, KVCacheType.DISABLED ] elif remove_input_padding and use_mamba_conv1d_plugin: use_in_flight_batching = paged_state == True res = not use_in_flight_batching return res @staticmethod def default_range(max_range, offset=0, min_range=1, opt_offset=0): result = [ min_range, (max_range + min_range + opt_offset) // 2, max_range ] return [elem + offset for elem in result] @staticmethod def split_num_tokens_range(max_num_tokens): split_point = [64, 128, 256, 512, 1024] num_tokens_ranges = [] for i, p in enumerate(split_point): if i == 0 and max_num_tokens <= p: return [1, max_num_tokens, max_num_tokens] elif max_num_tokens <= p: num_tokens_ranges.append( [split_point[i - 1], max_num_tokens, max_num_tokens]) return num_tokens_ranges elif i == 0 and max_num_tokens > p: num_tokens_ranges = [[1, 64, 64]] else: num_tokens_ranges.append( [split_point[i - 1], split_point[i], split_point[i]]) num_tokens_ranges.append( [split_point[-1], max_num_tokens, max_num_tokens]) return num_tokens_ranges @staticmethod def get_profiles_ranges( *, max_batch_size, max_beam_width, max_input_len, max_num_tokens, max_draft_len, opt_batch_size, opt_num_tokens, enable_ctx_gen_opt_profiles, multiple_profiles, kv_cache_type: KVCacheType = KVCacheType.CONTINUOUS): default_range = GenerationMixin.default_range if opt_batch_size: bb_range_cxt = [1, opt_batch_size, max_batch_size] bb_range_gen = [ 1, opt_batch_size * max_beam_width, max_batch_size * max_beam_width ] else: bb_range_cxt = default_range(max_batch_size) bb_range_gen = default_range(max_batch_size * max_beam_width) tokens_per_engine_step = max_draft_len + 1 tokens_per_engine_step_range = [ 1, tokens_per_engine_step, tokens_per_engine_step ] bbd_range_ctx = [ bb_range_cxt[i] * (tokens_per_engine_step if i != 0 else 1) for i in range(len(bb_range_cxt)) ] bbd_range_gen = [ bb_range_gen[i] * (tokens_per_engine_step if i != 0 else 1) for i in range(len(bb_range_gen)) ] inlen_range_cxt = default_range(max_input_len) inlen_range_gen = [1, 1, tokens_per_engine_step] if enable_ctx_gen_opt_profiles: num_profiles = 2 bb_range = [bb_range_cxt, bb_range_gen] bbd_range = [bbd_range_ctx, bbd_range_gen] inlen_range = [inlen_range_cxt, inlen_range_gen] position_ids_inlen_range = [inlen_range_cxt, [1, 1, 1]] num_tokens_range_ctx = default_range(max_batch_size * max_input_len) # Draft tokens cannot be combined with beam search num_tokens_range_gen = default_range( max_batch_size * max(tokens_per_engine_step, max_beam_width)) num_tokens_range = [num_tokens_range_ctx, num_tokens_range_gen] # Only keep context range when kv cache is disabled. if kv_cache_type == KVCacheType.DISABLED: num_profiles = 1 bb_range = [bb_range[0]] bbd_range = [bbd_range[0]] inlen_range = [inlen_range[0]] position_ids_inlen_range = [position_ids_inlen_range[0]] num_tokens_range_ctx = [num_tokens_range_ctx[0]] # Draft tokens cannot be combined with beam search num_tokens_range_gen = [num_tokens_range_gen[0]] num_tokens_range = [num_tokens_range[0]] else: if multiple_profiles: num_tokens_range = GenerationMixin.split_num_tokens_range( max_num_tokens) else: if opt_num_tokens is None: opt_num_tokens = min(max_num_tokens, max_batch_size * max_beam_width) num_tokens_range = [[1, opt_num_tokens, max_num_tokens]] num_profiles = len(num_tokens_range) bb_range = [bb_range_gen] * num_profiles bbd_range = [bbd_range_gen] * num_profiles inlen_range = [[1, 1, max_input_len]] * num_profiles position_ids_inlen_range = [[1, 1, max_input_len]] * num_profiles tokens_per_engine_step_range = [tokens_per_engine_step_range ] * num_profiles position_ids_num_tokens_range = num_tokens_range # If max_draft_len != 0, the input_ids may include draft tokens. And the length of position_ids may be not the same as input_ids. # In extreme cases, input_ids may contain (max_draft_token + 1) * N, and the actual position_ids value is only 1 * N. # Therefore, we need to adjust the min value in the ranges of position_ids. if max_draft_len != 0: position_ids_num_tokens_range = list( map( lambda x: [math.ceil(x[0] / (max_draft_len + 1)), x[1], x[2]], num_tokens_range)) ranges = { 'bb_range': bb_range, 'bbd_range': bbd_range, 'inlen_range': inlen_range, 'position_ids_inlen_range': position_ids_inlen_range, 'num_tokens_range': num_tokens_range, 'tokens_per_engine_step_range': tokens_per_engine_step_range, 'position_ids_num_tokens_range': position_ids_num_tokens_range, } return num_profiles, ranges def prepare_attention_inputs( self, *, max_batch_size, max_beam_width, max_input_len, max_seq_len, num_kv_heads, head_size, num_layers, kv_dtype, kv_cache_type: KVCacheType, num_profiles=1, enable_ctx_gen_opt_profiles=False, remove_input_padding=False, use_gpt_attention_plugin=False, tokens_per_block=32, mapping=Mapping(), streamingllm=False, attn_layer_idx=None, opt_batch_size=None, num_kv_heads_per_layer: Optional[List[int]] = None): if attn_layer_idx is not None and num_kv_heads_per_layer is not None: assert len(attn_layer_idx) == len(num_kv_heads_per_layer), ( f"Expected len(attn_layer_idx) ({len(attn_layer_idx)})" f" == len(num_kv_heads_per_layer) ({len(num_kv_heads_per_layer)})" ) default_range = GenerationMixin.default_range if opt_batch_size: bb_range_cxt = [1, opt_batch_size, max_batch_size] bb_range_gen = [ 1, opt_batch_size * max_beam_width, max_batch_size * max_beam_width ] else: bb_range_cxt = default_range(max_batch_size) bb_range_gen = default_range(max_batch_size * max_beam_width) _bs_range = default_range(max_batch_size) _beam_width_range = default_range(max_beam_width) _max_len_range = default_range(max_seq_len) _mask_len_ctx = default_range(max_input_len) _kv_cache_range_ctx = [0, 0, 0] _kv_cache_range_gen = default_range(max_seq_len, -1) if kv_cache_type == KVCacheType.DISABLED: _kv_cache_range = default_range(max_seq_len) else: kv_max_seq_len = max_seq_len if streamingllm: # add the max bubble length kv_max_seq_len += tokens_per_block - 1 if max_beam_width > 1: # support cyclic kv cache cases that use one more block kv_max_seq_len += tokens_per_block _kv_cache_range = default_range(kv_max_seq_len) if enable_ctx_gen_opt_profiles: if kv_cache_type != KVCacheType.DISABLED: assert num_profiles == 2 bb_range = [bb_range_cxt, bb_range_gen] mask_len_range = [_mask_len_ctx, _max_len_range] if use_gpt_attention_plugin: kv_cache_range = [_kv_cache_range, _kv_cache_range] else: kv_cache_range = [_kv_cache_range_ctx, _kv_cache_range_gen] else: assert num_profiles == 1 bb_range = [bb_range_cxt] mask_len_range = [_mask_len_ctx] if use_gpt_attention_plugin: kv_cache_range = [_kv_cache_range] else: kv_cache_range = [_kv_cache_range_ctx] else: bb_range = [bb_range_gen] * num_profiles mask_len_range = [_max_len_range] * num_profiles kv_cache_range = [_kv_cache_range] * num_profiles bs_range = [_bs_range] * num_profiles beam_width_range = [_beam_width_range] * num_profiles max_len_range = [_max_len_range] * num_profiles num_kv_heads = (num_kv_heads + mapping.tp_size - 1) // mapping.tp_size if num_kv_heads_per_layer is not None: num_kv_heads_per_layer = [ (nheads + mapping.tp_size - 1) // mapping.tp_size for nheads in num_kv_heads_per_layer ] layers_range = mapping.pp_layers(num_layers) if attn_layer_idx is None: attn_layer_idx = [i for i in range(num_layers)] # layer indices of attention layers local to the current pp rank local_attn_layers = [i for i in layers_range if i in attn_layer_idx] # number of attention layers local to previous pp ranks num_attn_layers_lower_ranks = attn_layer_idx.index(local_attn_layers[0]) num_attn_layers = len(local_attn_layers) num_layers_prev_rank = layers_range[ 0] // mapping.pp_rank if mapping.pp_rank != 0 else len(layers_range) past_key_value = [] kv_cache_block_offsets = None host_kv_cache_block_offsets = None host_kv_cache_pool_pointers = None host_kv_cache_pool_mapping = None if kv_cache_type == KVCacheType.DISABLED: past_key_value = [None] * num_layers_prev_rank else: if kv_cache_type != KVCacheType.PAGED: for layer_idx in layers_range: if layer_idx not in local_attn_layers: # not an attention layer ==> give it None pkv input past_key_value.append(None) continue attn_idx = local_attn_layers.index(layer_idx) if num_kv_heads_per_layer is not None: heads_dim_name = f"num_heads_{layer_idx}" kv_heads = num_kv_heads_per_layer[ num_attn_layers_lower_ranks + attn_idx] else: heads_dim_name = "num_heads" kv_heads = num_kv_heads kv_dim_range = OrderedDict([ ('batch_size_beam_width', bb_range), ('kv', [2] * num_profiles), (heads_dim_name, [kv_heads] * num_profiles), ('past_key_len', kv_cache_range), ('head_size', [head_size] * num_profiles), ]) kv = Tensor(name=f'past_key_value_{layer_idx}', dtype=kv_dtype, shape=[-1, 2, kv_heads, -1, head_size], dim_range=kv_dim_range) past_key_value.append(kv) else: if enable_ctx_gen_opt_profiles: max_blocks_per_seq_range = [ [ math.ceil(kv_cache_range[0][0] / tokens_per_block), math.ceil(kv_cache_range[0][1] / tokens_per_block), math.ceil(kv_cache_range[0][2] / tokens_per_block) ], [ math.ceil(kv_cache_range[1][0] / tokens_per_block), math.ceil(kv_cache_range[1][1] / tokens_per_block), math.ceil(kv_cache_range[1][2] / tokens_per_block) ] ] else: max_blocks_per_seq_range = [[ math.ceil(kv_cache_range[0][0] / tokens_per_block), math.ceil(kv_cache_range[0][1] / tokens_per_block), math.ceil(kv_cache_range[0][2] / tokens_per_block) ]] * num_profiles NUM_KV_CACHE_POOLS = -1 # the number of unique variable window sizes, which is only known at runtime, affects the number of pools. # dim_range=(min=1, opt=1 (this is the usual case - non vgqa, non vsliding_window), max=num_layers, # TODO(nhaber): Benchmark if making NUM_KV_CACHE_POOLS dynamic has a significant performance hit? kv_pools_range = [[1, 1, len(local_attn_layers)]] * num_profiles kv_cache_block_offsets = Tensor( name=f'kv_cache_block_offsets', dtype=trt.int32, shape=[NUM_KV_CACHE_POOLS, -1, 2, -1], dim_range=OrderedDict([ ('num_kv_cache_pools', kv_pools_range), ('batch_size_beam_width', bb_range), ('kv', [2] * num_profiles), ('max_blocks_per_seq', max_blocks_per_seq_range), ])) host_kv_cache_block_offsets = Tensor( name=f'host_kv_cache_block_offsets', dtype=trt.int32, shape=[NUM_KV_CACHE_POOLS, -1, 2, -1], dim_range=OrderedDict([ ('num_kv_cache_pools', kv_pools_range), ('batch_size_beam_width', bb_range), ('kv', [2] * num_profiles), ('max_blocks_per_seq', max_blocks_per_seq_range), ])) host_kv_cache_pool_pointers = Tensor( name=f'host_kv_cache_pool_pointers', dtype=trt.int64, shape=[NUM_KV_CACHE_POOLS, 2], dim_range=OrderedDict([ ('num_pools_layers', kv_pools_range), ('num_pools_kv', [2] * num_profiles), ])) host_kv_cache_pool_mapping = Tensor( name=f'host_kv_cache_pool_mapping', dtype=trt.int32, shape=[num_attn_layers, 2], # 2: (Index of pool, Index of layer within pool) dim_range=OrderedDict([ ('pools_mapping', [num_attn_layers] * num_profiles), ('layer_cache_pool_locator', [2] * num_profiles) ])) past_key_value = [None] * num_layers_prev_rank assert len(past_key_value) == num_layers_prev_rank sequence_length = None context_lengths = None host_context_lengths = None host_past_key_value_lengths = None host_max_attention_window_sizes = None host_sink_token_length = None attention_mask = None cache_indirection = None host_request_types = None runtime_perf_knobs = None context_progress = None if use_gpt_attention_plugin: if kv_cache_type != KVCacheType.DISABLED: sequence_length = Tensor( name='sequence_length', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([('batch_size_beam_width', bb_range) ]), ) host_request_types = Tensor( name='host_request_types', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([('batch_size_beam_width', bb_range)]), ) if kv_cache_type != KVCacheType.DISABLED: host_past_key_value_lengths = Tensor( name='host_past_key_value_lengths', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([('batch_size_beam_width', bb_range) ]), ) context_lengths = Tensor( name='context_lengths', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([('batch_size_beam_width', bb_range)]), ) runtime_perf_knobs = Tensor(name='host_runtime_perf_knobs', dtype=trt.int64, shape=[16], dim_range=OrderedDict([ ('perf_knob_size', [16] * num_profiles) ])) context_progress = Tensor(name='host_context_progress', dtype=trt.int64, shape=[1], dim_range=OrderedDict([ ('context_progress_size', [1] * num_profiles) ])) else: attention_mask = Tensor( name='attention_mask', dtype=trt.int32, shape=[-1, -1], dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), ('mask_len', mask_len_range), ]), ) if use_gpt_attention_plugin and remove_input_padding: host_context_lengths = Tensor( name='host_context_lengths', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([('batch_size_beam_width', bb_range)]), ) if use_gpt_attention_plugin: # TODO: change shape to [1] host_max_attention_window_sizes = Tensor( name=f'host_max_attention_window_sizes', dtype=trt.int32, shape=[num_attn_layers], dim_range=OrderedDict([('num_layers', [num_attn_layers] * num_profiles)])) host_sink_token_length = Tensor(name='host_sink_token_length', dtype=trt.int32, shape=[1], dim_range=OrderedDict([ ('scalar', [1] * num_profiles) ])) if kv_cache_type != KVCacheType.DISABLED: cache_indirection = Tensor( name='cache_indirection', dtype=trt.int32, shape=[-1, -1, -1], dim_range=OrderedDict([ ('batch_size_cache', bs_range), ('beam_width', beam_width_range), ('max_seq_len', max_len_range), ]), ) return { 'attention_mask': attention_mask, 'sequence_length': sequence_length, 'host_past_key_value_lengths': host_past_key_value_lengths, 'host_max_attention_window_sizes': host_max_attention_window_sizes, 'host_sink_token_length': host_sink_token_length, 'past_key_value': past_key_value, 'cache_indirection': cache_indirection, 'kv_cache_block_offsets': kv_cache_block_offsets, 'host_kv_cache_block_offsets': host_kv_cache_block_offsets, 'host_kv_cache_pool_pointers': host_kv_cache_pool_pointers, 'host_kv_cache_pool_mapping': host_kv_cache_pool_mapping, 'context_lengths': context_lengths, 'host_context_lengths': host_context_lengths, 'host_request_types': host_request_types, 'host_runtime_perf_knobs': runtime_perf_knobs, 'host_context_progress': context_progress, } def prepare_basic_inputs( self, *, max_batch_size, max_beam_width, max_input_len, max_seq_len, max_num_tokens, hidden_size, num_kv_heads, head_size, num_layers, kv_dtype, kv_cache_type: KVCacheType, remove_input_padding=False, use_gpt_attention_plugin=False, use_gemm_plugin=False, tokens_per_block=32, gather_context_logits=False, dtype=None, num_heads=None, mapping=Mapping(), opt_num_tokens=None, prompt_embedding_table_size: int = 0, position_encoding_2d=False, use_lora_plugin: bool = False, lora_target_modules: List[str] = None, speculative_decoding_draft_tokens_external: bool = False, spec_decoding_is_generation_length_variable: bool = False, max_draft_len=0, multiple_profiles: bool = False, streamingllm: bool = False, opt_batch_size=None, pp_reduce_scatter: bool = False, mrope_rotary_cos_sin_size: int = None, ): enable_ctx_gen_opt_profiles = GenerationMixin.has_ctx_gen_opt_profiles( use_gpt_attention_plugin=use_gpt_attention_plugin, use_gemm_plugin=use_gemm_plugin, remove_input_padding=remove_input_padding, kv_cache_type=kv_cache_type) num_profiles, ranges = GenerationMixin.get_profiles_ranges( max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_input_len=max_input_len, max_num_tokens=max_num_tokens, max_draft_len=max_draft_len, opt_batch_size=opt_batch_size, opt_num_tokens=opt_num_tokens, enable_ctx_gen_opt_profiles=enable_ctx_gen_opt_profiles, multiple_profiles=multiple_profiles, kv_cache_type=kv_cache_type) bb_range = ranges['bb_range'] bbd_range = ranges['bbd_range'] inlen_range = ranges['inlen_range'] num_tokens_range = ranges['num_tokens_range'] position_ids_inlen_range = ranges['position_ids_inlen_range'] tokens_per_engine_step_range = ranges['tokens_per_engine_step_range'] position_ids_num_tokens_range = ranges['position_ids_num_tokens_range'] input_ids = None position_ids = None hidden_states = None if remove_input_padding: if mapping.is_first_pp_rank(): input_ids = Tensor(name='input_ids', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([ ('num_tokens', num_tokens_range), ])) if position_encoding_2d: position_ids = Tensor( name='position_ids', dtype=trt.int32, shape=[2, -1], dim_range=OrderedDict([ ('2', [2] * num_profiles), ('position_ids_num_tokens_range', position_ids_num_tokens_range), ]), ) else: position_ids = Tensor( name='position_ids', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([ ('position_ids_num_tokens_range', position_ids_num_tokens_range), ]), ) else: assert dtype is not None assert num_heads is not None pp_hidden_size = hidden_size // mapping.tp_size if pp_reduce_scatter else hidden_size hidden_states = Tensor( name='hidden_states_input', dtype=dtype, shape=[-1, pp_hidden_size], dim_range=OrderedDict([ ('num_tokens', num_tokens_range), ('hidden_size', [pp_hidden_size] * num_profiles), ]), ) else: if mapping.is_first_pp_rank(): input_ids = Tensor(name='input_ids', dtype=trt.int32, shape=[-1, -1], dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), ('input_len', inlen_range), ])) if position_encoding_2d: position_ids = Tensor( name='position_ids', dtype=trt.int32, shape=[-1, 2, -1], dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), ('2', [2] * num_profiles), ('position_ids_inlen_range', position_ids_inlen_range), ]), ) else: position_ids = Tensor( name='position_ids', dtype=trt.int32, shape=[-1, -1], dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), ('position_ids_inlen_range', position_ids_inlen_range), ]), ) else: assert dtype is not None assert num_heads is not None hidden_states = Tensor( name='hidden_states_input', dtype=dtype, shape=[-1, -1, hidden_size], dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), ('input_len', inlen_range), ('hidden_size', [hidden_size] * num_profiles), ]), ) if mapping.tp_size > 1: current_all_reduce_helper().set_workspace_tensor( mapping, num_profiles) prompt_embedding_table = None tasks = None prompt_vocab_size = None if prompt_embedding_table_size > 0: _p_embedding_range = [ 1, prompt_embedding_table_size // 2, prompt_embedding_table_size ] p_embedding_range = [_p_embedding_range] * num_profiles prompt_embedding_table = Tensor(name='prompt_embedding_table', dtype=dtype, shape=[-1, hidden_size], dim_range=OrderedDict([ ('prompt_embedding_table_size', p_embedding_range), ('hidden_size', [hidden_size] * num_profiles), ])) if remove_input_padding: tasks = Tensor(name='tasks', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([ ('input_len_task', num_tokens_range), ])) else: tasks = Tensor(name='tasks', dtype=trt.int32, shape=[-1, 1], dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), ('broadcast_dim', [1] * num_profiles), ])) prompt_vocab_size = Tensor(name='prompt_vocab_size', dtype=trt.int32, shape=[1], dim_range=OrderedDict([ ('size', [1] * num_profiles) ])) lora_weights_pointers = None lora_ranks = None if use_lora_plugin: lora_weights_pointers = [] lora_ranks = [] layers_range = mapping.pp_layers(num_layers) for i in layers_range: lora_weight_pointer_dict = {} lora_rank_dict = {} for lora_module in lora_target_modules: lora_weight_pointer = Tensor( name=f'{lora_module}_lora_weights_pointers_{i}', dtype=trt.int64, shape=[-1, 3], dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), ('in_out_scales', [3] * num_profiles), ])) lora_weight_pointer_dict.update({ f"{lora_module}_lora_weights_pointers": lora_weight_pointer }) lora_rank = Tensor( name=f'{lora_module}_lora_ranks_{i}', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([('batch_size_beam_width', bb_range)]), ) lora_rank_dict.update( {f"{lora_module}_lora_ranks": lora_rank}) lora_weights_pointers.append(lora_weight_pointer_dict) lora_ranks.append(lora_rank_dict) last_token_ids = None if mapping.is_last_pp_rank() and not gather_context_logits: if not remove_input_padding and max_draft_len > 0: last_token_ids = Tensor( name='last_token_ids', dtype=trt.int32, shape=[-1, -1], dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), ('last_token_ids', tokens_per_engine_step_range), ]), ) else: last_token_ids = Tensor( name='last_token_ids', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([ ('batch_size_last_token_ids', bbd_range), ]), ) spec_decoding_params = None # Use positional offsets and packed mask only when not in SpS spec decoding if speculative_decoding_draft_tokens_external == False and max_draft_len > 0: tokens_per_engine_step = max_draft_len + 1 # 32 bits packed mask aligned. num_packed_masks = (tokens_per_engine_step + 32 - 1) // 32 packed_mask_len_range = [[0, 1, num_packed_masks]] * num_profiles # total number of spec decoding tokens for all sequences (sequence length can be variable). num_gen_tokens_range = [ GenerationMixin.default_range( max_batch_size * max_beam_width * tokens_per_engine_step, min_range=0) ] * num_profiles bb_range_0 = [[0] + bbr[1:] for bbr in bb_range] # support variable sequence lengths for medusa. spec_decoding_generation_lengths = Tensor( name='spec_decoding_generation_lengths', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([('batch_size_beam_width_0', bb_range_0) ]), ) # position offsets that are fixed during the whole session. # it will be shared among all sequences. spec_decoding_position_offsets = Tensor( name='spec_decoding_position_offsets', dtype=trt.int32, shape=[-1, -1], dim_range=OrderedDict([ ('batch_size_beam_width_0', bb_range_0), ('spec_decoding_position_ids_dim0', tokens_per_engine_step_range), ]), ) spec_decoding_packed_mask = Tensor( name='spec_decoding_packed_mask', dtype=trt.int32, shape=[-1, -1], dim_range=OrderedDict([ ('spec_decoding_packed_mask_dim0', num_gen_tokens_range), ('spec_decoding_packed_mask_dim1', packed_mask_len_range), ]), ) spec_decoding_use = Tensor(name='spec_decoding_use', dtype=trt.int32, shape=[1], dim_range=OrderedDict([ ('spec_decoding_use_dim', [1] * num_profiles), ])) spec_decoding_params = SpecDecodingParams( spec_decoding_is_generation_length_variable= spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length=tokens_per_engine_step, spec_decoding_generation_lengths= spec_decoding_generation_lengths, spec_decoding_position_offsets=spec_decoding_position_offsets, spec_decoding_packed_mask=spec_decoding_packed_mask, spec_decoding_use=spec_decoding_use) mrope_params = None if mrope_rotary_cos_sin_size is not None: mrope_rotary_cos_sin = Tensor( name='mrope_rotary_cos_sin', dtype=trt.float32, shape=[-1, mrope_rotary_cos_sin_size], dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), ('mult_dim', [mrope_rotary_cos_sin_size] * num_profiles), ]), ) mrope_position_deltas = Tensor( name='mrope_position_deltas', dtype=trt.int32, shape=[-1, 1], dim_range=OrderedDict([('batch_size_beam_width', bb_range), ('mult_dim_delta', [1] * num_profiles)]), ) mrope_params = MropeParams( mrope_rotary_cos_sin=mrope_rotary_cos_sin, mrope_position_deltas=mrope_position_deltas, ) basic_inputs = { 'input_ids': input_ids, 'hidden_states_input': hidden_states, 'position_ids': position_ids, 'last_token_ids': last_token_ids, 'prompt_embedding_table': prompt_embedding_table, 'tasks': tasks, 'prompt_vocab_size': prompt_vocab_size, 'lora_ranks': lora_ranks, 'lora_weights_pointers': lora_weights_pointers, 'spec_decoding_params': spec_decoding_params, 'mrope_params': mrope_params, } attention_inputs = self.prepare_attention_inputs( max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_input_len=max_input_len, max_seq_len=max_seq_len, num_kv_heads=num_kv_heads, head_size=head_size, num_layers=num_layers, kv_dtype=kv_dtype, num_profiles=num_profiles, enable_ctx_gen_opt_profiles=enable_ctx_gen_opt_profiles, remove_input_padding=remove_input_padding, use_gpt_attention_plugin=use_gpt_attention_plugin, kv_cache_type=kv_cache_type, tokens_per_block=tokens_per_block, mapping=mapping, streamingllm=streamingllm, opt_batch_size=opt_batch_size) for key, value in attention_inputs.items(): basic_inputs[key] = value return basic_inputs