mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
521 lines
17 KiB
Python
521 lines
17 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from tensorrt_llm._utils import nvtx_range
|
|
|
|
|
|
def ceil_div(x: int, y: int) -> int:
|
|
"""
|
|
Perform ceiling division of two integers.
|
|
|
|
Args:
|
|
x: the dividend.
|
|
y: the divisor.
|
|
|
|
Returns:
|
|
The result of the ceiling division.
|
|
"""
|
|
return (x + y - 1) // y
|
|
|
|
|
|
def align(x: int, y: int) -> int:
|
|
return ceil_div(x, y) * y
|
|
|
|
|
|
def ceil_to_ue8m0(x: torch.Tensor):
|
|
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
|
|
|
|
|
@nvtx_range("[DG] quantization")
|
|
@torch.compile(dynamic=True)
|
|
def per_token_cast_to_fp8_e8m0(
|
|
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if x.dim() == 2:
|
|
assert x.size(1) % 128 == 0
|
|
m, n = x.shape
|
|
x_view = x.view(m, -1, 128)
|
|
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
|
sf = ceil_to_ue8m0(x_amax / 448.0)
|
|
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(
|
|
m, n), sf
|
|
else:
|
|
assert x.size(2) % 128 == 0
|
|
g, m, n = x.shape
|
|
x_view = x.view(g, m, -1, 128)
|
|
x_amax = x_view.abs().float().amax(dim=3).view(g, m, -1).clamp(1e-4)
|
|
sf = ceil_to_ue8m0(x_amax / 448.0)
|
|
return (x_view * (1.0 / sf.unsqueeze(3))).to(torch.float8_e4m3fn).view(
|
|
g, m, n), sf
|
|
|
|
|
|
def per_block_cast_to_fp8_e8m0(
|
|
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if x.dim() == 2:
|
|
m, n = x.shape
|
|
x_padded = torch.zeros((align(m, 128), align(n, 128)),
|
|
dtype=x.dtype,
|
|
device=x.device)
|
|
x_padded[:m, :n] = x
|
|
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
|
sf = ceil_to_ue8m0(x_amax / 448.0)
|
|
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
|
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
|
x_view.size(0), x_view.size(2))
|
|
else:
|
|
g, m, n = x.shape
|
|
x_padded = torch.zeros((g, align(m, 128), align(n, 128)),
|
|
dtype=x.dtype,
|
|
device=x.device)
|
|
x_padded[:, :m, :n] = x
|
|
x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128)
|
|
x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4)
|
|
sf = ceil_to_ue8m0(x_amax / 448.0)
|
|
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
|
return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view(
|
|
x_view.size(0), x_view.size(1), x_view.size(3))
|
|
|
|
|
|
def resmooth_to_fp8_e8m0(weight: torch.Tensor,
|
|
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
weight = weight.cuda()
|
|
sf = sf.cuda()
|
|
if weight.dim() == 2:
|
|
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
|
|
128, dim=1)[:weight.shape[0], :weight.shape[1]]
|
|
else:
|
|
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
|
|
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
|
|
return per_block_cast_to_fp8_e8m0(x)
|
|
|
|
|
|
def get_m_alignment_for_contiguous_layout():
|
|
return 128
|
|
|
|
|
|
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
|
tma_alignment_bytes = 16
|
|
assert tma_alignment_bytes % element_size == 0
|
|
alignment = tma_alignment_bytes // element_size
|
|
return align(x, alignment)
|
|
|
|
|
|
def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor:
|
|
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
|
assert x.dtype == torch.float and x.dim() in (2, 3)
|
|
|
|
# First, convert into UE8M0 `uint8_t`
|
|
ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8)
|
|
|
|
# Second, make padded packed tensors
|
|
mn, k = x.shape[-2], x.shape[-1]
|
|
remove_dim = False
|
|
if x.dim() == 2:
|
|
x, remove_dim = x.unsqueeze(0), True
|
|
b = x.shape[0]
|
|
aligned_mn = get_tma_aligned_size(mn, 4)
|
|
aligned_k = align(k, 4)
|
|
padded = torch.zeros((b, aligned_mn, aligned_k),
|
|
device=x.device,
|
|
dtype=torch.uint8)
|
|
padded[:, :mn, :k] = ue8m0_tensor
|
|
padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn,
|
|
aligned_k // 4)
|
|
|
|
# Finally, transpose
|
|
transposed = torch.transpose(
|
|
torch.empty((b, aligned_k // 4, aligned_mn),
|
|
device=x.device,
|
|
dtype=torch.int), 1, 2)
|
|
transposed[:, :, :] = padded
|
|
aligned_x = transposed[:, :mn, :]
|
|
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
|
|
|
|
|
def check_sf_layout(sf: torch.Tensor,
|
|
mn: int,
|
|
k: int,
|
|
gran: Tuple[int, int],
|
|
num_groups: Optional[int],
|
|
tma_stride_check: bool = False,
|
|
type_check: Optional[torch.dtype] = None) -> torch.Tensor:
|
|
# Type check
|
|
if type_check is not None:
|
|
assert sf.dtype == type_check
|
|
|
|
# Always do shape checks
|
|
assert sf.dtype in (torch.float, torch.int)
|
|
assert sf.dim() == int(num_groups is not None) + 2
|
|
if num_groups is not None:
|
|
assert sf.size(-3) == num_groups
|
|
assert sf.size(-2) == ceil_div(mn, gran[0])
|
|
assert sf.size(-1) == ceil_div(
|
|
k, gran[1] * (1 if sf.dtype == torch.float else 4))
|
|
|
|
# TMA stride checks: TMA aligned and MN-major
|
|
if tma_stride_check:
|
|
if num_groups is not None:
|
|
assert sf.stride(-3) == sf.stride(-1) * sf.size(-1)
|
|
assert sf.stride(-2) == 1
|
|
assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())
|
|
|
|
return sf
|
|
|
|
|
|
@nvtx_range("[DG] transform_sf_into_required_layout")
|
|
def transform_sf_into_required_layout(sf: torch.Tensor,
|
|
mn: int,
|
|
k: int,
|
|
recipe: Tuple[int, int, int],
|
|
num_groups: Optional[int] = None,
|
|
is_sfa: bool = False):
|
|
gran = (recipe[0 if is_sfa else 1], recipe[2])
|
|
|
|
should_skip_transform = ((sf.dtype == torch.int and gran == (1, 128))
|
|
or (sf.dtype == torch.int and gran == (128, 128)))
|
|
|
|
if not should_skip_transform:
|
|
# Pre-transform checks
|
|
check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups)
|
|
|
|
# (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
|
if sf.dtype == torch.float and gran == (1, 128):
|
|
sf = get_col_major_tma_aligned_packed_tensor(sf)
|
|
return check_sf_layout(sf,
|
|
mn=mn,
|
|
k=k,
|
|
gran=(1, 128),
|
|
num_groups=num_groups,
|
|
tma_stride_check=True,
|
|
type_check=torch.int)
|
|
|
|
# (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
|
if sf.dtype == torch.float and gran == (128, 128):
|
|
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
|
sf = get_col_major_tma_aligned_packed_tensor(sf)
|
|
return check_sf_layout(sf,
|
|
mn=mn,
|
|
k=k,
|
|
gran=(1, 128),
|
|
num_groups=num_groups,
|
|
tma_stride_check=True,
|
|
type_check=torch.int)
|
|
|
|
if should_skip_transform:
|
|
# TODO: add transpose kernel if SF layout is not satisfied
|
|
return check_sf_layout(sf,
|
|
mn=mn,
|
|
k=k,
|
|
gran=(1, 128),
|
|
num_groups=num_groups,
|
|
tma_stride_check=True,
|
|
type_check=torch.int)
|
|
|
|
assert False, f'Unknown cases: {sf.dtype=}, {gran=}'
|
|
|
|
|
|
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
|
|
@triton.jit
|
|
def _silu_and_mul_post_quant_kernel(
|
|
input_ptr,
|
|
stride_input_0,
|
|
stride_input_1,
|
|
stride_input_2,
|
|
output_ptr,
|
|
stride_output_0,
|
|
stride_output_1,
|
|
stride_output_2,
|
|
output_scale_ptr,
|
|
stride_output_scale_0,
|
|
stride_output_scale_1,
|
|
stride_output_scale_2,
|
|
masked_m_ptr,
|
|
size_k,
|
|
fp8_max,
|
|
fp8_min,
|
|
BLOCK: tl.constexpr,
|
|
NUM_STAGE: tl.constexpr,
|
|
SCALE_UE8M0: tl.constexpr,
|
|
):
|
|
expert_id = tl.program_id(2)
|
|
token_id = tl.program_id(1)
|
|
hidden_dim_block_index = tl.program_id(0)
|
|
|
|
block_num_per_expert = tl.num_programs(1)
|
|
|
|
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
|
|
|
|
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
|
|
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
|
|
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
|
|
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
|
|
|
|
offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4)
|
|
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
|
|
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
|
|
output_scale_offs = (output_scale_ptr + expert_id * stride_output_scale_0 +
|
|
hidden_dim_block_index * stride_output_scale_1)
|
|
|
|
for token_index in tl.range(token_id,
|
|
token_num_cur_expert,
|
|
block_num_per_expert,
|
|
num_stages=NUM_STAGE):
|
|
output_s_int32 = 0
|
|
for pack_index in tl.range(4):
|
|
local_mask = offs_in_d + pack_index * 128
|
|
up = tl.load(
|
|
input_ptr_offs + token_index * stride_input_1 +
|
|
pack_index * 128,
|
|
mask=local_mask < size_k,
|
|
other=0.0,
|
|
)
|
|
gate = tl.load(
|
|
input_ptr_offs + token_index * stride_input_1 + size_k +
|
|
pack_index * 128,
|
|
mask=local_mask < size_k,
|
|
other=0.0,
|
|
).to(tl.float32)
|
|
gate = gate / (1 + tl.exp(-gate))
|
|
gate = gate.to(input_ptr.dtype.element_ty)
|
|
gate_up = up * gate
|
|
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
|
|
output_s = _absmax / fp8_max
|
|
if SCALE_UE8M0:
|
|
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
|
|
output_q = tl.clamp(gate_up / output_s, fp8_min,
|
|
fp8_max).to(output_ptr.dtype.element_ty)
|
|
output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) <<
|
|
(8 * pack_index))
|
|
tl.store(
|
|
output_ptr_offs + token_index * stride_output_1 +
|
|
pack_index * 128,
|
|
output_q,
|
|
mask=local_mask < size_k,
|
|
)
|
|
tl.store(
|
|
output_scale_offs + token_index * stride_output_scale_2,
|
|
output_s_int32,
|
|
)
|
|
|
|
|
|
def silu_and_mul_masked_post_quant_fwd(
|
|
output: torch.Tensor,
|
|
output_scale: torch.Tensor,
|
|
input: torch.Tensor,
|
|
quant_group_size: int,
|
|
masked_m: torch.Tensor,
|
|
scale_ue8m0: bool = False,
|
|
):
|
|
"""
|
|
input shape [g, m, k]
|
|
output shape [g, m, k // 2], dtype fp8
|
|
output_scale [g, k // 4, m // 2 // 128], dtype int32
|
|
quant_group_size int
|
|
masked_m shape [g]
|
|
"""
|
|
|
|
assert input.is_contiguous()
|
|
assert len(input.shape) == 3
|
|
assert input.shape[0] == masked_m.shape[0]
|
|
assert input.shape[-1] % 2 == 0
|
|
|
|
# FP8 quantization parameters
|
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
fp8_max = finfo.max
|
|
fp8_min = finfo.min
|
|
|
|
g, m, k = input.shape
|
|
k = k // 2
|
|
|
|
# Get block/grid/stage/warp
|
|
expert_num = len(masked_m)
|
|
|
|
if expert_num < 4:
|
|
BLOCK_NUM_PER_EXPERT = 64
|
|
else:
|
|
BLOCK_NUM_PER_EXPERT = 128
|
|
|
|
BLOCK = quant_group_size * 4
|
|
num_warps = 1
|
|
NUM_STAGES = 6
|
|
hidden_dim_split_block_num = triton.cdiv(k, BLOCK)
|
|
grid = (
|
|
hidden_dim_split_block_num,
|
|
BLOCK_NUM_PER_EXPERT,
|
|
expert_num,
|
|
)
|
|
_silu_and_mul_post_quant_kernel[grid](
|
|
input,
|
|
*input.stride(),
|
|
output,
|
|
*output.stride(),
|
|
output_scale,
|
|
*output_scale.stride(),
|
|
masked_m,
|
|
k,
|
|
fp8_max,
|
|
fp8_min,
|
|
BLOCK=BLOCK,
|
|
NUM_STAGE=NUM_STAGES,
|
|
num_warps=num_warps,
|
|
SCALE_UE8M0=scale_ue8m0,
|
|
)
|
|
output_scale = output_scale.transpose(1, 2)[:, :m, :]
|
|
check_sf_layout(
|
|
output_scale,
|
|
m,
|
|
k,
|
|
(1, 128),
|
|
g,
|
|
tma_stride_check=True,
|
|
)
|
|
return output_scale
|
|
|
|
|
|
@triton.jit
|
|
def _per_token_quant_and_transform_kernel(
|
|
input_ptr,
|
|
stride_input_0,
|
|
stride_input_1,
|
|
output_ptr,
|
|
stride_output_0,
|
|
stride_output_1,
|
|
output_scale_ptr,
|
|
stride_output_scale_0,
|
|
stride_output_scale_1,
|
|
token_num_cur_expert,
|
|
size_k,
|
|
fp8_max,
|
|
fp8_min,
|
|
BLOCK: tl.constexpr,
|
|
NUM_STAGE: tl.constexpr,
|
|
SCALE_UE8M0: tl.constexpr,
|
|
):
|
|
tl.program_id(2)
|
|
token_id = tl.program_id(1)
|
|
hidden_dim_block_index = tl.program_id(0)
|
|
|
|
block_num_per_expert = tl.num_programs(1)
|
|
|
|
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
|
|
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
|
|
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
|
|
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
|
|
|
|
offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4)
|
|
input_ptr_offs = input_ptr + offs_in_d
|
|
output_ptr_offs = output_ptr + offs_in_d
|
|
output_scale_offs = (output_scale_ptr +
|
|
hidden_dim_block_index * stride_output_scale_0)
|
|
|
|
for token_index in tl.range(token_id,
|
|
token_num_cur_expert,
|
|
block_num_per_expert,
|
|
num_stages=NUM_STAGE):
|
|
output_s_int32 = 0
|
|
for pack_index in tl.range(4):
|
|
local_mask = offs_in_d + pack_index * 128
|
|
act = tl.load(
|
|
input_ptr_offs + token_index * stride_input_0 +
|
|
pack_index * 128,
|
|
mask=local_mask < size_k,
|
|
other=0.0,
|
|
).to(tl.float32)
|
|
_absmax = tl.maximum(tl.max(tl.abs(act)), 1e-10)
|
|
output_s = _absmax / fp8_max
|
|
if SCALE_UE8M0:
|
|
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
|
|
output_q = tl.clamp(act / output_s, fp8_min,
|
|
fp8_max).to(output_ptr.dtype.element_ty)
|
|
output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) <<
|
|
(8 * pack_index))
|
|
tl.store(
|
|
output_ptr_offs + token_index * stride_output_0 +
|
|
pack_index * 128,
|
|
output_q,
|
|
mask=local_mask < size_k,
|
|
)
|
|
tl.store(
|
|
output_scale_offs + token_index * stride_output_scale_1,
|
|
output_s_int32,
|
|
)
|
|
|
|
|
|
def per_token_quant_and_transform(
|
|
input: torch.Tensor,
|
|
quant_group_size: int = 128,
|
|
scale_ue8m0: bool = True,
|
|
):
|
|
"""
|
|
input shape [g, m, k]
|
|
output shape [g, m, k // 2], dtype fp8
|
|
output_scale [g, k // 4, m // 2 // 128], dtype int32
|
|
quant_group_size int
|
|
masked_m shape [g]
|
|
"""
|
|
|
|
assert input.is_contiguous()
|
|
assert len(input.shape) == 2
|
|
assert input.shape[-1] % 2 == 0
|
|
|
|
# FP8 quantization parameters
|
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
fp8_max = finfo.max
|
|
fp8_min = -fp8_max
|
|
|
|
m, k = input.shape
|
|
|
|
# Create output
|
|
output = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda")
|
|
|
|
# Create output scale
|
|
alignment = 4
|
|
scale_k = ceil_div(k, quant_group_size)
|
|
m_padded = align(m, alignment)
|
|
scale_k_padded = align(scale_k, alignment)
|
|
output_scale = torch.zeros((scale_k_padded // 4, m_padded),
|
|
dtype=torch.int32,
|
|
device='cuda')
|
|
|
|
# Get block/grid/stage/warp
|
|
BLOCK_NUM_PER_EXPERT = 64
|
|
|
|
BLOCK = quant_group_size * 4
|
|
num_warps = 1
|
|
NUM_STAGES = 6
|
|
hidden_dim_split_block_num = triton.cdiv(k, BLOCK)
|
|
grid = (
|
|
hidden_dim_split_block_num,
|
|
BLOCK_NUM_PER_EXPERT,
|
|
1,
|
|
)
|
|
_per_token_quant_and_transform_kernel[grid](
|
|
input,
|
|
*input.stride(),
|
|
output,
|
|
*output.stride(),
|
|
output_scale,
|
|
*output_scale.stride(),
|
|
m,
|
|
k,
|
|
fp8_max,
|
|
fp8_min,
|
|
BLOCK=BLOCK,
|
|
NUM_STAGE=NUM_STAGES,
|
|
num_warps=num_warps,
|
|
SCALE_UE8M0=scale_ue8m0,
|
|
)
|
|
output_scale = output_scale.transpose(0, 1)[:m, :]
|
|
check_sf_layout(
|
|
output_scale,
|
|
m,
|
|
k,
|
|
(1, 128),
|
|
num_groups=None,
|
|
tma_stride_check=True,
|
|
)
|
|
return output, output_scale
|