/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include #include #include #include #include #include #include #if ENABLE_MULTI_DEVICE #include #endif // ENABLE_MULTI_DEVICE namespace torch_ext { #if ENABLE_MULTI_DEVICE namespace { class AllToAllHelixOp { public: AllToAllHelixOp(std::set group) : mGroup(std::move(group)) { } ~AllToAllHelixOp() = default; int initialize() { TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); mNcclComm = getComm(mGroup); TLLM_CHECK_WITH_INFO(mGroup.size() > 0, "group size should be greater than 0"); TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); return 0; } std::vector run(torch::TensorList input_list, torch::optional num_lists) { TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used"); auto num_lists_ = static_cast(num_lists.value_or(1)); auto num_ranks = static_cast(mGroup.size()); // note: ensures that input_list size > 0 TLLM_CHECK_WITH_INFO(static_cast(input_list.size()) == num_ranks * num_lists_, "input_list size should be equal to group size * num_lists"); std::vector output_list(static_cast(num_lists_)); auto stream = at::cuda::getCurrentCUDAStream(input_list[0].get_device()); ncclGroupStart(); for (int il = 0; il < num_lists_; ++il) { auto off = il * num_ranks; auto output_shape = input_list[off].sizes().vec(); output_shape.insert(output_shape.begin(), num_ranks); auto output = torch::empty(output_shape, input_list[off].options()); output_list[il] = output; auto type = tensorrt_llm::runtime::TorchUtils::dataType(input_list[off].scalar_type()); auto nccl_type = (*getDtypeMap())[type]; for (int r = 0; r < num_ranks; ++r) { auto const& input = input_list[off + r]; ncclSend(input.data_ptr(), input.numel(), nccl_type, r, *mNcclComm, stream); ncclRecv(output[r].mutable_data_ptr(), output[r].numel(), nccl_type, r, *mNcclComm, stream); } } NCCLCHECK_THROW(ncclGroupEnd()); return output_list; } private: std::set mGroup; std::shared_ptr mNcclComm; }; } // namespace #endif // ENABLE_MULTI_DEVICE std::vector alltoall_helix( torch::TensorList input_list, torch::List group_, torch::optional num_lists) { #if ENABLE_MULTI_DEVICE std::set group; for (int64_t rank : group_) { group.insert(static_cast(rank)); } AllToAllHelixOp op(group); op.initialize(); return op.run(input_list, num_lists); #else return {}; #endif // ENABLE_MULTI_DEVICE } } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def("alltoall_helix(Tensor[] input_list, int[] group, int? num_lists) -> Tensor[]"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("alltoall_helix", &torch_ext::alltoall_helix); }