mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
import os
|
|
import sys
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from tensorrt_llm._torch.modules.linear import Linear
|
|
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
from utils.util import skip_pre_hopper
|
|
|
|
|
|
@skip_pre_hopper
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
def test_fp8_linear(dtype):
|
|
SEQ_LEN = 10
|
|
HIDDEN_SIZE = 128
|
|
torch.manual_seed(0)
|
|
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
|
|
x_fp8, x_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(x)
|
|
x_fp8 = x_fp8.view(torch.float8_e4m3fn)
|
|
x_scale = x_scale.float().squeeze()
|
|
w = torch.randn((HIDDEN_SIZE, HIDDEN_SIZE), dtype=dtype).cuda()
|
|
w_fp8, w_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(w)
|
|
w_fp8 = w_fp8.view(torch.float8_e4m3fn)
|
|
w_scale = w_scale.float().squeeze()
|
|
|
|
qc = QuantConfig(quant_algo=QuantAlgo.FP8)
|
|
l0 = Linear(in_features=HIDDEN_SIZE,
|
|
out_features=HIDDEN_SIZE,
|
|
bias=False,
|
|
dtype=dtype,
|
|
quant_config=qc)
|
|
assert l0.weight.dtype == torch.float8_e4m3fn
|
|
l0.load_weights([{
|
|
'weight': w_fp8,
|
|
'weight_scale': w_scale,
|
|
'input_scale': x_scale
|
|
}])
|
|
l0.cuda()
|
|
torch.testing.assert_close(l0.weight, w_fp8)
|
|
torch.testing.assert_close(l0.weight_scale, w_scale)
|
|
torch.testing.assert_close(l0.input_scale, x_scale)
|
|
|
|
with torch.inference_mode():
|
|
output = l0.forward(x)
|
|
|
|
# torch run
|
|
def ref_quant(x_, x_scale_):
|
|
x_ = x_.float()
|
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
inv_scale = x_scale_.reciprocal()
|
|
x_fp8_ = (x_ * inv_scale).clamp(min=finfo.min, max=finfo.max)
|
|
return x_fp8_.to(torch.float8_e4m3fn)
|
|
|
|
def ref_linear():
|
|
ref_x_fp8 = ref_quant(x, x_scale)
|
|
ref_output = torch._scaled_mm(ref_x_fp8,
|
|
w_fp8.t(),
|
|
out_dtype=dtype,
|
|
scale_a=x_scale,
|
|
scale_b=w_scale,
|
|
bias=l0.bias)
|
|
return ref_output
|
|
|
|
with torch.inference_mode():
|
|
ref_output = ref_linear()
|
|
|
|
# compare
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(output, ref_output)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_fp8_linear(torch.float16)
|