TensorRT-LLMs/cpp/tensorrt_llm/thop/fp4Op.cpp
Yihan Wang 9df4dad3b6
[None][fix] Introduce inline namespace to avoid symbol collision (#9541)
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
2025-12-12 23:32:15 +08:00

512 lines
20 KiB
C++

/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/quantization.h"
#include "tensorrt_llm/thop/fp8Op.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <cuda_fp16.h>
#include <cuda_fp4.h>
#include <cuda_fp8.h>
#include <cstdint>
namespace th = torch;
TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{
static int getExp(float v)
{
int vIntRepr;
memcpy(&vIntRepr, &v, sizeof(vIntRepr));
int expBits = (vIntRepr >> 23) & 0xff;
return expBits - 127;
}
static int getMantissaBits(float v)
{
int vIntRepr;
memcpy(&vIntRepr, &v, sizeof(vIntRepr));
int mantissaBits = vIntRepr & 0x7fffff;
return mantissaBits;
}
static bool getSign(float v)
{
int vIntRepr;
memcpy(&vIntRepr, &v, sizeof(vIntRepr));
return vIntRepr >> 31;
}
static float makeExpFloat(int expValue)
{
expValue += 127;
expValue <<= 23;
float vFloat;
memcpy(&vFloat, &expValue, sizeof(vFloat));
return vFloat;
}
/*
* E2M1 to float
* 0111 -> 6
* 0110 -> 4
* 0101 -> 3
* 0100 -> 2
* 0011 -> 1.5
* 0010 -> 1
* 0001 -> 0.5
* 0000 -> 0
*/
static float const kE2M1ToFloatArray[] = {0, 0.5, 1, 1.5, 2, 3, 4, 6};
static float const kE2M1Array[] = {0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5};
static int const kE2M1Count = 8;
uint8_t floatToE2M1(float value)
{
float absValue = fabs(value);
TORCH_CHECK_LT(absValue, 8.0f);
uint8_t result = getSign(value) ? 8 : 0;
int fp4AbsValue = kE2M1Count - 1;
for (; fp4AbsValue > 0; --fp4AbsValue)
{
if (kE2M1Array[fp4AbsValue] < absValue)
break;
// Tie to even.
if (kE2M1Array[fp4AbsValue] == absValue && !(fp4AbsValue & 1))
break;
}
result |= fp4AbsValue;
return result;
}
float e2M1ToFloat(uint8_t value)
{
bool signBit = value & 8;
auto absValue = value & 7;
float result = kE2M1ToFloatArray[absValue];
if (signBit)
result = -result;
return result;
}
torch::autograd::variable_list FloatToE2M1AndUFP8SFScale(
th::Tensor floatTensor, int64_t sfVecSize, int64_t sfType, torch::optional<bool> isSfSwizzledLayout)
{
CHECK_CPU_INPUT(floatTensor, th::kFloat32);
auto inputShape = floatTensor.sizes();
TORCH_CHECK(inputShape.size() == 2, "Input should be 2D tensor.");
TORCH_CHECK(inputShape[1] % sfVecSize == 0);
th::Tensor valueE2M1 = th::zeros({inputShape[0], inputShape[1] / 2}, th::dtype(FLOAT4_E2M1X2).requires_grad(false));
th::Tensor scaleFP8SF
= th::zeros({tensorrt_llm::computeSwizzledLayoutSFSize(inputShape[0], inputShape[1] / sfVecSize)},
th::dtype(SF_DTYPE).requires_grad(false));
th::Tensor repFloat = th::zeros(inputShape, th::dtype(th::kFloat32).requires_grad(false));
int hiddenDim = inputShape[1];
int packedFp4HiddenDim = hiddenDim / 2;
int groupsPerHiddenDim = hiddenDim / sfVecSize;
// Note: if isSfSwizzledLayout is provided, use its value; otherwise default to true.
tensorrt_llm::QuantizationSFLayout layout = isSfSwizzledLayout.value_or(true)
? tensorrt_llm::QuantizationSFLayout::SWIZZLED
: tensorrt_llm::QuantizationSFLayout::LINEAR;
for (size_t vIdx = 0; vIdx < static_cast<size_t>(inputShape[0]); ++vIdx)
{
for (int group = 0; group < groupsPerHiddenDim; ++group)
{
float const* inputPtr = floatTensor.data_ptr<float>() + vIdx * hiddenDim + group * sfVecSize;
float* repPtr = repFloat.data_ptr<float>() + vIdx * hiddenDim + group * sfVecSize;
uint8_t* packedFp4Ptr = valueE2M1.data_ptr<uint8_t>() + vIdx * packedFp4HiddenDim + group * sfVecSize / 2;
int8_t* scaleFP8SFPtr = static_cast<int8_t*>(scaleFP8SF.data_ptr());
float maxAbsValue = 0.0f;
for (int i = 0; i < sfVecSize; ++i)
{
maxAbsValue = std::max(maxAbsValue, fabs(inputPtr[i]));
}
int scaleExp = getExp(maxAbsValue);
scaleExp -= 2;
if (sfType == 0)
{
scaleExp = std::max(scaleExp, -126);
int e8M0Scale = scaleExp + 127;
TORCH_CHECK_GT(e8M0Scale, 0);
TORCH_CHECK_LT(e8M0Scale, 255);
int8_t e8M0ScaleOut = e8M0Scale & 0xff;
scaleFP8SFPtr[computeSFIndex(vIdx, group, inputShape[0], groupsPerHiddenDim, layout)] = e8M0ScaleOut;
}
else
{
scaleExp = std::max(scaleExp, -6);
int e4M3Scale = scaleExp + 7;
TORCH_CHECK_GT(e4M3Scale, 0);
TORCH_CHECK_LT(e4M3Scale, 15);
int8_t e4M3ScaleOut = (e4M3Scale & 0xff) << 3;
scaleFP8SFPtr[computeSFIndex(vIdx, group, inputShape[0], groupsPerHiddenDim, layout)] = e4M3ScaleOut;
}
float scaleFloat = makeExpFloat(scaleExp);
float invScaleFloat = 1.0 / scaleFloat;
// printf("vIdx=%ld, group=%d, maxAbsValue=%f, scaleExp=%d, e8M0Scale=%d, scaleFloat=%f,
// invScaleFloat=%f\n",
// vIdx, group, maxAbsValue, scaleExp, e8M0Scale, scaleFloat, invScaleFloat);
for (int i = 0; i < sfVecSize; ++i)
{
float value = inputPtr[i];
float scaledValue = invScaleFloat * value;
uint8_t fp4Value = floatToE2M1(scaledValue);
float e2M1FloatValue = e2M1ToFloat(fp4Value);
float repResult = e2M1FloatValue * scaleFloat;
repPtr[i] = repResult;
uint8_t packedValue = packedFp4Ptr[i / 2];
if (i % 2 == 0)
{
packedValue &= 0xf0;
packedValue |= fp4Value;
}
else
{
packedValue &= 0x0f;
packedValue |= (fp4Value << 4);
}
packedFp4Ptr[i / 2] = packedValue;
// printf(" i=%d, value=%f, scaledValue=%f, fp4Value=%x, e2M1FloatValue=%f, repResult=%f\n",
// i, value, scaledValue, (int)fp4Value, e2M1FloatValue, repResult);
}
}
}
return {valueE2M1, scaleFP8SF, repFloat};
}
// Preprocess the weights.
torch::autograd::variable_list HalfToE2M1AndUFP8SFScale(
th::Tensor halfTensor, th::Tensor globalScale, int64_t sfVecSize, int64_t sfType)
{
CHECK_INPUT(halfTensor, th::kFloat16);
CHECK_INPUT(globalScale, th::kFloat32);
auto inputShape = halfTensor.sizes();
TORCH_CHECK(inputShape.size() == 2 || inputShape.size() == 3, "Input should be 2D or 3D tensor.");
bool has_experts = inputShape.size() == 3;
auto num_experts = has_experts ? inputShape[0] : 1;
auto rows = has_experts ? inputShape[1] : inputShape[0];
auto cols = has_experts ? inputShape[2] : inputShape[1];
auto const expert_sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(rows, cols / sfVecSize);
TORCH_CHECK(cols % sfVecSize == 0);
std::array<int64_t, 3> shape{num_experts, rows, cols / 2};
th::IntArrayRef shape_ref(shape.data() + !has_experts, shape.size() - !has_experts);
th::Tensor valueE2M1 = th::zeros(shape_ref, th::dtype(FLOAT4_E2M1X2).device(torch::kCUDA).requires_grad(false));
th::Tensor scaleFP8SF
= th::zeros({num_experts * expert_sf_size}, th::dtype(SF_DTYPE).device(torch::kCUDA).requires_grad(false));
int const mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
for (size_t eIdx = 0; eIdx < static_cast<size_t>(num_experts); eIdx++)
{
size_t const expert_elem_offset = rows * cols * eIdx;
size_t const expert_sf_offset = expert_sf_size * eIdx;
constexpr int FP4_PER_INT64 = 16;
constexpr int FP8_PER_INT32 = 4;
tensorrt_llm::kernels::invokeFP4Quantization(1, rows, cols,
reinterpret_cast<half*>(halfTensor.data_ptr()) + expert_elem_offset, globalScale.data_ptr<float>() + eIdx,
reinterpret_cast<int64_t*>(valueE2M1.data_ptr()) + expert_elem_offset / FP4_PER_INT64,
reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()) + expert_sf_offset / FP8_PER_INT32, sfType == 0,
tensorrt_llm::QuantizationSFLayout::SWIZZLED, mMultiProcessorCount, 0);
}
return {valueE2M1, scaleFP8SF};
}
// Interleave (and possibly pad) the weights block scaling factor.
// blockScale: [num_experts, rows, cols] or [rows, cols]
// Return: num_experts * pad_up(rows, 128) * pad_up(cols, 4)
th::Tensor BlockScaleInterleave(th::Tensor const& blockScale)
{
bool is_cuda = blockScale.device().is_cuda();
if (is_cuda)
{
CHECK_INPUT(blockScale, SF_DTYPE);
}
else
{
CHECK_CPU_INPUT(blockScale, SF_DTYPE);
}
auto blockScaleShape = blockScale.sizes();
TORCH_CHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3, "Block Scale should be 2D or 3D tensor.");
auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1;
auto rows = blockScaleShape.size() == 3 ? blockScaleShape[1] : blockScaleShape[0];
auto cols = blockScaleShape.size() == 3 ? blockScaleShape[2] : blockScaleShape[1];
auto expert_out_size = tensorrt_llm::computeSwizzledLayoutSFSize(rows, cols);
auto rows_padded = PadUpFn(rows, 128);
auto cols_padded = PadUpFn(cols, 4);
TORCH_CHECK(
expert_out_size == rows_padded * cols_padded, "expert_out_size should be equal to rows_padded * cols_padded.");
th::Tensor interleavedBlockScale = th::empty(
{expert_out_size * num_experts}, th::dtype(SF_DTYPE).device(blockScale.device()).requires_grad(false));
if (is_cuda)
{
const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount();
auto stream = at::cuda::getCurrentCUDAStream(blockScale.get_device());
tensorrt_llm::kernels::invokeBlockScaleInterleave(num_experts, rows, rows_padded, cols, cols_padded,
blockScale.data_ptr<uint8_t>(), static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream);
}
else
{
for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++)
{
uint8_t* interleavedBlockScalePtr
= static_cast<uint8_t*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx)
{
auto globalRowIdx = eIdx * rows + rIdx;
uint8_t* blockScalePtr = blockScale.data_ptr<uint8_t>() + globalRowIdx * cols;
for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx)
{
uint8_t sf_ori = 0;
if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols))
{
sf_ori = blockScalePtr[cIdx];
}
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::QuantizationSFLayout::SWIZZLED);
interleavedBlockScalePtr[sf_index] = sf_ori;
}
}
}
}
return interleavedBlockScale;
}
// Reverse interleave the weights block scaling factor.
// blockScale: [num_experts, rows, cols] or [rows, cols]
// Note: rows and cols are the dimensions of the original unswizzled SFMatrix, so reshape input before passing into
// this function! Return: The same shape as blockScale
th::Tensor BlockScaleInterleaveReverse(th::Tensor const& blockScale)
{
bool is_cuda = blockScale.device().is_cuda();
if (is_cuda)
{
CHECK_INPUT(blockScale, SF_DTYPE);
}
else
{
CHECK_CPU_INPUT(blockScale, SF_DTYPE);
}
auto blockScaleShape = blockScale.sizes();
TORCH_CHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3, "Block Scale should be 2D or 3D tensor.");
auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1;
auto rows = blockScaleShape.size() == 3 ? blockScaleShape[1] : blockScaleShape[0];
auto cols = blockScaleShape.size() == 3 ? blockScaleShape[2] : blockScaleShape[1];
TORCH_CHECK(rows % 128 == 0, "rows of Interleaved block scales should be multiple of 128.");
TORCH_CHECK(cols % 4 == 0, "cols of Interleaved block scales should be multiple of 4.");
auto expert_out_size = rows * cols;
th::Tensor reversedBlockScale
= th::empty(blockScaleShape, th::dtype(SF_DTYPE).device(blockScale.device()).requires_grad(false));
if (is_cuda)
{
const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount();
auto stream = at::cuda::getCurrentCUDAStream(blockScale.get_device());
tensorrt_llm::kernels::invokeBlockScaleInterleaveReverse(num_experts, rows, cols,
blockScale.data_ptr<uint8_t>(), static_cast<uint8_t*>(reversedBlockScale.data_ptr()), smCount, stream);
}
else
{
// index in the swizzled SFMatrix -> (eIdx, rIdx, cIdx) in the unswizzled SFMatrix
std::map<int, std::array<int, 3>> identity;
for (int eIdx = 0; eIdx < num_experts; eIdx++)
{
for (int rIdx = 0; rIdx < rows; ++rIdx)
{
for (int cIdx = 0; cIdx < cols; ++cIdx)
{
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::QuantizationSFLayout::SWIZZLED);
identity[eIdx * expert_out_size + sf_index] = std::array<int, 3>{eIdx, rIdx, cIdx};
}
}
}
uint8_t* blockScalePtr = static_cast<uint8_t*>(blockScale.data_ptr());
for (int i = 0; i < expert_out_size * num_experts; i++)
{
auto loc_2d = identity[i];
if (loc_2d[1] < rows && loc_2d[2] < cols)
{
uint8_t* reversedBlockScalePtr
= reversedBlockScale.data_ptr<uint8_t>() + (loc_2d[0] * rows + loc_2d[1]) * cols + loc_2d[2];
*reversedBlockScalePtr = blockScalePtr[i];
}
}
}
return reversedBlockScale;
}
th::Tensor E2M1AndUFP8SFScaleToFloat(th::Tensor valueE2M1, th::Tensor scaleFP8SF, int64_t sfVecSize, int64_t sfType)
{
CHECK_CPU_INPUT(valueE2M1, FLOAT4_E2M1X2);
CHECK_CPU_INPUT(scaleFP8SF, SF_DTYPE);
auto packedShape = valueE2M1.sizes();
auto scaleShape = scaleFP8SF.sizes();
TORCH_CHECK(packedShape.size() == 2, "valueE2M1 should be 2D tensor.");
TORCH_CHECK(scaleShape.size() == 1, "scaleFP8SF should be 1D tensor.");
th::Tensor floatTensor
= th::zeros({packedShape[0], packedShape[1] * 2}, th::dtype(th::kFloat32).requires_grad(false));
int hiddenDim = packedShape[1] * 2;
int packedFp4HiddenDim = hiddenDim / 2;
int groupsPerHiddenDim = hiddenDim / sfVecSize;
for (size_t vIdx = 0; vIdx < static_cast<size_t>(packedShape[0]); ++vIdx)
{
for (int group = 0; group < groupsPerHiddenDim; ++group)
{
float* floatPtr = floatTensor.data_ptr<float>() + vIdx * hiddenDim + group * sfVecSize;
uint8_t* packedFp4Ptr = valueE2M1.data_ptr<uint8_t>() + vIdx * packedFp4HiddenDim + group * sfVecSize / 2;
uint8_t* scaleFP8SFPtr = scaleFP8SF.data_ptr<uint8_t>();
uint8_t fp8Scale = scaleFP8SFPtr[computeSFIndex(
vIdx, group, packedShape[0], groupsPerHiddenDim, tensorrt_llm::QuantizationSFLayout::SWIZZLED)];
int scale = fp8Scale;
if (sfType == 0)
{
scale -= 127;
}
else
{
scale >>= 3;
scale -= 7;
}
float scaleFloat = makeExpFloat(scale);
for (int i = 0; i < sfVecSize; ++i)
{
uint8_t packedFp4 = packedFp4Ptr[i / 2];
if (i % 2 == 1)
{
packedFp4 >>= 4;
}
packedFp4 &= 0xf;
float value = e2M1ToFloat(packedFp4) * scaleFloat;
floatPtr[i] = value;
}
}
}
return floatTensor;
}
// Used by the (fp16 -> int4) quant layer + int4 gemm network.
th::Tensor E2M1AndUFP8SFScaleToFloatV2(th::Tensor valueE2M1, th::Tensor scaleFP8SF,
std::optional<th::Tensor> globalScale, int64_t sfVecSize, int64_t sfType, bool isSfSwizzledLayout = true)
{
CHECK_CPU_INPUT(valueE2M1, FLOAT4_E2M1X2);
CHECK_CPU_INPUT(scaleFP8SF, SF_DTYPE);
auto packedShape = valueE2M1.sizes();
auto scaleShape = scaleFP8SF.sizes();
TORCH_CHECK(packedShape.size() == 2, "valueE2M1 should be 2D tensor.");
TORCH_CHECK(scaleShape.size() == 1, "scaleFP8SF should be 1D tensor.");
th::Tensor floatTensor
= th::zeros({packedShape[0], packedShape[1] * 2}, th::dtype(th::kFloat32).requires_grad(false));
float globalScaleVal{1.0f};
if (sfType == 1)
{
TORCH_CHECK(globalScale.has_value(), "globalScale is required when sfType is 1.");
CHECK_CPU_INPUT(globalScale.value(), th::kFloat32);
globalScaleVal = globalScale->data_ptr<float>()[0];
}
int hiddenDim = packedShape[1] * 2;
int packedFp4HiddenDim = hiddenDim / 2;
int groupsPerHiddenDim = hiddenDim / sfVecSize;
tensorrt_llm::QuantizationSFLayout layout = isSfSwizzledLayout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED
: tensorrt_llm::QuantizationSFLayout::LINEAR;
for (size_t vIdx = 0; vIdx < static_cast<size_t>(packedShape[0]); ++vIdx)
{
for (int group = 0; group < groupsPerHiddenDim; ++group)
{
float* floatPtr = floatTensor.data_ptr<float>() + vIdx * hiddenDim + group * sfVecSize;
uint8_t* packedFp4Ptr = valueE2M1.data_ptr<uint8_t>() + vIdx * packedFp4HiddenDim + group * sfVecSize / 2;
uint8_t* scaleFP8SFPtr = scaleFP8SF.data_ptr<uint8_t>();
uint8_t fp8Scale = scaleFP8SFPtr[computeSFIndex(vIdx, group, packedShape[0], groupsPerHiddenDim, layout)];
float scaleFloat;
if (sfType == 0)
{
uint32_t tmp = uint32_t(fp8Scale) << 23;
memcpy(&scaleFloat, &tmp, sizeof(scaleFloat));
}
else
{
scaleFloat = float(reinterpret_cast<__nv_fp8_e4m3&>(fp8Scale));
}
scaleFloat *= globalScaleVal;
for (int i = 0; i < sfVecSize; ++i)
{
uint8_t packedFp4 = packedFp4Ptr[i / 2];
if (i % 2 == 1)
{
packedFp4 >>= 4;
}
packedFp4 &= 0xf;
float value = e2M1ToFloat(packedFp4) * scaleFloat;
floatPtr[i] = value;
}
}
}
return floatTensor;
}
} // namespace torch_ext
TRTLLM_NAMESPACE_END
static auto float_to_e2m1_and_ufp8sf_scale = torch::RegisterOperators(
"tensorrt_llm::float_to_e2m1_and_ufp8sf_scale", &tensorrt_llm::torch_ext::FloatToE2M1AndUFP8SFScale);
static auto half_to_e2m1_and_ufp8sf_scale = torch::RegisterOperators(
"tensorrt_llm::half_to_e2m1_and_ufp8sf_scale", &tensorrt_llm::torch_ext::HalfToE2M1AndUFP8SFScale);
static auto e2m1_and_ufp8sf_scale_to_float = torch::RegisterOperators(
"tensorrt_llm::e2m1_and_ufp8sf_scale_to_float", &tensorrt_llm::torch_ext::E2M1AndUFP8SFScaleToFloat);
static auto e2m1_and_ufp8sf_scale_to_float_v2 = torch::RegisterOperators(
"tensorrt_llm::e2m1_and_ufp8sf_scale_to_float_v2", &tensorrt_llm::torch_ext::E2M1AndUFP8SFScaleToFloatV2);
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("block_scale_interleave(Tensor input) -> Tensor");
m.def("block_scale_interleave_reverse(Tensor input) -> Tensor");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("block_scale_interleave", &tensorrt_llm::torch_ext::BlockScaleInterleave);
m.impl("block_scale_interleave_reverse", &tensorrt_llm::torch_ext::BlockScaleInterleaveReverse);
}
TORCH_LIBRARY_IMPL(trtllm, CPU, m)
{
m.impl("block_scale_interleave", &tensorrt_llm::torch_ext::BlockScaleInterleave);
m.impl("block_scale_interleave_reverse", &tensorrt_llm::torch_ext::BlockScaleInterleaveReverse);
}