mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5599086][fix] Fix FP8 Linear module for spark (#8707)
Signed-off-by: Simeng Liu <simengl@nvidia.com>
This commit is contained in:
parent
45b36cc069
commit
834a780655
@ -444,16 +444,14 @@ class FP8QDQLinearMethod(LinearMethodBase):
|
||||
|
||||
copy_weight(module.weight_scale, max(weight_scale))
|
||||
|
||||
q_weight = q_weight.to(module.dtype) * weight_scale[0]
|
||||
k_weight = k_weight.to(module.dtype) * weight_scale[1]
|
||||
v_weight = v_weight.to(module.dtype) * weight_scale[2]
|
||||
# use in-place multiplication and division to avoid extra memory allocation
|
||||
q_weight = q_weight.to(module.dtype).mul_(weight_scale[0])
|
||||
k_weight = k_weight.to(module.dtype).mul_(weight_scale[1])
|
||||
v_weight = v_weight.to(module.dtype).mul_(weight_scale[2])
|
||||
|
||||
fused_weight = torch.cat((q_weight, k_weight, v_weight))
|
||||
if module.weight_scale.device != fused_weight.device:
|
||||
module.weight_scale = Parameter(
|
||||
module.weight_scale.data.to(fused_weight.device))
|
||||
fused_weight = (fused_weight / module.weight_scale).to(
|
||||
torch.float8_e4m3fn)
|
||||
fused_weight = fused_weight.div_(
|
||||
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
|
||||
copy_weight(module.weight, fused_weight)
|
||||
|
||||
# Load k and v scales, used for NVFP4 KV cache
|
||||
@ -486,14 +484,12 @@ class FP8QDQLinearMethod(LinearMethodBase):
|
||||
gate_weight, up_weight = load_weights_fused_gate_up_helper(
|
||||
module, weights)
|
||||
|
||||
gate_weight = gate_weight.to(module.dtype) * weight_scale[0]
|
||||
up_weight = up_weight.to(module.dtype) * weight_scale[1]
|
||||
# use in-place multiplication and division to avoid extra memory allocation
|
||||
gate_weight = gate_weight.to(module.dtype).mul_(weight_scale[0])
|
||||
up_weight = up_weight.to(module.dtype).mul_(weight_scale[1])
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
if module.weight_scale.device != fused_weight.device:
|
||||
module.weight_scale = Parameter(
|
||||
module.weight_scale.data.to(fused_weight.device))
|
||||
fused_weight = (fused_weight / module.weight_scale).to(
|
||||
torch.float8_e4m3fn)
|
||||
fused_weight = fused_weight.div_(
|
||||
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
|
||||
copy_weight(module.weight, fused_weight)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user