mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Add TRT-LLM Gen MOE to Deepseek fix fused moe rebase bug. Fix atol in test_fp4_gemm_quantize.py fix fused moe rebase bug. Fix FusedMoe. Disable 2nd routing kernel preexit Bump routing reduction to fp32 Disable PDL for fc1 [DEBUG] Lift token limit to 16k [Bugfix] Token limit to 16k + fp32 routing + tanh Make fp8 tileN 8 Fix FP8 MoE + Remove redundent temp output for FP4 [FP8-only] Avoid wasting CTAs for activation kernel fix: unblock FP8 weightloading with trtllm-gen Remove max_token limit for trtllm-gen path perf: avoid type-conversion and fill_ from aten Minor fix Signed-off-by: Hao Lu <haolu@nvidia.com> * Fix rebase issues Signed-off-by: Hao Lu <haolu@nvidia.com> * Fix compile issue Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> * CI clean Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> --------- Signed-off-by: Hao Lu <haolu@nvidia.com> Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> Co-authored-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
170 lines
4.7 KiB
Python
170 lines
4.7 KiB
Python
import torch
|
|
|
|
# The declarations must be aligned with thUtils.h
|
|
SF_DTYPE = torch.uint8
|
|
FLOAT4_E2M1X2 = torch.uint8
|
|
|
|
|
|
def pad_up(x: int, y: int) -> int:
|
|
return ((x + y - 1) // y) * y
|
|
|
|
|
|
# For GEMM autotuning.
|
|
# Taken from https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/include/tensorrt_llm/runtime//modelConfig.h#L38
|
|
# TODO: move to model config, tune for blackwell hardware
|
|
FP4_BUCKETS = [64, 128, 256, 512, 1024]
|
|
|
|
# Export
|
|
float4_e2m1x2 = FLOAT4_E2M1X2
|
|
float4_sf_dtype = SF_DTYPE
|
|
fp4_buckets = FP4_BUCKETS
|
|
|
|
__all__ = ['float4_e2m1x2', 'float4_sf_dtype', 'pad_up', 'fp4_buckets']
|
|
|
|
|
|
def get_fp4_shape(input_shape, sf_vec_size):
|
|
m = 1
|
|
for i in range(len(input_shape) - 1):
|
|
m *= input_shape[i]
|
|
|
|
output_shape = [i for i in input_shape]
|
|
output_shape[-1] //= 2
|
|
|
|
scale_shape = pad_up(m, 128) * pad_up(input_shape[-1] // sf_vec_size, 4)
|
|
return output_shape, scale_shape
|
|
|
|
|
|
def reorder_rows_for_gated_act_gemm(x):
|
|
"""
|
|
PyTorch implementation of trt-llm gen `reorderRowsForGatedActGemm`
|
|
|
|
Reorders rows in the gemm/MOE_gemm weight matrix for min-latency
|
|
[r0, r1, r2, r3, ..., rN/2, r(N/2+1), .. r(N-1)]
|
|
to
|
|
[r0, rN/2, r1, rN/2+1, ..., r(N/2-1), r(N-1)]
|
|
"""
|
|
assert x.dim() == 2, f"x should be a 2D tensor, not {x.dim()}"
|
|
M, K = x.shape
|
|
assert M % 2 == 0, f"x.shape[0] must be even, not {M}"
|
|
|
|
# We split into top half and bottom half, but if M is odd,
|
|
# the bottom half is one row larger.
|
|
top = x[:(M + 1) // 2] # round up
|
|
bot = x[(M + 1) // 2:] # remainder
|
|
|
|
# Create the output
|
|
out = torch.empty_like(x)
|
|
|
|
# We'll place rows of `top` and `bot` in alternation
|
|
out[0::2] = top
|
|
out[1::2] = bot
|
|
|
|
return out
|
|
|
|
|
|
# yapf: disable
|
|
srcToDstBlk16RowMap = [
|
|
0, 8,
|
|
1, 9,
|
|
2, 10,
|
|
3, 11,
|
|
4, 12,
|
|
5, 13,
|
|
6, 14,
|
|
7, 15
|
|
]
|
|
|
|
srcToDstBlk32RowMap = [
|
|
0, 8, 16, 24,
|
|
1, 9, 17, 25,
|
|
2, 10, 18, 26,
|
|
3, 11, 19, 27,
|
|
4, 12, 20, 28,
|
|
5, 13, 21, 29,
|
|
6, 14, 22, 30,
|
|
7, 15, 23, 31
|
|
]
|
|
# yapf: enable
|
|
|
|
|
|
def get_shuffle_block_size(epilogue_tile_m: int) -> int:
|
|
shuffle_block_size = 16
|
|
if epilogue_tile_m % 128 == 0:
|
|
shuffle_block_size = 32
|
|
return shuffle_block_size
|
|
|
|
|
|
def shuffle_matrix_a(input_tensor: torch.Tensor,
|
|
epilogue_tile_m: int) -> torch.Tensor:
|
|
"""
|
|
PyTorch equivalent of trtllm-gen `shuffleMatrixA`
|
|
|
|
Higher-level PyTorch approach to reorder the rows in blocks of size 16 or 32.
|
|
- We do NOT try to handle custom e2m1 memory usage (i.e. no 'K/2' bytes).
|
|
- Instead, we purely reorder rows in a standard PyTorch shape [M, K].
|
|
"""
|
|
assert input_tensor.dim(
|
|
) == 2, f"input_tensor should be a 2D tensor, not {input_tensor.dim()}"
|
|
|
|
# M, K from the input
|
|
M, K = input_tensor.shape
|
|
|
|
# Choose block size 16 or 32
|
|
shuffle_block_size = get_shuffle_block_size(epilogue_tile_m)
|
|
row_map = (srcToDstBlk16RowMap
|
|
if shuffle_block_size == 16 else srcToDstBlk32RowMap)
|
|
|
|
assert M % shuffle_block_size == 0, f"input_tensor.shape[0] must be multiples of {shuffle_block_size}"
|
|
|
|
# row_indices[new_row] = old_row
|
|
# so row_indices is an array of size M telling us from which old_row
|
|
# the new_row should be taken.
|
|
row_indices = torch.empty(M, dtype=torch.long, device=input_tensor.device)
|
|
|
|
for old_row in range(M):
|
|
block_idx = old_row // shuffle_block_size
|
|
row_in_block = old_row % shuffle_block_size
|
|
mapped_row_in_block = row_map[row_in_block]
|
|
|
|
new_row = block_idx * shuffle_block_size + mapped_row_in_block
|
|
|
|
row_indices[new_row] = old_row
|
|
|
|
# Then gather rows in that new order
|
|
# out[new_row, :] = input_tensor[old_row, :]
|
|
out = input_tensor[row_indices, :]
|
|
|
|
return out
|
|
|
|
|
|
def shuffle_matrix_sf_a(
|
|
input_tensor: torch.Tensor,
|
|
epilogue_tile_m: int,
|
|
num_elts_per_sf: int = 16,
|
|
):
|
|
"""
|
|
Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
|
|
`shuffleMatrixSfA` expects the input to be in 128x4 layout and then
|
|
apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
|
|
layout.
|
|
This function expects the input to be in linear layout. It's done this
|
|
way because the scaling factors in the NVFP4 checkpoints are quantized
|
|
and are in linear layout.
|
|
This function doesn't add padding.
|
|
"""
|
|
assert input_tensor.dtype == float4_sf_dtype
|
|
assert num_elts_per_sf == 16
|
|
|
|
assert input_tensor.dim(
|
|
) == 2, f"input_tensor should be a 2D tensor, not {input_tensor.dim()}"
|
|
|
|
# M, K from the input
|
|
M, K = input_tensor.shape
|
|
assert M % 128 == 0
|
|
assert K % 4 == 0
|
|
|
|
w_shuffled = shuffle_matrix_a(input_tensor, epilogue_tile_m)
|
|
|
|
# 128x4
|
|
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(w_shuffled)
|