import math import weakref from typing import Optional, Union, cast import torch from torch import nn from tensorrt_llm._utils import get_sm_version from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from ..attention_backend import (AttentionInputType, AttentionMetadata, TrtllmAttention, TrtllmAttentionMetadata) from ..attention_backend.interface import (AttentionMask, PositionalEmbeddingParams, PredefinedAttentionMask) from ..attention_backend.utils import create_attention, get_attention_backend from ..distributed import AllReduceParams from ..model_config import ModelConfig from ..peft.lora.layer import LoraLayer, LoraModuleType from ..utils import Fp4QuantizedTensor, get_model_extra_attrs from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig from .multi_stream_utils import maybe_execute_in_parallel from .rms_norm import RMSNorm from .rotary_embedding import RotaryEmbedding class Attention(nn.Module): def __init__( self, *, hidden_size: int, num_attention_heads: int, num_key_value_heads: int, max_position_embeddings: int, bias: bool, pos_embd_params: Optional[PositionalEmbeddingParams] = None, rope_fusion: Optional[bool] = None, layer_idx: Optional[int] = None, dtype: torch.dtype = None, dense_bias: Optional[bool] = None, config: Optional[ModelConfig] = None, q_scaling: float = 1.0, attention_chunk_size: Optional[int] = None, ): """ Initialize the Attention module. Args: hidden_size (int): The size of the hidden dimension. num_attention_heads (int): The number of attention heads. num_key_value_heads (int): The number of key value heads. max_position_embeddings (int): The maximum position embeddings. bias (bool): Whether to use bias in the linear layers. pos_embd_params (Optional[PositionalEmbeddingParams]): The positional embedding parameters. rope_fusion (Optional[bool]): Whether to fuse RoPE into the attention OP and skip applying unfused RoPE. If None, whether to fuse is decided by the capability of the attention backend. layer_idx (Optional[int]): The layer index. dtype (torch.dtype): The data type. dense_bias (Optional[bool]): Whether to use bias in the output projection layer. config (Optional[ModelConfig]): The model configuration. q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0. attention_chunk_size (Optional[int]): See [Chunked Attention] below. """ super().__init__() self.layer_idx = layer_idx config = config or ModelConfig() self.hidden_size = hidden_size self.num_heads = num_attention_heads self.head_dim = getattr(config.pretrained_config, 'head_dim', None) if not isinstance(self.head_dim, int): self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = max_position_embeddings self.pos_embd_params = pos_embd_params self.dense_bias = dense_bias self.q_scaling = q_scaling # [Chunked Attention] # Chunked attention is applied to context requests only. Chunked attention will be # applied when this field is specified and mMaskType == CAUSAL. # # In chunked attention, we break context requests into chunks of a specified size. Tokens can only # attend to tokens in the same chunk. So, for example, if the chunk size is 3, we might have a mask # that looks like this: # # 1 0 0 0 0 0 # 1 1 0 0 0 0 # 1 1 1 0 0 0 # 0 0 0 1 0 0 # 0 0 0 1 1 0 # 0 0 0 1 1 1 self.attention_chunk_size = attention_chunk_size if dense_bias is None: self.dense_bias = bias # tensor parallel tp_size = config.mapping.tp_size pp_size = config.mapping.pp_size if config.mapping.enable_attention_dp: tp_size = 1 mapping = Mapping( world_size=tp_size * pp_size, tp_size=tp_size, pp_size=pp_size, rank=config.mapping.rank, gpus_per_node=config.mapping.gpus_per_node, enable_attention_dp=config.mapping.enable_attention_dp, ) assert self.num_heads % tp_size == 0 self.num_heads = self.num_heads // tp_size self.num_key_value_heads = (self.num_key_value_heads + tp_size - 1) // tp_size self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim self.qkv_proj = Linear( self.hidden_size, tp_size * self.q_size + 2 * tp_size * self.kv_size, bias=bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, weights_loading_config=WeightsLoadingConfig( weight_mode=WeightMode.FUSED_QKV_LINEAR), quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization) self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) self.o_proj = Linear( tp_size * self.q_size, self.hidden_size, bias=self.dense_bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, lora=self.o_lora, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization) self.quant_config = config.get_quant_config() self.attn_backend = config.attn_backend attn_cls = get_attention_backend(self.attn_backend) # These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used, # but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora # handles them as a single fused operation. self.splitted_qkv_lora = LoraLayer([ LoraModuleType.ATTENTION_Q, LoraModuleType.ATTENTION_K, LoraModuleType.ATTENTION_V ], [self.q_size, self.kv_size, self.kv_size]) self.fused_qkv_lora = LoraLayer([LoraModuleType.ATTENTION_QKV], [self.q_size + 2 * self.kv_size]) self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) # Whether to fuse RoPE into the attention OP. # If true, RoPE will be applied in self.attn.forward. # If false, RoPE will be applied in self.apply_rope. self.rope_fusion = rope_fusion if self.rope_fusion and not attn_cls.support_fused_rope(): logger.warning( "rope_fusion is true but the attention backend does not support it. Will disable rope_fusion." ) self.rope_fusion = False # If rope_fusion is not specified, enable if the attention backend supports it. if self.rope_fusion is None: self.rope_fusion = attn_cls.support_fused_rope() self.rotary_emb = None if not self.rope_fusion and self.pos_embd_params is not None: self.rotary_emb = RotaryEmbedding( self.pos_embd_params.rope, head_dim=self.head_dim, is_neox=self.pos_embd_params.is_neox, ) self.attn = create_attention( self.attn_backend, self.layer_idx, self.num_heads, self.head_dim, self.num_key_value_heads, pos_embd_params=self.pos_embd_params if self.rope_fusion else None, quant_config=self.quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, q_scaling=self.q_scaling, attention_chunk_size=self.attention_chunk_size, ) self.support_fused_qkv = self.attn.support_fused_qkv() self.support_nvfp4_output = self.attn.support_nvfp4_output() if not config.skip_create_weights_in_init: self.create_weights() def create_weights(self): # self.attn has no weights but has states that are related to quant_config, # which could be modified after __init__ self.attn.update_quant_config(self.quant_config) def split_qkv(self, q, k=None, v=None): if k is None and v is None: q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1) return q, k, v def convert_qkv(self, q, k, v): if k is None and v is None and not self.support_fused_qkv: q, k, v = self.split_qkv(q) elif k is not None and v is not None and self.support_fused_qkv: qkv = torch.concat([q, k, v], dim=-1) q, k, v = qkv, None, None return q, k, v def forward( self, position_ids: Optional[torch.IntTensor], hidden_states: Union[torch.Tensor, Fp4QuantizedTensor], attn_metadata: AttentionMetadata, attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL, mrope_config: Optional[dict] = None, all_reduce_params: Optional[AllReduceParams] = None, lora_params: Optional[dict] = None, attention_window_size: Optional[int] = None, attention_mask_data: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ Forward pass for the Attention module. Args: position_ids (Optional[torch.IntTensor]): The position IDs. hidden_states (torch.Tensor): The hidden states. attn_metadata (AttentionMetadata): The attention metadata. attention_mask (AttentionMask): The attention mask type. mrope_config (Optional[dict]): The MROPE configuration. all_reduce_params (Optional[AllReduceParams]): The all reduce parameters. lora_params (Optional[dict]): The LoRA parameters. attention_window_size (Optional[int]): The attention window size. attention_mask_data (Optional[torch.Tensor]): The attention mask data. Returns: torch.Tensor: The output tensor. """ qkv = self.qkv_proj(hidden_states) if bool(lora_params): qkv_lora = self.splitted_qkv_lora(hidden_states, lora_params, self.layer_idx) if qkv_lora is not None: qkv = qkv + qkv_lora qkv_lora = self.fused_qkv_lora(hidden_states, lora_params, self.layer_idx) if qkv_lora is not None: qkv = qkv + qkv_lora q, k, v = qkv, None, None q, k, v = self.apply_rope(q, k, v, position_ids) out_scale = None out_scale_sf = None if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales or self.o_proj.has_fp8_rowwise: out_scale = self.o_proj.inv_input_scale if self.o_proj.has_nvfp4 and self.support_nvfp4_output: out_scale_sf = self.o_proj.input_scale q, k, v = self.convert_qkv(q, k, v) attn_output = self.attn.forward( q, k, v, attn_metadata, out_scale=out_scale, out_scale_sf=out_scale_sf, attention_mask=attention_mask, mrope_config=mrope_config, attention_window_size=attention_window_size, attention_mask_data=attention_mask_data) hidden_states = attn_output attn_output = self.o_proj(attn_output, all_reduce_params=all_reduce_params, lora_params=lora_params, layer_idx=self.layer_idx) return attn_output def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], v: Optional[torch.Tensor], position_ids: torch.Tensor): """ Apply RoPE to the query and key. Depending on the implementation, q, k, v could be either fused (q, k, v = concat(q, k, v), None, None) or unfused (none of q, k, v is None). Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn. This method could be overridden in the subclass, in which extra functionalities such as q_norm/k_norm could be added. Args: q (torch.Tensor): The query tensor. k (Optional[torch.Tensor]): The key tensor. v (Optional[torch.Tensor]): The value tensor. position_ids (torch.Tensor): The position IDs of each token for RoPE. Returns: tuple: A tuple of (q, k, v). """ q, k, v = self.split_qkv(q, k, v) # If RoPE is fused into the attention OP, do not apply RoPE here. if not self.rope_fusion and position_ids is not None: q, k = self.rotary_emb(position_ids, [q, k]) return q, k, v def extract_extra_attrs(layer_idx: str): extra_attrs = get_model_extra_attrs() assert extra_attrs is not None, "Model extra attrs is not set" metadata_ref = extra_attrs.get("attention_metadata", None) assert metadata_ref is not None, "Attention metadata is not set" metadata = metadata_ref() assert isinstance( metadata, TrtllmAttentionMetadata, ) mla_layers = extra_attrs.get("mla_layers", None) assert mla_layers is not None, "MLA layers is not registered" mla_layer_ref = mla_layers.get(layer_idx, None) assert mla_layer_ref is not None, f"Cannot find MLA layer for layer {layer_idx}" mla_layer = mla_layer_ref() assert isinstance( mla_layer, MLA), "MLA layer must be a subclass of MLA or an instance of MLA" return metadata, mla_layer @torch.library.custom_op("trtllm::mla_custom_op_inplace", mutates_args=("output", )) def mla_custom_op_inplace( hidden_states: torch.Tensor, position_ids: Optional[torch.Tensor], layer_idx: str, output: torch.Tensor, ) -> None: metadata, mla_layer = extract_extra_attrs(layer_idx) mla_layer.forward_impl(position_ids, hidden_states, metadata, output=output) def fp8_block_scaling_bmm_out( mat1: torch.Tensor, mat2_fp8: torch.Tensor, mat2_scale: torch.Tensor, out: torch.Tensor, ) -> torch.Tensor: sm_version = get_sm_version() if sm_version == 90 or sm_version == 89: mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( mat1) torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8, mat1_scale, mat2_scale, out) elif sm_version == 100: low_latency = True use_deep_seek_fp8 = True tile_size = 8 epilogue_tile_m = 64 if use_deep_seek_fp8 else 128 m_size = mat1.shape[0] if m_size % tile_size != 0: tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size mat1 = torch.nn.functional.pad( mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0) mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( mat1) output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen( mat1_fp8, mat2_fp8, tile_size=tile_size, epilogue_tile_m=epilogue_tile_m, use_deep_seek_fp8=use_deep_seek_fp8, low_latency=low_latency, dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1), dq_sfs_b=mat2_scale, out_dtype=out.dtype, ) out.copy_(output[:, :m_size]) else: raise NotImplementedError(f"SM{sm_version} is not supported") class MLA(nn.Module): def __init__( self, *, hidden_size: int, num_attention_heads: int, num_key_value_heads: int, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, q_lora_rank: int, kv_lora_rank: int, predicted_tokens_per_seq: int, max_position_embeddings: int, bias: bool, aux_stream: Optional[torch.cuda.Stream] = None, pos_embd_params: Optional[PositionalEmbeddingParams] = None, layer_idx: Optional[int] = None, dtype: torch.dtype = None, dense_bias: Optional[bool] = None, config: Optional[ModelConfig] = None, ): """ Initialize the MLA module. Args: hidden_size (int): The size of the hidden dimension. num_attention_heads (int): The number of attention heads. num_key_value_heads (int): The number of key value heads. qk_nope_head_dim (int): The dimension of the query and key without Rope. qk_rope_head_dim (int): The dimension of the Rope of query and key. v_head_dim (int): The dimension of the value. q_lora_rank (int): The dimension of the compressed query. kv_lora_rank (int): The dimension of the compressed key and value. predicted_tokens_per_seq (int): The number of predicted tokens per sequence. max_position_embeddings (int): The maximum position embeddings. bias (bool): Whether to use bias in the linear layers. aux_stream (Optional[torch.cuda.Stream]): The auxiliary CUDA stream for running operations in two parallel streams. pos_embd_params (PositionalEmbeddingParams): The positional embedding parameters. layer_idx (int): The layer index. dtype (torch.dtype): The data type. dense_bias (bool): Whether to use bias in the output projection layer. config (ModelConfig): The model configuration. """ super().__init__() self.layer_idx = layer_idx self.layer_idx_str = str(layer_idx) self.dtype = dtype self.hidden_size = hidden_size self.num_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.predicted_tokens_per_seq = predicted_tokens_per_seq self.max_position_embeddings = max_position_embeddings self.pos_embd_params = pos_embd_params self.dense_bias = dense_bias if dense_bias is None: self.dense_bias = bias if self.q_lora_rank is None: self.q_lora_rank = hidden_size self.is_lite = True else: self.is_lite = False assert pos_embd_params is not None, "pos_embd_params must be provided in MLA" self.register_to_config = False if config is not None: if "mla_layers" not in config.extra_attrs: config.extra_attrs["mla_layers"] = {} config.extra_attrs["mla_layers"][self.layer_idx_str] = weakref.ref( self) self.register_to_config = True # tensor parallel config = config or ModelConfig() tp_size = config.mapping.tp_size pp_size = config.mapping.pp_size if config.mapping.enable_attention_dp: tp_size = 1 mapping = Mapping( world_size=tp_size * pp_size, tp_size=tp_size, pp_size=pp_size, rank=config.mapping.rank, gpus_per_node=config.mapping.gpus_per_node, enable_attention_dp=config.mapping.enable_attention_dp, ) assert self.num_heads % tp_size == 0 self.num_heads = self.num_heads // tp_size self.num_key_value_heads = (self.num_key_value_heads + tp_size - 1) // tp_size rms_norm_eps = config.pretrained_config.rms_norm_eps quant_config = config.get_quant_config() self.quant_config = quant_config if not self.is_lite: self.fused_a = Linear( hidden_size, self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim, bias=bias, dtype=dtype, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, use_custom_cublas_mm=True, force_dynamic_quantization=config.force_dynamic_quantization) self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank, eps=rms_norm_eps, dtype=dtype) self.q_b_proj = Linear( self.q_lora_rank, tp_size * self.num_heads * self.qk_head_dim, bias=bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization) else: self.fused_a = Linear( hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=bias, dtype=dtype, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, use_custom_cublas_mm=True, force_dynamic_quantization=config.force_dynamic_quantization) self.q_proj = Linear( self.q_lora_rank, tp_size * self.num_heads * self.qk_head_dim, bias=bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization) self.q_b_proj = self.q_proj self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank, dtype=dtype, eps=rms_norm_eps) self.kv_b_proj = Linear( self.kv_lora_rank, tp_size * self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization) # This parameter will view into self.kv_b_proj.weight after loading weights. # For dummy weight initialization, this parameter is initialized with empty tensor. # Used in forward_generation only self.v_b_proj = nn.Parameter( torch.empty( (self.num_heads, self.v_head_dim, self.kv_lora_rank), dtype=dtype, ), requires_grad=False, ) self.o_proj = Linear( self.num_key_value_heads * self.v_head_dim * tp_size, self.hidden_size, bias=self.dense_bias, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization) def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 mscale_all_dim = pos_embd_params.rope.mscale_all_dim scaling_factor = pos_embd_params.rope.scale mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) q_scaling = 1.0 / (mscale * mscale) self.mha = create_attention( config.attn_backend, self.layer_idx, self.num_heads, head_dim=self.qk_head_dim, num_kv_heads=self.num_key_value_heads, pos_embd_params=pos_embd_params, quant_config=quant_config, q_scaling=q_scaling, is_mla_enable=True, q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, qk_nope_head_dim=self.qk_nope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim, v_head_dim=self.v_head_dim, predicted_tokens_per_seq=self.predicted_tokens_per_seq, skip_create_weights_in_init=config.skip_create_weights_in_init, ) self.mqa = create_attention( config.attn_backend, self.layer_idx, self.num_heads, head_dim=self.kv_lora_rank + self.qk_rope_head_dim, num_kv_heads=1, pos_embd_params=pos_embd_params, quant_config=quant_config, q_scaling=q_scaling, is_mla_enable=True, q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, qk_nope_head_dim=self.qk_nope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim, v_head_dim=self.kv_lora_rank, predicted_tokens_per_seq=self.predicted_tokens_per_seq, skip_create_weights_in_init=config.skip_create_weights_in_init, ) self.aux_stream = aux_stream self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] self.rope_fusion = self.mha.support_fused_rope() self.support_fused_qkv = self.mha.support_fused_qkv() self.rotary_emb = None self.apply_rotary_emb = not self.rope_fusion if self.apply_rotary_emb: self.rotary_emb = RotaryEmbedding( pos_embd_params.rope, head_dim=self.qk_rope_head_dim, is_neox=pos_embd_params.is_neox, ) if not config.skip_create_weights_in_init: self.create_weights() def create_weights(self): # self.mha/mqa has no weights but has states that are related to quant_config, # which could be modified after __init__ self.mha.update_quant_config(self.quant_config) self.mqa.update_quant_config(self.quant_config) # k_b_proj_trans's dtype must be consistent with self.kv_b_proj, # which can be modified after __init__ has_fp8_block_scales = ( self.kv_b_proj.quant_config and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales()) mla_weight_dtype = torch.float8_e4m3fn if has_fp8_block_scales else self.dtype self.k_b_proj_trans = nn.Parameter( torch.empty( (self.num_heads, self.kv_lora_rank, self.qk_nope_head_dim), dtype=mla_weight_dtype, ), requires_grad=False, ) if has_fp8_block_scales: self.k_b_proj_trans_scale = nn.Parameter( torch.empty( ( self.num_heads, self.kv_lora_rank // 128, self.qk_nope_head_dim // 128, ), dtype=torch.float32, ), requires_grad=False, ) # This parameter will view into self.kv_b_proj.weight_scale after loading weights. # For dummy weight initialization, this parameter is initialized with empty tensor. self.v_b_proj_scale = nn.Parameter( torch.empty( ( self.num_heads, self.v_head_dim // 128, self.kv_lora_rank // 128, ), dtype=torch.float32, ), requires_grad=False, ) else: self.k_b_proj_trans_scale = None self.v_b_proj_scale = None def apply_rope( self, q: torch.Tensor, k_pe: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: q = q.view(-1, self.num_heads, self.qk_head_dim) q_pe = q[..., self.qk_nope_head_dim:].reshape( -1, self.num_heads * self.qk_rope_head_dim) q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe]) q[..., self.qk_nope_head_dim:] = q_pe.view(-1, self.num_heads, self.qk_rope_head_dim) return k_pe def create_output(self, hidden_states: torch.Tensor): num_tokens = hidden_states.shape[0] hidden_size = self.o_proj.in_features return hidden_states.new_empty([num_tokens, hidden_size], dtype=hidden_states.dtype) def forward_impl(self, position_ids: Optional[torch.Tensor], hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, output: Optional[torch.Tensor] = None) -> torch.Tensor: """ Forward pass for the MLA module. Args: position_ids (Optional[torch.IntTensor]): The position IDs. hidden_states (torch.Tensor): The hidden states. attn_metadata (AttentionMetadata): The attention metadata. all_reduce_params (Optional[AllReduceParams]): The all reduce parameters. Returns: torch.Tensor: The output tensor. """ if self.is_lite: compressed_kv, k_pe = self.fused_a(hidden_states).split( [self.kv_lora_rank, self.qk_rope_head_dim], -1) compressed_kv = self.kv_a_layernorm(compressed_kv) q = hidden_states else: q, compressed_kv, k_pe = self.fused_a(hidden_states).split( [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], -1) q, compressed_kv = maybe_execute_in_parallel( lambda: self.q_a_layernorm(q), lambda: self.kv_a_layernorm(compressed_kv), self.ln_events[0], self.ln_events[1], self.aux_stream, ) q, latent_cache = maybe_execute_in_parallel( lambda: self.q_b_proj(q), lambda: torch.concat([compressed_kv, k_pe], dim=-1), self.ln_events[0], self.ln_events[1], self.aux_stream, ) # split q, k, v into context and gen batches num_contexts = attn_metadata.num_contexts num_generations = attn_metadata.num_generations num_ctx_tokens = attn_metadata.num_ctx_tokens num_tokens = attn_metadata.num_tokens assert q.shape[ 0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}" if num_contexts > 0: q_ctx = q[:num_ctx_tokens, ...] compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...] k_pe_ctx = k_pe[:num_ctx_tokens, ...] latent_cache_ctx = latent_cache[:num_ctx_tokens, ...] if self.apply_rotary_emb: assert position_ids is not None k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids) attn_output_context = self.forward_context( q_ctx, compressed_kv_ctx, k_pe_ctx, attn_metadata, latent_cache_ctx, output=output if num_generations == 0 else None) if num_generations == 0: return attn_output_context else: attn_output_context = None if num_generations > 0: q_gen = q[num_ctx_tokens:, ...] compressed_kv_gen = compressed_kv[num_ctx_tokens:, ...] k_pe_gen = k_pe[num_ctx_tokens:, ...] latent_cache_gen = latent_cache[num_ctx_tokens:, ...] if self.apply_rotary_emb: assert position_ids is not None k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids) attn_output_gen = self.forward_generation( q_gen, compressed_kv_gen, k_pe_gen, attn_metadata, latent_cache_gen, output=output if num_contexts == 0 else None) if num_contexts == 0: return attn_output_gen else: attn_output_gen = None # release pytorch activation memory q = None compressed_kv = None k_pe = None assert attn_output_context is not None and attn_output_gen is not None assert ( len(attn_output_context.shape) == 2 ), f"attn_output_context must be rank 2, not {len(attn_output_context.shape)}" assert ( len(attn_output_gen.shape) == 2 ), f"attn_output_gen must be rank 2, not {len(attn_output_gen.shape)}" output = output if output is not None else torch.empty( (num_tokens, attn_output_context.shape[1]), dtype=attn_output_context.dtype, device=attn_output_context.device) output[:attn_output_context.shape[0], :] = attn_output_context output[attn_output_context.shape[0]:, :] = attn_output_gen attn_output_context = None attn_output_gen = None return output def _maybe_concat_qkv(self, q, k, v): if k is not None and v is not None and self.support_fused_qkv: qkv = torch.concat([q, k, v], dim=-1) q, k, v = qkv, None, None return q, k, v def forward_context_default( self, q: torch.Tensor, compressed_kv: torch.Tensor, k_pe: torch.Tensor, attn_metadata: AttentionMetadata, latent_cache: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None) -> torch.Tensor: kv = self.kv_b_proj(compressed_kv) k_nope, v = kv.split( [ self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim ], -1, ) k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim) k[..., :self.qk_nope_head_dim] = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim) if self.apply_rotary_emb: k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1, self.qk_rope_head_dim) k = k.view(-1, self.num_heads * self.qk_head_dim) # May concat q(including q_pe), k + k_pe, v together q, k, v = self._maybe_concat_qkv(q, k, v) # out_scale = getattr(self.o_proj, "inv_input_scale", None) out_scale = None # Currently we use BF16 MHA for context phase attn_output = self.mha.forward( q, k, v, attn_metadata, attention_input_type=AttentionInputType.context_only, latent_cache=latent_cache, out_scale=out_scale, output=output, ) return attn_output def forward_context_with_cached_kv( self, q: torch.Tensor, latent_cache: torch.Tensor, attn_metadata: AttentionMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert latent_cache is not None trtllm_attention = cast(TrtllmAttention, self.mha) # apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q trtllm_attention.mla_rope_append_paged_kv_assign_q( q, latent_cache, attn_metadata) # copy full_compressed_kv and full_k_pe from paged kv cache full_compressed_kv, full_k_pe = trtllm_attention.load_paged_kv_cache_for_mla( attn_metadata, q.dtype) assert full_compressed_kv.shape[ 0] == attn_metadata.num_ctx_cached_tokens + attn_metadata.num_ctx_tokens assert full_compressed_kv.shape[1] == self.kv_lora_rank assert full_k_pe.shape[ 0] == attn_metadata.num_ctx_cached_tokens + attn_metadata.num_ctx_tokens assert full_k_pe.shape[1] == self.qk_rope_head_dim assert full_compressed_kv.is_contiguous() assert full_k_pe.is_contiguous() # compute full_k_nope and full_v from full_compressed_kv full_kv = self.kv_b_proj(full_compressed_kv) full_k_nope, full_v = full_kv.split( [ self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim ], -1, ) full_k_nope = full_k_nope.view(-1, self.num_heads, self.qk_nope_head_dim) full_v = full_v.view(-1, self.num_heads, self.v_head_dim) # build paged_full_kv tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block # paged_full_kv will be initialized to 0 in the kernel to avoid NaN paged_full_kv = torch.empty([ attn_metadata.num_contexts, 2, (attn_metadata.max_ctx_kv_len + tokens_per_block - 1) // tokens_per_block, self.num_heads, tokens_per_block, max(self.qk_nope_head_dim + self.qk_rope_head_dim, self.v_head_dim) ], dtype=q.dtype, device=q.device) mla_context_kv_cache_block_offsets = trtllm_attention.set_paged_kv_cache_for_mla( paged_full_kv, full_k_nope, full_v, full_k_pe, attn_metadata, ) # release pytorch activation memory full_compressed_kv = None full_k_pe = None full_kv = None full_k_nope = None full_v = None # out_scale = getattr(self.o_proj, "inv_input_scale", None) out_scale = None # Currently we use BF16 MHA for context phase attn_output = self.mha.forward( q, None, None, attn_metadata, attention_input_type=AttentionInputType.context_only, latent_cache=None, out_scale=out_scale, mla_context_paged_kv=paged_full_kv, mla_context_kv_cache_block_offsets= mla_context_kv_cache_block_offsets, output=output, ) return attn_output def forward_context_with_chunked_prefill( self, q: torch.Tensor, compressed_kv: torch.Tensor, latent_cache: torch. Tensor, # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size] attn_metadata: TrtllmAttentionMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: trtllm_attention = cast(TrtllmAttention, self.mha) # apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q trtllm_attention.mla_rope_append_paged_kv_assign_q( q, latent_cache, attn_metadata) # determine the number of loop # currently we assume that the chunk size is the same as the max_num_tokens chunk_size = attn_metadata.runtime_features.chunk_size chunked_loop_num = attn_metadata.chunked_loop_num # [toal_token_q, num_heads, 2] -> [toal_token_q, num_heads] float2 self.softmax_stats_tensor = torch.empty( (attn_metadata.num_ctx_tokens, self.num_heads, 2), dtype=torch.float, device='cuda', ) self.temp_softmax_stats_tensor = torch.empty( (attn_metadata.num_ctx_tokens, self.num_heads, 2), dtype=torch.float, device='cuda', ) if output is None: attn_output = q.new_empty( (q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype) else: attn_output = output temp_attn_output = q.new_empty( (q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype) # use fake cached_cu_seq_len for chunked loop origin_kv_lens_cuda_runtime = attn_metadata.kv_lens_cuda_runtime origin_kv_lens_runtime = attn_metadata.kv_lens_runtime for loop_idx in range(chunked_loop_num): # {b, chunked_unit_size, h, kv_lora_rank + qk_rope_head_dim} zero padded # fetch `loop_idx` chunk from kv cache temp_cu_chunked_seq_len = attn_metadata.cu_chunked_seq_len[loop_idx] total_ctx_chunked_tokens = attn_metadata.host_cu_chunked_seq_len[ loop_idx, attn_metadata.num_contexts] chunked_compressed_kv, chunked_k_pe = trtllm_attention.load_chunked_kv_cache_for_mla( metadata=attn_metadata, chunked_idx=loop_idx, num_ctx_cached_tokens=total_ctx_chunked_tokens, cu_chunked_seq_len=temp_cu_chunked_seq_len, out_dtype=q.dtype) # up proj to uncompressed kv # [tokens, 2, h, kv_dim], without rope_dim chunked_kv = self.kv_b_proj(chunked_compressed_kv) # build full_kv # full_kv {B, 2, chunk_size / tokens_per_block, h, tokens_per_block, kv_dim + rope_dim} tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block full_kv = torch.zeros([ attn_metadata.num_contexts, 2, (chunk_size + tokens_per_block - 1) // tokens_per_block, self.num_heads, tokens_per_block, max(self.qk_nope_head_dim + self.qk_rope_head_dim, self.v_head_dim) ], dtype=q.dtype, device=q.device) mla_kv_cache_block_offsets = trtllm_attention.set_chunked_kv_cache_for_mla( full_kv, chunked_kv, chunked_k_pe, cu_chunked_seq_len=temp_cu_chunked_seq_len, cached=True, metadata=attn_metadata) # copy chunked_seq_len to replace kv_lens_runtime attn_metadata.kv_lens_runtime = attn_metadata.host_chunked_seq_len[ loop_idx] attn_metadata.kv_lens_cuda_runtime = attn_metadata.chunked_seq_len[ loop_idx] out_scale = None # do not apply mask for attention within loop temp_attn_output = self.mha.forward( q, None, None, attn_metadata, attention_input_type=AttentionInputType.context_only, latent_cache=None, out_scale=out_scale, attention_mask=PredefinedAttentionMask.FULL, mla_context_paged_kv=full_kv, mla_context_kv_cache_block_offsets=mla_kv_cache_block_offsets, softmax_stats_tensor=self.temp_softmax_stats_tensor, output=temp_attn_output, ) # merge attn result temp_merge_op = attn_metadata.merge_op_tensor[loop_idx] trtllm_attention.merge_attention_for_mla( attn_output, temp_attn_output, self.softmax_stats_tensor, self.temp_softmax_stats_tensor, temp_merge_op, attn_metadata) # deal with the uncached kv kv = self.kv_b_proj(compressed_kv) _, k_pe = latent_cache.view([ -1, self.kv_lora_rank + self.qk_rope_head_dim ]).split([self.kv_lora_rank, self.qk_rope_head_dim], -1) k_pe = k_pe.contiguous() # final round of attention # out_scale = getattr(self.o_proj, "inv_input_scale", None) out_scale = None # Currently we use BF16 MHA for context phase tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block full_kv = torch.zeros([ attn_metadata.num_contexts, 2, (attn_metadata.max_ctx_seq_len + tokens_per_block - 1) // tokens_per_block, self.num_heads, tokens_per_block, max(self.qk_nope_head_dim + self.qk_rope_head_dim, self.v_head_dim) ], dtype=q.dtype, device=q.device) mla_kv_cache_block_offsets = trtllm_attention.set_chunked_kv_cache_for_mla( full_kv, kv, k_pe, cu_chunked_seq_len=None, cached=False, metadata=attn_metadata) # copy q_lens to replace kv_lens_runtime attn_metadata.kv_lens_runtime = attn_metadata.prompt_lens_cpu_runtime attn_metadata.kv_lens_cuda_runtime = attn_metadata.prompt_lens_cuda_runtime temp_attn_output = self.mha.forward( q, None, None, attn_metadata, attention_input_type=AttentionInputType.context_only, latent_cache=None, out_scale=out_scale, mla_context_paged_kv=full_kv, mla_context_kv_cache_block_offsets=mla_kv_cache_block_offsets, softmax_stats_tensor=self.temp_softmax_stats_tensor, output=temp_attn_output, ) temp_merge_op = attn_metadata.merge_op_tensor[chunked_loop_num] trtllm_attention.merge_attention_for_mla(attn_output, temp_attn_output, self.softmax_stats_tensor, self.temp_softmax_stats_tensor, temp_merge_op, attn_metadata) # copy back kv_lens_runtime and kv_lens_cuda_runtime attn_metadata.kv_lens_runtime = origin_kv_lens_runtime attn_metadata.kv_lens_cuda_runtime = origin_kv_lens_cuda_runtime return attn_output def forward_context( self, q: torch.Tensor, compressed_kv: torch.Tensor, k_pe: torch.Tensor, attn_metadata: AttentionMetadata, latent_cache: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: if isinstance(self.mha, TrtllmAttention): assert isinstance(attn_metadata, TrtllmAttentionMetadata) trtllm_attention = cast(TrtllmAttention, self.mha) if trtllm_attention.is_chunked_prefill_for_mla_context( attn_metadata): return self.forward_context_with_chunked_prefill( q, compressed_kv, latent_cache, attn_metadata, output) elif trtllm_attention.has_cached_kv_for_mla_context(attn_metadata): return self.forward_context_with_cached_kv( q, latent_cache, attn_metadata, output) return self.forward_context_default(q, compressed_kv, k_pe, attn_metadata, latent_cache, output) def forward_generation( self, q: torch.Tensor, compressed_kv: torch.Tensor, k_pe: torch.Tensor, attn_metadata: AttentionMetadata, latent_cache: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None) -> torch.Tensor: num_tokens = q.shape[0] q_nope, q_pe = q.view([-1, self.num_heads, self.qk_head_dim]).split( [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # fused_q contains 1) the result of the following bmm with shape [num_tokens, num_heads, kv_lora_rank] # 2) rope(q_pe) with shape [num_tokens, num_heads, qk_rope_head_dim]. rope is applied inside AttentionOp fused_q = torch.empty( [ num_tokens, self.num_heads, (self.kv_lora_rank + self.qk_rope_head_dim) ], dtype=q.dtype, device=q.device, ) if self.k_b_proj_trans.dtype == torch.bfloat16: # [num_heads, num_tokens, self.qk_nope_head_dim] q_nope_t = q_nope.transpose(0, 1) # [num_heads, num_tokens, self.kv_lora_rank] q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) # [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim] # -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank] # The output of bmm is written directly into fused_q torch.ops.trtllm.bmm_out(q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out) elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn: # [num_heads, num_tokens, self.kv_lora_rank] q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) fp8_block_scaling_bmm_out(q_nope, self.k_b_proj_trans, self.k_b_proj_trans_scale, q_nope_out) else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.") if self.apply_rotary_emb: fused_q[..., self.kv_lora_rank:] = q_pe fused_q = fused_q.view([ num_tokens, self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim) ]) # out_scale = getattr(self.o_proj, "inv_input_scale", None) out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16 attn_out_latent = self.mqa.forward( fused_q, None, None, attn_metadata, attention_input_type=AttentionInputType.generation_only, out_scale=out_scale, latent_cache=latent_cache, # kvcache and k_pe q_pe=q_pe, # used by `invokeMLARopeGeneration` ) fused_q = None assert (attn_out_latent.shape[0] == q.shape[0] and attn_out_latent.shape[1] == self.num_heads * self.kv_lora_rank) # [seq, num_heads, kv_lora_rank] attn_out_latent = attn_out_latent.view( [-1, self.num_heads, self.kv_lora_rank]) # [seq, num_heads * v_head_dim] output = output if output is not None else torch.empty( [num_tokens, self.num_heads * self.v_head_dim], dtype=attn_out_latent.dtype, device=attn_out_latent.device) attn_output = output.view([num_tokens, self.num_heads, self.v_head_dim]) if self.v_b_proj.dtype == torch.bfloat16: # [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim] # -> [num_heads, seq, v_head_dim] torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1), self.v_b_proj.transpose(1, 2), attn_output.transpose(0, 1)) elif self.v_b_proj.dtype == torch.float8_e4m3fn: fp8_block_scaling_bmm_out(attn_out_latent, self.v_b_proj, self.v_b_proj_scale, attn_output.transpose(0, 1)) else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.") return output def forward( self, position_ids: Optional[torch.Tensor], hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, all_reduce_params: Optional[AllReduceParams] = None, ) -> torch.Tensor: attn_output = self.create_output(hidden_states) if self.register_to_config: torch.ops.trtllm.mla_custom_op_inplace(hidden_states, position_ids, self.layer_idx_str, attn_output) else: self.forward_impl(position_ids, hidden_states, attn_metadata, output=attn_output) attn_output = self.o_proj(attn_output, all_reduce_params=all_reduce_params) return attn_output