mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-28 22:56:13 +08:00
77 lines
2.7 KiB
Python
77 lines
2.7 KiB
Python
import pytest
|
|
import torch
|
|
from utils.util import skip_pre_blackwell
|
|
|
|
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
|
from tensorrt_llm._torch.modules.linear import Linear
|
|
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
|
|
|
scaling_vector_size = 16
|
|
|
|
|
|
@skip_pre_blackwell
|
|
@pytest.mark.parametrize(
|
|
"dtype", [torch.float16, torch.bfloat16]
|
|
) # TODO: Do we need float32 test case? fp4_quantize only supports fp16, bf16, fp8_e4m3
|
|
def test_fp4_linear(dtype):
|
|
SEQ_LEN = 10
|
|
HIDDEN_SIZE = 128
|
|
torch.manual_seed(0)
|
|
|
|
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
|
|
x_sf_global = (448 * 6) / x.abs().max().float()
|
|
|
|
w = torch.randn((HIDDEN_SIZE, HIDDEN_SIZE), dtype=dtype).cuda()
|
|
w_sf_global = (448 * 6) / w.abs().max().float()
|
|
w_fp4, w_sf_block = torch.ops.trtllm.fp4_quantize(w, w_sf_global,
|
|
scaling_vector_size,
|
|
False)
|
|
|
|
qc = QuantConfig(quant_algo=QuantAlgo.NVFP4)
|
|
l_fp4 = Linear(in_features=HIDDEN_SIZE,
|
|
out_features=HIDDEN_SIZE,
|
|
bias=False,
|
|
dtype=dtype,
|
|
quant_config=qc)
|
|
|
|
assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2
|
|
assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype
|
|
|
|
w_sf_block_unswizzled = (
|
|
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
|
|
w_sf_block.cpu().view(HIDDEN_SIZE, -1)))
|
|
|
|
l_fp4.load_weights([{
|
|
'input_scale':
|
|
1.0 / x_sf_global.cpu(), # Simulates amax/(448*6) in modelopt ckpt
|
|
'weight':
|
|
w_fp4.cpu(),
|
|
'weight_scale':
|
|
w_sf_block_unswizzled.view(
|
|
torch.float8_e4m3fn), # Simulates float8_e4m3fn in modelopt ckpt
|
|
'weight_scale_2':
|
|
1.0 / w_sf_global.cpu() # Simulates amax/(448*6) in modelopt ckpt
|
|
}])
|
|
l_fp4 = l_fp4.cuda()
|
|
|
|
torch.testing.assert_close(l_fp4.weight, w_fp4)
|
|
torch.testing.assert_close(l_fp4.input_scale[0], x_sf_global)
|
|
torch.testing.assert_close(l_fp4.weight_scale, w_sf_block)
|
|
alpha_ref = 1.0 / (w_sf_global * x_sf_global)
|
|
torch.testing.assert_close(l_fp4.alpha[0], alpha_ref)
|
|
|
|
with torch.inference_mode():
|
|
output = l_fp4.forward(x)
|
|
|
|
# ref linear
|
|
with torch.inference_mode():
|
|
x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize(
|
|
x, x_sf_global, scaling_vector_size, False)
|
|
output_ref = torch.ops.trtllm.fp4_gemm(x_fp4, w_fp4, x_sf_block,
|
|
w_sf_block, alpha_ref, False,
|
|
dtype)
|
|
|
|
# compare
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(output, output_ref)
|