import torch # The declarations must be aligned with thUtils.h SF_DTYPE = torch.uint8 FLOAT4_E2M1X2 = torch.uint8 def pad_up(x: int, y: int) -> int: return ((x + y - 1) // y) * y # For GEMM autotuning. # Taken from https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/include/tensorrt_llm/runtime//modelConfig.h#L38 # TODO: move to model config, tune for blackwell hardware FP4_BUCKETS = [64, 128, 256, 512, 1024] # Export float4_e2m1x2 = FLOAT4_E2M1X2 float4_sf_dtype = SF_DTYPE fp4_buckets = FP4_BUCKETS __all__ = ['float4_e2m1x2', 'float4_sf_dtype', 'pad_up', 'fp4_buckets'] def get_fp4_shape(input_shape, sf_vec_size): m = 1 for i in range(len(input_shape) - 1): m *= input_shape[i] output_shape = [i for i in input_shape] output_shape[-1] //= 2 scale_shape = pad_up(m, 128) * pad_up(input_shape[-1] // sf_vec_size, 4) return output_shape, scale_shape def reorder_rows_for_gated_act_gemm(x): """ PyTorch implementation of trt-llm gen `reorderRowsForGatedActGemm` Reorders rows in the gemm/MOE_gemm weight matrix for min-latency [r0, r1, r2, r3, ..., rN/2, r(N/2+1), .. r(N-1)] to [r0, rN/2, r1, rN/2+1, ..., r(N/2-1), r(N-1)] """ assert x.dim() == 2, f"x should be a 2D tensor, not {x.dim()}" M, K = x.shape assert M % 2 == 0, f"x.shape[0] must be even, not {M}" # We split into top half and bottom half, but if M is odd, # the bottom half is one row larger. top = x[:(M + 1) // 2] # round up bot = x[(M + 1) // 2:] # remainder # Create the output out = torch.empty_like(x) # We'll place rows of `top` and `bot` in alternation out[0::2] = top out[1::2] = bot return out # yapf: disable srcToDstBlk16RowMap = [ 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15 ] srcToDstBlk32RowMap = [ 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27, 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31 ] # yapf: enable def get_shuffle_block_size(epilogue_tile_m: int) -> int: shuffle_block_size = 16 if epilogue_tile_m % 128 == 0: shuffle_block_size = 32 return shuffle_block_size def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: """ PyTorch equivalent of trtllm-gen `shuffleMatrixA` Higher-level PyTorch approach to reorder the rows in blocks of size 16 or 32. - We do NOT try to handle custom e2m1 memory usage (i.e. no 'K/2' bytes). - Instead, we purely reorder rows in a standard PyTorch shape [M, K]. """ assert input_tensor.dim( ) == 2, f"input_tensor should be a 2D tensor, not {input_tensor.dim()}" # M, K from the input M, K = input_tensor.shape # Choose block size 16 or 32 shuffle_block_size = get_shuffle_block_size(epilogue_tile_m) row_map = (srcToDstBlk16RowMap if shuffle_block_size == 16 else srcToDstBlk32RowMap) assert M % shuffle_block_size == 0, f"input_tensor.shape[0] must be multiples of {shuffle_block_size}" # row_indices[new_row] = old_row # so row_indices is an array of size M telling us from which old_row # the new_row should be taken. row_indices = torch.empty(M, dtype=torch.long, device=input_tensor.device) for old_row in range(M): block_idx = old_row // shuffle_block_size row_in_block = old_row % shuffle_block_size mapped_row_in_block = row_map[row_in_block] new_row = block_idx * shuffle_block_size + mapped_row_in_block row_indices[new_row] = old_row # Then gather rows in that new order # out[new_row, :] = input_tensor[old_row, :] out = input_tensor[row_indices, :] return out def shuffle_matrix_sf_a( input_tensor: torch.Tensor, epilogue_tile_m: int, num_elts_per_sf: int = 16, ): """ Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. `shuffleMatrixSfA` expects the input to be in 128x4 layout and then apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 layout. This function expects the input to be in linear layout. It's done this way because the scaling factors in the NVFP4 checkpoints are quantized and are in linear layout. This function doesn't add padding. """ assert input_tensor.dtype == float4_sf_dtype assert num_elts_per_sf == 16 assert input_tensor.dim( ) == 2, f"input_tensor should be a 2D tensor, not {input_tensor.dim()}" # M, K from the input M, K = input_tensor.shape assert M % 128 == 0 assert K % 4 == 0 w_shuffled = shuffle_matrix_a(input_tensor, epilogue_tile_m) # 128x4 return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(w_shuffled)