[None][fix] Add a timeout in MNNVL throughput to prevent hangs if one rank crashes (#9532)

Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
Daniel Stokes 2026-01-21 15:14:39 +13:00 committed by GitHub
parent 3c39b1faa9
commit 2f3b2a3172
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,6 +32,10 @@ namespace kernels::moe_comm
#define ENABLE_DEBUG_PRINT 0
#define DISABLE_SYNC_FOR_PROFILING 0
#ifndef DISABLE_TIMEOUT
#define DISABLE_TIMEOUT 0
#endif
// Macros for concise launch-time specialization
#define SWITCH_BOOL(flag, NAME, ...) \
if (flag) \
@ -141,6 +145,13 @@ namespace kernels::moe_comm
__VA_ARGS__ \
}
#if DISABLE_TIMEOUT
#define check_timeout(s) false
#else
// 300 * 2000 MHz - should be high enough on any GPU but will prevent a hang
#define check_timeout(s) ((clock64() - (s)) > (300ll * 2000ll * 1000ll * 1000ll))
#endif
// ============================================================================
// Helper Functions for Expert-to-Rank Mapping
// ============================================================================
@ -515,6 +526,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize)
{
bool flag_set = false;
auto s = clock64();
do
{
uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank];
@ -528,7 +540,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
rank_id, peer_rank, flag_value, expected_value, flag_ptr);
#endif
flag_set = flag_value == expected_value;
} while (!flag_set);
} while (!flag_set && !check_timeout(s));
if (__builtin_expect(!flag_set, 0))
{
printf("dispatch: ---Rank %d timed out waiting for completion flag from rank %d\n", rank_id,
peer_rank);
asm volatile("trap;");
return;
}
}
#endif
}
@ -1038,6 +1058,7 @@ __global__ void moeA2ACombineKernel(
for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize)
{
bool flag_set = false;
auto s = clock64();
do
{
uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank];
@ -1046,12 +1067,20 @@ __global__ void moeA2ACombineKernel(
asm volatile("ld.relaxed.sys.u32 %0, [%1];" : "=r"(flag_value) : "l"(flag_ptr));
#if ENABLE_DEBUG_PRINT
printf(
"combine: ---Rank %d received completion flag from rank %d, flag_value: %d, expected_value: %d, "
"combine: ---Rank %d received completion flag from rank %d, flag_value: %d, expected_value: "
"%d, "
"address: %p\n",
rank_id, peer_rank, flag_value, expected_value, flag_ptr);
#endif
flag_set = flag_value == expected_value;
} while (!flag_set);
} while (!flag_set && !check_timeout(s));
if (__builtin_expect(!flag_set, 0))
{
printf("combine: ---Rank %d timed out waiting for completion flag from rank %d\n", rank_id, peer_rank);
asm volatile("trap;");
return;
}
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
// .acquire and .release qualifiers for fence instruction require sm_90 or higher.