mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-10048][feat] Fuse the AllGather for expert statistics required by the EPLB. (#10885)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
parent
5efee01da1
commit
e405468230
@ -361,19 +361,15 @@ __global__ void moeA2APrepareDispatchKernel(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Generic Dispatch Kernel Implementation
|
||||
// One warp per token design:
|
||||
// - Each CTA has 256 threads = 8 warps
|
||||
// - Each warp independently processes one token and all its payloads
|
||||
// - Better GPU utilization and reduced synchronization overhead
|
||||
// Dispatch Kernels
|
||||
// ============================================================================
|
||||
|
||||
template <typename ThreadingPolicy, int TOP_K>
|
||||
template <typename ThreadingPolicy, int TOP_K, bool ENABLE_EPLB>
|
||||
__global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [local_num_tokens, TOP_K]
|
||||
const DispatchKernelPointers ptrs, // Struct containing all kernel pointers
|
||||
int num_payloads, // Number of payloads
|
||||
int max_tokens_per_rank, // Maximum tokens per rank
|
||||
int local_num_tokens, int rank_id, int ep_size, int num_experts_per_rank)
|
||||
int local_num_tokens, int rank_id, int ep_size, int num_experts, int eplb_stats_num_experts)
|
||||
{
|
||||
|
||||
int thread_idx = ThreadingPolicy::offset();
|
||||
@ -411,6 +407,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
|
||||
}
|
||||
|
||||
uint64_t already_copied = 0;
|
||||
int num_experts_per_rank = num_experts / ep_size;
|
||||
for (int k = 0; k < TOP_K; k++)
|
||||
{
|
||||
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
|
||||
@ -501,6 +498,21 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
|
||||
ptrs.recv_counters[target_rank][rank_id] = send_count;
|
||||
}
|
||||
|
||||
if constexpr (ENABLE_EPLB)
|
||||
{
|
||||
// Write local stats into peer buffers before the release fence below.
|
||||
#pragma unroll 1
|
||||
for (int target_rank = 0; target_rank < ep_size; ++target_rank)
|
||||
{
|
||||
int* target_stats = ptrs.eplb_gathered_stats[target_rank];
|
||||
for (int expert_id = lane_id; expert_id < eplb_stats_num_experts; expert_id += warpSize)
|
||||
{
|
||||
int stat_val = ptrs.eplb_local_stats[expert_id];
|
||||
target_stats[rank_id * eplb_stats_num_experts + expert_id] = stat_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if !DISABLE_SYNC_FOR_PROFILING
|
||||
uint32_t expected_value = *ptrs.flag_val;
|
||||
|
||||
@ -588,6 +600,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
for (int target_rank = 0; target_rank < params.ep_size; target_rank++)
|
||||
{
|
||||
kernel_ptrs.recv_counters[target_rank] = params.recv_counters[target_rank];
|
||||
kernel_ptrs.eplb_gathered_stats[target_rank] = params.eplb_gathered_stats[target_rank];
|
||||
for (int payload = 0; payload < params.num_payloads; payload++)
|
||||
{
|
||||
kernel_ptrs.recv_buffers[target_rank][payload] = params.recv_buffers[target_rank][payload];
|
||||
@ -606,6 +619,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
kernel_ptrs.local_token_counter = params.local_token_counter;
|
||||
kernel_ptrs.topk_target_ranks = params.topk_target_ranks;
|
||||
kernel_ptrs.topk_send_indices = params.topk_send_indices;
|
||||
kernel_ptrs.eplb_local_stats = params.eplb_local_stats;
|
||||
|
||||
int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ADispatchBlockSize();
|
||||
constexpr int kWarpSize = 32;
|
||||
@ -621,10 +635,12 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
grid_size = 1;
|
||||
}
|
||||
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
|
||||
SWITCH_TOP_K(params.top_k, TOP_K,
|
||||
moeA2ADispatchKernel<BlockPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
|
||||
params.token_selected_experts, kernel_ptrs, params.num_payloads, params.max_tokens_per_rank,
|
||||
params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts_per_rank))
|
||||
SWITCH_BOOL(params.enable_eplb, EPLB_STATS,
|
||||
SWITCH_TOP_K(params.top_k, TOP_K,
|
||||
moeA2ADispatchKernel<BlockPolicy, TOP_K, EPLB_STATS>
|
||||
<<<grid_size, kBlockSize, shared_bytes, params.stream>>>(params.token_selected_experts, kernel_ptrs,
|
||||
params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank,
|
||||
params.ep_size, params.num_experts, params.eplb_stats_num_experts)))
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -635,10 +651,12 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
grid_size = 1;
|
||||
}
|
||||
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
|
||||
SWITCH_TOP_K(params.top_k, TOP_K,
|
||||
moeA2ADispatchKernel<WarpPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
|
||||
params.token_selected_experts, kernel_ptrs, params.num_payloads, params.max_tokens_per_rank,
|
||||
params.local_num_tokens, params.ep_rank, params.ep_size, params.num_experts_per_rank))
|
||||
SWITCH_BOOL(params.enable_eplb, EPLB_STATS,
|
||||
SWITCH_TOP_K(params.top_k, TOP_K,
|
||||
moeA2ADispatchKernel<WarpPolicy, TOP_K, EPLB_STATS>
|
||||
<<<grid_size, kBlockSize, shared_bytes, params.stream>>>(params.token_selected_experts, kernel_ptrs,
|
||||
params.num_payloads, params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank,
|
||||
params.ep_size, params.num_experts, params.eplb_stats_num_experts)))
|
||||
}
|
||||
}
|
||||
|
||||
@ -980,24 +998,24 @@ __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t c
|
||||
if (payload_bytes == nullptr)
|
||||
return;
|
||||
|
||||
int slot_idx = ThreadingPolicy::token_idx();
|
||||
int global_token_idx = ThreadingPolicy::token_idx();
|
||||
|
||||
int total_slots = ep_size * max_tokens_per_rank;
|
||||
if (slot_idx >= total_slots)
|
||||
int global_token_num = ep_size * max_tokens_per_rank;
|
||||
if (global_token_idx >= global_token_num)
|
||||
return;
|
||||
|
||||
// Map global token to (source_rank, token_idx)
|
||||
int source_rank = slot_idx / max_tokens_per_rank;
|
||||
int token_idx = slot_idx % max_tokens_per_rank;
|
||||
// Map global_token_idx to (rank_idx, local_token_idx)
|
||||
int rank_idx = global_token_idx / max_tokens_per_rank;
|
||||
int local_token_idx = global_token_idx % max_tokens_per_rank;
|
||||
|
||||
// Skip invalid tokens beyond per-source recv count
|
||||
if (token_idx >= recv_counters[source_rank])
|
||||
// Skip invalid tokens beyond per-rank recv count
|
||||
if (local_token_idx >= recv_counters[rank_idx])
|
||||
return;
|
||||
|
||||
// Calculate source and destination pointers for this token
|
||||
size_t slot_offset = static_cast<size_t>(slot_idx) * bytes_per_token;
|
||||
uint8_t* dst_ptr = recv_buffer_bytes + slot_offset;
|
||||
uint8_t const* src_ptr = payload_bytes + slot_offset;
|
||||
size_t offset = static_cast<size_t>(global_token_idx) * bytes_per_token;
|
||||
uint8_t* dst_ptr = recv_buffer_bytes + offset;
|
||||
uint8_t const* src_ptr = payload_bytes + offset;
|
||||
|
||||
// Copy one token's data using vectorized copy with policy
|
||||
vectorized_copy<ThreadingPolicy>(dst_ptr, src_ptr, bytes_per_token);
|
||||
@ -1118,9 +1136,9 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params)
|
||||
}
|
||||
|
||||
int bytes_per_token = params.elements_per_token * element_size;
|
||||
int total_slots = params.prepare_payload == nullptr ? 1 : params.ep_size * params.max_tokens_per_rank;
|
||||
int grid_size_warp = ceilDiv(total_slots, kWarpsPerBlock);
|
||||
int grid_size_block = total_slots; // one block per token
|
||||
int global_token_num = params.prepare_payload == nullptr ? 1 : params.ep_size * params.max_tokens_per_rank;
|
||||
int grid_size_warp = ceilDiv(global_token_num, kWarpsPerBlock);
|
||||
int grid_size_block = global_token_num; // one block per token
|
||||
|
||||
if (params.one_block_per_token)
|
||||
{
|
||||
|
||||
@ -61,6 +61,10 @@ struct DispatchKernelPointers
|
||||
// Top-K compact routing info per local token (size: [local_num_tokens, top_k])
|
||||
int* topk_target_ranks; // target rank per k, -1 for duplicates
|
||||
int* topk_send_indices; // dst index per k, -1 for duplicates
|
||||
|
||||
// Optional: Statistics for EPLB
|
||||
int const* eplb_local_stats; // [eplb_stats_num_experts]
|
||||
int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank
|
||||
};
|
||||
|
||||
// Combine kernel pointers - non-const output in src_data_ptrs[0], const recv buffers
|
||||
@ -83,13 +87,13 @@ struct CombineKernelPointers
|
||||
// Dispatch phase parameters
|
||||
struct MoeA2ADispatchParams
|
||||
{
|
||||
// Threading policy
|
||||
bool one_block_per_token; // True: one block per token, False: one warp per token
|
||||
|
||||
// Threading policy
|
||||
// EP configuration
|
||||
int ep_size; // Number of EP ranks
|
||||
int ep_rank; // Current EP rank
|
||||
int num_experts_per_rank; // Number of experts per rank (num_experts / ep_size)
|
||||
int ep_size; // Number of EP ranks
|
||||
int ep_rank; // Current EP rank
|
||||
int num_experts; // Total number of experts
|
||||
|
||||
// Token configuration
|
||||
int local_num_tokens; // Number of tokens on this rank
|
||||
@ -118,6 +122,12 @@ struct MoeA2ADispatchParams
|
||||
// rank has signaled the target rank
|
||||
void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload
|
||||
|
||||
// Optional: Statistics for EPLB
|
||||
bool enable_eplb; // Whether to enable EPLB
|
||||
int eplb_stats_num_experts; // Number of experts for EPLB stats
|
||||
int const* eplb_local_stats; // [eplb_stats_num_experts]
|
||||
int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank
|
||||
|
||||
// CUDA stream
|
||||
cudaStream_t stream;
|
||||
};
|
||||
|
||||
@ -43,8 +43,9 @@ enum MoeA2AMetaInfoIndex : int64_t
|
||||
COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = 5,
|
||||
TOPK_TARGET_RANKS_OFFSET_INDEX = 6,
|
||||
TOPK_SEND_INDICES_OFFSET_INDEX = 7,
|
||||
PAYLOAD_DATA_OFFSET_INDEX = 8,
|
||||
NUM_METAINFO_FIELDS = 9
|
||||
EPLB_GATHERED_STATS_OFFSET_INDEX = 8,
|
||||
PAYLOAD_DATA_OFFSET_INDEX = 9,
|
||||
NUM_METAINFO_FIELDS = 10
|
||||
};
|
||||
|
||||
using MoeA2ADataOffsets = std::array<int64_t, NUM_METAINFO_FIELDS>;
|
||||
@ -60,6 +61,7 @@ inline std::vector<std::pair<char const*, int64_t>> getMoeA2AMetaInfoIndexPairs(
|
||||
{"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", COMBINE_COMPLETION_FLAGS_OFFSET_INDEX},
|
||||
{"MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX", TOPK_TARGET_RANKS_OFFSET_INDEX},
|
||||
{"MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX", TOPK_SEND_INDICES_OFFSET_INDEX},
|
||||
{"MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX", EPLB_GATHERED_STATS_OFFSET_INDEX},
|
||||
{"MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX", PAYLOAD_DATA_OFFSET_INDEX},
|
||||
{"MOE_A2A_NUM_METAINFO_FIELDS", NUM_METAINFO_FIELDS},
|
||||
};
|
||||
|
||||
@ -43,7 +43,7 @@ inline size_t alignOffset(size_t offset, size_t alignment)
|
||||
}
|
||||
|
||||
// Calculate auxiliary data offsets
|
||||
MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
|
||||
MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNumExperts)
|
||||
{
|
||||
// TODO: Use lambdas to encapsulate offset and alignment for each entry, which is less error prone and easier to
|
||||
// read.
|
||||
@ -90,6 +90,11 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
|
||||
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::moe_comm::kMaxTopK)
|
||||
* SIZEOF_INT32;
|
||||
|
||||
// eplb gathered stats: [epSize, eplbStatsNumExperts]
|
||||
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
|
||||
offsets[EPLB_GATHERED_STATS_OFFSET_INDEX] = offset;
|
||||
offset += static_cast<size_t>(epSize) * static_cast<size_t>(eplbStatsNumExperts) * SIZEOF_INT32;
|
||||
|
||||
// payload data
|
||||
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
|
||||
offsets[PAYLOAD_DATA_OFFSET_INDEX] = offset;
|
||||
@ -105,10 +110,12 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
|
||||
// - epRank: Current expert parallel rank
|
||||
// - epSize: Total expert parallel size
|
||||
// - maxNumTokens: Maximum number of tokens supported
|
||||
// - eplbStatsNumExperts: (Optional) Number of experts used for EPLB stats
|
||||
//
|
||||
// Returns:
|
||||
// - metainfo: Tensor containing offsets for auxiliary data
|
||||
torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens)
|
||||
torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens,
|
||||
torch::optional<int64_t> eplbStatsNumExperts)
|
||||
{
|
||||
// Validate inputs
|
||||
CHECK_TH_CUDA(workspace);
|
||||
@ -120,8 +127,11 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank,
|
||||
// Initialize workspace to zero
|
||||
workspace[epRank].zero_();
|
||||
|
||||
int64_t eplbStatsNumExpertsValue = eplbStatsNumExperts.value_or(0);
|
||||
TORCH_CHECK(eplbStatsNumExpertsValue >= 0, "eplbStatsNumExperts must be positive if not None.");
|
||||
|
||||
// Calculate auxiliary data offsets
|
||||
MoeA2ADataOffsets offsets = calculateOffsets(epSize, maxNumTokens);
|
||||
MoeA2ADataOffsets offsets = calculateOffsets(epSize, maxNumTokens, static_cast<int>(eplbStatsNumExpertsValue));
|
||||
|
||||
// Return metainfo as a tensor containing offsets
|
||||
torch::Tensor metainfo = torch::empty(
|
||||
@ -155,18 +165,23 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank,
|
||||
// - epRank: Current expert parallel rank
|
||||
// - epSize: Total expert parallel size
|
||||
// - topK: Number of experts selected per token
|
||||
// - numExperts: Total number of experts (must be divisible by epSize)
|
||||
// - numExperts: Total number of routing slots (tokenSelectedExperts values are in [0, numExperts))
|
||||
// - eplbStatsNumExperts: Number of experts used for EPLB stats (may be <= numExperts)
|
||||
// - eplbLocalStats: [eplbStatsNumExperts] tensor containing local statistics for EPLB.
|
||||
//
|
||||
// Return values:
|
||||
// - recvTensors: Vector of receive buffers (one tensor per payload), each [ep_size, runtimeMaxTokensPerRank,
|
||||
// elements_per_token]
|
||||
// - combinePayloadOffset: Offset into workspace for the combine payload region, to be used by the combine operation
|
||||
// - eplbGatheredStats: (Optional) [ep_size, eplbStatsNumExperts] tensor containing gathered statistics for EPLB, or
|
||||
// an empty tensor if eplbLocalStats is None.
|
||||
//
|
||||
// Note: token_selected_experts is used for routing but is NOT automatically included as a payload.
|
||||
// If you want to dispatch token_selected_experts, include it explicitly in inputPayloads.
|
||||
std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor const& tokenSelectedExperts,
|
||||
std::vector<torch::Tensor> const& inputPayloads, torch::Tensor const& workspace, torch::Tensor const& metainfo,
|
||||
int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, int64_t numExperts)
|
||||
std::tuple<std::vector<torch::Tensor>, int64_t, torch::Tensor> moeA2ADispatchOp(
|
||||
torch::Tensor const& tokenSelectedExperts, std::vector<torch::Tensor> const& inputPayloads,
|
||||
torch::Tensor const& workspace, torch::Tensor const& metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank,
|
||||
int64_t epSize, int64_t topK, int64_t numExperts, torch::optional<torch::Tensor> eplbLocalStats)
|
||||
{
|
||||
using tensorrt_llm::kernels::moe_comm::PayloadDescriptor;
|
||||
using tensorrt_llm::kernels::moe_comm::MoeA2ADispatchParams;
|
||||
@ -194,6 +209,19 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
|
||||
TORCH_CHECK(inputPayloads.size() <= kMaxPayloads, "Too many input payloads");
|
||||
TORCH_CHECK(numExperts >= epSize, "numExperts must be greater than or equal to epSize");
|
||||
TORCH_CHECK(numExperts % epSize == 0, "numExperts must be divisible by epSize for contiguous partitioning");
|
||||
bool enableEplb = eplbLocalStats.has_value();
|
||||
int64_t eplbStatsNumExperts = 0;
|
||||
if (enableEplb)
|
||||
{
|
||||
TORCH_CHECK(eplbLocalStats.has_value(), "enable_eplb requires eplb_local_stats");
|
||||
torch::Tensor const& eplbLocalStatsTensor = eplbLocalStats.value();
|
||||
eplbStatsNumExperts = eplbLocalStatsTensor.size(0);
|
||||
TORCH_CHECK(eplbStatsNumExperts > 0, "eplb_local_stats must not be empty");
|
||||
TORCH_CHECK(eplbStatsNumExperts <= numExperts, "eplb_local_stats size must be <= numExperts (slots)");
|
||||
CHECK_INPUT(eplbLocalStatsTensor, torch::kInt32);
|
||||
TORCH_CHECK(eplbLocalStatsTensor.is_contiguous(), "eplb_local_stats must be contiguous");
|
||||
TORCH_CHECK(eplbLocalStatsTensor.dim() == 1, "eplb_local_stats must be a 1D tensor");
|
||||
}
|
||||
|
||||
// All input payloads must have the same first dimension (localNumTokens)
|
||||
for (auto const& payload : inputPayloads)
|
||||
@ -280,10 +308,12 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
|
||||
= tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken(); // TODO: Decide this based on the workload
|
||||
params.ep_size = static_cast<int>(epSize);
|
||||
params.ep_rank = static_cast<int>(epRank);
|
||||
params.num_experts_per_rank = static_cast<int>(numExperts) / static_cast<int>(epSize);
|
||||
params.num_experts = static_cast<int>(numExperts);
|
||||
params.local_num_tokens = static_cast<int>(localNumTokens);
|
||||
params.max_tokens_per_rank = static_cast<int>(runtimeMaxTokensPerRank);
|
||||
params.top_k = static_cast<int>(topK);
|
||||
params.enable_eplb = enableEplb;
|
||||
params.eplb_stats_num_experts = static_cast<int>(eplbStatsNumExperts);
|
||||
|
||||
params.token_selected_experts = tokenSelectedExperts.data_ptr<int32_t>();
|
||||
|
||||
@ -304,6 +334,15 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
|
||||
= reinterpret_cast<int*>(targetWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]);
|
||||
params.completion_flags[target_rank]
|
||||
= reinterpret_cast<uint32_t*>(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]);
|
||||
if (enableEplb)
|
||||
{
|
||||
params.eplb_gathered_stats[target_rank]
|
||||
= reinterpret_cast<int*>(targetWorkSpacePtr + offsets[EPLB_GATHERED_STATS_OFFSET_INDEX]);
|
||||
}
|
||||
else
|
||||
{
|
||||
params.eplb_gathered_stats[target_rank] = nullptr;
|
||||
}
|
||||
|
||||
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
|
||||
{
|
||||
@ -312,6 +351,15 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
|
||||
}
|
||||
}
|
||||
|
||||
if (enableEplb)
|
||||
{
|
||||
params.eplb_local_stats = eplbLocalStats.value().data_ptr<int32_t>();
|
||||
}
|
||||
else
|
||||
{
|
||||
params.eplb_local_stats = nullptr;
|
||||
}
|
||||
|
||||
params.stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Prepare for dispatch (zero counters/indices and increment flag_val)
|
||||
@ -336,7 +384,20 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
|
||||
// Compute aligned offset after dispatch payloads for combine payload region
|
||||
int64_t combinePayloadOffset = static_cast<int64_t>(alignOffset(currentOffset, CACHELINE_ALIGNMENT));
|
||||
|
||||
return std::make_tuple(std::move(recvTensors), combinePayloadOffset);
|
||||
torch::Tensor eplbGatheredStats;
|
||||
if (enableEplb)
|
||||
{
|
||||
int* gatheredStatsPtr = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[EPLB_GATHERED_STATS_OFFSET_INDEX]);
|
||||
auto statsOptions = workspace.options().dtype(torch::kInt32);
|
||||
eplbGatheredStats = torch::from_blob(
|
||||
gatheredStatsPtr, {static_cast<int64_t>(epSize), static_cast<int64_t>(eplbStatsNumExperts)}, statsOptions);
|
||||
}
|
||||
else
|
||||
{
|
||||
eplbGatheredStats = torch::empty({0}, workspace.options().dtype(torch::kInt32));
|
||||
}
|
||||
|
||||
return std::make_tuple(std::move(recvTensors), combinePayloadOffset, std::move(eplbGatheredStats));
|
||||
}
|
||||
|
||||
// MoE All-to-All Combine Operation
|
||||
@ -525,9 +586,12 @@ torch::Tensor moeA2AGetCombinePayloadTensorOp(torch::Tensor const& workspace, in
|
||||
}
|
||||
|
||||
// Return the size of auxiliary data in workspace
|
||||
int64_t moeA2AGetAuxDataSizeOp(int64_t epSize, int64_t maxNumTokens)
|
||||
int64_t moeA2AGetAuxDataSizeOp(int64_t epSize, int64_t maxNumTokens, torch::optional<int64_t> eplbStatsNumExperts)
|
||||
{
|
||||
MoeA2ADataOffsets offsets = calculateOffsets(static_cast<int>(epSize), static_cast<int>(maxNumTokens));
|
||||
int64_t eplbStatsNumExpertsValue = eplbStatsNumExperts.value_or(0);
|
||||
TORCH_CHECK(eplbStatsNumExpertsValue >= 0, "eplbStatsNumExperts must be positive if not None.");
|
||||
MoeA2ADataOffsets offsets = calculateOffsets(
|
||||
static_cast<int>(epSize), static_cast<int>(maxNumTokens), static_cast<int>(eplbStatsNumExpertsValue));
|
||||
return static_cast<int64_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
|
||||
}
|
||||
|
||||
@ -546,14 +610,16 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module)
|
||||
module.def(
|
||||
"moe_a2a_dispatch(Tensor token_selected_experts, Tensor[] input_payloads, "
|
||||
"Tensor(a!->*) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, "
|
||||
"int ep_rank, int ep_size, int top_k, int num_experts) -> (Tensor(a!)[], int)");
|
||||
"int ep_rank, int ep_size, int top_k, int num_experts, "
|
||||
"Tensor? eplb_local_stats=None) -> (Tensor(a!)[], int, Tensor(a!))");
|
||||
module.def(
|
||||
"moe_a2a_combine(Tensor(a) payload, int local_num_tokens,"
|
||||
"Tensor(a!) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, "
|
||||
"int ep_rank, int ep_size, int top_k, int combine_payload_offset, "
|
||||
"bool payload_in_workspace) -> Tensor");
|
||||
module.def(
|
||||
"moe_a2a_initialize(Tensor(a!) workspace, int ep_rank, int ep_size, int max_num_tokens_per_rank) -> Tensor");
|
||||
"moe_a2a_initialize(Tensor(a!) workspace, int ep_rank, int ep_size, int max_num_tokens_per_rank, "
|
||||
"int? eplb_stats_num_experts=None) -> Tensor");
|
||||
module.def(
|
||||
"moe_a2a_sanitize_expert_ids(Tensor(a!) expert_ids, Tensor(a!) workspace, Tensor metainfo, int ep_rank, int "
|
||||
"invalid_expert_id) -> ()");
|
||||
@ -561,7 +627,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module)
|
||||
"moe_a2a_get_combine_payload_tensor(Tensor(a) workspace, int ep_rank, int ep_size, int "
|
||||
"runtime_max_tokens_per_rank, "
|
||||
"int combine_payload_offset, ScalarType out_dtype, int hidden_size) -> Tensor(a)");
|
||||
module.def("moe_a2a_get_aux_data_size(int ep_size, int max_num_tokens) -> int",
|
||||
module.def("moe_a2a_get_aux_data_size(int ep_size, int max_num_tokens, int? eplb_stats_num_experts=None) -> int",
|
||||
&tensorrt_llm::torch_ext::moe_comm::moeA2AGetAuxDataSizeOp);
|
||||
}
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ class _A2AState:
|
||||
phase: str = "idle" # idle | dispatched
|
||||
local_num_tokens: int | None = None
|
||||
combine_payload_offset: int | None = None
|
||||
eplb_gathered_stats: torch.Tensor | None = None
|
||||
|
||||
|
||||
class MoeAlltoAll:
|
||||
@ -41,9 +42,13 @@ class MoeAlltoAll:
|
||||
_METAINFO_INDEX: Dict[str, int] | None = None
|
||||
|
||||
@staticmethod
|
||||
def get_aux_data_size(ep_size: int, max_num_tokens: int) -> int:
|
||||
def get_aux_data_size(
|
||||
ep_size: int,
|
||||
max_num_tokens: int,
|
||||
eplb_stats_num_experts: Optional[int] = None,
|
||||
) -> int:
|
||||
return torch.ops.trtllm.moe_a2a_get_aux_data_size(
|
||||
ep_size, max_num_tokens)
|
||||
ep_size, max_num_tokens, eplb_stats_num_experts)
|
||||
|
||||
@staticmethod
|
||||
def calculate_required_workspace_size(
|
||||
@ -52,11 +57,13 @@ class MoeAlltoAll:
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eplb_stats_num_experts: Optional[int] = None,
|
||||
extra_payload_bytes_per_token: int = 0) -> int:
|
||||
element_size = dtype.itemsize
|
||||
|
||||
# Auxiliary data size
|
||||
workspace_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens)
|
||||
workspace_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens,
|
||||
eplb_stats_num_experts)
|
||||
|
||||
# Dispatch needs workspace for [ep_size, max_tokens] tokens,
|
||||
# but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization.
|
||||
@ -103,6 +110,8 @@ class MoeAlltoAll:
|
||||
int(thop.MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX),
|
||||
"TOPK_SEND_INDICES_OFFSET_INDEX":
|
||||
int(thop.MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX),
|
||||
"EPLB_GATHERED_STATS_OFFSET_INDEX":
|
||||
int(thop.MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX),
|
||||
"PAYLOAD_DATA_OFFSET_INDEX":
|
||||
int(thop.MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX),
|
||||
"NUM_METAINFO_FIELDS":
|
||||
@ -114,8 +123,9 @@ class MoeAlltoAll:
|
||||
mapping: Mapping,
|
||||
max_num_tokens: int,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
num_slots: int,
|
||||
workspace_size_per_rank: int,
|
||||
num_experts: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize MoeAlltoAll with workspace allocation.
|
||||
@ -124,6 +134,10 @@ class MoeAlltoAll:
|
||||
mapping: TensorRT-LLM Mapping object containing rank information
|
||||
max_num_tokens: Maximum number of tokens supported. Should be ModelConfig.max_num_tokens.
|
||||
workspace_size_per_rank: Size of workspace per rank in bytes
|
||||
num_slots: Number of routing slots (token_selected_experts values are in [0, num_slots)).
|
||||
Note: The terminology is mapped to `num_experts` in this class and the kernels.
|
||||
num_experts: (Optional) Number of experts for EPLB stats (must be <= num_slots). DO NOT provide this parameter if EPLB is not enabled.
|
||||
Note: The terminology is mapped to `eplb_stats_num_experts` in this class and the kernels.
|
||||
"""
|
||||
# Check for environment variable override
|
||||
workspace_mb_env = os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB")
|
||||
@ -148,42 +162,52 @@ class MoeAlltoAll:
|
||||
self.ep_rank = mapping.moe_ep_rank
|
||||
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
self.num_experts = num_slots
|
||||
|
||||
if not isinstance(self.top_k, int) or self.top_k <= 0:
|
||||
raise ValueError("top_k must be a positive int")
|
||||
if not isinstance(self.num_experts, int) or self.num_experts <= 0:
|
||||
raise ValueError("num_experts must be a positive int")
|
||||
raise ValueError("num_slots must be a positive int")
|
||||
|
||||
if num_experts is not None:
|
||||
assert num_experts > 0 and num_experts <= num_slots, "num_experts must be in (0, num_slots]"
|
||||
tllm_logger.info(
|
||||
"NVLinkOneSided AlltoAll: EPLB is enabled, with num_slots="
|
||||
f"{num_slots} and num_experts={num_experts}")
|
||||
self.enable_eplb = num_experts is not None
|
||||
self.eplb_stats_num_experts = num_experts
|
||||
|
||||
if self._WORKSPACE is None:
|
||||
tllm_logger.info(
|
||||
f"nvlink_one_sided AlltoAll: Allocating workspace with size {workspace_size_per_rank} bytes. ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, max_num_tokens: {self.max_num_tokens}"
|
||||
f"NVLinkOneSided AlltoAll: Allocating workspace with size {workspace_size_per_rank} bytes. ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, max_num_tokens: {self.max_num_tokens}"
|
||||
)
|
||||
mnnvl_mem = MnnvlMemory(mapping, workspace_size_per_rank)
|
||||
workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8)
|
||||
metainfo = torch.ops.trtllm.moe_a2a_initialize(
|
||||
workspace,
|
||||
self.ep_rank,
|
||||
self.ep_size,
|
||||
self.max_num_tokens,
|
||||
)
|
||||
workspace, self.ep_rank, self.ep_size, self.max_num_tokens,
|
||||
self.eplb_stats_num_experts)
|
||||
MoeAlltoAll._WORKSPACE = {
|
||||
"workspace_size_per_rank": workspace_size_per_rank,
|
||||
"max_num_tokens": self.max_num_tokens,
|
||||
"ep_rank": self.ep_rank,
|
||||
"ep_size": self.ep_size,
|
||||
"eplb_stats_num_experts": self.eplb_stats_num_experts,
|
||||
"mnnvl_mem": mnnvl_mem,
|
||||
"workspace": workspace,
|
||||
"metainfo": metainfo,
|
||||
}
|
||||
else:
|
||||
assert self._WORKSPACE[
|
||||
"workspace_size_per_rank"] == workspace_size_per_rank, "reuse workspace with different workspace_size_per_rank"
|
||||
"workspace_size_per_rank"] == workspace_size_per_rank, "mistakenly reusing workspace with different workspace_size_per_rank"
|
||||
assert self._WORKSPACE[
|
||||
"max_num_tokens"] == self.max_num_tokens, "reuse workspace with different max_num_tokens"
|
||||
"max_num_tokens"] == self.max_num_tokens, "mistakenly reusing workspace with different max_num_tokens"
|
||||
assert self._WORKSPACE[
|
||||
"ep_rank"] == self.ep_rank, "reuse workspace with different ep_rank"
|
||||
"ep_rank"] == self.ep_rank, "mistakenly reusing workspace with different ep_rank"
|
||||
assert self._WORKSPACE[
|
||||
"ep_size"] == self.ep_size, "reuse workspace with different ep_size"
|
||||
"ep_size"] == self.ep_size, "mistakenly reusing workspace with different ep_size"
|
||||
assert self._WORKSPACE[
|
||||
"eplb_stats_num_experts"] == self.eplb_stats_num_experts, (
|
||||
"reuse workspace with different eplb_stats_num_experts")
|
||||
|
||||
self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"]
|
||||
self.workspace = self._WORKSPACE["workspace"]
|
||||
@ -196,7 +220,8 @@ class MoeAlltoAll:
|
||||
input_payloads: list[torch.Tensor],
|
||||
runtime_max_tokens_per_rank: int,
|
||||
invalid_token_expert_id: Optional[int] = None,
|
||||
expert_id_payload_index: Optional[int] = None):
|
||||
expert_id_payload_index: Optional[int] = None,
|
||||
eplb_local_stats: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Perform MoE all-to-all dispatch operation.
|
||||
|
||||
@ -206,19 +231,40 @@ class MoeAlltoAll:
|
||||
runtime_max_tokens_per_rank: Maximum of the number of tokens of each DP rank's local batch.
|
||||
invalid_token_expert_id: If not None, set the token_selected_experts of the invalid tokens to this expert id. This is used to notify the MoE to skip these tokens for GroupGEMM.
|
||||
expert_id_payload_index: The index of token_selected_experts in the input_payloads. Must be provided if invalid_token_expert_id is not None.
|
||||
eplb_local_stats: (Optional) [num_experts] tensor containing local statistics for EPLB
|
||||
|
||||
Returns:
|
||||
recv_tensors: List of tensors received, each has shape [ep_size, max_tokens_per_rank, payload_num_elements_per_token]
|
||||
"""
|
||||
assert self._state.phase == "idle", "dispatch called twice without an intervening combine"
|
||||
assert runtime_max_tokens_per_rank <= self.max_num_tokens, "runtime_max_tokens_per_rank must not exceed max_num_tokens"
|
||||
recv_tensors, combine_payload_offset = torch.ops.trtllm.moe_a2a_dispatch(
|
||||
token_selected_experts, input_payloads, self.workspace,
|
||||
self.metainfo, runtime_max_tokens_per_rank, self.ep_rank,
|
||||
self.ep_size, self.top_k, self.num_experts)
|
||||
if eplb_local_stats is not None:
|
||||
assert self.enable_eplb, "eplb_local_stats provided but enable_eplb is False"
|
||||
assert eplb_local_stats.dim(
|
||||
) == 1, "eplb_local_stats must be a 1D tensor"
|
||||
assert eplb_local_stats.size(
|
||||
0
|
||||
) == self.eplb_stats_num_experts, "eplb_local_stats size must match eplb_stats_num_experts"
|
||||
|
||||
recv_tensors, combine_payload_offset, eplb_gathered_stats = torch.ops.trtllm.moe_a2a_dispatch(
|
||||
token_selected_experts,
|
||||
input_payloads,
|
||||
self.workspace,
|
||||
self.metainfo,
|
||||
runtime_max_tokens_per_rank,
|
||||
self.ep_rank,
|
||||
self.ep_size,
|
||||
self.top_k,
|
||||
self.num_experts,
|
||||
eplb_local_stats,
|
||||
)
|
||||
if eplb_gathered_stats.numel() == 0:
|
||||
eplb_gathered_stats = None
|
||||
|
||||
# Update state together after successful dispatch
|
||||
self._state.local_num_tokens = token_selected_experts.size(0)
|
||||
self._state.combine_payload_offset = combine_payload_offset
|
||||
self._state.eplb_gathered_stats = eplb_gathered_stats
|
||||
self._state.phase = "dispatched"
|
||||
|
||||
if invalid_token_expert_id is not None:
|
||||
|
||||
@ -185,3 +185,9 @@ class Communication(ABC):
|
||||
Combined output tensor [local_num_tokens, hidden_size]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_eplb_gathered_statistics(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Return gathered EPLB statistics from the last dispatch, if available.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -76,6 +76,7 @@ class CommunicationFactory:
|
||||
expert_size_per_partition: Number of experts per partition (required for DeepEP)
|
||||
payload_in_workspace: If True, final_hidden_states is already in workspace (for NVLinkOneSided)
|
||||
alltoall_result_do_sum: If True, sum the alltoall results (for NVLinkTwoSided)
|
||||
# TODO: Need a way to indicate whether EPLB is enabled.
|
||||
|
||||
Returns:
|
||||
The selected communication method, or None if attention does not use DP
|
||||
@ -122,6 +123,7 @@ class CommunicationFactory:
|
||||
# Priority: NVLinkOneSided > NVLinkTwoSided > DeepEP > DeepEPLowLatency > AllGather
|
||||
|
||||
try:
|
||||
enable_eplb = model_config.moe_load_balancer is not None
|
||||
strategy = NVLinkOneSided(
|
||||
mapping,
|
||||
num_slots,
|
||||
@ -130,6 +132,7 @@ class CommunicationFactory:
|
||||
payload_in_workspace,
|
||||
hidden_size=hidden_size,
|
||||
dtype=act_dtype,
|
||||
num_experts=num_experts if enable_eplb else None,
|
||||
)
|
||||
logger.info("Selected communication strategy: NVLinkOneSided")
|
||||
return strategy
|
||||
@ -231,6 +234,7 @@ class CommunicationFactory:
|
||||
alltoall_result_do_sum=alltoall_result_do_sum,
|
||||
)
|
||||
elif method in ["NVLINK_ONE_SIDED"]:
|
||||
enable_eplb = model_config.moe_load_balancer is not None
|
||||
return NVLinkOneSided(
|
||||
mapping,
|
||||
num_slots,
|
||||
@ -239,6 +243,7 @@ class CommunicationFactory:
|
||||
payload_in_workspace,
|
||||
hidden_size=hidden_size,
|
||||
dtype=act_dtype,
|
||||
num_experts=num_experts if enable_eplb else None,
|
||||
)
|
||||
elif method == "DEEPEP":
|
||||
return DeepEP(
|
||||
|
||||
@ -65,11 +65,18 @@ class NVLinkOneSided(Communication):
|
||||
RECV_COUNTERS_OFFSET_INDEX = None
|
||||
DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = None
|
||||
COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = None
|
||||
EPLB_GATHERED_STATS_OFFSET_INDEX = None
|
||||
PAYLOAD_DATA_OFFSET_INDEX = None
|
||||
|
||||
@staticmethod
|
||||
def get_aux_data_size(ep_size: int, max_num_tokens: int) -> int:
|
||||
return torch.ops.trtllm.moe_a2a_get_aux_data_size(ep_size, max_num_tokens)
|
||||
def get_aux_data_size(
|
||||
ep_size: int,
|
||||
max_num_tokens: int,
|
||||
eplb_stats_num_experts: Optional[int] = None,
|
||||
) -> int:
|
||||
return torch.ops.trtllm.moe_a2a_get_aux_data_size(
|
||||
ep_size, max_num_tokens, eplb_stats_num_experts
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def calculate_required_workspace_size(
|
||||
@ -78,12 +85,15 @@ class NVLinkOneSided(Communication):
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eplb_stats_num_experts: Optional[int] = None,
|
||||
extra_payload_bytes_per_token: int = 0,
|
||||
) -> int:
|
||||
element_size = dtype.itemsize
|
||||
|
||||
# Auxiliary data size
|
||||
workspace_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens)
|
||||
workspace_size = NVLinkOneSided.get_aux_data_size(
|
||||
ep_size, max_num_tokens, eplb_stats_num_experts
|
||||
)
|
||||
|
||||
# Dispatch needs workspace for [ep_size, max_tokens] tokens,
|
||||
# but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization.
|
||||
@ -97,13 +107,12 @@ class NVLinkOneSided(Communication):
|
||||
# token_final_scales
|
||||
workspace_size += ep_size * max_num_tokens * top_k * 4
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
# extra payload bytes per token
|
||||
workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
|
||||
# Required workspace for combine [ep_size, max_tokens] tokens
|
||||
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
# extra payload bytes per token
|
||||
workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
|
||||
return workspace_size
|
||||
|
||||
@ -124,29 +133,36 @@ class NVLinkOneSided(Communication):
|
||||
cls.COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = int(
|
||||
thop.MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX
|
||||
)
|
||||
cls.EPLB_GATHERED_STATS_OFFSET_INDEX = int(
|
||||
thop.MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX
|
||||
)
|
||||
cls.PAYLOAD_DATA_OFFSET_INDEX = int(thop.MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mapping: Mapping,
|
||||
num_experts: int,
|
||||
num_slots: int,
|
||||
top_k: int,
|
||||
max_num_tokens_per_rank: int,
|
||||
payload_in_workspace: bool = False,
|
||||
hidden_size: Optional[int] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
num_experts: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize NVLinkOneSided with workspace allocation.
|
||||
|
||||
Args:
|
||||
mapping: TensorRT-LLM Mapping object containing rank information
|
||||
num_experts: Total number of experts
|
||||
num_slots: Number of routing slots (token_selected_experts values are in [0, num_slots)).
|
||||
Note: The terminology is mapped to `num_experts` in this class and the kernels.
|
||||
top_k: Number of experts per token
|
||||
max_num_tokens_per_rank: Maximum number of tokens per rank (for workspace allocation)
|
||||
payload_in_workspace: If True, final_hidden_states is already in workspace
|
||||
hidden_size: Hidden dimension size (optional, for auto workspace calculation)
|
||||
dtype: Data type (optional, for auto workspace calculation)
|
||||
num_experts: (Optional) Number of experts for EPLB stats (must be <= num_slots). DO NOT provide this parameter if EPLB is not enabled.
|
||||
Note: The terminology is mapped to `eplb_stats_num_experts` in this class and the kernels.
|
||||
"""
|
||||
super().__init__(mapping)
|
||||
|
||||
@ -154,10 +170,20 @@ class NVLinkOneSided(Communication):
|
||||
raise RuntimeError("Currently NVLinkOneSided only supports pure EP for MoE.")
|
||||
|
||||
# Store needed parameters
|
||||
self.num_experts = num_experts
|
||||
self.num_experts = num_slots
|
||||
self.top_k = top_k
|
||||
self.max_num_tokens_per_rank = max_num_tokens_per_rank
|
||||
self.payload_in_workspace = payload_in_workspace
|
||||
if num_experts is not None:
|
||||
assert num_experts > 0 and num_experts <= num_slots, (
|
||||
"num_experts must be in (0, num_slots]"
|
||||
)
|
||||
tllm_logger.info(
|
||||
"NVLinkOneSided AlltoAll: EPLB is enabled, with num_slots="
|
||||
f"{num_slots} and num_experts={num_experts}"
|
||||
)
|
||||
self.enable_eplb = num_experts is not None
|
||||
self.eplb_stats_num_experts = num_experts
|
||||
|
||||
# Initialize constants from C++
|
||||
self._init_constants()
|
||||
@ -166,7 +192,12 @@ class NVLinkOneSided(Communication):
|
||||
auto_workspace_size = None
|
||||
if hidden_size is not None and dtype is not None:
|
||||
auto_workspace_size = self.calculate_required_workspace_size(
|
||||
self.ep_size, self.top_k, max_num_tokens_per_rank, hidden_size, dtype
|
||||
self.ep_size,
|
||||
self.top_k,
|
||||
max_num_tokens_per_rank,
|
||||
hidden_size,
|
||||
dtype,
|
||||
eplb_stats_num_experts=self.eplb_stats_num_experts,
|
||||
)
|
||||
workspace_mb_env = os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB")
|
||||
if workspace_mb_env:
|
||||
@ -200,12 +231,14 @@ class NVLinkOneSided(Communication):
|
||||
self.ep_rank,
|
||||
self.ep_size,
|
||||
self.max_num_tokens_per_rank,
|
||||
self.eplb_stats_num_experts,
|
||||
)
|
||||
NVLinkOneSided._WORKSPACE = {
|
||||
"workspace_size_per_rank": self.workspace_size_per_rank,
|
||||
"max_num_tokens_per_rank": self.max_num_tokens_per_rank,
|
||||
"ep_rank": self.ep_rank,
|
||||
"ep_size": self.ep_size,
|
||||
"eplb_stats_num_experts": self.eplb_stats_num_experts,
|
||||
"mnnvl_mem": mnnvl_mem,
|
||||
"workspace": workspace,
|
||||
"metainfo": metainfo,
|
||||
@ -223,6 +256,9 @@ class NVLinkOneSided(Communication):
|
||||
assert self._WORKSPACE["ep_size"] == self.ep_size, (
|
||||
"reuse workspace with different ep_size"
|
||||
)
|
||||
assert self._WORKSPACE["eplb_stats_num_experts"] == self.eplb_stats_num_experts, (
|
||||
"reuse workspace with different eplb_stats_num_experts"
|
||||
)
|
||||
|
||||
self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"]
|
||||
self.workspace = self._WORKSPACE["workspace"]
|
||||
@ -230,10 +266,7 @@ class NVLinkOneSided(Communication):
|
||||
self.max_num_tokens_per_rank = self._WORKSPACE["max_num_tokens_per_rank"]
|
||||
|
||||
# Initialize dispatch state
|
||||
self._dispatch_state = {}
|
||||
|
||||
# Internal state
|
||||
self._state: str = "idle" # idle | dispatched
|
||||
self._dispatch_state = {"phase": "idle"}
|
||||
|
||||
# Invalid token expert ID (default to -1), the kernels in TRTLLM-gen is hard-code to support -1 only.
|
||||
self.invalid_token_expert_id: int = -1
|
||||
@ -286,7 +319,7 @@ class NVLinkOneSided(Communication):
|
||||
Tuple of (hidden_states, hidden_states_sf, token_selected_slots, token_final_scales)
|
||||
Each tensor has shape [ep_size, max_tokens_per_rank, ...]
|
||||
"""
|
||||
if self._state == "dispatched":
|
||||
if self._dispatch_state.get("phase") == "dispatched":
|
||||
raise RuntimeError("dispatch called twice without an intervening combine")
|
||||
|
||||
# Calculate runtime_max_tokens_per_rank from all_rank_num_tokens
|
||||
@ -304,23 +337,35 @@ class NVLinkOneSided(Communication):
|
||||
if token_final_scales is not None:
|
||||
payloads.append(token_final_scales)
|
||||
|
||||
recv_buffers, combine_payload_offset = torch.ops.trtllm.moe_a2a_dispatch(
|
||||
token_selected_slots,
|
||||
payloads,
|
||||
self.workspace,
|
||||
self.moe_a2a_metainfo,
|
||||
runtime_max_tokens_per_rank,
|
||||
self.ep_rank,
|
||||
self.ep_size,
|
||||
self.top_k,
|
||||
self.num_experts,
|
||||
eplb_local_stats = kwargs.get("eplb_local_stats")
|
||||
if eplb_local_stats is not None:
|
||||
assert self.enable_eplb, "eplb_local_stats provided but enable_eplb is False"
|
||||
assert eplb_local_stats.dim() == 1, "eplb_local_stats must be a 1D tensor"
|
||||
assert eplb_local_stats.size(0) == self.eplb_stats_num_experts, (
|
||||
"eplb_local_stats size must match eplb_stats_num_experts"
|
||||
)
|
||||
|
||||
recv_buffers, combine_payload_offset, eplb_gathered_stats = (
|
||||
torch.ops.trtllm.moe_a2a_dispatch(
|
||||
token_selected_slots,
|
||||
payloads,
|
||||
self.workspace,
|
||||
self.moe_a2a_metainfo,
|
||||
runtime_max_tokens_per_rank,
|
||||
self.ep_rank,
|
||||
self.ep_size,
|
||||
self.top_k,
|
||||
self.num_experts,
|
||||
eplb_local_stats,
|
||||
)
|
||||
)
|
||||
|
||||
self._state = "dispatched"
|
||||
|
||||
if eplb_gathered_stats.numel() == 0:
|
||||
eplb_gathered_stats = None
|
||||
self._dispatch_state["eplb_gathered_stats"] = eplb_gathered_stats
|
||||
self._dispatch_state["combine_payload_offset"] = int(combine_payload_offset)
|
||||
self._dispatch_state["local_num_tokens"] = token_selected_slots.size(0)
|
||||
self._dispatch_state["runtime_max_tokens_per_rank"] = runtime_max_tokens_per_rank
|
||||
self._dispatch_state["phase"] = "dispatched"
|
||||
|
||||
# Extract results from recv_buffers
|
||||
# Payload order matches input:
|
||||
@ -363,6 +408,13 @@ class NVLinkOneSided(Communication):
|
||||
token_final_scales_recv,
|
||||
)
|
||||
|
||||
def get_eplb_gathered_statistics(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Return gathered EPLB statistics from the last dispatch, if available.
|
||||
"""
|
||||
assert self.enable_eplb, "EPLB is not enabled"
|
||||
return self._dispatch_state.get("eplb_gathered_stats")
|
||||
|
||||
def combine(
|
||||
self,
|
||||
final_hidden_states: torch.Tensor,
|
||||
@ -380,7 +432,7 @@ class NVLinkOneSided(Communication):
|
||||
Combined output tensor [local_num_tokens, hidden_size]
|
||||
|
||||
"""
|
||||
if self._state != "dispatched":
|
||||
if self._dispatch_state.get("phase") != "dispatched":
|
||||
raise RuntimeError("combine called before a successful dispatch")
|
||||
|
||||
local_num_tokens = self._dispatch_state.get("local_num_tokens")
|
||||
@ -423,8 +475,7 @@ class NVLinkOneSided(Communication):
|
||||
)
|
||||
|
||||
# Reset state for next round
|
||||
self._state = "idle"
|
||||
self._dispatch_state.clear()
|
||||
self._dispatch_state = {"phase": "idle"}
|
||||
|
||||
return output
|
||||
|
||||
@ -444,7 +495,7 @@ class NVLinkOneSided(Communication):
|
||||
Returns:
|
||||
Tensor view into workspace [ep_size, max_tokens_per_rank, hidden_size]
|
||||
"""
|
||||
if self._state != "dispatched":
|
||||
if self._dispatch_state.get("phase") != "dispatched":
|
||||
raise RuntimeError(
|
||||
"get_combine_payload_tensor_in_workspace called before a successful dispatch"
|
||||
)
|
||||
|
||||
@ -606,9 +606,11 @@ class ConfigurableMoE(MoE):
|
||||
if self.layer_load_balancer and token_selected_experts is not None:
|
||||
self._load_balancer_done_wait_gpu_stage(is_first_call)
|
||||
|
||||
# Update EPLB statistics (method depends on whether using NVLINK two-sided)
|
||||
# Use base class method: ignore_allreduce=True for NVLINK two-sided (uses local stats only)
|
||||
ignore_allreduce = self._is_using_nvlink_two_sided()
|
||||
# Update EPLB statistics (method depends on communication strategy)
|
||||
# Use base class method: ignore_allreduce=True for NVLINK two-sided/one-sided (uses local stats only)
|
||||
ignore_allreduce = (
|
||||
self._is_using_nvlink_two_sided() or self._is_using_nvlink_one_sided()
|
||||
)
|
||||
self._load_balancer_update_statistic(
|
||||
token_selected_experts,
|
||||
is_first_call,
|
||||
@ -628,6 +630,9 @@ class ConfigurableMoE(MoE):
|
||||
# ========== Step 3.5: Communication Prepare Phase (BEFORE quantization) ==========
|
||||
# NVLINK two-sided has a prepare phase to gather EPLB statistics
|
||||
|
||||
local_statistic_tensor_for_dispatch = None
|
||||
eplb_dispatch_kwargs = {}
|
||||
should_update_eplb_after_dispatch = False
|
||||
# Only NVLINK two-sided needs prepare_dispatch
|
||||
if self._is_using_nvlink_two_sided():
|
||||
# Get local statistic info if this is the last call and EPLB is enabled
|
||||
@ -645,6 +650,16 @@ class ConfigurableMoE(MoE):
|
||||
if gathered_stats is not None:
|
||||
gathered_stats = gathered_stats.view((self.mapping.moe_ep_size, self.num_experts))
|
||||
self._load_balancer_update_statistic_with_gathered_statistic(gathered_stats)
|
||||
# TODO: The abstract does not work well as NVLinkTwoSided gathers EPLB stats in prepare_dispatch,
|
||||
# while NVLinkOneSided gathers EPLB stats in dispatch.
|
||||
elif self._is_using_nvlink_one_sided():
|
||||
if self.layer_load_balancer and is_last_call:
|
||||
local_statistic_tensor_for_dispatch = (
|
||||
self._load_balancer_get_local_statistic_tensor()
|
||||
)
|
||||
if local_statistic_tensor_for_dispatch is not None:
|
||||
eplb_dispatch_kwargs["eplb_local_stats"] = local_statistic_tensor_for_dispatch
|
||||
should_update_eplb_after_dispatch = True
|
||||
|
||||
# ========== Step 4 & 5: Quantization and Communication Dispatch ==========
|
||||
# Order depends on whether strategy supports post-quant dispatch
|
||||
@ -666,11 +681,10 @@ class ConfigurableMoE(MoE):
|
||||
# Step 4b: Dispatch AFTER quantization
|
||||
# Get pre_quant_scale for W4AFP8 if available (only DeepEPLowLatency needs it)
|
||||
# Other strategies will ignore this via **kwargs, so it's safe to pass unconditionally
|
||||
dispatch_kwargs = {}
|
||||
dispatch_kwargs = dict(eplb_dispatch_kwargs)
|
||||
if hasattr(self, "quant_scales") and self.quant_scales is not None:
|
||||
if hasattr(self.quant_scales, "pre_quant_scale_1"):
|
||||
dispatch_kwargs["pre_quant_scale"] = self.quant_scales.pre_quant_scale_1
|
||||
|
||||
x, x_sf, token_selected_slots, token_final_scales = self.comm.dispatch(
|
||||
hidden_states=x,
|
||||
hidden_states_sf=x_sf,
|
||||
@ -680,6 +694,9 @@ class ConfigurableMoE(MoE):
|
||||
use_dp_padding=use_dp_padding,
|
||||
**dispatch_kwargs,
|
||||
)
|
||||
if should_update_eplb_after_dispatch:
|
||||
gathered_stats = self.comm.get_eplb_gathered_statistics()
|
||||
self._load_balancer_update_statistic_with_gathered_statistic(gathered_stats)
|
||||
else:
|
||||
# ===== Pre-quant flow: Dispatch → Quantize =====
|
||||
|
||||
@ -960,6 +977,10 @@ class ConfigurableMoE(MoE):
|
||||
"""Check if using NVLinkTwoSided communication strategy"""
|
||||
return isinstance(self.comm, NVLinkTwoSided)
|
||||
|
||||
def _is_using_nvlink_one_sided(self) -> bool:
|
||||
"""Check if using NVLinkOneSided communication strategy"""
|
||||
return isinstance(self.comm, NVLinkOneSided)
|
||||
|
||||
def _get_nvlink_onesided_moe_output(
|
||||
self,
|
||||
all_rank_num_tokens: Optional[List[int]],
|
||||
|
||||
@ -166,15 +166,22 @@ class CutlassFusedMoE(MoE):
|
||||
dtype = self.dtype or torch.float16
|
||||
|
||||
workspace_size = MoeAlltoAll.calculate_required_workspace_size(
|
||||
ep_size, self.routing_method.experts_per_token,
|
||||
max_num_tokens, hidden_size, dtype)
|
||||
ep_size,
|
||||
self.routing_method.experts_per_token,
|
||||
max_num_tokens,
|
||||
hidden_size,
|
||||
dtype,
|
||||
self.num_experts if self.layer_load_balancer else None,
|
||||
)
|
||||
|
||||
self.moe_a2a = MoeAlltoAll(
|
||||
mapping=self.mapping,
|
||||
max_num_tokens=model_config.max_num_tokens,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_experts=self.num_slots,
|
||||
num_slots=self.num_slots,
|
||||
workspace_size_per_rank=workspace_size,
|
||||
num_experts=self.num_experts
|
||||
if self.layer_load_balancer else None,
|
||||
)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||
raise NotImplementedError(
|
||||
@ -509,7 +516,10 @@ class CutlassFusedMoE(MoE):
|
||||
|
||||
if self.layer_load_balancer:
|
||||
self._load_balancer_done_wait_gpu_stage(is_first_call)
|
||||
ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided
|
||||
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type in (
|
||||
AlltoallMethodType.NVLinkTwoSided,
|
||||
AlltoallMethodType.NVLinkOneSided,
|
||||
)
|
||||
self._load_balancer_update_statistic(
|
||||
token_selected_experts,
|
||||
is_first_call,
|
||||
@ -608,14 +618,32 @@ class CutlassFusedMoE(MoE):
|
||||
payloads.append(token_selected_slots)
|
||||
payloads.append(token_final_scales)
|
||||
|
||||
recv_tensors = self.moe_a2a.dispatch(
|
||||
token_selected_slots,
|
||||
payloads,
|
||||
runtime_max_tokens_per_rank,
|
||||
invalid_token_expert_id=self.
|
||||
num_slots, # Caution: Cutlass MoE uses num_slots as invalid token expert id
|
||||
expert_id_payload_index=expert_id_payload_index,
|
||||
)
|
||||
loadbalancer_local_statistic_info = None
|
||||
if self.layer_load_balancer and is_last_call:
|
||||
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
|
||||
)
|
||||
if loadbalancer_local_statistic_info is not None:
|
||||
recv_tensors = self.moe_a2a.dispatch(
|
||||
token_selected_slots,
|
||||
payloads,
|
||||
runtime_max_tokens_per_rank,
|
||||
invalid_token_expert_id=self.
|
||||
num_slots, # Caution: Cutlass MoE uses num_slots as invalid token expert id
|
||||
expert_id_payload_index=expert_id_payload_index,
|
||||
eplb_local_stats=loadbalancer_local_statistic_info,
|
||||
)
|
||||
gathered_stats = self.moe_a2a._state.eplb_gathered_stats
|
||||
self._load_balancer_update_statistic_with_gathered_statistic(
|
||||
gathered_stats)
|
||||
else:
|
||||
recv_tensors = self.moe_a2a.dispatch(
|
||||
token_selected_slots,
|
||||
payloads,
|
||||
runtime_max_tokens_per_rank,
|
||||
invalid_token_expert_id=self.
|
||||
num_slots, # Caution: Cutlass MoE uses num_slots as invalid token expert id
|
||||
expert_id_payload_index=expert_id_payload_index,
|
||||
)
|
||||
|
||||
if x_sf is not None:
|
||||
x_recv, x_sf_recv, token_selected_slots_recv, token_final_scales_recv = recv_tensors
|
||||
|
||||
@ -140,15 +140,22 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
dtype = self.dtype or torch.bfloat16
|
||||
|
||||
workspace_size = MoeAlltoAll.calculate_required_workspace_size(
|
||||
ep_size, self.routing_method.experts_per_token,
|
||||
max_num_tokens, hidden_size, dtype)
|
||||
ep_size,
|
||||
self.routing_method.experts_per_token,
|
||||
max_num_tokens,
|
||||
hidden_size,
|
||||
dtype,
|
||||
self.num_experts if self.layer_load_balancer else None,
|
||||
)
|
||||
|
||||
self.moe_a2a = MoeAlltoAll(
|
||||
mapping=self.mapping,
|
||||
max_num_tokens=model_config.max_num_tokens,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_experts=self.num_slots,
|
||||
workspace_size_per_rank=workspace_size)
|
||||
num_slots=self.num_slots,
|
||||
workspace_size_per_rank=workspace_size,
|
||||
num_experts=self.num_experts
|
||||
if self.layer_load_balancer else None)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||
raise NotImplementedError(
|
||||
"DeepEP and DeepEPLowLatency are not supported for TRTLLMGenFusedMoE yet"
|
||||
@ -690,7 +697,10 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
|
||||
self._load_balancer_done_wait_gpu_stage(is_first_call)
|
||||
|
||||
ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided
|
||||
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type in (
|
||||
AlltoallMethodType.NVLinkTwoSided,
|
||||
AlltoallMethodType.NVLinkOneSided,
|
||||
)
|
||||
self._load_balancer_update_statistic(
|
||||
token_selected_experts,
|
||||
is_first_call,
|
||||
@ -778,14 +788,32 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
payloads.append(token_selected_experts)
|
||||
payloads.append(token_final_scales)
|
||||
|
||||
recv_tensors = self.moe_a2a.dispatch(
|
||||
token_selected_experts,
|
||||
payloads,
|
||||
runtime_max_tokens_per_rank,
|
||||
invalid_token_expert_id=
|
||||
-1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id
|
||||
expert_id_payload_index=expert_id_payload_index,
|
||||
)
|
||||
loadbalancer_local_statistic_info = None
|
||||
if self.layer_load_balancer and is_last_call:
|
||||
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
|
||||
)
|
||||
if loadbalancer_local_statistic_info is not None:
|
||||
recv_tensors = self.moe_a2a.dispatch(
|
||||
token_selected_experts,
|
||||
payloads,
|
||||
runtime_max_tokens_per_rank,
|
||||
invalid_token_expert_id=
|
||||
-1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id
|
||||
expert_id_payload_index=expert_id_payload_index,
|
||||
eplb_local_stats=loadbalancer_local_statistic_info,
|
||||
)
|
||||
gathered_stats = self.moe_a2a._state.eplb_gathered_stats
|
||||
self._load_balancer_update_statistic_with_gathered_statistic(
|
||||
gathered_stats)
|
||||
else:
|
||||
recv_tensors = self.moe_a2a.dispatch(
|
||||
token_selected_experts,
|
||||
payloads,
|
||||
runtime_max_tokens_per_rank,
|
||||
invalid_token_expert_id=
|
||||
-1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id
|
||||
expert_id_payload_index=expert_id_payload_index,
|
||||
)
|
||||
|
||||
if x_sf is not None:
|
||||
x_recv, x_sf_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors
|
||||
|
||||
@ -56,24 +56,23 @@ def compute_target_rank_id(expert_id, num_experts_per_rank):
|
||||
return expert_id // num_experts_per_rank
|
||||
|
||||
|
||||
def generate_token_selected_experts(local_num_tokens: int, ep_size: int,
|
||||
num_experts_per_rank: int,
|
||||
def generate_token_selected_experts(local_num_tokens: int, num_experts: int,
|
||||
top_k: int) -> torch.Tensor:
|
||||
"""Generate global expert IDs tensor, aligned with single-GPU test semantics."""
|
||||
return torch.randint(
|
||||
0,
|
||||
ep_size * num_experts_per_rank,
|
||||
num_experts,
|
||||
(local_num_tokens, top_k),
|
||||
dtype=torch.int32,
|
||||
device='cuda',
|
||||
)
|
||||
|
||||
|
||||
def create_experts(num_experts_per_rank,
|
||||
hidden_size,
|
||||
ep_rank,
|
||||
device,
|
||||
dtype=torch.bfloat16):
|
||||
def create_experts_per_rank(num_experts_per_rank,
|
||||
hidden_size,
|
||||
ep_rank,
|
||||
device,
|
||||
dtype=torch.bfloat16):
|
||||
"""
|
||||
Create a 3D tensor of expert weights for a given rank.
|
||||
|
||||
@ -226,9 +225,9 @@ def make_bfloat16_payloads(
|
||||
|
||||
|
||||
def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k,
|
||||
workspace_size_per_rank,
|
||||
num_experts_per_rank, hidden_size,
|
||||
invalid_token_expert_id):
|
||||
workspace_size_per_rank, num_experts,
|
||||
hidden_size, invalid_token_expert_id,
|
||||
enable_eplb):
|
||||
"""Worker function for MPIPoolExecutor."""
|
||||
rank = tllm.mpi_rank()
|
||||
torch.cuda.set_device(rank)
|
||||
@ -244,25 +243,40 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k,
|
||||
# Create MoeAlltoAll manager
|
||||
max_num_tokens = max(all_num_tokens)
|
||||
|
||||
moe_a2a = MoeAlltoAll(mapping, max_num_tokens, top_k,
|
||||
ep_size * num_experts_per_rank,
|
||||
workspace_size_per_rank)
|
||||
eplb_stats_num_experts = (
|
||||
num_experts // 2 if enable_eplb else None
|
||||
) # Use half of the experts for testing EPLB stats
|
||||
moe_a2a = MoeAlltoAll(
|
||||
mapping=mapping,
|
||||
max_num_tokens=max_num_tokens,
|
||||
top_k=top_k,
|
||||
num_slots=num_experts,
|
||||
workspace_size_per_rank=workspace_size_per_rank,
|
||||
num_experts=eplb_stats_num_experts,
|
||||
)
|
||||
|
||||
# Get the number of tokens for this specific rank (same as single-GPU)
|
||||
rank_local_tokens = all_num_tokens[rank]
|
||||
|
||||
# Generate data using helper functions
|
||||
token_selected_experts = generate_token_selected_experts(
|
||||
rank_local_tokens, ep_size, num_experts_per_rank, top_k)
|
||||
rank_local_tokens, num_experts, top_k)
|
||||
payloads, expert_id_payload_index = make_nvfp4_payloads(
|
||||
rank_local_tokens, hidden_size, top_k, rank, token_selected_experts)
|
||||
|
||||
eplb_local_stats = None
|
||||
if enable_eplb:
|
||||
eplb_local_stats = (torch.arange(
|
||||
eplb_stats_num_experts, dtype=torch.int32, device="cuda") +
|
||||
rank * 1000)
|
||||
|
||||
recv_tensors = moe_a2a.dispatch(
|
||||
token_selected_experts,
|
||||
payloads,
|
||||
max_num_tokens,
|
||||
invalid_token_expert_id=invalid_token_expert_id,
|
||||
expert_id_payload_index=expert_id_payload_index)
|
||||
expert_id_payload_index=expert_id_payload_index,
|
||||
eplb_local_stats=eplb_local_stats)
|
||||
|
||||
# Verify completion flags after dispatch
|
||||
completion_flags_offset = moe_a2a.metainfo[MoeAlltoAll._METAINFO_INDEX[
|
||||
@ -306,10 +320,16 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k,
|
||||
max_num_tokens, top_k).cpu()
|
||||
|
||||
# Return results to be collected (move to CPU for MPI transfer)
|
||||
eplb_gathered_stats = moe_a2a._state.eplb_gathered_stats
|
||||
if eplb_gathered_stats is not None:
|
||||
eplb_gathered_stats = eplb_gathered_stats.cpu()
|
||||
if eplb_local_stats is not None:
|
||||
eplb_local_stats = eplb_local_stats.cpu()
|
||||
|
||||
return (token_selected_experts.cpu(), [p.cpu() for p in payloads],
|
||||
[rt.cpu()
|
||||
for rt in recv_tensors], send_counters, topk_send_indices,
|
||||
topk_target_ranks, recv_counters, expert_id_payload_index)
|
||||
[rt.cpu() for rt in recv_tensors], send_counters,
|
||||
topk_send_indices, topk_target_ranks, recv_counters,
|
||||
expert_id_payload_index, eplb_local_stats, eplb_gathered_stats)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
@ -318,11 +338,12 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k,
|
||||
def verify_dispatch(all_token_selected_experts, all_payloads, all_recv_tensors,
|
||||
all_send_counters, all_topk_send_indices,
|
||||
all_topk_target_ranks, all_recv_counters, ep_size,
|
||||
all_num_tokens, top_k, num_experts_per_rank,
|
||||
expert_id_payload_index, invalid_token_expert_id):
|
||||
all_num_tokens, top_k, num_experts, expert_id_payload_index,
|
||||
invalid_token_expert_id):
|
||||
"""Verify dispatch results including actual content verification"""
|
||||
|
||||
max_num_tokens = max(all_num_tokens)
|
||||
num_experts_per_rank = num_experts // ep_size
|
||||
# Verify dimensions and dtypes
|
||||
for send_rank in range(ep_size):
|
||||
local_num_tokens = all_num_tokens[send_rank]
|
||||
@ -475,26 +496,32 @@ class TestMoEAlltoAll:
|
||||
enabled=False
|
||||
) # MPI pool executors have known thread cleanup timing issues
|
||||
@pytest.mark.parametrize(
|
||||
"mpi_pool_executor,all_num_tokens,top_k",
|
||||
"mpi_pool_executor,all_num_tokens,top_k,enable_eplb",
|
||||
[
|
||||
# (num_workers, all_num_tokens, top_k)
|
||||
# Basic configurations
|
||||
(4, [32, 32, 32, 32], 2), # Four ranks with uniform distribution
|
||||
(4, [32, 32, 32, 32], 2, False
|
||||
), # Four ranks with uniform distribution
|
||||
(4, [16, 32, 64, 48
|
||||
], 2), # Four ranks with non-uniform distribution
|
||||
(2, [100, 50], 2), # Two ranks with different loads
|
||||
], 2, False), # Four ranks with non-uniform distribution
|
||||
(2, [100, 50], 2, False), # Two ranks with different loads
|
||||
(8, [10, 20, 30, 40, 50, 60, 70, 80
|
||||
], 2), # Eight ranks with increasing load
|
||||
], 2, False), # Eight ranks with increasing load
|
||||
|
||||
# Different top_k values
|
||||
(4, [32, 32, 32, 32], 4), # Four ranks with top_k = 4
|
||||
(4, [32, 32, 32, 32], 8), # Four ranks with top_k = 8
|
||||
(4, [32, 32, 32, 32], 4, False), # Four ranks with top_k = 4
|
||||
(4, [32, 32, 32, 32], 8, False), # Four ranks with top_k = 8
|
||||
|
||||
# Edge cases
|
||||
(4, [1, 1, 1, 1], 2), # Four ranks with single token per rank
|
||||
(4, [1, 1, 1, 1], 2, False
|
||||
), # Four ranks with single token per rank
|
||||
|
||||
# EPLB stats path
|
||||
(4, [32, 32, 32, 32], 2, True),
|
||||
],
|
||||
indirect=["mpi_pool_executor"])
|
||||
def test_dispatch(self, mpi_pool_executor, all_num_tokens, top_k):
|
||||
def test_dispatch(self, mpi_pool_executor, all_num_tokens, top_k,
|
||||
enable_eplb):
|
||||
"""Test MoE A2A dispatch with MNNVL across multiple GPUs"""
|
||||
|
||||
try:
|
||||
@ -511,7 +538,7 @@ class TestMoEAlltoAll:
|
||||
) >= ep_size, f"Need at least {ep_size} GPUs, found {torch.cuda.device_count()}"
|
||||
|
||||
hidden_size = 1024
|
||||
num_experts_per_rank = 8
|
||||
num_experts = 32
|
||||
|
||||
# Large enough workspace
|
||||
workspace_size_per_rank = 512 * 1024 * 1024
|
||||
@ -523,8 +550,8 @@ class TestMoEAlltoAll:
|
||||
results = mpi_pool_executor.map(
|
||||
run_moe_a2a_dispatch_single_rank,
|
||||
*zip(*[(ep_size, all_num_tokens, top_k, workspace_size_per_rank,
|
||||
num_experts_per_rank, hidden_size,
|
||||
invalid_token_expert_id)] * ep_size),
|
||||
num_experts, hidden_size, invalid_token_expert_id,
|
||||
enable_eplb)] * ep_size),
|
||||
)
|
||||
|
||||
# Collect results from all ranks (same as single-GPU collecting from emulated ranks)
|
||||
@ -540,6 +567,8 @@ class TestMoEAlltoAll:
|
||||
all_recv_counters = [r[6] for r in all_results]
|
||||
all_expert_id_payload_index = [r[7] for r in all_results]
|
||||
expert_id_payload_index = all_expert_id_payload_index[0]
|
||||
all_eplb_local_stats = [r[8] for r in all_results]
|
||||
all_eplb_gathered_stats = [r[9] for r in all_results]
|
||||
|
||||
assert all(i == expert_id_payload_index
|
||||
for i in all_expert_id_payload_index
|
||||
@ -550,9 +579,18 @@ class TestMoEAlltoAll:
|
||||
all_recv_tensors, all_send_counters,
|
||||
all_topk_send_indices, all_topk_target_ranks,
|
||||
all_recv_counters, ep_size, all_num_tokens, top_k,
|
||||
num_experts_per_rank, expert_id_payload_index,
|
||||
num_experts, expert_id_payload_index,
|
||||
invalid_token_expert_id)
|
||||
|
||||
if enable_eplb:
|
||||
expected_stats = torch.stack(all_eplb_local_stats, dim=0)
|
||||
for rank in range(ep_size):
|
||||
gathered_stats = all_eplb_gathered_stats[rank]
|
||||
assert gathered_stats is not None
|
||||
assert torch.equal(
|
||||
gathered_stats,
|
||||
expected_stats), (f"Rank {rank} gathered_stats mismatch")
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 8,
|
||||
reason='needs at least 8 GPUs to run multi-GPU test')
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
@ -587,8 +625,9 @@ class TestMoEAlltoAll:
|
||||
assert torch.cuda.device_count(
|
||||
) >= ep_size, f"Need at least {ep_size} GPUs, found {torch.cuda.device_count()}"
|
||||
|
||||
hidden_size = 2880 # gpt-oss
|
||||
num_experts_per_rank = 8
|
||||
# gpt-oss-20b
|
||||
hidden_size = 2880
|
||||
num_experts = 32
|
||||
|
||||
# Large enough workspace
|
||||
workspace_size_per_rank = 512 * 1024 * 1024
|
||||
@ -599,8 +638,8 @@ class TestMoEAlltoAll:
|
||||
results = mpi_pool_executor.map(
|
||||
run_moe_a2a_dispatch_moe_combine_single_rank,
|
||||
*zip(*[(ep_size, all_num_tokens, top_k, workspace_size_per_rank,
|
||||
num_experts_per_rank, hidden_size,
|
||||
invalid_token_expert_id)] * ep_size),
|
||||
num_experts, hidden_size, invalid_token_expert_id)] *
|
||||
ep_size),
|
||||
)
|
||||
|
||||
# Collect results
|
||||
@ -617,14 +656,12 @@ class TestMoEAlltoAll:
|
||||
|
||||
# Verify combine results
|
||||
print("Starting verification...")
|
||||
verify_combine_results(all_results, ep_size, all_num_tokens, top_k,
|
||||
hidden_size, num_experts_per_rank)
|
||||
verify_combine(all_results, ep_size)
|
||||
|
||||
|
||||
def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k,
|
||||
workspace_size_per_rank,
|
||||
num_experts_per_rank,
|
||||
hidden_size,
|
||||
num_experts, hidden_size,
|
||||
invalid_token_expert_id):
|
||||
"""Worker function for dispatch and combine test."""
|
||||
rank = tllm.mpi_rank()
|
||||
@ -639,15 +676,19 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k,
|
||||
world_size=ep_size)
|
||||
|
||||
# Create MoeAlltoAll manager
|
||||
moe_a2a = MoeAlltoAll(mapping, max_num_tokens, top_k,
|
||||
ep_size * num_experts_per_rank,
|
||||
workspace_size_per_rank)
|
||||
moe_a2a = MoeAlltoAll(
|
||||
mapping=mapping,
|
||||
max_num_tokens=max_num_tokens,
|
||||
top_k=top_k,
|
||||
num_slots=num_experts,
|
||||
workspace_size_per_rank=workspace_size_per_rank,
|
||||
)
|
||||
|
||||
rank_local_tokens = all_num_tokens[rank]
|
||||
|
||||
# Generate test data - use simpler payload for combine test
|
||||
token_selected_experts = generate_token_selected_experts(
|
||||
rank_local_tokens, ep_size, num_experts_per_rank, top_k)
|
||||
rank_local_tokens, num_experts, top_k)
|
||||
|
||||
payloads, expert_id_payload_index = make_bfloat16_payloads(
|
||||
rank_local_tokens, hidden_size, top_k, rank, token_selected_experts)
|
||||
@ -670,11 +711,12 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k,
|
||||
|
||||
# emulate MoE computation on the received data
|
||||
# Create experts for this rank
|
||||
rank_experts = create_experts(num_experts_per_rank,
|
||||
hidden_size,
|
||||
rank,
|
||||
device,
|
||||
dtype=torch.bfloat16)
|
||||
num_experts_per_rank = num_experts // ep_size
|
||||
rank_experts = create_experts_per_rank(num_experts_per_rank,
|
||||
hidden_size,
|
||||
rank,
|
||||
device,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
hidden_states_recv = fake_moe(
|
||||
hidden_states_recv.view(ep_size * max_num_tokens,
|
||||
@ -721,8 +763,7 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k,
|
||||
raise
|
||||
|
||||
|
||||
def verify_combine_results(all_results, ep_size, all_num_tokens, top_k,
|
||||
hidden_size, num_experts_per_rank):
|
||||
def verify_combine(all_results, ep_size):
|
||||
"""Verify that combine correctly sums the dispatched tokens."""
|
||||
|
||||
# Extract results
|
||||
|
||||
Loading…
Reference in New Issue
Block a user