TensorRT-LLMs/tensorrt_llm/quantization/utils/fp4_utils.py
hlu1 8207d5fd39
[None] [feat] Add model gpt-oss (#6645)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
2025-08-07 03:04:18 -04:00

207 lines
5.9 KiB
Python

from enum import IntEnum
import torch
# The declarations must be aligned with thUtils.h
SF_DTYPE = torch.uint8
FLOAT4_E2M1X2 = torch.uint8
# For GEMM autotuning.
# Taken from https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/include/tensorrt_llm/runtime//modelConfig.h#L38
# TODO: move to model config, tune for blackwell hardware
FP4_BUCKETS = [64, 128, 256, 512, 1024]
# Export
float4_e2m1x2 = FLOAT4_E2M1X2
float4_sf_dtype = SF_DTYPE
fp4_buckets = FP4_BUCKETS
__all__ = ['float4_e2m1x2', 'float4_sf_dtype', 'pad_up', 'fp4_buckets']
def pad_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
class FP4GemmType(IntEnum):
W4A4_NVFP4_NVFP4 = 0
W4A8_MXFP4_MXFP8 = 1
def get_fp4_shape(input_shape, sf_vec_size, is_swizzled_layout=True):
m = 1
for i in range(len(input_shape) - 1):
m *= input_shape[i]
output_shape = [i for i in input_shape]
output_shape[-1] //= 2
scale_shape = pad_up(m, 128) * pad_up(
input_shape[-1] // sf_vec_size,
4) if is_swizzled_layout else m * (input_shape[-1] // sf_vec_size)
return output_shape, scale_shape
def get_reorder_rows_for_gated_act_gemm_row_indices(x) -> torch.Tensor:
"""
Reorders rows in the gemm/MOE_gemm weight matrix for min-latency
[r0, r1, r2, r3, ..., rN/2, r(N/2+1), .. r(N-1)]
to
[r0, rN/2, r1, rN/2+1, ..., r(N/2-1), r(N-1)]
"""
M = x.shape[0]
assert M % 2 == 0, f"x.shape[0] must be even, not {M}"
row_indices = torch.arange(M, dtype=torch.long)
# We split into top half and bottom half, but if M is odd,
# the bottom half is one row larger.
top = row_indices[:(M + 1) // 2] # round up
bot = row_indices[(M + 1) // 2:] # remainder
# Create the output
permuted_row_indices = torch.empty_like(row_indices)
# We'll place rows of `top` and `bot` in alternation
permuted_row_indices[0::2] = top
permuted_row_indices[1::2] = bot
return permuted_row_indices
def reorder_rows_for_gated_act_gemm(x):
"""
PyTorch implementation of trt-llm gen `reorderRowsForGatedActGemm`
"""
row_indices = get_reorder_rows_for_gated_act_gemm_row_indices(x)
permute = lambda x: x[row_indices]
return permute(x)
# yapf: disable
srcToDstBlk16RowMap = [
0, 8,
1, 9,
2, 10,
3, 11,
4, 12,
5, 13,
6, 14,
7, 15
]
srcToDstBlk32RowMap = [
0, 8, 16, 24,
1, 9, 17, 25,
2, 10, 18, 26,
3, 11, 19, 27,
4, 12, 20, 28,
5, 13, 21, 29,
6, 14, 22, 30,
7, 15, 23, 31
]
# yapf: enable
def get_shuffle_block_size(epilogue_tile_m: int) -> int:
shuffle_block_size = 16
if epilogue_tile_m % 128 == 0:
shuffle_block_size = 32
return shuffle_block_size
def get_shuffle_matrix_a_row_indices(input_tensor: torch.Tensor,
epilogue_tile_m: int) -> torch.Tensor:
"""
Higher-level PyTorch approach to reorder the rows in blocks of size 16 or 32.
- We do NOT try to handle custom e2m1 memory usage (i.e. no 'K/2' bytes).
- Instead, we purely reorder rows in a standard PyTorch shape [M, K].
"""
# M from the input
M = input_tensor.shape[0]
# Choose block size 16 or 32
shuffle_block_size = get_shuffle_block_size(epilogue_tile_m)
row_map = (srcToDstBlk16RowMap
if shuffle_block_size == 16 else srcToDstBlk32RowMap)
assert M % shuffle_block_size == 0, f"input_tensor.shape[0] must be multiples of {shuffle_block_size}"
# row_indices[new_row] = old_row
# so row_indices is an array of size M telling us from which old_row
# the new_row should be taken.
row_indices = torch.empty(M, dtype=torch.long)
for old_row in range(M):
block_idx = old_row // shuffle_block_size
row_in_block = old_row % shuffle_block_size
mapped_row_in_block = row_map[row_in_block]
new_row = block_idx * shuffle_block_size + mapped_row_in_block
row_indices[new_row] = old_row
return row_indices
def shuffle_matrix_a(input_tensor: torch.Tensor,
epilogue_tile_m: int) -> torch.Tensor:
"""
PyTorch equivalent of trtllm-gen `shuffleMatrixA`
"""
row_indices = get_shuffle_matrix_a_row_indices(input_tensor,
epilogue_tile_m)
return torch.ops.trtllm.shuffle_matrix(input_tensor,
row_indices.to(input_tensor.device))
def get_shuffle_matrix_sf_a_row_indices(
input_tensor: torch.Tensor,
epilogue_tile_m: int,
num_elts_per_sf: int = 16) -> torch.Tensor:
assert input_tensor.dtype == float4_sf_dtype
assert num_elts_per_sf == 16 or num_elts_per_sf == 32
assert input_tensor.dim(
) == 2, f"input_tensor should be a 2D tensor, not {input_tensor.dim()}"
# M, K from the input
M, K = input_tensor.shape
assert M % 128 == 0
assert K % 4 == 0
row_indices = get_shuffle_matrix_a_row_indices(input_tensor,
epilogue_tile_m)
return row_indices
def shuffle_matrix_sf_a(
input_tensor: torch.Tensor,
epilogue_tile_m: int,
num_elts_per_sf: int = 16,
):
"""
Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
`shuffleMatrixSfA` expects the input to be in 128x4 layout and then
apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
layout.
This function expects the input to be in linear layout. It's done this
way because the scaling factors in the NVFP4 checkpoints are quantized
and are in linear layout.
This function doesn't add padding.
"""
row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor,
epilogue_tile_m)
w_shuffled = torch.ops.trtllm.shuffle_matrix(
input_tensor, row_indices.to(input_tensor.device))
# 128x4
return torch.ops.trtllm.block_scale_interleave(w_shuffled)