sycl: fix soft_max_f32 max reduction (#24451)

This commit is contained in:
someoneinjd
2026-06-15 15:10:12 +08:00
committed by GitHub
parent 72be44f1d2
commit d8a3f523c8
+7 -8
View File
@@ -56,7 +56,7 @@ static void soft_max_f32(const float * x,
: block_size_template;
const int nthreads = block_size;
const int nwarps = nthreads / WARP_SIZE;
size_t nreduce = nwarps / WARP_SIZE;
const size_t nreduce = nwarps / WARP_SIZE;
const int tid = item_ct1.get_local_id(2);
@@ -105,17 +105,15 @@ static void soft_max_f32(const float * x,
max_val = warp_reduce_max<WARP_SIZE>(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -INFINITY;
}
item_ct1.barrier();
if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
item_ct1.barrier();
max_val = buf_iw[lane_id];
max_val = -INFINITY;
for (int i = lane_id; i < nwarps; i += WARP_SIZE) {
max_val = sycl::max(max_val, buf_iw[i]);
}
max_val = warp_reduce_max<WARP_SIZE>(max_val);
}
float tmp = 0.0f; // partial sum
@@ -290,7 +288,8 @@ static void soft_max_f32_sycl(const float *x, const T *mask,
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
soft_max_f32<false, 0, 0>(
x, mask, sinks, dst, params,
dpct_local_acc_ct1