From 834a7806554ea6a96fda2380cf98d0098e337ae1 Mon Sep 17 00:00:00 2001 From: Simeng Liu <109828133+SimengLiu-nv@users.noreply.github.com> Date: Wed, 29 Oct 2025 13:58:19 -0700 Subject: [PATCH] [https://nvbugs/5599086][fix] Fix FP8 Linear module for spark (#8707) Signed-off-by: Simeng Liu --- tensorrt_llm/_torch/modules/linear.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 9f4e233aad..75a8a472bf 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -444,16 +444,14 @@ class FP8QDQLinearMethod(LinearMethodBase): copy_weight(module.weight_scale, max(weight_scale)) - q_weight = q_weight.to(module.dtype) * weight_scale[0] - k_weight = k_weight.to(module.dtype) * weight_scale[1] - v_weight = v_weight.to(module.dtype) * weight_scale[2] + # use in-place multiplication and division to avoid extra memory allocation + q_weight = q_weight.to(module.dtype).mul_(weight_scale[0]) + k_weight = k_weight.to(module.dtype).mul_(weight_scale[1]) + v_weight = v_weight.to(module.dtype).mul_(weight_scale[2]) fused_weight = torch.cat((q_weight, k_weight, v_weight)) - if module.weight_scale.device != fused_weight.device: - module.weight_scale = Parameter( - module.weight_scale.data.to(fused_weight.device)) - fused_weight = (fused_weight / module.weight_scale).to( - torch.float8_e4m3fn) + fused_weight = fused_weight.div_( + module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn) copy_weight(module.weight, fused_weight) # Load k and v scales, used for NVFP4 KV cache @@ -486,14 +484,12 @@ class FP8QDQLinearMethod(LinearMethodBase): gate_weight, up_weight = load_weights_fused_gate_up_helper( module, weights) - gate_weight = gate_weight.to(module.dtype) * weight_scale[0] - up_weight = up_weight.to(module.dtype) * weight_scale[1] + # use in-place multiplication and division to avoid extra memory allocation + gate_weight = gate_weight.to(module.dtype).mul_(weight_scale[0]) + up_weight = up_weight.to(module.dtype).mul_(weight_scale[1]) fused_weight = torch.cat((gate_weight, up_weight)) - if module.weight_scale.device != fused_weight.device: - module.weight_scale = Parameter( - module.weight_scale.data.to(fused_weight.device)) - fused_weight = (fused_weight / module.weight_scale).to( - torch.float8_e4m3fn) + fused_weight = fused_weight.div_( + module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn) copy_weight(module.weight, fused_weight)