mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Support dynamic per-tensor FP8 (#4250)
* Support dynamic per-tensor FP8 Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com> * Update test cases. Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com> --------- Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com>
This commit is contained in:
parent
11aa50d1ea
commit
46c5a56444
@ -306,16 +306,24 @@ class Linear(nn.Module):
|
||||
if self.has_any_quant:
|
||||
qc = self.quant_config
|
||||
if self.has_fp8_qdq:
|
||||
cur_input_scale = self.input_scale
|
||||
if input.dtype != torch.float8_e4m3fn:
|
||||
qinput, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
|
||||
input, self.input_scale)
|
||||
if self.input_scale is not None:
|
||||
# Static quantization
|
||||
qinput, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
|
||||
input, self.input_scale)
|
||||
else:
|
||||
# Dynamic quantization
|
||||
qinput, cur_input_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
|
||||
input)
|
||||
cur_input_scale = cur_input_scale.to(torch.float32)
|
||||
else:
|
||||
qinput = input
|
||||
# This op does not support bias now.
|
||||
output = torch.ops.trtllm.cublas_scaled_mm(
|
||||
qinput,
|
||||
weight.t(),
|
||||
scale_a=self.input_scale,
|
||||
scale_a=cur_input_scale,
|
||||
scale_b=self.weight_scale,
|
||||
bias=None,
|
||||
out_dtype=self.dtype or input.dtype,
|
||||
@ -452,9 +460,15 @@ class Linear(nn.Module):
|
||||
if quant_mode.has_fp8_qdq():
|
||||
input_scale, weight_scale = load_weight_scales_fp8_qdq(
|
||||
weights)
|
||||
_copy(self.input_scale, input_scale[0])
|
||||
if len(input_scale) != 0:
|
||||
# Static quantization
|
||||
_copy(self.input_scale, input_scale[0])
|
||||
self.inv_input_scale.data = 1.0 / self.input_scale
|
||||
else:
|
||||
# Dynamic quantization
|
||||
self.input_scale = None
|
||||
self.inv_input_scale = None
|
||||
_copy(self.weight_scale, weight_scale[0])
|
||||
self.inv_input_scale.data = 1.0 / self.input_scale
|
||||
elif quant_mode.has_nvfp4():
|
||||
input_scale, weight_scale, alpha = load_weight_scales_nvfp4(
|
||||
weights,
|
||||
@ -499,7 +513,12 @@ class Linear(nn.Module):
|
||||
if quant_mode.has_fp8_qdq():
|
||||
input_scale, weight_scale = load_weight_scales_fp8_qdq(
|
||||
weights)
|
||||
_copy(self.input_scale, max(input_scale))
|
||||
if len(input_scale) != 0:
|
||||
# Static quantization
|
||||
_copy(self.input_scale, max(input_scale))
|
||||
else:
|
||||
# Dynamic quantization
|
||||
self.input_scale = None
|
||||
_copy(self.weight_scale, max(weight_scale))
|
||||
q_weight = q_weight.to(self.dtype) * weight_scale[0]
|
||||
k_weight = k_weight.to(self.dtype) * weight_scale[1]
|
||||
@ -561,7 +580,12 @@ class Linear(nn.Module):
|
||||
if quant_mode.has_fp8_qdq():
|
||||
input_scale, weight_scale = load_weight_scales_fp8_qdq(
|
||||
weights)
|
||||
_copy(self.input_scale, max(input_scale))
|
||||
if len(input_scale) != 0:
|
||||
# Static quantization
|
||||
_copy(self.input_scale, max(input_scale))
|
||||
else:
|
||||
# Dynamic quantization
|
||||
self.input_scale = None
|
||||
_copy(self.weight_scale, max(weight_scale))
|
||||
gate_weight = gate_weight.to(self.dtype) * weight_scale[0]
|
||||
up_weight = up_weight.to(self.dtype) * weight_scale[1]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user