[TRTLLM-10048][feat] Fuse the AllGather for expert statistics required by the EPLB. (#10885)

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
Bo Li 2026-01-26 17:59:03 +08:00 committed by GitHub
parent 5efee01da1
commit e405468230
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 508 additions and 186 deletions

View File

@ -361,19 +361,15 @@ __global__ void moeA2APrepareDispatchKernel(
}
// ============================================================================
// Generic Dispatch Kernel Implementation
// One warp per token design:
// - Each CTA has 256 threads = 8 warps
// - Each warp independently processes one token and all its payloads
// - Better GPU utilization and reduced synchronization overhead
// Dispatch Kernels
// ============================================================================
template <typename ThreadingPolicy, int TOP_K>
template <typename ThreadingPolicy, int TOP_K, bool ENABLE_EPLB>
__global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [local_num_tokens, TOP_K]
const DispatchKernelPointers ptrs, // Struct containing all kernel pointers
int num_payloads, // Number of payloads
int max_tokens_per_rank, // Maximum tokens per rank
int local_num_tokens, int rank_id, int ep_size, int num_experts_per_rank)
int local_num_tokens, int rank_id, int ep_size, int num_experts, int eplb_stats_num_experts)
{
int thread_idx = ThreadingPolicy::offset();
@ -411,6 +407,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
}
uint64_t already_copied = 0;
int num_experts_per_rank = num_experts / ep_size;
for (int k = 0; k < TOP_K; k++)
{
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
@ -501,6 +498,21 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
ptrs.recv_counters[target_rank][rank_id] = send_count;
}
if constexpr (ENABLE_EPLB)
{
// Write local stats into peer buffers before the release fence below.
#pragma unroll 1
for (int target_rank = 0; target_rank < ep_size; ++target_rank)
{
int* target_stats = ptrs.eplb_gathered_stats[target_rank];
for (int expert_id = lane_id; expert_id < eplb_stats_num_experts; expert_id += warpSize)
{
int stat_val = ptrs.eplb_local_stats[expert_id];
target_stats[rank_id * eplb_stats_num_experts + expert_id] = stat_val;
}
}
}
#if !DISABLE_SYNC_FOR_PROFILING
uint32_t expected_value = *ptrs.flag_val;
@ -588,6 +600,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
for (int target_rank = 0; target_rank < params.ep_size; target_rank++)
{
kernel_ptrs.recv_counters[target_rank] = params.recv_counters[target_rank];
kernel_ptrs.eplb_gathered_stats[target_rank] = params.eplb_gathered_stats[target_rank];
for (int payload = 0; payload < params.num_payloads; payload++)
{
kernel_ptrs.recv_buffers[target_rank][payload] = params.recv_buffers[target_rank][payload];
@ -606,6 +619,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
kernel_ptrs.local_token_counter = params.local_token_counter;
kernel_ptrs.topk_target_ranks = params.topk_target_ranks;
kernel_ptrs.topk_send_indices = params.topk_send_indices;
kernel_ptrs.eplb_local_stats = params.eplb_local_stats;
int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ADispatchBlockSize();
constexpr int kWarpSize = 32;
@ -621,10 +635,12 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
grid_size = 1;
}
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<BlockPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
params.token_selected_experts, kernel_ptrs, params.num_payloads, params.max_tokens_per_rank,
params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts_per_rank))
SWITCH_BOOL(params.enable_eplb, EPLB_STATS,
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<BlockPolicy, TOP_K, EPLB_STATS>
<<<grid_size, kBlockSize, shared_bytes, params.stream>>>(params.token_selected_experts, kernel_ptrs,
params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank,
params.ep_size, params.num_experts, params.eplb_stats_num_experts)))
}
else
{
@ -635,10 +651,12 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
grid_size = 1;
}
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<WarpPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
params.token_selected_experts, kernel_ptrs, params.num_payloads, params.max_tokens_per_rank,
params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts_per_rank))
SWITCH_BOOL(params.enable_eplb, EPLB_STATS,
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<WarpPolicy, TOP_K, EPLB_STATS>
<<<grid_size, kBlockSize, shared_bytes, params.stream>>>(params.token_selected_experts, kernel_ptrs,
params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank,
params.ep_size, params.num_experts, params.eplb_stats_num_experts)))
}
}
@ -980,24 +998,24 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c
if (payload_bytes == nullptr)
return;
int slot_idx = ThreadingPolicy::token_idx();
int global_token_idx = ThreadingPolicy::token_idx();
int total_slots = ep_size * max_tokens_per_rank;
if (slot_idx >= total_slots)
int global_token_num = ep_size * max_tokens_per_rank;
if (global_token_idx >= global_token_num)
return;
// Map global token to (source_rank, token_idx)
int source_rank = slot_idx / max_tokens_per_rank;
int token_idx = slot_idx % max_tokens_per_rank;
// Map global_token_idx to (rank_idx, local_token_idx)
int rank_idx = global_token_idx / max_tokens_per_rank;
int local_token_idx = global_token_idx % max_tokens_per_rank;
// Skip invalid tokens beyond per-source recv count
if (token_idx >= recv_counters[source_rank])
// Skip invalid tokens beyond per-rank recv count
if (local_token_idx >= recv_counters[rank_idx])
return;
// Calculate source and destination pointers for this token
size_t slot_offset = static_cast<size_t>(slot_idx) * bytes_per_token;
uint8_t* dst_ptr = recv_buffer_bytes + slot_offset;
uint8_t const* src_ptr = payload_bytes + slot_offset;
size_t offset = static_cast<size_t>(global_token_idx) * bytes_per_token;
uint8_t* dst_ptr = recv_buffer_bytes + offset;
uint8_t const* src_ptr = payload_bytes + offset;
// Copy one token's data using vectorized copy with policy
vectorized_copy<ThreadingPolicy>(dst_ptr, src_ptr, bytes_per_token);
@ -1118,9 +1136,9 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params)
}
int bytes_per_token = params.elements_per_token * element_size;
int total_slots = params.prepare_payload == nullptr ? 1 : params.ep_size * params.max_tokens_per_rank;
int grid_size_warp = ceilDiv(total_slots, kWarpsPerBlock);
int grid_size_block = total_slots; // one block per token
int global_token_num = params.prepare_payload == nullptr ? 1 : params.ep_size * params.max_tokens_per_rank;
int grid_size_warp = ceilDiv(global_token_num, kWarpsPerBlock);
int grid_size_block = global_token_num; // one block per token
if (params.one_block_per_token)
{

View File

@ -61,6 +61,10 @@ struct DispatchKernelPointers
// Top-K compact routing info per local token (size: [local_num_tokens, top_k])
int* topk_target_ranks; // target rank per k, -1 for duplicates
int* topk_send_indices; // dst index per k, -1 for duplicates
// Optional: Statistics for EPLB
int const* eplb_local_stats; // [eplb_stats_num_experts]
int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank
};
// Combine kernel pointers - non-const output in src_data_ptrs[0], const recv buffers
@ -83,13 +87,13 @@ struct CombineKernelPointers
// Dispatch phase parameters
struct MoeA2ADispatchParams
{
// Threading policy
bool one_block_per_token; // True: one block per token, False: one warp per token
// Threading policy
// EP configuration
int ep_size; // Number of EP ranks
int ep_rank; // Current EP rank
int num_experts_per_rank; // Number of experts per rank (num_experts / ep_size)
int ep_size; // Number of EP ranks
int ep_rank; // Current EP rank
int num_experts; // Total number of experts
// Token configuration
int local_num_tokens; // Number of tokens on this rank
@ -118,6 +122,12 @@ struct MoeA2ADispatchParams
// rank has signaled the target rank
void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload
// Optional: Statistics for EPLB
bool enable_eplb; // Whether to enable EPLB
int eplb_stats_num_experts; // Number of experts for EPLB stats
int const* eplb_local_stats; // [eplb_stats_num_experts]
int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank
// CUDA stream
cudaStream_t stream;
};

View File

@ -43,8 +43,9 @@ enum MoeA2AMetaInfoIndex : int64_t
COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = 5,
TOPK_TARGET_RANKS_OFFSET_INDEX = 6,
TOPK_SEND_INDICES_OFFSET_INDEX = 7,
PAYLOAD_DATA_OFFSET_INDEX = 8,
NUM_METAINFO_FIELDS = 9
EPLB_GATHERED_STATS_OFFSET_INDEX = 8,
PAYLOAD_DATA_OFFSET_INDEX = 9,
NUM_METAINFO_FIELDS = 10
};
using MoeA2ADataOffsets = std::array<int64_t, NUM_METAINFO_FIELDS>;
@ -60,6 +61,7 @@ inline std::vector<std::pair<char const*, int64_t>> getMoeA2AMetaInfoIndexPairs(
{"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", COMBINE_COMPLETION_FLAGS_OFFSET_INDEX},
{"MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX", TOPK_TARGET_RANKS_OFFSET_INDEX},
{"MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX", TOPK_SEND_INDICES_OFFSET_INDEX},
{"MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX", EPLB_GATHERED_STATS_OFFSET_INDEX},
{"MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX", PAYLOAD_DATA_OFFSET_INDEX},
{"MOE_A2A_NUM_METAINFO_FIELDS", NUM_METAINFO_FIELDS},
};

View File

@ -43,7 +43,7 @@ inline size_t alignOffset(size_t offset, size_t alignment)
}
// Calculate auxiliary data offsets
MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNumExperts)
{
// TODO: Use lambdas to encapsulate offset and alignment for each entry, which is less error prone and easier to
// read.
@ -90,6 +90,11 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::moe_comm::kMaxTopK)
* SIZEOF_INT32;
// eplb gathered stats: [epSize, eplbStatsNumExperts]
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
offsets[EPLB_GATHERED_STATS_OFFSET_INDEX] = offset;
offset += static_cast<size_t>(epSize) * static_cast<size_t>(eplbStatsNumExperts) * SIZEOF_INT32;
// payload data
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
offsets[PAYLOAD_DATA_OFFSET_INDEX] = offset;
@ -105,10 +110,12 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
// - epRank: Current expert parallel rank
// - epSize: Total expert parallel size
// - maxNumTokens: Maximum number of tokens supported
// - eplbStatsNumExperts: (Optional) Number of experts used for EPLB stats
//
// Returns:
// - metainfo: Tensor containing offsets for auxiliary data
torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens)
torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens,
torch::optional<int64_t> eplbStatsNumExperts)
{
// Validate inputs
CHECK_TH_CUDA(workspace);
@ -120,8 +127,11 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank,
// Initialize workspace to zero
workspace[epRank].zero_();
int64_t eplbStatsNumExpertsValue = eplbStatsNumExperts.value_or(0);
TORCH_CHECK(eplbStatsNumExpertsValue >= 0, "eplbStatsNumExperts must be positive if not None.");
// Calculate auxiliary data offsets
MoeA2ADataOffsets offsets = calculateOffsets(epSize, maxNumTokens);
MoeA2ADataOffsets offsets = calculateOffsets(epSize, maxNumTokens, static_cast<int>(eplbStatsNumExpertsValue));
// Return metainfo as a tensor containing offsets
torch::Tensor metainfo = torch::empty(
@ -155,18 +165,23 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank,
// - epRank: Current expert parallel rank
// - epSize: Total expert parallel size
// - topK: Number of experts selected per token
// - numExperts: Total number of experts (must be divisible by epSize)
// - numExperts: Total number of routing slots (tokenSelectedExperts values are in [0, numExperts))
// - eplbStatsNumExperts: Number of experts used for EPLB stats (may be <= numExperts)
// - eplbLocalStats: [eplbStatsNumExperts] tensor containing local statistics for EPLB.
//
// Return values:
// - recvTensors: Vector of receive buffers (one tensor per payload), each [ep_size, runtimeMaxTokensPerRank,
// elements_per_token]
// - combinePayloadOffset: Offset into workspace for the combine payload region, to be used by the combine operation
// - eplbGatheredStats: (Optional) [ep_size, eplbStatsNumExperts] tensor containing gathered statistics for EPLB, or
// an empty tensor if eplbLocalStats is None.
//
// Note: token_selected_experts is used for routing but is NOT automatically included as a payload.
// If you want to dispatch token_selected_experts, include it explicitly in inputPayloads.
std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor const& tokenSelectedExperts,
std::vector<torch::Tensor> const& inputPayloads, torch::Tensor const& workspace, torch::Tensor const& metainfo,
int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, int64_t numExperts)
std::tuple<std::vector<torch::Tensor>, int64_t, torch::Tensor> moeA2ADispatchOp(
torch::Tensor const& tokenSelectedExperts, std::vector<torch::Tensor> const& inputPayloads,
torch::Tensor const& workspace, torch::Tensor const& metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank,
int64_t epSize, int64_t topK, int64_t numExperts, torch::optional<torch::Tensor> eplbLocalStats)
{
using tensorrt_llm::kernels::moe_comm::PayloadDescriptor;
using tensorrt_llm::kernels::moe_comm::MoeA2ADispatchParams;
@ -194,6 +209,19 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
TORCH_CHECK(inputPayloads.size() <= kMaxPayloads, "Too many input payloads");
TORCH_CHECK(numExperts >= epSize, "numExperts must be greater than or equal to epSize");
TORCH_CHECK(numExperts % epSize == 0, "numExperts must be divisible by epSize for contiguous partitioning");
bool enableEplb = eplbLocalStats.has_value();
int64_t eplbStatsNumExperts = 0;
if (enableEplb)
{
TORCH_CHECK(eplbLocalStats.has_value(), "enable_eplb requires eplb_local_stats");
torch::Tensor const& eplbLocalStatsTensor = eplbLocalStats.value();
eplbStatsNumExperts = eplbLocalStatsTensor.size(0);
TORCH_CHECK(eplbStatsNumExperts > 0, "eplb_local_stats must not be empty");
TORCH_CHECK(eplbStatsNumExperts <= numExperts, "eplb_local_stats size must be <= numExperts (slots)");
CHECK_INPUT(eplbLocalStatsTensor, torch::kInt32);
TORCH_CHECK(eplbLocalStatsTensor.is_contiguous(), "eplb_local_stats must be contiguous");
TORCH_CHECK(eplbLocalStatsTensor.dim() == 1, "eplb_local_stats must be a 1D tensor");
}
// All input payloads must have the same first dimension (localNumTokens)
for (auto const& payload : inputPayloads)
@ -280,10 +308,12 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
= tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken(); // TODO: Decide this based on the workload
params.ep_size = static_cast<int>(epSize);
params.ep_rank = static_cast<int>(epRank);
params.num_experts_per_rank = static_cast<int>(numExperts) / static_cast<int>(epSize);
params.num_experts = static_cast<int>(numExperts);
params.local_num_tokens = static_cast<int>(localNumTokens);
params.max_tokens_per_rank = static_cast<int>(runtimeMaxTokensPerRank);
params.top_k = static_cast<int>(topK);
params.enable_eplb = enableEplb;
params.eplb_stats_num_experts = static_cast<int>(eplbStatsNumExperts);
params.token_selected_experts = tokenSelectedExperts.data_ptr<int32_t>();
@ -304,6 +334,15 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
= reinterpret_cast<int*>(targetWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]);
params.completion_flags[target_rank]
= reinterpret_cast<uint32_t*>(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]);
if (enableEplb)
{
params.eplb_gathered_stats[target_rank]
= reinterpret_cast<int*>(targetWorkSpacePtr + offsets[EPLB_GATHERED_STATS_OFFSET_INDEX]);
}
else
{
params.eplb_gathered_stats[target_rank] = nullptr;
}
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
@ -312,6 +351,15 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
}
}
if (enableEplb)
{
params.eplb_local_stats = eplbLocalStats.value().data_ptr<int32_t>();
}
else
{
params.eplb_local_stats = nullptr;
}
params.stream = at::cuda::getCurrentCUDAStream();
// Prepare for dispatch (zero counters/indices and increment flag_val)
@ -336,7 +384,20 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
// Compute aligned offset after dispatch payloads for combine payload region
int64_t combinePayloadOffset = static_cast<int64_t>(alignOffset(currentOffset, CACHELINE_ALIGNMENT));
return std::make_tuple(std::move(recvTensors), combinePayloadOffset);
torch::Tensor eplbGatheredStats;
if (enableEplb)
{
int* gatheredStatsPtr = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[EPLB_GATHERED_STATS_OFFSET_INDEX]);
auto statsOptions = workspace.options().dtype(torch::kInt32);
eplbGatheredStats = torch::from_blob(
gatheredStatsPtr, {static_cast<int64_t>(epSize), static_cast<int64_t>(eplbStatsNumExperts)}, statsOptions);
}
else
{
eplbGatheredStats = torch::empty({0}, workspace.options().dtype(torch::kInt32));
}
return std::make_tuple(std::move(recvTensors), combinePayloadOffset, std::move(eplbGatheredStats));
}
// MoE All-to-All Combine Operation
@ -525,9 +586,12 @@ torch::Tensor moeA2AGetCombinePayloadTensorOp(torch::Tensor const& workspace, in
}
// Return the size of auxiliary data in workspace
int64_t moeA2AGetAuxDataSizeOp(int64_t epSize, int64_t maxNumTokens)
int64_t moeA2AGetAuxDataSizeOp(int64_t epSize, int64_t maxNumTokens, torch::optional<int64_t> eplbStatsNumExperts)
{
MoeA2ADataOffsets offsets = calculateOffsets(static_cast<int>(epSize), static_cast<int>(maxNumTokens));
int64_t eplbStatsNumExpertsValue = eplbStatsNumExperts.value_or(0);
TORCH_CHECK(eplbStatsNumExpertsValue >= 0, "eplbStatsNumExperts must be positive if not None.");
MoeA2ADataOffsets offsets = calculateOffsets(
static_cast<int>(epSize), static_cast<int>(maxNumTokens), static_cast<int>(eplbStatsNumExpertsValue));
return static_cast<int64_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
}
@ -546,14 +610,16 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module)
module.def(
"moe_a2a_dispatch(Tensor token_selected_experts, Tensor[] input_payloads, "
"Tensor(a!->*) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, "
"int ep_rank, int ep_size, int top_k, int num_experts) -> (Tensor(a!)[], int)");
"int ep_rank, int ep_size, int top_k, int num_experts, "
"Tensor? eplb_local_stats=None) -> (Tensor(a!)[], int, Tensor(a!))");
module.def(
"moe_a2a_combine(Tensor(a) payload, int local_num_tokens,"
"Tensor(a!) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, "
"int ep_rank, int ep_size, int top_k, int combine_payload_offset, "
"bool payload_in_workspace) -> Tensor");
module.def(
"moe_a2a_initialize(Tensor(a!) workspace, int ep_rank, int ep_size, int max_num_tokens_per_rank) -> Tensor");
"moe_a2a_initialize(Tensor(a!) workspace, int ep_rank, int ep_size, int max_num_tokens_per_rank, "
"int? eplb_stats_num_experts=None) -> Tensor");
module.def(
"moe_a2a_sanitize_expert_ids(Tensor(a!) expert_ids, Tensor(a!) workspace, Tensor metainfo, int ep_rank, int "
"invalid_expert_id) -> ()");
@ -561,7 +627,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module)
"moe_a2a_get_combine_payload_tensor(Tensor(a) workspace, int ep_rank, int ep_size, int "
"runtime_max_tokens_per_rank, "
"int combine_payload_offset, ScalarType out_dtype, int hidden_size) -> Tensor(a)");
module.def("moe_a2a_get_aux_data_size(int ep_size, int max_num_tokens) -> int",
module.def("moe_a2a_get_aux_data_size(int ep_size, int max_num_tokens, int? eplb_stats_num_experts=None) -> int",
&tensorrt_llm::torch_ext::moe_comm::moeA2AGetAuxDataSizeOp);
}

View File

@ -25,6 +25,7 @@ class _A2AState:
phase: str = "idle" # idle | dispatched
local_num_tokens: int | None = None
combine_payload_offset: int | None = None
eplb_gathered_stats: torch.Tensor | None = None
class MoeAlltoAll:
@ -41,9 +42,13 @@ class MoeAlltoAll:
_METAINFO_INDEX: Dict[str, int] | None = None
@staticmethod
def get_aux_data_size(ep_size: int, max_num_tokens: int) -> int:
def get_aux_data_size(
ep_size: int,
max_num_tokens: int,
eplb_stats_num_experts: Optional[int] = None,
) -> int:
return torch.ops.trtllm.moe_a2a_get_aux_data_size(
ep_size, max_num_tokens)
ep_size, max_num_tokens, eplb_stats_num_experts)
@staticmethod
def calculate_required_workspace_size(
@ -52,11 +57,13 @@ class MoeAlltoAll:
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
eplb_stats_num_experts: Optional[int] = None,
extra_payload_bytes_per_token: int = 0) -> int:
element_size = dtype.itemsize
# Auxiliary data size
workspace_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens)
workspace_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens,
eplb_stats_num_experts)
# Dispatch needs workspace for [ep_size, max_tokens] tokens,
# but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization.
@ -103,6 +110,8 @@ class MoeAlltoAll:
int(thop.MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX),
"TOPK_SEND_INDICES_OFFSET_INDEX":
int(thop.MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX),
"EPLB_GATHERED_STATS_OFFSET_INDEX":
int(thop.MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX),
"PAYLOAD_DATA_OFFSET_INDEX":
int(thop.MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX),
"NUM_METAINFO_FIELDS":
@ -114,8 +123,9 @@ class MoeAlltoAll:
mapping: Mapping,
max_num_tokens: int,
top_k: int,
num_experts: int,
num_slots: int,
workspace_size_per_rank: int,
num_experts: Optional[int] = None,
):
"""
Initialize MoeAlltoAll with workspace allocation.
@ -124,6 +134,10 @@ class MoeAlltoAll:
mapping: TensorRT-LLM Mapping object containing rank information
max_num_tokens: Maximum number of tokens supported. Should be ModelConfig.max_num_tokens.
workspace_size_per_rank: Size of workspace per rank in bytes
num_slots: Number of routing slots (token_selected_experts values are in [0, num_slots)).
Note: The terminology is mapped to `num_experts` in this class and the kernels.
num_experts: (Optional) Number of experts for EPLB stats (must be <= num_slots). DO NOT provide this parameter if EPLB is not enabled.
Note: The terminology is mapped to `eplb_stats_num_experts` in this class and the kernels.
"""
# Check for environment variable override
workspace_mb_env = os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB")
@ -148,42 +162,52 @@ class MoeAlltoAll:
self.ep_rank = mapping.moe_ep_rank
self.top_k = top_k
self.num_experts = num_experts
self.num_experts = num_slots
if not isinstance(self.top_k, int) or self.top_k <= 0:
raise ValueError("top_k must be a positive int")
if not isinstance(self.num_experts, int) or self.num_experts <= 0:
raise ValueError("num_experts must be a positive int")
raise ValueError("num_slots must be a positive int")
if num_experts is not None:
assert num_experts > 0 and num_experts <= num_slots, "num_experts must be in (0, num_slots]"
tllm_logger.info(
"NVLinkOneSided AlltoAll: EPLB is enabled, with num_slots="
f"{num_slots} and num_experts={num_experts}")
self.enable_eplb = num_experts is not None
self.eplb_stats_num_experts = num_experts
if self._WORKSPACE is None:
tllm_logger.info(
f"nvlink_one_sided AlltoAll: Allocating workspace with size {workspace_size_per_rank} bytes. ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, max_num_tokens: {self.max_num_tokens}"
f"NVLinkOneSided AlltoAll: Allocating workspace with size {workspace_size_per_rank} bytes. ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, max_num_tokens: {self.max_num_tokens}"
)
mnnvl_mem = MnnvlMemory(mapping, workspace_size_per_rank)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8)
metainfo = torch.ops.trtllm.moe_a2a_initialize(
workspace,
self.ep_rank,
self.ep_size,
self.max_num_tokens,
)
workspace, self.ep_rank, self.ep_size, self.max_num_tokens,
self.eplb_stats_num_experts)
MoeAlltoAll._WORKSPACE = {
"workspace_size_per_rank": workspace_size_per_rank,
"max_num_tokens": self.max_num_tokens,
"ep_rank": self.ep_rank,
"ep_size": self.ep_size,
"eplb_stats_num_experts": self.eplb_stats_num_experts,
"mnnvl_mem": mnnvl_mem,
"workspace": workspace,
"metainfo": metainfo,
}
else:
assert self._WORKSPACE[
"workspace_size_per_rank"] == workspace_size_per_rank, "reuse workspace with different workspace_size_per_rank"
"workspace_size_per_rank"] == workspace_size_per_rank, "mistakenly reusing workspace with different workspace_size_per_rank"
assert self._WORKSPACE[
"max_num_tokens"] == self.max_num_tokens, "reuse workspace with different max_num_tokens"
"max_num_tokens"] == self.max_num_tokens, "mistakenly reusing workspace with different max_num_tokens"
assert self._WORKSPACE[
"ep_rank"] == self.ep_rank, "reuse workspace with different ep_rank"
"ep_rank"] == self.ep_rank, "mistakenly reusing workspace with different ep_rank"
assert self._WORKSPACE[
"ep_size"] == self.ep_size, "reuse workspace with different ep_size"
"ep_size"] == self.ep_size, "mistakenly reusing workspace with different ep_size"
assert self._WORKSPACE[
"eplb_stats_num_experts"] == self.eplb_stats_num_experts, (
"reuse workspace with different eplb_stats_num_experts")
self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"]
self.workspace = self._WORKSPACE["workspace"]
@ -196,7 +220,8 @@ class MoeAlltoAll:
input_payloads: list[torch.Tensor],
runtime_max_tokens_per_rank: int,
invalid_token_expert_id: Optional[int] = None,
expert_id_payload_index: Optional[int] = None):
expert_id_payload_index: Optional[int] = None,
eplb_local_stats: Optional[torch.Tensor] = None):
"""
Perform MoE all-to-all dispatch operation.
@ -206,19 +231,40 @@ class MoeAlltoAll:
runtime_max_tokens_per_rank: Maximum of the number of tokens of each DP rank's local batch.
invalid_token_expert_id: If not None, set the token_selected_experts of the invalid tokens to this expert id. This is used to notify the MoE to skip these tokens for GroupGEMM.
expert_id_payload_index: The index of token_selected_experts in the input_payloads. Must be provided if invalid_token_expert_id is not None.
eplb_local_stats: (Optional) [num_experts] tensor containing local statistics for EPLB
Returns:
recv_tensors: List of tensors received, each has shape [ep_size, max_tokens_per_rank, payload_num_elements_per_token]
"""
assert self._state.phase == "idle", "dispatch called twice without an intervening combine"
assert runtime_max_tokens_per_rank <= self.max_num_tokens, "runtime_max_tokens_per_rank must not exceed max_num_tokens"
recv_tensors, combine_payload_offset = torch.ops.trtllm.moe_a2a_dispatch(
token_selected_experts, input_payloads, self.workspace,
self.metainfo, runtime_max_tokens_per_rank, self.ep_rank,
self.ep_size, self.top_k, self.num_experts)
if eplb_local_stats is not None:
assert self.enable_eplb, "eplb_local_stats provided but enable_eplb is False"
assert eplb_local_stats.dim(
) == 1, "eplb_local_stats must be a 1D tensor"
assert eplb_local_stats.size(
0
) == self.eplb_stats_num_experts, "eplb_local_stats size must match eplb_stats_num_experts"
recv_tensors, combine_payload_offset, eplb_gathered_stats = torch.ops.trtllm.moe_a2a_dispatch(
token_selected_experts,
input_payloads,
self.workspace,
self.metainfo,
runtime_max_tokens_per_rank,
self.ep_rank,
self.ep_size,
self.top_k,
self.num_experts,
eplb_local_stats,
)
if eplb_gathered_stats.numel() == 0:
eplb_gathered_stats = None
# Update state together after successful dispatch
self._state.local_num_tokens = token_selected_experts.size(0)
self._state.combine_payload_offset = combine_payload_offset
self._state.eplb_gathered_stats = eplb_gathered_stats
self._state.phase = "dispatched"
if invalid_token_expert_id is not None:

View File

@ -185,3 +185,9 @@ class Communication(ABC):
Combined output tensor [local_num_tokens, hidden_size]
"""
raise NotImplementedError
def get_eplb_gathered_statistics(self) -> Optional[torch.Tensor]:
"""
Return gathered EPLB statistics from the last dispatch, if available.
"""
raise NotImplementedError

View File

@ -76,6 +76,7 @@ class CommunicationFactory:
expert_size_per_partition: Number of experts per partition (required for DeepEP)
payload_in_workspace: If True, final_hidden_states is already in workspace (for NVLinkOneSided)
alltoall_result_do_sum: If True, sum the alltoall results (for NVLinkTwoSided)
# TODO: Need a way to indicate whether EPLB is enabled.
Returns:
The selected communication method, or None if attention does not use DP
@ -122,6 +123,7 @@ class CommunicationFactory:
# Priority: NVLinkOneSided > NVLinkTwoSided > DeepEP > DeepEPLowLatency > AllGather
try:
enable_eplb = model_config.moe_load_balancer is not None
strategy = NVLinkOneSided(
mapping,
num_slots,
@ -130,6 +132,7 @@ class CommunicationFactory:
payload_in_workspace,
hidden_size=hidden_size,
dtype=act_dtype,
num_experts=num_experts if enable_eplb else None,
)
logger.info("Selected communication strategy: NVLinkOneSided")
return strategy
@ -231,6 +234,7 @@ class CommunicationFactory:
alltoall_result_do_sum=alltoall_result_do_sum,
)
elif method in ["NVLINK_ONE_SIDED"]:
enable_eplb = model_config.moe_load_balancer is not None
return NVLinkOneSided(
mapping,
num_slots,
@ -239,6 +243,7 @@ class CommunicationFactory:
payload_in_workspace,
hidden_size=hidden_size,
dtype=act_dtype,
num_experts=num_experts if enable_eplb else None,
)
elif method == "DEEPEP":
return DeepEP(

View File

@ -65,11 +65,18 @@ class NVLinkOneSided(Communication):
RECV_COUNTERS_OFFSET_INDEX = None
DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = None
COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = None
EPLB_GATHERED_STATS_OFFSET_INDEX = None
PAYLOAD_DATA_OFFSET_INDEX = None
@staticmethod
def get_aux_data_size(ep_size: int, max_num_tokens: int) -> int:
return torch.ops.trtllm.moe_a2a_get_aux_data_size(ep_size, max_num_tokens)
def get_aux_data_size(
ep_size: int,
max_num_tokens: int,
eplb_stats_num_experts: Optional[int] = None,
) -> int:
return torch.ops.trtllm.moe_a2a_get_aux_data_size(
ep_size, max_num_tokens, eplb_stats_num_experts
)
@staticmethod
def calculate_required_workspace_size(
@ -78,12 +85,15 @@ class NVLinkOneSided(Communication):
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
eplb_stats_num_experts: Optional[int] = None,
extra_payload_bytes_per_token: int = 0,
) -> int:
element_size = dtype.itemsize
# Auxiliary data size
workspace_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens)
workspace_size = NVLinkOneSided.get_aux_data_size(
ep_size, max_num_tokens, eplb_stats_num_experts
)
# Dispatch needs workspace for [ep_size, max_tokens] tokens,
# but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization.
@ -97,13 +107,12 @@ class NVLinkOneSided(Communication):
# token_final_scales
workspace_size += ep_size * max_num_tokens * top_k * 4
workspace_size = pad_up(workspace_size, 128)
# extra payload bytes per token
workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token
workspace_size = pad_up(workspace_size, 128)
# Required workspace for combine [ep_size, max_tokens] tokens
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
workspace_size = pad_up(workspace_size, 128)
# extra payload bytes per token
workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token
workspace_size = pad_up(workspace_size, 128)
return workspace_size
@ -124,29 +133,36 @@ class NVLinkOneSided(Communication):
cls.COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = int(
thop.MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX
)
cls.EPLB_GATHERED_STATS_OFFSET_INDEX = int(
thop.MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX
)
cls.PAYLOAD_DATA_OFFSET_INDEX = int(thop.MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX)
def __init__(
self,
mapping: Mapping,
num_experts: int,
num_slots: int,
top_k: int,
max_num_tokens_per_rank: int,
payload_in_workspace: bool = False,
hidden_size: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
num_experts: Optional[int] = None,
):
"""
Initialize NVLinkOneSided with workspace allocation.
Args:
mapping: TensorRT-LLM Mapping object containing rank information
num_experts: Total number of experts
num_slots: Number of routing slots (token_selected_experts values are in [0, num_slots)).
Note: The terminology is mapped to `num_experts` in this class and the kernels.
top_k: Number of experts per token
max_num_tokens_per_rank: Maximum number of tokens per rank (for workspace allocation)
payload_in_workspace: If True, final_hidden_states is already in workspace
hidden_size: Hidden dimension size (optional, for auto workspace calculation)
dtype: Data type (optional, for auto workspace calculation)
num_experts: (Optional) Number of experts for EPLB stats (must be <= num_slots). DO NOT provide this parameter if EPLB is not enabled.
Note: The terminology is mapped to `eplb_stats_num_experts` in this class and the kernels.
"""
super().__init__(mapping)
@ -154,10 +170,20 @@ class NVLinkOneSided(Communication):
raise RuntimeError("Currently NVLinkOneSided only supports pure EP for MoE.")
# Store needed parameters
self.num_experts = num_experts
self.num_experts = num_slots
self.top_k = top_k
self.max_num_tokens_per_rank = max_num_tokens_per_rank
self.payload_in_workspace = payload_in_workspace
if num_experts is not None:
assert num_experts > 0 and num_experts <= num_slots, (
"num_experts must be in (0, num_slots]"
)
tllm_logger.info(
"NVLinkOneSided AlltoAll: EPLB is enabled, with num_slots="
f"{num_slots} and num_experts={num_experts}"
)
self.enable_eplb = num_experts is not None
self.eplb_stats_num_experts = num_experts
# Initialize constants from C++
self._init_constants()
@ -166,7 +192,12 @@ class NVLinkOneSided(Communication):
auto_workspace_size = None
if hidden_size is not None and dtype is not None:
auto_workspace_size = self.calculate_required_workspace_size(
self.ep_size, self.top_k, max_num_tokens_per_rank, hidden_size, dtype
self.ep_size,
self.top_k,
max_num_tokens_per_rank,
hidden_size,
dtype,
eplb_stats_num_experts=self.eplb_stats_num_experts,
)
workspace_mb_env = os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB")
if workspace_mb_env:
@ -200,12 +231,14 @@ class NVLinkOneSided(Communication):
self.ep_rank,
self.ep_size,
self.max_num_tokens_per_rank,
self.eplb_stats_num_experts,
)
NVLinkOneSided._WORKSPACE = {
"workspace_size_per_rank": self.workspace_size_per_rank,
"max_num_tokens_per_rank": self.max_num_tokens_per_rank,
"ep_rank": self.ep_rank,
"ep_size": self.ep_size,
"eplb_stats_num_experts": self.eplb_stats_num_experts,
"mnnvl_mem": mnnvl_mem,
"workspace": workspace,
"metainfo": metainfo,
@ -223,6 +256,9 @@ class NVLinkOneSided(Communication):
assert self._WORKSPACE["ep_size"] == self.ep_size, (
"reuse workspace with different ep_size"
)
assert self._WORKSPACE["eplb_stats_num_experts"] == self.eplb_stats_num_experts, (
"reuse workspace with different eplb_stats_num_experts"
)
self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"]
self.workspace = self._WORKSPACE["workspace"]
@ -230,10 +266,7 @@ class NVLinkOneSided(Communication):
self.max_num_tokens_per_rank = self._WORKSPACE["max_num_tokens_per_rank"]
# Initialize dispatch state
self._dispatch_state = {}
# Internal state
self._state: str = "idle" # idle | dispatched
self._dispatch_state = {"phase": "idle"}
# Invalid token expert ID (default to -1), the kernels in TRTLLM-gen is hard-code to support -1 only.
self.invalid_token_expert_id: int = -1
@ -286,7 +319,7 @@ class NVLinkOneSided(Communication):
Tuple of (hidden_states, hidden_states_sf, token_selected_slots, token_final_scales)
Each tensor has shape [ep_size, max_tokens_per_rank, ...]
"""
if self._state == "dispatched":
if self._dispatch_state.get("phase") == "dispatched":
raise RuntimeError("dispatch called twice without an intervening combine")
# Calculate runtime_max_tokens_per_rank from all_rank_num_tokens
@ -304,23 +337,35 @@ class NVLinkOneSided(Communication):
if token_final_scales is not None:
payloads.append(token_final_scales)
recv_buffers, combine_payload_offset = torch.ops.trtllm.moe_a2a_dispatch(
token_selected_slots,
payloads,
self.workspace,
self.moe_a2a_metainfo,
runtime_max_tokens_per_rank,
self.ep_rank,
self.ep_size,
self.top_k,
self.num_experts,
eplb_local_stats = kwargs.get("eplb_local_stats")
if eplb_local_stats is not None:
assert self.enable_eplb, "eplb_local_stats provided but enable_eplb is False"
assert eplb_local_stats.dim() == 1, "eplb_local_stats must be a 1D tensor"
assert eplb_local_stats.size(0) == self.eplb_stats_num_experts, (
"eplb_local_stats size must match eplb_stats_num_experts"
)
recv_buffers, combine_payload_offset, eplb_gathered_stats = (
torch.ops.trtllm.moe_a2a_dispatch(
token_selected_slots,
payloads,
self.workspace,
self.moe_a2a_metainfo,
runtime_max_tokens_per_rank,
self.ep_rank,
self.ep_size,
self.top_k,
self.num_experts,
eplb_local_stats,
)
)
self._state = "dispatched"
if eplb_gathered_stats.numel() == 0:
eplb_gathered_stats = None
self._dispatch_state["eplb_gathered_stats"] = eplb_gathered_stats
self._dispatch_state["combine_payload_offset"] = int(combine_payload_offset)
self._dispatch_state["local_num_tokens"] = token_selected_slots.size(0)
self._dispatch_state["runtime_max_tokens_per_rank"] = runtime_max_tokens_per_rank
self._dispatch_state["phase"] = "dispatched"
# Extract results from recv_buffers
# Payload order matches input:
@ -363,6 +408,13 @@ class NVLinkOneSided(Communication):
token_final_scales_recv,
)
def get_eplb_gathered_statistics(self) -> Optional[torch.Tensor]:
"""
Return gathered EPLB statistics from the last dispatch, if available.
"""
assert self.enable_eplb, "EPLB is not enabled"
return self._dispatch_state.get("eplb_gathered_stats")
def combine(
self,
final_hidden_states: torch.Tensor,
@ -380,7 +432,7 @@ class NVLinkOneSided(Communication):
Combined output tensor [local_num_tokens, hidden_size]
"""
if self._state != "dispatched":
if self._dispatch_state.get("phase") != "dispatched":
raise RuntimeError("combine called before a successful dispatch")
local_num_tokens = self._dispatch_state.get("local_num_tokens")
@ -423,8 +475,7 @@ class NVLinkOneSided(Communication):
)
# Reset state for next round
self._state = "idle"
self._dispatch_state.clear()
self._dispatch_state = {"phase": "idle"}
return output
@ -444,7 +495,7 @@ class NVLinkOneSided(Communication):
Returns:
Tensor view into workspace [ep_size, max_tokens_per_rank, hidden_size]
"""
if self._state != "dispatched":
if self._dispatch_state.get("phase") != "dispatched":
raise RuntimeError(
"get_combine_payload_tensor_in_workspace called before a successful dispatch"
)

View File

@ -606,9 +606,11 @@ class ConfigurableMoE(MoE):
if self.layer_load_balancer and token_selected_experts is not None:
self._load_balancer_done_wait_gpu_stage(is_first_call)
# Update EPLB statistics (method depends on whether using NVLINK two-sided)
# Use base class method: ignore_allreduce=True for NVLINK two-sided (uses local stats only)
ignore_allreduce = self._is_using_nvlink_two_sided()
# Update EPLB statistics (method depends on communication strategy)
# Use base class method: ignore_allreduce=True for NVLINK two-sided/one-sided (uses local stats only)
ignore_allreduce = (
self._is_using_nvlink_two_sided() or self._is_using_nvlink_one_sided()
)
self._load_balancer_update_statistic(
token_selected_experts,
is_first_call,
@ -628,6 +630,9 @@ class ConfigurableMoE(MoE):
# ========== Step 3.5: Communication Prepare Phase (BEFORE quantization) ==========
# NVLINK two-sided has a prepare phase to gather EPLB statistics
local_statistic_tensor_for_dispatch = None
eplb_dispatch_kwargs = {}
should_update_eplb_after_dispatch = False
# Only NVLINK two-sided needs prepare_dispatch
if self._is_using_nvlink_two_sided():
# Get local statistic info if this is the last call and EPLB is enabled
@ -645,6 +650,16 @@ class ConfigurableMoE(MoE):
if gathered_stats is not None:
gathered_stats = gathered_stats.view((self.mapping.moe_ep_size, self.num_experts))
self._load_balancer_update_statistic_with_gathered_statistic(gathered_stats)
# TODO: The abstract does not work well as NVLinkTwoSided gathers EPLB stats in prepare_dispatch,
# while NVLinkOneSided gathers EPLB stats in dispatch.
elif self._is_using_nvlink_one_sided():
if self.layer_load_balancer and is_last_call:
local_statistic_tensor_for_dispatch = (
self._load_balancer_get_local_statistic_tensor()
)
if local_statistic_tensor_for_dispatch is not None:
eplb_dispatch_kwargs["eplb_local_stats"] = local_statistic_tensor_for_dispatch
should_update_eplb_after_dispatch = True
# ========== Step 4 & 5: Quantization and Communication Dispatch ==========
# Order depends on whether strategy supports post-quant dispatch
@ -666,11 +681,10 @@ class ConfigurableMoE(MoE):
# Step 4b: Dispatch AFTER quantization
# Get pre_quant_scale for W4AFP8 if available (only DeepEPLowLatency needs it)
# Other strategies will ignore this via **kwargs, so it's safe to pass unconditionally
dispatch_kwargs = {}
dispatch_kwargs = dict(eplb_dispatch_kwargs)
if hasattr(self, "quant_scales") and self.quant_scales is not None:
if hasattr(self.quant_scales, "pre_quant_scale_1"):
dispatch_kwargs["pre_quant_scale"] = self.quant_scales.pre_quant_scale_1
x, x_sf, token_selected_slots, token_final_scales = self.comm.dispatch(
hidden_states=x,
hidden_states_sf=x_sf,
@ -680,6 +694,9 @@ class ConfigurableMoE(MoE):
use_dp_padding=use_dp_padding,
**dispatch_kwargs,
)
if should_update_eplb_after_dispatch:
gathered_stats = self.comm.get_eplb_gathered_statistics()
self._load_balancer_update_statistic_with_gathered_statistic(gathered_stats)
else:
# ===== Pre-quant flow: Dispatch → Quantize =====
@ -960,6 +977,10 @@ class ConfigurableMoE(MoE):
"""Check if using NVLinkTwoSided communication strategy"""
return isinstance(self.comm, NVLinkTwoSided)
def _is_using_nvlink_one_sided(self) -> bool:
"""Check if using NVLinkOneSided communication strategy"""
return isinstance(self.comm, NVLinkOneSided)
def _get_nvlink_onesided_moe_output(
self,
all_rank_num_tokens: Optional[List[int]],

View File

@ -166,15 +166,22 @@ class CutlassFusedMoE(MoE):
dtype = self.dtype or torch.float16
workspace_size = MoeAlltoAll.calculate_required_workspace_size(
ep_size, self.routing_method.experts_per_token,
max_num_tokens, hidden_size, dtype)
ep_size,
self.routing_method.experts_per_token,
max_num_tokens,
hidden_size,
dtype,
self.num_experts if self.layer_load_balancer else None,
)
self.moe_a2a = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=model_config.max_num_tokens,
top_k=self.routing_method.experts_per_token,
num_experts=self.num_slots,
num_slots=self.num_slots,
workspace_size_per_rank=workspace_size,
num_experts=self.num_experts
if self.layer_load_balancer else None,
)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
raise NotImplementedError(
@ -509,7 +516,10 @@ class CutlassFusedMoE(MoE):
if self.layer_load_balancer:
self._load_balancer_done_wait_gpu_stage(is_first_call)
ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type in (
AlltoallMethodType.NVLinkTwoSided,
AlltoallMethodType.NVLinkOneSided,
)
self._load_balancer_update_statistic(
token_selected_experts,
is_first_call,
@ -608,14 +618,32 @@ class CutlassFusedMoE(MoE):
payloads.append(token_selected_slots)
payloads.append(token_final_scales)
recv_tensors = self.moe_a2a.dispatch(
token_selected_slots,
payloads,
runtime_max_tokens_per_rank,
invalid_token_expert_id=self.
num_slots, # Caution: Cutlass MoE uses num_slots as invalid token expert id
expert_id_payload_index=expert_id_payload_index,
)
loadbalancer_local_statistic_info = None
if self.layer_load_balancer and is_last_call:
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
)
if loadbalancer_local_statistic_info is not None:
recv_tensors = self.moe_a2a.dispatch(
token_selected_slots,
payloads,
runtime_max_tokens_per_rank,
invalid_token_expert_id=self.
num_slots, # Caution: Cutlass MoE uses num_slots as invalid token expert id
expert_id_payload_index=expert_id_payload_index,
eplb_local_stats=loadbalancer_local_statistic_info,
)
gathered_stats = self.moe_a2a._state.eplb_gathered_stats
self._load_balancer_update_statistic_with_gathered_statistic(
gathered_stats)
else:
recv_tensors = self.moe_a2a.dispatch(
token_selected_slots,
payloads,
runtime_max_tokens_per_rank,
invalid_token_expert_id=self.
num_slots, # Caution: Cutlass MoE uses num_slots as invalid token expert id
expert_id_payload_index=expert_id_payload_index,
)
if x_sf is not None:
x_recv, x_sf_recv, token_selected_slots_recv, token_final_scales_recv = recv_tensors

View File

@ -140,15 +140,22 @@ class TRTLLMGenFusedMoE(MoE):
dtype = self.dtype or torch.bfloat16
workspace_size = MoeAlltoAll.calculate_required_workspace_size(
ep_size, self.routing_method.experts_per_token,
max_num_tokens, hidden_size, dtype)
ep_size,
self.routing_method.experts_per_token,
max_num_tokens,
hidden_size,
dtype,
self.num_experts if self.layer_load_balancer else None,
)
self.moe_a2a = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=model_config.max_num_tokens,
top_k=self.routing_method.experts_per_token,
num_experts=self.num_slots,
workspace_size_per_rank=workspace_size)
num_slots=self.num_slots,
workspace_size_per_rank=workspace_size,
num_experts=self.num_experts
if self.layer_load_balancer else None)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
raise NotImplementedError(
"DeepEP and DeepEPLowLatency are not supported for TRTLLMGenFusedMoE yet"
@ -690,7 +697,10 @@ class TRTLLMGenFusedMoE(MoE):
self._load_balancer_done_wait_gpu_stage(is_first_call)
ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type in (
AlltoallMethodType.NVLinkTwoSided,
AlltoallMethodType.NVLinkOneSided,
)
self._load_balancer_update_statistic(
token_selected_experts,
is_first_call,
@ -778,14 +788,32 @@ class TRTLLMGenFusedMoE(MoE):
payloads.append(token_selected_experts)
payloads.append(token_final_scales)
recv_tensors = self.moe_a2a.dispatch(
token_selected_experts,
payloads,
runtime_max_tokens_per_rank,
invalid_token_expert_id=
-1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id
expert_id_payload_index=expert_id_payload_index,
)
loadbalancer_local_statistic_info = None
if self.layer_load_balancer and is_last_call:
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
)
if loadbalancer_local_statistic_info is not None:
recv_tensors = self.moe_a2a.dispatch(
token_selected_experts,
payloads,
runtime_max_tokens_per_rank,
invalid_token_expert_id=
-1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id
expert_id_payload_index=expert_id_payload_index,
eplb_local_stats=loadbalancer_local_statistic_info,
)
gathered_stats = self.moe_a2a._state.eplb_gathered_stats
self._load_balancer_update_statistic_with_gathered_statistic(
gathered_stats)
else:
recv_tensors = self.moe_a2a.dispatch(
token_selected_experts,
payloads,
runtime_max_tokens_per_rank,
invalid_token_expert_id=
-1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id
expert_id_payload_index=expert_id_payload_index,
)
if x_sf is not None:
x_recv, x_sf_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors

View File

@ -56,24 +56,23 @@ def compute_target_rank_id(expert_id, num_experts_per_rank):
return expert_id // num_experts_per_rank
def generate_token_selected_experts(local_num_tokens: int, ep_size: int,
num_experts_per_rank: int,
def generate_token_selected_experts(local_num_tokens: int, num_experts: int,
top_k: int) -> torch.Tensor:
"""Generate global expert IDs tensor, aligned with single-GPU test semantics."""
return torch.randint(
0,
ep_size * num_experts_per_rank,
num_experts,
(local_num_tokens, top_k),
dtype=torch.int32,
device='cuda',
)
def create_experts(num_experts_per_rank,
hidden_size,
ep_rank,
device,
dtype=torch.bfloat16):
def create_experts_per_rank(num_experts_per_rank,
hidden_size,
ep_rank,
device,
dtype=torch.bfloat16):
"""
Create a 3D tensor of expert weights for a given rank.
@ -226,9 +225,9 @@ def make_bfloat16_payloads(
def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k,
workspace_size_per_rank,
num_experts_per_rank, hidden_size,
invalid_token_expert_id):
workspace_size_per_rank, num_experts,
hidden_size, invalid_token_expert_id,
enable_eplb):
"""Worker function for MPIPoolExecutor."""
rank = tllm.mpi_rank()
torch.cuda.set_device(rank)
@ -244,25 +243,40 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k,
# Create MoeAlltoAll manager
max_num_tokens = max(all_num_tokens)
moe_a2a = MoeAlltoAll(mapping, max_num_tokens, top_k,
ep_size * num_experts_per_rank,
workspace_size_per_rank)
eplb_stats_num_experts = (
num_experts // 2 if enable_eplb else None
) # Use half of the experts for testing EPLB stats
moe_a2a = MoeAlltoAll(
mapping=mapping,
max_num_tokens=max_num_tokens,
top_k=top_k,
num_slots=num_experts,
workspace_size_per_rank=workspace_size_per_rank,
num_experts=eplb_stats_num_experts,
)
# Get the number of tokens for this specific rank (same as single-GPU)
rank_local_tokens = all_num_tokens[rank]
# Generate data using helper functions
token_selected_experts = generate_token_selected_experts(
rank_local_tokens, ep_size, num_experts_per_rank, top_k)
rank_local_tokens, num_experts, top_k)
payloads, expert_id_payload_index = make_nvfp4_payloads(
rank_local_tokens, hidden_size, top_k, rank, token_selected_experts)
eplb_local_stats = None
if enable_eplb:
eplb_local_stats = (torch.arange(
eplb_stats_num_experts, dtype=torch.int32, device="cuda") +
rank * 1000)
recv_tensors = moe_a2a.dispatch(
token_selected_experts,
payloads,
max_num_tokens,
invalid_token_expert_id=invalid_token_expert_id,
expert_id_payload_index=expert_id_payload_index)
expert_id_payload_index=expert_id_payload_index,
eplb_local_stats=eplb_local_stats)
# Verify completion flags after dispatch
completion_flags_offset = moe_a2a.metainfo[MoeAlltoAll._METAINFO_INDEX[
@ -306,10 +320,16 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k,
max_num_tokens, top_k).cpu()
# Return results to be collected (move to CPU for MPI transfer)
eplb_gathered_stats = moe_a2a._state.eplb_gathered_stats
if eplb_gathered_stats is not None:
eplb_gathered_stats = eplb_gathered_stats.cpu()
if eplb_local_stats is not None:
eplb_local_stats = eplb_local_stats.cpu()
return (token_selected_experts.cpu(), [p.cpu() for p in payloads],
[rt.cpu()
for rt in recv_tensors], send_counters, topk_send_indices,
topk_target_ranks, recv_counters, expert_id_payload_index)
[rt.cpu() for rt in recv_tensors], send_counters,
topk_send_indices, topk_target_ranks, recv_counters,
expert_id_payload_index, eplb_local_stats, eplb_gathered_stats)
except Exception:
traceback.print_exc()
raise
@ -318,11 +338,12 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k,
def verify_dispatch(all_token_selected_experts, all_payloads, all_recv_tensors,
all_send_counters, all_topk_send_indices,
all_topk_target_ranks, all_recv_counters, ep_size,
all_num_tokens, top_k, num_experts_per_rank,
expert_id_payload_index, invalid_token_expert_id):
all_num_tokens, top_k, num_experts, expert_id_payload_index,
invalid_token_expert_id):
"""Verify dispatch results including actual content verification"""
max_num_tokens = max(all_num_tokens)
num_experts_per_rank = num_experts // ep_size
# Verify dimensions and dtypes
for send_rank in range(ep_size):
local_num_tokens = all_num_tokens[send_rank]
@ -475,26 +496,32 @@ class TestMoEAlltoAll:
enabled=False
) # MPI pool executors have known thread cleanup timing issues
@pytest.mark.parametrize(
"mpi_pool_executor,all_num_tokens,top_k",
"mpi_pool_executor,all_num_tokens,top_k,enable_eplb",
[
# (num_workers, all_num_tokens, top_k)
# Basic configurations
(4, [32, 32, 32, 32], 2), # Four ranks with uniform distribution
(4, [32, 32, 32, 32], 2, False
), # Four ranks with uniform distribution
(4, [16, 32, 64, 48
], 2), # Four ranks with non-uniform distribution
(2, [100, 50], 2), # Two ranks with different loads
], 2, False), # Four ranks with non-uniform distribution
(2, [100, 50], 2, False), # Two ranks with different loads
(8, [10, 20, 30, 40, 50, 60, 70, 80
], 2), # Eight ranks with increasing load
], 2, False), # Eight ranks with increasing load
# Different top_k values
(4, [32, 32, 32, 32], 4), # Four ranks with top_k = 4
(4, [32, 32, 32, 32], 8), # Four ranks with top_k = 8
(4, [32, 32, 32, 32], 4, False), # Four ranks with top_k = 4
(4, [32, 32, 32, 32], 8, False), # Four ranks with top_k = 8
# Edge cases
(4, [1, 1, 1, 1], 2), # Four ranks with single token per rank
(4, [1, 1, 1, 1], 2, False
), # Four ranks with single token per rank
# EPLB stats path
(4, [32, 32, 32, 32], 2, True),
],
indirect=["mpi_pool_executor"])
def test_dispatch(self, mpi_pool_executor, all_num_tokens, top_k):
def test_dispatch(self, mpi_pool_executor, all_num_tokens, top_k,
enable_eplb):
"""Test MoE A2A dispatch with MNNVL across multiple GPUs"""
try:
@ -511,7 +538,7 @@ class TestMoEAlltoAll:
) >= ep_size, f"Need at least {ep_size} GPUs, found {torch.cuda.device_count()}"
hidden_size = 1024
num_experts_per_rank = 8
num_experts = 32
# Large enough workspace
workspace_size_per_rank = 512 * 1024 * 1024
@ -523,8 +550,8 @@ class TestMoEAlltoAll:
results = mpi_pool_executor.map(
run_moe_a2a_dispatch_single_rank,
*zip(*[(ep_size, all_num_tokens, top_k, workspace_size_per_rank,
num_experts_per_rank, hidden_size,
invalid_token_expert_id)] * ep_size),
num_experts, hidden_size, invalid_token_expert_id,
enable_eplb)] * ep_size),
)
# Collect results from all ranks (same as single-GPU collecting from emulated ranks)
@ -540,6 +567,8 @@ class TestMoEAlltoAll:
all_recv_counters = [r[6] for r in all_results]
all_expert_id_payload_index = [r[7] for r in all_results]
expert_id_payload_index = all_expert_id_payload_index[0]
all_eplb_local_stats = [r[8] for r in all_results]
all_eplb_gathered_stats = [r[9] for r in all_results]
assert all(i == expert_id_payload_index
for i in all_expert_id_payload_index
@ -550,9 +579,18 @@ class TestMoEAlltoAll:
all_recv_tensors, all_send_counters,
all_topk_send_indices, all_topk_target_ranks,
all_recv_counters, ep_size, all_num_tokens, top_k,
num_experts_per_rank, expert_id_payload_index,
num_experts, expert_id_payload_index,
invalid_token_expert_id)
if enable_eplb:
expected_stats = torch.stack(all_eplb_local_stats, dim=0)
for rank in range(ep_size):
gathered_stats = all_eplb_gathered_stats[rank]
assert gathered_stats is not None
assert torch.equal(
gathered_stats,
expected_stats), (f"Rank {rank} gathered_stats mismatch")
@pytest.mark.skipif(torch.cuda.device_count() < 8,
reason='needs at least 8 GPUs to run multi-GPU test')
@pytest.mark.threadleak(enabled=False)
@ -587,8 +625,9 @@ class TestMoEAlltoAll:
assert torch.cuda.device_count(
) >= ep_size, f"Need at least {ep_size} GPUs, found {torch.cuda.device_count()}"
hidden_size = 2880 # gpt-oss
num_experts_per_rank = 8
# gpt-oss-20b
hidden_size = 2880
num_experts = 32
# Large enough workspace
workspace_size_per_rank = 512 * 1024 * 1024
@ -599,8 +638,8 @@ class TestMoEAlltoAll:
results = mpi_pool_executor.map(
run_moe_a2a_dispatch_moe_combine_single_rank,
*zip(*[(ep_size, all_num_tokens, top_k, workspace_size_per_rank,
num_experts_per_rank, hidden_size,
invalid_token_expert_id)] * ep_size),
num_experts, hidden_size, invalid_token_expert_id)] *
ep_size),
)
# Collect results
@ -617,14 +656,12 @@ class TestMoEAlltoAll:
# Verify combine results
print("Starting verification...")
verify_combine_results(all_results, ep_size, all_num_tokens, top_k,
hidden_size, num_experts_per_rank)
verify_combine(all_results, ep_size)
def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k,
workspace_size_per_rank,
num_experts_per_rank,
hidden_size,
num_experts, hidden_size,
invalid_token_expert_id):
"""Worker function for dispatch and combine test."""
rank = tllm.mpi_rank()
@ -639,15 +676,19 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k,
world_size=ep_size)
# Create MoeAlltoAll manager
moe_a2a = MoeAlltoAll(mapping, max_num_tokens, top_k,
ep_size * num_experts_per_rank,
workspace_size_per_rank)
moe_a2a = MoeAlltoAll(
mapping=mapping,
max_num_tokens=max_num_tokens,
top_k=top_k,
num_slots=num_experts,
workspace_size_per_rank=workspace_size_per_rank,
)
rank_local_tokens = all_num_tokens[rank]
# Generate test data - use simpler payload for combine test
token_selected_experts = generate_token_selected_experts(
rank_local_tokens, ep_size, num_experts_per_rank, top_k)
rank_local_tokens, num_experts, top_k)
payloads, expert_id_payload_index = make_bfloat16_payloads(
rank_local_tokens, hidden_size, top_k, rank, token_selected_experts)
@ -670,11 +711,12 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k,
# emulate MoE computation on the received data
# Create experts for this rank
rank_experts = create_experts(num_experts_per_rank,
hidden_size,
rank,
device,
dtype=torch.bfloat16)
num_experts_per_rank = num_experts // ep_size
rank_experts = create_experts_per_rank(num_experts_per_rank,
hidden_size,
rank,
device,
dtype=torch.bfloat16)
hidden_states_recv = fake_moe(
hidden_states_recv.view(ep_size * max_num_tokens,
@ -721,8 +763,7 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k,
raise
def verify_combine_results(all_results, ep_size, all_num_tokens, top_k,
hidden_size, num_experts_per_rank):
def verify_combine(all_results, ep_size):
"""Verify that combine correctly sums the dispatched tokens."""
# Extract results