diff --git a/csrc/moe/topk_softplus_sqrt_kernels.cu b/csrc/moe/topk_softplus_sqrt_kernels.cu index 43d461a0179..d5bb8edadc6 100644 --- a/csrc/moe/topk_softplus_sqrt_kernels.cu +++ b/csrc/moe/topk_softplus_sqrt_kernels.cu @@ -298,130 +298,131 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ asm volatile("griddepcontrol.launch_dependents;"); #endif return; - } - + } else { #pragma unroll - for (int ii = 0; ii < VPT; ++ii) { - float val = row_chunk[ii]; - float val_b = val * beta; - // Compute softplus: log(1 + exp(val)) with numerical stability - // When val > threshold, softplus(x) ≈ x to avoid exp overflow - val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta; - val = sqrtf(val); - if (correction_bias) { - const int group_id = ii / ELTS_PER_LDG; - const int local_id = ii % ELTS_PER_LDG; - const int expert_idx = first_elt_read_by_thread + - group_id * THREADS_PER_ROW * ELTS_PER_LDG + - local_id; - val = val + correction_bias[expert_idx]; + for (int ii = 0; ii < VPT; ++ii) { + float val = row_chunk[ii]; + float val_b = val * beta; + // Compute softplus: log(1 + exp(val)) with numerical stability + // When val > threshold, softplus(x) ≈ x to avoid exp overflow + val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta; + val = sqrtf(val); + if (correction_bias) { + const int group_id = ii / ELTS_PER_LDG; + const int local_id = ii % ELTS_PER_LDG; + const int expert_idx = first_elt_read_by_thread + + group_id * THREADS_PER_ROW * ELTS_PER_LDG + + local_id; + val = val + correction_bias[expert_idx]; + } + row_chunk[ii] = val; } - row_chunk[ii] = val; - } - // Original TopK path: find top-k experts by score - // Now, sigmoid_res contains the sigmoid of the row chunk. Now, I want to find - // the topk elements in each row, along with the max index. - int start_col = first_elt_read_by_thread; - static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + // Original TopK path: find top-k experts by score + // Now, sigmoid_res contains the sigmoid of the row chunk. Now, I want to + // find the topk elements in each row, along with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; - float selected_sum = 0.f; - for (int k_idx = 0; k_idx < k; ++k_idx) { - // First, each thread does the local argmax - float max_val = row_chunk[0]; - int expert = start_col; + float selected_sum = 0.f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; #pragma unroll - for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; - ++ldg, col += COLS_PER_GROUP_LDG) { + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; + ++ldg, col += COLS_PER_GROUP_LDG) { #pragma unroll - for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { - float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; - // No check on the experts here since columns with the smallest index - // are processed first and only updated if > (not >=) - if (val > max_val) { - max_val = val; - expert = col + ii; + // No check on the experts here since columns with the smallest index + // are processed first and only updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; + } } } - } // Now, we perform the argmax reduce. We use the butterfly pattern so threads // reach consensus about the max. This will be useful for K > 1 so that the // threads can agree on "who" had the max value. That thread can then blank out // their max with -inf and the warp can run more iterations... #pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - float other_max = - VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); - int other_expert = - VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = + VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); + int other_expert = + VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); - // We want lower indices to "win" in every thread so we break ties this - // way - if (other_max > max_val || - (other_max == max_val && other_expert < expert)) { - max_val = other_max; - expert = other_expert; + // We want lower indices to "win" in every thread so we break ties this + // way + if (other_max > max_val || + (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = + expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results + // to global memory. (This will be a single) thread per row of the + // input/output matrices. + const int idx = k * thread_row + k_idx; + if (correction_bias != nullptr) { + max_val -= correction_bias[expert]; + } + output[idx] = max_val; + indices[idx] = + should_process_row ? (expert - start_expert) : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + if (renormalize) { + selected_sum += max_val; + } + } + + // Finally, we clear the value in the thread with the current max if there + // is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = + (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the + // "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be + // between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = + -10000.f; + } } } - // Write the max for this k iteration to global memory. + // Apply renormalization and routed scaling factor to final weights. if (thread_group_idx == 0) { - // Add a guard to ignore experts not included by this node - const bool node_uses_expert = - expert >= start_expert && expert < end_expert; - const bool should_process_row = row_is_active && node_uses_expert; - - // The lead thread from each sub-group will write out the final results to - // global memory. (This will be a single) thread per row of the - // input/output matrices. - const int idx = k * thread_row + k_idx; - if (correction_bias != nullptr) { - max_val -= correction_bias[expert]; - } - output[idx] = max_val; - indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; - source_rows[idx] = k_idx * num_rows + thread_row; + float scale = static_cast(routed_scaling_factor); if (renormalize) { - selected_sum += max_val; + const float denom = selected_sum > 0.f ? selected_sum : 1.f; + scale /= denom; + } + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * thread_row + k_idx; + output[idx] = output[idx] * scale; } } - - // Finally, we clear the value in the thread with the current max if there - // is another iteration to run. - if (k_idx + 1 < k) { - const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; - const int thread_to_clear_in_group = - (expert / ELTS_PER_LDG) % THREADS_PER_ROW; - - // Only the thread in the group which produced the max will reset the - // "winning" value to -inf. - if (thread_group_idx == thread_to_clear_in_group) { - const int offset_for_expert = expert % ELTS_PER_LDG; - // Safe to set to any negative value since row_chunk values must be - // between 0 and 1. - row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = - -10000.f; - } - } - } - - // Apply renormalization and routed scaling factor to final weights. - if (thread_group_idx == 0) { - float scale = static_cast(routed_scaling_factor); - if (renormalize) { - const float denom = selected_sum > 0.f ? selected_sum : 1.f; - scale /= denom; - } - for (int k_idx = 0; k_idx < k; ++k_idx) { - const int idx = k * thread_row + k_idx; - output[idx] = output[idx] * scale; - } - } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); + asm volatile("griddepcontrol.launch_dependents;"); #endif + } } namespace detail {