/* * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include #include #include #if ENABLE_MULTI_DEVICE #include #endif // ENABLE_MULTI_DEVICE #include #include #include namespace torch_ext { #if ENABLE_MULTI_DEVICE namespace { class ReducescatterOp { public: ReducescatterOp(std::set group) : mGroup(std::move(group)) { } ~ReducescatterOp() = default; int initialize() { TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); mNcclComm = getComm(mGroup); TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); return 0; } std::vector run_list(torch::TensorList input_list, torch::optional> sizes) { TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used"); bool use_nccl_reducescatter = !sizes.has_value() || std::all_of(sizes.value().begin(), sizes.value().end(), [&sizes](int64_t size) { return size == sizes.value()[0]; }); int groupRank = 0; if (sizes.has_value()) { auto rank = COMM_SESSION.getRank(); for (auto const& currentRank : mGroup) { if (rank == currentRank) break; ++groupRank; } TLLM_CHECK(static_cast(groupRank) < mGroup.size()); } std::vector output_list; output_list.reserve(input_list.size()); ncclGroupStart(); for (auto const& input : input_list) { auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); auto type = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); std::vector outputShape = input.sizes().vec(); if (sizes.has_value()) { outputShape[0] = sizes.value()[groupRank]; } else { outputShape[0] = outputShape[0] / mGroup.size(); } auto output = torch::empty(outputShape, input.options()); if (use_nccl_reducescatter) { ncclReduceScatter(input.data_ptr(), output.mutable_data_ptr(), output.numel(), (*getDtypeMap())[type], ncclSum, *mNcclComm, stream); } else { size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{}); int64_t split_offset = 0; for (int root = 0; root < static_cast(mGroup.size()); ++root) { auto split_size = sizes.value()[root]; ncclReduce(input.index({torch::indexing::Slice(split_offset, torch::indexing::None)}).data_ptr(), output.mutable_data_ptr(), numel_base * split_size, (*getDtypeMap())[type], ncclSum, root, *mNcclComm, stream); split_offset += split_size; } } output_list.push_back(output); } NCCLCHECK_THROW(ncclGroupEnd()); return output_list; } torch::Tensor run(torch::Tensor const& input, torch::optional> sizes) { return run_list({input}, sizes)[0]; } private: std::set mGroup; std::shared_ptr mNcclComm; }; } // namespace #endif // ENABLE_MULTI_DEVICE extern torch::Tensor reducescatter( torch::Tensor input, torch::optional> sizes, torch::List group_) { #if ENABLE_MULTI_DEVICE std::set group; for (int64_t rank : group_) { group.insert(static_cast(rank)); } ReducescatterOp op(group); op.initialize(); auto output = op.run(input, sizes); return output; #else return input; #endif // ENABLE_MULTI_DEVICE } extern std::vector reducescatter_list( torch::TensorList input_list, torch::optional> sizes, torch::List group_) { #if ENABLE_MULTI_DEVICE std::set group; for (int64_t rank : group_) { group.insert(static_cast(rank)); } ReducescatterOp op(group); op.initialize(); auto output_list = op.run_list(input_list, sizes); return output_list; #else return input_list.vec(); #endif // ENABLE_MULTI_DEVICE } } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def("reducescatter(Tensor input, SymInt[]? sizes, int[] group) -> Tensor"); m.def("reducescatter_list(Tensor[] input_list, SymInt[]? sizes, int[] group) -> Tensor[]"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("reducescatter", &torch_ext::reducescatter); m.impl("reducescatter_list", &torch_ext::reducescatter_list); }