import math import weakref from enum import IntEnum from typing import Optional, cast import torch from torch import nn from tensorrt_llm.mapping import Mapping from ..attention_backend import (AttentionInputType, AttentionMetadata, TrtllmAttention, TrtllmAttentionMetadata) from ..attention_backend.interface import (PositionalEmbeddingParams, PredefinedAttentionMask) from ..attention_backend.utils import create_attention, get_attention_backend from ..distributed import AllReduceParams from ..model_config import ModelConfig from ..peft.lora.layer import LoraLayer, LoraModuleType from ..utils import get_model_extra_attrs from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig from .multi_stream_utils import maybe_execute_in_parallel from .rms_norm import RMSNorm from .rotary_embedding import RotaryEmbedding class QkNormType(IntEnum): """ The type of QK normalization. """ none = 0 # No normalization applied to Q and K pre_rope = 1 # Apply normalization before Rope post_rope = 2 # Apply normalization after Rope class Attention(nn.Module): def __init__( self, *, hidden_size: int, num_attention_heads: int, num_key_value_heads: int, max_position_embeddings: int, bias: bool, pos_embd_params: Optional[PositionalEmbeddingParams] = None, layer_idx: Optional[int] = None, dtype: torch.dtype = None, dense_bias: Optional[bool] = None, config: Optional[ModelConfig] = None, qk_norm_type: QkNormType = QkNormType.none, q_scaling: float = 1.0, ): """ Initialize the Attention module. Args: hidden_size (int): The size of the hidden dimension. num_attention_heads (int): The number of attention heads. num_key_value_heads (int): The number of key value heads. max_position_embeddings (int): The maximum position embeddings. bias (bool): Whether to use bias in the linear layers. pos_embd_params (PositionalEmbeddingParams): The positional embedding parameters. layer_idx (int): The layer index. dtype (torch.dtype): The data type. dense_bias (bool): Whether to use bias in the output projection layer. config (ModelConfig): The model configuration. qk_norm_type (QkNormType): The type of QK normalization. q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0. """ super().__init__() self.layer_idx = layer_idx config = config or ModelConfig() self.hidden_size = hidden_size self.num_heads = num_attention_heads self.head_dim = getattr(config.pretrained_config, "head_dim", self.hidden_size // self.num_heads) self.num_key_value_heads = num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = max_position_embeddings self.pos_embd_params = pos_embd_params self.qk_norm_type = qk_norm_type self.dense_bias = dense_bias self.q_scaling = q_scaling if dense_bias is None: self.dense_bias = bias # tensor parallel tp_size = config.mapping.tp_size pp_size = config.mapping.pp_size if config.mapping.enable_attention_dp: tp_size = 1 mapping = Mapping( world_size=tp_size * pp_size, tp_size=tp_size, pp_size=pp_size, rank=config.mapping.rank, gpus_per_node=config.mapping.gpus_per_node, enable_attention_dp=config.mapping.enable_attention_dp, ) assert self.num_heads % tp_size == 0 self.num_heads = self.num_heads // tp_size self.num_key_value_heads = (self.num_key_value_heads + tp_size - 1) // tp_size self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim self.qkv_proj = Linear( self.hidden_size, tp_size * self.q_size + 2 * tp_size * self.kv_size, bias=bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, weights_loading_config=WeightsLoadingConfig( weight_mode=WeightMode.FUSED_QKV_LINEAR), quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, ) self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) self.o_proj = Linear( tp_size * self.q_size, self.hidden_size, bias=self.dense_bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, lora=self.o_lora, ) self.quant_config = config.get_quant_config() self.attn_backend = config.attn_backend self.pos_embd_params = pos_embd_params # These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used, # but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora # handles them as a single fused operation. self.splitted_qkv_lora = LoraLayer([ LoraModuleType.ATTENTION_Q, LoraModuleType.ATTENTION_K, LoraModuleType.ATTENTION_V ], [self.q_size, self.kv_size, self.kv_size]) self.fused_qkv_lora = LoraLayer([LoraModuleType.ATTENTION_QKV], [self.q_size + 2 * self.kv_size]) self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) attn_cls = get_attention_backend(self.attn_backend) self.enable_rope_fusion = attn_cls.support_fused_rope( ) and qk_norm_type != QkNormType.post_rope self.attn = create_attention( self.attn_backend, self.layer_idx, self.num_heads, self.head_dim, self.num_key_value_heads, pos_embd_params=self.pos_embd_params if self.enable_rope_fusion else None, quant_config=self.quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, q_scaling=self.q_scaling, ) self.support_fused_qkv = self.attn.support_fused_qkv() self.rotary_emb = None self.apply_rotary_emb = (not self.enable_rope_fusion and pos_embd_params is not None) if self.apply_rotary_emb: self.rotary_emb = RotaryEmbedding( pos_embd_params.rope, head_dim=self.head_dim, is_neox=pos_embd_params.is_neox, ) if not config.skip_create_weights_in_init: self.create_weights() def create_weights(self): # self.attn has no weights but has states that are related to quant_config, # which could be modified after __init__ self.attn.update_quant_config(self.quant_config) def split_qkv(self, q, k=None, v=None): if k is None and v is None: q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1) return q, k, v def convert_qkv(self, q, k, v): if k is None and v is None and not self.support_fused_qkv: q, k, v = self.split_qkv(q) elif k is not None and v is not None and self.support_fused_qkv: qkv = torch.concat([q, k, v], dim=-1) q, k, v = qkv, None, None return q, k, v def forward( self, position_ids: Optional[torch.LongTensor], hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, attention_mask: PredefinedAttentionMask = PredefinedAttentionMask. CAUSAL, mrope_config: Optional[dict] = None, all_reduce_params: Optional[AllReduceParams] = None, lora_params: Optional[dict] = None, attention_window_size: Optional[int] = None, **kwargs, ) -> torch.Tensor: """ Forward pass for the Attention module. Args: position_ids (Optional[torch.LongTensor]): The position IDs. hidden_states (torch.Tensor): The hidden states. attn_metadata (AttentionMetadata): The attention metadata. attention_mask (PredefinedAttentionMask): The attention mask type. mrope_config (Optional[dict]): The MROPE configuration. all_reduce_params (Optional[AllReduceParams]): The all reduce parameters. lora_params (Optional[dict]): The LoRA parameters. attention_window_size (Optional[int]): The attention window size. Returns: torch.Tensor: The output tensor. """ qkv = self.qkv_proj(hidden_states) if bool(lora_params): qkv_lora = self.splitted_qkv_lora(hidden_states, lora_params, self.layer_idx) if qkv_lora is not None: qkv = qkv + qkv_lora qkv_lora = self.fused_qkv_lora(hidden_states, lora_params, self.layer_idx) if qkv_lora is not None: qkv = qkv + qkv_lora q, k, v = qkv, None, None if self.qk_norm_type == QkNormType.pre_rope: q, k, v = self.split_qkv(q, k, v) q, k = self.apply_qk_norm(q, k) if self.apply_rotary_emb and position_ids is not None: q, k, v = self.split_qkv(q, k, v) q, k = self.rotary_emb(position_ids, [q, k]) if self.qk_norm_type == QkNormType.post_rope: q, k = self.apply_qk_norm(q, k) out_scale = None if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales: out_scale = self.o_proj.inv_input_scale q, k, v = self.convert_qkv(q, k, v) attn_output = self.attn.forward( q, k, v, attn_metadata, out_scale=out_scale, attention_mask=attention_mask, mrope_config=mrope_config, attention_window_size=attention_window_size) hidden_states = attn_output attn_output = self.o_proj(attn_output, all_reduce_params=all_reduce_params, lora_params=lora_params, layer_idx=self.layer_idx) return attn_output def apply_qk_norm(self, q, k): raise NotImplementedError( f"QK norm is not implemented for {self.__class__.__name__}." "Please override the `apply_qk_norm` method in the subclass.") def extract_extra_attrs(layer_idx: str): extra_attrs = get_model_extra_attrs() assert extra_attrs is not None, "Model extra attrs is not set" metadata_ref = extra_attrs.get("attention_metadata", None) assert metadata_ref is not None, "Attention metadata is not set" metadata = metadata_ref() assert isinstance( metadata, TrtllmAttentionMetadata, ) mla_layers = extra_attrs.get("mla_layers", None) assert mla_layers is not None, "MLA layers is not registered" mla_layer_ref = mla_layers.get(layer_idx, None) assert mla_layer_ref is not None, f"Cannot find MLA layer for layer {layer_idx}" mla_layer = mla_layer_ref() assert isinstance( mla_layer, MLA), "MLA layer must be a subclass of MLA or an instance of MLA" return metadata, mla_layer @torch.library.custom_op("trtllm::mla_custom_op", mutates_args=()) def mla_custom_op( position_ids: Optional[torch.Tensor], hidden_states: torch.Tensor, layer_idx: str, ) -> torch.Tensor: metadata, mla_layer = extract_extra_attrs(layer_idx) return mla_layer.forward_impl(position_ids, hidden_states, metadata) @mla_custom_op.register_fake def _(position_ids, hidden_states, layer_idx): _, mla_layer = extract_extra_attrs(layer_idx) return mla_layer.forward_impl_fake(hidden_states) class MLA(nn.Module): def __init__( self, *, hidden_size: int, num_attention_heads: int, num_key_value_heads: int, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, q_lora_rank: int, kv_lora_rank: int, predicted_tokens_per_seq: int, max_position_embeddings: int, bias: bool, aux_stream: Optional[torch.cuda.Stream] = None, pos_embd_params: Optional[PositionalEmbeddingParams] = None, layer_idx: Optional[int] = None, dtype: torch.dtype = None, dense_bias: Optional[bool] = None, config: Optional[ModelConfig] = None, ): """ Initialize the MLA module. Args: hidden_size (int): The size of the hidden dimension. num_attention_heads (int): The number of attention heads. num_key_value_heads (int): The number of key value heads. qk_nope_head_dim (int): The dimension of the query and key without Rope. qk_rope_head_dim (int): The dimension of the Rope of query and key. v_head_dim (int): The dimension of the value. q_lora_rank (int): The dimension of the compressed query. kv_lora_rank (int): The dimension of the compressed key and value. predicted_tokens_per_seq (int): The number of predicted tokens per sequence. max_position_embeddings (int): The maximum position embeddings. bias (bool): Whether to use bias in the linear layers. aux_stream (Optional[torch.cuda.Stream]): The auxiliary CUDA stream for running operations in two parallel streams. pos_embd_params (PositionalEmbeddingParams): The positional embedding parameters. layer_idx (int): The layer index. dtype (torch.dtype): The data type. dense_bias (bool): Whether to use bias in the output projection layer. config (ModelConfig): The model configuration. """ super().__init__() self.layer_idx = layer_idx self.layer_idx_str = str(layer_idx) self.dtype = dtype self.hidden_size = hidden_size self.num_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.predicted_tokens_per_seq = predicted_tokens_per_seq self.max_position_embeddings = max_position_embeddings self.pos_embd_params = pos_embd_params self.dense_bias = dense_bias if dense_bias is None: self.dense_bias = bias if self.q_lora_rank is None: self.q_lora_rank = hidden_size self.is_lite = True else: self.is_lite = False assert pos_embd_params is not None, "pos_embd_params must be provided in MLA" self.register_to_config = False if config is not None: if "mla_layers" not in config.extra_attrs: config.extra_attrs["mla_layers"] = {} config.extra_attrs["mla_layers"][self.layer_idx_str] = weakref.ref( self) self.register_to_config = True # tensor parallel config = config or ModelConfig() tp_size = config.mapping.tp_size pp_size = config.mapping.pp_size if config.mapping.enable_attention_dp: tp_size = 1 mapping = Mapping( world_size=tp_size * pp_size, tp_size=tp_size, pp_size=pp_size, rank=config.mapping.rank, gpus_per_node=config.mapping.gpus_per_node, enable_attention_dp=config.mapping.enable_attention_dp, ) assert self.num_heads % tp_size == 0 self.num_heads = self.num_heads // tp_size self.num_key_value_heads = (self.num_key_value_heads + tp_size - 1) // tp_size rms_norm_eps = config.pretrained_config.rms_norm_eps quant_config = config.get_quant_config() self.quant_config = quant_config if not self.is_lite: self.fused_a = Linear( hidden_size, self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim, bias=bias, dtype=dtype, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, use_custom_cublas_mm=True) self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank, eps=rms_norm_eps, dtype=dtype) self.q_b_proj = Linear( self.q_lora_rank, tp_size * self.num_heads * self.qk_head_dim, bias=bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init) else: self.fused_a = Linear( hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=bias, dtype=dtype, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, use_custom_cublas_mm=True) self.q_proj = Linear( self.q_lora_rank, tp_size * self.num_heads * self.qk_head_dim, bias=bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, ) self.q_b_proj = self.q_proj self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank, dtype=dtype, eps=rms_norm_eps) self.kv_b_proj = Linear( self.kv_lora_rank, tp_size * self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init) # This parameter will view into self.kv_b_proj.weight after loading weights. # For dummy weight initialization, this parameter is initialized with empty tensor. # Used in forward_generation only self.v_b_proj = nn.Parameter( torch.empty( (self.num_heads, self.v_head_dim, self.kv_lora_rank), dtype=dtype, ), requires_grad=False, ) self.o_proj = Linear( self.num_key_value_heads * self.v_head_dim * tp_size, self.hidden_size, bias=self.dense_bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, ) def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 mscale_all_dim = pos_embd_params.rope.mscale_all_dim scaling_factor = pos_embd_params.rope.scale mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) q_scaling = 1.0 / (mscale * mscale) self.mha = create_attention( config.attn_backend, self.layer_idx, self.num_heads, head_dim=self.qk_head_dim, num_kv_heads=self.num_key_value_heads, pos_embd_params=pos_embd_params, quant_config=quant_config, q_scaling=q_scaling, is_mla_enable=True, q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, qk_nope_head_dim=self.qk_nope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim, v_head_dim=self.v_head_dim, predicted_tokens_per_seq=self.predicted_tokens_per_seq, skip_create_weights_in_init=config.skip_create_weights_in_init, ) self.mqa = create_attention( config.attn_backend, self.layer_idx, self.num_heads, head_dim=self.kv_lora_rank + self.qk_rope_head_dim, num_kv_heads=1, pos_embd_params=pos_embd_params, quant_config=quant_config, q_scaling=q_scaling, is_mla_enable=True, q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, qk_nope_head_dim=self.qk_nope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim, v_head_dim=self.kv_lora_rank, predicted_tokens_per_seq=self.predicted_tokens_per_seq, skip_create_weights_in_init=config.skip_create_weights_in_init, ) self.aux_stream = aux_stream self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] self.enable_rope_fusion = self.mha.support_fused_rope() self.support_fused_qkv = self.mha.support_fused_qkv() self.rotary_emb = RotaryEmbedding( pos_embd_params.rope, head_dim=self.qk_rope_head_dim, is_neox=pos_embd_params.is_neox, ) self.apply_rotary_emb = not self.enable_rope_fusion if not config.skip_create_weights_in_init: self.create_weights() def create_weights(self): # self.mha/mqa has no weights but has states that are related to quant_config, # which could be modified after __init__ self.mha.update_quant_config(self.quant_config) self.mqa.update_quant_config(self.quant_config) # k_b_proj_trans's dtype must be consistent with self.kv_b_proj, # which can be modified after __init__ has_fp8_block_scales = ( self.kv_b_proj.quant_config and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales()) mla_weight_dtype = torch.float8_e4m3fn if has_fp8_block_scales else self.dtype self.k_b_proj_trans = nn.Parameter( torch.empty( (self.num_heads, self.kv_lora_rank, self.qk_nope_head_dim), dtype=mla_weight_dtype, ), requires_grad=False, ) if has_fp8_block_scales: self.k_b_proj_trans_scale = nn.Parameter( torch.empty( ( self.num_heads, self.kv_lora_rank // 128, self.qk_nope_head_dim // 128, ), dtype=torch.float32, ), requires_grad=False, ) # This parameter will view into self.kv_b_proj.weight_scale after loading weights. # For dummy weight initialization, this parameter is initialized with empty tensor. self.v_b_proj_scale = nn.Parameter( torch.empty( ( self.num_heads, self.v_head_dim // 128, self.kv_lora_rank // 128, ), dtype=torch.float32, ), requires_grad=False, ) else: self.k_b_proj_trans_scale = None self.v_b_proj_scale = None def apply_rope( self, q: torch.Tensor, k_pe: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: q = q.view(-1, self.num_heads, self.qk_head_dim) q_pe = q[..., self.qk_nope_head_dim:].reshape( -1, self.num_heads * self.qk_rope_head_dim) q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe]) q[..., self.qk_nope_head_dim:] = q_pe.view(-1, self.num_heads, self.qk_rope_head_dim) return k_pe def forward_impl_fake(self, hidden_states: torch.Tensor): num_tokens = hidden_states.shape[0] hidden_size = self.o_proj.in_features return hidden_states.new_empty([num_tokens, hidden_size], dtype=hidden_states.dtype) def forward_impl( self, position_ids: Optional[torch.Tensor], hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: """ Forward pass for the MLA module. Args: position_ids (Optional[torch.LongTensor]): The position IDs. hidden_states (torch.Tensor): The hidden states. attn_metadata (AttentionMetadata): The attention metadata. all_reduce_params (Optional[AllReduceParams]): The all reduce parameters. Returns: torch.Tensor: The output tensor. """ if self.is_lite: compressed_kv, k_pe = self.fused_a(hidden_states).split( [self.kv_lora_rank, self.qk_rope_head_dim], -1) compressed_kv = self.kv_a_layernorm(compressed_kv) q = hidden_states else: q, compressed_kv, k_pe = self.fused_a(hidden_states).split( [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], -1) q, compressed_kv = maybe_execute_in_parallel( lambda: self.q_a_layernorm(q), lambda: self.kv_a_layernorm(compressed_kv), self.ln_events[0], self.ln_events[1], self.aux_stream, ) q, latent_cache = maybe_execute_in_parallel( lambda: self.q_b_proj(q), lambda: torch.concat([compressed_kv, k_pe], dim=-1), self.ln_events[0], self.ln_events[1], self.aux_stream, ) # split q, k, v into context and gen batches num_contexts = attn_metadata.num_contexts num_generations = attn_metadata.num_generations num_ctx_tokens = attn_metadata.num_ctx_tokens num_tokens = attn_metadata.num_tokens assert q.shape[ 0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}" if num_contexts > 0: q_ctx = q[:num_ctx_tokens, ...] compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...] k_pe_ctx = k_pe[:num_ctx_tokens, ...] latent_cache_ctx = latent_cache[:num_ctx_tokens, ...] if self.apply_rotary_emb: assert position_ids is not None k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids) attn_output_context = self.forward_context(q_ctx, compressed_kv_ctx, k_pe_ctx, attn_metadata, latent_cache_ctx, position_ids) else: attn_output_context = None if num_generations > 0: q_gen = q[num_ctx_tokens:, ...] compressed_kv_gen = compressed_kv[num_ctx_tokens:, ...] k_pe_gen = k_pe[num_ctx_tokens:, ...] latent_cache_gen = latent_cache[num_ctx_tokens:, ...] if self.apply_rotary_emb: assert position_ids is not None k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids) attn_output_gen = self.forward_generation(q_gen, compressed_kv_gen, k_pe_gen, attn_metadata, latent_cache_gen) else: attn_output_gen = None # release pytorch activation memory q = None compressed_kv = None k_pe = None # merge context and gen batches if attn_output_context is not None and attn_output_gen is not None: assert ( len(attn_output_context.shape) == 2 ), f"attn_output_context must be rank 2, not {len(attn_output_context.shape)}" assert ( len(attn_output_gen.shape) == 2 ), f"attn_output_gen must be rank 2, not {len(attn_output_gen.shape)}" attn_output = torch.cat([attn_output_context, attn_output_gen], dim=0) # release pytorch activation memory attn_output_context = None attn_output_gen = None elif attn_output_gen is None: attn_output = attn_output_context else: attn_output = attn_output_gen return attn_output def _maybe_concat_qkv(self, q, k, v): if k is not None and v is not None and self.support_fused_qkv: qkv = torch.concat([q, k, v], dim=-1) q, k, v = qkv, None, None return q, k, v def forward_context_default( self, q: torch.Tensor, compressed_kv: torch.Tensor, k_pe: torch.Tensor, attn_metadata: AttentionMetadata, latent_cache: Optional[torch.Tensor] = None, ) -> torch.Tensor: kv = self.kv_b_proj(compressed_kv) k_nope, v = kv.split( [ self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim ], -1, ) k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim) k[..., :self.qk_nope_head_dim] = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim) if self.apply_rotary_emb: k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1, self.qk_rope_head_dim) k = k.view(-1, self.num_heads * self.qk_head_dim) # May concat q(including q_pe), k + k_pe, v together q, k, v = self._maybe_concat_qkv(q, k, v) # out_scale = getattr(self.o_proj, "inv_input_scale", None) out_scale = None # Currently we use BF16 MHA for context phase attn_output = self.mha.forward( q, k, v, attn_metadata, attention_input_type=AttentionInputType.context_only, latent_cache=latent_cache, out_scale=out_scale, ) return attn_output def forward_context_with_cached_kv( self, q: torch.Tensor, compressed_kv: torch.Tensor, k_pe: torch.Tensor, attn_metadata: AttentionMetadata, position_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: trtllm_attention = cast(TrtllmAttention, self.mha) # copy past_compressed_kv and past_k_pe from paged kv cache past_latent_cache = trtllm_attention.load_paged_kv_cache_for_mla( attn_metadata, q.dtype) assert past_latent_cache.shape[0] == attn_metadata.num_ctx_cached_tokens assert past_latent_cache.shape[ 1] == self.kv_lora_rank + self.qk_rope_head_dim past_compressed_kv, past_k_pe = past_latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) # compute past_k_nope and past_v from past_compressed_kv # TODO: remove this contiguous by return two tensors from load_paged_kv_cache_for_mla past_compressed_kv = past_compressed_kv.contiguous() past_kv = self.kv_b_proj(past_compressed_kv) past_k_nope, past_v = past_kv.split( [ self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim ], -1, ) past_k_nope = past_k_nope.view(-1, self.num_heads, self.qk_nope_head_dim) past_v = past_v.view(-1, self.num_heads, self.v_head_dim) # compute current k_nope and v from compressed_kv kv = self.kv_b_proj(compressed_kv) k_nope, v = kv.split([ self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim ], dim=-1) # split current q into q_nope and q_pe q_nope, q_pe = q.view([ -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim ]).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # apply rope to current q_pe and k_pe assert position_ids is not None assert position_ids.dim() == 1 or (position_ids.dim() == 2 and position_ids.shape[0] == 1) assert self.rotary_emb is not None assert self.rotary_emb.head_dim == self.qk_rope_head_dim assert q_pe.shape[0] == k_pe.shape[0] q_pe = q_pe.contiguous().view(-1, self.num_heads * self.qk_rope_head_dim) q_pe, k_pe = self.rotary_emb( position_ids[..., :attn_metadata.num_ctx_tokens], [q_pe, k_pe]) k_pe = k_pe.contiguous() # build q for attention op q_view = q.view(-1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim) q_view[:, :, self.qk_nope_head_dim:] = q_pe.view(-1, self.num_heads, self.qk_rope_head_dim) q = q_view.view( -1, self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) assert q.is_contiguous() # append paged kv cache for mla # we may finish it inside the attention op by passing latent_cache trtllm_attention.append_paged_kv_cache_for_mla( compressed_kv, k_pe, attn_metadata, ) # build full_k and full_v k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim) v = v.view(-1, self.num_heads, self.v_head_dim) tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block # paged kv cache should be initialized to 0 to avoid NaN full_kv = torch.zeros([ attn_metadata.num_contexts, 2, (attn_metadata.max_ctx_kv_len + tokens_per_block - 1) // tokens_per_block, self.num_heads, tokens_per_block, max(self.qk_nope_head_dim + self.qk_rope_head_dim, self.v_head_dim) ], dtype=q.dtype, device=q.device) mla_context_kv_cache_block_offsets = trtllm_attention.set_paged_kv_cache_v2_for_mla( full_kv, past_k_nope, past_v, past_k_pe, k_nope, v, k_pe, attn_metadata, ) # out_scale = getattr(self.o_proj, "inv_input_scale", None) out_scale = None # Currently we use BF16 MHA for context phase attn_output = self.mha.forward( q, None, None, attn_metadata, attention_input_type=AttentionInputType.context_only, latent_cache=None, out_scale=out_scale, mla_context_paged_kv=full_kv, mla_context_kv_cache_block_offsets= mla_context_kv_cache_block_offsets, ) return attn_output def forward_context( self, q: torch.Tensor, compressed_kv: torch.Tensor, k_pe: torch.Tensor, attn_metadata: AttentionMetadata, latent_cache: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: if isinstance(self.mha, TrtllmAttention): assert isinstance(attn_metadata, TrtllmAttentionMetadata) trtllm_attention = cast(TrtllmAttention, self.mha) if trtllm_attention.has_cached_kv_for_mla_context(attn_metadata): return self.forward_context_with_cached_kv( q, compressed_kv, k_pe, attn_metadata, position_ids) return self.forward_context_default(q, compressed_kv, k_pe, attn_metadata, latent_cache) def forward_generation( self, q: torch.Tensor, compressed_kv: torch.Tensor, k_pe: torch.Tensor, attn_metadata: AttentionMetadata, latent_cache: Optional[torch.Tensor] = None, ) -> torch.Tensor: num_tokens = q.shape[0] q_nope, q_pe = q.view([-1, self.num_heads, self.qk_head_dim]).split( [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # fused_q contains 1) the result of the following bmm with shape [num_tokens, num_heads, kv_lora_rank] # 2) rope(q_pe) with shape [num_tokens, num_heads, qk_rope_head_dim]. rope is applied inside AttentionOp fused_q = torch.empty( [ num_tokens, self.num_heads, (self.kv_lora_rank + self.qk_rope_head_dim) ], dtype=q.dtype, device=q.device, ) if self.k_b_proj_trans.dtype == torch.bfloat16: # [num_heads, num_tokens, self.qk_nope_head_dim] q_nope_t = q_nope.transpose(0, 1) # [num_heads, num_tokens, self.kv_lora_rank] q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) # [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim] # -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank] # The output of bmm is written directly into fused_q torch.ops.trtllm.bmm_out(q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out) elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn: q_nope_fp8, q_nope_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( q_nope) # [num_heads, num_tokens, self.kv_lora_rank] q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) torch.ops.trtllm.fp8_block_scaling_bmm_out( q_nope_fp8, self.k_b_proj_trans, q_nope_scales, self.k_b_proj_trans_scale, q_nope_out) q_nope_scales = None else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.") if self.apply_rotary_emb: fused_q[..., self.kv_lora_rank:] = q_pe fused_q = fused_q.view([ num_tokens, self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim) ]) # out_scale = getattr(self.o_proj, "inv_input_scale", None) out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16 attn_out_latent = self.mqa.forward( fused_q, None, None, attn_metadata, attention_input_type=AttentionInputType.generation_only, out_scale=out_scale, latent_cache=latent_cache, # kvcache and k_pe q_pe=q_pe, # used by `invokeMLARopeGeneration` ) fused_q = None assert (attn_out_latent.shape[0] == q.shape[0] and attn_out_latent.shape[1] == self.num_heads * self.kv_lora_rank) # [seq, num_heads, kv_lora_rank] attn_out_latent = attn_out_latent.view( [-1, self.num_heads, self.kv_lora_rank]) attn_output = torch.empty([num_tokens, self.num_heads, self.v_head_dim], dtype=attn_out_latent.dtype, device=attn_out_latent.device) if self.v_b_proj.dtype == torch.bfloat16: # [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim] # -> [num_heads, seq, v_head_dim] torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1), self.v_b_proj.transpose(1, 2), attn_output.transpose(0, 1)) elif self.v_b_proj.dtype == torch.float8_e4m3fn: attn_out_latent, attn_out_latent_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( attn_out_latent) torch.ops.trtllm.fp8_block_scaling_bmm_out( attn_out_latent, self.v_b_proj, attn_out_latent_scales, self.v_b_proj_scale, attn_output.transpose(0, 1)) attn_out_latent_scales = None else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.") # [seq, num_heads * v_head_dim] return attn_output.flatten(1, 2) def forward( self, position_ids: Optional[torch.Tensor], hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, all_reduce_params: Optional[AllReduceParams] = None, ) -> torch.Tensor: if self.register_to_config: attn_output = torch.ops.trtllm.mla_custom_op( position_ids, hidden_states, self.layer_idx_str) else: attn_output = self.forward_impl(position_ids, hidden_states, attn_metadata) attn_output = self.o_proj(attn_output, all_reduce_params=all_reduce_params) return attn_output