TensorRT-LLMs/tensorrt_llm/quantization/utils/fp8_utils.py
Necofish 03cdf5804f
[None][fix] impl fused triton kernel for e8m0 resmooth to reduce memory footprint (#10327)
Signed-off-by: Nekofish-L <liuxiangyang@mail.ustc.edu.cn>
Co-authored-by: Kanghwan <861393+karljang@users.noreply.github.com>
2026-01-15 22:13:18 -08:00

599 lines
19 KiB
Python

from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from tensorrt_llm._utils import nvtx_range
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def align(x: int, y: int) -> int:
return ceil_div(x, y) * y
def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
@nvtx_range("[DG] quantization")
@torch.compile(dynamic=True)
def per_token_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.dim() == 2:
assert x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n), sf
else:
assert x.size(2) % 128 == 0
g, m, n = x.shape
x_view = x.view(g, m, -1, 128)
x_amax = x_view.abs().float().amax(dim=3).view(g, m, -1).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
return (x_view * (1.0 / sf.unsqueeze(3))).to(torch.float8_e4m3fn).view(
g, m, n), sf
@triton.jit
def _resmooth_kernel(
w_ptr,
s_ptr,
M,
K,
stride_wb,
stride_wm,
stride_wk,
stride_sb,
stride_sm,
stride_sk,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
batch_idx = tl.program_id(0)
pid_m = tl.program_id(1)
pid_k = tl.program_id(2)
curr_w_ptr = w_ptr + batch_idx * stride_wb
curr_s_ptr = s_ptr + batch_idx * stride_sb
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
s_offset = pid_m * stride_sm + pid_k * stride_sk
old_scale = tl.load(curr_s_ptr + s_offset)
w_mask = (rm[:, None] < M) & (rk[None, :] < K)
w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk
w_fp8 = tl.load(curr_w_ptr + w_offsets, mask=w_mask, other=0.0)
w_fp32 = w_fp8.to(tl.float32)
w_val = w_fp32 * old_scale
block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4)
# UE8M0 sf = 2 ^ ceil(log2(sf))
new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0)))
w_requant = w_val * (1.0 / new_scale)
tl.store(curr_w_ptr + w_offsets, w_requant, mask=w_mask)
tl.store(curr_s_ptr + s_offset, new_scale)
def resmooth_to_fp8_e8m0(
weight: torch.Tensor,
weight_scale: torch.Tensor,
block_size: tuple[int, int] = (128, 128),
):
assert weight.dtype == torch.float8_e4m3fn
assert weight_scale.dtype == torch.float32
orig_shape = weight.shape
M, K = orig_shape[-2:]
w_view = weight.view(-1, M, K)
s_view = weight_scale.view(-1, weight_scale.shape[-2],
weight_scale.shape[-1])
num_batches = w_view.shape[0]
BLOCK_M, BLOCK_K = block_size
grid = (num_batches, triton.cdiv(M, BLOCK_M), triton.cdiv(K, BLOCK_K))
_resmooth_kernel[grid](
w_view,
s_view,
M,
K,
w_view.stride(0),
w_view.stride(1),
w_view.stride(2),
s_view.stride(0),
s_view.stride(1),
s_view.stride(2),
BLOCK_M=BLOCK_M,
BLOCK_K=BLOCK_K,
)
# this is an in-place operation, however, we return for simplicity
return weight, weight_scale
def get_m_alignment_for_contiguous_layout():
return 128
def get_tma_aligned_size(x: int, element_size: int) -> int:
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return align(x, alignment)
def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor:
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dtype == torch.float and x.dim() in (2, 3)
# First, convert into UE8M0 `uint8_t`
ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8)
# Second, make padded packed tensors
mn, k = x.shape[-2], x.shape[-1]
remove_dim = False
if x.dim() == 2:
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
aligned_mn = get_tma_aligned_size(mn, 4)
aligned_k = align(k, 4)
padded = torch.zeros((b, aligned_mn, aligned_k),
device=x.device,
dtype=torch.uint8)
padded[:, :mn, :k] = ue8m0_tensor
padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn,
aligned_k // 4)
# Finally, transpose
transposed = torch.transpose(
torch.empty((b, aligned_k // 4, aligned_mn),
device=x.device,
dtype=torch.int), 1, 2)
transposed[:, :, :] = padded
aligned_x = transposed[:, :mn, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
def check_sf_layout(sf: torch.Tensor,
mn: int,
k: int,
gran: Tuple[int, int],
num_groups: Optional[int],
tma_stride_check: bool = False,
type_check: Optional[torch.dtype] = None) -> torch.Tensor:
# Type check
if type_check is not None:
assert sf.dtype == type_check
# Always do shape checks
assert sf.dtype in (torch.float, torch.int)
assert sf.dim() == int(num_groups is not None) + 2
if num_groups is not None:
assert sf.size(-3) == num_groups
assert sf.size(-2) == ceil_div(mn, gran[0])
assert sf.size(-1) == ceil_div(
k, gran[1] * (1 if sf.dtype == torch.float else 4))
# TMA stride checks: TMA aligned and MN-major
if tma_stride_check:
if num_groups is not None:
assert sf.stride(-3) == sf.stride(-1) * sf.size(-1)
assert sf.stride(-2) == 1
assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())
return sf
@nvtx_range("[DG] transform_sf_into_required_layout")
def transform_sf_into_required_layout(sf: torch.Tensor,
mn: int,
k: int,
recipe: Tuple[int, int, int],
num_groups: Optional[int] = None,
is_sfa: bool = False):
gran = (recipe[0 if is_sfa else 1], recipe[2])
should_skip_transform = ((sf.dtype == torch.int and gran == (1, 128))
or (sf.dtype == torch.int and gran == (128, 128)))
if not should_skip_transform:
# Pre-transform checks
check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups)
# (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if sf.dtype == torch.float and gran == (1, 128):
sf = get_col_major_tma_aligned_packed_tensor(sf)
return check_sf_layout(sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int)
# (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if sf.dtype == torch.float and gran == (128, 128):
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
sf = get_col_major_tma_aligned_packed_tensor(sf)
return check_sf_layout(sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int)
if should_skip_transform:
# TODO: add transpose kernel if SF layout is not satisfied
return check_sf_layout(sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int)
assert False, f'Unknown cases: {sf.dtype=}, {gran=}'
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
input_ptr,
stride_input_0,
stride_input_1,
stride_input_2,
output_ptr,
stride_output_0,
stride_output_1,
stride_output_2,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
stride_output_scale_2,
masked_m_ptr,
size_k,
fp8_max,
fp8_min,
BLOCK: tl.constexpr,
NUM_STAGE: tl.constexpr,
SCALE_UE8M0: tl.constexpr,
):
expert_id = tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)
block_num_per_expert = tl.num_programs(1)
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4)
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
output_scale_offs = (output_scale_ptr + expert_id * stride_output_scale_0 +
hidden_dim_block_index * stride_output_scale_1)
for token_index in tl.range(token_id,
token_num_cur_expert,
block_num_per_expert,
num_stages=NUM_STAGE):
output_s_int32 = 0
for pack_index in tl.range(4):
local_mask = offs_in_d + pack_index * 128
up = tl.load(
input_ptr_offs + token_index * stride_input_1 +
pack_index * 128,
mask=local_mask < size_k,
other=0.0,
)
gate = tl.load(
input_ptr_offs + token_index * stride_input_1 + size_k +
pack_index * 128,
mask=local_mask < size_k,
other=0.0,
).to(tl.float32)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
output_s = _absmax / fp8_max
if SCALE_UE8M0:
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
output_q = tl.clamp(gate_up / output_s, fp8_min,
fp8_max).to(output_ptr.dtype.element_ty)
output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) <<
(8 * pack_index))
tl.store(
output_ptr_offs + token_index * stride_output_1 +
pack_index * 128,
output_q,
mask=local_mask < size_k,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_2,
output_s_int32,
)
def silu_and_mul_masked_post_quant_fwd(
output: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
quant_group_size: int,
masked_m: torch.Tensor,
scale_ue8m0: bool = False,
):
"""
input shape [g, m, k]
output shape [g, m, k // 2], dtype fp8
output_scale [g, k // 4, m // 2 // 128], dtype int32
quant_group_size int
masked_m shape [g]
"""
assert input.is_contiguous()
assert len(input.shape) == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0
# FP8 quantization parameters
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = finfo.min
g, m, k = input.shape
k = k // 2
# Get block/grid/stage/warp
expert_num = len(masked_m)
if expert_num < 4:
BLOCK_NUM_PER_EXPERT = 64
else:
BLOCK_NUM_PER_EXPERT = 128
BLOCK = quant_group_size * 4
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(k, BLOCK)
grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
expert_num,
)
_silu_and_mul_post_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
masked_m,
k,
fp8_max,
fp8_min,
BLOCK=BLOCK,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
SCALE_UE8M0=scale_ue8m0,
)
output_scale = output_scale.transpose(1, 2)[:, :m, :]
check_sf_layout(
output_scale,
m,
k,
(1, 128),
g,
tma_stride_check=True,
)
return output_scale
@triton.jit
def _per_token_quant_and_transform_kernel(
input_ptr,
stride_input_0,
stride_input_1,
stride_input_2,
output_ptr,
stride_output_0,
stride_output_1,
stride_output_2,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
stride_output_scale_2,
token_num_cur_expert,
size_k,
fp8_max,
fp8_min,
BLOCK: tl.constexpr,
NUM_STAGE: tl.constexpr,
SCALE_UE8M0: tl.constexpr,
):
batch_id = tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)
block_num_per_expert = tl.num_programs(1)
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_output_scale_0 = tl.cast(stride_output_scale_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
stride_output_scale_1 = tl.cast(stride_output_scale_1, dtype=tl.int64)
stride_input_2 = tl.cast(stride_input_2, dtype=tl.int64)
stride_output_2 = tl.cast(stride_output_2, dtype=tl.int64)
stride_output_scale_2 = tl.cast(stride_output_scale_2, dtype=tl.int64)
offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4)
input_ptr_offs = input_ptr + batch_id * stride_input_0 + offs_in_d
output_ptr_offs = output_ptr + batch_id * stride_output_0 + offs_in_d
output_scale_offs = (output_scale_ptr + batch_id * stride_output_scale_0 +
hidden_dim_block_index * stride_output_scale_1)
for token_index in tl.range(token_id,
token_num_cur_expert,
block_num_per_expert,
num_stages=NUM_STAGE):
output_s_int32 = 0
for pack_index in tl.range(4):
local_mask = offs_in_d + pack_index * 128
act = tl.load(
input_ptr_offs + token_index * stride_input_1 +
pack_index * 128,
mask=local_mask < size_k,
other=0.0,
).to(tl.float32)
_absmax = tl.maximum(tl.max(tl.abs(act)), 1e-10)
output_s = _absmax / fp8_max
if SCALE_UE8M0:
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
output_q = tl.clamp(act / output_s, fp8_min,
fp8_max).to(output_ptr.dtype.element_ty)
output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) <<
(8 * pack_index))
tl.store(
output_ptr_offs + token_index * stride_output_1 +
pack_index * 128,
output_q,
mask=local_mask < size_k,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_2,
output_s_int32,
)
def per_token_quant_and_transform(
input: torch.Tensor,
quant_group_size: int = 128,
scale_ue8m0: bool = True,
need_permute102: bool = False,
):
"""
input shape [g, m, k]
output shape [g, m, k // 2], dtype fp8
output_scale [g, k // 4, m // 2 // 128], dtype int32
quant_group_size int
masked_m shape [g]
"""
assert input.shape[-1] % 2 == 0
# FP8 quantization parameters
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max
b = 1
original_input_rank = len(input.shape)
if (original_input_rank == 2):
assert input.is_contiguous()
input = input.unsqueeze(0)
b, m, k = input.shape
elif (original_input_rank == 3):
if need_permute102:
input = input.transpose(0, 1)
b, m, k = input.shape
else:
raise AssertionError(
f"Unsupported input shape rank: {original_input_rank}")
# Create output
output = torch.empty((b, m, k), dtype=torch.float8_e4m3fn, device="cuda")
# Create output scale
alignment = 4
scale_k = ceil_div(k, quant_group_size)
m_padded = align(m, alignment)
scale_k_padded = align(scale_k, alignment)
output_scale = torch.empty((b, scale_k_padded // 4, m_padded),
dtype=torch.int32,
device='cuda')
# Get block/grid/stage/warp
BLOCK_NUM_PER_EXPERT = 64
BLOCK = quant_group_size * 4
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(k, BLOCK)
grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
b,
)
_per_token_quant_and_transform_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
m,
k,
fp8_max,
fp8_min,
BLOCK=BLOCK,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
SCALE_UE8M0=scale_ue8m0,
)
if (original_input_rank == 2):
output = output.squeeze(0)
output_scale = output_scale.squeeze(0)
output_scale = output_scale.transpose(0, 1)[:m, :]
else:
output_scale = output_scale.transpose(1, 2)[:, :m, :]
check_sf_layout(
output_scale,
m,
k,
(1, 128),
num_groups=b if original_input_rank == 3 else None,
tma_stride_check=True,
)
return output, output_scale
def fp8_quantize_1x128_sf_transpose(
x: torch.Tensor,
use_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
x_fp8, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x, use_ue8m0=use_ue8m0)
if x_scale.ndim == 1: # Handle SM version differences (SM90: 1D padded, SM100+: 2D)
x_padded = (x.shape[0] + 3) // 4 * 4
num_blocks = (x.shape[1] + 127) // 128
x_scale = x_scale[:x_padded * num_blocks].view(num_blocks,
x_padded)[:, :x.shape[0]]
x_scale = x_scale.contiguous().transpose(0, 1)
return x_fp8, x_scale