[#10345][perf] Enable multi-stream MOE for super. Also adds multi-stream MLA attn (#11520)

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
This commit is contained in:
Suyog Gupta 2026-02-15 15:07:56 -08:00 committed by GitHub
parent fcb7bea07f
commit f3d784c6f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 3659 additions and 600 deletions

View File

@ -4,5 +4,19 @@ max_seq_len: 4096
enable_chunked_prefill: true
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64]
transforms:
match_swiglu_pattern:
enabled: true
match_nvfp4_swiglu_pattern:
enabled: true
fuse_nvfp4_moe:
allow_different_input_scales: true
fuse_nvfp4_swiglu:
enabled: true
fuse_swiglu:
enabled: true
multi_stream_moe:
stage: compile
enabled: true
multi_stream_mla_attn:
stage: compile
enabled: true

View File

@ -37,7 +37,7 @@ transforms:
"fc2_latent_proj": "gather"
multi_stream_moe:
stage: compile
enabled: false
enabled: true
gather_logits_before_lm_head:
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
enabled: true

View File

@ -50,6 +50,7 @@ transforms:
expected_layout: bsnd
match_rmsnorm_pattern:
stage: pattern_matcher
run_shape_prop: true
match_l2norm_pattern:
stage: pattern_matcher
############################################################################################
@ -75,6 +76,18 @@ transforms:
stage: pattern_matcher
quantize_nvfp4_from_graph:
stage: pattern_matcher
# SwiGLU pattern matching must run AFTER quantization transforms. For pre-quantized
# checkpoints (e.g., NVFP4), quantization converts torch_linear_simple ops to quantized
# ops first, and then match_nvfp4_swiglu_pattern captures the NVFP4 SwiGLU pattern.
# For non-quantized models, quantization transforms are no-ops, so match_swiglu_pattern
# proceeds normally.
match_swiglu_pattern:
stage: pattern_matcher
enabled: false
match_nvfp4_swiglu_pattern:
stage: pattern_matcher
requires_shape_prop: true
enabled: false
quantize_fp8_moe:
stage: pattern_matcher
quantize_nvfp4_moe:
@ -126,6 +139,8 @@ transforms:
fuse_nvfp4_linear:
stage: post_load_fusion
backend: trtllm
fuse_nvfp4_swiglu:
stage: post_load_fusion
fuse_moe:
stage: post_load_fusion
expect_mem_change: true
@ -149,6 +164,9 @@ transforms:
fuse_l2norm:
stage: post_load_fusion
backend: fla
fuse_swiglu:
stage: post_load_fusion
enabled: false
fuse_add_rms_norm:
stage: post_load_fusion
enabled: true
@ -200,6 +218,9 @@ transforms:
multi_stream_moe:
stage: compile
enabled: false
multi_stream_mla_attn:
stage: compile
enabled: false
compile_model:
stage: compile
expect_mem_change: true

View File

@ -18,9 +18,11 @@
This module provides linear layer implementations:
- linear: Linear layer operations
- torch_router: MoE router operations
- swiglu: SwiGLU MLP custom operations
"""
__all__ = [
"linear",
"torch_router",
"swiglu",
]

View File

@ -0,0 +1,311 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SwiGLU MLP custom operations for graph transformation.
This module provides custom operators for SwiGLU MLP fusion:
- torch_swiglu_mlp: Intermediate representation after pattern matching
- fused_swiglu_mlp: Fused implementation with concatenated gate+up weights
"""
from typing import Optional
import torch
import torch.nn.functional as F
try:
from flashinfer.activation import silu_and_mul as _flashinfer_silu_and_mul
except ImportError:
_flashinfer_silu_and_mul = None
def _silu_and_mul(x: torch.Tensor) -> torch.Tensor:
"""SwiGLU activation: split x in half, apply silu to first half, multiply with second half.
Uses FlashInfer's fused kernel when available, falls back to manual implementation.
"""
if _flashinfer_silu_and_mul is not None:
return _flashinfer_silu_and_mul(x)
gate, up = x.chunk(2, dim=-1)
return F.silu(gate) * up
@torch.library.custom_op("auto_deploy::torch_swiglu_mlp", mutates_args=())
def torch_swiglu_mlp(
input: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
gate_bias: Optional[torch.Tensor],
up_bias: Optional[torch.Tensor],
down_bias: Optional[torch.Tensor],
) -> torch.Tensor:
"""Standardized SwiGLU MLP operation.
Computes: silu(x @ gate.T + gate_bias) * (x @ up.T + up_bias) @ down.T + down_bias
This is the intermediate representation used after pattern matching,
before weight fusion is applied.
Args:
input: Input tensor of shape [..., hidden_size].
gate_weight: Gate projection weight of shape [intermediate_size, hidden_size].
up_weight: Up projection weight of shape [intermediate_size, hidden_size].
down_weight: Down projection weight of shape [hidden_size, intermediate_size].
gate_bias: Optional gate projection bias of shape [intermediate_size].
up_bias: Optional up projection bias of shape [intermediate_size].
down_bias: Optional down projection bias of shape [hidden_size].
Returns:
Output tensor of shape [..., hidden_size].
"""
gate_out = F.linear(input, gate_weight, gate_bias)
up_out = F.linear(input, up_weight, up_bias)
hidden = F.silu(gate_out) * up_out
return F.linear(hidden, down_weight, down_bias)
@torch_swiglu_mlp.register_fake
def _(
input: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
gate_bias: Optional[torch.Tensor],
up_bias: Optional[torch.Tensor],
down_bias: Optional[torch.Tensor],
) -> torch.Tensor:
"""Fake implementation for tracing."""
# Output shape is [..., hidden_size] where hidden_size = down_weight.shape[0]
output_shape = list(input.shape[:-1]) + [down_weight.shape[0]]
return input.new_empty(output_shape, dtype=input.dtype)
@torch.library.custom_op("auto_deploy::fused_swiglu_mlp", mutates_args=())
def fused_swiglu_mlp(
input: torch.Tensor,
gate_up_weight: torch.Tensor,
down_weight: torch.Tensor,
gate_up_bias: Optional[torch.Tensor],
down_bias: Optional[torch.Tensor],
) -> torch.Tensor:
"""Fused SwiGLU MLP with concatenated gate+up weights.
Performs a single matmul for gate and up projections, then splits the result.
Computes: silu(gate_out) * up_out @ down.T + down_bias
where gate_out, up_out = split(x @ gate_up.T + gate_up_bias)
Args:
input: Input tensor of shape [..., hidden_size].
gate_up_weight: Concatenated gate+up weight of shape [2*intermediate_size, hidden_size].
down_weight: Down projection weight of shape [hidden_size, intermediate_size].
gate_up_bias: Optional concatenated gate+up bias of shape [2*intermediate_size].
down_bias: Optional down projection bias of shape [hidden_size].
Returns:
Output tensor of shape [..., hidden_size].
"""
# Single matmul for both gate and up projections
gate_up_out = F.linear(input, gate_up_weight, gate_up_bias)
# Apply SwiGLU activation: split, silu(gate) * up (uses FlashInfer when available)
hidden = _silu_and_mul(gate_up_out)
# Down projection
return F.linear(hidden, down_weight, down_bias)
@fused_swiglu_mlp.register_fake
def _(
input: torch.Tensor,
gate_up_weight: torch.Tensor,
down_weight: torch.Tensor,
gate_up_bias: Optional[torch.Tensor],
down_bias: Optional[torch.Tensor],
) -> torch.Tensor:
"""Fake implementation for tracing."""
# Output shape is [..., hidden_size] where hidden_size = down_weight.shape[0]
output_shape = list(input.shape[:-1]) + [down_weight.shape[0]]
return input.new_empty(output_shape, dtype=input.dtype)
# ── NVFP4 quantized SwiGLU ops ──────────────────────────────────────────────
@torch.library.custom_op("auto_deploy::torch_nvfp4_swiglu_mlp", mutates_args=())
def torch_nvfp4_swiglu_mlp(
input: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
gate_input_scale: torch.Tensor,
gate_weight_scale: torch.Tensor,
gate_alpha: torch.Tensor,
up_input_scale: torch.Tensor,
up_weight_scale: torch.Tensor,
up_alpha: torch.Tensor,
down_input_scale: torch.Tensor,
down_weight_scale: torch.Tensor,
down_alpha: torch.Tensor,
) -> torch.Tensor:
"""NVFP4 quantized SwiGLU MLP operation (intermediate representation).
Computes: silu(nvfp4_linear(x, gate)) * nvfp4_linear(x, up) -> nvfp4_linear(down)
This is the intermediate representation used after pattern matching for NVFP4
quantized checkpoints, before gate+up weight fusion is applied.
Args:
input: Input tensor of shape [..., hidden_size].
gate_weight: FP4 packed gate weight [intermediate_size, hidden_size/2] uint8.
up_weight: FP4 packed up weight [intermediate_size, hidden_size/2] uint8.
down_weight: FP4 packed down weight [hidden_size, intermediate_size/2] uint8.
gate_input_scale: Input scale for gate projection.
gate_weight_scale: Per-block weight scale for gate projection.
gate_alpha: Alpha (combined scale) for gate projection.
up_input_scale: Input scale for up projection.
up_weight_scale: Per-block weight scale for up projection.
up_alpha: Alpha (combined scale) for up projection.
down_input_scale: Input scale for down projection.
down_weight_scale: Per-block weight scale for down projection.
down_alpha: Alpha (combined scale) for down projection.
Returns:
Output tensor of shape [..., hidden_size].
"""
gate_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
input,
gate_weight,
None,
input_scale=[gate_input_scale],
weight_scale=[gate_weight_scale, gate_alpha],
input_zp=[],
weight_zp=[],
)
up_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
input,
up_weight,
None,
input_scale=[up_input_scale],
weight_scale=[up_weight_scale, up_alpha],
input_zp=[],
weight_zp=[],
)
hidden = F.silu(gate_out) * up_out
return torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
hidden,
down_weight,
None,
input_scale=[down_input_scale],
weight_scale=[down_weight_scale, down_alpha],
input_zp=[],
weight_zp=[],
)
@torch_nvfp4_swiglu_mlp.register_fake
def _(
input: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
gate_input_scale: torch.Tensor,
gate_weight_scale: torch.Tensor,
gate_alpha: torch.Tensor,
up_input_scale: torch.Tensor,
up_weight_scale: torch.Tensor,
up_alpha: torch.Tensor,
down_input_scale: torch.Tensor,
down_weight_scale: torch.Tensor,
down_alpha: torch.Tensor,
) -> torch.Tensor:
"""Fake implementation for tracing."""
# Output shape: [..., hidden_size] where hidden_size = down_weight.shape[0]
output_shape = list(input.shape[:-1]) + [down_weight.shape[0]]
return input.new_empty(output_shape, dtype=input.dtype)
@torch.library.custom_op("auto_deploy::fused_nvfp4_swiglu_mlp", mutates_args=())
def fused_nvfp4_swiglu_mlp(
input: torch.Tensor,
gate_up_weight: torch.Tensor,
down_weight: torch.Tensor,
gate_up_input_scale: torch.Tensor,
gate_up_weight_scale: torch.Tensor,
gate_up_alpha: torch.Tensor,
down_input_scale: torch.Tensor,
down_weight_scale: torch.Tensor,
down_alpha: torch.Tensor,
) -> torch.Tensor:
"""Fused NVFP4 SwiGLU MLP with concatenated gate+up weights.
Performs a single NVFP4 matmul for gate and up projections, then splits,
applies SwiGLU activation, and does the down NVFP4 matmul.
Args:
input: Input tensor of shape [..., hidden_size].
gate_up_weight: Concatenated FP4 packed gate+up weight
[2*intermediate_size, hidden_size/2] uint8.
down_weight: FP4 packed down weight [hidden_size, intermediate_size/2] uint8.
gate_up_input_scale: Shared input scale for gate+up projection.
gate_up_weight_scale: Concatenated per-block weight scale for gate+up.
gate_up_alpha: Shared alpha for gate+up projection.
down_input_scale: Input scale for down projection.
down_weight_scale: Per-block weight scale for down projection.
down_alpha: Alpha for down projection.
Returns:
Output tensor of shape [..., hidden_size].
"""
# Single NVFP4 linear for both gate and up projections
gate_up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
input,
gate_up_weight,
bias=None,
input_scale=gate_up_input_scale,
weight_scale=gate_up_weight_scale,
alpha=gate_up_alpha,
)
# Apply SwiGLU activation: split, silu(gate) * up (uses FlashInfer when available)
hidden = _silu_and_mul(gate_up_out)
# Down projection
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
hidden,
down_weight,
bias=None,
input_scale=down_input_scale,
weight_scale=down_weight_scale,
alpha=down_alpha,
)
@fused_nvfp4_swiglu_mlp.register_fake
def _(
input: torch.Tensor,
gate_up_weight: torch.Tensor,
down_weight: torch.Tensor,
gate_up_input_scale: torch.Tensor,
gate_up_weight_scale: torch.Tensor,
gate_up_alpha: torch.Tensor,
down_input_scale: torch.Tensor,
down_weight_scale: torch.Tensor,
down_alpha: torch.Tensor,
) -> torch.Tensor:
"""Fake implementation for tracing."""
# Output shape: [..., hidden_size] where hidden_size = down_weight.shape[0]
output_shape = list(input.shape[:-1]) + [down_weight.shape[0]]
return input.new_empty(output_shape, dtype=input.dtype)

View File

@ -171,7 +171,7 @@ def _find_moe_module_lists(
def _reduce_moe_experts(
model: nn.Module,
min_num_experts: int,
num_moe_experts_for_export: int,
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
@ -179,19 +179,21 @@ def _reduce_moe_experts(
Uses a probe forward pass to identify which ``nn.ModuleList`` instances
feed into ``torch_moe``-family custom ops (see :func:`_find_moe_module_lists`),
then truncates each to *min_num_experts* entries. The returned list of dicts
carries the metadata needed by :func:`_restore_moe_experts` and
:func:`_expand_moe_experts_in_graph`.
then truncates each to *num_moe_experts_for_export* entries. The returned
list of dicts carries the metadata needed by :func:`_restore_moe_experts`
and :func:`_expand_moe_experts_in_graph`.
"""
if min_num_experts < 1:
raise ValueError(f"min_num_experts must be >= 1, got {min_num_experts}")
if num_moe_experts_for_export < 1:
raise ValueError(
f"num_moe_experts_for_export must be >= 1, got {num_moe_experts_for_export}"
)
moe_lists = _find_moe_module_lists(model, args, kwargs)
reductions: List[Dict[str, Any]] = []
for path, (parent, attr_name, mod_list) in moe_lists.items():
orig_count = len(mod_list)
if orig_count <= min_num_experts:
if orig_count <= num_moe_experts_for_export:
continue
reductions.append(
@ -203,10 +205,10 @@ def _reduce_moe_experts(
"expert_prefix": path,
}
)
setattr(parent, attr_name, nn.ModuleList(list(mod_list[:min_num_experts])))
setattr(parent, attr_name, nn.ModuleList(list(mod_list[:num_moe_experts_for_export])))
ad_logger.info(
f"Reduced MOE experts in '{path}' from {orig_count} to "
f"{min_num_experts} for faster export"
f"{num_moe_experts_for_export} for faster export"
)
return reductions
@ -252,7 +254,12 @@ def _expand_moe_experts_in_graph(
if not reductions:
return
# MOE ops whose arguments include per-expert weight lists (from index 3 onward)
# MOE ops whose arguments include per-expert weight lists.
# All these ops share the same first 3 positional args (x, selected_experts,
# routing_weights) which are plain Tensors, followed by one or more
# List[Tensor] args that hold per-expert weights/scales. We use the op
# schema to discover which arguments are Tensor[] rather than hard-coding
# the starting index.
moe_ops = {
torch.ops.auto_deploy.torch_moe,
torch.ops.auto_deploy.torch_quant_fp8_moe,
@ -266,11 +273,18 @@ def _expand_moe_experts_in_graph(
if not is_op(node, moe_ops):
continue
# Collect indices of list-of-node arguments (expert weight/scale lists)
# Collect indices of List[Tensor] arguments from the op schema these
# are the per-expert weight / scale lists.
op = node.target
schema = op._schema if hasattr(op, "_schema") else next(iter(op._schemas.values()))
_tensor_list_types = ("Tensor[]", "List[Tensor]")
list_arg_indices = [
i
for i in range(3, len(node.args))
if isinstance(node.args[i], (list, tuple)) and len(node.args[i]) > 0
for i, arg_meta in enumerate(schema.arguments)
if any(t in str(arg_meta.type) for t in _tensor_list_types)
and i < len(node.args)
and isinstance(node.args[i], (list, tuple))
and len(node.args[i]) > 0
]
if not list_arg_indices:
continue
@ -673,6 +687,7 @@ def torch_export_to_gm(
return egm
# Optionally reduce MOE experts for faster export tracing
# TODO (https://github.com/NVIDIA/TensorRT-LLM/issues/7547): Reuse the export patch system
moe_reductions: List[Dict[str, Any]] = []
if num_moe_experts_for_export is not None:
moe_reductions = _reduce_moe_experts(model, num_moe_experts_for_export, args, kwargs)

View File

@ -444,7 +444,7 @@ class Glm4MoeLiteMoE(nn.Module):
if self.shared_experts is not None:
final_hidden_states = final_hidden_states + self.shared_experts(identity)
return final_hidden_states.to(hidden_states.dtype)
return final_hidden_states
class Glm4MoeLiteAttention(nn.Module):

View File

@ -0,0 +1,618 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Graph transforms for SwiGLU MLP fusion.
This module provides two-stage transformation for SwiGLU MLP:
1. MatchSwiGLUPattern: Detects SwiGLU patterns and replaces with torch_swiglu_mlp
2. FuseSwiGLU: Fuses gate+up weights into a single concatenated matmul
The SwiGLU pattern is: silu(x @ gate.T) * (x @ up.T) @ down.T
"""
from typing import Tuple, Type
import torch
from pydantic import Field
from torch.fx import GraphModule, Node
# Import the custom ops to ensure they are registered and for use in replacements
from ...custom_ops.linear.swiglu import torch_swiglu_mlp
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import (
del_attr_by_name,
delete_all_unused_submodules,
eliminate_dead_code,
get_attr_by_name,
)
from ...utils.node_utils import is_op
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from ..interface import (
BaseTransform,
SharedConfig,
TransformConfig,
TransformInfo,
TransformRegistry,
)
def _try_free_attr_node(gm: GraphModule, graph, attr_node: Node) -> None:
"""Erase a get_attr node and eagerly delete its module attribute if it has no users.
This is used to free unfused weight tensors as soon as they are no longer
referenced in the graph, avoiding a temporary memory spike that would occur
if cleanup were deferred until after all nodes are processed.
"""
if attr_node is not None and attr_node.op == "get_attr" and len(attr_node.users) == 0:
target = attr_node.target
graph.erase_node(attr_node)
del_attr_by_name(gm, target)
def _swiglu_pattern_no_bias(x, gate_weight, up_weight, down_weight):
"""Pattern for SwiGLU MLP without biases.
Matches: silu(linear(x, gate_weight, None)) * linear(x, up_weight, None) -> linear(down_weight, None)
"""
gate_out = torch.ops.auto_deploy.torch_linear_simple.default(x, gate_weight, None)
up_out = torch.ops.auto_deploy.torch_linear_simple.default(x, up_weight, None)
silu_out = torch.ops.aten.silu.default(gate_out)
mul_out = torch.ops.aten.mul.Tensor(silu_out, up_out)
down_out = torch.ops.auto_deploy.torch_linear_simple.default(mul_out, down_weight, None)
return down_out
def _swiglu_replacement_no_bias(x, gate_weight, up_weight, down_weight):
"""Replacement for SwiGLU pattern without biases."""
# Call the Python wrapper directly, not via torch.ops.auto_deploy
# This ensures proper FakeTensor mode handling during tracing
return torch_swiglu_mlp(x, gate_weight, up_weight, down_weight, None, None, None)
def _swiglu_pattern_with_bias(
x, gate_weight, up_weight, down_weight, gate_bias, up_bias, down_bias
):
"""Pattern for SwiGLU MLP with biases.
Matches: silu(linear(x, gate_weight, gate_bias)) * linear(x, up_weight, up_bias) -> linear(down_weight, down_bias)
"""
gate_out = torch.ops.auto_deploy.torch_linear_simple.default(x, gate_weight, gate_bias)
up_out = torch.ops.auto_deploy.torch_linear_simple.default(x, up_weight, up_bias)
silu_out = torch.ops.aten.silu.default(gate_out)
mul_out = torch.ops.aten.mul.Tensor(silu_out, up_out)
down_out = torch.ops.auto_deploy.torch_linear_simple.default(mul_out, down_weight, down_bias)
return down_out
def _swiglu_replacement_with_bias(
x, gate_weight, up_weight, down_weight, gate_bias, up_bias, down_bias
):
"""Replacement for SwiGLU pattern with biases."""
# Call the Python wrapper directly, not via torch.ops.auto_deploy
# This ensures proper FakeTensor mode handling during tracing
return torch_swiglu_mlp(x, gate_weight, up_weight, down_weight, gate_bias, up_bias, down_bias)
@TransformRegistry.register("match_swiglu_pattern")
class MatchSwiGLUPattern(BaseTransform):
"""Matches SwiGLU MLP patterns and replaces with torch_swiglu_mlp op.
This transform runs in the pattern_matcher stage and detects the following pattern:
silu(x @ gate.T) * (x @ up.T) @ down.T
And replaces it with a single torch_swiglu_mlp op that can be fused later.
Uses ADPatternMatcherPass for declarative pattern matching.
"""
config: TransformConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return TransformConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
patterns = ADPatternMatcherPass()
# Dummy shapes for tracing - shapes don't matter for matching
hidden, intermediate = 128, 256
# Pattern 1: SwiGLU without biases (most common case)
dummy_args_no_bias = [
torch.randn(2, hidden, device="meta", dtype=torch.float16), # x
torch.randn(intermediate, hidden, device="meta", dtype=torch.float16), # gate_weight
torch.randn(intermediate, hidden, device="meta", dtype=torch.float16), # up_weight
torch.randn(hidden, intermediate, device="meta", dtype=torch.float16), # down_weight
]
register_ad_pattern(
search_fn=_swiglu_pattern_no_bias,
replace_fn=_swiglu_replacement_no_bias,
patterns=patterns,
dummy_args=dummy_args_no_bias,
)
# Pattern 2: SwiGLU with biases
dummy_args_with_bias = [
torch.randn(2, hidden, device="meta", dtype=torch.float16), # x
torch.randn(intermediate, hidden, device="meta", dtype=torch.float16), # gate_weight
torch.randn(intermediate, hidden, device="meta", dtype=torch.float16), # up_weight
torch.randn(hidden, intermediate, device="meta", dtype=torch.float16), # down_weight
torch.randn(intermediate, device="meta", dtype=torch.float16), # gate_bias
torch.randn(intermediate, device="meta", dtype=torch.float16), # up_bias
torch.randn(hidden, device="meta", dtype=torch.float16), # down_bias
]
register_ad_pattern(
search_fn=_swiglu_pattern_with_bias,
replace_fn=_swiglu_replacement_with_bias,
patterns=patterns,
dummy_args=dummy_args_with_bias,
)
num_matches = patterns.apply(gm.graph)
if num_matches > 0:
gm.recompile()
info = TransformInfo(
skipped=False,
num_matches=num_matches,
is_clean=num_matches == 0,
has_valid_shapes=num_matches == 0,
)
return gm, info
class FuseSwiGLUConfig(TransformConfig):
"""Configuration for the SwiGLU fusion transform."""
enabled: bool = Field(
default=True,
description="Whether to enable SwiGLU fusion.",
)
@TransformRegistry.register("fuse_swiglu")
class FuseSwiGLU(BaseTransform):
"""Fuses torch_swiglu_mlp ops by concatenating gate and up weights.
This transform runs in the post_load_fusion stage and replaces torch_swiglu_mlp ops
with fused_swiglu_mlp ops that use a single concatenated gate+up weight matrix.
This reduces memory bandwidth by performing a single matmul instead of two
separate matmuls for gate and up projections.
"""
config: FuseSwiGLUConfig
@classmethod
def get_config_class(cls) -> Type[FuseSwiGLUConfig]:
return FuseSwiGLUConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
if not self.config.enabled:
return gm, TransformInfo(skipped=True, num_matches=0)
graph = gm.graph
cnt = 0
fused_weight_idx = 0
for node in list(graph.nodes):
if not is_op(node, torch.ops.auto_deploy.torch_swiglu_mlp.default):
continue
# Extract args: (input, gate_weight, up_weight, down_weight, gate_bias, up_bias, down_bias)
input_node = node.args[0]
gate_weight_node = node.args[1]
up_weight_node = node.args[2]
down_weight_node = node.args[3]
gate_bias_node = node.args[4] if len(node.args) > 4 else None
up_bias_node = node.args[5] if len(node.args) > 5 else None
down_bias_node = node.args[6] if len(node.args) > 6 else None
# Get the actual weight tensors
gate_weight = get_attr_by_name(gm, gate_weight_node.target)
up_weight = get_attr_by_name(gm, up_weight_node.target)
# Concatenate gate and up weights: [intermediate, hidden] -> [2*intermediate, hidden]
gate_up_weight = torch.cat([gate_weight, up_weight], dim=0)
# Create new attribute for the fused weight
fused_weight_name = f"fused_swiglu_gate_up_{fused_weight_idx}"
gm.register_buffer(fused_weight_name, gate_up_weight)
# Handle biases
gate_up_bias_node = None
fused_bias_name = None
if gate_bias_node is not None and gate_bias_node.op == "get_attr":
gate_bias = get_attr_by_name(gm, gate_bias_node.target)
up_bias = get_attr_by_name(gm, up_bias_node.target) if up_bias_node else None
if up_bias is not None:
gate_up_bias = torch.cat([gate_bias, up_bias], dim=0)
fused_bias_name = f"fused_swiglu_gate_up_bias_{fused_weight_idx}"
gm.register_buffer(fused_bias_name, gate_up_bias)
# Create get_attr node for the fused weight
with graph.inserting_before(node):
fused_weight_node = graph.get_attr(fused_weight_name)
if fused_bias_name is not None:
gate_up_bias_node = graph.get_attr(fused_bias_name)
# Create the fused_swiglu_mlp node
with graph.inserting_after(node):
fused_node: Node = graph.call_function(
torch.ops.auto_deploy.fused_swiglu_mlp.default,
args=(
input_node,
fused_weight_node,
down_weight_node,
gate_up_bias_node,
down_bias_node,
),
)
# Replace uses and erase old node
node.replace_all_uses_with(fused_node)
graph.erase_node(node)
# Eagerly free unfused weight/bias tensors that are no longer referenced
# to avoid a temporary memory spike from holding both fused and unfused
# copies simultaneously across all layers.
_try_free_attr_node(gm, graph, gate_weight_node)
_try_free_attr_node(gm, graph, up_weight_node)
_try_free_attr_node(gm, graph, gate_bias_node)
_try_free_attr_node(gm, graph, up_bias_node)
fused_weight_idx += 1
cnt += 1
if cnt > 0:
gm.recompile()
# Clean up any remaining dead code and unused submodules
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)
info = TransformInfo(
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
)
return gm, info
# ── NVFP4 quantized SwiGLU pattern matching and fusion ──────────────────────
from ...custom_ops.linear.swiglu import torch_nvfp4_swiglu_mlp # noqa: E402
def _nvfp4_swiglu_pattern_no_bias(
x,
gate_weight,
gate_input_scale,
gate_weight_scale,
gate_alpha,
up_weight,
up_input_scale,
up_weight_scale,
up_alpha,
down_weight,
down_input_scale,
down_weight_scale,
down_alpha,
):
"""Pattern for NVFP4 quantized SwiGLU MLP without biases.
Matches: silu(nvfp4_linear(x, gate)) * nvfp4_linear(x, up) -> nvfp4_linear(down)
"""
gate_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear.default(
x,
gate_weight,
None,
input_scale=[gate_input_scale],
weight_scale=[gate_weight_scale, gate_alpha],
input_zp=[],
weight_zp=[],
)
up_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear.default(
x,
up_weight,
None,
input_scale=[up_input_scale],
weight_scale=[up_weight_scale, up_alpha],
input_zp=[],
weight_zp=[],
)
silu_out = torch.ops.aten.silu.default(gate_out)
mul_out = torch.ops.aten.mul.Tensor(silu_out, up_out)
down_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear.default(
mul_out,
down_weight,
None,
input_scale=[down_input_scale],
weight_scale=[down_weight_scale, down_alpha],
input_zp=[],
weight_zp=[],
)
return down_out
def _nvfp4_swiglu_replacement_no_bias(
x,
gate_weight,
gate_input_scale,
gate_weight_scale,
gate_alpha,
up_weight,
up_input_scale,
up_weight_scale,
up_alpha,
down_weight,
down_input_scale,
down_weight_scale,
down_alpha,
):
"""Replacement for NVFP4 quantized SwiGLU pattern without biases."""
return torch_nvfp4_swiglu_mlp(
x,
gate_weight,
up_weight,
down_weight,
gate_input_scale,
gate_weight_scale,
gate_alpha,
up_input_scale,
up_weight_scale,
up_alpha,
down_input_scale,
down_weight_scale,
down_alpha,
)
@TransformRegistry.register("match_nvfp4_swiglu_pattern")
class MatchNVFP4SwiGLUPattern(BaseTransform):
"""Matches NVFP4 quantized SwiGLU MLP patterns and replaces with torch_nvfp4_swiglu_mlp.
This transform runs in the pattern_matcher stage AFTER quantize_nvfp4_linear_from_config
has converted torch_linear_simple ops to torch_fake_quant_nvfp4_linear ops.
It detects the following NVFP4 pattern:
silu(nvfp4_linear(x, gate)) * nvfp4_linear(x, up) -> nvfp4_linear(down)
And replaces it with a single torch_nvfp4_swiglu_mlp op that can be fused later.
"""
config: TransformConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return TransformConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
patterns = ADPatternMatcherPass()
# FP4 shape params for dummy args (shapes don't matter for matching)
N = 32 # intermediate_size
K_packed = 32 # hidden_size / 2 (FP4 packing)
K_eff = 2 * K_packed # actual hidden_size
N_down = K_eff # hidden_size (output of down proj)
K_down_packed = N // 2 # intermediate_size / 2 (down proj input)
# Weight scale sizes (per-block scale: N * K / 16)
gate_cutlass_len = N * (K_eff // 16)
down_cutlass_len = N_down * (N // 16)
x = torch.randn(2, K_eff, device="meta", dtype=torch.float16)
# Gate args
gate_w = torch.randint(0, 255, (N, K_packed), device="meta", dtype=torch.uint8)
gate_is = torch.tensor(0.01, device="meta", dtype=torch.float32)
gate_ws = torch.randint(0, 255, (gate_cutlass_len,), device="meta", dtype=torch.uint8)
gate_a = torch.tensor(1.2345, device="meta", dtype=torch.float32)
# Up args (same shapes as gate)
up_w = torch.randint(0, 255, (N, K_packed), device="meta", dtype=torch.uint8)
up_is = torch.tensor(0.02, device="meta", dtype=torch.float32)
up_ws = torch.randint(0, 255, (gate_cutlass_len,), device="meta", dtype=torch.uint8)
up_a = torch.tensor(2.3456, device="meta", dtype=torch.float32)
# Down args
down_w = torch.randint(0, 255, (N_down, K_down_packed), device="meta", dtype=torch.uint8)
down_is = torch.tensor(0.03, device="meta", dtype=torch.float32)
down_ws = torch.randint(0, 255, (down_cutlass_len,), device="meta", dtype=torch.uint8)
down_a = torch.tensor(3.4567, device="meta", dtype=torch.float32)
dummy_args = [
x,
gate_w,
gate_is,
gate_ws,
gate_a,
up_w,
up_is,
up_ws,
up_a,
down_w,
down_is,
down_ws,
down_a,
]
register_ad_pattern(
search_fn=_nvfp4_swiglu_pattern_no_bias,
replace_fn=_nvfp4_swiglu_replacement_no_bias,
patterns=patterns,
dummy_args=dummy_args,
)
num_matches = patterns.apply(gm.graph)
if num_matches > 0:
gm.recompile()
info = TransformInfo(
skipped=False,
num_matches=num_matches,
is_clean=num_matches == 0,
has_valid_shapes=num_matches == 0,
)
return gm, info
@TransformRegistry.register("fuse_nvfp4_swiglu")
class FuseNVFP4SwiGLU(BaseTransform):
"""Fuses torch_nvfp4_swiglu_mlp ops by concatenating gate and up FP4 weights.
This transform runs in the post_load_fusion stage and replaces torch_nvfp4_swiglu_mlp
ops with fused_nvfp4_swiglu_mlp ops that use a single concatenated gate+up weight matrix.
FP4 weight fusion:
- gate+up packed weights are concatenated along dim=0
- gate+up per-block weight scales are concatenated along dim=0
- gate+up input_scale and alpha must match (shared input)
"""
config: TransformConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return TransformConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
graph = gm.graph
cnt = 0
fused_weight_idx = 0
for node in list(graph.nodes):
if not is_op(node, torch.ops.auto_deploy.torch_nvfp4_swiglu_mlp.default):
continue
# Extract args:
# (input, gate_weight, up_weight, down_weight,
# gate_input_scale, gate_weight_scale, gate_alpha,
# up_input_scale, up_weight_scale, up_alpha,
# down_input_scale, down_weight_scale, down_alpha)
input_node = node.args[0]
gate_weight_node = node.args[1]
up_weight_node = node.args[2]
down_weight_node = node.args[3]
gate_input_scale_node = node.args[4]
gate_weight_scale_node = node.args[5]
gate_alpha_node = node.args[6]
up_input_scale_node = node.args[7]
up_weight_scale_node = node.args[8]
up_alpha_node = node.args[9]
down_input_scale_node = node.args[10]
down_weight_scale_node = node.args[11]
down_alpha_node = node.args[12]
# Get the actual weight tensors
gate_weight = get_attr_by_name(gm, gate_weight_node.target)
up_weight = get_attr_by_name(gm, up_weight_node.target)
# Concatenate gate and up FP4 packed weights along dim=0
gate_up_weight = torch.cat([gate_weight, up_weight], dim=0)
# Get and concatenate weight scales
gate_weight_scale = get_attr_by_name(gm, gate_weight_scale_node.target)
up_weight_scale = get_attr_by_name(gm, up_weight_scale_node.target)
gate_up_weight_scale = torch.cat([gate_weight_scale, up_weight_scale], dim=0)
# Register fused buffers
prefix = f"fused_nvfp4_swiglu_{fused_weight_idx}"
gm.register_buffer(f"{prefix}_gate_up_weight", gate_up_weight)
gm.register_buffer(f"{prefix}_gate_up_weight_scale", gate_up_weight_scale)
# Create get_attr nodes for fused weights/scales
with graph.inserting_before(node):
fused_gate_up_weight_node = graph.get_attr(f"{prefix}_gate_up_weight")
fused_gate_up_weight_scale_node = graph.get_attr(f"{prefix}_gate_up_weight_scale")
# Create the fused_nvfp4_swiglu_mlp node
# Use gate's input_scale and alpha (same as up's since they share input)
with graph.inserting_after(node):
fused_node: Node = graph.call_function(
torch.ops.auto_deploy.fused_nvfp4_swiglu_mlp.default,
args=(
input_node,
fused_gate_up_weight_node,
down_weight_node,
gate_input_scale_node, # shared input_scale for gate+up
fused_gate_up_weight_scale_node,
gate_alpha_node, # shared alpha for gate+up
down_input_scale_node,
down_weight_scale_node,
down_alpha_node,
),
)
# Replace uses and erase old node
node.replace_all_uses_with(fused_node)
graph.erase_node(node)
# Eagerly free unfused weight/scale tensors that are no longer referenced
# to avoid a temporary memory spike from holding both fused and unfused
# copies simultaneously across all layers.
_try_free_attr_node(gm, graph, gate_weight_node)
_try_free_attr_node(gm, graph, up_weight_node)
_try_free_attr_node(gm, graph, gate_weight_scale_node)
_try_free_attr_node(gm, graph, up_weight_scale_node)
_try_free_attr_node(gm, graph, up_input_scale_node)
_try_free_attr_node(gm, graph, up_alpha_node)
fused_weight_idx += 1
cnt += 1
if cnt > 0:
gm.recompile()
# Clean up any remaining dead code and unused submodules
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)
info = TransformInfo(
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
)
return gm, info

View File

@ -9,31 +9,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformation for fusing Add + Cast + RMSNorm."""
"""Transformation for fusing Add + (optional Cast) + RMSNorm via direct FX graph manipulation."""
from typing import Tuple
import operator
from typing import List, Optional, Tuple
import torch
from torch.fx import GraphModule
from torch.fx import GraphModule, Node
from ...custom_ops.normalization.flashinfer_fused_add_rms_norm import flashinfer_fused_add_rms_norm
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from ...utils._graph import eliminate_dead_code
from ...utils.node_utils import is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
@TransformRegistry.register("fuse_add_rms_norm")
class FuseAddRMSNorm(BaseTransform):
"""Fuse (add + cast + RMSNorm) into one fused op.
"""Fuse (add + optional cast + RMSNorm) into one fused op.
Matches:
x = add(input, residual)
y = x.to(dtype)
z = flashinfer_rms_norm(y, weight, eps)
Uses direct FX graph manipulation instead of the inductor pattern matcher
to correctly handle patterns where intermediate nodes (add, rms_norm) have
multiple users in the graph.
Replaces with:
z, x = flashinfer_fused_add_rms_norm(input, residual, weight, eps)
Pattern 1 (without cast):
%add = aten.add(%x, %residual)
%norm = flashinfer_rms_norm(%add, %weight, eps)
Pattern 2 (with cast):
%add = aten.add(%x, %residual)
%cast = aten.to.dtype(%add, bfloat16)
%norm = flashinfer_rms_norm(%cast, %weight, eps)
Both are replaced with:
%fused = flashinfer_fused_add_rms_norm(%x, %residual, %weight, eps)
%norm_out = getitem(%fused, 0) # norm result (replaces %norm)
%add_out = getitem(%fused, 1) # add result (replaces %add)
"""
def _apply(
@ -43,42 +55,102 @@ class FuseAddRMSNorm(BaseTransform):
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
patterns = ADPatternMatcherPass()
graph = gm.graph
num_matches = 0
# Dummy shapes for tracing
bsz, hidden = 2, 128
dummy_args = [
torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # x (bf16)
torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # residual (bf16)
torch.randn(hidden, device="meta", dtype=torch.bfloat16), # weight
1e-5, # eps
]
# --- Step 1: collect (add_node, optional cast_node, norm_node) triples ---
matches: List[Tuple[Node, Optional[Node], Node]] = []
op_ignore_types = {torch.ops.aten.to.dtype: (torch.dtype,)}
scalar_workaround = {"eps": 1e-5}
for node in graph.nodes:
# Match flashinfer_rms_norm (handles both overload packet and .default)
if not is_op(node, torch.ops.auto_deploy.flashinfer_rms_norm):
continue
def _fused_add_norm_pattern(x, residual, weight, eps):
added = torch.ops.aten.add.Tensor(x, residual)
cast = torch.ops.aten.to.dtype(added, torch.bfloat16)
# Note: we assume flashinfer_rms_norm is the target
norm = torch.ops.auto_deploy.flashinfer_rms_norm.default(cast, weight, eps)
return norm, added
input_to_norm = node.args[0]
cast_node: Optional[Node] = None
def _fused_add_norm_replacement(x, residual, weight, eps):
# Use the python wrapper directly, not via torch.ops.auto_deploy
return flashinfer_fused_add_rms_norm(x, residual, weight, eps)
# Check for an optional aten.to.dtype cast between add and norm
if isinstance(input_to_norm, Node) and is_op(input_to_norm, torch.ops.aten.to.dtype):
cast_node = input_to_norm
input_to_norm = cast_node.args[0]
# Register pattern
register_ad_pattern(
search_fn=_fused_add_norm_pattern,
replace_fn=_fused_add_norm_replacement,
patterns=patterns,
dummy_args=dummy_args,
op_ignore_types=op_ignore_types,
scalar_workaround=scalar_workaround,
)
# The (possibly unwrapped) input must be an aten.add.Tensor
if not isinstance(input_to_norm, Node) or not is_op(
input_to_norm, torch.ops.aten.add.Tensor
):
continue
num_matches = patterns.apply(gm.graph)
add_node = input_to_norm
matches.append((add_node, cast_node, node))
# --- Step 2: apply fusions ---
erased: set = set() # track erased node ids to skip stale matches
for add_node, cast_node, norm_node in matches:
# Safety: skip if a node in this match was already consumed
if id(add_node) in erased or id(norm_node) in erased:
continue
# Original operands
add_lhs = add_node.args[0] # e.g. previous residual
add_rhs = add_node.args[1] # e.g. attention/MoE output
weight = norm_node.args[1]
eps = norm_node.args[2]
# Insert the fused call right before the norm node. Using
# inserting_before(norm_node) ensures correct topological order:
# fused_node → norm_out → add_out all appear before norm_node.
with graph.inserting_before(norm_node):
# flashinfer_fused_add_rms_norm(x, residual, weight, eps):
# residual += x → residual becomes add result
# x = rms_norm(residual) → x becomes norm result
# returns (x, residual) = (norm_result, add_result)
fused_node = graph.call_function(
flashinfer_fused_add_rms_norm,
args=(add_rhs, add_lhs, weight, eps),
)
norm_out = graph.call_function(operator.getitem, args=(fused_node, 0))
add_out = graph.call_function(operator.getitem, args=(fused_node, 1))
# Rewire all consumers of the original norm → norm_out
norm_node.replace_all_uses_with(norm_out)
# Erase norm first so cast_node (if present) loses its only user
graph.erase_node(norm_node)
erased.add(id(norm_node))
# Erase cast_node *before* replacing add's uses, otherwise
# replace_all_uses_with would rewrite cast's input to add_out
# which sits after cast in the graph → topological violation.
if cast_node is not None:
if len(cast_node.users) == 0:
graph.erase_node(cast_node)
erased.add(id(cast_node))
else:
# Rare: cast has users besides norm. Redirect them to a
# new cast placed after add_out so the ordering is valid.
with graph.inserting_before(list(cast_node.users)[0]):
new_cast = graph.call_function(
cast_node.target,
args=(add_out, *cast_node.args[1:]),
kwargs=cast_node.kwargs,
)
cast_node.replace_all_uses_with(new_cast)
graph.erase_node(cast_node)
erased.add(id(cast_node))
# Rewire all consumers of the original add → add_out
# (includes the residual connection to the *next* layer's add)
add_node.replace_all_uses_with(add_out)
graph.erase_node(add_node)
erased.add(id(add_node))
num_matches += 1
# Clean up any remaining dead code
if num_matches > 0:
eliminate_dead_code(gm)
info = TransformInfo(
skipped=False,

View File

@ -0,0 +1,215 @@
"""Transform for multi-stream execution of Q and KV projection chains in MLA attention.
In DeepSeek-style MLA (Multi-head Latent Attention), the input layernorm output
forks into two independent projection chains that merge at the RoPE + attention op:
- **Q chain** (heavier): q_a_proj -> rms_norm -> q_b_proj -> view -> split
- **KV chain** (lighter): kv_a_proj_with_mqa -> split -> rms_norm + view
The Q chain is ~9x heavier than the KV chain. This transform moves the KV
projection linear onto the auxiliary CUDA stream so it executes concurrently
with the Q chain on the main stream.
"""
from collections import deque
from typing import Callable, List, Tuple
import torch
from torch.fx import GraphModule, Node
from ...models.factory import ModelFactory
# Reuse CachedSequenceInterface for the _apply signature.
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import create_derived_custom_op
from ...utils.multi_stream_utils import (
_make_aux_stream_impl,
cuda_stream_manager,
record_event_passthrough,
)
from ...utils.node_utils import is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
# ---------------------------------------------------------------------------
# Supported linear op targets. Extend this list to cover quantised variants.
# ---------------------------------------------------------------------------
_LINEAR_OPS: List[Callable] = [
torch.ops.auto_deploy.torch_linear_simple,
torch.ops.aten.linear,
]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _is_linear(node: Node) -> bool:
"""Return ``True`` if *node* is a call to one of the supported linear ops."""
return is_op(node, _LINEAR_OPS)
def _has_downstream_linear(start: Node, max_depth: int = 3) -> bool:
"""BFS from *start* through its users and return ``True`` if a linear op is reachable.
The search only follows *user* edges (downstream in the data-flow graph)
and stops after *max_depth* hops. ``start`` itself is **not** checked.
"""
visited: set[Node] = {start}
queue: deque[Tuple[Node, int]] = deque()
for user in start.users:
queue.append((user, 1))
while queue:
node, depth = queue.popleft()
if node in visited:
continue
visited.add(node)
if _is_linear(node):
return True
if depth < max_depth:
for user in node.users:
queue.append((user, depth + 1))
return False
def _find_kv_proj_linears(gm: GraphModule) -> List[Tuple[Node, Node]]:
"""Find (fork_point, kv_linear) pairs suitable for aux-stream execution.
A *fork point* is a node that directly feeds two or more supported linear
ops. Among these linears the one that does **not** lead to another linear
within a small BFS depth is the KV projection candidate (the lighter
branch).
Returns a list of ``(fork_point, kv_linear_node)`` tuples.
"""
results: List[Tuple[Node, Node]] = []
for node in gm.graph.nodes:
# Collect direct linear users of this node.
linear_users = [u for u in node.users if _is_linear(u)]
if len(linear_users) < 2:
continue
# Separate into "has downstream linear" (Q-like) and "does not" (KV-like).
kv_candidates = [ln for ln in linear_users if not _has_downstream_linear(ln)]
q_candidates = [ln for ln in linear_users if _has_downstream_linear(ln)]
if not kv_candidates or not q_candidates:
continue
# Pick the KV candidate(s). In MLA there is exactly one per fork point.
for kv_linear in kv_candidates:
results.append((node, kv_linear))
return results
def _create_aux_op(base_op: Callable) -> Callable:
"""Create an ``_aux`` variant of a linear op that runs on the auxiliary CUDA stream.
Uses a custom ``make_fake`` that delegates to the base op's registered fake
so that output shapes are computed correctly (linear output shape != input shape).
"""
return create_derived_custom_op(
base_op,
"_aux",
_make_aux_stream_impl,
make_fake=lambda base: lambda *a, **kw: base(*a, **kw),
)
def _execute_kv_proj_in_aux_stream(gm: GraphModule) -> Tuple[GraphModule, int]:
"""Replace KV projection linears with aux-stream variants.
For each matched ``(fork_point, kv_linear)`` the rewriter:
1. Inserts ``record_event_passthrough(fork_point)`` so the main-stream
event is recorded *before* the Q-chain kernels are submitted.
2. Replaces the KV linear's target with its ``_aux`` variant and wires the
``record_event_passthrough`` output as the hidden-state input
(creating a true data dependency).
The remaining KV-chain ops (split, rms_norm, view) stay on the main
stream they are lightweight and run after the aux wait that is built
into the derived op.
Aux-stream variants are created lazily only for base ops that actually
appear in the matched KV positions.
"""
pairs = _find_kv_proj_linears(gm)
if not pairs:
return gm, 0
graph = gm.graph
node_order = {n: i for i, n in enumerate(graph.nodes)}
# Create aux ops lazily for whatever linear op types are found.
ops_in_graph = {kv_linear.target for _, kv_linear in pairs}
op_dict = {op: _create_aux_op(op) for op in ops_in_graph}
num_replaced = 0
for fork_point, kv_linear in pairs:
# Find the Q-chain linear(s) so we can insert the event record
# *before* the earliest Q-chain op in graph order.
q_linears = [u for u in fork_point.users if _is_linear(u) and u is not kv_linear]
earliest_q = min(q_linears, key=lambda n: node_order.get(n, 0))
# Insert record_event_passthrough right before the first Q-chain
# linear so the event is recorded before Q kernels hit the GPU.
with graph.inserting_before(earliest_q):
rec_node = graph.call_function(
record_event_passthrough,
args=(fork_point,),
)
# Replace KV linear with its aux-stream variant. The hidden-state
# input (args[0]) is rewired to ``rec_node`` to create a data
# dependency that ensures the event is recorded first.
new_args = tuple(rec_node if arg is fork_point else arg for arg in kv_linear.args)
with graph.inserting_after(kv_linear):
new_node = graph.call_function(
op_dict[kv_linear.target], args=new_args, kwargs=kv_linear.kwargs
)
kv_linear.replace_all_uses_with(new_node)
graph.erase_node(kv_linear)
num_replaced += 1
return gm, num_replaced
# ---------------------------------------------------------------------------
# Transform class
# ---------------------------------------------------------------------------
@TransformRegistry.register("multi_stream_mla_attn")
class MultiStreamMLAAttn(BaseTransform):
"""Multi-stream Q/KV projection parallelism for MLA attention blocks."""
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
# Ensure aux stream and events are set up for the current device.
cuda_stream_manager.add_device(torch.cuda.current_device())
gm, num_matches = _execute_kv_proj_in_aux_stream(gm)
info = TransformInfo(
skipped=False,
num_matches=num_matches,
is_clean=num_matches == 0,
has_valid_shapes=num_matches == 0,
)
return gm, info

View File

@ -1,295 +1,195 @@
"""Transform for multi-stream execution of MoE layers that have shared experts and routed experts."""
from threading import RLock
from typing import Any, Callable, Dict, List, Tuple
from typing import Callable, List, Optional, Set, Tuple
import torch
from torch.fx import GraphModule
from tensorrt_llm._torch.utils import ActivationType
from torch.fx import GraphModule, Node
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ...utils.multi_stream_utils import (
begin_aux_stream_passthrough,
cuda_stream_manager,
end_aux_stream_passthrough,
wait_aux_stream_passthrough,
)
from ...utils.node_utils import is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
# Previously, CudaStreamManager and the custom ops that use the cuda streams and events were
# placed in custom_ops folder. However doing so resulted in CudaStreamManager
# being created only in the parent process, but we need each rank to have its own CudaStreamManager that
# manages the cuda streams and events for that rank. Placing the logic to instantiate
# CudaStreamManager and the custom ops that use the cuda streams and events at the transform level ensures that
# each rank has its own CudaStreamManager since each rank applies the transform independently.
class _Singleton(type):
_instances: Dict[type, Any] = {}
_lock = RLock()
def _find_merge_add(moe_node: Node) -> Optional[Node]:
"""Walk forward from a MoE op through users to find the ``aten.add.Tensor`` merge node.
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
if cls not in cls._instances:
with cls._lock:
if cls not in cls._instances: # double-checked locking
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
The merge ``add`` is the node where shared-expert output and routed-expert
output are combined. The search is a breadth-first traversal of the user
graph starting from the MoE node.
"""
visited: Set[Node] = set()
queue = list(moe_node.users.keys())
while queue:
n = queue.pop(0)
if n in visited:
continue
visited.add(n)
if is_op(n, torch.ops.aten.add.Tensor):
return n
queue.extend(n.users.keys())
return None
# A singleton that holds the pointers to the cuda streams and events.
# Each device has its own cuda streams and events.
class CudaStreamManager(metaclass=_Singleton):
AUX_STREAM_NAME = "aux"
MAIN_STREAM_NAME = "main"
devices: List[torch.device] = []
events: Dict[torch.device, Dict[str, Any]] = {}
streams: Dict[torch.device, Dict[str, Any]] = {}
def __init__(self) -> None:
# In case __init__ ever gets called twice, guard against re-init
if hasattr(self, "streams"):
return
self._lock = RLock()
self.add_device(torch.cuda.current_device())
def add_device(self, device: int) -> None:
if device not in self.devices:
self.devices.append(device)
with torch.cuda.device(device):
self.events[device] = {
self.AUX_STREAM_NAME: torch.cuda.Event(),
self.MAIN_STREAM_NAME: torch.cuda.Event(),
}
self.streams[device] = {
self.AUX_STREAM_NAME: torch.cuda.Stream(),
self.MAIN_STREAM_NAME: torch.cuda.default_stream(),
}
else:
ad_logger.warning(f"CudaStreamManager: Device {device} already added")
def get_stream(self, device: int, stream_name: str) -> torch.cuda.Stream:
return self.streams[device][stream_name]
def get_event(self, device: int, event_name: str) -> torch.cuda.Event:
return self.events[device][event_name]
def _get_ancestors(node: Node) -> Set[Node]:
"""Return the set of all nodes reachable by walking backwards from *node*."""
ancestors: Set[Node] = set()
queue = list(node.all_input_nodes)
while queue:
n = queue.pop()
if n in ancestors:
continue
ancestors.add(n)
queue.extend(n.all_input_nodes)
return ancestors
# Every device will have a singleton instance of CudaStreamManager.
cuda_stream_manager = CudaStreamManager()
@torch.library.custom_op("auto_deploy::record_event", mutates_args=())
def record_event(device: int, stream_name: str) -> None:
event = cuda_stream_manager.get_event(device, stream_name)
event.record()
@torch.library.custom_op("auto_deploy::wait_event", mutates_args=())
def wait_event(device: int, stream_name: str) -> None:
event = cuda_stream_manager.get_event(device, stream_name)
event.wait()
# skip during compilation
@torch._dynamo.disable
def record_event_wrapper(
fn: Callable,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> torch.Tensor:
device = kwargs.pop("device", torch.cuda.current_device())
output = fn(*args, **kwargs)
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
return output
@torch._dynamo.disable
def aux_stream_wrapper(
fn: Callable,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> torch.Tensor:
stream_name = cuda_stream_manager.AUX_STREAM_NAME
device = kwargs.pop("device", torch.cuda.current_device())
with torch.cuda.stream(cuda_stream_manager.get_stream(device, stream_name)):
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
output = fn(*args, **kwargs)
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.AUX_STREAM_NAME)
return output
# trtllm bf16
@torch.library.custom_op("auto_deploy::trtllm_moe_fused_aux", mutates_args=())
def trtllm_moe_fused_aux(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
device = torch.cuda.current_device()
with torch.cuda.stream(
cuda_stream_manager.get_stream(device, cuda_stream_manager.AUX_STREAM_NAME)
):
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.trtllm_moe_fused(
x,
selected_experts,
routing_weights,
w3_w1_stacked_weight,
w2_stacked_weight,
is_gated_mlp,
act_fn,
)
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.AUX_STREAM_NAME)
return output
@trtllm_moe_fused_aux.register_fake
def trtllm_moe_fused_aux_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
return torch.empty_like(x)
# triton bf16
@torch.library.custom_op("auto_deploy::triton_moe_fused_aux", mutates_args=())
def triton_moe_fused_aux(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
device = torch.cuda.current_device()
with torch.cuda.stream(
cuda_stream_manager.get_stream(device, cuda_stream_manager.AUX_STREAM_NAME)
):
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.triton_moe_fused(
x,
selected_experts,
routing_weights,
w1_stacked_weight,
w2_stacked_weight,
)
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.AUX_STREAM_NAME)
return output
@triton_moe_fused_aux.register_fake
def triton_moe_fused_aux_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused_aux", mutates_args=())
def trtllm_quant_fp8_moe_fused_aux(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
fc1_expert_weights: torch.Tensor,
fc2_expert_weights: torch.Tensor,
fc1_act_scale: torch.Tensor,
fc1_dequant_scale: torch.Tensor,
fc2_act_scale_reciprocal: torch.Tensor,
fc2_dequant_scale: torch.Tensor,
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
device = torch.cuda.current_device()
with torch.cuda.stream(
cuda_stream_manager.get_stream(device, cuda_stream_manager.AUX_STREAM_NAME)
):
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused(
x,
selected_experts,
routing_weights,
fc1_expert_weights,
fc2_expert_weights,
fc1_act_scale,
fc1_dequant_scale,
fc2_act_scale_reciprocal,
fc2_dequant_scale,
is_gated_mlp,
act_fn,
)
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.AUX_STREAM_NAME)
return output
@trtllm_quant_fp8_moe_fused_aux.register_fake
def trtllm_quant_fp8_moe_fused_aux_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
fc1_expert_weights: torch.Tensor,
fc2_expert_weights: torch.Tensor,
fc1_act_scale: torch.Tensor,
fc1_dequant_scale: torch.Tensor,
fc2_act_scale_reciprocal: torch.Tensor,
fc2_dequant_scale: torch.Tensor,
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
return torch.empty_like(x)
def _execute_op_in_aux_stream(
gm: GraphModule, op_dict: Dict[Callable, Callable]
def _execute_shared_expert_in_aux_stream(
gm: GraphModule, moe_ops: List[Callable]
) -> Tuple[GraphModule, int]:
"""Move shared-expert computation to the auxiliary CUDA stream.
For each MoE fused op in the graph:
1. Walk forward to find the ``aten.add.Tensor`` that merges the
shared-expert output and the routed-expert output.
2. Identify which ``add`` input is the routed branch (descended from
the MoE node) and which is the shared-expert branch.
3. Trace the shared-expert branch backwards to collect all its
computation nodes and identify the fork point (the latest common
ancestor shared with the MoE / routing path).
4. Insert ``begin_aux_stream_passthrough`` before the first shared-expert
op to switch to the auxiliary CUDA stream.
5. Insert ``end_aux_stream_passthrough`` after the last shared-expert op
to switch back to the main stream.
6. Insert ``wait_aux_stream_passthrough`` on the routed-branch input
just before the ``add`` so the main stream waits for the auxiliary
stream to finish before merging outputs.
"""
graph = gm.graph
num_replaced = 0
# Collect targets first to avoid mutating while iterating
target_nodes = [n for n in graph.nodes if is_op(n, op_dict.keys())]
# Collect targets first to avoid mutating while iterating.
target_nodes = [n for n in graph.nodes if is_op(n, moe_ops)]
if not target_nodes:
return gm, 0
for n in target_nodes:
target_input_node = None
for input_node in n.all_input_nodes:
if input_node.target == torch.ops.aten.view.default:
target_input_node = input_node
break
# Look through dtype cast nodes (aten.to) to find the view node
if input_node.target == torch.ops.aten.to:
for nested_input in input_node.all_input_nodes:
if nested_input.target == torch.ops.aten.view.default:
target_input_node = nested_input
break
if target_input_node is not None:
break
node_order = {node: i for i, node in enumerate(graph.nodes)}
assert target_input_node is not None, f"Target input node not found for node {n}"
with graph.inserting_before(target_input_node):
kwargs = target_input_node.kwargs.copy()
kwargs["device"] = torch.cuda.current_device()
new_node = graph.call_function(
record_event_wrapper,
args=(target_input_node.target, *target_input_node.args),
kwargs=kwargs,
for moe_node in target_nodes:
# ---- Step 1: Find the merge ``add`` node. ----
add_node = _find_merge_add(moe_node)
if add_node is None:
ad_logger.warning(
f"No merge add found downstream of MoE node {moe_node.name}; "
"skipping multi-stream transform for this node."
)
target_input_node.replace_all_uses_with(new_node)
graph.erase_node(target_input_node)
with graph.inserting_after(n):
new_node = graph.call_function(op_dict[n.target], args=n.args, kwargs=n.kwargs)
n.replace_all_uses_with(new_node)
graph.erase_node(n)
continue
# ---- Step 2: Determine which ``add`` input is routed vs. shared. ----
arg0, arg1 = add_node.args[0], add_node.args[1]
arg0_ancestors = _get_ancestors(arg0)
if moe_node in arg0_ancestors or arg0 is moe_node:
routed_output, shared_output = arg0, arg1
else:
routed_output, shared_output = arg1, arg0
# ---- Step 3: Collect shared-expert nodes & find fork point. ----
moe_ancestors = _get_ancestors(moe_node)
moe_ancestors.add(moe_node)
shared_nodes: List[Node] = []
fork_point: Optional[Node] = None
visited: Set[Node] = set()
queue = [shared_output]
while queue:
n = queue.pop(0)
if n in visited:
continue
visited.add(n)
# Skip static weight / parameter nodes.
if n.op == "get_attr":
continue
if n in moe_ancestors:
# This node is on the MoE / routing path — candidate fork point.
if fork_point is None or node_order.get(n, 0) > node_order.get(fork_point, 0):
fork_point = n
continue
shared_nodes.append(n)
for inp in n.all_input_nodes:
queue.append(inp)
if not shared_nodes or fork_point is None:
ad_logger.warning(
f"Could not identify shared-expert subgraph for MoE node "
f"{moe_node.name}; skipping multi-stream transform for this node."
)
continue
# Order shared nodes by their position in the graph.
shared_nodes.sort(key=lambda n: node_order.get(n, 0))
first_shared = shared_nodes[0]
# Sanity check: the first shared op must directly consume the fork
# point so we can wire begin_aux_stream_passthrough into it.
if fork_point not in first_shared.all_input_nodes:
ad_logger.warning(
f"First shared-expert op ({first_shared.name}) does not directly "
f"consume fork point ({fork_point.name}); skipping."
)
continue
# ---- Step 4: Insert begin_aux before the first shared-expert op. ----
# NOTE: do NOT bake ``torch.cuda.current_device()`` into the graph —
# that would hard-code device 0 and break on other ranks in a
# multi-GPU setup. Omitting ``device`` lets the passthrough
# functions resolve the device at **runtime** (default ``-1``).
with graph.inserting_before(first_shared):
begin_aux_node = graph.call_function(
begin_aux_stream_passthrough,
args=(fork_point,),
)
# Create a data dependency: first_shared reads begin_aux output
# instead of fork_point.
first_shared.args = tuple(
begin_aux_node if arg is fork_point else arg for arg in first_shared.args
)
# ---- Step 5: Insert end_aux after the last shared-expert op. ----
with graph.inserting_after(shared_output):
end_aux_node = graph.call_function(
end_aux_stream_passthrough,
args=(shared_output,),
)
# Replace shared-expert input to ``add`` with end_aux output.
add_node.args = tuple(
end_aux_node if arg is shared_output else arg for arg in add_node.args
)
# ---- Step 6: Insert wait_aux before the ``add``. ----
with graph.inserting_before(add_node):
wait_aux_node = graph.call_function(
wait_aux_stream_passthrough,
args=(routed_output,),
)
add_node.args = tuple(
wait_aux_node if arg is routed_output else arg for arg in add_node.args
)
num_replaced += 1
return gm, num_replaced
@ -306,14 +206,16 @@ class MultiStreamMOE(BaseTransform):
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
op_dict = {
torch.ops.auto_deploy.trtllm_moe_fused: torch.ops.auto_deploy.trtllm_moe_fused_aux,
torch.ops.auto_deploy.triton_moe_fused: torch.ops.auto_deploy.triton_moe_fused_aux,
torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused: torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused_aux,
}
base_ops = [
torch.ops.auto_deploy.trtllm_moe_fused,
torch.ops.auto_deploy.triton_moe_fused,
torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused,
torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused,
]
# Ensure that aux stream and events for the current device are added to the CudaStreamManager.
cuda_stream_manager.add_device(torch.cuda.current_device())
gm, num_matches = _execute_op_in_aux_stream(gm, op_dict)
gm, num_matches = _execute_shared_expert_in_aux_stream(gm, base_ops)
info = TransformInfo(
skipped=False,
@ -321,5 +223,4 @@ class MultiStreamMOE(BaseTransform):
is_clean=num_matches == 0,
has_valid_shapes=num_matches == 0,
)
return gm, info

View File

@ -20,6 +20,100 @@ from torch.utils._pytree import _LEAF_SPEC, TreeSpec
from .logger import ad_logger
from .node_utils import get_weight_tensor, is_op
# ---------------------------------------------------------------------------
# Dynamic custom-op derivation helpers
# ---------------------------------------------------------------------------
# These are used to create new custom ops that share the schema of an existing
# op but wrap it with additional logic (e.g. stream management). A single
# module-level dict of ``Library`` objects (keyed by namespace) is used so that
# the registrations persist and are visible via ``torch.ops.<namespace>.*``.
# ---------------------------------------------------------------------------
_derived_op_libs: Dict[str, torch.library.Library] = {}
_derived_op_registry: Dict[str, Callable] = {}
def _get_lib(namespace: str) -> torch.library.Library:
"""Return (and lazily create) a ``FRAGMENT`` Library for *namespace*."""
if namespace not in _derived_op_libs:
_derived_op_libs[namespace] = torch.library.Library(namespace, "FRAGMENT")
return _derived_op_libs[namespace]
def create_derived_custom_op(
base_op: Callable,
suffix: str,
make_impl: Callable[[Callable], Callable],
make_fake: Optional[Callable[[Callable], Callable]] = None,
) -> Callable:
"""Dynamically create a new custom op derived from an existing one.
The new op has the **same** schema (arguments, default values, and return
type) as *base_op* but with a different name (``<base_name><suffix>``) and a
custom implementation produced by *make_impl*.
Args:
base_op: The base custom op either an ``OpOverloadPacket``
(e.g. ``torch.ops.auto_deploy.trtllm_moe_fused``) or an
``OpOverload`` (e.g. ``trtllm_moe_fused.default``).
suffix: Suffix appended to the base op name to form the new op name
(e.g. ``"_aux"``).
make_impl: A factory ``(base_overload) -> impl_fn`` that receives the
resolved base ``OpOverload`` and returns the *implementation*
function for the new op. ``impl_fn`` will be called with the
same positional/keyword arguments as *base_op*.
make_fake: Optional factory ``(base_overload) -> fake_fn`` that returns
the *Meta / fake-tensor* implementation. When ``None`` the
default fake implementation ``torch.empty_like(args[0])`` is used.
Returns:
The newly registered op as an ``OpOverloadPacket``
(e.g. ``torch.ops.auto_deploy.<name><suffix>``). Repeated calls with
the same *base_op* and *suffix* return the cached op.
"""
base_overload = base_op.default if hasattr(base_op, "default") else base_op
schema = base_overload._schema
# e.g. "auto_deploy::trtllm_moe_fused"
qualified_name = schema.name
namespace, base_name = qualified_name.split("::")
new_name = f"{base_name}{suffix}"
new_qualified = f"{namespace}::{new_name}"
# Return the cached op if it was already created.
if new_qualified in _derived_op_registry:
return _derived_op_registry[new_qualified]
# Build the schema string for the derived op. ``str(schema)`` produces a
# fully-qualified string such as
# auto_deploy::trtllm_moe_fused(Tensor x, …) -> Tensor
# We replace the qualified name with the bare new name (the Library already
# knows its namespace).
new_schema_str = str(schema).replace(qualified_name, new_name, 1)
lib = _get_lib(namespace)
lib.define(new_schema_str)
# Register the real implementation for all devices.
# We use "CompositeExplicitAutograd" so that we can provide a separate
# Meta / fake kernel for shape inference.
lib.impl(new_name, make_impl(base_overload), "CompositeExplicitAutograd")
# Register the Meta / fake implementation.
if make_fake is not None:
lib.impl(new_name, make_fake(base_overload), "Meta")
else:
def _default_fake(*args, **kwargs):
return torch.empty_like(args[0])
lib.impl(new_name, _default_fake, "Meta")
new_op = getattr(getattr(torch.ops, namespace), new_name)
_derived_op_registry[new_qualified] = new_op
return new_op
_NoValType = type("_NoValType", (), {})

View File

@ -0,0 +1,242 @@
"""Shared CUDA multi-stream utilities for multi-stream transforms.
This module provides the core infrastructure for executing parts of an FX graph
on auxiliary CUDA streams. It is consumed by the multi-stream MoE and MLA
attention transforms in ``..transform.library``.
Key components:
- ``CudaStreamManager``: per-device singleton managing auxiliary streams/events.
- Custom ops ``record_event`` / ``wait_event``: graph-safe event primitives.
- Passthrough helpers that switch streams while preserving the data-flow edges
required by FX graph execution and CUDA graph capture.
- ``_make_aux_stream_impl``: factory for building an implementation that runs
a base op on the auxiliary CUDA stream.
"""
from threading import RLock
from typing import Any, Callable, Dict, List
import torch
from .logger import ad_logger
# ---------------------------------------------------------------------------
# Singleton metaclass
# ---------------------------------------------------------------------------
class _Singleton(type):
_instances: Dict[type, Any] = {}
_lock = RLock()
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
if cls not in cls._instances:
with cls._lock:
if cls not in cls._instances: # double-checked locking
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
# ---------------------------------------------------------------------------
# CudaStreamManager
# ---------------------------------------------------------------------------
# Previously, CudaStreamManager and the custom ops that use the cuda streams and events were
# placed in custom_ops folder. However doing so resulted in CudaStreamManager
# being created only in the parent process, but we need each rank to have its own CudaStreamManager that
# manages the cuda streams and events for that rank. Placing the logic to instantiate
# CudaStreamManager and the custom ops that use the cuda streams and events at the transform level ensures that
# each rank has its own CudaStreamManager since each rank applies the transform independently.
class CudaStreamManager(metaclass=_Singleton):
AUX_STREAM_NAME = "aux"
MAIN_STREAM_NAME = "main"
devices: List[torch.device] = []
events: Dict[torch.device, Dict[str, Any]] = {}
streams: Dict[torch.device, Dict[str, Any]] = {}
# Per-device save slot for the caller's stream. ``begin_aux_stream_passthrough``
# saves the real current stream here so that ``end_aux_stream_passthrough`` can
# restore it — this is critical during CUDA graph capture where the capture stream
# differs from ``torch.cuda.default_stream()``.
_caller_streams: Dict[int, Any] = {}
def __init__(self) -> None:
# In case __init__ ever gets called twice, guard against re-init
if hasattr(self, "streams"):
return
self._lock = RLock()
self.add_device(torch.cuda.current_device())
def add_device(self, device: int) -> None:
if device not in self.devices:
self.devices.append(device)
with torch.cuda.device(device):
self.events[device] = {
self.AUX_STREAM_NAME: torch.cuda.Event(),
self.MAIN_STREAM_NAME: torch.cuda.Event(),
}
self.streams[device] = {
self.AUX_STREAM_NAME: torch.cuda.Stream(),
self.MAIN_STREAM_NAME: torch.cuda.default_stream(),
}
else:
ad_logger.warning(f"CudaStreamManager: Device {device} already added")
def get_stream(self, device: int, stream_name: str) -> torch.cuda.Stream:
return self.streams[device][stream_name]
def get_event(self, device: int, event_name: str) -> torch.cuda.Event:
return self.events[device][event_name]
# Every device will have a singleton instance of CudaStreamManager.
cuda_stream_manager = CudaStreamManager()
# ---------------------------------------------------------------------------
# Custom ops — graph-safe CUDA event primitives
# ---------------------------------------------------------------------------
@torch.library.custom_op("auto_deploy::record_event", mutates_args=())
def record_event(device: int, stream_name: str) -> None:
event = cuda_stream_manager.get_event(device, stream_name)
event.record()
@torch.library.custom_op("auto_deploy::wait_event", mutates_args=())
def wait_event(device: int, stream_name: str) -> None:
event = cuda_stream_manager.get_event(device, stream_name)
event.wait()
# ---------------------------------------------------------------------------
# Passthrough helpers
# ---------------------------------------------------------------------------
@torch._dynamo.disable
def record_event_passthrough(
x: torch.Tensor,
*,
device: int = -1,
) -> torch.Tensor:
"""Record a CUDA event on the main stream and return the input unchanged.
Inserted after the gating/routing computation to mark a synchronization
point. The aux stream waits for this event before starting the MoE
computation, enabling overlap between the shared expert (main stream)
and routed experts (aux stream).
"""
if device < 0:
device = torch.cuda.current_device()
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
return x
@torch._dynamo.disable
def begin_aux_stream_passthrough(
x: torch.Tensor,
*,
device: int = -1,
) -> torch.Tensor:
"""Record a CUDA event on the main stream, switch to aux, and wait for it.
After this function returns the thread-local current stream is the
auxiliary stream. All subsequent GPU ops dispatched by the FX graph
interpreter will be recorded on aux until ``end_aux_stream_passthrough``
switches back to main.
"""
if device < 0:
device = torch.cuda.current_device()
# Save the *actual* current stream so ``end_aux`` can restore it.
# During CUDA graph capture the current stream is the capture stream,
# which is NOT ``torch.cuda.default_stream()``.
caller_stream = torch.cuda.current_stream(device)
cuda_stream_manager._caller_streams[device] = caller_stream
# Record where the caller's stream has reached so aux knows when data is ready.
main_event = cuda_stream_manager.get_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
main_event.record(caller_stream)
# Switch the thread-local current stream to aux.
aux_stream = cuda_stream_manager.get_stream(device, cuda_stream_manager.AUX_STREAM_NAME)
torch.cuda.set_stream(aux_stream)
# Make aux wait for the main-stream event before executing any work.
aux_stream.wait_event(main_event)
return x
@torch._dynamo.disable
def end_aux_stream_passthrough(
x: torch.Tensor,
*,
device: int = -1,
) -> torch.Tensor:
"""Record a CUDA event on the aux stream and switch back to the caller's stream.
This does **not** make the caller's stream wait for aux. The caller must
insert ``wait_aux_stream_passthrough`` at the point where both branches
need to be synchronised (typically right before the ``add`` that merges
shared-expert and routed-expert outputs).
"""
if device < 0:
device = torch.cuda.current_device()
# Record the aux-stream progress so the caller's stream can wait for it later.
aux_event = cuda_stream_manager.get_event(device, cuda_stream_manager.AUX_STREAM_NAME)
aux_event.record()
# Restore the caller's stream saved by ``begin_aux_stream_passthrough``.
# This is critical during CUDA graph capture where the capture stream
# differs from ``torch.cuda.default_stream()``.
caller_stream = cuda_stream_manager._caller_streams.pop(device, None)
if caller_stream is not None:
torch.cuda.set_stream(caller_stream)
else:
torch.cuda.set_stream(
cuda_stream_manager.get_stream(device, cuda_stream_manager.MAIN_STREAM_NAME)
)
return x
@torch._dynamo.disable
def wait_aux_stream_passthrough(
x: torch.Tensor,
*,
device: int = -1,
) -> torch.Tensor:
"""Make the current stream wait for the auxiliary stream's last recorded event.
This is a GPU-side wait (non-blocking on the CPU). Insert this right
before the ``add`` that merges shared-expert output (computed on aux)
with routed-expert output (computed on main).
Uses ``torch.cuda.current_stream()`` rather than the stored default stream
so that the correct stream is waited on during CUDA graph capture.
"""
if device < 0:
device = torch.cuda.current_device()
aux_event = cuda_stream_manager.get_event(device, cuda_stream_manager.AUX_STREAM_NAME)
torch.cuda.current_stream(device).wait_event(aux_event)
return x
# ---------------------------------------------------------------------------
# Aux-stream implementation factory
# ---------------------------------------------------------------------------
def _make_aux_stream_impl(base_overload: Callable) -> Callable:
"""Build an implementation that runs *base_overload* on the auxiliary CUDA stream."""
def _impl(*args, **kwargs):
device = torch.cuda.current_device()
with torch.cuda.stream(
cuda_stream_manager.get_stream(device, cuda_stream_manager.AUX_STREAM_NAME)
):
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
output = base_overload(*args, **kwargs)
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.AUX_STREAM_NAME)
return output
return _impl

View File

@ -343,6 +343,10 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
"sharding_source": ['factory', 'heuristic'],
"sharding_dims": ['ep', 'bmm'],
},
"multi_stream_moe": {
"stage": "compile",
"enabled": True,
},
}
}
@ -455,7 +459,15 @@ class TestGLM4Flash(LlmapiAccuracyTestHarness):
"fuse_nvfp4_moe": {
"allow_different_input_scales": True,
},
},
"multi_stream_moe": {
"stage": "compile",
"enabled": True,
},
"multi_stream_mla_attn": {
"stage": "compile",
"enabled": True,
},
}
}
if enable_chunked_prefill:
config["enable_chunked_prefill"] = True

View File

@ -4,6 +4,9 @@ import pytest
import torch
from _custom_op_utils import torch_rope_reference
# Import after we've imported torch (to ensure custom ops are registered)
from tensorrt_llm._torch.auto_deploy.custom_ops.rope import triton_rope # noqa: F401
def _precompute_freqs_cis(
seq_len: int, head_dim: int, rope_theta: Optional[float] = None

View File

@ -1,134 +0,0 @@
from typing import Tuple
import torch
import torch.nn as nn
from torch.fx import GraphModule, Node
from tensorrt_llm._torch.auto_deploy.transform.library.multi_stream_moe import (
aux_stream_wrapper,
cuda_stream_manager,
record_event_wrapper,
)
from tensorrt_llm._torch.auto_deploy.utils._graph import canonicalize_graph
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
@torch.library.custom_op("auto_deploy::multi_stream_linear", mutates_args=())
def multi_stream_linear(
input: torch.Tensor, weight0: torch.Tensor, weight1: torch.Tensor
) -> torch.Tensor:
output = torch.ops.aten.linear(input, weight0)
output = torch.ops.aten.linear(output, weight1)
return output
@multi_stream_linear.register_fake
def multi_stream_linear_fake(input, weight0, weight1):
"""Fake implementation of multi_stream_linear."""
output = torch.ops.aten.linear(input, weight0)
return torch.ops.aten.linear(output, weight1)
def replace_multi_stream_linear_with_aux_stream_wrapper(gm: GraphModule) -> Tuple[GraphModule, int]:
"""Traverse ``gm`` and replace all ``auto_deploy::multi_stream_linear`` ops with ``aux_stream_wrapper``.
The replacement preserves the original args/kwargs of the node.
After rewriting, the graph is cleaned and recompiled.
Args:
gm: The FX graph module to transform.
aux_stream_wrapper: A callable to replace the custom op with.
Returns:
A tuple of (gm, num_replaced)
"""
graph = gm.graph
num_replaced = 0
# Collect targets first to avoid mutating while iterating
target_nodes: list[Node] = []
target_nodes = [n for n in graph.nodes if is_op(n, torch.ops.auto_deploy.multi_stream_linear)]
for n in target_nodes:
target_input_node = None
for input_node in n.all_input_nodes:
if len(input_node.users) > 1:
target_input_node = input_node
break
if target_input_node is None:
raise ValueError(f"Target input node not found for node {n}")
with graph.inserting_before(target_input_node):
kwargs = target_input_node.kwargs.copy()
kwargs["device"] = torch.cuda.current_device()
new_node = graph.call_function(
record_event_wrapper,
args=(target_input_node.target, *target_input_node.args),
kwargs=kwargs,
)
target_input_node.replace_all_uses_with(new_node)
graph.erase_node(target_input_node)
with graph.inserting_after(n):
new_node = graph.call_function(
aux_stream_wrapper, args=(n.target, *n.args), kwargs=n.kwargs
)
n.replace_all_uses_with(new_node)
graph.erase_node(n)
num_replaced += 1
if num_replaced:
canonicalize_graph(gm)
return gm, num_replaced
class ParallelTwoLinear(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.fc10 = nn.Linear(in_dim, in_dim)
self.fc11 = nn.Linear(in_dim, out_dim)
self.fc2 = nn.Linear(in_dim, out_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.relu(x)
y0 = self.fc2(x)
y1 = torch.ops.auto_deploy.multi_stream_linear(x, self.fc10.weight, self.fc11.weight)
return y0 + y1
def test_multi_stream_linear():
in_dim, out_dim = 128, 256
cuda_stream_manager.add_device(torch.cuda.current_device())
model = (
nn.Sequential(ParallelTwoLinear(in_dim, out_dim), ParallelTwoLinear(out_dim, out_dim))
.eval()
.to("cuda")
)
# Example input used for export
example_input = torch.randn(4, in_dim).to("cuda")
# Export the graph
egm = torch.export.export(model, (example_input,))
gm = egm.module()
test_x = torch.randn(4, in_dim).to("cuda")
ref_output = model(test_x)
# pattern matching and replace
gm, num_replaced = replace_multi_stream_linear_with_aux_stream_wrapper(gm)
assert num_replaced == 2
y = gm(test_x)
assert torch.allclose(y, ref_output)
static_x = torch.randn(4, in_dim).to("cuda")
static_output = torch.randn(4, out_dim).to("cuda")
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
static_output.copy_(gm(static_x))
static_x.copy_(test_x)
graph.replay()
assert torch.allclose(static_output, ref_output)

View File

@ -0,0 +1,230 @@
"""Tests for multi-stream Q/KV projection parallelism in MLA attention.
The test builds a minimal mock model that mirrors the MLA fork pattern:
a shared input feeds two parallel linear chains (one heavier "Q-like",
one lighter "KV-like") whose outputs are combined with an add.
The transform should:
1. Detect the fork point (shared input with 2+ linear users).
2. Identify the lighter KV-like linear (no downstream linear within
a few hops) vs. the heavier Q-like chain (has a downstream linear).
3. Move the KV linear onto the auxiliary CUDA stream.
4. Preserve numerical correctness.
5. Be compatible with CUDA graph capture & replay.
"""
import torch
import torch.nn as nn
from tensorrt_llm._torch.auto_deploy.transform.library.multi_stream_attn import (
_execute_kv_proj_in_aux_stream,
_find_kv_proj_linears,
)
from tensorrt_llm._torch.auto_deploy.utils.multi_stream_utils import cuda_stream_manager
# ---------------------------------------------------------------------------
# Helpers -- mock MLA-like module
# ---------------------------------------------------------------------------
class MockMLABlock(nn.Module):
"""Simplified MLA-like attention block with Q and KV projection chains.
Q chain (heavier): q_a_proj -> relu (stand-in for rms_norm) -> q_b_proj
KV chain (lighter): kv_a_proj
Merge: add(q_b_proj_output, kv_a_proj_output)
The layernorm at the output simulates the inter-layer distance in a real
transformer (output projection, residual add, layernorm) so that the
next layer's fork point is beyond the BFS max_depth from this layer's
KV linear.
"""
def __init__(self, hidden_dim: int, q_inner_dim: int, kv_out_dim: int):
super().__init__()
# Q chain: two linears with a non-linearity in between
self.q_a_proj = nn.Linear(hidden_dim, q_inner_dim, bias=False)
self.q_b_proj = nn.Linear(q_inner_dim, kv_out_dim, bias=False)
# KV chain: single linear
self.kv_a_proj = nn.Linear(hidden_dim, kv_out_dim, bias=False)
# Inter-layer distance (layernorm + relu simulate residual + norm)
self.layernorm = nn.LayerNorm(kv_out_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Q chain: q_a_proj -> relu -> q_b_proj
q = self.q_a_proj(x)
q = torch.nn.functional.relu(q)
q = self.q_b_proj(q)
# KV chain: kv_a_proj
kv = self.kv_a_proj(x)
out = q + kv
# Inter-layer distance to push next layer's linears beyond BFS depth
return self.layernorm(torch.nn.functional.relu(out))
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def _build_gm(model, example_input):
"""Export *model* to an FX GraphModule."""
egm = torch.export.export(model, (example_input,))
return egm.module()
def test_pattern_matching_single_block():
"""The pattern matcher should find exactly one pair for a single MLA block."""
model = MockMLABlock(128, 64, 128).eval().to("cuda")
example_input = torch.randn(4, 128, device="cuda")
gm = _build_gm(model, example_input)
pairs = _find_kv_proj_linears(gm)
assert len(pairs) == 1, f"Expected 1 fork-point pair, got {len(pairs)}"
def test_pattern_matching_multi_block():
"""Multiple layers with sufficient inter-layer distance should all be matched."""
hidden_dim, q_inner_dim, kv_out_dim = 128, 64, 128
model = (
nn.Sequential(
MockMLABlock(hidden_dim, q_inner_dim, kv_out_dim),
MockMLABlock(kv_out_dim, q_inner_dim, kv_out_dim),
)
.eval()
.to("cuda")
)
example_input = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example_input)
pairs = _find_kv_proj_linears(gm)
assert len(pairs) == 2, f"Expected 2 fork-point pairs, got {len(pairs)}"
def test_numerical_correctness():
"""After the transform the GraphModule must produce the same output as the original model."""
hidden_dim, q_inner_dim, kv_out_dim = 128, 64, 128
cuda_stream_manager.add_device(torch.cuda.current_device())
model = MockMLABlock(hidden_dim, q_inner_dim, kv_out_dim).eval().to("cuda")
example_input = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example_input)
test_x = torch.randn(4, hidden_dim, device="cuda")
ref_output = model(test_x)
gm, num_replaced = _execute_kv_proj_in_aux_stream(gm)
assert num_replaced == 1, f"Expected 1 replacement, got {num_replaced}"
y = gm(test_x)
assert torch.allclose(y, ref_output, atol=1e-5), (
f"Output mismatch: max diff = {(y - ref_output).abs().max().item()}"
)
def test_numerical_correctness_multi_block():
"""Multi-block correctness test."""
hidden_dim, q_inner_dim, kv_out_dim = 128, 64, 128
cuda_stream_manager.add_device(torch.cuda.current_device())
model = (
nn.Sequential(
MockMLABlock(hidden_dim, q_inner_dim, kv_out_dim),
MockMLABlock(kv_out_dim, q_inner_dim, kv_out_dim),
)
.eval()
.to("cuda")
)
example_input = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example_input)
test_x = torch.randn(4, hidden_dim, device="cuda")
ref_output = model(test_x)
gm, num_replaced = _execute_kv_proj_in_aux_stream(gm)
assert num_replaced == 2, f"Expected 2 replacements, got {num_replaced}"
y = gm(test_x)
assert torch.allclose(y, ref_output, atol=1e-5), (
f"Output mismatch: max diff = {(y - ref_output).abs().max().item()}"
)
def test_cuda_graph_compatibility():
"""The transformed GraphModule must work under CUDA graph capture and replay."""
hidden_dim, q_inner_dim, kv_out_dim = 128, 64, 128
cuda_stream_manager.add_device(torch.cuda.current_device())
model = MockMLABlock(hidden_dim, q_inner_dim, kv_out_dim).eval().to("cuda")
example_input = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example_input)
test_x = torch.randn(4, hidden_dim, device="cuda")
ref_output = model(test_x)
gm, num_replaced = _execute_kv_proj_in_aux_stream(gm)
assert num_replaced == 1
# Allocate static buffers for CUDA graph capture.
static_x = torch.randn(4, hidden_dim, device="cuda")
static_output = torch.randn(4, kv_out_dim, device="cuda")
# Warm up (required before capture).
for _ in range(3):
static_output.copy_(gm(static_x))
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
static_output.copy_(gm(static_x))
static_x.copy_(test_x)
graph.replay()
assert torch.allclose(static_output, ref_output, atol=1e-5), (
f"CUDA graph output mismatch: max diff = {(static_output - ref_output).abs().max().item()}"
)
def test_no_match_on_single_linear():
"""A node with only one linear user should not be matched."""
class SingleLinear(nn.Module):
def __init__(self, dim):
super().__init__()
self.fc = nn.Linear(dim, dim, bias=False)
def forward(self, x):
return self.fc(x)
model = SingleLinear(64).eval().to("cuda")
example_input = torch.randn(4, 64, device="cuda")
gm = _build_gm(model, example_input)
pairs = _find_kv_proj_linears(gm)
assert len(pairs) == 0, f"Expected 0 matches, got {len(pairs)}"
def test_no_match_when_both_have_downstream_linear():
"""When *both* branches have downstream linears the pattern should not match."""
class BothHeavy(nn.Module):
def __init__(self, dim, inner):
super().__init__()
self.fc_a1 = nn.Linear(dim, inner, bias=False)
self.fc_a2 = nn.Linear(inner, dim, bias=False)
self.fc_b1 = nn.Linear(dim, inner, bias=False)
self.fc_b2 = nn.Linear(inner, dim, bias=False)
def forward(self, x):
a = self.fc_a2(torch.relu(self.fc_a1(x)))
b = self.fc_b2(torch.relu(self.fc_b1(x)))
return a + b
model = BothHeavy(64, 32).eval().to("cuda")
example_input = torch.randn(4, 64, device="cuda")
gm = _build_gm(model, example_input)
pairs = _find_kv_proj_linears(gm)
assert len(pairs) == 0, f"Expected 0 matches, got {len(pairs)}"

View File

@ -0,0 +1,502 @@
"""Tests for multi-stream MoE shared-expert transform across model architectures.
Verifies that ``_execute_shared_expert_in_aux_stream`` correctly identifies the
shared-expert branch and moves it to the auxiliary CUDA stream for the MoE
patterns used in DeepSeek V3, GLM4 MoE Lite, Mixtral, and Nemotron-H (with
and without latent projections).
Architecture patterns tested:
**DeepSeek V3 / GLM4 MoE Lite** Gated-MLP shared expert
(gate_proj + up_proj SiLU gate down_proj).
Routed MoE dispatched first; shared expert on ``identity``.
Merge: ``moe_out + shared_out``.
**Mixtral** Pure routed MoE, *no* shared expert.
The transform must produce zero matches (no-op).
**Nemotron-H (no latent)** Simple-MLP shared expert
(up_proj ReLU² down_proj).
Shared expert dispatched first; routed MoE second.
Merge: ``shared_out + routed_out``.
**Nemotron-H (with latent projection)** Same shared expert as above,
but the routed path wraps the MoE op with
``fc1_latent_proj MoE fc2_latent_proj``.
Tests that the BFS from MoE to the merge ``add`` traverses the extra
projection nodes correctly.
Each architecture is tested for:
1. Pattern matching correct number of replacements.
2. Graph structure ``begin_aux``, ``end_aux``, ``wait_aux`` nodes present.
3. Numerical correctness output matches eager reference within tolerance.
4. CUDA graph compatibility capture + replay produces correct output.
5. Multi-layer stacking multiple MoE layers handled independently.
"""
import torch
import torch.nn as nn
from tensorrt_llm._torch.auto_deploy.transform.library.multi_stream_moe import (
_execute_shared_expert_in_aux_stream,
)
from tensorrt_llm._torch.auto_deploy.utils.multi_stream_utils import (
begin_aux_stream_passthrough,
cuda_stream_manager,
end_aux_stream_passthrough,
wait_aux_stream_passthrough,
)
# ---------------------------------------------------------------------------
# Mock fused-MoE custom op (distinct name to avoid conflicts with other tests)
# ---------------------------------------------------------------------------
@torch.library.custom_op("auto_deploy::mock_fused_moe_moe_test", mutates_args=())
def mock_fused_moe(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
expert_weight: torch.Tensor,
) -> torch.Tensor:
"""Mock fused MoE: a simple linear transform standing in for the real kernel."""
return torch.ops.aten.linear(x, expert_weight)
@mock_fused_moe.register_fake
def _mock_fused_moe_fake(x, selected_experts, routing_weights, expert_weight):
return torch.ops.aten.linear(x, expert_weight)
_MOE_OPS = [torch.ops.auto_deploy.mock_fused_moe_moe_test]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _build_gm(model, example_input):
"""Export *model* to an FX ``GraphModule``."""
return torch.export.export(model, (example_input,)).module()
def _stream_targets(gm):
"""Return the set of ``call_function`` targets present in *gm*."""
return {n.target for n in gm.graph.nodes if n.op == "call_function"}
def _assert_stream_nodes_present(gm):
"""Assert that the three stream-management passthrough nodes are in the graph."""
targets = _stream_targets(gm)
assert begin_aux_stream_passthrough in targets, "begin_aux_stream_passthrough not in graph"
assert end_aux_stream_passthrough in targets, "end_aux_stream_passthrough not in graph"
assert wait_aux_stream_passthrough in targets, "wait_aux_stream_passthrough not in graph"
def _assert_numerical_correctness(gm, model, test_x, *, atol=1e-5):
"""Assert that *gm* and *model* produce the same output on *test_x*."""
ref = model(test_x)
out = gm(test_x)
assert torch.allclose(out, ref, atol=atol), (
f"Output mismatch: max diff = {(out - ref).abs().max().item()}"
)
def _assert_cuda_graph_correctness(gm, model, test_x, *, atol=1e-5):
"""Assert correctness under CUDA graph capture + replay."""
ref = model(test_x)
out_shape = ref.shape
static_x = torch.randn_like(test_x)
static_out = torch.empty(out_shape, device="cuda", dtype=ref.dtype)
# Warm-up (required before capture).
for _ in range(3):
static_out.copy_(gm(static_x))
cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cuda_graph):
static_out.copy_(gm(static_x))
static_x.copy_(test_x)
cuda_graph.replay()
assert torch.allclose(static_out, ref, atol=atol), (
f"CUDA graph output mismatch: max diff = {(static_out - ref).abs().max().item()}"
)
# ---------------------------------------------------------------------------
# Mock modules — shared expert variants
# ---------------------------------------------------------------------------
class _GatedMLP(nn.Module):
"""DeepSeek / GLM4 shared expert: ``down_proj(silu(gate_proj(x)) * up_proj(x))``."""
def __init__(self, hidden_dim: int, intermediate_dim: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
self.up_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
self.down_proj = nn.Linear(intermediate_dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
class _SimpleMLP(nn.Module):
"""Nemotron-H shared expert: ``down_proj(relu(up_proj(x)) ** 2)``."""
def __init__(self, hidden_dim: int, intermediate_dim: int):
super().__init__()
self.up_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
self.down_proj = nn.Linear(intermediate_dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(torch.relu(self.up_proj(x)) ** 2)
# ---------------------------------------------------------------------------
# Mock MoE layer modules — one per architecture pattern
# ---------------------------------------------------------------------------
class MockDeepSeekGLM4MoELayer(nn.Module):
"""DeepSeek V3 / GLM4 MoE Lite pattern.
Graph topology::
hidden_states gate topk
shared_experts (gated MLP) shared_out
mock_fused_moe moe_out
moe_out + shared_out layernorm out
Routed MoE is dispatched *before* the shared expert in graph order.
The ``add`` has the routed output on the left.
"""
def __init__(self, hidden_dim: int, intermediate_dim: int, num_experts: int = 8):
super().__init__()
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
self.shared_experts = _GatedMLP(hidden_dim, intermediate_dim)
self.expert_weight = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
self.layernorm = nn.LayerNorm(hidden_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
identity = hidden_states
logits = self.gate(hidden_states)
routing_weights, selected_experts = torch.topk(logits, k=2, dim=-1)
# Routed path first (matches DeepSeek / GLM4 dispatch order).
moe_out = torch.ops.auto_deploy.mock_fused_moe_moe_test(
hidden_states, selected_experts, routing_weights, self.expert_weight
)
# Shared expert on original input.
shared_out = self.shared_experts(identity)
return self.layernorm(moe_out + shared_out)
class MockMixtralMoELayer(nn.Module):
"""Mixtral pattern — pure routed MoE, **no** shared expert.
The transform must return 0 replacements for this topology.
"""
def __init__(self, hidden_dim: int, num_experts: int = 8):
super().__init__()
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
self.expert_weight = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
self.layernorm = nn.LayerNorm(hidden_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.gate(hidden_states)
routing_weights, selected_experts = torch.topk(logits, k=2, dim=-1)
moe_out = torch.ops.auto_deploy.mock_fused_moe_moe_test(
hidden_states, selected_experts, routing_weights, self.expert_weight
)
return self.layernorm(moe_out)
class MockNemotronHMoELayer(nn.Module):
"""Nemotron-H pattern *without* latent projections.
Graph topology::
hidden_states gate topk
shared_experts (simple MLP) shared_out
mock_fused_moe moe_out
shared_out + moe_out layernorm out
Shared expert is dispatched *before* the MoE in graph order.
The ``add`` has the shared output on the left.
"""
def __init__(self, hidden_dim: int, intermediate_dim: int, num_experts: int = 8):
super().__init__()
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
self.shared_experts = _SimpleMLP(hidden_dim, intermediate_dim)
self.expert_weight = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
self.layernorm = nn.LayerNorm(hidden_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residuals = hidden_states
logits = self.gate(hidden_states)
routing_weights, selected_experts = torch.topk(logits, k=2, dim=-1)
# Shared expert dispatched first (matches Nemotron-H dispatch order).
shared_out = self.shared_experts(residuals)
# Routed path.
moe_out = torch.ops.auto_deploy.mock_fused_moe_moe_test(
hidden_states, selected_experts, routing_weights, self.expert_weight
)
return self.layernorm(shared_out + moe_out)
class MockNemotronHLatentMoELayer(nn.Module):
"""Nemotron-H pattern *with* latent projections.
Graph topology::
hidden_states gate topk
shared_experts (simple MLP) shared_out
fc1_latent mock_fused_moe fc2_latent routed_out
shared_out + routed_out ln out
The latent projections add nodes between the MoE op and the merge ``add``,
testing that the forward BFS from MoE correctly traverses extra projection
nodes.
"""
def __init__(
self,
hidden_dim: int,
intermediate_dim: int,
latent_dim: int,
num_experts: int = 8,
):
super().__init__()
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
self.shared_experts = _SimpleMLP(hidden_dim, intermediate_dim)
self.fc1_latent_proj = nn.Linear(hidden_dim, latent_dim, bias=False)
self.fc2_latent_proj = nn.Linear(latent_dim, hidden_dim, bias=False)
# Expert weight operates in latent space.
self.expert_weight = nn.Parameter(torch.randn(latent_dim, latent_dim))
self.layernorm = nn.LayerNorm(hidden_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residuals = hidden_states
logits = self.gate(hidden_states)
routing_weights, selected_experts = torch.topk(logits, k=2, dim=-1)
# Shared expert dispatched first.
shared_out = self.shared_experts(residuals)
# Latent projection → MoE → back-projection.
x_latent = self.fc1_latent_proj(hidden_states)
moe_out = torch.ops.auto_deploy.mock_fused_moe_moe_test(
x_latent, selected_experts, routing_weights, self.expert_weight
)
routed_out = self.fc2_latent_proj(moe_out)
return self.layernorm(shared_out + routed_out)
# ===================================================================
# Tests — DeepSeek V3 / GLM4 MoE Lite (gated-MLP shared expert)
# ===================================================================
def test_deepseek_glm4_pattern_and_correctness():
"""Single-layer: pattern match + graph structure + numerical correctness."""
hidden_dim, intermediate_dim = 128, 256
cuda_stream_manager.add_device(torch.cuda.current_device())
model = MockDeepSeekGLM4MoELayer(hidden_dim, intermediate_dim).eval().to("cuda")
example = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example)
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
assert num == 1, f"Expected 1 replacement, got {num}"
_assert_stream_nodes_present(gm)
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
def test_deepseek_glm4_cuda_graph():
"""CUDA graph capture + replay for DeepSeek / GLM4 pattern."""
hidden_dim, intermediate_dim = 128, 256
cuda_stream_manager.add_device(torch.cuda.current_device())
model = MockDeepSeekGLM4MoELayer(hidden_dim, intermediate_dim).eval().to("cuda")
example = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example)
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
assert num == 1
_assert_cuda_graph_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
def test_deepseek_glm4_multi_layer():
"""Two stacked DeepSeek/GLM4 MoE layers — both should be transformed."""
hidden_dim, intermediate_dim = 128, 256
cuda_stream_manager.add_device(torch.cuda.current_device())
model = (
nn.Sequential(
MockDeepSeekGLM4MoELayer(hidden_dim, intermediate_dim),
MockDeepSeekGLM4MoELayer(hidden_dim, intermediate_dim),
)
.eval()
.to("cuda")
)
example = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example)
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
assert num == 2, f"Expected 2 replacements, got {num}"
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
# ===================================================================
# Tests — Mixtral (no shared expert → no-op)
# ===================================================================
def test_mixtral_no_shared_expert_no_match():
"""Mixtral has no shared expert; the transform must produce zero matches."""
hidden_dim = 128
cuda_stream_manager.add_device(torch.cuda.current_device())
model = MockMixtralMoELayer(hidden_dim).eval().to("cuda")
example = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example)
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
assert num == 0, f"Expected 0 replacements for Mixtral (no shared expert), got {num}"
# Graph should NOT contain any stream-management nodes.
targets = _stream_targets(gm)
assert begin_aux_stream_passthrough not in targets
assert end_aux_stream_passthrough not in targets
assert wait_aux_stream_passthrough not in targets
# Numerical correctness should still hold (graph unchanged).
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
# ===================================================================
# Tests — Nemotron-H without latent projections
# ===================================================================
def test_nemotron_h_pattern_and_correctness():
"""Single-layer Nemotron-H (no latent): pattern + graph + correctness."""
hidden_dim, intermediate_dim = 128, 256
cuda_stream_manager.add_device(torch.cuda.current_device())
model = MockNemotronHMoELayer(hidden_dim, intermediate_dim).eval().to("cuda")
example = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example)
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
assert num == 1, f"Expected 1 replacement, got {num}"
_assert_stream_nodes_present(gm)
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
def test_nemotron_h_cuda_graph():
"""CUDA graph capture + replay for Nemotron-H (no latent) pattern."""
hidden_dim, intermediate_dim = 128, 256
cuda_stream_manager.add_device(torch.cuda.current_device())
model = MockNemotronHMoELayer(hidden_dim, intermediate_dim).eval().to("cuda")
example = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example)
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
assert num == 1
_assert_cuda_graph_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
def test_nemotron_h_multi_layer():
"""Two stacked Nemotron-H (no latent) layers — both should be transformed."""
hidden_dim, intermediate_dim = 128, 256
cuda_stream_manager.add_device(torch.cuda.current_device())
model = (
nn.Sequential(
MockNemotronHMoELayer(hidden_dim, intermediate_dim),
MockNemotronHMoELayer(hidden_dim, intermediate_dim),
)
.eval()
.to("cuda")
)
example = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example)
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
assert num == 2, f"Expected 2 replacements, got {num}"
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
# ===================================================================
# Tests — Nemotron-H with latent projections
# ===================================================================
def test_nemotron_h_latent_pattern_and_correctness():
"""Single-layer Nemotron-H (with latent): pattern + graph + correctness."""
hidden_dim, intermediate_dim, latent_dim = 128, 256, 64
cuda_stream_manager.add_device(torch.cuda.current_device())
model = MockNemotronHLatentMoELayer(hidden_dim, intermediate_dim, latent_dim).eval().to("cuda")
example = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example)
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
assert num == 1, f"Expected 1 replacement, got {num}"
_assert_stream_nodes_present(gm)
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
def test_nemotron_h_latent_cuda_graph():
"""CUDA graph capture + replay for Nemotron-H (with latent) pattern."""
hidden_dim, intermediate_dim, latent_dim = 128, 256, 64
cuda_stream_manager.add_device(torch.cuda.current_device())
model = MockNemotronHLatentMoELayer(hidden_dim, intermediate_dim, latent_dim).eval().to("cuda")
example = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example)
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
assert num == 1
_assert_cuda_graph_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
def test_nemotron_h_latent_multi_layer():
"""Two stacked Nemotron-H (with latent) layers — both should be transformed."""
hidden_dim, intermediate_dim, latent_dim = 128, 256, 64
cuda_stream_manager.add_device(torch.cuda.current_device())
model = (
nn.Sequential(
MockNemotronHLatentMoELayer(hidden_dim, intermediate_dim, latent_dim),
MockNemotronHLatentMoELayer(hidden_dim, intermediate_dim, latent_dim),
)
.eval()
.to("cuda")
)
example = torch.randn(4, hidden_dim, device="cuda")
gm = _build_gm(model, example)
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
assert num == 2, f"Expected 2 replacements, got {num}"
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))

View File

@ -0,0 +1,188 @@
import pytest
import torch
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.custom_ops.linear.swiglu import * # noqa
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
class SwiGLUMLP(torch.nn.Module):
"""SwiGLU MLP module: silu(x @ gate.T) * (x @ up.T) @ down.T"""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x):
return self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
class SwiGLUMLPWithBias(torch.nn.Module):
"""SwiGLU MLP module with biases."""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=True)
self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=True)
self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=True)
def forward(self, x):
return self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
class SwiGLUTestModel(torch.nn.Module):
"""Test model with SwiGLU MLP sandwiched between linear layers."""
def __init__(
self,
hidden_size: int = 256,
intermediate_size: int = 512,
with_bias: bool = False,
):
super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, hidden_size, device="cuda", dtype=torch.float16)
if with_bias:
self.mlp = SwiGLUMLPWithBias(hidden_size, intermediate_size)
else:
self.mlp = SwiGLUMLP(hidden_size, intermediate_size)
self.mlp = self.mlp.to(device="cuda", dtype=torch.float16)
self.linear2 = torch.nn.Linear(hidden_size, hidden_size, device="cuda", dtype=torch.float16)
def forward(self, x):
x = self.linear1(x)
x = self.mlp(x)
x = self.linear2(x)
return x
class SwiGLUTestModelMultipleMLP(torch.nn.Module):
"""Test model with multiple SwiGLU MLPs to test multiple pattern matches."""
def __init__(
self,
hidden_size: int = 256,
intermediate_size: int = 512,
num_layers: int = 2,
):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
torch.nn.ModuleDict(
{
"linear": torch.nn.Linear(
hidden_size, hidden_size, device="cuda", dtype=torch.float16
),
"mlp": SwiGLUMLP(hidden_size, intermediate_size).to(
device="cuda", dtype=torch.float16
),
}
)
)
def forward(self, x):
for layer in self.layers:
x = layer["linear"](x)
x = layer["mlp"](x)
return x
def _run_fusion_test(model, expected_op, expected_num_matches=1):
"""Run the SwiGLU fusion test.
Args:
model: The test model to transform.
expected_op: The expected fused op to find in the transformed graph.
expected_num_matches: Expected number of fused ops.
"""
x = torch.randn(2, 256, device="cuda", dtype=torch.float16)
dynamic_shapes = {0: Dim.DYNAMIC}
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
# Apply transforms
gm_transformed = InferenceOptimizer(
None,
{
"match_swiglu_pattern": {
"stage": "pattern_matcher",
},
"fuse_swiglu": {
"stage": "post_load_fusion",
"enabled": True,
},
},
)(None, gm)
# Move to CUDA if needed
gm_transformed = gm_transformed.to("cuda")
# Check that the expected op is present
count = sum(1 for n in gm_transformed.graph.nodes if is_op(n, expected_op))
assert count == expected_num_matches, (
f"Expected {expected_num_matches} {expected_op} ops, got {count}"
)
# Verify numerical correctness
y_transformed = gm_transformed(x)
y_model = model(x)
torch.testing.assert_close(y_transformed, y_model, atol=1e-2, rtol=1e-2)
# Test with a different batch size
new_input = torch.randn(4, 256, device="cuda", dtype=torch.float16)
y_transformed_2 = gm_transformed(new_input)
y_model_2 = model(new_input)
torch.testing.assert_close(y_transformed_2, y_model_2, atol=1e-2, rtol=1e-2)
def test_swiglu_fusion_basic():
"""Test basic SwiGLU fusion without biases."""
model = SwiGLUTestModel(with_bias=False)
_run_fusion_test(model, torch.ops.auto_deploy.fused_swiglu_mlp.default)
def test_swiglu_fusion_with_bias():
"""Test SwiGLU fusion with biases."""
model = SwiGLUTestModel(with_bias=True)
_run_fusion_test(model, torch.ops.auto_deploy.fused_swiglu_mlp.default)
@pytest.mark.parametrize("num_layers", [2, 3])
def test_swiglu_fusion_multiple_layers(num_layers):
"""Test that multiple SwiGLU patterns are fused correctly."""
model = SwiGLUTestModelMultipleMLP(num_layers=num_layers)
_run_fusion_test(
model, torch.ops.auto_deploy.fused_swiglu_mlp.default, expected_num_matches=num_layers
)
def test_swiglu_pattern_match_only():
"""Test pattern matching stage only (without fusion)."""
model = SwiGLUTestModel()
x = torch.randn(2, 256, device="cuda", dtype=torch.float16)
dynamic_shapes = {0: Dim.DYNAMIC}
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
# Only run pattern matching, not fusion
gm_matched = InferenceOptimizer(
None,
{
"match_swiglu_pattern": {
"stage": "pattern_matcher",
},
},
)(None, gm)
# Check that the intermediate op is present
has_swiglu_op = any(
is_op(n, torch.ops.auto_deploy.torch_swiglu_mlp.default) for n in gm_matched.graph.nodes
)
assert has_swiglu_op, "Pattern matcher should produce torch_swiglu_mlp op"
# Verify numerical correctness
y_matched = gm_matched(x)
y_model = model(x)
torch.testing.assert_close(y_matched, y_model, atol=1e-3, rtol=1e-3)

View File

@ -1,14 +1,26 @@
import operator
import torch
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.flashinfer_fused_add_rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.flashinfer_fused_add_rms_norm import ( # noqa
flashinfer_fused_add_rms_norm,
)
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.interface import TransformConfig
from tensorrt_llm._torch.auto_deploy.transform.library.fused_add_rms_norm import FuseAddRMSNorm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
# ---------------------------------------------------------------------------
# Models
# ---------------------------------------------------------------------------
class AddCastNormModel(torch.nn.Module):
"""Pattern 1: add + cast(to.dtype) + rms_norm."""
class TestModel(torch.nn.Module):
def __init__(self, hidden_size=128, eps=1e-5):
super().__init__()
self.weight = torch.nn.Parameter(
@ -23,55 +35,263 @@ class TestModel(torch.nn.Module):
return norm, added
def _run_test(model):
# The replacement uses flashinfer_fused_add_rms_norm python wrapper which calls the inplace op
# auto_deploy::flashinfer_fused_add_rms_norm_inplace
op = torch.ops.auto_deploy.flashinfer_fused_add_rms_norm_inplace
class AddNormModel(torch.nn.Module):
"""Pattern 2: add + rms_norm (no intermediate cast)."""
def checker(gm):
return any(is_op(n, op) for n in gm.graph.nodes)
def __init__(self, hidden_size=128, eps=1e-5):
super().__init__()
self.weight = torch.nn.Parameter(
torch.ones(hidden_size, device="cuda", dtype=torch.bfloat16)
)
self.eps = eps
bsz, seq_len, hidden = 2, 8, 128
# Inputs should be bfloat16
x = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
residual = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
def forward(self, x, residual):
added = x + residual
norm = torch.ops.auto_deploy.flashinfer_rms_norm(added, self.weight, self.eps)
return norm, added
# Dynamic shapes
dyn_batch_size = Dim.DYNAMIC
ds_x = {0: dyn_batch_size}
ds_res = {0: dyn_batch_size}
gm = torch_export_to_gm(model, args=(x, residual), dynamic_shapes=(ds_x, ds_res), clone=True)
class MultiUserModel(torch.nn.Module):
"""Both add and rms_norm outputs have multiple users (DeepSeek V3 MoE pattern).
gm_transformed = InferenceOptimizer(
add_result has 2 users: rms_norm + next residual add
norm_result has 2 users: linear1 + linear2
"""
def __init__(self, hidden_size=128, eps=1e-5):
super().__init__()
self.weight = torch.nn.Parameter(
torch.ones(hidden_size, device="cuda", dtype=torch.bfloat16)
)
self.linear1 = torch.nn.Linear(
hidden_size, hidden_size, bias=False, device="cuda", dtype=torch.bfloat16
)
self.linear2 = torch.nn.Linear(
hidden_size, hidden_size, bias=False, device="cuda", dtype=torch.bfloat16
)
self.eps = eps
def forward(self, residual, attn_output, moe_output):
# add with 2 users (norm + next_add)
add_result = residual + attn_output
# rms_norm with 2 users (linear1, linear2)
norm_result = torch.ops.auto_deploy.flashinfer_rms_norm(add_result, self.weight, self.eps)
out1 = self.linear1(norm_result)
out2 = self.linear2(norm_result)
combined = out1 + out2
# add_result also feeds into next residual add
next_residual = add_result + moe_output
return combined, next_residual
class ChainedModel(torch.nn.Module):
"""Two consecutive add+norm pairs sharing residual (like transformer layers).
Layer 1: add1 = embed + attn_out, norm1 = rms_norm(add1) -- add1 has 2 users
Layer 2: add2 = add1 + mlp_out, norm2 = rms_norm(add2) -- add2 has 2 users
"""
def __init__(self, hidden_size=128, eps=1e-5):
super().__init__()
self.weight1 = torch.nn.Parameter(
torch.ones(hidden_size, device="cuda", dtype=torch.bfloat16)
)
self.weight2 = torch.nn.Parameter(
torch.ones(hidden_size, device="cuda", dtype=torch.bfloat16)
)
self.linear = torch.nn.Linear(
hidden_size, hidden_size, bias=False, device="cuda", dtype=torch.bfloat16
)
self.eps = eps
def forward(self, embed, attn_out, mlp_out):
add1 = embed + attn_out
norm1 = torch.ops.auto_deploy.flashinfer_rms_norm(add1, self.weight1, self.eps)
branch1 = self.linear(norm1)
add2 = add1 + mlp_out
norm2 = torch.ops.auto_deploy.flashinfer_rms_norm(add2, self.weight2, self.eps)
branch2 = self.linear(norm2)
return branch1 + branch2, add2
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _count_fused_ops(gm):
"""Count flashinfer_fused_add_rms_norm wrapper calls in the graph."""
return sum(
1
for n in gm.graph.nodes
if n.op == "call_function" and n.target is flashinfer_fused_add_rms_norm
)
def _count_rms_norm_ops(gm):
"""Count flashinfer_rms_norm calls in the graph."""
return sum(1 for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.flashinfer_rms_norm))
def _count_add_ops(gm):
"""Count aten.add.Tensor calls in the graph."""
return sum(1 for n in gm.graph.nodes if is_op(n, torch.ops.aten.add.Tensor))
def _export_model(model, *inputs, dynamic_dim0=True):
"""Export a model to a GraphModule, optionally with a dynamic batch dimension."""
if dynamic_dim0:
dyn = Dim.DYNAMIC
ds = tuple({0: dyn} for _ in inputs)
else:
ds = None
return torch_export_to_gm(model, args=inputs, dynamic_shapes=ds, clone=True)
def _apply_transform(gm):
"""Apply fuse_add_rms_norm via InferenceOptimizer (integration-style)."""
return InferenceOptimizer(
None,
{
"fuse_add_rms_norm": {
"stage": "post_load_fusion",
},
},
{"fuse_add_rms_norm": {"stage": "post_load_fusion"}},
)(None, gm)
# Check if transform happened
if not checker(gm_transformed):
raise AssertionError(
"flashinfer_fused_add_rms_norm_inplace op not found in transformed graph"
)
# Validation
# Clone inputs because the fused op is inplace
x_in = x.clone()
res_in = residual.clone()
# The fused op is inplace, so inputs x_in and res_in will be modified.
# gm_transformed returns (x_in, res_in) which are the modified tensors.
y_transformed = gm_transformed(x_in, res_in)
y_model = model(x.clone(), residual.clone())
torch.testing.assert_close(y_transformed[0], y_model[0], atol=1e-2, rtol=1e-2)
torch.testing.assert_close(y_transformed[1], y_model[1], atol=1e-2, rtol=1e-2)
def _apply_transform_direct(gm):
"""Apply the transform directly (unit-test style)."""
config = TransformConfig(stage="post_load_fusion")
transform = FuseAddRMSNorm(config=config)
gm, info = transform._apply(gm, None, None, None)
return gm, info
def test_fuse_add_rms_norm():
model = TestModel()
_run_test(model)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_fuse_add_cast_rms_norm():
"""Original test: add + cast(bf16) + rms_norm → fused op."""
model = AddCastNormModel()
bsz, seq_len, hidden = 2, 8, 128
x = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
residual = torch.randn_like(x)
gm = _export_model(model, x, residual)
gm_t = _apply_transform(gm)
# Structure check
assert _count_fused_ops(gm_t) >= 1, "fused op not found in graph"
assert _count_rms_norm_ops(gm_t) == 0, "unfused rms_norm still in graph"
# Numerical check
y_fused = gm_t(x.clone(), residual.clone())
y_ref = model(x.clone(), residual.clone())
torch.testing.assert_close(y_fused[0], y_ref[0], atol=1e-2, rtol=1e-2)
torch.testing.assert_close(y_fused[1], y_ref[1], atol=1e-2, rtol=1e-2)
def test_fuse_add_rms_norm_no_cast():
"""Pattern 2: add + rms_norm (no cast) → fused op."""
model = AddNormModel()
bsz, seq_len, hidden = 2, 8, 128
x = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
residual = torch.randn_like(x)
gm = _export_model(model, x, residual)
gm_t = _apply_transform(gm)
# Structure check
assert _count_fused_ops(gm_t) >= 1, "fused op not found in graph"
assert _count_rms_norm_ops(gm_t) == 0, "unfused rms_norm still in graph"
# Numerical check
y_fused = gm_t(x.clone(), residual.clone())
y_ref = model(x.clone(), residual.clone())
torch.testing.assert_close(y_fused[0], y_ref[0], atol=1e-2, rtol=1e-2)
torch.testing.assert_close(y_fused[1], y_ref[1], atol=1e-2, rtol=1e-2)
def test_fuse_add_rms_norm_multi_user():
"""Multi-user: both add (2 users) and rms_norm (2 users) → fused op.
This is the key pattern from the DeepSeek V3 / GLM4-MoE graph that failed
with the old inductor-based pattern matcher due to num_users constraints.
"""
model = MultiUserModel()
bsz, seq_len, hidden = 2, 8, 128
residual = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
attn_out = torch.randn_like(residual)
moe_out = torch.randn_like(residual)
gm = _export_model(model, residual, attn_out, moe_out)
# Before: 1 add+norm fusible pair, add has 2 users, norm has 2 users
assert _count_rms_norm_ops(gm) == 1
gm_t, info = _apply_transform_direct(gm)
# Structure check
assert info.num_matches == 1, f"Expected 1 match, got {info.num_matches}"
assert _count_fused_ops(gm_t) == 1, "fused op not found in graph"
assert _count_rms_norm_ops(gm_t) == 0, "unfused rms_norm still in graph"
# Verify getitem nodes for both outputs
getitems = [
n
for n in gm_t.graph.nodes
if n.op == "call_function"
and n.target is operator.getitem
and isinstance(n.args[0], torch.fx.Node)
and n.args[0].target is flashinfer_fused_add_rms_norm
]
assert len(getitems) == 2, f"Expected 2 getitem nodes, got {len(getitems)}"
# Numerical check
y_fused = gm_t(residual.clone(), attn_out.clone(), moe_out.clone())
y_ref = model(residual.clone(), attn_out.clone(), moe_out.clone())
torch.testing.assert_close(y_fused[0], y_ref[0], atol=1e-2, rtol=1e-2)
torch.testing.assert_close(y_fused[1], y_ref[1], atol=1e-2, rtol=1e-2)
def test_fuse_add_rms_norm_chained():
"""Chained: two consecutive add+norm pairs across transformer layers."""
model = ChainedModel()
bsz, seq_len, hidden = 2, 8, 128
embed = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
attn_out = torch.randn_like(embed)
mlp_out = torch.randn_like(embed)
gm = _export_model(model, embed, attn_out, mlp_out)
# Before: 2 add+norm fusible pairs
assert _count_rms_norm_ops(gm) == 2
gm_t, info = _apply_transform_direct(gm)
# Structure check
assert info.num_matches == 2, f"Expected 2 matches, got {info.num_matches}"
assert _count_fused_ops(gm_t) == 2, "Expected 2 fused ops"
assert _count_rms_norm_ops(gm_t) == 0, "unfused rms_norm still in graph"
# Verify second fused op receives add_out from first fused op (residual chain)
fused_nodes = [
n
for n in gm_t.graph.nodes
if n.op == "call_function" and n.target is flashinfer_fused_add_rms_norm
]
assert len(fused_nodes) == 2
# The second fused op's residual arg should be a getitem from the first fused op
second_residual_arg = fused_nodes[1].args[1]
assert (
second_residual_arg.op == "call_function"
and second_residual_arg.target is operator.getitem
and second_residual_arg.args[0] is fused_nodes[0]
), "Second fused op's residual should come from first fused op's add_out"
# Numerical check
y_fused = gm_t(embed.clone(), attn_out.clone(), mlp_out.clone())
y_ref = model(embed.clone(), attn_out.clone(), mlp_out.clone())
torch.testing.assert_close(y_fused[0], y_ref[0], atol=1e-2, rtol=1e-2)
torch.testing.assert_close(y_fused[1], y_ref[1], atol=1e-2, rtol=1e-2)

View File

@ -0,0 +1,378 @@
"""Tests for NVFP4 quantized SwiGLU pattern matching and fusion transforms.
Tests the parallel NVFP4 SwiGLU path:
1. match_nvfp4_swiglu_pattern: Matches torch_fake_quant_nvfp4_linear SwiGLU -> torch_nvfp4_swiglu_mlp
2. fuse_nvfp4_swiglu: Fuses gate+up FP4 weights -> fused_nvfp4_swiglu_mlp
"""
import pytest
import torch
import torch.nn as nn
from _torch_test_utils import fp4_compatible, trtllm_ops_available
from torch.export import Dim
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale
_skip_reason = "Requires NVFP4 (Blackwell+) and TRT-LLM ops"
_skip_condition = not (fp4_compatible() and trtllm_ops_available())
class NVFP4SwiGLUMLP(nn.Module):
"""SwiGLU MLP using NVFP4 quantized linear ops.
Mimics the graph structure produced by quantize_nvfp4_linear_from_config
applied to a standard SwiGLU MLP: silu(gate(x)) * up(x) -> down(hidden).
"""
def __init__(self, hidden_size: int = 128, intermediate_size: int = 128):
super().__init__()
device = torch.device("cuda")
scaling_vector_size = 16
# Create random weights and quantize them to FP4
gate_weight = (
torch.randn(intermediate_size, hidden_size, dtype=torch.half, device=device) * 0.05
)
up_weight = (
torch.randn(intermediate_size, hidden_size, dtype=torch.half, device=device) * 0.05
)
down_weight = (
torch.randn(hidden_size, intermediate_size, dtype=torch.half, device=device) * 0.05
)
# Quantize gate projection
s_w_gate = fp4_global_scale(gate_weight)
gate_fp4, gate_cutlass = torch.ops.trtllm.fp4_quantize(
gate_weight, s_w_gate, scaling_vector_size, False
)
# Use a shared input scale for gate and up (same input x)
s_in = fp4_global_scale(torch.randn(1, hidden_size, dtype=torch.half, device=device))
gate_alpha = (1.0 / (s_in * s_w_gate)).to(torch.float32)
self.register_buffer("gate_weight", gate_fp4)
self.register_buffer("gate_input_scale", s_in.to(torch.float32))
self.register_buffer("gate_weight_scale", gate_cutlass)
self.register_buffer("gate_alpha", gate_alpha)
# Quantize up projection (same input scale as gate)
s_w_up = fp4_global_scale(up_weight)
up_fp4, up_cutlass = torch.ops.trtllm.fp4_quantize(
up_weight, s_w_up, scaling_vector_size, False
)
up_alpha = (1.0 / (s_in * s_w_up)).to(torch.float32)
self.register_buffer("up_weight", up_fp4)
self.register_buffer("up_input_scale", s_in.to(torch.float32))
self.register_buffer("up_weight_scale", up_cutlass)
self.register_buffer("up_alpha", up_alpha)
# Quantize down projection (different input: the hidden state)
s_in_down = fp4_global_scale(
torch.randn(1, intermediate_size, dtype=torch.half, device=device)
)
s_w_down = fp4_global_scale(down_weight)
down_fp4, down_cutlass = torch.ops.trtllm.fp4_quantize(
down_weight, s_w_down, scaling_vector_size, False
)
down_alpha = (1.0 / (s_in_down * s_w_down)).to(torch.float32)
self.register_buffer("down_weight", down_fp4)
self.register_buffer("down_input_scale", s_in_down.to(torch.float32))
self.register_buffer("down_weight_scale", down_cutlass)
self.register_buffer("down_alpha", down_alpha)
def forward(self, x):
gate_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
x,
self.gate_weight,
None,
[self.gate_input_scale],
[self.gate_weight_scale, self.gate_alpha],
[],
[],
)
up_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
x,
self.up_weight,
None,
[self.up_input_scale],
[self.up_weight_scale, self.up_alpha],
[],
[],
)
hidden = torch.nn.functional.silu(gate_out) * up_out
return torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
hidden,
self.down_weight,
None,
[self.down_input_scale],
[self.down_weight_scale, self.down_alpha],
[],
[],
)
class NVFP4SwiGLUTestModel(nn.Module):
"""Test model wrapping NVFP4 SwiGLU MLP between linear layers."""
def __init__(self, hidden_size: int = 128, intermediate_size: int = 128):
super().__init__()
device = torch.device("cuda")
self.linear_in = nn.Linear(hidden_size, hidden_size, device=device, dtype=torch.float16)
self.mlp = NVFP4SwiGLUMLP(hidden_size, intermediate_size)
self.linear_out = nn.Linear(hidden_size, hidden_size, device=device, dtype=torch.float16)
def forward(self, x):
x = self.linear_in(x)
x = self.mlp(x)
x = self.linear_out(x)
return x
class NVFP4SwiGLUMultiLayerModel(nn.Module):
"""Test model with multiple NVFP4 SwiGLU MLP layers."""
def __init__(
self,
hidden_size: int = 128,
intermediate_size: int = 128,
num_layers: int = 2,
):
super().__init__()
device = torch.device("cuda")
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
nn.ModuleDict(
{
"linear": nn.Linear(
hidden_size,
hidden_size,
device=device,
dtype=torch.float16,
),
"mlp": NVFP4SwiGLUMLP(hidden_size, intermediate_size),
}
)
)
def forward(self, x):
for layer in self.layers:
x = layer["linear"](x)
x = layer["mlp"](x)
return x
# -- Test helpers --------------------------------------------------------------
def _count_ops(gm, op):
"""Count how many nodes in the graph match the given op."""
return sum(1 for n in gm.graph.nodes if is_op(n, op))
def _has_no_fake_quant_nvfp4(gm):
"""Verify no torch_fake_quant_nvfp4_linear ops remain."""
return _count_ops(gm, torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear) == 0
# -- Tests ---------------------------------------------------------------------
@pytest.mark.skipif(_skip_condition, reason=_skip_reason)
def test_nvfp4_swiglu_pattern_match_only():
"""Test that match_nvfp4_swiglu_pattern produces torch_nvfp4_swiglu_mlp op."""
torch.manual_seed(0)
model = NVFP4SwiGLUMLP().to("cuda")
x = torch.randn(2, 128, device="cuda", dtype=torch.float16)
gm = torch_export_to_gm(model, args=(x,), clone=True)
# Verify the graph has torch_fake_quant_nvfp4_linear ops before transform
assert _count_ops(gm, torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear) == 3, (
"Expected 3 torch_fake_quant_nvfp4_linear ops (gate, up, down) before transform"
)
# Apply only pattern matching
gm_matched = InferenceOptimizer(
None,
{
"match_nvfp4_swiglu_pattern": {
"stage": "pattern_matcher",
},
},
)(None, gm)
# Check the intermediate op is present
nvfp4_swiglu_count = _count_ops(
gm_matched, torch.ops.auto_deploy.torch_nvfp4_swiglu_mlp.default
)
assert nvfp4_swiglu_count == 1, (
f"Expected 1 torch_nvfp4_swiglu_mlp op, got {nvfp4_swiglu_count}"
)
# All 3 fake_quant_nvfp4 ops should be consumed
assert _has_no_fake_quant_nvfp4(gm_matched), (
"torch_fake_quant_nvfp4_linear ops should be consumed by pattern matcher"
)
# Verify numerical correctness
gm_matched = gm_matched.to("cuda")
y_matched = gm_matched(x)
y_model = model(x)
torch.testing.assert_close(y_matched, y_model, atol=1e-3, rtol=1e-3)
@pytest.mark.skipif(_skip_condition, reason=_skip_reason)
def test_nvfp4_swiglu_full_fusion():
"""Test full pipeline: pattern match -> fuse -> fused_nvfp4_swiglu_mlp."""
torch.manual_seed(0)
model = NVFP4SwiGLUTestModel().to("cuda")
x = torch.randn(2, 128, device="cuda", dtype=torch.float16)
gm = torch_export_to_gm(model, args=(x,), clone=True, dynamic_shapes=({0: Dim.DYNAMIC},))
# Apply pattern matching + fusion
gm_fused = InferenceOptimizer(
None,
{
"match_nvfp4_swiglu_pattern": {
"stage": "pattern_matcher",
},
"fuse_nvfp4_swiglu": {
"stage": "post_load_fusion",
},
},
)(None, gm)
gm_fused = gm_fused.to("cuda")
# Check the fused op is present
fused_count = _count_ops(gm_fused, torch.ops.auto_deploy.fused_nvfp4_swiglu_mlp.default)
assert fused_count == 1, f"Expected 1 fused_nvfp4_swiglu_mlp op, got {fused_count}"
# No intermediate or unfused ops should remain
assert _count_ops(gm_fused, torch.ops.auto_deploy.torch_nvfp4_swiglu_mlp.default) == 0, (
"Intermediate torch_nvfp4_swiglu_mlp should be replaced by fused version"
)
assert _has_no_fake_quant_nvfp4(gm_fused), (
"No torch_fake_quant_nvfp4_linear ops should remain after fusion"
)
# Verify numerical correctness (fused uses TRT-LLM kernel, allow wider tolerance)
y_fused = gm_fused(x)
y_model = model(x)
torch.testing.assert_close(y_fused, y_model, atol=0.15, rtol=0.05)
# Test with a different batch size to verify dynamic shapes work
x2 = torch.randn(4, 128, device="cuda", dtype=torch.float16)
y_fused_2 = gm_fused(x2)
y_model_2 = model(x2)
torch.testing.assert_close(y_fused_2, y_model_2, atol=0.15, rtol=0.05)
@pytest.mark.skipif(_skip_condition, reason=_skip_reason)
@pytest.mark.parametrize("num_layers", [2, 3])
def test_nvfp4_swiglu_fusion_multiple_layers(num_layers):
"""Test that multiple NVFP4 SwiGLU patterns are fused correctly."""
torch.manual_seed(0)
model = NVFP4SwiGLUMultiLayerModel(num_layers=num_layers).to("cuda")
x = torch.randn(2, 128, device="cuda", dtype=torch.float16)
gm = torch_export_to_gm(model, args=(x,), clone=True)
# Apply pattern matching + fusion
gm_fused = InferenceOptimizer(
None,
{
"match_nvfp4_swiglu_pattern": {
"stage": "pattern_matcher",
},
"fuse_nvfp4_swiglu": {
"stage": "post_load_fusion",
},
},
)(None, gm)
gm_fused = gm_fused.to("cuda")
# Check that all layers are fused
fused_count = _count_ops(gm_fused, torch.ops.auto_deploy.fused_nvfp4_swiglu_mlp.default)
assert fused_count == num_layers, (
f"Expected {num_layers} fused_nvfp4_swiglu_mlp ops, got {fused_count}"
)
# Verify numerical correctness
y_fused = gm_fused(x)
y_model = model(x)
torch.testing.assert_close(y_fused, y_model, atol=0.2, rtol=0.1)
@pytest.mark.skipif(_skip_condition, reason=_skip_reason)
def test_nvfp4_swiglu_does_not_match_non_swiglu():
"""Test that the NVFP4 SwiGLU matcher does not match non-SwiGLU NVFP4 linears."""
torch.manual_seed(0)
device = torch.device("cuda")
hidden_size = 128
# Model with two sequential NVFP4 linears + relu (NOT a SwiGLU pattern)
class NonSwiGLUModel(nn.Module):
def __init__(self):
super().__init__()
w1 = torch.randn(hidden_size, hidden_size, dtype=torch.half, device=device) * 0.05
w2 = torch.randn(hidden_size, hidden_size, dtype=torch.half, device=device) * 0.05
s_in = fp4_global_scale(torch.randn(1, hidden_size, dtype=torch.half, device=device))
s_w1 = fp4_global_scale(w1)
s_w2 = fp4_global_scale(w2)
w1_fp4, w1_cutlass = torch.ops.trtllm.fp4_quantize(w1, s_w1, 16, False)
w2_fp4, w2_cutlass = torch.ops.trtllm.fp4_quantize(w2, s_w2, 16, False)
self.register_buffer("w1", w1_fp4)
self.register_buffer("w1_is", s_in.to(torch.float32))
self.register_buffer("w1_ws", w1_cutlass)
self.register_buffer("w1_a", (1.0 / (s_in * s_w1)).to(torch.float32))
self.register_buffer("w2", w2_fp4)
self.register_buffer("w2_is", s_in.to(torch.float32))
self.register_buffer("w2_ws", w2_cutlass)
self.register_buffer("w2_a", (1.0 / (s_in * s_w2)).to(torch.float32))
def forward(self, x):
# Sequential linears without SwiGLU pattern
y = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
x, self.w1, None, [self.w1_is], [self.w1_ws, self.w1_a], [], []
)
y = torch.nn.functional.relu(y)
return torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
y, self.w2, None, [self.w2_is], [self.w2_ws, self.w2_a], [], []
)
model = NonSwiGLUModel().to("cuda")
x = torch.randn(2, hidden_size, device="cuda", dtype=torch.float16)
gm = torch_export_to_gm(model, args=(x,), clone=True)
gm_result = InferenceOptimizer(
None,
{
"match_nvfp4_swiglu_pattern": {
"stage": "pattern_matcher",
},
},
)(None, gm)
# No SwiGLU ops should be found
assert _count_ops(gm_result, torch.ops.auto_deploy.torch_nvfp4_swiglu_mlp.default) == 0, (
"Non-SwiGLU NVFP4 pattern should not match"
)
# Original NVFP4 linear ops should still be present
assert _count_ops(gm_result, torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear) == 2, (
"Original NVFP4 linear ops should be unchanged"
)

View File

@ -356,106 +356,98 @@ def test_moe_export_with_reduced_experts(
# Real-model MOE export: GLM4 MoE Lite
# ---------------------------------------------------------------------------
try:
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_glm4_moe_lite import (
Glm4MoeLiteConfig,
Glm4MoeLiteForCausalLM,
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_glm4_moe_lite import ( # noqa: E402
Glm4MoeLiteConfig,
Glm4MoeLiteForCausalLM,
)
def _make_tiny_glm4_config(n_routed_experts: int = 8) -> Glm4MoeLiteConfig:
"""Create a minimal ``Glm4MoeLiteConfig`` suitable for unit tests."""
return Glm4MoeLiteConfig(
vocab_size=256,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=4,
q_lora_rank=32,
kv_lora_rank=32,
qk_nope_head_dim=12,
qk_rope_head_dim=4,
v_head_dim=16,
n_routed_experts=n_routed_experts,
n_shared_experts=1,
num_experts_per_tok=2,
moe_intermediate_size=64,
n_group=1,
topk_group=1,
routed_scaling_factor=1.0,
norm_topk_prob=True,
first_k_dense_replace=1, # layer 0 = dense MLP, layer 1 = MoE
max_position_embeddings=128,
rope_scaling=None,
pad_token_id=0,
)
_HAS_GLM4 = True
except ImportError:
_HAS_GLM4 = False
def _count_moe_experts_in_graph(gm: GraphModule) -> int:
"""Return the number of experts in the first ``torch_moe`` call in *gm*."""
for node in gm.graph.nodes:
if node.op == "call_function" and "torch_moe" in str(node.target):
return len(node.args[3]) # w1_weight list length
return 0
if _HAS_GLM4:
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="GLM4 MoE Lite requires CUDA (uses noaux_tc_op)"
)
@pytest.mark.parametrize("n_routed_experts", [8, 16])
@pytest.mark.parametrize("num_moe_experts_for_export", [2])
def test_glm4_moe_lite_export_with_reduced_experts(n_routed_experts, num_moe_experts_for_export):
"""Export a tiny ``Glm4MoeLiteForCausalLM`` with reduced experts and verify
that the expanded graph has the correct structure and accepts the original
state dict.
"""
# GLM4 MoE Lite uses noaux_tc_op which is CUDA-only, so we must use CUDA device
device = "cuda"
config = _make_tiny_glm4_config(n_routed_experts=n_routed_experts)
model = Glm4MoeLiteForCausalLM(config).to(device)
model.eval()
def _make_tiny_glm4_config(n_routed_experts: int = 8) -> Glm4MoeLiteConfig:
"""Create a minimal ``Glm4MoeLiteConfig`` suitable for unit tests."""
return Glm4MoeLiteConfig(
vocab_size=256,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=4,
q_lora_rank=32,
kv_lora_rank=32,
qk_nope_head_dim=12,
qk_rope_head_dim=4,
v_head_dim=16,
n_routed_experts=n_routed_experts,
n_shared_experts=1,
num_experts_per_tok=2,
moe_intermediate_size=64,
n_group=1,
topk_group=1,
routed_scaling_factor=1.0,
norm_topk_prob=True,
first_k_dense_replace=1, # layer 0 = dense MLP, layer 1 = MoE
max_position_embeddings=128,
rope_scaling=None,
pad_token_id=0,
)
input_ids = torch.randint(0, config.vocab_size, (1, 8), device=device)
position_ids = torch.arange(8, device=device).unsqueeze(0)
sample_kwargs = {"input_ids": input_ids, "position_ids": position_ids}
def _count_moe_experts_in_graph(gm: GraphModule) -> int:
"""Return the number of experts in the first ``torch_moe`` call in *gm*."""
for node in gm.graph.nodes:
if node.op == "call_function" and "torch_moe" in str(node.target):
return len(node.args[3]) # w1_weight list length
return 0
# --- full export (baseline) ---
gm_full = torch_export_to_gm(model, kwargs=sample_kwargs)
@pytest.mark.skipif(not _HAS_GLM4, reason="GLM4 MoE Lite model not available on this branch")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="GLM4 MoE Lite requires CUDA (uses noaux_tc_op)"
# --- export with reduced experts ---
gm_reduced = torch_export_to_gm(
model,
kwargs=sample_kwargs,
num_moe_experts_for_export=num_moe_experts_for_export,
)
@pytest.mark.parametrize("n_routed_experts", [8, 16])
@pytest.mark.parametrize("num_moe_experts_for_export", [2])
def test_glm4_moe_lite_export_with_reduced_experts(
n_routed_experts, num_moe_experts_for_export
):
"""Export a tiny ``Glm4MoeLiteForCausalLM`` with reduced experts and verify
that the expanded graph has the correct structure and accepts the original
state dict.
"""
# GLM4 MoE Lite uses noaux_tc_op which is CUDA-only, so we must use CUDA device
device = "cuda"
config = _make_tiny_glm4_config(n_routed_experts=n_routed_experts)
model = Glm4MoeLiteForCausalLM(config).to(device)
model.eval()
input_ids = torch.randint(0, config.vocab_size, (1, 8), device=device)
position_ids = torch.arange(8, device=device).unsqueeze(0)
sample_kwargs = {"input_ids": input_ids, "position_ids": position_ids}
# Structural: both graphs must expose all experts
assert _count_moe_experts_in_graph(gm_full) == n_routed_experts
assert _count_moe_experts_in_graph(gm_reduced) == n_routed_experts
# --- full export (baseline) ---
gm_full = torch_export_to_gm(model, kwargs=sample_kwargs)
# State-dict keys must match between full and reduced exports
full_keys = set(gm_full.state_dict().keys())
reduced_keys = set(gm_reduced.state_dict().keys())
assert full_keys == reduced_keys, (
f"State-dict key mismatch.\n"
f" Only in full: {full_keys - reduced_keys}\n"
f" Only in reduced: {reduced_keys - full_keys}"
)
# --- export with reduced experts ---
gm_reduced = torch_export_to_gm(
model,
kwargs=sample_kwargs,
num_moe_experts_for_export=num_moe_experts_for_export,
)
# Load the original model weights into the reduced export graph
gm_reduced.load_state_dict(model.state_dict(), strict=False)
# Structural: both graphs must expose all experts
assert _count_moe_experts_in_graph(gm_full) == n_routed_experts
assert _count_moe_experts_in_graph(gm_reduced) == n_routed_experts
# State-dict keys must match between full and reduced exports
full_keys = set(gm_full.state_dict().keys())
reduced_keys = set(gm_reduced.state_dict().keys())
assert full_keys == reduced_keys, (
f"State-dict key mismatch.\n"
f" Only in full: {full_keys - reduced_keys}\n"
f" Only in reduced: {reduced_keys - full_keys}"
)
# Load the original model weights into the reduced export graph
gm_reduced.load_state_dict(model.state_dict(), strict=False)
# Source model must be fully restored
for name, mod in model.named_modules():
if hasattr(mod, "experts") and isinstance(mod.experts, nn.ModuleList):
assert len(mod.experts) == n_routed_experts, (
f"Expert list in '{name}' was not restored to {n_routed_experts}"
)
# Source model must be fully restored
for name, mod in model.named_modules():
if hasattr(mod, "experts") and isinstance(mod.experts, nn.ModuleList):
assert len(mod.experts) == n_routed_experts, (
f"Expert list in '{name}' was not restored to {n_routed_experts}"
)

View File

@ -0,0 +1,163 @@
"""Tests for ``create_derived_custom_op`` in ``_graph.py``."""
import torch
from torch._subclasses import FakeTensorMode
from tensorrt_llm._torch.auto_deploy.utils._graph import create_derived_custom_op
# ---------------------------------------------------------------------------
# Helpers tiny custom ops used as base ops for the tests
# ---------------------------------------------------------------------------
@torch.library.custom_op("ad_test_derived::double", mutates_args=())
def _double(x: torch.Tensor) -> torch.Tensor:
return x * 2
@_double.register_fake
def _double_fake(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.custom_op("ad_test_derived::weighted_add", mutates_args=())
def _weighted_add(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0) -> torch.Tensor:
return x + alpha * y
@_weighted_add.register_fake
def _weighted_add_fake(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0) -> torch.Tensor:
return torch.empty_like(x)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestCreateDerivedCustomOp:
"""Tests for the ``create_derived_custom_op`` utility."""
def test_basic_derived_op(self):
"""A derived op should be callable and produce correct results."""
def make_impl(base_overload):
# Wrapper that calls the base op then negates the result.
def impl(*args, **kwargs):
return -base_overload(*args, **kwargs)
return impl
base_op = torch.ops.ad_test_derived.double
derived = create_derived_custom_op(base_op, "_neg", make_impl)
x = torch.tensor([1.0, 2.0, 3.0])
result = derived(x)
expected = -(x * 2)
torch.testing.assert_close(result, expected)
def test_derived_op_is_registered(self):
"""The derived op must be accessible via ``torch.ops``."""
def make_impl(base_overload):
return lambda *a, **kw: base_overload(*a, **kw)
create_derived_custom_op(torch.ops.ad_test_derived.double, "_registered", make_impl)
assert hasattr(torch.ops.ad_test_derived, "double_registered")
def test_caching(self):
"""Repeated calls with the same base_op + suffix must return the same object."""
def make_impl(base_overload):
return lambda *a, **kw: base_overload(*a, **kw)
op1 = create_derived_custom_op(torch.ops.ad_test_derived.double, "_cached", make_impl)
op2 = create_derived_custom_op(torch.ops.ad_test_derived.double, "_cached", make_impl)
assert op1 is op2
def test_different_suffix_produces_different_op(self):
"""Different suffixes must create distinct ops."""
def make_impl(base_overload):
return lambda *a, **kw: base_overload(*a, **kw)
op_a = create_derived_custom_op(torch.ops.ad_test_derived.double, "_sfx_a", make_impl)
op_b = create_derived_custom_op(torch.ops.ad_test_derived.double, "_sfx_b", make_impl)
assert op_a is not op_b
def test_default_fake_implementation(self):
"""When *make_fake* is None the default (empty_like) must be used."""
def make_impl(base_overload):
return lambda *a, **kw: base_overload(*a, **kw)
derived = create_derived_custom_op(
torch.ops.ad_test_derived.double, "_dflt_fake", make_impl
)
# Calling the Meta implementation via FakeTensorMode
with FakeTensorMode():
x = torch.empty(4)
out = derived(x)
assert out.shape == x.shape
def test_custom_fake_implementation(self):
"""A user-supplied *make_fake* must override the default."""
def make_impl(base_overload):
return lambda *a, **kw: base_overload(*a, **kw)
# Fake that always returns shape (1,) regardless of input shape.
def make_fake(base_overload):
def fake(*args, **kwargs):
return args[0].new_empty(1)
return fake
derived = create_derived_custom_op(
torch.ops.ad_test_derived.double,
"_custom_fake",
make_impl,
make_fake=make_fake,
)
with FakeTensorMode():
x = torch.empty(10)
out = derived(x)
assert out.shape == (1,)
def test_preserves_schema_with_defaults(self):
"""Derived op must preserve the base op's argument defaults."""
def make_impl(base_overload):
def impl(*args, **kwargs):
return base_overload(*args, **kwargs) * 10
return impl
base_op = torch.ops.ad_test_derived.weighted_add
derived = create_derived_custom_op(base_op, "_x10", make_impl)
x = torch.ones(3)
y = torch.ones(3) * 2.0
# With default alpha=1.0 → (x + 1.0*y) * 10 = 30
result_default = derived(x, y)
torch.testing.assert_close(result_default, torch.full((3,), 30.0))
# With explicit alpha=0.5 → (x + 0.5*y) * 10 = 20
result_alpha = derived(x, y, alpha=0.5)
torch.testing.assert_close(result_alpha, torch.full((3,), 20.0))
def test_accepts_op_overload(self):
"""The function should accept an OpOverload (e.g. ``.default``) as well."""
def make_impl(base_overload):
return lambda *a, **kw: base_overload(*a, **kw) + 1
derived = create_derived_custom_op(
torch.ops.ad_test_derived.double.default, "_from_overload", make_impl
)
x = torch.tensor([5.0])
# double → 10, +1 → 11
torch.testing.assert_close(derived(x), torch.tensor([11.0]))