mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Fix miss bias add for FP4Linear. (#3361)
Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com> Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
parent
5bdf997963
commit
2a2b7bfc66
@ -323,7 +323,8 @@ class Linear(nn.Module):
|
||||
|
||||
output = torch.ops.trtllm.fp8_block_scaling_gemm(
|
||||
act_input_fp8, self.weight, act_input_sf, self.weight_scale)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
elif self.has_nv_fp4:
|
||||
if isinstance(input, Fp4QuantizedTensor):
|
||||
act_fp4, act_sf = input.fp4_tensor, input.scaling_factor
|
||||
@ -336,6 +337,8 @@ class Linear(nn.Module):
|
||||
act_sf, self.weight_scale,
|
||||
self.alpha, False,
|
||||
self.dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
else:
|
||||
# TODO(zhenhuanc): support other quant mode
|
||||
raise ValueError(f'unsupported quant mode: {qc.quant_mode}')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user