perf: better heuristic for allreduce (#5432)

Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
This commit is contained in:
Void 2025-07-02 10:56:06 +08:00 committed by GitHub
parent 10c50515c2
commit 7992869798
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 18 deletions

View File

@ -440,7 +440,7 @@ public:
};
template <AllReduceFusionPattern Pattern, typename DType, int NRanks, bool Fp32Acc, bool TriggerCompletionAtEnd = true>
__global__ void allreduce_fusion_kernel_oneshot_lamport(AllReduceFusionParams params)
__global__ void __launch_bounds__(1024) allreduce_fusion_kernel_oneshot_lamport(AllReduceFusionParams params)
{
IndexHelper<DType> index_helper(params);
int token_id = index_helper.token_id;
@ -666,10 +666,16 @@ void allreduce_fusion_kernel_launcher(AllReduceFusionParams const& params)
threads_per_block *= 2;
cluster_size /= 2;
}
int sm_count = get_sm_count();
while (cluster_num * cluster_size > sm_count && cluster_size > 1 && threads_per_block <= 512)
{
threads_per_block *= 2;
cluster_size /= 2;
}
TLLM_CHECK(oneshot || threads_per_block >= params.nranks);
int block_size = threads_per_block;
TLLM_CHECK(block_size <= 1024 && cluster_size > 0);
int sm_count = get_sm_count();
int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size;
cudaLaunchConfig_t cfg;
cudaLaunchAttribute attribute[2];

View File

@ -910,26 +910,23 @@ private:
{
strategy = AllReduceStrategyType::MIN_LATENCY;
}
else if (world_size <= 4)
{
if (message_size_bytes < 1 * 1000 * 1000)
{
strategy = AllReduceStrategyType::MIN_LATENCY;
}
else
{
strategy = AllReduceStrategyType::NCCL;
}
}
else
{
if (message_size_bytes < 500 * 1000)
static char* threshold_ptr = std::getenv("ALLREDUCE_AUTO_HEURISTIC_MIN_LATENCY_THRESHOLD_TOKEN_NUM");
size_t threshold = 128;
if (threshold_ptr)
{
strategy = AllReduceStrategyType::MIN_LATENCY;
threshold = static_cast<size_t>(std::atoi(threshold_ptr));
}
// Generally, NCCL is faster than MIN_LATENCY when the token number is greater than 256. I conservatively
// set the threshold here to 128 tokens.
if (seq_len > threshold)
{
strategy = AllReduceStrategyType::NCCL;
}
else
{
strategy = AllReduceStrategyType::NCCL;
strategy = AllReduceStrategyType::MIN_LATENCY;
}
}
return strategy;

View File

@ -569,8 +569,8 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars)
"cpp/tensorrt_llm/executor/executorImpl.cpp",
"cpp/tensorrt_llm/executor/executorImpl.h",
"cpp/tensorrt_llm/runtime/ncclCommunicator.cpp",
"cpp/tensorrt_llm/kernels/allReduceFusionKernels.h",
"cpp/tensorrt_llm/kernels/allReduceFusionKernels.cu",
"cpp/tensorrt_llm/kernels/communicationKernels/",
"cpp/tensorrt_llm/thop/allreduceOp.cpp",
"cpp/tensorrt_llm/kernels/customAllReduceKernels.h",
"cpp/tensorrt_llm/kernels/customAllReduceKernels.cu",
"cpp/tensorrt_llm/kernels/gptKernels.h",