fix: correct scaling factor calculation in fp4 quant

Signed-off-by: muse-coder <935522618@qq.com>
This commit is contained in:
muse-coder 2025-07-31 00:01:17 +08:00
parent 0f083b9daf
commit 019ab04531

View File

@ -83,8 +83,8 @@ def random_fp4_tensor_and_sf_v2(shape, sf_vec_size):
float_tensor = torch.randn(shape, dtype=torch.float32)
half_tensor = float_tensor.to(torch.float16).cuda()
# global scale trick for int4 quantization.
alpha = 448.0 / (torch.max(float_tensor) / 6.0)
# global scale trick for fp4 quantization.
alpha = (448 * 6) / float_tensor.abs().max().float()
sf_scale_tensor = torch.FloatTensor([alpha]).cuda()
gemm_alpha_tensor = torch.FloatTensor([1.0 / alpha])
@ -282,9 +282,8 @@ class TestFunctional(unittest.TestCase):
input_fp16 = input_fp32.to(torch.float16).cuda()
# global scale trick for int4 quantization.
alpha = 448.0 / (torch.max(input_fp32) / 6.0)
# global scale trick for fp4 quantization.
alpha = (448 * 6) / input_fp32.abs().max().float()
weights_fp32_transposed = torch.transpose(weights_fp32, 0, 1)
sf_scale_tensor = torch.FloatTensor([alpha]).cuda()
act_unscale_tensor = torch.FloatTensor([1.0 / alpha]).cuda()