mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
import os
|
|
import sys
|
|
|
|
import pytest
|
|
import torch
|
|
from utils.util import skip_pre_blackwell
|
|
|
|
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
|
from tensorrt_llm._torch.modules.linear import Linear
|
|
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
|
|
scaling_vector_size = 16
|
|
|
|
|
|
@skip_pre_blackwell
|
|
@pytest.mark.parametrize(
|
|
"dtype", [torch.float16, torch.bfloat16]
|
|
) # TODO: Do we need float32 test case? fp4_quantize only supports fp16, bf16, fp8_e4m3
|
|
def test_fp4_linear(dtype):
|
|
SEQ_LEN = 10
|
|
HIDDEN_SIZE = 128
|
|
torch.manual_seed(0)
|
|
|
|
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
|
|
x_sf_global = (448 * 6) / x.abs().max().float()
|
|
|
|
w = torch.randn((HIDDEN_SIZE, HIDDEN_SIZE), dtype=dtype).cuda()
|
|
w_sf_global = (448 * 6) / w.abs().max().float()
|
|
w_fp4, w_sf_block = torch.ops.trtllm.fp4_quantize(w, w_sf_global,
|
|
scaling_vector_size,
|
|
False)
|
|
|
|
qc = QuantConfig(quant_algo=QuantAlgo.NVFP4)
|
|
l_fp4 = Linear(in_features=HIDDEN_SIZE,
|
|
out_features=HIDDEN_SIZE,
|
|
bias=False,
|
|
dtype=dtype,
|
|
quant_config=qc)
|
|
|
|
assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2
|
|
assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype
|
|
|
|
w_sf_block_unswizzled = (
|
|
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
|
|
w_sf_block.cpu().view(HIDDEN_SIZE, -1)))
|
|
|
|
l_fp4.load_weights([{
|
|
'input_scale':
|
|
1.0 / x_sf_global.cpu(), # Simulates amax/(448*6) in modelopt ckpt
|
|
'weight':
|
|
w_fp4.cpu(),
|
|
'weight_scale':
|
|
w_sf_block_unswizzled.view(
|
|
torch.float8_e4m3fn), # Simulates float8_e4m3fn in modelopt ckpt
|
|
'weight_scale_2':
|
|
1.0 / w_sf_global.cpu() # Simulates amax/(448*6) in modelopt ckpt
|
|
}])
|
|
l_fp4 = l_fp4.cuda()
|
|
|
|
torch.testing.assert_close(l_fp4.weight, w_fp4)
|
|
torch.testing.assert_close(l_fp4.input_scale[0], x_sf_global)
|
|
torch.testing.assert_close(l_fp4.weight_scale, w_sf_block)
|
|
alpha_ref = 1.0 / (w_sf_global * x_sf_global)
|
|
torch.testing.assert_close(l_fp4.alpha[0], alpha_ref)
|
|
|
|
with torch.inference_mode():
|
|
output = l_fp4.forward(x)
|
|
|
|
# ref linear
|
|
with torch.inference_mode():
|
|
x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize(
|
|
x, x_sf_global, scaling_vector_size, False)
|
|
output_ref = torch.ops.trtllm.fp4_gemm(x_fp4, w_fp4, x_sf_block,
|
|
w_sf_block, alpha_ref, False,
|
|
dtype)
|
|
|
|
# compare
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(output, output_ref)
|