Support dynamic per-tensor FP8 (#4250)

* Support dynamic per-tensor FP8

Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com>

* Update test cases.

Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com>

---------

Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com>
This commit is contained in:
Tracin 2025-05-16 13:33:58 +08:00 committed by GitHub
parent 11aa50d1ea
commit 46c5a56444
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -306,16 +306,24 @@ class Linear(nn.Module):
if self.has_any_quant:
qc = self.quant_config
if self.has_fp8_qdq:
cur_input_scale = self.input_scale
if input.dtype != torch.float8_e4m3fn:
qinput, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
input, self.input_scale)
if self.input_scale is not None:
# Static quantization
qinput, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
input, self.input_scale)
else:
# Dynamic quantization
qinput, cur_input_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
input)
cur_input_scale = cur_input_scale.to(torch.float32)
else:
qinput = input
# This op does not support bias now.
output = torch.ops.trtllm.cublas_scaled_mm(
qinput,
weight.t(),
scale_a=self.input_scale,
scale_a=cur_input_scale,
scale_b=self.weight_scale,
bias=None,
out_dtype=self.dtype or input.dtype,
@ -452,9 +460,15 @@ class Linear(nn.Module):
if quant_mode.has_fp8_qdq():
input_scale, weight_scale = load_weight_scales_fp8_qdq(
weights)
_copy(self.input_scale, input_scale[0])
if len(input_scale) != 0:
# Static quantization
_copy(self.input_scale, input_scale[0])
self.inv_input_scale.data = 1.0 / self.input_scale
else:
# Dynamic quantization
self.input_scale = None
self.inv_input_scale = None
_copy(self.weight_scale, weight_scale[0])
self.inv_input_scale.data = 1.0 / self.input_scale
elif quant_mode.has_nvfp4():
input_scale, weight_scale, alpha = load_weight_scales_nvfp4(
weights,
@ -499,7 +513,12 @@ class Linear(nn.Module):
if quant_mode.has_fp8_qdq():
input_scale, weight_scale = load_weight_scales_fp8_qdq(
weights)
_copy(self.input_scale, max(input_scale))
if len(input_scale) != 0:
# Static quantization
_copy(self.input_scale, max(input_scale))
else:
# Dynamic quantization
self.input_scale = None
_copy(self.weight_scale, max(weight_scale))
q_weight = q_weight.to(self.dtype) * weight_scale[0]
k_weight = k_weight.to(self.dtype) * weight_scale[1]
@ -561,7 +580,12 @@ class Linear(nn.Module):
if quant_mode.has_fp8_qdq():
input_scale, weight_scale = load_weight_scales_fp8_qdq(
weights)
_copy(self.input_scale, max(input_scale))
if len(input_scale) != 0:
# Static quantization
_copy(self.input_scale, max(input_scale))
else:
# Dynamic quantization
self.input_scale = None
_copy(self.weight_scale, max(weight_scale))
gate_weight = gate_weight.to(self.dtype) * weight_scale[0]
up_weight = up_weight.to(self.dtype) * weight_scale[1]