From b4e9669d2cdb6e1f65f658797d52947dad65c94f Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Fri, 13 Feb 2026 16:59:30 -0800 Subject: [PATCH] [None][chore] Optimize MOE export by tracing with reduced experts and expanding graph (#11504) Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../_torch/auto_deploy/export/export.py | 323 ++++++++++++++++++ .../transform/library/export_to_gm.py | 8 + .../singlegpu/transformations/test_export.py | 220 ++++++++++++ 3 files changed, 551 insertions(+) diff --git a/tensorrt_llm/_torch/auto_deploy/export/export.py b/tensorrt_llm/_torch/auto_deploy/export/export.py index 4265fea9b9..b76a72bc39 100644 --- a/tensorrt_llm/_torch/auto_deploy/export/export.py +++ b/tensorrt_llm/_torch/auto_deploy/export/export.py @@ -10,6 +10,7 @@ import torch import torch.export as te import torch.nn as nn from torch import fx +from torch.utils._python_dispatch import TorchDispatchMode from ..utils._graph import canonicalize_graph, lift_to_meta, load_buffers_and_params, tree_to from ..utils.logger import ad_logger @@ -25,6 +26,309 @@ except ImportError: torch_export_context = nullcontext +# ===================================================================== +# MOE export optimization: reduce experts for faster tracing, then +# expand the graph back to include all experts after export. +# ===================================================================== + + +def _infer_target_pattern(target_0: str, target_1: str) -> Tuple[str, str]: + """Infer ``(prefix, suffix)`` from two consecutive expert-weight targets. + + Compares two ``get_attr`` targets that differ only in the expert index and + returns ``(prefix, suffix)`` such that ``target == prefix + str(idx) + suffix``. + + Example:: + + >>> _infer_target_pattern('experts.0.gate.weight', 'experts.1.gate.weight') + ('experts.', '.gate.weight') + """ + parts_0 = target_0.split(".") + parts_1 = target_1.split(".") + if len(parts_0) != len(parts_1): + raise ValueError(f"Target structure mismatch: {target_0} vs {target_1}") + + diff_positions = [i for i, (a, b) in enumerate(zip(parts_0, parts_1)) if a != b] + if len(diff_positions) != 1: + raise ValueError( + f"Expected exactly one differing part, found {len(diff_positions)}: " + f"{target_0} vs {target_1}" + ) + + idx = diff_positions[0] + prefix = ".".join(parts_0[:idx]) + "." if idx > 0 else "" + suffix = "." + ".".join(parts_0[idx + 1 :]) if idx < len(parts_0) - 1 else "" + return prefix, suffix + + +def _infer_single_target_pattern(target: str, expert_prefix: str) -> Tuple[str, str]: + """Infer ``(prefix, suffix)`` when only one expert target is available. + + Uses the known *expert_prefix* to locate the expert index position. + + Example:: + + >>> _infer_single_target_pattern('layer.0.experts.0.w.weight', 'layer.0.experts') + ('layer.0.experts.', '.w.weight') + """ + full_prefix = expert_prefix + "." + if not target.startswith(full_prefix): + raise ValueError(f"Target '{target}' does not start with '{full_prefix}'") + remainder = target[len(full_prefix) :] # e.g. '0.w.weight' + _idx_str, _, after_idx = remainder.partition(".") + suffix = "." + after_idx if after_idx else "" + return full_prefix, suffix + + +def _register_nested_parameter(gm: fx.GraphModule, dotted_name: str, param: nn.Parameter) -> None: + """Register a parameter at a nested dotted path, creating intermediate modules as needed.""" + parts = dotted_name.split(".") + current: nn.Module = gm + for part in parts[:-1]: + if hasattr(current, part): + current = getattr(current, part) + else: + new_mod = nn.Module() + current.add_module(part, new_mod) + current = new_mod + current.register_parameter(parts[-1], param) + + +class _MoeExpertProbe(TorchDispatchMode): + """Dispatch mode that records parameter tensor IDs flowing into ``torch_moe``-family ops. + + Used by :func:`_find_moe_module_lists` to discover which ``nn.ModuleList`` + instances provide expert weights without relying on attribute naming conventions. + """ + + # MOE custom ops whose list arguments represent per-expert weight tensors. + _MOE_OP_NAMES = ("torch_moe", "torch_quant_fp8_moe", "torch_quant_nvfp4_moe") + + def __init__(self): + super().__init__() + self.captured_param_ids: set = set() + self._moe_ops = self._collect_moe_ops() + + @classmethod + def _collect_moe_ops(cls) -> set: + ops: set = set() + for name in cls._MOE_OP_NAMES: + try: + ops.add(getattr(torch.ops.auto_deploy, name).default) + except AttributeError: + pass + return ops + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if func in self._moe_ops: + for arg in list(args) + list(kwargs.values()): + if isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, torch.Tensor): + self.captured_param_ids.add(id(item)) + return func(*args, **kwargs) + + +def _find_moe_module_lists( + model: nn.Module, + args: Optional[Tuple[Any, ...]] = None, + kwargs: Optional[Dict[str, Any]] = None, +) -> Dict[str, Tuple[nn.Module, str, nn.ModuleList]]: + """Identify ``nn.ModuleList`` instances whose parameters feed into ``torch_moe`` ops. + + Runs a lightweight forward pass with :class:`_MoeExpertProbe` active to + discover which ``nn.ModuleList`` children of the model contribute + per-expert weight tensors to ``torch_moe``-family custom ops. + + Returns: + Mapping of *module_list_path* → ``(parent_module, attr_name, module_list)``. + """ + # Build reverse map: id(param) → (parent_module, attr_name, module_list, full_path) + param_to_modlist: Dict[int, Tuple[nn.Module, str, nn.ModuleList, str]] = {} + for name, module in model.named_modules(): + for attr_name, child in module.named_children(): + if isinstance(child, nn.ModuleList) and len(child) > 0: + ml_path = f"{name}.{attr_name}" if name else attr_name + for param in child.parameters(): + param_to_modlist[id(param)] = (module, attr_name, child, ml_path) + + # Run a quick forward pass to see which params flow into MOE ops. + probe = _MoeExpertProbe() + with torch.inference_mode(), probe: + model(*(args or ()), **(kwargs or {})) + + # Cross-reference captured tensor IDs with ModuleList parameters. + result: Dict[str, Tuple[nn.Module, str, nn.ModuleList]] = {} + for pid in probe.captured_param_ids: + if pid in param_to_modlist: + parent, attr_name, mod_list, path = param_to_modlist[pid] + if path not in result: + result[path] = (parent, attr_name, mod_list) + + return result + + +def _reduce_moe_experts( + model: nn.Module, + min_num_experts: int, + args: Optional[Tuple[Any, ...]] = None, + kwargs: Optional[Dict[str, Any]] = None, +) -> List[Dict[str, Any]]: + """Reduce MOE expert ``nn.ModuleList``s for faster export tracing. + + Uses a probe forward pass to identify which ``nn.ModuleList`` instances + feed into ``torch_moe``-family custom ops (see :func:`_find_moe_module_lists`), + then truncates each to *min_num_experts* entries. The returned list of dicts + carries the metadata needed by :func:`_restore_moe_experts` and + :func:`_expand_moe_experts_in_graph`. + """ + if min_num_experts < 1: + raise ValueError(f"min_num_experts must be >= 1, got {min_num_experts}") + + moe_lists = _find_moe_module_lists(model, args, kwargs) + + reductions: List[Dict[str, Any]] = [] + for path, (parent, attr_name, mod_list) in moe_lists.items(): + orig_count = len(mod_list) + if orig_count <= min_num_experts: + continue + + reductions.append( + { + "module": parent, + "attr_name": attr_name, + "original_list": mod_list, + "original_count": orig_count, + "expert_prefix": path, + } + ) + setattr(parent, attr_name, nn.ModuleList(list(mod_list[:min_num_experts]))) + ad_logger.info( + f"Reduced MOE experts in '{path}' from {orig_count} to " + f"{min_num_experts} for faster export" + ) + return reductions + + +def _restore_moe_experts(reductions: List[Dict[str, Any]]) -> None: + """Restore MOE expert ``nn.ModuleList``s to their original state.""" + for info in reductions: + setattr(info["module"], info["attr_name"], info["original_list"]) + + +def _find_original_num_experts(target: str, reductions: List[Dict[str, Any]]) -> Optional[int]: + """Return the original expert count for a ``get_attr`` *target*, or ``None``.""" + for info in reductions: + if target.startswith(info["expert_prefix"] + "."): + return info["original_count"] + return None + + +def _find_expert_prefix(target: str, reductions: List[Dict[str, Any]]) -> Optional[str]: + """Return the ``expert_prefix`` that matches *target*, or ``None``.""" + for info in reductions: + if target.startswith(info["expert_prefix"] + "."): + return info["expert_prefix"] + return None + + +def _expand_moe_experts_in_graph( + gm: fx.GraphModule, + model: nn.Module, + reductions: List[Dict[str, Any]], +) -> None: + """Expand MOE expert weights in *gm* to match the full *model*. + + After exporting with a reduced number of experts this function: + + 1. Finds every ``torch_moe``-family node whose weight-list arguments are + shorter than the original expert count. + 2. Registers the missing expert parameters on *gm* (copied from the + already-restored *model*). + 3. Creates the corresponding ``get_attr`` nodes and extends the weight + lists in the call node so the graph is equivalent to a full export. + """ + if not reductions: + return + + # MOE ops whose arguments include per-expert weight lists (from index 3 onward) + moe_ops = { + torch.ops.auto_deploy.torch_moe, + torch.ops.auto_deploy.torch_quant_fp8_moe, + torch.ops.auto_deploy.torch_quant_nvfp4_moe, + } + + graph = gm.graph + num_expanded = 0 + + for node in list(graph.nodes): + if not is_op(node, moe_ops): + continue + + # Collect indices of list-of-node arguments (expert weight/scale lists) + list_arg_indices = [ + i + for i in range(3, len(node.args)) + if isinstance(node.args[i], (list, tuple)) and len(node.args[i]) > 0 + ] + if not list_arg_indices: + continue + + first_list = node.args[list_arg_indices[0]] + current_num = len(first_list) + first_target = first_list[0].target + original_num = _find_original_num_experts(first_target, reductions) + + if original_num is None or original_num <= current_num: + continue + + ad_logger.debug( + f"Expanding MOE node '{node.name}': {current_num} -> {original_num} experts" + ) + + # Insert new get_attr nodes at the very beginning of the graph + first_graph_node = next(iter(graph.nodes)) + + new_args = list(node.args) + for li in list_arg_indices: + weight_list = list(node.args[li]) + + # Determine the naming pattern: prefix + + suffix + if len(weight_list) >= 2: + prefix, suffix = _infer_target_pattern(weight_list[0].target, weight_list[1].target) + else: + ep = _find_expert_prefix(weight_list[0].target, reductions) + assert ep is not None, ( + f"Could not find expert prefix for target '{weight_list[0].target}'" + ) + prefix, suffix = _infer_single_target_pattern(weight_list[0].target, ep) + + # Add the missing expert weights + for expert_idx in range(current_num, original_num): + new_target = f"{prefix}{expert_idx}{suffix}" + + # Copy the parameter from the restored model + orig_param = model.get_parameter(new_target) + _register_nested_parameter(gm, new_target, nn.Parameter(orig_param.data)) + + # Create a get_attr node + with graph.inserting_before(first_graph_node): + new_node = graph.get_attr(new_target) + new_node.meta["val"] = gm.get_parameter(new_target) + + weight_list.append(new_node) + + new_args[li] = weight_list + + node.args = tuple(new_args) + num_expanded += 1 + + if num_expanded: + canonicalize_graph(gm) + ad_logger.info(f"Expanded {num_expanded} MOE node(s) in the exported graph") + + def _clean_up_device_info(gm: fx.GraphModule) -> None: """Correct device information in the graph.""" devices = {t.device for _, t in gm.named_parameters()} @@ -330,6 +634,7 @@ def torch_export_to_gm( strict: bool = False, patch_configs: Optional[Dict[str, Union[dict, Any]]] = None, patch_list: Optional[List[str]] = None, + num_moe_experts_for_export: Optional[int] = None, ) -> fx.GraphModule: """torch's export with wrapping into GraphModule + useful additions to the resulting module. @@ -341,6 +646,8 @@ def torch_export_to_gm( 4. Retain load hooks for state_dict loading from the original module. 5. Manage parameter aliasing in the model. 6. Remove assertions from the graph. + 7. Optionally speed up export for MOE models by tracing with fewer experts + and expanding the graph afterward. Args: model: The model to export @@ -353,6 +660,10 @@ def torch_export_to_gm( will be applied with default settings. patch_list: Optional list of patch names to apply with default settings. Cannot be used together with patch_configs. + num_moe_experts_for_export: If set, only this many experts are traced during + ``torch.export`` (the graph is expanded to include all experts afterward). + This can dramatically speed up export for large MOE models. + Recommended value: 2. """ def _capture_fn(model, args, kwargs): @@ -361,11 +672,23 @@ def torch_export_to_gm( assert isinstance(egm, fx.GraphModule) return egm + # Optionally reduce MOE experts for faster export tracing + moe_reductions: List[Dict[str, Any]] = [] + if num_moe_experts_for_export is not None: + moe_reductions = _reduce_moe_experts(model, num_moe_experts_for_export, args, kwargs) + # run capture with export egm = run_forward_for_capture( model, _capture_fn, args, kwargs, clone, patch_list=patch_list, patch_configs=patch_configs ) + # Restore full expert lists on the source model and expand the graph to include + # all expert weights. This must happen before the load-hook / deduplication + # post-processing so that those steps see the complete set of parameters. + if moe_reductions: + _restore_moe_experts(moe_reductions) + _expand_moe_experts_in_graph(egm, model, moe_reductions) + # Export strips away all methods not traced during forward. The model could have # load hooks that contain logic for correct state_dict loading. We need to add those # hooks back to the exported graph module. diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index 5cb152f3da..48f5648168 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -41,6 +41,13 @@ class ExportToGMConfig(TransformConfig): "Default is to apply all registered patches.", default=None, ) + num_moe_experts_for_export: Optional[int] = Field( + description="If set, only this many MOE experts are traced during torch.export, " + "and the graph is expanded to include all experts afterwards. " + "This can dramatically speed up export for large MOE models (e.g. 256 experts). " + "Recommended value: 2.", + default=None, + ) @contextmanager @@ -183,6 +190,7 @@ class ExportToGM(BaseTransform): clone=self.config.clone_state_dict, strict=self.config.strict, patch_list=self.config.patch_list, + num_moe_experts_for_export=self.config.num_moe_experts_for_export, ) # post process the sub graph module diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/test_export.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/test_export.py index 7a5b5e2446..72a1f3ee10 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/test_export.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/test_export.py @@ -239,3 +239,223 @@ def test_deduplicate_during_export(model_cls: Type[nn.Module], device_export: st # Test loading fc1.weight into gm.fc1.weight. State dict does not contain fc3.weight check_parameter_loading("fc3.weight", "fc1.weight") + + +# --------------------------------------------------------------------------- +# MOE export with reduced experts +# --------------------------------------------------------------------------- + +# Ensure the auto_deploy::torch_moe custom op is registered +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401, E402 + + +class SimpleMoEForExport(nn.Module): + """A minimal MOE model using the ``auto_deploy::torch_moe`` custom op. + + The *expert_attr_name* parameter controls the attribute under which the + expert ``nn.ModuleList`` is stored. This lets tests verify that the + probe-based expert discovery does **not** rely on the name ``"experts"``. + """ + + def __init__( + self, + num_experts: int = 8, + hidden_dim: int = 16, + inter_dim: int = 32, + expert_attr_name: str = "experts", + ): + super().__init__() + self.num_experts = num_experts + self.hidden_dim = hidden_dim + self.expert_attr_name = expert_attr_name + expert_list = nn.ModuleList( + [ + nn.ModuleDict( + { + "gate_proj": nn.Linear(hidden_dim, inter_dim, bias=False), + "down_proj": nn.Linear(inter_dim, hidden_dim, bias=False), + "up_proj": nn.Linear(hidden_dim, inter_dim, bias=False), + } + ) + for _ in range(num_experts) + ] + ) + # Store under a configurable attribute name so tests can verify + # that the probe does NOT rely on the name being "experts". + setattr(self, expert_attr_name, expert_list) + self.gate = nn.Linear(hidden_dim, num_experts, bias=False) + + @property + def _expert_list(self) -> nn.ModuleList: + return getattr(self, self.expert_attr_name) + + def forward(self, x): + experts = self._expert_list + router_logits = self.gate(x) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, 2, dim=-1) + routing_weights = routing_weights.to(x.dtype) + return torch.ops.auto_deploy.torch_moe( + x, + selected_experts, + routing_weights, + w1_weight=[e["gate_proj"].weight for e in experts], + w2_weight=[e["down_proj"].weight for e in experts], + w3_weight=[e["up_proj"].weight for e in experts], + ) + + +@pytest.mark.parametrize("expert_attr_name", ["experts", "mlp_bank"]) +@pytest.mark.parametrize("num_experts", [4, 8]) +@pytest.mark.parametrize("num_moe_experts_for_export", [1, 2]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_moe_export_with_reduced_experts( + num_experts, num_moe_experts_for_export, device, expert_attr_name +): + """Export with fewer experts then expand — result must match full export.""" + mod = SimpleMoEForExport(num_experts=num_experts, expert_attr_name=expert_attr_name).to(device) + mod.eval() + x = torch.randn(4, mod.hidden_dim, device=device) + + # Full export (baseline) + gm_full = torch_export_to_gm(mod, (x,)) + + # Export with reduced experts + gm_reduced = torch_export_to_gm( + mod, + (x,), + num_moe_experts_for_export=num_moe_experts_for_export, + ) + + # --- structural check: both graphs must have the right expert count --- + def _count_moe_experts(gm): + for node in gm.graph.nodes: + if node.op == "call_function" and "torch_moe" in str(node.target): + return len(node.args[3]) # w1_weight list length + return 0 + + assert _count_moe_experts(gm_full) == num_experts + assert _count_moe_experts(gm_reduced) == num_experts + + # --- numerical check: outputs must match --- + y_full = gm_full(x) + y_reduced = gm_reduced(x) + assert all_close(y_full, y_reduced), "Reduced-expert export output differs from full export" + + # --- state-dict round-trip: loading the original weights must work --- + sd = mod.state_dict() + gm_reduced.load_state_dict(sd, strict=False) + y_loaded = gm_reduced(x) + assert all_close(y_loaded, y_full), "Output after state-dict reload differs" + + # --- verify source model is unmodified --- + assert len(mod._expert_list) == num_experts, "Source model experts were not restored" + + +# --------------------------------------------------------------------------- +# Real-model MOE export: GLM4 MoE Lite +# --------------------------------------------------------------------------- + +try: + from tensorrt_llm._torch.auto_deploy.models.custom.modeling_glm4_moe_lite import ( + Glm4MoeLiteConfig, + Glm4MoeLiteForCausalLM, + ) + + _HAS_GLM4 = True +except ImportError: + _HAS_GLM4 = False + + +if _HAS_GLM4: + + def _make_tiny_glm4_config(n_routed_experts: int = 8) -> Glm4MoeLiteConfig: + """Create a minimal ``Glm4MoeLiteConfig`` suitable for unit tests.""" + return Glm4MoeLiteConfig( + vocab_size=256, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + q_lora_rank=32, + kv_lora_rank=32, + qk_nope_head_dim=12, + qk_rope_head_dim=4, + v_head_dim=16, + n_routed_experts=n_routed_experts, + n_shared_experts=1, + num_experts_per_tok=2, + moe_intermediate_size=64, + n_group=1, + topk_group=1, + routed_scaling_factor=1.0, + norm_topk_prob=True, + first_k_dense_replace=1, # layer 0 = dense MLP, layer 1 = MoE + max_position_embeddings=128, + rope_scaling=None, + pad_token_id=0, + ) + + def _count_moe_experts_in_graph(gm: GraphModule) -> int: + """Return the number of experts in the first ``torch_moe`` call in *gm*.""" + for node in gm.graph.nodes: + if node.op == "call_function" and "torch_moe" in str(node.target): + return len(node.args[3]) # w1_weight list length + return 0 + + @pytest.mark.skipif(not _HAS_GLM4, reason="GLM4 MoE Lite model not available on this branch") + @pytest.mark.skipif( + not torch.cuda.is_available(), reason="GLM4 MoE Lite requires CUDA (uses noaux_tc_op)" + ) + @pytest.mark.parametrize("n_routed_experts", [8, 16]) + @pytest.mark.parametrize("num_moe_experts_for_export", [2]) + def test_glm4_moe_lite_export_with_reduced_experts( + n_routed_experts, num_moe_experts_for_export + ): + """Export a tiny ``Glm4MoeLiteForCausalLM`` with reduced experts and verify + that the expanded graph has the correct structure and accepts the original + state dict. + """ + # GLM4 MoE Lite uses noaux_tc_op which is CUDA-only, so we must use CUDA device + device = "cuda" + config = _make_tiny_glm4_config(n_routed_experts=n_routed_experts) + model = Glm4MoeLiteForCausalLM(config).to(device) + model.eval() + + input_ids = torch.randint(0, config.vocab_size, (1, 8), device=device) + position_ids = torch.arange(8, device=device).unsqueeze(0) + sample_kwargs = {"input_ids": input_ids, "position_ids": position_ids} + + # --- full export (baseline) --- + gm_full = torch_export_to_gm(model, kwargs=sample_kwargs) + + # --- export with reduced experts --- + gm_reduced = torch_export_to_gm( + model, + kwargs=sample_kwargs, + num_moe_experts_for_export=num_moe_experts_for_export, + ) + + # Structural: both graphs must expose all experts + assert _count_moe_experts_in_graph(gm_full) == n_routed_experts + assert _count_moe_experts_in_graph(gm_reduced) == n_routed_experts + + # State-dict keys must match between full and reduced exports + full_keys = set(gm_full.state_dict().keys()) + reduced_keys = set(gm_reduced.state_dict().keys()) + assert full_keys == reduced_keys, ( + f"State-dict key mismatch.\n" + f" Only in full: {full_keys - reduced_keys}\n" + f" Only in reduced: {reduced_keys - full_keys}" + ) + + # Load the original model weights into the reduced export graph + gm_reduced.load_state_dict(model.state_dict(), strict=False) + + # Source model must be fully restored + for name, mod in model.named_modules(): + if hasattr(mod, "experts") and isinstance(mod.experts, nn.ModuleList): + assert len(mod.experts) == n_routed_experts, ( + f"Expert list in '{name}' was not restored to {n_routed_experts}" + )