/* * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/kernels/moeCommKernels.h" #include "tensorrt_llm/kernels/moePrepareKernels.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/thop/thUtils.h" #include #include #include namespace torch_ext { std::tuple moeCommPrepareIndicesOp(torch::Tensor gatheredTargetRankIds, c10::optional realRankTokenCountCumSum, int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize) { CHECK_INPUT(gatheredTargetRankIds, torch::kInt32); TORCH_CHECK(gatheredTargetRankIds.dim() == 2, "gatheredTargetRankIds must be a 2D tensor"); TORCH_CHECK(gatheredTargetRankIds.size(1) == topK, "gatheredTargetRankIds must have topK columns"); int const* realRankTokenCountCumSumPtr = nullptr; if (realRankTokenCountCumSum.has_value()) { TORCH_CHECK(realRankTokenCountCumSum.value().dim() == 1, "realRankTokenCountCumSum must be a 1D tensor"); TORCH_CHECK(realRankTokenCountCumSum.value().dtype() == torch::kInt32, "realRankTokenCountCumSum must be a int32 tensor"); TORCH_CHECK( realRankTokenCountCumSum.value().size(0) == epSize, "realRankTokenCountCumSum must have epSize elements"); realRankTokenCountCumSumPtr = realRankTokenCountCumSum.value().data_ptr(); } else { TORCH_CHECK(gatheredTargetRankIds.size(0) == epSize * maxTokenCountPerRank, "gatheredTargetRankIds should have shape (epSize * maxTokenCountPerRank, topK)"); } TORCH_CHECK(maxTokenCountPerRank > 0, "maxTokenCountPerRank must be greater than 0"); TORCH_CHECK(expertCount > 0, "expertCount must be greater than 0"); TORCH_CHECK(topK > 0, "topK must be greater than 0"); TORCH_CHECK(topK <= expertCount, "topK must be less than or equal to expertCount"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); auto stream = at::cuda::getCurrentCUDAStream(); int maxSendRanksPerToken = std::max(epSize, topK); torch::Tensor localGatherIndices = torch::empty({maxTokenCountPerRank * epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32)); torch::Tensor sendRankCountCumSum = torch::empty({epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32)); torch::Tensor sendRankLocalIndices = torch::empty( {maxTokenCountPerRank * maxSendRanksPerToken}, gatheredTargetRankIds.options().dtype(torch::kInt32)); torch::Tensor recvRankCountCumSum = torch::empty({epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32)); torch::Tensor recvRankLocalIndices = torch::empty({maxTokenCountPerRank * epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32)); torch::Tensor backwardRecvRankLocalIndices = torch::empty( {maxTokenCountPerRank * maxSendRanksPerToken}, gatheredTargetRankIds.options().dtype(torch::kInt32)); tensorrt_llm::kernels::MoeExpertParallelInfo expertParallelInfo; expertParallelInfo.expertCount = expertCount; expertParallelInfo.topK = topK; tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast(epSize), static_cast(epRank)}; tensorrt_llm::kernels::moeAllToAllPrepareIndices(worldInfo, expertParallelInfo, maxTokenCountPerRank, gatheredTargetRankIds.data_ptr(), realRankTokenCountCumSumPtr, localGatherIndices.data_ptr(), sendRankCountCumSum.data_ptr(), sendRankLocalIndices.data_ptr(), recvRankCountCumSum.data_ptr(), recvRankLocalIndices.data_ptr(), backwardRecvRankLocalIndices.data_ptr(), stream); return std::make_tuple(localGatherIndices, sendRankCountCumSum, sendRankLocalIndices, recvRankCountCumSum, recvRankLocalIndices, backwardRecvRankLocalIndices); } void moeLocalGatherOp(torch::Tensor recvRankCumSum, torch::Tensor localGatherIndices, torch::Tensor gatheredExpertIds, c10::optional gatheredScales, torch::Tensor localExpertIds, c10::optional localScales, int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize) { CHECK_INPUT(recvRankCumSum, torch::kInt32); CHECK_INPUT(localGatherIndices, torch::kInt32); CHECK_INPUT(gatheredExpertIds, torch::kInt32); CHECK_INPUT(localExpertIds, torch::kInt32); TORCH_CHECK(maxTokenCountPerRank > 0, "maxTokenCountPerRank must be greater than 0"); TORCH_CHECK(expertCount > 0, "expertCount must be greater than 0"); TORCH_CHECK(topK > 0, "topK must be greater than 0"); TORCH_CHECK(topK <= expertCount, "topK must be less than or equal to expertCount"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); TORCH_CHECK(recvRankCumSum.dim() == 1, "recvRankCumSum must be a 1D tensor"); TORCH_CHECK(recvRankCumSum.size(0) == epSize, "recvRankCumSum must have epSize elements"); TORCH_CHECK(localGatherIndices.dim() == 1, "localGatherIndices must be a 1D tensor"); TORCH_CHECK(gatheredExpertIds.dim() == 2, "gatheredExpertIds must be a 2D tensor"); TORCH_CHECK(localExpertIds.dim() == 2, "localExpertIds must be a 2D tensor"); TORCH_CHECK(gatheredExpertIds.size(1) == topK, "gatheredExpertIds must have topK columns"); TORCH_CHECK(localExpertIds.size(1) == topK, "localExpertIds must have topK columns"); int localMaxTokenCount = static_cast(localGatherIndices.size(0)); TORCH_CHECK(localExpertIds.size(0) == localMaxTokenCount, "localExpertIds must have localMaxTokenCount rows"); TORCH_CHECK(gatheredScales.has_value() == localScales.has_value(), "gatheredScales and localScales must be both valid or both invalid"); float const* gatheredScalesPtr = nullptr; float* localScalesPtr = nullptr; if (gatheredScales.has_value()) { CHECK_INPUT(gatheredScales.value(), torch::kFloat32); CHECK_INPUT(localScales.value(), torch::kFloat32); TORCH_CHECK(gatheredScales->dim() == 2, "gatheredScales must be a 2D tensor"); TORCH_CHECK(gatheredScales->size(1) == topK, "gatheredScales must have topK columns"); TORCH_CHECK(localScales->dim() == 2, "localScales must be a 2D tensor"); TORCH_CHECK(localScales->size(1) == topK, "localScales must have topK columns"); TORCH_CHECK(localScales->size(0) == localMaxTokenCount, "localScales must have localMaxTokenCount rows"); gatheredScalesPtr = gatheredScales->data_ptr(); localScalesPtr = localScales->data_ptr(); } auto stream = at::cuda::getCurrentCUDAStream(); tensorrt_llm::kernels::MoeExpertParallelInfo expertParallelInfo; expertParallelInfo.expertCount = expertCount; expertParallelInfo.topK = topK; tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast(epSize), static_cast(epRank)}; tensorrt_llm::kernels::moeLocalGather(worldInfo, expertParallelInfo, maxTokenCountPerRank, localMaxTokenCount, recvRankCumSum.data_ptr(), localGatherIndices.data_ptr(), gatheredExpertIds.data_ptr(), gatheredScalesPtr, localExpertIds.data_ptr(), localScalesPtr, stream); } void moeCommOp(torch::Tensor input, torch::Tensor sendRankCumSum, torch::Tensor sendIndices, torch::Tensor output, torch::Tensor recvRankCumSum, torch::Tensor recvIndices, torch::Tensor allWorkspaces, int64_t epRank, int64_t epSize) { CHECK_INPUT(sendRankCumSum, torch::kInt32); CHECK_INPUT(sendIndices, torch::kInt32); CHECK_INPUT(recvRankCumSum, torch::kInt32); CHECK_INPUT(recvIndices, torch::kInt32); // allWorkspaces is a uint64 tensor, but may not be contiguous TORCH_CHECK(allWorkspaces.dtype() == torch::kUInt64, "allWorkspaces must be a uint64 tensor"); TORCH_CHECK(input.dim() == 2, "input must be a 2D tensor"); TORCH_CHECK(output.dim() == 2, "output must be a 2D tensor"); TORCH_CHECK(sendRankCumSum.dim() == 1, "sendRankCumSum must be a 1D tensor"); TORCH_CHECK(sendIndices.dim() == 1, "sendIndices must be a 1D tensor"); TORCH_CHECK(recvRankCumSum.dim() == 1, "recvRankCumSum must be a 1D tensor"); TORCH_CHECK(recvIndices.dim() == 1, "recvIndices must be a 1D tensor"); TORCH_CHECK(allWorkspaces.dim() == 2, "allWorkspaces must be a 2D tensor"); TORCH_CHECK(input.size(1) == output.size(1), "input and output must have the same second dimension"); TORCH_CHECK(sendRankCumSum.size(0) == epSize, "sendRankCumSum must have epSize elements"); TORCH_CHECK(recvRankCumSum.size(0) == epSize, "recvRankCumSum must have epSize elements"); TORCH_CHECK(allWorkspaces.size(0) == epSize, "allWorkspaces must have epSize elements"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast(epSize), static_cast(epRank)}; tensorrt_llm::kernels::SendRecvDataInfo sendRecvDataInfo; size_t eltSize = input.dtype().itemsize(); size_t eltCountPerU64 = sizeof(uint64_t) / eltSize; TORCH_CHECK(input.size(1) % (eltCountPerU64 * 2) == 0, "input.size(1) must be aligned to 16 bytes"); sendRecvDataInfo.vectorSizeInU64 = input.size(1) / eltCountPerU64; sendRecvDataInfo.DoPreCompute(); tensorrt_llm::kernels::SendRecvDispls sendDispls, recvDispls; sendDispls.dataPtr = static_cast(input.data_ptr()); sendDispls.rankCountCumSum = sendRankCumSum.data_ptr(); sendDispls.rankLocalIndices = sendIndices.data_ptr(); sendDispls.vectorStrideInU64 = input.stride(0) / eltCountPerU64; recvDispls.dataPtr = static_cast(output.data_ptr()); recvDispls.rankCountCumSum = recvRankCumSum.data_ptr(); recvDispls.rankLocalIndices = recvIndices.data_ptr(); recvDispls.vectorStrideInU64 = output.stride(0) / eltCountPerU64; tensorrt_llm::kernels::MoeCommWorkspace workspace; workspace.workspacePtr = allWorkspaces.data_ptr(); workspace.rankStrideInU64 = allWorkspaces.stride(0); auto stream = at::cuda::getCurrentCUDAStream(); tensorrt_llm::kernels::moeAllToAll(worldInfo, sendRecvDataInfo, sendDispls, recvDispls, workspace, stream); } int64_t getWorkspaceSizePerRank(int64_t epSize) { int epSize32 = static_cast(epSize); return tensorrt_llm::kernels::getMoeCommWorkspaceSize(epSize32); } void setMaxUsableSmCount(int64_t maxSmCount) { tensorrt_llm::kernels::setMaxUsableSmCount(maxSmCount); } int64_t getPrepareWorkspaceSizePerRank(int64_t epSize) { int epSize32 = static_cast(epSize); return tensorrt_llm::kernels::moe_prepare::getMoePrepareWorkspaceSize(epSize32); } std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, c10::optional> moePrepareOp(torch::Tensor expertsIds, c10::optional scales, c10::optional expertsStatics, torch::Tensor allWorkspaces, int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount, int64_t slotCount, int64_t topK) { CHECK_INPUT(expertsIds, torch::kInt32); TORCH_CHECK(expertCount % 4 == 0, "expertCount must be divisible by 4"); TORCH_CHECK(slotCount % 4 == 0, "slotCount must be divisible by 4"); int64_t maxSendRanksPerToken = std::max(epSize, topK); int64_t tokenCount = expertsIds.size(0); torch::Tensor preparedLocalExpertIds = torch::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(torch::kInt32)); torch::Tensor sendRankCountCumSum = torch::empty({epSize}, expertsIds.options().dtype(torch::kInt32)); torch::Tensor RecvRankCountCumSum = torch::empty({epSize}, expertsIds.options().dtype(torch::kInt32)); torch::Tensor gatherRecvRankIndices = torch::empty({maxTokenCountPerRank * epSize}, expertsIds.options().dtype(torch::kInt32)); torch::Tensor recvRankIndices = torch::empty({maxTokenCountPerRank * epSize}, expertsIds.options().dtype(torch::kInt32)); torch::Tensor gatherBackwardRecvRankIndices = torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32)); torch::Tensor backwardRecvRankIndices = torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32)); torch::Tensor gatherSendRankIndices = torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32)); torch::Tensor sendRankIndices = torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32)); c10::optional preparedLocalScales; float* scalesPtr = nullptr; float* preparedLocalScalesPtr = nullptr; if (scales.has_value()) { CHECK_INPUT(scales.value(), torch::kFloat32); scalesPtr = scales->data_ptr(); preparedLocalScales = torch::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(torch::kFloat32)); preparedLocalScalesPtr = preparedLocalScales->data_ptr(); } int* localExpertStaticsPtr = nullptr; int* gatheredExpertStaticsPtr = nullptr; c10::optional gatheredExpertStatics; if (expertsStatics.has_value()) { localExpertStaticsPtr = expertsStatics.value().data_ptr(); gatheredExpertStatics = torch::empty({epSize, expertCount}, expertsIds.options().dtype(torch::kInt32)); gatheredExpertStaticsPtr = gatheredExpertStatics.value().data_ptr(); } tensorrt_llm::kernels::moe_prepare::MoeCommWorkspace workspace; workspace.workspacePtr = allWorkspaces.data_ptr(); workspace.rankStrideInU64 = allWorkspaces.stride(0); auto stream = at::cuda::getCurrentCUDAStream(); tensorrt_llm::kernels::moe_prepare::computeCountAndIndice(expertsIds.data_ptr(), sendRankCountCumSum.data_ptr(), RecvRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), backwardRecvRankIndices.data_ptr(), recvRankIndices.data_ptr(), workspace, tokenCount, maxTokenCountPerRank, topK, slotCount, epRank, epSize, stream); tensorrt_llm::kernels::moe_prepare::computeCumsum( sendRankCountCumSum.data_ptr(), RecvRankCountCumSum.data_ptr(), epRank, epSize, stream); tensorrt_llm::kernels::moe_prepare::moveIndice(sendRankCountCumSum.data_ptr(), RecvRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), gatherSendRankIndices.data_ptr(), backwardRecvRankIndices.data_ptr(), gatherBackwardRecvRankIndices.data_ptr(), recvRankIndices.data_ptr(), gatherRecvRankIndices.data_ptr(), epRank, epSize, maxTokenCountPerRank, stream); tensorrt_llm::kernels::moe_prepare::allToAllMetadata(expertsIds.data_ptr(), preparedLocalExpertIds.data_ptr(), scalesPtr, preparedLocalScalesPtr, localExpertStaticsPtr, gatheredExpertStaticsPtr, workspace, sendRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), RecvRankCountCumSum.data_ptr(), recvRankIndices.data_ptr(), tokenCount, maxTokenCountPerRank, topK, expertCount, slotCount, epRank, epSize, stream); return std::make_tuple(preparedLocalExpertIds, preparedLocalScales, sendRankCountCumSum, gatherSendRankIndices, RecvRankCountCumSum, gatherRecvRankIndices, gatherBackwardRecvRankIndices, gatheredExpertStatics); } } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "moe_comm_prepare_indices(Tensor gathered_target_rank_ids, Tensor? real_rank_token_count_cum_sum, int " "max_token_count_per_rank, int expert_count, int top_k, int ep_rank, int ep_size) -> (Tensor, Tensor, Tensor, " "Tensor, Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("moe_comm_prepare_indices", &torch_ext::moeCommPrepareIndicesOp); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "moe_local_gather(Tensor recv_rank_cum_sum, Tensor local_gather_indices, Tensor gathered_expert_ids, Tensor? " "gathered_scales, Tensor local_expert_ids, Tensor? local_scales, int max_token_count_per_rank, int " "expert_count, int top_k, int ep_rank, int ep_size) -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("moe_local_gather", &torch_ext::moeLocalGatherOp); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "moe_comm(Tensor input, Tensor send_rank_cum_sum, Tensor send_indices, Tensor output, Tensor " "recv_rank_cum_sum, Tensor recv_indices, Tensor all_workspaces, int ep_rank, int ep_size) -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("moe_comm", &torch_ext::moeCommOp); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def("get_moe_commworkspace_size_per_rank(int ep_size) -> int"); } TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m) { m.impl("get_moe_commworkspace_size_per_rank", &torch_ext::getWorkspaceSizePerRank); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def("set_moe_max_usable_sm_count(int max_sm_count) -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m) { m.impl("set_moe_max_usable_sm_count", &torch_ext::setMaxUsableSmCount); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor? scales, Tensor? experts_statics, " "Tensor allWorkspace, int max_token_count_per_rank, int ep_rank, int ep_size, int expert_count, int " "slot_count, int top_k) -> (Tensor, Tensor?, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("mnnvl_moe_alltoallv_prepare_without_allgather", &torch_ext::moePrepareOp); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def("get_moe_prepare_workspace_size_per_rank(int ep_size) -> int"); } TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m) { m.impl("get_moe_prepare_workspace_size_per_rank", &torch_ext::getPrepareWorkspaceSizePerRank); }