mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Bugfix] Fix RMSNorm kernels to multiply in weight's native dtype (#42379)
Signed-off-by: Lanze Liu <lanzetech@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
@@ -78,8 +78,7 @@ __global__ void rms_norm_kernel(
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
float x = static_cast<float>(src1.val[j]);
|
||||
float w = static_cast<float>(src2.val[j]);
|
||||
dst.val[j] = static_cast<scalar_t>(x * s_variance * w);
|
||||
dst.val[j] = static_cast<scalar_t>(x * s_variance) * src2.val[j];
|
||||
}
|
||||
v_out[i] = dst;
|
||||
}
|
||||
@@ -143,8 +142,7 @@ fused_add_rms_norm_kernel(
|
||||
#pragma unroll
|
||||
for (int j = 0; j < width; ++j) {
|
||||
float x = Converter::convert(res.data[j]);
|
||||
float wf = Converter::convert(w.data[j]);
|
||||
out.data[j] = Converter::convert(x * s_variance * wf);
|
||||
out.data[j] = Converter::convert(x * s_variance) * w.data[j];
|
||||
}
|
||||
input_v[strided_id] = out;
|
||||
}
|
||||
@@ -183,8 +181,8 @@ fused_add_rms_norm_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
float w = (float)weight[idx];
|
||||
input[blockIdx.x * input_stride + idx] = (scalar_t)(x * s_variance * w);
|
||||
input[blockIdx.x * input_stride + idx] =
|
||||
(scalar_t)(x * s_variance) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -66,13 +66,8 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
float x = static_cast<float>(src1.val[j]);
|
||||
float w = static_cast<float>(src2.val[j]);
|
||||
// Round normalized result through scalar_t to match the precision of the
|
||||
// unfused composite (rms_norm writes scalar_t, then
|
||||
// static_scaled_fp8_quant re-loads it as float before FP8 conversion).
|
||||
// Without this round, the fused path is strictly more accurate and
|
||||
// disagrees with the composite at exact E4M3 quantization tie boundaries.
|
||||
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
|
||||
// Multiply in weight's native dtype to match rms_norm_kernel.
|
||||
scalar_t out_norm = static_cast<scalar_t>(x * s_variance) * src2.val[j];
|
||||
out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] =
|
||||
scaled_fp8_conversion<true, fp8_type>(static_cast<float>(out_norm),
|
||||
scale_inv);
|
||||
@@ -142,12 +137,8 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
float x = Converter::convert(res.data[i]);
|
||||
float wf = Converter::convert(w.data[i]);
|
||||
// See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
|
||||
// to match the unfused composite path at FP8 boundaries. We use the
|
||||
// backend's hip_type for the intermediate since c10::Half/BFloat16 has
|
||||
// ambiguous conversions on CUDA and no implicit conversion on ROCm.
|
||||
HipT out_norm_h = Converter::convert(x * s_variance * wf);
|
||||
// Multiply in weight's native dtype to match fused_add_rms_norm_kernel.
|
||||
HipT out_norm_h = Converter::convert(x * s_variance) * w.data[i];
|
||||
out[id * width + i] = scaled_fp8_conversion<true, fp8_type>(
|
||||
Converter::convert(out_norm_h), scale_inv);
|
||||
}
|
||||
@@ -192,10 +183,8 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
float w = (float)weight[idx];
|
||||
// See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
|
||||
// to match the unfused composite path at FP8 boundaries.
|
||||
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
|
||||
// Multiply in weight's native dtype to match fused_add_rms_norm_kernel.
|
||||
scalar_t out_norm = static_cast<scalar_t>(x * s_variance) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] = scaled_fp8_conversion<true, fp8_type>(
|
||||
static_cast<float>(out_norm), scale_inv);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user