TensorRT-LLMs/tests/microbenchmarks/bench_moe_comm.py
Bo Li ab73f6ebc6
[None][chore] Add microbench for MoE Comm methods. (#10317)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
2026-02-09 02:57:01 -05:00

886 lines
32 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: E501
"""Unified MoE communication microbenchmark (MPI).
This benchmarks ONLY the communication kernels, specifically:
- Communication.dispatch()
- Communication.combine()
Launch (examples):
```bash
# Basic usage
python tests/microbenchmarks/bench_moe_comm.py --ep_size 8 --backend DEEPEP --profile deepseek_v3
# Show per-kernel breakdown and stats across iterations
python tests/microbenchmarks/bench_moe_comm.py --ep_size 8 --backend NVLINK_ONE_SIDED --kernel_breakdown --iter_stats
# With batch size sweeping
python tests/microbenchmarks/bench_moe_comm.py --ep_size 8 --backend NVLINK_ONE_SIDED -b 640 -e 2048 -f 2
# Run all supported strategies for the given profile
python tests/microbenchmarks/bench_moe_comm.py --ep_size 8 --profile deepseek_v3
```
"""
from __future__ import annotations
import argparse
import json
import os
import pickle
import sys
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import cloudpickle
import torch
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
from torch.autograd import DeviceType
import tensorrt_llm as tllm
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.modules.fused_moe import CutlassFusedMoE, MoE
from tensorrt_llm._torch.modules.fused_moe.communication import Communication, CommunicationFactory
from tensorrt_llm._torch.modules.fused_moe.routing import DefaultMoeRoutingMethod
from tensorrt_llm._utils import local_mpi_rank, mpi_allgather, mpi_barrier, mpi_rank, mpi_world_size
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
@dataclass(frozen=True)
class _DummyPretrainedConfig:
# Only the two attributes used by CommunicationFactory / ModelConfig.torch_dtype.
hidden_size: int
torch_dtype: torch.dtype
@dataclass(frozen=True)
class Profile:
name: str
hidden_size: int
top_k: int
num_experts: int
quant_algo: QuantAlgo
PROFILES: Dict[str, Profile] = {
# DeepSeek-V3: hidden_size 7168, router_topk 8 (public config)
"deepseek_v3": Profile(
name="deepseek_v3",
hidden_size=7168,
top_k=8,
# Previously: experts_per_rank=32 with recommended ep_size=8 => 256 total experts.
num_experts=256,
quant_algo=QuantAlgo.FP8_BLOCK_SCALES,
),
# Repo already references "gpt-oss" hidden_size=2880 in the MoE A2A unit test.
"gpt_oss": Profile(
name="gpt_oss",
hidden_size=2880,
top_k=4,
num_experts=128,
quant_algo=QuantAlgo.W4A8_MXFP4_MXFP8,
),
}
def _maybe_warn_rank0(msg: str):
if mpi_rank() == 0:
print(msg, flush=True)
def _sync():
torch.cuda.synchronize()
mpi_barrier()
def _set_device_from_local_rank():
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark")
dev = local_mpi_rank() % torch.cuda.device_count()
torch.cuda.set_device(dev)
return dev
def _make_inputs(
local_num_tokens: int,
hidden_size: int,
top_k: int,
num_experts_total: int,
act_dtype: torch.dtype,
device: torch.device,
quant_algo: QuantAlgo,
backend: Communication,
moe: Optional[MoE] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
# Hidden states: payload we want to communicate.
hidden_states = torch.randn(local_num_tokens, hidden_size, dtype=act_dtype, device=device)
# We keep scaling factors optional; most strategies can ignore it.
hidden_states_sf = None
# Post-quant communication: Quantize → Dispatch (mirrors ConfigurableMoE ordering),
# using Cutlass' quantize_input() (outside the timed comm region).
if quant_algo != QuantAlgo.NO_QUANT and backend.supports_post_quant_dispatch():
hidden_states, hidden_states_sf = moe.quantize_input(hidden_states, post_quant_comm=True)
# Routing IDs: global expert IDs in [0, num_experts_total).
token_selected_slots = torch.randint(
0,
num_experts_total,
(local_num_tokens, top_k),
dtype=torch.int32,
device=device,
)
# Router weights/scales.
# DeepEP expects router weights/topk_weights to be float32.
token_final_scales = torch.rand(local_num_tokens, top_k, dtype=torch.float32, device=device)
return hidden_states, hidden_states_sf, token_selected_slots, token_final_scales
def _create_mapping(ep_size: int) -> Mapping:
rank = mpi_rank()
return Mapping(
rank=rank,
tp_size=ep_size,
moe_ep_size=ep_size,
enable_attention_dp=True,
world_size=ep_size,
)
def _create_model_config(
*,
mapping: Mapping,
hidden_size: int,
act_dtype: torch.dtype,
max_num_tokens_per_rank: int,
quant_config: Optional[QuantConfig],
) -> ModelConfig:
# Keep it minimal: just enough fields for CommunicationFactory.
return ModelConfig(
pretrained_config=_DummyPretrainedConfig(hidden_size=hidden_size, torch_dtype=act_dtype),
mapping=mapping,
quant_config=(quant_config or QuantConfig()),
max_num_tokens=int(max_num_tokens_per_rank),
moe_max_num_tokens=int(max_num_tokens_per_rank),
use_cuda_graph=False,
use_low_precision_moe_combine=False,
)
def _time_dispatch_and_combine(
backend: Communication,
*,
hidden_states: torch.Tensor,
hidden_states_sf: Optional[torch.Tensor],
token_selected_slots: torch.Tensor,
token_final_scales: Optional[torch.Tensor],
all_rank_num_tokens: List[int],
hidden_size: int,
warmup: int,
iters: int,
flush_l2: bool = True,
) -> Tuple[List[float], List[float], Dict[str, Any]]:
"""Time dispatch and combine using Kineto (torch.profiler with CUPTI).
Returns:
dispatch_times_us: Per-iteration dispatch GPU times in microseconds
combine_times_us: Per-iteration combine GPU times in microseconds
detailed_stats: Dict containing per-kernel timing breakdown
"""
device = hidden_states.device
# Prepare dispatch once (excluded from timing)
_ = backend.prepare_dispatch(token_selected_slots, all_rank_num_tokens, None)
# L2 cache flushing buffer
l2_buffer = None
if flush_l2:
l2_size = torch.cuda.get_device_properties(device).L2_cache_size
# Use 2x L2 size to ensure complete flush
l2_flush_size = (l2_size * 2) // 4 # Size in int32 elements
l2_buffer = torch.empty(l2_flush_size, dtype=torch.int32, device=device)
# Warmup iterations (not profiled)
for _ in range(warmup):
if l2_buffer is not None:
l2_buffer.zero_()
recv_hidden_states, _, _, _ = backend.dispatch(
hidden_states,
hidden_states_sf,
token_selected_slots,
token_final_scales,
all_rank_num_tokens,
)
shape = list(recv_hidden_states.shape)
shape[-1] = hidden_size
recv_hidden_states_moe = torch.empty(
tuple(shape), dtype=torch.bfloat16, device=recv_hidden_states.device
)
_ = backend.combine(
recv_hidden_states_moe, all_rank_max_num_tokens=max(all_rank_num_tokens)
)
# Profile with Kineto
with torch.profiler.profile(
# Include CPU so `record_function("dispatch_iter..."/"combine_iter...")` ranges
# appear in key_averages() / events(). Without CPU activity those ranges are
# missing, causing dispatch/combine attribution to fail.
activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU],
record_shapes=False,
with_stack=False,
) as prof:
_sync()
for _ in range(iters):
# L2 cache flushing
if l2_buffer is not None:
l2_buffer.zero_()
# Mark dispatch operation for aggregated timing
with torch.profiler.record_function("dispatch"):
recv_hidden_states, _, _, _ = backend.dispatch(
hidden_states,
hidden_states_sf,
token_selected_slots,
token_final_scales,
all_rank_num_tokens,
)
# Simulate MoE computation output
shape = list(recv_hidden_states.shape)
shape[-1] = hidden_size
recv_hidden_states_moe = torch.empty(
tuple(shape), dtype=torch.bfloat16, device=recv_hidden_states.device
)
# Mark combine operation for aggregated timing
with torch.profiler.record_function("combine"):
_ = backend.combine(
recv_hidden_states_moe, all_rank_max_num_tokens=max(all_rank_num_tokens)
)
_sync()
# if mpi_rank() == 0:
# print("########################################################")
# print(prof.key_averages())
# print("########################################################")
# ------------------------------------------------------------------
# Categorize GPU kernels by enclosing record_function scope
# ("dispatch" or "combine").
#
# Each record_function marker produces both a CPU event and a CUDA
# range event. We collect the GPU-side "dispatch"/"combine" ranges
# and check whether each GPU kernel's start timestamp falls inside
# one of them (GPU time-range containment).
# ------------------------------------------------------------------
events_list = list(prof.events())
# if mpi_rank() == 0:
# print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
# for evt in events_list:
# print(evt)
# print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
def _is_gpu_event(evt) -> bool:
return getattr(evt, "device_type", None) == DeviceType.CUDA
# Step 1: Collect GPU time ranges of "dispatch"/"combine" CUDA ranges.
gpu_dispatch_intervals: List[Tuple[int, int]] = []
gpu_combine_intervals: List[Tuple[int, int]] = []
for evt in events_list:
if not _is_gpu_event(evt) or evt.name not in ("dispatch", "combine"):
continue
tr = getattr(evt, "time_range", None)
if tr is None:
continue
assert tr.end > tr.start
(gpu_dispatch_intervals if evt.name == "dispatch" else gpu_combine_intervals).append(
(tr.start, tr.end)
)
gpu_dispatch_intervals.sort()
gpu_combine_intervals.sort()
# Step 2: Scope resolver (GPU events only) ---------------------------------
def _find_scope(evt) -> Optional[str]:
"""Return 'dispatch', 'combine', or None based on GPU time-range containment."""
tr = getattr(evt, "time_range", None)
if tr is None:
return None
t = tr.start
for s, e in gpu_dispatch_intervals:
if s <= t <= e:
return "dispatch"
for s, e in gpu_combine_intervals:
if s <= t <= e:
return "combine"
return None
# Step 3: Iterate events and bucket by scope ----------------------------
dispatch_kernel_times: Dict[str, List[float]] = {}
combine_kernel_times: Dict[str, List[float]] = {}
other_kernel_times: Dict[str, List[float]] = {}
for evt in events_list:
if not _is_gpu_event(evt):
continue
if evt.device_time <= 0:
continue
if evt.name in ("dispatch", "combine"):
continue # skip record_function range markers
scope = _find_scope(evt)
if scope == "dispatch":
dispatch_kernel_times.setdefault(evt.name, []).append(evt.device_time)
elif scope == "combine":
combine_kernel_times.setdefault(evt.name, []).append(evt.device_time)
else:
other_kernel_times.setdefault(evt.name, []).append(evt.device_time)
# Step 4: Build per-kernel stats ----------------------------------------
def _build_kernel_list(kernel_times: Dict[str, List[float]]) -> List[Dict[str, Any]]:
result = []
for name, times in kernel_times.items():
result.append(
{
"name": name,
"count": len(times),
"_times": times, # raw per-iteration times, gathered across ranks later
}
)
return result
dispatch_kernels = _build_kernel_list(dispatch_kernel_times)
combine_kernels = _build_kernel_list(combine_kernel_times)
other_kernels = _build_kernel_list(other_kernel_times)
# Step 5: Collect per-iteration dispatch/combine times (us) ---------------
# Use the CUDA-side "dispatch"/"combine" range events (device_type=CUDA)
# for direct GPU time measurement.
dispatch_times_us: List[float] = []
combine_times_us: List[float] = []
for evt in events_list:
if not _is_gpu_event(evt):
continue
if evt.name == "dispatch":
dispatch_times_us.append(evt.device_time)
elif evt.name == "combine":
combine_times_us.append(evt.device_time)
# Sort each category by mean time descending
dispatch_kernels.sort(
key=lambda x: sum(x["_times"]) / len(x["_times"]) if x["_times"] else 0, reverse=True
)
combine_kernels.sort(
key=lambda x: sum(x["_times"]) / len(x["_times"]) if x["_times"] else 0, reverse=True
)
other_kernels.sort(
key=lambda x: sum(x["_times"]) / len(x["_times"]) if x["_times"] else 0, reverse=True
)
detailed_stats = {
"dispatch_kernels": dispatch_kernels,
"combine_kernels": combine_kernels,
"other_kernels": other_kernels,
}
return dispatch_times_us, combine_times_us, detailed_stats
def _compute_stats(values: List[float]) -> Dict[str, float]:
"""Compute summary statistics over a list of values."""
if not values:
return {"mean": 0.0, "median": 0.0, "stdev": 0.0, "min": 0.0, "max": 0.0}
s = sorted(values)
n = len(s)
mean = sum(s) / n
variance = sum((x - mean) ** 2 for x in s) / n
return {
"mean": mean,
"median": s[n // 2],
"stdev": variance**0.5,
"min": s[0],
"max": s[-1],
}
def _gather_per_rank(times_us: List[float], iter_stats: bool = False) -> Dict[str, Any]:
"""Allgather per-iteration times from each rank, return per-rank results.
If iter_stats=True, return full stats (mean/median/stdev/min/max).
If iter_stats=False, return just the mean.
"""
all_times = mpi_allgather(times_us)
if iter_stats:
return {f"rank{i}": _compute_stats(t) for i, t in enumerate(all_times)}
return {f"rank{i}": (sum(t) / len(t) if t else 0.0) for i, t in enumerate(all_times)}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Unified MoE communication microbenchmark (MPI).")
parser.add_argument(
"--ep_size",
type=int,
default=None,
help="Number of MPI worker ranks to spawn via MPIPoolExecutor.",
)
parser.add_argument(
"--backend",
type=lambda s: str(s).upper(),
default=None,
choices=[
"ALLGATHER",
"NVLINK_ONE_SIDED",
"NVLINK_TWO_SIDED",
"DEEPEP",
"DEEPEPLOWLATENCY",
],
help="Which communication backend to benchmark (default: run all backends).",
)
parser.add_argument(
"--profile",
type=str,
default="deepseek_v3",
choices=sorted(PROFILES.keys()),
help="Optional named profile to provide defaults for hidden_size/top_k/num_experts.",
)
parser.add_argument("--hidden_size", type=int, default=None, help="Custom hidden size.")
# Sizes to scan (NCCL-tests style, adapted from bytes -> local_batch_size in tokens).
parser.add_argument(
"-b",
"--minbatch",
type=int,
default=640,
help="Minimum local_batch_size (tokens) to start with. Default: 640. For single size, pass -b N (or -e N).",
)
parser.add_argument(
"-e",
"--maxbatch",
type=int,
default=None,
help="Maximum local_batch_size (tokens) to end at. For single size, pass -e N (or -b N).",
)
# Increments can be either fixed or a multiplication factor. Only one should be used.
parser.add_argument(
"-i",
"--stepbatch",
type=int,
default=None,
help="Fixed increment between local_batch_size values (tokens). Default: 128.",
)
parser.add_argument(
"-f",
"--stepfactor",
type=float,
default=None,
help="Multiplication factor between local_batch_size values. Default: disabled.",
)
parser.add_argument("--top_k", type=int, default=None, help="Custom router top-k.")
parser.add_argument(
"--num_experts",
type=int,
default=None,
help="Total number of experts.",
)
parser.add_argument(
"--quant",
type=lambda s: str(s).upper(),
default=None,
choices=[q.name for q in QuantAlgo],
help="Override quantization algo (defaults to profile.quant_algo).",
)
parser.add_argument("--iters", type=int, default=200, help="Timed iterations.")
parser.add_argument(
"--warmup", type=int, default=20, help="Warmup iterations. (Currently ignored.)"
)
parser.add_argument(
"--max_num_tokens_per_rank",
type=int,
default=None,
help="Max tokens per rank for workspace allocation (defaults to max scanned local_batch_size).",
)
parser.add_argument(
"--kernel_breakdown",
action="store_true",
help="Show per-kernel timing breakdown.",
)
parser.add_argument(
"--iter_stats",
action="store_true",
help="Show detailed per-iteration stats (mean/median/stdev/min/max) instead of just mean.",
)
parser.add_argument(
"--output_file",
type=str,
default=None,
help="Path to write JSON report file (default: None, stdout only).",
)
return parser.parse_args()
def _iter_local_batch_sizes(args: argparse.Namespace) -> List[int]:
scanning_enabled = any(
x is not None for x in (args.minbatch, args.maxbatch, args.stepbatch, args.stepfactor)
)
if not scanning_enabled:
raise ValueError("Must specify -b/--minbatch and/or -e/--maxbatch (tokens).")
if args.stepbatch is not None and args.stepfactor is not None:
raise ValueError("Only one of -i/--stepbatch or -f/--stepfactor should be used.")
if args.minbatch is None and args.maxbatch is None:
raise ValueError("Must specify at least one of -b/--minbatch or -e/--maxbatch.")
# For single-size mode, allowing -b N or -e N (or both).
minb = int(args.minbatch) if args.minbatch is not None else int(args.maxbatch)
if minb <= 0:
raise ValueError("--minbatch must be > 0")
maxb = int(args.maxbatch) if args.maxbatch is not None else minb
if maxb < minb:
raise ValueError("--maxbatch must be >= --minbatch")
if args.stepfactor is not None:
factor = float(args.stepfactor)
if factor <= 1.0:
raise ValueError("--stepfactor must be > 1.0")
out: List[int] = []
cur = float(minb)
while True:
v = int(cur)
if v > maxb:
break
if v > 0 and (not out or out[-1] != v):
out.append(v)
cur *= factor
return out
step = int(args.stepbatch or 128)
if step <= 0:
raise ValueError("--stepbatch must be > 0")
return list(range(minb, maxb + 1, step))
def _resolve_profile_args(args: argparse.Namespace) -> Tuple[int, int, int, int, QuantAlgo]:
"""Returns (hidden_size, local_num_tokens, top_k, num_experts_total, quant_algo)."""
local_num_tokens = _iter_local_batch_sizes(args)[0]
# If a profile is provided, it supplies defaults; any explicit CLI values override.
if args.profile is not None:
prof = PROFILES[args.profile]
hidden_size = int(args.hidden_size or prof.hidden_size)
top_k = int(args.top_k or prof.top_k)
num_experts_total = int(args.num_experts or prof.num_experts)
quant_algo = prof.quant_algo
return hidden_size, local_num_tokens, top_k, num_experts_total, quant_algo
# No profile: all fields must be provided explicitly.
if args.hidden_size is None or args.top_k is None or args.num_experts is None:
raise ValueError(
"No --profile specified; must provide --hidden_size, --top_k, --num_experts."
)
hidden_size = int(args.hidden_size)
top_k = int(args.top_k)
num_experts_total = int(args.num_experts)
return hidden_size, local_num_tokens, top_k, num_experts_total, QuantAlgo.NO_QUANT
def _run_benchmark_worker_under_current_mpi(args: argparse.Namespace) -> None:
# Keep benchmark output clean.
tllm.logger.set_level("error")
# MPI-spawned workers may not inherit the parent's mutated environment reliably.
# Opt-in to DeepEP backends by default (does not override an explicit user setting).
os.environ.setdefault("TRTLLM_CAN_USE_DEEP_EP", "1")
ep_size = mpi_world_size()
rank = mpi_rank()
_ = _set_device_from_local_rank()
device = torch.device("cuda")
hidden_size, _, top_k, num_experts_total, profile_quant_algo = _resolve_profile_args(args)
local_batch_sizes = _iter_local_batch_sizes(args)
act_dtype = torch.bfloat16
quant_algo = QuantAlgo[args.quant] if args.quant is not None else profile_quant_algo
quant_config = (
QuantConfig(quant_algo=None)
if quant_algo == QuantAlgo.NO_QUANT
else QuantConfig(quant_algo=quant_algo)
)
max_scanned_tokens = max(local_batch_sizes)
max_num_tokens_per_rank = int(args.max_num_tokens_per_rank or max_scanned_tokens)
if max_num_tokens_per_rank < max_scanned_tokens:
raise ValueError(
"--max_num_tokens_per_rank must be >= the maximum scanned local_batch_size "
"(from -b/--minbatch..-e/--maxbatch)."
)
mapping = _create_mapping(ep_size)
if num_experts_total % ep_size != 0:
raise ValueError(
f"num_experts ({num_experts_total}) must be divisible by ep_size ({ep_size})"
)
experts_per_rank = num_experts_total // ep_size
num_slots = num_experts_total
if rank == 0:
print(
json.dumps(
{
"bench": "bench_moe_comm",
"launcher": "spawn",
"profile": args.profile,
"backend": args.backend,
"ep_size": ep_size,
"hidden_size": hidden_size,
"local_batch_size": local_batch_sizes,
"top_k": top_k,
"dtype": str(act_dtype),
"quant_algo": quant_algo.name,
"experts_per_rank": experts_per_rank,
"num_experts_total": num_experts_total,
"max_num_tokens_per_rank": max_num_tokens_per_rank,
"device_count": torch.cuda.device_count(),
},
indent=2,
),
flush=True,
)
backends = (
["ALLGATHER", "NVLINK_ONE_SIDED", "NVLINK_TWO_SIDED", "DEEPEP", "DEEPEPLOWLATENCY"]
if args.backend is None
else [args.backend]
)
all_results: List[Dict[str, Any]] = []
for backend_name in backends:
try:
model_config = _create_model_config(
mapping=mapping,
hidden_size=hidden_size,
act_dtype=act_dtype,
max_num_tokens_per_rank=max_num_tokens_per_rank,
quant_config=quant_config,
)
backend = CommunicationFactory._create_forced_method( # pylint: disable=protected-access
backend_name,
model_config,
num_experts_total,
num_slots,
top_k,
experts_per_rank,
payload_in_workspace=False,
alltoall_result_do_sum=True,
)
if backend is None:
_maybe_warn_rank0(
f"[bench_moe_comm] Skipping {backend_name}: factory returned None."
)
mpi_barrier()
continue
except Exception as e:
_maybe_warn_rank0(f"[bench_moe_comm] Skipping {backend_name}: {type(e).__name__}: {e}")
mpi_barrier()
continue
# Post-quant communication: Quantize → Dispatch (mirrors ConfigurableMoE ordering),
# using Cutlass' quantize_input() (outside the timed comm region).
moe = None
if quant_algo != QuantAlgo.NO_QUANT and backend.supports_post_quant_dispatch():
routing_method = DefaultMoeRoutingMethod(top_k=top_k)
moe = CutlassFusedMoE(
routing_method=routing_method,
num_experts=num_experts_total,
hidden_size=hidden_size,
intermediate_size=hidden_size * 4,
dtype=torch.bfloat16,
reduce_results=False,
model_config=model_config,
init_load_balancer=False,
without_comm=True,
)
# Ensure quantization params (e.g., NVFP4 global scale) live on CUDA.
moe = moe.to(device)
for local_num_tokens in local_batch_sizes:
all_rank_num_tokens = mpi_allgather(int(local_num_tokens))
try:
if not backend.is_workload_feasible(all_rank_num_tokens, num_chunks=1):
_maybe_warn_rank0(
f"[bench_moe_comm] Skipping {backend_name} @ local_batch_size={local_num_tokens}: workload not feasible."
)
mpi_barrier()
continue
except Exception as e:
_maybe_warn_rank0(
f"[bench_moe_comm] Skipping {backend_name} @ local_batch_size={local_num_tokens}: feasibility check failed: {e}"
)
mpi_barrier()
continue
hidden_states, hidden_states_sf, token_selected_slots, token_final_scales = (
_make_inputs(
local_num_tokens,
hidden_size,
top_k,
num_experts_total,
act_dtype,
device,
quant_algo,
backend,
moe,
)
)
# Time dispatch and combine with Kineto
dispatch_times_us, combine_times_us, detailed_stats = _time_dispatch_and_combine(
backend,
hidden_states=hidden_states,
hidden_states_sf=hidden_states_sf,
token_selected_slots=token_selected_slots,
token_final_scales=token_final_scales,
all_rank_num_tokens=all_rank_num_tokens,
hidden_size=hidden_size,
warmup=int(args.warmup),
iters=int(args.iters),
flush_l2=True,
)
iter_stats = bool(args.iter_stats)
dispatch_stats = _gather_per_rank(dispatch_times_us, iter_stats=iter_stats)
combine_stats = _gather_per_rank(combine_times_us, iter_stats=iter_stats)
# Prepare output
output = {
"backend": backend_name,
"local_batch_size": int(local_num_tokens),
"dispatch_us": dispatch_stats,
"combine_us": combine_stats,
}
# Add kernel breakdown if requested and available
if args.kernel_breakdown and detailed_stats is not None:
for category in ("dispatch_kernels", "combine_kernels", "other_kernels"):
kernels = detailed_stats.get(category, [])
for kernel in kernels:
kernel["per_rank"] = _gather_per_rank(
kernel.pop("_times", []), iter_stats=iter_stats
)
output[category] = kernels
if rank == 0:
print(json.dumps(output, indent=2), flush=True)
all_results.append(output)
mpi_barrier()
# Write JSON report if requested
if rank == 0 and args.output_file and all_results:
with open(args.output_file, "w") as f:
json.dump(all_results, f, indent=2)
print(f"Report written to {args.output_file}", flush=True)
return
def _spawn_worker_main(args_blob: bytes) -> List[Dict[str, Any]]:
"""Worker entrypoint for MPIPoolExecutor.
Note: MPIPoolExecutor typically serializes callables by name; for scripts executed
as __main__, we set MPI's pickler to cloudpickle in the parent so this function is
available to workers (same pattern as our unit tests).
"""
args = pickle.loads(args_blob)
# In spawned workers, we are already inside an MPI world of size == ep_size.
try:
_run_benchmark_worker_under_current_mpi(args)
except Exception as e:
# Make worker-side stack trace visible at the parent.
rank = mpi_rank()
size = mpi_world_size()
msg = (
"[bench_moe_comm worker] uncaught exception:\n"
f"rank={rank}/{size} local_rank={local_mpi_rank()} pid={os.getpid()}\n"
"Worker traceback:\n"
f"{traceback.format_exc()}"
)
try:
sys.stderr.write(msg + "\n")
sys.stderr.flush()
except Exception:
pass
raise RuntimeError(msg) from e
return []
def main() -> None:
args = parse_args()
# Important: parent is NOT part of the worker MPI world; do not call mpi_barrier here.
# Make functions/classes in this script pickleable to spawned workers.
# (Same pattern used in our MPI unit tests, but adapted for a script entrypoint.)
cloudpickle.register_pickle_by_value(sys.modules[__name__])
MPI.pickle.__init__( # type: ignore[attr-defined]
cloudpickle.dumps,
cloudpickle.loads,
pickle.HIGHEST_PROTOCOL,
)
ep_size = int(args.ep_size or 8)
if ep_size <= 0:
raise ValueError("--ep_size must be > 0")
if mpi_world_size() != 1:
raise RuntimeError("bench_moe_comm should be run from a non-MPI parent (world_size==1).")
if mpi_rank() == 0:
print(
json.dumps(
{
"bench": "bench_moe_comm",
"launcher": "spawn",
"ep_size": ep_size,
"profile": args.profile,
"backend": args.backend,
},
indent=2,
),
flush=True,
)
args_blob = cloudpickle.dumps(args)
with MPIPoolExecutor(max_workers=ep_size) as executor:
# Map the same args to all workers; each worker uses its own mpi_rank() and participates
# in collectives within its spawned MPI world.
_ = list(executor.map(_spawn_worker_main, [args_blob] * ep_size))
if __name__ == "__main__":
# Opt-in to DeepEP backends by default. This does not override an explicit user setting.
os.environ.setdefault("TRTLLM_CAN_USE_DEEP_EP", "1")
main()