TensorRT-LLMs/tensorrt_llm/quantization/utils/fp4_utils.py
hlu1 31624b079a
feat: [Deepseek] Add trtllm-gen MOE FP4 MOE backend (#3387)
* Add TRT-LLM Gen MOE to Deepseek

fix fused moe rebase bug.

Fix atol in test_fp4_gemm_quantize.py

fix fused moe rebase bug.

Fix FusedMoe.

Disable 2nd routing kernel preexit

Bump routing reduction to fp32

Disable PDL for fc1

[DEBUG] Lift token limit to 16k

[Bugfix] Token limit to 16k + fp32 routing + tanh

Make fp8 tileN 8

Fix FP8 MoE + Remove redundent temp output for FP4

[FP8-only] Avoid wasting CTAs for activation kernel

fix: unblock FP8 weightloading with trtllm-gen

Remove max_token limit for trtllm-gen path

perf: avoid type-conversion and fill_ from aten

Minor fix

Signed-off-by: Hao Lu <haolu@nvidia.com>

* Fix rebase issues

Signed-off-by: Hao Lu <haolu@nvidia.com>

* Fix compile issue

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* CI clean

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

---------

Signed-off-by: Hao Lu <haolu@nvidia.com>
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
Co-authored-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
2025-04-21 10:01:33 +08:00

170 lines
4.7 KiB
Python

import torch
# The declarations must be aligned with thUtils.h
SF_DTYPE = torch.uint8
FLOAT4_E2M1X2 = torch.uint8
def pad_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
# For GEMM autotuning.
# Taken from https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/include/tensorrt_llm/runtime//modelConfig.h#L38
# TODO: move to model config, tune for blackwell hardware
FP4_BUCKETS = [64, 128, 256, 512, 1024]
# Export
float4_e2m1x2 = FLOAT4_E2M1X2
float4_sf_dtype = SF_DTYPE
fp4_buckets = FP4_BUCKETS
__all__ = ['float4_e2m1x2', 'float4_sf_dtype', 'pad_up', 'fp4_buckets']
def get_fp4_shape(input_shape, sf_vec_size):
m = 1
for i in range(len(input_shape) - 1):
m *= input_shape[i]
output_shape = [i for i in input_shape]
output_shape[-1] //= 2
scale_shape = pad_up(m, 128) * pad_up(input_shape[-1] // sf_vec_size, 4)
return output_shape, scale_shape
def reorder_rows_for_gated_act_gemm(x):
"""
PyTorch implementation of trt-llm gen `reorderRowsForGatedActGemm`
Reorders rows in the gemm/MOE_gemm weight matrix for min-latency
[r0, r1, r2, r3, ..., rN/2, r(N/2+1), .. r(N-1)]
to
[r0, rN/2, r1, rN/2+1, ..., r(N/2-1), r(N-1)]
"""
assert x.dim() == 2, f"x should be a 2D tensor, not {x.dim()}"
M, K = x.shape
assert M % 2 == 0, f"x.shape[0] must be even, not {M}"
# We split into top half and bottom half, but if M is odd,
# the bottom half is one row larger.
top = x[:(M + 1) // 2] # round up
bot = x[(M + 1) // 2:] # remainder
# Create the output
out = torch.empty_like(x)
# We'll place rows of `top` and `bot` in alternation
out[0::2] = top
out[1::2] = bot
return out
# yapf: disable
srcToDstBlk16RowMap = [
0, 8,
1, 9,
2, 10,
3, 11,
4, 12,
5, 13,
6, 14,
7, 15
]
srcToDstBlk32RowMap = [
0, 8, 16, 24,
1, 9, 17, 25,
2, 10, 18, 26,
3, 11, 19, 27,
4, 12, 20, 28,
5, 13, 21, 29,
6, 14, 22, 30,
7, 15, 23, 31
]
# yapf: enable
def get_shuffle_block_size(epilogue_tile_m: int) -> int:
shuffle_block_size = 16
if epilogue_tile_m % 128 == 0:
shuffle_block_size = 32
return shuffle_block_size
def shuffle_matrix_a(input_tensor: torch.Tensor,
epilogue_tile_m: int) -> torch.Tensor:
"""
PyTorch equivalent of trtllm-gen `shuffleMatrixA`
Higher-level PyTorch approach to reorder the rows in blocks of size 16 or 32.
- We do NOT try to handle custom e2m1 memory usage (i.e. no 'K/2' bytes).
- Instead, we purely reorder rows in a standard PyTorch shape [M, K].
"""
assert input_tensor.dim(
) == 2, f"input_tensor should be a 2D tensor, not {input_tensor.dim()}"
# M, K from the input
M, K = input_tensor.shape
# Choose block size 16 or 32
shuffle_block_size = get_shuffle_block_size(epilogue_tile_m)
row_map = (srcToDstBlk16RowMap
if shuffle_block_size == 16 else srcToDstBlk32RowMap)
assert M % shuffle_block_size == 0, f"input_tensor.shape[0] must be multiples of {shuffle_block_size}"
# row_indices[new_row] = old_row
# so row_indices is an array of size M telling us from which old_row
# the new_row should be taken.
row_indices = torch.empty(M, dtype=torch.long, device=input_tensor.device)
for old_row in range(M):
block_idx = old_row // shuffle_block_size
row_in_block = old_row % shuffle_block_size
mapped_row_in_block = row_map[row_in_block]
new_row = block_idx * shuffle_block_size + mapped_row_in_block
row_indices[new_row] = old_row
# Then gather rows in that new order
# out[new_row, :] = input_tensor[old_row, :]
out = input_tensor[row_indices, :]
return out
def shuffle_matrix_sf_a(
input_tensor: torch.Tensor,
epilogue_tile_m: int,
num_elts_per_sf: int = 16,
):
"""
Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
`shuffleMatrixSfA` expects the input to be in 128x4 layout and then
apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
layout.
This function expects the input to be in linear layout. It's done this
way because the scaling factors in the NVFP4 checkpoints are quantized
and are in linear layout.
This function doesn't add padding.
"""
assert input_tensor.dtype == float4_sf_dtype
assert num_elts_per_sf == 16
assert input_tensor.dim(
) == 2, f"input_tensor should be a 2D tensor, not {input_tensor.dim()}"
# M, K from the input
M, K = input_tensor.shape
assert M % 128 == 0
assert K % 4 == 0
w_shuffled = shuffle_matrix_a(input_tensor, epilogue_tile_m)
# 128x4
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(w_shuffled)