[Compile] Fix compile warning with topk softplus sqrt (#41261)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-05-14 08:12:50 -04:00
committed by GitHub
parent 0a65d46628
commit 6548560496
+100 -99
View File
@@ -298,130 +298,131 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
asm volatile("griddepcontrol.launch_dependents;");
#endif
return;
}
} else {
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
float val = row_chunk[ii];
float val_b = val * beta;
// Compute softplus: log(1 + exp(val)) with numerical stability
// When val > threshold, softplus(x) ≈ x to avoid exp overflow
val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta;
val = sqrtf(val);
if (correction_bias) {
const int group_id = ii / ELTS_PER_LDG;
const int local_id = ii % ELTS_PER_LDG;
const int expert_idx = first_elt_read_by_thread +
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
local_id;
val = val + correction_bias[expert_idx];
for (int ii = 0; ii < VPT; ++ii) {
float val = row_chunk[ii];
float val_b = val * beta;
// Compute softplus: log(1 + exp(val)) with numerical stability
// When val > threshold, softplus(x) ≈ x to avoid exp overflow
val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta;
val = sqrtf(val);
if (correction_bias) {
const int group_id = ii / ELTS_PER_LDG;
const int local_id = ii % ELTS_PER_LDG;
const int expert_idx = first_elt_read_by_thread +
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
local_id;
val = val + correction_bias[expert_idx];
}
row_chunk[ii] = val;
}
row_chunk[ii] = val;
}
// Original TopK path: find top-k experts by score
// Now, sigmoid_res contains the sigmoid of the row chunk. Now, I want to find
// the topk elements in each row, along with the max index.
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
// Original TopK path: find top-k experts by score
// Now, sigmoid_res contains the sigmoid of the row chunk. Now, I want to
// find the topk elements in each row, along with the max index.
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
// First, each thread does the local argmax
float max_val = row_chunk[0];
int expert = start_col;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
// First, each thread does the local argmax
float max_val = row_chunk[0];
int expert = start_col;
#pragma unroll
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
++ldg, col += COLS_PER_GROUP_LDG) {
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
++ldg, col += COLS_PER_GROUP_LDG) {
#pragma unroll
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
// No check on the experts here since columns with the smallest index
// are processed first and only updated if > (not >=)
if (val > max_val) {
max_val = val;
expert = col + ii;
// No check on the experts here since columns with the smallest index
// are processed first and only updated if > (not >=)
if (val > max_val) {
max_val = val;
expert = col + ii;
}
}
}
}
// Now, we perform the argmax reduce. We use the butterfly pattern so threads
// reach consensus about the max. This will be useful for K > 1 so that the
// threads can agree on "who" had the max value. That thread can then blank out
// their max with -inf and the warp can run more iterations...
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
float other_max =
VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
int other_expert =
VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
float other_max =
VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
int other_expert =
VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this
// way
if (other_max > max_val ||
(other_max == max_val && other_expert < expert)) {
max_val = other_max;
expert = other_expert;
// We want lower indices to "win" in every thread so we break ties this
// way
if (other_max > max_val ||
(other_max == max_val && other_expert < expert)) {
max_val = other_max;
expert = other_expert;
}
}
// Write the max for this k iteration to global memory.
if (thread_group_idx == 0) {
// Add a guard to ignore experts not included by this node
const bool node_uses_expert =
expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
// The lead thread from each sub-group will write out the final results
// to global memory. (This will be a single) thread per row of the
// input/output matrices.
const int idx = k * thread_row + k_idx;
if (correction_bias != nullptr) {
max_val -= correction_bias[expert];
}
output[idx] = max_val;
indices[idx] =
should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
if (renormalize) {
selected_sum += max_val;
}
}
// Finally, we clear the value in the thread with the current max if there
// is another iteration to run.
if (k_idx + 1 < k) {
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
const int thread_to_clear_in_group =
(expert / ELTS_PER_LDG) % THREADS_PER_ROW;
// Only the thread in the group which produced the max will reset the
// "winning" value to -inf.
if (thread_group_idx == thread_to_clear_in_group) {
const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be
// between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
-10000.f;
}
}
}
// Write the max for this k iteration to global memory.
// Apply renormalization and routed scaling factor to final weights.
if (thread_group_idx == 0) {
// Add a guard to ignore experts not included by this node
const bool node_uses_expert =
expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
// The lead thread from each sub-group will write out the final results to
// global memory. (This will be a single) thread per row of the
// input/output matrices.
const int idx = k * thread_row + k_idx;
if (correction_bias != nullptr) {
max_val -= correction_bias[expert];
}
output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
float scale = static_cast<float>(routed_scaling_factor);
if (renormalize) {
selected_sum += max_val;
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
scale /= denom;
}
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * thread_row + k_idx;
output[idx] = output[idx] * scale;
}
}
// Finally, we clear the value in the thread with the current max if there
// is another iteration to run.
if (k_idx + 1 < k) {
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
const int thread_to_clear_in_group =
(expert / ELTS_PER_LDG) % THREADS_PER_ROW;
// Only the thread in the group which produced the max will reset the
// "winning" value to -inf.
if (thread_group_idx == thread_to_clear_in_group) {
const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be
// between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
-10000.f;
}
}
}
// Apply renormalization and routed scaling factor to final weights.
if (thread_group_idx == 0) {
float scale = static_cast<float>(routed_scaling_factor);
if (renormalize) {
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
scale /= denom;
}
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * thread_row + k_idx;
output[idx] = output[idx] * scale;
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
}
namespace detail {