perf: Optimize swizzle_sf, unswizzle_sf, reswizzle_sf (#5318)

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
Bo Li 2025-06-26 14:03:56 +08:00 committed by GitHub
parent 7e681fbe52
commit 1bab9000a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 446 additions and 103 deletions

View File

@ -226,18 +226,22 @@ void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float con
}
}
__global__ void nvfp4_block_scale_interleave_kernel(
int numbatches, int numRows, int numCols, uint8_t const* SFIn, uint8_t* SFOutput)
__global__ void nvfp4_block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded, int numCols,
int numColsPadded, uint8_t const* SFIn, uint8_t* SFOutput)
{
constexpr int SF_VEC_SIZE = 16;
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x)
for (int rowIdx = blockIdx.x; rowIdx < numRowsPadded; rowIdx += gridDim.x)
{
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++)
for (int batchIdx = 0; batchIdx < numBatches; batchIdx++)
{
for (int colIdx = threadIdx.x; colIdx < numCols; colIdx += blockDim.x)
for (int colIdx = threadIdx.x; colIdx < numColsPadded; colIdx += blockDim.x)
{
int64_t inOffset = batchIdx * numRows * numCols + rowIdx * numCols + colIdx;
auto sf = SFIn[inOffset];
uint8_t sf = 0;
if (rowIdx < numRows && colIdx < numCols)
{
int64_t inOffset = batchIdx * numRows * numCols + rowIdx * numCols + colIdx;
sf = SFIn[inOffset];
}
std::optional<int> batchIdxOpt = batchIdx;
std::optional<int> numRowsOpt = numRows;
@ -246,16 +250,55 @@ __global__ void nvfp4_block_scale_interleave_kernel(
// int const numSfTilesK = (numCols + 4 - 1) / 4;
// int const tileOffset = ((mi / 128) * numSfTilesK + ki / 4) * 512;
// int const dstIdx = tileOffset + (mi % 32) * 16 + ((mi % 128) / 32) * 4 + ki % 4;
auto dstIdx
= get_sf_out_offset_128x4<SF_VEC_SIZE>(batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols * 16);
auto dstIdx = get_sf_out_offset_128x4<SF_VEC_SIZE>(
batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols * SF_VEC_SIZE);
SFOutput[dstIdx] = sf;
}
}
}
}
__global__ void nvfp4_block_scale_interleave_reverse_kernel(
int numBatches, int numRows, int numCols, uint8_t const* SFIn, uint8_t* SFOutput)
{
constexpr int SF_VEC_SIZE = 16;
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x)
{
for (int batchIdx = 0; batchIdx < numBatches; batchIdx++)
{
for (int colIdx = threadIdx.x; colIdx < numCols; colIdx += blockDim.x)
{
std::optional<int> batchIdxOpt = batchIdx;
std::optional<int> numRowsOpt = numRows;
// Get the swizzled input index using the same swizzling pattern
auto srcIdx = get_sf_out_offset_128x4<SF_VEC_SIZE>(
batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols * SF_VEC_SIZE);
auto sf = SFIn[srcIdx];
// Output goes to linear layout
int64_t outOffset = batchIdx * numRows * numCols + rowIdx * numCols + colIdx;
SFOutput[outOffset] = sf;
}
}
}
}
// This is intended for weight loading, so m and n are large, b <= 256
void invokeNVFP4BlockScaleInterleave(
void invokeNVFP4BlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, uint8_t const* SFIn,
uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream)
{
// Each thread reads 1 int8 value
dim3 block(std::min(n_padded, 1024));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = std::max(1u, 4096u / block.x);
dim3 grid(std::min(m_padded, multiProcessorCount * numBlocksPerSM));
nvfp4_block_scale_interleave_kernel<<<grid, block, 0, stream>>>(b, m, m_padded, n, n_padded, SFIn, SFOutput);
}
// This is intended for weight loading, so m and n are large, b <= 256
void invokeNVFP4BlockScaleInterleaveReverse(
int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream)
{
// Each thread reads 1 int8 value
@ -264,7 +307,7 @@ void invokeNVFP4BlockScaleInterleave(
int const numBlocksPerSM = std::max(1u, 4096u / block.x);
dim3 grid(std::min(m, multiProcessorCount * numBlocksPerSM));
nvfp4_block_scale_interleave_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput);
nvfp4_block_scale_interleave_reverse_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput);
}
// Instantiate the function.

View File

@ -74,7 +74,10 @@ template <typename T, int SF_VEC_SIZE = 16>
void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* globalScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream = 0);
void invokeNVFP4BlockScaleInterleave(
void invokeNVFP4BlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, uint8_t const* SFIn,
uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0);
void invokeNVFP4BlockScaleInterleaveReverse(
int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0);
} // namespace kernels

View File

@ -104,6 +104,7 @@ float e2M1ToFloat(uint8_t value)
return result;
}
// Given the rowIdx and colIdx in the unswizzled SFMatrix, compute the 1D offset in the swizzled SFMatrix.
// colIdx and totalCloumn should be in SFMatrix, not activation Matrix, so no sfVecSize needed.
int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, tensorrt_llm::FP4QuantizationSFLayout layout)
{
@ -272,7 +273,9 @@ torch::autograd::variable_list HalfToE2M1AndUFP8SFScale(
return {valueE2M1, scaleFP8SF};
}
// Interleave the weights block scaling factor.
// Interleave (and possibly pad) the weights block scaling factor.
// blockScale: [num_experts, rows, cols] or [rows, cols]
// Return: num_experts * pad_up(rows, 128) * pad_up(cols, 4)
th::Tensor NVFP4BlockScaleInterleave(th::Tensor const& blockScale)
{
bool is_cuda = blockScale.device().is_cuda();
@ -291,31 +294,40 @@ th::Tensor NVFP4BlockScaleInterleave(th::Tensor const& blockScale)
auto cols = blockScaleShape.size() == 3 ? blockScaleShape[2] : blockScaleShape[1];
auto expert_out_size = tensorrt_llm::computeFP4SwizzledLayoutSFSize(rows, cols);
th::Tensor interleavedBlockScale = th::zeros(
auto rows_padded = PadUpFn(rows, 128);
auto cols_padded = PadUpFn(cols, 4);
TORCH_CHECK(
expert_out_size == rows_padded * cols_padded, "expert_out_size should be equal to rows_padded * cols_padded.");
th::Tensor interleavedBlockScale = th::empty(
{expert_out_size * num_experts}, th::dtype(SF_DTYPE).device(blockScale.device()).requires_grad(false));
if (is_cuda)
{
const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount();
auto stream = at::cuda::getCurrentCUDAStream(blockScale.get_device());
tensorrt_llm::kernels::invokeNVFP4BlockScaleInterleave(num_experts, rows, cols, blockScale.data_ptr<uint8_t>(),
static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream);
tensorrt_llm::kernels::invokeNVFP4BlockScaleInterleave(num_experts, rows, rows_padded, cols, cols_padded,
blockScale.data_ptr<uint8_t>(), static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream);
}
else
{
for (size_t eIdx = 0; eIdx < static_cast<size_t>(num_experts); eIdx++)
for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++)
{
uint8_t* interleavedBlockScalePtr
= static_cast<uint8_t*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
for (size_t rIdx = 0; rIdx < static_cast<size_t>(rows); ++rIdx)
for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx)
{
auto globalRowIdx = eIdx * rows + rIdx;
uint8_t* blockScalePtr = blockScale.data_ptr<uint8_t>() + globalRowIdx * cols;
for (int cIdx = 0; cIdx < cols; ++cIdx)
for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx)
{
uint8_t sf_ori = 0;
if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols))
{
sf_ori = blockScalePtr[cIdx];
}
int sf_index
= computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED);
interleavedBlockScalePtr[sf_index] = blockScalePtr[cIdx];
interleavedBlockScalePtr[sf_index] = sf_ori;
}
}
}
@ -324,41 +336,68 @@ th::Tensor NVFP4BlockScaleInterleave(th::Tensor const& blockScale)
return interleavedBlockScale;
}
// Reverse nterleave the weights block scaling factor.
th::Tensor NVFP4BlockScaleInterleaveReverse(th::Tensor blockScale)
// Reverse interleave the weights block scaling factor.
// blockScale: [num_experts, rows, cols] or [rows, cols]
// Note: rows and cols are the dimensions of the original unswizzled SFMatrix, so reshape input before passing into
// this function! Return: The same shape as blockScale
th::Tensor NVFP4BlockScaleInterleaveReverse(th::Tensor const& blockScale)
{
CHECK_CPU_INPUT(blockScale, SF_DTYPE);
bool is_cuda = blockScale.device().is_cuda();
if (is_cuda)
{
CHECK_INPUT(blockScale, SF_DTYPE);
}
else
{
CHECK_CPU_INPUT(blockScale, SF_DTYPE);
}
auto blockScaleShape = blockScale.sizes();
TORCH_CHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3, "Block Scale should be 2D or 3D tensor.");
auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1;
auto rows = blockScaleShape.size() == 3 ? blockScaleShape[1] : blockScaleShape[0];
auto cols = blockScaleShape.size() == 3 ? blockScaleShape[2] : blockScaleShape[1];
auto expert_out_size = tensorrt_llm::computeFP4SwizzledLayoutSFSize(rows, cols);
TORCH_CHECK(rows % 128 == 0, "rows of Interleaved block scales should be multiple of 128.");
TORCH_CHECK(cols % 4 == 0, "cols of Interleaved block scales should be multiple of 4.");
auto expert_out_size = rows * cols;
th::Tensor reversedBlockScale
= th::empty(blockScaleShape, th::dtype(SF_DTYPE).device(blockScale.device()).requires_grad(false));
th::Tensor reversedBlockScale = th::zeros(blockScaleShape, th::dtype(SF_DTYPE).requires_grad(false));
std::map<int, std::array<int, 3>> identity;
for (int eIdx = 0; eIdx < num_experts; eIdx++)
if (is_cuda)
{
for (int rIdx = 0; rIdx < rows; ++rIdx)
const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount();
auto stream = at::cuda::getCurrentCUDAStream(blockScale.get_device());
tensorrt_llm::kernels::invokeNVFP4BlockScaleInterleaveReverse(num_experts, rows, cols,
blockScale.data_ptr<uint8_t>(), static_cast<uint8_t*>(reversedBlockScale.data_ptr()), smCount, stream);
}
else
{
// index in the swizzled SFMatrix -> (eIdx, rIdx, cIdx) in the unswizzled SFMatrix
std::map<int, std::array<int, 3>> identity;
for (int eIdx = 0; eIdx < num_experts; eIdx++)
{
for (int cIdx = 0; cIdx < cols; ++cIdx)
for (int rIdx = 0; rIdx < rows; ++rIdx)
{
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED);
identity[eIdx * expert_out_size + sf_index] = std::array<int, 3>{eIdx, rIdx, cIdx};
for (int cIdx = 0; cIdx < cols; ++cIdx)
{
int sf_index
= computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED);
identity[eIdx * expert_out_size + sf_index] = std::array<int, 3>{eIdx, rIdx, cIdx};
}
}
}
uint8_t* blockScalePtr = static_cast<uint8_t*>(blockScale.data_ptr());
for (int i = 0; i < expert_out_size * num_experts; i++)
{
auto loc_2d = identity[i];
if (loc_2d[1] < rows && loc_2d[2] < cols)
{
uint8_t* reversedBlockScalePtr
= reversedBlockScale.data_ptr<uint8_t>() + (loc_2d[0] * rows + loc_2d[1]) * cols + loc_2d[2];
*reversedBlockScalePtr = blockScalePtr[i];
}
}
}
uint8_t* blockScalePtr = static_cast<uint8_t*>(blockScale.data_ptr());
for (int i = 0; i < expert_out_size * num_experts; i++)
{
auto loc_2d = identity[i];
if (loc_2d[1] < rows && loc_2d[2] < cols)
{
uint8_t* reversedBlockScalePtr
= reversedBlockScale.data_ptr<uint8_t>() + (loc_2d[0] * rows + loc_2d[1]) * cols + loc_2d[2];
*reversedBlockScalePtr = blockScalePtr[i];
}
}
return reversedBlockScale;
}

View File

@ -30,6 +30,7 @@ _add_trt_llm_dll_directory()
import sys
import tensorrt_llm.functional as functional
import tensorrt_llm.math_utils as math_utils
import tensorrt_llm.models as models
import tensorrt_llm.quantization as quantization
import tensorrt_llm.runtime as runtime
@ -104,6 +105,7 @@ __all__ = [
'SamplingParams',
'DisaggregatedParams',
'KvCacheConfig',
'math_utils',
'__version__',
]

View File

@ -8,6 +8,7 @@ from typing import Dict, List
import torch
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
from tensorrt_llm.math_utils import ceil_div, pad_up
from tensorrt_llm.quantization.utils import fp4_utils
is_torch_compiling_flag = False
@ -106,75 +107,85 @@ def disable_fp4_allgather():
def compute_swizzled_sf_shape(row: int, col: int):
padded_row = (row + 128 - 1) // 128 * 128
padded_col = (col + 4 - 1) // 4 * 4
padded_row = pad_up(row, 128)
padded_col = pad_up(col, 4)
return padded_row, padded_col
def swizzle_sf(sf: torch.Tensor,
row: int,
col: int,
rows: int,
cols: int,
scaling_vector_size: int = 16):
factor = scaling_vector_size * 4
num_m_tiles = (row + 128 - 1) // 128
num_k_tiles = (col + factor - 1) // factor
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
sf_full = torch.zeros(num_m_tiles * 32 * 4,
num_k_tiles * 4,
dtype=sf.dtype,
device=sf.device)
sf_full[:row, :(col //
scaling_vector_size)] = sf[:row, :(col //
scaling_vector_size)]
sf_full_reshaped = sf_full.view(num_m_tiles, 4, 32, num_k_tiles, 4)
sf_full_swizzle = sf_full_reshaped.transpose(1, 3)
sf_swizzle = sf_full_swizzle.reshape(-1)
return sf_swizzle
"""Swizzle FP4 scaling factors using C++ torch op implementation
Args:
sf: [b, rows, cols_sf] or [rows, cols_sf]. The original unswizzled scaling factors.
rows: rows of the original unquantized tensor
cols_sf: ceil_div(cols, scaling_vector_size) where cols is the number of columns of the original unquantized tensor
scaling_vector_size: the size of the scaling vector
Returns:
[b * pad_up(rows, 128) * pad_up(cols_sf, 4), ] 1D swizzled scaling factors, possibly with rows and cols padded.
"""
sf_cols = ceil_div(cols, scaling_vector_size)
sf = sf.view(-1, rows, sf_cols)
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(sf)
def unswizzle_sf(sf: torch.Tensor,
row: int,
col: int,
rows: int,
cols: int,
scaling_vector_size: int = 16):
factor = scaling_vector_size * 4
num_m_tiles = (row + 128 - 1) // 128
num_k_tiles = (col + factor - 1) // factor
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
sf_reshaped = sf.view(num_m_tiles, num_k_tiles, 32, 4, 4)
sf_unswizzle = sf_reshaped.transpose(1, 3)
sf_unswizzle = sf_unswizzle.reshape(num_m_tiles * 32 * 4, num_k_tiles * 4)
sf_unswizzle_sliced = sf_unswizzle[:row, :(col // scaling_vector_size)]
return sf_unswizzle_sliced.contiguous()
"""Swizzle FP4 scaling factors using C++ torch op implementation
Args:
sf: The (padded and) swizzled scaling factors.
rows: rows of the original unquantized tensor
cols: cols of the original unquantized tensor
scaling_vector_size: the size of the scaling vector
Returns:
2D unswizzled scaling factors
"""
sf_cols = ceil_div(cols, scaling_vector_size)
sf = sf.view(-1, rows, sf_cols)
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(sf).view(
-1, sf_cols)
def reswizzle_sf(sf: torch.Tensor,
row: int,
col: int,
rows: int,
cols: int,
scaling_vector_size: int = 16):
factor = scaling_vector_size * 4
num_m_tiles = (row + 128 - 1) // 128
num_k_tiles = (col + factor - 1) // factor
partition_size = num_m_tiles * num_k_tiles * 32 * 4 * 4
num_partitions = sf.numel() // partition_size
sf_reshaped = sf.view(num_partitions, num_m_tiles, num_k_tiles, 32, 4, 4)
sf_unswizzle = sf_reshaped.transpose(2, 4)
sf_unswizzle = sf_unswizzle.reshape(num_partitions, num_m_tiles * 32 * 4,
num_k_tiles * 4)
total_rows = num_partitions * row
num_m_tiles_out = (total_rows + 128 - 1) // 128
sf_out = torch.zeros(
num_m_tiles_out,
4,
32,
num_k_tiles,
4,
dtype=sf.dtype,
device=sf.device,
)
sf_out_reshaped = sf_out.view(num_m_tiles_out * 32 * 4, num_k_tiles * 4)
sf_out_reshaped[:total_rows] = sf_unswizzle[:, :row].reshape(total_rows, -1)
sf_out_swizzle = sf_out.transpose(1, 3).reshape(-1)
return sf_out_swizzle
"""Reswizzle FP4 scaling factors using C++ torch op implementation.
It unswizzles the scaling factors in each partition first, then concatenates them together, and finally swizzles them back.
Args:
sf: The (padded and) swizzled scaling factors.
rows: rows of the original unquantized tensor
cols: cols of the original unquantized tensor
scaling_vector_size: the size of the scaling vector
Returns:
1D reswizzled scaling factors
"""
sf_cols = ceil_div(cols, scaling_vector_size)
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
padded_cols = padded_sf_cols * scaling_vector_size
assert sf.numel() % (padded_rows * padded_sf_cols) == 0
num_partitions = sf.numel() // (padded_rows * padded_sf_cols)
sf_reshaped = sf.view(num_partitions, padded_rows, padded_sf_cols)
# Unswizzle each partition
sf_unswizzled = unswizzle_sf(sf_reshaped, padded_rows, padded_cols,
scaling_vector_size)
# Brings the unswizzled scaling factors in each partition together
total_rows = num_partitions * rows
sf_unswizzled = sf_unswizzled.view(num_partitions, padded_rows,
padded_sf_cols)
sf_concatenated = sf_unswizzled[:, :rows, :sf_cols].contiguous(
) # TODO: This will incur a elementwise kernel
sf_concatenated = sf_concatenated.view(total_rows, sf_cols)
# Finally swizzle the concatenated scaling factors
return swizzle_sf(sf_concatenated, total_rows, cols, scaling_vector_size)
def next_positive_power_of_2(x: int) -> int:

View File

@ -0,0 +1,20 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def pad_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y

View File

@ -6,11 +6,6 @@ import torch
SF_DTYPE = torch.uint8
FLOAT4_E2M1X2 = torch.uint8
def pad_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
# For GEMM autotuning.
# Taken from https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/include/tensorrt_llm/runtime//modelConfig.h#L38
# TODO: move to model config, tune for blackwell hardware
@ -24,6 +19,10 @@ fp4_buckets = FP4_BUCKETS
__all__ = ['float4_e2m1x2', 'float4_sf_dtype', 'pad_up', 'fp4_buckets']
def pad_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
class FP4GemmType(IntEnum):
W4A4_NVFP4_NVFP4 = 0
W4A8_MXFP4_MXFP8 = 1

View File

@ -0,0 +1,226 @@
import pytest
import torch
from utils.util import skip_pre_blackwell
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._torch.utils import (compute_swizzled_sf_shape, reswizzle_sf,
swizzle_sf, unswizzle_sf)
from tensorrt_llm.math_utils import ceil_div
# Reference PyTorch implementations (original)
def swizzle_sf_ref(sf: torch.Tensor,
row: int,
col: int,
scaling_vector_size: int = 16):
"""Reference PyTorch implementation of swizzle_sf"""
col_sf = ceil_div(col, scaling_vector_size)
num_m_tiles = ceil_div(row, 128)
num_k_tiles = ceil_div(col_sf, 4)
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
sf_full = torch.zeros(num_m_tiles * 32 * 4,
num_k_tiles * 4,
dtype=sf.dtype,
device=sf.device)
sf_full[:row, :col_sf] = sf[:row, :col_sf]
sf_full_reshaped = sf_full.view(num_m_tiles, 4, 32, num_k_tiles, 4)
sf_full_swizzle = sf_full_reshaped.transpose(1, 3)
sf_swizzle = sf_full_swizzle.reshape(-1)
return sf_swizzle
def unswizzle_sf_ref(sf: torch.Tensor,
row: int,
col: int,
scaling_vector_size: int = 16):
"""Reference PyTorch implementation of unswizzle_sf"""
cols_sf = ceil_div(col, scaling_vector_size)
num_m_tiles = ceil_div(row, 128)
num_k_tiles = ceil_div(cols_sf, 4)
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
sf_reshaped = sf.view(num_m_tiles, num_k_tiles, 32, 4, 4)
sf_unswizzle = sf_reshaped.transpose(1, 3)
sf_unswizzle = sf_unswizzle.reshape(num_m_tiles * 32 * 4, num_k_tiles * 4)
return sf_unswizzle.contiguous()
def reswizzle_sf_ref(sf: torch.Tensor,
row: int,
col: int,
scaling_vector_size: int = 16):
"""Reference PyTorch implementation of reswizzle_sf"""
cols_sf = ceil_div(col, scaling_vector_size)
num_m_tiles = ceil_div(row, 128)
num_k_tiles = ceil_div(cols_sf, 4)
partition_size = num_m_tiles * num_k_tiles * 32 * 4 * 4
num_partitions = sf.numel() // partition_size
sf_reshaped = sf.view(num_partitions, num_m_tiles, num_k_tiles, 32, 4, 4)
sf_unswizzle = sf_reshaped.transpose(2, 4)
sf_unswizzle = sf_unswizzle.reshape(num_partitions, num_m_tiles * 32 * 4,
num_k_tiles * 4)
total_rows = num_partitions * row
num_m_tiles_out = ceil_div(total_rows, 128)
sf_out = torch.zeros(
num_m_tiles_out,
4,
32,
num_k_tiles,
4,
dtype=sf.dtype,
device=sf.device,
)
sf_out_reshaped = sf_out.view(num_m_tiles_out * 32 * 4, num_k_tiles * 4)
sf_out_reshaped[:total_rows] = sf_unswizzle[:, :row].reshape(total_rows, -1)
sf_out_swizzle = sf_out.transpose(1, 3).reshape(-1)
return sf_out_swizzle
@skip_pre_blackwell
@pytest.mark.parametrize(
"rows,cols",
[
(1, 1), # test padding
(1, 16),
(1, 63),
(127, 1),
(127, 16),
(127, 63),
(128, 64), # 1x1 tiles
(128, 128), # 1×2 tiles
(256, 64), # 2×1 tiles
(512, 256), # 4×4 tiles
])
def test_swizzle_sf(rows, cols):
"""Test C++ swizzle_sf against PyTorch reference implementation"""
scaling_vector_size = 16
sf_cols = ceil_div(cols, scaling_vector_size)
# Create scaling factor data using fp4_sf_dtype
sf_data = torch.randint(0,
256, (rows * sf_cols, ),
dtype=fp4_utils.float4_sf_dtype,
device="cuda").view(rows, sf_cols)
# Apply reference implementation
ref_result = swizzle_sf_ref(sf_data, rows, cols, scaling_vector_size)
# Apply C++ implementation
result = swizzle_sf(sf_data, rows, cols, scaling_vector_size)
# Verify results are equivalent
torch.testing.assert_close(result, ref_result)
@skip_pre_blackwell
@pytest.mark.parametrize(
"rows,cols",
[
(128, 64), # 1x1 tiles
(128, 128), # 1×2 tiles
(256, 64), # 2×1 tiles
(512, 256), # 4×4 tiles
])
def test_unswizzle_sf(rows, cols):
"""Test C++ unswizzle_sf against PyTorch reference implementation"""
scaling_vector_size = 16
sf_cols = ceil_div(cols, scaling_vector_size)
# Create scaling factor data by first swizzling with reference implementation
original_sf_data = torch.randint(0,
256, (rows * sf_cols, ),
dtype=fp4_utils.float4_sf_dtype,
device="cuda").view(rows, sf_cols)
swizzled_sf_data = swizzle_sf_ref(original_sf_data, rows, cols,
scaling_vector_size)
# Apply reference unswizzle
ref_result = unswizzle_sf_ref(swizzled_sf_data, rows, cols,
scaling_vector_size)
# Note that unlike swizzle_sf, unswizzle_sf does not return a 1D tensor
result = unswizzle_sf(swizzled_sf_data, rows, cols, scaling_vector_size)
# Verify C++ result matches reference result
torch.testing.assert_close(result, ref_result)
@skip_pre_blackwell
@pytest.mark.parametrize(
"rows,cols",
[
(1, 1), # test padding
(1, 16),
(1, 63),
(127, 1),
(127, 16),
(127, 63),
(128, 64), # 1x1 tiles
(128, 128), # 1×2 tiles
(256, 64), # 2×1 tiles
(512, 256), # 4×4 tiles
])
def test_swizzle_round_trip(rows, cols):
"""Test that swizzle/unswizzle operations are inverse of each other"""
scaling_vector_size = 16
sf_cols = ceil_div(cols, scaling_vector_size)
# Create scaling factor data
original_sf_data = torch.randint(0,
256, (rows * sf_cols, ),
dtype=fp4_utils.float4_sf_dtype,
device="cuda")
# Apply swizzle then unswizzle using the utils functions
swizzled_sf = swizzle_sf(original_sf_data, rows, cols, scaling_vector_size)
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
# Check that the swizzled scaling factor data is padded correctly
assert padded_rows * padded_sf_cols == swizzled_sf.numel()
padded_cols = padded_sf_cols * scaling_vector_size
unswizzled_sf = unswizzle_sf(swizzled_sf, padded_rows, padded_cols,
scaling_vector_size)
unswizzled_sf = unswizzled_sf[:rows, :sf_cols]
# Verify round-trip preserves original scaling factor data
torch.testing.assert_close(original_sf_data.view(rows, sf_cols),
unswizzled_sf[:rows, :sf_cols])
@skip_pre_blackwell
@pytest.mark.parametrize("num_partitions", [1, 2, 3])
@pytest.mark.parametrize(
"rows,cols",
[
(1, 1), # test padding
(1, 16),
(1, 63),
(127, 1),
(127, 16),
(127, 63),
(128, 64), # 1x1 tiles
(128, 128), # 1×2 tiles
(256, 64), # 2×1 tiles
(512, 256), # 4×4 tiles
])
def test_reswizzle_sf(num_partitions, rows, cols):
"""Test C++ reswizzle_sf against PyTorch reference implementation"""
scaling_vector_size = 16
sf_cols = ceil_div(cols, scaling_vector_size)
original_sf_data = torch.randint(0,
256, (num_partitions, rows, sf_cols),
dtype=fp4_utils.float4_sf_dtype,
device="cuda")
swizzled_sf_data = swizzle_sf(original_sf_data, rows, cols,
scaling_vector_size)
# Apply reference reswizzle
ref_result = reswizzle_sf_ref(swizzled_sf_data, rows, cols,
scaling_vector_size)
# Apply C++ reswizzle
result = reswizzle_sf(swizzled_sf_data, rows, cols, scaling_vector_size)
# Verify results are equivalent
torch.testing.assert_close(result, ref_result)