feat: large-scale EP(part 6: Online EP load balancer integration for GB200 nvfp4) (#4818)

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
Signed-off-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Co-authored-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
dongxuy04 2025-06-08 10:25:18 +08:00 committed by GitHub
parent 5ee0de7f2a
commit 1e369658f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 2121 additions and 361 deletions

View File

@ -1,7 +1,7 @@
version: "3.9"
services:
tensorrt_llm-dev:
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420
network_mode: host
ipc: host

View File

@ -1,3 +1,6 @@
import os
import sys
from conan import ConanFile
from conan.tools.cmake import CMakeDeps, CMakeToolchain
@ -9,10 +12,22 @@ class TensorRT_LLM(ConanFile):
virtualrunenv = False
def requirements(self):
pass # TODO add dependencies here
self.requires("libnuma/system")
def generate(self):
cmake = CMakeDeps(self)
cmake.generate()
tc = CMakeToolchain(self)
tc.generate()
def build_requirements(self):
# register libnuma_conan.py for conan
base_dir = os.path.dirname(os.path.abspath(__file__))
libnuma_path = os.path.join(base_dir, "libnuma_conan.py")
conan_bin = os.path.abspath(sys.argv[0])
if not os.path.isfile(conan_bin) or not os.access(conan_bin, os.X_OK):
raise RuntimeError(f"Conan binary not found {sys.argv[0]}")
self.run(
f"{conan_bin} export {libnuma_path} --name=libnuma --version=system"
)

36
cpp/libnuma_conan.py Normal file
View File

@ -0,0 +1,36 @@
from conan import ConanFile
class LibnumaSystemConan(ConanFile):
name = "libnuma"
version = "system"
package_type = "shared-library"
settings = "os", "arch"
def package_info(self):
if self.settings.os == "Windows":
self.output.info("libnuma not needed on Windows.")
return
self.cpp_info.includedirs = ["/usr/include"]
libdirs = []
arch = str(self.settings.arch)
os_name = str(self.settings.os)
if os_name == "Linux":
if arch == "x86_64":
libdirs.append("/usr/lib/x86_64-linux-gnu")
elif arch in ["armv8", "aarch64"]:
libdirs.append("/usr/lib/aarch64-linux-gnu")
else:
self.output.warn(
f"Unrecognized architecture: {arch}, falling back to /usr/lib"
)
libdirs.append("/usr/lib")
else:
self.output.warn(f"Unsupported OS: {os_name}, assuming /usr/lib")
libdirs.append("/usr/lib")
self.cpp_info.libdirs = libdirs
self.cpp_info.system_libs = ["numa"]

View File

@ -245,12 +245,166 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
}
}
template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
__global__ void moeComputeRouteNoRedundantKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo,
int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount)
{
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
int expertIds[ITEM_PER_THREAD];
int slotIds[ITEM_PER_THREAD];
for (int slotId = threadIdx.x; slotId < metaInfo.epSize * metaInfo.slotCountPerRank; slotId += THREAD_COUNT)
{
sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds[slotId];
}
int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD;
for (; blockOffset < tokenCount * metaInfo.topK; blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD)
{
int tokenIdxBase = blockOffset + threadIdx.x;
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
expertIds[i]
= tokenIdx < tokenCount * metaInfo.topK ? tokenSelectedExperts[tokenIdx] : metaInfo.expertCount;
}
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
if (expertIds[i] < 0 || expertIds[i] >= metaInfo.expertCount)
{
expertIds[i] = metaInfo.expertCount;
}
}
if (blockOffset == blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD)
{
__syncthreads();
}
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
slotIds[i] = expertIds[i] < metaInfo.expertCount ? sharedGlobalSlotIdsInfo[expertIds[i]]
: metaInfo.epSize * metaInfo.slotCountPerRank;
}
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
if (tokenIdx < tokenCount * metaInfo.topK)
{
tokenRoutedSlotIds[tokenIdx] = slotIds[i];
}
}
}
}
template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
__global__ void moeComputeRouteKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo,
int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount, bool offsetByEpRank)
{
int warpId = threadIdx.x / 32;
int laneId = threadIdx.x % 32;
static int const kWarpCount = THREAD_COUNT / 32;
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
__shared__ int sharedExpertReplicaCountAndStartOffset[MAX_EXPERT_COUNT];
__shared__ int sharedArbitrateExpertId[THREAD_COUNT * ITEM_PER_THREAD];
__shared__ int sharedExpertCount[MAX_EXPERT_COUNT];
for (int expertIdx = threadIdx.x; expertIdx < metaInfo.expertCount; expertIdx += THREAD_COUNT)
{
int replicaCount = placementInfo.expertReplicaCount[expertIdx];
int replicaStartOffset = placementInfo.expertReplicaStartOffset[expertIdx];
sharedExpertReplicaCountAndStartOffset[expertIdx] = (replicaCount << 16) | replicaStartOffset;
sharedExpertCount[expertIdx] = 0;
}
for (int slotId = threadIdx.x; slotId < metaInfo.epSize * metaInfo.slotCountPerRank; slotId += THREAD_COUNT)
{
sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds[slotId];
}
int expertIds[ITEM_PER_THREAD];
int tokenIdxBase = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD + threadIdx.x;
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
expertIds[i] = tokenIdx < tokenCount * metaInfo.topK ? tokenSelectedExperts[tokenIdx] : metaInfo.expertCount;
}
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
if (expertIds[i] < 0 || expertIds[i] >= metaInfo.expertCount)
{
expertIds[i] = metaInfo.expertCount;
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int countAndStart
= expertIds[i] < metaInfo.expertCount ? sharedExpertReplicaCountAndStartOffset[expertIds[i]] : (1 << 16);
int arbitrateExpertId = (countAndStart >> 16) > 1 ? expertIds[i] : metaInfo.expertCount;
sharedArbitrateExpertId[threadIdx.x + i * THREAD_COUNT] = arbitrateExpertId;
}
__syncthreads();
int baseOffset = blockIdx.x + (offsetByEpRank ? metaInfo.epRank : 0);
if (warpId == 0)
{
#pragma unroll
for (int i = 0; i < kWarpCount * ITEM_PER_THREAD; ++i)
{
int expertId = sharedArbitrateExpertId[laneId + i * 32];
__syncwarp();
unsigned match = __match_any_sync(0xFFFFFFFF, expertId);
int leader = __ffs(match) - 1;
int matchCount = __popc(match);
int oldVal = 0;
if (laneId == leader && expertId < metaInfo.expertCount)
{
oldVal = atomicAdd_block(&sharedExpertCount[expertId], matchCount);
}
__syncwarp();
oldVal = __shfl_sync(0XFFFFFFFF, oldVal, leader);
unsigned lowerMask = match & ((1u << laneId) - 1);
int rankInGroup = __popc(lowerMask);
int offset = oldVal + rankInGroup;
offset += baseOffset;
sharedArbitrateExpertId[laneId + i * 32] = offset;
}
}
__syncthreads();
int targetGlobalSlotId[ITEM_PER_THREAD];
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int countAndStart
= expertIds[i] < metaInfo.expertCount ? sharedExpertReplicaCountAndStartOffset[expertIds[i]] : (1 << 16);
int count = countAndStart >> 16;
int offset = countAndStart & 0xFFFF;
int arbitratedIndex = sharedArbitrateExpertId[threadIdx.x + i * THREAD_COUNT];
offset += arbitratedIndex % count;
targetGlobalSlotId[i] = expertIds[i] < metaInfo.expertCount ? sharedGlobalSlotIdsInfo[offset]
: metaInfo.epSize * metaInfo.slotCountPerRank;
}
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
if (tokenIdx < tokenCount * metaInfo.topK)
{
tokenRoutedSlotIds[tokenIdx] = targetGlobalSlotId[i];
}
}
}
template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
__global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo,
int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount, bool offsetByEpRank)
{
using BlockSort = cub::BlockRadixSort<int, THREAD_COUNT, 1>;
extern __shared__ int sharedGlobalSlotIdsInfo[];
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
__shared__ typename BlockSort::TempStorage tempStorage;
@ -361,9 +515,19 @@ void moeComputeRouteDevice(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo pla
constexpr int kThreadCount = 256;
constexpr int kEltPerThread = 4;
int blockCount = (tokenCount * metaInfo.topK + kThreadCount * kEltPerThread - 1) / (kThreadCount * kEltPerThread);
int dynamicShmSize = sizeof(int) * metaInfo.epSize * metaInfo.slotCountPerRank;
moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
int dynamicShmSize = sizeof(int16_t) * metaInfo.epSize * metaInfo.slotCountPerRank;
if (metaInfo.expertCount == metaInfo.epSize * metaInfo.slotCountPerRank)
{
// no redundant expert, so we don't need complex routing, but just assign to the correct solt.
moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread>
<<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
}
else
{
moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
}
}
void moeWaitSignalForCpuStageHost(MoeLoadBalanceSingleLayerSignal* signal)

View File

@ -16,7 +16,7 @@
*/
#include "moeBindings.h"
#include "tensorrt_llm/runtime/moeLoadBalancer.h"
#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h"
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
@ -98,6 +98,8 @@ void initMoeBindings(pybind11::module_& m)
py::class_<tr::MoeLoadBalancer>(m, "MoeLoadBalancer")
.def(py::init<int, int, int>(), py::arg("ep_rank"), py::arg("ep_size"), py::arg("layer_updates_per_iter"),
"Initialize the MoeLoadBalancer with the specified expert parallel rank, size, and update frequency")
.def("set_use_gpu_memcpy", &tr::MoeLoadBalancer::setUseGpuMemcpy, py::arg("use_gpu_memcpy"),
"Set whether to use GPU memcpy for weight updates")
.def("add_layer", &tr::MoeLoadBalancer::AddLayer, py::arg("expert_count"), py::arg("top_k"),
py::arg("slot_count_per_rank"), "Add a new MOE layer to the load balancer")
.def("finalize_model", &tr::MoeLoadBalancer::finalizeModel,

View File

@ -43,7 +43,8 @@ set(SRCS
ipcNvlsMemory.cpp
mcastDeviceMemory.cpp
memoryCounters.cpp
moeLoadBalancer.cpp
moeLoadBalancer/moeLoadBalancer.cpp
moeLoadBalancer/topologyDetector.cpp
ncclCommunicator.cpp
promptTuningParams.cpp
runtimeKernels.cu
@ -80,3 +81,27 @@ target_include_directories(runtime_src PRIVATE ${MPI_C_INCLUDE_DIRS})
if(ENABLE_MULTI_DEVICE)
target_link_libraries(runtime_src PUBLIC ${NCCL_LIB})
endif()
if(NOT WIN32)
find_package(libnuma QUIET CONFIG)
if(NOT libnuma_FOUND)
message(
STATUS "libnuma not found via Conan, falling back to system libnuma")
find_path(NUMA_INCLUDE_DIR numa.h)
find_library(NUMA_LIBRARY numa)
if(NUMA_INCLUDE_DIR AND NUMA_LIBRARY)
add_library(libnuma::libnuma UNKNOWN IMPORTED)
set_target_properties(
libnuma::libnuma
PROPERTIES IMPORTED_LOCATION "${NUMA_LIBRARY}"
INTERFACE_INCLUDE_DIRECTORIES "${NUMA_INCLUDE_DIR}")
else()
message(FATAL_ERROR "NUMA library not found, please install libnuma-dev")
endif()
else()
message(STATUS "libnuma found.")
endif()
target_link_libraries(runtime_src PUBLIC libnuma::libnuma)
endif()

View File

@ -14,10 +14,11 @@
* limitations under the License.
*/
#include "tensorrt_llm/runtime/moeLoadBalancer.h"
#include "moeLoadBalancer.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
#include "topologyDetector.h"
#include <algorithm>
#include <atomic>
#include <cassert>
@ -33,7 +34,6 @@
namespace tensorrt_llm::runtime
{
// Helper structure to hold replica information
struct ReplicaInfo
{
@ -271,6 +271,7 @@ void allocateStatisticInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo const&
tensorrt_llm::kernels::MoeLoadBalanceStatisticInfo* statisticInfo)
{
TLLM_CUDA_CHECK(cudaMallocHost(&statisticInfo->expertLoadFactor, sizeof(float) * metaInfo.expertCount));
std::fill_n(statisticInfo->expertLoadFactor, metaInfo.expertCount, 0.0f);
TLLM_CHECK_WITH_INFO(statisticInfo->rawDataWindowSize > 0, "statisticInfo->rawDataWindowSize should > 0.");
TLLM_CUDA_CHECK(cudaMalloc(
&statisticInfo->expertTokenCount, sizeof(int) * metaInfo.expertCount * statisticInfo->rawDataWindowSize));
@ -288,9 +289,9 @@ void freeStatisticInfo(tensorrt_llm::kernels::MoeLoadBalanceStatisticInfo* stati
}
void allocatePlacementInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo const& metaInfo,
tensorrt_llm::kernels::MoePlacementInfo* placementInfo, bool isCpu = false)
tensorrt_llm::kernels::MoePlacementInfo* placementInfo, bool isCpu = false, bool useManaged = false)
{
auto allocFn = [isCpu](void** ptr, size_t size)
auto allocFn = [isCpu, useManaged](void** ptr, size_t size)
{
if (isCpu)
{
@ -298,7 +299,21 @@ void allocatePlacementInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo const&
}
else
{
return cudaMalloc(ptr, size);
if (useManaged)
{
TLLM_CUDA_CHECK(cudaMallocManaged(ptr, size));
int cur_dev;
TLLM_CUDA_CHECK(cudaGetDevice(&cur_dev));
TLLM_CUDA_CHECK(cudaMemAdvise(*ptr, size, cudaMemAdviseSetPreferredLocation, cur_dev));
TLLM_CUDA_CHECK(cudaMemAdvise(*ptr, size, cudaMemAdviseSetAccessedBy, cur_dev));
TLLM_CUDA_CHECK(cudaMemAdvise(*ptr, size, cudaMemAdviseSetAccessedBy, cudaCpuDeviceId));
TLLM_CUDA_CHECK(cudaMemset(*ptr, 0, size));
return cudaSuccess;
}
else
{
return cudaMalloc(ptr, size);
}
}
};
TLLM_CUDA_CHECK(
@ -405,7 +420,7 @@ void SingleLayerMoeLoadBalancer::createResources()
}
allocatePlacementInfo(mMetaInfo, &mCpuPlacementInfo.placementInfoForGPU, true);
allocatePlacementInfo(mMetaInfo, &mGpuPlacement, false);
allocatePlacementInfo(mMetaInfo, &mGpuPlacement, false, true);
mSingleLayerSignal = allocateSingleLayerSignal();
TLLM_CUDA_CHECK(cudaEventCreate(&mUpdateWeightsDoneEvent));
@ -451,7 +466,18 @@ void SingleLayerMoeLoadBalancer::maybeStartUpdateWeights()
{
if (mIterId >= 0 && mUpdateWeightsEnabled)
{
mMoeLoadBalancer->addUpdateTask([this] { updateWeightsRoutine(); });
mMoeLoadBalancer->addUpdateTask(
[this]
{
if (mMoeLoadBalancer->mUseGpuMemcpy)
{
updateWeightsRoutine();
}
else
{
updateWeightsRoutineByCpu();
}
});
}
}
@ -461,6 +487,7 @@ void SingleLayerMoeLoadBalancer::waitLastUpdateDone()
{
std::unique_lock<std::mutex> lock(mUpdateWeightsMutex);
mUpdateWeightsCondition.wait(lock, [this] { return mUpdateWeightsDone; });
lock.unlock();
}
}
@ -485,6 +512,25 @@ void SingleLayerMoeLoadBalancer::copyPlacementInfoToGpu()
{
std::fill_n(mCpuPlacementInfo.rankExpertIds[i].begin(), mMetaInfo.slotCountPerRank, -1);
}
// clear expert load factor for next statistic
std::fill_n(mStatisticInfo.expertLoadFactor, mMetaInfo.expertCount, 0.0f);
}
void SingleLayerMoeLoadBalancer::copyPlacementInfoToGpuByCpu()
{
memcpy(mGpuPlacement.expertReplicaCount, mCpuPlacementInfo.placementInfoForGPU.expertReplicaCount,
sizeof(int) * mMetaInfo.expertCount);
memcpy(mGpuPlacement.expertReplicaStartOffset, mCpuPlacementInfo.placementInfoForGPU.expertReplicaStartOffset,
sizeof(int) * mMetaInfo.expertCount);
memcpy(mGpuPlacement.globalSlotIds, mCpuPlacementInfo.placementInfoForGPU.globalSlotIds,
sizeof(int) * mMetaInfo.epSize * mMetaInfo.slotCountPerRank);
mCpuPlacementInfo.rankExpertIds.swap(mCpuPlacementInfo.oldRankExpertIds);
for (int i = 0; i < mMetaInfo.epSize; ++i)
{
std::fill_n(mCpuPlacementInfo.rankExpertIds[i].begin(), mMetaInfo.slotCountPerRank, -1);
}
// clear expert load factor for next statistic
std::fill_n(mStatisticInfo.expertLoadFactor, mMetaInfo.expertCount, 0.0f);
}
void SingleLayerMoeLoadBalancer::updateWeightsRoutine()
@ -501,6 +547,21 @@ void SingleLayerMoeLoadBalancer::updateWeightsRoutine()
mUpdateWeightsCondition.notify_one();
}
void SingleLayerMoeLoadBalancer::updateWeightsRoutineByCpu()
{
doReplication(mMetaInfo, mStatisticInfo.expertLoadFactor, &mCpuPlacementInfo);
doPlacement(mMetaInfo, mStatisticInfo.expertLoadFactor, &mCpuPlacementInfo);
prepareGpuPlacementInfo(mMetaInfo, &mCpuPlacementInfo);
mLastUpdateTaskId = mMoeLoadBalancer->addCopyTask(
[this](int rank, int size) { mWeightUpdater->updateWeights(&mCpuPlacementInfo, rank, size); });
mMoeLoadBalancer->waitCopyTaskDone(mLastUpdateTaskId);
mLastUpdateTaskId = -1;
copyPlacementInfoToGpuByCpu();
std::unique_lock<std::mutex> lock(mUpdateWeightsMutex);
mUpdateWeightsDone = true;
mUpdateWeightsCondition.notify_one();
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// Weight Updater
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -646,8 +707,61 @@ void HostMemoryMoeWeightUpdater::copyWeights(MoeWeight const& src, MoeWeight con
}
}
void HostMemoryMoeWeightUpdater::updateWeights(tensorrt_llm::runtime::MoePlacementCpuInfo const* placementCpuInfo)
void HostMemoryMoeWeightUpdater::copyWeightsCpu(MoeWeight const& src, MoeWeight const& dst, int rank, int size)
{
TLLM_CHECK(src.mWeightPtr != nullptr && dst.mWeightPtr != nullptr);
TLLM_CHECK(src.mHeight == dst.mHeight && src.mWidth == dst.mWidth);
char* srcPtr = static_cast<char*>(src.mWeightPtr);
char* dstPtr = static_cast<char*>(dst.mWeightPtr);
size_t singleCopySize, copyCount, srcPitch, dstPitch;
if (src.mPitch == src.mWidth && dst.mPitch == dst.mWidth)
{
singleCopySize = src.mWidth * src.mHeight;
copyCount = 1;
srcPitch = singleCopySize;
dstPitch = singleCopySize;
}
else
{
singleCopySize = src.mWidth;
copyCount = src.mHeight;
srcPitch = src.mPitch;
dstPitch = dst.mPitch;
}
size_t fullCopyCount = copyCount / size * size;
size_t threadCopyCount = fullCopyCount / size;
for (size_t i = rank * threadCopyCount; i < (rank + 1) * threadCopyCount; i++)
{
memcpy(dstPtr + i * dstPitch, srcPtr + i * srcPitch, singleCopySize);
}
size_t threadStartOffset = rank * singleCopySize / size;
size_t threadEndOffset = (rank + 1) * singleCopySize / size;
size_t threadCopySize = threadEndOffset - threadStartOffset;
for (size_t i = fullCopyCount; i < copyCount && threadCopySize > 0; i++)
{
memcpy(dstPtr + i * dstPitch + threadStartOffset, srcPtr + i * srcPitch + threadStartOffset, threadCopySize);
}
}
void PrintUpdateInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo,
tensorrt_llm::runtime::MoePlacementCpuInfo const* placementCpuInfo)
{
std::stringstream ss;
ss << "[UpdateInfo] rank=" << metaInfo.epRank << ", expert weights=\n [";
for (int slotId = 0; slotId < metaInfo.slotCountPerRank * metaInfo.epSize; slotId++)
{
ss << placementCpuInfo->rankExpertIds[slotId / metaInfo.slotCountPerRank][slotId % metaInfo.slotCountPerRank]
<< ", ";
}
ss << "\n";
fprintf(stderr, "%s\n", ss.str().c_str());
}
void HostMemoryMoeWeightUpdater::updateWeights(
tensorrt_llm::runtime::MoePlacementCpuInfo const* placementCpuInfo, int rank, int size)
{
// PrintUpdateInfo(mMetaInfo, placementCpuInfo);
bool useGpu = mLayerLoadBalancer->mMoeLoadBalancer->mUseGpuMemcpy;
for (int slotId = 0; slotId < mMetaInfo.slotCountPerRank; ++slotId)
{
int oldExpertId = placementCpuInfo->oldRankExpertIds[mMetaInfo.epRank][slotId];
@ -665,7 +779,14 @@ void HostMemoryMoeWeightUpdater::updateWeights(tensorrt_llm::runtime::MoePlaceme
auto& name = slotIt->first;
auto& slotWeight = slotIt->second[slotId];
auto& hostWeight = mHostWeights[name][newExpertId];
copyWeights(hostWeight, slotWeight, mLayerLoadBalancer->getStream());
if (useGpu)
{
copyWeights(hostWeight, slotWeight, mLayerLoadBalancer->getStream());
}
else
{
copyWeightsCpu(hostWeight, slotWeight, rank, size);
}
}
}
}
@ -680,8 +801,40 @@ MoeLoadBalancer::MoeLoadBalancer(int epRank, int epSize, int layerUpdatesPerIter
, mLayerUpdatesPerIter{layerUpdatesPerIter}
{
TLLM_CUDA_CHECK(cudaGetDevice(&mCudaDeviceId));
// create a non-blocking stream for compute and update
// create a non-blocking stream for compute and update, not needed anymore for CPU copy engine.
TLLM_CUDA_CHECK(cudaStreamCreateWithFlags(&mStream, cudaStreamNonBlocking));
auto& topologyDetector = TopologyDetector::getInstance();
int currentGpuNumaId = topologyDetector.getCurrentGpuNumaId();
int numaCpuCount = topologyDetector.getCurrentGpuNumaCpuCount();
int numaGpuCount = topologyDetector.getGpuCountUnderNuma(currentGpuNumaId);
TLLM_CHECK_WITH_INFO(
numaCpuCount > 0 && numaGpuCount > 0, "numaCpuCount=%d, numaGpuCount=%d", numaCpuCount, numaGpuCount);
int cpuCountPerGpu = std::max(1, numaCpuCount / numaGpuCount);
std::string cpuArch = topologyDetector.getCpuArchitecture();
int numCopyThreads = 8;
if (getenv("TLLM_LOAD_BALANCE_NUM_COPY_THREADS"))
{
int numCopyThreadsFromEnv = atoi(getenv("TLLM_LOAD_BALANCE_NUM_COPY_THREADS"));
if (numCopyThreadsFromEnv > 0)
{
TLLM_LOG_INFO(
"Setting TLLM_LOAD_BALANCE_NUM_COPY_THREADS to %d by environment variable", numCopyThreadsFromEnv);
numCopyThreads = numCopyThreadsFromEnv;
}
}
else
{
if (cpuCountPerGpu > 0)
{
numCopyThreads = std::min(16, std::max(4, cpuCountPerGpu / 2));
TLLM_LOG_INFO("Auto-setting copy threads to %d based on NUMA topology (NUMA node %d, %d CPUs, arch: %s)",
numCopyThreads, currentGpuNumaId, numaCpuCount, cpuArch.c_str());
}
}
mMultiThreadWorker.reset(new MultiThreadWorker(numCopyThreads));
}
MoeLoadBalancer::~MoeLoadBalancer() {}
@ -730,6 +883,7 @@ void MoeLoadBalancer::finalizeModel()
}
if (mLayerUpdatesPerIter > 0)
{
mMultiThreadWorker->start();
generateUpdatePlan();
startThreads();
}
@ -783,6 +937,7 @@ void MoeLoadBalancer::shutdown()
mWorkerThread->join();
TLLM_LOG_INFO("MoeLoadBalancer shutdown.");
mMultiThreadWorker->stop();
}
}
@ -831,7 +986,6 @@ void MoeLoadBalancer::workerThread()
}
addUpdateTask(nullptr);
mComputeAndUpdateThread->join();
TLLM_LOG_INFO("MoeLoadBalancer worker thread stopped");
}
void MoeLoadBalancer::computeAndUpdateThread()
@ -850,7 +1004,6 @@ void MoeLoadBalancer::computeAndUpdateThread()
}
task();
}
TLLM_LOG_INFO("MoeLoadBalancer compute and update thread stopped");
}
void MoeLoadBalancer::addUpdateTask(std::function<void()> task)
@ -860,4 +1013,127 @@ void MoeLoadBalancer::addUpdateTask(std::function<void()> task)
mUpdateQueueCondition.notify_one();
}
int64_t MoeLoadBalancer::addCopyTask(std::function<void(int, int)> task)
{
return mMultiThreadWorker->addTask(task);
}
void MoeLoadBalancer::waitCopyTaskDone(int64_t taskId)
{
if (!mUseGpuMemcpy)
{
mMultiThreadWorker->waitTaskDone(taskId);
}
}
MultiThreadWorker::MultiThreadWorker(int numThreads)
: mNumThreads(numThreads)
, mRunning(false)
, mNextTaskId(0)
{
}
MultiThreadWorker::~MultiThreadWorker()
{
stop();
}
void MultiThreadWorker::start()
{
std::lock_guard<std::mutex> lk(mMutex);
if (mRunning)
return;
mRunning = true;
mThreads.reserve(mNumThreads);
for (int i = 0; i < mNumThreads; ++i)
{
mThreads.emplace_back(&MultiThreadWorker::workerLoop, this, i);
}
}
int64_t MultiThreadWorker::addTask(std::function<void(int, int)> func)
{
auto task = std::make_shared<Task>();
{
std::lock_guard<std::mutex> lk(mMutex);
task->id = mNextTaskId++;
task->func = std::move(func);
task->remaining = mNumThreads;
mTasks.push_back(task);
mTaskMap[task->id] = task;
}
mCondition.notify_all();
return task->id;
}
void MultiThreadWorker::waitTaskDone(int64_t taskId)
{
std::unique_lock<std::mutex> lk(mMutex);
auto it = mTaskMap.find(taskId);
if (it == mTaskMap.end())
{
TLLM_CHECK_WITH_INFO(mDoneTaskMap.count(taskId) > 0, "Task %ld not found", taskId);
mDoneTaskMap.erase(taskId);
return;
}
auto task = it->second;
task->cv.wait(lk, [task] { return task->remaining == 0; });
TLLM_CHECK_WITH_INFO(mDoneTaskMap.count(taskId) > 0, "Task %ld not found", taskId);
mDoneTaskMap.erase(taskId);
}
void MultiThreadWorker::stop()
{
{
std::lock_guard<std::mutex> lk(mMutex);
if (!mRunning)
return;
mRunning = false;
}
mCondition.notify_all();
for (auto& t : mThreads)
{
if (t.joinable())
t.join();
}
mThreads.clear();
}
void MultiThreadWorker::workerLoop(int rank)
{
auto& topologyDetector = TopologyDetector::getInstance();
topologyDetector.bindThreadByCurrentGpu(); // use relaxed mode
while (true)
{
std::shared_ptr<Task> task;
{
std::unique_lock<std::mutex> lk(mMutex);
mCondition.wait(lk, [this] { return !mRunning || !mTasks.empty(); });
if (!mRunning && mTasks.empty())
return;
task = mTasks.front();
}
task->func(rank, mNumThreads);
{
std::unique_lock<std::mutex> lk(mMutex);
if (--task->remaining == 0)
{
mTasks.pop_front();
mTaskMap.erase(task->id);
mDoneTaskMap[task->id] = task;
task->cv.notify_all();
}
else
{
task->cv.wait(lk, [task] { return task->remaining == 0; });
}
}
}
}
} // namespace tensorrt_llm::runtime

View File

@ -16,6 +16,7 @@
#pragma once
#include <atomic>
#include <condition_variable>
#include <cstdint>
#include <functional>
@ -81,7 +82,7 @@ public:
void addSingleWeightSlot(int localSlotId, std::string const& name, MoeWeight weightSlot);
virtual void addSingleHostWeight(int expertId, std::string const& name, MoeWeight hostWeight) = 0;
virtual void finalizeWeights();
virtual void updateWeights(MoePlacementCpuInfo const* placementCpuInfo) = 0;
virtual void updateWeights(MoePlacementCpuInfo const* placementCpuInfo, int rank = 0, int size = 1) = 0;
protected:
void finalizeWeightSlot();
@ -102,10 +103,11 @@ public:
void addSingleHostWeight(int expertId, std::string const& name, MoeWeight hostWeight) override;
void finalizeWeights() override;
void updateWeights(MoePlacementCpuInfo const* placementCpuInfo) override;
void updateWeights(MoePlacementCpuInfo const* placementCpuInfo, int rank = 0, int size = 1) override;
private:
static void copyWeights(MoeWeight const& src, MoeWeight const& dst, cudaStream_t stream);
static void copyWeightsCpu(MoeWeight const& src, MoeWeight const& dst, int rank, int size);
void finalizeHostWeight();
bool mHostWeightsFinalized = false;
std::map<std::string, std::vector<MoeWeight>> mHostWeights;
@ -166,6 +168,7 @@ public:
private:
friend class MoeLoadBalancer;
friend class HostMemoryMoeWeightUpdater;
void createResources();
void destroyResources();
@ -187,7 +190,11 @@ private:
bool mUpdateWeightsEnabled = true;
void copyPlacementInfoToGpu();
void copyPlacementInfoToGpuByCpu();
void updateWeightsRoutine();
void updateWeightsRoutineByCpu();
int64_t mLastUpdateTaskId = -1;
cudaEvent_t mUpdateWeightsDoneEvent = nullptr;
tensorrt_llm::kernels::MoeLoadBalanceMetaInfo mMetaInfo;
@ -203,6 +210,42 @@ private:
int mLayerId = -1;
};
class MultiThreadWorker
{
public:
explicit MultiThreadWorker(int numThreads);
~MultiThreadWorker();
void start();
int64_t addTask(std::function<void(int, int)> func);
void waitTaskDone(int64_t taskId);
void stop();
private:
struct Task
{
int64_t id;
std::function<void(int, int)> func;
int remaining;
std::condition_variable cv;
};
void workerLoop(int rank);
int mNumThreads;
std::vector<std::thread> mThreads;
std::mutex mMutex;
std::condition_variable mCondition;
std::deque<std::shared_ptr<Task>> mTasks;
std::unordered_map<int64_t, std::shared_ptr<Task>> mTaskMap;
std::unordered_map<int64_t, std::shared_ptr<Task>> mDoneTaskMap;
bool mRunning;
int64_t mNextTaskId;
};
class MoeLoadBalancer
{
public:
@ -227,8 +270,15 @@ public:
// should bind to python
void shutdown();
// Test interface to use GPU to do memcpy test functionality
void setUseGpuMemcpy(bool useGpuMemcpy = false)
{
mUseGpuMemcpy = useGpuMemcpy;
}
private:
friend class SingleLayerMoeLoadBalancer;
friend class HostMemoryMoeWeightUpdater;
void startThreads();
@ -247,6 +297,8 @@ private:
std::condition_variable mUpdateQueueCondition;
std::queue<std::function<void()>> mUpdateTaskQueue;
void addUpdateTask(std::function<void()> task);
int64_t addCopyTask(std::function<void(int, int)> task);
void waitCopyTaskDone(int64_t taskId);
std::vector<std::shared_ptr<SingleLayerMoeLoadBalancer>> mLayers;
@ -272,10 +324,14 @@ private:
std::unique_ptr<std::thread> mWorkerThread;
std::unique_ptr<std::thread> mComputeAndUpdateThread;
std::unique_ptr<MultiThreadWorker> mMultiThreadWorker;
// update plan member and function
int mLayerUpdatesPerIter = 1;
std::deque<std::set<int>> mUpdateLayerQueue;
void generateUpdatePlan();
bool mUseGpuMemcpy = false;
};
// functions exposed for testing

View File

@ -0,0 +1,446 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <algorithm> // For std::for_each, std::sort, std::unique
#include <filesystem>
#include <fstream>
#include <limits> // For std::numeric_limits
#include <map>
#include <mutex>
#include <set>
#include <sstream>
#include <thread>
#include <vector>
#ifdef __linux__
#include <cerrno> // For errno
#include <cstring> // For strerror
#include <numa.h> // For libnuma
#include <numaif.h> // For struct bitmask definition if not in numa.h
#include <pthread.h>
#include <sched.h>
#endif
namespace tensorrt_llm::runtime
{
TopologyDetector::TopologyDetector()
{
std::lock_guard<std::mutex> lock(mDetectionMutex);
if (!mTopologyDetected)
{
detectCpuTopology();
detectGpuTopology();
#ifdef __linux__
if (numa_available() != -1)
{ // Only precompute if libnuma is usable
precomputeCpuAffinityMasks();
}
#endif
mTopologyDetected = true;
}
}
TopologyDetector::~TopologyDetector()
{
#ifdef __linux__
auto free_mask_map = [](std::map<int, struct bitmask*>& mask_map)
{
for (auto const& [id, mask] : mask_map)
{
if (mask)
{
numa_free_cpumask(mask);
}
}
mask_map.clear();
};
free_mask_map(mGpuStrictCpuMasks);
#endif
}
void TopologyDetector::detectCpuTopology()
{
// Detect CPU architecture
#if defined(__x86_64__) || defined(_M_X64)
mCpuArchitecture = "x86_64";
#elif defined(__aarch64__) || defined(_M_ARM64)
mCpuArchitecture = "aarch64";
#elif defined(__powerpc64__)
mCpuArchitecture = "ppc64";
#else
mCpuArchitecture = "unknown";
#endif
// Detect NUMA topology on Linux systems using libnuma
#ifdef __linux__
if (numa_available() == -1)
{
// libnuma not available, fall back to default behavior
TLLM_LOG_WARNING("libnuma not available. Falling back to default CPU topology detection.");
mNumaToCpuCountMap[0] = std::thread::hardware_concurrency();
return;
}
int maxNode = numa_max_node();
if (maxNode < 0)
{
// Failed to get max node, fall back to default behavior
TLLM_LOG_WARNING("Failed to get max NUMA node. Falling back to default CPU topology detection.");
mNumaToCpuCountMap[0] = std::thread::hardware_concurrency();
return;
}
mNumaToCpuCountMap.clear(); // Clear before re-populating
std::map<int, int> tempNumaToCpuCountMap;
for (int i = 0; i <= maxNode; ++i)
{
struct bitmask* cpus = numa_allocate_cpumask();
if (!cpus)
{
TLLM_LOG_WARNING("Failed to allocate cpumask for NUMA node query. Skipping node %d.", i);
continue; // Skip to the next node if allocation fails
}
// Attempt to get CPUs for node i. If numa_node_to_cpus returns 0, it's successful.
if (numa_node_to_cpus(i, cpus) == 0)
{
int cpuCount = 0;
for (int cpu_idx = 0; cpu_idx < numa_num_possible_cpus(); ++cpu_idx)
{
if (numa_bitmask_isbitset(cpus, cpu_idx))
{
cpuCount++;
}
}
if (cpuCount > 0)
{ // Only add NUMA nodes with actual CPUs
tempNumaToCpuCountMap[i] = cpuCount;
}
}
// If numa_node_to_cpus failed (returned -1), node 'i' might be invalid or an error occurred.
// In this case, we simply don't add it to our map, effectively skipping it.
numa_free_cpumask(cpus); // Always free the allocated mask
}
mNumaToCpuCountMap = tempNumaToCpuCountMap;
if (mNumaToCpuCountMap.empty())
{
// If no NUMA nodes with CPUs were detected (e.g. libnuma error or unusual configuration),
// default to a single NUMA node with all hardware concurrency.
TLLM_LOG_WARNING(
"No NUMA nodes with CPUs detected via libnuma, or libnuma error. Defaulting to single NUMA node.");
mNumaToCpuCountMap[0] = std::thread::hardware_concurrency();
}
#else
// For non-Linux systems, assume a single NUMA node
mNumaToCpuCountMap[0] = std::thread::hardware_concurrency();
#endif
}
void TopologyDetector::detectGpuTopology()
{
int deviceCount = 0;
cudaError_t result = cudaGetDeviceCount(&deviceCount);
if (result != cudaSuccess || deviceCount == 0)
{
return;
}
mGpuToNumaMap.clear(); // Clear before re-populating
mNumaToGpuMap.clear(); // Clear before re-populating
for (int deviceId = 0; deviceId < deviceCount; ++deviceId)
{
int numaNode = 0; // Default NUMA node
#ifdef __linux__
if (numa_available() != -1)
{
char pciPath[256];
cudaDeviceProp prop;
if (cudaGetDeviceProperties(&prop, deviceId) == cudaSuccess)
{
// Construct PCI path to find NUMA node
snprintf(pciPath, sizeof(pciPath), "/sys/bus/pci/devices/%04x:%02x:%02x.0/numa_node", prop.pciDomainID,
prop.pciBusID, prop.pciDeviceID);
std::ifstream numaFile(pciPath);
if (numaFile.is_open())
{
numaFile >> numaNode;
numaFile.close();
// If NUMA node is -1, it means no specific NUMA information, use node 0
if (numaNode < 0)
{
numaNode = 0;
}
}
else
{
// Fallback if sysfs path is not available or readable
TLLM_LOG_DEBUG("Could not open %s to determine NUMA node for GPU %d. Defaulting to node 0.",
pciPath, deviceId);
numaNode = 0;
}
TLLM_LOG_INFO("GPU %d is on NUMA node %d", deviceId, numaNode);
}
else
{
TLLM_LOG_WARNING("Failed to get properties for GPU %d. Defaulting to NUMA node 0.", deviceId);
numaNode = 0;
}
}
else
{
// libnuma not available, default GPU to NUMA node 0
numaNode = 0;
}
#endif
mGpuToNumaMap[deviceId] = numaNode;
mNumaToGpuMap[numaNode].push_back(deviceId);
}
}
#ifdef __linux__
static void bitmask_copy_manual(struct bitmask* dst, const struct bitmask* src)
{
if (!dst || !src)
return;
numa_bitmask_clearall(dst);
for (int i = 0; i < numa_num_possible_cpus(); ++i)
{
if (numa_bitmask_isbitset(src, i))
{
numa_bitmask_setbit(dst, i);
}
}
}
static void bitmask_or_manual(struct bitmask* dst, const struct bitmask* src)
{
if (!dst || !src)
return;
for (int i = 0; i < numa_num_possible_cpus(); ++i)
{
if (numa_bitmask_isbitset(src, i))
{
numa_bitmask_setbit(dst, i);
}
}
}
void TopologyDetector::precomputeCpuAffinityMasks()
{
int num_gpus = 0;
cudaError_t err = cudaGetDeviceCount(&num_gpus);
if (err != cudaSuccess || num_gpus == 0)
{
return;
}
for (int gpuId = 0; gpuId < num_gpus; ++gpuId)
{
auto itGpuNuma = mGpuToNumaMap.find(gpuId);
if (itGpuNuma == mGpuToNumaMap.end())
{
TLLM_LOG_WARNING("GPU %d not found in mGpuToNumaMap during mask precomputation. Skipping.", gpuId);
continue;
}
int gpuNumaNode = itGpuNuma->second;
// Strict Mask: CPUs on the GPU's direct NUMA node
struct bitmask* strictMask = numa_allocate_cpumask(); // Uses numa_bitmask_alloc internally
if (strictMask)
{
numa_bitmask_clearall(strictMask); // Initialize to empty
if (mNumaToCpuCountMap.count(gpuNumaNode) && mNumaToCpuCountMap.at(gpuNumaNode) > 0)
{
if (numa_node_to_cpus(gpuNumaNode, strictMask) != 0)
{
TLLM_LOG_WARNING(
"Failed to get CPUs for GPU %d's NUMA node %d for strict mask. Strict mask will be empty.",
gpuId, gpuNumaNode);
numa_bitmask_clearall(strictMask); // Ensure it's empty on failure
}
}
mGpuStrictCpuMasks[gpuId] = strictMask;
}
else
{
TLLM_LOG_WARNING("Failed to allocate strict CPU mask for GPU %d.", gpuId);
}
}
}
const struct bitmask* TopologyDetector::getStrictCpuMaskForGpu(int gpuId) const
{
auto it = mGpuStrictCpuMasks.find(gpuId);
if (it != mGpuStrictCpuMasks.end())
{
return it->second;
}
return nullptr;
}
#endif
void TopologyDetector::bindThreadByCurrentGpu()
{
#ifdef __linux__
if (numa_available() == -1)
{
TLLM_LOG_WARNING("libnuma not available. Cannot bind thread to NUMA node.");
return;
}
int currentDevice = -1;
if (cudaGetDevice(&currentDevice) != cudaSuccess)
{
TLLM_LOG_WARNING("Failed to get current CUDA device. Cannot bind thread.");
return;
}
const struct bitmask* targetMask = nullptr;
targetMask = getStrictCpuMaskForGpu(currentDevice);
if (targetMask)
{
// Check if the mask is not all clear before attempting to set affinity
bool maskIsClear = true;
for (int k = 0; k < numa_num_possible_cpus(); ++k)
{
if (numa_bitmask_isbitset(targetMask, k))
{
maskIsClear = false;
break;
}
}
if (!maskIsClear)
{
// Create a mutable copy of the targetMask to pass to numa_sched_setaffinity
struct bitmask* mutableCopyForAffinity = numa_allocate_cpumask();
if (mutableCopyForAffinity)
{
bitmask_copy_manual(mutableCopyForAffinity, targetMask);
if (numa_sched_setaffinity(0, mutableCopyForAffinity) == -1)
{ // 0 refers to the current thread
TLLM_LOG_WARNING("Failed to set thread affinity for GPU %d using precomputed mask. Error: %s",
currentDevice, strerror(errno));
}
numa_free_cpumask(mutableCopyForAffinity);
}
else
{
TLLM_LOG_WARNING(
"Failed to allocate temporary bitmask for setting affinity. Cannot bind thread for GPU %d.",
currentDevice);
}
}
else
{
TLLM_LOG_DEBUG("Target affinity mask for GPU %d is empty. Not setting affinity.", currentDevice);
}
}
else
{
TLLM_LOG_WARNING("Precomputed CPU affinity mask not found for GPU %d. Cannot bind thread.", currentDevice);
}
#else
TLLM_LOG_DEBUG("Thread binding by GPU NUMA node is only supported on Linux with libnuma.");
#endif
}
int TopologyDetector::getCurrentGpuNumaCpuCount()
{
int numaId = getCurrentGpuNumaId();
if (numaId >= 0)
{
auto it = mNumaToCpuCountMap.find(numaId);
if (it != mNumaToCpuCountMap.end())
{
return it->second;
}
}
TLLM_LOG_DEBUG(
"CPU count for GPU's NUMA node %d not found or node invalid. Returning total hardware concurrency.", numaId);
return std::thread::hardware_concurrency();
}
int TopologyDetector::getCurrentGpuNumaId()
{
int currentDevice = -1;
if (cudaGetDevice(&currentDevice) != cudaSuccess)
{
return -1; // Indicate error or no CUDA device context
}
auto it = mGpuToNumaMap.find(currentDevice);
if (it != mGpuToNumaMap.end())
{
return it->second;
}
TLLM_LOG_WARNING("NUMA node for current GPU %d not found in map. Defaulting to node 0.", currentDevice);
return 0;
}
int TopologyDetector::getGpuCountUnderNuma(int numaId)
{
auto it = mNumaToGpuMap.find(numaId);
if (it != mNumaToGpuMap.end())
{
return it->second.size();
}
return 0;
}
std::string TopologyDetector::getCpuArchitecture()
{
return mCpuArchitecture;
}
bool TopologyDetector::canSupportHostNativeAtomics()
{
int currentDevice = -1;
if (cudaGetDevice(&currentDevice) != cudaSuccess)
{
TLLM_LOG_WARNING("Failed to get current CUDA device for atomic support check.");
return false;
}
int hostNativeAtomicSupported = 0;
cudaError_t err
= cudaDeviceGetAttribute(&hostNativeAtomicSupported, cudaDevAttrHostNativeAtomicSupported, currentDevice);
if (err != cudaSuccess)
{
TLLM_LOG_WARNING("Failed to get cudaDevAttrHostNativeAtomicSupported for device %d. Error: %s", currentDevice,
cudaGetErrorString(err));
return false;
}
return static_cast<bool>(hostNativeAtomicSupported);
}
} // namespace tensorrt_llm::runtime

View File

@ -0,0 +1,100 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <map>
#include <mutex>
#include <string>
#include <vector>
#ifdef __linux__
#include <numa.h> // For libnuma
#endif
// Forward declaration for struct bitmask to avoid including numaif.h if numa.h already covers it,
// or if only numa.h is intended to be the public include for this header's users.
#ifdef __linux__
struct bitmask;
#endif
namespace tensorrt_llm::runtime
{
class TopologyDetector
{
public:
static TopologyDetector& getInstance()
{
static TopologyDetector instance;
return instance;
}
~TopologyDetector();
// Binds the current thread to the CPU cores of the NUMA node associated with the current GPU.
void bindThreadByCurrentGpu();
// Returns the number of CPU cores on the NUMA node associated with the current GPU.
// Returns total hardware concurrency as a fallback if specific count cannot be determined.
int getCurrentGpuNumaCpuCount();
// Returns the ID of the NUMA node associated with the current GPU.
// Returns 0 as a default or -1 on error.
int getCurrentGpuNumaId();
// Returns the number of GPUs associated with the given NUMA node ID.
int getGpuCountUnderNuma(int numaId);
// Returns the number of GPUs which have same NUMA node ID with the current GPU.
int getGpuCountUnderSameNuma()
{
return getGpuCountUnderNuma(getCurrentGpuNumaId());
}
// Returns the detected CPU architecture (e.g., "x86_64", "aarch64").
std::string getCpuArchitecture();
// Checks if the current CUDA device and host system support native atomic operations.
bool canSupportHostNativeAtomics();
#ifdef __linux__
// Getters for precomputed CPU affinity masks
const struct bitmask* getStrictCpuMaskForGpu(int gpuId) const;
#endif
private:
TopologyDetector();
void detectCpuTopology(); // Detects CPU NUMA topology and CPU counts per node.
void detectGpuTopology(); // Detects GPU to NUMA node mapping.
#ifdef __linux__
void precomputeCpuAffinityMasks(); // Precomputes CPU masks for each GPU
#endif
// Member variables
std::map<int, int> mGpuToNumaMap; // GPU ID -> NUMA Node ID
std::map<int, std::vector<int>> mNumaToGpuMap; // NUMA Node ID -> List of GPU IDs
std::map<int, int> mNumaToCpuCountMap; // NUMA Node ID -> CPU Core Count
std::string mCpuArchitecture;
bool mTopologyDetected = false;
std::mutex mDetectionMutex; // Mutex to protect topology detection process
#ifdef __linux__
// Precomputed CPU affinity masks
std::map<int, struct bitmask*> mGpuStrictCpuMasks; // GPU ID -> Strict CPU mask
#endif
};
} // namespace tensorrt_llm::runtime

View File

@ -20,12 +20,14 @@
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <c10/core/Allocator.h> // for c10::DataPtr
#include <c10/core/StorageImpl.h> // for c10::StorageImpl and use_byte_size_t()
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <c10/util/intrusive_ptr.h> // for c10::make_intrusive#include <torch/extension.h>
#include <vector>
#include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
#include "tensorrt_llm/runtime/moeLoadBalancer.h"
#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h"
namespace torch_ext
{
@ -109,6 +111,43 @@ torch::Tensor moeLoadBalanceRouting(
return tokenRoutedSlotIds;
}
void migrateToManaged(at::Tensor& tensor)
{
TORCH_CHECK(tensor.device().is_cuda(), "only support CUDA Tensor");
// 1) compute total bytes
size_t byte_size = tensor.numel() * tensor.element_size();
// 2) allocate UVM
void* managed_ptr = nullptr;
cudaError_t err = cudaMallocManaged(&managed_ptr, byte_size);
TORCH_CHECK(err == cudaSuccess, "cudaMallocManaged failed");
// 3) advise to place on current GPU
int cur_dev;
TLLM_CUDA_CHECK(cudaGetDevice(&cur_dev));
TLLM_CUDA_CHECK(cudaMemAdvise(managed_ptr, byte_size, cudaMemAdviseSetPreferredLocation, cur_dev));
TLLM_CUDA_CHECK(cudaMemAdvise(managed_ptr, byte_size, cudaMemAdviseSetAccessedBy, cur_dev));
TLLM_CUDA_CHECK(cudaMemAdvise(managed_ptr, byte_size, cudaMemAdviseSetAccessedBy, cudaCpuDeviceId));
// 4) copy old data to UVM
TLLM_CUDA_CHECK(cudaMemcpy(managed_ptr, tensor.data_ptr(), byte_size, cudaMemcpyDeviceToDevice));
// 5) use new DataPtr/StorageImpl to construct storage
// here managed_ptr is dataand also contextuse cudaFree as deleter
c10::DataPtr dp(
managed_ptr, managed_ptr, [](void* ptr) { cudaFree(ptr); }, tensor.device());
auto allocator = c10::GetAllocator(tensor.device().type());
auto storage_impl = c10::make_intrusive<c10::StorageImpl>(c10::StorageImpl::use_byte_size_t(), byte_size,
std::move(dp), allocator,
/*resizable=*/false);
at::Storage new_storage(storage_impl);
// Finally replace tensor's storageoffset = 0shape and stride kept unchanged
tensor.set_(new_storage,
/*storage_offset=*/0, tensor.sizes().vec(), tensor.strides().vec());
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
@ -154,3 +193,13 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_load_balance_routing", &torch_ext::moeLoadBalanceRouting);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("migrate_to_managed(Tensor tensor) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("migrate_to_managed", &torch_ext::migrateToManaged);
}

View File

@ -18,7 +18,7 @@
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
#include "tensorrt_llm/runtime/moeLoadBalancer.h"
#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h"
using namespace tensorrt_llm::runtime;
@ -318,6 +318,8 @@ protected:
TLLM_CUDA_CHECK(cudaSetDevice(0));
mLoadBalancer = std::make_unique<MoeLoadBalancer>(param.epRank, param.epSize, param.layerUpdatesPerIter);
mLoadBalancer->setUseGpuMemcpy(true);
// Create multiple MoE layers
createLayers(param);

View File

@ -53,6 +53,8 @@ init_ubuntu() {
llvm \
libclang-rt-dev \
libffi-dev \
libnuma1 \
libnuma-dev \
python3-dev \
python3-pip \
python-is-python3 \
@ -88,6 +90,8 @@ install_python_rockylinux() {
llvm-toolset \
lld \
libffi-devel \
numactl \
numactl-devel \
zlib-devel \
xz-devel \
sqlite-devel \

View File

@ -28,10 +28,10 @@ UPLOAD_PATH = env.uploadPath ? env.uploadPath : "sw-tensorrt-generic/llm-artifac
// Container configuration
// available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/
// [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id]
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539"
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539"
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202505211401-4539"
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202505211401-4539"
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506021004-9420"
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506021004-9420"
// TODO: Move common variables to an unified location
BUILD_CORES_REQUEST = "8"

View File

@ -1,7 +1,7 @@
import java.lang.InterruptedException
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539"
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
def createKubernetesPodConfig(image)
{

View File

@ -440,7 +440,7 @@ def main(*,
with working_directory(build_dir):
if clean or first_build or configure_cmake:
build_run(
f"\"{venv_conan}\" install --remote=tensorrt-llm --output-folder={build_dir}/conan -s 'build_type={build_type}' {source_dir}"
f"\"{venv_conan}\" install --build=missing --remote=tensorrt-llm --output-folder={build_dir}/conan -s 'build_type={build_type}' {source_dir}"
)
cmake_def_args.append(
f"-DCMAKE_TOOLCHAIN_FILE={build_dir}/conan/conan_toolchain.cmake"

View File

@ -23,18 +23,13 @@ class MoeLoadBalancerConfig:
repr=False)
layer_updates_per_iter: int = 0
num_experts: Optional[int] = field(default=None, init=False)
ep_rank: Optional[int] = field(default=None, init=False)
ep_size: Optional[int] = field(default=None, init=False)
def setup(self, num_experts: int, ep_rank: int, ep_size: int) -> None:
self.num_experts = num_experts
def setup(self, ep_rank: int, ep_size: int) -> None:
self.ep_rank = ep_rank
self.ep_size = ep_size
if self.num_slots is None:
self.num_slots = self.num_experts
assert self.num_slots >= self.num_experts
assert self.num_slots % self.ep_size == 0
assert self.num_slots is not None
@property
def num_local_slots(self) -> int:
@ -49,17 +44,13 @@ class MoeLoadBalancerConfig:
return self.slot_start + self.num_local_slots
def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]:
if self.initial_global_assignments is None:
return [(ep_rank * self.num_experts // self.ep_size + i) %
self.num_experts for ep_rank in range(self.ep_size)
for i in range(self.num_local_slots)]
else:
if self.initial_global_assignments is not None:
assert layer_idx in self.initial_global_assignments
assert len(
self.initial_global_assignments[layer_idx]) == self.num_slots
assert set(self.initial_global_assignments[layer_idx]) == set(
range(self.num_experts))
return self.initial_global_assignments[layer_idx]
else:
return None
@dataclass(kw_only=True)

View File

@ -53,7 +53,7 @@ from ..modules.attention import MLA
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (CutlassFusedMoE, DeepSeekV3MoeRoutingMethod,
MoeLoadBalancer, create_moe)
create_moe)
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@ -344,7 +344,6 @@ class Deepseekv3MoE(nn.Module):
dtype: Optional[torch.dtype] = None,
model_config: ModelConfig = ModelConfig(),
override_quant_config: Optional[QuantConfig] = None,
moe_load_balancer: Optional[MoeLoadBalancer] = None,
layer_idx: Optional[int] = None):
from ..distributed import AllReduce
@ -379,7 +378,6 @@ class Deepseekv3MoE(nn.Module):
override_quant_config=override_quant_config,
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap],
enable_alltoall=self.enable_alltoall,
moe_load_balancer=moe_load_balancer,
layer_idx=layer_idx)
self.mapping = model_config.mapping
@ -542,11 +540,9 @@ class Deepseekv3MoE(nn.Module):
class DeepseekV3DecoderLayer(DecoderLayer):
def __init__(self,
model_config: ModelConfig[PretrainedConfig],
layer_idx: int,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
moe_load_balancer: Optional[MoeLoadBalancer] = None):
def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
torch.cuda.Stream]):
super().__init__()
self.model_config = model_config
config = model_config.pretrained_config
@ -598,7 +594,6 @@ class DeepseekV3DecoderLayer(DecoderLayer):
model_config=model_config,
override_quant_config=quant_config,
aux_stream_dict=aux_stream_dict,
moe_load_balancer=moe_load_balancer,
layer_idx=layer_idx)
else:
block_size = 1
@ -865,13 +860,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
class DeepseekV3MTP(DeepseekV3DecoderLayer):
def __init__(self,
model_config: ModelConfig[PretrainedConfig],
layer_idx: int,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
moe_load_balancer: Optional[MoeLoadBalancer] = None):
super().__init__(model_config, layer_idx, aux_stream_dict,
moe_load_balancer)
def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
torch.cuda.Stream]):
super().__init__(model_config, layer_idx, aux_stream_dict)
config = model_config.pretrained_config
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
@ -992,23 +984,9 @@ class DeepseekV3Model(DecoderModel):
dtype=config.torch_dtype,
)
self.moe_load_balancer = None
if model_config.moe_load_balancer is not None:
num_experts = config.n_routed_experts
ep_rank = model_config.mapping.moe_ep_rank
ep_size = model_config.mapping.moe_ep_size
model_config.moe_load_balancer.setup(num_experts=num_experts,
ep_rank=ep_rank,
ep_size=ep_size)
self.moe_load_balancer = MoeLoadBalancer(
ep_rank=ep_rank,
ep_size=ep_size,
layer_updates_per_iter=model_config.moe_load_balancer.
layer_updates_per_iter)
self.layers = nn.ModuleList([
DeepseekV3DecoderLayer(model_config, layer_idx,
self.aux_stream_dict, self.moe_load_balancer)
self.aux_stream_dict)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(hidden_size=config.hidden_size,
@ -1054,7 +1032,6 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
self.moe_load_balancer = self.model.moe_load_balancer
self.model_nextn = 0
if model_config.spec_config is not None:
model_nextn = model_config.spec_config.num_nextn_predict_layers
@ -1063,8 +1040,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
if ckpt_nextn == 1:
mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers,
self.model.aux_stream_dict,
self.moe_load_balancer)
self.model.aux_stream_dict)
self.model.layers.append(mtp_layer)
self.epilogue.append(mtp_layer)
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
@ -1074,8 +1050,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
mtp_layers = nn.ModuleList([
DeepseekV3MTP(model_config,
layer_idx + self.num_hidden_layers,
self.model.aux_stream_dict,
self.moe_load_balancer)
self.model.aux_stream_dict)
for layer_idx in range(model_nextn)
])
self.model.layers.extend(mtp_layers)
@ -1100,9 +1075,6 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
extend_exclude_modules)
self.epilogue.append(self.mtp_worker)
if self.moe_load_balancer is not None:
self.moe_load_balancer.finalize_model()
def forward(
self,
attn_metadata: AttentionMetadata,

View File

@ -10,7 +10,7 @@ from .fused_moe_cutlass import CutlassFusedMoE
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
from .fused_moe_vanilla import VanillaMoE
from .interface import MoE, MoEWeightLoadingMode
from .moe_load_balancer import MoeLoadBalancer
from .moe_load_balancer import get_moe_load_balancer
from .routing import BaseMoeRoutingMethod
@ -53,15 +53,17 @@ def create_moe(
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA,
apply_router_weight_on_input: bool = False,
enable_alltoall: bool = False,
moe_load_balancer: Optional[MoeLoadBalancer] = None,
layer_idx: Optional[int] = None,
) -> MoE:
moe_cls = get_moe_cls(model_config, override_quant_config)
moe_load_balancer = get_moe_load_balancer()
if moe_load_balancer is not None:
assert moe_cls == CutlassFusedMoE, "MoE Load Balance is only supported in CutlassFusedMoE now."
if moe_cls == TRTLLMGenFusedMoE:
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE."
assert not enable_alltoall, "enable_alltoall is not supported in TRTLLMGenFusedMoE."
assert moe_load_balancer is None, "moe_load_balancer is not supported in TRTLLMGenFusedMoE."
return moe_cls(
routing_method=routing_method,
@ -87,13 +89,11 @@ def create_moe(
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
enable_alltoall=enable_alltoall,
moe_load_balancer=moe_load_balancer,
layer_idx=layer_idx,
)
elif moe_cls == VanillaMoE:
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE."
assert not enable_alltoall, "enable_alltoall is not supported in VanillaMoE."
assert moe_load_balancer is None, "moe_load_balancer is not supported in VanillaMoE."
return moe_cls(
routing_method=routing_method,

View File

@ -1,5 +1,5 @@
import os
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
@ -8,11 +8,11 @@ from tensorrt_llm._utils import logger
from ...distributed import allgather, reducescatter
from ...expert_statistic import ExpertStatistic
from ...model_config import ModelConfig, MoeLoadBalancerConfig
from ...model_config import ModelConfig
from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
reswizzle_sf, swizzle_sf, unswizzle_sf)
from .interface import MoE
from .moe_load_balancer import MoeLoadBalancer
from .moe_load_balancer import get_moe_load_balancer
from .quantization import (FP8BlockScalesFusedMoEMethod, FP8QDQFusedMoEMethod,
MoEWeightLoadingMode, NVFP4CutlassFusedMoEMethod,
UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod)
@ -82,7 +82,6 @@ class CutlassFusedMoE(MoE):
VANILLA,
apply_router_weight_on_input: bool = False,
enable_alltoall: bool = False,
moe_load_balancer: Optional[MoeLoadBalancer] = None,
layer_idx: Optional[int] = None,
):
@ -98,46 +97,56 @@ class CutlassFusedMoE(MoE):
)
self.layer_idx = layer_idx
moe_load_balancer = get_moe_load_balancer()
self.layer_load_balancer = None
moe_load_balancer_config = model_config.moe_load_balancer
if moe_load_balancer_config is None:
assert moe_load_balancer is None
# A dummy MoeLoadBalancerConfig to generate default initial_global_assignments
moe_load_balancer_config = MoeLoadBalancerConfig()
moe_load_balancer_config.setup(num_experts=num_experts,
ep_rank=self.ep_rank,
ep_size=self.ep_size)
else:
assert moe_load_balancer is not None
init_expert_size_per_partition = moe_load_balancer_config.num_local_slots if moe_load_balancer_config else self.num_experts // self.ep_size
self.initial_global_assignments = [
(ep_rank * self.num_experts // self.ep_size + local_slot_id) %
self.num_experts for ep_rank in range(self.ep_size)
for local_slot_id in range(init_expert_size_per_partition)
]
self.num_slots = moe_load_balancer_config.num_slots
if self.smart_router:
assert self.num_slots == self.num_experts, "Smart router should not have redundant slots"
self.initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments(
layer_idx)
self.expert_size_per_partition = moe_load_balancer_config.num_local_slots
self.slot_start = moe_load_balancer_config.slot_start
self.slot_end = moe_load_balancer_config.slot_end
self.initial_local_expert_ids = self.initial_global_assignments[
self.slot_start:self.slot_end]
assert len(
self.initial_local_expert_ids) == self.expert_size_per_partition
self.balancer_layer = None
if moe_load_balancer is not None:
self.balancer_layer = moe_load_balancer.add_layer(
expert_count=num_experts,
top_k=routing_method.experts_per_token,
slot_count_per_rank=self.expert_size_per_partition,
)
self.balancer_layer.set_initial_weight_assignments(
if moe_load_balancer:
assert moe_load_balancer_config is not None
top_k = self.routing_method.experts_per_token
self.expert_size_per_partition = moe_load_balancer_config.num_local_slots
self.layer_load_balancer = moe_load_balancer.add_layer(
self.num_experts, top_k, self.expert_size_per_partition)
loaded_initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments(
self.layer_idx)
self.num_slots = moe_load_balancer_config.num_slots
if loaded_initial_global_assignments is not None:
assert isinstance(loaded_initial_global_assignments, list)
assert len(loaded_initial_global_assignments) == self.num_slots
assert self.num_slots >= self.num_experts
assert set(loaded_initial_global_assignments) == set(
range(self.num_experts))
self.initial_global_assignments = loaded_initial_global_assignments
self.layer_load_balancer.set_initial_weight_assignments(
self.initial_global_assignments)
logger.info(
f"MoE load balancer enabled. num_experts = {num_experts}, num_slots = {self.num_slots}, ep_size = {self.ep_size}"
)
logger.info(
f"initial_global_assignments (layer {layer_idx}) = {self.initial_global_assignments}"
f"initial_global_assignments (layer {self.layer_idx}) = {self.initial_global_assignments}"
)
else:
assert num_experts % self.ep_size == 0
self.expert_size_per_partition = num_experts // self.ep_size
self.num_slots = num_experts
if self.smart_router:
assert self.num_slots == self.num_experts, "Smart router should not have redundant slots"
self.slot_start = self.ep_rank * self.expert_size_per_partition
self.slot_end = self.slot_start + self.expert_size_per_partition
self.initial_local_expert_ids = self.initial_global_assignments[
self.slot_start:self.slot_end]
assert len(
self.initial_local_expert_ids) == self.expert_size_per_partition
max_num_tokens = model_config.max_num_tokens
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
@ -259,13 +268,14 @@ class CutlassFusedMoE(MoE):
return outputs
def forward_chunk(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
cutlass_min_latency_mode: bool = False,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
cutlass_min_latency_mode: bool = False,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: Tuple = (True, True),
) -> torch.Tensor:
if isinstance(x, Fp4QuantizedTensor):
assert output_dtype is not None
@ -273,31 +283,25 @@ class CutlassFusedMoE(MoE):
else:
output_dtype = x.dtype
is_first_call, is_last_call = repeating_info
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_first_call:
self.layer_load_balancer.wait_for_gpu_stage()
use_fp8_block_scaling = False
use_w4a8_group_scaling = False
weight_dtype = self.w3_w1_weight.dtype
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
if self.balancer_layer is None:
token_selected_slots = token_selected_experts
else:
# If attention DP is enabled, token_selected_experts is a local rank tensor,
# so we need to offset the round robin position by ep_rank
token_selected_slots = self.balancer_layer.route(
token_selected_experts, offset_by_ep_rank=self.use_dp)
# If load balancer is disabled, the statistics are collected from expert IDs.
# If load balancer is enabled, the statistics are collected from expert slot IDs.
ExpertStatistic.set_layer(self.layer_idx)
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
assert token_selected_slots.shape[
assert token_selected_experts.shape[
1] == self.routing_method.experts_per_token
assert token_selected_slots.shape == token_final_scales.shape
assert token_selected_slots.shape[0] == router_logits.shape[0]
assert token_selected_experts.shape == token_final_scales.shape
assert token_selected_experts.shape[0] == router_logits.shape[0]
assert token_final_scales.dtype == torch.float32
assert token_selected_slots.dtype == torch.int32
assert token_selected_experts.dtype == torch.int32
if self.apply_router_weight_on_input:
assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
@ -310,13 +314,32 @@ class CutlassFusedMoE(MoE):
alltoall_info = None
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_first_call:
self.layer_load_balancer.maybe_cudagraph_done_wait()
need_statistic = False
if self.layer_load_balancer is None:
token_selected_slots = token_selected_experts
else:
token_selected_slots = self.layer_load_balancer.route(
token_selected_experts, self.use_dp)
if not self.layer_load_balancer.is_static_routing():
need_statistic = True
# If load balancer is disabled, the statistics are collected from expert IDs.
# If load balancer is enabled, the statistics are collected from expert slot IDs.
ExpertStatistic.set_layer(self.layer_idx)
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
token_selected_experts_for_statistic = token_selected_experts if need_statistic else None
if self.enable_alltoall:
x, token_selected_slots, token_final_scales, alltoall_info = \
x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
x,
token_selected_slots,
token_final_scales)
token_final_scales,
token_selected_experts_for_statistic)
x_sf = None
if self.has_any_quant:
if self.has_fp8_qdq:
@ -348,8 +371,11 @@ class CutlassFusedMoE(MoE):
if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
) and not self.enable_alltoall:
x, x_sf, token_selected_slots, token_final_scales = allgather(
[x, x_sf, token_selected_slots, token_final_scales],
x, x_sf, token_selected_slots, token_final_scales, token_selected_experts_for_statistic = allgather(
[
x, x_sf, token_selected_slots, token_final_scales,
token_selected_experts_for_statistic
],
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
@ -358,6 +384,12 @@ class CutlassFusedMoE(MoE):
x_sf = reswizzle_sf(x_sf, x_row, x_col,
self.scaling_vector_size)
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
):
self.layer_load_balancer.statistic(
token_selected_experts_for_statistic, is_first_call,
is_last_call)
if self.smart_router and not cutlass_min_latency_mode:
ep_size = self.cluster_size
ep_rank = self.cluster_rank
@ -405,20 +437,29 @@ class CutlassFusedMoE(MoE):
tune_max_num_tokens=self.tune_max_num_tokens,
)
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_last_call:
self.layer_load_balancer.set_cpu_stage()
if cutlass_min_latency_mode:
assert not self.reduce_results
return final_hidden_states
assert not self.enable_alltoall
else:
# Custom op requires all inputs are in the same type.
# Only in cutlass_min_latency_mode, the output is a list of tensors.
# Otherwise, the output should be unpacked as a single tensor.
final_hidden_states = final_hidden_states[0]
if not self.enable_alltoall:
return final_hidden_states
else:
return self.alltoall_combine(final_hidden_states, alltoall_info,
token_count)
if self.enable_alltoall:
final_hidden_states = self.alltoall_combine(final_hidden_states,
alltoall_info,
token_count)
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_last_call:
self.layer_load_balancer.maybe_cudagraph_done_set_cpu_stage()
return final_hidden_states
def forward(
self,
@ -500,6 +541,8 @@ class CutlassFusedMoE(MoE):
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
for idx_chunk, (x, router_logits) in enumerate(
zip(x_list, router_logits_list)):
is_first_call = idx_chunk == 0
is_last_call = idx_chunk == num_chunks - 1
if not self.enable_alltoall:
if idx_chunk % 2 == 0:
with torch.cuda.stream(self.aux_stream):
@ -508,7 +551,8 @@ class CutlassFusedMoE(MoE):
router_logits,
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk] if self.use_dp else None,
use_dp_padding=use_dp_padding)
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
if idx_chunk > 0:
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
@ -521,7 +565,8 @@ class CutlassFusedMoE(MoE):
router_logits,
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk] if self.use_dp else None,
use_dp_padding=use_dp_padding)
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
@ -533,7 +578,8 @@ class CutlassFusedMoE(MoE):
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk]
if self.use_dp else None)
if self.use_dp else None,
repeating_info=(is_first_call, is_last_call))
outputs_list.append(outputs)
if not self.enable_alltoall:
@ -557,32 +603,48 @@ class CutlassFusedMoE(MoE):
outputs = outputs[:all_rank_num_tokens[rank]]
return outputs
def alltoall_prepare_maybe_dispatch(self, all_rank_num_tokens: list,
x: torch.Tensor,
token_selected_slots: torch.Tensor,
token_final_scales: torch.Tensor):
def alltoall_prepare_maybe_dispatch(
self, all_rank_num_tokens: list, x: torch.Tensor,
token_selected_slots: torch.Tensor,
token_final_scales: torch.Tensor,
token_selected_experts_for_statistic: Optional[torch.Tensor]):
top_k = self.routing_method.experts_per_token
expert_count = self.num_experts
# gather router info
max_num_token = max(all_rank_num_tokens)
token_selected_slots = torch.nn.functional.pad(
token_selected_slots,
(0, 0, 0, max_num_token - token_selected_slots.shape[0]),
'constant', self.num_experts)
'constant', self.num_slots)
token_selected_experts_for_statistic = torch.nn.functional.pad(
token_selected_experts_for_statistic,
(0, 0, 0,
max_num_token - token_selected_experts_for_statistic.shape[0]),
'constant', self.num_experts
) if token_selected_experts_for_statistic is not None else None
token_final_scales = torch.nn.functional.pad(
token_final_scales,
(0, 0, 0, max_num_token - token_final_scales.shape[0]))
gathered_token_selected_slots, gathered_token_final_scales = allgather(
[token_selected_slots, token_final_scales], self.mapping, dim=0)
gathered_token_selected_slots, gathered_token_final_scales, gathered_token_selected_experts_for_statistic = allgather(
[
token_selected_slots, token_final_scales,
token_selected_experts_for_statistic
],
self.mapping,
dim=0)
if gathered_token_selected_experts_for_statistic is not None:
gathered_token_selected_experts_for_statistic = torch.flatten(
gathered_token_selected_experts_for_statistic.contiguous(),
start_dim=0,
end_dim=-2)
gathered_token_selected_slots = torch.flatten(
gathered_token_selected_slots.contiguous(), start_dim=0, end_dim=-2)
gathered_token_final_scales = torch.flatten(
gathered_token_final_scales.contiguous(), start_dim=0, end_dim=-2)
gathered_target_rank_ids = MnnvlMoe.compute_target_rank_id(
gathered_token_selected_slots, self.num_experts, self.ep_size)
gathered_token_selected_slots, self.num_slots, self.ep_size)
alltoall_info, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv_prepare(
gathered_target_rank_ids, None, gathered_token_selected_slots,
gathered_token_final_scales, max_num_token, expert_count, top_k,
gathered_token_final_scales, max_num_token, self.num_slots, top_k,
self.ep_rank, self.ep_size)
if not self.use_postquant_alltoall:
@ -593,7 +655,7 @@ class CutlassFusedMoE(MoE):
self.alltoall_workspace,
self.ep_rank, self.ep_size)
return x, token_selected_slots, token_final_scales, alltoall_info
return x, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic, alltoall_info
def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor,
x_row: int, x_col: int,
@ -633,6 +695,63 @@ class CutlassFusedMoE(MoE):
return final_hidden_states
def register_parameter_weight_slot_fn(self, weight_name: str,
local_slot_id: int):
assert hasattr(
self,
weight_name), f"FusedMoE doesn't has weight attr: {weight_name}"
weight_tensor = getattr(self, weight_name).data[local_slot_id]
self.layer_load_balancer.register_weight_slot(local_slot_id,
weight_name,
weight_tensor)
def register_to_fix_weight_fn(self, weight_name: str):
assert hasattr(
self,
weight_name), f"FusedMoE doesn't has weight attr: {weight_name}"
param = getattr(self, weight_name)
weight_tensor = param.detach()
assert isinstance(
weight_tensor,
torch.Tensor), f'weight {weight_name} should be a tensor'
assert weight_tensor.is_contiguous(
), f'weight {weight_name} should be a is_contiguous, shape={weight_tensor.shape}, strides={weight_tensor.is_contiguous()}'
assert weight_tensor.numel() * weight_tensor.element_size() == weight_tensor.untyped_storage().size(),\
f'weight {weight_name} shape={weight_tensor.shape} storage_size = {weight_tensor.untyped_storage().size()}, numel={weight_tensor.numel()}, eltsize={weight_tensor.element_size()}, dtype={weight_tensor.dtype}'
self.layer_load_balancer.fix_tensor(weight_tensor)
param.data = weight_tensor
def register_all_parameter_slot_and_to_fix_weight_fns(
self, weight_and_tensor_dict: Dict[str, torch.Tensor]):
"""
weight_and_tensor_dict: key is the name of the weight, value is the tensor of loaded shared tensor shard.
E.g. if num_experts=256 and 4 GPUs per node, then each rank need to load 256 / 4 = 64 expert weights for host sharing.
By this way, host_tensor_sharer can share the weights and each rank has access to all 256 experts.
"""
for local_slot_id, expert_id in enumerate(
self.initial_local_expert_ids):
for weight_name in weight_and_tensor_dict:
self.layer_load_balancer.add_register_weight_fn(
self.register_parameter_weight_slot_fn,
(weight_name, local_slot_id))
for weight_name in weight_and_tensor_dict:
self.layer_load_balancer.add_to_fix_weight_fn(
self.register_to_fix_weight_fn, (weight_name, ))
local_shared_load_expert_ids = self.layer_load_balancer.get_load_expert_ids(
)
for expert_id in range(self.num_experts):
for weight_name, weight_tensor in weight_and_tensor_dict.items():
if expert_id in local_shared_load_expert_ids:
local_slot_id = local_shared_load_expert_ids.index(
expert_id)
self.layer_load_balancer.host_tensor_sharer.share_host_tensor_with_shape(
expert_id, weight_name, weight_tensor[local_slot_id])
else:
self.layer_load_balancer.host_tensor_sharer.pre_register_host_tensor_with_shape(
expert_id, weight_name, weight_tensor.dtype,
weight_tensor[0].shape)
def load_weights(self, weights: List[Dict]):
assert self._weights_created
assert len(weights) == 1

View File

@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Union
import torch
from ...model_config import ModelConfig, MoeLoadBalancerConfig
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor
from .interface import MoE, MoEWeightLoadingMode
from .quantization import (FP8BlockScalesFusedMoEMethod,
@ -71,18 +71,15 @@ class TRTLLMGenFusedMoE(MoE):
assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE."
assert not self.use_dp, "AttentionDP is not supported in TRTLLMGenFusedMoE."
# A dummy MoeLoadBalancerConfig to generate default initial_global_assignments and initial_local_expert_ids
moe_load_balancer_config = MoeLoadBalancerConfig()
moe_load_balancer_config.setup(num_experts=num_experts,
ep_rank=self.ep_rank,
ep_size=self.ep_size)
self.num_slots = moe_load_balancer_config.num_slots
self.initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments(
layer_idx)
self.expert_size_per_partition = moe_load_balancer_config.num_local_slots
self.slot_start = moe_load_balancer_config.slot_start
self.slot_end = moe_load_balancer_config.slot_end
self.num_slots = self.num_experts
self.expert_size_per_partition = self.num_experts // self.ep_size
self.initial_global_assignments = [
(ep_rank * self.num_experts // self.ep_size + local_slot_id) %
self.num_experts for ep_rank in range(self.ep_size)
for local_slot_id in range(self.expert_size_per_partition)
]
self.slot_start = self.ep_rank * self.expert_size_per_partition
self.slot_end = self.slot_start + self.expert_size_per_partition
self.initial_local_expert_ids = self.initial_global_assignments[
self.slot_start:self.slot_end]
assert len(

View File

@ -1,13 +1,18 @@
import atexit
import platform
import threading
from multiprocessing import shared_memory
from typing import Callable, List, Optional
from contextlib import nullcontext
from multiprocessing import resource_tracker, shared_memory
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from mpi4py import MPI
import tensorrt_llm
import tensorrt_llm.bindings.internal.runtime as _tbr
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import is_graph_capturing
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight:
@ -20,6 +25,7 @@ def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight:
assert t.dim() <= 2, "t.dim() should be less than or equal to 2"
shape = [1, 1]
pitch = 1
elt_size = torch.tensor([], dtype=t.dtype).element_size()
if t.dim() == 2:
shape[0] = t.size(0)
shape[1] = t.size(1)
@ -31,8 +37,8 @@ def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight:
pass
mw = _tbr.MoeWeight()
mw.height = shape[0]
mw.width = shape[1]
mw.pitch = pitch
mw.width = shape[1] * elt_size
mw.pitch = pitch * elt_size
mw.weight_ptr = t.data_ptr()
return mw
@ -42,7 +48,8 @@ class HostMoeTensorSharer:
A class representing a host tensor sharer.
"""
def __init__(self, layer_id: int, shared_mpi_comm: MPI.Comm):
def __init__(self, layer_id: int, expert_count: int,
shared_mpi_comm: MPI.Comm):
"""
Initialize a HostMoeTensorSharer instance.
@ -51,11 +58,24 @@ class HostMoeTensorSharer:
"""
self.shared_mpi_comm = shared_mpi_comm
self.layer_id = layer_id
self.expert_count = expert_count
self.shared_memory_base_name = None
self.host_tensor_shapes = []
self.local_rank = self.shared_mpi_comm.Get_rank()
self.local_size = self.shared_mpi_comm.Get_size()
self.expert_start = self.local_rank * self.expert_count // self.local_size
self.expert_end = (self.local_rank +
1) * self.expert_count // self.local_size
self.name_info = {} # key is weight name, value is (dtype, shape)
self.host_weights = {}
self.own_shms = {}
self.all_shms = []
self.own_shm = None
self.imported_shms = []
self.shared_tensors = {}
self.names = []
def set_shared_memory_base_name(self, shared_memory_base_name):
"""
@ -66,17 +86,19 @@ class HostMoeTensorSharer:
"""
self.shared_memory_base_name = shared_memory_base_name
def get_shared_memory_name(self, expert_id: int, name: str):
def get_shared_memory_name(self, rank: Optional[int] = None):
"""
Get the shared memory name for the layer.
Args:
expert_id: The ID of the expert
name: The name of the weight
rank: The rank who created the shared memory. Current rank if None
"""
if rank is None:
rank = self.local_rank
assert 0 <= rank < self.local_size
assert isinstance(self.shared_memory_base_name,
str), "self.shared_memory_base_name must be a string"
shared_memory_name = f"{self.shared_memory_base_name}_l{self.layer_id}_e{expert_id}_{name}"
shared_memory_name = f"{self.shared_memory_base_name}_l{self.layer_id}_lr{rank}_all"
return shared_memory_name
def pre_register_host_tensor_with_shape(self, expert_id: int, name: str,
@ -96,7 +118,13 @@ class HostMoeTensorSharer:
"""
assert len(tensor_shape
) <= 2, "tensor_shape dim must be less than or equal to 2"
self.host_tensor_shapes.append((expert_id, name, dtype, tensor_shape))
assert 0 <= expert_id < self.expert_count
assert expert_id < self.expert_start or expert_id >= self.expert_end
if name not in self.name_info:
self.name_info[name] = (dtype, tensor_shape)
else:
assert dtype == self.name_info[name][0] and tensor_shape == self.name_info[name][1], \
f'weights name={name}, dtype={dtype}, shape={tensor_shape}, but already registered with dtype={self.name_info[name][0]}, shape={self.name_info[name][1]}'
def share_host_tensor_with_shape(self, expert_id: int, name: str,
t: torch.Tensor):
@ -111,38 +139,99 @@ class HostMoeTensorSharer:
name: The name of the weight
t: The weight tensor
"""
assert len(
t.shape) <= 2, "tensor_shape dim must be less than or equal to 2"
assert t.is_contiguous() == True, "t.is_contiguous() must be True"
shm_name = self.get_shared_memory_name(expert_id, name)
shm = shared_memory.SharedMemory(name=shm_name,
create=True,
size=t.numel() * t.element_size())
shm.buf[:t.numel() * t.element_size()] = t.numpy().tobytes()
assert (expert_id, name) not in self.shared_tensors.keys()
assert self.expert_start <= expert_id < self.expert_end
self.shared_tensors[(expert_id, name)] = t
dtype = t.dtype
tensor_shape = t.shape
t = torch.frombuffer(shm.buf,
dtype=dtype).view(tensor_shape).pin_memory()
key = (expert_id, name)
assert key not in self.host_weights.keys(), f"key={key} already exists"
self.host_weights[key] = t
self.own_shms[(expert_id, name)] = shm
self.all_shms.append(shm)
atexit.register(shm.unlink)
if name not in self.name_info:
self.name_info[name] = (dtype, tensor_shape)
else:
assert dtype == self.name_info[name][0] and tensor_shape == self.name_info[name][1], \
f'weights name={name}, dtype={dtype}, shape={tensor_shape}, but already registered with dtype={self.name_info[name][0]}, shape={self.name_info[name][1]}'
@staticmethod
def align_size(size: int):
return (size + 256 - 1) // 256 * 256
def finalize_layer_weights(self):
self.names = list(sorted(self.name_info.keys()))
assert len(
self.shared_tensors.keys()) == (self.expert_end -
self.expert_start) * len(self.names)
total_size = 0
for name in self.names:
dtype, shape = self.name_info[name]
for expert_id in range(self.expert_start, self.expert_end):
t = self.shared_tensors[(expert_id, name)]
assert dtype == t.dtype and shape == t.shape
data_size = t.numel() * t.element_size()
aligned_size = self.align_size(data_size)
total_size += aligned_size
shm_name = self.get_shared_memory_name()
shm = shared_memory.SharedMemory(name=shm_name,
create=True,
size=total_size)
self.own_shm = shm
offset = 0
for name in self.names:
for expert_id in range(self.expert_start, self.expert_end):
t = self.shared_tensors[(expert_id, name)]
data_size = t.numel() * t.element_size()
aligned_size = self.align_size(data_size)
shm.buf[offset:offset + data_size] = t.numpy().tobytes()
dtype = t.dtype
tensor_shape = t.shape
elt_count = t.numel()
st = torch.frombuffer(shm.buf,
dtype=dtype,
offset=offset,
count=elt_count).view(tensor_shape)
key = (expert_id, name)
assert key not in self.host_weights.keys(
), f"key={key} already exists"
self.host_weights[key] = st
offset += aligned_size
self.shared_tensors = {}
def finalize_host_tensor_sharing(self, add_host_weight_fn: Callable = None):
"""
Finalize the host tensor sharing.
"""
for expert_weight_info in self.host_tensor_shapes:
expert_id, name, dtype, tensor_shape = expert_weight_info
shm_name = self.get_shared_memory_name(expert_id, name)
for rank in range(self.local_size):
if rank == self.local_rank:
continue
shm_name = self.get_shared_memory_name(rank)
shm = shared_memory.SharedMemory(name=shm_name)
self.all_shms.append(shm)
t = torch.frombuffer(shm.buf,
dtype=dtype).view(tensor_shape).pin_memory()
key = (expert_id, name)
assert key not in self.host_weights.keys(
), f"key={key} already exists"
self.host_weights[key] = t
self.imported_shms.append(shm)
rank_expert_start = rank * self.expert_count // self.local_size
rank_expert_end = (rank + 1) * self.expert_count // self.local_size
offset = 0
for name in self.names:
dtype, shape = self.name_info[name]
elt_count = int(np.prod(shape))
data_size = torch.tensor([],
dtype=dtype).element_size() * elt_count
aligned_size = self.align_size(data_size)
for expert_id in range(rank_expert_start, rank_expert_end):
t = torch.frombuffer(shm.buf,
dtype=dtype,
offset=offset,
count=elt_count).view(shape)
key = (expert_id, name)
assert key not in self.host_weights.keys(
), f"key={key} already exists"
self.host_weights[key] = t
offset += aligned_size
if add_host_weight_fn is not None:
for key, t in self.host_weights.items():
@ -154,8 +243,20 @@ class HostMoeTensorSharer:
"""
Clean up the resources before C++ shutdown and barrier
"""
for shm in self.all_shms:
for shm in self.imported_shms:
shm.close()
resource_tracker.unregister(shm._name, "shared_memory")
self.imported_shms = None
if self.own_shm:
self.own_shm.close()
def post_shutdown_cleanup(self):
"""
Clean up the resources before C++ shutdown and barrier
"""
if self.own_shm:
self.own_shm.unlink()
self.own_shm = None
class SingleLayerMoeLoadBalancer:
@ -167,19 +268,56 @@ class SingleLayerMoeLoadBalancer:
def __init__(
self,
single_layer_load_balancer_impl: _tbr.SingleLayerMoeLoadBalancer,
shared_mpi_comm: MPI.Comm):
shared_mpi_comm: MPI.Comm,
expert_count: int,
updates_enabled: bool = True):
"""
Initialize a SingleLayerMoeLoadBalancer instance.
Args:
single_layer_load_balancer_impl: The C++ implementation of SingleLayerMoeLoadBalancer
shared_mpi_comm: The MPI communicator for shared memory
expert_count: total number of experts
updates_enabled: whether to enable weight updates
"""
self.single_layer_load_balancer_impl = single_layer_load_balancer_impl
self.single_layer_load_balancer_ptr = single_layer_load_balancer_impl.get_pointer(
)
self.expert_count = expert_count
self.updates_enabled = updates_enabled
layer_id = self.single_layer_load_balancer_impl.get_layer_id()
self.host_tensor_sharer = HostMoeTensorSharer(shared_mpi_comm, layer_id)
self.host_tensor_sharer = HostMoeTensorSharer(
layer_id, expert_count,
shared_mpi_comm) if self.updates_enabled else None
self.register_weight_fns = []
self.to_fix_weight_fns = []
shared_rank = shared_mpi_comm.Get_rank()
shared_size = shared_mpi_comm.Get_size()
load_expert_start = shared_rank * self.expert_count // shared_size
load_expert_end = min(
(shared_rank + 1) * self.expert_count // shared_size,
self.expert_count)
self.load_expert_ids = list(range(load_expert_start, load_expert_end))
self.statistic_flag_tensor = None
self.cudagraph_stream = None
self.cudagraph_event = None
def get_layer_idx(self):
return self.single_layer_load_balancer_impl.get_layer_id()
def get_load_expert_ids(self):
assert self.updates_enabled, "should not call get_load_expert_ids when using statistic routing"
return self.load_expert_ids
def is_static_routing(self):
return not self.updates_enabled
def need_load_shared_weights(self):
return self.updates_enabled
def set_shared_memory_base_name(self, shared_memory_base_name):
"""
@ -188,8 +326,9 @@ class SingleLayerMoeLoadBalancer:
Args:
shared_memory_base_name: The base name for the shared memory
"""
self.host_tensor_sharer.set_shared_memory_base_name(
shared_memory_base_name)
if self.updates_enabled:
self.host_tensor_sharer.set_shared_memory_base_name(
shared_memory_base_name)
def _add_weight_slot(self, slot_id: int, name: str,
weight_slot: _tbr.MoeWeight):
@ -201,20 +340,21 @@ class SingleLayerMoeLoadBalancer:
name: The name of the weight
weight_slot: The weight object
"""
self.single_layer_load_balancer_impl.add_weight_slot(
self.single_layer_load_balancer_impl.add_single_weight_slot(
slot_id, name, weight_slot)
def register_weight_slot(self, slot_id: int, name: str, t: torch.Tensor):
def register_weight_slot(self, local_slot_id: int, name: str,
t: torch.Tensor):
"""
Register a weight slot to the layer.
Args:
slot_id: The ID of the slot
local_slot_id: The ID of the slot at local rank
name: The name of the weight
t: The weight tensor
"""
moe_weight = _tensor_to_weight(t)
self._add_weight_slot(slot_id, name, moe_weight)
self._add_weight_slot(local_slot_id, name, moe_weight)
def _add_host_weight(self, expert_id: int, name: str,
host_weight: _tbr.MoeWeight):
@ -226,7 +366,7 @@ class SingleLayerMoeLoadBalancer:
name: The name of the weight
host_weight: The host weight object
"""
self.single_layer_load_balancer_impl.add_host_weight(
self.single_layer_load_balancer_impl.add_single_host_weight(
expert_id, name, host_weight)
def _add_host_weight_from_tensor(self, expert_id: int, name: str,
@ -248,46 +388,128 @@ class SingleLayerMoeLoadBalancer:
self.single_layer_load_balancer_impl.set_initial_weight_assignments(
initial_weight_assignments)
def add_to_fix_weight_fn(self,
fn: Callable,
args: Tuple,
kwargs: Dict = {}):
self.to_fix_weight_fns.append((fn, args, kwargs))
def add_register_weight_fn(self,
fn: Callable,
args: Tuple,
kwargs: Dict = {}):
"""
Add weight register function, this function doesn't run fn directly but run all functions after model.to("cuda")
so this function can be called when model is not on GPU yet.
"""
self.register_weight_fns.append((fn, args, kwargs))
def fix_tensor(self, wt: torch.Tensor):
torch.ops.trtllm.migrate_to_managed(wt)
def register_weight_slots_after_to_cuda(self):
"""
Register weights after model has been moved to cuda, should be invoked after model.to("cuda") and before finalize_model.
"""
for fn, args, kwargs in self.to_fix_weight_fns:
fn(*args, **kwargs)
self.to_fix_weight_fns = []
for fn, args, kwargs in self.register_weight_fns:
fn(*args, **kwargs)
self.register_weight_fns = []
def py_finalize_model(self):
"""
Finalize the model after all layers have been added.
This must be called before starting any iterations.
"""
self.host_tensor_sharer.finalize_host_tensor_sharing(
self._add_host_weight_from_tensor)
if self.updates_enabled:
self.host_tensor_sharer.finalize_host_tensor_sharing(
self._add_host_weight_from_tensor)
def wait_for_gpu_stage(self) -> torch.Tensor:
def wait_for_gpu_stage(self) -> Optional[torch.Tensor]:
"""
Wait for the GPU stage to complete.
Returns:
A tensor indicating whether the stage is enabled
"""
return torch.ops.trtllm.moe_load_balance_wait_gpu_stage(
self.single_layer_load_balancer_ptr)
if self.updates_enabled:
assert self.statistic_flag_tensor is None, \
"Already has statistic_flag_tensor, should not wait."
if is_graph_capturing():
self.cudagraph_event = torch.cuda.Event()
self.cudagraph_stream = torch.cuda.Stream()
current_stream_event = torch.cuda.Event()
current_stream_event.record(torch.cuda.current_stream())
with torch.cuda.stream(self.cudagraph_stream):
current_stream_event.wait()
self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage(
self.single_layer_load_balancer_ptr)
self.cudagraph_event.record(self.cudagraph_stream)
else:
self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage(
self.single_layer_load_balancer_ptr)
return self.statistic_flag_tensor
else:
return
def maybe_cudagraph_done_wait(self):
if self.updates_enabled:
if is_graph_capturing():
assert self.cudagraph_event is not None, "should have cudagraph_event when capturing"
assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing"
self.cudagraph_event.wait()
def set_cpu_stage(self):
"""
Set the CPU stage.
"""
torch.ops.trtllm.moe_load_balance_set_cpu_stage(
self.single_layer_load_balancer_ptr)
if self.updates_enabled:
assert self.statistic_flag_tensor is not None, \
"Doesn't have statistic_flag_tensor, should not set_cpu_stage."
self.statistic_flag_tensor = None
if is_graph_capturing():
assert self.cudagraph_stream is not None, "Doesn't have cudagraph_stream, should not set_cpu_stage."
current_stream_event = torch.cuda.Event()
current_stream_event.record(torch.cuda.current_stream())
with torch.cuda.stream(self.cudagraph_stream):
current_stream_event.wait()
torch.ops.trtllm.moe_load_balance_set_cpu_stage(
self.single_layer_load_balancer_ptr)
self.cudagraph_event.record(self.cudagraph_stream)
else:
torch.ops.trtllm.moe_load_balance_set_cpu_stage(
self.single_layer_load_balancer_ptr)
def maybe_cudagraph_done_set_cpu_stage(self):
if self.updates_enabled:
if is_graph_capturing():
assert self.cudagraph_event is not None, "should have cudagraph_event when capturing"
assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing"
self.cudagraph_event.wait()
self.cudagraph_stream = None
self.cudagraph_event = None
def statistic(self, gathered_raw_expert_ids: torch.Tensor,
enabled: torch.Tensor, is_first_stage: bool,
is_last_stage: bool):
is_first_stage: bool, is_last_stage: bool):
"""
Perform statistics on the expert IDs.
Args:
gathered_raw_expert_ids: The gathered raw expert IDs from all ranks
enabled: A tensor indicating whether the operation is enabled
is_first_stage: Whether this is the first stage
is_last_stage: Whether this is the last stage
"""
torch.ops.trtllm.moe_load_balance_statistic(
gathered_raw_expert_ids, enabled,
self.single_layer_load_balancer_ptr, is_first_stage, is_last_stage)
if self.updates_enabled:
assert isinstance(self.statistic_flag_tensor, torch.Tensor)
torch.ops.trtllm.moe_load_balance_statistic(
gathered_raw_expert_ids, self.statistic_flag_tensor,
self.single_layer_load_balancer_ptr, is_first_stage,
is_last_stage)
def route(self,
token_selected_experts: torch.Tensor,
@ -310,7 +532,15 @@ class SingleLayerMoeLoadBalancer:
"""
Clean up the resources before C++ shutdown and barrier
"""
self.host_tensor_sharer.pre_shutdown_cleanup()
if self.updates_enabled:
self.host_tensor_sharer.pre_shutdown_cleanup()
def py_post_shutdown_cleanup(self):
"""
Clean up the resources after C++ shutdown and barrier
"""
if self.updates_enabled:
self.host_tensor_sharer.post_shutdown_cleanup()
# Global variable to store the current active MoeLoadBalancer instance
@ -346,6 +576,21 @@ class MoeLoadBalancer:
self.single_layer_load_balancers = []
self.shared_memory_base_name = shared_memory_base_name
self._setup_mpi_comm()
self.is_shutdown = False
self.iter_id = 0
self.in_iter = False
self.enable_statistic = False
self.enable_update_weights = False
def __del__(self):
if not self.is_shutdown:
self.shutdown()
def is_static_routing(self):
# if we don't update, then it is statistic routing.
return self.layer_updates_per_iter == 0
def _setup_mpi_comm(self):
global_mpi_comm = tensorrt_llm.mpi_comm()
@ -357,6 +602,9 @@ class MoeLoadBalancer:
f"Interesting, shared size {shared_size} is not same as local size {local_size}"
self.shared_mpi_comm = shared_mpi_comm
def set_use_gpu_memcpy(self, use_gpu_memcpy: bool):
self.load_balancer_impl.set_use_gpu_memcpy(use_gpu_memcpy)
def add_layer(self, expert_count: int, top_k: int,
slot_count_per_rank: int) -> SingleLayerMoeLoadBalancer:
"""
@ -372,13 +620,24 @@ class MoeLoadBalancer:
"""
single_layer_load_balancer_impl = self.load_balancer_impl.add_layer(
expert_count, top_k, slot_count_per_rank)
updates_enabled = not self.is_static_routing()
single_layer_load_balancer = SingleLayerMoeLoadBalancer(
single_layer_load_balancer_impl, self.shared_mpi_comm)
single_layer_load_balancer_impl,
self.shared_mpi_comm,
expert_count,
updates_enabled=updates_enabled)
single_layer_load_balancer.set_shared_memory_base_name(
self.shared_memory_base_name)
self.single_layer_load_balancers.append(single_layer_load_balancer)
return single_layer_load_balancer
def register_weight_slots_after_to_cuda(self):
"""
Register weights after model has been moved to cuda, should be invoked after model.to("cuda") and before finalize_model.
"""
for layer in self.single_layer_load_balancers:
layer.register_weight_slots_after_to_cuda()
def finalize_model(self):
"""
Finalize the model after all layers have been added.
@ -391,6 +650,7 @@ class MoeLoadBalancer:
for single_layer_load_balancer in self.single_layer_load_balancers:
single_layer_load_balancer.py_finalize_model()
self.load_balancer_impl.finalize_model()
torch.cuda.empty_cache()
def set_warm_up_iter_count(self, iter_count: int):
"""
@ -401,27 +661,33 @@ class MoeLoadBalancer:
"""
self.load_balancer_impl.set_warm_up_iter_count(iter_count)
def start_iter(self, iter_id: int, enable_statistic: bool,
enable_update_weights: bool):
def set_next_iter_info(self, enable_statistic: Optional[bool],
enable_update_weights: Optional[bool]):
if enable_statistic is not None:
self.enable_statistic = enable_statistic
if enable_update_weights is not None:
self.enable_update_weights = enable_update_weights
def start_iter(self):
"""
Start a new iteration.
Args:
iter_id: The ID of the iteration
enable_statistic: Whether to enable statistics collection
enable_update_weights: Whether to enable weight updates
"""
self.load_balancer_impl.start_iter(iter_id, enable_statistic,
enable_update_weights)
assert self.in_iter == False, "already in forward"
self.in_iter = True
self.load_balancer_impl.start_iter(self.iter_id, self.enable_statistic,
self.enable_update_weights)
def end_iter(self, iter_id: int):
def end_iter(self):
"""
End the current iteration.
Args:
iter_id: The ID of the iteration to end
"""
self.load_balancer_impl.end_iter(iter_id)
assert self.in_iter, "not in forward, cannot end_iter"
self.load_balancer_impl.end_iter(self.iter_id)
self.in_iter = False
self.iter_id += 1
def shutdown(self):
"""
@ -432,6 +698,10 @@ class MoeLoadBalancer:
self.load_balancer_impl.shutdown()
# use this sync to make sure all the shm resources can be cleaned up
self.shared_mpi_comm.barrier()
for single_layer_load_balancer in self.single_layer_load_balancers:
single_layer_load_balancer.py_post_shutdown_cleanup()
self.shared_mpi_comm.barrier()
self.is_shutdown = True
def __repr__(self):
"""
@ -482,6 +752,80 @@ class MoeLoadBalancer:
return False
moe_model_arch_list = [
'DeepseekV3ForCausalLM',
]
def maybe_create_moe_load_balancer(
model_config, mapping: Optional[Mapping]) -> Optional[MoeLoadBalancer]:
ep_rank = model_config.mapping.moe_ep_rank
ep_size = model_config.mapping.moe_ep_size
model_arch = model_config.pretrained_config.architectures[0]
using_ep = mapping and mapping.moe_ep_size > 1
in_supported_model_arch = model_arch in moe_model_arch_list
using_smart_router = mapping and mapping.moe_cluster_size > 1
moe_load_balancer = nullcontext()
if in_supported_model_arch and using_ep and not using_smart_router and model_config.moe_load_balancer is not None:
model_config.moe_load_balancer.setup(ep_rank=ep_rank, ep_size=ep_size)
if model_config.moe_load_balancer.layer_updates_per_iter > 0:
# TODO: remove this when supported.
cpu_arch = platform.machine().lower()
assert cpu_arch == 'aarch64', "online load balancer only support aarch64, e.g. GB200 now, x86 coming soon."
moe_load_balancer = MoeLoadBalancer(
ep_rank=ep_rank,
ep_size=ep_size,
layer_updates_per_iter=model_config.moe_load_balancer.
layer_updates_per_iter)
logger.info(
f"Created MoE LoadBalancer, layer_updates_per_iter={model_config.moe_load_balancer.layer_updates_per_iter}..."
)
return moe_load_balancer
class MoeLoadBalancerIterContext:
def __init__(self,
moe_load_balancer: Optional[MoeLoadBalancer],
enable_statistic: Optional[bool] = None,
enable_updates: Optional[bool] = None):
self.moe_load_balancer = moe_load_balancer
self.enable_statistic = enable_statistic
self.enable_updates = enable_updates
def __enter__(self):
"""
Enter the context manager.
Returns:
The MoeLoadBalancerIterContext instance
"""
if self.moe_load_balancer is not None and not self.moe_load_balancer.is_static_routing(
):
self.moe_load_balancer.set_next_iter_info(self.enable_statistic,
self.enable_updates)
self.moe_load_balancer.start_iter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Exit the context manager.
Args:
exc_type: The exception type
exc_val: The exception value
exc_tb: The exception traceback
Returns:
False to not suppress exceptions
"""
if self.moe_load_balancer is not None and not self.moe_load_balancer.is_static_routing(
):
self.moe_load_balancer.end_iter()
return False
def get_moe_load_balancer() -> Optional[MoeLoadBalancer]:
"""
Get the current active MoeLoadBalancer instance.

View File

@ -61,6 +61,14 @@ class FusedMoEMethodBase(ABC):
Base class for all fused MoE methods.
"""
def need_load_shared_weights(self, module):
if hasattr(
module, "layer_load_balancer"
) and module.layer_load_balancer and module.layer_load_balancer.need_load_shared_weights(
):
return True
return False
def create_weights(self, module: torch.nn.Module, weight_dtype: torch.dtype,
w3_w1_weight_shape: tuple[int, int, int],
w2_weight_shape: tuple[int, int, int]):
@ -76,12 +84,15 @@ class FusedMoEMethodBase(ABC):
requires_grad=False)
module.register_parameter("w2_weight", w2_weight)
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):
def load_expert_weights_to_dst(self, module: torch.nn.Module,
weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode,
load_expert_ids: List[int],
dst_w3_w1_weights_tensor: torch.Tensor,
dst_w2_weights_tensor: torch.Tensor):
# Multithread weight load is superseded by prefetch_files() in model_engine.py
# Also, threading adds overhead in order to protect shuffle index cache with critical section.
for local_slot_id, expert_id in enumerate(
module.initial_local_expert_ids):
for local_slot_id, expert_id in enumerate(load_expert_ids):
# expert_idx is the local slot index of current rank
expert_idx = local_slot_id
@ -101,15 +112,55 @@ class FusedMoEMethodBase(ABC):
)
self.load_expert_w3_w1_weight(module, w1_weight, w3_weight,
module.w3_w1_weight.data[expert_idx])
dst_w3_w1_weights_tensor[expert_idx])
self.load_expert_w2_weight(module, w2_weight,
module.w2_weight.data[expert_idx])
dst_w2_weights_tensor[expert_idx])
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):
self.load_expert_weights_to_dst(module, weights, weight_loading_mode,
module.initial_local_expert_ids,
module.w3_w1_weight.data,
module.w2_weight.data)
self.load_quant_scales(module, weights)
# Re-setup quant scales after loading weights as the tensors may have been modified.
self.setup_quant_scales(module)
if self.need_load_shared_weights(module):
local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids(
)
local_shared_w3_w1_tensors = torch.empty(
(len(local_shared_load_expert_ids), ) +
module.w3_w1_weight.data.shape[1:],
dtype=module.w3_w1_weight.data.dtype,
device='cpu')
local_shared_w2_tensors = torch.empty(
(len(local_shared_load_expert_ids), ) +
module.w2_weight.data.shape[1:],
dtype=module.w2_weight.data.dtype,
device='cpu')
self.load_expert_weights_to_dst(module, weights,
weight_loading_mode,
local_shared_load_expert_ids,
local_shared_w3_w1_tensors,
local_shared_w2_tensors)
module.register_all_parameter_slot_and_to_fix_weight_fns({
'w3_w1_weight':
local_shared_w3_w1_tensors,
'w2_weight':
local_shared_w2_tensors
})
module.layer_load_balancer.host_tensor_sharer.finalize_layer_weights(
)
if hasattr(module,
"layer_load_balancer") and module.layer_load_balancer:
module.layer_load_balancer.set_initial_weight_assignments(
module.initial_global_assignments)
def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]):
pass
@ -828,6 +879,50 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
w2_weight_scale_2 = 1.0 / w2_weight_scale_2[...].reshape([])
dst_w2_alpha.copy_(1.0 / (final_fc2_input_scale * w2_weight_scale_2))
def load_all_fp4_weight_scales_and_alphas(
self, module: torch.nn.Module, weights: Dict,
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
dst_w2_weight_scale: torch.Tensor, dst_fc31_alpha: torch.Tensor,
dst_fc2_alpha: torch.Tensor):
for local_slot_id, expert_id in enumerate(load_expert_ids):
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"]
w1_weight_scale_2 = weights[f"{expert_id}.w1.weight_scale_2"]
w3_weight_scale_2 = weights[f"{expert_id}.w3.weight_scale_2"]
w2_weight_scale_2 = weights[f"{expert_id}.w2.weight_scale_2"]
elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w1_w3_weight_scale = weights["gate_up_proj_weight_scale"][
expert_id].transpose(0, 1).contiguous()
w1_weight_scale, w3_weight_scale = w1_w3_weight_scale.chunk(
2, dim=0)
w2_weight_scale = weights["down_proj_weight_scale"][
expert_id].transpose(0, 1).contiguous()
w1_weight_scale_2 = weights["gate_up_proj_weight_scale_2"]
w3_weight_scale_2 = weights["gate_up_proj_weight_scale_2"]
w2_weight_scale_2 = weights["down_proj_weight_scale_2"]
else:
raise NotImplementedError(
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
)
expert_idx = local_slot_id
self.load_expert_w3_w1_weight_scale_nvfp4(
module, w1_weight_scale, w3_weight_scale,
dst_w3_w1_weight_scale[expert_idx])
self.load_expert_w2_weight_scale_nvfp4(
module, w2_weight_scale, dst_w2_weight_scale[expert_idx])
self.load_expert_fc31_alpha_nvfp4(w1_weight_scale_2,
w3_weight_scale_2,
module.fc31_input_scale.data,
dst_fc31_alpha[expert_idx])
self.load_expert_fc2_alpha_nvfp4(w2_weight_scale_2,
module.fc2_input_scale.data,
dst_fc2_alpha[expert_idx])
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
# Step1: Load input scales.
tmp_fc31_input_scale = torch.empty(module.num_experts,
@ -862,46 +957,50 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
tmp_fc2_input_scale.max().reciprocal())
# Step2: Load weight block scales and alphas.
for local_slot_id, expert_id in enumerate(
module.initial_local_expert_ids):
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"]
w1_weight_scale_2 = weights[f"{expert_id}.w1.weight_scale_2"]
w3_weight_scale_2 = weights[f"{expert_id}.w3.weight_scale_2"]
w2_weight_scale_2 = weights[f"{expert_id}.w2.weight_scale_2"]
elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w1_w3_weight_scale = weights["gate_up_proj_weight_scale"][
expert_id].transpose(0, 1).contiguous()
w1_weight_scale, w3_weight_scale = w1_w3_weight_scale.chunk(
2, dim=0)
w2_weight_scale = weights["down_proj_weight_scale"][
expert_id].transpose(0, 1).contiguous()
w1_weight_scale_2 = weights["gate_up_proj_weight_scale_2"]
w3_weight_scale_2 = weights["gate_up_proj_weight_scale_2"]
w2_weight_scale_2 = weights["down_proj_weight_scale_2"]
else:
raise NotImplementedError(
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
)
self.load_all_fp4_weight_scales_and_alphas(
module, weights, module.initial_local_expert_ids,
module.w3_w1_weight_scale.data, module.w2_weight_scale.data,
module.fc31_alpha.data, module.fc2_alpha.data)
expert_idx = local_slot_id
# Step 3: if need load into shared
if self.need_load_shared_weights(module):
local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids(
)
local_shared_w3_w1_scale_tensors = torch.empty(
(len(local_shared_load_expert_ids), ) +
module.w3_w1_weight_scale.data.shape[1:],
dtype=module.w3_w1_weight_scale.data.dtype,
device='cpu')
local_shared_w2_scale_tensors = torch.empty(
(len(local_shared_load_expert_ids), ) +
module.w2_weight_scale.data.shape[1:],
dtype=module.w2_weight_scale.data.dtype,
device='cpu')
local_shared_fc31_alpha_tensors = torch.empty(
(len(local_shared_load_expert_ids), ) +
module.fc31_alpha.data.shape[1:],
dtype=module.fc31_alpha.data.dtype,
device='cpu')
local_shared_fc2_alpha_tensors = torch.empty(
(len(local_shared_load_expert_ids), ) +
module.fc2_alpha.data.shape[1:],
dtype=module.fc2_alpha.data.dtype,
device='cpu')
self.load_all_fp4_weight_scales_and_alphas(
module, weights, local_shared_load_expert_ids,
local_shared_w3_w1_scale_tensors, local_shared_w2_scale_tensors,
local_shared_fc31_alpha_tensors, local_shared_fc2_alpha_tensors)
self.load_expert_w3_w1_weight_scale_nvfp4(
module, w1_weight_scale, w3_weight_scale,
module.w3_w1_weight_scale.data[expert_idx])
self.load_expert_w2_weight_scale_nvfp4(
module, w2_weight_scale,
module.w2_weight_scale.data[expert_idx])
self.load_expert_fc31_alpha_nvfp4(
w1_weight_scale_2, w3_weight_scale_2,
module.fc31_input_scale.data,
module.fc31_alpha.data[expert_idx])
self.load_expert_fc2_alpha_nvfp4(w2_weight_scale_2,
module.fc2_input_scale.data,
module.fc2_alpha.data[expert_idx])
module.register_all_parameter_slot_and_to_fix_weight_fns({
'w3_w1_weight_scale':
local_shared_w3_w1_scale_tensors,
'w2_weight_scale':
local_shared_w2_scale_tensors,
'fc31_alpha':
local_shared_fc31_alpha_tensors,
'fc2_alpha':
local_shared_fc2_alpha_tensors,
})
def setup_quant_scales(self, module: torch.nn.Module):
module.quant_scales = FusedMoEQuantScalesNVFP4(

View File

@ -1,5 +1,6 @@
import bisect
import contextlib
import functools
import gc
import glob
import inspect
@ -11,6 +12,7 @@ import traceback
import weakref
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
import psutil
@ -47,6 +49,8 @@ from ..model_config import ModelConfig, MoeLoadBalancerConfig
from ..models import AutoModelForCausalLM
from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode,
timing)
from ..modules.fused_moe.moe_load_balancer import (
MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer)
from ..speculative import SpecConfig, SpecMetadata, get_spec_metadata
from ..utils import (get_model_extra_attrs, set_torch_compiling,
with_model_extra_attrs)
@ -323,6 +327,8 @@ class PyTorchModelEngine(ModelEngine):
# py_executor.py for how this is used.
self.last_spec_metadata = None
self.in_warmup = False
self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures(
)
@ -470,6 +476,25 @@ class PyTorchModelEngine(ModelEngine):
hidden_size=self.model.config.hidden_size,
dtype=torch_dtype_to_str(self.model.config.torch_dtype))
@contextmanager
def set_warmup_flag(self):
self.in_warmup = True
try:
yield
finally:
self.in_warmup = False
@staticmethod
def with_warmup_flag(method):
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
with self.set_warmup_flag():
return method(self, *args, **kwargs)
return wrapper
@with_warmup_flag
def warmup(self, resource_manager: ResourceManager) -> None:
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
@ -962,7 +987,8 @@ class PyTorchModelEngine(ModelEngine):
getattr(config.pretrained_config,
sub_config).num_hidden_layers = num_layers
with timing("Model init total"):
with timing("Model init total"), maybe_create_moe_load_balancer(
config, self.mapping) as moe_load_balancer:
try:
with MetaInitMode():
model = AutoModelForCausalLM.from_config(config)
@ -1017,6 +1043,13 @@ class PyTorchModelEngine(ModelEngine):
raise NotImplementedError(
f"No load support for load format: {load_format}")
if isinstance(moe_load_balancer, MoeLoadBalancer):
setattr(self, "moe_load_balancer", moe_load_balancer)
moe_load_balancer.register_weight_slots_after_to_cuda()
logger.info("moe_load_balancer finalizing model...")
moe_load_balancer.finalize_model()
logger.info("moe_load_balancer finalize model done")
torch.cuda.current_stream().synchronize()
return model
@ -1950,6 +1983,15 @@ class PyTorchModelEngine(ModelEngine):
else:
spec_metadata = None
moe_load_balancer = None
if hasattr(self, 'moe_load_balancer'):
moe_load_balancer = getattr(self, 'moe_load_balancer')
if not self.in_warmup:
moe_enable_statistic = True
moe_enable_update = True
moe_load_balancer.set_next_iter_info(moe_enable_statistic,
moe_enable_update)
if kv_cache_manager is None:
inputs, gather_ids = self._prepare_tp_inputs_no_cache(
scheduled_requests, attn_metadata, spec_metadata)
@ -1957,7 +1999,9 @@ class PyTorchModelEngine(ModelEngine):
inputs.update(extra_model_inputs)
self.last_spec_metadata = spec_metadata
return self._forward_step(inputs, gather_ids, gather_context_logits)
with MoeLoadBalancerIterContext(moe_load_balancer):
return self._forward_step(inputs, gather_ids,
gather_context_logits)
with self._maybe_pad_batch(scheduled_requests,
kv_cache_manager) as scheduled_requests:
@ -1984,21 +2028,32 @@ class PyTorchModelEngine(ModelEngine):
self.iter_counter += 1
if maybe_graph is None:
outputs = self._forward_step(inputs, gather_ids,
gather_context_logits)
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = self._forward_step(inputs, gather_ids,
gather_context_logits)
else:
if maybe_graph.needs_capture():
def capture_forward_fn(inputs: Dict[str, Any]):
with MoeLoadBalancerIterContext(moe_load_balancer):
return self._forward_step(
inputs,
gather_ids=gather_ids,
gather_context_logits=gather_context_logits)
pool = maybe_graph.capture(
lambda inputs: self._forward_step(
inputs,
gather_ids=gather_ids,
gather_context_logits=gather_context_logits),
capture_forward_fn,
self._cuda_graph_mem_pool,
extra_model_inputs,
)
self._cuda_graph_mem_pool = pool
outputs = maybe_graph.run(inputs, extra_model_inputs)
# here we don't need to use context since cuda graph capture didn't run kernel.
# maybe we need a cleaner way to do this.
outputs = maybe_graph.run(inputs, extra_model_inputs)
else:
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = maybe_graph.run(inputs, extra_model_inputs)
# Note: To overlap the CPU and GPU computation as much as possible,
# guided_decoder.build should be called immediately after the launch of the single step;

View File

@ -18,6 +18,7 @@ BASE_ZMQ_CLASSES = {
"llmapi.run_llm_with_postproc": ["perform_faked_oai_postprocess"
], # only used in tests
### starting import of torch models classes. They are used in test_llm_multi_gpu.py.
"tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"],
"tensorrt_llm._torch.models.modeling_bert":
["BertForSequenceClassification"],
"tensorrt_llm._torch.models.modeling_clip": ["CLIPVisionModel"],
@ -48,6 +49,7 @@ BASE_ZMQ_CLASSES = {
"tensorrt_llm._torch.models.modeling_siglip": ["SiglipVisionModel"],
"tensorrt_llm._torch.models.modeling_vila": ["VilaModel"],
### ending import of torch models classes
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
"tensorrt_llm._torch.pyexecutor.llm_request":
["LogitsStorage", "PyResult", "LlmResult", "LlmResponse", "LogProbStorage"],
"tensorrt_llm._torch.speculative.mtp": ["MTPConfig"],
@ -58,13 +60,13 @@ BASE_ZMQ_CLASSES = {
["ClusterInfo", "MathThroughput"],
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
"tensorrt_llm.bindings.executor": [
"BatchingType", "CapacitySchedulerPolicy", "ContextPhaseParams",
"BatchingType", "CacheTransceiverConfig", "CapacitySchedulerPolicy",
"ContextPhaseParams", "ContextChunkingPolicy", "DynamicBatchConfig",
"ExecutorConfig", "ExtendedRuntimePerfKnobConfig", "Response", "Result",
"FinishReason", "KvCacheConfig", "KvCacheTransferMode",
"KvCacheRetentionConfig",
"KvCacheRetentionConfig.TokenRangeRetentionConfig", "PeftCacheConfig",
"SchedulerConfig", "DynamicBatchConfig", "ContextChunkingPolicy",
"CacheTransceiverConfig"
"SchedulerConfig"
],
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig"],
"tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"],

View File

@ -78,6 +78,11 @@ class TestHostMoeTensorSharer(unittest.TestCase):
size = comm.Get_size()
layer_id = 0
# Test tensor parameters
experts_per_rank = 2 # Each rank is responsible for 2 consecutive experts
expert_count = size * experts_per_rank
tensor_shape = (16, 32) # Use 2D tensor for testing
# Maximum supported ranks (can adjust as needed)
max_ranks = 8
if size > max_ranks:
@ -87,17 +92,12 @@ class TestHostMoeTensorSharer(unittest.TestCase):
shared_comm = comm.Split_type(split_type=MPI.COMM_TYPE_SHARED)
# Initialize HostMoeTensorSharer
sharer = HostMoeTensorSharer(layer_id, shared_comm)
sharer = HostMoeTensorSharer(layer_id, expert_count, shared_comm)
# Set shared memory base name
shared_memory_base_name = "test_host_sharer"
sharer.set_shared_memory_base_name(shared_memory_base_name)
# Test tensor parameters
experts_per_rank = 2 # Each rank is responsible for 2 consecutive experts
expert_count = size * experts_per_rank
tensor_shape = (16, 32) # Use 2D tensor for testing
# Calculate the range of experts this rank is responsible for
start_expert_id = rank * experts_per_rank
end_expert_id = start_expert_id + experts_per_rank
@ -124,6 +124,8 @@ class TestHostMoeTensorSharer(unittest.TestCase):
sharer.pre_register_host_tensor_with_shape(
expert_id, "weight", torch.float32, tensor_shape)
sharer.finalize_layer_weights()
# Ensure all processes have created and registered their tensors
comm.Barrier()

View File

@ -2,10 +2,11 @@ import unittest
from unittest.mock import MagicMock, patch
import torch
from mpi4py import MPI
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import (
MoeLoadBalancer, SingleLayerMoeLoadBalancer, get_moe_load_balancer,
moe_load_balancer_add_single_layer)
MoeLoadBalancer, MoeLoadBalancerIterContext, SingleLayerMoeLoadBalancer,
get_moe_load_balancer, moe_load_balancer_add_single_layer)
class TestMoeLoadBalancer(unittest.TestCase):
@ -186,7 +187,9 @@ class TestMoeLoadBalancer(unittest.TestCase):
# Setup
mock_single_layer_impl = MagicMock()
layer = SingleLayerMoeLoadBalancer(mock_single_layer_impl, None)
layer = SingleLayerMoeLoadBalancer(mock_single_layer_impl,
MPI.COMM_WORLD,
expert_count=4)
# Mock out torch.ops.trtllm functions
with patch('torch.ops.trtllm.moe_load_balance_wait_gpu_stage') as mock_wait, \
@ -198,13 +201,13 @@ class TestMoeLoadBalancer(unittest.TestCase):
# add_weight_slot
mock_weight = MagicMock()
layer._add_weight_slot(1, "weight1", mock_weight)
mock_single_layer_impl.add_weight_slot.assert_called_once_with(
mock_single_layer_impl.add_single_weight_slot.assert_called_once_with(
1, "weight1", mock_weight)
# add_host_weight
mock_host_weight = MagicMock()
layer._add_host_weight(2, "weight2", mock_host_weight)
mock_single_layer_impl.add_host_weight.assert_called_once_with(
mock_single_layer_impl.add_single_host_weight.assert_called_once_with(
2, "weight2", mock_host_weight)
# set_initial_weight_assignments
@ -215,7 +218,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
# wait_for_gpu_stage
mock_wait.return_value = torch.tensor([1])
result = layer.wait_for_gpu_stage()
layer.wait_for_gpu_stage()
result = layer.statistic_flag_tensor
mock_wait.assert_called_once_with(
mock_single_layer_impl.get_pointer())
self.assertEqual(result, mock_wait.return_value)
@ -228,7 +232,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
# statistic
mock_expert_ids = torch.tensor([[0, 1], [2, 3]])
mock_enabled = torch.tensor([1])
layer.statistic(mock_expert_ids, mock_enabled, True, False)
layer.statistic_flag_tensor = mock_enabled
layer.statistic(mock_expert_ids, True, False)
mock_statistic.assert_called_once_with(
mock_expert_ids, mock_enabled,
mock_single_layer_impl.get_pointer(), True, False)
@ -237,8 +242,6 @@ class TestMoeLoadBalancer(unittest.TestCase):
mock_selected_experts = torch.tensor([[0, 1], [2, 3]])
mock_route.return_value = torch.tensor([[0, 1], [2, 3]])
result = layer.route(mock_selected_experts)
mock_route.assert_called_once_with(
mock_selected_experts, mock_single_layer_impl.get_pointer())
assert torch.equal(result, mock_route.return_value)
@patch('tensorrt_llm.bindings.internal.runtime.MoeLoadBalancer')
@ -260,14 +263,13 @@ class TestMoeLoadBalancer(unittest.TestCase):
mock_load_balancer_impl.return_value.set_warm_up_iter_count.assert_called_once_with(
10)
# start_iter
balancer.start_iter(1, True, True)
mock_load_balancer_impl.return_value.start_iter.assert_called_once_with(
1, True, True)
balancer.set_next_iter_info(True, True)
# end_iter
balancer.end_iter(1)
mock_load_balancer_impl.return_value.end_iter.assert_called_once_with(1)
with MoeLoadBalancerIterContext(balancer):
mock_load_balancer_impl.return_value.start_iter.assert_called_once_with(
0, True, True)
mock_load_balancer_impl.return_value.end_iter.assert_called_once_with(0)
# shutdown
balancer.shutdown()
@ -288,6 +290,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
# Create a real MoeLoadBalancer
balancer = MoeLoadBalancer(ep_rank, ep_size, 1)
balancer.set_use_gpu_memcpy(True)
# Add a layer with initial weight assignments
# Each slot is assigned to exactly one expert initially
layer = balancer.add_layer(expert_count, top_k, slots_per_rank)
@ -297,9 +301,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
# Finalize the model
balancer.finalize_model()
# Start iteration - enable statistic, disable weight update
iter_id = 0
balancer.start_iter(iter_id, True, False)
# enable statistic, disable weight update
balancer.set_next_iter_info(True, False)
# Create sample token data - each token selects 2 experts
# 4 tokens, each selecting 2 experts
@ -314,17 +317,15 @@ class TestMoeLoadBalancer(unittest.TestCase):
device="cuda")
try:
# Wait for GPU stage and get enabled flag
enabled = layer.wait_for_gpu_stage()
with MoeLoadBalancerIterContext(balancer):
# Wait for GPU stage and get enabled flag
layer.wait_for_gpu_stage()
# Run statistic - just test it runs without error
layer.statistic(gathered_raw_expert_ids, enabled, True, True)
# Run statistic - just test it runs without error
layer.statistic(gathered_raw_expert_ids, True, True)
# Set CPU stage to signal completion
layer.set_cpu_stage()
# End iteration
balancer.end_iter(iter_id)
# Set CPU stage to signal completion
layer.set_cpu_stage()
# Test passed if we got here without exceptions
self.assertTrue(True, "Statistic kernel ran successfully")
@ -350,6 +351,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
# Create a real MoeLoadBalancer
balancer = MoeLoadBalancer(ep_rank, ep_size, 1)
balancer.set_use_gpu_memcpy(True)
# Add a layer with known initial weight assignments
layer = balancer.add_layer(expert_count, top_k, slots_per_rank)
@ -360,9 +363,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
# Finalize the model
balancer.finalize_model()
# Start iteration - enable statistic, disable weight update
iter_id = 0
balancer.start_iter(iter_id, True, False)
# enable statistic, disable weight update
balancer.set_next_iter_info(True, False)
# Create sample token data - tokens selecting different experts
token_selected_experts = torch.tensor(
@ -376,17 +378,15 @@ class TestMoeLoadBalancer(unittest.TestCase):
device="cuda")
try:
# Wait for GPU stage
layer.wait_for_gpu_stage()
with MoeLoadBalancerIterContext(balancer):
# Wait for GPU stage
layer.wait_for_gpu_stage()
# Run routing
routed_slots = layer.route(token_selected_experts)
# Run routing
routed_slots = layer.route(token_selected_experts)
# Set CPU stage
layer.set_cpu_stage()
# End iteration
balancer.end_iter(iter_id)
# Set CPU stage
layer.set_cpu_stage()
# Verify results - with our initial assignment, expert i should map to slot i
expected_slots = torch.tensor(

View File

@ -136,6 +136,8 @@ class TestMoePythonBindings(unittest.TestCase):
ep_size=self.ep_size,
layer_updates_per_iter=self.layer_updates_per_iter)
balancer.set_use_gpu_memcpy(True)
# Add a layer
layer = balancer.add_layer(expert_count=self.expert_count,
top_k=self.top_k,
@ -206,6 +208,8 @@ class TestMoePythonBindings(unittest.TestCase):
ep_size=self.ep_size,
layer_updates_per_iter=self.layer_updates_per_iter)
balancer.set_use_gpu_memcpy(True)
# Create initial weight assignments
initial_assignments = []
for r in range(self.ep_size):