[None][fix] impl fused triton kernel for e8m0 resmooth to reduce memory footprint (#10327)

Signed-off-by: Nekofish-L <liuxiangyang@mail.ustc.edu.cn>
Co-authored-by: Kanghwan <861393+karljang@users.noreply.github.com>
This commit is contained in:
Necofish 2026-01-16 14:13:18 +08:00 committed by GitHub
parent f001c4946d
commit 03cdf5804f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -51,45 +51,85 @@ def per_token_cast_to_fp8_e8m0(
g, m, n), sf
def per_block_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.dim() == 2:
m, n = x.shape
x_padded = torch.zeros((align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2))
else:
g, m, n = x.shape
x_padded = torch.zeros((g, align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:, :m, :n] = x
x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(1), x_view.size(3))
@triton.jit
def _resmooth_kernel(
w_ptr,
s_ptr,
M,
K,
stride_wb,
stride_wm,
stride_wk,
stride_sb,
stride_sm,
stride_sk,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
batch_idx = tl.program_id(0)
pid_m = tl.program_id(1)
pid_k = tl.program_id(2)
curr_w_ptr = w_ptr + batch_idx * stride_wb
curr_s_ptr = s_ptr + batch_idx * stride_sb
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
s_offset = pid_m * stride_sm + pid_k * stride_sk
old_scale = tl.load(curr_s_ptr + s_offset)
w_mask = (rm[:, None] < M) & (rk[None, :] < K)
w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk
w_fp8 = tl.load(curr_w_ptr + w_offsets, mask=w_mask, other=0.0)
w_fp32 = w_fp8.to(tl.float32)
w_val = w_fp32 * old_scale
block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4)
# UE8M0 sf = 2 ^ ceil(log2(sf))
new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0)))
w_requant = w_val * (1.0 / new_scale)
tl.store(curr_w_ptr + w_offsets, w_requant, mask=w_mask)
tl.store(curr_s_ptr + s_offset, new_scale)
def resmooth_to_fp8_e8m0(weight: torch.Tensor,
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
weight = weight.cuda()
sf = sf.cuda()
if weight.dim() == 2:
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
128, dim=1)[:weight.shape[0], :weight.shape[1]]
else:
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
return per_block_cast_to_fp8_e8m0(x)
def resmooth_to_fp8_e8m0(
weight: torch.Tensor,
weight_scale: torch.Tensor,
block_size: tuple[int, int] = (128, 128),
):
assert weight.dtype == torch.float8_e4m3fn
assert weight_scale.dtype == torch.float32
orig_shape = weight.shape
M, K = orig_shape[-2:]
w_view = weight.view(-1, M, K)
s_view = weight_scale.view(-1, weight_scale.shape[-2],
weight_scale.shape[-1])
num_batches = w_view.shape[0]
BLOCK_M, BLOCK_K = block_size
grid = (num_batches, triton.cdiv(M, BLOCK_M), triton.cdiv(K, BLOCK_K))
_resmooth_kernel[grid](
w_view,
s_view,
M,
K,
w_view.stride(0),
w_view.stride(1),
w_view.stride(2),
s_view.stride(0),
s_view.stride(1),
s_view.stride(2),
BLOCK_M=BLOCK_M,
BLOCK_K=BLOCK_K,
)
# this is an in-place operation, however, we return for simplicity
return weight, weight_scale
def get_m_alignment_for_contiguous_layout():