TensorRT-LLMs/cpp/kernels/fmha_v2/train_ops/te_mha.py
qsang-nv 0fd59d64ab
infra: open source fmha v2 kernels (#4185)
* add fmha repo

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix format

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix code style

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix header

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix header kernel_traits.h

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* add .gitignore file

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* add SLIDING_WINDOW_ATTENTION

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix style

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix format

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* update setup.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* update build_wheel.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

---------

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Signed-off-by: qsang-nv <200703406+qsang-nv@users.noreply.github.com>
2025-05-15 10:56:34 +08:00

535 lines
19 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
from typing import Any, Dict, Tuple, Union
import fp8_mha_api
import torch
import transformer_engine.pytorch.cpp_extensions as ext
import transformer_engine.pytorch.fp8 as fp8
import transformer_engine_extensions as tex
from torch.nn.parameter import Parameter
from transformer_engine.pytorch.module import TransformerEngineBaseModule
_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
#FP8Tensors indices used (in this case 8)
# GEMM1_INPUT - unrelated
# GEMM1_WEIGHT - unrelated
# GEMM2_WEIGHT - unrelated
# GRAD_OUTPUT2
# GEMM1_OUTPUT - should be QKV
# GEMM2_INPUT - should be O
# GRAD_INPUT1 - should be dO
# GRAD_OUTPUT1 - should be dQKV
# need Index for:
# S 8
# dP 9
# Make sure no unintended scales are accessed.
for name in tex.FP8Tensors.__entries:
val = int(tex.FP8Tensors.__dict__[name])
if val >= 10:
print(name, val)
assert all([
int(tex.FP8Tensors.__dict__[name]) < 10 for name in tex.FP8Tensors.__entries
])
# Map names to make it easier to read.
META_QKV = tex.FP8Tensors.GEMM1_OUTPUT
META_O = tex.FP8Tensors.GEMM2_INPUT
META_DO = tex.FP8Tensors.GRAD_INPUT1
META_DQKV = tex.FP8Tensors.GRAD_OUTPUT1
# New scales.
META_S = 10
META_DP = 11 #TODO this is E5M2!
class _MHA(torch.autograd.Function):
@staticmethod
def forward(ctx, inp: torch.Tensor, qkv_weight: torch.Tensor,
qkv_bias: torch.Tensor, proj_weight: torch.Tensor,
proj_bias: torch.Tensor, cu_seqlens: torch.Tensor,
num_attention_heads: int, p_dropout: float, max_s: int,
set_zero: bool, fp8_meta: Dict[str, Any],
workspace: torch.Tensor, is_training: bool) -> torch.Tensor:
assert inp.dim() == 2
# Make sure input dimensions are compatible
in_features = qkv_weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
h = num_attention_heads
d = in_features // h
n_tokens = inp.shape[0]
fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"],
fprop_tensor=True)
npad = 256 - (n_tokens % 256)
if npad < 256:
inp = torch.nn.functional.pad(inp, (0, 0, 0, npad))
inputmat, inputmat_t = ext.fp8_cast_transpose_fused(
inp,
fp8_meta["scaling"],
tex.FP8Tensors.GEMM1_INPUT,
fp8_dtype_forward,
)
ext.fp8_cast_transpose_fused(
qkv_weight,
fp8_meta["scaling"],
tex.FP8Tensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=qkv_weight.cast,
transpose_out=qkv_weight.transposed,
)
qkv_out = torch.empty(
inputmat.shape[0],
qkv_weight.shape[0],
dtype=torch.int8,
device="cuda",
)
ext.fp8_gemm(
qkv_weight.cast,
tex.FP8Tensors.GEMM1_WEIGHT,
fp8_dtype_forward,
inputmat,
tex.FP8Tensors.GEMM1_INPUT,
fp8_dtype_forward,
fp8_meta["scaling"],
torch.int8,
workspace,
bias=qkv_bias,
use_bias=True,
out=qkv_out,
out_index=tex.FP8Tensors.GEMM1_OUTPUT,
use_split_accumulator=_2X_ACC_FPROP,
)
##################FP8_FMHA change begins for FPROP ##############################
#### [FP8_FMHA] cast_to_fp16 -> FP16_FMHA can be replaced with FP8_FMHA
#qkv_out = ext.cast_from_fp8(
# qkv_out,
# fp8_meta["scaling"],
# tex.FP8Tensors.GEMM1_OUTPUT,
# fp8_dtype_forward,
# ext.TE_DType[torch.float16]
#)
#qkv_out = qkv_out[:n_tokens,:]
## FMHA
#b = cu_seqlens.numel() - 1
#is_nl = False
#if b < 4 and b > 1:
# max_s = 512
# is_nl = True
#qkv_out = qkv_out.view(-1, 3, h, d)
#context, S_dmask = fmha.fwd(qkv_out, cu_seqlens, p_dropout, max_s, is_training, is_nl, set_zero, None)
#context = context.view(-1, in_features)
#if npad < 256:
# context = torch.nn.functional.pad(context, (0,0,0,npad))
#context, context_t = ext.fp8_cast_transpose_fused(
# context,
# fp8_meta["scaling"],
# tex.FP8Tensors.GEMM2_INPUT,
# fp8_dtype_forward,
#)
qkv_out = qkv_out[:n_tokens, :]
qkv_out = qkv_out.view(-1, 3, h, d)
rng_state = torch.get_rng_state()
context_, M, Z = fp8_mha_api.fwd(
qkv_out,
cu_seqlens,
fp8_meta["scaling"].scale_inv[META_QKV], #d_scale_qkv
fp8_meta["scaling"].scale[META_O], #q_scale_o
fp8_meta["scaling"].amax_history[0][META_S], #amax_s
fp8_meta["scaling"].amax_history[0][META_O], #amax_o
p_dropout,
max_s,
is_training,
set_zero,
None, # gen
)
context = context_.view(-1, in_features)
if npad < 256:
context = torch.nn.functional.pad(context, (0, 0, 0, npad))
# unfortunately can't get rid of this transpose as this is needed for bwd.
context_t = tex.fp8_transpose(
context,
fp8_dtype_forward,
)
##################FP8_FMHA change ends for FPROP ##############################
ext.fp8_cast_transpose_fused(
proj_weight,
fp8_meta["scaling"],
tex.FP8Tensors.GEMM2_WEIGHT,
fp8_dtype_forward,
cast_out=proj_weight.cast,
transpose_out=proj_weight.transposed,
)
proj_out = ext.fp8_gemm(
proj_weight.cast,
tex.FP8Tensors.GEMM2_WEIGHT,
fp8_dtype_forward,
context,
tex.FP8Tensors.GEMM2_INPUT,
fp8_dtype_forward,
fp8_meta["scaling"],
torch.float16,
workspace,
bias=proj_bias,
use_bias=True,
use_split_accumulator=_2X_ACC_FPROP,
)
proj_out = proj_out[:n_tokens, :]
ctx.save_for_backward(
inputmat_t,
qkv_weight,
workspace,
fp8_meta["scaling"].scale_inv[
tex.FP8Tensors.GEMM1_WEIGHT].clone().detach(),
fp8_meta["scaling"].scale_inv[
tex.FP8Tensors.GEMM1_INPUT].clone().detach(),
qkv_out,
M,
Z, #S_dmask,
context_,
context_t,
proj_weight,
fp8_meta["scaling"].scale_inv[
tex.FP8Tensors.GEMM2_WEIGHT].clone().detach(),
fp8_meta["scaling"].scale_inv[
tex.FP8Tensors.GEMM2_INPUT].clone().detach(),
#TODO remove duplicates.
fp8_meta["scaling"].scale_inv[META_QKV].clone().detach(
), # d_scale_qkv
fp8_meta["scaling"].scale_inv[META_S].clone().detach(), # d_scale_s
fp8_meta["scaling"].scale_inv[META_O].clone().detach(), # d_scale_o
fp8_meta["scaling"].scale[META_S].clone().detach(), # q_scale_s
)
ctx.fp8_meta = fp8_meta
ctx.cu_seqlens = cu_seqlens
ctx.p_dropout = p_dropout
ctx.max_s = max_s
ctx.set_zero = set_zero
#ctx.is_nl = is_nl
ctx.hidden_size = in_features
ctx.num_attention_heads = num_attention_heads
ctx.rng_state = rng_state
return proj_out
@staticmethod
def backward(
ctx,
grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
(
inputmat_t,
qkv_weight,
workspace,
qkv_fwd_weight_scale_inv,
qkv_fwd_inp_scale_inv,
qkv_out,
M,
Z, #S_dmask,
context,
context_t,
proj_weight,
proj_fwd_weight_scale_inv,
proj_fwd_inp_scale_inv,
d_scale_qkv,
d_scale_s,
d_scale_o,
q_scale_s,
) = ctx.saved_tensors
#grad_output, grad_output_c, grad_output_t, grad_bias = grad_output_preprocess(
# ctx, grad_output, ctx.parallel_mode == "row"
#)
fp8_dtype_forward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"],
fprop_tensor=True)
fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"],
fprop_tensor=False)
n_tokens = grad_output.shape[0]
npad = 256 - (n_tokens % 256)
if npad < 256:
grad_output = torch.nn.functional.pad(grad_output, (0, 0, 0, npad))
proj_bgrad, proj_grad_output_c, proj_grad_output_t = ext.fp8_cast_transpose_bgrad_fused(
grad_output,
ctx.fp8_meta["scaling"],
tex.FP8Tensors.GRAD_OUTPUT2,
fp8_dtype_backward,
)
# PROJ DGRAD
proj_dgrad = torch.empty(
grad_output.shape[0],
ctx.hidden_size,
dtype=torch.int8,
device="cuda",
)
# print ('PROJ_DGRAD')
ext.fp8_gemm(
proj_weight.transposed,
tex.FP8Tensors.GEMM2_WEIGHT,
fp8_dtype_forward,
proj_grad_output_c,
tex.FP8Tensors.GRAD_OUTPUT2,
fp8_dtype_backward,
ctx.fp8_meta["scaling"],
torch.int8, #float16,
workspace,
bias=proj_bgrad,
use_bias=False,
out=proj_dgrad,
out_index=tex.FP8Tensors.GRAD_INPUT1,
use_split_accumulator=_2X_ACC_DGRAD,
A_scale_inv_override=proj_fwd_weight_scale_inv,
)
# proj_dgrad = ext.cast_to_fp8(
# proj_dgrad,
# ctx.fp8_meta["scaling"],
# tex.FP8Tensors.GRAD_INPUT1,
# fp8_dtype_backward)
# PROJ WGRAD
proj_wgrad = ext.fp8_gemm(
context_t,
tex.FP8Tensors.GEMM2_INPUT,
fp8_dtype_forward,
proj_grad_output_t,
tex.FP8Tensors.GRAD_OUTPUT2,
fp8_dtype_backward,
ctx.fp8_meta["scaling"],
torch.float16,
workspace,
use_split_accumulator=_2X_ACC_WGRAD,
A_scale_inv_override=proj_fwd_inp_scale_inv,
)
####################################################################################
##################FP8_FMHA change begins for BPROP #################################
#### [FP8_FMHA] cast_to_fp16 -> FP16_FMHA dgrad can be replaced with FP8_FMHA dgrad
#proj_dgrad = ext.cast_from_fp8(
# proj_dgrad,
# ctx.fp8_meta["scaling"],
# tex.FP8Tensors.GRAD_INPUT1,
# fp8_dtype_backward,
# ext.TE_DType[torch.float16]
#)
#proj_dgrad = proj_dgrad[:n_tokens,:]
#proj_dgrad = proj_dgrad.view(-1, ctx.num_attention_heads, ctx.hidden_size//ctx.num_attention_heads)
#if ctx.is_nl:
# dqkv, dp, dkv = fmha.bwd_nl(proj_dgrad, qkv_out, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.set_zero)
#else:
# dqkv, dp = fmha.bwd(proj_dgrad, qkv_out, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.set_zero)
rng_state_old = torch.get_rng_state()
torch.set_rng_state(ctx.rng_state)
dqkv, = fp8_mha_api.bwd(
proj_dgrad.view_as(context),
qkv_out,
context,
M,
Z,
ctx.cu_seqlens,
d_scale_qkv,
d_scale_s,
d_scale_o,
ctx.fp8_meta['scaling'].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta['scaling'].scale_inv[META_DP], # d_scale_dp
q_scale_s,
ctx.fp8_meta['scaling'].scale[META_DP], # q_scale_dp
ctx.fp8_meta['scaling'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling'].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta['scaling'].amax_history[0][META_DQKV], # amax_dqkv
ctx.p_dropout,
ctx.max_s,
ctx.set_zero,
None)
torch.set_rng_state(rng_state_old)
dqkv = dqkv.view(-1, 3 * ctx.hidden_size)
if npad < 256:
dqkv = torch.nn.functional.pad(dqkv, (0, 0, 0, npad))
####################################################################################
qkv_bgrad, dqkv_grad_output_c, dqkv_grad_output_t = ext.fp8_cast_transpose_bgrad_fused(
dqkv,
ctx.fp8_meta["scaling"],
tex.FP8Tensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
# QKV DGRAD
qkv_dgrad = ext.fp8_gemm(
qkv_weight.transposed,
tex.FP8Tensors.GEMM1_WEIGHT,
fp8_dtype_forward,
dqkv_grad_output_c,
tex.FP8Tensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.fp8_meta["scaling"],
torch.float16,
workspace,
use_split_accumulator=_2X_ACC_DGRAD,
A_scale_inv_override=qkv_fwd_weight_scale_inv,
)
# QKV WGRAD
qkv_wgrad = ext.fp8_gemm(
inputmat_t,
tex.FP8Tensors.GEMM1_INPUT,
fp8_dtype_forward,
dqkv_grad_output_t,
tex.FP8Tensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.fp8_meta["scaling"],
torch.float16,
workspace,
use_split_accumulator=_2X_ACC_WGRAD,
A_scale_inv_override=qkv_fwd_inp_scale_inv,
)
qkv_dgrad = qkv_dgrad[:n_tokens, :]
fp8.fp8_updates(
ctx.fp8_meta,
reduce_amax_across_tp_group=False,
tp_group=None,
fwd_bwd_update=False,
fwd_only_update=False,
)
return (qkv_dgrad, qkv_wgrad, qkv_bgrad, proj_wgrad, proj_bgrad, None,
None, None, None, None, None, None, None)
#grad_output_c, grad_output_t = fp8_cast_transpose_fused(
# grad_output,
# ctx.fp8_meta["scaling"],
# tex.FP8Tensors.GRAD_OUTPUT1,
# fp8_dtype_backward,
#)
class FP8_MHA(TransformerEngineBaseModule):
def __init__(self, config, params_dtype: torch.dtype = torch.float32):
super().__init__()
self.p_dropout = config.attention_probs_dropout_prob
self.h = config.num_attention_heads
self.hidden_size = config.hidden_size
self.d = self.hidden_size // self.h
self.set_zero = config.packed_samples # TODO read this from config
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
self.qkv_weight = Parameter(
torch.empty(
self.hidden_size * 3,
self.hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
))
self.qkv_bias = Parameter(
torch.empty(
self.hidden_size * 3,
device=torch.cuda.current_device(),
dtype=params_dtype,
))
self.proj_weight = Parameter(
torch.empty(
self.hidden_size,
self.hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
))
self.proj_bias = Parameter(
torch.empty(
self.hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
))
with torch.no_grad():
self.qkv_bias.zero_()
self.qkv_weight.fill_(1.0)
self.proj_bias.zero_()
self.proj_weight.fill_(1.0)
# workspace for cublasLt
self.workspace = torch.empty(_CUBLASLT_WORKSPACE_SIZE_BYTES,
dtype=torch.int8,
device="cuda")
self.max_adjusted = False
def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop"""
super().fp8_init(num_gemms)
if self.max_adjusted:
return
self.fp8_meta['fp8_max'][META_DP] = 57344.0
self.max_adjusted = True
def forward(self, inp: torch.Tensor, cu_seqlens, max_s) -> torch.Tensor:
self.pre_forward(inp, num_gemms=3)
out = _MHA.apply(inp, self.qkv_weight, self.qkv_bias, self.proj_weight,
self.proj_bias, cu_seqlens, self.h, self.p_dropout,
max_s, self.set_zero, self.fp8_meta, self.workspace,
self.training)
if torch.is_grad_enabled() and self.training:
fp8.fp8_updates(
self.fp8_meta,
reduce_amax_across_tp_group=False,
tp_group=None,
fwd_bwd_update=False,
fwd_only_update=True,
)
# out = out.view(-1, self.hidden_size)
return out #, self.fp8_meta["scaling"].amax_history
#fp8_recipe = recipe.DelayedScaling(
# margin=0,
# interval=1,
# fp8_format=recipe.Format.E4M3,
# amax_history_len=1,
# amax_compute_algo="most_recent",
#)
#
#bs = 1
#seq_len = 333
#a = torch.empty(bs*seq_len,1024,dtype=torch.half).cuda()
#a.fill_(0.1)
#seqlen = torch.empty(bs,dtype=torch.int32).cuda()
#seqlen.fill_(seq_len)
##A_index = tex.FP8Tensors.GEMM1_INPUT
##b = torch.ones(20,10,dtype=torch.half).cuda()
##B_index = tex.FP8Tensors.GEMM1_WEIGHT
#class Config():
# def __init__(self):
# self.hidden_size = 1024
# self.attention_probs_dropout_prob = 0.1
# self.num_attention_heads = 16
# self.d = self.hidden_size // self.num_attention_heads
# self.packed_samples = False # TODO read this from config
#mha = FP8_MHA(Config()).half()
#
#with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
# cu_seqlens = torch.zeros(bs+1, device=a.device, dtype=torch.int32)
# cu_seqlens[1:] = torch.cumsum(seqlen, dim=0)
# op = mha(a, cu_seqlens, seq_len)
# op_grad = torch.ones(bs*seq_len, 1024, dtype=torch.float16).cuda()
# op.backward(op_grad)
# print (mha.qkv_weight.grad)
#print ('op {}:{} {} '.format(op.shape, op.dtype, op))