From e4054682301c47fe03f7adaa877200e0f36a5b9f Mon Sep 17 00:00:00 2001 From: Bo Li <22713281+bobboli@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:59:03 +0800 Subject: [PATCH] [TRTLLM-10048][feat] Fuse the AllGather for expert statistics required by the EPLB. (#10885) Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> --- .../moeAlltoAllKernels.cu | 76 +++++---- .../communicationKernels/moeAlltoAllKernels.h | 18 ++- cpp/tensorrt_llm/thop/moeAlltoAllMeta.h | 6 +- cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp | 94 +++++++++-- .../_torch/distributed/moe_alltoall.py | 88 ++++++++--- .../modules/fused_moe/communication/base.py | 6 + .../communication/communication_factory.py | 5 + .../communication/nvlink_one_sided.py | 117 ++++++++++---- .../modules/fused_moe/configurable_moe.py | 31 +++- .../modules/fused_moe/fused_moe_cutlass.py | 52 +++++-- .../modules/fused_moe/fused_moe_trtllm_gen.py | 54 +++++-- .../unittest/_torch/multi_gpu/test_moe_a2a.py | 147 +++++++++++------- 12 files changed, 508 insertions(+), 186 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index 303255e6a7..da1aed6a37 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -361,19 +361,15 @@ __global__ void moeA2APrepareDispatchKernel( } // ============================================================================ -// Generic Dispatch Kernel Implementation -// One warp per token design: -// - Each CTA has 256 threads = 8 warps -// - Each warp independently processes one token and all its payloads -// - Better GPU utilization and reduced synchronization overhead +// Dispatch Kernels // ============================================================================ -template +template __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [local_num_tokens, TOP_K] const DispatchKernelPointers ptrs, // Struct containing all kernel pointers int num_payloads, // Number of payloads int max_tokens_per_rank, // Maximum tokens per rank - int local_num_tokens, int rank_id, int ep_size, int num_experts_per_rank) + int local_num_tokens, int rank_id, int ep_size, int num_experts, int eplb_stats_num_experts) { int thread_idx = ThreadingPolicy::offset(); @@ -411,6 +407,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ } uint64_t already_copied = 0; + int num_experts_per_rank = num_experts / ep_size; for (int k = 0; k < TOP_K; k++) { int expert_id = token_selected_experts[local_token_idx * TOP_K + k]; @@ -501,6 +498,21 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ ptrs.recv_counters[target_rank][rank_id] = send_count; } + if constexpr (ENABLE_EPLB) + { + // Write local stats into peer buffers before the release fence below. +#pragma unroll 1 + for (int target_rank = 0; target_rank < ep_size; ++target_rank) + { + int* target_stats = ptrs.eplb_gathered_stats[target_rank]; + for (int expert_id = lane_id; expert_id < eplb_stats_num_experts; expert_id += warpSize) + { + int stat_val = ptrs.eplb_local_stats[expert_id]; + target_stats[rank_id * eplb_stats_num_experts + expert_id] = stat_val; + } + } + } + #if !DISABLE_SYNC_FOR_PROFILING uint32_t expected_value = *ptrs.flag_val; @@ -588,6 +600,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) for (int target_rank = 0; target_rank < params.ep_size; target_rank++) { kernel_ptrs.recv_counters[target_rank] = params.recv_counters[target_rank]; + kernel_ptrs.eplb_gathered_stats[target_rank] = params.eplb_gathered_stats[target_rank]; for (int payload = 0; payload < params.num_payloads; payload++) { kernel_ptrs.recv_buffers[target_rank][payload] = params.recv_buffers[target_rank][payload]; @@ -606,6 +619,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) kernel_ptrs.local_token_counter = params.local_token_counter; kernel_ptrs.topk_target_ranks = params.topk_target_ranks; kernel_ptrs.topk_send_indices = params.topk_send_indices; + kernel_ptrs.eplb_local_stats = params.eplb_local_stats; int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ADispatchBlockSize(); constexpr int kWarpSize = 32; @@ -621,10 +635,12 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) grid_size = 1; } int shared_bytes = 2 * params.top_k * (int) sizeof(int); - SWITCH_TOP_K(params.top_k, TOP_K, - moeA2ADispatchKernel<<>>( - params.token_selected_experts, kernel_ptrs, params.num_payloads, params.max_tokens_per_rank, - params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts_per_rank)) + SWITCH_BOOL(params.enable_eplb, EPLB_STATS, + SWITCH_TOP_K(params.top_k, TOP_K, + moeA2ADispatchKernel + <<>>(params.token_selected_experts, kernel_ptrs, + params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, + params.ep_size, params.num_experts, params.eplb_stats_num_experts))) } else { @@ -635,10 +651,12 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) grid_size = 1; } int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int); - SWITCH_TOP_K(params.top_k, TOP_K, - moeA2ADispatchKernel<<>>( - params.token_selected_experts, kernel_ptrs, params.num_payloads, params.max_tokens_per_rank, - params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts_per_rank)) + SWITCH_BOOL(params.enable_eplb, EPLB_STATS, + SWITCH_TOP_K(params.top_k, TOP_K, + moeA2ADispatchKernel + <<>>(params.token_selected_experts, kernel_ptrs, + params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, + params.ep_size, params.num_experts, params.eplb_stats_num_experts))) } } @@ -980,24 +998,24 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c if (payload_bytes == nullptr) return; - int slot_idx = ThreadingPolicy::token_idx(); + int global_token_idx = ThreadingPolicy::token_idx(); - int total_slots = ep_size * max_tokens_per_rank; - if (slot_idx >= total_slots) + int global_token_num = ep_size * max_tokens_per_rank; + if (global_token_idx >= global_token_num) return; - // Map global token to (source_rank, token_idx) - int source_rank = slot_idx / max_tokens_per_rank; - int token_idx = slot_idx % max_tokens_per_rank; + // Map global_token_idx to (rank_idx, local_token_idx) + int rank_idx = global_token_idx / max_tokens_per_rank; + int local_token_idx = global_token_idx % max_tokens_per_rank; - // Skip invalid tokens beyond per-source recv count - if (token_idx >= recv_counters[source_rank]) + // Skip invalid tokens beyond per-rank recv count + if (local_token_idx >= recv_counters[rank_idx]) return; // Calculate source and destination pointers for this token - size_t slot_offset = static_cast(slot_idx) * bytes_per_token; - uint8_t* dst_ptr = recv_buffer_bytes + slot_offset; - uint8_t const* src_ptr = payload_bytes + slot_offset; + size_t offset = static_cast(global_token_idx) * bytes_per_token; + uint8_t* dst_ptr = recv_buffer_bytes + offset; + uint8_t const* src_ptr = payload_bytes + offset; // Copy one token's data using vectorized copy with policy vectorized_copy(dst_ptr, src_ptr, bytes_per_token); @@ -1118,9 +1136,9 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params) } int bytes_per_token = params.elements_per_token * element_size; - int total_slots = params.prepare_payload == nullptr ? 1 : params.ep_size * params.max_tokens_per_rank; - int grid_size_warp = ceilDiv(total_slots, kWarpsPerBlock); - int grid_size_block = total_slots; // one block per token + int global_token_num = params.prepare_payload == nullptr ? 1 : params.ep_size * params.max_tokens_per_rank; + int grid_size_warp = ceilDiv(global_token_num, kWarpsPerBlock); + int grid_size_block = global_token_num; // one block per token if (params.one_block_per_token) { diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h index 20e68657fc..942b2424bb 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h @@ -61,6 +61,10 @@ struct DispatchKernelPointers // Top-K compact routing info per local token (size: [local_num_tokens, top_k]) int* topk_target_ranks; // target rank per k, -1 for duplicates int* topk_send_indices; // dst index per k, -1 for duplicates + + // Optional: Statistics for EPLB + int const* eplb_local_stats; // [eplb_stats_num_experts] + int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank }; // Combine kernel pointers - non-const output in src_data_ptrs[0], const recv buffers @@ -83,13 +87,13 @@ struct CombineKernelPointers // Dispatch phase parameters struct MoeA2ADispatchParams { + // Threading policy bool one_block_per_token; // True: one block per token, False: one warp per token - // Threading policy // EP configuration - int ep_size; // Number of EP ranks - int ep_rank; // Current EP rank - int num_experts_per_rank; // Number of experts per rank (num_experts / ep_size) + int ep_size; // Number of EP ranks + int ep_rank; // Current EP rank + int num_experts; // Total number of experts // Token configuration int local_num_tokens; // Number of tokens on this rank @@ -118,6 +122,12 @@ struct MoeA2ADispatchParams // rank has signaled the target rank void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload + // Optional: Statistics for EPLB + bool enable_eplb; // Whether to enable EPLB + int eplb_stats_num_experts; // Number of experts for EPLB stats + int const* eplb_local_stats; // [eplb_stats_num_experts] + int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank + // CUDA stream cudaStream_t stream; }; diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h b/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h index d8634e6a4f..76f083fde4 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h +++ b/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h @@ -43,8 +43,9 @@ enum MoeA2AMetaInfoIndex : int64_t COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = 5, TOPK_TARGET_RANKS_OFFSET_INDEX = 6, TOPK_SEND_INDICES_OFFSET_INDEX = 7, - PAYLOAD_DATA_OFFSET_INDEX = 8, - NUM_METAINFO_FIELDS = 9 + EPLB_GATHERED_STATS_OFFSET_INDEX = 8, + PAYLOAD_DATA_OFFSET_INDEX = 9, + NUM_METAINFO_FIELDS = 10 }; using MoeA2ADataOffsets = std::array; @@ -60,6 +61,7 @@ inline std::vector> getMoeA2AMetaInfoIndexPairs( {"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", COMBINE_COMPLETION_FLAGS_OFFSET_INDEX}, {"MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX", TOPK_TARGET_RANKS_OFFSET_INDEX}, {"MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX", TOPK_SEND_INDICES_OFFSET_INDEX}, + {"MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX", EPLB_GATHERED_STATS_OFFSET_INDEX}, {"MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX", PAYLOAD_DATA_OFFSET_INDEX}, {"MOE_A2A_NUM_METAINFO_FIELDS", NUM_METAINFO_FIELDS}, }; diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp index 29ad780d4c..d81ae4e399 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp +++ b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp @@ -43,7 +43,7 @@ inline size_t alignOffset(size_t offset, size_t alignment) } // Calculate auxiliary data offsets -MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens) +MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNumExperts) { // TODO: Use lambdas to encapsulate offset and alignment for each entry, which is less error prone and easier to // read. @@ -90,6 +90,11 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens) offset += static_cast(maxNumTokens) * static_cast(tensorrt_llm::kernels::moe_comm::kMaxTopK) * SIZEOF_INT32; + // eplb gathered stats: [epSize, eplbStatsNumExperts] + offset = alignOffset(offset, CACHELINE_ALIGNMENT); + offsets[EPLB_GATHERED_STATS_OFFSET_INDEX] = offset; + offset += static_cast(epSize) * static_cast(eplbStatsNumExperts) * SIZEOF_INT32; + // payload data offset = alignOffset(offset, CACHELINE_ALIGNMENT); offsets[PAYLOAD_DATA_OFFSET_INDEX] = offset; @@ -105,10 +110,12 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens) // - epRank: Current expert parallel rank // - epSize: Total expert parallel size // - maxNumTokens: Maximum number of tokens supported +// - eplbStatsNumExperts: (Optional) Number of experts used for EPLB stats // // Returns: // - metainfo: Tensor containing offsets for auxiliary data -torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens) +torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens, + torch::optional eplbStatsNumExperts) { // Validate inputs CHECK_TH_CUDA(workspace); @@ -120,8 +127,11 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, // Initialize workspace to zero workspace[epRank].zero_(); + int64_t eplbStatsNumExpertsValue = eplbStatsNumExperts.value_or(0); + TORCH_CHECK(eplbStatsNumExpertsValue >= 0, "eplbStatsNumExperts must be positive if not None."); + // Calculate auxiliary data offsets - MoeA2ADataOffsets offsets = calculateOffsets(epSize, maxNumTokens); + MoeA2ADataOffsets offsets = calculateOffsets(epSize, maxNumTokens, static_cast(eplbStatsNumExpertsValue)); // Return metainfo as a tensor containing offsets torch::Tensor metainfo = torch::empty( @@ -155,18 +165,23 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, // - epRank: Current expert parallel rank // - epSize: Total expert parallel size // - topK: Number of experts selected per token -// - numExperts: Total number of experts (must be divisible by epSize) +// - numExperts: Total number of routing slots (tokenSelectedExperts values are in [0, numExperts)) +// - eplbStatsNumExperts: Number of experts used for EPLB stats (may be <= numExperts) +// - eplbLocalStats: [eplbStatsNumExperts] tensor containing local statistics for EPLB. // // Return values: // - recvTensors: Vector of receive buffers (one tensor per payload), each [ep_size, runtimeMaxTokensPerRank, // elements_per_token] // - combinePayloadOffset: Offset into workspace for the combine payload region, to be used by the combine operation +// - eplbGatheredStats: (Optional) [ep_size, eplbStatsNumExperts] tensor containing gathered statistics for EPLB, or +// an empty tensor if eplbLocalStats is None. // // Note: token_selected_experts is used for routing but is NOT automatically included as a payload. // If you want to dispatch token_selected_experts, include it explicitly in inputPayloads. -std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor const& tokenSelectedExperts, - std::vector const& inputPayloads, torch::Tensor const& workspace, torch::Tensor const& metainfo, - int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, int64_t numExperts) +std::tuple, int64_t, torch::Tensor> moeA2ADispatchOp( + torch::Tensor const& tokenSelectedExperts, std::vector const& inputPayloads, + torch::Tensor const& workspace, torch::Tensor const& metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, + int64_t epSize, int64_t topK, int64_t numExperts, torch::optional eplbLocalStats) { using tensorrt_llm::kernels::moe_comm::PayloadDescriptor; using tensorrt_llm::kernels::moe_comm::MoeA2ADispatchParams; @@ -194,6 +209,19 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c TORCH_CHECK(inputPayloads.size() <= kMaxPayloads, "Too many input payloads"); TORCH_CHECK(numExperts >= epSize, "numExperts must be greater than or equal to epSize"); TORCH_CHECK(numExperts % epSize == 0, "numExperts must be divisible by epSize for contiguous partitioning"); + bool enableEplb = eplbLocalStats.has_value(); + int64_t eplbStatsNumExperts = 0; + if (enableEplb) + { + TORCH_CHECK(eplbLocalStats.has_value(), "enable_eplb requires eplb_local_stats"); + torch::Tensor const& eplbLocalStatsTensor = eplbLocalStats.value(); + eplbStatsNumExperts = eplbLocalStatsTensor.size(0); + TORCH_CHECK(eplbStatsNumExperts > 0, "eplb_local_stats must not be empty"); + TORCH_CHECK(eplbStatsNumExperts <= numExperts, "eplb_local_stats size must be <= numExperts (slots)"); + CHECK_INPUT(eplbLocalStatsTensor, torch::kInt32); + TORCH_CHECK(eplbLocalStatsTensor.is_contiguous(), "eplb_local_stats must be contiguous"); + TORCH_CHECK(eplbLocalStatsTensor.dim() == 1, "eplb_local_stats must be a 1D tensor"); + } // All input payloads must have the same first dimension (localNumTokens) for (auto const& payload : inputPayloads) @@ -280,10 +308,12 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c = tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken(); // TODO: Decide this based on the workload params.ep_size = static_cast(epSize); params.ep_rank = static_cast(epRank); - params.num_experts_per_rank = static_cast(numExperts) / static_cast(epSize); + params.num_experts = static_cast(numExperts); params.local_num_tokens = static_cast(localNumTokens); params.max_tokens_per_rank = static_cast(runtimeMaxTokensPerRank); params.top_k = static_cast(topK); + params.enable_eplb = enableEplb; + params.eplb_stats_num_experts = static_cast(eplbStatsNumExperts); params.token_selected_experts = tokenSelectedExperts.data_ptr(); @@ -304,6 +334,15 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c = reinterpret_cast(targetWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]); params.completion_flags[target_rank] = reinterpret_cast(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]); + if (enableEplb) + { + params.eplb_gathered_stats[target_rank] + = reinterpret_cast(targetWorkSpacePtr + offsets[EPLB_GATHERED_STATS_OFFSET_INDEX]); + } + else + { + params.eplb_gathered_stats[target_rank] = nullptr; + } for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) { @@ -312,6 +351,15 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c } } + if (enableEplb) + { + params.eplb_local_stats = eplbLocalStats.value().data_ptr(); + } + else + { + params.eplb_local_stats = nullptr; + } + params.stream = at::cuda::getCurrentCUDAStream(); // Prepare for dispatch (zero counters/indices and increment flag_val) @@ -336,7 +384,20 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c // Compute aligned offset after dispatch payloads for combine payload region int64_t combinePayloadOffset = static_cast(alignOffset(currentOffset, CACHELINE_ALIGNMENT)); - return std::make_tuple(std::move(recvTensors), combinePayloadOffset); + torch::Tensor eplbGatheredStats; + if (enableEplb) + { + int* gatheredStatsPtr = reinterpret_cast(rankWorkSpacePtr + offsets[EPLB_GATHERED_STATS_OFFSET_INDEX]); + auto statsOptions = workspace.options().dtype(torch::kInt32); + eplbGatheredStats = torch::from_blob( + gatheredStatsPtr, {static_cast(epSize), static_cast(eplbStatsNumExperts)}, statsOptions); + } + else + { + eplbGatheredStats = torch::empty({0}, workspace.options().dtype(torch::kInt32)); + } + + return std::make_tuple(std::move(recvTensors), combinePayloadOffset, std::move(eplbGatheredStats)); } // MoE All-to-All Combine Operation @@ -525,9 +586,12 @@ torch::Tensor moeA2AGetCombinePayloadTensorOp(torch::Tensor const& workspace, in } // Return the size of auxiliary data in workspace -int64_t moeA2AGetAuxDataSizeOp(int64_t epSize, int64_t maxNumTokens) +int64_t moeA2AGetAuxDataSizeOp(int64_t epSize, int64_t maxNumTokens, torch::optional eplbStatsNumExperts) { - MoeA2ADataOffsets offsets = calculateOffsets(static_cast(epSize), static_cast(maxNumTokens)); + int64_t eplbStatsNumExpertsValue = eplbStatsNumExperts.value_or(0); + TORCH_CHECK(eplbStatsNumExpertsValue >= 0, "eplbStatsNumExperts must be positive if not None."); + MoeA2ADataOffsets offsets = calculateOffsets( + static_cast(epSize), static_cast(maxNumTokens), static_cast(eplbStatsNumExpertsValue)); return static_cast(offsets[PAYLOAD_DATA_OFFSET_INDEX]); } @@ -546,14 +610,16 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module) module.def( "moe_a2a_dispatch(Tensor token_selected_experts, Tensor[] input_payloads, " "Tensor(a!->*) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, " - "int ep_rank, int ep_size, int top_k, int num_experts) -> (Tensor(a!)[], int)"); + "int ep_rank, int ep_size, int top_k, int num_experts, " + "Tensor? eplb_local_stats=None) -> (Tensor(a!)[], int, Tensor(a!))"); module.def( "moe_a2a_combine(Tensor(a) payload, int local_num_tokens," "Tensor(a!) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, " "int ep_rank, int ep_size, int top_k, int combine_payload_offset, " "bool payload_in_workspace) -> Tensor"); module.def( - "moe_a2a_initialize(Tensor(a!) workspace, int ep_rank, int ep_size, int max_num_tokens_per_rank) -> Tensor"); + "moe_a2a_initialize(Tensor(a!) workspace, int ep_rank, int ep_size, int max_num_tokens_per_rank, " + "int? eplb_stats_num_experts=None) -> Tensor"); module.def( "moe_a2a_sanitize_expert_ids(Tensor(a!) expert_ids, Tensor(a!) workspace, Tensor metainfo, int ep_rank, int " "invalid_expert_id) -> ()"); @@ -561,7 +627,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module) "moe_a2a_get_combine_payload_tensor(Tensor(a) workspace, int ep_rank, int ep_size, int " "runtime_max_tokens_per_rank, " "int combine_payload_offset, ScalarType out_dtype, int hidden_size) -> Tensor(a)"); - module.def("moe_a2a_get_aux_data_size(int ep_size, int max_num_tokens) -> int", + module.def("moe_a2a_get_aux_data_size(int ep_size, int max_num_tokens, int? eplb_stats_num_experts=None) -> int", &tensorrt_llm::torch_ext::moe_comm::moeA2AGetAuxDataSizeOp); } diff --git a/tensorrt_llm/_torch/distributed/moe_alltoall.py b/tensorrt_llm/_torch/distributed/moe_alltoall.py index 231a671f3b..cc593499c7 100644 --- a/tensorrt_llm/_torch/distributed/moe_alltoall.py +++ b/tensorrt_llm/_torch/distributed/moe_alltoall.py @@ -25,6 +25,7 @@ class _A2AState: phase: str = "idle" # idle | dispatched local_num_tokens: int | None = None combine_payload_offset: int | None = None + eplb_gathered_stats: torch.Tensor | None = None class MoeAlltoAll: @@ -41,9 +42,13 @@ class MoeAlltoAll: _METAINFO_INDEX: Dict[str, int] | None = None @staticmethod - def get_aux_data_size(ep_size: int, max_num_tokens: int) -> int: + def get_aux_data_size( + ep_size: int, + max_num_tokens: int, + eplb_stats_num_experts: Optional[int] = None, + ) -> int: return torch.ops.trtllm.moe_a2a_get_aux_data_size( - ep_size, max_num_tokens) + ep_size, max_num_tokens, eplb_stats_num_experts) @staticmethod def calculate_required_workspace_size( @@ -52,11 +57,13 @@ class MoeAlltoAll: max_num_tokens: int, hidden_size: int, dtype: torch.dtype, + eplb_stats_num_experts: Optional[int] = None, extra_payload_bytes_per_token: int = 0) -> int: element_size = dtype.itemsize # Auxiliary data size - workspace_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens) + workspace_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens, + eplb_stats_num_experts) # Dispatch needs workspace for [ep_size, max_tokens] tokens, # but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization. @@ -103,6 +110,8 @@ class MoeAlltoAll: int(thop.MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX), "TOPK_SEND_INDICES_OFFSET_INDEX": int(thop.MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX), + "EPLB_GATHERED_STATS_OFFSET_INDEX": + int(thop.MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX), "PAYLOAD_DATA_OFFSET_INDEX": int(thop.MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX), "NUM_METAINFO_FIELDS": @@ -114,8 +123,9 @@ class MoeAlltoAll: mapping: Mapping, max_num_tokens: int, top_k: int, - num_experts: int, + num_slots: int, workspace_size_per_rank: int, + num_experts: Optional[int] = None, ): """ Initialize MoeAlltoAll with workspace allocation. @@ -124,6 +134,10 @@ class MoeAlltoAll: mapping: TensorRT-LLM Mapping object containing rank information max_num_tokens: Maximum number of tokens supported. Should be ModelConfig.max_num_tokens. workspace_size_per_rank: Size of workspace per rank in bytes + num_slots: Number of routing slots (token_selected_experts values are in [0, num_slots)). + Note: The terminology is mapped to `num_experts` in this class and the kernels. + num_experts: (Optional) Number of experts for EPLB stats (must be <= num_slots). DO NOT provide this parameter if EPLB is not enabled. + Note: The terminology is mapped to `eplb_stats_num_experts` in this class and the kernels. """ # Check for environment variable override workspace_mb_env = os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB") @@ -148,42 +162,52 @@ class MoeAlltoAll: self.ep_rank = mapping.moe_ep_rank self.top_k = top_k - self.num_experts = num_experts + self.num_experts = num_slots + if not isinstance(self.top_k, int) or self.top_k <= 0: raise ValueError("top_k must be a positive int") if not isinstance(self.num_experts, int) or self.num_experts <= 0: - raise ValueError("num_experts must be a positive int") + raise ValueError("num_slots must be a positive int") + + if num_experts is not None: + assert num_experts > 0 and num_experts <= num_slots, "num_experts must be in (0, num_slots]" + tllm_logger.info( + "NVLinkOneSided AlltoAll: EPLB is enabled, with num_slots=" + f"{num_slots} and num_experts={num_experts}") + self.enable_eplb = num_experts is not None + self.eplb_stats_num_experts = num_experts if self._WORKSPACE is None: tllm_logger.info( - f"nvlink_one_sided AlltoAll: Allocating workspace with size {workspace_size_per_rank} bytes. ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, max_num_tokens: {self.max_num_tokens}" + f"NVLinkOneSided AlltoAll: Allocating workspace with size {workspace_size_per_rank} bytes. ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, max_num_tokens: {self.max_num_tokens}" ) mnnvl_mem = MnnvlMemory(mapping, workspace_size_per_rank) workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) metainfo = torch.ops.trtllm.moe_a2a_initialize( - workspace, - self.ep_rank, - self.ep_size, - self.max_num_tokens, - ) + workspace, self.ep_rank, self.ep_size, self.max_num_tokens, + self.eplb_stats_num_experts) MoeAlltoAll._WORKSPACE = { "workspace_size_per_rank": workspace_size_per_rank, "max_num_tokens": self.max_num_tokens, "ep_rank": self.ep_rank, "ep_size": self.ep_size, + "eplb_stats_num_experts": self.eplb_stats_num_experts, "mnnvl_mem": mnnvl_mem, "workspace": workspace, "metainfo": metainfo, } else: assert self._WORKSPACE[ - "workspace_size_per_rank"] == workspace_size_per_rank, "reuse workspace with different workspace_size_per_rank" + "workspace_size_per_rank"] == workspace_size_per_rank, "mistakenly reusing workspace with different workspace_size_per_rank" assert self._WORKSPACE[ - "max_num_tokens"] == self.max_num_tokens, "reuse workspace with different max_num_tokens" + "max_num_tokens"] == self.max_num_tokens, "mistakenly reusing workspace with different max_num_tokens" assert self._WORKSPACE[ - "ep_rank"] == self.ep_rank, "reuse workspace with different ep_rank" + "ep_rank"] == self.ep_rank, "mistakenly reusing workspace with different ep_rank" assert self._WORKSPACE[ - "ep_size"] == self.ep_size, "reuse workspace with different ep_size" + "ep_size"] == self.ep_size, "mistakenly reusing workspace with different ep_size" + assert self._WORKSPACE[ + "eplb_stats_num_experts"] == self.eplb_stats_num_experts, ( + "reuse workspace with different eplb_stats_num_experts") self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"] self.workspace = self._WORKSPACE["workspace"] @@ -196,7 +220,8 @@ class MoeAlltoAll: input_payloads: list[torch.Tensor], runtime_max_tokens_per_rank: int, invalid_token_expert_id: Optional[int] = None, - expert_id_payload_index: Optional[int] = None): + expert_id_payload_index: Optional[int] = None, + eplb_local_stats: Optional[torch.Tensor] = None): """ Perform MoE all-to-all dispatch operation. @@ -206,19 +231,40 @@ class MoeAlltoAll: runtime_max_tokens_per_rank: Maximum of the number of tokens of each DP rank's local batch. invalid_token_expert_id: If not None, set the token_selected_experts of the invalid tokens to this expert id. This is used to notify the MoE to skip these tokens for GroupGEMM. expert_id_payload_index: The index of token_selected_experts in the input_payloads. Must be provided if invalid_token_expert_id is not None. + eplb_local_stats: (Optional) [num_experts] tensor containing local statistics for EPLB Returns: recv_tensors: List of tensors received, each has shape [ep_size, max_tokens_per_rank, payload_num_elements_per_token] """ assert self._state.phase == "idle", "dispatch called twice without an intervening combine" assert runtime_max_tokens_per_rank <= self.max_num_tokens, "runtime_max_tokens_per_rank must not exceed max_num_tokens" - recv_tensors, combine_payload_offset = torch.ops.trtllm.moe_a2a_dispatch( - token_selected_experts, input_payloads, self.workspace, - self.metainfo, runtime_max_tokens_per_rank, self.ep_rank, - self.ep_size, self.top_k, self.num_experts) + if eplb_local_stats is not None: + assert self.enable_eplb, "eplb_local_stats provided but enable_eplb is False" + assert eplb_local_stats.dim( + ) == 1, "eplb_local_stats must be a 1D tensor" + assert eplb_local_stats.size( + 0 + ) == self.eplb_stats_num_experts, "eplb_local_stats size must match eplb_stats_num_experts" + + recv_tensors, combine_payload_offset, eplb_gathered_stats = torch.ops.trtllm.moe_a2a_dispatch( + token_selected_experts, + input_payloads, + self.workspace, + self.metainfo, + runtime_max_tokens_per_rank, + self.ep_rank, + self.ep_size, + self.top_k, + self.num_experts, + eplb_local_stats, + ) + if eplb_gathered_stats.numel() == 0: + eplb_gathered_stats = None + # Update state together after successful dispatch self._state.local_num_tokens = token_selected_experts.size(0) self._state.combine_payload_offset = combine_payload_offset + self._state.eplb_gathered_stats = eplb_gathered_stats self._state.phase = "dispatched" if invalid_token_expert_id is not None: diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/base.py b/tensorrt_llm/_torch/modules/fused_moe/communication/base.py index b98f92830e..8868b25fc0 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/base.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/base.py @@ -185,3 +185,9 @@ class Communication(ABC): Combined output tensor [local_num_tokens, hidden_size] """ raise NotImplementedError + + def get_eplb_gathered_statistics(self) -> Optional[torch.Tensor]: + """ + Return gathered EPLB statistics from the last dispatch, if available. + """ + raise NotImplementedError diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py index b11fd19a7f..c7f5b22a4a 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py @@ -76,6 +76,7 @@ class CommunicationFactory: expert_size_per_partition: Number of experts per partition (required for DeepEP) payload_in_workspace: If True, final_hidden_states is already in workspace (for NVLinkOneSided) alltoall_result_do_sum: If True, sum the alltoall results (for NVLinkTwoSided) + # TODO: Need a way to indicate whether EPLB is enabled. Returns: The selected communication method, or None if attention does not use DP @@ -122,6 +123,7 @@ class CommunicationFactory: # Priority: NVLinkOneSided > NVLinkTwoSided > DeepEP > DeepEPLowLatency > AllGather try: + enable_eplb = model_config.moe_load_balancer is not None strategy = NVLinkOneSided( mapping, num_slots, @@ -130,6 +132,7 @@ class CommunicationFactory: payload_in_workspace, hidden_size=hidden_size, dtype=act_dtype, + num_experts=num_experts if enable_eplb else None, ) logger.info("Selected communication strategy: NVLinkOneSided") return strategy @@ -231,6 +234,7 @@ class CommunicationFactory: alltoall_result_do_sum=alltoall_result_do_sum, ) elif method in ["NVLINK_ONE_SIDED"]: + enable_eplb = model_config.moe_load_balancer is not None return NVLinkOneSided( mapping, num_slots, @@ -239,6 +243,7 @@ class CommunicationFactory: payload_in_workspace, hidden_size=hidden_size, dtype=act_dtype, + num_experts=num_experts if enable_eplb else None, ) elif method == "DEEPEP": return DeepEP( diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py index 2cbe066204..df5b834e28 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py @@ -65,11 +65,18 @@ class NVLinkOneSided(Communication): RECV_COUNTERS_OFFSET_INDEX = None DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = None COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = None + EPLB_GATHERED_STATS_OFFSET_INDEX = None PAYLOAD_DATA_OFFSET_INDEX = None @staticmethod - def get_aux_data_size(ep_size: int, max_num_tokens: int) -> int: - return torch.ops.trtllm.moe_a2a_get_aux_data_size(ep_size, max_num_tokens) + def get_aux_data_size( + ep_size: int, + max_num_tokens: int, + eplb_stats_num_experts: Optional[int] = None, + ) -> int: + return torch.ops.trtllm.moe_a2a_get_aux_data_size( + ep_size, max_num_tokens, eplb_stats_num_experts + ) @staticmethod def calculate_required_workspace_size( @@ -78,12 +85,15 @@ class NVLinkOneSided(Communication): max_num_tokens: int, hidden_size: int, dtype: torch.dtype, + eplb_stats_num_experts: Optional[int] = None, extra_payload_bytes_per_token: int = 0, ) -> int: element_size = dtype.itemsize # Auxiliary data size - workspace_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens) + workspace_size = NVLinkOneSided.get_aux_data_size( + ep_size, max_num_tokens, eplb_stats_num_experts + ) # Dispatch needs workspace for [ep_size, max_tokens] tokens, # but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization. @@ -97,13 +107,12 @@ class NVLinkOneSided(Communication): # token_final_scales workspace_size += ep_size * max_num_tokens * top_k * 4 workspace_size = pad_up(workspace_size, 128) - # extra payload bytes per token - workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token - workspace_size = pad_up(workspace_size, 128) - # Required workspace for combine [ep_size, max_tokens] tokens workspace_size += ep_size * max_num_tokens * hidden_size * element_size workspace_size = pad_up(workspace_size, 128) + # extra payload bytes per token + workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token + workspace_size = pad_up(workspace_size, 128) return workspace_size @@ -124,29 +133,36 @@ class NVLinkOneSided(Communication): cls.COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = int( thop.MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX ) + cls.EPLB_GATHERED_STATS_OFFSET_INDEX = int( + thop.MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX + ) cls.PAYLOAD_DATA_OFFSET_INDEX = int(thop.MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX) def __init__( self, mapping: Mapping, - num_experts: int, + num_slots: int, top_k: int, max_num_tokens_per_rank: int, payload_in_workspace: bool = False, hidden_size: Optional[int] = None, dtype: Optional[torch.dtype] = None, + num_experts: Optional[int] = None, ): """ Initialize NVLinkOneSided with workspace allocation. Args: mapping: TensorRT-LLM Mapping object containing rank information - num_experts: Total number of experts + num_slots: Number of routing slots (token_selected_experts values are in [0, num_slots)). + Note: The terminology is mapped to `num_experts` in this class and the kernels. top_k: Number of experts per token max_num_tokens_per_rank: Maximum number of tokens per rank (for workspace allocation) payload_in_workspace: If True, final_hidden_states is already in workspace hidden_size: Hidden dimension size (optional, for auto workspace calculation) dtype: Data type (optional, for auto workspace calculation) + num_experts: (Optional) Number of experts for EPLB stats (must be <= num_slots). DO NOT provide this parameter if EPLB is not enabled. + Note: The terminology is mapped to `eplb_stats_num_experts` in this class and the kernels. """ super().__init__(mapping) @@ -154,10 +170,20 @@ class NVLinkOneSided(Communication): raise RuntimeError("Currently NVLinkOneSided only supports pure EP for MoE.") # Store needed parameters - self.num_experts = num_experts + self.num_experts = num_slots self.top_k = top_k self.max_num_tokens_per_rank = max_num_tokens_per_rank self.payload_in_workspace = payload_in_workspace + if num_experts is not None: + assert num_experts > 0 and num_experts <= num_slots, ( + "num_experts must be in (0, num_slots]" + ) + tllm_logger.info( + "NVLinkOneSided AlltoAll: EPLB is enabled, with num_slots=" + f"{num_slots} and num_experts={num_experts}" + ) + self.enable_eplb = num_experts is not None + self.eplb_stats_num_experts = num_experts # Initialize constants from C++ self._init_constants() @@ -166,7 +192,12 @@ class NVLinkOneSided(Communication): auto_workspace_size = None if hidden_size is not None and dtype is not None: auto_workspace_size = self.calculate_required_workspace_size( - self.ep_size, self.top_k, max_num_tokens_per_rank, hidden_size, dtype + self.ep_size, + self.top_k, + max_num_tokens_per_rank, + hidden_size, + dtype, + eplb_stats_num_experts=self.eplb_stats_num_experts, ) workspace_mb_env = os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB") if workspace_mb_env: @@ -200,12 +231,14 @@ class NVLinkOneSided(Communication): self.ep_rank, self.ep_size, self.max_num_tokens_per_rank, + self.eplb_stats_num_experts, ) NVLinkOneSided._WORKSPACE = { "workspace_size_per_rank": self.workspace_size_per_rank, "max_num_tokens_per_rank": self.max_num_tokens_per_rank, "ep_rank": self.ep_rank, "ep_size": self.ep_size, + "eplb_stats_num_experts": self.eplb_stats_num_experts, "mnnvl_mem": mnnvl_mem, "workspace": workspace, "metainfo": metainfo, @@ -223,6 +256,9 @@ class NVLinkOneSided(Communication): assert self._WORKSPACE["ep_size"] == self.ep_size, ( "reuse workspace with different ep_size" ) + assert self._WORKSPACE["eplb_stats_num_experts"] == self.eplb_stats_num_experts, ( + "reuse workspace with different eplb_stats_num_experts" + ) self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"] self.workspace = self._WORKSPACE["workspace"] @@ -230,10 +266,7 @@ class NVLinkOneSided(Communication): self.max_num_tokens_per_rank = self._WORKSPACE["max_num_tokens_per_rank"] # Initialize dispatch state - self._dispatch_state = {} - - # Internal state - self._state: str = "idle" # idle | dispatched + self._dispatch_state = {"phase": "idle"} # Invalid token expert ID (default to -1), the kernels in TRTLLM-gen is hard-code to support -1 only. self.invalid_token_expert_id: int = -1 @@ -286,7 +319,7 @@ class NVLinkOneSided(Communication): Tuple of (hidden_states, hidden_states_sf, token_selected_slots, token_final_scales) Each tensor has shape [ep_size, max_tokens_per_rank, ...] """ - if self._state == "dispatched": + if self._dispatch_state.get("phase") == "dispatched": raise RuntimeError("dispatch called twice without an intervening combine") # Calculate runtime_max_tokens_per_rank from all_rank_num_tokens @@ -304,23 +337,35 @@ class NVLinkOneSided(Communication): if token_final_scales is not None: payloads.append(token_final_scales) - recv_buffers, combine_payload_offset = torch.ops.trtllm.moe_a2a_dispatch( - token_selected_slots, - payloads, - self.workspace, - self.moe_a2a_metainfo, - runtime_max_tokens_per_rank, - self.ep_rank, - self.ep_size, - self.top_k, - self.num_experts, + eplb_local_stats = kwargs.get("eplb_local_stats") + if eplb_local_stats is not None: + assert self.enable_eplb, "eplb_local_stats provided but enable_eplb is False" + assert eplb_local_stats.dim() == 1, "eplb_local_stats must be a 1D tensor" + assert eplb_local_stats.size(0) == self.eplb_stats_num_experts, ( + "eplb_local_stats size must match eplb_stats_num_experts" + ) + + recv_buffers, combine_payload_offset, eplb_gathered_stats = ( + torch.ops.trtllm.moe_a2a_dispatch( + token_selected_slots, + payloads, + self.workspace, + self.moe_a2a_metainfo, + runtime_max_tokens_per_rank, + self.ep_rank, + self.ep_size, + self.top_k, + self.num_experts, + eplb_local_stats, + ) ) - - self._state = "dispatched" - + if eplb_gathered_stats.numel() == 0: + eplb_gathered_stats = None + self._dispatch_state["eplb_gathered_stats"] = eplb_gathered_stats self._dispatch_state["combine_payload_offset"] = int(combine_payload_offset) self._dispatch_state["local_num_tokens"] = token_selected_slots.size(0) self._dispatch_state["runtime_max_tokens_per_rank"] = runtime_max_tokens_per_rank + self._dispatch_state["phase"] = "dispatched" # Extract results from recv_buffers # Payload order matches input: @@ -363,6 +408,13 @@ class NVLinkOneSided(Communication): token_final_scales_recv, ) + def get_eplb_gathered_statistics(self) -> Optional[torch.Tensor]: + """ + Return gathered EPLB statistics from the last dispatch, if available. + """ + assert self.enable_eplb, "EPLB is not enabled" + return self._dispatch_state.get("eplb_gathered_stats") + def combine( self, final_hidden_states: torch.Tensor, @@ -380,7 +432,7 @@ class NVLinkOneSided(Communication): Combined output tensor [local_num_tokens, hidden_size] """ - if self._state != "dispatched": + if self._dispatch_state.get("phase") != "dispatched": raise RuntimeError("combine called before a successful dispatch") local_num_tokens = self._dispatch_state.get("local_num_tokens") @@ -423,8 +475,7 @@ class NVLinkOneSided(Communication): ) # Reset state for next round - self._state = "idle" - self._dispatch_state.clear() + self._dispatch_state = {"phase": "idle"} return output @@ -444,7 +495,7 @@ class NVLinkOneSided(Communication): Returns: Tensor view into workspace [ep_size, max_tokens_per_rank, hidden_size] """ - if self._state != "dispatched": + if self._dispatch_state.get("phase") != "dispatched": raise RuntimeError( "get_combine_payload_tensor_in_workspace called before a successful dispatch" ) diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index c711babe2b..ce251234e5 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -606,9 +606,11 @@ class ConfigurableMoE(MoE): if self.layer_load_balancer and token_selected_experts is not None: self._load_balancer_done_wait_gpu_stage(is_first_call) - # Update EPLB statistics (method depends on whether using NVLINK two-sided) - # Use base class method: ignore_allreduce=True for NVLINK two-sided (uses local stats only) - ignore_allreduce = self._is_using_nvlink_two_sided() + # Update EPLB statistics (method depends on communication strategy) + # Use base class method: ignore_allreduce=True for NVLINK two-sided/one-sided (uses local stats only) + ignore_allreduce = ( + self._is_using_nvlink_two_sided() or self._is_using_nvlink_one_sided() + ) self._load_balancer_update_statistic( token_selected_experts, is_first_call, @@ -628,6 +630,9 @@ class ConfigurableMoE(MoE): # ========== Step 3.5: Communication Prepare Phase (BEFORE quantization) ========== # NVLINK two-sided has a prepare phase to gather EPLB statistics + local_statistic_tensor_for_dispatch = None + eplb_dispatch_kwargs = {} + should_update_eplb_after_dispatch = False # Only NVLINK two-sided needs prepare_dispatch if self._is_using_nvlink_two_sided(): # Get local statistic info if this is the last call and EPLB is enabled @@ -645,6 +650,16 @@ class ConfigurableMoE(MoE): if gathered_stats is not None: gathered_stats = gathered_stats.view((self.mapping.moe_ep_size, self.num_experts)) self._load_balancer_update_statistic_with_gathered_statistic(gathered_stats) + # TODO: The abstract does not work well as NVLinkTwoSided gathers EPLB stats in prepare_dispatch, + # while NVLinkOneSided gathers EPLB stats in dispatch. + elif self._is_using_nvlink_one_sided(): + if self.layer_load_balancer and is_last_call: + local_statistic_tensor_for_dispatch = ( + self._load_balancer_get_local_statistic_tensor() + ) + if local_statistic_tensor_for_dispatch is not None: + eplb_dispatch_kwargs["eplb_local_stats"] = local_statistic_tensor_for_dispatch + should_update_eplb_after_dispatch = True # ========== Step 4 & 5: Quantization and Communication Dispatch ========== # Order depends on whether strategy supports post-quant dispatch @@ -666,11 +681,10 @@ class ConfigurableMoE(MoE): # Step 4b: Dispatch AFTER quantization # Get pre_quant_scale for W4AFP8 if available (only DeepEPLowLatency needs it) # Other strategies will ignore this via **kwargs, so it's safe to pass unconditionally - dispatch_kwargs = {} + dispatch_kwargs = dict(eplb_dispatch_kwargs) if hasattr(self, "quant_scales") and self.quant_scales is not None: if hasattr(self.quant_scales, "pre_quant_scale_1"): dispatch_kwargs["pre_quant_scale"] = self.quant_scales.pre_quant_scale_1 - x, x_sf, token_selected_slots, token_final_scales = self.comm.dispatch( hidden_states=x, hidden_states_sf=x_sf, @@ -680,6 +694,9 @@ class ConfigurableMoE(MoE): use_dp_padding=use_dp_padding, **dispatch_kwargs, ) + if should_update_eplb_after_dispatch: + gathered_stats = self.comm.get_eplb_gathered_statistics() + self._load_balancer_update_statistic_with_gathered_statistic(gathered_stats) else: # ===== Pre-quant flow: Dispatch → Quantize ===== @@ -960,6 +977,10 @@ class ConfigurableMoE(MoE): """Check if using NVLinkTwoSided communication strategy""" return isinstance(self.comm, NVLinkTwoSided) + def _is_using_nvlink_one_sided(self) -> bool: + """Check if using NVLinkOneSided communication strategy""" + return isinstance(self.comm, NVLinkOneSided) + def _get_nvlink_onesided_moe_output( self, all_rank_num_tokens: Optional[List[int]], diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 215a4a63b2..e392bfadee 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -166,15 +166,22 @@ class CutlassFusedMoE(MoE): dtype = self.dtype or torch.float16 workspace_size = MoeAlltoAll.calculate_required_workspace_size( - ep_size, self.routing_method.experts_per_token, - max_num_tokens, hidden_size, dtype) + ep_size, + self.routing_method.experts_per_token, + max_num_tokens, + hidden_size, + dtype, + self.num_experts if self.layer_load_balancer else None, + ) self.moe_a2a = MoeAlltoAll( mapping=self.mapping, max_num_tokens=model_config.max_num_tokens, top_k=self.routing_method.experts_per_token, - num_experts=self.num_slots, + num_slots=self.num_slots, workspace_size_per_rank=workspace_size, + num_experts=self.num_experts + if self.layer_load_balancer else None, ) elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: raise NotImplementedError( @@ -509,7 +516,10 @@ class CutlassFusedMoE(MoE): if self.layer_load_balancer: self._load_balancer_done_wait_gpu_stage(is_first_call) - ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided + ignore_allreduce = self.enable_alltoall and self.alltoall_method_type in ( + AlltoallMethodType.NVLinkTwoSided, + AlltoallMethodType.NVLinkOneSided, + ) self._load_balancer_update_statistic( token_selected_experts, is_first_call, @@ -608,14 +618,32 @@ class CutlassFusedMoE(MoE): payloads.append(token_selected_slots) payloads.append(token_final_scales) - recv_tensors = self.moe_a2a.dispatch( - token_selected_slots, - payloads, - runtime_max_tokens_per_rank, - invalid_token_expert_id=self. - num_slots, # Caution: Cutlass MoE uses num_slots as invalid token expert id - expert_id_payload_index=expert_id_payload_index, - ) + loadbalancer_local_statistic_info = None + if self.layer_load_balancer and is_last_call: + loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor( + ) + if loadbalancer_local_statistic_info is not None: + recv_tensors = self.moe_a2a.dispatch( + token_selected_slots, + payloads, + runtime_max_tokens_per_rank, + invalid_token_expert_id=self. + num_slots, # Caution: Cutlass MoE uses num_slots as invalid token expert id + expert_id_payload_index=expert_id_payload_index, + eplb_local_stats=loadbalancer_local_statistic_info, + ) + gathered_stats = self.moe_a2a._state.eplb_gathered_stats + self._load_balancer_update_statistic_with_gathered_statistic( + gathered_stats) + else: + recv_tensors = self.moe_a2a.dispatch( + token_selected_slots, + payloads, + runtime_max_tokens_per_rank, + invalid_token_expert_id=self. + num_slots, # Caution: Cutlass MoE uses num_slots as invalid token expert id + expert_id_payload_index=expert_id_payload_index, + ) if x_sf is not None: x_recv, x_sf_recv, token_selected_slots_recv, token_final_scales_recv = recv_tensors diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 44b31deb32..879daf4a6d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -140,15 +140,22 @@ class TRTLLMGenFusedMoE(MoE): dtype = self.dtype or torch.bfloat16 workspace_size = MoeAlltoAll.calculate_required_workspace_size( - ep_size, self.routing_method.experts_per_token, - max_num_tokens, hidden_size, dtype) + ep_size, + self.routing_method.experts_per_token, + max_num_tokens, + hidden_size, + dtype, + self.num_experts if self.layer_load_balancer else None, + ) self.moe_a2a = MoeAlltoAll( mapping=self.mapping, max_num_tokens=model_config.max_num_tokens, top_k=self.routing_method.experts_per_token, - num_experts=self.num_slots, - workspace_size_per_rank=workspace_size) + num_slots=self.num_slots, + workspace_size_per_rank=workspace_size, + num_experts=self.num_experts + if self.layer_load_balancer else None) elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: raise NotImplementedError( "DeepEP and DeepEPLowLatency are not supported for TRTLLMGenFusedMoE yet" @@ -690,7 +697,10 @@ class TRTLLMGenFusedMoE(MoE): self._load_balancer_done_wait_gpu_stage(is_first_call) - ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided + ignore_allreduce = self.enable_alltoall and self.alltoall_method_type in ( + AlltoallMethodType.NVLinkTwoSided, + AlltoallMethodType.NVLinkOneSided, + ) self._load_balancer_update_statistic( token_selected_experts, is_first_call, @@ -778,14 +788,32 @@ class TRTLLMGenFusedMoE(MoE): payloads.append(token_selected_experts) payloads.append(token_final_scales) - recv_tensors = self.moe_a2a.dispatch( - token_selected_experts, - payloads, - runtime_max_tokens_per_rank, - invalid_token_expert_id= - -1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id - expert_id_payload_index=expert_id_payload_index, - ) + loadbalancer_local_statistic_info = None + if self.layer_load_balancer and is_last_call: + loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor( + ) + if loadbalancer_local_statistic_info is not None: + recv_tensors = self.moe_a2a.dispatch( + token_selected_experts, + payloads, + runtime_max_tokens_per_rank, + invalid_token_expert_id= + -1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id + expert_id_payload_index=expert_id_payload_index, + eplb_local_stats=loadbalancer_local_statistic_info, + ) + gathered_stats = self.moe_a2a._state.eplb_gathered_stats + self._load_balancer_update_statistic_with_gathered_statistic( + gathered_stats) + else: + recv_tensors = self.moe_a2a.dispatch( + token_selected_experts, + payloads, + runtime_max_tokens_per_rank, + invalid_token_expert_id= + -1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id + expert_id_payload_index=expert_id_payload_index, + ) if x_sf is not None: x_recv, x_sf_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors diff --git a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py index 49ed032aff..ddca5fb8ab 100644 --- a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py +++ b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py @@ -56,24 +56,23 @@ def compute_target_rank_id(expert_id, num_experts_per_rank): return expert_id // num_experts_per_rank -def generate_token_selected_experts(local_num_tokens: int, ep_size: int, - num_experts_per_rank: int, +def generate_token_selected_experts(local_num_tokens: int, num_experts: int, top_k: int) -> torch.Tensor: """Generate global expert IDs tensor, aligned with single-GPU test semantics.""" return torch.randint( 0, - ep_size * num_experts_per_rank, + num_experts, (local_num_tokens, top_k), dtype=torch.int32, device='cuda', ) -def create_experts(num_experts_per_rank, - hidden_size, - ep_rank, - device, - dtype=torch.bfloat16): +def create_experts_per_rank(num_experts_per_rank, + hidden_size, + ep_rank, + device, + dtype=torch.bfloat16): """ Create a 3D tensor of expert weights for a given rank. @@ -226,9 +225,9 @@ def make_bfloat16_payloads( def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k, - workspace_size_per_rank, - num_experts_per_rank, hidden_size, - invalid_token_expert_id): + workspace_size_per_rank, num_experts, + hidden_size, invalid_token_expert_id, + enable_eplb): """Worker function for MPIPoolExecutor.""" rank = tllm.mpi_rank() torch.cuda.set_device(rank) @@ -244,25 +243,40 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k, # Create MoeAlltoAll manager max_num_tokens = max(all_num_tokens) - moe_a2a = MoeAlltoAll(mapping, max_num_tokens, top_k, - ep_size * num_experts_per_rank, - workspace_size_per_rank) + eplb_stats_num_experts = ( + num_experts // 2 if enable_eplb else None + ) # Use half of the experts for testing EPLB stats + moe_a2a = MoeAlltoAll( + mapping=mapping, + max_num_tokens=max_num_tokens, + top_k=top_k, + num_slots=num_experts, + workspace_size_per_rank=workspace_size_per_rank, + num_experts=eplb_stats_num_experts, + ) # Get the number of tokens for this specific rank (same as single-GPU) rank_local_tokens = all_num_tokens[rank] # Generate data using helper functions token_selected_experts = generate_token_selected_experts( - rank_local_tokens, ep_size, num_experts_per_rank, top_k) + rank_local_tokens, num_experts, top_k) payloads, expert_id_payload_index = make_nvfp4_payloads( rank_local_tokens, hidden_size, top_k, rank, token_selected_experts) + eplb_local_stats = None + if enable_eplb: + eplb_local_stats = (torch.arange( + eplb_stats_num_experts, dtype=torch.int32, device="cuda") + + rank * 1000) + recv_tensors = moe_a2a.dispatch( token_selected_experts, payloads, max_num_tokens, invalid_token_expert_id=invalid_token_expert_id, - expert_id_payload_index=expert_id_payload_index) + expert_id_payload_index=expert_id_payload_index, + eplb_local_stats=eplb_local_stats) # Verify completion flags after dispatch completion_flags_offset = moe_a2a.metainfo[MoeAlltoAll._METAINFO_INDEX[ @@ -306,10 +320,16 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k, max_num_tokens, top_k).cpu() # Return results to be collected (move to CPU for MPI transfer) + eplb_gathered_stats = moe_a2a._state.eplb_gathered_stats + if eplb_gathered_stats is not None: + eplb_gathered_stats = eplb_gathered_stats.cpu() + if eplb_local_stats is not None: + eplb_local_stats = eplb_local_stats.cpu() + return (token_selected_experts.cpu(), [p.cpu() for p in payloads], - [rt.cpu() - for rt in recv_tensors], send_counters, topk_send_indices, - topk_target_ranks, recv_counters, expert_id_payload_index) + [rt.cpu() for rt in recv_tensors], send_counters, + topk_send_indices, topk_target_ranks, recv_counters, + expert_id_payload_index, eplb_local_stats, eplb_gathered_stats) except Exception: traceback.print_exc() raise @@ -318,11 +338,12 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k, def verify_dispatch(all_token_selected_experts, all_payloads, all_recv_tensors, all_send_counters, all_topk_send_indices, all_topk_target_ranks, all_recv_counters, ep_size, - all_num_tokens, top_k, num_experts_per_rank, - expert_id_payload_index, invalid_token_expert_id): + all_num_tokens, top_k, num_experts, expert_id_payload_index, + invalid_token_expert_id): """Verify dispatch results including actual content verification""" max_num_tokens = max(all_num_tokens) + num_experts_per_rank = num_experts // ep_size # Verify dimensions and dtypes for send_rank in range(ep_size): local_num_tokens = all_num_tokens[send_rank] @@ -475,26 +496,32 @@ class TestMoEAlltoAll: enabled=False ) # MPI pool executors have known thread cleanup timing issues @pytest.mark.parametrize( - "mpi_pool_executor,all_num_tokens,top_k", + "mpi_pool_executor,all_num_tokens,top_k,enable_eplb", [ # (num_workers, all_num_tokens, top_k) # Basic configurations - (4, [32, 32, 32, 32], 2), # Four ranks with uniform distribution + (4, [32, 32, 32, 32], 2, False + ), # Four ranks with uniform distribution (4, [16, 32, 64, 48 - ], 2), # Four ranks with non-uniform distribution - (2, [100, 50], 2), # Two ranks with different loads + ], 2, False), # Four ranks with non-uniform distribution + (2, [100, 50], 2, False), # Two ranks with different loads (8, [10, 20, 30, 40, 50, 60, 70, 80 - ], 2), # Eight ranks with increasing load + ], 2, False), # Eight ranks with increasing load # Different top_k values - (4, [32, 32, 32, 32], 4), # Four ranks with top_k = 4 - (4, [32, 32, 32, 32], 8), # Four ranks with top_k = 8 + (4, [32, 32, 32, 32], 4, False), # Four ranks with top_k = 4 + (4, [32, 32, 32, 32], 8, False), # Four ranks with top_k = 8 # Edge cases - (4, [1, 1, 1, 1], 2), # Four ranks with single token per rank + (4, [1, 1, 1, 1], 2, False + ), # Four ranks with single token per rank + + # EPLB stats path + (4, [32, 32, 32, 32], 2, True), ], indirect=["mpi_pool_executor"]) - def test_dispatch(self, mpi_pool_executor, all_num_tokens, top_k): + def test_dispatch(self, mpi_pool_executor, all_num_tokens, top_k, + enable_eplb): """Test MoE A2A dispatch with MNNVL across multiple GPUs""" try: @@ -511,7 +538,7 @@ class TestMoEAlltoAll: ) >= ep_size, f"Need at least {ep_size} GPUs, found {torch.cuda.device_count()}" hidden_size = 1024 - num_experts_per_rank = 8 + num_experts = 32 # Large enough workspace workspace_size_per_rank = 512 * 1024 * 1024 @@ -523,8 +550,8 @@ class TestMoEAlltoAll: results = mpi_pool_executor.map( run_moe_a2a_dispatch_single_rank, *zip(*[(ep_size, all_num_tokens, top_k, workspace_size_per_rank, - num_experts_per_rank, hidden_size, - invalid_token_expert_id)] * ep_size), + num_experts, hidden_size, invalid_token_expert_id, + enable_eplb)] * ep_size), ) # Collect results from all ranks (same as single-GPU collecting from emulated ranks) @@ -540,6 +567,8 @@ class TestMoEAlltoAll: all_recv_counters = [r[6] for r in all_results] all_expert_id_payload_index = [r[7] for r in all_results] expert_id_payload_index = all_expert_id_payload_index[0] + all_eplb_local_stats = [r[8] for r in all_results] + all_eplb_gathered_stats = [r[9] for r in all_results] assert all(i == expert_id_payload_index for i in all_expert_id_payload_index @@ -550,9 +579,18 @@ class TestMoEAlltoAll: all_recv_tensors, all_send_counters, all_topk_send_indices, all_topk_target_ranks, all_recv_counters, ep_size, all_num_tokens, top_k, - num_experts_per_rank, expert_id_payload_index, + num_experts, expert_id_payload_index, invalid_token_expert_id) + if enable_eplb: + expected_stats = torch.stack(all_eplb_local_stats, dim=0) + for rank in range(ep_size): + gathered_stats = all_eplb_gathered_stats[rank] + assert gathered_stats is not None + assert torch.equal( + gathered_stats, + expected_stats), (f"Rank {rank} gathered_stats mismatch") + @pytest.mark.skipif(torch.cuda.device_count() < 8, reason='needs at least 8 GPUs to run multi-GPU test') @pytest.mark.threadleak(enabled=False) @@ -587,8 +625,9 @@ class TestMoEAlltoAll: assert torch.cuda.device_count( ) >= ep_size, f"Need at least {ep_size} GPUs, found {torch.cuda.device_count()}" - hidden_size = 2880 # gpt-oss - num_experts_per_rank = 8 + # gpt-oss-20b + hidden_size = 2880 + num_experts = 32 # Large enough workspace workspace_size_per_rank = 512 * 1024 * 1024 @@ -599,8 +638,8 @@ class TestMoEAlltoAll: results = mpi_pool_executor.map( run_moe_a2a_dispatch_moe_combine_single_rank, *zip(*[(ep_size, all_num_tokens, top_k, workspace_size_per_rank, - num_experts_per_rank, hidden_size, - invalid_token_expert_id)] * ep_size), + num_experts, hidden_size, invalid_token_expert_id)] * + ep_size), ) # Collect results @@ -617,14 +656,12 @@ class TestMoEAlltoAll: # Verify combine results print("Starting verification...") - verify_combine_results(all_results, ep_size, all_num_tokens, top_k, - hidden_size, num_experts_per_rank) + verify_combine(all_results, ep_size) def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k, workspace_size_per_rank, - num_experts_per_rank, - hidden_size, + num_experts, hidden_size, invalid_token_expert_id): """Worker function for dispatch and combine test.""" rank = tllm.mpi_rank() @@ -639,15 +676,19 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k, world_size=ep_size) # Create MoeAlltoAll manager - moe_a2a = MoeAlltoAll(mapping, max_num_tokens, top_k, - ep_size * num_experts_per_rank, - workspace_size_per_rank) + moe_a2a = MoeAlltoAll( + mapping=mapping, + max_num_tokens=max_num_tokens, + top_k=top_k, + num_slots=num_experts, + workspace_size_per_rank=workspace_size_per_rank, + ) rank_local_tokens = all_num_tokens[rank] # Generate test data - use simpler payload for combine test token_selected_experts = generate_token_selected_experts( - rank_local_tokens, ep_size, num_experts_per_rank, top_k) + rank_local_tokens, num_experts, top_k) payloads, expert_id_payload_index = make_bfloat16_payloads( rank_local_tokens, hidden_size, top_k, rank, token_selected_experts) @@ -670,11 +711,12 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k, # emulate MoE computation on the received data # Create experts for this rank - rank_experts = create_experts(num_experts_per_rank, - hidden_size, - rank, - device, - dtype=torch.bfloat16) + num_experts_per_rank = num_experts // ep_size + rank_experts = create_experts_per_rank(num_experts_per_rank, + hidden_size, + rank, + device, + dtype=torch.bfloat16) hidden_states_recv = fake_moe( hidden_states_recv.view(ep_size * max_num_tokens, @@ -721,8 +763,7 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k, raise -def verify_combine_results(all_results, ep_size, all_num_tokens, top_k, - hidden_size, num_experts_per_rank): +def verify_combine(all_results, ep_size): """Verify that combine correctly sums the dispatched tokens.""" # Extract results