diff --git a/cpp/tensorrt_llm/common/customAllReduceUtils.h b/cpp/tensorrt_llm/common/customAllReduceUtils.h index 5f35ff5f89..f69eb0cb25 100644 --- a/cpp/tensorrt_llm/common/customAllReduceUtils.h +++ b/cpp/tensorrt_llm/common/customAllReduceUtils.h @@ -115,7 +115,7 @@ inline AllReduceStrategyType selectStrategyLookUpTable( || num_token_index >= AllReduceBestStrategyTable.at(sm_version).at(tp_index).at(fusion_op_index).at(hidden_size_index).size()) { - return AllReduceStrategyType::NCCL_SYMMETRIC; + return AllReduceStrategyType::NCCL; } return static_cast( diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index 00fe82c6fd..e9812b52c0 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -282,7 +282,7 @@ public: std::vector run(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias, bool trigger_completion_at_end, - torch::optional workspace) noexcept + torch::optional workspace) { size_t size = input.numel(); size_t seq_len = input.size(0); @@ -582,7 +582,7 @@ private: std::vector runLowPrecisionAllReduce(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, - torch::optional const& scale, torch::optional const& bias) noexcept + torch::optional const& scale, torch::optional const& bias) { #ifdef ENABLE_FP8 auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); @@ -650,8 +650,7 @@ private: std::vector runFusionAllReduce(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias, - bool trigger_completion_at_end, torch::optional workspace, - AllReduceStrategyType strategy) noexcept + bool trigger_completion_at_end, torch::optional workspace, AllReduceStrategyType strategy) { // Should handle only Lamport implementation auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); @@ -1224,7 +1223,7 @@ private: if (ifFallbackToNCCL(seq_len, message_size_bytes, max_workspace_size)) { - return AllReduceStrategyType::NCCL_SYMMETRIC; + return AllReduceStrategyType::NCCL; } // This rule based heuristic only chooses between NCCL_SYMMETRIC and MIN_LATENCY strategies. @@ -1250,7 +1249,8 @@ private: bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size) { - // If messageSize is greater than maxWorkspaceSize or topology is unsuitable, use NCCL_SYMMETRIC fallback. + // If messageSize is greater than maxWorkspaceSize or topology is unsuitable, use NCCL fallback. + // TODO: Use NCCL_SYMMETRIC once the memory allocation issue is resolved. if (message_size_bytes > max_workspace_size || !mIsP2PSupported || !mIsNVLINKSupported) { return true;