[https://nvbugs/5599086][fix] Fix FP8 Linear module for spark (#8707)

Signed-off-by: Simeng Liu <simengl@nvidia.com>
This commit is contained in:
Simeng Liu 2025-10-29 13:58:19 -07:00 committed by GitHub
parent 45b36cc069
commit 834a780655
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -444,16 +444,14 @@ class FP8QDQLinearMethod(LinearMethodBase):
copy_weight(module.weight_scale, max(weight_scale))
q_weight = q_weight.to(module.dtype) * weight_scale[0]
k_weight = k_weight.to(module.dtype) * weight_scale[1]
v_weight = v_weight.to(module.dtype) * weight_scale[2]
# use in-place multiplication and division to avoid extra memory allocation
q_weight = q_weight.to(module.dtype).mul_(weight_scale[0])
k_weight = k_weight.to(module.dtype).mul_(weight_scale[1])
v_weight = v_weight.to(module.dtype).mul_(weight_scale[2])
fused_weight = torch.cat((q_weight, k_weight, v_weight))
if module.weight_scale.device != fused_weight.device:
module.weight_scale = Parameter(
module.weight_scale.data.to(fused_weight.device))
fused_weight = (fused_weight / module.weight_scale).to(
torch.float8_e4m3fn)
fused_weight = fused_weight.div_(
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
copy_weight(module.weight, fused_weight)
# Load k and v scales, used for NVFP4 KV cache
@ -486,14 +484,12 @@ class FP8QDQLinearMethod(LinearMethodBase):
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
gate_weight = gate_weight.to(module.dtype) * weight_scale[0]
up_weight = up_weight.to(module.dtype) * weight_scale[1]
# use in-place multiplication and division to avoid extra memory allocation
gate_weight = gate_weight.to(module.dtype).mul_(weight_scale[0])
up_weight = up_weight.to(module.dtype).mul_(weight_scale[1])
fused_weight = torch.cat((gate_weight, up_weight))
if module.weight_scale.device != fused_weight.device:
module.weight_scale = Parameter(
module.weight_scale.data.to(fused_weight.device))
fused_weight = (fused_weight / module.weight_scale).to(
torch.float8_e4m3fn)
fused_weight = fused_weight.div_(
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
copy_weight(module.weight, fused_weight)