mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 02:02:01 +08:00
[None][fix] impl fused triton kernel for e8m0 resmooth to reduce memory footprint (#10327)
Signed-off-by: Nekofish-L <liuxiangyang@mail.ustc.edu.cn> Co-authored-by: Kanghwan <861393+karljang@users.noreply.github.com>
This commit is contained in:
parent
f001c4946d
commit
03cdf5804f
@ -51,45 +51,85 @@ def per_token_cast_to_fp8_e8m0(
|
||||
g, m, n), sf
|
||||
|
||||
|
||||
def per_block_cast_to_fp8_e8m0(
|
||||
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if x.dim() == 2:
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((align(m, 128), align(n, 128)),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
sf = ceil_to_ue8m0(x_amax / 448.0)
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
||||
x_view.size(0), x_view.size(2))
|
||||
else:
|
||||
g, m, n = x.shape
|
||||
x_padded = torch.zeros((g, align(m, 128), align(n, 128)),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:, :m, :n] = x
|
||||
x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4)
|
||||
sf = ceil_to_ue8m0(x_amax / 448.0)
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view(
|
||||
x_view.size(0), x_view.size(1), x_view.size(3))
|
||||
@triton.jit
|
||||
def _resmooth_kernel(
|
||||
w_ptr,
|
||||
s_ptr,
|
||||
M,
|
||||
K,
|
||||
stride_wb,
|
||||
stride_wm,
|
||||
stride_wk,
|
||||
stride_sb,
|
||||
stride_sm,
|
||||
stride_sk,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
pid_m = tl.program_id(1)
|
||||
pid_k = tl.program_id(2)
|
||||
|
||||
curr_w_ptr = w_ptr + batch_idx * stride_wb
|
||||
curr_s_ptr = s_ptr + batch_idx * stride_sb
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
|
||||
s_offset = pid_m * stride_sm + pid_k * stride_sk
|
||||
old_scale = tl.load(curr_s_ptr + s_offset)
|
||||
|
||||
w_mask = (rm[:, None] < M) & (rk[None, :] < K)
|
||||
w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk
|
||||
w_fp8 = tl.load(curr_w_ptr + w_offsets, mask=w_mask, other=0.0)
|
||||
w_fp32 = w_fp8.to(tl.float32)
|
||||
|
||||
w_val = w_fp32 * old_scale
|
||||
block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4)
|
||||
|
||||
# UE8M0 sf = 2 ^ ceil(log2(sf))
|
||||
new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0)))
|
||||
w_requant = w_val * (1.0 / new_scale)
|
||||
|
||||
tl.store(curr_w_ptr + w_offsets, w_requant, mask=w_mask)
|
||||
tl.store(curr_s_ptr + s_offset, new_scale)
|
||||
|
||||
|
||||
def resmooth_to_fp8_e8m0(weight: torch.Tensor,
|
||||
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
weight = weight.cuda()
|
||||
sf = sf.cuda()
|
||||
if weight.dim() == 2:
|
||||
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
|
||||
128, dim=1)[:weight.shape[0], :weight.shape[1]]
|
||||
else:
|
||||
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
|
||||
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
|
||||
return per_block_cast_to_fp8_e8m0(x)
|
||||
def resmooth_to_fp8_e8m0(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: tuple[int, int] = (128, 128),
|
||||
):
|
||||
assert weight.dtype == torch.float8_e4m3fn
|
||||
assert weight_scale.dtype == torch.float32
|
||||
|
||||
orig_shape = weight.shape
|
||||
M, K = orig_shape[-2:]
|
||||
w_view = weight.view(-1, M, K)
|
||||
s_view = weight_scale.view(-1, weight_scale.shape[-2],
|
||||
weight_scale.shape[-1])
|
||||
|
||||
num_batches = w_view.shape[0]
|
||||
BLOCK_M, BLOCK_K = block_size
|
||||
|
||||
grid = (num_batches, triton.cdiv(M, BLOCK_M), triton.cdiv(K, BLOCK_K))
|
||||
|
||||
_resmooth_kernel[grid](
|
||||
w_view,
|
||||
s_view,
|
||||
M,
|
||||
K,
|
||||
w_view.stride(0),
|
||||
w_view.stride(1),
|
||||
w_view.stride(2),
|
||||
s_view.stride(0),
|
||||
s_view.stride(1),
|
||||
s_view.stride(2),
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_K=BLOCK_K,
|
||||
)
|
||||
# this is an in-place operation, however, we return for simplicity
|
||||
return weight, weight_scale
|
||||
|
||||
|
||||
def get_m_alignment_for_contiguous_layout():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user