TensorRT-LLMs/tests/unittest/_torch/test_fp4_linear.py
Yukun He c678774c99
feat: Apply the new torch-flow compatible AutoTuner to both Fused MoE and NVFP4 Linear operators. (#3151)
* Several optimizations and fixings on the Autotuner.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Apply the new Python side Autotuner on current linear for nvFP4 data type.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Apply the new Python side Autotuner on MoE op
* Remove routers from cache key to improve inference perf
* Prevent unnecessary code profiling. Use do_preparation keyword to select which part should be executed during before evaluating any tactic.
* Remove try-catch inside moe profiling process.
* Move default tactic -1 to 0 transforms in cpp runner.
* Revise relavant tests.
* Predefined the bucketizing strategy for fused_moe

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Add specific_profile support for AutoTuner to bypass the standard cache search process for perf optimization
* Add specific_profile for moe
* Add specific profile for linear

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Fixing and revising according to reviewer's suggestions.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Use lru_cache for inference pref optimization.
* Revert gen_custom_cache_key feature

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Replace runner with runner id to achieve a serializable cache.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Code clean up and minor fixings.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Move all tunable runners and custom ops into torch_custom_ops.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Treat min_latency_mode as a independent dynamic tensor. Modify get_valid_tactics to suit for it.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

---------

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2025-04-08 14:28:36 +08:00

80 lines
2.8 KiB
Python

import pytest
import torch
from utils.util import skip_pre_blackwell
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._torch.autotuner import autotune
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
scaling_vector_size = 16
@skip_pre_blackwell
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.bfloat16]
) # TODO: Do we need float32 test case? fp4_quantize only supports fp16, bf16, fp8_e4m3
def test_fp4_linear(dtype):
SEQ_LEN = 10
HIDDEN_SIZE = 128
torch.manual_seed(0)
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
x_sf_global = (448 * 6) / x.abs().max().float()
w = torch.randn((HIDDEN_SIZE, HIDDEN_SIZE), dtype=dtype).cuda()
w_sf_global = (448 * 6) / w.abs().max().float()
w_fp4, w_sf_block = torch.ops.trtllm.fp4_quantize(w, w_sf_global,
scaling_vector_size,
False)
qc = QuantConfig(quant_algo=QuantAlgo.NVFP4)
l_fp4 = Linear(in_features=HIDDEN_SIZE,
out_features=HIDDEN_SIZE,
bias=False,
dtype=dtype,
quant_config=qc)
assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2
assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype
w_sf_block_unswizzled = (
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
w_sf_block.cpu().view(HIDDEN_SIZE, -1)))
l_fp4.load_weights([{
'input_scale':
1.0 / x_sf_global.cpu(), # Simulates amax/(448*6) in modelopt ckpt
'weight':
w_fp4.cpu(),
'weight_scale':
w_sf_block_unswizzled.view(
torch.float8_e4m3fn), # Simulates float8_e4m3fn in modelopt ckpt
'weight_scale_2':
1.0 / w_sf_global.cpu() # Simulates amax/(448*6) in modelopt ckpt
}])
l_fp4 = l_fp4.cuda()
torch.testing.assert_close(l_fp4.weight, w_fp4)
torch.testing.assert_close(l_fp4.input_scale[0], x_sf_global)
torch.testing.assert_close(l_fp4.weight_scale, w_sf_block)
alpha_ref = 1.0 / (w_sf_global * x_sf_global)
torch.testing.assert_close(l_fp4.alpha[0], alpha_ref)
with torch.inference_mode(), autotune():
output = l_fp4.forward(x)
output_ref = l_fp4.forward(x)
# ref linear
with torch.inference_mode():
x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize(
x, x_sf_global, scaling_vector_size, False)
output_ref = torch.ops.trtllm.fp4_gemm(x_fp4, w_fp4, x_sf_block,
w_sf_block, alpha_ref, False,
dtype)
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, output_ref)