Layers

Activation

class tensorrt_llm.layers.activation.Mish[source]

Bases: Module

forward(input)[source]

Attention

class tensorrt_llm.layers.attention.Attention(hidden_size, num_attention_heads, num_kv_heads=None, max_position_embeddings=1024, num_layers=1, apply_query_key_layer_scaling=False, attention_mask_type=AttentionMaskType.padding, bias=True, dtype=None, position_embedding_type=PositionEmbeddingType.learned_absolute, rotary_embedding_base=10000.0, rotary_embedding_scaling=None, use_int8_kv_cache=False, rotary_embedding_percentage=1.0, tp_group=None, tp_size=1, tp_rank=0, multi_block_mode=False, quant_mode: ~tensorrt_llm.quantization.mode.QuantMode = QuantMode.None, q_scaling=1.0, cross_attention=False, relative_attention=False, max_distance=0, num_buckets=0, instance_id: int = 0)[source]

Bases: Module

forward(hidden_states: Tensor, attention_mask=None, use_cache=False, kv_cache_params=None, attention_params=None, encoder_output: Tensor | None = None, workspace=None)[source]
class tensorrt_llm.layers.attention.AttentionParams(sequence_length: Tensor | None = None, context_lengths: Tensor | None = None, host_context_lengths: Tensor | None = None, max_context_length: int | None = None, host_request_types: Tensor | None = None, encoder_input_lengths: Tensor | None = None, encoder_max_input_length: Tensor | None = None)[source]

Bases: object

is_valid(gpt_attention_plugin, remove_input_padding)[source]
is_valid_cross_attn(do_cross_attention)[source]
class tensorrt_llm.layers.attention.BertAttention(hidden_size, num_attention_heads, num_kv_heads=None, max_position_embeddings=1024, num_layers=1, q_scaling=1.0, apply_query_key_layer_scaling=False, bias=True, dtype=None, tp_group=None, tp_size=1, tp_rank=0, relative_attention=False, max_distance=0, num_buckets=0)[source]

Bases: Module

forward(hidden_states: Tensor, attention_mask=None, input_lengths=None)[source]
class tensorrt_llm.layers.attention.KeyValueCacheParams(past_key_value: List[Tensor] | None = None, host_past_key_value_lengths: Tensor | None = None, kv_cache_block_pointers: List[Tensor] | None = None, cache_indirection: Tensor | None = None, past_key_value_length: Tensor | None = None)[source]

Bases: object

get_first_kv_cache_block_pointers()[source]
get_first_past_key_value()[source]
is_valid(gpt_attention_plugin)[source]

Cast

class tensorrt_llm.layers.cast.Cast(output_dtype: str = 'float32')[source]

Bases: Module

forward(x)[source]

Conv

class tensorrt_llm.layers.conv.Conv2d(in_channels: int, out_channels: int, kernel_size: Tuple[int, int], stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', dtype=None)[source]

Bases: Module

forward(input)[source]
class tensorrt_llm.layers.conv.ConvTranspose2d(in_channels: int, out_channels: int, kernel_size: Tuple[int, int], stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (0, 0), output_padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', dtype=None)[source]

Bases: Module

forward(input, output_size=None)[source]

Embedding

class tensorrt_llm.layers.embedding.Embedding(num_embeddings, embedding_dim, dtype=None, tp_size=1, tp_group=None, sharding_dim=0, tp_rank=None)[source]

Bases: Module

The embedding layer takes input indices (x) and the embedding lookup table (weight) as input. And output the corresponding embeddings according to input indices. The size of weight is [num_embeddings, embedding_dim]

Four parameters (tp_size, tp_group, sharding_dim, tp_rank) are involved in tensor parallelism. Only when “tp_size > 1 and tp_group is not None”, tensor parallelism is enabled.

When “sharding_dim == 0”, the weight is shared in the vocabulary dimension.

tp_rank must be set when sharding_dim == 0.

When “sharding_dim == 1”, the weight is shard in the hidden dimension.

forward(x)[source]
class tensorrt_llm.layers.embedding.PromptTuningEmbedding(num_embeddings, embedding_dim, vocab_size=None, dtype=None, tp_size=1, tp_group=None, sharding_dim=0, tp_rank=0)[source]

Bases: Embedding

Pass all tokens though both normal and prompt embedding tables.

Then, combine results based on whether the token was “normal” or “prompt/virtual”.

forward(tokens, prompt_embedding_table, tasks, task_vocab_size)[source]

Linear

tensorrt_llm.layers.linear.ColumnLinear

alias of Linear

class tensorrt_llm.layers.linear.Linear(in_features, out_features, bias=True, dtype=None, tp_group=None, tp_size=1, gather_output=True, share_weight=None)[source]

Bases: Module

forward(x)[source]
multiply_gather(x, weight, gemm_plugin, use_fp8=False)[source]
class tensorrt_llm.layers.linear.RowLinear(in_features, out_features, bias=True, dtype=None, tp_group=None, tp_size=1, instance_id: int = 0)[source]

Bases: Module

forward(x, workspace=None)[source]
multiply_reduce(x, weight, gemm_plugin, use_fp8=False, workspace=None)[source]

MLP

class tensorrt_llm.layers.mlp.GatedMLP(hidden_size, ffn_hidden_size, hidden_act, bias=True, dtype=None, tp_group=None, tp_size=1, quant_mode=QuantMode.None, instance_id: int = 0)[source]

Bases: MLP

forward(hidden_states, workspace=None)[source]
class tensorrt_llm.layers.mlp.MLP(hidden_size, ffn_hidden_size, hidden_act, bias=True, dtype=None, tp_group=None, tp_size=1, quant_mode=QuantMode.None, instance_id: int = 0)[source]

Bases: Module

forward(hidden_states, workspace=None)[source]

Normalization

class tensorrt_llm.layers.normalization.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, dtype=None)[source]

Bases: Module

forward(x)[source]
class tensorrt_llm.layers.normalization.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, dtype=None)[source]

Bases: Module

forward(x)[source]
class tensorrt_llm.layers.normalization.RmsNorm(normalized_shape, eps=1e-06, elementwise_affine=True, dtype=None)[source]

Bases: Module

forward(x)[source]

Pooling

class tensorrt_llm.layers.pooling.AvgPool2d(kernel_size: Tuple[int], stride: Tuple[int] | None = None, padding: Tuple[int] | None = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True)[source]

Bases: Module

forward(input)[source]