mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][chore] Optimize MOE export by tracing with reduced experts and expanding graph (#11504)
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
This commit is contained in:
parent
f164669c04
commit
b4e9669d2c
@ -10,6 +10,7 @@ import torch
|
||||
import torch.export as te
|
||||
import torch.nn as nn
|
||||
from torch import fx
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
from ..utils._graph import canonicalize_graph, lift_to_meta, load_buffers_and_params, tree_to
|
||||
from ..utils.logger import ad_logger
|
||||
@ -25,6 +26,309 @@ except ImportError:
|
||||
torch_export_context = nullcontext
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# MOE export optimization: reduce experts for faster tracing, then
|
||||
# expand the graph back to include all experts after export.
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def _infer_target_pattern(target_0: str, target_1: str) -> Tuple[str, str]:
|
||||
"""Infer ``(prefix, suffix)`` from two consecutive expert-weight targets.
|
||||
|
||||
Compares two ``get_attr`` targets that differ only in the expert index and
|
||||
returns ``(prefix, suffix)`` such that ``target == prefix + str(idx) + suffix``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> _infer_target_pattern('experts.0.gate.weight', 'experts.1.gate.weight')
|
||||
('experts.', '.gate.weight')
|
||||
"""
|
||||
parts_0 = target_0.split(".")
|
||||
parts_1 = target_1.split(".")
|
||||
if len(parts_0) != len(parts_1):
|
||||
raise ValueError(f"Target structure mismatch: {target_0} vs {target_1}")
|
||||
|
||||
diff_positions = [i for i, (a, b) in enumerate(zip(parts_0, parts_1)) if a != b]
|
||||
if len(diff_positions) != 1:
|
||||
raise ValueError(
|
||||
f"Expected exactly one differing part, found {len(diff_positions)}: "
|
||||
f"{target_0} vs {target_1}"
|
||||
)
|
||||
|
||||
idx = diff_positions[0]
|
||||
prefix = ".".join(parts_0[:idx]) + "." if idx > 0 else ""
|
||||
suffix = "." + ".".join(parts_0[idx + 1 :]) if idx < len(parts_0) - 1 else ""
|
||||
return prefix, suffix
|
||||
|
||||
|
||||
def _infer_single_target_pattern(target: str, expert_prefix: str) -> Tuple[str, str]:
|
||||
"""Infer ``(prefix, suffix)`` when only one expert target is available.
|
||||
|
||||
Uses the known *expert_prefix* to locate the expert index position.
|
||||
|
||||
Example::
|
||||
|
||||
>>> _infer_single_target_pattern('layer.0.experts.0.w.weight', 'layer.0.experts')
|
||||
('layer.0.experts.', '.w.weight')
|
||||
"""
|
||||
full_prefix = expert_prefix + "."
|
||||
if not target.startswith(full_prefix):
|
||||
raise ValueError(f"Target '{target}' does not start with '{full_prefix}'")
|
||||
remainder = target[len(full_prefix) :] # e.g. '0.w.weight'
|
||||
_idx_str, _, after_idx = remainder.partition(".")
|
||||
suffix = "." + after_idx if after_idx else ""
|
||||
return full_prefix, suffix
|
||||
|
||||
|
||||
def _register_nested_parameter(gm: fx.GraphModule, dotted_name: str, param: nn.Parameter) -> None:
|
||||
"""Register a parameter at a nested dotted path, creating intermediate modules as needed."""
|
||||
parts = dotted_name.split(".")
|
||||
current: nn.Module = gm
|
||||
for part in parts[:-1]:
|
||||
if hasattr(current, part):
|
||||
current = getattr(current, part)
|
||||
else:
|
||||
new_mod = nn.Module()
|
||||
current.add_module(part, new_mod)
|
||||
current = new_mod
|
||||
current.register_parameter(parts[-1], param)
|
||||
|
||||
|
||||
class _MoeExpertProbe(TorchDispatchMode):
|
||||
"""Dispatch mode that records parameter tensor IDs flowing into ``torch_moe``-family ops.
|
||||
|
||||
Used by :func:`_find_moe_module_lists` to discover which ``nn.ModuleList``
|
||||
instances provide expert weights without relying on attribute naming conventions.
|
||||
"""
|
||||
|
||||
# MOE custom ops whose list arguments represent per-expert weight tensors.
|
||||
_MOE_OP_NAMES = ("torch_moe", "torch_quant_fp8_moe", "torch_quant_nvfp4_moe")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.captured_param_ids: set = set()
|
||||
self._moe_ops = self._collect_moe_ops()
|
||||
|
||||
@classmethod
|
||||
def _collect_moe_ops(cls) -> set:
|
||||
ops: set = set()
|
||||
for name in cls._MOE_OP_NAMES:
|
||||
try:
|
||||
ops.add(getattr(torch.ops.auto_deploy, name).default)
|
||||
except AttributeError:
|
||||
pass
|
||||
return ops
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if func in self._moe_ops:
|
||||
for arg in list(args) + list(kwargs.values()):
|
||||
if isinstance(arg, (list, tuple)):
|
||||
for item in arg:
|
||||
if isinstance(item, torch.Tensor):
|
||||
self.captured_param_ids.add(id(item))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _find_moe_module_lists(
|
||||
model: nn.Module,
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Tuple[nn.Module, str, nn.ModuleList]]:
|
||||
"""Identify ``nn.ModuleList`` instances whose parameters feed into ``torch_moe`` ops.
|
||||
|
||||
Runs a lightweight forward pass with :class:`_MoeExpertProbe` active to
|
||||
discover which ``nn.ModuleList`` children of the model contribute
|
||||
per-expert weight tensors to ``torch_moe``-family custom ops.
|
||||
|
||||
Returns:
|
||||
Mapping of *module_list_path* → ``(parent_module, attr_name, module_list)``.
|
||||
"""
|
||||
# Build reverse map: id(param) → (parent_module, attr_name, module_list, full_path)
|
||||
param_to_modlist: Dict[int, Tuple[nn.Module, str, nn.ModuleList, str]] = {}
|
||||
for name, module in model.named_modules():
|
||||
for attr_name, child in module.named_children():
|
||||
if isinstance(child, nn.ModuleList) and len(child) > 0:
|
||||
ml_path = f"{name}.{attr_name}" if name else attr_name
|
||||
for param in child.parameters():
|
||||
param_to_modlist[id(param)] = (module, attr_name, child, ml_path)
|
||||
|
||||
# Run a quick forward pass to see which params flow into MOE ops.
|
||||
probe = _MoeExpertProbe()
|
||||
with torch.inference_mode(), probe:
|
||||
model(*(args or ()), **(kwargs or {}))
|
||||
|
||||
# Cross-reference captured tensor IDs with ModuleList parameters.
|
||||
result: Dict[str, Tuple[nn.Module, str, nn.ModuleList]] = {}
|
||||
for pid in probe.captured_param_ids:
|
||||
if pid in param_to_modlist:
|
||||
parent, attr_name, mod_list, path = param_to_modlist[pid]
|
||||
if path not in result:
|
||||
result[path] = (parent, attr_name, mod_list)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _reduce_moe_experts(
|
||||
model: nn.Module,
|
||||
min_num_experts: int,
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Reduce MOE expert ``nn.ModuleList``s for faster export tracing.
|
||||
|
||||
Uses a probe forward pass to identify which ``nn.ModuleList`` instances
|
||||
feed into ``torch_moe``-family custom ops (see :func:`_find_moe_module_lists`),
|
||||
then truncates each to *min_num_experts* entries. The returned list of dicts
|
||||
carries the metadata needed by :func:`_restore_moe_experts` and
|
||||
:func:`_expand_moe_experts_in_graph`.
|
||||
"""
|
||||
if min_num_experts < 1:
|
||||
raise ValueError(f"min_num_experts must be >= 1, got {min_num_experts}")
|
||||
|
||||
moe_lists = _find_moe_module_lists(model, args, kwargs)
|
||||
|
||||
reductions: List[Dict[str, Any]] = []
|
||||
for path, (parent, attr_name, mod_list) in moe_lists.items():
|
||||
orig_count = len(mod_list)
|
||||
if orig_count <= min_num_experts:
|
||||
continue
|
||||
|
||||
reductions.append(
|
||||
{
|
||||
"module": parent,
|
||||
"attr_name": attr_name,
|
||||
"original_list": mod_list,
|
||||
"original_count": orig_count,
|
||||
"expert_prefix": path,
|
||||
}
|
||||
)
|
||||
setattr(parent, attr_name, nn.ModuleList(list(mod_list[:min_num_experts])))
|
||||
ad_logger.info(
|
||||
f"Reduced MOE experts in '{path}' from {orig_count} to "
|
||||
f"{min_num_experts} for faster export"
|
||||
)
|
||||
return reductions
|
||||
|
||||
|
||||
def _restore_moe_experts(reductions: List[Dict[str, Any]]) -> None:
|
||||
"""Restore MOE expert ``nn.ModuleList``s to their original state."""
|
||||
for info in reductions:
|
||||
setattr(info["module"], info["attr_name"], info["original_list"])
|
||||
|
||||
|
||||
def _find_original_num_experts(target: str, reductions: List[Dict[str, Any]]) -> Optional[int]:
|
||||
"""Return the original expert count for a ``get_attr`` *target*, or ``None``."""
|
||||
for info in reductions:
|
||||
if target.startswith(info["expert_prefix"] + "."):
|
||||
return info["original_count"]
|
||||
return None
|
||||
|
||||
|
||||
def _find_expert_prefix(target: str, reductions: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Return the ``expert_prefix`` that matches *target*, or ``None``."""
|
||||
for info in reductions:
|
||||
if target.startswith(info["expert_prefix"] + "."):
|
||||
return info["expert_prefix"]
|
||||
return None
|
||||
|
||||
|
||||
def _expand_moe_experts_in_graph(
|
||||
gm: fx.GraphModule,
|
||||
model: nn.Module,
|
||||
reductions: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Expand MOE expert weights in *gm* to match the full *model*.
|
||||
|
||||
After exporting with a reduced number of experts this function:
|
||||
|
||||
1. Finds every ``torch_moe``-family node whose weight-list arguments are
|
||||
shorter than the original expert count.
|
||||
2. Registers the missing expert parameters on *gm* (copied from the
|
||||
already-restored *model*).
|
||||
3. Creates the corresponding ``get_attr`` nodes and extends the weight
|
||||
lists in the call node so the graph is equivalent to a full export.
|
||||
"""
|
||||
if not reductions:
|
||||
return
|
||||
|
||||
# MOE ops whose arguments include per-expert weight lists (from index 3 onward)
|
||||
moe_ops = {
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
torch.ops.auto_deploy.torch_quant_fp8_moe,
|
||||
torch.ops.auto_deploy.torch_quant_nvfp4_moe,
|
||||
}
|
||||
|
||||
graph = gm.graph
|
||||
num_expanded = 0
|
||||
|
||||
for node in list(graph.nodes):
|
||||
if not is_op(node, moe_ops):
|
||||
continue
|
||||
|
||||
# Collect indices of list-of-node arguments (expert weight/scale lists)
|
||||
list_arg_indices = [
|
||||
i
|
||||
for i in range(3, len(node.args))
|
||||
if isinstance(node.args[i], (list, tuple)) and len(node.args[i]) > 0
|
||||
]
|
||||
if not list_arg_indices:
|
||||
continue
|
||||
|
||||
first_list = node.args[list_arg_indices[0]]
|
||||
current_num = len(first_list)
|
||||
first_target = first_list[0].target
|
||||
original_num = _find_original_num_experts(first_target, reductions)
|
||||
|
||||
if original_num is None or original_num <= current_num:
|
||||
continue
|
||||
|
||||
ad_logger.debug(
|
||||
f"Expanding MOE node '{node.name}': {current_num} -> {original_num} experts"
|
||||
)
|
||||
|
||||
# Insert new get_attr nodes at the very beginning of the graph
|
||||
first_graph_node = next(iter(graph.nodes))
|
||||
|
||||
new_args = list(node.args)
|
||||
for li in list_arg_indices:
|
||||
weight_list = list(node.args[li])
|
||||
|
||||
# Determine the naming pattern: prefix + <expert_idx> + suffix
|
||||
if len(weight_list) >= 2:
|
||||
prefix, suffix = _infer_target_pattern(weight_list[0].target, weight_list[1].target)
|
||||
else:
|
||||
ep = _find_expert_prefix(weight_list[0].target, reductions)
|
||||
assert ep is not None, (
|
||||
f"Could not find expert prefix for target '{weight_list[0].target}'"
|
||||
)
|
||||
prefix, suffix = _infer_single_target_pattern(weight_list[0].target, ep)
|
||||
|
||||
# Add the missing expert weights
|
||||
for expert_idx in range(current_num, original_num):
|
||||
new_target = f"{prefix}{expert_idx}{suffix}"
|
||||
|
||||
# Copy the parameter from the restored model
|
||||
orig_param = model.get_parameter(new_target)
|
||||
_register_nested_parameter(gm, new_target, nn.Parameter(orig_param.data))
|
||||
|
||||
# Create a get_attr node
|
||||
with graph.inserting_before(first_graph_node):
|
||||
new_node = graph.get_attr(new_target)
|
||||
new_node.meta["val"] = gm.get_parameter(new_target)
|
||||
|
||||
weight_list.append(new_node)
|
||||
|
||||
new_args[li] = weight_list
|
||||
|
||||
node.args = tuple(new_args)
|
||||
num_expanded += 1
|
||||
|
||||
if num_expanded:
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Expanded {num_expanded} MOE node(s) in the exported graph")
|
||||
|
||||
|
||||
def _clean_up_device_info(gm: fx.GraphModule) -> None:
|
||||
"""Correct device information in the graph."""
|
||||
devices = {t.device for _, t in gm.named_parameters()}
|
||||
@ -330,6 +634,7 @@ def torch_export_to_gm(
|
||||
strict: bool = False,
|
||||
patch_configs: Optional[Dict[str, Union[dict, Any]]] = None,
|
||||
patch_list: Optional[List[str]] = None,
|
||||
num_moe_experts_for_export: Optional[int] = None,
|
||||
) -> fx.GraphModule:
|
||||
"""torch's export with wrapping into GraphModule + useful additions to the resulting module.
|
||||
|
||||
@ -341,6 +646,8 @@ def torch_export_to_gm(
|
||||
4. Retain load hooks for state_dict loading from the original module.
|
||||
5. Manage parameter aliasing in the model.
|
||||
6. Remove assertions from the graph.
|
||||
7. Optionally speed up export for MOE models by tracing with fewer experts
|
||||
and expanding the graph afterward.
|
||||
|
||||
Args:
|
||||
model: The model to export
|
||||
@ -353,6 +660,10 @@ def torch_export_to_gm(
|
||||
will be applied with default settings.
|
||||
patch_list: Optional list of patch names to apply with default settings.
|
||||
Cannot be used together with patch_configs.
|
||||
num_moe_experts_for_export: If set, only this many experts are traced during
|
||||
``torch.export`` (the graph is expanded to include all experts afterward).
|
||||
This can dramatically speed up export for large MOE models.
|
||||
Recommended value: 2.
|
||||
"""
|
||||
|
||||
def _capture_fn(model, args, kwargs):
|
||||
@ -361,11 +672,23 @@ def torch_export_to_gm(
|
||||
assert isinstance(egm, fx.GraphModule)
|
||||
return egm
|
||||
|
||||
# Optionally reduce MOE experts for faster export tracing
|
||||
moe_reductions: List[Dict[str, Any]] = []
|
||||
if num_moe_experts_for_export is not None:
|
||||
moe_reductions = _reduce_moe_experts(model, num_moe_experts_for_export, args, kwargs)
|
||||
|
||||
# run capture with export
|
||||
egm = run_forward_for_capture(
|
||||
model, _capture_fn, args, kwargs, clone, patch_list=patch_list, patch_configs=patch_configs
|
||||
)
|
||||
|
||||
# Restore full expert lists on the source model and expand the graph to include
|
||||
# all expert weights. This must happen before the load-hook / deduplication
|
||||
# post-processing so that those steps see the complete set of parameters.
|
||||
if moe_reductions:
|
||||
_restore_moe_experts(moe_reductions)
|
||||
_expand_moe_experts_in_graph(egm, model, moe_reductions)
|
||||
|
||||
# Export strips away all methods not traced during forward. The model could have
|
||||
# load hooks that contain logic for correct state_dict loading. We need to add those
|
||||
# hooks back to the exported graph module.
|
||||
|
||||
@ -41,6 +41,13 @@ class ExportToGMConfig(TransformConfig):
|
||||
"Default is to apply all registered patches.",
|
||||
default=None,
|
||||
)
|
||||
num_moe_experts_for_export: Optional[int] = Field(
|
||||
description="If set, only this many MOE experts are traced during torch.export, "
|
||||
"and the graph is expanded to include all experts afterwards. "
|
||||
"This can dramatically speed up export for large MOE models (e.g. 256 experts). "
|
||||
"Recommended value: 2.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -183,6 +190,7 @@ class ExportToGM(BaseTransform):
|
||||
clone=self.config.clone_state_dict,
|
||||
strict=self.config.strict,
|
||||
patch_list=self.config.patch_list,
|
||||
num_moe_experts_for_export=self.config.num_moe_experts_for_export,
|
||||
)
|
||||
|
||||
# post process the sub graph module
|
||||
|
||||
@ -239,3 +239,223 @@ def test_deduplicate_during_export(model_cls: Type[nn.Module], device_export: st
|
||||
|
||||
# Test loading fc1.weight into gm.fc1.weight. State dict does not contain fc3.weight
|
||||
check_parameter_loading("fc3.weight", "fc1.weight")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MOE export with reduced experts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Ensure the auto_deploy::torch_moe custom op is registered
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401, E402
|
||||
|
||||
|
||||
class SimpleMoEForExport(nn.Module):
|
||||
"""A minimal MOE model using the ``auto_deploy::torch_moe`` custom op.
|
||||
|
||||
The *expert_attr_name* parameter controls the attribute under which the
|
||||
expert ``nn.ModuleList`` is stored. This lets tests verify that the
|
||||
probe-based expert discovery does **not** rely on the name ``"experts"``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int = 8,
|
||||
hidden_dim: int = 16,
|
||||
inter_dim: int = 32,
|
||||
expert_attr_name: str = "experts",
|
||||
):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.hidden_dim = hidden_dim
|
||||
self.expert_attr_name = expert_attr_name
|
||||
expert_list = nn.ModuleList(
|
||||
[
|
||||
nn.ModuleDict(
|
||||
{
|
||||
"gate_proj": nn.Linear(hidden_dim, inter_dim, bias=False),
|
||||
"down_proj": nn.Linear(inter_dim, hidden_dim, bias=False),
|
||||
"up_proj": nn.Linear(hidden_dim, inter_dim, bias=False),
|
||||
}
|
||||
)
|
||||
for _ in range(num_experts)
|
||||
]
|
||||
)
|
||||
# Store under a configurable attribute name so tests can verify
|
||||
# that the probe does NOT rely on the name being "experts".
|
||||
setattr(self, expert_attr_name, expert_list)
|
||||
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
|
||||
@property
|
||||
def _expert_list(self) -> nn.ModuleList:
|
||||
return getattr(self, self.expert_attr_name)
|
||||
|
||||
def forward(self, x):
|
||||
experts = self._expert_list
|
||||
router_logits = self.gate(x)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, 2, dim=-1)
|
||||
routing_weights = routing_weights.to(x.dtype)
|
||||
return torch.ops.auto_deploy.torch_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w1_weight=[e["gate_proj"].weight for e in experts],
|
||||
w2_weight=[e["down_proj"].weight for e in experts],
|
||||
w3_weight=[e["up_proj"].weight for e in experts],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expert_attr_name", ["experts", "mlp_bank"])
|
||||
@pytest.mark.parametrize("num_experts", [4, 8])
|
||||
@pytest.mark.parametrize("num_moe_experts_for_export", [1, 2])
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
def test_moe_export_with_reduced_experts(
|
||||
num_experts, num_moe_experts_for_export, device, expert_attr_name
|
||||
):
|
||||
"""Export with fewer experts then expand — result must match full export."""
|
||||
mod = SimpleMoEForExport(num_experts=num_experts, expert_attr_name=expert_attr_name).to(device)
|
||||
mod.eval()
|
||||
x = torch.randn(4, mod.hidden_dim, device=device)
|
||||
|
||||
# Full export (baseline)
|
||||
gm_full = torch_export_to_gm(mod, (x,))
|
||||
|
||||
# Export with reduced experts
|
||||
gm_reduced = torch_export_to_gm(
|
||||
mod,
|
||||
(x,),
|
||||
num_moe_experts_for_export=num_moe_experts_for_export,
|
||||
)
|
||||
|
||||
# --- structural check: both graphs must have the right expert count ---
|
||||
def _count_moe_experts(gm):
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and "torch_moe" in str(node.target):
|
||||
return len(node.args[3]) # w1_weight list length
|
||||
return 0
|
||||
|
||||
assert _count_moe_experts(gm_full) == num_experts
|
||||
assert _count_moe_experts(gm_reduced) == num_experts
|
||||
|
||||
# --- numerical check: outputs must match ---
|
||||
y_full = gm_full(x)
|
||||
y_reduced = gm_reduced(x)
|
||||
assert all_close(y_full, y_reduced), "Reduced-expert export output differs from full export"
|
||||
|
||||
# --- state-dict round-trip: loading the original weights must work ---
|
||||
sd = mod.state_dict()
|
||||
gm_reduced.load_state_dict(sd, strict=False)
|
||||
y_loaded = gm_reduced(x)
|
||||
assert all_close(y_loaded, y_full), "Output after state-dict reload differs"
|
||||
|
||||
# --- verify source model is unmodified ---
|
||||
assert len(mod._expert_list) == num_experts, "Source model experts were not restored"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real-model MOE export: GLM4 MoE Lite
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_glm4_moe_lite import (
|
||||
Glm4MoeLiteConfig,
|
||||
Glm4MoeLiteForCausalLM,
|
||||
)
|
||||
|
||||
_HAS_GLM4 = True
|
||||
except ImportError:
|
||||
_HAS_GLM4 = False
|
||||
|
||||
|
||||
if _HAS_GLM4:
|
||||
|
||||
def _make_tiny_glm4_config(n_routed_experts: int = 8) -> Glm4MoeLiteConfig:
|
||||
"""Create a minimal ``Glm4MoeLiteConfig`` suitable for unit tests."""
|
||||
return Glm4MoeLiteConfig(
|
||||
vocab_size=256,
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
q_lora_rank=32,
|
||||
kv_lora_rank=32,
|
||||
qk_nope_head_dim=12,
|
||||
qk_rope_head_dim=4,
|
||||
v_head_dim=16,
|
||||
n_routed_experts=n_routed_experts,
|
||||
n_shared_experts=1,
|
||||
num_experts_per_tok=2,
|
||||
moe_intermediate_size=64,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
routed_scaling_factor=1.0,
|
||||
norm_topk_prob=True,
|
||||
first_k_dense_replace=1, # layer 0 = dense MLP, layer 1 = MoE
|
||||
max_position_embeddings=128,
|
||||
rope_scaling=None,
|
||||
pad_token_id=0,
|
||||
)
|
||||
|
||||
def _count_moe_experts_in_graph(gm: GraphModule) -> int:
|
||||
"""Return the number of experts in the first ``torch_moe`` call in *gm*."""
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and "torch_moe" in str(node.target):
|
||||
return len(node.args[3]) # w1_weight list length
|
||||
return 0
|
||||
|
||||
@pytest.mark.skipif(not _HAS_GLM4, reason="GLM4 MoE Lite model not available on this branch")
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="GLM4 MoE Lite requires CUDA (uses noaux_tc_op)"
|
||||
)
|
||||
@pytest.mark.parametrize("n_routed_experts", [8, 16])
|
||||
@pytest.mark.parametrize("num_moe_experts_for_export", [2])
|
||||
def test_glm4_moe_lite_export_with_reduced_experts(
|
||||
n_routed_experts, num_moe_experts_for_export
|
||||
):
|
||||
"""Export a tiny ``Glm4MoeLiteForCausalLM`` with reduced experts and verify
|
||||
that the expanded graph has the correct structure and accepts the original
|
||||
state dict.
|
||||
"""
|
||||
# GLM4 MoE Lite uses noaux_tc_op which is CUDA-only, so we must use CUDA device
|
||||
device = "cuda"
|
||||
config = _make_tiny_glm4_config(n_routed_experts=n_routed_experts)
|
||||
model = Glm4MoeLiteForCausalLM(config).to(device)
|
||||
model.eval()
|
||||
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 8), device=device)
|
||||
position_ids = torch.arange(8, device=device).unsqueeze(0)
|
||||
sample_kwargs = {"input_ids": input_ids, "position_ids": position_ids}
|
||||
|
||||
# --- full export (baseline) ---
|
||||
gm_full = torch_export_to_gm(model, kwargs=sample_kwargs)
|
||||
|
||||
# --- export with reduced experts ---
|
||||
gm_reduced = torch_export_to_gm(
|
||||
model,
|
||||
kwargs=sample_kwargs,
|
||||
num_moe_experts_for_export=num_moe_experts_for_export,
|
||||
)
|
||||
|
||||
# Structural: both graphs must expose all experts
|
||||
assert _count_moe_experts_in_graph(gm_full) == n_routed_experts
|
||||
assert _count_moe_experts_in_graph(gm_reduced) == n_routed_experts
|
||||
|
||||
# State-dict keys must match between full and reduced exports
|
||||
full_keys = set(gm_full.state_dict().keys())
|
||||
reduced_keys = set(gm_reduced.state_dict().keys())
|
||||
assert full_keys == reduced_keys, (
|
||||
f"State-dict key mismatch.\n"
|
||||
f" Only in full: {full_keys - reduced_keys}\n"
|
||||
f" Only in reduced: {reduced_keys - full_keys}"
|
||||
)
|
||||
|
||||
# Load the original model weights into the reduced export graph
|
||||
gm_reduced.load_state_dict(model.state_dict(), strict=False)
|
||||
|
||||
# Source model must be fully restored
|
||||
for name, mod in model.named_modules():
|
||||
if hasattr(mod, "experts") and isinstance(mod.experts, nn.ModuleList):
|
||||
assert len(mod.experts) == n_routed_experts, (
|
||||
f"Expert list in '{name}' was not restored to {n_routed_experts}"
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user