import os import sys import pytest import torch from utils.util import skip_pre_blackwell import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils from tensorrt_llm._torch.modules.linear import Linear from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig sys.path.append(os.path.join(os.path.dirname(__file__), '..')) scaling_vector_size = 16 @skip_pre_blackwell @pytest.mark.parametrize( "dtype", [torch.float16, torch.bfloat16] ) # TODO: Do we need float32 test case? fp4_quantize only supports fp16, bf16, fp8_e4m3 def test_fp4_linear(dtype): SEQ_LEN = 10 HIDDEN_SIZE = 128 torch.manual_seed(0) x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() x_sf_global = (448 * 6) / x.abs().max().float() w = torch.randn((HIDDEN_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() w_sf_global = (448 * 6) / w.abs().max().float() w_fp4, w_sf_block = torch.ops.trtllm.fp4_quantize(w, w_sf_global, scaling_vector_size, False) qc = QuantConfig(quant_algo=QuantAlgo.NVFP4) l_fp4 = Linear(in_features=HIDDEN_SIZE, out_features=HIDDEN_SIZE, bias=False, dtype=dtype, quant_config=qc) assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2 assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype w_sf_block_unswizzled = ( torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse( w_sf_block.cpu().view(HIDDEN_SIZE, -1))) l_fp4.load_weights([{ 'input_scale': 1.0 / x_sf_global.cpu(), # Simulates amax/(448*6) in modelopt ckpt 'weight': w_fp4.cpu(), 'weight_scale': w_sf_block_unswizzled.view( torch.float8_e4m3fn), # Simulates float8_e4m3fn in modelopt ckpt 'weight_scale_2': 1.0 / w_sf_global.cpu() # Simulates amax/(448*6) in modelopt ckpt }]) l_fp4 = l_fp4.cuda() torch.testing.assert_close(l_fp4.weight, w_fp4) torch.testing.assert_close(l_fp4.input_scale[0], x_sf_global) torch.testing.assert_close(l_fp4.weight_scale, w_sf_block) alpha_ref = 1.0 / (w_sf_global * x_sf_global) torch.testing.assert_close(l_fp4.alpha[0], alpha_ref) with torch.inference_mode(): output = l_fp4.forward(x) # ref linear with torch.inference_mode(): x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize( x, x_sf_global, scaling_vector_size, False) output_ref = torch.ops.trtllm.fp4_gemm(x_fp4, w_fp4, x_sf_block, w_sf_block, alpha_ref, False, dtype) # compare torch.cuda.synchronize() torch.testing.assert_close(output, output_ref)