TensorRT-LLMs/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp
Yihan Wang 9df4dad3b6
[None][fix] Introduce inline namespace to avoid symbol collision (#9541)
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
2025-12-12 23:32:15 +08:00

285 lines
12 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <c10/core/Allocator.h> // for c10::DataPtr
#include <c10/core/StorageImpl.h> // for c10::StorageImpl and use_byte_size_t()
#include <c10/cuda/CUDAStream.h>
#include <c10/util/intrusive_ptr.h> // for c10::make_intrusive#include <torch/extension.h>
#include <vector>
#include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
#include "tensorrt_llm/runtime/moeLoadBalancer/hostAccessibleDeviceAllocator.h"
#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h"
TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{
torch::Tensor moeLoadBalanceWaitGpuStage(int64_t singleLayerLoadBalancerPtr)
{
TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null");
auto* loadBalancer
= reinterpret_cast<tensorrt_llm::runtime::SingleLayerMoeLoadBalancer*>(singleLayerLoadBalancerPtr);
auto enabled = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto signal = loadBalancer->getSignal();
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::moeWaitSignalForGpuStageDevice(signal, enabled.data_ptr<int>(), stream);
return enabled;
}
void moeLoadBalanceSetCpuStage(int64_t singleLayerLoadBalancerPtr)
{
TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null");
auto* loadBalancer
= reinterpret_cast<tensorrt_llm::runtime::SingleLayerMoeLoadBalancer*>(singleLayerLoadBalancerPtr);
auto signal = loadBalancer->getSignal();
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::moeSetSignalForCpuStageDevice(signal, stream);
}
void moeLoadBalanceStatistic(torch::Tensor gatheredRawExpertIds, torch::Tensor enabled,
int64_t singleLayerLoadBalancerPtr, int64_t isFirstStage, int64_t isLastStage)
{
CHECK_INPUT(gatheredRawExpertIds, torch::kInt32);
CHECK_INPUT(enabled, torch::kInt32);
TORCH_CHECK(gatheredRawExpertIds.dim() == 2, "gatheredRawExpertIds must be a 2D tensor");
int topK = gatheredRawExpertIds.size(1);
TORCH_CHECK(enabled.dim() == 1, "enabled must be a 1D tensor");
TORCH_CHECK(enabled.size(0) == 1, "enabled must have 1 element");
TORCH_CHECK(isFirstStage == 0 || isFirstStage == 1, "isFirstStage must be 0 or 1");
TORCH_CHECK(isLastStage == 0 || isLastStage == 1, "isLastStage must be 0 or 1");
TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null");
auto* loadBalancer
= reinterpret_cast<tensorrt_llm::runtime::SingleLayerMoeLoadBalancer*>(singleLayerLoadBalancerPtr);
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo = loadBalancer->getMetaInfo();
TORCH_CHECK(topK == metaInfo.topK, "topK must be equal to metaInfo.topK");
auto statisticInfo = loadBalancer->getStatisticInfo();
int numTotalTokens = gatheredRawExpertIds.size(0);
tensorrt_llm::kernels::moeStatisticDevice(metaInfo, *statisticInfo, numTotalTokens, enabled.data_ptr<int>(),
static_cast<bool>(isFirstStage), static_cast<bool>(isLastStage), gatheredRawExpertIds.data_ptr<int>(), stream);
}
void moeHierarchicalStatisticLocalDevice(torch::Tensor localRawExpertIds, torch::Tensor localExpertTokenCount,
torch::Tensor enabled, int64_t singleLayerLoadBalancerPtr, int64_t isFirstStage, int64_t isLastStage)
{
CHECK_INPUT(localRawExpertIds, torch::kInt32);
CHECK_INPUT(localExpertTokenCount, torch::kInt32);
CHECK_INPUT(enabled, torch::kInt32);
TORCH_CHECK(localRawExpertIds.dim() == 2, "localRawExpertIds must be a 2D tensor");
TORCH_CHECK(localExpertTokenCount.dim() == 1, "localExpertTokenCount must be a 1D tensor");
int topK = localRawExpertIds.size(1);
TORCH_CHECK(enabled.dim() == 1, "enabled must be a 1D tensor");
TORCH_CHECK(enabled.size(0) == 1, "enabled must have 1 element");
TORCH_CHECK(isFirstStage == 0 || isFirstStage == 1, "isFirstStage must be 0 or 1");
TORCH_CHECK(isLastStage == 0 || isLastStage == 1, "isLastStage must be 0 or 1");
TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null");
auto* loadBalancer
= reinterpret_cast<tensorrt_llm::runtime::SingleLayerMoeLoadBalancer*>(singleLayerLoadBalancerPtr);
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo = loadBalancer->getMetaInfo();
TORCH_CHECK(localExpertTokenCount.size(0) == metaInfo.expertCount, "localExpertTokenCount should have shape (%d,)",
metaInfo.expertCount);
TORCH_CHECK(topK == metaInfo.topK, "topK must be equal to metaInfo.topK");
int numTotalTokens = localRawExpertIds.size(0);
tensorrt_llm::kernels::moeHierarchicalStatisticLocalDevice(metaInfo, numTotalTokens,
localExpertTokenCount.data_ptr<int>(), enabled.data_ptr<int>(), static_cast<bool>(isFirstStage),
static_cast<bool>(isLastStage), localRawExpertIds.data_ptr<int>(), stream);
}
void moeHierarchicalStatisticUpdate(
torch::Tensor globalExpertTokenCount, torch::Tensor enabled, int64_t singleLayerLoadBalancerPtr)
{
CHECK_INPUT(globalExpertTokenCount, torch::kInt32);
CHECK_INPUT(enabled, torch::kInt32);
TORCH_CHECK(globalExpertTokenCount.dim() == 1, "globalExpertTokenCount must be a 1D tensor");
TORCH_CHECK(enabled.dim() == 1, "enabled must be a 1D tensor");
TORCH_CHECK(enabled.size(0) == 1, "enabled must have 1 element");
TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null");
auto* loadBalancer
= reinterpret_cast<tensorrt_llm::runtime::SingleLayerMoeLoadBalancer*>(singleLayerLoadBalancerPtr);
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo = loadBalancer->getMetaInfo();
auto statisticInfo = loadBalancer->getStatisticInfo();
TORCH_CHECK(globalExpertTokenCount.size(0) == metaInfo.expertCount,
"globalExpertTokenCount should have shape (%d,)", metaInfo.expertCount);
tensorrt_llm::kernels::moeHierarchicalStatisticUpdate(
metaInfo, *statisticInfo, globalExpertTokenCount.data_ptr<int>(), enabled.data_ptr<int>(), stream);
}
torch::Tensor moeLoadBalanceRouting(
torch::Tensor tokenSelectedExperts, bool offsetByEpRank, int64_t singleLayerLoadBalancerPtr)
{
CHECK_INPUT(tokenSelectedExperts, torch::kInt32);
TORCH_CHECK(tokenSelectedExperts.dim() == 2, "tokenSelectedExperts must be a 2D tensor");
int topK = tokenSelectedExperts.size(1);
TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null");
TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null");
auto* loadBalancer
= reinterpret_cast<tensorrt_llm::runtime::SingleLayerMoeLoadBalancer*>(singleLayerLoadBalancerPtr);
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo = loadBalancer->getMetaInfo();
TORCH_CHECK(topK == metaInfo.topK, "topK must be equal to metaInfo.topK");
int tokenCount = tokenSelectedExperts.size(0);
auto tokenRoutedSlotIds = torch::empty_like(tokenSelectedExperts);
tensorrt_llm::kernels::moeComputeRouteDevice(metaInfo, loadBalancer->getGpuPlacementInfo(),
tokenSelectedExperts.data_ptr<int>(), tokenRoutedSlotIds.data_ptr<int>(), tokenCount, offsetByEpRank, stream);
return tokenRoutedSlotIds;
}
void migrateToHostAccessible(at::Tensor& tensor)
{
TORCH_CHECK(tensor.device().is_cuda(), "only support CUDA Tensor");
TLLM_CHECK_WITH_INFO(tensorrt_llm::runtime::HostAccessibleDeviceAllocator::getInstance().isSupported(),
"host accessible allocator is not supported on system, please install GDRCopy.");
// 1) compute total bytes
size_t byte_size = tensor.numel() * tensor.element_size();
// 2) allocate host accessible memory
void* devPtr = tensorrt_llm::runtime::HostAccessibleDeviceAllocator::getInstance().allocate(byte_size);
// 3) copy old data to new memory
TLLM_CUDA_CHECK(cudaMemcpy(devPtr, tensor.data_ptr(), byte_size, cudaMemcpyDeviceToDevice));
// 4) use new DataPtr/StorageImpl to construct storage
// here managed_ptr is dataand also contextuse cudaFree as deleter
c10::DataPtr dp(
devPtr, devPtr,
[](void* ptr) { tensorrt_llm::runtime::HostAccessibleDeviceAllocator::getInstance().free(ptr); },
tensor.device());
auto allocator = c10::GetAllocator(tensor.device().type());
auto storage_impl = c10::make_intrusive<c10::StorageImpl>(c10::StorageImpl::use_byte_size_t(), byte_size,
std::move(dp), allocator,
/*resizable=*/false);
at::Storage new_storage(storage_impl);
// Finally replace tensor's storageoffset = 0shape and stride kept unchanged
tensor.set_(new_storage,
/*storage_offset=*/0, tensor.sizes().vec(), tensor.strides().vec());
}
} // namespace torch_ext
TRTLLM_NAMESPACE_END
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("moe_load_balance_wait_gpu_stage(int single_layer_load_balancer_ptr) -> Tensor");
}
TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m)
{
m.impl("moe_load_balance_wait_gpu_stage", &tensorrt_llm::torch_ext::moeLoadBalanceWaitGpuStage);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("moe_load_balance_set_cpu_stage(int single_layer_load_balancer_ptr) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m)
{
m.impl("moe_load_balance_set_cpu_stage", &tensorrt_llm::torch_ext::moeLoadBalanceSetCpuStage);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_load_balance_statistic(Tensor gathered_raw_expert_ids, Tensor enabled, int "
"single_layer_load_balancer_ptr, int is_first_stage, int is_last_stage) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_load_balance_statistic", &tensorrt_llm::torch_ext::moeLoadBalanceStatistic);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_hierarchical_statistic_local_device(Tensor local_raw_expert_ids, Tensor local_expert_token_count, Tensor "
"enabled, int "
"single_layer_load_balancer_ptr, int is_first_stage, int is_last_stage) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_hierarchical_statistic_local_device", &tensorrt_llm::torch_ext::moeHierarchicalStatisticLocalDevice);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_hierarchical_statistic_update(Tensor global_expert_token_count, Tensor enabled, int "
"single_layer_load_balancer_ptr) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_hierarchical_statistic_update", &tensorrt_llm::torch_ext::moeHierarchicalStatisticUpdate);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_load_balance_routing(Tensor token_selected_experts, bool offset_by_ep_rank, "
"int single_layer_load_balancer_ptr) -> Tensor");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_load_balance_routing", &tensorrt_llm::torch_ext::moeLoadBalanceRouting);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("migrate_to_host_accessible(Tensor tensor) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("migrate_to_host_accessible", &tensorrt_llm::torch_ext::migrateToHostAccessible);
}