TensorRT-LLMs/tensorrt_llm/quantization/utils/fp8_utils.py
Fanrong Li 1bbc0e323b
[None][fix] Pre-allocate workspaces for DeepGEMM MoE to avoid frequent cudaFree/cudaMalloc (#6811)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2025-08-13 10:27:57 +08:00

521 lines
17 KiB
Python

from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from tensorrt_llm._utils import nvtx_range
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def align(x: int, y: int) -> int:
return ceil_div(x, y) * y
def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
@nvtx_range("[DG] quantization")
@torch.compile(dynamic=True)
def per_token_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.dim() == 2:
assert x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n), sf
else:
assert x.size(2) % 128 == 0
g, m, n = x.shape
x_view = x.view(g, m, -1, 128)
x_amax = x_view.abs().float().amax(dim=3).view(g, m, -1).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
return (x_view * (1.0 / sf.unsqueeze(3))).to(torch.float8_e4m3fn).view(
g, m, n), sf
def per_block_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.dim() == 2:
m, n = x.shape
x_padded = torch.zeros((align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2))
else:
g, m, n = x.shape
x_padded = torch.zeros((g, align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:, :m, :n] = x
x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(1), x_view.size(3))
def resmooth_to_fp8_e8m0(weight: torch.Tensor,
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
weight = weight.cuda()
sf = sf.cuda()
if weight.dim() == 2:
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
128, dim=1)[:weight.shape[0], :weight.shape[1]]
else:
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
return per_block_cast_to_fp8_e8m0(x)
def get_m_alignment_for_contiguous_layout():
return 128
def get_tma_aligned_size(x: int, element_size: int) -> int:
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return align(x, alignment)
def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor:
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dtype == torch.float and x.dim() in (2, 3)
# First, convert into UE8M0 `uint8_t`
ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8)
# Second, make padded packed tensors
mn, k = x.shape[-2], x.shape[-1]
remove_dim = False
if x.dim() == 2:
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
aligned_mn = get_tma_aligned_size(mn, 4)
aligned_k = align(k, 4)
padded = torch.zeros((b, aligned_mn, aligned_k),
device=x.device,
dtype=torch.uint8)
padded[:, :mn, :k] = ue8m0_tensor
padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn,
aligned_k // 4)
# Finally, transpose
transposed = torch.transpose(
torch.empty((b, aligned_k // 4, aligned_mn),
device=x.device,
dtype=torch.int), 1, 2)
transposed[:, :, :] = padded
aligned_x = transposed[:, :mn, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
def check_sf_layout(sf: torch.Tensor,
mn: int,
k: int,
gran: Tuple[int, int],
num_groups: Optional[int],
tma_stride_check: bool = False,
type_check: Optional[torch.dtype] = None) -> torch.Tensor:
# Type check
if type_check is not None:
assert sf.dtype == type_check
# Always do shape checks
assert sf.dtype in (torch.float, torch.int)
assert sf.dim() == int(num_groups is not None) + 2
if num_groups is not None:
assert sf.size(-3) == num_groups
assert sf.size(-2) == ceil_div(mn, gran[0])
assert sf.size(-1) == ceil_div(
k, gran[1] * (1 if sf.dtype == torch.float else 4))
# TMA stride checks: TMA aligned and MN-major
if tma_stride_check:
if num_groups is not None:
assert sf.stride(-3) == sf.stride(-1) * sf.size(-1)
assert sf.stride(-2) == 1
assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())
return sf
@nvtx_range("[DG] transform_sf_into_required_layout")
def transform_sf_into_required_layout(sf: torch.Tensor,
mn: int,
k: int,
recipe: Tuple[int, int, int],
num_groups: Optional[int] = None,
is_sfa: bool = False):
gran = (recipe[0 if is_sfa else 1], recipe[2])
should_skip_transform = ((sf.dtype == torch.int and gran == (1, 128))
or (sf.dtype == torch.int and gran == (128, 128)))
if not should_skip_transform:
# Pre-transform checks
check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups)
# (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if sf.dtype == torch.float and gran == (1, 128):
sf = get_col_major_tma_aligned_packed_tensor(sf)
return check_sf_layout(sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int)
# (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if sf.dtype == torch.float and gran == (128, 128):
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
sf = get_col_major_tma_aligned_packed_tensor(sf)
return check_sf_layout(sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int)
if should_skip_transform:
# TODO: add transpose kernel if SF layout is not satisfied
return check_sf_layout(sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int)
assert False, f'Unknown cases: {sf.dtype=}, {gran=}'
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
input_ptr,
stride_input_0,
stride_input_1,
stride_input_2,
output_ptr,
stride_output_0,
stride_output_1,
stride_output_2,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
stride_output_scale_2,
masked_m_ptr,
size_k,
fp8_max,
fp8_min,
BLOCK: tl.constexpr,
NUM_STAGE: tl.constexpr,
SCALE_UE8M0: tl.constexpr,
):
expert_id = tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)
block_num_per_expert = tl.num_programs(1)
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4)
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
output_scale_offs = (output_scale_ptr + expert_id * stride_output_scale_0 +
hidden_dim_block_index * stride_output_scale_1)
for token_index in tl.range(token_id,
token_num_cur_expert,
block_num_per_expert,
num_stages=NUM_STAGE):
output_s_int32 = 0
for pack_index in tl.range(4):
local_mask = offs_in_d + pack_index * 128
up = tl.load(
input_ptr_offs + token_index * stride_input_1 +
pack_index * 128,
mask=local_mask < size_k,
other=0.0,
)
gate = tl.load(
input_ptr_offs + token_index * stride_input_1 + size_k +
pack_index * 128,
mask=local_mask < size_k,
other=0.0,
).to(tl.float32)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
output_s = _absmax / fp8_max
if SCALE_UE8M0:
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
output_q = tl.clamp(gate_up / output_s, fp8_min,
fp8_max).to(output_ptr.dtype.element_ty)
output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) <<
(8 * pack_index))
tl.store(
output_ptr_offs + token_index * stride_output_1 +
pack_index * 128,
output_q,
mask=local_mask < size_k,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_2,
output_s_int32,
)
def silu_and_mul_masked_post_quant_fwd(
output: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
quant_group_size: int,
masked_m: torch.Tensor,
scale_ue8m0: bool = False,
):
"""
input shape [g, m, k]
output shape [g, m, k // 2], dtype fp8
output_scale [g, k // 4, m // 2 // 128], dtype int32
quant_group_size int
masked_m shape [g]
"""
assert input.is_contiguous()
assert len(input.shape) == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0
# FP8 quantization parameters
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = finfo.min
g, m, k = input.shape
k = k // 2
# Get block/grid/stage/warp
expert_num = len(masked_m)
if expert_num < 4:
BLOCK_NUM_PER_EXPERT = 64
else:
BLOCK_NUM_PER_EXPERT = 128
BLOCK = quant_group_size * 4
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(k, BLOCK)
grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
expert_num,
)
_silu_and_mul_post_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
masked_m,
k,
fp8_max,
fp8_min,
BLOCK=BLOCK,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
SCALE_UE8M0=scale_ue8m0,
)
output_scale = output_scale.transpose(1, 2)[:, :m, :]
check_sf_layout(
output_scale,
m,
k,
(1, 128),
g,
tma_stride_check=True,
)
return output_scale
@triton.jit
def _per_token_quant_and_transform_kernel(
input_ptr,
stride_input_0,
stride_input_1,
output_ptr,
stride_output_0,
stride_output_1,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
token_num_cur_expert,
size_k,
fp8_max,
fp8_min,
BLOCK: tl.constexpr,
NUM_STAGE: tl.constexpr,
SCALE_UE8M0: tl.constexpr,
):
tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)
block_num_per_expert = tl.num_programs(1)
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4)
input_ptr_offs = input_ptr + offs_in_d
output_ptr_offs = output_ptr + offs_in_d
output_scale_offs = (output_scale_ptr +
hidden_dim_block_index * stride_output_scale_0)
for token_index in tl.range(token_id,
token_num_cur_expert,
block_num_per_expert,
num_stages=NUM_STAGE):
output_s_int32 = 0
for pack_index in tl.range(4):
local_mask = offs_in_d + pack_index * 128
act = tl.load(
input_ptr_offs + token_index * stride_input_0 +
pack_index * 128,
mask=local_mask < size_k,
other=0.0,
).to(tl.float32)
_absmax = tl.maximum(tl.max(tl.abs(act)), 1e-10)
output_s = _absmax / fp8_max
if SCALE_UE8M0:
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
output_q = tl.clamp(act / output_s, fp8_min,
fp8_max).to(output_ptr.dtype.element_ty)
output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) <<
(8 * pack_index))
tl.store(
output_ptr_offs + token_index * stride_output_0 +
pack_index * 128,
output_q,
mask=local_mask < size_k,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_1,
output_s_int32,
)
def per_token_quant_and_transform(
input: torch.Tensor,
quant_group_size: int = 128,
scale_ue8m0: bool = True,
):
"""
input shape [g, m, k]
output shape [g, m, k // 2], dtype fp8
output_scale [g, k // 4, m // 2 // 128], dtype int32
quant_group_size int
masked_m shape [g]
"""
assert input.is_contiguous()
assert len(input.shape) == 2
assert input.shape[-1] % 2 == 0
# FP8 quantization parameters
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max
m, k = input.shape
# Create output
output = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda")
# Create output scale
alignment = 4
scale_k = ceil_div(k, quant_group_size)
m_padded = align(m, alignment)
scale_k_padded = align(scale_k, alignment)
output_scale = torch.zeros((scale_k_padded // 4, m_padded),
dtype=torch.int32,
device='cuda')
# Get block/grid/stage/warp
BLOCK_NUM_PER_EXPERT = 64
BLOCK = quant_group_size * 4
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(k, BLOCK)
grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
1,
)
_per_token_quant_and_transform_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
m,
k,
fp8_max,
fp8_min,
BLOCK=BLOCK,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
SCALE_UE8M0=scale_ue8m0,
)
output_scale = output_scale.transpose(0, 1)[:m, :]
check_sf_layout(
output_scale,
m,
k,
(1, 128),
num_groups=None,
tma_stride_check=True,
)
return output, output_scale