From d8a3f523c8da0111b8bcf2d73ca50819579fca4b Mon Sep 17 00:00:00 2001 From: someoneinjd Date: Mon, 15 Jun 2026 15:10:12 +0800 Subject: [PATCH] sycl: fix soft_max_f32 max reduction (#24451) --- ggml/src/ggml-sycl/softmax.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index fdf9b843e0..18bf379bbe 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -56,7 +56,7 @@ static void soft_max_f32(const float * x, : block_size_template; const int nthreads = block_size; const int nwarps = nthreads / WARP_SIZE; - size_t nreduce = nwarps / WARP_SIZE; + const size_t nreduce = nwarps / WARP_SIZE; const int tid = item_ct1.get_local_id(2); @@ -105,17 +105,15 @@ static void soft_max_f32(const float * x, max_val = warp_reduce_max(max_val); if (block_size > WARP_SIZE) { - if (warp_id == 0) { - buf_iw[lane_id] = -INFINITY; - } - item_ct1.barrier(); - if (lane_id == 0) { buf_iw[warp_id] = max_val; } item_ct1.barrier(); - max_val = buf_iw[lane_id]; + max_val = -INFINITY; + for (int i = lane_id; i < nwarps; i += WARP_SIZE) { + max_val = sycl::max(max_val, buf_iw[i]); + } max_val = warp_reduce_max(max_val); } float tmp = 0.0f; // partial sum @@ -290,7 +288,8 @@ static void soft_max_f32_sycl(const float *x, const T *mask, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { soft_max_f32( x, mask, sinks, dst, params, dpct_local_acc_ct1