TensorRT-LLMs/cpp/kernels/fmha_v2/train_ops/my_utils.py
qsang-nv 0fd59d64ab
infra: open source fmha v2 kernels (#4185)
* add fmha repo

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix format

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix code style

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix header

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix header kernel_traits.h

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* add .gitignore file

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* add SLIDING_WINDOW_ATTENTION

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix style

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix format

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* update setup.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* update build_wheel.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

---------

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Signed-off-by: qsang-nv <200703406+qsang-nv@users.noreply.github.com>
2025-05-15 10:56:34 +08:00

238 lines
8.3 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import math
import numpy as np
import torch
def widen_cells():
from IPython.core.display import HTML, display
display(HTML("<style>.container { width:85% !important; }</style>"))
np.set_printoptions(edgeitems=1000, linewidth=1000000)
def unpad(x: torch.Tensor, cu_seqlens: torch.IntTensor) -> torch.Tensor:
"""interpret the first two dims as BxS and unpad"""
shape = list(x.shape)
b = len(cu_seqlens) - 1
assert b == shape[0]
total = cu_seqlens[-1].item()
new_shape = [total] + shape[2:]
y = torch.empty(new_shape, dtype=x.dtype, device=x.device)
for bi in range(b):
start = cu_seqlens[bi]
end = cu_seqlens[bi + 1]
si = end - start
y[start:end, :] = x[bi, :si, :]
return y
def pad(x: torch.Tensor, cu_seqlens: torch.Tensor, s: int) -> torch.Tensor:
# interpret the first two dims as BxS and unpad
shape = list(x.shape)
b = len(cu_seqlens) - 1
total = cu_seqlens[-1].item()
assert total == shape[0]
new_shape = [b, s] + shape[1:]
y = torch.zeros(new_shape, dtype=x.dtype, device=x.device)
for bi in range(b):
start = cu_seqlens[bi]
end = cu_seqlens[bi + 1]
si = end - start
y[bi, :si, :] = x[start:end, :]
return y
def mask_softmax(S: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
Snew = torch.zeros_like(S)
shape = S.shape
b = len(cu_seqlens) - 1
assert b == shape[0]
for bi in range(b):
start = cu_seqlens[bi]
end = cu_seqlens[bi + 1]
si = end - start
Snew[bi, :, :si, :si] = S[bi, :, :si, :si]
return Snew
def reshape_softmax(S: torch.Tensor,
b: int,
s: int,
h: int,
d: int,
warps_m: int,
warps_n: int,
extract_dmask: bool = True) -> [torch.Tensor, torch.Tensor]:
m = 16
if warps_m == 1:
m = 16
if warps_n == 1 and s == 1024:
m = 64
if warps_n == 1 and s == 2048:
m = 64
if warps_n == 1 and s == 4096:
m = 128
n = s
m_per_cta, n_per_cta = warps_m * 16, warps_n * 16
mmas_m, mmas_n = m // m_per_cta, n // n_per_cta
loops = s // (mmas_m * m_per_cta)
assert (loops == 8 and s == 128 ) or \
(loops == 16 and s == 256 ) or \
(loops == 32 and s == 512 ) or \
(loops == 24 and s == 384 ) or \
(loops == 64 and s == 1024) or \
(loops == 128 and s == 2048) or \
(loops == 256 and s == 4096) or \
(loops == 32 and s == 4096) or \
(loops == 16 and s == 1024) or \
(loops == 32 and s == 2048)
quads, lohi, lr, vals = 8, 2, 2, 2
# 0 1 2 3 4 5 6 7 8 9 10 11
# B x H x LOOPS x MMAS_M x MMAS_N x WARPS_N x WARPS_M x QUADS x 4 x LOHI x LR x 2:
# MMA register layout
# 0 1 2 3 6 9 7 4 5 10 8 11
# B x H x LOOPS x MMAS_M x WARPS_M x LOHI x QUADS x MMAS_N x WARPS_N x LR x 4 x 2
# Expected format B x H x S x S
S = S.reshape((b, h, loops, mmas_m, mmas_n, warps_n, warps_m, quads, 4, lohi, lr, vals)) \
.permute(0, 1, 2, 3, 6, 9, 7, 4, 5, 10, 8, 11) \
.reshape((b,h,s,s))
if not extract_dmask:
return S, None
# torch.signbit does not work on -0, need to do it in numpy!
# positive is True
SS = S.to(torch.float16)
dmask = torch.tensor(np.logical_not(np.signbit(SS.cpu().numpy())),
dtype=SS.dtype,
device=SS.device)
dmask = dmask.to(S.dtype)
#S = S.abs()
return S, dmask
def reshape_softmax_new(S: torch.Tensor, b: int, s: int, h: int, d: int,
warps_m: int,
warps_n: int) -> [torch.Tensor, torch.Tensor]:
m = s if s == 128 else 16
n = s
m_per_cta, n_per_cta = warps_m * 16, warps_n * 16
mmas_m, mmas_n = m // m_per_cta, n // n_per_cta
loops = s // (mmas_m * m_per_cta)
assert (loops == 1 and s == 128 ) or \
(loops == 16 and s == 256 ) or \
(loops == 32 and s == 512 ) or \
(loops == 24 and s == 384 ) or \
(loops == 64 and s == 1024) or \
(loops == 128 and s == 2048) or \
(loops == 256 and s == 4096)
quads, lohi, lr, vals = 8, 2, 2, 2
# 0 1 2 3 7 8 4 6 10 5 9 11
# B x H x LOOPS x MMAS_M x MMAS_N x WARPS_N x WARPS_M x QUADS x 4 x LOHI x LR x 2:
# MMA register layout
# 0 1 2 3 4 5 6 7 8 9 10 11
# B x H x LOOPS x MMAS_M x WARPS_M x LOHI x QUADS x MMAS_N x WARPS_N x LR x 4 x 2:
# Expected format B x H x S x S
S = S.reshape((b, h, loops, mmas_m, warps_m, lohi, quads, mmas_n, warps_n, lr, 4, vals)) \
.permute(0, 1, 2, 3, 7, 8, 4, 6, 10, 5, 9, 11) \
.reshape(-1)
return S
def mha_ref(qkv, amask, D, b, s, h, d, p_dropout, is_causal, alibi_bias):
qkv = qkv.view(b, s, h, 3, d)
# [b, h, s, d]
q = qkv[:, :, :, 0, :].permute(0, 2, 1, 3)
k = qkv[:, :, :, 1, :].permute(0, 2, 1, 3)
v = qkv[:, :, :, 2, :].permute(0, 2, 1, 3)
#like this multiplications will be done in FP32 too, but we avoid one downcast to FP16
p = torch.matmul(q.float(), k.permute(0, 1, 3, 2).float())
# [b, h, s, s]
p_masked = p / math.sqrt(d) + (1.0 - amask) * -10000.0
if (is_causal):
causal_mask = torch.triu(
torch.ones(s, s, dtype=torch.bool, device=qkv.device), 1)
p_masked.masked_fill_(causal_mask, float('-inf'))
if alibi_bias is not None:
# [b, h, s, s] + [b, h, 1, s]
p_masked = p_masked + alibi_bias
mm = torch.max(p_masked, -1)[0].to(torch.float) # [b, h, s]
mm_kd = torch.max(p_masked, -1, True) # [b, h, s, 1]
exp = torch.exp(p_masked - mm_kd[0])
ll = torch.sum(exp, -1).to(torch.float)
softmax_lse = mm + torch.log(ll) # [b, h, s]
sm = torch.softmax(p_masked, -1)
sm = sm.to(qkv.dtype)
rp_keep = 1.0 / (1.0 - p_dropout)
#d = sm * rp_keep
d = sm * D.to(sm.dtype) * rp_keep # TODO
ctx = torch.matmul(d, v)
ctx = ctx.permute(0, 2, 1, 3).contiguous()
ctx.retain_grad()
p.retain_grad()
sm.retain_grad()
return ctx, p, sm, softmax_lse
def perr(a, b):
a, b = a.float(), b.float()
diff = (a - b)
return (diff.abs().sum() / a.abs().sum()).item()
def mae(a, b):
a, b = a.float(), b.float()
diff = (a - b)
return diff.abs().mean().item()
def build_alibi_tensor(max_seq_len, num_attention_heads, batch_size):
# Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
"""Returns tensor shaped (batch_size * num_attention_heads, 1, max_seq_len)"""
def get_slopes(n):
def get_slopes_power_of_2(n):
# 2 ** (-8/n)
start = (2**(-2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
# [h]
slopes = torch.Tensor(get_slopes(num_attention_heads))
# [h, 1, 1] * [1, 1, s] --> [h, 1, s]
alibi = slopes.unsqueeze(1).unsqueeze(
1) * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand(
num_attention_heads, -1, -1)
#Select the part of the tensor that corresponds to our tensor parallel index.
# [b, h, 1, s]
alibi = alibi.repeat(batch_size, 1, 1, 1)
return alibi