TensorRT-LLMs/tests/_torch/test_fp8_linear.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

77 lines
2.4 KiB
Python

import os
import sys
import pytest
import torch
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.util import skip_pre_hopper
@skip_pre_hopper
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fp8_linear(dtype):
SEQ_LEN = 10
HIDDEN_SIZE = 128
torch.manual_seed(0)
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
x_fp8, x_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(x)
x_fp8 = x_fp8.view(torch.float8_e4m3fn)
x_scale = x_scale.float().squeeze()
w = torch.randn((HIDDEN_SIZE, HIDDEN_SIZE), dtype=dtype).cuda()
w_fp8, w_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(w)
w_fp8 = w_fp8.view(torch.float8_e4m3fn)
w_scale = w_scale.float().squeeze()
qc = QuantConfig(quant_algo=QuantAlgo.FP8)
l0 = Linear(in_features=HIDDEN_SIZE,
out_features=HIDDEN_SIZE,
bias=False,
dtype=dtype,
quant_config=qc)
assert l0.weight.dtype == torch.float8_e4m3fn
l0.load_weights([{
'weight': w_fp8,
'weight_scale': w_scale,
'input_scale': x_scale
}])
l0.cuda()
torch.testing.assert_close(l0.weight, w_fp8)
torch.testing.assert_close(l0.weight_scale, w_scale)
torch.testing.assert_close(l0.input_scale, x_scale)
with torch.inference_mode():
output = l0.forward(x)
# torch run
def ref_quant(x_, x_scale_):
x_ = x_.float()
finfo = torch.finfo(torch.float8_e4m3fn)
inv_scale = x_scale_.reciprocal()
x_fp8_ = (x_ * inv_scale).clamp(min=finfo.min, max=finfo.max)
return x_fp8_.to(torch.float8_e4m3fn)
def ref_linear():
ref_x_fp8 = ref_quant(x, x_scale)
ref_output = torch._scaled_mm(ref_x_fp8,
w_fp8.t(),
out_dtype=dtype,
scale_a=x_scale,
scale_b=w_scale,
bias=l0.bias)
return ref_output
with torch.inference_mode():
ref_output = ref_linear()
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, ref_output)
if __name__ == '__main__':
test_fp8_linear(torch.float16)