from enum import IntEnum import torch # The declarations must be aligned with thUtils.h SF_DTYPE = torch.uint8 FLOAT4_E2M1X2 = torch.uint8 # For GEMM autotuning. # Taken from https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/include/tensorrt_llm/runtime//modelConfig.h#L38 # TODO: move to model config, tune for blackwell hardware FP4_BUCKETS = [64, 128, 256, 512, 1024] # Export float4_e2m1x2 = FLOAT4_E2M1X2 float4_sf_dtype = SF_DTYPE fp4_buckets = FP4_BUCKETS __all__ = ['float4_e2m1x2', 'float4_sf_dtype', 'pad_up', 'fp4_buckets'] def pad_up(x: int, y: int) -> int: return ((x + y - 1) // y) * y class FP4GemmType(IntEnum): W4A4_NVFP4_NVFP4 = 0 W4A8_MXFP4_MXFP8 = 1 def get_fp4_shape(input_shape, sf_vec_size, is_swizzled_layout=True): m = 1 for i in range(len(input_shape) - 1): m *= input_shape[i] output_shape = [i for i in input_shape] output_shape[-1] //= 2 scale_shape = pad_up(m, 128) * pad_up( input_shape[-1] // sf_vec_size, 4) if is_swizzled_layout else m * (input_shape[-1] // sf_vec_size) return output_shape, scale_shape def get_reorder_rows_for_gated_act_gemm_row_indices(x) -> torch.Tensor: """ Reorders rows in the gemm/MOE_gemm weight matrix for min-latency [r0, r1, r2, r3, ..., rN/2, r(N/2+1), .. r(N-1)] to [r0, rN/2, r1, rN/2+1, ..., r(N/2-1), r(N-1)] """ M = x.shape[0] assert M % 2 == 0, f"x.shape[0] must be even, not {M}" row_indices = torch.arange(M, dtype=torch.long) # We split into top half and bottom half, but if M is odd, # the bottom half is one row larger. top = row_indices[:(M + 1) // 2] # round up bot = row_indices[(M + 1) // 2:] # remainder # Create the output permuted_row_indices = torch.empty_like(row_indices) # We'll place rows of `top` and `bot` in alternation permuted_row_indices[0::2] = top permuted_row_indices[1::2] = bot return permuted_row_indices def reorder_rows_for_gated_act_gemm(x): """ PyTorch implementation of trt-llm gen `reorderRowsForGatedActGemm` """ row_indices = get_reorder_rows_for_gated_act_gemm_row_indices(x) permute = lambda x: x[row_indices] return permute(x) # yapf: disable srcToDstBlk16RowMap = [ 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15 ] srcToDstBlk32RowMap = [ 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27, 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31 ] # yapf: enable def get_shuffle_block_size(epilogue_tile_m: int) -> int: shuffle_block_size = 16 if epilogue_tile_m % 128 == 0: shuffle_block_size = 32 return shuffle_block_size def get_shuffle_matrix_a_row_indices(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: """ Higher-level PyTorch approach to reorder the rows in blocks of size 16 or 32. - We do NOT try to handle custom e2m1 memory usage (i.e. no 'K/2' bytes). - Instead, we purely reorder rows in a standard PyTorch shape [M, K]. """ # M from the input M = input_tensor.shape[0] # Choose block size 16 or 32 shuffle_block_size = get_shuffle_block_size(epilogue_tile_m) row_map = (srcToDstBlk16RowMap if shuffle_block_size == 16 else srcToDstBlk32RowMap) assert M % shuffle_block_size == 0, f"input_tensor.shape[0] must be multiples of {shuffle_block_size}" # row_indices[new_row] = old_row # so row_indices is an array of size M telling us from which old_row # the new_row should be taken. row_indices = torch.empty(M, dtype=torch.long) for old_row in range(M): block_idx = old_row // shuffle_block_size row_in_block = old_row % shuffle_block_size mapped_row_in_block = row_map[row_in_block] new_row = block_idx * shuffle_block_size + mapped_row_in_block row_indices[new_row] = old_row return row_indices def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: """ PyTorch equivalent of trtllm-gen `shuffleMatrixA` """ row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) return torch.ops.trtllm.shuffle_matrix(input_tensor, row_indices.to(input_tensor.device)) def get_shuffle_matrix_sf_a_row_indices( input_tensor: torch.Tensor, epilogue_tile_m: int, num_elts_per_sf: int = 16) -> torch.Tensor: assert input_tensor.dtype == float4_sf_dtype assert num_elts_per_sf == 16 or num_elts_per_sf == 32 assert input_tensor.dim( ) == 2, f"input_tensor should be a 2D tensor, not {input_tensor.dim()}" # M, K from the input M, K = input_tensor.shape assert M % 128 == 0 assert K % 4 == 0 row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) return row_indices def shuffle_matrix_sf_a( input_tensor: torch.Tensor, epilogue_tile_m: int, num_elts_per_sf: int = 16, ): """ Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. `shuffleMatrixSfA` expects the input to be in 128x4 layout and then apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 layout. This function expects the input to be in linear layout. It's done this way because the scaling factors in the NVFP4 checkpoints are quantized and are in linear layout. This function doesn't add padding. """ row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) w_shuffled = torch.ops.trtllm.shuffle_matrix( input_tensor, row_indices.to(input_tensor.device)) # 128x4 return torch.ops.trtllm.block_scale_interleave(w_shuffled)