[None][chore] Optimize MOE export by tracing with reduced experts and expanding graph (#11504)

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
This commit is contained in:
Suyog Gupta 2026-02-13 16:59:30 -08:00 committed by GitHub
parent f164669c04
commit b4e9669d2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 551 additions and 0 deletions

View File

@ -10,6 +10,7 @@ import torch
import torch.export as te
import torch.nn as nn
from torch import fx
from torch.utils._python_dispatch import TorchDispatchMode
from ..utils._graph import canonicalize_graph, lift_to_meta, load_buffers_and_params, tree_to
from ..utils.logger import ad_logger
@ -25,6 +26,309 @@ except ImportError:
torch_export_context = nullcontext
# =====================================================================
# MOE export optimization: reduce experts for faster tracing, then
# expand the graph back to include all experts after export.
# =====================================================================
def _infer_target_pattern(target_0: str, target_1: str) -> Tuple[str, str]:
"""Infer ``(prefix, suffix)`` from two consecutive expert-weight targets.
Compares two ``get_attr`` targets that differ only in the expert index and
returns ``(prefix, suffix)`` such that ``target == prefix + str(idx) + suffix``.
Example::
>>> _infer_target_pattern('experts.0.gate.weight', 'experts.1.gate.weight')
('experts.', '.gate.weight')
"""
parts_0 = target_0.split(".")
parts_1 = target_1.split(".")
if len(parts_0) != len(parts_1):
raise ValueError(f"Target structure mismatch: {target_0} vs {target_1}")
diff_positions = [i for i, (a, b) in enumerate(zip(parts_0, parts_1)) if a != b]
if len(diff_positions) != 1:
raise ValueError(
f"Expected exactly one differing part, found {len(diff_positions)}: "
f"{target_0} vs {target_1}"
)
idx = diff_positions[0]
prefix = ".".join(parts_0[:idx]) + "." if idx > 0 else ""
suffix = "." + ".".join(parts_0[idx + 1 :]) if idx < len(parts_0) - 1 else ""
return prefix, suffix
def _infer_single_target_pattern(target: str, expert_prefix: str) -> Tuple[str, str]:
"""Infer ``(prefix, suffix)`` when only one expert target is available.
Uses the known *expert_prefix* to locate the expert index position.
Example::
>>> _infer_single_target_pattern('layer.0.experts.0.w.weight', 'layer.0.experts')
('layer.0.experts.', '.w.weight')
"""
full_prefix = expert_prefix + "."
if not target.startswith(full_prefix):
raise ValueError(f"Target '{target}' does not start with '{full_prefix}'")
remainder = target[len(full_prefix) :] # e.g. '0.w.weight'
_idx_str, _, after_idx = remainder.partition(".")
suffix = "." + after_idx if after_idx else ""
return full_prefix, suffix
def _register_nested_parameter(gm: fx.GraphModule, dotted_name: str, param: nn.Parameter) -> None:
"""Register a parameter at a nested dotted path, creating intermediate modules as needed."""
parts = dotted_name.split(".")
current: nn.Module = gm
for part in parts[:-1]:
if hasattr(current, part):
current = getattr(current, part)
else:
new_mod = nn.Module()
current.add_module(part, new_mod)
current = new_mod
current.register_parameter(parts[-1], param)
class _MoeExpertProbe(TorchDispatchMode):
"""Dispatch mode that records parameter tensor IDs flowing into ``torch_moe``-family ops.
Used by :func:`_find_moe_module_lists` to discover which ``nn.ModuleList``
instances provide expert weights without relying on attribute naming conventions.
"""
# MOE custom ops whose list arguments represent per-expert weight tensors.
_MOE_OP_NAMES = ("torch_moe", "torch_quant_fp8_moe", "torch_quant_nvfp4_moe")
def __init__(self):
super().__init__()
self.captured_param_ids: set = set()
self._moe_ops = self._collect_moe_ops()
@classmethod
def _collect_moe_ops(cls) -> set:
ops: set = set()
for name in cls._MOE_OP_NAMES:
try:
ops.add(getattr(torch.ops.auto_deploy, name).default)
except AttributeError:
pass
return ops
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func in self._moe_ops:
for arg in list(args) + list(kwargs.values()):
if isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, torch.Tensor):
self.captured_param_ids.add(id(item))
return func(*args, **kwargs)
def _find_moe_module_lists(
model: nn.Module,
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Tuple[nn.Module, str, nn.ModuleList]]:
"""Identify ``nn.ModuleList`` instances whose parameters feed into ``torch_moe`` ops.
Runs a lightweight forward pass with :class:`_MoeExpertProbe` active to
discover which ``nn.ModuleList`` children of the model contribute
per-expert weight tensors to ``torch_moe``-family custom ops.
Returns:
Mapping of *module_list_path* ``(parent_module, attr_name, module_list)``.
"""
# Build reverse map: id(param) → (parent_module, attr_name, module_list, full_path)
param_to_modlist: Dict[int, Tuple[nn.Module, str, nn.ModuleList, str]] = {}
for name, module in model.named_modules():
for attr_name, child in module.named_children():
if isinstance(child, nn.ModuleList) and len(child) > 0:
ml_path = f"{name}.{attr_name}" if name else attr_name
for param in child.parameters():
param_to_modlist[id(param)] = (module, attr_name, child, ml_path)
# Run a quick forward pass to see which params flow into MOE ops.
probe = _MoeExpertProbe()
with torch.inference_mode(), probe:
model(*(args or ()), **(kwargs or {}))
# Cross-reference captured tensor IDs with ModuleList parameters.
result: Dict[str, Tuple[nn.Module, str, nn.ModuleList]] = {}
for pid in probe.captured_param_ids:
if pid in param_to_modlist:
parent, attr_name, mod_list, path = param_to_modlist[pid]
if path not in result:
result[path] = (parent, attr_name, mod_list)
return result
def _reduce_moe_experts(
model: nn.Module,
min_num_experts: int,
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""Reduce MOE expert ``nn.ModuleList``s for faster export tracing.
Uses a probe forward pass to identify which ``nn.ModuleList`` instances
feed into ``torch_moe``-family custom ops (see :func:`_find_moe_module_lists`),
then truncates each to *min_num_experts* entries. The returned list of dicts
carries the metadata needed by :func:`_restore_moe_experts` and
:func:`_expand_moe_experts_in_graph`.
"""
if min_num_experts < 1:
raise ValueError(f"min_num_experts must be >= 1, got {min_num_experts}")
moe_lists = _find_moe_module_lists(model, args, kwargs)
reductions: List[Dict[str, Any]] = []
for path, (parent, attr_name, mod_list) in moe_lists.items():
orig_count = len(mod_list)
if orig_count <= min_num_experts:
continue
reductions.append(
{
"module": parent,
"attr_name": attr_name,
"original_list": mod_list,
"original_count": orig_count,
"expert_prefix": path,
}
)
setattr(parent, attr_name, nn.ModuleList(list(mod_list[:min_num_experts])))
ad_logger.info(
f"Reduced MOE experts in '{path}' from {orig_count} to "
f"{min_num_experts} for faster export"
)
return reductions
def _restore_moe_experts(reductions: List[Dict[str, Any]]) -> None:
"""Restore MOE expert ``nn.ModuleList``s to their original state."""
for info in reductions:
setattr(info["module"], info["attr_name"], info["original_list"])
def _find_original_num_experts(target: str, reductions: List[Dict[str, Any]]) -> Optional[int]:
"""Return the original expert count for a ``get_attr`` *target*, or ``None``."""
for info in reductions:
if target.startswith(info["expert_prefix"] + "."):
return info["original_count"]
return None
def _find_expert_prefix(target: str, reductions: List[Dict[str, Any]]) -> Optional[str]:
"""Return the ``expert_prefix`` that matches *target*, or ``None``."""
for info in reductions:
if target.startswith(info["expert_prefix"] + "."):
return info["expert_prefix"]
return None
def _expand_moe_experts_in_graph(
gm: fx.GraphModule,
model: nn.Module,
reductions: List[Dict[str, Any]],
) -> None:
"""Expand MOE expert weights in *gm* to match the full *model*.
After exporting with a reduced number of experts this function:
1. Finds every ``torch_moe``-family node whose weight-list arguments are
shorter than the original expert count.
2. Registers the missing expert parameters on *gm* (copied from the
already-restored *model*).
3. Creates the corresponding ``get_attr`` nodes and extends the weight
lists in the call node so the graph is equivalent to a full export.
"""
if not reductions:
return
# MOE ops whose arguments include per-expert weight lists (from index 3 onward)
moe_ops = {
torch.ops.auto_deploy.torch_moe,
torch.ops.auto_deploy.torch_quant_fp8_moe,
torch.ops.auto_deploy.torch_quant_nvfp4_moe,
}
graph = gm.graph
num_expanded = 0
for node in list(graph.nodes):
if not is_op(node, moe_ops):
continue
# Collect indices of list-of-node arguments (expert weight/scale lists)
list_arg_indices = [
i
for i in range(3, len(node.args))
if isinstance(node.args[i], (list, tuple)) and len(node.args[i]) > 0
]
if not list_arg_indices:
continue
first_list = node.args[list_arg_indices[0]]
current_num = len(first_list)
first_target = first_list[0].target
original_num = _find_original_num_experts(first_target, reductions)
if original_num is None or original_num <= current_num:
continue
ad_logger.debug(
f"Expanding MOE node '{node.name}': {current_num} -> {original_num} experts"
)
# Insert new get_attr nodes at the very beginning of the graph
first_graph_node = next(iter(graph.nodes))
new_args = list(node.args)
for li in list_arg_indices:
weight_list = list(node.args[li])
# Determine the naming pattern: prefix + <expert_idx> + suffix
if len(weight_list) >= 2:
prefix, suffix = _infer_target_pattern(weight_list[0].target, weight_list[1].target)
else:
ep = _find_expert_prefix(weight_list[0].target, reductions)
assert ep is not None, (
f"Could not find expert prefix for target '{weight_list[0].target}'"
)
prefix, suffix = _infer_single_target_pattern(weight_list[0].target, ep)
# Add the missing expert weights
for expert_idx in range(current_num, original_num):
new_target = f"{prefix}{expert_idx}{suffix}"
# Copy the parameter from the restored model
orig_param = model.get_parameter(new_target)
_register_nested_parameter(gm, new_target, nn.Parameter(orig_param.data))
# Create a get_attr node
with graph.inserting_before(first_graph_node):
new_node = graph.get_attr(new_target)
new_node.meta["val"] = gm.get_parameter(new_target)
weight_list.append(new_node)
new_args[li] = weight_list
node.args = tuple(new_args)
num_expanded += 1
if num_expanded:
canonicalize_graph(gm)
ad_logger.info(f"Expanded {num_expanded} MOE node(s) in the exported graph")
def _clean_up_device_info(gm: fx.GraphModule) -> None:
"""Correct device information in the graph."""
devices = {t.device for _, t in gm.named_parameters()}
@ -330,6 +634,7 @@ def torch_export_to_gm(
strict: bool = False,
patch_configs: Optional[Dict[str, Union[dict, Any]]] = None,
patch_list: Optional[List[str]] = None,
num_moe_experts_for_export: Optional[int] = None,
) -> fx.GraphModule:
"""torch's export with wrapping into GraphModule + useful additions to the resulting module.
@ -341,6 +646,8 @@ def torch_export_to_gm(
4. Retain load hooks for state_dict loading from the original module.
5. Manage parameter aliasing in the model.
6. Remove assertions from the graph.
7. Optionally speed up export for MOE models by tracing with fewer experts
and expanding the graph afterward.
Args:
model: The model to export
@ -353,6 +660,10 @@ def torch_export_to_gm(
will be applied with default settings.
patch_list: Optional list of patch names to apply with default settings.
Cannot be used together with patch_configs.
num_moe_experts_for_export: If set, only this many experts are traced during
``torch.export`` (the graph is expanded to include all experts afterward).
This can dramatically speed up export for large MOE models.
Recommended value: 2.
"""
def _capture_fn(model, args, kwargs):
@ -361,11 +672,23 @@ def torch_export_to_gm(
assert isinstance(egm, fx.GraphModule)
return egm
# Optionally reduce MOE experts for faster export tracing
moe_reductions: List[Dict[str, Any]] = []
if num_moe_experts_for_export is not None:
moe_reductions = _reduce_moe_experts(model, num_moe_experts_for_export, args, kwargs)
# run capture with export
egm = run_forward_for_capture(
model, _capture_fn, args, kwargs, clone, patch_list=patch_list, patch_configs=patch_configs
)
# Restore full expert lists on the source model and expand the graph to include
# all expert weights. This must happen before the load-hook / deduplication
# post-processing so that those steps see the complete set of parameters.
if moe_reductions:
_restore_moe_experts(moe_reductions)
_expand_moe_experts_in_graph(egm, model, moe_reductions)
# Export strips away all methods not traced during forward. The model could have
# load hooks that contain logic for correct state_dict loading. We need to add those
# hooks back to the exported graph module.

View File

@ -41,6 +41,13 @@ class ExportToGMConfig(TransformConfig):
"Default is to apply all registered patches.",
default=None,
)
num_moe_experts_for_export: Optional[int] = Field(
description="If set, only this many MOE experts are traced during torch.export, "
"and the graph is expanded to include all experts afterwards. "
"This can dramatically speed up export for large MOE models (e.g. 256 experts). "
"Recommended value: 2.",
default=None,
)
@contextmanager
@ -183,6 +190,7 @@ class ExportToGM(BaseTransform):
clone=self.config.clone_state_dict,
strict=self.config.strict,
patch_list=self.config.patch_list,
num_moe_experts_for_export=self.config.num_moe_experts_for_export,
)
# post process the sub graph module

View File

@ -239,3 +239,223 @@ def test_deduplicate_during_export(model_cls: Type[nn.Module], device_export: st
# Test loading fc1.weight into gm.fc1.weight. State dict does not contain fc3.weight
check_parameter_loading("fc3.weight", "fc1.weight")
# ---------------------------------------------------------------------------
# MOE export with reduced experts
# ---------------------------------------------------------------------------
# Ensure the auto_deploy::torch_moe custom op is registered
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401, E402
class SimpleMoEForExport(nn.Module):
"""A minimal MOE model using the ``auto_deploy::torch_moe`` custom op.
The *expert_attr_name* parameter controls the attribute under which the
expert ``nn.ModuleList`` is stored. This lets tests verify that the
probe-based expert discovery does **not** rely on the name ``"experts"``.
"""
def __init__(
self,
num_experts: int = 8,
hidden_dim: int = 16,
inter_dim: int = 32,
expert_attr_name: str = "experts",
):
super().__init__()
self.num_experts = num_experts
self.hidden_dim = hidden_dim
self.expert_attr_name = expert_attr_name
expert_list = nn.ModuleList(
[
nn.ModuleDict(
{
"gate_proj": nn.Linear(hidden_dim, inter_dim, bias=False),
"down_proj": nn.Linear(inter_dim, hidden_dim, bias=False),
"up_proj": nn.Linear(hidden_dim, inter_dim, bias=False),
}
)
for _ in range(num_experts)
]
)
# Store under a configurable attribute name so tests can verify
# that the probe does NOT rely on the name being "experts".
setattr(self, expert_attr_name, expert_list)
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
@property
def _expert_list(self) -> nn.ModuleList:
return getattr(self, self.expert_attr_name)
def forward(self, x):
experts = self._expert_list
router_logits = self.gate(x)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, 2, dim=-1)
routing_weights = routing_weights.to(x.dtype)
return torch.ops.auto_deploy.torch_moe(
x,
selected_experts,
routing_weights,
w1_weight=[e["gate_proj"].weight for e in experts],
w2_weight=[e["down_proj"].weight for e in experts],
w3_weight=[e["up_proj"].weight for e in experts],
)
@pytest.mark.parametrize("expert_attr_name", ["experts", "mlp_bank"])
@pytest.mark.parametrize("num_experts", [4, 8])
@pytest.mark.parametrize("num_moe_experts_for_export", [1, 2])
@pytest.mark.parametrize("device", ["cuda"])
def test_moe_export_with_reduced_experts(
num_experts, num_moe_experts_for_export, device, expert_attr_name
):
"""Export with fewer experts then expand — result must match full export."""
mod = SimpleMoEForExport(num_experts=num_experts, expert_attr_name=expert_attr_name).to(device)
mod.eval()
x = torch.randn(4, mod.hidden_dim, device=device)
# Full export (baseline)
gm_full = torch_export_to_gm(mod, (x,))
# Export with reduced experts
gm_reduced = torch_export_to_gm(
mod,
(x,),
num_moe_experts_for_export=num_moe_experts_for_export,
)
# --- structural check: both graphs must have the right expert count ---
def _count_moe_experts(gm):
for node in gm.graph.nodes:
if node.op == "call_function" and "torch_moe" in str(node.target):
return len(node.args[3]) # w1_weight list length
return 0
assert _count_moe_experts(gm_full) == num_experts
assert _count_moe_experts(gm_reduced) == num_experts
# --- numerical check: outputs must match ---
y_full = gm_full(x)
y_reduced = gm_reduced(x)
assert all_close(y_full, y_reduced), "Reduced-expert export output differs from full export"
# --- state-dict round-trip: loading the original weights must work ---
sd = mod.state_dict()
gm_reduced.load_state_dict(sd, strict=False)
y_loaded = gm_reduced(x)
assert all_close(y_loaded, y_full), "Output after state-dict reload differs"
# --- verify source model is unmodified ---
assert len(mod._expert_list) == num_experts, "Source model experts were not restored"
# ---------------------------------------------------------------------------
# Real-model MOE export: GLM4 MoE Lite
# ---------------------------------------------------------------------------
try:
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_glm4_moe_lite import (
Glm4MoeLiteConfig,
Glm4MoeLiteForCausalLM,
)
_HAS_GLM4 = True
except ImportError:
_HAS_GLM4 = False
if _HAS_GLM4:
def _make_tiny_glm4_config(n_routed_experts: int = 8) -> Glm4MoeLiteConfig:
"""Create a minimal ``Glm4MoeLiteConfig`` suitable for unit tests."""
return Glm4MoeLiteConfig(
vocab_size=256,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=4,
q_lora_rank=32,
kv_lora_rank=32,
qk_nope_head_dim=12,
qk_rope_head_dim=4,
v_head_dim=16,
n_routed_experts=n_routed_experts,
n_shared_experts=1,
num_experts_per_tok=2,
moe_intermediate_size=64,
n_group=1,
topk_group=1,
routed_scaling_factor=1.0,
norm_topk_prob=True,
first_k_dense_replace=1, # layer 0 = dense MLP, layer 1 = MoE
max_position_embeddings=128,
rope_scaling=None,
pad_token_id=0,
)
def _count_moe_experts_in_graph(gm: GraphModule) -> int:
"""Return the number of experts in the first ``torch_moe`` call in *gm*."""
for node in gm.graph.nodes:
if node.op == "call_function" and "torch_moe" in str(node.target):
return len(node.args[3]) # w1_weight list length
return 0
@pytest.mark.skipif(not _HAS_GLM4, reason="GLM4 MoE Lite model not available on this branch")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="GLM4 MoE Lite requires CUDA (uses noaux_tc_op)"
)
@pytest.mark.parametrize("n_routed_experts", [8, 16])
@pytest.mark.parametrize("num_moe_experts_for_export", [2])
def test_glm4_moe_lite_export_with_reduced_experts(
n_routed_experts, num_moe_experts_for_export
):
"""Export a tiny ``Glm4MoeLiteForCausalLM`` with reduced experts and verify
that the expanded graph has the correct structure and accepts the original
state dict.
"""
# GLM4 MoE Lite uses noaux_tc_op which is CUDA-only, so we must use CUDA device
device = "cuda"
config = _make_tiny_glm4_config(n_routed_experts=n_routed_experts)
model = Glm4MoeLiteForCausalLM(config).to(device)
model.eval()
input_ids = torch.randint(0, config.vocab_size, (1, 8), device=device)
position_ids = torch.arange(8, device=device).unsqueeze(0)
sample_kwargs = {"input_ids": input_ids, "position_ids": position_ids}
# --- full export (baseline) ---
gm_full = torch_export_to_gm(model, kwargs=sample_kwargs)
# --- export with reduced experts ---
gm_reduced = torch_export_to_gm(
model,
kwargs=sample_kwargs,
num_moe_experts_for_export=num_moe_experts_for_export,
)
# Structural: both graphs must expose all experts
assert _count_moe_experts_in_graph(gm_full) == n_routed_experts
assert _count_moe_experts_in_graph(gm_reduced) == n_routed_experts
# State-dict keys must match between full and reduced exports
full_keys = set(gm_full.state_dict().keys())
reduced_keys = set(gm_reduced.state_dict().keys())
assert full_keys == reduced_keys, (
f"State-dict key mismatch.\n"
f" Only in full: {full_keys - reduced_keys}\n"
f" Only in reduced: {reduced_keys - full_keys}"
)
# Load the original model weights into the reduced export graph
gm_reduced.load_state_dict(model.state_dict(), strict=False)
# Source model must be fully restored
for name, mod in model.named_modules():
if hasattr(mod, "experts") and isinstance(mod.experts, nn.ModuleList):
assert len(mod.experts) == n_routed_experts, (
f"Expert list in '{name}' was not restored to {n_routed_experts}"
)