mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Compile] Fix compile warning with topk softplus sqrt (#41261)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -298,130 +298,131 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
float val = row_chunk[ii];
|
||||
float val_b = val * beta;
|
||||
// Compute softplus: log(1 + exp(val)) with numerical stability
|
||||
// When val > threshold, softplus(x) ≈ x to avoid exp overflow
|
||||
val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta;
|
||||
val = sqrtf(val);
|
||||
if (correction_bias) {
|
||||
const int group_id = ii / ELTS_PER_LDG;
|
||||
const int local_id = ii % ELTS_PER_LDG;
|
||||
const int expert_idx = first_elt_read_by_thread +
|
||||
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
|
||||
local_id;
|
||||
val = val + correction_bias[expert_idx];
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
float val = row_chunk[ii];
|
||||
float val_b = val * beta;
|
||||
// Compute softplus: log(1 + exp(val)) with numerical stability
|
||||
// When val > threshold, softplus(x) ≈ x to avoid exp overflow
|
||||
val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta;
|
||||
val = sqrtf(val);
|
||||
if (correction_bias) {
|
||||
const int group_id = ii / ELTS_PER_LDG;
|
||||
const int local_id = ii % ELTS_PER_LDG;
|
||||
const int expert_idx = first_elt_read_by_thread +
|
||||
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
|
||||
local_id;
|
||||
val = val + correction_bias[expert_idx];
|
||||
}
|
||||
row_chunk[ii] = val;
|
||||
}
|
||||
row_chunk[ii] = val;
|
||||
}
|
||||
|
||||
// Original TopK path: find top-k experts by score
|
||||
// Now, sigmoid_res contains the sigmoid of the row chunk. Now, I want to find
|
||||
// the topk elements in each row, along with the max index.
|
||||
int start_col = first_elt_read_by_thread;
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
// Original TopK path: find top-k experts by score
|
||||
// Now, sigmoid_res contains the sigmoid of the row chunk. Now, I want to
|
||||
// find the topk elements in each row, along with the max index.
|
||||
int start_col = first_elt_read_by_thread;
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
|
||||
float selected_sum = 0.f;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
int expert = start_col;
|
||||
float selected_sum = 0.f;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
int expert = start_col;
|
||||
#pragma unroll
|
||||
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
|
||||
++ldg, col += COLS_PER_GROUP_LDG) {
|
||||
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
|
||||
++ldg, col += COLS_PER_GROUP_LDG) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
|
||||
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
|
||||
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||
|
||||
// No check on the experts here since columns with the smallest index
|
||||
// are processed first and only updated if > (not >=)
|
||||
if (val > max_val) {
|
||||
max_val = val;
|
||||
expert = col + ii;
|
||||
// No check on the experts here since columns with the smallest index
|
||||
// are processed first and only updated if > (not >=)
|
||||
if (val > max_val) {
|
||||
max_val = val;
|
||||
expert = col + ii;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we perform the argmax reduce. We use the butterfly pattern so threads
|
||||
// reach consensus about the max. This will be useful for K > 1 so that the
|
||||
// threads can agree on "who" had the max value. That thread can then blank out
|
||||
// their max with -inf and the warp can run more iterations...
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
float other_max =
|
||||
VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert =
|
||||
VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
float other_max =
|
||||
VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert =
|
||||
VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
|
||||
|
||||
// We want lower indices to "win" in every thread so we break ties this
|
||||
// way
|
||||
if (other_max > max_val ||
|
||||
(other_max == max_val && other_expert < expert)) {
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
// We want lower indices to "win" in every thread so we break ties this
|
||||
// way
|
||||
if (other_max > max_val ||
|
||||
(other_max == max_val && other_expert < expert)) {
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
if (thread_group_idx == 0) {
|
||||
// Add a guard to ignore experts not included by this node
|
||||
const bool node_uses_expert =
|
||||
expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
// The lead thread from each sub-group will write out the final results
|
||||
// to global memory. (This will be a single) thread per row of the
|
||||
// input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
if (correction_bias != nullptr) {
|
||||
max_val -= correction_bias[expert];
|
||||
}
|
||||
output[idx] = max_val;
|
||||
indices[idx] =
|
||||
should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
if (renormalize) {
|
||||
selected_sum += max_val;
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there
|
||||
// is another iteration to run.
|
||||
if (k_idx + 1 < k) {
|
||||
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
const int thread_to_clear_in_group =
|
||||
(expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
|
||||
// Only the thread in the group which produced the max will reset the
|
||||
// "winning" value to -inf.
|
||||
if (thread_group_idx == thread_to_clear_in_group) {
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be
|
||||
// between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
|
||||
-10000.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
// Apply renormalization and routed scaling factor to final weights.
|
||||
if (thread_group_idx == 0) {
|
||||
// Add a guard to ignore experts not included by this node
|
||||
const bool node_uses_expert =
|
||||
expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
// The lead thread from each sub-group will write out the final results to
|
||||
// global memory. (This will be a single) thread per row of the
|
||||
// input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
if (correction_bias != nullptr) {
|
||||
max_val -= correction_bias[expert];
|
||||
}
|
||||
output[idx] = max_val;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
float scale = static_cast<float>(routed_scaling_factor);
|
||||
if (renormalize) {
|
||||
selected_sum += max_val;
|
||||
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
|
||||
scale /= denom;
|
||||
}
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = output[idx] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there
|
||||
// is another iteration to run.
|
||||
if (k_idx + 1 < k) {
|
||||
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
const int thread_to_clear_in_group =
|
||||
(expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
|
||||
// Only the thread in the group which produced the max will reset the
|
||||
// "winning" value to -inf.
|
||||
if (thread_group_idx == thread_to_clear_in_group) {
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be
|
||||
// between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
|
||||
-10000.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply renormalization and routed scaling factor to final weights.
|
||||
if (thread_group_idx == 0) {
|
||||
float scale = static_cast<float>(routed_scaling_factor);
|
||||
if (renormalize) {
|
||||
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
|
||||
scale /= denom;
|
||||
}
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = output[idx] * scale;
|
||||
}
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
Reference in New Issue
Block a user