TensorRT-LLMs/tests/_torch/test_fp4_linear.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

82 lines
2.8 KiB
Python

import os
import sys
import pytest
import torch
from utils.util import skip_pre_blackwell
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
scaling_vector_size = 16
@skip_pre_blackwell
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.bfloat16]
) # TODO: Do we need float32 test case? fp4_quantize only supports fp16, bf16, fp8_e4m3
def test_fp4_linear(dtype):
SEQ_LEN = 10
HIDDEN_SIZE = 128
torch.manual_seed(0)
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
x_sf_global = (448 * 6) / x.abs().max().float()
w = torch.randn((HIDDEN_SIZE, HIDDEN_SIZE), dtype=dtype).cuda()
w_sf_global = (448 * 6) / w.abs().max().float()
w_fp4, w_sf_block = torch.ops.trtllm.fp4_quantize(w, w_sf_global,
scaling_vector_size,
False)
qc = QuantConfig(quant_algo=QuantAlgo.NVFP4)
l_fp4 = Linear(in_features=HIDDEN_SIZE,
out_features=HIDDEN_SIZE,
bias=False,
dtype=dtype,
quant_config=qc)
assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2
assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype
w_sf_block_unswizzled = (
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
w_sf_block.cpu().view(HIDDEN_SIZE, -1)))
l_fp4.load_weights([{
'input_scale':
1.0 / x_sf_global.cpu(), # Simulates amax/(448*6) in modelopt ckpt
'weight':
w_fp4.cpu(),
'weight_scale':
w_sf_block_unswizzled.view(
torch.float8_e4m3fn), # Simulates float8_e4m3fn in modelopt ckpt
'weight_scale_2':
1.0 / w_sf_global.cpu() # Simulates amax/(448*6) in modelopt ckpt
}])
l_fp4 = l_fp4.cuda()
torch.testing.assert_close(l_fp4.weight, w_fp4)
torch.testing.assert_close(l_fp4.input_scale[0], x_sf_global)
torch.testing.assert_close(l_fp4.weight_scale, w_sf_block)
alpha_ref = 1.0 / (w_sf_global * x_sf_global)
torch.testing.assert_close(l_fp4.alpha[0], alpha_ref)
with torch.inference_mode():
output = l_fp4.forward(x)
# ref linear
with torch.inference_mode():
x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize(
x, x_sf_global, scaling_vector_size, False)
output_ref = torch.ops.trtllm.fp4_gemm(x_fp4, w_fp4, x_sf_block,
w_sf_block, alpha_ref, False,
dtype)
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, output_ref)