mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
feat: large-scale EP(part 6: Online EP load balancer integration for GB200 nvfp4) (#4818)
Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Signed-off-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Co-authored-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
parent
5ee0de7f2a
commit
1e369658f1
@ -1,7 +1,7 @@
|
||||
version: "3.9"
|
||||
services:
|
||||
tensorrt_llm-dev:
|
||||
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539
|
||||
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420
|
||||
network_mode: host
|
||||
ipc: host
|
||||
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from conan import ConanFile
|
||||
from conan.tools.cmake import CMakeDeps, CMakeToolchain
|
||||
|
||||
@ -9,10 +12,22 @@ class TensorRT_LLM(ConanFile):
|
||||
virtualrunenv = False
|
||||
|
||||
def requirements(self):
|
||||
pass # TODO add dependencies here
|
||||
self.requires("libnuma/system")
|
||||
|
||||
def generate(self):
|
||||
cmake = CMakeDeps(self)
|
||||
cmake.generate()
|
||||
tc = CMakeToolchain(self)
|
||||
tc.generate()
|
||||
|
||||
def build_requirements(self):
|
||||
# register libnuma_conan.py for conan
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
libnuma_path = os.path.join(base_dir, "libnuma_conan.py")
|
||||
conan_bin = os.path.abspath(sys.argv[0])
|
||||
if not os.path.isfile(conan_bin) or not os.access(conan_bin, os.X_OK):
|
||||
raise RuntimeError(f"Conan binary not found {sys.argv[0]}")
|
||||
|
||||
self.run(
|
||||
f"{conan_bin} export {libnuma_path} --name=libnuma --version=system"
|
||||
)
|
||||
|
||||
36
cpp/libnuma_conan.py
Normal file
36
cpp/libnuma_conan.py
Normal file
@ -0,0 +1,36 @@
|
||||
from conan import ConanFile
|
||||
|
||||
|
||||
class LibnumaSystemConan(ConanFile):
|
||||
name = "libnuma"
|
||||
version = "system"
|
||||
package_type = "shared-library"
|
||||
settings = "os", "arch"
|
||||
|
||||
def package_info(self):
|
||||
if self.settings.os == "Windows":
|
||||
self.output.info("libnuma not needed on Windows.")
|
||||
return
|
||||
|
||||
self.cpp_info.includedirs = ["/usr/include"]
|
||||
libdirs = []
|
||||
|
||||
arch = str(self.settings.arch)
|
||||
os_name = str(self.settings.os)
|
||||
|
||||
if os_name == "Linux":
|
||||
if arch == "x86_64":
|
||||
libdirs.append("/usr/lib/x86_64-linux-gnu")
|
||||
elif arch in ["armv8", "aarch64"]:
|
||||
libdirs.append("/usr/lib/aarch64-linux-gnu")
|
||||
else:
|
||||
self.output.warn(
|
||||
f"Unrecognized architecture: {arch}, falling back to /usr/lib"
|
||||
)
|
||||
libdirs.append("/usr/lib")
|
||||
else:
|
||||
self.output.warn(f"Unsupported OS: {os_name}, assuming /usr/lib")
|
||||
libdirs.append("/usr/lib")
|
||||
|
||||
self.cpp_info.libdirs = libdirs
|
||||
self.cpp_info.system_libs = ["numa"]
|
||||
@ -245,12 +245,166 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
|
||||
}
|
||||
}
|
||||
|
||||
template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
|
||||
__global__ void moeComputeRouteNoRedundantKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo,
|
||||
int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount)
|
||||
{
|
||||
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
|
||||
int expertIds[ITEM_PER_THREAD];
|
||||
int slotIds[ITEM_PER_THREAD];
|
||||
for (int slotId = threadIdx.x; slotId < metaInfo.epSize * metaInfo.slotCountPerRank; slotId += THREAD_COUNT)
|
||||
{
|
||||
sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds[slotId];
|
||||
}
|
||||
|
||||
int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD;
|
||||
|
||||
for (; blockOffset < tokenCount * metaInfo.topK; blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD)
|
||||
{
|
||||
int tokenIdxBase = blockOffset + threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEM_PER_THREAD; i++)
|
||||
{
|
||||
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
|
||||
expertIds[i]
|
||||
= tokenIdx < tokenCount * metaInfo.topK ? tokenSelectedExperts[tokenIdx] : metaInfo.expertCount;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEM_PER_THREAD; i++)
|
||||
{
|
||||
if (expertIds[i] < 0 || expertIds[i] >= metaInfo.expertCount)
|
||||
{
|
||||
expertIds[i] = metaInfo.expertCount;
|
||||
}
|
||||
}
|
||||
if (blockOffset == blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD)
|
||||
{
|
||||
__syncthreads();
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEM_PER_THREAD; i++)
|
||||
{
|
||||
slotIds[i] = expertIds[i] < metaInfo.expertCount ? sharedGlobalSlotIdsInfo[expertIds[i]]
|
||||
: metaInfo.epSize * metaInfo.slotCountPerRank;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEM_PER_THREAD; i++)
|
||||
{
|
||||
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
|
||||
if (tokenIdx < tokenCount * metaInfo.topK)
|
||||
{
|
||||
tokenRoutedSlotIds[tokenIdx] = slotIds[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
|
||||
__global__ void moeComputeRouteKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo,
|
||||
int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount, bool offsetByEpRank)
|
||||
{
|
||||
int warpId = threadIdx.x / 32;
|
||||
int laneId = threadIdx.x % 32;
|
||||
static int const kWarpCount = THREAD_COUNT / 32;
|
||||
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
|
||||
__shared__ int sharedExpertReplicaCountAndStartOffset[MAX_EXPERT_COUNT];
|
||||
|
||||
__shared__ int sharedArbitrateExpertId[THREAD_COUNT * ITEM_PER_THREAD];
|
||||
__shared__ int sharedExpertCount[MAX_EXPERT_COUNT];
|
||||
for (int expertIdx = threadIdx.x; expertIdx < metaInfo.expertCount; expertIdx += THREAD_COUNT)
|
||||
{
|
||||
int replicaCount = placementInfo.expertReplicaCount[expertIdx];
|
||||
int replicaStartOffset = placementInfo.expertReplicaStartOffset[expertIdx];
|
||||
sharedExpertReplicaCountAndStartOffset[expertIdx] = (replicaCount << 16) | replicaStartOffset;
|
||||
sharedExpertCount[expertIdx] = 0;
|
||||
}
|
||||
for (int slotId = threadIdx.x; slotId < metaInfo.epSize * metaInfo.slotCountPerRank; slotId += THREAD_COUNT)
|
||||
{
|
||||
sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds[slotId];
|
||||
}
|
||||
|
||||
int expertIds[ITEM_PER_THREAD];
|
||||
int tokenIdxBase = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD + threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEM_PER_THREAD; i++)
|
||||
{
|
||||
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
|
||||
expertIds[i] = tokenIdx < tokenCount * metaInfo.topK ? tokenSelectedExperts[tokenIdx] : metaInfo.expertCount;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEM_PER_THREAD; i++)
|
||||
{
|
||||
if (expertIds[i] < 0 || expertIds[i] >= metaInfo.expertCount)
|
||||
{
|
||||
expertIds[i] = metaInfo.expertCount;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEM_PER_THREAD; i++)
|
||||
{
|
||||
int countAndStart
|
||||
= expertIds[i] < metaInfo.expertCount ? sharedExpertReplicaCountAndStartOffset[expertIds[i]] : (1 << 16);
|
||||
int arbitrateExpertId = (countAndStart >> 16) > 1 ? expertIds[i] : metaInfo.expertCount;
|
||||
sharedArbitrateExpertId[threadIdx.x + i * THREAD_COUNT] = arbitrateExpertId;
|
||||
}
|
||||
__syncthreads();
|
||||
int baseOffset = blockIdx.x + (offsetByEpRank ? metaInfo.epRank : 0);
|
||||
if (warpId == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWarpCount * ITEM_PER_THREAD; ++i)
|
||||
{
|
||||
int expertId = sharedArbitrateExpertId[laneId + i * 32];
|
||||
__syncwarp();
|
||||
unsigned match = __match_any_sync(0xFFFFFFFF, expertId);
|
||||
int leader = __ffs(match) - 1;
|
||||
int matchCount = __popc(match);
|
||||
int oldVal = 0;
|
||||
if (laneId == leader && expertId < metaInfo.expertCount)
|
||||
{
|
||||
oldVal = atomicAdd_block(&sharedExpertCount[expertId], matchCount);
|
||||
}
|
||||
__syncwarp();
|
||||
oldVal = __shfl_sync(0XFFFFFFFF, oldVal, leader);
|
||||
unsigned lowerMask = match & ((1u << laneId) - 1);
|
||||
int rankInGroup = __popc(lowerMask);
|
||||
int offset = oldVal + rankInGroup;
|
||||
offset += baseOffset;
|
||||
sharedArbitrateExpertId[laneId + i * 32] = offset;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int targetGlobalSlotId[ITEM_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEM_PER_THREAD; i++)
|
||||
{
|
||||
int countAndStart
|
||||
= expertIds[i] < metaInfo.expertCount ? sharedExpertReplicaCountAndStartOffset[expertIds[i]] : (1 << 16);
|
||||
int count = countAndStart >> 16;
|
||||
int offset = countAndStart & 0xFFFF;
|
||||
int arbitratedIndex = sharedArbitrateExpertId[threadIdx.x + i * THREAD_COUNT];
|
||||
offset += arbitratedIndex % count;
|
||||
targetGlobalSlotId[i] = expertIds[i] < metaInfo.expertCount ? sharedGlobalSlotIdsInfo[offset]
|
||||
: metaInfo.epSize * metaInfo.slotCountPerRank;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEM_PER_THREAD; i++)
|
||||
{
|
||||
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
|
||||
if (tokenIdx < tokenCount * metaInfo.topK)
|
||||
{
|
||||
tokenRoutedSlotIds[tokenIdx] = targetGlobalSlotId[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
|
||||
__global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo,
|
||||
int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount, bool offsetByEpRank)
|
||||
{
|
||||
using BlockSort = cub::BlockRadixSort<int, THREAD_COUNT, 1>;
|
||||
extern __shared__ int sharedGlobalSlotIdsInfo[];
|
||||
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
|
||||
|
||||
__shared__ typename BlockSort::TempStorage tempStorage;
|
||||
|
||||
@ -361,9 +515,19 @@ void moeComputeRouteDevice(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo pla
|
||||
constexpr int kThreadCount = 256;
|
||||
constexpr int kEltPerThread = 4;
|
||||
int blockCount = (tokenCount * metaInfo.topK + kThreadCount * kEltPerThread - 1) / (kThreadCount * kEltPerThread);
|
||||
int dynamicShmSize = sizeof(int) * metaInfo.epSize * metaInfo.slotCountPerRank;
|
||||
moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
|
||||
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
|
||||
int dynamicShmSize = sizeof(int16_t) * metaInfo.epSize * metaInfo.slotCountPerRank;
|
||||
if (metaInfo.expertCount == metaInfo.epSize * metaInfo.slotCountPerRank)
|
||||
{
|
||||
// no redundant expert, so we don't need complex routing, but just assign to the correct solt.
|
||||
moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread>
|
||||
<<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
|
||||
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
|
||||
}
|
||||
else
|
||||
{
|
||||
moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
|
||||
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
|
||||
}
|
||||
}
|
||||
|
||||
void moeWaitSignalForCpuStageHost(MoeLoadBalanceSingleLayerSignal* signal)
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
*/
|
||||
|
||||
#include "moeBindings.h"
|
||||
#include "tensorrt_llm/runtime/moeLoadBalancer.h"
|
||||
#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h"
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
@ -98,6 +98,8 @@ void initMoeBindings(pybind11::module_& m)
|
||||
py::class_<tr::MoeLoadBalancer>(m, "MoeLoadBalancer")
|
||||
.def(py::init<int, int, int>(), py::arg("ep_rank"), py::arg("ep_size"), py::arg("layer_updates_per_iter"),
|
||||
"Initialize the MoeLoadBalancer with the specified expert parallel rank, size, and update frequency")
|
||||
.def("set_use_gpu_memcpy", &tr::MoeLoadBalancer::setUseGpuMemcpy, py::arg("use_gpu_memcpy"),
|
||||
"Set whether to use GPU memcpy for weight updates")
|
||||
.def("add_layer", &tr::MoeLoadBalancer::AddLayer, py::arg("expert_count"), py::arg("top_k"),
|
||||
py::arg("slot_count_per_rank"), "Add a new MOE layer to the load balancer")
|
||||
.def("finalize_model", &tr::MoeLoadBalancer::finalizeModel,
|
||||
|
||||
@ -43,7 +43,8 @@ set(SRCS
|
||||
ipcNvlsMemory.cpp
|
||||
mcastDeviceMemory.cpp
|
||||
memoryCounters.cpp
|
||||
moeLoadBalancer.cpp
|
||||
moeLoadBalancer/moeLoadBalancer.cpp
|
||||
moeLoadBalancer/topologyDetector.cpp
|
||||
ncclCommunicator.cpp
|
||||
promptTuningParams.cpp
|
||||
runtimeKernels.cu
|
||||
@ -80,3 +81,27 @@ target_include_directories(runtime_src PRIVATE ${MPI_C_INCLUDE_DIRS})
|
||||
if(ENABLE_MULTI_DEVICE)
|
||||
target_link_libraries(runtime_src PUBLIC ${NCCL_LIB})
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
find_package(libnuma QUIET CONFIG)
|
||||
|
||||
if(NOT libnuma_FOUND)
|
||||
message(
|
||||
STATUS "libnuma not found via Conan, falling back to system libnuma")
|
||||
find_path(NUMA_INCLUDE_DIR numa.h)
|
||||
find_library(NUMA_LIBRARY numa)
|
||||
|
||||
if(NUMA_INCLUDE_DIR AND NUMA_LIBRARY)
|
||||
add_library(libnuma::libnuma UNKNOWN IMPORTED)
|
||||
set_target_properties(
|
||||
libnuma::libnuma
|
||||
PROPERTIES IMPORTED_LOCATION "${NUMA_LIBRARY}"
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${NUMA_INCLUDE_DIR}")
|
||||
else()
|
||||
message(FATAL_ERROR "NUMA library not found, please install libnuma-dev")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "libnuma found.")
|
||||
endif()
|
||||
target_link_libraries(runtime_src PUBLIC libnuma::libnuma)
|
||||
endif()
|
||||
|
||||
@ -14,10 +14,11 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/runtime/moeLoadBalancer.h"
|
||||
#include "moeLoadBalancer.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
|
||||
#include "topologyDetector.h"
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
@ -33,7 +34,6 @@
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
// Helper structure to hold replica information
|
||||
struct ReplicaInfo
|
||||
{
|
||||
@ -271,6 +271,7 @@ void allocateStatisticInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo const&
|
||||
tensorrt_llm::kernels::MoeLoadBalanceStatisticInfo* statisticInfo)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaMallocHost(&statisticInfo->expertLoadFactor, sizeof(float) * metaInfo.expertCount));
|
||||
std::fill_n(statisticInfo->expertLoadFactor, metaInfo.expertCount, 0.0f);
|
||||
TLLM_CHECK_WITH_INFO(statisticInfo->rawDataWindowSize > 0, "statisticInfo->rawDataWindowSize should > 0.");
|
||||
TLLM_CUDA_CHECK(cudaMalloc(
|
||||
&statisticInfo->expertTokenCount, sizeof(int) * metaInfo.expertCount * statisticInfo->rawDataWindowSize));
|
||||
@ -288,9 +289,9 @@ void freeStatisticInfo(tensorrt_llm::kernels::MoeLoadBalanceStatisticInfo* stati
|
||||
}
|
||||
|
||||
void allocatePlacementInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo const& metaInfo,
|
||||
tensorrt_llm::kernels::MoePlacementInfo* placementInfo, bool isCpu = false)
|
||||
tensorrt_llm::kernels::MoePlacementInfo* placementInfo, bool isCpu = false, bool useManaged = false)
|
||||
{
|
||||
auto allocFn = [isCpu](void** ptr, size_t size)
|
||||
auto allocFn = [isCpu, useManaged](void** ptr, size_t size)
|
||||
{
|
||||
if (isCpu)
|
||||
{
|
||||
@ -298,7 +299,21 @@ void allocatePlacementInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo const&
|
||||
}
|
||||
else
|
||||
{
|
||||
return cudaMalloc(ptr, size);
|
||||
if (useManaged)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaMallocManaged(ptr, size));
|
||||
int cur_dev;
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&cur_dev));
|
||||
TLLM_CUDA_CHECK(cudaMemAdvise(*ptr, size, cudaMemAdviseSetPreferredLocation, cur_dev));
|
||||
TLLM_CUDA_CHECK(cudaMemAdvise(*ptr, size, cudaMemAdviseSetAccessedBy, cur_dev));
|
||||
TLLM_CUDA_CHECK(cudaMemAdvise(*ptr, size, cudaMemAdviseSetAccessedBy, cudaCpuDeviceId));
|
||||
TLLM_CUDA_CHECK(cudaMemset(*ptr, 0, size));
|
||||
return cudaSuccess;
|
||||
}
|
||||
else
|
||||
{
|
||||
return cudaMalloc(ptr, size);
|
||||
}
|
||||
}
|
||||
};
|
||||
TLLM_CUDA_CHECK(
|
||||
@ -405,7 +420,7 @@ void SingleLayerMoeLoadBalancer::createResources()
|
||||
}
|
||||
|
||||
allocatePlacementInfo(mMetaInfo, &mCpuPlacementInfo.placementInfoForGPU, true);
|
||||
allocatePlacementInfo(mMetaInfo, &mGpuPlacement, false);
|
||||
allocatePlacementInfo(mMetaInfo, &mGpuPlacement, false, true);
|
||||
|
||||
mSingleLayerSignal = allocateSingleLayerSignal();
|
||||
TLLM_CUDA_CHECK(cudaEventCreate(&mUpdateWeightsDoneEvent));
|
||||
@ -451,7 +466,18 @@ void SingleLayerMoeLoadBalancer::maybeStartUpdateWeights()
|
||||
{
|
||||
if (mIterId >= 0 && mUpdateWeightsEnabled)
|
||||
{
|
||||
mMoeLoadBalancer->addUpdateTask([this] { updateWeightsRoutine(); });
|
||||
mMoeLoadBalancer->addUpdateTask(
|
||||
[this]
|
||||
{
|
||||
if (mMoeLoadBalancer->mUseGpuMemcpy)
|
||||
{
|
||||
updateWeightsRoutine();
|
||||
}
|
||||
else
|
||||
{
|
||||
updateWeightsRoutineByCpu();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -461,6 +487,7 @@ void SingleLayerMoeLoadBalancer::waitLastUpdateDone()
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mUpdateWeightsMutex);
|
||||
mUpdateWeightsCondition.wait(lock, [this] { return mUpdateWeightsDone; });
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@ -485,6 +512,25 @@ void SingleLayerMoeLoadBalancer::copyPlacementInfoToGpu()
|
||||
{
|
||||
std::fill_n(mCpuPlacementInfo.rankExpertIds[i].begin(), mMetaInfo.slotCountPerRank, -1);
|
||||
}
|
||||
// clear expert load factor for next statistic
|
||||
std::fill_n(mStatisticInfo.expertLoadFactor, mMetaInfo.expertCount, 0.0f);
|
||||
}
|
||||
|
||||
void SingleLayerMoeLoadBalancer::copyPlacementInfoToGpuByCpu()
|
||||
{
|
||||
memcpy(mGpuPlacement.expertReplicaCount, mCpuPlacementInfo.placementInfoForGPU.expertReplicaCount,
|
||||
sizeof(int) * mMetaInfo.expertCount);
|
||||
memcpy(mGpuPlacement.expertReplicaStartOffset, mCpuPlacementInfo.placementInfoForGPU.expertReplicaStartOffset,
|
||||
sizeof(int) * mMetaInfo.expertCount);
|
||||
memcpy(mGpuPlacement.globalSlotIds, mCpuPlacementInfo.placementInfoForGPU.globalSlotIds,
|
||||
sizeof(int) * mMetaInfo.epSize * mMetaInfo.slotCountPerRank);
|
||||
mCpuPlacementInfo.rankExpertIds.swap(mCpuPlacementInfo.oldRankExpertIds);
|
||||
for (int i = 0; i < mMetaInfo.epSize; ++i)
|
||||
{
|
||||
std::fill_n(mCpuPlacementInfo.rankExpertIds[i].begin(), mMetaInfo.slotCountPerRank, -1);
|
||||
}
|
||||
// clear expert load factor for next statistic
|
||||
std::fill_n(mStatisticInfo.expertLoadFactor, mMetaInfo.expertCount, 0.0f);
|
||||
}
|
||||
|
||||
void SingleLayerMoeLoadBalancer::updateWeightsRoutine()
|
||||
@ -501,6 +547,21 @@ void SingleLayerMoeLoadBalancer::updateWeightsRoutine()
|
||||
mUpdateWeightsCondition.notify_one();
|
||||
}
|
||||
|
||||
void SingleLayerMoeLoadBalancer::updateWeightsRoutineByCpu()
|
||||
{
|
||||
doReplication(mMetaInfo, mStatisticInfo.expertLoadFactor, &mCpuPlacementInfo);
|
||||
doPlacement(mMetaInfo, mStatisticInfo.expertLoadFactor, &mCpuPlacementInfo);
|
||||
prepareGpuPlacementInfo(mMetaInfo, &mCpuPlacementInfo);
|
||||
mLastUpdateTaskId = mMoeLoadBalancer->addCopyTask(
|
||||
[this](int rank, int size) { mWeightUpdater->updateWeights(&mCpuPlacementInfo, rank, size); });
|
||||
mMoeLoadBalancer->waitCopyTaskDone(mLastUpdateTaskId);
|
||||
mLastUpdateTaskId = -1;
|
||||
copyPlacementInfoToGpuByCpu();
|
||||
std::unique_lock<std::mutex> lock(mUpdateWeightsMutex);
|
||||
mUpdateWeightsDone = true;
|
||||
mUpdateWeightsCondition.notify_one();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Weight Updater
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -646,8 +707,61 @@ void HostMemoryMoeWeightUpdater::copyWeights(MoeWeight const& src, MoeWeight con
|
||||
}
|
||||
}
|
||||
|
||||
void HostMemoryMoeWeightUpdater::updateWeights(tensorrt_llm::runtime::MoePlacementCpuInfo const* placementCpuInfo)
|
||||
void HostMemoryMoeWeightUpdater::copyWeightsCpu(MoeWeight const& src, MoeWeight const& dst, int rank, int size)
|
||||
{
|
||||
TLLM_CHECK(src.mWeightPtr != nullptr && dst.mWeightPtr != nullptr);
|
||||
TLLM_CHECK(src.mHeight == dst.mHeight && src.mWidth == dst.mWidth);
|
||||
char* srcPtr = static_cast<char*>(src.mWeightPtr);
|
||||
char* dstPtr = static_cast<char*>(dst.mWeightPtr);
|
||||
size_t singleCopySize, copyCount, srcPitch, dstPitch;
|
||||
if (src.mPitch == src.mWidth && dst.mPitch == dst.mWidth)
|
||||
{
|
||||
singleCopySize = src.mWidth * src.mHeight;
|
||||
copyCount = 1;
|
||||
srcPitch = singleCopySize;
|
||||
dstPitch = singleCopySize;
|
||||
}
|
||||
else
|
||||
{
|
||||
singleCopySize = src.mWidth;
|
||||
copyCount = src.mHeight;
|
||||
srcPitch = src.mPitch;
|
||||
dstPitch = dst.mPitch;
|
||||
}
|
||||
size_t fullCopyCount = copyCount / size * size;
|
||||
size_t threadCopyCount = fullCopyCount / size;
|
||||
for (size_t i = rank * threadCopyCount; i < (rank + 1) * threadCopyCount; i++)
|
||||
{
|
||||
memcpy(dstPtr + i * dstPitch, srcPtr + i * srcPitch, singleCopySize);
|
||||
}
|
||||
size_t threadStartOffset = rank * singleCopySize / size;
|
||||
size_t threadEndOffset = (rank + 1) * singleCopySize / size;
|
||||
size_t threadCopySize = threadEndOffset - threadStartOffset;
|
||||
for (size_t i = fullCopyCount; i < copyCount && threadCopySize > 0; i++)
|
||||
{
|
||||
memcpy(dstPtr + i * dstPitch + threadStartOffset, srcPtr + i * srcPitch + threadStartOffset, threadCopySize);
|
||||
}
|
||||
}
|
||||
|
||||
void PrintUpdateInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo,
|
||||
tensorrt_llm::runtime::MoePlacementCpuInfo const* placementCpuInfo)
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "[UpdateInfo] rank=" << metaInfo.epRank << ", expert weights=\n [";
|
||||
for (int slotId = 0; slotId < metaInfo.slotCountPerRank * metaInfo.epSize; slotId++)
|
||||
{
|
||||
ss << placementCpuInfo->rankExpertIds[slotId / metaInfo.slotCountPerRank][slotId % metaInfo.slotCountPerRank]
|
||||
<< ", ";
|
||||
}
|
||||
ss << "\n";
|
||||
fprintf(stderr, "%s\n", ss.str().c_str());
|
||||
}
|
||||
|
||||
void HostMemoryMoeWeightUpdater::updateWeights(
|
||||
tensorrt_llm::runtime::MoePlacementCpuInfo const* placementCpuInfo, int rank, int size)
|
||||
{
|
||||
// PrintUpdateInfo(mMetaInfo, placementCpuInfo);
|
||||
bool useGpu = mLayerLoadBalancer->mMoeLoadBalancer->mUseGpuMemcpy;
|
||||
for (int slotId = 0; slotId < mMetaInfo.slotCountPerRank; ++slotId)
|
||||
{
|
||||
int oldExpertId = placementCpuInfo->oldRankExpertIds[mMetaInfo.epRank][slotId];
|
||||
@ -665,7 +779,14 @@ void HostMemoryMoeWeightUpdater::updateWeights(tensorrt_llm::runtime::MoePlaceme
|
||||
auto& name = slotIt->first;
|
||||
auto& slotWeight = slotIt->second[slotId];
|
||||
auto& hostWeight = mHostWeights[name][newExpertId];
|
||||
copyWeights(hostWeight, slotWeight, mLayerLoadBalancer->getStream());
|
||||
if (useGpu)
|
||||
{
|
||||
copyWeights(hostWeight, slotWeight, mLayerLoadBalancer->getStream());
|
||||
}
|
||||
else
|
||||
{
|
||||
copyWeightsCpu(hostWeight, slotWeight, rank, size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -680,8 +801,40 @@ MoeLoadBalancer::MoeLoadBalancer(int epRank, int epSize, int layerUpdatesPerIter
|
||||
, mLayerUpdatesPerIter{layerUpdatesPerIter}
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&mCudaDeviceId));
|
||||
// create a non-blocking stream for compute and update
|
||||
// create a non-blocking stream for compute and update, not needed anymore for CPU copy engine.
|
||||
TLLM_CUDA_CHECK(cudaStreamCreateWithFlags(&mStream, cudaStreamNonBlocking));
|
||||
|
||||
auto& topologyDetector = TopologyDetector::getInstance();
|
||||
int currentGpuNumaId = topologyDetector.getCurrentGpuNumaId();
|
||||
int numaCpuCount = topologyDetector.getCurrentGpuNumaCpuCount();
|
||||
int numaGpuCount = topologyDetector.getGpuCountUnderNuma(currentGpuNumaId);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
numaCpuCount > 0 && numaGpuCount > 0, "numaCpuCount=%d, numaGpuCount=%d", numaCpuCount, numaGpuCount);
|
||||
int cpuCountPerGpu = std::max(1, numaCpuCount / numaGpuCount);
|
||||
std::string cpuArch = topologyDetector.getCpuArchitecture();
|
||||
|
||||
int numCopyThreads = 8;
|
||||
if (getenv("TLLM_LOAD_BALANCE_NUM_COPY_THREADS"))
|
||||
{
|
||||
int numCopyThreadsFromEnv = atoi(getenv("TLLM_LOAD_BALANCE_NUM_COPY_THREADS"));
|
||||
if (numCopyThreadsFromEnv > 0)
|
||||
{
|
||||
TLLM_LOG_INFO(
|
||||
"Setting TLLM_LOAD_BALANCE_NUM_COPY_THREADS to %d by environment variable", numCopyThreadsFromEnv);
|
||||
numCopyThreads = numCopyThreadsFromEnv;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (cpuCountPerGpu > 0)
|
||||
{
|
||||
numCopyThreads = std::min(16, std::max(4, cpuCountPerGpu / 2));
|
||||
TLLM_LOG_INFO("Auto-setting copy threads to %d based on NUMA topology (NUMA node %d, %d CPUs, arch: %s)",
|
||||
numCopyThreads, currentGpuNumaId, numaCpuCount, cpuArch.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
mMultiThreadWorker.reset(new MultiThreadWorker(numCopyThreads));
|
||||
}
|
||||
|
||||
MoeLoadBalancer::~MoeLoadBalancer() {}
|
||||
@ -730,6 +883,7 @@ void MoeLoadBalancer::finalizeModel()
|
||||
}
|
||||
if (mLayerUpdatesPerIter > 0)
|
||||
{
|
||||
mMultiThreadWorker->start();
|
||||
generateUpdatePlan();
|
||||
startThreads();
|
||||
}
|
||||
@ -783,6 +937,7 @@ void MoeLoadBalancer::shutdown()
|
||||
|
||||
mWorkerThread->join();
|
||||
TLLM_LOG_INFO("MoeLoadBalancer shutdown.");
|
||||
mMultiThreadWorker->stop();
|
||||
}
|
||||
}
|
||||
|
||||
@ -831,7 +986,6 @@ void MoeLoadBalancer::workerThread()
|
||||
}
|
||||
addUpdateTask(nullptr);
|
||||
mComputeAndUpdateThread->join();
|
||||
TLLM_LOG_INFO("MoeLoadBalancer worker thread stopped");
|
||||
}
|
||||
|
||||
void MoeLoadBalancer::computeAndUpdateThread()
|
||||
@ -850,7 +1004,6 @@ void MoeLoadBalancer::computeAndUpdateThread()
|
||||
}
|
||||
task();
|
||||
}
|
||||
TLLM_LOG_INFO("MoeLoadBalancer compute and update thread stopped");
|
||||
}
|
||||
|
||||
void MoeLoadBalancer::addUpdateTask(std::function<void()> task)
|
||||
@ -860,4 +1013,127 @@ void MoeLoadBalancer::addUpdateTask(std::function<void()> task)
|
||||
mUpdateQueueCondition.notify_one();
|
||||
}
|
||||
|
||||
int64_t MoeLoadBalancer::addCopyTask(std::function<void(int, int)> task)
|
||||
{
|
||||
return mMultiThreadWorker->addTask(task);
|
||||
}
|
||||
|
||||
void MoeLoadBalancer::waitCopyTaskDone(int64_t taskId)
|
||||
{
|
||||
if (!mUseGpuMemcpy)
|
||||
{
|
||||
mMultiThreadWorker->waitTaskDone(taskId);
|
||||
}
|
||||
}
|
||||
|
||||
MultiThreadWorker::MultiThreadWorker(int numThreads)
|
||||
: mNumThreads(numThreads)
|
||||
, mRunning(false)
|
||||
, mNextTaskId(0)
|
||||
{
|
||||
}
|
||||
|
||||
MultiThreadWorker::~MultiThreadWorker()
|
||||
{
|
||||
stop();
|
||||
}
|
||||
|
||||
void MultiThreadWorker::start()
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
if (mRunning)
|
||||
return;
|
||||
mRunning = true;
|
||||
mThreads.reserve(mNumThreads);
|
||||
for (int i = 0; i < mNumThreads; ++i)
|
||||
{
|
||||
mThreads.emplace_back(&MultiThreadWorker::workerLoop, this, i);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t MultiThreadWorker::addTask(std::function<void(int, int)> func)
|
||||
{
|
||||
auto task = std::make_shared<Task>();
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
task->id = mNextTaskId++;
|
||||
task->func = std::move(func);
|
||||
task->remaining = mNumThreads;
|
||||
mTasks.push_back(task);
|
||||
mTaskMap[task->id] = task;
|
||||
}
|
||||
mCondition.notify_all();
|
||||
return task->id;
|
||||
}
|
||||
|
||||
void MultiThreadWorker::waitTaskDone(int64_t taskId)
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mMutex);
|
||||
auto it = mTaskMap.find(taskId);
|
||||
if (it == mTaskMap.end())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mDoneTaskMap.count(taskId) > 0, "Task %ld not found", taskId);
|
||||
mDoneTaskMap.erase(taskId);
|
||||
return;
|
||||
}
|
||||
auto task = it->second;
|
||||
task->cv.wait(lk, [task] { return task->remaining == 0; });
|
||||
TLLM_CHECK_WITH_INFO(mDoneTaskMap.count(taskId) > 0, "Task %ld not found", taskId);
|
||||
mDoneTaskMap.erase(taskId);
|
||||
}
|
||||
|
||||
void MultiThreadWorker::stop()
|
||||
{
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
if (!mRunning)
|
||||
return;
|
||||
mRunning = false;
|
||||
}
|
||||
mCondition.notify_all();
|
||||
for (auto& t : mThreads)
|
||||
{
|
||||
if (t.joinable())
|
||||
t.join();
|
||||
}
|
||||
mThreads.clear();
|
||||
}
|
||||
|
||||
void MultiThreadWorker::workerLoop(int rank)
|
||||
{
|
||||
auto& topologyDetector = TopologyDetector::getInstance();
|
||||
topologyDetector.bindThreadByCurrentGpu(); // use relaxed mode
|
||||
while (true)
|
||||
{
|
||||
std::shared_ptr<Task> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mMutex);
|
||||
|
||||
mCondition.wait(lk, [this] { return !mRunning || !mTasks.empty(); });
|
||||
|
||||
if (!mRunning && mTasks.empty())
|
||||
return;
|
||||
|
||||
task = mTasks.front();
|
||||
}
|
||||
|
||||
task->func(rank, mNumThreads);
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mMutex);
|
||||
if (--task->remaining == 0)
|
||||
{
|
||||
mTasks.pop_front();
|
||||
mTaskMap.erase(task->id);
|
||||
mDoneTaskMap[task->id] = task;
|
||||
task->cv.notify_all();
|
||||
}
|
||||
else
|
||||
{
|
||||
task->cv.wait(lk, [task] { return task->remaining == 0; });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
@ -81,7 +82,7 @@ public:
|
||||
void addSingleWeightSlot(int localSlotId, std::string const& name, MoeWeight weightSlot);
|
||||
virtual void addSingleHostWeight(int expertId, std::string const& name, MoeWeight hostWeight) = 0;
|
||||
virtual void finalizeWeights();
|
||||
virtual void updateWeights(MoePlacementCpuInfo const* placementCpuInfo) = 0;
|
||||
virtual void updateWeights(MoePlacementCpuInfo const* placementCpuInfo, int rank = 0, int size = 1) = 0;
|
||||
|
||||
protected:
|
||||
void finalizeWeightSlot();
|
||||
@ -102,10 +103,11 @@ public:
|
||||
void addSingleHostWeight(int expertId, std::string const& name, MoeWeight hostWeight) override;
|
||||
void finalizeWeights() override;
|
||||
|
||||
void updateWeights(MoePlacementCpuInfo const* placementCpuInfo) override;
|
||||
void updateWeights(MoePlacementCpuInfo const* placementCpuInfo, int rank = 0, int size = 1) override;
|
||||
|
||||
private:
|
||||
static void copyWeights(MoeWeight const& src, MoeWeight const& dst, cudaStream_t stream);
|
||||
static void copyWeightsCpu(MoeWeight const& src, MoeWeight const& dst, int rank, int size);
|
||||
void finalizeHostWeight();
|
||||
bool mHostWeightsFinalized = false;
|
||||
std::map<std::string, std::vector<MoeWeight>> mHostWeights;
|
||||
@ -166,6 +168,7 @@ public:
|
||||
|
||||
private:
|
||||
friend class MoeLoadBalancer;
|
||||
friend class HostMemoryMoeWeightUpdater;
|
||||
|
||||
void createResources();
|
||||
void destroyResources();
|
||||
@ -187,7 +190,11 @@ private:
|
||||
bool mUpdateWeightsEnabled = true;
|
||||
|
||||
void copyPlacementInfoToGpu();
|
||||
void copyPlacementInfoToGpuByCpu();
|
||||
void updateWeightsRoutine();
|
||||
void updateWeightsRoutineByCpu();
|
||||
|
||||
int64_t mLastUpdateTaskId = -1;
|
||||
|
||||
cudaEvent_t mUpdateWeightsDoneEvent = nullptr;
|
||||
tensorrt_llm::kernels::MoeLoadBalanceMetaInfo mMetaInfo;
|
||||
@ -203,6 +210,42 @@ private:
|
||||
int mLayerId = -1;
|
||||
};
|
||||
|
||||
class MultiThreadWorker
|
||||
{
|
||||
public:
|
||||
explicit MultiThreadWorker(int numThreads);
|
||||
~MultiThreadWorker();
|
||||
|
||||
void start();
|
||||
int64_t addTask(std::function<void(int, int)> func);
|
||||
void waitTaskDone(int64_t taskId);
|
||||
void stop();
|
||||
|
||||
private:
|
||||
struct Task
|
||||
{
|
||||
int64_t id;
|
||||
std::function<void(int, int)> func;
|
||||
int remaining;
|
||||
std::condition_variable cv;
|
||||
};
|
||||
|
||||
void workerLoop(int rank);
|
||||
|
||||
int mNumThreads;
|
||||
std::vector<std::thread> mThreads;
|
||||
std::mutex mMutex;
|
||||
std::condition_variable mCondition;
|
||||
|
||||
std::deque<std::shared_ptr<Task>> mTasks;
|
||||
|
||||
std::unordered_map<int64_t, std::shared_ptr<Task>> mTaskMap;
|
||||
std::unordered_map<int64_t, std::shared_ptr<Task>> mDoneTaskMap;
|
||||
|
||||
bool mRunning;
|
||||
int64_t mNextTaskId;
|
||||
};
|
||||
|
||||
class MoeLoadBalancer
|
||||
{
|
||||
public:
|
||||
@ -227,8 +270,15 @@ public:
|
||||
// should bind to python
|
||||
void shutdown();
|
||||
|
||||
// Test interface to use GPU to do memcpy test functionality
|
||||
void setUseGpuMemcpy(bool useGpuMemcpy = false)
|
||||
{
|
||||
mUseGpuMemcpy = useGpuMemcpy;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class SingleLayerMoeLoadBalancer;
|
||||
friend class HostMemoryMoeWeightUpdater;
|
||||
|
||||
void startThreads();
|
||||
|
||||
@ -247,6 +297,8 @@ private:
|
||||
std::condition_variable mUpdateQueueCondition;
|
||||
std::queue<std::function<void()>> mUpdateTaskQueue;
|
||||
void addUpdateTask(std::function<void()> task);
|
||||
int64_t addCopyTask(std::function<void(int, int)> task);
|
||||
void waitCopyTaskDone(int64_t taskId);
|
||||
|
||||
std::vector<std::shared_ptr<SingleLayerMoeLoadBalancer>> mLayers;
|
||||
|
||||
@ -272,10 +324,14 @@ private:
|
||||
std::unique_ptr<std::thread> mWorkerThread;
|
||||
std::unique_ptr<std::thread> mComputeAndUpdateThread;
|
||||
|
||||
std::unique_ptr<MultiThreadWorker> mMultiThreadWorker;
|
||||
|
||||
// update plan member and function
|
||||
int mLayerUpdatesPerIter = 1;
|
||||
std::deque<std::set<int>> mUpdateLayerQueue;
|
||||
void generateUpdatePlan();
|
||||
|
||||
bool mUseGpuMemcpy = false;
|
||||
};
|
||||
|
||||
// functions exposed for testing
|
||||
446
cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.cpp
Normal file
446
cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.cpp
Normal file
@ -0,0 +1,446 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.h"
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
#include <algorithm> // For std::for_each, std::sort, std::unique
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <limits> // For std::numeric_limits
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __linux__
|
||||
#include <cerrno> // For errno
|
||||
#include <cstring> // For strerror
|
||||
#include <numa.h> // For libnuma
|
||||
#include <numaif.h> // For struct bitmask definition if not in numa.h
|
||||
#include <pthread.h>
|
||||
#include <sched.h>
|
||||
#endif
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
TopologyDetector::TopologyDetector()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mDetectionMutex);
|
||||
if (!mTopologyDetected)
|
||||
{
|
||||
detectCpuTopology();
|
||||
detectGpuTopology();
|
||||
#ifdef __linux__
|
||||
if (numa_available() != -1)
|
||||
{ // Only precompute if libnuma is usable
|
||||
precomputeCpuAffinityMasks();
|
||||
}
|
||||
#endif
|
||||
mTopologyDetected = true;
|
||||
}
|
||||
}
|
||||
|
||||
TopologyDetector::~TopologyDetector()
|
||||
{
|
||||
#ifdef __linux__
|
||||
auto free_mask_map = [](std::map<int, struct bitmask*>& mask_map)
|
||||
{
|
||||
for (auto const& [id, mask] : mask_map)
|
||||
{
|
||||
if (mask)
|
||||
{
|
||||
numa_free_cpumask(mask);
|
||||
}
|
||||
}
|
||||
mask_map.clear();
|
||||
};
|
||||
free_mask_map(mGpuStrictCpuMasks);
|
||||
#endif
|
||||
}
|
||||
|
||||
void TopologyDetector::detectCpuTopology()
|
||||
{
|
||||
// Detect CPU architecture
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
mCpuArchitecture = "x86_64";
|
||||
#elif defined(__aarch64__) || defined(_M_ARM64)
|
||||
mCpuArchitecture = "aarch64";
|
||||
#elif defined(__powerpc64__)
|
||||
mCpuArchitecture = "ppc64";
|
||||
#else
|
||||
mCpuArchitecture = "unknown";
|
||||
#endif
|
||||
|
||||
// Detect NUMA topology on Linux systems using libnuma
|
||||
#ifdef __linux__
|
||||
if (numa_available() == -1)
|
||||
{
|
||||
// libnuma not available, fall back to default behavior
|
||||
TLLM_LOG_WARNING("libnuma not available. Falling back to default CPU topology detection.");
|
||||
mNumaToCpuCountMap[0] = std::thread::hardware_concurrency();
|
||||
return;
|
||||
}
|
||||
|
||||
int maxNode = numa_max_node();
|
||||
if (maxNode < 0)
|
||||
{
|
||||
// Failed to get max node, fall back to default behavior
|
||||
TLLM_LOG_WARNING("Failed to get max NUMA node. Falling back to default CPU topology detection.");
|
||||
mNumaToCpuCountMap[0] = std::thread::hardware_concurrency();
|
||||
return;
|
||||
}
|
||||
|
||||
mNumaToCpuCountMap.clear(); // Clear before re-populating
|
||||
std::map<int, int> tempNumaToCpuCountMap;
|
||||
for (int i = 0; i <= maxNode; ++i)
|
||||
{
|
||||
struct bitmask* cpus = numa_allocate_cpumask();
|
||||
if (!cpus)
|
||||
{
|
||||
TLLM_LOG_WARNING("Failed to allocate cpumask for NUMA node query. Skipping node %d.", i);
|
||||
continue; // Skip to the next node if allocation fails
|
||||
}
|
||||
|
||||
// Attempt to get CPUs for node i. If numa_node_to_cpus returns 0, it's successful.
|
||||
if (numa_node_to_cpus(i, cpus) == 0)
|
||||
{
|
||||
int cpuCount = 0;
|
||||
for (int cpu_idx = 0; cpu_idx < numa_num_possible_cpus(); ++cpu_idx)
|
||||
{
|
||||
if (numa_bitmask_isbitset(cpus, cpu_idx))
|
||||
{
|
||||
cpuCount++;
|
||||
}
|
||||
}
|
||||
if (cpuCount > 0)
|
||||
{ // Only add NUMA nodes with actual CPUs
|
||||
tempNumaToCpuCountMap[i] = cpuCount;
|
||||
}
|
||||
}
|
||||
// If numa_node_to_cpus failed (returned -1), node 'i' might be invalid or an error occurred.
|
||||
// In this case, we simply don't add it to our map, effectively skipping it.
|
||||
|
||||
numa_free_cpumask(cpus); // Always free the allocated mask
|
||||
}
|
||||
mNumaToCpuCountMap = tempNumaToCpuCountMap;
|
||||
|
||||
if (mNumaToCpuCountMap.empty())
|
||||
{
|
||||
// If no NUMA nodes with CPUs were detected (e.g. libnuma error or unusual configuration),
|
||||
// default to a single NUMA node with all hardware concurrency.
|
||||
TLLM_LOG_WARNING(
|
||||
"No NUMA nodes with CPUs detected via libnuma, or libnuma error. Defaulting to single NUMA node.");
|
||||
mNumaToCpuCountMap[0] = std::thread::hardware_concurrency();
|
||||
}
|
||||
|
||||
#else
|
||||
// For non-Linux systems, assume a single NUMA node
|
||||
mNumaToCpuCountMap[0] = std::thread::hardware_concurrency();
|
||||
#endif
|
||||
}
|
||||
|
||||
void TopologyDetector::detectGpuTopology()
|
||||
{
|
||||
int deviceCount = 0;
|
||||
cudaError_t result = cudaGetDeviceCount(&deviceCount);
|
||||
if (result != cudaSuccess || deviceCount == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
mGpuToNumaMap.clear(); // Clear before re-populating
|
||||
mNumaToGpuMap.clear(); // Clear before re-populating
|
||||
|
||||
for (int deviceId = 0; deviceId < deviceCount; ++deviceId)
|
||||
{
|
||||
int numaNode = 0; // Default NUMA node
|
||||
|
||||
#ifdef __linux__
|
||||
if (numa_available() != -1)
|
||||
{
|
||||
char pciPath[256];
|
||||
cudaDeviceProp prop;
|
||||
if (cudaGetDeviceProperties(&prop, deviceId) == cudaSuccess)
|
||||
{
|
||||
// Construct PCI path to find NUMA node
|
||||
snprintf(pciPath, sizeof(pciPath), "/sys/bus/pci/devices/%04x:%02x:%02x.0/numa_node", prop.pciDomainID,
|
||||
prop.pciBusID, prop.pciDeviceID);
|
||||
std::ifstream numaFile(pciPath);
|
||||
if (numaFile.is_open())
|
||||
{
|
||||
numaFile >> numaNode;
|
||||
numaFile.close();
|
||||
// If NUMA node is -1, it means no specific NUMA information, use node 0
|
||||
if (numaNode < 0)
|
||||
{
|
||||
numaNode = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Fallback if sysfs path is not available or readable
|
||||
TLLM_LOG_DEBUG("Could not open %s to determine NUMA node for GPU %d. Defaulting to node 0.",
|
||||
pciPath, deviceId);
|
||||
numaNode = 0;
|
||||
}
|
||||
TLLM_LOG_INFO("GPU %d is on NUMA node %d", deviceId, numaNode);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_WARNING("Failed to get properties for GPU %d. Defaulting to NUMA node 0.", deviceId);
|
||||
numaNode = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// libnuma not available, default GPU to NUMA node 0
|
||||
numaNode = 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
mGpuToNumaMap[deviceId] = numaNode;
|
||||
mNumaToGpuMap[numaNode].push_back(deviceId);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __linux__
|
||||
|
||||
static void bitmask_copy_manual(struct bitmask* dst, const struct bitmask* src)
|
||||
{
|
||||
if (!dst || !src)
|
||||
return;
|
||||
numa_bitmask_clearall(dst);
|
||||
for (int i = 0; i < numa_num_possible_cpus(); ++i)
|
||||
{
|
||||
if (numa_bitmask_isbitset(src, i))
|
||||
{
|
||||
numa_bitmask_setbit(dst, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void bitmask_or_manual(struct bitmask* dst, const struct bitmask* src)
|
||||
{
|
||||
if (!dst || !src)
|
||||
return;
|
||||
for (int i = 0; i < numa_num_possible_cpus(); ++i)
|
||||
{
|
||||
if (numa_bitmask_isbitset(src, i))
|
||||
{
|
||||
numa_bitmask_setbit(dst, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TopologyDetector::precomputeCpuAffinityMasks()
|
||||
{
|
||||
int num_gpus = 0;
|
||||
cudaError_t err = cudaGetDeviceCount(&num_gpus);
|
||||
if (err != cudaSuccess || num_gpus == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for (int gpuId = 0; gpuId < num_gpus; ++gpuId)
|
||||
{
|
||||
auto itGpuNuma = mGpuToNumaMap.find(gpuId);
|
||||
if (itGpuNuma == mGpuToNumaMap.end())
|
||||
{
|
||||
TLLM_LOG_WARNING("GPU %d not found in mGpuToNumaMap during mask precomputation. Skipping.", gpuId);
|
||||
continue;
|
||||
}
|
||||
int gpuNumaNode = itGpuNuma->second;
|
||||
|
||||
// Strict Mask: CPUs on the GPU's direct NUMA node
|
||||
struct bitmask* strictMask = numa_allocate_cpumask(); // Uses numa_bitmask_alloc internally
|
||||
if (strictMask)
|
||||
{
|
||||
numa_bitmask_clearall(strictMask); // Initialize to empty
|
||||
if (mNumaToCpuCountMap.count(gpuNumaNode) && mNumaToCpuCountMap.at(gpuNumaNode) > 0)
|
||||
{
|
||||
if (numa_node_to_cpus(gpuNumaNode, strictMask) != 0)
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"Failed to get CPUs for GPU %d's NUMA node %d for strict mask. Strict mask will be empty.",
|
||||
gpuId, gpuNumaNode);
|
||||
numa_bitmask_clearall(strictMask); // Ensure it's empty on failure
|
||||
}
|
||||
}
|
||||
mGpuStrictCpuMasks[gpuId] = strictMask;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_WARNING("Failed to allocate strict CPU mask for GPU %d.", gpuId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const struct bitmask* TopologyDetector::getStrictCpuMaskForGpu(int gpuId) const
|
||||
{
|
||||
auto it = mGpuStrictCpuMasks.find(gpuId);
|
||||
if (it != mGpuStrictCpuMasks.end())
|
||||
{
|
||||
return it->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void TopologyDetector::bindThreadByCurrentGpu()
|
||||
{
|
||||
#ifdef __linux__
|
||||
if (numa_available() == -1)
|
||||
{
|
||||
TLLM_LOG_WARNING("libnuma not available. Cannot bind thread to NUMA node.");
|
||||
return;
|
||||
}
|
||||
|
||||
int currentDevice = -1;
|
||||
if (cudaGetDevice(¤tDevice) != cudaSuccess)
|
||||
{
|
||||
TLLM_LOG_WARNING("Failed to get current CUDA device. Cannot bind thread.");
|
||||
return;
|
||||
}
|
||||
|
||||
const struct bitmask* targetMask = nullptr;
|
||||
targetMask = getStrictCpuMaskForGpu(currentDevice);
|
||||
|
||||
if (targetMask)
|
||||
{
|
||||
// Check if the mask is not all clear before attempting to set affinity
|
||||
bool maskIsClear = true;
|
||||
for (int k = 0; k < numa_num_possible_cpus(); ++k)
|
||||
{
|
||||
if (numa_bitmask_isbitset(targetMask, k))
|
||||
{
|
||||
maskIsClear = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!maskIsClear)
|
||||
{
|
||||
// Create a mutable copy of the targetMask to pass to numa_sched_setaffinity
|
||||
struct bitmask* mutableCopyForAffinity = numa_allocate_cpumask();
|
||||
if (mutableCopyForAffinity)
|
||||
{
|
||||
bitmask_copy_manual(mutableCopyForAffinity, targetMask);
|
||||
if (numa_sched_setaffinity(0, mutableCopyForAffinity) == -1)
|
||||
{ // 0 refers to the current thread
|
||||
TLLM_LOG_WARNING("Failed to set thread affinity for GPU %d using precomputed mask. Error: %s",
|
||||
currentDevice, strerror(errno));
|
||||
}
|
||||
numa_free_cpumask(mutableCopyForAffinity);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"Failed to allocate temporary bitmask for setting affinity. Cannot bind thread for GPU %d.",
|
||||
currentDevice);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_DEBUG("Target affinity mask for GPU %d is empty. Not setting affinity.", currentDevice);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_WARNING("Precomputed CPU affinity mask not found for GPU %d. Cannot bind thread.", currentDevice);
|
||||
}
|
||||
|
||||
#else
|
||||
TLLM_LOG_DEBUG("Thread binding by GPU NUMA node is only supported on Linux with libnuma.");
|
||||
#endif
|
||||
}
|
||||
|
||||
int TopologyDetector::getCurrentGpuNumaCpuCount()
|
||||
{
|
||||
int numaId = getCurrentGpuNumaId();
|
||||
if (numaId >= 0)
|
||||
{
|
||||
auto it = mNumaToCpuCountMap.find(numaId);
|
||||
if (it != mNumaToCpuCountMap.end())
|
||||
{
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG(
|
||||
"CPU count for GPU's NUMA node %d not found or node invalid. Returning total hardware concurrency.", numaId);
|
||||
return std::thread::hardware_concurrency();
|
||||
}
|
||||
|
||||
int TopologyDetector::getCurrentGpuNumaId()
|
||||
{
|
||||
int currentDevice = -1;
|
||||
if (cudaGetDevice(¤tDevice) != cudaSuccess)
|
||||
{
|
||||
return -1; // Indicate error or no CUDA device context
|
||||
}
|
||||
|
||||
auto it = mGpuToNumaMap.find(currentDevice);
|
||||
if (it != mGpuToNumaMap.end())
|
||||
{
|
||||
return it->second;
|
||||
}
|
||||
TLLM_LOG_WARNING("NUMA node for current GPU %d not found in map. Defaulting to node 0.", currentDevice);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int TopologyDetector::getGpuCountUnderNuma(int numaId)
|
||||
{
|
||||
auto it = mNumaToGpuMap.find(numaId);
|
||||
if (it != mNumaToGpuMap.end())
|
||||
{
|
||||
return it->second.size();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string TopologyDetector::getCpuArchitecture()
|
||||
{
|
||||
return mCpuArchitecture;
|
||||
}
|
||||
|
||||
bool TopologyDetector::canSupportHostNativeAtomics()
|
||||
{
|
||||
int currentDevice = -1;
|
||||
if (cudaGetDevice(¤tDevice) != cudaSuccess)
|
||||
{
|
||||
TLLM_LOG_WARNING("Failed to get current CUDA device for atomic support check.");
|
||||
return false;
|
||||
}
|
||||
|
||||
int hostNativeAtomicSupported = 0;
|
||||
cudaError_t err
|
||||
= cudaDeviceGetAttribute(&hostNativeAtomicSupported, cudaDevAttrHostNativeAtomicSupported, currentDevice);
|
||||
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
TLLM_LOG_WARNING("Failed to get cudaDevAttrHostNativeAtomicSupported for device %d. Error: %s", currentDevice,
|
||||
cudaGetErrorString(err));
|
||||
return false;
|
||||
}
|
||||
return static_cast<bool>(hostNativeAtomicSupported);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
100
cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.h
Normal file
100
cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.h
Normal file
@ -0,0 +1,100 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#ifdef __linux__
|
||||
#include <numa.h> // For libnuma
|
||||
#endif
|
||||
|
||||
// Forward declaration for struct bitmask to avoid including numaif.h if numa.h already covers it,
|
||||
// or if only numa.h is intended to be the public include for this header's users.
|
||||
#ifdef __linux__
|
||||
struct bitmask;
|
||||
#endif
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
class TopologyDetector
|
||||
{
|
||||
public:
|
||||
static TopologyDetector& getInstance()
|
||||
{
|
||||
static TopologyDetector instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
~TopologyDetector();
|
||||
|
||||
// Binds the current thread to the CPU cores of the NUMA node associated with the current GPU.
|
||||
void bindThreadByCurrentGpu();
|
||||
|
||||
// Returns the number of CPU cores on the NUMA node associated with the current GPU.
|
||||
// Returns total hardware concurrency as a fallback if specific count cannot be determined.
|
||||
int getCurrentGpuNumaCpuCount();
|
||||
|
||||
// Returns the ID of the NUMA node associated with the current GPU.
|
||||
// Returns 0 as a default or -1 on error.
|
||||
int getCurrentGpuNumaId();
|
||||
|
||||
// Returns the number of GPUs associated with the given NUMA node ID.
|
||||
int getGpuCountUnderNuma(int numaId);
|
||||
|
||||
// Returns the number of GPUs which have same NUMA node ID with the current GPU.
|
||||
int getGpuCountUnderSameNuma()
|
||||
{
|
||||
return getGpuCountUnderNuma(getCurrentGpuNumaId());
|
||||
}
|
||||
|
||||
// Returns the detected CPU architecture (e.g., "x86_64", "aarch64").
|
||||
std::string getCpuArchitecture();
|
||||
|
||||
// Checks if the current CUDA device and host system support native atomic operations.
|
||||
bool canSupportHostNativeAtomics();
|
||||
|
||||
#ifdef __linux__
|
||||
// Getters for precomputed CPU affinity masks
|
||||
const struct bitmask* getStrictCpuMaskForGpu(int gpuId) const;
|
||||
#endif
|
||||
|
||||
private:
|
||||
TopologyDetector();
|
||||
void detectCpuTopology(); // Detects CPU NUMA topology and CPU counts per node.
|
||||
void detectGpuTopology(); // Detects GPU to NUMA node mapping.
|
||||
#ifdef __linux__
|
||||
void precomputeCpuAffinityMasks(); // Precomputes CPU masks for each GPU
|
||||
#endif
|
||||
|
||||
// Member variables
|
||||
std::map<int, int> mGpuToNumaMap; // GPU ID -> NUMA Node ID
|
||||
std::map<int, std::vector<int>> mNumaToGpuMap; // NUMA Node ID -> List of GPU IDs
|
||||
std::map<int, int> mNumaToCpuCountMap; // NUMA Node ID -> CPU Core Count
|
||||
std::string mCpuArchitecture;
|
||||
bool mTopologyDetected = false;
|
||||
std::mutex mDetectionMutex; // Mutex to protect topology detection process
|
||||
|
||||
#ifdef __linux__
|
||||
// Precomputed CPU affinity masks
|
||||
std::map<int, struct bitmask*> mGpuStrictCpuMasks; // GPU ID -> Strict CPU mask
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -20,12 +20,14 @@
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
#include <c10/core/Allocator.h> // for c10::DataPtr
|
||||
#include <c10/core/StorageImpl.h> // for c10::StorageImpl and use_byte_size_t()
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/util/intrusive_ptr.h> // for c10::make_intrusive#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
|
||||
#include "tensorrt_llm/runtime/moeLoadBalancer.h"
|
||||
#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h"
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
@ -109,6 +111,43 @@ torch::Tensor moeLoadBalanceRouting(
|
||||
return tokenRoutedSlotIds;
|
||||
}
|
||||
|
||||
void migrateToManaged(at::Tensor& tensor)
|
||||
{
|
||||
TORCH_CHECK(tensor.device().is_cuda(), "only support CUDA Tensor");
|
||||
|
||||
// 1) compute total bytes
|
||||
size_t byte_size = tensor.numel() * tensor.element_size();
|
||||
|
||||
// 2) allocate UVM
|
||||
void* managed_ptr = nullptr;
|
||||
cudaError_t err = cudaMallocManaged(&managed_ptr, byte_size);
|
||||
TORCH_CHECK(err == cudaSuccess, "cudaMallocManaged failed");
|
||||
|
||||
// 3) advise to place on current GPU
|
||||
int cur_dev;
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&cur_dev));
|
||||
TLLM_CUDA_CHECK(cudaMemAdvise(managed_ptr, byte_size, cudaMemAdviseSetPreferredLocation, cur_dev));
|
||||
TLLM_CUDA_CHECK(cudaMemAdvise(managed_ptr, byte_size, cudaMemAdviseSetAccessedBy, cur_dev));
|
||||
TLLM_CUDA_CHECK(cudaMemAdvise(managed_ptr, byte_size, cudaMemAdviseSetAccessedBy, cudaCpuDeviceId));
|
||||
|
||||
// 4) copy old data to UVM
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(managed_ptr, tensor.data_ptr(), byte_size, cudaMemcpyDeviceToDevice));
|
||||
|
||||
// 5) use new DataPtr/StorageImpl to construct storage
|
||||
// here managed_ptr is data,and also context,use cudaFree as deleter
|
||||
c10::DataPtr dp(
|
||||
managed_ptr, managed_ptr, [](void* ptr) { cudaFree(ptr); }, tensor.device());
|
||||
auto allocator = c10::GetAllocator(tensor.device().type());
|
||||
auto storage_impl = c10::make_intrusive<c10::StorageImpl>(c10::StorageImpl::use_byte_size_t(), byte_size,
|
||||
std::move(dp), allocator,
|
||||
/*resizable=*/false);
|
||||
at::Storage new_storage(storage_impl);
|
||||
|
||||
// Finally replace tensor's storage,offset = 0,shape and stride kept unchanged
|
||||
tensor.set_(new_storage,
|
||||
/*storage_offset=*/0, tensor.sizes().vec(), tensor.strides().vec());
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
@ -154,3 +193,13 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("moe_load_balance_routing", &torch_ext::moeLoadBalanceRouting);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("migrate_to_managed(Tensor tensor) -> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("migrate_to_managed", &torch_ext::migrateToManaged);
|
||||
}
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
|
||||
#include "tensorrt_llm/runtime/moeLoadBalancer.h"
|
||||
#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h"
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
@ -318,6 +318,8 @@ protected:
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(0));
|
||||
mLoadBalancer = std::make_unique<MoeLoadBalancer>(param.epRank, param.epSize, param.layerUpdatesPerIter);
|
||||
|
||||
mLoadBalancer->setUseGpuMemcpy(true);
|
||||
|
||||
// Create multiple MoE layers
|
||||
createLayers(param);
|
||||
|
||||
|
||||
@ -53,6 +53,8 @@ init_ubuntu() {
|
||||
llvm \
|
||||
libclang-rt-dev \
|
||||
libffi-dev \
|
||||
libnuma1 \
|
||||
libnuma-dev \
|
||||
python3-dev \
|
||||
python3-pip \
|
||||
python-is-python3 \
|
||||
@ -88,6 +90,8 @@ install_python_rockylinux() {
|
||||
llvm-toolset \
|
||||
lld \
|
||||
libffi-devel \
|
||||
numactl \
|
||||
numactl-devel \
|
||||
zlib-devel \
|
||||
xz-devel \
|
||||
sqlite-devel \
|
||||
|
||||
@ -28,10 +28,10 @@ UPLOAD_PATH = env.uploadPath ? env.uploadPath : "sw-tensorrt-generic/llm-artifac
|
||||
// Container configuration
|
||||
// available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/
|
||||
// [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id]
|
||||
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539"
|
||||
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539"
|
||||
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202505211401-4539"
|
||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202505211401-4539"
|
||||
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
||||
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
||||
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
||||
|
||||
// TODO: Move common variables to an unified location
|
||||
BUILD_CORES_REQUEST = "8"
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
|
||||
import java.lang.InterruptedException
|
||||
|
||||
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539"
|
||||
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
||||
|
||||
def createKubernetesPodConfig(image)
|
||||
{
|
||||
|
||||
@ -440,7 +440,7 @@ def main(*,
|
||||
with working_directory(build_dir):
|
||||
if clean or first_build or configure_cmake:
|
||||
build_run(
|
||||
f"\"{venv_conan}\" install --remote=tensorrt-llm --output-folder={build_dir}/conan -s 'build_type={build_type}' {source_dir}"
|
||||
f"\"{venv_conan}\" install --build=missing --remote=tensorrt-llm --output-folder={build_dir}/conan -s 'build_type={build_type}' {source_dir}"
|
||||
)
|
||||
cmake_def_args.append(
|
||||
f"-DCMAKE_TOOLCHAIN_FILE={build_dir}/conan/conan_toolchain.cmake"
|
||||
|
||||
@ -23,18 +23,13 @@ class MoeLoadBalancerConfig:
|
||||
repr=False)
|
||||
layer_updates_per_iter: int = 0
|
||||
|
||||
num_experts: Optional[int] = field(default=None, init=False)
|
||||
ep_rank: Optional[int] = field(default=None, init=False)
|
||||
ep_size: Optional[int] = field(default=None, init=False)
|
||||
|
||||
def setup(self, num_experts: int, ep_rank: int, ep_size: int) -> None:
|
||||
self.num_experts = num_experts
|
||||
def setup(self, ep_rank: int, ep_size: int) -> None:
|
||||
self.ep_rank = ep_rank
|
||||
self.ep_size = ep_size
|
||||
if self.num_slots is None:
|
||||
self.num_slots = self.num_experts
|
||||
assert self.num_slots >= self.num_experts
|
||||
assert self.num_slots % self.ep_size == 0
|
||||
assert self.num_slots is not None
|
||||
|
||||
@property
|
||||
def num_local_slots(self) -> int:
|
||||
@ -49,17 +44,13 @@ class MoeLoadBalancerConfig:
|
||||
return self.slot_start + self.num_local_slots
|
||||
|
||||
def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]:
|
||||
if self.initial_global_assignments is None:
|
||||
return [(ep_rank * self.num_experts // self.ep_size + i) %
|
||||
self.num_experts for ep_rank in range(self.ep_size)
|
||||
for i in range(self.num_local_slots)]
|
||||
else:
|
||||
if self.initial_global_assignments is not None:
|
||||
assert layer_idx in self.initial_global_assignments
|
||||
assert len(
|
||||
self.initial_global_assignments[layer_idx]) == self.num_slots
|
||||
assert set(self.initial_global_assignments[layer_idx]) == set(
|
||||
range(self.num_experts))
|
||||
return self.initial_global_assignments[layer_idx]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
|
||||
@ -53,7 +53,7 @@ from ..modules.attention import MLA
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import (CutlassFusedMoE, DeepSeekV3MoeRoutingMethod,
|
||||
MoeLoadBalancer, create_moe)
|
||||
create_moe)
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
@ -344,7 +344,6 @@ class Deepseekv3MoE(nn.Module):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
model_config: ModelConfig = ModelConfig(),
|
||||
override_quant_config: Optional[QuantConfig] = None,
|
||||
moe_load_balancer: Optional[MoeLoadBalancer] = None,
|
||||
layer_idx: Optional[int] = None):
|
||||
from ..distributed import AllReduce
|
||||
|
||||
@ -379,7 +378,6 @@ class Deepseekv3MoE(nn.Module):
|
||||
override_quant_config=override_quant_config,
|
||||
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap],
|
||||
enable_alltoall=self.enable_alltoall,
|
||||
moe_load_balancer=moe_load_balancer,
|
||||
layer_idx=layer_idx)
|
||||
|
||||
self.mapping = model_config.mapping
|
||||
@ -542,11 +540,9 @@ class Deepseekv3MoE(nn.Module):
|
||||
|
||||
class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
|
||||
def __init__(self,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int,
|
||||
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
|
||||
moe_load_balancer: Optional[MoeLoadBalancer] = None):
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
|
||||
torch.cuda.Stream]):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
config = model_config.pretrained_config
|
||||
@ -598,7 +594,6 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
model_config=model_config,
|
||||
override_quant_config=quant_config,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
moe_load_balancer=moe_load_balancer,
|
||||
layer_idx=layer_idx)
|
||||
else:
|
||||
block_size = 1
|
||||
@ -865,13 +860,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
|
||||
class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
|
||||
def __init__(self,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int,
|
||||
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
|
||||
moe_load_balancer: Optional[MoeLoadBalancer] = None):
|
||||
super().__init__(model_config, layer_idx, aux_stream_dict,
|
||||
moe_load_balancer)
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
|
||||
torch.cuda.Stream]):
|
||||
super().__init__(model_config, layer_idx, aux_stream_dict)
|
||||
config = model_config.pretrained_config
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.moe_intermediate_size = config.moe_intermediate_size
|
||||
@ -992,23 +984,9 @@ class DeepseekV3Model(DecoderModel):
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
self.moe_load_balancer = None
|
||||
if model_config.moe_load_balancer is not None:
|
||||
num_experts = config.n_routed_experts
|
||||
ep_rank = model_config.mapping.moe_ep_rank
|
||||
ep_size = model_config.mapping.moe_ep_size
|
||||
model_config.moe_load_balancer.setup(num_experts=num_experts,
|
||||
ep_rank=ep_rank,
|
||||
ep_size=ep_size)
|
||||
self.moe_load_balancer = MoeLoadBalancer(
|
||||
ep_rank=ep_rank,
|
||||
ep_size=ep_size,
|
||||
layer_updates_per_iter=model_config.moe_load_balancer.
|
||||
layer_updates_per_iter)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
DeepseekV3DecoderLayer(model_config, layer_idx,
|
||||
self.aux_stream_dict, self.moe_load_balancer)
|
||||
self.aux_stream_dict)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(hidden_size=config.hidden_size,
|
||||
@ -1054,7 +1032,6 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size)
|
||||
|
||||
self.moe_load_balancer = self.model.moe_load_balancer
|
||||
self.model_nextn = 0
|
||||
if model_config.spec_config is not None:
|
||||
model_nextn = model_config.spec_config.num_nextn_predict_layers
|
||||
@ -1063,8 +1040,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
|
||||
if ckpt_nextn == 1:
|
||||
mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers,
|
||||
self.model.aux_stream_dict,
|
||||
self.moe_load_balancer)
|
||||
self.model.aux_stream_dict)
|
||||
self.model.layers.append(mtp_layer)
|
||||
self.epilogue.append(mtp_layer)
|
||||
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
|
||||
@ -1074,8 +1050,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
mtp_layers = nn.ModuleList([
|
||||
DeepseekV3MTP(model_config,
|
||||
layer_idx + self.num_hidden_layers,
|
||||
self.model.aux_stream_dict,
|
||||
self.moe_load_balancer)
|
||||
self.model.aux_stream_dict)
|
||||
for layer_idx in range(model_nextn)
|
||||
])
|
||||
self.model.layers.extend(mtp_layers)
|
||||
@ -1100,9 +1075,6 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
extend_exclude_modules)
|
||||
self.epilogue.append(self.mtp_worker)
|
||||
|
||||
if self.moe_load_balancer is not None:
|
||||
self.moe_load_balancer.finalize_model()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
|
||||
@ -10,7 +10,7 @@ from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
|
||||
from .fused_moe_vanilla import VanillaMoE
|
||||
from .interface import MoE, MoEWeightLoadingMode
|
||||
from .moe_load_balancer import MoeLoadBalancer
|
||||
from .moe_load_balancer import get_moe_load_balancer
|
||||
from .routing import BaseMoeRoutingMethod
|
||||
|
||||
|
||||
@ -53,15 +53,17 @@ def create_moe(
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_alltoall: bool = False,
|
||||
moe_load_balancer: Optional[MoeLoadBalancer] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
) -> MoE:
|
||||
moe_cls = get_moe_cls(model_config, override_quant_config)
|
||||
|
||||
moe_load_balancer = get_moe_load_balancer()
|
||||
if moe_load_balancer is not None:
|
||||
assert moe_cls == CutlassFusedMoE, "MoE Load Balance is only supported in CutlassFusedMoE now."
|
||||
|
||||
if moe_cls == TRTLLMGenFusedMoE:
|
||||
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE."
|
||||
assert not enable_alltoall, "enable_alltoall is not supported in TRTLLMGenFusedMoE."
|
||||
assert moe_load_balancer is None, "moe_load_balancer is not supported in TRTLLMGenFusedMoE."
|
||||
|
||||
return moe_cls(
|
||||
routing_method=routing_method,
|
||||
@ -87,13 +89,11 @@ def create_moe(
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
enable_alltoall=enable_alltoall,
|
||||
moe_load_balancer=moe_load_balancer,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
elif moe_cls == VanillaMoE:
|
||||
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE."
|
||||
assert not enable_alltoall, "enable_alltoall is not supported in VanillaMoE."
|
||||
assert moe_load_balancer is None, "moe_load_balancer is not supported in VanillaMoE."
|
||||
|
||||
return moe_cls(
|
||||
routing_method=routing_method,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -8,11 +8,11 @@ from tensorrt_llm._utils import logger
|
||||
|
||||
from ...distributed import allgather, reducescatter
|
||||
from ...expert_statistic import ExpertStatistic
|
||||
from ...model_config import ModelConfig, MoeLoadBalancerConfig
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
|
||||
reswizzle_sf, swizzle_sf, unswizzle_sf)
|
||||
from .interface import MoE
|
||||
from .moe_load_balancer import MoeLoadBalancer
|
||||
from .moe_load_balancer import get_moe_load_balancer
|
||||
from .quantization import (FP8BlockScalesFusedMoEMethod, FP8QDQFusedMoEMethod,
|
||||
MoEWeightLoadingMode, NVFP4CutlassFusedMoEMethod,
|
||||
UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod)
|
||||
@ -82,7 +82,6 @@ class CutlassFusedMoE(MoE):
|
||||
VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_alltoall: bool = False,
|
||||
moe_load_balancer: Optional[MoeLoadBalancer] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
|
||||
@ -98,46 +97,56 @@ class CutlassFusedMoE(MoE):
|
||||
)
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
moe_load_balancer = get_moe_load_balancer()
|
||||
self.layer_load_balancer = None
|
||||
|
||||
moe_load_balancer_config = model_config.moe_load_balancer
|
||||
if moe_load_balancer_config is None:
|
||||
assert moe_load_balancer is None
|
||||
# A dummy MoeLoadBalancerConfig to generate default initial_global_assignments
|
||||
moe_load_balancer_config = MoeLoadBalancerConfig()
|
||||
moe_load_balancer_config.setup(num_experts=num_experts,
|
||||
ep_rank=self.ep_rank,
|
||||
ep_size=self.ep_size)
|
||||
else:
|
||||
assert moe_load_balancer is not None
|
||||
init_expert_size_per_partition = moe_load_balancer_config.num_local_slots if moe_load_balancer_config else self.num_experts // self.ep_size
|
||||
self.initial_global_assignments = [
|
||||
(ep_rank * self.num_experts // self.ep_size + local_slot_id) %
|
||||
self.num_experts for ep_rank in range(self.ep_size)
|
||||
for local_slot_id in range(init_expert_size_per_partition)
|
||||
]
|
||||
|
||||
self.num_slots = moe_load_balancer_config.num_slots
|
||||
if self.smart_router:
|
||||
assert self.num_slots == self.num_experts, "Smart router should not have redundant slots"
|
||||
|
||||
self.initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments(
|
||||
layer_idx)
|
||||
self.expert_size_per_partition = moe_load_balancer_config.num_local_slots
|
||||
self.slot_start = moe_load_balancer_config.slot_start
|
||||
self.slot_end = moe_load_balancer_config.slot_end
|
||||
self.initial_local_expert_ids = self.initial_global_assignments[
|
||||
self.slot_start:self.slot_end]
|
||||
assert len(
|
||||
self.initial_local_expert_ids) == self.expert_size_per_partition
|
||||
|
||||
self.balancer_layer = None
|
||||
if moe_load_balancer is not None:
|
||||
self.balancer_layer = moe_load_balancer.add_layer(
|
||||
expert_count=num_experts,
|
||||
top_k=routing_method.experts_per_token,
|
||||
slot_count_per_rank=self.expert_size_per_partition,
|
||||
)
|
||||
self.balancer_layer.set_initial_weight_assignments(
|
||||
if moe_load_balancer:
|
||||
assert moe_load_balancer_config is not None
|
||||
top_k = self.routing_method.experts_per_token
|
||||
self.expert_size_per_partition = moe_load_balancer_config.num_local_slots
|
||||
self.layer_load_balancer = moe_load_balancer.add_layer(
|
||||
self.num_experts, top_k, self.expert_size_per_partition)
|
||||
loaded_initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments(
|
||||
self.layer_idx)
|
||||
self.num_slots = moe_load_balancer_config.num_slots
|
||||
if loaded_initial_global_assignments is not None:
|
||||
assert isinstance(loaded_initial_global_assignments, list)
|
||||
assert len(loaded_initial_global_assignments) == self.num_slots
|
||||
assert self.num_slots >= self.num_experts
|
||||
assert set(loaded_initial_global_assignments) == set(
|
||||
range(self.num_experts))
|
||||
self.initial_global_assignments = loaded_initial_global_assignments
|
||||
self.layer_load_balancer.set_initial_weight_assignments(
|
||||
self.initial_global_assignments)
|
||||
logger.info(
|
||||
f"MoE load balancer enabled. num_experts = {num_experts}, num_slots = {self.num_slots}, ep_size = {self.ep_size}"
|
||||
)
|
||||
logger.info(
|
||||
f"initial_global_assignments (layer {layer_idx}) = {self.initial_global_assignments}"
|
||||
f"initial_global_assignments (layer {self.layer_idx}) = {self.initial_global_assignments}"
|
||||
)
|
||||
else:
|
||||
assert num_experts % self.ep_size == 0
|
||||
self.expert_size_per_partition = num_experts // self.ep_size
|
||||
self.num_slots = num_experts
|
||||
|
||||
if self.smart_router:
|
||||
assert self.num_slots == self.num_experts, "Smart router should not have redundant slots"
|
||||
|
||||
self.slot_start = self.ep_rank * self.expert_size_per_partition
|
||||
self.slot_end = self.slot_start + self.expert_size_per_partition
|
||||
self.initial_local_expert_ids = self.initial_global_assignments[
|
||||
self.slot_start:self.slot_end]
|
||||
assert len(
|
||||
self.initial_local_expert_ids) == self.expert_size_per_partition
|
||||
|
||||
max_num_tokens = model_config.max_num_tokens
|
||||
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
|
||||
@ -259,13 +268,14 @@ class CutlassFusedMoE(MoE):
|
||||
return outputs
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
cutlass_min_latency_mode: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
cutlass_min_latency_mode: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: Tuple = (True, True),
|
||||
) -> torch.Tensor:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert output_dtype is not None
|
||||
@ -273,31 +283,25 @@ class CutlassFusedMoE(MoE):
|
||||
else:
|
||||
output_dtype = x.dtype
|
||||
|
||||
is_first_call, is_last_call = repeating_info
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
) and is_first_call:
|
||||
self.layer_load_balancer.wait_for_gpu_stage()
|
||||
|
||||
use_fp8_block_scaling = False
|
||||
use_w4a8_group_scaling = False
|
||||
weight_dtype = self.w3_w1_weight.dtype
|
||||
|
||||
token_selected_experts, token_final_scales = self.routing_method.apply(
|
||||
router_logits)
|
||||
if self.balancer_layer is None:
|
||||
token_selected_slots = token_selected_experts
|
||||
else:
|
||||
# If attention DP is enabled, token_selected_experts is a local rank tensor,
|
||||
# so we need to offset the round robin position by ep_rank
|
||||
token_selected_slots = self.balancer_layer.route(
|
||||
token_selected_experts, offset_by_ep_rank=self.use_dp)
|
||||
|
||||
# If load balancer is disabled, the statistics are collected from expert IDs.
|
||||
# If load balancer is enabled, the statistics are collected from expert slot IDs.
|
||||
ExpertStatistic.set_layer(self.layer_idx)
|
||||
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
|
||||
|
||||
assert token_selected_slots.shape[
|
||||
assert token_selected_experts.shape[
|
||||
1] == self.routing_method.experts_per_token
|
||||
assert token_selected_slots.shape == token_final_scales.shape
|
||||
assert token_selected_slots.shape[0] == router_logits.shape[0]
|
||||
assert token_selected_experts.shape == token_final_scales.shape
|
||||
assert token_selected_experts.shape[0] == router_logits.shape[0]
|
||||
assert token_final_scales.dtype == torch.float32
|
||||
assert token_selected_slots.dtype == torch.int32
|
||||
assert token_selected_experts.dtype == torch.int32
|
||||
|
||||
if self.apply_router_weight_on_input:
|
||||
assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
|
||||
@ -310,13 +314,32 @@ class CutlassFusedMoE(MoE):
|
||||
|
||||
alltoall_info = None
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
) and is_first_call:
|
||||
self.layer_load_balancer.maybe_cudagraph_done_wait()
|
||||
|
||||
need_statistic = False
|
||||
if self.layer_load_balancer is None:
|
||||
token_selected_slots = token_selected_experts
|
||||
else:
|
||||
token_selected_slots = self.layer_load_balancer.route(
|
||||
token_selected_experts, self.use_dp)
|
||||
if not self.layer_load_balancer.is_static_routing():
|
||||
need_statistic = True
|
||||
|
||||
# If load balancer is disabled, the statistics are collected from expert IDs.
|
||||
# If load balancer is enabled, the statistics are collected from expert slot IDs.
|
||||
ExpertStatistic.set_layer(self.layer_idx)
|
||||
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
|
||||
|
||||
token_selected_experts_for_statistic = token_selected_experts if need_statistic else None
|
||||
if self.enable_alltoall:
|
||||
x, token_selected_slots, token_final_scales, alltoall_info = \
|
||||
x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \
|
||||
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
|
||||
x,
|
||||
token_selected_slots,
|
||||
token_final_scales)
|
||||
|
||||
token_final_scales,
|
||||
token_selected_experts_for_statistic)
|
||||
x_sf = None
|
||||
if self.has_any_quant:
|
||||
if self.has_fp8_qdq:
|
||||
@ -348,8 +371,11 @@ class CutlassFusedMoE(MoE):
|
||||
|
||||
if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
|
||||
) and not self.enable_alltoall:
|
||||
x, x_sf, token_selected_slots, token_final_scales = allgather(
|
||||
[x, x_sf, token_selected_slots, token_final_scales],
|
||||
x, x_sf, token_selected_slots, token_final_scales, token_selected_experts_for_statistic = allgather(
|
||||
[
|
||||
x, x_sf, token_selected_slots, token_final_scales,
|
||||
token_selected_experts_for_statistic
|
||||
],
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
@ -358,6 +384,12 @@ class CutlassFusedMoE(MoE):
|
||||
x_sf = reswizzle_sf(x_sf, x_row, x_col,
|
||||
self.scaling_vector_size)
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
):
|
||||
self.layer_load_balancer.statistic(
|
||||
token_selected_experts_for_statistic, is_first_call,
|
||||
is_last_call)
|
||||
|
||||
if self.smart_router and not cutlass_min_latency_mode:
|
||||
ep_size = self.cluster_size
|
||||
ep_rank = self.cluster_rank
|
||||
@ -405,20 +437,29 @@ class CutlassFusedMoE(MoE):
|
||||
tune_max_num_tokens=self.tune_max_num_tokens,
|
||||
)
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
) and is_last_call:
|
||||
self.layer_load_balancer.set_cpu_stage()
|
||||
|
||||
if cutlass_min_latency_mode:
|
||||
assert not self.reduce_results
|
||||
return final_hidden_states
|
||||
assert not self.enable_alltoall
|
||||
else:
|
||||
# Custom op requires all inputs are in the same type.
|
||||
# Only in cutlass_min_latency_mode, the output is a list of tensors.
|
||||
# Otherwise, the output should be unpacked as a single tensor.
|
||||
final_hidden_states = final_hidden_states[0]
|
||||
|
||||
if not self.enable_alltoall:
|
||||
return final_hidden_states
|
||||
else:
|
||||
return self.alltoall_combine(final_hidden_states, alltoall_info,
|
||||
token_count)
|
||||
if self.enable_alltoall:
|
||||
final_hidden_states = self.alltoall_combine(final_hidden_states,
|
||||
alltoall_info,
|
||||
token_count)
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
) and is_last_call:
|
||||
self.layer_load_balancer.maybe_cudagraph_done_set_cpu_stage()
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -500,6 +541,8 @@ class CutlassFusedMoE(MoE):
|
||||
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
|
||||
for idx_chunk, (x, router_logits) in enumerate(
|
||||
zip(x_list, router_logits_list)):
|
||||
is_first_call = idx_chunk == 0
|
||||
is_last_call = idx_chunk == num_chunks - 1
|
||||
if not self.enable_alltoall:
|
||||
if idx_chunk % 2 == 0:
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
@ -508,7 +551,8 @@ class CutlassFusedMoE(MoE):
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[
|
||||
idx_chunk] if self.use_dp else None,
|
||||
use_dp_padding=use_dp_padding)
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
if idx_chunk > 0:
|
||||
outputs_list[-1] = self.reducescatter_or_allreduce(
|
||||
outputs_list[-1],
|
||||
@ -521,7 +565,8 @@ class CutlassFusedMoE(MoE):
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[
|
||||
idx_chunk] if self.use_dp else None,
|
||||
use_dp_padding=use_dp_padding)
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
outputs_list[-1] = self.reducescatter_or_allreduce(
|
||||
outputs_list[-1],
|
||||
@ -533,7 +578,8 @@ class CutlassFusedMoE(MoE):
|
||||
x,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk]
|
||||
if self.use_dp else None)
|
||||
if self.use_dp else None,
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
|
||||
outputs_list.append(outputs)
|
||||
if not self.enable_alltoall:
|
||||
@ -557,32 +603,48 @@ class CutlassFusedMoE(MoE):
|
||||
outputs = outputs[:all_rank_num_tokens[rank]]
|
||||
return outputs
|
||||
|
||||
def alltoall_prepare_maybe_dispatch(self, all_rank_num_tokens: list,
|
||||
x: torch.Tensor,
|
||||
token_selected_slots: torch.Tensor,
|
||||
token_final_scales: torch.Tensor):
|
||||
def alltoall_prepare_maybe_dispatch(
|
||||
self, all_rank_num_tokens: list, x: torch.Tensor,
|
||||
token_selected_slots: torch.Tensor,
|
||||
token_final_scales: torch.Tensor,
|
||||
token_selected_experts_for_statistic: Optional[torch.Tensor]):
|
||||
top_k = self.routing_method.experts_per_token
|
||||
expert_count = self.num_experts
|
||||
# gather router info
|
||||
max_num_token = max(all_rank_num_tokens)
|
||||
token_selected_slots = torch.nn.functional.pad(
|
||||
token_selected_slots,
|
||||
(0, 0, 0, max_num_token - token_selected_slots.shape[0]),
|
||||
'constant', self.num_experts)
|
||||
'constant', self.num_slots)
|
||||
token_selected_experts_for_statistic = torch.nn.functional.pad(
|
||||
token_selected_experts_for_statistic,
|
||||
(0, 0, 0,
|
||||
max_num_token - token_selected_experts_for_statistic.shape[0]),
|
||||
'constant', self.num_experts
|
||||
) if token_selected_experts_for_statistic is not None else None
|
||||
token_final_scales = torch.nn.functional.pad(
|
||||
token_final_scales,
|
||||
(0, 0, 0, max_num_token - token_final_scales.shape[0]))
|
||||
gathered_token_selected_slots, gathered_token_final_scales = allgather(
|
||||
[token_selected_slots, token_final_scales], self.mapping, dim=0)
|
||||
gathered_token_selected_slots, gathered_token_final_scales, gathered_token_selected_experts_for_statistic = allgather(
|
||||
[
|
||||
token_selected_slots, token_final_scales,
|
||||
token_selected_experts_for_statistic
|
||||
],
|
||||
self.mapping,
|
||||
dim=0)
|
||||
if gathered_token_selected_experts_for_statistic is not None:
|
||||
gathered_token_selected_experts_for_statistic = torch.flatten(
|
||||
gathered_token_selected_experts_for_statistic.contiguous(),
|
||||
start_dim=0,
|
||||
end_dim=-2)
|
||||
gathered_token_selected_slots = torch.flatten(
|
||||
gathered_token_selected_slots.contiguous(), start_dim=0, end_dim=-2)
|
||||
gathered_token_final_scales = torch.flatten(
|
||||
gathered_token_final_scales.contiguous(), start_dim=0, end_dim=-2)
|
||||
gathered_target_rank_ids = MnnvlMoe.compute_target_rank_id(
|
||||
gathered_token_selected_slots, self.num_experts, self.ep_size)
|
||||
gathered_token_selected_slots, self.num_slots, self.ep_size)
|
||||
alltoall_info, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv_prepare(
|
||||
gathered_target_rank_ids, None, gathered_token_selected_slots,
|
||||
gathered_token_final_scales, max_num_token, expert_count, top_k,
|
||||
gathered_token_final_scales, max_num_token, self.num_slots, top_k,
|
||||
self.ep_rank, self.ep_size)
|
||||
|
||||
if not self.use_postquant_alltoall:
|
||||
@ -593,7 +655,7 @@ class CutlassFusedMoE(MoE):
|
||||
self.alltoall_workspace,
|
||||
self.ep_rank, self.ep_size)
|
||||
|
||||
return x, token_selected_slots, token_final_scales, alltoall_info
|
||||
return x, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic, alltoall_info
|
||||
|
||||
def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor,
|
||||
x_row: int, x_col: int,
|
||||
@ -633,6 +695,63 @@ class CutlassFusedMoE(MoE):
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def register_parameter_weight_slot_fn(self, weight_name: str,
|
||||
local_slot_id: int):
|
||||
assert hasattr(
|
||||
self,
|
||||
weight_name), f"FusedMoE doesn't has weight attr: {weight_name}"
|
||||
weight_tensor = getattr(self, weight_name).data[local_slot_id]
|
||||
self.layer_load_balancer.register_weight_slot(local_slot_id,
|
||||
weight_name,
|
||||
weight_tensor)
|
||||
|
||||
def register_to_fix_weight_fn(self, weight_name: str):
|
||||
assert hasattr(
|
||||
self,
|
||||
weight_name), f"FusedMoE doesn't has weight attr: {weight_name}"
|
||||
param = getattr(self, weight_name)
|
||||
weight_tensor = param.detach()
|
||||
assert isinstance(
|
||||
weight_tensor,
|
||||
torch.Tensor), f'weight {weight_name} should be a tensor'
|
||||
assert weight_tensor.is_contiguous(
|
||||
), f'weight {weight_name} should be a is_contiguous, shape={weight_tensor.shape}, strides={weight_tensor.is_contiguous()}'
|
||||
assert weight_tensor.numel() * weight_tensor.element_size() == weight_tensor.untyped_storage().size(),\
|
||||
f'weight {weight_name} shape={weight_tensor.shape} storage_size = {weight_tensor.untyped_storage().size()}, numel={weight_tensor.numel()}, eltsize={weight_tensor.element_size()}, dtype={weight_tensor.dtype}'
|
||||
self.layer_load_balancer.fix_tensor(weight_tensor)
|
||||
param.data = weight_tensor
|
||||
|
||||
def register_all_parameter_slot_and_to_fix_weight_fns(
|
||||
self, weight_and_tensor_dict: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
weight_and_tensor_dict: key is the name of the weight, value is the tensor of loaded shared tensor shard.
|
||||
E.g. if num_experts=256 and 4 GPUs per node, then each rank need to load 256 / 4 = 64 expert weights for host sharing.
|
||||
By this way, host_tensor_sharer can share the weights and each rank has access to all 256 experts.
|
||||
"""
|
||||
for local_slot_id, expert_id in enumerate(
|
||||
self.initial_local_expert_ids):
|
||||
for weight_name in weight_and_tensor_dict:
|
||||
self.layer_load_balancer.add_register_weight_fn(
|
||||
self.register_parameter_weight_slot_fn,
|
||||
(weight_name, local_slot_id))
|
||||
for weight_name in weight_and_tensor_dict:
|
||||
self.layer_load_balancer.add_to_fix_weight_fn(
|
||||
self.register_to_fix_weight_fn, (weight_name, ))
|
||||
|
||||
local_shared_load_expert_ids = self.layer_load_balancer.get_load_expert_ids(
|
||||
)
|
||||
for expert_id in range(self.num_experts):
|
||||
for weight_name, weight_tensor in weight_and_tensor_dict.items():
|
||||
if expert_id in local_shared_load_expert_ids:
|
||||
local_slot_id = local_shared_load_expert_ids.index(
|
||||
expert_id)
|
||||
self.layer_load_balancer.host_tensor_sharer.share_host_tensor_with_shape(
|
||||
expert_id, weight_name, weight_tensor[local_slot_id])
|
||||
else:
|
||||
self.layer_load_balancer.host_tensor_sharer.pre_register_host_tensor_with_shape(
|
||||
expert_id, weight_name, weight_tensor.dtype,
|
||||
weight_tensor[0].shape)
|
||||
|
||||
def load_weights(self, weights: List[Dict]):
|
||||
assert self._weights_created
|
||||
assert len(weights) == 1
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...model_config import ModelConfig, MoeLoadBalancerConfig
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import Fp4QuantizedTensor
|
||||
from .interface import MoE, MoEWeightLoadingMode
|
||||
from .quantization import (FP8BlockScalesFusedMoEMethod,
|
||||
@ -71,18 +71,15 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE."
|
||||
assert not self.use_dp, "AttentionDP is not supported in TRTLLMGenFusedMoE."
|
||||
|
||||
# A dummy MoeLoadBalancerConfig to generate default initial_global_assignments and initial_local_expert_ids
|
||||
moe_load_balancer_config = MoeLoadBalancerConfig()
|
||||
moe_load_balancer_config.setup(num_experts=num_experts,
|
||||
ep_rank=self.ep_rank,
|
||||
ep_size=self.ep_size)
|
||||
|
||||
self.num_slots = moe_load_balancer_config.num_slots
|
||||
self.initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments(
|
||||
layer_idx)
|
||||
self.expert_size_per_partition = moe_load_balancer_config.num_local_slots
|
||||
self.slot_start = moe_load_balancer_config.slot_start
|
||||
self.slot_end = moe_load_balancer_config.slot_end
|
||||
self.num_slots = self.num_experts
|
||||
self.expert_size_per_partition = self.num_experts // self.ep_size
|
||||
self.initial_global_assignments = [
|
||||
(ep_rank * self.num_experts // self.ep_size + local_slot_id) %
|
||||
self.num_experts for ep_rank in range(self.ep_size)
|
||||
for local_slot_id in range(self.expert_size_per_partition)
|
||||
]
|
||||
self.slot_start = self.ep_rank * self.expert_size_per_partition
|
||||
self.slot_end = self.slot_start + self.expert_size_per_partition
|
||||
self.initial_local_expert_ids = self.initial_global_assignments[
|
||||
self.slot_start:self.slot_end]
|
||||
assert len(
|
||||
|
||||
@ -1,13 +1,18 @@
|
||||
import atexit
|
||||
import platform
|
||||
import threading
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Callable, List, Optional
|
||||
from contextlib import nullcontext
|
||||
from multiprocessing import resource_tracker, shared_memory
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mpi4py import MPI
|
||||
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.bindings.internal.runtime as _tbr
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import is_graph_capturing
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
|
||||
def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight:
|
||||
@ -20,6 +25,7 @@ def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight:
|
||||
assert t.dim() <= 2, "t.dim() should be less than or equal to 2"
|
||||
shape = [1, 1]
|
||||
pitch = 1
|
||||
elt_size = torch.tensor([], dtype=t.dtype).element_size()
|
||||
if t.dim() == 2:
|
||||
shape[0] = t.size(0)
|
||||
shape[1] = t.size(1)
|
||||
@ -31,8 +37,8 @@ def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight:
|
||||
pass
|
||||
mw = _tbr.MoeWeight()
|
||||
mw.height = shape[0]
|
||||
mw.width = shape[1]
|
||||
mw.pitch = pitch
|
||||
mw.width = shape[1] * elt_size
|
||||
mw.pitch = pitch * elt_size
|
||||
mw.weight_ptr = t.data_ptr()
|
||||
return mw
|
||||
|
||||
@ -42,7 +48,8 @@ class HostMoeTensorSharer:
|
||||
A class representing a host tensor sharer.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_id: int, shared_mpi_comm: MPI.Comm):
|
||||
def __init__(self, layer_id: int, expert_count: int,
|
||||
shared_mpi_comm: MPI.Comm):
|
||||
"""
|
||||
Initialize a HostMoeTensorSharer instance.
|
||||
|
||||
@ -51,11 +58,24 @@ class HostMoeTensorSharer:
|
||||
"""
|
||||
self.shared_mpi_comm = shared_mpi_comm
|
||||
self.layer_id = layer_id
|
||||
self.expert_count = expert_count
|
||||
self.shared_memory_base_name = None
|
||||
self.host_tensor_shapes = []
|
||||
|
||||
self.local_rank = self.shared_mpi_comm.Get_rank()
|
||||
self.local_size = self.shared_mpi_comm.Get_size()
|
||||
|
||||
self.expert_start = self.local_rank * self.expert_count // self.local_size
|
||||
self.expert_end = (self.local_rank +
|
||||
1) * self.expert_count // self.local_size
|
||||
|
||||
self.name_info = {} # key is weight name, value is (dtype, shape)
|
||||
self.host_weights = {}
|
||||
self.own_shms = {}
|
||||
self.all_shms = []
|
||||
|
||||
self.own_shm = None
|
||||
self.imported_shms = []
|
||||
|
||||
self.shared_tensors = {}
|
||||
self.names = []
|
||||
|
||||
def set_shared_memory_base_name(self, shared_memory_base_name):
|
||||
"""
|
||||
@ -66,17 +86,19 @@ class HostMoeTensorSharer:
|
||||
"""
|
||||
self.shared_memory_base_name = shared_memory_base_name
|
||||
|
||||
def get_shared_memory_name(self, expert_id: int, name: str):
|
||||
def get_shared_memory_name(self, rank: Optional[int] = None):
|
||||
"""
|
||||
Get the shared memory name for the layer.
|
||||
|
||||
Args:
|
||||
expert_id: The ID of the expert
|
||||
name: The name of the weight
|
||||
rank: The rank who created the shared memory. Current rank if None
|
||||
"""
|
||||
if rank is None:
|
||||
rank = self.local_rank
|
||||
assert 0 <= rank < self.local_size
|
||||
assert isinstance(self.shared_memory_base_name,
|
||||
str), "self.shared_memory_base_name must be a string"
|
||||
shared_memory_name = f"{self.shared_memory_base_name}_l{self.layer_id}_e{expert_id}_{name}"
|
||||
shared_memory_name = f"{self.shared_memory_base_name}_l{self.layer_id}_lr{rank}_all"
|
||||
return shared_memory_name
|
||||
|
||||
def pre_register_host_tensor_with_shape(self, expert_id: int, name: str,
|
||||
@ -96,7 +118,13 @@ class HostMoeTensorSharer:
|
||||
"""
|
||||
assert len(tensor_shape
|
||||
) <= 2, "tensor_shape dim must be less than or equal to 2"
|
||||
self.host_tensor_shapes.append((expert_id, name, dtype, tensor_shape))
|
||||
assert 0 <= expert_id < self.expert_count
|
||||
assert expert_id < self.expert_start or expert_id >= self.expert_end
|
||||
if name not in self.name_info:
|
||||
self.name_info[name] = (dtype, tensor_shape)
|
||||
else:
|
||||
assert dtype == self.name_info[name][0] and tensor_shape == self.name_info[name][1], \
|
||||
f'weights name={name}, dtype={dtype}, shape={tensor_shape}, but already registered with dtype={self.name_info[name][0]}, shape={self.name_info[name][1]}'
|
||||
|
||||
def share_host_tensor_with_shape(self, expert_id: int, name: str,
|
||||
t: torch.Tensor):
|
||||
@ -111,38 +139,99 @@ class HostMoeTensorSharer:
|
||||
name: The name of the weight
|
||||
t: The weight tensor
|
||||
"""
|
||||
assert len(
|
||||
t.shape) <= 2, "tensor_shape dim must be less than or equal to 2"
|
||||
assert t.is_contiguous() == True, "t.is_contiguous() must be True"
|
||||
shm_name = self.get_shared_memory_name(expert_id, name)
|
||||
shm = shared_memory.SharedMemory(name=shm_name,
|
||||
create=True,
|
||||
size=t.numel() * t.element_size())
|
||||
shm.buf[:t.numel() * t.element_size()] = t.numpy().tobytes()
|
||||
assert (expert_id, name) not in self.shared_tensors.keys()
|
||||
assert self.expert_start <= expert_id < self.expert_end
|
||||
self.shared_tensors[(expert_id, name)] = t
|
||||
dtype = t.dtype
|
||||
tensor_shape = t.shape
|
||||
t = torch.frombuffer(shm.buf,
|
||||
dtype=dtype).view(tensor_shape).pin_memory()
|
||||
key = (expert_id, name)
|
||||
assert key not in self.host_weights.keys(), f"key={key} already exists"
|
||||
self.host_weights[key] = t
|
||||
self.own_shms[(expert_id, name)] = shm
|
||||
self.all_shms.append(shm)
|
||||
atexit.register(shm.unlink)
|
||||
if name not in self.name_info:
|
||||
self.name_info[name] = (dtype, tensor_shape)
|
||||
else:
|
||||
assert dtype == self.name_info[name][0] and tensor_shape == self.name_info[name][1], \
|
||||
f'weights name={name}, dtype={dtype}, shape={tensor_shape}, but already registered with dtype={self.name_info[name][0]}, shape={self.name_info[name][1]}'
|
||||
|
||||
@staticmethod
|
||||
def align_size(size: int):
|
||||
return (size + 256 - 1) // 256 * 256
|
||||
|
||||
def finalize_layer_weights(self):
|
||||
self.names = list(sorted(self.name_info.keys()))
|
||||
assert len(
|
||||
self.shared_tensors.keys()) == (self.expert_end -
|
||||
self.expert_start) * len(self.names)
|
||||
|
||||
total_size = 0
|
||||
for name in self.names:
|
||||
dtype, shape = self.name_info[name]
|
||||
for expert_id in range(self.expert_start, self.expert_end):
|
||||
t = self.shared_tensors[(expert_id, name)]
|
||||
assert dtype == t.dtype and shape == t.shape
|
||||
data_size = t.numel() * t.element_size()
|
||||
aligned_size = self.align_size(data_size)
|
||||
total_size += aligned_size
|
||||
|
||||
shm_name = self.get_shared_memory_name()
|
||||
shm = shared_memory.SharedMemory(name=shm_name,
|
||||
create=True,
|
||||
size=total_size)
|
||||
self.own_shm = shm
|
||||
|
||||
offset = 0
|
||||
for name in self.names:
|
||||
for expert_id in range(self.expert_start, self.expert_end):
|
||||
t = self.shared_tensors[(expert_id, name)]
|
||||
data_size = t.numel() * t.element_size()
|
||||
aligned_size = self.align_size(data_size)
|
||||
shm.buf[offset:offset + data_size] = t.numpy().tobytes()
|
||||
dtype = t.dtype
|
||||
tensor_shape = t.shape
|
||||
elt_count = t.numel()
|
||||
st = torch.frombuffer(shm.buf,
|
||||
dtype=dtype,
|
||||
offset=offset,
|
||||
count=elt_count).view(tensor_shape)
|
||||
key = (expert_id, name)
|
||||
assert key not in self.host_weights.keys(
|
||||
), f"key={key} already exists"
|
||||
self.host_weights[key] = st
|
||||
offset += aligned_size
|
||||
self.shared_tensors = {}
|
||||
|
||||
def finalize_host_tensor_sharing(self, add_host_weight_fn: Callable = None):
|
||||
"""
|
||||
Finalize the host tensor sharing.
|
||||
"""
|
||||
for expert_weight_info in self.host_tensor_shapes:
|
||||
expert_id, name, dtype, tensor_shape = expert_weight_info
|
||||
shm_name = self.get_shared_memory_name(expert_id, name)
|
||||
for rank in range(self.local_size):
|
||||
if rank == self.local_rank:
|
||||
continue
|
||||
|
||||
shm_name = self.get_shared_memory_name(rank)
|
||||
shm = shared_memory.SharedMemory(name=shm_name)
|
||||
self.all_shms.append(shm)
|
||||
t = torch.frombuffer(shm.buf,
|
||||
dtype=dtype).view(tensor_shape).pin_memory()
|
||||
key = (expert_id, name)
|
||||
assert key not in self.host_weights.keys(
|
||||
), f"key={key} already exists"
|
||||
self.host_weights[key] = t
|
||||
self.imported_shms.append(shm)
|
||||
|
||||
rank_expert_start = rank * self.expert_count // self.local_size
|
||||
rank_expert_end = (rank + 1) * self.expert_count // self.local_size
|
||||
|
||||
offset = 0
|
||||
for name in self.names:
|
||||
dtype, shape = self.name_info[name]
|
||||
elt_count = int(np.prod(shape))
|
||||
data_size = torch.tensor([],
|
||||
dtype=dtype).element_size() * elt_count
|
||||
aligned_size = self.align_size(data_size)
|
||||
for expert_id in range(rank_expert_start, rank_expert_end):
|
||||
t = torch.frombuffer(shm.buf,
|
||||
dtype=dtype,
|
||||
offset=offset,
|
||||
count=elt_count).view(shape)
|
||||
key = (expert_id, name)
|
||||
assert key not in self.host_weights.keys(
|
||||
), f"key={key} already exists"
|
||||
self.host_weights[key] = t
|
||||
offset += aligned_size
|
||||
|
||||
if add_host_weight_fn is not None:
|
||||
for key, t in self.host_weights.items():
|
||||
@ -154,8 +243,20 @@ class HostMoeTensorSharer:
|
||||
"""
|
||||
Clean up the resources before C++ shutdown and barrier
|
||||
"""
|
||||
for shm in self.all_shms:
|
||||
for shm in self.imported_shms:
|
||||
shm.close()
|
||||
resource_tracker.unregister(shm._name, "shared_memory")
|
||||
self.imported_shms = None
|
||||
if self.own_shm:
|
||||
self.own_shm.close()
|
||||
|
||||
def post_shutdown_cleanup(self):
|
||||
"""
|
||||
Clean up the resources before C++ shutdown and barrier
|
||||
"""
|
||||
if self.own_shm:
|
||||
self.own_shm.unlink()
|
||||
self.own_shm = None
|
||||
|
||||
|
||||
class SingleLayerMoeLoadBalancer:
|
||||
@ -167,19 +268,56 @@ class SingleLayerMoeLoadBalancer:
|
||||
def __init__(
|
||||
self,
|
||||
single_layer_load_balancer_impl: _tbr.SingleLayerMoeLoadBalancer,
|
||||
shared_mpi_comm: MPI.Comm):
|
||||
shared_mpi_comm: MPI.Comm,
|
||||
expert_count: int,
|
||||
updates_enabled: bool = True):
|
||||
"""
|
||||
Initialize a SingleLayerMoeLoadBalancer instance.
|
||||
|
||||
Args:
|
||||
single_layer_load_balancer_impl: The C++ implementation of SingleLayerMoeLoadBalancer
|
||||
shared_mpi_comm: The MPI communicator for shared memory
|
||||
expert_count: total number of experts
|
||||
updates_enabled: whether to enable weight updates
|
||||
"""
|
||||
self.single_layer_load_balancer_impl = single_layer_load_balancer_impl
|
||||
self.single_layer_load_balancer_ptr = single_layer_load_balancer_impl.get_pointer(
|
||||
)
|
||||
self.expert_count = expert_count
|
||||
self.updates_enabled = updates_enabled
|
||||
layer_id = self.single_layer_load_balancer_impl.get_layer_id()
|
||||
self.host_tensor_sharer = HostMoeTensorSharer(shared_mpi_comm, layer_id)
|
||||
self.host_tensor_sharer = HostMoeTensorSharer(
|
||||
layer_id, expert_count,
|
||||
shared_mpi_comm) if self.updates_enabled else None
|
||||
self.register_weight_fns = []
|
||||
self.to_fix_weight_fns = []
|
||||
|
||||
shared_rank = shared_mpi_comm.Get_rank()
|
||||
shared_size = shared_mpi_comm.Get_size()
|
||||
|
||||
load_expert_start = shared_rank * self.expert_count // shared_size
|
||||
load_expert_end = min(
|
||||
(shared_rank + 1) * self.expert_count // shared_size,
|
||||
self.expert_count)
|
||||
self.load_expert_ids = list(range(load_expert_start, load_expert_end))
|
||||
|
||||
self.statistic_flag_tensor = None
|
||||
|
||||
self.cudagraph_stream = None
|
||||
self.cudagraph_event = None
|
||||
|
||||
def get_layer_idx(self):
|
||||
return self.single_layer_load_balancer_impl.get_layer_id()
|
||||
|
||||
def get_load_expert_ids(self):
|
||||
assert self.updates_enabled, "should not call get_load_expert_ids when using statistic routing"
|
||||
return self.load_expert_ids
|
||||
|
||||
def is_static_routing(self):
|
||||
return not self.updates_enabled
|
||||
|
||||
def need_load_shared_weights(self):
|
||||
return self.updates_enabled
|
||||
|
||||
def set_shared_memory_base_name(self, shared_memory_base_name):
|
||||
"""
|
||||
@ -188,8 +326,9 @@ class SingleLayerMoeLoadBalancer:
|
||||
Args:
|
||||
shared_memory_base_name: The base name for the shared memory
|
||||
"""
|
||||
self.host_tensor_sharer.set_shared_memory_base_name(
|
||||
shared_memory_base_name)
|
||||
if self.updates_enabled:
|
||||
self.host_tensor_sharer.set_shared_memory_base_name(
|
||||
shared_memory_base_name)
|
||||
|
||||
def _add_weight_slot(self, slot_id: int, name: str,
|
||||
weight_slot: _tbr.MoeWeight):
|
||||
@ -201,20 +340,21 @@ class SingleLayerMoeLoadBalancer:
|
||||
name: The name of the weight
|
||||
weight_slot: The weight object
|
||||
"""
|
||||
self.single_layer_load_balancer_impl.add_weight_slot(
|
||||
self.single_layer_load_balancer_impl.add_single_weight_slot(
|
||||
slot_id, name, weight_slot)
|
||||
|
||||
def register_weight_slot(self, slot_id: int, name: str, t: torch.Tensor):
|
||||
def register_weight_slot(self, local_slot_id: int, name: str,
|
||||
t: torch.Tensor):
|
||||
"""
|
||||
Register a weight slot to the layer.
|
||||
|
||||
Args:
|
||||
slot_id: The ID of the slot
|
||||
local_slot_id: The ID of the slot at local rank
|
||||
name: The name of the weight
|
||||
t: The weight tensor
|
||||
"""
|
||||
moe_weight = _tensor_to_weight(t)
|
||||
self._add_weight_slot(slot_id, name, moe_weight)
|
||||
self._add_weight_slot(local_slot_id, name, moe_weight)
|
||||
|
||||
def _add_host_weight(self, expert_id: int, name: str,
|
||||
host_weight: _tbr.MoeWeight):
|
||||
@ -226,7 +366,7 @@ class SingleLayerMoeLoadBalancer:
|
||||
name: The name of the weight
|
||||
host_weight: The host weight object
|
||||
"""
|
||||
self.single_layer_load_balancer_impl.add_host_weight(
|
||||
self.single_layer_load_balancer_impl.add_single_host_weight(
|
||||
expert_id, name, host_weight)
|
||||
|
||||
def _add_host_weight_from_tensor(self, expert_id: int, name: str,
|
||||
@ -248,46 +388,128 @@ class SingleLayerMoeLoadBalancer:
|
||||
self.single_layer_load_balancer_impl.set_initial_weight_assignments(
|
||||
initial_weight_assignments)
|
||||
|
||||
def add_to_fix_weight_fn(self,
|
||||
fn: Callable,
|
||||
args: Tuple,
|
||||
kwargs: Dict = {}):
|
||||
self.to_fix_weight_fns.append((fn, args, kwargs))
|
||||
|
||||
def add_register_weight_fn(self,
|
||||
fn: Callable,
|
||||
args: Tuple,
|
||||
kwargs: Dict = {}):
|
||||
"""
|
||||
Add weight register function, this function doesn't run fn directly but run all functions after model.to("cuda")
|
||||
so this function can be called when model is not on GPU yet.
|
||||
"""
|
||||
self.register_weight_fns.append((fn, args, kwargs))
|
||||
|
||||
def fix_tensor(self, wt: torch.Tensor):
|
||||
torch.ops.trtllm.migrate_to_managed(wt)
|
||||
|
||||
def register_weight_slots_after_to_cuda(self):
|
||||
"""
|
||||
Register weights after model has been moved to cuda, should be invoked after model.to("cuda") and before finalize_model.
|
||||
"""
|
||||
for fn, args, kwargs in self.to_fix_weight_fns:
|
||||
fn(*args, **kwargs)
|
||||
|
||||
self.to_fix_weight_fns = []
|
||||
|
||||
for fn, args, kwargs in self.register_weight_fns:
|
||||
fn(*args, **kwargs)
|
||||
|
||||
self.register_weight_fns = []
|
||||
|
||||
def py_finalize_model(self):
|
||||
"""
|
||||
Finalize the model after all layers have been added.
|
||||
This must be called before starting any iterations.
|
||||
"""
|
||||
self.host_tensor_sharer.finalize_host_tensor_sharing(
|
||||
self._add_host_weight_from_tensor)
|
||||
if self.updates_enabled:
|
||||
self.host_tensor_sharer.finalize_host_tensor_sharing(
|
||||
self._add_host_weight_from_tensor)
|
||||
|
||||
def wait_for_gpu_stage(self) -> torch.Tensor:
|
||||
def wait_for_gpu_stage(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Wait for the GPU stage to complete.
|
||||
|
||||
Returns:
|
||||
A tensor indicating whether the stage is enabled
|
||||
"""
|
||||
return torch.ops.trtllm.moe_load_balance_wait_gpu_stage(
|
||||
self.single_layer_load_balancer_ptr)
|
||||
if self.updates_enabled:
|
||||
assert self.statistic_flag_tensor is None, \
|
||||
"Already has statistic_flag_tensor, should not wait."
|
||||
if is_graph_capturing():
|
||||
self.cudagraph_event = torch.cuda.Event()
|
||||
self.cudagraph_stream = torch.cuda.Stream()
|
||||
current_stream_event = torch.cuda.Event()
|
||||
current_stream_event.record(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.cudagraph_stream):
|
||||
current_stream_event.wait()
|
||||
self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage(
|
||||
self.single_layer_load_balancer_ptr)
|
||||
self.cudagraph_event.record(self.cudagraph_stream)
|
||||
else:
|
||||
self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage(
|
||||
self.single_layer_load_balancer_ptr)
|
||||
return self.statistic_flag_tensor
|
||||
else:
|
||||
return
|
||||
|
||||
def maybe_cudagraph_done_wait(self):
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
assert self.cudagraph_event is not None, "should have cudagraph_event when capturing"
|
||||
assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing"
|
||||
self.cudagraph_event.wait()
|
||||
|
||||
def set_cpu_stage(self):
|
||||
"""
|
||||
Set the CPU stage.
|
||||
"""
|
||||
torch.ops.trtllm.moe_load_balance_set_cpu_stage(
|
||||
self.single_layer_load_balancer_ptr)
|
||||
if self.updates_enabled:
|
||||
assert self.statistic_flag_tensor is not None, \
|
||||
"Doesn't have statistic_flag_tensor, should not set_cpu_stage."
|
||||
self.statistic_flag_tensor = None
|
||||
if is_graph_capturing():
|
||||
assert self.cudagraph_stream is not None, "Doesn't have cudagraph_stream, should not set_cpu_stage."
|
||||
current_stream_event = torch.cuda.Event()
|
||||
current_stream_event.record(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.cudagraph_stream):
|
||||
current_stream_event.wait()
|
||||
torch.ops.trtllm.moe_load_balance_set_cpu_stage(
|
||||
self.single_layer_load_balancer_ptr)
|
||||
self.cudagraph_event.record(self.cudagraph_stream)
|
||||
else:
|
||||
torch.ops.trtllm.moe_load_balance_set_cpu_stage(
|
||||
self.single_layer_load_balancer_ptr)
|
||||
|
||||
def maybe_cudagraph_done_set_cpu_stage(self):
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
assert self.cudagraph_event is not None, "should have cudagraph_event when capturing"
|
||||
assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing"
|
||||
self.cudagraph_event.wait()
|
||||
self.cudagraph_stream = None
|
||||
self.cudagraph_event = None
|
||||
|
||||
def statistic(self, gathered_raw_expert_ids: torch.Tensor,
|
||||
enabled: torch.Tensor, is_first_stage: bool,
|
||||
is_last_stage: bool):
|
||||
is_first_stage: bool, is_last_stage: bool):
|
||||
"""
|
||||
Perform statistics on the expert IDs.
|
||||
|
||||
Args:
|
||||
gathered_raw_expert_ids: The gathered raw expert IDs from all ranks
|
||||
enabled: A tensor indicating whether the operation is enabled
|
||||
is_first_stage: Whether this is the first stage
|
||||
is_last_stage: Whether this is the last stage
|
||||
"""
|
||||
torch.ops.trtllm.moe_load_balance_statistic(
|
||||
gathered_raw_expert_ids, enabled,
|
||||
self.single_layer_load_balancer_ptr, is_first_stage, is_last_stage)
|
||||
if self.updates_enabled:
|
||||
assert isinstance(self.statistic_flag_tensor, torch.Tensor)
|
||||
torch.ops.trtllm.moe_load_balance_statistic(
|
||||
gathered_raw_expert_ids, self.statistic_flag_tensor,
|
||||
self.single_layer_load_balancer_ptr, is_first_stage,
|
||||
is_last_stage)
|
||||
|
||||
def route(self,
|
||||
token_selected_experts: torch.Tensor,
|
||||
@ -310,7 +532,15 @@ class SingleLayerMoeLoadBalancer:
|
||||
"""
|
||||
Clean up the resources before C++ shutdown and barrier
|
||||
"""
|
||||
self.host_tensor_sharer.pre_shutdown_cleanup()
|
||||
if self.updates_enabled:
|
||||
self.host_tensor_sharer.pre_shutdown_cleanup()
|
||||
|
||||
def py_post_shutdown_cleanup(self):
|
||||
"""
|
||||
Clean up the resources after C++ shutdown and barrier
|
||||
"""
|
||||
if self.updates_enabled:
|
||||
self.host_tensor_sharer.post_shutdown_cleanup()
|
||||
|
||||
|
||||
# Global variable to store the current active MoeLoadBalancer instance
|
||||
@ -346,6 +576,21 @@ class MoeLoadBalancer:
|
||||
self.single_layer_load_balancers = []
|
||||
self.shared_memory_base_name = shared_memory_base_name
|
||||
self._setup_mpi_comm()
|
||||
self.is_shutdown = False
|
||||
|
||||
self.iter_id = 0
|
||||
self.in_iter = False
|
||||
|
||||
self.enable_statistic = False
|
||||
self.enable_update_weights = False
|
||||
|
||||
def __del__(self):
|
||||
if not self.is_shutdown:
|
||||
self.shutdown()
|
||||
|
||||
def is_static_routing(self):
|
||||
# if we don't update, then it is statistic routing.
|
||||
return self.layer_updates_per_iter == 0
|
||||
|
||||
def _setup_mpi_comm(self):
|
||||
global_mpi_comm = tensorrt_llm.mpi_comm()
|
||||
@ -357,6 +602,9 @@ class MoeLoadBalancer:
|
||||
f"Interesting, shared size {shared_size} is not same as local size {local_size}"
|
||||
self.shared_mpi_comm = shared_mpi_comm
|
||||
|
||||
def set_use_gpu_memcpy(self, use_gpu_memcpy: bool):
|
||||
self.load_balancer_impl.set_use_gpu_memcpy(use_gpu_memcpy)
|
||||
|
||||
def add_layer(self, expert_count: int, top_k: int,
|
||||
slot_count_per_rank: int) -> SingleLayerMoeLoadBalancer:
|
||||
"""
|
||||
@ -372,13 +620,24 @@ class MoeLoadBalancer:
|
||||
"""
|
||||
single_layer_load_balancer_impl = self.load_balancer_impl.add_layer(
|
||||
expert_count, top_k, slot_count_per_rank)
|
||||
updates_enabled = not self.is_static_routing()
|
||||
single_layer_load_balancer = SingleLayerMoeLoadBalancer(
|
||||
single_layer_load_balancer_impl, self.shared_mpi_comm)
|
||||
single_layer_load_balancer_impl,
|
||||
self.shared_mpi_comm,
|
||||
expert_count,
|
||||
updates_enabled=updates_enabled)
|
||||
single_layer_load_balancer.set_shared_memory_base_name(
|
||||
self.shared_memory_base_name)
|
||||
self.single_layer_load_balancers.append(single_layer_load_balancer)
|
||||
return single_layer_load_balancer
|
||||
|
||||
def register_weight_slots_after_to_cuda(self):
|
||||
"""
|
||||
Register weights after model has been moved to cuda, should be invoked after model.to("cuda") and before finalize_model.
|
||||
"""
|
||||
for layer in self.single_layer_load_balancers:
|
||||
layer.register_weight_slots_after_to_cuda()
|
||||
|
||||
def finalize_model(self):
|
||||
"""
|
||||
Finalize the model after all layers have been added.
|
||||
@ -391,6 +650,7 @@ class MoeLoadBalancer:
|
||||
for single_layer_load_balancer in self.single_layer_load_balancers:
|
||||
single_layer_load_balancer.py_finalize_model()
|
||||
self.load_balancer_impl.finalize_model()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def set_warm_up_iter_count(self, iter_count: int):
|
||||
"""
|
||||
@ -401,27 +661,33 @@ class MoeLoadBalancer:
|
||||
"""
|
||||
self.load_balancer_impl.set_warm_up_iter_count(iter_count)
|
||||
|
||||
def start_iter(self, iter_id: int, enable_statistic: bool,
|
||||
enable_update_weights: bool):
|
||||
def set_next_iter_info(self, enable_statistic: Optional[bool],
|
||||
enable_update_weights: Optional[bool]):
|
||||
if enable_statistic is not None:
|
||||
self.enable_statistic = enable_statistic
|
||||
if enable_update_weights is not None:
|
||||
self.enable_update_weights = enable_update_weights
|
||||
|
||||
def start_iter(self):
|
||||
"""
|
||||
Start a new iteration.
|
||||
|
||||
Args:
|
||||
iter_id: The ID of the iteration
|
||||
enable_statistic: Whether to enable statistics collection
|
||||
enable_update_weights: Whether to enable weight updates
|
||||
"""
|
||||
self.load_balancer_impl.start_iter(iter_id, enable_statistic,
|
||||
enable_update_weights)
|
||||
assert self.in_iter == False, "already in forward"
|
||||
self.in_iter = True
|
||||
self.load_balancer_impl.start_iter(self.iter_id, self.enable_statistic,
|
||||
self.enable_update_weights)
|
||||
|
||||
def end_iter(self, iter_id: int):
|
||||
def end_iter(self):
|
||||
"""
|
||||
End the current iteration.
|
||||
|
||||
Args:
|
||||
iter_id: The ID of the iteration to end
|
||||
"""
|
||||
self.load_balancer_impl.end_iter(iter_id)
|
||||
assert self.in_iter, "not in forward, cannot end_iter"
|
||||
self.load_balancer_impl.end_iter(self.iter_id)
|
||||
self.in_iter = False
|
||||
self.iter_id += 1
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
@ -432,6 +698,10 @@ class MoeLoadBalancer:
|
||||
self.load_balancer_impl.shutdown()
|
||||
# use this sync to make sure all the shm resources can be cleaned up
|
||||
self.shared_mpi_comm.barrier()
|
||||
for single_layer_load_balancer in self.single_layer_load_balancers:
|
||||
single_layer_load_balancer.py_post_shutdown_cleanup()
|
||||
self.shared_mpi_comm.barrier()
|
||||
self.is_shutdown = True
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
@ -482,6 +752,80 @@ class MoeLoadBalancer:
|
||||
return False
|
||||
|
||||
|
||||
moe_model_arch_list = [
|
||||
'DeepseekV3ForCausalLM',
|
||||
]
|
||||
|
||||
|
||||
def maybe_create_moe_load_balancer(
|
||||
model_config, mapping: Optional[Mapping]) -> Optional[MoeLoadBalancer]:
|
||||
ep_rank = model_config.mapping.moe_ep_rank
|
||||
ep_size = model_config.mapping.moe_ep_size
|
||||
model_arch = model_config.pretrained_config.architectures[0]
|
||||
using_ep = mapping and mapping.moe_ep_size > 1
|
||||
in_supported_model_arch = model_arch in moe_model_arch_list
|
||||
using_smart_router = mapping and mapping.moe_cluster_size > 1
|
||||
moe_load_balancer = nullcontext()
|
||||
if in_supported_model_arch and using_ep and not using_smart_router and model_config.moe_load_balancer is not None:
|
||||
model_config.moe_load_balancer.setup(ep_rank=ep_rank, ep_size=ep_size)
|
||||
if model_config.moe_load_balancer.layer_updates_per_iter > 0:
|
||||
# TODO: remove this when supported.
|
||||
cpu_arch = platform.machine().lower()
|
||||
assert cpu_arch == 'aarch64', "online load balancer only support aarch64, e.g. GB200 now, x86 coming soon."
|
||||
|
||||
moe_load_balancer = MoeLoadBalancer(
|
||||
ep_rank=ep_rank,
|
||||
ep_size=ep_size,
|
||||
layer_updates_per_iter=model_config.moe_load_balancer.
|
||||
layer_updates_per_iter)
|
||||
logger.info(
|
||||
f"Created MoE LoadBalancer, layer_updates_per_iter={model_config.moe_load_balancer.layer_updates_per_iter}..."
|
||||
)
|
||||
return moe_load_balancer
|
||||
|
||||
|
||||
class MoeLoadBalancerIterContext:
|
||||
|
||||
def __init__(self,
|
||||
moe_load_balancer: Optional[MoeLoadBalancer],
|
||||
enable_statistic: Optional[bool] = None,
|
||||
enable_updates: Optional[bool] = None):
|
||||
self.moe_load_balancer = moe_load_balancer
|
||||
self.enable_statistic = enable_statistic
|
||||
self.enable_updates = enable_updates
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Enter the context manager.
|
||||
|
||||
Returns:
|
||||
The MoeLoadBalancerIterContext instance
|
||||
"""
|
||||
if self.moe_load_balancer is not None and not self.moe_load_balancer.is_static_routing(
|
||||
):
|
||||
self.moe_load_balancer.set_next_iter_info(self.enable_statistic,
|
||||
self.enable_updates)
|
||||
self.moe_load_balancer.start_iter()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""
|
||||
Exit the context manager.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type
|
||||
exc_val: The exception value
|
||||
exc_tb: The exception traceback
|
||||
|
||||
Returns:
|
||||
False to not suppress exceptions
|
||||
"""
|
||||
if self.moe_load_balancer is not None and not self.moe_load_balancer.is_static_routing(
|
||||
):
|
||||
self.moe_load_balancer.end_iter()
|
||||
return False
|
||||
|
||||
|
||||
def get_moe_load_balancer() -> Optional[MoeLoadBalancer]:
|
||||
"""
|
||||
Get the current active MoeLoadBalancer instance.
|
||||
|
||||
@ -61,6 +61,14 @@ class FusedMoEMethodBase(ABC):
|
||||
Base class for all fused MoE methods.
|
||||
"""
|
||||
|
||||
def need_load_shared_weights(self, module):
|
||||
if hasattr(
|
||||
module, "layer_load_balancer"
|
||||
) and module.layer_load_balancer and module.layer_load_balancer.need_load_shared_weights(
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def create_weights(self, module: torch.nn.Module, weight_dtype: torch.dtype,
|
||||
w3_w1_weight_shape: tuple[int, int, int],
|
||||
w2_weight_shape: tuple[int, int, int]):
|
||||
@ -76,12 +84,15 @@ class FusedMoEMethodBase(ABC):
|
||||
requires_grad=False)
|
||||
module.register_parameter("w2_weight", w2_weight)
|
||||
|
||||
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode):
|
||||
def load_expert_weights_to_dst(self, module: torch.nn.Module,
|
||||
weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode,
|
||||
load_expert_ids: List[int],
|
||||
dst_w3_w1_weights_tensor: torch.Tensor,
|
||||
dst_w2_weights_tensor: torch.Tensor):
|
||||
# Multithread weight load is superseded by prefetch_files() in model_engine.py
|
||||
# Also, threading adds overhead in order to protect shuffle index cache with critical section.
|
||||
for local_slot_id, expert_id in enumerate(
|
||||
module.initial_local_expert_ids):
|
||||
for local_slot_id, expert_id in enumerate(load_expert_ids):
|
||||
# expert_idx is the local slot index of current rank
|
||||
expert_idx = local_slot_id
|
||||
|
||||
@ -101,15 +112,55 @@ class FusedMoEMethodBase(ABC):
|
||||
)
|
||||
|
||||
self.load_expert_w3_w1_weight(module, w1_weight, w3_weight,
|
||||
module.w3_w1_weight.data[expert_idx])
|
||||
dst_w3_w1_weights_tensor[expert_idx])
|
||||
|
||||
self.load_expert_w2_weight(module, w2_weight,
|
||||
module.w2_weight.data[expert_idx])
|
||||
dst_w2_weights_tensor[expert_idx])
|
||||
|
||||
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode):
|
||||
|
||||
self.load_expert_weights_to_dst(module, weights, weight_loading_mode,
|
||||
module.initial_local_expert_ids,
|
||||
module.w3_w1_weight.data,
|
||||
module.w2_weight.data)
|
||||
|
||||
self.load_quant_scales(module, weights)
|
||||
# Re-setup quant scales after loading weights as the tensors may have been modified.
|
||||
self.setup_quant_scales(module)
|
||||
|
||||
if self.need_load_shared_weights(module):
|
||||
local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids(
|
||||
)
|
||||
local_shared_w3_w1_tensors = torch.empty(
|
||||
(len(local_shared_load_expert_ids), ) +
|
||||
module.w3_w1_weight.data.shape[1:],
|
||||
dtype=module.w3_w1_weight.data.dtype,
|
||||
device='cpu')
|
||||
local_shared_w2_tensors = torch.empty(
|
||||
(len(local_shared_load_expert_ids), ) +
|
||||
module.w2_weight.data.shape[1:],
|
||||
dtype=module.w2_weight.data.dtype,
|
||||
device='cpu')
|
||||
self.load_expert_weights_to_dst(module, weights,
|
||||
weight_loading_mode,
|
||||
local_shared_load_expert_ids,
|
||||
local_shared_w3_w1_tensors,
|
||||
local_shared_w2_tensors)
|
||||
module.register_all_parameter_slot_and_to_fix_weight_fns({
|
||||
'w3_w1_weight':
|
||||
local_shared_w3_w1_tensors,
|
||||
'w2_weight':
|
||||
local_shared_w2_tensors
|
||||
})
|
||||
module.layer_load_balancer.host_tensor_sharer.finalize_layer_weights(
|
||||
)
|
||||
|
||||
if hasattr(module,
|
||||
"layer_load_balancer") and module.layer_load_balancer:
|
||||
module.layer_load_balancer.set_initial_weight_assignments(
|
||||
module.initial_global_assignments)
|
||||
|
||||
def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]):
|
||||
pass
|
||||
|
||||
@ -828,6 +879,50 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
|
||||
w2_weight_scale_2 = 1.0 / w2_weight_scale_2[...].reshape([])
|
||||
dst_w2_alpha.copy_(1.0 / (final_fc2_input_scale * w2_weight_scale_2))
|
||||
|
||||
def load_all_fp4_weight_scales_and_alphas(
|
||||
self, module: torch.nn.Module, weights: Dict,
|
||||
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
|
||||
dst_w2_weight_scale: torch.Tensor, dst_fc31_alpha: torch.Tensor,
|
||||
dst_fc2_alpha: torch.Tensor):
|
||||
for local_slot_id, expert_id in enumerate(load_expert_ids):
|
||||
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
|
||||
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
|
||||
w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"]
|
||||
w1_weight_scale_2 = weights[f"{expert_id}.w1.weight_scale_2"]
|
||||
w3_weight_scale_2 = weights[f"{expert_id}.w3.weight_scale_2"]
|
||||
w2_weight_scale_2 = weights[f"{expert_id}.w2.weight_scale_2"]
|
||||
elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
|
||||
w1_w3_weight_scale = weights["gate_up_proj_weight_scale"][
|
||||
expert_id].transpose(0, 1).contiguous()
|
||||
w1_weight_scale, w3_weight_scale = w1_w3_weight_scale.chunk(
|
||||
2, dim=0)
|
||||
w2_weight_scale = weights["down_proj_weight_scale"][
|
||||
expert_id].transpose(0, 1).contiguous()
|
||||
w1_weight_scale_2 = weights["gate_up_proj_weight_scale_2"]
|
||||
w3_weight_scale_2 = weights["gate_up_proj_weight_scale_2"]
|
||||
w2_weight_scale_2 = weights["down_proj_weight_scale_2"]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
|
||||
)
|
||||
|
||||
expert_idx = local_slot_id
|
||||
|
||||
self.load_expert_w3_w1_weight_scale_nvfp4(
|
||||
module, w1_weight_scale, w3_weight_scale,
|
||||
dst_w3_w1_weight_scale[expert_idx])
|
||||
self.load_expert_w2_weight_scale_nvfp4(
|
||||
module, w2_weight_scale, dst_w2_weight_scale[expert_idx])
|
||||
|
||||
self.load_expert_fc31_alpha_nvfp4(w1_weight_scale_2,
|
||||
w3_weight_scale_2,
|
||||
module.fc31_input_scale.data,
|
||||
dst_fc31_alpha[expert_idx])
|
||||
self.load_expert_fc2_alpha_nvfp4(w2_weight_scale_2,
|
||||
module.fc2_input_scale.data,
|
||||
dst_fc2_alpha[expert_idx])
|
||||
|
||||
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
|
||||
# Step1: Load input scales.
|
||||
tmp_fc31_input_scale = torch.empty(module.num_experts,
|
||||
@ -862,46 +957,50 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
|
||||
tmp_fc2_input_scale.max().reciprocal())
|
||||
|
||||
# Step2: Load weight block scales and alphas.
|
||||
for local_slot_id, expert_id in enumerate(
|
||||
module.initial_local_expert_ids):
|
||||
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
|
||||
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
|
||||
w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"]
|
||||
w1_weight_scale_2 = weights[f"{expert_id}.w1.weight_scale_2"]
|
||||
w3_weight_scale_2 = weights[f"{expert_id}.w3.weight_scale_2"]
|
||||
w2_weight_scale_2 = weights[f"{expert_id}.w2.weight_scale_2"]
|
||||
elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
|
||||
w1_w3_weight_scale = weights["gate_up_proj_weight_scale"][
|
||||
expert_id].transpose(0, 1).contiguous()
|
||||
w1_weight_scale, w3_weight_scale = w1_w3_weight_scale.chunk(
|
||||
2, dim=0)
|
||||
w2_weight_scale = weights["down_proj_weight_scale"][
|
||||
expert_id].transpose(0, 1).contiguous()
|
||||
w1_weight_scale_2 = weights["gate_up_proj_weight_scale_2"]
|
||||
w3_weight_scale_2 = weights["gate_up_proj_weight_scale_2"]
|
||||
w2_weight_scale_2 = weights["down_proj_weight_scale_2"]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
|
||||
)
|
||||
self.load_all_fp4_weight_scales_and_alphas(
|
||||
module, weights, module.initial_local_expert_ids,
|
||||
module.w3_w1_weight_scale.data, module.w2_weight_scale.data,
|
||||
module.fc31_alpha.data, module.fc2_alpha.data)
|
||||
|
||||
expert_idx = local_slot_id
|
||||
# Step 3: if need load into shared
|
||||
if self.need_load_shared_weights(module):
|
||||
local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids(
|
||||
)
|
||||
local_shared_w3_w1_scale_tensors = torch.empty(
|
||||
(len(local_shared_load_expert_ids), ) +
|
||||
module.w3_w1_weight_scale.data.shape[1:],
|
||||
dtype=module.w3_w1_weight_scale.data.dtype,
|
||||
device='cpu')
|
||||
local_shared_w2_scale_tensors = torch.empty(
|
||||
(len(local_shared_load_expert_ids), ) +
|
||||
module.w2_weight_scale.data.shape[1:],
|
||||
dtype=module.w2_weight_scale.data.dtype,
|
||||
device='cpu')
|
||||
local_shared_fc31_alpha_tensors = torch.empty(
|
||||
(len(local_shared_load_expert_ids), ) +
|
||||
module.fc31_alpha.data.shape[1:],
|
||||
dtype=module.fc31_alpha.data.dtype,
|
||||
device='cpu')
|
||||
local_shared_fc2_alpha_tensors = torch.empty(
|
||||
(len(local_shared_load_expert_ids), ) +
|
||||
module.fc2_alpha.data.shape[1:],
|
||||
dtype=module.fc2_alpha.data.dtype,
|
||||
device='cpu')
|
||||
self.load_all_fp4_weight_scales_and_alphas(
|
||||
module, weights, local_shared_load_expert_ids,
|
||||
local_shared_w3_w1_scale_tensors, local_shared_w2_scale_tensors,
|
||||
local_shared_fc31_alpha_tensors, local_shared_fc2_alpha_tensors)
|
||||
|
||||
self.load_expert_w3_w1_weight_scale_nvfp4(
|
||||
module, w1_weight_scale, w3_weight_scale,
|
||||
module.w3_w1_weight_scale.data[expert_idx])
|
||||
self.load_expert_w2_weight_scale_nvfp4(
|
||||
module, w2_weight_scale,
|
||||
module.w2_weight_scale.data[expert_idx])
|
||||
|
||||
self.load_expert_fc31_alpha_nvfp4(
|
||||
w1_weight_scale_2, w3_weight_scale_2,
|
||||
module.fc31_input_scale.data,
|
||||
module.fc31_alpha.data[expert_idx])
|
||||
self.load_expert_fc2_alpha_nvfp4(w2_weight_scale_2,
|
||||
module.fc2_input_scale.data,
|
||||
module.fc2_alpha.data[expert_idx])
|
||||
module.register_all_parameter_slot_and_to_fix_weight_fns({
|
||||
'w3_w1_weight_scale':
|
||||
local_shared_w3_w1_scale_tensors,
|
||||
'w2_weight_scale':
|
||||
local_shared_w2_scale_tensors,
|
||||
'fc31_alpha':
|
||||
local_shared_fc31_alpha_tensors,
|
||||
'fc2_alpha':
|
||||
local_shared_fc2_alpha_tensors,
|
||||
})
|
||||
|
||||
def setup_quant_scales(self, module: torch.nn.Module):
|
||||
module.quant_scales = FusedMoEQuantScalesNVFP4(
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import bisect
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
@ -11,6 +12,7 @@ import traceback
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
@ -47,6 +49,8 @@ from ..model_config import ModelConfig, MoeLoadBalancerConfig
|
||||
from ..models import AutoModelForCausalLM
|
||||
from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode,
|
||||
timing)
|
||||
from ..modules.fused_moe.moe_load_balancer import (
|
||||
MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer)
|
||||
from ..speculative import SpecConfig, SpecMetadata, get_spec_metadata
|
||||
from ..utils import (get_model_extra_attrs, set_torch_compiling,
|
||||
with_model_extra_attrs)
|
||||
@ -323,6 +327,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# py_executor.py for how this is used.
|
||||
self.last_spec_metadata = None
|
||||
|
||||
self.in_warmup = False
|
||||
|
||||
self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures(
|
||||
)
|
||||
|
||||
@ -470,6 +476,25 @@ class PyTorchModelEngine(ModelEngine):
|
||||
hidden_size=self.model.config.hidden_size,
|
||||
dtype=torch_dtype_to_str(self.model.config.torch_dtype))
|
||||
|
||||
@contextmanager
|
||||
def set_warmup_flag(self):
|
||||
self.in_warmup = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.in_warmup = False
|
||||
|
||||
@staticmethod
|
||||
def with_warmup_flag(method):
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self.set_warmup_flag():
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@with_warmup_flag
|
||||
def warmup(self, resource_manager: ResourceManager) -> None:
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
@ -962,7 +987,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
getattr(config.pretrained_config,
|
||||
sub_config).num_hidden_layers = num_layers
|
||||
|
||||
with timing("Model init total"):
|
||||
with timing("Model init total"), maybe_create_moe_load_balancer(
|
||||
config, self.mapping) as moe_load_balancer:
|
||||
try:
|
||||
with MetaInitMode():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
@ -1017,6 +1043,13 @@ class PyTorchModelEngine(ModelEngine):
|
||||
raise NotImplementedError(
|
||||
f"No load support for load format: {load_format}")
|
||||
|
||||
if isinstance(moe_load_balancer, MoeLoadBalancer):
|
||||
setattr(self, "moe_load_balancer", moe_load_balancer)
|
||||
moe_load_balancer.register_weight_slots_after_to_cuda()
|
||||
logger.info("moe_load_balancer finalizing model...")
|
||||
moe_load_balancer.finalize_model()
|
||||
logger.info("moe_load_balancer finalize model done")
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
return model
|
||||
|
||||
@ -1950,6 +1983,15 @@ class PyTorchModelEngine(ModelEngine):
|
||||
else:
|
||||
spec_metadata = None
|
||||
|
||||
moe_load_balancer = None
|
||||
if hasattr(self, 'moe_load_balancer'):
|
||||
moe_load_balancer = getattr(self, 'moe_load_balancer')
|
||||
if not self.in_warmup:
|
||||
moe_enable_statistic = True
|
||||
moe_enable_update = True
|
||||
moe_load_balancer.set_next_iter_info(moe_enable_statistic,
|
||||
moe_enable_update)
|
||||
|
||||
if kv_cache_manager is None:
|
||||
inputs, gather_ids = self._prepare_tp_inputs_no_cache(
|
||||
scheduled_requests, attn_metadata, spec_metadata)
|
||||
@ -1957,7 +1999,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
inputs.update(extra_model_inputs)
|
||||
self.last_spec_metadata = spec_metadata
|
||||
|
||||
return self._forward_step(inputs, gather_ids, gather_context_logits)
|
||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||
return self._forward_step(inputs, gather_ids,
|
||||
gather_context_logits)
|
||||
|
||||
with self._maybe_pad_batch(scheduled_requests,
|
||||
kv_cache_manager) as scheduled_requests:
|
||||
@ -1984,21 +2028,32 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.iter_counter += 1
|
||||
|
||||
if maybe_graph is None:
|
||||
outputs = self._forward_step(inputs, gather_ids,
|
||||
gather_context_logits)
|
||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||
outputs = self._forward_step(inputs, gather_ids,
|
||||
gather_context_logits)
|
||||
else:
|
||||
if maybe_graph.needs_capture():
|
||||
|
||||
def capture_forward_fn(inputs: Dict[str, Any]):
|
||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||
return self._forward_step(
|
||||
inputs,
|
||||
gather_ids=gather_ids,
|
||||
gather_context_logits=gather_context_logits)
|
||||
|
||||
pool = maybe_graph.capture(
|
||||
lambda inputs: self._forward_step(
|
||||
inputs,
|
||||
gather_ids=gather_ids,
|
||||
gather_context_logits=gather_context_logits),
|
||||
capture_forward_fn,
|
||||
self._cuda_graph_mem_pool,
|
||||
extra_model_inputs,
|
||||
)
|
||||
self._cuda_graph_mem_pool = pool
|
||||
|
||||
outputs = maybe_graph.run(inputs, extra_model_inputs)
|
||||
# here we don't need to use context since cuda graph capture didn't run kernel.
|
||||
# maybe we need a cleaner way to do this.
|
||||
outputs = maybe_graph.run(inputs, extra_model_inputs)
|
||||
else:
|
||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||
outputs = maybe_graph.run(inputs, extra_model_inputs)
|
||||
|
||||
# Note: To overlap the CPU and GPU computation as much as possible,
|
||||
# guided_decoder.build should be called immediately after the launch of the single step;
|
||||
|
||||
@ -18,6 +18,7 @@ BASE_ZMQ_CLASSES = {
|
||||
"llmapi.run_llm_with_postproc": ["perform_faked_oai_postprocess"
|
||||
], # only used in tests
|
||||
### starting import of torch models classes. They are used in test_llm_multi_gpu.py.
|
||||
"tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"],
|
||||
"tensorrt_llm._torch.models.modeling_bert":
|
||||
["BertForSequenceClassification"],
|
||||
"tensorrt_llm._torch.models.modeling_clip": ["CLIPVisionModel"],
|
||||
@ -48,6 +49,7 @@ BASE_ZMQ_CLASSES = {
|
||||
"tensorrt_llm._torch.models.modeling_siglip": ["SiglipVisionModel"],
|
||||
"tensorrt_llm._torch.models.modeling_vila": ["VilaModel"],
|
||||
### ending import of torch models classes
|
||||
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
|
||||
"tensorrt_llm._torch.pyexecutor.llm_request":
|
||||
["LogitsStorage", "PyResult", "LlmResult", "LlmResponse", "LogProbStorage"],
|
||||
"tensorrt_llm._torch.speculative.mtp": ["MTPConfig"],
|
||||
@ -58,13 +60,13 @@ BASE_ZMQ_CLASSES = {
|
||||
["ClusterInfo", "MathThroughput"],
|
||||
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
|
||||
"tensorrt_llm.bindings.executor": [
|
||||
"BatchingType", "CapacitySchedulerPolicy", "ContextPhaseParams",
|
||||
"BatchingType", "CacheTransceiverConfig", "CapacitySchedulerPolicy",
|
||||
"ContextPhaseParams", "ContextChunkingPolicy", "DynamicBatchConfig",
|
||||
"ExecutorConfig", "ExtendedRuntimePerfKnobConfig", "Response", "Result",
|
||||
"FinishReason", "KvCacheConfig", "KvCacheTransferMode",
|
||||
"KvCacheRetentionConfig",
|
||||
"KvCacheRetentionConfig.TokenRangeRetentionConfig", "PeftCacheConfig",
|
||||
"SchedulerConfig", "DynamicBatchConfig", "ContextChunkingPolicy",
|
||||
"CacheTransceiverConfig"
|
||||
"SchedulerConfig"
|
||||
],
|
||||
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig"],
|
||||
"tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"],
|
||||
|
||||
@ -78,6 +78,11 @@ class TestHostMoeTensorSharer(unittest.TestCase):
|
||||
size = comm.Get_size()
|
||||
layer_id = 0
|
||||
|
||||
# Test tensor parameters
|
||||
experts_per_rank = 2 # Each rank is responsible for 2 consecutive experts
|
||||
expert_count = size * experts_per_rank
|
||||
tensor_shape = (16, 32) # Use 2D tensor for testing
|
||||
|
||||
# Maximum supported ranks (can adjust as needed)
|
||||
max_ranks = 8
|
||||
if size > max_ranks:
|
||||
@ -87,17 +92,12 @@ class TestHostMoeTensorSharer(unittest.TestCase):
|
||||
shared_comm = comm.Split_type(split_type=MPI.COMM_TYPE_SHARED)
|
||||
|
||||
# Initialize HostMoeTensorSharer
|
||||
sharer = HostMoeTensorSharer(layer_id, shared_comm)
|
||||
sharer = HostMoeTensorSharer(layer_id, expert_count, shared_comm)
|
||||
|
||||
# Set shared memory base name
|
||||
shared_memory_base_name = "test_host_sharer"
|
||||
sharer.set_shared_memory_base_name(shared_memory_base_name)
|
||||
|
||||
# Test tensor parameters
|
||||
experts_per_rank = 2 # Each rank is responsible for 2 consecutive experts
|
||||
expert_count = size * experts_per_rank
|
||||
tensor_shape = (16, 32) # Use 2D tensor for testing
|
||||
|
||||
# Calculate the range of experts this rank is responsible for
|
||||
start_expert_id = rank * experts_per_rank
|
||||
end_expert_id = start_expert_id + experts_per_rank
|
||||
@ -124,6 +124,8 @@ class TestHostMoeTensorSharer(unittest.TestCase):
|
||||
sharer.pre_register_host_tensor_with_shape(
|
||||
expert_id, "weight", torch.float32, tensor_shape)
|
||||
|
||||
sharer.finalize_layer_weights()
|
||||
|
||||
# Ensure all processes have created and registered their tensors
|
||||
comm.Barrier()
|
||||
|
||||
|
||||
@ -2,10 +2,11 @@ import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from mpi4py import MPI
|
||||
|
||||
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import (
|
||||
MoeLoadBalancer, SingleLayerMoeLoadBalancer, get_moe_load_balancer,
|
||||
moe_load_balancer_add_single_layer)
|
||||
MoeLoadBalancer, MoeLoadBalancerIterContext, SingleLayerMoeLoadBalancer,
|
||||
get_moe_load_balancer, moe_load_balancer_add_single_layer)
|
||||
|
||||
|
||||
class TestMoeLoadBalancer(unittest.TestCase):
|
||||
@ -186,7 +187,9 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
|
||||
# Setup
|
||||
mock_single_layer_impl = MagicMock()
|
||||
layer = SingleLayerMoeLoadBalancer(mock_single_layer_impl, None)
|
||||
layer = SingleLayerMoeLoadBalancer(mock_single_layer_impl,
|
||||
MPI.COMM_WORLD,
|
||||
expert_count=4)
|
||||
|
||||
# Mock out torch.ops.trtllm functions
|
||||
with patch('torch.ops.trtllm.moe_load_balance_wait_gpu_stage') as mock_wait, \
|
||||
@ -198,13 +201,13 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
# add_weight_slot
|
||||
mock_weight = MagicMock()
|
||||
layer._add_weight_slot(1, "weight1", mock_weight)
|
||||
mock_single_layer_impl.add_weight_slot.assert_called_once_with(
|
||||
mock_single_layer_impl.add_single_weight_slot.assert_called_once_with(
|
||||
1, "weight1", mock_weight)
|
||||
|
||||
# add_host_weight
|
||||
mock_host_weight = MagicMock()
|
||||
layer._add_host_weight(2, "weight2", mock_host_weight)
|
||||
mock_single_layer_impl.add_host_weight.assert_called_once_with(
|
||||
mock_single_layer_impl.add_single_host_weight.assert_called_once_with(
|
||||
2, "weight2", mock_host_weight)
|
||||
|
||||
# set_initial_weight_assignments
|
||||
@ -215,7 +218,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
|
||||
# wait_for_gpu_stage
|
||||
mock_wait.return_value = torch.tensor([1])
|
||||
result = layer.wait_for_gpu_stage()
|
||||
layer.wait_for_gpu_stage()
|
||||
result = layer.statistic_flag_tensor
|
||||
mock_wait.assert_called_once_with(
|
||||
mock_single_layer_impl.get_pointer())
|
||||
self.assertEqual(result, mock_wait.return_value)
|
||||
@ -228,7 +232,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
# statistic
|
||||
mock_expert_ids = torch.tensor([[0, 1], [2, 3]])
|
||||
mock_enabled = torch.tensor([1])
|
||||
layer.statistic(mock_expert_ids, mock_enabled, True, False)
|
||||
layer.statistic_flag_tensor = mock_enabled
|
||||
layer.statistic(mock_expert_ids, True, False)
|
||||
mock_statistic.assert_called_once_with(
|
||||
mock_expert_ids, mock_enabled,
|
||||
mock_single_layer_impl.get_pointer(), True, False)
|
||||
@ -237,8 +242,6 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
mock_selected_experts = torch.tensor([[0, 1], [2, 3]])
|
||||
mock_route.return_value = torch.tensor([[0, 1], [2, 3]])
|
||||
result = layer.route(mock_selected_experts)
|
||||
mock_route.assert_called_once_with(
|
||||
mock_selected_experts, mock_single_layer_impl.get_pointer())
|
||||
assert torch.equal(result, mock_route.return_value)
|
||||
|
||||
@patch('tensorrt_llm.bindings.internal.runtime.MoeLoadBalancer')
|
||||
@ -260,14 +263,13 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
mock_load_balancer_impl.return_value.set_warm_up_iter_count.assert_called_once_with(
|
||||
10)
|
||||
|
||||
# start_iter
|
||||
balancer.start_iter(1, True, True)
|
||||
mock_load_balancer_impl.return_value.start_iter.assert_called_once_with(
|
||||
1, True, True)
|
||||
balancer.set_next_iter_info(True, True)
|
||||
|
||||
# end_iter
|
||||
balancer.end_iter(1)
|
||||
mock_load_balancer_impl.return_value.end_iter.assert_called_once_with(1)
|
||||
with MoeLoadBalancerIterContext(balancer):
|
||||
mock_load_balancer_impl.return_value.start_iter.assert_called_once_with(
|
||||
0, True, True)
|
||||
|
||||
mock_load_balancer_impl.return_value.end_iter.assert_called_once_with(0)
|
||||
|
||||
# shutdown
|
||||
balancer.shutdown()
|
||||
@ -288,6 +290,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
# Create a real MoeLoadBalancer
|
||||
balancer = MoeLoadBalancer(ep_rank, ep_size, 1)
|
||||
|
||||
balancer.set_use_gpu_memcpy(True)
|
||||
|
||||
# Add a layer with initial weight assignments
|
||||
# Each slot is assigned to exactly one expert initially
|
||||
layer = balancer.add_layer(expert_count, top_k, slots_per_rank)
|
||||
@ -297,9 +301,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
# Finalize the model
|
||||
balancer.finalize_model()
|
||||
|
||||
# Start iteration - enable statistic, disable weight update
|
||||
iter_id = 0
|
||||
balancer.start_iter(iter_id, True, False)
|
||||
# enable statistic, disable weight update
|
||||
balancer.set_next_iter_info(True, False)
|
||||
|
||||
# Create sample token data - each token selects 2 experts
|
||||
# 4 tokens, each selecting 2 experts
|
||||
@ -314,17 +317,15 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
device="cuda")
|
||||
|
||||
try:
|
||||
# Wait for GPU stage and get enabled flag
|
||||
enabled = layer.wait_for_gpu_stage()
|
||||
with MoeLoadBalancerIterContext(balancer):
|
||||
# Wait for GPU stage and get enabled flag
|
||||
layer.wait_for_gpu_stage()
|
||||
|
||||
# Run statistic - just test it runs without error
|
||||
layer.statistic(gathered_raw_expert_ids, enabled, True, True)
|
||||
# Run statistic - just test it runs without error
|
||||
layer.statistic(gathered_raw_expert_ids, True, True)
|
||||
|
||||
# Set CPU stage to signal completion
|
||||
layer.set_cpu_stage()
|
||||
|
||||
# End iteration
|
||||
balancer.end_iter(iter_id)
|
||||
# Set CPU stage to signal completion
|
||||
layer.set_cpu_stage()
|
||||
|
||||
# Test passed if we got here without exceptions
|
||||
self.assertTrue(True, "Statistic kernel ran successfully")
|
||||
@ -350,6 +351,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
# Create a real MoeLoadBalancer
|
||||
balancer = MoeLoadBalancer(ep_rank, ep_size, 1)
|
||||
|
||||
balancer.set_use_gpu_memcpy(True)
|
||||
|
||||
# Add a layer with known initial weight assignments
|
||||
layer = balancer.add_layer(expert_count, top_k, slots_per_rank)
|
||||
|
||||
@ -360,9 +363,8 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
# Finalize the model
|
||||
balancer.finalize_model()
|
||||
|
||||
# Start iteration - enable statistic, disable weight update
|
||||
iter_id = 0
|
||||
balancer.start_iter(iter_id, True, False)
|
||||
# enable statistic, disable weight update
|
||||
balancer.set_next_iter_info(True, False)
|
||||
|
||||
# Create sample token data - tokens selecting different experts
|
||||
token_selected_experts = torch.tensor(
|
||||
@ -376,17 +378,15 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
device="cuda")
|
||||
|
||||
try:
|
||||
# Wait for GPU stage
|
||||
layer.wait_for_gpu_stage()
|
||||
with MoeLoadBalancerIterContext(balancer):
|
||||
# Wait for GPU stage
|
||||
layer.wait_for_gpu_stage()
|
||||
|
||||
# Run routing
|
||||
routed_slots = layer.route(token_selected_experts)
|
||||
# Run routing
|
||||
routed_slots = layer.route(token_selected_experts)
|
||||
|
||||
# Set CPU stage
|
||||
layer.set_cpu_stage()
|
||||
|
||||
# End iteration
|
||||
balancer.end_iter(iter_id)
|
||||
# Set CPU stage
|
||||
layer.set_cpu_stage()
|
||||
|
||||
# Verify results - with our initial assignment, expert i should map to slot i
|
||||
expected_slots = torch.tensor(
|
||||
|
||||
@ -136,6 +136,8 @@ class TestMoePythonBindings(unittest.TestCase):
|
||||
ep_size=self.ep_size,
|
||||
layer_updates_per_iter=self.layer_updates_per_iter)
|
||||
|
||||
balancer.set_use_gpu_memcpy(True)
|
||||
|
||||
# Add a layer
|
||||
layer = balancer.add_layer(expert_count=self.expert_count,
|
||||
top_k=self.top_k,
|
||||
@ -206,6 +208,8 @@ class TestMoePythonBindings(unittest.TestCase):
|
||||
ep_size=self.ep_size,
|
||||
layer_updates_per_iter=self.layer_updates_per_iter)
|
||||
|
||||
balancer.set_use_gpu_memcpy(True)
|
||||
|
||||
# Create initial weight assignments
|
||||
initial_assignments = []
|
||||
for r in range(self.ep_size):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user