TensorRT-LLMs/cpp/tensorrt_llm/thop/moeCommOp.cpp
WeiHaocheng 42a9385d02
[TRTLLM-5331] perf: Replace allgaher with AllToAllPrepare (#5570)
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
2025-06-30 13:06:09 +08:00

371 lines
18 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/moeCommKernels.h"
#include "tensorrt_llm/kernels/moePrepareKernels.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <vector>
namespace torch_ext
{
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
moeCommPrepareIndicesOp(torch::Tensor gatheredTargetRankIds, c10::optional<torch::Tensor> realRankTokenCountCumSum,
int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize)
{
CHECK_INPUT(gatheredTargetRankIds, torch::kInt32);
TORCH_CHECK(gatheredTargetRankIds.dim() == 2, "gatheredTargetRankIds must be a 2D tensor");
TORCH_CHECK(gatheredTargetRankIds.size(1) == topK, "gatheredTargetRankIds must have topK columns");
int const* realRankTokenCountCumSumPtr = nullptr;
if (realRankTokenCountCumSum.has_value())
{
TORCH_CHECK(realRankTokenCountCumSum.value().dim() == 1, "realRankTokenCountCumSum must be a 1D tensor");
TORCH_CHECK(realRankTokenCountCumSum.value().dtype() == torch::kInt32,
"realRankTokenCountCumSum must be a int32 tensor");
TORCH_CHECK(
realRankTokenCountCumSum.value().size(0) == epSize, "realRankTokenCountCumSum must have epSize elements");
realRankTokenCountCumSumPtr = realRankTokenCountCumSum.value().data_ptr<int>();
}
else
{
TORCH_CHECK(gatheredTargetRankIds.size(0) == epSize * maxTokenCountPerRank,
"gatheredTargetRankIds should have shape (epSize * maxTokenCountPerRank, topK)");
}
TORCH_CHECK(maxTokenCountPerRank > 0, "maxTokenCountPerRank must be greater than 0");
TORCH_CHECK(expertCount > 0, "expertCount must be greater than 0");
TORCH_CHECK(topK > 0, "topK must be greater than 0");
TORCH_CHECK(topK <= expertCount, "topK must be less than or equal to expertCount");
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
auto stream = at::cuda::getCurrentCUDAStream();
int maxSendRanksPerToken = std::max(epSize, topK);
torch::Tensor localGatherIndices
= torch::empty({maxTokenCountPerRank * epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
torch::Tensor sendRankCountCumSum = torch::empty({epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
torch::Tensor sendRankLocalIndices = torch::empty(
{maxTokenCountPerRank * maxSendRanksPerToken}, gatheredTargetRankIds.options().dtype(torch::kInt32));
torch::Tensor recvRankCountCumSum = torch::empty({epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
torch::Tensor recvRankLocalIndices
= torch::empty({maxTokenCountPerRank * epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32));
torch::Tensor backwardRecvRankLocalIndices = torch::empty(
{maxTokenCountPerRank * maxSendRanksPerToken}, gatheredTargetRankIds.options().dtype(torch::kInt32));
tensorrt_llm::kernels::MoeExpertParallelInfo expertParallelInfo;
expertParallelInfo.expertCount = expertCount;
expertParallelInfo.topK = topK;
tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
tensorrt_llm::kernels::moeAllToAllPrepareIndices(worldInfo, expertParallelInfo, maxTokenCountPerRank,
gatheredTargetRankIds.data_ptr<int>(), realRankTokenCountCumSumPtr, localGatherIndices.data_ptr<int>(),
sendRankCountCumSum.data_ptr<int>(), sendRankLocalIndices.data_ptr<int>(), recvRankCountCumSum.data_ptr<int>(),
recvRankLocalIndices.data_ptr<int>(), backwardRecvRankLocalIndices.data_ptr<int>(), stream);
return std::make_tuple(localGatherIndices, sendRankCountCumSum, sendRankLocalIndices, recvRankCountCumSum,
recvRankLocalIndices, backwardRecvRankLocalIndices);
}
void moeLocalGatherOp(torch::Tensor recvRankCumSum, torch::Tensor localGatherIndices, torch::Tensor gatheredExpertIds,
torch::Tensor gatheredScales, torch::Tensor localExpertIds, torch::Tensor localScales, int64_t maxTokenCountPerRank,
int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize)
{
CHECK_INPUT(recvRankCumSum, torch::kInt32);
CHECK_INPUT(localGatherIndices, torch::kInt32);
CHECK_INPUT(gatheredExpertIds, torch::kInt32);
CHECK_INPUT(gatheredScales, torch::kFloat32);
CHECK_INPUT(localExpertIds, torch::kInt32);
CHECK_INPUT(localScales, torch::kFloat32);
TORCH_CHECK(maxTokenCountPerRank > 0, "maxTokenCountPerRank must be greater than 0");
TORCH_CHECK(expertCount > 0, "expertCount must be greater than 0");
TORCH_CHECK(topK > 0, "topK must be greater than 0");
TORCH_CHECK(topK <= expertCount, "topK must be less than or equal to expertCount");
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
TORCH_CHECK(recvRankCumSum.dim() == 1, "recvRankCumSum must be a 1D tensor");
TORCH_CHECK(recvRankCumSum.size(0) == epSize, "recvRankCumSum must have epSize elements");
TORCH_CHECK(localGatherIndices.dim() == 1, "localGatherIndices must be a 1D tensor");
TORCH_CHECK(gatheredExpertIds.dim() == 2, "gatheredExpertIds must be a 2D tensor");
TORCH_CHECK(gatheredScales.dim() == 2, "gatheredScales must be a 2D tensor");
TORCH_CHECK(localExpertIds.dim() == 2, "localExpertIds must be a 2D tensor");
TORCH_CHECK(localScales.dim() == 2, "localScales must be a 2D tensor");
TORCH_CHECK(gatheredExpertIds.size(1) == topK, "gatheredExpertIds must have topK columns");
TORCH_CHECK(gatheredScales.size(1) == topK, "gatheredScales must have topK columns");
TORCH_CHECK(localExpertIds.size(1) == topK, "localExpertIds must have topK columns");
TORCH_CHECK(localScales.size(1) == topK, "localScales must have topK columns");
int localMaxTokenCount = static_cast<int>(localGatherIndices.size(0));
TORCH_CHECK(localExpertIds.size(0) == localMaxTokenCount, "localExpertIds must have localMaxTokenCount rows");
TORCH_CHECK(localScales.size(0) == localMaxTokenCount, "localScales must have localMaxTokenCount rows");
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::MoeExpertParallelInfo expertParallelInfo;
expertParallelInfo.expertCount = expertCount;
expertParallelInfo.topK = topK;
tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
tensorrt_llm::kernels::moeLocalGather(worldInfo, expertParallelInfo, maxTokenCountPerRank, localMaxTokenCount,
recvRankCumSum.data_ptr<int>(), localGatherIndices.data_ptr<int>(), gatheredExpertIds.data_ptr<int>(),
gatheredScales.data_ptr<float>(), localExpertIds.data_ptr<int>(), localScales.data_ptr<float>(), stream);
}
void moeCommOp(torch::Tensor input, torch::Tensor sendRankCumSum, torch::Tensor sendIndices, torch::Tensor output,
torch::Tensor recvRankCumSum, torch::Tensor recvIndices, torch::Tensor allWorkspaces, int64_t epRank,
int64_t epSize)
{
CHECK_INPUT(sendRankCumSum, torch::kInt32);
CHECK_INPUT(sendIndices, torch::kInt32);
CHECK_INPUT(recvRankCumSum, torch::kInt32);
CHECK_INPUT(recvIndices, torch::kInt32);
// allWorkspaces is a uint64 tensor, but may not be contiguous
TORCH_CHECK(allWorkspaces.dtype() == torch::kUInt64, "allWorkspaces must be a uint64 tensor");
TORCH_CHECK(input.dim() == 2, "input must be a 2D tensor");
TORCH_CHECK(output.dim() == 2, "output must be a 2D tensor");
TORCH_CHECK(sendRankCumSum.dim() == 1, "sendRankCumSum must be a 1D tensor");
TORCH_CHECK(sendIndices.dim() == 1, "sendIndices must be a 1D tensor");
TORCH_CHECK(recvRankCumSum.dim() == 1, "recvRankCumSum must be a 1D tensor");
TORCH_CHECK(recvIndices.dim() == 1, "recvIndices must be a 1D tensor");
TORCH_CHECK(allWorkspaces.dim() == 2, "allWorkspaces must be a 2D tensor");
TORCH_CHECK(input.size(1) == output.size(1), "input and output must have the same second dimension");
TORCH_CHECK(sendRankCumSum.size(0) == epSize, "sendRankCumSum must have epSize elements");
TORCH_CHECK(recvRankCumSum.size(0) == epSize, "recvRankCumSum must have epSize elements");
TORCH_CHECK(allWorkspaces.size(0) == epSize, "allWorkspaces must have epSize elements");
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
tensorrt_llm::kernels::SendRecvDataInfo sendRecvDataInfo;
size_t eltSize = input.dtype().itemsize();
size_t eltCountPerU64 = sizeof(uint64_t) / eltSize;
TORCH_CHECK(input.size(1) % (eltCountPerU64 * 2) == 0, "input.size(1) must be aligned to 16 bytes");
sendRecvDataInfo.vectorSizeInU64 = input.size(1) / eltCountPerU64;
sendRecvDataInfo.DoPreCompute();
tensorrt_llm::kernels::SendRecvDispls sendDispls, recvDispls;
sendDispls.dataPtr = static_cast<uint64_t*>(input.data_ptr());
sendDispls.rankCountCumSum = sendRankCumSum.data_ptr<int>();
sendDispls.rankLocalIndices = sendIndices.data_ptr<int>();
sendDispls.vectorStrideInU64 = input.stride(0) / eltCountPerU64;
recvDispls.dataPtr = static_cast<uint64_t*>(output.data_ptr());
recvDispls.rankCountCumSum = recvRankCumSum.data_ptr<int>();
recvDispls.rankLocalIndices = recvIndices.data_ptr<int>();
recvDispls.vectorStrideInU64 = output.stride(0) / eltCountPerU64;
tensorrt_llm::kernels::MoeCommWorkspace workspace;
workspace.workspacePtr = allWorkspaces.data_ptr<uint64_t>();
workspace.rankStrideInU64 = allWorkspaces.stride(0);
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::moeAllToAll(worldInfo, sendRecvDataInfo, sendDispls, recvDispls, workspace, stream);
}
int64_t getWorkspaceSizePerRank(int64_t epSize)
{
int epSize32 = static_cast<int>(epSize);
return tensorrt_llm::kernels::getMoeCommWorkspaceSize(epSize32);
}
void setMaxUsableSmCount(int64_t maxSmCount)
{
tensorrt_llm::kernels::setMaxUsableSmCount(maxSmCount);
}
int64_t getPrepareWorkspaceSizePerRank(int64_t epSize)
{
int epSize32 = static_cast<int>(epSize);
return tensorrt_llm::kernels::moe_prepare::getMoePrepareWorkspaceSize(epSize32);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
c10::optional<torch::Tensor>>
moePrepareOp(torch::Tensor expertsIds, torch::Tensor scales, c10::optional<torch::Tensor> expertsStatics,
torch::Tensor allWorkspaces, int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount,
int64_t slotCount, int64_t topK)
{
CHECK_INPUT(expertsIds, torch::kInt32);
CHECK_INPUT(scales, torch::kFloat32);
TORCH_CHECK(expertCount % 4 == 0, "expertCount must be divisible by 4");
TORCH_CHECK(slotCount % 4 == 0, "slotCount must be divisible by 4");
int64_t maxSendRanksPerToken = std::max(epSize, topK);
int64_t tokenCount = expertsIds.size(0);
torch::Tensor preparedLocalExpertIds
= torch::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(torch::kInt32));
torch::Tensor preparedLocalScales
= torch::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(torch::kFloat32));
torch::Tensor sendRankCountCumSum = torch::empty({epSize}, expertsIds.options().dtype(torch::kInt32));
torch::Tensor RecvRankCountCumSum = torch::empty({epSize}, expertsIds.options().dtype(torch::kInt32));
torch::Tensor gatherRecvRankIndices
= torch::empty({maxTokenCountPerRank * epSize}, expertsIds.options().dtype(torch::kInt32));
torch::Tensor recvRankIndices
= torch::empty({maxTokenCountPerRank * epSize}, expertsIds.options().dtype(torch::kInt32));
torch::Tensor gatherBackwardRecvRankIndices
= torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32));
torch::Tensor backwardRecvRankIndices
= torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32));
torch::Tensor gatherSendRankIndices
= torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32));
torch::Tensor sendRankIndices
= torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32));
int* localExpertStaticsPtr = nullptr;
int* gatheredExpertStaticsPtr = nullptr;
c10::optional<torch::Tensor> gatheredExpertStatics;
if (expertsStatics.has_value())
{
localExpertStaticsPtr = expertsStatics.value().data_ptr<int>();
gatheredExpertStatics = torch::empty({epSize, expertCount}, expertsIds.options().dtype(torch::kInt32));
gatheredExpertStaticsPtr = gatheredExpertStatics.value().data_ptr<int>();
}
tensorrt_llm::kernels::moe_prepare::MoeCommWorkspace workspace;
workspace.workspacePtr = allWorkspaces.data_ptr<uint64_t>();
workspace.rankStrideInU64 = allWorkspaces.stride(0);
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::moe_prepare::computeCountAndIndice(expertsIds.data_ptr<int>(),
sendRankCountCumSum.data_ptr<int>(), RecvRankCountCumSum.data_ptr<int>(), sendRankIndices.data_ptr<int>(),
backwardRecvRankIndices.data_ptr<int>(), recvRankIndices.data_ptr<int>(), workspace, tokenCount,
maxTokenCountPerRank, topK, slotCount, epRank, epSize, stream);
tensorrt_llm::kernels::moe_prepare::computeCumsum(
sendRankCountCumSum.data_ptr<int>(), RecvRankCountCumSum.data_ptr<int>(), epRank, epSize, stream);
tensorrt_llm::kernels::moe_prepare::moveIndice(sendRankCountCumSum.data_ptr<int>(),
RecvRankCountCumSum.data_ptr<int>(), sendRankIndices.data_ptr<int>(), gatherSendRankIndices.data_ptr<int>(),
backwardRecvRankIndices.data_ptr<int>(), gatherBackwardRecvRankIndices.data_ptr<int>(),
recvRankIndices.data_ptr<int>(), gatherRecvRankIndices.data_ptr<int>(), epRank, epSize, maxTokenCountPerRank,
stream);
tensorrt_llm::kernels::moe_prepare::allToAllMetadata(expertsIds.data_ptr<int>(),
preparedLocalExpertIds.data_ptr<int>(), scales.data_ptr<float>(), preparedLocalScales.data_ptr<float>(),
localExpertStaticsPtr, gatheredExpertStaticsPtr, workspace, sendRankCountCumSum.data_ptr<int>(),
sendRankIndices.data_ptr<int>(), RecvRankCountCumSum.data_ptr<int>(), recvRankIndices.data_ptr<int>(),
tokenCount, maxTokenCountPerRank, topK, expertCount, slotCount, epRank, epSize, stream);
return std::make_tuple(preparedLocalExpertIds, preparedLocalScales, sendRankCountCumSum, gatherSendRankIndices,
RecvRankCountCumSum, gatherRecvRankIndices, gatherBackwardRecvRankIndices, gatheredExpertStatics);
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_comm_prepare_indices(Tensor gathered_target_rank_ids, Tensor? real_rank_token_count_cum_sum, int "
"max_token_count_per_rank, int expert_count, int top_k, int ep_rank, int ep_size) -> (Tensor, Tensor, Tensor, "
"Tensor, "
"Tensor, Tensor)");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_comm_prepare_indices", &torch_ext::moeCommPrepareIndicesOp);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_local_gather(Tensor recv_rank_cum_sum, Tensor local_gather_indices, Tensor gathered_expert_ids, Tensor "
"gathered_scales, Tensor local_expert_ids, Tensor local_scales, int max_token_count_per_rank, int "
"expert_count, int "
"top_k, int ep_rank, int ep_size) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_local_gather", &torch_ext::moeLocalGatherOp);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_comm(Tensor input, Tensor send_rank_cum_sum, Tensor send_indices, Tensor output, Tensor "
"recv_rank_cum_sum, "
"Tensor recv_indices, Tensor all_workspaces, int ep_rank, int ep_size) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_comm", &torch_ext::moeCommOp);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("get_moe_commworkspace_size_per_rank(int ep_size) -> int");
}
TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m)
{
m.impl("get_moe_commworkspace_size_per_rank", &torch_ext::getWorkspaceSizePerRank);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("set_moe_max_usable_sm_count(int max_sm_count) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m)
{
m.impl("set_moe_max_usable_sm_count", &torch_ext::setMaxUsableSmCount);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor scales, Tensor? experts_statics, "
"Tensor allWorkspace, int "
"max_token_count_per_rank, int ep_rank, int ep_size, int expert_count, int slot_count, int top_k) -> (Tensor, "
"Tensor, Tensor, "
"Tensor, "
"Tensor, Tensor, Tensor, Tensor?)");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("mnnvl_moe_alltoallv_prepare_without_allgather", &torch_ext::moePrepareOp);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("get_moe_prepare_workspace_size_per_rank(int ep_size) -> int");
}
TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m)
{
m.impl("get_moe_prepare_workspace_size_per_rank", &torch_ext::getPrepareWorkspaceSizePerRank);
}