TensorRT-LLMs/tensorrt_llm/_torch/modules/attention.py
Chang Liu 7ceb5e5ab6
[TRTLLM-9198][perf] Add torch.compile + multi-stream support for k-cache scatter and weight scaling (#8988)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
Co-authored-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
2025-11-11 12:33:30 +08:00

2111 lines
88 KiB
Python

import math
import weakref
from typing import Optional, Union, cast
import torch
from torch import nn
from tensorrt_llm._utils import (get_sm_version, is_sm_100f, nvtx_range,
nvtx_range_debug)
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from ..attention_backend import (AttentionInputType, AttentionMetadata,
FlashInferAttentionMetadata, TrtllmAttention,
TrtllmAttentionMetadata)
from ..attention_backend.interface import (AttentionBackend, AttentionMask,
PositionalEmbeddingParams,
PredefinedAttentionMask)
from ..attention_backend.sparse.dsa import (
DSAtrtllmAttentionMetadata, transform_local_topk_and_prepare_pool_view)
from ..attention_backend.utils import create_attention, get_attention_backend
from ..distributed import AllReduceParams, alltoall_helix
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
is_torch_compiling, maybe_compile)
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
from .multi_stream_utils import maybe_execute_in_parallel
from .rms_norm import RMSNorm
from .rotary_embedding import MRotaryEmbedding, RotaryEmbedding
# Import FlashMLA sparse attention kernel
try:
from tensorrt_llm.flash_mla import flash_mla_sparse_fwd
except ImportError:
flash_mla_sparse_fwd = None
def extract_extra_attrs(layer_idx: str, attn_type: str):
assert attn_type in ["mla", "attn"], "Invalid attention type"
extra_attrs = get_model_extra_attrs()
assert extra_attrs is not None, "Model extra attrs is not set"
metadata_ref = extra_attrs.get("attention_metadata", None)
assert metadata_ref is not None, "Attention metadata is not set"
metadata = metadata_ref()
if attn_type == "mla":
assert isinstance(
metadata,
TrtllmAttentionMetadata,
)
else:
assert isinstance(
metadata,
FlashInferAttentionMetadata,
) or isinstance(
metadata,
TrtllmAttentionMetadata,
)
attn_layers = extra_attrs.get(attn_type + "_layers", None)
assert attn_layers is not None, "Attention layer is not registered"
attn_layer_ref = attn_layers.get(layer_idx, None)
assert attn_layer_ref is not None, f"Cannot find attention layer for layer {layer_idx}"
attn_layer = attn_layer_ref()
if attn_type == "mla":
assert isinstance(
attn_layer,
MLA), "MLA layer must be a subclass of MLA or an instance of MLA"
elif attn_type == "attn":
assert isinstance(
attn_layer, Attention
), "Attention layer must be a subclass of Attention or an instance of Attention"
return metadata, attn_layer
@maybe_compile
def maybe_compiled_copy_(dst, src):
dst.copy_(src)
@maybe_compile
def maybe_compiled_cat(tensors, dim):
return torch.cat(tensors, dim)
@torch.library.custom_op("trtllm::attn_custom_op_inplace",
mutates_args=("output", ))
def attn_custom_op_inplace(
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
attention_mask: str,
mrope_rotary_cos_sin: Optional[torch.Tensor],
mrope_position_deltas: Optional[torch.Tensor],
attention_window_size: Optional[int],
attention_mask_data: Optional[torch.Tensor],
attention_sinks: Optional[torch.Tensor],
layer_idx: str,
output: torch.Tensor,
) -> None:
metadata, attn_layer = extract_extra_attrs(layer_idx, "attn")
# NVFP4 output cannot be supported by torch compile for TRTLLM backend.
attn_layer._attn_impl(q,
k,
v,
metadata,
PredefinedAttentionMask(attention_mask),
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
enable_attn_nvfp4_output=False,
output=output,
attention_sinks=attention_sinks)
class Attention(nn.Module):
def __init__(
self,
*,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
max_position_embeddings: int,
bias: bool,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
rope_fusion: Optional[bool] = None,
layer_idx: Optional[int] = None,
dtype: torch.dtype = None,
dense_bias: Optional[bool] = None,
config: Optional[ModelConfig] = None,
q_scaling: float = 1.0,
attention_chunk_size: Optional[int] = None,
disable_deep_gemm: bool = False,
attn_output_gate: Optional[bool] = None,
use_custom_cublas_mm: bool = False,
):
"""
Initialize the Attention module.
Args:
hidden_size (int): The size of the hidden dimension.
num_attention_heads (int): The number of attention heads.
num_key_value_heads (int): The number of key value heads.
max_position_embeddings (int): The maximum position embeddings.
bias (bool): Whether to use bias in the linear layers.
pos_embd_params (Optional[PositionalEmbeddingParams]): The positional embedding parameters.
rope_fusion (Optional[bool]): Whether to fuse RoPE into the attention OP and skip applying unfused RoPE. If None, whether to fuse is decided by the capability of the attention backend.
layer_idx (Optional[int]): The layer index.
dtype (torch.dtype): The data type.
dense_bias (Optional[bool]): Whether to use bias in the output projection layer.
config (Optional[ModelConfig]): The model configuration.
q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0.
attention_chunk_size (Optional[int]): See [Chunked Attention] below.
disable_deep_gemm (bool): Whether to disable the use of DeepGEMM in Linear layers (currently only matters on SM100 + FP8).
attn_output_gate (Optional[bool]): Determines whether to use an output gate in the attention Op. If False, the decision is automatically handled by the attention backend based on its capabilities.
"""
super().__init__()
self.layer_idx = layer_idx
self.layer_idx_str = str(layer_idx)
self.register_to_config = False
# We only register TRTLLM attention layers to config.
if config is not None:
if "attn_layers" not in config.extra_attrs:
config.extra_attrs["attn_layers"] = {}
config.extra_attrs["attn_layers"][self.layer_idx_str] = weakref.ref(
self)
self.register_to_config = True
config = config or ModelConfig()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = getattr(config.pretrained_config, 'head_dim', None)
if not isinstance(self.head_dim, int):
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.pos_embd_params = pos_embd_params
self.dense_bias = dense_bias
self.q_scaling = q_scaling
self.attn_output_gate = attn_output_gate
if self.attn_output_gate:
logger.info_once("using attn output gate!", key="attn_output_gate")
# [Chunked Attention]
# Chunked attention is applied to context requests only. Chunked attention will be
# applied when this field is specified and mMaskType == CAUSAL.
#
# In chunked attention, we break context requests into chunks of a specified size. Tokens can only
# attend to tokens in the same chunk. So, for example, if the chunk size is 3, we might have a mask
# that looks like this:
#
# 1 0 0 0 0 0
# 1 1 0 0 0 0
# 1 1 1 0 0 0
# 0 0 0 1 0 0
# 0 0 0 1 1 0
# 0 0 0 1 1 1
self.attention_chunk_size = attention_chunk_size
if dense_bias is None:
self.dense_bias = bias
# tensor parallel
tp_size = config.mapping.tp_size
pp_size = config.mapping.pp_size
cp_size = config.mapping.cp_size
if config.mapping.enable_attention_dp:
tp_size = 1
mapping = Mapping(
world_size=tp_size * pp_size * cp_size,
tp_size=tp_size,
pp_size=pp_size,
cp_size=cp_size,
cp_config=config.mapping.cp_config,
rank=config.mapping.rank,
gpus_per_node=config.mapping.gpus_per_node,
enable_attention_dp=config.mapping.enable_attention_dp,
)
self.tp_size = tp_size
self.tp_rank = mapping.tp_rank
assert self.num_heads % tp_size == 0
self.num_heads = self.num_heads // tp_size
self.num_key_value_heads = (self.num_key_value_heads + tp_size -
1) // tp_size
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
self.qkv_proj = Linear(
self.hidden_size,
tp_size * self.q_size * (2 if self.attn_output_gate else 1) +
2 * tp_size * self.kv_size,
bias=bias,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.FUSED_QKV_LINEAR),
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
disable_deep_gemm=disable_deep_gemm,
use_custom_cublas_mm=use_custom_cublas_mm)
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])
self.o_proj = Linear(
tp_size * self.q_size,
self.hidden_size,
bias=self.dense_bias,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.o_lora,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
disable_deep_gemm=disable_deep_gemm,
use_custom_cublas_mm=use_custom_cublas_mm)
self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
attn_cls = get_attention_backend(
self.attn_backend,
sparse_attn_config=config.sparse_attention_config)
# These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used,
# but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora
# handles them as a single fused operation.
self.splitted_qkv_lora = LoraLayer([
LoraModuleType.ATTENTION_Q, LoraModuleType.ATTENTION_K,
LoraModuleType.ATTENTION_V
], [self.q_size, self.kv_size, self.kv_size])
self.fused_qkv_lora = LoraLayer([LoraModuleType.ATTENTION_QKV],
[self.q_size + 2 * self.kv_size])
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])
# Whether to fuse RoPE into the attention OP.
# If true, RoPE will be applied in self.attn.forward.
# If false, RoPE will be applied in self.apply_rope.
if config.sparse_attention_config is not None:
logger.warning("disable rope_fusion for sparse attention.")
rope_fusion = False
self.rope_fusion = rope_fusion
if self.rope_fusion and not attn_cls.support_fused_rope():
logger.warning(
"rope_fusion is true but the attention backend does not support it. Will disable rope_fusion."
)
self.rope_fusion = False
# If rope_fusion is not specified, enable if the attention backend supports it.
if self.rope_fusion is None:
self.rope_fusion = attn_cls.support_fused_rope()
self.rotary_emb = None
if not self.rope_fusion and self.pos_embd_params is not None:
if self.pos_embd_params.type.is_mrope():
self.rotary_emb = MRotaryEmbedding(
self.pos_embd_params.rope,
head_dim=self.head_dim,
is_neox=self.pos_embd_params.is_neox,
mrope_section=self.pos_embd_params.mrope_section,
)
else:
self.rotary_emb = RotaryEmbedding(
self.pos_embd_params.rope,
head_dim=self.head_dim,
is_neox=self.pos_embd_params.is_neox,
)
self.attn = create_attention(
self.attn_backend,
self.layer_idx,
self.num_heads,
self.head_dim,
self.num_key_value_heads,
pos_embd_params=self.pos_embd_params if self.rope_fusion else None,
quant_config=self.quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
q_scaling=self.q_scaling,
attention_chunk_size=self.attention_chunk_size,
sparse_attention_config=config.sparse_attention_config,
)
self.support_fused_qkv = self.attn.support_fused_qkv()
self.support_nvfp4_output = self.attn.support_nvfp4_output()
if not config.skip_create_weights_in_init:
self.create_weights()
def create_weights(self):
# self.attn has no weights but has states that are related to quant_config,
# which could be modified after __init__
self.attn.update_quant_config(self.quant_config)
self.o_proj.create_weights()
self.has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
or self.o_proj.has_fp8_block_scales
or self.o_proj.has_fp8_rowwise
or self.o_proj.has_w4a8_nvfp4_fp8)
def split_qkv(self, q, k=None, v=None):
if k is None and v is None:
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
return q, k, v
def convert_qkv(self, q, k, v):
if k is None and v is None and not self.support_fused_qkv:
q, k, v = self.split_qkv(q)
elif k is not None and v is not None and self.support_fused_qkv:
qkv = torch.concat([q, k, v], dim=-1)
q, k, v = qkv, None, None
return q, k, v
def create_output(self, q: torch.Tensor):
num_tokens = q.shape[0]
hidden_size = self.o_proj.in_features
out_dtype = q.dtype
if self.attn_backend == "TRTLLM":
if self.has_quant_scale and (self.attn.has_fp8_kv_cache
or self.attn.has_fp4_kv_cache):
out_dtype = torch.float8_e4m3fn
output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype)
return output
def _attn_impl(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
attention_mask: AttentionMask,
mrope_rotary_cos_sin: Optional[torch.Tensor],
mrope_position_deltas: Optional[torch.Tensor],
attention_window_size: Optional[int],
attention_mask_data: Optional[torch.Tensor],
enable_attn_nvfp4_output: bool = True,
output: Optional[torch.Tensor] = None,
output_sf: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
):
num_tokens = attn_metadata.num_tokens
q = q[:num_tokens, :]
if k is not None:
k = k[:num_tokens, :]
if v is not None:
v = v[:num_tokens, :]
out_scale = None
out_scale_sf = None
if self.has_quant_scale and not self.attn_output_gate:
out_scale = self.o_proj.inv_input_scale
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output and not self.attn_output_gate:
out_scale_sf = self.o_proj.input_scale
kv_scales_sf = None
kv_scales_sf_inv = None
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp4_kv_cache(
):
kv_scales_sf = self.qkv_proj.kv_scales
kv_scales_sf_inv = self.qkv_proj.inv_kv_scales
mrope_config = None
if mrope_rotary_cos_sin is not None or mrope_position_deltas is not None:
mrope_config = dict()
if mrope_rotary_cos_sin is not None:
mrope_config["mrope_rotary_cos_sin"] = mrope_rotary_cos_sin
if mrope_position_deltas is not None:
mrope_config["mrope_position_deltas"] = mrope_position_deltas
attn_output = self.attn.forward(
q,
k,
v,
attn_metadata,
out_scale=out_scale,
out_scale_sf=out_scale_sf,
kv_scales_sf=kv_scales_sf,
kv_scales_sf_inv=kv_scales_sf_inv,
attention_mask=attention_mask,
mrope_config=mrope_config,
attention_window_size=attention_window_size,
attention_mask_data=attention_mask_data,
enable_attn_nvfp4_output=enable_attn_nvfp4_output,
output=output[:num_tokens, :] if output is not None else None,
output_sf=output_sf,
attention_sinks=attention_sinks)
if isinstance(attn_output, tuple):
assert len(
attn_output
) == 2, "attn_output should be a tuple of (output, output_sf)"
return attn_output[0], attn_output[1]
return attn_output, None
def forward_impl(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
attention_mask: AttentionMask,
attention_window_size: Optional[int],
attention_mask_data: Optional[torch.Tensor],
mrope_config: Optional[dict],
attention_sinks: Optional[torch.Tensor] = None,
):
mrope_rotary_cos_sin = None
mrope_position_deltas = None
if mrope_config is not None:
if "mrope_rotary_cos_sin" in mrope_config:
mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"]
if "mrope_position_deltas" in mrope_config:
mrope_position_deltas = mrope_config["mrope_position_deltas"]
# Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
# Only enable custom inplace op when torch compiling.
use_custom_inplace_op = (self.register_to_config
and (self.attn_backend == "TRTLLM"
or self.attn_backend == "FLASHINFER")
and is_torch_compiling())
if use_custom_inplace_op:
output = self.create_output(q)
attn_custom_op_inplace(
q,
k,
v,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
attention_sinks,
self.layer_idx_str,
output,
)
else:
output, output_sf = self._attn_impl(q,
k,
v,
attn_metadata,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
attention_sinks=attention_sinks)
if output_sf is not None:
output = Fp4QuantizedTensor(output, output_sf)
return output
def forward(
self,
position_ids: Optional[torch.IntTensor],
hidden_states: Union[torch.Tensor, Fp4QuantizedTensor],
attn_metadata: AttentionMetadata,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
attention_window_size: Optional[int] = None,
attention_mask_data: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Forward pass for the Attention module.
Args:
position_ids (Optional[torch.IntTensor]): The position IDs.
hidden_states (torch.Tensor): The hidden states.
attn_metadata (AttentionMetadata): The attention metadata.
attention_mask (AttentionMask): The attention mask type.
mrope_config (Optional[dict]): The MROPE configuration.
all_reduce_params (Optional[AllReduceParams]): The all reduce parameters.
lora_params (Optional[dict]): The LoRA parameters.
attention_window_size (Optional[int]): The attention window size.
attention_mask_data (Optional[torch.Tensor]): The attention mask data.
Returns:
torch.Tensor: The output tensor.
"""
qkv = self.qkv_proj(hidden_states)
if bool(lora_params):
qkv_lora = self.splitted_qkv_lora(hidden_states, lora_params,
self.layer_idx)
if qkv_lora is not None:
qkv = qkv + qkv_lora
qkv_lora = self.fused_qkv_lora(hidden_states, lora_params,
self.layer_idx)
if qkv_lora is not None:
qkv = qkv + qkv_lora
if self.attn_output_gate:
q_gate, k, v = qkv.split(
[self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
orig_shape = q_gate.shape[:-1]
# Single line: view -> chunk -> reshape both q and gate
q, gate = [
t.reshape(*orig_shape, -1) for t in torch.chunk(
q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1)
]
else:
q, k, v = qkv, None, None
q, k, v = self.apply_rope(q, k, v, position_ids)
q, k, v = self.convert_qkv(q, k, v)
if attention_sinks is not None:
assert self.attn_backend == "TRTLLM", "Attention sinks are only supported for TRTLLM backend."
attn_output = self.forward_impl(q,
k,
v,
attn_metadata,
attention_mask,
attention_window_size,
attention_mask_data,
mrope_config=mrope_config,
attention_sinks=attention_sinks)
if self.attn_output_gate:
gate = torch.sigmoid(gate)
attn_output = attn_output * gate
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params,
lora_params=lora_params,
layer_idx=self.layer_idx)
return attn_output
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], position_ids: torch.Tensor):
"""
Apply RoPE to the query and key.
Depending on the implementation, q, k, v could be either fused (q, k, v = concat(q, k, v), None, None) or unfused (none of q, k, v is None).
Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn.
This method could be overridden in the subclass, in which extra functionalities such as q_norm/k_norm could be added.
Args:
q (torch.Tensor): The query tensor.
k (Optional[torch.Tensor]): The key tensor.
v (Optional[torch.Tensor]): The value tensor.
position_ids (torch.Tensor): The position IDs of each token for RoPE.
Returns:
tuple: A tuple of (q, k, v).
"""
# If RoPE is fused into the attention OP, do not apply RoPE here.
if not self.rope_fusion and position_ids is not None:
q, k, v = self.split_qkv(q, k, v)
q, k = self.rotary_emb(position_ids, [q, k])
return q, k, v
def apply_qk_norm(self, q, k):
raise NotImplementedError(
f"QK norm is not implemented for {self.__class__.__name__}."
"Please override the `apply_qk_norm` method in the subclass.")
@torch.library.custom_op("trtllm::mla_custom_op_inplace",
mutates_args=("output", ))
def mla_custom_op_inplace(
hidden_states: torch.Tensor,
position_ids: Optional[torch.Tensor],
layer_idx: str,
output: torch.Tensor,
latent_cache_gen: Optional[torch.Tensor],
) -> None:
metadata, mla_layer = extract_extra_attrs(layer_idx, "mla")
mla_layer.forward_impl(position_ids,
hidden_states,
metadata,
output=output,
latent_cache_gen=latent_cache_gen)
def fp8_block_scaling_bmm_out(
mat1: torch.Tensor,
mat2_fp8: torch.Tensor,
mat2_scale: torch.Tensor,
out: torch.Tensor,
mat2_dequant: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sm_version = get_sm_version()
if sm_version == 90 or sm_version == 89 or sm_version == 120:
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
mat1)
output = out.new_empty(out.shape, dtype=out.dtype, device=out.device)
torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
mat1_scale, mat2_scale,
output)
out.copy_(output)
elif is_sm_100f(sm_version):
torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)
else:
raise NotImplementedError(f"SM{sm_version} is not supported")
class MLA(nn.Module):
def __init__(
self,
*,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
predicted_tokens_per_seq: int,
max_position_embeddings: int,
bias: bool,
aux_stream: Optional[torch.cuda.Stream] = None,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
layer_idx: Optional[int] = None,
dtype: torch.dtype = None,
dense_bias: Optional[bool] = None,
config: Optional[ModelConfig] = None,
enable_unit_test: bool = False,
):
"""
Initialize the MLA module.
Args:
hidden_size (int): The size of the hidden dimension.
num_attention_heads (int): The number of attention heads.
num_key_value_heads (int): The number of key value heads.
qk_nope_head_dim (int): The dimension of the query and key without Rope.
qk_rope_head_dim (int): The dimension of the Rope of query and key.
v_head_dim (int): The dimension of the value.
q_lora_rank (int): The dimension of the compressed query.
kv_lora_rank (int): The dimension of the compressed key and value.
predicted_tokens_per_seq (int): The number of predicted tokens per sequence.
max_position_embeddings (int): The maximum position embeddings.
bias (bool): Whether to use bias in the linear layers.
aux_stream (Optional[torch.cuda.Stream]): The auxiliary CUDA stream for running operations in two parallel streams.
pos_embd_params (PositionalEmbeddingParams): The positional embedding parameters.
layer_idx (int): The layer index.
dtype (torch.dtype): The data type.
dense_bias (bool): Whether to use bias in the output projection layer.
config (ModelConfig): The model configuration.
enable_unit_test (bool): Whether to enable unit test.
"""
super().__init__()
self.layer_idx = layer_idx
self.layer_idx_str = str(layer_idx)
self.dtype = dtype
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.predicted_tokens_per_seq = predicted_tokens_per_seq
self.max_position_embeddings = max_position_embeddings
self.pos_embd_params = pos_embd_params
self.dense_bias = dense_bias
self.enable_unit_test = enable_unit_test
if dense_bias is None:
self.dense_bias = bias
if self.q_lora_rank is None:
self.q_lora_rank = hidden_size
self.is_lite = True
else:
self.is_lite = False
assert pos_embd_params is not None, "pos_embd_params must be provided in MLA"
self.register_to_config = False
if config is not None:
if "mla_layers" not in config.extra_attrs:
config.extra_attrs["mla_layers"] = {}
config.extra_attrs["mla_layers"][self.layer_idx_str] = weakref.ref(
self)
self.register_to_config = True
# only support one kind of sparse attention, dsa now.
if config is not None and config.sparse_attention_config is not None:
self.is_dsa = True
else:
self.is_dsa = False
# tensor parallel
config = config or ModelConfig()
self.mapping = config.mapping
tp_size = self.mapping.tp_size
pp_size = self.mapping.pp_size
cp_size = self.mapping.cp_size
if self.mapping.enable_attention_dp:
tp_size = 1
if self.mapping.has_cp_ulysses():
raise NotImplementedError("MLA doesn't support CP Ulyssees yet")
mapping = Mapping(
world_size=tp_size * pp_size * cp_size,
tp_size=tp_size,
pp_size=pp_size,
cp_size=cp_size,
cp_config=self.mapping.cp_config,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
assert self.num_heads % (tp_size * cp_size) == 0
self.num_heads_tp = self.num_heads // tp_size
self.num_heads_tp_cp = self.num_heads_tp // cp_size
self.num_key_value_heads_tp = (self.num_key_value_heads + tp_size -
1) // tp_size
if self.enable_unit_test:
rms_norm_eps = getattr(config.pretrained_config, "rms_norm_eps",
1e-6)
else:
rms_norm_eps = config.pretrained_config.rms_norm_eps
quant_config = config.get_quant_config()
self.quant_config = quant_config
if not self.is_lite:
self.kv_a_proj_with_mqa = Linear(
hidden_size,
self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
bias=bias,
dtype=dtype,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
use_custom_cublas_mm=True,
force_dynamic_quantization=config.force_dynamic_quantization)
self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank,
eps=rms_norm_eps,
dtype=dtype)
self.q_b_proj = Linear(
self.q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=bias,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
else:
self.kv_a_proj_with_mqa = Linear(
hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=bias,
dtype=dtype,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
use_custom_cublas_mm=True,
force_dynamic_quantization=config.force_dynamic_quantization)
self.q_proj = Linear(
self.q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=bias,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
self.q_b_proj = self.q_proj
self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank,
dtype=dtype,
eps=rms_norm_eps)
self.kv_b_proj = Linear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=bias,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
# This parameter will view into self.kv_b_proj.weight after loading weights.
# For dummy weight initialization, this parameter is initialized with empty tensor.
# Used in forward_absorption only
self.v_b_proj = nn.Parameter(
torch.empty(
(self.num_heads_tp_cp, self.v_head_dim, self.kv_lora_rank),
dtype=dtype,
),
requires_grad=False,
)
mapping_o = Mapping(
world_size=tp_size * pp_size * cp_size,
tp_size=tp_size * cp_size,
pp_size=pp_size,
cp_size=1,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
self.o_proj = Linear(
self.num_key_value_heads * self.v_head_dim,
self.hidden_size,
bias=self.dense_bias,
dtype=dtype,
mapping=mapping_o,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
mscale_all_dim = pos_embd_params.rope.mscale_all_dim
scaling_factor = pos_embd_params.rope.scale
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
q_scaling = 1.0 / (mscale * mscale)
if not self.is_dsa:
self.mha = create_attention(
config.attn_backend,
self.layer_idx,
self.num_heads_tp,
head_dim=self.qk_head_dim,
num_kv_heads=self.num_key_value_heads_tp,
pos_embd_params=pos_embd_params,
quant_config=quant_config,
q_scaling=q_scaling,
is_mla_enable=True,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
predicted_tokens_per_seq=self.predicted_tokens_per_seq,
skip_create_weights_in_init=config.skip_create_weights_in_init,
sparse_attention_config=config.sparse_attention_config,
)
else:
self.mha = None
self.mqa = create_attention(
config.attn_backend,
self.layer_idx,
self.num_heads_tp,
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
num_kv_heads=1,
pos_embd_params=pos_embd_params,
quant_config=quant_config,
q_scaling=q_scaling,
is_mla_enable=True,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.kv_lora_rank,
hidden_size=self.hidden_size,
predicted_tokens_per_seq=self.predicted_tokens_per_seq,
skip_create_weights_in_init=config.skip_create_weights_in_init,
sparse_attention_config=config.sparse_attention_config,
dtype=dtype,
aux_stream=aux_stream,
)
self.softmax_scale = 1.0 / (math.sqrt(self.qk_head_dim) * q_scaling)
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
self.rope_fusion = self.mqa.support_fused_rope()
self.rotary_emb = None
self.apply_rotary_emb = not self.rope_fusion
if self.apply_rotary_emb:
self.rotary_emb = RotaryEmbedding(
pos_embd_params.rope,
head_dim=self.qk_rope_head_dim,
is_neox=pos_embd_params.is_neox,
)
if not config.skip_create_weights_in_init:
self.create_weights()
def create_weights(self):
# self.mha/mqa has no weights but has states that are related to quant_config,
# which could be modified after __init__
if not self.is_dsa:
self.mha.update_quant_config(self.quant_config)
self.mqa.update_quant_config(self.quant_config)
# Although we use FP8 MLA for context/generation phase, the output is still in BF16
self.out_scale = None
# k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
# which can be modified after __init__
has_fp8_block_scales = (
self.kv_b_proj.quant_config
and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales())
mla_weight_dtype = torch.float8_e4m3fn if has_fp8_block_scales else self.dtype
self.k_b_proj_trans = nn.Parameter(
torch.empty(
(self.num_heads_tp, self.kv_lora_rank, self.qk_nope_head_dim),
dtype=mla_weight_dtype,
),
requires_grad=False,
)
self.k_b_proj_trans_dequant = None
self.v_b_proj_dequant = None
if has_fp8_block_scales:
self.k_b_proj_trans_scale = nn.Parameter(
torch.empty(
(
self.num_heads_tp,
self.kv_lora_rank // 128,
self.qk_nope_head_dim // 128,
),
dtype=torch.float32,
),
requires_grad=False,
)
# This parameter will view into self.kv_b_proj.weight_scale after loading weights.
# For dummy weight initialization, this parameter is initialized with empty tensor.
self.v_b_proj_scale = nn.Parameter(
torch.empty(
(
self.num_heads_tp_cp,
self.v_head_dim // 128,
self.kv_lora_rank // 128,
),
dtype=torch.float32,
),
requires_grad=False,
)
if is_sm_100f():
assert self.dtype == torch.bfloat16
self.k_b_proj_trans_dequant = nn.Parameter(
torch.empty(
(self.num_heads_tp, self.kv_lora_rank,
self.qk_nope_head_dim),
dtype=self.dtype,
),
requires_grad=False,
)
self.v_b_proj_dequant = nn.Parameter(
torch.empty(
(self.num_heads_tp_cp, self.v_head_dim,
self.kv_lora_rank),
dtype=self.dtype,
),
requires_grad=False,
)
else:
self.k_b_proj_trans_scale = None
self.v_b_proj_scale = None
def apply_rope(
self,
q: torch.Tensor,
k_pe: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
q = q.view(-1, self.num_heads_tp, self.qk_head_dim)
q_pe = q[..., self.qk_nope_head_dim:].reshape(
-1, self.num_heads_tp * self.qk_rope_head_dim)
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe])
q[..., self.qk_nope_head_dim:] = q_pe.view(-1, self.num_heads_tp,
self.qk_rope_head_dim)
return k_pe
def _attn_forward_gen(self, attn_backend: AttentionBackend, q: torch.Tensor,
k: torch.Tensor, v: torch.Tensor,
position_ids: Optional[torch.Tensor],
attn_metadata: AttentionMetadata, **kwargs):
if self.mapping.cp_size > 1:
# partial_o: [num_tokens, num_heads_tp * kv_lora_rank]
# softmax_stats: [num_tokens, num_heads_tp, 2]
softmax_stats = torch.empty((q.shape[0], self.num_heads_tp, 2),
device=q.device,
dtype=torch.float32)
partial_o = attn_backend.forward(
q,
k,
v,
attn_metadata,
softmax_stats_tensor=softmax_stats,
helix_position_offsets=position_ids,
**kwargs)
# this is the post-processing of helix parallel attention,
# similar to the post-processing of ring attention
kv_lora_rank = partial_o.shape[-1] // self.num_heads_tp
assert self.kv_lora_rank == kv_lora_rank
chunks_o = [
t.contiguous() for t in torch.split(partial_o,
partial_o.shape[-1] //
self.mapping.cp_size,
dim=-1)
]
chunks_stats = [
t.contiguous() for t in torch.split(softmax_stats,
softmax_stats.shape[1] //
self.mapping.cp_size,
dim=1)
]
gathered_o, gathered_stats = alltoall_helix(
chunks_o + chunks_stats,
self.mapping.cp_group,
)
return torch.ops.trtllm.helix_post_process(gathered_o,
gathered_stats, 1.0)
else:
attn_output = attn_backend.forward(q, k, v, attn_metadata, **kwargs)
return attn_output
def create_output(self, hidden_states: torch.Tensor, num_contexts: int):
num_tokens = hidden_states.shape[0]
hidden_size = self.o_proj.in_features
if self.enable_unit_test and num_contexts > 0:
# note: for testing Helix parallelism, we ensure that the output is
# large enough for the context phase, but we then cut it again in
# `forward_context`
hidden_size *= self.mapping.cp_size
return hidden_states.new_empty([num_tokens, hidden_size],
dtype=hidden_states.dtype)
def forward_impl(self,
position_ids: Optional[torch.Tensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor,
latent_cache_gen: Optional[torch.Tensor] = None) -> None:
"""
Forward pass for the MLA module.
Args:
position_ids (Optional[torch.IntTensor]): The position IDs.
hidden_states (torch.Tensor): The hidden states.
attn_metadata (AttentionMetadata): The attention metadata.
all_reduce_params (Optional[AllReduceParams]): The all reduce parameters.
latent_cache_gen (Optional[torch.Tensor]): The latent cache used in generation.
Returns:
torch.Tensor: The output tensor.
"""
# split q, k, v into context and gen batches
num_contexts = attn_metadata.num_contexts
num_generations = attn_metadata.num_generations
num_ctx_tokens = attn_metadata.num_ctx_tokens
num_tokens = attn_metadata.num_tokens
hidden_states = hidden_states[:num_tokens, ...]
if position_ids is not None:
position_ids = position_ids[..., :num_tokens]
if self.is_lite:
compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
[self.kv_lora_rank, self.qk_rope_head_dim], -1)
compressed_kv = self.kv_a_layernorm(compressed_kv)
q = hidden_states
else:
q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(
hidden_states).split([
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim
], -1)
q, compressed_kv = maybe_execute_in_parallel(
lambda: self.q_a_layernorm(q),
lambda: self.kv_a_layernorm(compressed_kv),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
q, latent_cache = maybe_execute_in_parallel(
lambda: self.q_b_proj(q),
lambda: torch.concat([compressed_kv, k_pe], dim=-1),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
assert q.shape[
0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}"
assert output is not None, "output must be provided"
if num_contexts > 0:
q_ctx = q[:num_ctx_tokens, ...]
compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...]
k_pe_ctx = k_pe[:num_ctx_tokens, ...]
latent_cache_ctx = latent_cache[:num_ctx_tokens, ...]
if self.apply_rotary_emb:
assert position_ids is not None
k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids)
self.forward_context(
q_ctx,
compressed_kv_ctx,
k_pe_ctx,
position_ids,
attn_metadata,
output[:num_ctx_tokens, :],
latent_cache_ctx,
)
if num_generations > 0:
q_gen = q[num_ctx_tokens:, ...]
compressed_kv_gen = compressed_kv[num_ctx_tokens:, ...]
k_pe_gen = k_pe[num_ctx_tokens:, ...]
if latent_cache_gen is None:
latent_cache_gen = latent_cache[num_ctx_tokens:, ...]
if self.apply_rotary_emb:
assert position_ids is not None
k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids)
self.forward_absorption_generation(
q_gen,
compressed_kv_gen,
k_pe_gen,
attn_metadata,
output[num_ctx_tokens:num_tokens, :],
position_ids=position_ids,
latent_cache=latent_cache_gen,
)
def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor) -> None:
"""
Forward pass for the MLA module with DSA (always in MQA mode).
Args:
position_ids (Optional[torch.IntTensor]): The position IDs.
hidden_states (torch.Tensor): The hidden states.
attn_metadata (AttentionMetadata): The attention metadata.
Returns:
torch.Tensor: The output tensor.
"""
assert self.mha is None and self.mqa is not None, "DSA is only supported in MQA mode"
# split q, k, v into context and gen batches
num_contexts = attn_metadata.num_contexts
num_generations = attn_metadata.num_generations
num_ctx_tokens = attn_metadata.num_ctx_tokens
num_tokens = attn_metadata.num_tokens
hidden_states = hidden_states[:num_tokens, ...]
if position_ids is not None:
position_ids = position_ids[..., :num_tokens]
if self.fuse_a_indexer_k_weight:
q, compressed_kv, k_pe, indexer_k, indexer_weights = self.kv_a_proj_with_mqa(
hidden_states).split([
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim,
self.indexer.head_dim, self.indexer.n_heads
], -1)
else:
q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(
hidden_states).split([
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim
], -1)
indexer_k = None
indexer_weights = None
# TODO: possibly overlap/fuse q_a_rmsnorm + kv_a_rmsnorm + indexer.k_layernorm?
q, compressed_kv = maybe_execute_in_parallel(
lambda: self.q_a_layernorm(q),
lambda: self.kv_a_layernorm(compressed_kv),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
qr = q
latent_cache = torch.concat([compressed_kv, k_pe], dim=-1)
# TODO: fuse wq_b + (indexer) wlq here
q = self.q_b_proj(q)
# Indexer
topk_indices = self.indexer(
qr,
hidden_states,
attn_metadata,
position_ids,
indexer_k=indexer_k, # indexer K proj
indexer_weights=indexer_weights, # indexer weights proj
)
assert q.shape[
0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}"
assert output is not None, "output must be provided"
if num_contexts > 0:
q_ctx = q[:num_ctx_tokens, ...]
compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...]
k_pe_ctx = k_pe[:num_ctx_tokens, ...]
latent_cache_ctx = latent_cache[:num_ctx_tokens, ...]
if self.apply_rotary_emb:
assert position_ids is not None
k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids)
self.forward_context_dsa(
q_ctx,
compressed_kv_ctx,
k_pe_ctx,
attn_metadata,
output[:num_ctx_tokens, :],
latent_cache_ctx,
topk_indices=topk_indices[:num_ctx_tokens, :],
)
if num_generations > 0:
q_gen = q[num_ctx_tokens:, ...]
compressed_kv_gen = compressed_kv[num_ctx_tokens:, ...]
k_pe_gen = k_pe[num_ctx_tokens:, ...]
latent_cache_gen = latent_cache[num_ctx_tokens:, ...]
if self.apply_rotary_emb:
assert position_ids is not None
k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids)
self.forward_generation_dsa(
q_gen,
compressed_kv_gen,
k_pe_gen,
attn_metadata,
output[num_ctx_tokens:num_tokens, :],
latent_cache_gen,
topk_indices=topk_indices[num_ctx_tokens:num_tokens, :],
)
def forward_context_default(
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor,
latent_cache: Optional[torch.Tensor] = None,
) -> torch.Tensor:
kv = self.kv_b_proj(compressed_kv)
k_nope, v = kv.split(
[
self.num_heads_tp * self.qk_nope_head_dim,
self.num_heads_tp * self.v_head_dim
],
-1,
)
k = torch.empty_like(q).view(-1, self.num_heads_tp, self.qk_head_dim)
maybe_compiled_copy_(
k[..., :self.qk_nope_head_dim],
k_nope.view(-1, self.num_heads_tp, self.qk_nope_head_dim))
if self.apply_rotary_emb:
k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1,
self.qk_rope_head_dim)
k = k.view(-1, self.num_heads_tp * self.qk_head_dim)
helix_position_offsets = position_ids if self.mapping.cp_size > 1 else None
attn_output = self.mha.forward(
q,
k,
v,
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=latent_cache,
helix_position_offsets=helix_position_offsets,
out_scale=self.out_scale,
output=output,
)
return attn_output
def forward_context_dsa(
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor,
latent_cache: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_sm_version() >= 100:
return self.forward_absorption_context(q,
compressed_kv,
k_pe,
attn_metadata,
output,
latent_cache=latent_cache,
topk_indices=topk_indices)
else:
return self.forward_sparse_mla_kvcache_bf16(q,
latent_cache,
attn_metadata,
output,
topk_indices,
is_generation=False)
def forward_generation_dsa(
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor,
latent_cache: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_sm_version() >= 100:
return self.forward_absorption_generation(q,
compressed_kv,
k_pe,
attn_metadata,
output,
latent_cache=latent_cache,
topk_indices=topk_indices)
else:
return self.forward_sparse_mla_kvcache_bf16(q,
latent_cache,
attn_metadata,
output,
topk_indices,
is_generation=True)
def forward_context_with_cached_kv(
self,
q: torch.Tensor,
latent_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor,
) -> torch.Tensor:
assert latent_cache is not None
trtllm_attention = cast(TrtllmAttention, self.mha)
# apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
trtllm_attention.mla_rope_append_paged_kv_assign_q(
q, latent_cache, attn_metadata)
# copy full_compressed_kv and full_k_pe from paged kv cache
full_compressed_kv, full_k_pe = trtllm_attention.load_paged_kv_cache_for_mla(
attn_metadata, q.dtype)
assert full_compressed_kv.shape[
0] == attn_metadata.num_ctx_cached_tokens + attn_metadata.num_ctx_tokens
assert full_compressed_kv.shape[1] == self.kv_lora_rank
assert full_k_pe.shape[
0] == attn_metadata.num_ctx_cached_tokens + attn_metadata.num_ctx_tokens
assert full_k_pe.shape[1] == self.qk_rope_head_dim
assert full_compressed_kv.is_contiguous()
assert full_k_pe.is_contiguous()
# compute full_k_nope and full_v from full_compressed_kv
full_kv = self.kv_b_proj(full_compressed_kv)
full_k_nope, full_v = full_kv.split(
[
self.num_heads_tp * self.qk_nope_head_dim,
self.num_heads_tp * self.v_head_dim
],
-1,
)
full_k_nope = full_k_nope.view(-1, self.num_heads_tp,
self.qk_nope_head_dim)
full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim)
full_k = maybe_compiled_cat(
(full_k_nope, full_k_pe.expand(-1, self.num_heads_tp, -1)), dim=-1)
full_k = full_k.view(-1, self.num_heads_tp * self.qk_head_dim)
# release pytorch activation memory
full_compressed_kv = None
full_k_pe = None
full_kv = None
full_k_nope = None
# latent_cache must be None to differentiate from normal context phase,
# so that we can skip applying RoPE and appending KV cache inside attention op
attn_output = self.mha.forward(
q,
full_k,
full_v,
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=None,
out_scale=self.out_scale,
output=output,
)
return attn_output
def forward_context_with_chunked_prefill(
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
latent_cache: torch.
Tensor, # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
attn_metadata: TrtllmAttentionMetadata,
output: torch.Tensor,
) -> torch.Tensor:
trtllm_attention = cast(TrtllmAttention, self.mha)
# apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
trtllm_attention.mla_rope_append_paged_kv_assign_q(
q, latent_cache, attn_metadata)
# determine the number of loop
# currently we assume that the chunk size is the same as the max_num_tokens
chunked_loop_num = attn_metadata.chunked_loop_num
# [toal_token_q, num_heads, 2] -> [toal_token_q, num_heads] float2
self.softmax_stats_tensor = torch.empty(
(attn_metadata.num_ctx_tokens, self.num_heads_tp, 2),
dtype=torch.float,
device='cuda',
)
self.temp_softmax_stats_tensor = torch.empty(
(attn_metadata.num_ctx_tokens, self.num_heads_tp, 2),
dtype=torch.float,
device='cuda',
)
attn_output = output
temp_attn_output = q.new_empty(
(q.size(0), self.num_heads_tp * self.v_head_dim), dtype=q.dtype)
# use fake cached_cu_seq_len for chunked loop
origin_kv_lens_cuda_runtime = attn_metadata.kv_lens_cuda_runtime
origin_kv_lens_runtime = attn_metadata.kv_lens_runtime
origin_ctx_total_kv_len = attn_metadata.host_total_kv_lens[0]
for loop_idx in range(chunked_loop_num):
# {b, chunked_unit_size, h, kv_lora_rank + qk_rope_head_dim} zero padded
# fetch `loop_idx` chunk from kv cache
temp_cu_chunked_seq_len = attn_metadata.cu_chunked_seq_len[loop_idx]
total_ctx_chunked_tokens = attn_metadata.host_cu_chunked_seq_len[
loop_idx, attn_metadata.num_contexts]
chunked_global_offset = attn_metadata.chunked_global_offset[
loop_idx]
chunked_max_seq_len = attn_metadata.max_chunk_len_per_loop[loop_idx]
chunked_compressed_kv, chunked_k_pe = trtllm_attention.load_chunked_kv_cache_for_mla(
metadata=attn_metadata,
num_ctx_cached_tokens=total_ctx_chunked_tokens,
cu_chunked_seq_len=temp_cu_chunked_seq_len,
chunked_global_offset=chunked_global_offset,
chunked_max_seq_len=chunked_max_seq_len,
out_dtype=q.dtype)
# up proj to uncompressed kv
# [tokens, 2, h, kv_dim], without rope_dim
chunked_kv = self.kv_b_proj(chunked_compressed_kv)
chunked_k_nope, chunked_v = chunked_kv.split(
[
self.num_heads_tp * self.qk_nope_head_dim,
self.num_heads_tp * self.v_head_dim
],
-1,
)
chunked_k_nope = chunked_k_nope.view(-1, self.num_heads_tp,
self.qk_nope_head_dim)
chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim)
chunked_k = maybe_compiled_cat(
(chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads_tp,
-1)),
dim=-1)
chunked_k = chunked_k.view(-1, self.num_heads_tp * self.qk_head_dim)
# release pytorch activation memory
chunked_compressed_kv = None
chunked_k_pe = None
chunked_kv = None
chunked_k_nope = None
# copy chunked_seq_len to replace kv_lens_runtime
attn_metadata.kv_lens_runtime = attn_metadata.host_chunked_seq_len[
loop_idx]
attn_metadata.kv_lens_cuda_runtime = attn_metadata.chunked_seq_len[
loop_idx]
attn_metadata.host_total_kv_lens[0] = total_ctx_chunked_tokens
# do not apply mask for attention within loop
# latent_cache must be None to differentiate from normal context phase,
# so that we can skip applying RoPE and appending KV cache inside attention op
temp_attn_output = self.mha.forward(
q,
chunked_k,
chunked_v,
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=None,
out_scale=self.out_scale,
attention_mask=PredefinedAttentionMask.FULL,
softmax_stats_tensor=self.temp_softmax_stats_tensor,
chunked_prefill_buffer_batch_size=attn_metadata.
runtime_features.chunked_prefill_buffer_batch_size,
output=temp_attn_output,
)
# merge attn result
temp_merge_op = attn_metadata.merge_op_tensor[loop_idx]
trtllm_attention.merge_attention_for_mla(
attn_output, temp_attn_output, self.softmax_stats_tensor,
self.temp_softmax_stats_tensor, temp_merge_op, attn_metadata)
# deal with the uncached kv
kv = self.kv_b_proj(compressed_kv)
_, k_pe = latent_cache.view([
-1, self.kv_lora_rank + self.qk_rope_head_dim
]).split([self.kv_lora_rank, self.qk_rope_head_dim], -1)
# final round of attention
k_nope, v = kv.split(
[
self.num_heads_tp * self.qk_nope_head_dim,
self.num_heads_tp * self.v_head_dim
],
-1,
)
k_nope = k_nope.view(-1, self.num_heads_tp, self.qk_nope_head_dim)
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
k = maybe_compiled_cat((k_nope, k_pe.expand(-1, self.num_heads_tp, -1)),
dim=-1)
k = k.view(-1, self.num_heads_tp * self.qk_head_dim)
# copy q_lens to replace kv_lens_runtime
attn_metadata.kv_lens_runtime = attn_metadata.prompt_lens_cpu_runtime
attn_metadata.kv_lens_cuda_runtime = attn_metadata.prompt_lens_cuda_runtime
attn_metadata.host_total_kv_lens[
0] = attn_metadata.prompt_lens_cpu_runtime[:attn_metadata.
num_contexts].sum().item(
)
# latent_cache must be None to differentiate from normal context phase,
# so that we can skip applying RoPE and appending KV cache inside attention op
temp_attn_output = self.mha.forward(
q,
k,
v,
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=None,
out_scale=self.out_scale,
softmax_stats_tensor=self.temp_softmax_stats_tensor,
chunked_prefill_buffer_batch_size=attn_metadata.runtime_features.
chunked_prefill_buffer_batch_size,
output=temp_attn_output,
)
temp_merge_op = attn_metadata.merge_op_tensor[chunked_loop_num]
trtllm_attention.merge_attention_for_mla(attn_output, temp_attn_output,
self.softmax_stats_tensor,
self.temp_softmax_stats_tensor,
temp_merge_op, attn_metadata)
# copy back kv_lens_runtime and kv_lens_cuda_runtime
attn_metadata.kv_lens_runtime = origin_kv_lens_runtime
attn_metadata.kv_lens_cuda_runtime = origin_kv_lens_cuda_runtime
attn_metadata.host_total_kv_lens[0] = origin_ctx_total_kv_len
return attn_output
def forward_context(
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor,
latent_cache: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(self.mha, TrtllmAttention):
assert isinstance(attn_metadata, TrtllmAttentionMetadata)
trtllm_attention = cast(TrtllmAttention, self.mha)
if trtllm_attention.is_chunked_prefill_for_mla_context(
attn_metadata):
return self.forward_context_with_chunked_prefill(
q, compressed_kv, latent_cache, attn_metadata, output)
elif trtllm_attention.has_cached_kv_for_mla_context(attn_metadata):
return self.forward_context_with_cached_kv(
q, latent_cache, attn_metadata, output)
return self.forward_context_default(q, compressed_kv, k_pe,
position_ids, attn_metadata, output,
latent_cache)
def forward_absorption_generation(
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
latent_cache: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
num_tokens = q.shape[0]
q_nope, q_pe = q.view([-1, self.num_heads_tp, self.qk_head_dim]).split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# fused_q contains 1) the result of the following bmm with shape [num_tokens, num_heads, kv_lora_rank]
# 2) rope(q_pe) with shape [num_tokens, num_heads, qk_rope_head_dim]. rope is applied inside AttentionOp
num_seqs = attn_metadata.kv_lens_cuda_runtime.size(0)
cu_q_seqlens = torch.empty(num_seqs + 1,
dtype=torch.int32,
device=q.device)
cu_kv_seqlens = torch.empty(num_seqs + 1,
dtype=torch.int32,
device=q.device)
fmha_scheduler_counter = torch.empty(1,
dtype=torch.uint32,
device=q.device)
has_fp8_kv_cache = self.mqa.has_fp8_kv_cache if hasattr(
self.mqa, 'has_fp8_kv_cache') else False
mla_bmm1_scale = None
mla_bmm2_scale = None
quant_q_buffer = None
if has_fp8_kv_cache:
mla_bmm1_scale = torch.empty(2,
dtype=torch.float32,
device=q.device)
mla_bmm2_scale = torch.empty(1,
dtype=torch.float32,
device=q.device)
quant_q_buffer = torch.empty(
num_tokens,
self.num_heads_tp, (self.kv_lora_rank + self.qk_rope_head_dim),
dtype=torch.uint8,
device=q.device)
fused_q = torch.empty(
[
num_tokens, self.num_heads_tp,
(self.kv_lora_rank + self.qk_rope_head_dim)
],
dtype=q.dtype,
device=q.device,
)
rope_stream = self.aux_stream if not has_fp8_kv_cache else None
if self.k_b_proj_trans.dtype == torch.bfloat16:
# [num_heads, num_tokens, self.qk_nope_head_dim]
q_nope_t = q_nope.transpose(0, 1)
# [num_heads, num_tokens, self.kv_lora_rank]
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
# [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim]
# -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank]
# The output of bmm is written directly into fused_q
maybe_execute_in_parallel(
lambda: torch.ops.trtllm.bmm_out(
q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out),
lambda: self.mqa.mla_rope_generation(
fused_q, q_pe, latent_cache, attn_metadata, cu_q_seqlens,
cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale,
mla_bmm2_scale, quant_q_buffer),
self.ln_events[0],
self.ln_events[1],
rope_stream,
)
elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn:
# [num_heads, num_tokens, self.kv_lora_rank]
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
maybe_execute_in_parallel(
lambda: fp8_block_scaling_bmm_out(
q_nope,
self.k_b_proj_trans,
self.k_b_proj_trans_scale,
q_nope_out,
self.k_b_proj_trans_dequant,
),
lambda: self.mqa.mla_rope_generation(
fused_q, q_pe, latent_cache, attn_metadata, cu_q_seqlens,
cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale,
mla_bmm2_scale, quant_q_buffer),
self.ln_events[0],
self.ln_events[1],
rope_stream,
)
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")
fused_q = fused_q.view([
num_tokens,
self.num_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim)
])
# Use generation_only for generation phase and context_only for context phase in DSA attention
attention_input_type = AttentionInputType.generation_only
attn_out_latent = self._attn_forward_gen(
self.mqa,
fused_q,
None,
None,
position_ids,
attn_metadata,
attention_input_type=attention_input_type,
out_scale=self.out_scale,
latent_cache=latent_cache, # kvcache and k_pe
q_pe=q_pe, # used by `invokeMLARopeGeneration`
topk_indices=topk_indices, # used by DSA attention
is_generation=True, # used by DSA attention
cu_q_seqlens=cu_q_seqlens, # used by `mlaGeneration`
cu_kv_seqlens=cu_kv_seqlens, # used by `mlaGeneration`
fmha_scheduler_counter=
fmha_scheduler_counter, # used by `mlaGeneration`
mla_bmm1_scale=mla_bmm1_scale, # used by `mlaGeneration`
mla_bmm2_scale=mla_bmm2_scale, # used by `mlaGeneration`
quant_q_buffer=quant_q_buffer, # used by `mlaGeneration`
)
fused_q = None
# note: if we do not have CP, then num_heads_tp_cp == num_heads_tp
assert (attn_out_latent.shape[0] == q.shape[0]
and attn_out_latent.shape[1]
== self.num_heads_tp_cp * self.kv_lora_rank)
# [seq, num_heads, kv_lora_rank]
attn_out_latent = attn_out_latent.view(
[-1, self.num_heads_tp_cp, self.kv_lora_rank])
attn_output = output.view(
[num_tokens, self.num_heads_tp_cp, self.v_head_dim])
if self.v_b_proj.dtype == torch.bfloat16:
# [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim]
# -> [num_heads, seq, v_head_dim]
torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1),
self.v_b_proj.transpose(1, 2),
attn_output.transpose(0, 1))
elif self.v_b_proj.dtype == torch.float8_e4m3fn:
fp8_block_scaling_bmm_out(
attn_out_latent,
self.v_b_proj,
self.v_b_proj_scale,
attn_output.transpose(0, 1),
self.v_b_proj_dequant,
)
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")
return output
def forward_absorption_context(
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
latent_cache: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
num_tokens = q.shape[0]
q_nope, q_pe = q.view([-1, self.num_heads_tp, self.qk_head_dim]).split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# fused_q contains 1) the result of the following bmm with shape [num_tokens, num_heads, kv_lora_rank]
# 2) rope(q_pe) with shape [num_tokens, num_heads, qk_rope_head_dim]. rope is applied inside AttentionOp
fused_q = torch.empty(
[
num_tokens, self.num_heads_tp,
(self.kv_lora_rank + self.qk_rope_head_dim)
],
dtype=q.dtype,
device=q.device,
)
if self.k_b_proj_trans.dtype == torch.bfloat16:
# [num_heads, num_tokens, self.qk_nope_head_dim]
q_nope_t = q_nope.transpose(0, 1)
# [num_heads, num_tokens, self.kv_lora_rank]
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
# [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim]
# -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank]
# The output of bmm is written directly into fused_q
torch.ops.trtllm.bmm_out(q_nope_t,
self.k_b_proj_trans.transpose(1, 2),
q_nope_out)
elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn:
# [num_heads, num_tokens, self.kv_lora_rank]
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
fp8_block_scaling_bmm_out(
q_nope,
self.k_b_proj_trans,
self.k_b_proj_trans_scale,
q_nope_out,
self.k_b_proj_trans_dequant,
)
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")
if self.apply_rotary_emb:
fused_q[..., self.kv_lora_rank:] = q_pe
fused_q = fused_q.view([
num_tokens,
self.num_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim)
])
# Use generation_only for generation phase and context_only for context phase in DSA attention
attention_input_type = AttentionInputType.context_only
attn_out_latent = self._attn_forward_gen(
self.mqa,
fused_q,
None,
None,
position_ids,
attn_metadata,
attention_input_type=attention_input_type,
out_scale=self.out_scale,
latent_cache=latent_cache, # kvcache and k_pe
q_pe=q_pe, # used by `invokeMLARopeGeneration`
topk_indices=topk_indices, # used by DSA attention
is_generation=False, # used by DSA attention
)
fused_q = None
# note: if we do not have CP, then num_heads_tp_cp == num_heads_tp
assert (attn_out_latent.shape[0] == q.shape[0]
and attn_out_latent.shape[1]
== self.num_heads_tp_cp * self.kv_lora_rank)
# [seq, num_heads, kv_lora_rank]
attn_out_latent = attn_out_latent.view(
[-1, self.num_heads_tp_cp, self.kv_lora_rank])
attn_output = output.view(
[num_tokens, self.num_heads_tp_cp, self.v_head_dim])
if self.v_b_proj.dtype == torch.bfloat16:
# [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim]
# -> [num_heads, seq, v_head_dim]
torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1),
self.v_b_proj.transpose(1, 2),
attn_output.transpose(0, 1))
elif self.v_b_proj.dtype == torch.float8_e4m3fn:
fp8_block_scaling_bmm_out(
attn_out_latent,
self.v_b_proj,
self.v_b_proj_scale,
attn_output.transpose(0, 1),
self.v_b_proj_dequant,
)
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")
return output
@nvtx_range("forward_sparse_mla_kvcache_bf16")
def forward_sparse_mla_kvcache_bf16(
self,
q: torch.Tensor,
latent_cache: torch.Tensor,
attn_metadata: DSAtrtllmAttentionMetadata,
output: torch.Tensor,
topk_indices: torch.Tensor,
is_generation: bool = False,
) -> torch.Tensor:
"""
Forward sparse MLA (DSA) for BF16 KV cache for both context and generation phases using FlashMLA kernels
To form the input for FlashMLA kernel and adapt our KV cache manager, we need to:
1. Append current tokens to paged cache and apply rope to q/k via mla_rope_append_paged_kv_assign_q
2. Load full kv cache from paged memory (with k rope applied)
3. Call FlashMLA sparse attention kernel for sparse prefill/decode
"""
assert isinstance(attn_metadata, DSAtrtllmAttentionMetadata), \
"DSA requires DSAtrtllmAttentionMetadata"
# Append current tokens to paged cache and apply RoPE to q
# This writes latent_cache to paged KV and modifies q in-place
trtllm_attention = self.mqa
with nvtx_range_debug(
f"mla_rope_append_paged_kv_assign_q_is_generation={is_generation}"
):
trtllm_attention.mla_rope_append_paged_kv_assign_q(
q, latent_cache, attn_metadata, is_generation=is_generation)
num_tokens = q.shape[0]
q_nope, q_rope = q.view(-1, self.num_heads_tp, self.qk_head_dim).split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_nope_out = torch.empty(
[num_tokens, self.num_heads_tp, (self.kv_lora_rank)],
dtype=q.dtype,
device=q.device,
)
if self.k_b_proj_trans.dtype == torch.bfloat16:
# [num_heads, num_tokens, self.qk_nope_head_dim]
q_nope_t = q_nope.transpose(0, 1)
# [num_heads, num_tokens, self.kv_lora_rank]
q_nope_out = q_nope_out.transpose(0, 1)
# [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim]
# -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank]
# The output of bmm is written directly into fused_q
torch.ops.trtllm.bmm_out(q_nope_t,
self.k_b_proj_trans.transpose(1, 2),
q_nope_out)
elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn:
# [num_heads, num_tokens, self.kv_lora_rank]
q_nope_out = q_nope_out.transpose(0, 1)
fp8_block_scaling_bmm_out(
q_nope,
self.k_b_proj_trans,
self.k_b_proj_trans_scale,
q_nope_out,
self.k_b_proj_trans_dequant,
)
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")
q_nope_out = q_nope_out.transpose(0, 1)
q_concat = torch.cat([q_nope_out, q_rope], dim=-1)
sm_version = get_sm_version()
# FlashMLA sparse kernel (bf16) requires num_heads=128 on sm100 or multiple of 64 on sm90
if sm_version >= 100:
padding = 128
assert self.num_heads_tp <= padding, (
f"SM100 FlashMLA sparse kernel requires exactly {padding} heads, "
f"got {self.num_heads_tp}. Padding from values > {padding} is not supported."
)
else: # SM90
padding = ((self.num_heads_tp + 63) // 64) * 64 # multiple of 64
if self.num_heads_tp != padding:
logger.warning_once(
f"Padding num_heads from {self.num_heads_tp} to {padding} "
f"due to FlashMLA sparse attention kernel requirement",
key="sparse_mla_padding_warning")
# Create padded tensor with zeros for extra heads
q_padded = q_concat.new_empty(
(num_tokens, padding, q_concat.shape[2]))
q_padded[:, :self.num_heads_tp, :] = q_concat
q_concat = q_padded
# Convert indices and return all-layer KV pool
# Note: underlying pool is layer-interleaved: [num_blocks, num_layers, kv_factor, tokens_per_block, num_kv_heads, head_dim]
# to avoid reshape(copy) per-layer KV cache, we return all-layer KV pool w/ topk indices adjusted by stride_factor=num_layers*tokens_per_block
topk_indices_pool, kv_cache_pool = transform_local_topk_and_prepare_pool_view(
topk_indices,
attn_metadata,
layer_idx=self.layer_idx,
is_generation=is_generation,
)
topk_indices_pool = topk_indices_pool.view(num_tokens, 1, -1)
if flash_mla_sparse_fwd is not None:
attn_out_latent = flash_mla_sparse_fwd(q_concat, kv_cache_pool,
topk_indices_pool,
self.softmax_scale)[0]
else:
raise RuntimeError(
"flash_mla_sparse_fwd not available. Please ensure FlashMLA module is built."
)
# [seq, num_heads, kv_lora_rank], account for padding
attn_out_latent = attn_out_latent[:, :self.num_heads_tp, :]
# TODO: seems we need .contiguous() here when padding enabled before pass to bmm?
attn_out_latent = attn_out_latent.view(
[-1, self.num_heads_tp, self.kv_lora_rank])
assert (attn_out_latent.shape[0] == q.shape[0]
and attn_out_latent.shape[1] == self.num_heads_tp)
attn_output = output.view(
[num_tokens, self.num_heads_tp, self.v_head_dim])
if self.v_b_proj.dtype == torch.bfloat16:
# [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim]
# -> [num_heads, seq, v_head_dim]
torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1),
self.v_b_proj.transpose(1, 2),
attn_output.transpose(0, 1))
elif self.v_b_proj.dtype == torch.float8_e4m3fn:
fp8_block_scaling_bmm_out(
attn_out_latent,
self.v_b_proj,
self.v_b_proj_scale,
attn_output.transpose(0, 1),
self.v_b_proj_dequant,
)
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")
return output
def forward(
self,
position_ids: Optional[torch.Tensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
all_reduce_params: Optional[AllReduceParams] = None,
latent_cache_gen: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attn_output = self.create_output(hidden_states,
attn_metadata.num_contexts)
if self.is_dsa:
self.forward_impl_with_dsa(position_ids,
hidden_states,
attn_metadata,
output=attn_output)
elif self.register_to_config:
torch.ops.trtllm.mla_custom_op_inplace(hidden_states, position_ids,
self.layer_idx_str,
attn_output,
latent_cache_gen)
else:
self.forward_impl(position_ids,
hidden_states,
attn_metadata,
output=attn_output,
latent_cache_gen=latent_cache_gen)
if self.enable_unit_test and self.mapping.cp_size > 1:
# note: for allowing testing Helix parallelism, we ensure that
# the output is compatible with o_proj even in the context phase,
# thus we cut it to num_heads_tp_cp * v_head_dim
attn_output = attn_output[:, :self.num_heads_tp_cp *
self.v_head_dim].contiguous()
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)
return attn_output