Fix miss bias add for FP4Linear. (#3361)

Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
Tracin 2025-04-09 09:17:54 +08:00 committed by GitHub
parent 5bdf997963
commit 2a2b7bfc66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -323,7 +323,8 @@ class Linear(nn.Module):
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, self.weight, act_input_sf, self.weight_scale)
if bias is not None:
output = output + bias
elif self.has_nv_fp4:
if isinstance(input, Fp4QuantizedTensor):
act_fp4, act_sf = input.fp4_tensor, input.scaling_factor
@ -336,6 +337,8 @@ class Linear(nn.Module):
act_sf, self.weight_scale,
self.alpha, False,
self.dtype)
if bias is not None:
output = output + bias
else:
# TODO(zhenhuanc): support other quant mode
raise ValueError(f'unsupported quant mode: {qc.quant_mode}')