mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
perf: Optimize swizzle_sf, unswizzle_sf, reswizzle_sf (#5318)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
parent
7e681fbe52
commit
1bab9000a6
@ -226,18 +226,22 @@ void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float con
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void nvfp4_block_scale_interleave_kernel(
|
||||
int numbatches, int numRows, int numCols, uint8_t const* SFIn, uint8_t* SFOutput)
|
||||
__global__ void nvfp4_block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded, int numCols,
|
||||
int numColsPadded, uint8_t const* SFIn, uint8_t* SFOutput)
|
||||
{
|
||||
constexpr int SF_VEC_SIZE = 16;
|
||||
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x)
|
||||
for (int rowIdx = blockIdx.x; rowIdx < numRowsPadded; rowIdx += gridDim.x)
|
||||
{
|
||||
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++)
|
||||
for (int batchIdx = 0; batchIdx < numBatches; batchIdx++)
|
||||
{
|
||||
for (int colIdx = threadIdx.x; colIdx < numCols; colIdx += blockDim.x)
|
||||
for (int colIdx = threadIdx.x; colIdx < numColsPadded; colIdx += blockDim.x)
|
||||
{
|
||||
int64_t inOffset = batchIdx * numRows * numCols + rowIdx * numCols + colIdx;
|
||||
auto sf = SFIn[inOffset];
|
||||
uint8_t sf = 0;
|
||||
if (rowIdx < numRows && colIdx < numCols)
|
||||
{
|
||||
int64_t inOffset = batchIdx * numRows * numCols + rowIdx * numCols + colIdx;
|
||||
sf = SFIn[inOffset];
|
||||
}
|
||||
|
||||
std::optional<int> batchIdxOpt = batchIdx;
|
||||
std::optional<int> numRowsOpt = numRows;
|
||||
@ -246,16 +250,55 @@ __global__ void nvfp4_block_scale_interleave_kernel(
|
||||
// int const numSfTilesK = (numCols + 4 - 1) / 4;
|
||||
// int const tileOffset = ((mi / 128) * numSfTilesK + ki / 4) * 512;
|
||||
// int const dstIdx = tileOffset + (mi % 32) * 16 + ((mi % 128) / 32) * 4 + ki % 4;
|
||||
auto dstIdx
|
||||
= get_sf_out_offset_128x4<SF_VEC_SIZE>(batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols * 16);
|
||||
auto dstIdx = get_sf_out_offset_128x4<SF_VEC_SIZE>(
|
||||
batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols * SF_VEC_SIZE);
|
||||
SFOutput[dstIdx] = sf;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void nvfp4_block_scale_interleave_reverse_kernel(
|
||||
int numBatches, int numRows, int numCols, uint8_t const* SFIn, uint8_t* SFOutput)
|
||||
{
|
||||
constexpr int SF_VEC_SIZE = 16;
|
||||
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x)
|
||||
{
|
||||
for (int batchIdx = 0; batchIdx < numBatches; batchIdx++)
|
||||
{
|
||||
for (int colIdx = threadIdx.x; colIdx < numCols; colIdx += blockDim.x)
|
||||
{
|
||||
std::optional<int> batchIdxOpt = batchIdx;
|
||||
std::optional<int> numRowsOpt = numRows;
|
||||
|
||||
// Get the swizzled input index using the same swizzling pattern
|
||||
auto srcIdx = get_sf_out_offset_128x4<SF_VEC_SIZE>(
|
||||
batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols * SF_VEC_SIZE);
|
||||
auto sf = SFIn[srcIdx];
|
||||
|
||||
// Output goes to linear layout
|
||||
int64_t outOffset = batchIdx * numRows * numCols + rowIdx * numCols + colIdx;
|
||||
SFOutput[outOffset] = sf;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is intended for weight loading, so m and n are large, b <= 256
|
||||
void invokeNVFP4BlockScaleInterleave(
|
||||
void invokeNVFP4BlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, uint8_t const* SFIn,
|
||||
uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream)
|
||||
{
|
||||
// Each thread reads 1 int8 value
|
||||
dim3 block(std::min(n_padded, 1024));
|
||||
// Get number of blocks per SM (assume we can fully utilize the SM).
|
||||
int const numBlocksPerSM = std::max(1u, 4096u / block.x);
|
||||
dim3 grid(std::min(m_padded, multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
nvfp4_block_scale_interleave_kernel<<<grid, block, 0, stream>>>(b, m, m_padded, n, n_padded, SFIn, SFOutput);
|
||||
}
|
||||
|
||||
// This is intended for weight loading, so m and n are large, b <= 256
|
||||
void invokeNVFP4BlockScaleInterleaveReverse(
|
||||
int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream)
|
||||
{
|
||||
// Each thread reads 1 int8 value
|
||||
@ -264,7 +307,7 @@ void invokeNVFP4BlockScaleInterleave(
|
||||
int const numBlocksPerSM = std::max(1u, 4096u / block.x);
|
||||
dim3 grid(std::min(m, multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
nvfp4_block_scale_interleave_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput);
|
||||
nvfp4_block_scale_interleave_reverse_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput);
|
||||
}
|
||||
|
||||
// Instantiate the function.
|
||||
|
||||
@ -74,7 +74,10 @@ template <typename T, int SF_VEC_SIZE = 16>
|
||||
void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* globalScale, int64_t* output,
|
||||
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream = 0);
|
||||
|
||||
void invokeNVFP4BlockScaleInterleave(
|
||||
void invokeNVFP4BlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, uint8_t const* SFIn,
|
||||
uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0);
|
||||
|
||||
void invokeNVFP4BlockScaleInterleaveReverse(
|
||||
int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -104,6 +104,7 @@ float e2M1ToFloat(uint8_t value)
|
||||
return result;
|
||||
}
|
||||
|
||||
// Given the rowIdx and colIdx in the unswizzled SFMatrix, compute the 1D offset in the swizzled SFMatrix.
|
||||
// colIdx and totalCloumn should be in SFMatrix, not activation Matrix, so no sfVecSize needed.
|
||||
int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, tensorrt_llm::FP4QuantizationSFLayout layout)
|
||||
{
|
||||
@ -272,7 +273,9 @@ torch::autograd::variable_list HalfToE2M1AndUFP8SFScale(
|
||||
return {valueE2M1, scaleFP8SF};
|
||||
}
|
||||
|
||||
// Interleave the weights block scaling factor.
|
||||
// Interleave (and possibly pad) the weights block scaling factor.
|
||||
// blockScale: [num_experts, rows, cols] or [rows, cols]
|
||||
// Return: num_experts * pad_up(rows, 128) * pad_up(cols, 4)
|
||||
th::Tensor NVFP4BlockScaleInterleave(th::Tensor const& blockScale)
|
||||
{
|
||||
bool is_cuda = blockScale.device().is_cuda();
|
||||
@ -291,31 +294,40 @@ th::Tensor NVFP4BlockScaleInterleave(th::Tensor const& blockScale)
|
||||
auto cols = blockScaleShape.size() == 3 ? blockScaleShape[2] : blockScaleShape[1];
|
||||
|
||||
auto expert_out_size = tensorrt_llm::computeFP4SwizzledLayoutSFSize(rows, cols);
|
||||
th::Tensor interleavedBlockScale = th::zeros(
|
||||
auto rows_padded = PadUpFn(rows, 128);
|
||||
auto cols_padded = PadUpFn(cols, 4);
|
||||
TORCH_CHECK(
|
||||
expert_out_size == rows_padded * cols_padded, "expert_out_size should be equal to rows_padded * cols_padded.");
|
||||
th::Tensor interleavedBlockScale = th::empty(
|
||||
{expert_out_size * num_experts}, th::dtype(SF_DTYPE).device(blockScale.device()).requires_grad(false));
|
||||
|
||||
if (is_cuda)
|
||||
{
|
||||
const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(blockScale.get_device());
|
||||
tensorrt_llm::kernels::invokeNVFP4BlockScaleInterleave(num_experts, rows, cols, blockScale.data_ptr<uint8_t>(),
|
||||
static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream);
|
||||
tensorrt_llm::kernels::invokeNVFP4BlockScaleInterleave(num_experts, rows, rows_padded, cols, cols_padded,
|
||||
blockScale.data_ptr<uint8_t>(), static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t eIdx = 0; eIdx < static_cast<size_t>(num_experts); eIdx++)
|
||||
for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++)
|
||||
{
|
||||
uint8_t* interleavedBlockScalePtr
|
||||
= static_cast<uint8_t*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
|
||||
for (size_t rIdx = 0; rIdx < static_cast<size_t>(rows); ++rIdx)
|
||||
for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx)
|
||||
{
|
||||
auto globalRowIdx = eIdx * rows + rIdx;
|
||||
uint8_t* blockScalePtr = blockScale.data_ptr<uint8_t>() + globalRowIdx * cols;
|
||||
for (int cIdx = 0; cIdx < cols; ++cIdx)
|
||||
for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx)
|
||||
{
|
||||
uint8_t sf_ori = 0;
|
||||
if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols))
|
||||
{
|
||||
sf_ori = blockScalePtr[cIdx];
|
||||
}
|
||||
int sf_index
|
||||
= computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED);
|
||||
interleavedBlockScalePtr[sf_index] = blockScalePtr[cIdx];
|
||||
interleavedBlockScalePtr[sf_index] = sf_ori;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -324,41 +336,68 @@ th::Tensor NVFP4BlockScaleInterleave(th::Tensor const& blockScale)
|
||||
return interleavedBlockScale;
|
||||
}
|
||||
|
||||
// Reverse nterleave the weights block scaling factor.
|
||||
th::Tensor NVFP4BlockScaleInterleaveReverse(th::Tensor blockScale)
|
||||
// Reverse interleave the weights block scaling factor.
|
||||
// blockScale: [num_experts, rows, cols] or [rows, cols]
|
||||
// Note: rows and cols are the dimensions of the original unswizzled SFMatrix, so reshape input before passing into
|
||||
// this function! Return: The same shape as blockScale
|
||||
th::Tensor NVFP4BlockScaleInterleaveReverse(th::Tensor const& blockScale)
|
||||
{
|
||||
CHECK_CPU_INPUT(blockScale, SF_DTYPE);
|
||||
bool is_cuda = blockScale.device().is_cuda();
|
||||
if (is_cuda)
|
||||
{
|
||||
CHECK_INPUT(blockScale, SF_DTYPE);
|
||||
}
|
||||
else
|
||||
{
|
||||
CHECK_CPU_INPUT(blockScale, SF_DTYPE);
|
||||
}
|
||||
auto blockScaleShape = blockScale.sizes();
|
||||
TORCH_CHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3, "Block Scale should be 2D or 3D tensor.");
|
||||
auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1;
|
||||
auto rows = blockScaleShape.size() == 3 ? blockScaleShape[1] : blockScaleShape[0];
|
||||
auto cols = blockScaleShape.size() == 3 ? blockScaleShape[2] : blockScaleShape[1];
|
||||
auto expert_out_size = tensorrt_llm::computeFP4SwizzledLayoutSFSize(rows, cols);
|
||||
TORCH_CHECK(rows % 128 == 0, "rows of Interleaved block scales should be multiple of 128.");
|
||||
TORCH_CHECK(cols % 4 == 0, "cols of Interleaved block scales should be multiple of 4.");
|
||||
auto expert_out_size = rows * cols;
|
||||
th::Tensor reversedBlockScale
|
||||
= th::empty(blockScaleShape, th::dtype(SF_DTYPE).device(blockScale.device()).requires_grad(false));
|
||||
|
||||
th::Tensor reversedBlockScale = th::zeros(blockScaleShape, th::dtype(SF_DTYPE).requires_grad(false));
|
||||
std::map<int, std::array<int, 3>> identity;
|
||||
for (int eIdx = 0; eIdx < num_experts; eIdx++)
|
||||
if (is_cuda)
|
||||
{
|
||||
for (int rIdx = 0; rIdx < rows; ++rIdx)
|
||||
const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(blockScale.get_device());
|
||||
tensorrt_llm::kernels::invokeNVFP4BlockScaleInterleaveReverse(num_experts, rows, cols,
|
||||
blockScale.data_ptr<uint8_t>(), static_cast<uint8_t*>(reversedBlockScale.data_ptr()), smCount, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
// index in the swizzled SFMatrix -> (eIdx, rIdx, cIdx) in the unswizzled SFMatrix
|
||||
std::map<int, std::array<int, 3>> identity;
|
||||
for (int eIdx = 0; eIdx < num_experts; eIdx++)
|
||||
{
|
||||
for (int cIdx = 0; cIdx < cols; ++cIdx)
|
||||
for (int rIdx = 0; rIdx < rows; ++rIdx)
|
||||
{
|
||||
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED);
|
||||
identity[eIdx * expert_out_size + sf_index] = std::array<int, 3>{eIdx, rIdx, cIdx};
|
||||
for (int cIdx = 0; cIdx < cols; ++cIdx)
|
||||
{
|
||||
int sf_index
|
||||
= computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED);
|
||||
identity[eIdx * expert_out_size + sf_index] = std::array<int, 3>{eIdx, rIdx, cIdx};
|
||||
}
|
||||
}
|
||||
}
|
||||
uint8_t* blockScalePtr = static_cast<uint8_t*>(blockScale.data_ptr());
|
||||
for (int i = 0; i < expert_out_size * num_experts; i++)
|
||||
{
|
||||
auto loc_2d = identity[i];
|
||||
if (loc_2d[1] < rows && loc_2d[2] < cols)
|
||||
{
|
||||
uint8_t* reversedBlockScalePtr
|
||||
= reversedBlockScale.data_ptr<uint8_t>() + (loc_2d[0] * rows + loc_2d[1]) * cols + loc_2d[2];
|
||||
*reversedBlockScalePtr = blockScalePtr[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
uint8_t* blockScalePtr = static_cast<uint8_t*>(blockScale.data_ptr());
|
||||
for (int i = 0; i < expert_out_size * num_experts; i++)
|
||||
{
|
||||
auto loc_2d = identity[i];
|
||||
if (loc_2d[1] < rows && loc_2d[2] < cols)
|
||||
{
|
||||
uint8_t* reversedBlockScalePtr
|
||||
= reversedBlockScale.data_ptr<uint8_t>() + (loc_2d[0] * rows + loc_2d[1]) * cols + loc_2d[2];
|
||||
*reversedBlockScalePtr = blockScalePtr[i];
|
||||
}
|
||||
}
|
||||
|
||||
return reversedBlockScale;
|
||||
}
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ _add_trt_llm_dll_directory()
|
||||
import sys
|
||||
|
||||
import tensorrt_llm.functional as functional
|
||||
import tensorrt_llm.math_utils as math_utils
|
||||
import tensorrt_llm.models as models
|
||||
import tensorrt_llm.quantization as quantization
|
||||
import tensorrt_llm.runtime as runtime
|
||||
@ -104,6 +105,7 @@ __all__ = [
|
||||
'SamplingParams',
|
||||
'DisaggregatedParams',
|
||||
'KvCacheConfig',
|
||||
'math_utils',
|
||||
'__version__',
|
||||
]
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Dict, List
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
|
||||
from tensorrt_llm.math_utils import ceil_div, pad_up
|
||||
from tensorrt_llm.quantization.utils import fp4_utils
|
||||
|
||||
is_torch_compiling_flag = False
|
||||
@ -106,75 +107,85 @@ def disable_fp4_allgather():
|
||||
|
||||
|
||||
def compute_swizzled_sf_shape(row: int, col: int):
|
||||
padded_row = (row + 128 - 1) // 128 * 128
|
||||
padded_col = (col + 4 - 1) // 4 * 4
|
||||
padded_row = pad_up(row, 128)
|
||||
padded_col = pad_up(col, 4)
|
||||
return padded_row, padded_col
|
||||
|
||||
|
||||
def swizzle_sf(sf: torch.Tensor,
|
||||
row: int,
|
||||
col: int,
|
||||
rows: int,
|
||||
cols: int,
|
||||
scaling_vector_size: int = 16):
|
||||
factor = scaling_vector_size * 4
|
||||
num_m_tiles = (row + 128 - 1) // 128
|
||||
num_k_tiles = (col + factor - 1) // factor
|
||||
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
|
||||
sf_full = torch.zeros(num_m_tiles * 32 * 4,
|
||||
num_k_tiles * 4,
|
||||
dtype=sf.dtype,
|
||||
device=sf.device)
|
||||
sf_full[:row, :(col //
|
||||
scaling_vector_size)] = sf[:row, :(col //
|
||||
scaling_vector_size)]
|
||||
sf_full_reshaped = sf_full.view(num_m_tiles, 4, 32, num_k_tiles, 4)
|
||||
sf_full_swizzle = sf_full_reshaped.transpose(1, 3)
|
||||
sf_swizzle = sf_full_swizzle.reshape(-1)
|
||||
return sf_swizzle
|
||||
"""Swizzle FP4 scaling factors using C++ torch op implementation
|
||||
Args:
|
||||
sf: [b, rows, cols_sf] or [rows, cols_sf]. The original unswizzled scaling factors.
|
||||
rows: rows of the original unquantized tensor
|
||||
cols_sf: ceil_div(cols, scaling_vector_size) where cols is the number of columns of the original unquantized tensor
|
||||
scaling_vector_size: the size of the scaling vector
|
||||
Returns:
|
||||
[b * pad_up(rows, 128) * pad_up(cols_sf, 4), ] 1D swizzled scaling factors, possibly with rows and cols padded.
|
||||
"""
|
||||
sf_cols = ceil_div(cols, scaling_vector_size)
|
||||
sf = sf.view(-1, rows, sf_cols)
|
||||
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(sf)
|
||||
|
||||
|
||||
def unswizzle_sf(sf: torch.Tensor,
|
||||
row: int,
|
||||
col: int,
|
||||
rows: int,
|
||||
cols: int,
|
||||
scaling_vector_size: int = 16):
|
||||
factor = scaling_vector_size * 4
|
||||
num_m_tiles = (row + 128 - 1) // 128
|
||||
num_k_tiles = (col + factor - 1) // factor
|
||||
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
|
||||
sf_reshaped = sf.view(num_m_tiles, num_k_tiles, 32, 4, 4)
|
||||
sf_unswizzle = sf_reshaped.transpose(1, 3)
|
||||
sf_unswizzle = sf_unswizzle.reshape(num_m_tiles * 32 * 4, num_k_tiles * 4)
|
||||
sf_unswizzle_sliced = sf_unswizzle[:row, :(col // scaling_vector_size)]
|
||||
return sf_unswizzle_sliced.contiguous()
|
||||
"""Swizzle FP4 scaling factors using C++ torch op implementation
|
||||
Args:
|
||||
sf: The (padded and) swizzled scaling factors.
|
||||
rows: rows of the original unquantized tensor
|
||||
cols: cols of the original unquantized tensor
|
||||
scaling_vector_size: the size of the scaling vector
|
||||
Returns:
|
||||
2D unswizzled scaling factors
|
||||
"""
|
||||
sf_cols = ceil_div(cols, scaling_vector_size)
|
||||
sf = sf.view(-1, rows, sf_cols)
|
||||
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(sf).view(
|
||||
-1, sf_cols)
|
||||
|
||||
|
||||
def reswizzle_sf(sf: torch.Tensor,
|
||||
row: int,
|
||||
col: int,
|
||||
rows: int,
|
||||
cols: int,
|
||||
scaling_vector_size: int = 16):
|
||||
factor = scaling_vector_size * 4
|
||||
num_m_tiles = (row + 128 - 1) // 128
|
||||
num_k_tiles = (col + factor - 1) // factor
|
||||
partition_size = num_m_tiles * num_k_tiles * 32 * 4 * 4
|
||||
num_partitions = sf.numel() // partition_size
|
||||
sf_reshaped = sf.view(num_partitions, num_m_tiles, num_k_tiles, 32, 4, 4)
|
||||
sf_unswizzle = sf_reshaped.transpose(2, 4)
|
||||
sf_unswizzle = sf_unswizzle.reshape(num_partitions, num_m_tiles * 32 * 4,
|
||||
num_k_tiles * 4)
|
||||
total_rows = num_partitions * row
|
||||
num_m_tiles_out = (total_rows + 128 - 1) // 128
|
||||
sf_out = torch.zeros(
|
||||
num_m_tiles_out,
|
||||
4,
|
||||
32,
|
||||
num_k_tiles,
|
||||
4,
|
||||
dtype=sf.dtype,
|
||||
device=sf.device,
|
||||
)
|
||||
sf_out_reshaped = sf_out.view(num_m_tiles_out * 32 * 4, num_k_tiles * 4)
|
||||
sf_out_reshaped[:total_rows] = sf_unswizzle[:, :row].reshape(total_rows, -1)
|
||||
sf_out_swizzle = sf_out.transpose(1, 3).reshape(-1)
|
||||
return sf_out_swizzle
|
||||
"""Reswizzle FP4 scaling factors using C++ torch op implementation.
|
||||
It unswizzles the scaling factors in each partition first, then concatenates them together, and finally swizzles them back.
|
||||
Args:
|
||||
sf: The (padded and) swizzled scaling factors.
|
||||
rows: rows of the original unquantized tensor
|
||||
cols: cols of the original unquantized tensor
|
||||
scaling_vector_size: the size of the scaling vector
|
||||
Returns:
|
||||
1D reswizzled scaling factors
|
||||
"""
|
||||
sf_cols = ceil_div(cols, scaling_vector_size)
|
||||
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
|
||||
padded_cols = padded_sf_cols * scaling_vector_size
|
||||
|
||||
assert sf.numel() % (padded_rows * padded_sf_cols) == 0
|
||||
num_partitions = sf.numel() // (padded_rows * padded_sf_cols)
|
||||
|
||||
sf_reshaped = sf.view(num_partitions, padded_rows, padded_sf_cols)
|
||||
|
||||
# Unswizzle each partition
|
||||
sf_unswizzled = unswizzle_sf(sf_reshaped, padded_rows, padded_cols,
|
||||
scaling_vector_size)
|
||||
|
||||
# Brings the unswizzled scaling factors in each partition together
|
||||
total_rows = num_partitions * rows
|
||||
sf_unswizzled = sf_unswizzled.view(num_partitions, padded_rows,
|
||||
padded_sf_cols)
|
||||
sf_concatenated = sf_unswizzled[:, :rows, :sf_cols].contiguous(
|
||||
) # TODO: This will incur a elementwise kernel
|
||||
sf_concatenated = sf_concatenated.view(total_rows, sf_cols)
|
||||
|
||||
# Finally swizzle the concatenated scaling factors
|
||||
return swizzle_sf(sf_concatenated, total_rows, cols, scaling_vector_size)
|
||||
|
||||
|
||||
def next_positive_power_of_2(x: int) -> int:
|
||||
|
||||
20
tensorrt_llm/math_utils.py
Normal file
20
tensorrt_llm/math_utils.py
Normal file
@ -0,0 +1,20 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def pad_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
@ -6,11 +6,6 @@ import torch
|
||||
SF_DTYPE = torch.uint8
|
||||
FLOAT4_E2M1X2 = torch.uint8
|
||||
|
||||
|
||||
def pad_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
|
||||
# For GEMM autotuning.
|
||||
# Taken from https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/include/tensorrt_llm/runtime//modelConfig.h#L38
|
||||
# TODO: move to model config, tune for blackwell hardware
|
||||
@ -24,6 +19,10 @@ fp4_buckets = FP4_BUCKETS
|
||||
__all__ = ['float4_e2m1x2', 'float4_sf_dtype', 'pad_up', 'fp4_buckets']
|
||||
|
||||
|
||||
def pad_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
|
||||
class FP4GemmType(IntEnum):
|
||||
W4A4_NVFP4_NVFP4 = 0
|
||||
W4A8_MXFP4_MXFP8 = 1
|
||||
|
||||
226
tests/unittest/_torch/thop/test_fp4_swizzle.py
Normal file
226
tests/unittest/_torch/thop/test_fp4_swizzle.py
Normal file
@ -0,0 +1,226 @@
|
||||
import pytest
|
||||
import torch
|
||||
from utils.util import skip_pre_blackwell
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
from tensorrt_llm._torch.utils import (compute_swizzled_sf_shape, reswizzle_sf,
|
||||
swizzle_sf, unswizzle_sf)
|
||||
from tensorrt_llm.math_utils import ceil_div
|
||||
|
||||
|
||||
# Reference PyTorch implementations (original)
|
||||
def swizzle_sf_ref(sf: torch.Tensor,
|
||||
row: int,
|
||||
col: int,
|
||||
scaling_vector_size: int = 16):
|
||||
"""Reference PyTorch implementation of swizzle_sf"""
|
||||
col_sf = ceil_div(col, scaling_vector_size)
|
||||
num_m_tiles = ceil_div(row, 128)
|
||||
num_k_tiles = ceil_div(col_sf, 4)
|
||||
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
|
||||
sf_full = torch.zeros(num_m_tiles * 32 * 4,
|
||||
num_k_tiles * 4,
|
||||
dtype=sf.dtype,
|
||||
device=sf.device)
|
||||
sf_full[:row, :col_sf] = sf[:row, :col_sf]
|
||||
sf_full_reshaped = sf_full.view(num_m_tiles, 4, 32, num_k_tiles, 4)
|
||||
sf_full_swizzle = sf_full_reshaped.transpose(1, 3)
|
||||
sf_swizzle = sf_full_swizzle.reshape(-1)
|
||||
return sf_swizzle
|
||||
|
||||
|
||||
def unswizzle_sf_ref(sf: torch.Tensor,
|
||||
row: int,
|
||||
col: int,
|
||||
scaling_vector_size: int = 16):
|
||||
"""Reference PyTorch implementation of unswizzle_sf"""
|
||||
cols_sf = ceil_div(col, scaling_vector_size)
|
||||
num_m_tiles = ceil_div(row, 128)
|
||||
num_k_tiles = ceil_div(cols_sf, 4)
|
||||
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
|
||||
sf_reshaped = sf.view(num_m_tiles, num_k_tiles, 32, 4, 4)
|
||||
sf_unswizzle = sf_reshaped.transpose(1, 3)
|
||||
sf_unswizzle = sf_unswizzle.reshape(num_m_tiles * 32 * 4, num_k_tiles * 4)
|
||||
return sf_unswizzle.contiguous()
|
||||
|
||||
|
||||
def reswizzle_sf_ref(sf: torch.Tensor,
|
||||
row: int,
|
||||
col: int,
|
||||
scaling_vector_size: int = 16):
|
||||
"""Reference PyTorch implementation of reswizzle_sf"""
|
||||
cols_sf = ceil_div(col, scaling_vector_size)
|
||||
num_m_tiles = ceil_div(row, 128)
|
||||
num_k_tiles = ceil_div(cols_sf, 4)
|
||||
partition_size = num_m_tiles * num_k_tiles * 32 * 4 * 4
|
||||
num_partitions = sf.numel() // partition_size
|
||||
sf_reshaped = sf.view(num_partitions, num_m_tiles, num_k_tiles, 32, 4, 4)
|
||||
sf_unswizzle = sf_reshaped.transpose(2, 4)
|
||||
sf_unswizzle = sf_unswizzle.reshape(num_partitions, num_m_tiles * 32 * 4,
|
||||
num_k_tiles * 4)
|
||||
total_rows = num_partitions * row
|
||||
num_m_tiles_out = ceil_div(total_rows, 128)
|
||||
sf_out = torch.zeros(
|
||||
num_m_tiles_out,
|
||||
4,
|
||||
32,
|
||||
num_k_tiles,
|
||||
4,
|
||||
dtype=sf.dtype,
|
||||
device=sf.device,
|
||||
)
|
||||
sf_out_reshaped = sf_out.view(num_m_tiles_out * 32 * 4, num_k_tiles * 4)
|
||||
sf_out_reshaped[:total_rows] = sf_unswizzle[:, :row].reshape(total_rows, -1)
|
||||
sf_out_swizzle = sf_out.transpose(1, 3).reshape(-1)
|
||||
return sf_out_swizzle
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize(
|
||||
"rows,cols",
|
||||
[
|
||||
(1, 1), # test padding
|
||||
(1, 16),
|
||||
(1, 63),
|
||||
(127, 1),
|
||||
(127, 16),
|
||||
(127, 63),
|
||||
(128, 64), # 1x1 tiles
|
||||
(128, 128), # 1×2 tiles
|
||||
(256, 64), # 2×1 tiles
|
||||
(512, 256), # 4×4 tiles
|
||||
])
|
||||
def test_swizzle_sf(rows, cols):
|
||||
"""Test C++ swizzle_sf against PyTorch reference implementation"""
|
||||
scaling_vector_size = 16
|
||||
sf_cols = ceil_div(cols, scaling_vector_size)
|
||||
|
||||
# Create scaling factor data using fp4_sf_dtype
|
||||
sf_data = torch.randint(0,
|
||||
256, (rows * sf_cols, ),
|
||||
dtype=fp4_utils.float4_sf_dtype,
|
||||
device="cuda").view(rows, sf_cols)
|
||||
|
||||
# Apply reference implementation
|
||||
ref_result = swizzle_sf_ref(sf_data, rows, cols, scaling_vector_size)
|
||||
|
||||
# Apply C++ implementation
|
||||
result = swizzle_sf(sf_data, rows, cols, scaling_vector_size)
|
||||
|
||||
# Verify results are equivalent
|
||||
torch.testing.assert_close(result, ref_result)
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize(
|
||||
"rows,cols",
|
||||
[
|
||||
(128, 64), # 1x1 tiles
|
||||
(128, 128), # 1×2 tiles
|
||||
(256, 64), # 2×1 tiles
|
||||
(512, 256), # 4×4 tiles
|
||||
])
|
||||
def test_unswizzle_sf(rows, cols):
|
||||
"""Test C++ unswizzle_sf against PyTorch reference implementation"""
|
||||
scaling_vector_size = 16
|
||||
sf_cols = ceil_div(cols, scaling_vector_size)
|
||||
|
||||
# Create scaling factor data by first swizzling with reference implementation
|
||||
original_sf_data = torch.randint(0,
|
||||
256, (rows * sf_cols, ),
|
||||
dtype=fp4_utils.float4_sf_dtype,
|
||||
device="cuda").view(rows, sf_cols)
|
||||
swizzled_sf_data = swizzle_sf_ref(original_sf_data, rows, cols,
|
||||
scaling_vector_size)
|
||||
# Apply reference unswizzle
|
||||
ref_result = unswizzle_sf_ref(swizzled_sf_data, rows, cols,
|
||||
scaling_vector_size)
|
||||
|
||||
# Note that unlike swizzle_sf, unswizzle_sf does not return a 1D tensor
|
||||
result = unswizzle_sf(swizzled_sf_data, rows, cols, scaling_vector_size)
|
||||
|
||||
# Verify C++ result matches reference result
|
||||
torch.testing.assert_close(result, ref_result)
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize(
|
||||
"rows,cols",
|
||||
[
|
||||
(1, 1), # test padding
|
||||
(1, 16),
|
||||
(1, 63),
|
||||
(127, 1),
|
||||
(127, 16),
|
||||
(127, 63),
|
||||
(128, 64), # 1x1 tiles
|
||||
(128, 128), # 1×2 tiles
|
||||
(256, 64), # 2×1 tiles
|
||||
(512, 256), # 4×4 tiles
|
||||
])
|
||||
def test_swizzle_round_trip(rows, cols):
|
||||
"""Test that swizzle/unswizzle operations are inverse of each other"""
|
||||
scaling_vector_size = 16
|
||||
sf_cols = ceil_div(cols, scaling_vector_size)
|
||||
|
||||
# Create scaling factor data
|
||||
original_sf_data = torch.randint(0,
|
||||
256, (rows * sf_cols, ),
|
||||
dtype=fp4_utils.float4_sf_dtype,
|
||||
device="cuda")
|
||||
|
||||
# Apply swizzle then unswizzle using the utils functions
|
||||
swizzled_sf = swizzle_sf(original_sf_data, rows, cols, scaling_vector_size)
|
||||
|
||||
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
|
||||
|
||||
# Check that the swizzled scaling factor data is padded correctly
|
||||
assert padded_rows * padded_sf_cols == swizzled_sf.numel()
|
||||
|
||||
padded_cols = padded_sf_cols * scaling_vector_size
|
||||
unswizzled_sf = unswizzle_sf(swizzled_sf, padded_rows, padded_cols,
|
||||
scaling_vector_size)
|
||||
unswizzled_sf = unswizzled_sf[:rows, :sf_cols]
|
||||
|
||||
# Verify round-trip preserves original scaling factor data
|
||||
torch.testing.assert_close(original_sf_data.view(rows, sf_cols),
|
||||
unswizzled_sf[:rows, :sf_cols])
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize("num_partitions", [1, 2, 3])
|
||||
@pytest.mark.parametrize(
|
||||
"rows,cols",
|
||||
[
|
||||
(1, 1), # test padding
|
||||
(1, 16),
|
||||
(1, 63),
|
||||
(127, 1),
|
||||
(127, 16),
|
||||
(127, 63),
|
||||
(128, 64), # 1x1 tiles
|
||||
(128, 128), # 1×2 tiles
|
||||
(256, 64), # 2×1 tiles
|
||||
(512, 256), # 4×4 tiles
|
||||
])
|
||||
def test_reswizzle_sf(num_partitions, rows, cols):
|
||||
"""Test C++ reswizzle_sf against PyTorch reference implementation"""
|
||||
scaling_vector_size = 16
|
||||
sf_cols = ceil_div(cols, scaling_vector_size)
|
||||
|
||||
original_sf_data = torch.randint(0,
|
||||
256, (num_partitions, rows, sf_cols),
|
||||
dtype=fp4_utils.float4_sf_dtype,
|
||||
device="cuda")
|
||||
swizzled_sf_data = swizzle_sf(original_sf_data, rows, cols,
|
||||
scaling_vector_size)
|
||||
|
||||
# Apply reference reswizzle
|
||||
ref_result = reswizzle_sf_ref(swizzled_sf_data, rows, cols,
|
||||
scaling_vector_size)
|
||||
|
||||
# Apply C++ reswizzle
|
||||
result = reswizzle_sf(swizzled_sf_data, rows, cols, scaling_vector_size)
|
||||
|
||||
# Verify results are equivalent
|
||||
torch.testing.assert_close(result, ref_result)
|
||||
Loading…
Reference in New Issue
Block a user