TensorRT-LLMs/cpp/tensorrt_llm/plugins/quantizeToFP4Plugin/quantizeToFP4Plugin.cpp
Guoming Zhang 57079cecb3
[None][chroe] Rename TensorRT-LLM to TensorRT LLM for source code. (#7851)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
2025-09-22 10:05:47 -07:00

302 lines
9.6 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "quantizeToFP4Plugin.h"
#include "pluginUtils.h"
#include "tensorrt_llm/kernels/quantization.h"
#include <NvInferRuntimeBase.h>
using namespace nvinfer1;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::common;
using tensorrt_llm::plugins::QuantizeToFP4PluginCreator;
using tensorrt_llm::plugins::QuantizeToFP4Plugin;
constexpr nvinfer1::DataType FP4_DTYPE = nvinfer1::DataType::kFP4;
constexpr nvinfer1::DataType FP8_DTYPE = nvinfer1::DataType::kFP8;
static char const* QUANT_FP4_PLUGIN_VERSION{"1"};
static char const* QUANT_FP4_PLUGIN_NAME{"QuantizeToFP4"};
PluginFieldCollection QuantizeToFP4PluginCreator::mFC{};
std::vector<nvinfer1::PluginField> QuantizeToFP4PluginCreator::mPluginAttributes;
QuantizeToFP4Plugin::QuantizeToFP4Plugin(){};
// Parameterized constructor
QuantizeToFP4Plugin::QuantizeToFP4Plugin(void const* data, size_t length)
{
char const *d = reinterpret_cast<char const*>(data), *a = d;
TLLM_CHECK_WITH_INFO(d == a + length,
"Expected length (%d) != real length (%d). This is often "
"caused by using different TensorRT LLM version to build "
"engine and run engine.",
(int) length, (int) (d - a));
}
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* QuantizeToFP4Plugin::clone() const noexcept
{
auto* plugin = new QuantizeToFP4Plugin();
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
nvinfer1::DimsExprs QuantizeToFP4Plugin::getOutputDimensions(
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
// Quantized output in FP4 datatype.
if (outputIndex == 0)
{
DimsExprs ret;
ret.nbDims = inputs[0].nbDims;
for (int di = 0; di < ret.nbDims; ++di)
{
ret.d[di] = inputs[0].d[di];
}
// // Div up by 16 as the storage type has 16 FP4 values per element.
// ret.d[ret.nbDims - 1]
// = exprBuilder.operation(DimensionOperation::kCEIL_DIV, *ret.d[ret.nbDims - 1],
// *exprBuilder.constant(16));
return ret;
}
// Scaling Factors in FP8.
else if (outputIndex == 1)
{
DimsExprs ret;
ret.nbDims = inputs[0].nbDims;
for (int di = 0; di < ret.nbDims; ++di)
{
ret.d[di] = inputs[0].d[di];
}
// Sequence dimension or token dimension.
// Pad to multiple of 128.
auto dimM
= exprBuilder.operation(DimensionOperation::kCEIL_DIV, *ret.d[ret.nbDims - 2], *exprBuilder.constant(128));
ret.d[ret.nbDims - 2] = exprBuilder.operation(DimensionOperation::kPROD, *dimM, *exprBuilder.constant(128));
// Hidden size dimension.
// Div (rounding up) by 16 since 16 elements share one SF and SF padded to k%4==0.
ret.d[ret.nbDims - 1]
= exprBuilder.operation(DimensionOperation::kCEIL_DIV, *ret.d[ret.nbDims - 1], *exprBuilder.constant(16));
return ret;
}
return DimsExprs{};
}
bool QuantizeToFP4Plugin::supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
{
// half input + float global_sf + fp4 output (e2m1) + fp8 SF output.
int const totalPoses = 2 + 2;
TLLM_CHECK(0 <= pos && pos < totalPoses);
TLLM_CHECK(nbInputs == 2);
switch (pos)
{
case 0:
return (inOut[pos].type == nvinfer1::DataType::kHALF || inOut[pos].type == nvinfer1::DataType::kBF16
|| inOut[pos].type == nvinfer1::DataType::kFP8)
&& (inOut[pos].format == TensorFormat::kLINEAR);
case 1: return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR);
case 2: return (inOut[pos].type == FP4_DTYPE) && (inOut[pos].format == TensorFormat::kLINEAR);
case 3: return (inOut[pos].type == FP8_DTYPE) && (inOut[pos].format == TensorFormat::kLINEAR);
default: break;
}
return false;
}
void QuantizeToFP4Plugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
{
}
size_t QuantizeToFP4Plugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
{
return 0;
}
int QuantizeToFP4Plugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
// inputs
// input [M(*), N] half data type
// SF scale [1] float data type
// used to scale SF from input range to fp8 range (448.f / (MaxVal of input / 6.f))
// outputs
// output [M(*), N] fp4 storage (E2M1)
// SF output [M, N / 16] fp8 storage (UE4M3)
int64_t m64 = 1;
for (int i = 0; i < inputDesc[0].dims.nbDims - 1; ++i)
{
m64 *= inputDesc[0].dims.d[i];
}
int const m = TLLM_INT32_CAST(m64);
int const n = TLLM_INT32_CAST(inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]);
TLLM_CHECK_WITH_INFO(n % 16 == 0, "the N dimension must be multiple of 16.");
float const* SFScale = static_cast<float const*>(inputs[1]);
int64_t* output = reinterpret_cast<int64_t*>(outputs[0]);
int32_t* SFoutput = reinterpret_cast<int32_t*>(outputs[1]);
DataType inputDtype = inputDesc[0].type;
switch (inputDtype)
{
case DataType::kHALF:
{
auto input = reinterpret_cast<half const*>(inputs[0]);
invokeFP4Quantization(m, n, input, SFScale, output, SFoutput, false, FP4QuantizationSFLayout::SWIZZLED,
mMultiProcessorCount, stream);
break;
}
case DataType::kBF16:
{
auto input = reinterpret_cast<__nv_bfloat16 const*>(inputs[0]);
invokeFP4Quantization(m, n, input, SFScale, output, SFoutput, false, FP4QuantizationSFLayout::SWIZZLED,
mMultiProcessorCount, stream);
break;
}
case DataType::kFP8:
{
auto input = reinterpret_cast<__nv_fp8_e4m3 const*>(inputs[0]);
invokeFP4Quantization(m, n, input, SFScale, output, SFoutput, false, FP4QuantizationSFLayout::SWIZZLED,
mMultiProcessorCount, stream);
break;
}
default: TLLM_LOG_ERROR("only half, bfloat16 and fp8 data type are supported."); break;
}
// Use UE4M3 scales by default.
return 0;
}
// IPluginV2Ext Methods
nvinfer1::DataType QuantizeToFP4Plugin::getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
{
if (index == 0)
{
// Output 0 quantized output.
return FP4_DTYPE;
}
// Output 1 SF (scaling factors).
return FP8_DTYPE;
}
// IPluginV2 Methods
char const* QuantizeToFP4Plugin::getPluginType() const noexcept
{
return QUANT_FP4_PLUGIN_NAME;
}
char const* QuantizeToFP4Plugin::getPluginVersion() const noexcept
{
return QUANT_FP4_PLUGIN_VERSION;
}
int QuantizeToFP4Plugin::getNbOutputs() const noexcept
{
return 2;
}
int QuantizeToFP4Plugin::initialize() noexcept
{
return 0;
}
void QuantizeToFP4Plugin::terminate() noexcept {}
size_t QuantizeToFP4Plugin::getSerializationSize() const noexcept
{
return 0;
}
void QuantizeToFP4Plugin::serialize(void* buffer) const noexcept
{
char *d = static_cast<char*>(buffer), *a = d;
TLLM_CHECK(d == a + getSerializationSize());
}
void QuantizeToFP4Plugin::destroy() noexcept
{
// This gets called when the network containing plugin is destroyed
delete this;
}
///////////////
QuantizeToFP4PluginCreator::QuantizeToFP4PluginCreator()
{
// Fill PluginFieldCollection with PluginField arguments metadata
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* QuantizeToFP4PluginCreator::getPluginName() const noexcept
{
return QUANT_FP4_PLUGIN_NAME;
}
char const* QuantizeToFP4PluginCreator::getPluginVersion() const noexcept
{
return QUANT_FP4_PLUGIN_VERSION;
}
PluginFieldCollection const* QuantizeToFP4PluginCreator::getFieldNames() noexcept
{
return &mFC;
}
IPluginV2* QuantizeToFP4PluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
try
{
auto* obj = new QuantizeToFP4Plugin();
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
IPluginV2* QuantizeToFP4PluginCreator::deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept
{
// This object will be deleted when the network is destroyed, which will
// call QuantizeToFP4Plugin::destroy()
try
{
auto* obj = new QuantizeToFP4Plugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}