[Bugfix] Fix RMSNorm kernels to multiply in weight's native dtype (#42379)

Signed-off-by: Lanze Liu <lanzetech@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Lanze Liu
2026-05-29 23:16:53 -07:00
committed by GitHub
parent e9499996df
commit 124fac10cb
2 changed files with 10 additions and 23 deletions
+4 -6
View File
@@ -78,8 +78,7 @@ __global__ void rms_norm_kernel(
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
float w = static_cast<float>(src2.val[j]);
dst.val[j] = static_cast<scalar_t>(x * s_variance * w);
dst.val[j] = static_cast<scalar_t>(x * s_variance) * src2.val[j];
}
v_out[i] = dst;
}
@@ -143,8 +142,7 @@ fused_add_rms_norm_kernel(
#pragma unroll
for (int j = 0; j < width; ++j) {
float x = Converter::convert(res.data[j]);
float wf = Converter::convert(w.data[j]);
out.data[j] = Converter::convert(x * s_variance * wf);
out.data[j] = Converter::convert(x * s_variance) * w.data[j];
}
input_v[strided_id] = out;
}
@@ -183,8 +181,8 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
float w = (float)weight[idx];
input[blockIdx.x * input_stride + idx] = (scalar_t)(x * s_variance * w);
input[blockIdx.x * input_stride + idx] =
(scalar_t)(x * s_variance) * weight[idx];
}
}
@@ -66,13 +66,8 @@ __global__ void rms_norm_static_fp8_quant_kernel(
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
float w = static_cast<float>(src2.val[j]);
// Round normalized result through scalar_t to match the precision of the
// unfused composite (rms_norm writes scalar_t, then
// static_scaled_fp8_quant re-loads it as float before FP8 conversion).
// Without this round, the fused path is strictly more accurate and
// disagrees with the composite at exact E4M3 quantization tie boundaries.
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
// Multiply in weight's native dtype to match rms_norm_kernel.
scalar_t out_norm = static_cast<scalar_t>(x * s_variance) * src2.val[j];
out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] =
scaled_fp8_conversion<true, fp8_type>(static_cast<float>(out_norm),
scale_inv);
@@ -142,12 +137,8 @@ fused_add_rms_norm_static_fp8_quant_kernel(
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = Converter::convert(res.data[i]);
float wf = Converter::convert(w.data[i]);
// See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
// to match the unfused composite path at FP8 boundaries. We use the
// backend's hip_type for the intermediate since c10::Half/BFloat16 has
// ambiguous conversions on CUDA and no implicit conversion on ROCm.
HipT out_norm_h = Converter::convert(x * s_variance * wf);
// Multiply in weight's native dtype to match fused_add_rms_norm_kernel.
HipT out_norm_h = Converter::convert(x * s_variance) * w.data[i];
out[id * width + i] = scaled_fp8_conversion<true, fp8_type>(
Converter::convert(out_norm_h), scale_inv);
}
@@ -192,10 +183,8 @@ fused_add_rms_norm_static_fp8_quant_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
float w = (float)weight[idx];
// See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
// to match the unfused composite path at FP8 boundaries.
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
// Multiply in weight's native dtype to match fused_add_rms_norm_kernel.
scalar_t out_norm = static_cast<scalar_t>(x * s_variance) * weight[idx];
out[blockIdx.x * hidden_size + idx] = scaled_fp8_conversion<true, fp8_type>(
static_cast<float>(out_norm), scale_inv);
}