From bb7bcc75c23eb9cb40f6b44fd97407cbad55964b Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Fri, 9 May 2025 02:13:13 +0800 Subject: [PATCH] feat: Fallback to NCCL for various patterns when input size is large. (#4080) * Fallback to NCCL for various patterns when input size is large. Move the previous implementation to cpp side. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> * Revising. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --------- Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- cpp/tensorrt_llm/thop/allreduceOp.cpp | 391 +++++++++++++++---------- cpp/tensorrt_llm/thop/fp4Quantize.cpp | 1 + cpp/tensorrt_llm/thop/fp4Quantize.h | 28 ++ cpp/tensorrt_llm/thop/fp8Op.cpp | 1 + cpp/tensorrt_llm/thop/fp8Op.h | 44 +++ tensorrt_llm/_torch/distributed/ops.py | 75 +---- 6 files changed, 306 insertions(+), 234 deletions(-) create mode 100644 cpp/tensorrt_llm/thop/fp4Quantize.h create mode 100644 cpp/tensorrt_llm/thop/fp8Op.h diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index 8b6abffd66..b09f9f087e 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -27,9 +27,11 @@ #include "tensorrt_llm/kernels/userbuffers/ub_interface.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" +#include "tensorrt_llm/thop/fp4Quantize.h" +#include "tensorrt_llm/thop/fp8Op.h" #include "tensorrt_llm/thop/thUtils.h" +#include "tensorrt_llm/thop/userbuffersTensor.h" -#include "userbuffersTensor.h" #if ENABLE_MULTI_DEVICE #include #include @@ -160,21 +162,21 @@ public: // If strategy is set to UB, UB must be used as UB impl output is special and cannot be used // by others. - AllReduceStrategyType runtimeStrategy = getRuntimeStrategy(seq_len, size); + AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size); // Log runtime strategy auto const rank = COMM_SESSION.getRank(); - logRunTimeStrategy(runtimeStrategy, rank); + logRunTimeStrategy(runtime_strategy, rank); // Dispatch to different allreduce implementations - switch (runtimeStrategy) + switch (runtime_strategy) { case AllReduceStrategyType::UB: return runUBAllReduce(input, residual, norm_weight, scale, bias); case AllReduceStrategyType::NCCL: return runNCCLAllReduce(input, residual, norm_weight, scale, bias); case AllReduceStrategyType::MIN_LATENCY: case AllReduceStrategyType::ONESHOT: case AllReduceStrategyType::TWOSHOT: - return runFusionAllReduce(input, residual, norm_weight, scale, bias, workspace, runtimeStrategy); + return runFusionAllReduce(input, residual, norm_weight, scale, bias, workspace, runtime_strategy); default: TORCH_CHECK(false, "Invalid runtime strategy"); return {}; } } @@ -280,37 +282,18 @@ private: auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); int size = input.numel(); - int hidden_size = input.size(-1); - if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM) + torch::Tensor reduce_output = torch::empty_like(input); + NCCLCHECK(ncclAllReduce(input.data_ptr(), reduce_output.mutable_data_ptr(), size, (*getDtypeMap())[mType], + ncclSum, *mNcclComm, stream)); + + if (mOp == AllReduceFusionOp::NONE) { - torch::Tensor norm_out = torch::empty_like(input); - torch::Tensor residual_out = torch::empty_like(input); - - NCCLCHECK(ncclAllReduce(input.data_ptr(), residual_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], - ncclSum, *mNcclComm, stream)); - tensorrt_llm::kernels::AllReduceParams params; - params.fusion_params.bias_buffer = bias ? bias.value().data_ptr() : nullptr; - params.fusion_params.residual_buffer = residual ? residual.value().data_ptr() : nullptr; - params.fusion_params.weight_buffer = norm_weight ? norm_weight.value().data_ptr() : nullptr; - params.local_output_buffer_ptr = norm_out.mutable_data_ptr(); - params.elts_total = size; - - params.fusion_params.hidden_size = hidden_size; - params.fusion_params.eps = mEps; - params.fusion_params.intermediate_buffer = residual_out.mutable_data_ptr(); - tensorrt_llm::kernels::residualRmsNorm(params, mType, stream, mOp); - return {norm_out, residual_out}; + return {reduce_output}; } - else if (mOp == AllReduceFusionOp::NONE) - { - torch::Tensor output = torch::empty_like(input); - NCCLCHECK(ncclAllReduce(input.data_ptr(), output.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, - *mNcclComm, stream)); - return {output}; - } - TORCH_CHECK(false, "NCCL encounters unsupported fusion operation: " + tensorrt_llm::kernels::toString(mOp)); - return {}; + + // Treat any other patterns as fallback cases. + return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output); } std::vector runFusionAllReduce(torch::Tensor const& input, @@ -413,22 +396,22 @@ private: || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4) { // TODO: Better check for each pattern - int64_t sfVecSize = 16; + int64_t sf_vec_size = 16; int64_t m = 1; - auto const& inputShape = input.sizes(); - auto const& r = inputShape.size(); + auto const& input_shape = input.sizes(); + auto const& r = input_shape.size(); TORCH_CHECK(r >= 2, "Input should be >=2D tensor."); for (size_t i = 0; i < r - 1; i++) { - m *= inputShape[i]; + m *= input_shape[i]; } - auto const k = inputShape[r - 1]; - TORCH_CHECK(k % sfVecSize == 0, "Input should be divisible by sfVecSize."); - std::vector outputShape(inputShape.begin(), inputShape.end()); - outputShape[r - 1] = k / 2; + auto const k = input_shape[r - 1]; + TORCH_CHECK(k % sf_vec_size == 0, "Input should be divisible by sfVecSize."); + std::vector output_shape(input_shape.begin(), input_shape.end()); + output_shape[r - 1] = k / 2; - quant_out = at::detail::empty_cuda(outputShape, FLOAT4_E2M1X2, input.device(), std::nullopt); - scale_out = at::detail::empty_cuda({tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize)}, + quant_out = at::detail::empty_cuda(output_shape, FLOAT4_E2M1X2, input.device(), std::nullopt); + scale_out = at::detail::empty_cuda({tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sf_vec_size)}, SF_DTYPE, input.device(), std::nullopt); residual_out = torch::empty_like(residual.value()); @@ -495,18 +478,86 @@ private: return {}; } + std::vector fallbackRunSubsequentOps(torch::Tensor const& input, + torch::optional const& residual, torch::optional const& norm_weight, + torch::optional const& scale, torch::optional const& bias, + torch::Tensor& reduce_output) noexcept + { + // If we reach here, it means the extra fallback operations are required. + // All patterns are broken into ALlReduce + residual_rms_norm + following operations (quantization, etc.) + auto const size = input.numel(); + auto const hidden_size = input.size(-1); + auto const stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + torch::Tensor norm_out = torch::empty_like(input); + + tensorrt_llm::kernels::AllReduceParams params; + params.fusion_params.bias_buffer = bias ? bias.value().data_ptr() : nullptr; + params.fusion_params.residual_buffer = residual ? residual.value().data_ptr() : nullptr; + params.fusion_params.weight_buffer = norm_weight ? norm_weight.value().data_ptr() : nullptr; + params.local_output_buffer_ptr = norm_out.mutable_data_ptr(); + params.elts_total = size; + + params.fusion_params.hidden_size = hidden_size; + params.fusion_params.eps = mEps; + params.fusion_params.intermediate_buffer = reduce_output.mutable_data_ptr(); + tensorrt_llm::kernels::residualRmsNorm(params, mType, stream, AllReduceFusionOp::RESIDUAL_RMS_NORM); + + // If no quantization is needed, return the norm and residual outputs. + if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM) + { + return {norm_out, reduce_output}; + } + + const int64_t sf_vecsize = 16; + bool const sf_use_ue8m0 = false; + bool const is_sf_swizzled_layout = true; + TORCH_CHECK(scale, "scale is required for quantization ops"); + + // Attach the subsequent operations after the residual RMS norm all-reduce and return the final outputs. + switch (mOp) + { + case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8: + { + auto [quant_out, scale_out] = torch_ext::symmetric_static_quantize_per_tensor(norm_out, scale.value()); + return {quant_out, reduce_output}; + } + case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4: + { + auto [quant_out, scale_out] + = torch_ext::fp4_quantize(norm_out, scale.value(), sf_vecsize, sf_use_ue8m0, is_sf_swizzled_layout); + return {quant_out, scale_out, reduce_output}; + } + case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8: + { + auto [quant_out, scale_out] = torch_ext::symmetric_static_quantize_per_tensor(norm_out, scale.value()); + return {norm_out, quant_out, reduce_output}; + } + case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: + { + auto [quant_out, scale_out] + = torch_ext::fp4_quantize(norm_out, scale.value(), sf_vecsize, sf_use_ue8m0, is_sf_swizzled_layout); + return {norm_out, quant_out, scale_out, reduce_output}; + } + default: break; + } + + TORCH_CHECK(false, "Unsupported fusion operation: " + tensorrt_llm::kernels::toString(mOp)); + return {}; + } + AllReduceStrategyType getRuntimeStrategy(size_t seq_len, size_t size) noexcept { - static char* forceNcclAllReduceStrategyChar = std::getenv("FORCE_NCCL_ALL_REDUCE_STRATEGY"); - bool forceNcclAllReduceStrategy = (forceNcclAllReduceStrategyChar != nullptr); - AllReduceStrategyType runtimeStrategy; + static char* force_nccl_all_reduce_strategy_char = std::getenv("FORCE_NCCL_ALL_REDUCE_STRATEGY"); + bool force_nccl_all_reduce_strategy = (force_nccl_all_reduce_strategy_char != nullptr); + AllReduceStrategyType runtime_strategy; if (mStrategy == AllReduceStrategyType::UB) { - runtimeStrategy = AllReduceStrategyType::UB; + runtime_strategy = AllReduceStrategyType::UB; } - else if (forceNcclAllReduceStrategy || mStrategy == AllReduceStrategyType::NCCL) + else if (force_nccl_all_reduce_strategy || mStrategy == AllReduceStrategyType::NCCL) { - runtimeStrategy = AllReduceStrategyType::NCCL; + runtime_strategy = AllReduceStrategyType::NCCL; } else { @@ -514,14 +565,14 @@ private: static char* ifForBenchMark = std::getenv("OVERRIDE_HEURISTIC_ALLREDUCE_STRATEGY"); if (ifForBenchMark != nullptr) { - runtimeStrategy = mStrategy; + runtime_strategy = mStrategy; } else { - runtimeStrategy = selectImplementation(seq_len, size, mGroup.size(), mType); + runtime_strategy = selectImplementation(seq_len, size, mGroup.size(), mType); } } - return runtimeStrategy; + return runtime_strategy; } void logRunTimeStrategy(AllReduceStrategyType strategy, int rank) noexcept @@ -557,9 +608,9 @@ private: static std::map, std::tuple> cache; if (cache.find(mGroup) != cache.end()) { - auto [isNVLINKSupported, isP2PSupported] = cache[mGroup]; - mIsNVLINKSupported = isNVLINKSupported; - mIsP2PSupported = isP2PSupported; + auto [is_NVLINK_supported, is_P2P_supported] = cache[mGroup]; + mIsNVLINKSupported = is_NVLINK_supported; + mIsP2PSupported = is_P2P_supported; return; } setGroupTopology(); @@ -570,8 +621,8 @@ private: { auto const rank = COMM_SESSION.getRank(); TLLM_LOG_INFO("Detecting local TP group for rank %d", rank); - std::set localGroup = getLocalGroup(mGroup); - if (mGroup.size() != localGroup.size()) + std::set local_group = getLocalGroup(mGroup); + if (mGroup.size() != local_group.size()) { mIsP2PSupported = false; mIsNVLINKSupported = false; @@ -580,26 +631,27 @@ private: } TLLM_LOG_INFO("TP group is intra-node for rank %d", rank); - NvmlManager nvmlManager; - std::unordered_set visitedDevice; + NvmlManager nvml_manager; + std::unordered_set visited_device; mIsP2PSupported = true; mIsNVLINKSupported = true; // Use cudaDeviceCanAccessPeer to determine whether p2p is supported, // and use nvml to determine whether there are nvlink links between ranks. - for (int firstDeviceId : localGroup) + for (int first_device_id : local_group) { - for (int secondDeviceId : localGroup) + for (int second_device_id : local_group) { - if (firstDeviceId == secondDeviceId || visitedDevice.find(secondDeviceId) != visitedDevice.end()) + if (first_device_id == second_device_id + || visited_device.find(second_device_id) != visited_device.end()) { continue; } - int canAccessPeer = 0; - TLLM_CUDA_CHECK(cudaDeviceCanAccessPeer(&canAccessPeer, firstDeviceId, secondDeviceId)); + int can_access_peer = 0; + TLLM_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, first_device_id, second_device_id)); - if (!canAccessPeer) + if (!can_access_peer) { mIsP2PSupported = false; mIsNVLINKSupported = false; @@ -607,31 +659,31 @@ private: return; } - nvmlDevice_t firstDevice; - NVML_CHECK(nvmlDeviceGetHandleByIndex(firstDeviceId, &firstDevice)); + nvmlDevice_t first_device; + NVML_CHECK(nvmlDeviceGetHandleByIndex(first_device_id, &first_device)); - bool isNVLINK = false; + bool is_NVLINK = false; for (unsigned int link = 0; link < NVML_NVLINK_MAX_LINKS; link++) { - nvmlPciInfo_t remotePciInfo; - if (nvmlDeviceGetNvLinkRemotePciInfo_v2(firstDevice, link, &remotePciInfo) != NVML_SUCCESS) + nvmlPciInfo_t remote_pci_info; + if (nvmlDeviceGetNvLinkRemotePciInfo_v2(first_device, link, &remote_pci_info) != NVML_SUCCESS) { continue; } - nvmlDevice_t remoteDevice; - auto const result = nvmlDeviceGetHandleByPciBusId_v2(remotePciInfo.busId, &remoteDevice); + nvmlDevice_t remote_device; + auto const result = nvmlDeviceGetHandleByPciBusId_v2(remote_pci_info.busId, &remote_device); if (result == NVML_SUCCESS) { // Two GPUs are connected directly through nvlink - unsigned int remoteDeviceId; - NVML_CHECK(nvmlDeviceGetIndex(remoteDevice, &remoteDeviceId)); + unsigned int remote_device_id; + NVML_CHECK(nvmlDeviceGetIndex(remote_device, &remote_device_id)); - if (remoteDeviceId == static_cast(secondDeviceId)) + if (remote_device_id == static_cast(second_device_id)) { - isNVLINK = true; + is_NVLINK = true; } } else if (result == NVML_ERROR_NOT_FOUND) @@ -640,21 +692,21 @@ private: // now remotePciInfo represents the pci information of nvswitch, // determine whether nvlink is supported by whether two GPUs are connected to the same // nvswitch. - nvmlDevice_t secondDevice; - NVML_CHECK(nvmlDeviceGetHandleByIndex(secondDeviceId, &secondDevice)); + nvmlDevice_t second_device; + NVML_CHECK(nvmlDeviceGetHandleByIndex(second_device_id, &second_device)); - for (unsigned int secondLink = 0; secondLink < NVML_NVLINK_MAX_LINKS; secondLink++) + for (unsigned int second_link = 0; second_link < NVML_NVLINK_MAX_LINKS; second_link++) { - nvmlPciInfo_t secondRemotePciInfo; - if (nvmlDeviceGetNvLinkRemotePciInfo_v2(secondDevice, secondLink, &secondRemotePciInfo) + nvmlPciInfo_t second_remote_pci_info; + if (nvmlDeviceGetNvLinkRemotePciInfo_v2(second_device, second_link, &second_remote_pci_info) != NVML_SUCCESS) { continue; } - if (strcmp(remotePciInfo.busId, secondRemotePciInfo.busId) == 0) + if (strcmp(remote_pci_info.busId, second_remote_pci_info.busId) == 0) { - isNVLINK = true; + is_NVLINK = true; break; } } @@ -664,28 +716,74 @@ private: NVML_CHECK(result); } - if (isNVLINK) + if (is_NVLINK) { break; } } - mIsNVLINKSupported &= isNVLINK; + mIsNVLINKSupported &= is_NVLINK; } - visitedDevice.insert(firstDeviceId); + visited_device.insert(first_device_id); } } + bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size, bool is_auto) noexcept + { + // If messageSize is less than maxWorkspaceSize, use NCCL, regardless of the fusion type. + if (message_size_bytes > max_workspace_size) + { + if (!is_auto) + { + TLLM_LOG_WARNING( + "Since messageSize is greater than maxWorkspaceSize, fallback to AllReduceStrategy: NCCL"); + } + return true; + } + + // If Peer to Peer is not supported, fallback to NCCL. + if (!mIsP2PSupported) + { + if (!is_auto) + { + TLLM_LOG_WARNING("Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL"); + } + return true; + } + + // If NVLINK is not supported, fallback to NCCL. + if (!mIsNVLINKSupported) + { + if (!is_auto) + { + TLLM_LOG_WARNING("Since NVLINK not supported, fallback to AllReduceStrategy: NCCL"); + } + return true; + } + return false; + } + AllReduceStrategyType selectImplementation( - size_t seq_len, size_t messageSize, int worldSize, nvinfer1::DataType type) noexcept + size_t seq_len, size_t message_size, int world_size, nvinfer1::DataType type) noexcept { // Check that heuristic is only applied when AUTO is set. - bool const isAuto = (mStrategy == AllReduceStrategyType::AUTO); + bool const is_auto = (mStrategy == AllReduceStrategyType::AUTO); + auto const message_size_bytes = message_size * tensorrt_llm::common::getDTypeSize(type); + auto const max_workspace_size + = tensorrt_llm::utils::customAllReduceUtils::getMaxRequiredWorkspaceSize(world_size); - // This rule based heuristic only chooses NCCL and MIN_LATENCY strategies. + if (ifFallbackToNCCL(seq_len, message_size_bytes, max_workspace_size, is_auto)) + { + return AllReduceStrategyType::NCCL; + } - // Only the intersection of the supported fusion types of two implementations will go through the heuristic. - // Otherwise, MIN_LATENCY strategy will be returned due to more fusion patterns it can support. + // This rule based heuristic only chooses between NCCL and MIN_LATENCY strategies. + + // Heurisitic will only be applied on NONE and RESIDUAL_RMS_NORM fusion types. + // Because NCCL might be faster on some large messageSize cases. + // Otherwise, MIN_LATENCY strategy will be directly returned due to more fusions it can support. + // TODO: NCCL AllReduce + subsequent quantization ops (as fallback) can also support the fusion types. + // This should be compared with MIN_LATENCY fused kernels to determine the best strategy. switch (mOp) { case AllReduceFusionOp::NONE: @@ -693,89 +791,60 @@ private: case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8: case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8: case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4: - case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: - default: return AllReduceStrategyType::MIN_LATENCY; + case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: return AllReduceStrategyType::MIN_LATENCY; + // Suppose NCCL has fallback implementations for all fusion types. + default: return AllReduceStrategyType::NCCL; } // Check mOp to be supported by the heuristic. TORCH_CHECK(mOp == AllReduceFusionOp::NONE || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM, - "Only NONE and RESIDUAL_RMS_NORM are supported for heuristic."); + "Only NONE and RESIDUAL_RMS_NORM are supported for NCCL/MIN_LATENCY heuristic."); - // If AUTO is set, but P2P is not supported, fallback to NCCL. - if (!mIsP2PSupported) + // Default to NCCL. + AllReduceStrategyType strategy = AllReduceStrategyType::NCCL; + + // Currently we will not remove ONESHOT and TWOSHOT from the strategy list + // But torch flow user should not use them, but use AUTO or MIN_LATENCY instead. + // NOTICE: When a fusion type is not supported by the corresponding strategy but strategy is not AUTO, + // user should guarantee the correctness of the fusion pattern dispatching. + if (!is_auto) { - if (!isAuto) + if (mStrategy == AllReduceStrategyType::ONESHOT || mStrategy == AllReduceStrategyType::TWOSHOT) { - TLLM_LOG_WARNING("Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL"); - } - return AllReduceStrategyType::NCCL; - } - - // If AUTO is set, but NVLINK is not supported, fallback to NCCL. - if (isAuto && !mIsNVLINKSupported) - { - return AllReduceStrategyType::NCCL; - } - - auto const maxWorkspaceSize = tensorrt_llm::utils::customAllReduceUtils::getMaxRequiredWorkspaceSize(worldSize); - - AllReduceStrategyType strat = AllReduceStrategyType::NCCL; - auto const messageSizeBytes = messageSize * tensorrt_llm::common::getDTypeSize(type); - - if (messageSizeBytes <= maxWorkspaceSize) - { - // Currently we will not remove ONESHOT and TWOSHOT from the strategy list - // But torch flow user should not use them, but use AUTO or MIN_LATENCY instead. - // NOTICE: When a fusion type is not supported by the corresponding strategy but strategy is not AUTO, - // user should guarantee the correctness of the fusion pattern dispatching. - if (!isAuto) - { - if (mStrategy == AllReduceStrategyType::ONESHOT || mStrategy == AllReduceStrategyType::TWOSHOT) - { - strat = AllReduceStrategyType::MIN_LATENCY; - } - else - { - strat = mStrategy; - } - } - else if (worldSize <= 2) - { - strat = AllReduceStrategyType::MIN_LATENCY; - } - else if (worldSize <= 4) - { - if (messageSizeBytes < 1 * 1000 * 1000) - { - strat = AllReduceStrategyType::MIN_LATENCY; - } - else - { - strat = AllReduceStrategyType::NCCL; - } + strategy = AllReduceStrategyType::MIN_LATENCY; } else { - if (messageSizeBytes < 500 * 1000) - { - strat = AllReduceStrategyType::MIN_LATENCY; - } - else - { - strat = AllReduceStrategyType::NCCL; - } + strategy = mStrategy; + } + } + else if (world_size <= 2) + { + strategy = AllReduceStrategyType::MIN_LATENCY; + } + else if (world_size <= 4) + { + if (message_size_bytes < 1 * 1000 * 1000) + { + strategy = AllReduceStrategyType::MIN_LATENCY; + } + else + { + strategy = AllReduceStrategyType::NCCL; } } else { - if (!isAuto) + if (message_size_bytes < 500 * 1000) { - TLLM_LOG_WARNING("Since messageSize > maxWorkspace, fallback to AllReduceStrategy: NCCL"); + strategy = AllReduceStrategyType::MIN_LATENCY; + } + else + { + strategy = AllReduceStrategyType::NCCL; } - strat = AllReduceStrategyType::NCCL; } - - return strat; + return strategy; } private: @@ -793,10 +862,10 @@ private: #endif // ENABLE_MULTI_DEVICE -std::vector allreduce(torch::Tensor input, torch::optional residual, - torch::optional norm_weight, torch::optional scale, - torch::optional bias, torch::optional workspace, torch::List group_, - int64_t const strategy_, int64_t const fusion_op_, double const eps_) +std::vector allreduce(torch::Tensor const& input, torch::optional const& residual, + torch::optional const& norm_weight, torch::optional const& scale, + torch::optional const& bias, torch::optional const& workspace, + torch::List const& group_, int64_t const strategy_, int64_t const fusion_op_, double const eps_) { #if ENABLE_MULTI_DEVICE auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); diff --git a/cpp/tensorrt_llm/thop/fp4Quantize.cpp b/cpp/tensorrt_llm/thop/fp4Quantize.cpp index bbb5f28a70..030fe3e06d 100644 --- a/cpp/tensorrt_llm/thop/fp4Quantize.cpp +++ b/cpp/tensorrt_llm/thop/fp4Quantize.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "tensorrt_llm/thop/fp4Quantize.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/quantization.h" #include "tensorrt_llm/thop/thUtils.h" diff --git a/cpp/tensorrt_llm/thop/fp4Quantize.h b/cpp/tensorrt_llm/thop/fp4Quantize.h new file mode 100644 index 0000000000..ea6dc1f59f --- /dev/null +++ b/cpp/tensorrt_llm/thop/fp4Quantize.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/thop/thUtils.h" + +#include +#include + +namespace torch_ext +{ +std::tuple fp4_quantize(torch::Tensor const& self, torch::Tensor const& globalScale, + int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout); +} diff --git a/cpp/tensorrt_llm/thop/fp8Op.cpp b/cpp/tensorrt_llm/thop/fp8Op.cpp index bb197a0296..afd6e388d8 100644 --- a/cpp/tensorrt_llm/thop/fp8Op.cpp +++ b/cpp/tensorrt_llm/thop/fp8Op.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "tensorrt_llm/thop/fp8Op.h" #include "tensorrt_llm/common/cudaBf16Wrapper.h" #include "tensorrt_llm/common/cudaFp8Utils.h" #include "tensorrt_llm/thop/thUtils.h" diff --git a/cpp/tensorrt_llm/thop/fp8Op.h b/cpp/tensorrt_llm/thop/fp8Op.h new file mode 100644 index 0000000000..f12f166c4e --- /dev/null +++ b/cpp/tensorrt_llm/thop/fp8Op.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/quantization.h" +#include "tensorrt_llm/thop/thUtils.h" + +#include + +#include + +#include + +namespace torch_ext +{ +std::tuple symmetric_quantize_weight(torch::Tensor weight); +std::tuple symmetric_quantize_activation(torch::Tensor activation); +std::tuple symmetric_quantize_per_tensor(torch::Tensor input); +std::tuple symmetric_static_quantize_weight(torch::Tensor weight, torch::Tensor scales); +std::tuple symmetric_static_quantize_activation( + torch::Tensor activation, torch::Tensor scales); +std::tuple symmetric_static_quantize_per_tensor( + torch::Tensor input, torch::Tensor scales); + +torch::Tensor symmetric_dequantize_weight(torch::Tensor weight, torch::Tensor scales); +torch::Tensor symmetric_dequantize_activation(torch::Tensor activation, torch::Tensor scales); +torch::Tensor symmetric_dequantize_per_tensor(torch::Tensor input, torch::Tensor scales); + +} // namespace torch_ext diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 847c21bd7e..f26e2c0114 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -142,67 +142,12 @@ class AllReduce(nn.Module): self.mapping = mapping self.workspace = None self.strategy = strategy - self.max_workspace_size = CustomAllReduceHelper.max_workspace_size_auto( - self.mapping.tp_size, support_deterministic=False) - - self.fallback_func_mapping = { - AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8: - self.fallback_residual_rms_norm_quant_fp8, - AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8: - self.fallback_residual_rms_norm_out_quant_fp8, - AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4: - self.fallback_residual_rms_norm_quant_nvfp4, - AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: - self.fallback_residual_rms_norm_out_quant_nvfp4, - } if self.mapping.tp_size > 1: # When Strategy is UB, it is guaranteed that the workspace is not used. if self.strategy != AllReduceStrategy.UB: self.workspace = get_allreduce_workspace(self.mapping) - @staticmethod - def fallback_residual_rms_norm_quant_fp8( - output: Tuple[torch.Tensor, ...], - all_reduce_params: AllReduceParams, - ): - norm_out, residual_out = output - quant_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( - norm_out, all_reduce_params.scale) - return quant_fp8, residual_out - - @staticmethod - def fallback_residual_rms_norm_out_quant_fp8( - output: Tuple[torch.Tensor, ...], - all_reduce_params: AllReduceParams, - ): - norm_out, residual_out = output - quant_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( - norm_out, all_reduce_params.scale) - return norm_out, quant_fp8, residual_out - - @staticmethod - def fallback_residual_rms_norm_quant_nvfp4( - output: Tuple[torch.Tensor, ...], - all_reduce_params: AllReduceParams, - ): - norm_out, residual_out = output - quant_fp4, scale_factor = torch.ops.trtllm.fp4_quantize( - norm_out, all_reduce_params.scale, 16, False) - - return quant_fp4, scale_factor, residual_out - - @staticmethod - def fallback_residual_rms_norm_out_quant_nvfp4( - output: Tuple[torch.Tensor, ...], - all_reduce_params: AllReduceParams, - ): - norm_out, residual_out = output - quant_fp4, scale_factor = torch.ops.trtllm.fp4_quantize( - norm_out, all_reduce_params.scale, 16, False) - - return norm_out, quant_fp4, scale_factor, residual_out - def forward( self, input: torch.Tensor, @@ -241,16 +186,6 @@ class AllReduce(nn.Module): if all_reduce_params is None: all_reduce_params = AllReduceParams() - strategy = self.strategy - fusion_op = all_reduce_params.fusion_op - - # If the input size is larger than the max workspace size, fallback to NCCL strategy - if input.numel() > self.max_workspace_size \ - and all_reduce_params.fusion_op != AllReduceFusionOp.NONE \ - and all_reduce_params.fusion_op != AllReduceFusionOp.RESIDUAL_RMS_NORM: - strategy = AllReduceStrategy.NCCL - fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM - output = torch.ops.trtllm.allreduce( input=input, residual=all_reduce_params.residual, @@ -259,17 +194,11 @@ class AllReduce(nn.Module): bias=all_reduce_params.bias, workspace=self.workspace, group=self.mapping.tp_group, - strategy=strategy, - op=fusion_op, + strategy=self.strategy, + op=all_reduce_params.fusion_op, eps=all_reduce_params.eps, ) - if input.numel() > self.max_workspace_size \ - and all_reduce_params.fusion_op != AllReduceFusionOp.NONE \ - and all_reduce_params.fusion_op != AllReduceFusionOp.RESIDUAL_RMS_NORM: - output = self.fallback_func_mapping[all_reduce_params.fusion_op]( - output, all_reduce_params) - return output if len(output) > 1 else output[0]