TensorRT-LLMs/tensorrt_llm/_torch/modules/attention.py
Bo Li 515dd0d78f
feat: Add support for FP8 MLA on Hopper and Blackwell. (#3190)
* fp8 kv + bf16 ctx MLA + fp8 gen MLA

Use BF16 for context MLA.
mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together.

Allow mSM==90 for mFP8GenerationMLA==true.
For FMHA, dataTypeKv should be FP8.

For FP8 MLA generation, the output is still in BF16.

Refine debug info for FMHA kernel metadata.

Use inputType, outputType, SM together to hash kernel list.

Add FP8 MLA generation FMHA kernel.

Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel.

Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails.

Refine debug info in fused_multihead_attention_v2.cpp

Correct FP8 MLA metadata.

New kernel provided by Yuxin, which outputs BF16.

smem size is not set correctly, which will lead to illegal mem access.

Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched.

There are two bmm1 scales that should be set correctly.

New kernel generated by Yuxin.

Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA.

Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false.

Skip a check in fmhaDispatcher.

Modifications in fmhaRunner:
- Debug dump.
- if (!isFP8GenerationMLA) skips a lot of flag setting.
- TMA descriptor modification for qo (by Yuxin).

Cleanup debug output.

Clean up o tma descriptor modifications.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Resolve conflicts.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Apply the patch of FP8 FlashMLA and resolve conflicts.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Fix compilation error.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Fix compile error.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* pick blackwell support

Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>

* Add copyright notice to fused_multihead_attention_v2.cpp.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Add license.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Add missing license.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Exclude building flashMLA kernels under sm90.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Revert "Exclude building flashMLA kernels under sm90."

    This reverts commit f0c859d459.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Use macro to skip compiling FlashMLA for non sm90 targets.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

---------

Signed-off-by: Bo Li <bobboli0202@gmail.com>
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: Dylan Chen <ziqingc@nvidia.com>
Co-authored-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
2025-04-07 15:14:13 +08:00

649 lines
25 KiB
Python

from typing import Optional
import torch
from torch import nn
from ..attention_backend import (AttentionInputType, AttentionMetadata,
TrtllmAttention)
from ..attention_backend.interface import (PositionalEmbeddingParams,
PredefinedAttentionMask)
from ..attention_backend.utils import create_attention
from ..distributed import AllReduceParams, ParallelConfig, TensorParallelMode
from ..model_config import ModelConfig
from .linear import Linear, WeightMode, WeightsLoadingConfig
from .rms_norm import RMSNorm
from .rotary_embedding import RotaryEmbedding
class Attention(nn.Module):
def __init__(
self,
*,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
max_position_embeddings: int,
bias: bool,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
rotary_emb: Optional[RotaryEmbedding] = None,
layer_idx: Optional[int] = None,
dtype: torch.dtype = None,
dense_bias: Optional[bool] = None,
config: Optional[ModelConfig] = None,
):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.pos_embd_params = pos_embd_params
self.dense_bias = dense_bias
if dense_bias is None:
self.dense_bias = bias
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
# tensor parallel
config = config or ModelConfig()
tp_size = config.mapping.tp_size
tp_rank = config.mapping.tp_rank
gpus_per_node = config.mapping.gpus_per_node
if config.mapping.enable_attention_dp:
tp_size = 1
tp_rank = 0
assert self.num_heads % tp_size == 0
self.num_heads = self.num_heads // tp_size
self.num_key_value_heads = (self.num_key_value_heads + tp_size -
1) // tp_size
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
self.qkv_proj = Linear(
self.hidden_size,
tp_size * self.q_size + 2 * tp_size * self.kv_size,
bias=bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tp_size,
tensor_parallel_rank=tp_rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank),
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.FUSED_QKV_LINEAR),
quant_config=config.get_quant_config(),
skip_create_weights=config.skip_create_weights,
)
self.o_proj = Linear(
self.hidden_size,
self.hidden_size,
bias=self.dense_bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tp_size,
tensor_parallel_rank=tp_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank),
quant_config=config.get_quant_config(),
skip_create_weights=config.skip_create_weights,
)
self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
self.pos_embd_params = pos_embd_params
self.rotary_emb = rotary_emb
if not config.skip_create_weights:
self.create_weights()
def create_weights(self):
self.attn = create_attention(
self.attn_backend,
self.layer_idx,
self.num_heads,
self.head_dim,
self.num_key_value_heads,
pos_embd_params=self.pos_embd_params,
quant_config=self.quant_config,
)
def forward(
self,
position_ids: Optional[torch.LongTensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
CAUSAL,
mrope_config: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
is_fused_qkv = False
if isinstance(self.attn, TrtllmAttention):
is_fused_qkv = True
if is_fused_qkv:
if self.pos_embd_params is None and position_ids is not None:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q, k = self.rotary_emb(
position_ids,
[q.contiguous(), k.contiguous()], attn_metadata)
qkv = torch.concat(
[q.contiguous(),
k.contiguous(),
v.contiguous()], dim=-1)
out_scale = None
if self.o_proj.has_fp8_qdq or self.o_proj.has_nv_fp4 or self.o_proj.has_fp8_block_scales:
out_scale = self.o_proj.inv_input_scale
attn_output = self.attn.forward(qkv,
None,
None,
attn_metadata,
out_scale=out_scale,
attention_mask=attention_mask,
mrope_config=mrope_config)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if self.pos_embd_params is None and position_ids is not None:
q, k = self.rotary_emb(
position_ids,
[q.contiguous(), k.contiguous()], attn_metadata)
attn_output = self.attn.forward(q.contiguous(),
k.contiguous(),
v.contiguous(),
attn_metadata,
attention_mask=attention_mask,
mrope_config=mrope_config)
attn_output = self.o_proj(attn_output)
return attn_output
class MLA(nn.Module):
def __init__(
self,
*,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
predicted_tokens_per_seq: int,
max_position_embeddings: int,
bias: bool,
aux_stream: Optional[torch.cuda.Stream] = None,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
rotary_emb: Optional[RotaryEmbedding] = None,
layer_idx: Optional[int] = None,
dtype: torch.dtype = None,
dense_bias: Optional[bool] = None,
config: Optional[ModelConfig] = None,
):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.predicted_tokens_per_seq = predicted_tokens_per_seq
self.max_position_embeddings = max_position_embeddings
self.pos_embd_params = pos_embd_params
self.dense_bias = dense_bias
if dense_bias is None:
self.dense_bias = bias
if self.q_lora_rank is None:
self.q_lora_rank = hidden_size
self.is_lite = True
else:
self.is_lite = False
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
# tensor parallel
config = config or ModelConfig()
tp_size = config.mapping.tp_size
tp_rank = config.mapping.tp_rank
gpus_per_node = config.mapping.gpus_per_node
if config.mapping.enable_attention_dp:
tp_size = 1
tp_rank = 0
row_parallel_config = ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.ROW,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank,
)
col_parallel_config = ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank,
)
assert self.num_heads % tp_size == 0
self.num_heads = self.num_heads // tp_size
self.num_key_value_heads = (self.num_key_value_heads + tp_size -
1) // tp_size
rms_norm_eps = config.pretrained_config.rms_norm_eps
quant_config = config.get_quant_config()
quant_mode = quant_config.quant_mode
if not self.is_lite:
self.fused_a = Linear(
hidden_size,
self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
bias=bias,
dtype=dtype,
quant_config=quant_config,
skip_create_weights=config.skip_create_weights,
use_custom_cublas_mm=True)
self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank,
eps=rms_norm_eps,
dtype=dtype)
self.q_b_proj = Linear(
self.q_lora_rank,
tp_size * self.num_heads *
(self.qk_nope_head_dim + self.qk_rope_head_dim),
bias=bias,
dtype=dtype,
parallel_config=col_parallel_config,
quant_config=quant_config,
skip_create_weights=config.skip_create_weights)
else:
self.fused_a = Linear(
hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=bias,
dtype=dtype,
quant_config=quant_config,
skip_create_weights=config.skip_create_weights,
use_custom_cublas_mm=True)
self.q_proj = Linear(
self.q_lora_rank,
tp_size * self.num_heads *
(self.qk_nope_head_dim + self.qk_rope_head_dim),
bias=bias,
dtype=dtype,
parallel_config=col_parallel_config,
quant_config=quant_config,
skip_create_weights=config.skip_create_weights)
self.q_b_proj = self.q_proj
self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank,
dtype=dtype,
eps=rms_norm_eps)
if quant_mode.has_fp8_block_scales():
mla_weight_dtype = torch.float8_e4m3fn
else:
mla_weight_dtype = dtype
self.kv_b_proj = Linear(self.kv_lora_rank,
tp_size * self.num_heads *
(self.qk_nope_head_dim + self.v_head_dim),
bias=bias,
dtype=dtype,
parallel_config=col_parallel_config,
quant_config=quant_config,
skip_create_weights=config.skip_create_weights)
# This parameter will view into self.kv_b_proj.weight after loading weights.
# For dummy weight initialization, this parameter is initialized with empty tensor.
self.v_b_proj = nn.Parameter(
torch.empty(
(self.num_heads, self.v_head_dim, self.kv_lora_rank),
dtype=dtype,
),
requires_grad=False,
)
self.k_b_proj_trans = nn.Parameter(
torch.empty(
(self.num_heads, self.kv_lora_rank, self.qk_nope_head_dim),
dtype=mla_weight_dtype,
),
requires_grad=False,
)
if quant_mode.has_fp8_block_scales():
self.k_b_proj_trans_scale = nn.Parameter(
torch.empty(
(
self.num_heads,
self.kv_lora_rank // 128,
self.qk_nope_head_dim // 128,
),
dtype=torch.float32,
),
requires_grad=False,
)
# This parameter will view into self.kv_b_proj.weight_scale after loading weights.
# For dummy weight initialization, this parameter is initialized with empty tensor.
self.v_b_proj_scale = nn.Parameter(
torch.empty(
(
self.num_heads,
self.v_head_dim // 128,
self.kv_lora_rank // 128,
),
dtype=torch.float32,
),
requires_grad=False,
)
else:
self.k_b_proj_trans_scale = None
self.v_b_proj_scale = None
self.o_proj = Linear(
self.num_key_value_heads * self.v_head_dim * tp_size,
self.hidden_size,
bias=self.dense_bias,
dtype=dtype,
parallel_config=row_parallel_config,
quant_config=quant_config,
skip_create_weights=config.skip_create_weights,
)
self.mha = create_attention(
config.attn_backend,
self.layer_idx,
self.num_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
self.num_key_value_heads,
pos_embd_params=pos_embd_params,
quant_config=quant_config,
is_mla_enable=True,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
predicted_tokens_per_seq=self.predicted_tokens_per_seq,
)
self.mqa = create_attention(
config.attn_backend,
self.layer_idx,
self.num_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
1, # num_kv_heads
pos_embd_params=pos_embd_params,
quant_config=quant_config,
is_mla_enable=True,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.kv_lora_rank,
predicted_tokens_per_seq=self.predicted_tokens_per_seq,
)
self.rotary_emb = rotary_emb
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
def forward(
self,
position_ids: Optional[torch.LongTensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
all_reduce_params: Optional[AllReduceParams] = None,
) -> torch.Tensor:
if self.is_lite:
compressed_kv, k_pe = self.fused_a(hidden_states).split(
[self.kv_lora_rank, self.qk_rope_head_dim], -1)
compressed_kv = self.kv_a_layernorm(compressed_kv)
compressed_q = hidden_states
else:
compressed_q, compressed_kv, k_pe = self.fused_a(
hidden_states).split([
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim
], -1)
do_multi_stream = torch.cuda.is_current_stream_capturing(
) and self.aux_stream is not None
if do_multi_stream:
self.ln_events[0].record()
compressed_kv = self.kv_a_layernorm(compressed_kv)
with torch.cuda.stream(self.aux_stream):
self.ln_events[0].wait()
compressed_q = self.q_a_layernorm(compressed_q)
self.ln_events[1].record()
self.ln_events[1].wait()
else:
compressed_q = self.q_a_layernorm(compressed_q)
compressed_kv = self.kv_a_layernorm(compressed_kv)
q = self.q_b_proj(compressed_q)
# split q, k, v into context and gen batches
num_contexts = attn_metadata.num_contexts
num_generations = attn_metadata.num_generations
num_ctx_tokens = attn_metadata.num_ctx_tokens
num_tokens = attn_metadata.num_tokens
assert q.shape[
0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}"
if num_contexts > 0:
q_ctx = q[:num_ctx_tokens, ...]
compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...]
k_pe_ctx = k_pe[:num_ctx_tokens, ...]
attn_output_context = self.forward_context(q_ctx, compressed_kv_ctx,
k_pe_ctx, attn_metadata)
else:
attn_output_context = None
if num_generations > 0:
q_gen = q[num_ctx_tokens:, ...]
compressed_kv_gen = compressed_kv[num_ctx_tokens:, ...]
k_pe_gen = k_pe[num_ctx_tokens:, ...]
attn_output_gen = self.forward_generation(q_gen, compressed_kv_gen,
k_pe_gen, attn_metadata)
else:
attn_output_gen = None
# merge context and gen batches
if attn_output_context is not None and attn_output_gen is not None:
assert (
len(attn_output_context.shape) == 2
), f"attn_output_context must be rank 2, not {len(attn_output_context.shape)}"
assert (
len(attn_output_gen.shape) == 2
), f"attn_output_gen must be rank 2, not {len(attn_output_gen.shape)}"
attn_output = torch.cat([attn_output_context, attn_output_gen],
dim=0)
elif attn_output_gen is None:
attn_output = attn_output_context
else:
attn_output = attn_output_gen
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)
return attn_output
def forward_context(
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
latent_cache = torch.cat([compressed_kv, k_pe], dim=-1)
kv = self.kv_b_proj(compressed_kv)
k_nope, v = kv.split(
[
self.num_heads * self.qk_nope_head_dim,
self.num_heads * self.v_head_dim
],
-1,
)
k = torch.empty_like(q).view(
-1, self.num_heads, (self.qk_nope_head_dim + self.qk_rope_head_dim))
k[..., :self.qk_nope_head_dim] = k_nope.view(-1, self.num_heads,
self.qk_nope_head_dim)
k = k.view(
-1,
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
# Concat q(including q_pe), k + k_pe, v together as input_qkv
input_qkv = torch.cat([q, k, v], dim=-1)
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
out_scale = None # Currently we use BF16 MHA for context phase
attn_output = self.mha.forward(
input_qkv,
None,
None,
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=latent_cache,
out_scale=out_scale,
)
return attn_output
def forward_generation(
self,
q: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
num_tokens = q.shape[0]
latent_cache = torch.concat([compressed_kv, k_pe], dim=-1)
q_nope, q_pe = q.view([
-1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim
]).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# fused_q contains 1) the result of the following bmm with shape [num_tokens, num_heads, kv_lora_rank]
# 2) rope(q_pe) with shape [num_tokens, num_heads, qk_rope_head_dim]. rope is applied inside AttentionOp
fused_q = torch.empty(
[
num_tokens, self.num_heads,
(self.kv_lora_rank + self.qk_rope_head_dim)
],
dtype=q.dtype,
device=q.device,
)
if self.k_b_proj_trans.dtype == torch.bfloat16:
# [num_heads, num_tokens, self.qk_nope_head_dim]
q_nope = q_nope.transpose(0, 1)
# [num_heads, num_tokens, self.kv_lora_rank]
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
# [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim]
# -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank]
# The output of bmm is written directly into fused_q
torch.ops.trtllm.bmm_out(q_nope,
self.k_b_proj_trans.transpose(1, 2),
q_nope_out)
elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn:
q_nope, q_nope_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
q_nope)
# [num_heads, num_tokens, self.kv_lora_rank]
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
torch.ops.trtllm.fp8_block_scaling_bmm_out(
q_nope, self.k_b_proj_trans, q_nope_scales,
self.k_b_proj_trans_scale, q_nope_out)
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")
fused_q = fused_q.view([
num_tokens,
self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim)
])
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
attn_out_latent = self.mqa.forward(
fused_q,
None,
None,
attn_metadata,
attention_input_type=AttentionInputType.generation_only,
out_scale=out_scale,
latent_cache=latent_cache, # kvcache and k_pe
q_pe=q_pe, # used by `invokeMLARopeGeneration`
)
assert (attn_out_latent.shape[0] == q.shape[0] and
attn_out_latent.shape[1] == self.num_heads * self.kv_lora_rank)
# [seq, num_heads, kv_lora_rank]
attn_out_latent = attn_out_latent.view(
[-1, self.num_heads, self.kv_lora_rank])
attn_output = torch.empty([num_tokens, self.num_heads, self.v_head_dim],
dtype=attn_out_latent.dtype,
device=attn_out_latent.device)
if self.v_b_proj.dtype == torch.bfloat16:
# [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim]
# -> [num_heads, seq, v_head_dim]
torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1),
self.v_b_proj.transpose(1, 2),
attn_output.transpose(0, 1))
elif self.v_b_proj.dtype == torch.float8_e4m3fn:
attn_out_latent, attn_out_latent_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
attn_out_latent)
torch.ops.trtllm.fp8_block_scaling_bmm_out(
attn_out_latent, self.v_b_proj, attn_out_latent_scales,
self.v_b_proj_scale, attn_output.transpose(0, 1))
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")
# [seq, num_heads * v_head_dim]
attn_output_flatten = attn_output.flatten(1, 2)
return attn_output_flatten