mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
fix: correct scaling factor calculation in fp4 quant
Signed-off-by: muse-coder <935522618@qq.com>
This commit is contained in:
parent
0f083b9daf
commit
019ab04531
@ -83,8 +83,8 @@ def random_fp4_tensor_and_sf_v2(shape, sf_vec_size):
|
||||
float_tensor = torch.randn(shape, dtype=torch.float32)
|
||||
half_tensor = float_tensor.to(torch.float16).cuda()
|
||||
|
||||
# global scale trick for int4 quantization.
|
||||
alpha = 448.0 / (torch.max(float_tensor) / 6.0)
|
||||
# global scale trick for fp4 quantization.
|
||||
alpha = (448 * 6) / float_tensor.abs().max().float()
|
||||
sf_scale_tensor = torch.FloatTensor([alpha]).cuda()
|
||||
gemm_alpha_tensor = torch.FloatTensor([1.0 / alpha])
|
||||
|
||||
@ -282,9 +282,8 @@ class TestFunctional(unittest.TestCase):
|
||||
|
||||
input_fp16 = input_fp32.to(torch.float16).cuda()
|
||||
|
||||
# global scale trick for int4 quantization.
|
||||
alpha = 448.0 / (torch.max(input_fp32) / 6.0)
|
||||
|
||||
# global scale trick for fp4 quantization.
|
||||
alpha = (448 * 6) / input_fp32.abs().max().float()
|
||||
weights_fp32_transposed = torch.transpose(weights_fp32, 0, 1)
|
||||
sf_scale_tensor = torch.FloatTensor([alpha]).cuda()
|
||||
act_unscale_tensor = torch.FloatTensor([1.0 / alpha]).cuda()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user