/* * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/cudaDriverWrapper.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/customAllReduceUtils.h" #include "tensorrt_llm/common/dataType.h" #include "tensorrt_llm/common/mcastDevMemUtils.h" #include "tensorrt_llm/common/ncclUtils.h" #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h" #include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h" #include "tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.h" #include "tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h" #include "tensorrt_llm/kernels/customAllReduceKernels.h" #include "tensorrt_llm/kernels/quantization.h" #include "tensorrt_llm/kernels/userbuffers/ub_interface.h" #include "tensorrt_llm/runtime/mcastDeviceMemory.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include "tensorrt_llm/runtime/utils/pgUtils.h" #include "tensorrt_llm/thop/fp4Quantize.h" #include "tensorrt_llm/thop/fp8Op.h" #include "tensorrt_llm/thop/thUtils.h" #include "tensorrt_llm/thop/userbuffersTensor.h" #if ENABLE_MULTI_DEVICE #include #include #include #include #include #include #include #include #include #endif // ENABLE_MULTI_DEVICE #include #include #include #include #include #include // using namespace nvinfer1; using tensorrt_llm::kernels::AllReduceFusionOp; using tensorrt_llm::kernels::AllReduceStrategyType; using tensorrt_llm::mpi::MpiTag; using tensorrt_llm::pg_utils::get_world_pg; using tensorrt_llm::pg_utils::get_local_pg; using tensorrt_llm::pg_utils::PgHelper; namespace torch_ext { #if ENABLE_MULTI_DEVICE namespace { template struct overloaded : Ts... { using Ts::operator()...; }; template overloaded(Ts...) -> overloaded; class NvmlManager { public: NvmlManager() { NVML_CHECK_THROW(nvmlInit()); } ~NvmlManager() { NVML_CHECK(nvmlShutdown()); } }; std::set getLocalGroup(std::set const& group) { auto const myRank = COMM_SESSION.getRank(); auto const myLocalRank = LOCAL_COMM_SESSION.getRank(); auto const localSize = static_cast(LOCAL_COMM_SESSION.getSize()); std::vector ranks(localSize, 0); std::vector localRanks(localSize, 0); if (group.size() >= localSize) { LOCAL_COMM_SESSION.allgather(&myRank, ranks.data(), 1, tensorrt_llm::mpi::MpiType::kINT32); LOCAL_COMM_SESSION.allgather(&myLocalRank, localRanks.data(), 1, tensorrt_llm::mpi::MpiType::kINT32); } else { if (myRank == *group.begin()) { ranks.clear(); int rank; ranks.push_back(myRank); for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) { LOCAL_COMM_SESSION.recvValue(rank, *it, MpiTag::kDefault); ranks.push_back(rank); } for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) { LOCAL_COMM_SESSION.send( ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, MpiTag::kDefault); } localRanks.clear(); localRanks.push_back(myLocalRank); for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) { LOCAL_COMM_SESSION.recvValue(rank, *it, MpiTag::kDefault); localRanks.push_back(rank); } for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) { LOCAL_COMM_SESSION.send( localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, MpiTag::kDefault); } } else { LOCAL_COMM_SESSION.sendValue(myRank, *group.begin(), MpiTag::kDefault); LOCAL_COMM_SESSION.recv( ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), MpiTag::kDefault); LOCAL_COMM_SESSION.sendValue(myLocalRank, *group.begin(), MpiTag::kDefault); LOCAL_COMM_SESSION.recv( localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), MpiTag::kDefault); } } std::set localGroup; for (size_t i = 0; i < ranks.size(); ++i) { auto rank = ranks[i]; if (group.find(rank) != group.end()) { localGroup.insert(localRanks[i]); } } return localGroup; } std::set getLocalGroupTorch(std::set const& group) { auto const worldPg = get_world_pg(); auto const myRank = worldPg->getRank(); auto const localPg = get_local_pg(); auto const myLocalRank = localPg->getRank(); auto const localSize = static_cast(localPg->getSize()); PgHelper pgh_local{localPg}; PgHelper pgh_world{worldPg}; // for p2p std::vector ranks(localSize, -1); std::vector localRanks(localSize, -1); if (group.size() >= localSize) { PGCHECK_THROW(pgh_local.allgather(&myRank, ref(ranks), {})); PGCHECK_THROW(pgh_local.allgather(&myLocalRank, ref(localRanks), {})); } else { int tag = static_cast(MpiTag::kDefault); if (myRank == *group.begin()) { // Leader: gather from peers (world ranks), then broadcast full localSize arrays. size_t cnt = 0; ranks[cnt++] = myRank; int tmp; for (auto it = std::next(group.begin()); it != group.end(); ++it) { PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag)); ranks[cnt++] = tmp; } for (auto it = std::next(group.begin()); it != group.end(); ++it) { PGCHECK_THROW(pgh_world.send(ref(ranks), *it, tag)); } cnt = 0; localRanks[cnt++] = myLocalRank; for (auto it = std::next(group.begin()); it != group.end(); ++it) { PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag)); localRanks[cnt++] = tmp; } for (auto it = std::next(group.begin()); it != group.end(); ++it) { PGCHECK_THROW(pgh_world.send(ref(localRanks), *it, tag)); } } else { int leader = *group.begin(); PGCHECK_THROW(pgh_world.send(&myRank, leader, tag)); PGCHECK_THROW(pgh_world.recv(ref(ranks), leader, tag)); PGCHECK_THROW(pgh_world.send(&myLocalRank, leader, tag)); PGCHECK_THROW(pgh_world.recv(ref(localRanks), leader, tag)); } } std::set localGroup; for (size_t i = 0; i < ranks.size(); ++i) { int world_r = ranks[i]; if (group.find(world_r) != group.end()) localGroup.insert(localRanks[i]); } return localGroup; } class AllreduceOp { public: AllreduceOp( std::set group, nvinfer1::DataType type, AllReduceStrategyType strategy, AllReduceFusionOp op, float eps) : mGroup(std::move(group)) , mIsNVLINKSupported(false) , mIsP2PSupported(false) , mIsMNNVLSupported(false) , mType(type) , mStrategy(strategy) , mOp(op) , mEps(eps) { } AllreduceOp(std::set group, c10::intrusive_ptr const& process_group_, nvinfer1::DataType type, AllReduceStrategyType strategy, AllReduceFusionOp op, float eps) : mGroup(std::move(group)) , mIsNVLINKSupported(false) , mIsP2PSupported(false) , mIsMNNVLSupported(false) , mType(type) , mStrategy(strategy) , mOp(op) , mEps(eps) , mNcclComm(process_group_) { } ~AllreduceOp() = default; int getRank() const { return std::visit( overloaded{[&](std::shared_ptr const&) { return COMM_SESSION.getRank(); }, [&](c10::intrusive_ptr const& torchPg) { return get_world_pg()->getRank(); }}, mNcclComm); } std::vector run(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias, bool trigger_completion_at_end, torch::optional workspace) noexcept { size_t size = input.numel(); size_t seq_len = input.size(0); size_t hidden_size = input.size(-1); size_t bytes_per_element = input.element_size(); TLLM_LOG_DEBUG("All reduce message size is %zu", size * bytes_per_element); AllReduceStrategyType runtime_strategy = selectImplementation(seq_len, hidden_size); // Log runtime strategy auto const rank = getRank(); TLLM_LOG_DEBUG( "AllReduceOp runtime strategy for rank %d: " + tensorrt_llm::kernels::toString(runtime_strategy), rank); // Dispatch to different allreduce implementations switch (runtime_strategy) { case AllReduceStrategyType::UB: return runUBAllReduce(input, residual, norm_weight, scale, bias); case AllReduceStrategyType::NCCL: return runNCCLAllReduce(input, residual, norm_weight, scale, bias); case AllReduceStrategyType::NCCL_SYMMETRIC: return runNCCLAllReduceSymmetric(input, residual, norm_weight, scale, bias); case AllReduceStrategyType::MIN_LATENCY: case AllReduceStrategyType::ONESHOT: case AllReduceStrategyType::TWOSHOT: return runFusionAllReduce( input, residual, norm_weight, scale, bias, trigger_completion_at_end, workspace, runtime_strategy); case AllReduceStrategyType::LOWPRECISION: return runLowPrecisionAllReduce(input, residual, norm_weight, scale, bias); default: TORCH_CHECK(false, "Invalid runtime strategy"); return {}; } } int initialize() { TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, getRank()); if (mNcclComm.index() == 0) { mNcclComm = getComm(mGroup); } if (mStrategy != AllReduceStrategyType::NCCL && mStrategy != AllReduceStrategyType::UB) { initGroupTopology(); } TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, getRank()); return 0; } private: std::vector runUBAllReduce(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias) { auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); int size = input.numel(); int hidden_size = input.size(-1); torch::Tensor residual_out = torch::empty_like(input); TLLM_CHECK(mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8 || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4); TLLM_CHECK_WITH_INFO(tensorrt_llm::runtime::ub::ub_is_initialized(), "UserBuffer has not been initialized!"); auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance(); auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr()); TLLM_CHECK(!ub_buffer0.invalid()); auto ub_comm = ub_manager.comm(); int m = size / hidden_size; if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM) { TORCH_CHECK(norm_weight, "norm_weight is required for residual rms norm allreduce"); TORCH_CHECK(!bias, "bias is not supported for residual rms norm allreduce"); TORCH_CHECK(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16); auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); tensorrt_llm::kernels::ub::allreduce2_userbuff_rmsnorm_launcher(ub_buffer0.handle, 0, ub_buffer1.handle, 0, size, hidden_size, nullptr, norm_weight.value().data_ptr(), mEps, residual.value().data_ptr(), residual_out.data_ptr(), mType, ub_comm, stream); return {norm_out, residual_out}; } else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8) { TORCH_CHECK(scale, "scale is required for FP8 allreduce"); TORCH_CHECK(norm_weight, "norm_weight is required for FP8 allreduce"); TORCH_CHECK(!bias, "bias is not supported for FP8 allreduce"); auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), torch::kFloat8_e4m3fn); tensorrt_llm::kernels::ub::allreduce2_userbuff_inplace_rmsnorm_quant_launcher(ub_buffer0.handle, 0, ub_buffer1.handle, 0, size, hidden_size, nullptr, norm_weight.value().data_ptr(), mEps, static_cast(scale.value().data_ptr()), residual.value().data_ptr(), residual_out.data_ptr(), mType, ub_comm, stream); return {norm_out, residual_out}; } else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4) { TORCH_CHECK(scale, "scale is required for FP4 allreduce"); TORCH_CHECK(norm_weight, "norm_weight is required for FP4 allreduce"); TORCH_CHECK(!bias, "bias is not supported for FP4 allreduce"); int const sfVecSize = 16; int scale_size = tensorrt_llm::common::roundUp(m, 128) * tensorrt_llm::common::roundUp(hidden_size / sfVecSize, 4); TORCH_CHECK(hidden_size % sfVecSize == 0, "hidden_size must be divisible by 16 for FP4 allreduce"); auto output_shape = input.sizes().vec(); output_shape.back() /= 2; auto output_strides = input.strides().vec(); for (size_t i = 0; i < output_shape.size() - 1; i++) { output_strides[i] /= 2; } auto [quant_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(output_shape, torch::kByte); auto [scale_out, ub_buffer2] = torch_ext::create_userbuffers_tensor({scale_size}, torch::kByte); tensorrt_llm::kernels::ub::allreduce2_userbuff_inplace_rmsnorm_quant_fp4_launcher(ub_buffer0.handle, 0, ub_buffer1.handle, 0, ub_buffer2.handle, 0, size, hidden_size, nullptr, norm_weight.value().data_ptr(), mEps, static_cast(scale.value().data_ptr()), residual.value().data_ptr(), residual_out.data_ptr(), mType, ub_comm, stream); return {quant_out, scale_out, residual_out}; } TORCH_CHECK( false, "UBAllreduce does not support the fusion operation: " + tensorrt_llm::kernels::toString(mOp)); return {}; } std::vector runNCCLAllReduce(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias) { torch::Tensor reduce_output; std::visit(overloaded{[&](std::shared_ptr& rawComm) { auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); int size = input.numel(); reduce_output = torch::empty_like(input); NCCLCHECK_THROW(ncclAllReduce(input.data_ptr(), reduce_output.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *rawComm, stream)); }, [&](c10::intrusive_ptr& torchPg) { reduce_output = input.clone(); // TLLM_LOG_INFO("AllReduce Rank: %d, tensor numel: %d", torchPg->getRank(), // reduce_output.numel()); std::vector tensors{reduce_output}; PGCHECK_THROW(torchPg->allreduce(tensors, {c10d::ReduceOp::SUM})); }}, mNcclComm); if (mOp == AllReduceFusionOp::NONE) { return {reduce_output}; } // Treat any other patterns as fallback cases. return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output); } std::vector runNCCLAllReduceSymmetric(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias) { // Handle ProcessGroup path first - cannot extract NCCL comm for window registration // Use ProcessGroup's allreduce directly and return early if (mNcclComm.index() == 1) { auto torchPg = std::get<1>(mNcclComm); torch::Tensor reduceOutput = input.clone(); std::vector tensors{reduceOutput}; PGCHECK_THROW(torchPg->allreduce(tensors, {c10d::ReduceOp::SUM})); if (mOp == AllReduceFusionOp::NONE) { return {reduceOutput}; } // Treat any other patterns as fallback cases. return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduceOutput); } // From here on, we have a raw NCCL comm - can proceed with window registration auto rawComm = std::get<0>(mNcclComm); ncclComm_t comm = *rawComm; TLLM_CHECK_WITH_INFO(comm != nullptr, "NCCL communicator is null"); TLLM_LOG_DEBUG("[runNCCLAllReduceSymmetric] Using raw NCCL comm path (not ProcessGroup)"); using tensorrt_llm::common::nccl_util::NCCLWindowAllocator; using tensorrt_llm::common::nccl_util::createNCCLWindowTensor; auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); int size = input.numel(); size_t bufferSizeBytes = size * input.element_size(); // Using unregistered input buffers with NCCL symmetric, requires a memcpy // This is an overhead introduced with using NCCL_SYMMTRIC over NCCL. // Both the memcpy and the perf benefit from using NCCL_SYMMETRIC scale linear with the message size. // But a local memcpy is cheaper than the remote operations, so with larger message sizes the benefit is // stronger. Additionally, the perf benefit scales with the number of ranks, since multimem enables O(const.) // versus O(N) complexity. Hence we model this cutoff with a linear model. The numbers below were obtained on // GB200, scanning different message sizes and ranks. You can determine the regression onset for each number of // ranks to a single message size. And the following formula was obtained by fitting a linear model to the // regression onset. It is possible to override this empirical heuristic with the TLLM_NCCL_MIN_REGISTRATION // environment variable. double const a = -4986.43478503; double const b = 156716.52177552; int nRanks; NCCLCHECK_THROW(ncclCommCount(comm, &nRanks)); size_t minRegistrationThreshold = static_cast(std::max(0.0, a * nRanks + b)) * input.element_size(); // Disable window registration if neither NVLink nor MNNVL is supported // TODO replace in NCCL 2.29 with comm query if (!mIsNVLINKSupported && !mIsMNNVLSupported) { minRegistrationThreshold = std::numeric_limits::max(); } char const* envThreshold = std::getenv("TLLM_NCCL_MIN_REGISTRATION"); if (envThreshold != nullptr) { minRegistrationThreshold = static_cast(std::atoi(envThreshold)) * input.element_size(); } // Search for existing buffer auto& allocator = NCCLWindowAllocator::getInstance(); auto windowBuffer0 = allocator.searchBuffer(comm, input.data_ptr()); torch::Tensor inputTensor = input; void* inputPtr = input.data_ptr(); // If buffer is not registered, decide whether to register based on size if (!windowBuffer0.isValid()) { if (bufferSizeBytes < minRegistrationThreshold) { // Small buffer: use input directly without window registration TLLM_LOG_DEBUG( "[runNCCLAllReduceSymmetric] Buffer size %zu bytes < threshold %zu bytes, " "skipping window registration", bufferSizeBytes, minRegistrationThreshold); // inputTensor and inputPtr remain pointing to original input } else { // Large buffer: create window buffer and copy input (can swap inputTensor reference) auto [symmetricInput, symmetricBuffer0] = createNCCLWindowTensor(comm, input.sizes(), input.scalar_type()); TLLM_CUDA_CHECK(cudaMemcpyAsync( symmetricBuffer0.ptr, input.data_ptr(), bufferSizeBytes, cudaMemcpyDeviceToDevice, stream)); windowBuffer0 = symmetricBuffer0; inputTensor = symmetricInput; // Swap to window-backed tensor inputPtr = windowBuffer0.ptr; } } else { // Buffer already registered - use it directly inputPtr = windowBuffer0.ptr; } // Use window-backed output buffer auto [normOut, windowBuffer1] = createNCCLWindowTensor(comm, input.sizes(), input.scalar_type()); torch::Tensor outputTensor = normOut; void* outputPtr = windowBuffer1.ptr; // Perform allreduce NCCLCHECK_THROW(ncclAllReduce(inputPtr, outputPtr, size, (*getDtypeMap())[mType], ncclSum, comm, stream)); if (mOp == AllReduceFusionOp::NONE) { return {outputTensor}; } // Treat any other patterns as fallback cases. return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, outputTensor); } std::vector runLowPrecisionAllReduce(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias) noexcept { #ifdef ENABLE_FP8 auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); int size = input.numel(); int hidden_size = input.size(-1); auto const tp_size = mGroup.size(); auto const cur_rank = getRank(); int tp_rank = 0; for (auto const& currentRank : mGroup) { if (cur_rank == currentRank) break; ++tp_rank; } int bytes_per_element = input.element_size(); int token_num = size / hidden_size; auto parts = tensorrt_llm::kernels::splitNumber(size); torch::Tensor reduce_output = torch::empty_like(input); size_t global_offset = 0; for (size_t i = 0; i < parts.size(); ++i) { size_t tmp_size = parts[i]; tensorrt_llm::kernels::LowPrecisionAllReduceParams tmp_param; if (tp_size <= 4) { tmp_param = tensorrt_llm::kernels::LowPrecisionAllReduceParams::deserialize( tp_size, tp_rank, mType, token_num, hidden_size); } else { tmp_param = tensorrt_llm::kernels::LowPrecisionAllReduceParams::deserialize_hier( tp_size, tp_rank, mType, token_num, hidden_size); } tmp_param.local_input_buffer_ptr = reinterpret_cast( reinterpret_cast(input.data_ptr()) + global_offset * bytes_per_element); tmp_param.local_output_buffer_ptr = reinterpret_cast( reinterpret_cast(reduce_output.mutable_data_ptr()) + global_offset * bytes_per_element); tmp_param.elts_total = tmp_size; tensorrt_llm::kernels::customLowPrecisionAllReduce(tmp_param, mType, stream); global_offset += tmp_size; } if (mOp == AllReduceFusionOp::NONE) { return {reduce_output}; } // Treat any other patterns as fallback cases. return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output); #else C10_THROW_ERROR(NotImplementedError, "Can't use LOWPRECISION without compile with ENABLE FP8."); #endif } std::vector runFusionAllReduce(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias, bool trigger_completion_at_end, torch::optional workspace, AllReduceStrategyType strategy) noexcept { // Should handle only Lamport implementation auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); int size = input.numel(); int hidden_size = input.size(-1); int seq_len = input.size(0); auto const tp_size = mGroup.size(); auto const cur_rank = getRank(); int tp_rank = 0; for (auto const& currentRank : mGroup) { if (cur_rank == currentRank) break; ++tp_rank; } // Use cleaner output assigning torch::Tensor reduce_out; torch::Tensor residual_out; torch::Tensor norm_out; torch::Tensor quant_out; torch::Tensor scale_out; tensorrt_llm::kernels::ar_fusion::AllReduceFusionParams allreduce_fusion_params; allreduce_fusion_params.residual_in = nullptr; allreduce_fusion_params.rms_gamma = nullptr; allreduce_fusion_params.allreduce_out = nullptr; allreduce_fusion_params.quant_out = nullptr; allreduce_fusion_params.scale_out = nullptr; allreduce_fusion_params.residual_out = nullptr; allreduce_fusion_params.norm_out = nullptr; allreduce_fusion_params.trigger_completion_at_end = trigger_completion_at_end; // Determine if using oneshot or twoshot allreduce kernel in case using MIN_LATENCY strategy. if (strategy == AllReduceStrategyType::MIN_LATENCY) { allreduce_fusion_params.use_oneshot = seq_len <= tensorrt_llm::kernels::ar_fusion::kOneShotMaxToken || hidden_size < static_cast(tp_size); } else { allreduce_fusion_params.use_oneshot = strategy == AllReduceStrategyType::ONESHOT; } // Check for some kernel constraints if using TWOSHOT kernel if (!allreduce_fusion_params.use_oneshot) { TORCH_CHECK(input.size(0) >= static_cast(tp_size), "Sequence length must be greater than or equal to TP size"); } // Handle no fusion allreduce here if (mOp == AllReduceFusionOp::NONE) { reduce_out = torch::empty_like(input); allreduce_fusion_params.allreduce_out = reduce_out.mutable_data_ptr(); allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kAllReduce; } // Handle allreduce fusion here // Prepare required output tensors for each fusion pattern else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM) { norm_out = torch::empty_like(input); residual_out = torch::empty_like(residual.value()); allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr(); allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr(); allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNorm; } else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8 || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8) { quant_out = at::detail::empty_cuda(input.sizes(), torch::kFloat8_e4m3fn, input.device(), std::nullopt); residual_out = torch::empty_like(residual.value()); allreduce_fusion_params.quant_out = quant_out.mutable_data_ptr(); allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr(); allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormFP8Quant; // norm out is required if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8) { norm_out = torch::empty_like(input); allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr(); allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant; } } else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4 || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4) { // TODO: Better check for each pattern int64_t sf_vec_size = 16; int64_t m = 1; auto const& input_shape = input.sizes(); auto const& r = input_shape.size(); TORCH_CHECK(r >= 2, "Input should be >=2D tensor."); for (size_t i = 0; i < r - 1; i++) { m *= input_shape[i]; } auto const k = input_shape[r - 1]; TORCH_CHECK(k % sf_vec_size == 0, "Input should be divisible by sfVecSize."); std::vector output_shape(input_shape.begin(), input_shape.end()); output_shape[r - 1] = k / 2; quant_out = at::detail::empty_cuda(output_shape, FLOAT4_E2M1X2, input.device(), std::nullopt); scale_out = at::detail::empty_cuda({tensorrt_llm::computeSwizzledLayoutSFSize(m, k / sf_vec_size)}, SF_DTYPE, input.device(), std::nullopt); residual_out = torch::empty_like(residual.value()); allreduce_fusion_params.quant_out = quant_out.mutable_data_ptr(); allreduce_fusion_params.scale_out = scale_out.mutable_data_ptr(); allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr(); allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormFP4Quant; // norm out is required if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4) { norm_out = torch::empty_like(input); allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr(); allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant; } } else { TORCH_CHECK(false, "Unsupported fusion operation: " + tensorrt_llm::kernels::toString(mOp)); return {}; } allreduce_fusion_params.nranks = tp_size; allreduce_fusion_params.rank = tp_rank; allreduce_fusion_params.dtype = mType; allreduce_fusion_params.size = size; allreduce_fusion_params.hidden_dim = hidden_size; allreduce_fusion_params.workspace = reinterpret_cast(workspace.value().mutable_data_ptr()); allreduce_fusion_params.allreduce_in = input.data_ptr(); if (mOp != AllReduceFusionOp::NONE) { allreduce_fusion_params.residual_in = residual.value().data_ptr(); allreduce_fusion_params.rms_gamma = norm_weight.value().data_ptr(); allreduce_fusion_params.rms_eps = mEps; } allreduce_fusion_params.stream = stream; bool const is_scale_factor_required = mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8 || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8 || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4 || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4; allreduce_fusion_params.scale_factor = is_scale_factor_required ? static_cast(scale.value().data_ptr()) : nullptr; tensorrt_llm::kernels::ar_fusion::allreduce_fusion_op(allreduce_fusion_params); // Pack output tensors switch (mOp) { case AllReduceFusionOp::NONE: return {reduce_out}; case AllReduceFusionOp::RESIDUAL_RMS_NORM: return {norm_out, residual_out}; case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8: return {quant_out, residual_out}; case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8: return {norm_out, quant_out, residual_out}; case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4: return {quant_out, scale_out, residual_out}; case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: return {norm_out, quant_out, scale_out, residual_out}; default: TORCH_CHECK(false, "Unsupported fusion operation: " + tensorrt_llm::kernels::toString(mOp)); } return {}; } std::vector fallbackRunSubsequentOps(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias, torch::Tensor& reduce_output) { // If we reach here, it means the extra fallback operations are required. // All patterns are broken into ALlReduce + residual_rms_norm + following operations (quantization, etc.) auto const size = input.numel(); auto const hidden_size = input.size(-1); auto const stream = at::cuda::getCurrentCUDAStream(input.get_device()); torch::Tensor norm_out = torch::empty_like(input); tensorrt_llm::kernels::AllReduceParams params; params.fusion_params.bias_buffer = bias ? bias.value().data_ptr() : nullptr; params.fusion_params.residual_buffer = residual ? residual.value().data_ptr() : nullptr; params.fusion_params.weight_buffer = norm_weight ? norm_weight.value().data_ptr() : nullptr; params.local_output_buffer_ptr = norm_out.mutable_data_ptr(); params.elts_total = size; params.fusion_params.hidden_size = hidden_size; params.fusion_params.eps = mEps; params.fusion_params.intermediate_buffer = reduce_output.mutable_data_ptr(); tensorrt_llm::kernels::residualRmsNorm(params, mType, stream, AllReduceFusionOp::RESIDUAL_RMS_NORM); // If no quantization is needed, return the norm and residual outputs. if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM) { return {norm_out, reduce_output}; } int64_t const sf_vecsize = 16; bool const sf_use_ue8m0 = false; bool const is_sf_swizzled_layout = true; TORCH_CHECK(scale, "scale is required for quantization ops"); // Attach the subsequent operations after the residual RMS norm all-reduce and return the final outputs. switch (mOp) { case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8: { auto [quant_out, scale_out] = torch_ext::symmetric_static_quantize_per_tensor(norm_out, scale.value()); return {quant_out, reduce_output}; } case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4: { auto [quant_out, scale_out] = torch_ext::fp4_quantize(norm_out, scale.value(), sf_vecsize, sf_use_ue8m0, is_sf_swizzled_layout); return {quant_out, scale_out, reduce_output}; } case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8: { auto [quant_out, scale_out] = torch_ext::symmetric_static_quantize_per_tensor(norm_out, scale.value()); return {norm_out, quant_out, reduce_output}; } case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: { auto [quant_out, scale_out] = torch_ext::fp4_quantize(norm_out, scale.value(), sf_vecsize, sf_use_ue8m0, is_sf_swizzled_layout); return {norm_out, quant_out, scale_out, reduce_output}; } default: break; } TORCH_CHECK(false, "Unsupported fusion operation: " + tensorrt_llm::kernels::toString(mOp)); return {}; } void initGroupTopology() { static std::map, std::tuple> cache; if (cache.find(mGroup) != cache.end()) { auto [is_NVLINK_supported, is_P2P_supported, is_MNNVL_supported] = cache[mGroup]; mIsNVLINKSupported = is_NVLINK_supported; mIsP2PSupported = is_P2P_supported; mIsMNNVLSupported = is_MNNVL_supported; return; } setGroupTopology(); cache[mGroup] = {mIsNVLINKSupported, mIsP2PSupported, mIsMNNVLSupported}; } bool checkMNNVLSupport(int device_id) { #if ENABLE_MULTI_DEVICE // 1. Check CUDA driver version (needs >= 12.0.10) int cuda_driver_version = -1; TLLM_CUDA_CHECK(cudaDriverGetVersion(&cuda_driver_version)); if (cuda_driver_version < 12010) { TLLM_LOG_DEBUG("MNNVL check: CUDA Driver version %d < 12010", cuda_driver_version); return false; } // 2. Check multicast support CUdevice cu_device; TLLM_CU_CHECK(cuDeviceGet(&cu_device, device_id)); auto cuda_driver = tensorrt_llm::common::CUDADriverWrapper::getInstance(); int multicast_supported = 0; TLLM_CU_CHECK(cuda_driver->cuDeviceGetAttribute( &multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cu_device)); if (!multicast_supported) { TLLM_LOG_DEBUG("MNNVL check: Device %d does not support multicast", device_id); return false; } // 3. Check fabric handle support int fabric_handle_supported = 0; TLLM_CU_CHECK(cuda_driver->cuDeviceGetAttribute( &fabric_handle_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, cu_device)); if (!fabric_handle_supported) { TLLM_LOG_DEBUG("MNNVL check: Device %d does not support fabric handles", device_id); return false; } // 4. Check NVML GPU Fabric Info nvmlDevice_t nvml_device; NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(device_id, &nvml_device)); nvmlGpuFabricInfo_t fabric_info; NVML_CHECK_THROW(nvmlDeviceGetGpuFabricInfo(nvml_device, &fabric_info)); // Check if fabric is fully initialized if (fabric_info.state != NVML_GPU_FABRIC_STATE_COMPLETED || fabric_info.status != NVML_SUCCESS) { TLLM_LOG_DEBUG( "MNNVL check: Fabric state not complete - state=%u status=%u", fabric_info.state, fabric_info.status); return false; } // 5. Check NVLink links are active (similar to Python support_nvlink(True)) unsigned int active_links = 0; unsigned int available_links = 0; for (unsigned int link = 0; link < NVML_NVLINK_MAX_LINKS; link++) { unsigned int cap_p2p = 0; nvmlReturn_t cap_result = nvmlDeviceGetNvLinkCapability(nvml_device, link, NVML_NVLINK_CAP_P2P_SUPPORTED, &cap_p2p); if (cap_result == NVML_SUCCESS && cap_p2p) { available_links++; nvmlEnableState_t link_state; if (nvmlDeviceGetNvLinkState(nvml_device, link, &link_state) == NVML_SUCCESS && link_state == NVML_FEATURE_ENABLED) { active_links++; } } } bool all_links_up = (active_links == available_links && available_links > 0); if (!all_links_up) { TLLM_LOG_DEBUG( "MNNVL check: Not all NVLink links active - active=%u available=%u", active_links, available_links); return false; } TLLM_LOG_INFO("MNNVL check: Device %d supports MNNVL (fabric_clique=%u)", device_id, fabric_info.cliqueId); return true; #else return false; #endif } void setGroupTopology() { auto const rank = getRank(); TLLM_LOG_INFO("Detecting local TP group for rank %d", rank); std::set local_group = std::visit( overloaded{[&](std::shared_ptr&) { return getLocalGroup(mGroup); }, [&](c10::intrusive_ptr& torchPg) { return getLocalGroupTorch(mGroup); }}, mNcclComm); bool is_inter_node = (mGroup.size() != local_group.size()); NvmlManager nvml_manager; mIsP2PSupported = true; mIsNVLINKSupported = true; mIsMNNVLSupported = false; // First, check NVLink within local group (intra-node) if (!local_group.empty()) { for (int first_device_id : local_group) { for (int second_device_id : local_group) { if (first_device_id >= second_device_id) { continue; } int can_access_peer = 0; TLLM_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, first_device_id, second_device_id)); if (!can_access_peer) { mIsP2PSupported = false; mIsNVLINKSupported = false; TLLM_LOG_INFO( "P2P not supported between local devices %d and %d", first_device_id, second_device_id); // Continue checking other pairs, but mark as not supported continue; } nvmlDevice_t first_device; NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(first_device_id, &first_device)); bool is_NVLINK = false; for (unsigned int link = 0; link < NVML_NVLINK_MAX_LINKS; link++) { nvmlPciInfo_t remote_pci_info; if (nvmlDeviceGetNvLinkRemotePciInfo_v2(first_device, link, &remote_pci_info) != NVML_SUCCESS) { continue; } nvmlDevice_t remote_device; auto const result = nvmlDeviceGetHandleByPciBusId_v2(remote_pci_info.busId, &remote_device); if (result == NVML_SUCCESS) { // Two GPUs are connected directly through nvlink unsigned int remote_device_id; NVML_CHECK_THROW(nvmlDeviceGetIndex(remote_device, &remote_device_id)); if (remote_device_id == static_cast(second_device_id)) { is_NVLINK = true; } } else if (result == NVML_ERROR_NOT_FOUND) { // Maybe Two GPUs are connected via nvswitch, // now remotePciInfo represents the pci information of nvswitch, // determine whether nvlink is supported by whether two GPUs are connected to the same // nvswitch. nvmlDevice_t second_device; NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(second_device_id, &second_device)); for (unsigned int second_link = 0; second_link < NVML_NVLINK_MAX_LINKS; second_link++) { nvmlPciInfo_t second_remote_pci_info; if (nvmlDeviceGetNvLinkRemotePciInfo_v2( second_device, second_link, &second_remote_pci_info) != NVML_SUCCESS) { continue; } if (strcmp(remote_pci_info.busId, second_remote_pci_info.busId) == 0) { is_NVLINK = true; break; } } } else { NVML_CHECK_THROW(result); } if (is_NVLINK) { break; } } mIsNVLINKSupported &= is_NVLINK; } } } // For inter-node groups, check MNNVL support if (is_inter_node) { TLLM_LOG_INFO("Found inter-node TP group for rank %d, checking MNNVL support", rank); // Check MNNVL support on local device(s) bool local_mnnvl_supported = false; if (!local_group.empty()) { // Check MNNVL on first device in local group (all devices on same node should have same MNNVL status) int check_device = *local_group.begin(); local_mnnvl_supported = checkMNNVLSupport(check_device); } // Gather MNNVL status from all ranks in the group int local_mnnvl_status = local_mnnvl_supported ? 1 : 0; std::vector all_mnnvl_status(mGroup.size()); std::visit(overloaded{[&](std::shared_ptr& comm_ptr) { // For NCCL comm, use MPI to gather status // Use MPI allgather to collect MNNVL status // Create a sub-communicator for the group std::vector group_ranks(mGroup.begin(), mGroup.end()); MPI_Group world_group, new_group; MPI_Comm group_comm; MPI_Comm_group(COMM_SESSION, &world_group); MPI_Group_incl(world_group, group_ranks.size(), group_ranks.data(), &new_group); MPI_Comm_create_group(COMM_SESSION, new_group, 0, &group_comm); if (group_comm != MPI_COMM_NULL) { MPI_Allgather(&local_mnnvl_status, 1, MPI_INT, all_mnnvl_status.data(), 1, MPI_INT, group_comm); MPI_Comm_free(&group_comm); } MPI_Group_free(&new_group); MPI_Group_free(&world_group); }, [&](c10::intrusive_ptr& torchPg) { // For ProcessGroup, use allgather directly // Note: This assumes the ProcessGroup is already set up for the correct group std::vector input_tensors = {torch::tensor({local_mnnvl_status}, torch::kInt32)}; std::vector> output_tensors(1); output_tensors[0].resize(mGroup.size()); auto work = torchPg->allgather(output_tensors, input_tensors); if (work) { work->wait(); for (size_t i = 0; i < mGroup.size(); ++i) { all_mnnvl_status[i] = output_tensors[0][i].item(); } } }}, mNcclComm); // Check if all ranks support MNNVL bool all_ranks_support_mnnvl = true; for (int status : all_mnnvl_status) { if (status == 0) { all_ranks_support_mnnvl = false; break; } } // For inter-node: MNNVL support means all nodes have MNNVL // Also need local NVLink for optimal performance mIsMNNVLSupported = mIsNVLINKSupported && all_ranks_support_mnnvl; mIsP2PSupported = false; // P2P doesn't work across nodes TLLM_LOG_INFO("Inter-node topology: local_NVLink=%d, local_MNNVL=%d, all_ranks_MNNVL=%d, final_MNNVL=%d", mIsNVLINKSupported ? 1 : 0, local_mnnvl_status, all_ranks_support_mnnvl ? 1 : 0, mIsMNNVLSupported ? 1 : 0); } else { TLLM_LOG_INFO("TP group is intra-node for rank %d", rank); } } AllReduceStrategyType selectImplementation(size_t seq_len, size_t hidden_size) { if (mStrategy != AllReduceStrategyType::AUTO) { // For UB,NCCL,NCCL_SYMMETRIC, the correctness of the strategy dispatching is guaranteed by the user. if (mStrategy == AllReduceStrategyType::UB || mStrategy == AllReduceStrategyType::NCCL || mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC) { return mStrategy; } } // For ONESHOT, TWOSHOT, LOWPRECISION, fallback is allowed. auto const message_size = seq_len * hidden_size; // Check if LOWPRECISION is supported. if (isUsingLowPrecision(hidden_size)) { return AllReduceStrategyType::LOWPRECISION; } auto const message_size_bytes = message_size * tensorrt_llm::common::getDTypeSize(mType); auto const max_workspace_size = tensorrt_llm::utils::customAllReduceUtils::getMaxRequiredWorkspaceSize(mGroup.size()); if (ifFallbackToNCCL(seq_len, message_size_bytes, max_workspace_size)) { return AllReduceStrategyType::NCCL_SYMMETRIC; } // This rule based heuristic only chooses between NCCL_SYMMETRIC and MIN_LATENCY strategies. // From this point, all fusion patterns are supported by all these strategies: NCCL_SYMMETRIC, ONESHOT, TWOSHOT // and MIN_LATENCY. if (mStrategy != AllReduceStrategyType::AUTO) { // Check TWOSHOT constraint: seq_len >= tp_size if (mStrategy == AllReduceStrategyType::TWOSHOT && seq_len < mGroup.size()) { TLLM_LOG_WARNING("TWOSHOT strategy requires seq_len >= tp_size (%zu < %zu), falling back to ONESHOT", seq_len, mGroup.size()); return AllReduceStrategyType::ONESHOT; } return mStrategy; } else { return tensorrt_llm::utils::customAllReduceUtils::selectStrategyLookUpTable( seq_len, hidden_size, mOp, mGroup.size()); } } bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size) { // If messageSize is greater than maxWorkspaceSize or topology is unsuitable, use NCCL_SYMMETRIC fallback. if (message_size_bytes > max_workspace_size || !mIsP2PSupported || !mIsNVLINKSupported) { return true; } return false; } bool isUsingLowPrecision(size_t message_size) const noexcept { bool force_low_precision = mStrategy == AllReduceStrategyType::LOWPRECISION; #ifdef ENABLE_FP8 // Use LowPrecision if PCIe and p2p support and message size is larger than 2MB constexpr int LowPrecisionMinMessageSize = 2 * 1024 * 1024; return force_low_precision && !mIsNVLINKSupported && mIsP2PSupported && message_size >= LowPrecisionMinMessageSize; #else // Low precision is not available when FP8 is not enabled return false; #endif } private: std::set mGroup; bool mIsNVLINKSupported; bool mIsP2PSupported; bool mIsMNNVLSupported; nvinfer1::DataType mType; AllReduceStrategyType mStrategy; AllReduceFusionOp mOp; float mEps; std::variant, c10::intrusive_ptr> mNcclComm; }; } // namespace #endif // ENABLE_MULTI_DEVICE std::vector allreduce_raw(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias, torch::optional workspace, torch::List const& group_, int64_t const strategy_, int64_t const fusion_op_, double const eps_, bool const trigger_completion_at_end_) { #if ENABLE_MULTI_DEVICE auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); auto const strategy = static_cast(int8_t(strategy_)); auto const fusion_op = static_cast(int8_t(fusion_op_)); float const eps = eps_; std::set group; for (int64_t rank : group_) { group.insert(static_cast(rank)); } AllreduceOp op(group, dtype, strategy, fusion_op, eps); op.initialize(); return op.run(input, residual, norm_weight, scale, bias, trigger_completion_at_end_, workspace); #else return {input}; #endif // ENABLE_MULTI_DEVICE } std::vector allreduce_pg(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias, torch::optional const& workspace, torch::List const& group_, int64_t rank, c10::intrusive_ptr const& pg, int64_t const strategy_, int64_t const fusion_op_, double const eps_, bool const trigger_completion_at_end_) { #if ENABLE_MULTI_DEVICE auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); auto const strategy = static_cast(int8_t(strategy_)); auto const fusion_op = static_cast(int8_t(fusion_op_)); float const eps = eps_; std::set group; for (int64_t my_rank : group_) { group.insert(static_cast(my_rank)); } // Get nccl rank for this process process_group_ auto it = group.find(rank); if (it == group.end()) { throw std::runtime_error("Rank not found in group"); } int nccl_rank = std::distance(group.begin(), it); if (nccl_rank != pg->getRank()) { throw std::runtime_error("nccl_rank != pg->getRank()"); } AllreduceOp op(group, pg, dtype, strategy, fusion_op, eps); op.initialize(); auto ret = op.run(input, residual, norm_weight, scale, bias, trigger_completion_at_end_, workspace); return ret; #else return {input}; #endif // ENABLE_MULTI_DEVICE } // residual [m, hidden_dim] // norm_weight [hidden_dim] // device_num_experts [1] // scale_input [global_num_experts, m] // active_experts_token_input [device_num_experts, m, hidden_dim] // token_input [m, hidden_dim] std::vector moe_allreduce(torch::Tensor const& residual, torch::Tensor const& norm_weight, torch::Tensor const& device_num_experts, torch::Tensor const& scale_input, torch::Tensor const& active_experts_token_input, torch::Tensor const& token_input, torch::Tensor workspace, int64_t const rank, int64_t const nranks, double const eps) { auto allreduce_fusion_params = tensorrt_llm::kernels::ar_fusion::moe::MoeReductionAllReduceFusionParams(); allreduce_fusion_params.quant_out = nullptr; allreduce_fusion_params.scale_out = nullptr; allreduce_fusion_params.residual_out = nullptr; allreduce_fusion_params.norm_out = nullptr; allreduce_fusion_params.nranks = static_cast(nranks); allreduce_fusion_params.rank = static_cast(rank); allreduce_fusion_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(token_input.scalar_type()); // size: num_token * hidden_dim allreduce_fusion_params.size = static_cast(token_input.numel()); allreduce_fusion_params.hidden_dim = static_cast(active_experts_token_input.size(-1)); // workspace: AR scratch space allreduce_fusion_params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); allreduce_fusion_params.rms_gamma = norm_weight.data_ptr(); allreduce_fusion_params.rms_eps = static_cast(eps); allreduce_fusion_params.stream = at::cuda::getCurrentCUDAStream(norm_weight.get_device()); allreduce_fusion_params.residual_in = residual.data_ptr(); // MOE Reduction specific params allreduce_fusion_params.allreduce_in = nullptr; // for safety, set nullptr allreduce_fusion_params.moe_reduction_device_num_experts = static_cast(device_num_experts.data_ptr()); allreduce_fusion_params.moe_reduction_scale_input = static_cast(scale_input.data_ptr()); allreduce_fusion_params.moe_reduction_active_experts_token_input = active_experts_token_input.data_ptr(); allreduce_fusion_params.moe_reduction_token_input = token_input.data_ptr(); // output tensors torch::Tensor norm_out = torch::empty_like(token_input); torch::Tensor residual_out = torch::empty_like(residual); allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr(); allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr(); tensorrt_llm::kernels::ar_fusion::moe::moereduction_allreduce_fusion_op(allreduce_fusion_params); return {norm_out, residual_out}; } // Pattern: MoE Reduction + Add + AR + ADD_RMS, see this torch reference implementation: // expert_reduction = finalize(input, expanded_idx_to_permuted_idx, expert_scale_factor) // output_add = expert_reduction + shared_expert_output // output_residual = output_add + residual // output_hidden_states = rms_norm(output_residual, norm_weight, eps) // // Note: // input is the output of MoE FC2 // input [maxPermutedPaddedCount, hidden_dim] // residual [m, hidden_dim] // norm_weight [hidden_dim] // expanded_idx_to_permuted_idx [m, top_k] // expert_scale_factor [m, top_k] // shared_expert_output [m, hidden_dim] std::vector moe_finalize_allreduce(torch::Tensor const& input, torch::Tensor const& residual, torch::Tensor const& norm_weight, torch::Tensor const& expanded_idx_to_permuted_idx, torch::optional const& shared_expert_output, torch::optional const& expert_scale_factor, torch::Tensor workspace, int64_t const rank, int64_t const nranks, double const eps) { auto allreduce_fusion_params = tensorrt_llm::kernels::ar_fusion::moe::MoeFinalizeAllReduceFusionParams(); int hidden_dim = residual.size(-1); int top_k = expanded_idx_to_permuted_idx.size(-1); allreduce_fusion_params.quant_out = nullptr; allreduce_fusion_params.scale_out = nullptr; allreduce_fusion_params.nranks = static_cast(nranks); allreduce_fusion_params.rank = static_cast(rank); allreduce_fusion_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); // size: num_token * hidden_dim allreduce_fusion_params.size = residual.numel(); allreduce_fusion_params.hidden_dim = hidden_dim; // workspace: AR scratch space allreduce_fusion_params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); allreduce_fusion_params.rms_gamma = norm_weight.data_ptr(); allreduce_fusion_params.rms_eps = static_cast(eps); allreduce_fusion_params.residual_in = residual.data_ptr(); allreduce_fusion_params.stream = at::cuda::getCurrentCUDAStream(norm_weight.get_device()); // MOE Reduction specific params allreduce_fusion_params.top_k = top_k; allreduce_fusion_params.allreduce_in = input.data_ptr(); allreduce_fusion_params.expert_scale_factor = expert_scale_factor.has_value() ? expert_scale_factor.value().data_ptr() : nullptr; allreduce_fusion_params.scale_dtype = tensorrt_llm::runtime::TorchUtils::dataType( expert_scale_factor.has_value() ? expert_scale_factor.value().scalar_type() : input.scalar_type()); TORCH_CHECK( expanded_idx_to_permuted_idx.scalar_type() == torch::kInt32, "expanded_idx_to_permuted_idx must be int32"); allreduce_fusion_params.expanded_idx_to_permuted_idx = static_cast(expanded_idx_to_permuted_idx.data_ptr()); allreduce_fusion_params.shared_expert_output = shared_expert_output.has_value() ? shared_expert_output.value().data_ptr() : nullptr; // output tensors torch::Tensor norm_out = torch::empty_like(residual); torch::Tensor residual_out = torch::empty_like(residual); allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr(); allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr(); tensorrt_llm::kernels::ar_fusion::moe::moefinalize_allreduce_fusion_op(allreduce_fusion_params); return {norm_out, residual_out}; } std::vector mnnvlFusionAllReduce(torch::Tensor& input, torch::optional const& gamma, torch::optional const& residual_in, torch::optional epsilon, torch::Tensor& comm_buffer, torch::Tensor& buffer_flags, bool rmsnorm_fusion) { auto* mcast_mem = tensorrt_llm::common::findMcastDevMemBuffer(comm_buffer.data_ptr()); TORCH_CHECK( mcast_mem != nullptr, "[mnnvlFusionAllReduce] comm_buffer must be obtained from a mcastBuffer instance."); TORCH_CHECK(input.is_contiguous(), "[mnnvlFusionAllReduce] input must be contiguous"); auto const eltsPerThread = sizeof(float4) / input.itemsize(); auto const hiddenDim = input.size(1); auto const numTokens = input.size(0); TORCH_CHECK(hiddenDim % eltsPerThread == 0, "[mnnvlFusionAllReduce] Hidden dimension must be divisible by " + std::to_string(eltsPerThread) + ", got " + std::to_string(hiddenDim)); auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); torch::Tensor output = torch::empty_like(input); torch::Tensor residualOut; auto allreduce_params = tensorrt_llm::kernels::mnnvl::AllReduceFusionParams(); allreduce_params.nRanks = mcast_mem->getWorldSize(); allreduce_params.rank = mcast_mem->getRank(); allreduce_params.dType = dtype; allreduce_params.numTokens = numTokens; allreduce_params.tokenDim = hiddenDim; allreduce_params.bufferPtrsDev = reinterpret_cast(mcast_mem->getBufferPtrsDev()); allreduce_params.bufferPtrLocal = comm_buffer.mutable_data_ptr(); allreduce_params.multicastPtr = mcast_mem->getMulticastPtr(); allreduce_params.bufferFlags = reinterpret_cast(buffer_flags.mutable_data_ptr()); allreduce_params.input = input.const_data_ptr(); allreduce_params.output = output.mutable_data_ptr(); if (rmsnorm_fusion) { TORCH_CHECK(residual_in.has_value() && gamma.has_value() && epsilon.has_value(), "[mnnvlFusionAllReduce] residual_in, gamma, and epsilon must be provided for rmsnorm fusion"); TORCH_CHECK(residual_in.value().is_contiguous(), "[mnnvlFusionAllReduce] residual_in must be contiguous"); TORCH_CHECK(gamma.value().is_contiguous(), "[mnnvlFusionAllReduce] gamma must be contiguous"); allreduce_params.residualIn = residual_in.value().const_data_ptr(); allreduce_params.gamma = gamma.value().const_data_ptr(); allreduce_params.epsilon = static_cast(epsilon.value()); allreduce_params.rmsNormFusion = true; residualOut = torch::empty_like(residual_in.value()); allreduce_params.residualOut = residualOut.mutable_data_ptr(); } else { allreduce_params.rmsNormFusion = false; } allreduce_params.stream = at::cuda::getCurrentCUDAStream(output.get_device()); // Threshold to switch between one-shot and two-shot allreduce kernel // Empirical value, MSG size * World size constexpr size_t kOneShotSizeThreshold = 16 * 4 * 8192; if (numTokens * hiddenDim * allreduce_params.nRanks * input.itemsize() <= kOneShotSizeThreshold) { tensorrt_llm::kernels::mnnvl::oneshotAllreduceFusionOp(allreduce_params); } else { tensorrt_llm::kernels::mnnvl::twoshotAllreduceFusionOp(allreduce_params); } return {output, residualOut}; } } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mnnvl_fusion_allreduce(Tensor input, Tensor? residual, Tensor? gamma, " "float? epsilon, Tensor(a!) comm_buffer, Tensor buffer_flags, bool rmsnorm_fusion) -> " "Tensor[]"); m.def( "allreduce(" "Tensor input," "Tensor? residual," "Tensor? norm_weight," "Tensor? scale," "Tensor? bias," "Tensor? workspace," "int[] group," "int strategy," "int op," "float eps," "bool trigger_completion_at_end) -> Tensor[]"); m.def( "allreduce_pg(" "Tensor input," "Tensor? residual," "Tensor? norm_weight," "Tensor? scale," "Tensor? bias," "Tensor? workspace," "int[] group," "int rank," "__torch__.torch.classes.c10d.ProcessGroup pg," "int strategy," "int op," "float eps," "bool trigger_completion_at_end) -> Tensor[]"); m.def( "moe_allreduce(" "Tensor residual," "Tensor norm_weight," "Tensor device_num_experts," "Tensor scale_input," "Tensor active_experts_token_input," "Tensor token_input," "Tensor workspace," "int rank," "int nranks," "float eps) -> Tensor[]"); m.def("initialize_static_lowprecision_buffers(Tensor workspace, int tp_size) -> Tensor[]"); m.def( "moe_finalize_allreduce(" "Tensor input," "Tensor residual," "Tensor norm_weight," "Tensor expanded_idx_to_permuted_idx," "Tensor? shared_expert_output," "Tensor? expert_scale_factor," "Tensor workspace," "int rank," "int nranks," "float eps) -> Tensor[]"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("mnnvl_fusion_allreduce", &torch_ext::mnnvlFusionAllReduce); m.impl("allreduce", &torch_ext::allreduce_raw); m.impl("allreduce_pg", &torch_ext::allreduce_pg); m.impl("moe_allreduce", &torch_ext::moe_allreduce); m.impl("moe_finalize_allreduce", &torch_ext::moe_finalize_allreduce); } TORCH_LIBRARY_IMPL(trtllm, CPU, m) { m.impl("initialize_static_lowprecision_buffers", [](at::Tensor const& workspace, int64_t tp_size) { tensorrt_llm::kernels::initialize_static_lowprecision_buffers( reinterpret_cast(workspace.data_ptr()), (int) tp_size); return std::vector{}; }); }