mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[https://nvbugs/5636916][fix] Cherry-pick #10654: Fix accuracy issue of TWO-SHOT AllReduce kernel (#10841)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
165dd360b9
commit
bf7303c7f1
@ -126,20 +126,20 @@ inline AllReduceStrategyType selectStrategyLookUpTable(
|
||||
inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM90
|
||||
= {{ // TP=2
|
||||
{// Fusion=NONE
|
||||
{4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 5, 4, 4, 5, 5, 5, 4, 5, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 5, 4, 5, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 5, 5, 5, 5, 4, 4, 5, 0, 0, 0}, {4, 5, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||
{4, 4, 5, 5, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 5, 5, 5, 5, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 5, 5, 4, 5, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||
{4, 4, 4, 4, 4, 5, 5, 5, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 5, 4, 4, 5, 4, 4, 5, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 5, 5, 5, 5, 5, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 5, 5, 5, 5, 4, 5, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 5, 5, 4, 4, 4, 5, 4, 4, 4, 0, 0, 0}, {4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||
{4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
@ -147,20 +147,20 @@ inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM90
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}},
|
||||
{ // TP=4
|
||||
{// Fusion=NONE
|
||||
{4, 4, 4, 4, 5, 4, 4, 5, 4, 5, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 4, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 5, 4, 5, 5, 5, 4, 4, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 5, 5, 4, 5, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 5, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 5, 5, 5, 5, 4, 5, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 5, 4, 5, 5, 0, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 5, 4, 4, 4, 4, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 5, 5, 4, 5, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 5, 4, 5, 5, 4, 4, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
@ -168,20 +168,20 @@ inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM90
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}},
|
||||
{ // TP=8
|
||||
{// Fusion=NONE
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 5, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0, 0, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 0, 0, 5, 0, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 5, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
@ -191,67 +191,67 @@ inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM90
|
||||
inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM100
|
||||
= {{ // TP=2
|
||||
{// Fusion=NONE
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0}, {4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||
{4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM
|
||||
{4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 5, 4, 4, 4, 5, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||
{4, 4, 4, 4, 5, 4, 5, 5, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 5, 5, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||
{4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 5, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||
{4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}}},
|
||||
{ // TP=4
|
||||
{// Fusion=NONE
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0}}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0}}},
|
||||
{ // TP=8
|
||||
{// Fusion=NONE
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 0}},
|
||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}}}};
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0},
|
||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 0}}}};
|
||||
|
||||
inline const std::unordered_map<int, AllReduceBestStrategyTableType> AllReduceBestStrategyTable = {
|
||||
{90, AllReduceBestStrategyTableSM90},
|
||||
|
||||
@ -137,17 +137,8 @@ public:
|
||||
// corresponding CTA has not been launched.
|
||||
for (int flag_idx = blockIdx.x; flag_idx < kBarrierFlagCount; flag_idx += gridDim.x)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile(
|
||||
"st.global.relaxed.sys.b32 [%1], %0;" ::"r"(m_flag_value), "l"(m_target_flag + flag_idx * NRanks));
|
||||
#else
|
||||
st_flag(m_target_flag + flag_idx * NRanks, m_flag_value);
|
||||
#endif
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// Single release fence
|
||||
asm volatile("fence.release.sys;");
|
||||
#endif
|
||||
|
||||
while (ld_flag(m_current_flag) == prev_flag(m_flag_value))
|
||||
{
|
||||
|
||||
@ -141,6 +141,7 @@ def allreduce_benchmark(
|
||||
logger.set_rank(mapping.rank)
|
||||
|
||||
AutoTuner.get().setup_distributed_state(mapping)
|
||||
dist = Distributed.get(mapping)
|
||||
|
||||
sm_version = get_sm_version()
|
||||
|
||||
|
||||
@ -176,7 +176,7 @@ def main():
|
||||
parser.add_argument("--enable_auto", action="store_true", default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
tp_size_list = [2]
|
||||
tp_size_list = [2, 4, 8]
|
||||
|
||||
# Process the benchmark data
|
||||
# combine all the data into one dataframe
|
||||
|
||||
@ -282,15 +282,12 @@ def compare_logprobs(logprobs_list, ref_new_token_logprobs_list):
|
||||
def test_accuracy_with_allreduce_strategy(model_dir, sampler_type, allreduce_strategy):
|
||||
"""Test accuracy with different allreduce strategies.
|
||||
|
||||
The default allreduce_strategy (AUTO) produced wrong logprobs with large batch size,
|
||||
causing VeRL integration to fail to converge. There may be an issue with the
|
||||
customAllReduce kernels.
|
||||
|
||||
Tracked: NVBug (https://nvbugs/5727691)
|
||||
This test validates that both NCCL and AUTO allreduce strategies produce
|
||||
correct logprobs compared to HuggingFace reference implementation.
|
||||
|
||||
Expected behavior:
|
||||
- allreduce_strategy="NCCL": Accuracy assertion PASSES
|
||||
- allreduce_strategy="AUTO": Accuracy assertion FAILS
|
||||
- allreduce_strategy="AUTO": Accuracy assertion PASSES
|
||||
"""
|
||||
model_dir = str(llm_models_root() / model_dir)
|
||||
|
||||
@ -401,8 +398,4 @@ def test_accuracy_with_allreduce_strategy(model_dir, sampler_type, allreduce_str
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Compare LLM logprobs vs HF reference
|
||||
if allreduce_strategy == "AUTO":
|
||||
with pytest.raises(AssertionError, match=r"Final Min diff: .* is below threshold -2\.30"):
|
||||
compare_logprobs(llm_logprobs, ref_new_token_logprobs)
|
||||
else:
|
||||
compare_logprobs(llm_logprobs, ref_new_token_logprobs)
|
||||
compare_logprobs(llm_logprobs, ref_new_token_logprobs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user