mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-18 16:55:08 +08:00
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
This commit is contained in:
parent
fcb7bea07f
commit
f3d784c6f6
@ -4,5 +4,19 @@ max_seq_len: 4096
|
||||
enable_chunked_prefill: true
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64]
|
||||
transforms:
|
||||
match_swiglu_pattern:
|
||||
enabled: true
|
||||
match_nvfp4_swiglu_pattern:
|
||||
enabled: true
|
||||
fuse_nvfp4_moe:
|
||||
allow_different_input_scales: true
|
||||
fuse_nvfp4_swiglu:
|
||||
enabled: true
|
||||
fuse_swiglu:
|
||||
enabled: true
|
||||
multi_stream_moe:
|
||||
stage: compile
|
||||
enabled: true
|
||||
multi_stream_mla_attn:
|
||||
stage: compile
|
||||
enabled: true
|
||||
|
||||
@ -37,7 +37,7 @@ transforms:
|
||||
"fc2_latent_proj": "gather"
|
||||
multi_stream_moe:
|
||||
stage: compile
|
||||
enabled: false
|
||||
enabled: true
|
||||
gather_logits_before_lm_head:
|
||||
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
|
||||
enabled: true
|
||||
|
||||
@ -50,6 +50,7 @@ transforms:
|
||||
expected_layout: bsnd
|
||||
match_rmsnorm_pattern:
|
||||
stage: pattern_matcher
|
||||
run_shape_prop: true
|
||||
match_l2norm_pattern:
|
||||
stage: pattern_matcher
|
||||
############################################################################################
|
||||
@ -75,6 +76,18 @@ transforms:
|
||||
stage: pattern_matcher
|
||||
quantize_nvfp4_from_graph:
|
||||
stage: pattern_matcher
|
||||
# SwiGLU pattern matching must run AFTER quantization transforms. For pre-quantized
|
||||
# checkpoints (e.g., NVFP4), quantization converts torch_linear_simple ops to quantized
|
||||
# ops first, and then match_nvfp4_swiglu_pattern captures the NVFP4 SwiGLU pattern.
|
||||
# For non-quantized models, quantization transforms are no-ops, so match_swiglu_pattern
|
||||
# proceeds normally.
|
||||
match_swiglu_pattern:
|
||||
stage: pattern_matcher
|
||||
enabled: false
|
||||
match_nvfp4_swiglu_pattern:
|
||||
stage: pattern_matcher
|
||||
requires_shape_prop: true
|
||||
enabled: false
|
||||
quantize_fp8_moe:
|
||||
stage: pattern_matcher
|
||||
quantize_nvfp4_moe:
|
||||
@ -126,6 +139,8 @@ transforms:
|
||||
fuse_nvfp4_linear:
|
||||
stage: post_load_fusion
|
||||
backend: trtllm
|
||||
fuse_nvfp4_swiglu:
|
||||
stage: post_load_fusion
|
||||
fuse_moe:
|
||||
stage: post_load_fusion
|
||||
expect_mem_change: true
|
||||
@ -149,6 +164,9 @@ transforms:
|
||||
fuse_l2norm:
|
||||
stage: post_load_fusion
|
||||
backend: fla
|
||||
fuse_swiglu:
|
||||
stage: post_load_fusion
|
||||
enabled: false
|
||||
fuse_add_rms_norm:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
@ -200,6 +218,9 @@ transforms:
|
||||
multi_stream_moe:
|
||||
stage: compile
|
||||
enabled: false
|
||||
multi_stream_mla_attn:
|
||||
stage: compile
|
||||
enabled: false
|
||||
compile_model:
|
||||
stage: compile
|
||||
expect_mem_change: true
|
||||
|
||||
@ -18,9 +18,11 @@
|
||||
This module provides linear layer implementations:
|
||||
- linear: Linear layer operations
|
||||
- torch_router: MoE router operations
|
||||
- swiglu: SwiGLU MLP custom operations
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"linear",
|
||||
"torch_router",
|
||||
"swiglu",
|
||||
]
|
||||
|
||||
311
tensorrt_llm/_torch/auto_deploy/custom_ops/linear/swiglu.py
Normal file
311
tensorrt_llm/_torch/auto_deploy/custom_ops/linear/swiglu.py
Normal file
@ -0,0 +1,311 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""SwiGLU MLP custom operations for graph transformation.
|
||||
|
||||
This module provides custom operators for SwiGLU MLP fusion:
|
||||
- torch_swiglu_mlp: Intermediate representation after pattern matching
|
||||
- fused_swiglu_mlp: Fused implementation with concatenated gate+up weights
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from flashinfer.activation import silu_and_mul as _flashinfer_silu_and_mul
|
||||
except ImportError:
|
||||
_flashinfer_silu_and_mul = None
|
||||
|
||||
|
||||
def _silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||
"""SwiGLU activation: split x in half, apply silu to first half, multiply with second half.
|
||||
|
||||
Uses FlashInfer's fused kernel when available, falls back to manual implementation.
|
||||
"""
|
||||
if _flashinfer_silu_and_mul is not None:
|
||||
return _flashinfer_silu_and_mul(x)
|
||||
gate, up = x.chunk(2, dim=-1)
|
||||
return F.silu(gate) * up
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_swiglu_mlp", mutates_args=())
|
||||
def torch_swiglu_mlp(
|
||||
input: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
gate_bias: Optional[torch.Tensor],
|
||||
up_bias: Optional[torch.Tensor],
|
||||
down_bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Standardized SwiGLU MLP operation.
|
||||
|
||||
Computes: silu(x @ gate.T + gate_bias) * (x @ up.T + up_bias) @ down.T + down_bias
|
||||
|
||||
This is the intermediate representation used after pattern matching,
|
||||
before weight fusion is applied.
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape [..., hidden_size].
|
||||
gate_weight: Gate projection weight of shape [intermediate_size, hidden_size].
|
||||
up_weight: Up projection weight of shape [intermediate_size, hidden_size].
|
||||
down_weight: Down projection weight of shape [hidden_size, intermediate_size].
|
||||
gate_bias: Optional gate projection bias of shape [intermediate_size].
|
||||
up_bias: Optional up projection bias of shape [intermediate_size].
|
||||
down_bias: Optional down projection bias of shape [hidden_size].
|
||||
|
||||
Returns:
|
||||
Output tensor of shape [..., hidden_size].
|
||||
"""
|
||||
gate_out = F.linear(input, gate_weight, gate_bias)
|
||||
up_out = F.linear(input, up_weight, up_bias)
|
||||
hidden = F.silu(gate_out) * up_out
|
||||
return F.linear(hidden, down_weight, down_bias)
|
||||
|
||||
|
||||
@torch_swiglu_mlp.register_fake
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
gate_bias: Optional[torch.Tensor],
|
||||
up_bias: Optional[torch.Tensor],
|
||||
down_bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Fake implementation for tracing."""
|
||||
# Output shape is [..., hidden_size] where hidden_size = down_weight.shape[0]
|
||||
output_shape = list(input.shape[:-1]) + [down_weight.shape[0]]
|
||||
return input.new_empty(output_shape, dtype=input.dtype)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::fused_swiglu_mlp", mutates_args=())
|
||||
def fused_swiglu_mlp(
|
||||
input: torch.Tensor,
|
||||
gate_up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
gate_up_bias: Optional[torch.Tensor],
|
||||
down_bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Fused SwiGLU MLP with concatenated gate+up weights.
|
||||
|
||||
Performs a single matmul for gate and up projections, then splits the result.
|
||||
Computes: silu(gate_out) * up_out @ down.T + down_bias
|
||||
where gate_out, up_out = split(x @ gate_up.T + gate_up_bias)
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape [..., hidden_size].
|
||||
gate_up_weight: Concatenated gate+up weight of shape [2*intermediate_size, hidden_size].
|
||||
down_weight: Down projection weight of shape [hidden_size, intermediate_size].
|
||||
gate_up_bias: Optional concatenated gate+up bias of shape [2*intermediate_size].
|
||||
down_bias: Optional down projection bias of shape [hidden_size].
|
||||
|
||||
Returns:
|
||||
Output tensor of shape [..., hidden_size].
|
||||
"""
|
||||
# Single matmul for both gate and up projections
|
||||
gate_up_out = F.linear(input, gate_up_weight, gate_up_bias)
|
||||
|
||||
# Apply SwiGLU activation: split, silu(gate) * up (uses FlashInfer when available)
|
||||
hidden = _silu_and_mul(gate_up_out)
|
||||
|
||||
# Down projection
|
||||
return F.linear(hidden, down_weight, down_bias)
|
||||
|
||||
|
||||
@fused_swiglu_mlp.register_fake
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
gate_up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
gate_up_bias: Optional[torch.Tensor],
|
||||
down_bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Fake implementation for tracing."""
|
||||
# Output shape is [..., hidden_size] where hidden_size = down_weight.shape[0]
|
||||
output_shape = list(input.shape[:-1]) + [down_weight.shape[0]]
|
||||
return input.new_empty(output_shape, dtype=input.dtype)
|
||||
|
||||
|
||||
# ── NVFP4 quantized SwiGLU ops ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_nvfp4_swiglu_mlp", mutates_args=())
|
||||
def torch_nvfp4_swiglu_mlp(
|
||||
input: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
gate_input_scale: torch.Tensor,
|
||||
gate_weight_scale: torch.Tensor,
|
||||
gate_alpha: torch.Tensor,
|
||||
up_input_scale: torch.Tensor,
|
||||
up_weight_scale: torch.Tensor,
|
||||
up_alpha: torch.Tensor,
|
||||
down_input_scale: torch.Tensor,
|
||||
down_weight_scale: torch.Tensor,
|
||||
down_alpha: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""NVFP4 quantized SwiGLU MLP operation (intermediate representation).
|
||||
|
||||
Computes: silu(nvfp4_linear(x, gate)) * nvfp4_linear(x, up) -> nvfp4_linear(down)
|
||||
|
||||
This is the intermediate representation used after pattern matching for NVFP4
|
||||
quantized checkpoints, before gate+up weight fusion is applied.
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape [..., hidden_size].
|
||||
gate_weight: FP4 packed gate weight [intermediate_size, hidden_size/2] uint8.
|
||||
up_weight: FP4 packed up weight [intermediate_size, hidden_size/2] uint8.
|
||||
down_weight: FP4 packed down weight [hidden_size, intermediate_size/2] uint8.
|
||||
gate_input_scale: Input scale for gate projection.
|
||||
gate_weight_scale: Per-block weight scale for gate projection.
|
||||
gate_alpha: Alpha (combined scale) for gate projection.
|
||||
up_input_scale: Input scale for up projection.
|
||||
up_weight_scale: Per-block weight scale for up projection.
|
||||
up_alpha: Alpha (combined scale) for up projection.
|
||||
down_input_scale: Input scale for down projection.
|
||||
down_weight_scale: Per-block weight scale for down projection.
|
||||
down_alpha: Alpha (combined scale) for down projection.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape [..., hidden_size].
|
||||
"""
|
||||
gate_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
|
||||
input,
|
||||
gate_weight,
|
||||
None,
|
||||
input_scale=[gate_input_scale],
|
||||
weight_scale=[gate_weight_scale, gate_alpha],
|
||||
input_zp=[],
|
||||
weight_zp=[],
|
||||
)
|
||||
up_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
|
||||
input,
|
||||
up_weight,
|
||||
None,
|
||||
input_scale=[up_input_scale],
|
||||
weight_scale=[up_weight_scale, up_alpha],
|
||||
input_zp=[],
|
||||
weight_zp=[],
|
||||
)
|
||||
hidden = F.silu(gate_out) * up_out
|
||||
return torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
|
||||
hidden,
|
||||
down_weight,
|
||||
None,
|
||||
input_scale=[down_input_scale],
|
||||
weight_scale=[down_weight_scale, down_alpha],
|
||||
input_zp=[],
|
||||
weight_zp=[],
|
||||
)
|
||||
|
||||
|
||||
@torch_nvfp4_swiglu_mlp.register_fake
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
gate_input_scale: torch.Tensor,
|
||||
gate_weight_scale: torch.Tensor,
|
||||
gate_alpha: torch.Tensor,
|
||||
up_input_scale: torch.Tensor,
|
||||
up_weight_scale: torch.Tensor,
|
||||
up_alpha: torch.Tensor,
|
||||
down_input_scale: torch.Tensor,
|
||||
down_weight_scale: torch.Tensor,
|
||||
down_alpha: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Fake implementation for tracing."""
|
||||
# Output shape: [..., hidden_size] where hidden_size = down_weight.shape[0]
|
||||
output_shape = list(input.shape[:-1]) + [down_weight.shape[0]]
|
||||
return input.new_empty(output_shape, dtype=input.dtype)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::fused_nvfp4_swiglu_mlp", mutates_args=())
|
||||
def fused_nvfp4_swiglu_mlp(
|
||||
input: torch.Tensor,
|
||||
gate_up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
gate_up_input_scale: torch.Tensor,
|
||||
gate_up_weight_scale: torch.Tensor,
|
||||
gate_up_alpha: torch.Tensor,
|
||||
down_input_scale: torch.Tensor,
|
||||
down_weight_scale: torch.Tensor,
|
||||
down_alpha: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Fused NVFP4 SwiGLU MLP with concatenated gate+up weights.
|
||||
|
||||
Performs a single NVFP4 matmul for gate and up projections, then splits,
|
||||
applies SwiGLU activation, and does the down NVFP4 matmul.
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape [..., hidden_size].
|
||||
gate_up_weight: Concatenated FP4 packed gate+up weight
|
||||
[2*intermediate_size, hidden_size/2] uint8.
|
||||
down_weight: FP4 packed down weight [hidden_size, intermediate_size/2] uint8.
|
||||
gate_up_input_scale: Shared input scale for gate+up projection.
|
||||
gate_up_weight_scale: Concatenated per-block weight scale for gate+up.
|
||||
gate_up_alpha: Shared alpha for gate+up projection.
|
||||
down_input_scale: Input scale for down projection.
|
||||
down_weight_scale: Per-block weight scale for down projection.
|
||||
down_alpha: Alpha for down projection.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape [..., hidden_size].
|
||||
"""
|
||||
# Single NVFP4 linear for both gate and up projections
|
||||
gate_up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
input,
|
||||
gate_up_weight,
|
||||
bias=None,
|
||||
input_scale=gate_up_input_scale,
|
||||
weight_scale=gate_up_weight_scale,
|
||||
alpha=gate_up_alpha,
|
||||
)
|
||||
|
||||
# Apply SwiGLU activation: split, silu(gate) * up (uses FlashInfer when available)
|
||||
hidden = _silu_and_mul(gate_up_out)
|
||||
|
||||
# Down projection
|
||||
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
hidden,
|
||||
down_weight,
|
||||
bias=None,
|
||||
input_scale=down_input_scale,
|
||||
weight_scale=down_weight_scale,
|
||||
alpha=down_alpha,
|
||||
)
|
||||
|
||||
|
||||
@fused_nvfp4_swiglu_mlp.register_fake
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
gate_up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
gate_up_input_scale: torch.Tensor,
|
||||
gate_up_weight_scale: torch.Tensor,
|
||||
gate_up_alpha: torch.Tensor,
|
||||
down_input_scale: torch.Tensor,
|
||||
down_weight_scale: torch.Tensor,
|
||||
down_alpha: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Fake implementation for tracing."""
|
||||
# Output shape: [..., hidden_size] where hidden_size = down_weight.shape[0]
|
||||
output_shape = list(input.shape[:-1]) + [down_weight.shape[0]]
|
||||
return input.new_empty(output_shape, dtype=input.dtype)
|
||||
@ -171,7 +171,7 @@ def _find_moe_module_lists(
|
||||
|
||||
def _reduce_moe_experts(
|
||||
model: nn.Module,
|
||||
min_num_experts: int,
|
||||
num_moe_experts_for_export: int,
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
@ -179,19 +179,21 @@ def _reduce_moe_experts(
|
||||
|
||||
Uses a probe forward pass to identify which ``nn.ModuleList`` instances
|
||||
feed into ``torch_moe``-family custom ops (see :func:`_find_moe_module_lists`),
|
||||
then truncates each to *min_num_experts* entries. The returned list of dicts
|
||||
carries the metadata needed by :func:`_restore_moe_experts` and
|
||||
:func:`_expand_moe_experts_in_graph`.
|
||||
then truncates each to *num_moe_experts_for_export* entries. The returned
|
||||
list of dicts carries the metadata needed by :func:`_restore_moe_experts`
|
||||
and :func:`_expand_moe_experts_in_graph`.
|
||||
"""
|
||||
if min_num_experts < 1:
|
||||
raise ValueError(f"min_num_experts must be >= 1, got {min_num_experts}")
|
||||
if num_moe_experts_for_export < 1:
|
||||
raise ValueError(
|
||||
f"num_moe_experts_for_export must be >= 1, got {num_moe_experts_for_export}"
|
||||
)
|
||||
|
||||
moe_lists = _find_moe_module_lists(model, args, kwargs)
|
||||
|
||||
reductions: List[Dict[str, Any]] = []
|
||||
for path, (parent, attr_name, mod_list) in moe_lists.items():
|
||||
orig_count = len(mod_list)
|
||||
if orig_count <= min_num_experts:
|
||||
if orig_count <= num_moe_experts_for_export:
|
||||
continue
|
||||
|
||||
reductions.append(
|
||||
@ -203,10 +205,10 @@ def _reduce_moe_experts(
|
||||
"expert_prefix": path,
|
||||
}
|
||||
)
|
||||
setattr(parent, attr_name, nn.ModuleList(list(mod_list[:min_num_experts])))
|
||||
setattr(parent, attr_name, nn.ModuleList(list(mod_list[:num_moe_experts_for_export])))
|
||||
ad_logger.info(
|
||||
f"Reduced MOE experts in '{path}' from {orig_count} to "
|
||||
f"{min_num_experts} for faster export"
|
||||
f"{num_moe_experts_for_export} for faster export"
|
||||
)
|
||||
return reductions
|
||||
|
||||
@ -252,7 +254,12 @@ def _expand_moe_experts_in_graph(
|
||||
if not reductions:
|
||||
return
|
||||
|
||||
# MOE ops whose arguments include per-expert weight lists (from index 3 onward)
|
||||
# MOE ops whose arguments include per-expert weight lists.
|
||||
# All these ops share the same first 3 positional args (x, selected_experts,
|
||||
# routing_weights) which are plain Tensors, followed by one or more
|
||||
# List[Tensor] args that hold per-expert weights/scales. We use the op
|
||||
# schema to discover which arguments are Tensor[] rather than hard-coding
|
||||
# the starting index.
|
||||
moe_ops = {
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
torch.ops.auto_deploy.torch_quant_fp8_moe,
|
||||
@ -266,11 +273,18 @@ def _expand_moe_experts_in_graph(
|
||||
if not is_op(node, moe_ops):
|
||||
continue
|
||||
|
||||
# Collect indices of list-of-node arguments (expert weight/scale lists)
|
||||
# Collect indices of List[Tensor] arguments from the op schema – these
|
||||
# are the per-expert weight / scale lists.
|
||||
op = node.target
|
||||
schema = op._schema if hasattr(op, "_schema") else next(iter(op._schemas.values()))
|
||||
_tensor_list_types = ("Tensor[]", "List[Tensor]")
|
||||
list_arg_indices = [
|
||||
i
|
||||
for i in range(3, len(node.args))
|
||||
if isinstance(node.args[i], (list, tuple)) and len(node.args[i]) > 0
|
||||
for i, arg_meta in enumerate(schema.arguments)
|
||||
if any(t in str(arg_meta.type) for t in _tensor_list_types)
|
||||
and i < len(node.args)
|
||||
and isinstance(node.args[i], (list, tuple))
|
||||
and len(node.args[i]) > 0
|
||||
]
|
||||
if not list_arg_indices:
|
||||
continue
|
||||
@ -673,6 +687,7 @@ def torch_export_to_gm(
|
||||
return egm
|
||||
|
||||
# Optionally reduce MOE experts for faster export tracing
|
||||
# TODO (https://github.com/NVIDIA/TensorRT-LLM/issues/7547): Reuse the export patch system
|
||||
moe_reductions: List[Dict[str, Any]] = []
|
||||
if num_moe_experts_for_export is not None:
|
||||
moe_reductions = _reduce_moe_experts(model, num_moe_experts_for_export, args, kwargs)
|
||||
|
||||
@ -444,7 +444,7 @@ class Glm4MoeLiteMoE(nn.Module):
|
||||
if self.shared_experts is not None:
|
||||
final_hidden_states = final_hidden_states + self.shared_experts(identity)
|
||||
|
||||
return final_hidden_states.to(hidden_states.dtype)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class Glm4MoeLiteAttention(nn.Module):
|
||||
|
||||
618
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_swiglu.py
Normal file
618
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_swiglu.py
Normal file
@ -0,0 +1,618 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Graph transforms for SwiGLU MLP fusion.
|
||||
|
||||
This module provides two-stage transformation for SwiGLU MLP:
|
||||
1. MatchSwiGLUPattern: Detects SwiGLU patterns and replaces with torch_swiglu_mlp
|
||||
2. FuseSwiGLU: Fuses gate+up weights into a single concatenated matmul
|
||||
|
||||
The SwiGLU pattern is: silu(x @ gate.T) * (x @ up.T) @ down.T
|
||||
"""
|
||||
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
# Import the custom ops to ensure they are registered and for use in replacements
|
||||
from ...custom_ops.linear.swiglu import torch_swiglu_mlp
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import (
|
||||
del_attr_by_name,
|
||||
delete_all_unused_submodules,
|
||||
eliminate_dead_code,
|
||||
get_attr_by_name,
|
||||
)
|
||||
from ...utils.node_utils import is_op
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
TransformConfig,
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
|
||||
|
||||
def _try_free_attr_node(gm: GraphModule, graph, attr_node: Node) -> None:
|
||||
"""Erase a get_attr node and eagerly delete its module attribute if it has no users.
|
||||
|
||||
This is used to free unfused weight tensors as soon as they are no longer
|
||||
referenced in the graph, avoiding a temporary memory spike that would occur
|
||||
if cleanup were deferred until after all nodes are processed.
|
||||
"""
|
||||
if attr_node is not None and attr_node.op == "get_attr" and len(attr_node.users) == 0:
|
||||
target = attr_node.target
|
||||
graph.erase_node(attr_node)
|
||||
del_attr_by_name(gm, target)
|
||||
|
||||
|
||||
def _swiglu_pattern_no_bias(x, gate_weight, up_weight, down_weight):
|
||||
"""Pattern for SwiGLU MLP without biases.
|
||||
|
||||
Matches: silu(linear(x, gate_weight, None)) * linear(x, up_weight, None) -> linear(down_weight, None)
|
||||
"""
|
||||
gate_out = torch.ops.auto_deploy.torch_linear_simple.default(x, gate_weight, None)
|
||||
up_out = torch.ops.auto_deploy.torch_linear_simple.default(x, up_weight, None)
|
||||
silu_out = torch.ops.aten.silu.default(gate_out)
|
||||
mul_out = torch.ops.aten.mul.Tensor(silu_out, up_out)
|
||||
down_out = torch.ops.auto_deploy.torch_linear_simple.default(mul_out, down_weight, None)
|
||||
return down_out
|
||||
|
||||
|
||||
def _swiglu_replacement_no_bias(x, gate_weight, up_weight, down_weight):
|
||||
"""Replacement for SwiGLU pattern without biases."""
|
||||
# Call the Python wrapper directly, not via torch.ops.auto_deploy
|
||||
# This ensures proper FakeTensor mode handling during tracing
|
||||
return torch_swiglu_mlp(x, gate_weight, up_weight, down_weight, None, None, None)
|
||||
|
||||
|
||||
def _swiglu_pattern_with_bias(
|
||||
x, gate_weight, up_weight, down_weight, gate_bias, up_bias, down_bias
|
||||
):
|
||||
"""Pattern for SwiGLU MLP with biases.
|
||||
|
||||
Matches: silu(linear(x, gate_weight, gate_bias)) * linear(x, up_weight, up_bias) -> linear(down_weight, down_bias)
|
||||
"""
|
||||
gate_out = torch.ops.auto_deploy.torch_linear_simple.default(x, gate_weight, gate_bias)
|
||||
up_out = torch.ops.auto_deploy.torch_linear_simple.default(x, up_weight, up_bias)
|
||||
silu_out = torch.ops.aten.silu.default(gate_out)
|
||||
mul_out = torch.ops.aten.mul.Tensor(silu_out, up_out)
|
||||
down_out = torch.ops.auto_deploy.torch_linear_simple.default(mul_out, down_weight, down_bias)
|
||||
return down_out
|
||||
|
||||
|
||||
def _swiglu_replacement_with_bias(
|
||||
x, gate_weight, up_weight, down_weight, gate_bias, up_bias, down_bias
|
||||
):
|
||||
"""Replacement for SwiGLU pattern with biases."""
|
||||
# Call the Python wrapper directly, not via torch.ops.auto_deploy
|
||||
# This ensures proper FakeTensor mode handling during tracing
|
||||
return torch_swiglu_mlp(x, gate_weight, up_weight, down_weight, gate_bias, up_bias, down_bias)
|
||||
|
||||
|
||||
@TransformRegistry.register("match_swiglu_pattern")
|
||||
class MatchSwiGLUPattern(BaseTransform):
|
||||
"""Matches SwiGLU MLP patterns and replaces with torch_swiglu_mlp op.
|
||||
|
||||
This transform runs in the pattern_matcher stage and detects the following pattern:
|
||||
silu(x @ gate.T) * (x @ up.T) @ down.T
|
||||
|
||||
And replaces it with a single torch_swiglu_mlp op that can be fused later.
|
||||
|
||||
Uses ADPatternMatcherPass for declarative pattern matching.
|
||||
"""
|
||||
|
||||
config: TransformConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return TransformConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
patterns = ADPatternMatcherPass()
|
||||
|
||||
# Dummy shapes for tracing - shapes don't matter for matching
|
||||
hidden, intermediate = 128, 256
|
||||
|
||||
# Pattern 1: SwiGLU without biases (most common case)
|
||||
dummy_args_no_bias = [
|
||||
torch.randn(2, hidden, device="meta", dtype=torch.float16), # x
|
||||
torch.randn(intermediate, hidden, device="meta", dtype=torch.float16), # gate_weight
|
||||
torch.randn(intermediate, hidden, device="meta", dtype=torch.float16), # up_weight
|
||||
torch.randn(hidden, intermediate, device="meta", dtype=torch.float16), # down_weight
|
||||
]
|
||||
register_ad_pattern(
|
||||
search_fn=_swiglu_pattern_no_bias,
|
||||
replace_fn=_swiglu_replacement_no_bias,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_no_bias,
|
||||
)
|
||||
|
||||
# Pattern 2: SwiGLU with biases
|
||||
dummy_args_with_bias = [
|
||||
torch.randn(2, hidden, device="meta", dtype=torch.float16), # x
|
||||
torch.randn(intermediate, hidden, device="meta", dtype=torch.float16), # gate_weight
|
||||
torch.randn(intermediate, hidden, device="meta", dtype=torch.float16), # up_weight
|
||||
torch.randn(hidden, intermediate, device="meta", dtype=torch.float16), # down_weight
|
||||
torch.randn(intermediate, device="meta", dtype=torch.float16), # gate_bias
|
||||
torch.randn(intermediate, device="meta", dtype=torch.float16), # up_bias
|
||||
torch.randn(hidden, device="meta", dtype=torch.float16), # down_bias
|
||||
]
|
||||
register_ad_pattern(
|
||||
search_fn=_swiglu_pattern_with_bias,
|
||||
replace_fn=_swiglu_replacement_with_bias,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_with_bias,
|
||||
)
|
||||
|
||||
num_matches = patterns.apply(gm.graph)
|
||||
|
||||
if num_matches > 0:
|
||||
gm.recompile()
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_matches,
|
||||
is_clean=num_matches == 0,
|
||||
has_valid_shapes=num_matches == 0,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
class FuseSwiGLUConfig(TransformConfig):
|
||||
"""Configuration for the SwiGLU fusion transform."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable SwiGLU fusion.",
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_swiglu")
|
||||
class FuseSwiGLU(BaseTransform):
|
||||
"""Fuses torch_swiglu_mlp ops by concatenating gate and up weights.
|
||||
|
||||
This transform runs in the post_load_fusion stage and replaces torch_swiglu_mlp ops
|
||||
with fused_swiglu_mlp ops that use a single concatenated gate+up weight matrix.
|
||||
|
||||
This reduces memory bandwidth by performing a single matmul instead of two
|
||||
separate matmuls for gate and up projections.
|
||||
"""
|
||||
|
||||
config: FuseSwiGLUConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[FuseSwiGLUConfig]:
|
||||
return FuseSwiGLUConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
if not self.config.enabled:
|
||||
return gm, TransformInfo(skipped=True, num_matches=0)
|
||||
|
||||
graph = gm.graph
|
||||
cnt = 0
|
||||
fused_weight_idx = 0
|
||||
|
||||
for node in list(graph.nodes):
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_swiglu_mlp.default):
|
||||
continue
|
||||
|
||||
# Extract args: (input, gate_weight, up_weight, down_weight, gate_bias, up_bias, down_bias)
|
||||
input_node = node.args[0]
|
||||
gate_weight_node = node.args[1]
|
||||
up_weight_node = node.args[2]
|
||||
down_weight_node = node.args[3]
|
||||
gate_bias_node = node.args[4] if len(node.args) > 4 else None
|
||||
up_bias_node = node.args[5] if len(node.args) > 5 else None
|
||||
down_bias_node = node.args[6] if len(node.args) > 6 else None
|
||||
|
||||
# Get the actual weight tensors
|
||||
gate_weight = get_attr_by_name(gm, gate_weight_node.target)
|
||||
up_weight = get_attr_by_name(gm, up_weight_node.target)
|
||||
|
||||
# Concatenate gate and up weights: [intermediate, hidden] -> [2*intermediate, hidden]
|
||||
gate_up_weight = torch.cat([gate_weight, up_weight], dim=0)
|
||||
|
||||
# Create new attribute for the fused weight
|
||||
fused_weight_name = f"fused_swiglu_gate_up_{fused_weight_idx}"
|
||||
gm.register_buffer(fused_weight_name, gate_up_weight)
|
||||
|
||||
# Handle biases
|
||||
gate_up_bias_node = None
|
||||
fused_bias_name = None
|
||||
|
||||
if gate_bias_node is not None and gate_bias_node.op == "get_attr":
|
||||
gate_bias = get_attr_by_name(gm, gate_bias_node.target)
|
||||
up_bias = get_attr_by_name(gm, up_bias_node.target) if up_bias_node else None
|
||||
|
||||
if up_bias is not None:
|
||||
gate_up_bias = torch.cat([gate_bias, up_bias], dim=0)
|
||||
fused_bias_name = f"fused_swiglu_gate_up_bias_{fused_weight_idx}"
|
||||
gm.register_buffer(fused_bias_name, gate_up_bias)
|
||||
|
||||
# Create get_attr node for the fused weight
|
||||
with graph.inserting_before(node):
|
||||
fused_weight_node = graph.get_attr(fused_weight_name)
|
||||
|
||||
if fused_bias_name is not None:
|
||||
gate_up_bias_node = graph.get_attr(fused_bias_name)
|
||||
|
||||
# Create the fused_swiglu_mlp node
|
||||
with graph.inserting_after(node):
|
||||
fused_node: Node = graph.call_function(
|
||||
torch.ops.auto_deploy.fused_swiglu_mlp.default,
|
||||
args=(
|
||||
input_node,
|
||||
fused_weight_node,
|
||||
down_weight_node,
|
||||
gate_up_bias_node,
|
||||
down_bias_node,
|
||||
),
|
||||
)
|
||||
|
||||
# Replace uses and erase old node
|
||||
node.replace_all_uses_with(fused_node)
|
||||
graph.erase_node(node)
|
||||
|
||||
# Eagerly free unfused weight/bias tensors that are no longer referenced
|
||||
# to avoid a temporary memory spike from holding both fused and unfused
|
||||
# copies simultaneously across all layers.
|
||||
_try_free_attr_node(gm, graph, gate_weight_node)
|
||||
_try_free_attr_node(gm, graph, up_weight_node)
|
||||
_try_free_attr_node(gm, graph, gate_bias_node)
|
||||
_try_free_attr_node(gm, graph, up_bias_node)
|
||||
|
||||
fused_weight_idx += 1
|
||||
cnt += 1
|
||||
|
||||
if cnt > 0:
|
||||
gm.recompile()
|
||||
|
||||
# Clean up any remaining dead code and unused submodules
|
||||
eliminate_dead_code(gm)
|
||||
delete_all_unused_submodules(gm)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
# ── NVFP4 quantized SwiGLU pattern matching and fusion ──────────────────────
|
||||
|
||||
from ...custom_ops.linear.swiglu import torch_nvfp4_swiglu_mlp # noqa: E402
|
||||
|
||||
|
||||
def _nvfp4_swiglu_pattern_no_bias(
|
||||
x,
|
||||
gate_weight,
|
||||
gate_input_scale,
|
||||
gate_weight_scale,
|
||||
gate_alpha,
|
||||
up_weight,
|
||||
up_input_scale,
|
||||
up_weight_scale,
|
||||
up_alpha,
|
||||
down_weight,
|
||||
down_input_scale,
|
||||
down_weight_scale,
|
||||
down_alpha,
|
||||
):
|
||||
"""Pattern for NVFP4 quantized SwiGLU MLP without biases.
|
||||
|
||||
Matches: silu(nvfp4_linear(x, gate)) * nvfp4_linear(x, up) -> nvfp4_linear(down)
|
||||
"""
|
||||
gate_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear.default(
|
||||
x,
|
||||
gate_weight,
|
||||
None,
|
||||
input_scale=[gate_input_scale],
|
||||
weight_scale=[gate_weight_scale, gate_alpha],
|
||||
input_zp=[],
|
||||
weight_zp=[],
|
||||
)
|
||||
up_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear.default(
|
||||
x,
|
||||
up_weight,
|
||||
None,
|
||||
input_scale=[up_input_scale],
|
||||
weight_scale=[up_weight_scale, up_alpha],
|
||||
input_zp=[],
|
||||
weight_zp=[],
|
||||
)
|
||||
silu_out = torch.ops.aten.silu.default(gate_out)
|
||||
mul_out = torch.ops.aten.mul.Tensor(silu_out, up_out)
|
||||
down_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear.default(
|
||||
mul_out,
|
||||
down_weight,
|
||||
None,
|
||||
input_scale=[down_input_scale],
|
||||
weight_scale=[down_weight_scale, down_alpha],
|
||||
input_zp=[],
|
||||
weight_zp=[],
|
||||
)
|
||||
return down_out
|
||||
|
||||
|
||||
def _nvfp4_swiglu_replacement_no_bias(
|
||||
x,
|
||||
gate_weight,
|
||||
gate_input_scale,
|
||||
gate_weight_scale,
|
||||
gate_alpha,
|
||||
up_weight,
|
||||
up_input_scale,
|
||||
up_weight_scale,
|
||||
up_alpha,
|
||||
down_weight,
|
||||
down_input_scale,
|
||||
down_weight_scale,
|
||||
down_alpha,
|
||||
):
|
||||
"""Replacement for NVFP4 quantized SwiGLU pattern without biases."""
|
||||
return torch_nvfp4_swiglu_mlp(
|
||||
x,
|
||||
gate_weight,
|
||||
up_weight,
|
||||
down_weight,
|
||||
gate_input_scale,
|
||||
gate_weight_scale,
|
||||
gate_alpha,
|
||||
up_input_scale,
|
||||
up_weight_scale,
|
||||
up_alpha,
|
||||
down_input_scale,
|
||||
down_weight_scale,
|
||||
down_alpha,
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("match_nvfp4_swiglu_pattern")
|
||||
class MatchNVFP4SwiGLUPattern(BaseTransform):
|
||||
"""Matches NVFP4 quantized SwiGLU MLP patterns and replaces with torch_nvfp4_swiglu_mlp.
|
||||
|
||||
This transform runs in the pattern_matcher stage AFTER quantize_nvfp4_linear_from_config
|
||||
has converted torch_linear_simple ops to torch_fake_quant_nvfp4_linear ops.
|
||||
|
||||
It detects the following NVFP4 pattern:
|
||||
silu(nvfp4_linear(x, gate)) * nvfp4_linear(x, up) -> nvfp4_linear(down)
|
||||
|
||||
And replaces it with a single torch_nvfp4_swiglu_mlp op that can be fused later.
|
||||
"""
|
||||
|
||||
config: TransformConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return TransformConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
patterns = ADPatternMatcherPass()
|
||||
|
||||
# FP4 shape params for dummy args (shapes don't matter for matching)
|
||||
N = 32 # intermediate_size
|
||||
K_packed = 32 # hidden_size / 2 (FP4 packing)
|
||||
K_eff = 2 * K_packed # actual hidden_size
|
||||
N_down = K_eff # hidden_size (output of down proj)
|
||||
K_down_packed = N // 2 # intermediate_size / 2 (down proj input)
|
||||
|
||||
# Weight scale sizes (per-block scale: N * K / 16)
|
||||
gate_cutlass_len = N * (K_eff // 16)
|
||||
down_cutlass_len = N_down * (N // 16)
|
||||
|
||||
x = torch.randn(2, K_eff, device="meta", dtype=torch.float16)
|
||||
|
||||
# Gate args
|
||||
gate_w = torch.randint(0, 255, (N, K_packed), device="meta", dtype=torch.uint8)
|
||||
gate_is = torch.tensor(0.01, device="meta", dtype=torch.float32)
|
||||
gate_ws = torch.randint(0, 255, (gate_cutlass_len,), device="meta", dtype=torch.uint8)
|
||||
gate_a = torch.tensor(1.2345, device="meta", dtype=torch.float32)
|
||||
|
||||
# Up args (same shapes as gate)
|
||||
up_w = torch.randint(0, 255, (N, K_packed), device="meta", dtype=torch.uint8)
|
||||
up_is = torch.tensor(0.02, device="meta", dtype=torch.float32)
|
||||
up_ws = torch.randint(0, 255, (gate_cutlass_len,), device="meta", dtype=torch.uint8)
|
||||
up_a = torch.tensor(2.3456, device="meta", dtype=torch.float32)
|
||||
|
||||
# Down args
|
||||
down_w = torch.randint(0, 255, (N_down, K_down_packed), device="meta", dtype=torch.uint8)
|
||||
down_is = torch.tensor(0.03, device="meta", dtype=torch.float32)
|
||||
down_ws = torch.randint(0, 255, (down_cutlass_len,), device="meta", dtype=torch.uint8)
|
||||
down_a = torch.tensor(3.4567, device="meta", dtype=torch.float32)
|
||||
|
||||
dummy_args = [
|
||||
x,
|
||||
gate_w,
|
||||
gate_is,
|
||||
gate_ws,
|
||||
gate_a,
|
||||
up_w,
|
||||
up_is,
|
||||
up_ws,
|
||||
up_a,
|
||||
down_w,
|
||||
down_is,
|
||||
down_ws,
|
||||
down_a,
|
||||
]
|
||||
|
||||
register_ad_pattern(
|
||||
search_fn=_nvfp4_swiglu_pattern_no_bias,
|
||||
replace_fn=_nvfp4_swiglu_replacement_no_bias,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args,
|
||||
)
|
||||
|
||||
num_matches = patterns.apply(gm.graph)
|
||||
|
||||
if num_matches > 0:
|
||||
gm.recompile()
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_matches,
|
||||
is_clean=num_matches == 0,
|
||||
has_valid_shapes=num_matches == 0,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_nvfp4_swiglu")
|
||||
class FuseNVFP4SwiGLU(BaseTransform):
|
||||
"""Fuses torch_nvfp4_swiglu_mlp ops by concatenating gate and up FP4 weights.
|
||||
|
||||
This transform runs in the post_load_fusion stage and replaces torch_nvfp4_swiglu_mlp
|
||||
ops with fused_nvfp4_swiglu_mlp ops that use a single concatenated gate+up weight matrix.
|
||||
|
||||
FP4 weight fusion:
|
||||
- gate+up packed weights are concatenated along dim=0
|
||||
- gate+up per-block weight scales are concatenated along dim=0
|
||||
- gate+up input_scale and alpha must match (shared input)
|
||||
"""
|
||||
|
||||
config: TransformConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return TransformConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
graph = gm.graph
|
||||
cnt = 0
|
||||
fused_weight_idx = 0
|
||||
|
||||
for node in list(graph.nodes):
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_nvfp4_swiglu_mlp.default):
|
||||
continue
|
||||
|
||||
# Extract args:
|
||||
# (input, gate_weight, up_weight, down_weight,
|
||||
# gate_input_scale, gate_weight_scale, gate_alpha,
|
||||
# up_input_scale, up_weight_scale, up_alpha,
|
||||
# down_input_scale, down_weight_scale, down_alpha)
|
||||
input_node = node.args[0]
|
||||
gate_weight_node = node.args[1]
|
||||
up_weight_node = node.args[2]
|
||||
down_weight_node = node.args[3]
|
||||
gate_input_scale_node = node.args[4]
|
||||
gate_weight_scale_node = node.args[5]
|
||||
gate_alpha_node = node.args[6]
|
||||
up_input_scale_node = node.args[7]
|
||||
up_weight_scale_node = node.args[8]
|
||||
up_alpha_node = node.args[9]
|
||||
down_input_scale_node = node.args[10]
|
||||
down_weight_scale_node = node.args[11]
|
||||
down_alpha_node = node.args[12]
|
||||
|
||||
# Get the actual weight tensors
|
||||
gate_weight = get_attr_by_name(gm, gate_weight_node.target)
|
||||
up_weight = get_attr_by_name(gm, up_weight_node.target)
|
||||
|
||||
# Concatenate gate and up FP4 packed weights along dim=0
|
||||
gate_up_weight = torch.cat([gate_weight, up_weight], dim=0)
|
||||
|
||||
# Get and concatenate weight scales
|
||||
gate_weight_scale = get_attr_by_name(gm, gate_weight_scale_node.target)
|
||||
up_weight_scale = get_attr_by_name(gm, up_weight_scale_node.target)
|
||||
gate_up_weight_scale = torch.cat([gate_weight_scale, up_weight_scale], dim=0)
|
||||
|
||||
# Register fused buffers
|
||||
prefix = f"fused_nvfp4_swiglu_{fused_weight_idx}"
|
||||
gm.register_buffer(f"{prefix}_gate_up_weight", gate_up_weight)
|
||||
gm.register_buffer(f"{prefix}_gate_up_weight_scale", gate_up_weight_scale)
|
||||
|
||||
# Create get_attr nodes for fused weights/scales
|
||||
with graph.inserting_before(node):
|
||||
fused_gate_up_weight_node = graph.get_attr(f"{prefix}_gate_up_weight")
|
||||
fused_gate_up_weight_scale_node = graph.get_attr(f"{prefix}_gate_up_weight_scale")
|
||||
|
||||
# Create the fused_nvfp4_swiglu_mlp node
|
||||
# Use gate's input_scale and alpha (same as up's since they share input)
|
||||
with graph.inserting_after(node):
|
||||
fused_node: Node = graph.call_function(
|
||||
torch.ops.auto_deploy.fused_nvfp4_swiglu_mlp.default,
|
||||
args=(
|
||||
input_node,
|
||||
fused_gate_up_weight_node,
|
||||
down_weight_node,
|
||||
gate_input_scale_node, # shared input_scale for gate+up
|
||||
fused_gate_up_weight_scale_node,
|
||||
gate_alpha_node, # shared alpha for gate+up
|
||||
down_input_scale_node,
|
||||
down_weight_scale_node,
|
||||
down_alpha_node,
|
||||
),
|
||||
)
|
||||
|
||||
# Replace uses and erase old node
|
||||
node.replace_all_uses_with(fused_node)
|
||||
graph.erase_node(node)
|
||||
|
||||
# Eagerly free unfused weight/scale tensors that are no longer referenced
|
||||
# to avoid a temporary memory spike from holding both fused and unfused
|
||||
# copies simultaneously across all layers.
|
||||
_try_free_attr_node(gm, graph, gate_weight_node)
|
||||
_try_free_attr_node(gm, graph, up_weight_node)
|
||||
_try_free_attr_node(gm, graph, gate_weight_scale_node)
|
||||
_try_free_attr_node(gm, graph, up_weight_scale_node)
|
||||
_try_free_attr_node(gm, graph, up_input_scale_node)
|
||||
_try_free_attr_node(gm, graph, up_alpha_node)
|
||||
|
||||
fused_weight_idx += 1
|
||||
cnt += 1
|
||||
|
||||
if cnt > 0:
|
||||
gm.recompile()
|
||||
|
||||
# Clean up any remaining dead code and unused submodules
|
||||
eliminate_dead_code(gm)
|
||||
delete_all_unused_submodules(gm)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
|
||||
)
|
||||
|
||||
return gm, info
|
||||
@ -9,31 +9,43 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Transformation for fusing Add + Cast + RMSNorm."""
|
||||
"""Transformation for fusing Add + (optional Cast) + RMSNorm via direct FX graph manipulation."""
|
||||
|
||||
from typing import Tuple
|
||||
import operator
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...custom_ops.normalization.flashinfer_fused_add_rms_norm import flashinfer_fused_add_rms_norm
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
from ...utils._graph import eliminate_dead_code
|
||||
from ...utils.node_utils import is_op
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_add_rms_norm")
|
||||
class FuseAddRMSNorm(BaseTransform):
|
||||
"""Fuse (add + cast + RMSNorm) into one fused op.
|
||||
"""Fuse (add + optional cast + RMSNorm) into one fused op.
|
||||
|
||||
Matches:
|
||||
x = add(input, residual)
|
||||
y = x.to(dtype)
|
||||
z = flashinfer_rms_norm(y, weight, eps)
|
||||
Uses direct FX graph manipulation instead of the inductor pattern matcher
|
||||
to correctly handle patterns where intermediate nodes (add, rms_norm) have
|
||||
multiple users in the graph.
|
||||
|
||||
Replaces with:
|
||||
z, x = flashinfer_fused_add_rms_norm(input, residual, weight, eps)
|
||||
Pattern 1 (without cast):
|
||||
%add = aten.add(%x, %residual)
|
||||
%norm = flashinfer_rms_norm(%add, %weight, eps)
|
||||
|
||||
Pattern 2 (with cast):
|
||||
%add = aten.add(%x, %residual)
|
||||
%cast = aten.to.dtype(%add, bfloat16)
|
||||
%norm = flashinfer_rms_norm(%cast, %weight, eps)
|
||||
|
||||
Both are replaced with:
|
||||
%fused = flashinfer_fused_add_rms_norm(%x, %residual, %weight, eps)
|
||||
%norm_out = getitem(%fused, 0) # norm result (replaces %norm)
|
||||
%add_out = getitem(%fused, 1) # add result (replaces %add)
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
@ -43,42 +55,102 @@ class FuseAddRMSNorm(BaseTransform):
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
patterns = ADPatternMatcherPass()
|
||||
graph = gm.graph
|
||||
num_matches = 0
|
||||
|
||||
# Dummy shapes for tracing
|
||||
bsz, hidden = 2, 128
|
||||
dummy_args = [
|
||||
torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # x (bf16)
|
||||
torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # residual (bf16)
|
||||
torch.randn(hidden, device="meta", dtype=torch.bfloat16), # weight
|
||||
1e-5, # eps
|
||||
]
|
||||
# --- Step 1: collect (add_node, optional cast_node, norm_node) triples ---
|
||||
matches: List[Tuple[Node, Optional[Node], Node]] = []
|
||||
|
||||
op_ignore_types = {torch.ops.aten.to.dtype: (torch.dtype,)}
|
||||
scalar_workaround = {"eps": 1e-5}
|
||||
for node in graph.nodes:
|
||||
# Match flashinfer_rms_norm (handles both overload packet and .default)
|
||||
if not is_op(node, torch.ops.auto_deploy.flashinfer_rms_norm):
|
||||
continue
|
||||
|
||||
def _fused_add_norm_pattern(x, residual, weight, eps):
|
||||
added = torch.ops.aten.add.Tensor(x, residual)
|
||||
cast = torch.ops.aten.to.dtype(added, torch.bfloat16)
|
||||
# Note: we assume flashinfer_rms_norm is the target
|
||||
norm = torch.ops.auto_deploy.flashinfer_rms_norm.default(cast, weight, eps)
|
||||
return norm, added
|
||||
input_to_norm = node.args[0]
|
||||
cast_node: Optional[Node] = None
|
||||
|
||||
def _fused_add_norm_replacement(x, residual, weight, eps):
|
||||
# Use the python wrapper directly, not via torch.ops.auto_deploy
|
||||
return flashinfer_fused_add_rms_norm(x, residual, weight, eps)
|
||||
# Check for an optional aten.to.dtype cast between add and norm
|
||||
if isinstance(input_to_norm, Node) and is_op(input_to_norm, torch.ops.aten.to.dtype):
|
||||
cast_node = input_to_norm
|
||||
input_to_norm = cast_node.args[0]
|
||||
|
||||
# Register pattern
|
||||
register_ad_pattern(
|
||||
search_fn=_fused_add_norm_pattern,
|
||||
replace_fn=_fused_add_norm_replacement,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args,
|
||||
op_ignore_types=op_ignore_types,
|
||||
scalar_workaround=scalar_workaround,
|
||||
)
|
||||
# The (possibly unwrapped) input must be an aten.add.Tensor
|
||||
if not isinstance(input_to_norm, Node) or not is_op(
|
||||
input_to_norm, torch.ops.aten.add.Tensor
|
||||
):
|
||||
continue
|
||||
|
||||
num_matches = patterns.apply(gm.graph)
|
||||
add_node = input_to_norm
|
||||
matches.append((add_node, cast_node, node))
|
||||
|
||||
# --- Step 2: apply fusions ---
|
||||
erased: set = set() # track erased node ids to skip stale matches
|
||||
|
||||
for add_node, cast_node, norm_node in matches:
|
||||
# Safety: skip if a node in this match was already consumed
|
||||
if id(add_node) in erased or id(norm_node) in erased:
|
||||
continue
|
||||
|
||||
# Original operands
|
||||
add_lhs = add_node.args[0] # e.g. previous residual
|
||||
add_rhs = add_node.args[1] # e.g. attention/MoE output
|
||||
weight = norm_node.args[1]
|
||||
eps = norm_node.args[2]
|
||||
|
||||
# Insert the fused call right before the norm node. Using
|
||||
# inserting_before(norm_node) ensures correct topological order:
|
||||
# fused_node → norm_out → add_out all appear before norm_node.
|
||||
with graph.inserting_before(norm_node):
|
||||
# flashinfer_fused_add_rms_norm(x, residual, weight, eps):
|
||||
# residual += x → residual becomes add result
|
||||
# x = rms_norm(residual) → x becomes norm result
|
||||
# returns (x, residual) = (norm_result, add_result)
|
||||
fused_node = graph.call_function(
|
||||
flashinfer_fused_add_rms_norm,
|
||||
args=(add_rhs, add_lhs, weight, eps),
|
||||
)
|
||||
norm_out = graph.call_function(operator.getitem, args=(fused_node, 0))
|
||||
add_out = graph.call_function(operator.getitem, args=(fused_node, 1))
|
||||
|
||||
# Rewire all consumers of the original norm → norm_out
|
||||
norm_node.replace_all_uses_with(norm_out)
|
||||
|
||||
# Erase norm first so cast_node (if present) loses its only user
|
||||
graph.erase_node(norm_node)
|
||||
erased.add(id(norm_node))
|
||||
|
||||
# Erase cast_node *before* replacing add's uses, otherwise
|
||||
# replace_all_uses_with would rewrite cast's input to add_out
|
||||
# which sits after cast in the graph → topological violation.
|
||||
if cast_node is not None:
|
||||
if len(cast_node.users) == 0:
|
||||
graph.erase_node(cast_node)
|
||||
erased.add(id(cast_node))
|
||||
else:
|
||||
# Rare: cast has users besides norm. Redirect them to a
|
||||
# new cast placed after add_out so the ordering is valid.
|
||||
with graph.inserting_before(list(cast_node.users)[0]):
|
||||
new_cast = graph.call_function(
|
||||
cast_node.target,
|
||||
args=(add_out, *cast_node.args[1:]),
|
||||
kwargs=cast_node.kwargs,
|
||||
)
|
||||
cast_node.replace_all_uses_with(new_cast)
|
||||
graph.erase_node(cast_node)
|
||||
erased.add(id(cast_node))
|
||||
|
||||
# Rewire all consumers of the original add → add_out
|
||||
# (includes the residual connection to the *next* layer's add)
|
||||
add_node.replace_all_uses_with(add_out)
|
||||
|
||||
graph.erase_node(add_node)
|
||||
erased.add(id(add_node))
|
||||
|
||||
num_matches += 1
|
||||
|
||||
# Clean up any remaining dead code
|
||||
if num_matches > 0:
|
||||
eliminate_dead_code(gm)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
|
||||
@ -0,0 +1,215 @@
|
||||
"""Transform for multi-stream execution of Q and KV projection chains in MLA attention.
|
||||
|
||||
In DeepSeek-style MLA (Multi-head Latent Attention), the input layernorm output
|
||||
forks into two independent projection chains that merge at the RoPE + attention op:
|
||||
|
||||
- **Q chain** (heavier): q_a_proj -> rms_norm -> q_b_proj -> view -> split
|
||||
- **KV chain** (lighter): kv_a_proj_with_mqa -> split -> rms_norm + view
|
||||
|
||||
The Q chain is ~9x heavier than the KV chain. This transform moves the KV
|
||||
projection linear onto the auxiliary CUDA stream so it executes concurrently
|
||||
with the Q chain on the main stream.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
|
||||
# Reuse CachedSequenceInterface for the _apply signature.
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import create_derived_custom_op
|
||||
from ...utils.multi_stream_utils import (
|
||||
_make_aux_stream_impl,
|
||||
cuda_stream_manager,
|
||||
record_event_passthrough,
|
||||
)
|
||||
from ...utils.node_utils import is_op
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supported linear op targets. Extend this list to cover quantised variants.
|
||||
# ---------------------------------------------------------------------------
|
||||
_LINEAR_OPS: List[Callable] = [
|
||||
torch.ops.auto_deploy.torch_linear_simple,
|
||||
torch.ops.aten.linear,
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _is_linear(node: Node) -> bool:
|
||||
"""Return ``True`` if *node* is a call to one of the supported linear ops."""
|
||||
return is_op(node, _LINEAR_OPS)
|
||||
|
||||
|
||||
def _has_downstream_linear(start: Node, max_depth: int = 3) -> bool:
|
||||
"""BFS from *start* through its users and return ``True`` if a linear op is reachable.
|
||||
|
||||
The search only follows *user* edges (downstream in the data-flow graph)
|
||||
and stops after *max_depth* hops. ``start`` itself is **not** checked.
|
||||
"""
|
||||
visited: set[Node] = {start}
|
||||
queue: deque[Tuple[Node, int]] = deque()
|
||||
|
||||
for user in start.users:
|
||||
queue.append((user, 1))
|
||||
|
||||
while queue:
|
||||
node, depth = queue.popleft()
|
||||
if node in visited:
|
||||
continue
|
||||
visited.add(node)
|
||||
|
||||
if _is_linear(node):
|
||||
return True
|
||||
|
||||
if depth < max_depth:
|
||||
for user in node.users:
|
||||
queue.append((user, depth + 1))
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _find_kv_proj_linears(gm: GraphModule) -> List[Tuple[Node, Node]]:
|
||||
"""Find (fork_point, kv_linear) pairs suitable for aux-stream execution.
|
||||
|
||||
A *fork point* is a node that directly feeds two or more supported linear
|
||||
ops. Among these linears the one that does **not** lead to another linear
|
||||
within a small BFS depth is the KV projection candidate (the lighter
|
||||
branch).
|
||||
|
||||
Returns a list of ``(fork_point, kv_linear_node)`` tuples.
|
||||
"""
|
||||
results: List[Tuple[Node, Node]] = []
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
# Collect direct linear users of this node.
|
||||
linear_users = [u for u in node.users if _is_linear(u)]
|
||||
if len(linear_users) < 2:
|
||||
continue
|
||||
|
||||
# Separate into "has downstream linear" (Q-like) and "does not" (KV-like).
|
||||
kv_candidates = [ln for ln in linear_users if not _has_downstream_linear(ln)]
|
||||
q_candidates = [ln for ln in linear_users if _has_downstream_linear(ln)]
|
||||
|
||||
if not kv_candidates or not q_candidates:
|
||||
continue
|
||||
|
||||
# Pick the KV candidate(s). In MLA there is exactly one per fork point.
|
||||
for kv_linear in kv_candidates:
|
||||
results.append((node, kv_linear))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _create_aux_op(base_op: Callable) -> Callable:
|
||||
"""Create an ``_aux`` variant of a linear op that runs on the auxiliary CUDA stream.
|
||||
|
||||
Uses a custom ``make_fake`` that delegates to the base op's registered fake
|
||||
so that output shapes are computed correctly (linear output shape != input shape).
|
||||
"""
|
||||
return create_derived_custom_op(
|
||||
base_op,
|
||||
"_aux",
|
||||
_make_aux_stream_impl,
|
||||
make_fake=lambda base: lambda *a, **kw: base(*a, **kw),
|
||||
)
|
||||
|
||||
|
||||
def _execute_kv_proj_in_aux_stream(gm: GraphModule) -> Tuple[GraphModule, int]:
|
||||
"""Replace KV projection linears with aux-stream variants.
|
||||
|
||||
For each matched ``(fork_point, kv_linear)`` the rewriter:
|
||||
|
||||
1. Inserts ``record_event_passthrough(fork_point)`` so the main-stream
|
||||
event is recorded *before* the Q-chain kernels are submitted.
|
||||
2. Replaces the KV linear's target with its ``_aux`` variant and wires the
|
||||
``record_event_passthrough`` output as the hidden-state input
|
||||
(creating a true data dependency).
|
||||
|
||||
The remaining KV-chain ops (split, rms_norm, view) stay on the main
|
||||
stream — they are lightweight and run after the aux wait that is built
|
||||
into the derived op.
|
||||
|
||||
Aux-stream variants are created lazily — only for base ops that actually
|
||||
appear in the matched KV positions.
|
||||
"""
|
||||
pairs = _find_kv_proj_linears(gm)
|
||||
if not pairs:
|
||||
return gm, 0
|
||||
|
||||
graph = gm.graph
|
||||
node_order = {n: i for i, n in enumerate(graph.nodes)}
|
||||
|
||||
# Create aux ops lazily for whatever linear op types are found.
|
||||
ops_in_graph = {kv_linear.target for _, kv_linear in pairs}
|
||||
op_dict = {op: _create_aux_op(op) for op in ops_in_graph}
|
||||
|
||||
num_replaced = 0
|
||||
|
||||
for fork_point, kv_linear in pairs:
|
||||
# Find the Q-chain linear(s) so we can insert the event record
|
||||
# *before* the earliest Q-chain op in graph order.
|
||||
q_linears = [u for u in fork_point.users if _is_linear(u) and u is not kv_linear]
|
||||
earliest_q = min(q_linears, key=lambda n: node_order.get(n, 0))
|
||||
|
||||
# Insert record_event_passthrough right before the first Q-chain
|
||||
# linear so the event is recorded before Q kernels hit the GPU.
|
||||
with graph.inserting_before(earliest_q):
|
||||
rec_node = graph.call_function(
|
||||
record_event_passthrough,
|
||||
args=(fork_point,),
|
||||
)
|
||||
|
||||
# Replace KV linear with its aux-stream variant. The hidden-state
|
||||
# input (args[0]) is rewired to ``rec_node`` to create a data
|
||||
# dependency that ensures the event is recorded first.
|
||||
new_args = tuple(rec_node if arg is fork_point else arg for arg in kv_linear.args)
|
||||
|
||||
with graph.inserting_after(kv_linear):
|
||||
new_node = graph.call_function(
|
||||
op_dict[kv_linear.target], args=new_args, kwargs=kv_linear.kwargs
|
||||
)
|
||||
|
||||
kv_linear.replace_all_uses_with(new_node)
|
||||
graph.erase_node(kv_linear)
|
||||
num_replaced += 1
|
||||
|
||||
return gm, num_replaced
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transform class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@TransformRegistry.register("multi_stream_mla_attn")
|
||||
class MultiStreamMLAAttn(BaseTransform):
|
||||
"""Multi-stream Q/KV projection parallelism for MLA attention blocks."""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
# Ensure aux stream and events are set up for the current device.
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
gm, num_matches = _execute_kv_proj_in_aux_stream(gm)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_matches,
|
||||
is_clean=num_matches == 0,
|
||||
has_valid_shapes=num_matches == 0,
|
||||
)
|
||||
return gm, info
|
||||
@ -1,295 +1,195 @@
|
||||
"""Transform for multi-stream execution of MoE layers that have shared experts and routed experts."""
|
||||
|
||||
from threading import RLock
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing import Callable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.multi_stream_utils import (
|
||||
begin_aux_stream_passthrough,
|
||||
cuda_stream_manager,
|
||||
end_aux_stream_passthrough,
|
||||
wait_aux_stream_passthrough,
|
||||
)
|
||||
from ...utils.node_utils import is_op
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
# Previously, CudaStreamManager and the custom ops that use the cuda streams and events were
|
||||
# placed in custom_ops folder. However doing so resulted in CudaStreamManager
|
||||
# being created only in the parent process, but we need each rank to have its own CudaStreamManager that
|
||||
# manages the cuda streams and events for that rank. Placing the logic to instantiate
|
||||
# CudaStreamManager and the custom ops that use the cuda streams and events at the transform level ensures that
|
||||
# each rank has its own CudaStreamManager since each rank applies the transform independently.
|
||||
class _Singleton(type):
|
||||
_instances: Dict[type, Any] = {}
|
||||
_lock = RLock()
|
||||
def _find_merge_add(moe_node: Node) -> Optional[Node]:
|
||||
"""Walk forward from a MoE op through users to find the ``aten.add.Tensor`` merge node.
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
if cls not in cls._instances:
|
||||
with cls._lock:
|
||||
if cls not in cls._instances: # double-checked locking
|
||||
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
The merge ``add`` is the node where shared-expert output and routed-expert
|
||||
output are combined. The search is a breadth-first traversal of the user
|
||||
graph starting from the MoE node.
|
||||
"""
|
||||
visited: Set[Node] = set()
|
||||
queue = list(moe_node.users.keys())
|
||||
while queue:
|
||||
n = queue.pop(0)
|
||||
if n in visited:
|
||||
continue
|
||||
visited.add(n)
|
||||
if is_op(n, torch.ops.aten.add.Tensor):
|
||||
return n
|
||||
queue.extend(n.users.keys())
|
||||
return None
|
||||
|
||||
|
||||
# A singleton that holds the pointers to the cuda streams and events.
|
||||
# Each device has its own cuda streams and events.
|
||||
class CudaStreamManager(metaclass=_Singleton):
|
||||
AUX_STREAM_NAME = "aux"
|
||||
MAIN_STREAM_NAME = "main"
|
||||
devices: List[torch.device] = []
|
||||
events: Dict[torch.device, Dict[str, Any]] = {}
|
||||
streams: Dict[torch.device, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self) -> None:
|
||||
# In case __init__ ever gets called twice, guard against re-init
|
||||
if hasattr(self, "streams"):
|
||||
return
|
||||
|
||||
self._lock = RLock()
|
||||
self.add_device(torch.cuda.current_device())
|
||||
|
||||
def add_device(self, device: int) -> None:
|
||||
if device not in self.devices:
|
||||
self.devices.append(device)
|
||||
with torch.cuda.device(device):
|
||||
self.events[device] = {
|
||||
self.AUX_STREAM_NAME: torch.cuda.Event(),
|
||||
self.MAIN_STREAM_NAME: torch.cuda.Event(),
|
||||
}
|
||||
self.streams[device] = {
|
||||
self.AUX_STREAM_NAME: torch.cuda.Stream(),
|
||||
self.MAIN_STREAM_NAME: torch.cuda.default_stream(),
|
||||
}
|
||||
else:
|
||||
ad_logger.warning(f"CudaStreamManager: Device {device} already added")
|
||||
|
||||
def get_stream(self, device: int, stream_name: str) -> torch.cuda.Stream:
|
||||
return self.streams[device][stream_name]
|
||||
|
||||
def get_event(self, device: int, event_name: str) -> torch.cuda.Event:
|
||||
return self.events[device][event_name]
|
||||
def _get_ancestors(node: Node) -> Set[Node]:
|
||||
"""Return the set of all nodes reachable by walking backwards from *node*."""
|
||||
ancestors: Set[Node] = set()
|
||||
queue = list(node.all_input_nodes)
|
||||
while queue:
|
||||
n = queue.pop()
|
||||
if n in ancestors:
|
||||
continue
|
||||
ancestors.add(n)
|
||||
queue.extend(n.all_input_nodes)
|
||||
return ancestors
|
||||
|
||||
|
||||
# Every device will have a singleton instance of CudaStreamManager.
|
||||
cuda_stream_manager = CudaStreamManager()
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::record_event", mutates_args=())
|
||||
def record_event(device: int, stream_name: str) -> None:
|
||||
event = cuda_stream_manager.get_event(device, stream_name)
|
||||
event.record()
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::wait_event", mutates_args=())
|
||||
def wait_event(device: int, stream_name: str) -> None:
|
||||
event = cuda_stream_manager.get_event(device, stream_name)
|
||||
event.wait()
|
||||
|
||||
|
||||
# skip during compilation
|
||||
@torch._dynamo.disable
|
||||
def record_event_wrapper(
|
||||
fn: Callable,
|
||||
*args: Tuple[Any, ...],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
device = kwargs.pop("device", torch.cuda.current_device())
|
||||
output = fn(*args, **kwargs)
|
||||
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
|
||||
return output
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def aux_stream_wrapper(
|
||||
fn: Callable,
|
||||
*args: Tuple[Any, ...],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
stream_name = cuda_stream_manager.AUX_STREAM_NAME
|
||||
device = kwargs.pop("device", torch.cuda.current_device())
|
||||
with torch.cuda.stream(cuda_stream_manager.get_stream(device, stream_name)):
|
||||
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
|
||||
output = fn(*args, **kwargs)
|
||||
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
return output
|
||||
|
||||
|
||||
# trtllm bf16
|
||||
@torch.library.custom_op("auto_deploy::trtllm_moe_fused_aux", mutates_args=())
|
||||
def trtllm_moe_fused_aux(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w3_w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
device = torch.cuda.current_device()
|
||||
with torch.cuda.stream(
|
||||
cuda_stream_manager.get_stream(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
):
|
||||
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
|
||||
output = torch.ops.auto_deploy.trtllm_moe_fused(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w3_w1_stacked_weight,
|
||||
w2_stacked_weight,
|
||||
is_gated_mlp,
|
||||
act_fn,
|
||||
)
|
||||
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
return output
|
||||
|
||||
|
||||
@trtllm_moe_fused_aux.register_fake
|
||||
def trtllm_moe_fused_aux_fake(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w3_w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
# triton bf16
|
||||
@torch.library.custom_op("auto_deploy::triton_moe_fused_aux", mutates_args=())
|
||||
def triton_moe_fused_aux(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
device = torch.cuda.current_device()
|
||||
with torch.cuda.stream(
|
||||
cuda_stream_manager.get_stream(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
):
|
||||
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
|
||||
output = torch.ops.auto_deploy.triton_moe_fused(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w1_stacked_weight,
|
||||
w2_stacked_weight,
|
||||
)
|
||||
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
return output
|
||||
|
||||
|
||||
@triton_moe_fused_aux.register_fake
|
||||
def triton_moe_fused_aux_fake(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused_aux", mutates_args=())
|
||||
def trtllm_quant_fp8_moe_fused_aux(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
fc1_expert_weights: torch.Tensor,
|
||||
fc2_expert_weights: torch.Tensor,
|
||||
fc1_act_scale: torch.Tensor,
|
||||
fc1_dequant_scale: torch.Tensor,
|
||||
fc2_act_scale_reciprocal: torch.Tensor,
|
||||
fc2_dequant_scale: torch.Tensor,
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
device = torch.cuda.current_device()
|
||||
with torch.cuda.stream(
|
||||
cuda_stream_manager.get_stream(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
):
|
||||
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
|
||||
output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
fc1_expert_weights,
|
||||
fc2_expert_weights,
|
||||
fc1_act_scale,
|
||||
fc1_dequant_scale,
|
||||
fc2_act_scale_reciprocal,
|
||||
fc2_dequant_scale,
|
||||
is_gated_mlp,
|
||||
act_fn,
|
||||
)
|
||||
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
return output
|
||||
|
||||
|
||||
@trtllm_quant_fp8_moe_fused_aux.register_fake
|
||||
def trtllm_quant_fp8_moe_fused_aux_fake(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
fc1_expert_weights: torch.Tensor,
|
||||
fc2_expert_weights: torch.Tensor,
|
||||
fc1_act_scale: torch.Tensor,
|
||||
fc1_dequant_scale: torch.Tensor,
|
||||
fc2_act_scale_reciprocal: torch.Tensor,
|
||||
fc2_dequant_scale: torch.Tensor,
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
def _execute_op_in_aux_stream(
|
||||
gm: GraphModule, op_dict: Dict[Callable, Callable]
|
||||
def _execute_shared_expert_in_aux_stream(
|
||||
gm: GraphModule, moe_ops: List[Callable]
|
||||
) -> Tuple[GraphModule, int]:
|
||||
"""Move shared-expert computation to the auxiliary CUDA stream.
|
||||
|
||||
For each MoE fused op in the graph:
|
||||
1. Walk forward to find the ``aten.add.Tensor`` that merges the
|
||||
shared-expert output and the routed-expert output.
|
||||
2. Identify which ``add`` input is the routed branch (descended from
|
||||
the MoE node) and which is the shared-expert branch.
|
||||
3. Trace the shared-expert branch backwards to collect all its
|
||||
computation nodes and identify the fork point (the latest common
|
||||
ancestor shared with the MoE / routing path).
|
||||
4. Insert ``begin_aux_stream_passthrough`` before the first shared-expert
|
||||
op to switch to the auxiliary CUDA stream.
|
||||
5. Insert ``end_aux_stream_passthrough`` after the last shared-expert op
|
||||
to switch back to the main stream.
|
||||
6. Insert ``wait_aux_stream_passthrough`` on the routed-branch input
|
||||
just before the ``add`` so the main stream waits for the auxiliary
|
||||
stream to finish before merging outputs.
|
||||
"""
|
||||
graph = gm.graph
|
||||
num_replaced = 0
|
||||
|
||||
# Collect targets first to avoid mutating while iterating
|
||||
target_nodes = [n for n in graph.nodes if is_op(n, op_dict.keys())]
|
||||
# Collect targets first to avoid mutating while iterating.
|
||||
target_nodes = [n for n in graph.nodes if is_op(n, moe_ops)]
|
||||
if not target_nodes:
|
||||
return gm, 0
|
||||
|
||||
for n in target_nodes:
|
||||
target_input_node = None
|
||||
for input_node in n.all_input_nodes:
|
||||
if input_node.target == torch.ops.aten.view.default:
|
||||
target_input_node = input_node
|
||||
break
|
||||
# Look through dtype cast nodes (aten.to) to find the view node
|
||||
if input_node.target == torch.ops.aten.to:
|
||||
for nested_input in input_node.all_input_nodes:
|
||||
if nested_input.target == torch.ops.aten.view.default:
|
||||
target_input_node = nested_input
|
||||
break
|
||||
if target_input_node is not None:
|
||||
break
|
||||
node_order = {node: i for i, node in enumerate(graph.nodes)}
|
||||
|
||||
assert target_input_node is not None, f"Target input node not found for node {n}"
|
||||
with graph.inserting_before(target_input_node):
|
||||
kwargs = target_input_node.kwargs.copy()
|
||||
kwargs["device"] = torch.cuda.current_device()
|
||||
new_node = graph.call_function(
|
||||
record_event_wrapper,
|
||||
args=(target_input_node.target, *target_input_node.args),
|
||||
kwargs=kwargs,
|
||||
for moe_node in target_nodes:
|
||||
# ---- Step 1: Find the merge ``add`` node. ----
|
||||
add_node = _find_merge_add(moe_node)
|
||||
if add_node is None:
|
||||
ad_logger.warning(
|
||||
f"No merge add found downstream of MoE node {moe_node.name}; "
|
||||
"skipping multi-stream transform for this node."
|
||||
)
|
||||
target_input_node.replace_all_uses_with(new_node)
|
||||
graph.erase_node(target_input_node)
|
||||
with graph.inserting_after(n):
|
||||
new_node = graph.call_function(op_dict[n.target], args=n.args, kwargs=n.kwargs)
|
||||
n.replace_all_uses_with(new_node)
|
||||
graph.erase_node(n)
|
||||
continue
|
||||
|
||||
# ---- Step 2: Determine which ``add`` input is routed vs. shared. ----
|
||||
arg0, arg1 = add_node.args[0], add_node.args[1]
|
||||
arg0_ancestors = _get_ancestors(arg0)
|
||||
|
||||
if moe_node in arg0_ancestors or arg0 is moe_node:
|
||||
routed_output, shared_output = arg0, arg1
|
||||
else:
|
||||
routed_output, shared_output = arg1, arg0
|
||||
|
||||
# ---- Step 3: Collect shared-expert nodes & find fork point. ----
|
||||
moe_ancestors = _get_ancestors(moe_node)
|
||||
moe_ancestors.add(moe_node)
|
||||
|
||||
shared_nodes: List[Node] = []
|
||||
fork_point: Optional[Node] = None
|
||||
visited: Set[Node] = set()
|
||||
queue = [shared_output]
|
||||
|
||||
while queue:
|
||||
n = queue.pop(0)
|
||||
if n in visited:
|
||||
continue
|
||||
visited.add(n)
|
||||
|
||||
# Skip static weight / parameter nodes.
|
||||
if n.op == "get_attr":
|
||||
continue
|
||||
|
||||
if n in moe_ancestors:
|
||||
# This node is on the MoE / routing path — candidate fork point.
|
||||
if fork_point is None or node_order.get(n, 0) > node_order.get(fork_point, 0):
|
||||
fork_point = n
|
||||
continue
|
||||
|
||||
shared_nodes.append(n)
|
||||
for inp in n.all_input_nodes:
|
||||
queue.append(inp)
|
||||
|
||||
if not shared_nodes or fork_point is None:
|
||||
ad_logger.warning(
|
||||
f"Could not identify shared-expert subgraph for MoE node "
|
||||
f"{moe_node.name}; skipping multi-stream transform for this node."
|
||||
)
|
||||
continue
|
||||
|
||||
# Order shared nodes by their position in the graph.
|
||||
shared_nodes.sort(key=lambda n: node_order.get(n, 0))
|
||||
first_shared = shared_nodes[0]
|
||||
|
||||
# Sanity check: the first shared op must directly consume the fork
|
||||
# point so we can wire begin_aux_stream_passthrough into it.
|
||||
if fork_point not in first_shared.all_input_nodes:
|
||||
ad_logger.warning(
|
||||
f"First shared-expert op ({first_shared.name}) does not directly "
|
||||
f"consume fork point ({fork_point.name}); skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
# ---- Step 4: Insert begin_aux before the first shared-expert op. ----
|
||||
# NOTE: do NOT bake ``torch.cuda.current_device()`` into the graph —
|
||||
# that would hard-code device 0 and break on other ranks in a
|
||||
# multi-GPU setup. Omitting ``device`` lets the passthrough
|
||||
# functions resolve the device at **runtime** (default ``-1``).
|
||||
with graph.inserting_before(first_shared):
|
||||
begin_aux_node = graph.call_function(
|
||||
begin_aux_stream_passthrough,
|
||||
args=(fork_point,),
|
||||
)
|
||||
|
||||
# Create a data dependency: first_shared reads begin_aux output
|
||||
# instead of fork_point.
|
||||
first_shared.args = tuple(
|
||||
begin_aux_node if arg is fork_point else arg for arg in first_shared.args
|
||||
)
|
||||
|
||||
# ---- Step 5: Insert end_aux after the last shared-expert op. ----
|
||||
with graph.inserting_after(shared_output):
|
||||
end_aux_node = graph.call_function(
|
||||
end_aux_stream_passthrough,
|
||||
args=(shared_output,),
|
||||
)
|
||||
|
||||
# Replace shared-expert input to ``add`` with end_aux output.
|
||||
add_node.args = tuple(
|
||||
end_aux_node if arg is shared_output else arg for arg in add_node.args
|
||||
)
|
||||
|
||||
# ---- Step 6: Insert wait_aux before the ``add``. ----
|
||||
with graph.inserting_before(add_node):
|
||||
wait_aux_node = graph.call_function(
|
||||
wait_aux_stream_passthrough,
|
||||
args=(routed_output,),
|
||||
)
|
||||
|
||||
add_node.args = tuple(
|
||||
wait_aux_node if arg is routed_output else arg for arg in add_node.args
|
||||
)
|
||||
|
||||
num_replaced += 1
|
||||
|
||||
return gm, num_replaced
|
||||
@ -306,14 +206,16 @@ class MultiStreamMOE(BaseTransform):
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
op_dict = {
|
||||
torch.ops.auto_deploy.trtllm_moe_fused: torch.ops.auto_deploy.trtllm_moe_fused_aux,
|
||||
torch.ops.auto_deploy.triton_moe_fused: torch.ops.auto_deploy.triton_moe_fused_aux,
|
||||
torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused: torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused_aux,
|
||||
}
|
||||
base_ops = [
|
||||
torch.ops.auto_deploy.trtllm_moe_fused,
|
||||
torch.ops.auto_deploy.triton_moe_fused,
|
||||
torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused,
|
||||
torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused,
|
||||
]
|
||||
|
||||
# Ensure that aux stream and events for the current device are added to the CudaStreamManager.
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
gm, num_matches = _execute_op_in_aux_stream(gm, op_dict)
|
||||
gm, num_matches = _execute_shared_expert_in_aux_stream(gm, base_ops)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
@ -321,5 +223,4 @@ class MultiStreamMOE(BaseTransform):
|
||||
is_clean=num_matches == 0,
|
||||
has_valid_shapes=num_matches == 0,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
@ -20,6 +20,100 @@ from torch.utils._pytree import _LEAF_SPEC, TreeSpec
|
||||
from .logger import ad_logger
|
||||
from .node_utils import get_weight_tensor, is_op
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dynamic custom-op derivation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
# These are used to create new custom ops that share the schema of an existing
|
||||
# op but wrap it with additional logic (e.g. stream management). A single
|
||||
# module-level dict of ``Library`` objects (keyed by namespace) is used so that
|
||||
# the registrations persist and are visible via ``torch.ops.<namespace>.*``.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_derived_op_libs: Dict[str, torch.library.Library] = {}
|
||||
_derived_op_registry: Dict[str, Callable] = {}
|
||||
|
||||
|
||||
def _get_lib(namespace: str) -> torch.library.Library:
|
||||
"""Return (and lazily create) a ``FRAGMENT`` Library for *namespace*."""
|
||||
if namespace not in _derived_op_libs:
|
||||
_derived_op_libs[namespace] = torch.library.Library(namespace, "FRAGMENT")
|
||||
return _derived_op_libs[namespace]
|
||||
|
||||
|
||||
def create_derived_custom_op(
|
||||
base_op: Callable,
|
||||
suffix: str,
|
||||
make_impl: Callable[[Callable], Callable],
|
||||
make_fake: Optional[Callable[[Callable], Callable]] = None,
|
||||
) -> Callable:
|
||||
"""Dynamically create a new custom op derived from an existing one.
|
||||
|
||||
The new op has the **same** schema (arguments, default values, and return
|
||||
type) as *base_op* but with a different name (``<base_name><suffix>``) and a
|
||||
custom implementation produced by *make_impl*.
|
||||
|
||||
Args:
|
||||
base_op: The base custom op — either an ``OpOverloadPacket``
|
||||
(e.g. ``torch.ops.auto_deploy.trtllm_moe_fused``) or an
|
||||
``OpOverload`` (e.g. ``…trtllm_moe_fused.default``).
|
||||
suffix: Suffix appended to the base op name to form the new op name
|
||||
(e.g. ``"_aux"``).
|
||||
make_impl: A factory ``(base_overload) -> impl_fn`` that receives the
|
||||
resolved base ``OpOverload`` and returns the *implementation*
|
||||
function for the new op. ``impl_fn`` will be called with the
|
||||
same positional/keyword arguments as *base_op*.
|
||||
make_fake: Optional factory ``(base_overload) -> fake_fn`` that returns
|
||||
the *Meta / fake-tensor* implementation. When ``None`` the
|
||||
default fake implementation ``torch.empty_like(args[0])`` is used.
|
||||
|
||||
Returns:
|
||||
The newly registered op as an ``OpOverloadPacket``
|
||||
(e.g. ``torch.ops.auto_deploy.<name><suffix>``). Repeated calls with
|
||||
the same *base_op* and *suffix* return the cached op.
|
||||
"""
|
||||
base_overload = base_op.default if hasattr(base_op, "default") else base_op
|
||||
schema = base_overload._schema
|
||||
|
||||
# e.g. "auto_deploy::trtllm_moe_fused"
|
||||
qualified_name = schema.name
|
||||
namespace, base_name = qualified_name.split("::")
|
||||
new_name = f"{base_name}{suffix}"
|
||||
new_qualified = f"{namespace}::{new_name}"
|
||||
|
||||
# Return the cached op if it was already created.
|
||||
if new_qualified in _derived_op_registry:
|
||||
return _derived_op_registry[new_qualified]
|
||||
|
||||
# Build the schema string for the derived op. ``str(schema)`` produces a
|
||||
# fully-qualified string such as
|
||||
# auto_deploy::trtllm_moe_fused(Tensor x, …) -> Tensor
|
||||
# We replace the qualified name with the bare new name (the Library already
|
||||
# knows its namespace).
|
||||
new_schema_str = str(schema).replace(qualified_name, new_name, 1)
|
||||
|
||||
lib = _get_lib(namespace)
|
||||
lib.define(new_schema_str)
|
||||
|
||||
# Register the real implementation for all devices.
|
||||
# We use "CompositeExplicitAutograd" so that we can provide a separate
|
||||
# Meta / fake kernel for shape inference.
|
||||
lib.impl(new_name, make_impl(base_overload), "CompositeExplicitAutograd")
|
||||
|
||||
# Register the Meta / fake implementation.
|
||||
if make_fake is not None:
|
||||
lib.impl(new_name, make_fake(base_overload), "Meta")
|
||||
else:
|
||||
|
||||
def _default_fake(*args, **kwargs):
|
||||
return torch.empty_like(args[0])
|
||||
|
||||
lib.impl(new_name, _default_fake, "Meta")
|
||||
|
||||
new_op = getattr(getattr(torch.ops, namespace), new_name)
|
||||
_derived_op_registry[new_qualified] = new_op
|
||||
return new_op
|
||||
|
||||
|
||||
_NoValType = type("_NoValType", (), {})
|
||||
|
||||
|
||||
|
||||
242
tensorrt_llm/_torch/auto_deploy/utils/multi_stream_utils.py
Normal file
242
tensorrt_llm/_torch/auto_deploy/utils/multi_stream_utils.py
Normal file
@ -0,0 +1,242 @@
|
||||
"""Shared CUDA multi-stream utilities for multi-stream transforms.
|
||||
|
||||
This module provides the core infrastructure for executing parts of an FX graph
|
||||
on auxiliary CUDA streams. It is consumed by the multi-stream MoE and MLA
|
||||
attention transforms in ``..transform.library``.
|
||||
|
||||
Key components:
|
||||
- ``CudaStreamManager``: per-device singleton managing auxiliary streams/events.
|
||||
- Custom ops ``record_event`` / ``wait_event``: graph-safe event primitives.
|
||||
- Passthrough helpers that switch streams while preserving the data-flow edges
|
||||
required by FX graph execution and CUDA graph capture.
|
||||
- ``_make_aux_stream_impl``: factory for building an implementation that runs
|
||||
a base op on the auxiliary CUDA stream.
|
||||
"""
|
||||
|
||||
from threading import RLock
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from .logger import ad_logger
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton metaclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Singleton(type):
|
||||
_instances: Dict[type, Any] = {}
|
||||
_lock = RLock()
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
if cls not in cls._instances:
|
||||
with cls._lock:
|
||||
if cls not in cls._instances: # double-checked locking
|
||||
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CudaStreamManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Previously, CudaStreamManager and the custom ops that use the cuda streams and events were
|
||||
# placed in custom_ops folder. However doing so resulted in CudaStreamManager
|
||||
# being created only in the parent process, but we need each rank to have its own CudaStreamManager that
|
||||
# manages the cuda streams and events for that rank. Placing the logic to instantiate
|
||||
# CudaStreamManager and the custom ops that use the cuda streams and events at the transform level ensures that
|
||||
# each rank has its own CudaStreamManager since each rank applies the transform independently.
|
||||
|
||||
|
||||
class CudaStreamManager(metaclass=_Singleton):
|
||||
AUX_STREAM_NAME = "aux"
|
||||
MAIN_STREAM_NAME = "main"
|
||||
devices: List[torch.device] = []
|
||||
events: Dict[torch.device, Dict[str, Any]] = {}
|
||||
streams: Dict[torch.device, Dict[str, Any]] = {}
|
||||
# Per-device save slot for the caller's stream. ``begin_aux_stream_passthrough``
|
||||
# saves the real current stream here so that ``end_aux_stream_passthrough`` can
|
||||
# restore it — this is critical during CUDA graph capture where the capture stream
|
||||
# differs from ``torch.cuda.default_stream()``.
|
||||
_caller_streams: Dict[int, Any] = {}
|
||||
|
||||
def __init__(self) -> None:
|
||||
# In case __init__ ever gets called twice, guard against re-init
|
||||
if hasattr(self, "streams"):
|
||||
return
|
||||
|
||||
self._lock = RLock()
|
||||
self.add_device(torch.cuda.current_device())
|
||||
|
||||
def add_device(self, device: int) -> None:
|
||||
if device not in self.devices:
|
||||
self.devices.append(device)
|
||||
with torch.cuda.device(device):
|
||||
self.events[device] = {
|
||||
self.AUX_STREAM_NAME: torch.cuda.Event(),
|
||||
self.MAIN_STREAM_NAME: torch.cuda.Event(),
|
||||
}
|
||||
self.streams[device] = {
|
||||
self.AUX_STREAM_NAME: torch.cuda.Stream(),
|
||||
self.MAIN_STREAM_NAME: torch.cuda.default_stream(),
|
||||
}
|
||||
else:
|
||||
ad_logger.warning(f"CudaStreamManager: Device {device} already added")
|
||||
|
||||
def get_stream(self, device: int, stream_name: str) -> torch.cuda.Stream:
|
||||
return self.streams[device][stream_name]
|
||||
|
||||
def get_event(self, device: int, event_name: str) -> torch.cuda.Event:
|
||||
return self.events[device][event_name]
|
||||
|
||||
|
||||
# Every device will have a singleton instance of CudaStreamManager.
|
||||
cuda_stream_manager = CudaStreamManager()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom ops — graph-safe CUDA event primitives
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::record_event", mutates_args=())
|
||||
def record_event(device: int, stream_name: str) -> None:
|
||||
event = cuda_stream_manager.get_event(device, stream_name)
|
||||
event.record()
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::wait_event", mutates_args=())
|
||||
def wait_event(device: int, stream_name: str) -> None:
|
||||
event = cuda_stream_manager.get_event(device, stream_name)
|
||||
event.wait()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Passthrough helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def record_event_passthrough(
|
||||
x: torch.Tensor,
|
||||
*,
|
||||
device: int = -1,
|
||||
) -> torch.Tensor:
|
||||
"""Record a CUDA event on the main stream and return the input unchanged.
|
||||
|
||||
Inserted after the gating/routing computation to mark a synchronization
|
||||
point. The aux stream waits for this event before starting the MoE
|
||||
computation, enabling overlap between the shared expert (main stream)
|
||||
and routed experts (aux stream).
|
||||
"""
|
||||
if device < 0:
|
||||
device = torch.cuda.current_device()
|
||||
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
|
||||
return x
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def begin_aux_stream_passthrough(
|
||||
x: torch.Tensor,
|
||||
*,
|
||||
device: int = -1,
|
||||
) -> torch.Tensor:
|
||||
"""Record a CUDA event on the main stream, switch to aux, and wait for it.
|
||||
|
||||
After this function returns the thread-local current stream is the
|
||||
auxiliary stream. All subsequent GPU ops dispatched by the FX graph
|
||||
interpreter will be recorded on aux until ``end_aux_stream_passthrough``
|
||||
switches back to main.
|
||||
"""
|
||||
if device < 0:
|
||||
device = torch.cuda.current_device()
|
||||
# Save the *actual* current stream so ``end_aux`` can restore it.
|
||||
# During CUDA graph capture the current stream is the capture stream,
|
||||
# which is NOT ``torch.cuda.default_stream()``.
|
||||
caller_stream = torch.cuda.current_stream(device)
|
||||
cuda_stream_manager._caller_streams[device] = caller_stream
|
||||
# Record where the caller's stream has reached so aux knows when data is ready.
|
||||
main_event = cuda_stream_manager.get_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
|
||||
main_event.record(caller_stream)
|
||||
# Switch the thread-local current stream to aux.
|
||||
aux_stream = cuda_stream_manager.get_stream(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
torch.cuda.set_stream(aux_stream)
|
||||
# Make aux wait for the main-stream event before executing any work.
|
||||
aux_stream.wait_event(main_event)
|
||||
return x
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def end_aux_stream_passthrough(
|
||||
x: torch.Tensor,
|
||||
*,
|
||||
device: int = -1,
|
||||
) -> torch.Tensor:
|
||||
"""Record a CUDA event on the aux stream and switch back to the caller's stream.
|
||||
|
||||
This does **not** make the caller's stream wait for aux. The caller must
|
||||
insert ``wait_aux_stream_passthrough`` at the point where both branches
|
||||
need to be synchronised (typically right before the ``add`` that merges
|
||||
shared-expert and routed-expert outputs).
|
||||
"""
|
||||
if device < 0:
|
||||
device = torch.cuda.current_device()
|
||||
# Record the aux-stream progress so the caller's stream can wait for it later.
|
||||
aux_event = cuda_stream_manager.get_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
aux_event.record()
|
||||
# Restore the caller's stream saved by ``begin_aux_stream_passthrough``.
|
||||
# This is critical during CUDA graph capture where the capture stream
|
||||
# differs from ``torch.cuda.default_stream()``.
|
||||
caller_stream = cuda_stream_manager._caller_streams.pop(device, None)
|
||||
if caller_stream is not None:
|
||||
torch.cuda.set_stream(caller_stream)
|
||||
else:
|
||||
torch.cuda.set_stream(
|
||||
cuda_stream_manager.get_stream(device, cuda_stream_manager.MAIN_STREAM_NAME)
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def wait_aux_stream_passthrough(
|
||||
x: torch.Tensor,
|
||||
*,
|
||||
device: int = -1,
|
||||
) -> torch.Tensor:
|
||||
"""Make the current stream wait for the auxiliary stream's last recorded event.
|
||||
|
||||
This is a GPU-side wait (non-blocking on the CPU). Insert this right
|
||||
before the ``add`` that merges shared-expert output (computed on aux)
|
||||
with routed-expert output (computed on main).
|
||||
|
||||
Uses ``torch.cuda.current_stream()`` rather than the stored default stream
|
||||
so that the correct stream is waited on during CUDA graph capture.
|
||||
"""
|
||||
if device < 0:
|
||||
device = torch.cuda.current_device()
|
||||
aux_event = cuda_stream_manager.get_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
torch.cuda.current_stream(device).wait_event(aux_event)
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Aux-stream implementation factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_aux_stream_impl(base_overload: Callable) -> Callable:
|
||||
"""Build an implementation that runs *base_overload* on the auxiliary CUDA stream."""
|
||||
|
||||
def _impl(*args, **kwargs):
|
||||
device = torch.cuda.current_device()
|
||||
with torch.cuda.stream(
|
||||
cuda_stream_manager.get_stream(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
):
|
||||
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
|
||||
output = base_overload(*args, **kwargs)
|
||||
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
torch.ops.auto_deploy.wait_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
return output
|
||||
|
||||
return _impl
|
||||
@ -343,6 +343,10 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
|
||||
"sharding_source": ['factory', 'heuristic'],
|
||||
"sharding_dims": ['ep', 'bmm'],
|
||||
},
|
||||
"multi_stream_moe": {
|
||||
"stage": "compile",
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -455,7 +459,15 @@ class TestGLM4Flash(LlmapiAccuracyTestHarness):
|
||||
"fuse_nvfp4_moe": {
|
||||
"allow_different_input_scales": True,
|
||||
},
|
||||
},
|
||||
"multi_stream_moe": {
|
||||
"stage": "compile",
|
||||
"enabled": True,
|
||||
},
|
||||
"multi_stream_mla_attn": {
|
||||
"stage": "compile",
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
}
|
||||
if enable_chunked_prefill:
|
||||
config["enable_chunked_prefill"] = True
|
||||
|
||||
@ -4,6 +4,9 @@ import pytest
|
||||
import torch
|
||||
from _custom_op_utils import torch_rope_reference
|
||||
|
||||
# Import after we've imported torch (to ensure custom ops are registered)
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.rope import triton_rope # noqa: F401
|
||||
|
||||
|
||||
def _precompute_freqs_cis(
|
||||
seq_len: int, head_dim: int, rope_theta: Optional[float] = None
|
||||
|
||||
@ -1,134 +0,0 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.multi_stream_moe import (
|
||||
aux_stream_wrapper,
|
||||
cuda_stream_manager,
|
||||
record_event_wrapper,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import canonicalize_graph
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::multi_stream_linear", mutates_args=())
|
||||
def multi_stream_linear(
|
||||
input: torch.Tensor, weight0: torch.Tensor, weight1: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
output = torch.ops.aten.linear(input, weight0)
|
||||
output = torch.ops.aten.linear(output, weight1)
|
||||
return output
|
||||
|
||||
|
||||
@multi_stream_linear.register_fake
|
||||
def multi_stream_linear_fake(input, weight0, weight1):
|
||||
"""Fake implementation of multi_stream_linear."""
|
||||
output = torch.ops.aten.linear(input, weight0)
|
||||
return torch.ops.aten.linear(output, weight1)
|
||||
|
||||
|
||||
def replace_multi_stream_linear_with_aux_stream_wrapper(gm: GraphModule) -> Tuple[GraphModule, int]:
|
||||
"""Traverse ``gm`` and replace all ``auto_deploy::multi_stream_linear`` ops with ``aux_stream_wrapper``.
|
||||
|
||||
The replacement preserves the original args/kwargs of the node.
|
||||
After rewriting, the graph is cleaned and recompiled.
|
||||
|
||||
Args:
|
||||
gm: The FX graph module to transform.
|
||||
aux_stream_wrapper: A callable to replace the custom op with.
|
||||
|
||||
Returns:
|
||||
A tuple of (gm, num_replaced)
|
||||
"""
|
||||
graph = gm.graph
|
||||
num_replaced = 0
|
||||
|
||||
# Collect targets first to avoid mutating while iterating
|
||||
target_nodes: list[Node] = []
|
||||
target_nodes = [n for n in graph.nodes if is_op(n, torch.ops.auto_deploy.multi_stream_linear)]
|
||||
|
||||
for n in target_nodes:
|
||||
target_input_node = None
|
||||
for input_node in n.all_input_nodes:
|
||||
if len(input_node.users) > 1:
|
||||
target_input_node = input_node
|
||||
break
|
||||
if target_input_node is None:
|
||||
raise ValueError(f"Target input node not found for node {n}")
|
||||
with graph.inserting_before(target_input_node):
|
||||
kwargs = target_input_node.kwargs.copy()
|
||||
kwargs["device"] = torch.cuda.current_device()
|
||||
new_node = graph.call_function(
|
||||
record_event_wrapper,
|
||||
args=(target_input_node.target, *target_input_node.args),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
target_input_node.replace_all_uses_with(new_node)
|
||||
graph.erase_node(target_input_node)
|
||||
with graph.inserting_after(n):
|
||||
new_node = graph.call_function(
|
||||
aux_stream_wrapper, args=(n.target, *n.args), kwargs=n.kwargs
|
||||
)
|
||||
n.replace_all_uses_with(new_node)
|
||||
graph.erase_node(n)
|
||||
num_replaced += 1
|
||||
|
||||
if num_replaced:
|
||||
canonicalize_graph(gm)
|
||||
|
||||
return gm, num_replaced
|
||||
|
||||
|
||||
class ParallelTwoLinear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.fc10 = nn.Linear(in_dim, in_dim)
|
||||
self.fc11 = nn.Linear(in_dim, out_dim)
|
||||
self.fc2 = nn.Linear(in_dim, out_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.relu(x)
|
||||
y0 = self.fc2(x)
|
||||
y1 = torch.ops.auto_deploy.multi_stream_linear(x, self.fc10.weight, self.fc11.weight)
|
||||
return y0 + y1
|
||||
|
||||
|
||||
def test_multi_stream_linear():
|
||||
in_dim, out_dim = 128, 256
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
model = (
|
||||
nn.Sequential(ParallelTwoLinear(in_dim, out_dim), ParallelTwoLinear(out_dim, out_dim))
|
||||
.eval()
|
||||
.to("cuda")
|
||||
)
|
||||
|
||||
# Example input used for export
|
||||
example_input = torch.randn(4, in_dim).to("cuda")
|
||||
|
||||
# Export the graph
|
||||
egm = torch.export.export(model, (example_input,))
|
||||
gm = egm.module()
|
||||
|
||||
test_x = torch.randn(4, in_dim).to("cuda")
|
||||
ref_output = model(test_x)
|
||||
|
||||
# pattern matching and replace
|
||||
gm, num_replaced = replace_multi_stream_linear_with_aux_stream_wrapper(gm)
|
||||
|
||||
assert num_replaced == 2
|
||||
y = gm(test_x)
|
||||
assert torch.allclose(y, ref_output)
|
||||
|
||||
static_x = torch.randn(4, in_dim).to("cuda")
|
||||
static_output = torch.randn(4, out_dim).to("cuda")
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
static_output.copy_(gm(static_x))
|
||||
|
||||
static_x.copy_(test_x)
|
||||
graph.replay()
|
||||
|
||||
assert torch.allclose(static_output, ref_output)
|
||||
@ -0,0 +1,230 @@
|
||||
"""Tests for multi-stream Q/KV projection parallelism in MLA attention.
|
||||
|
||||
The test builds a minimal mock model that mirrors the MLA fork pattern:
|
||||
a shared input feeds two parallel linear chains (one heavier "Q-like",
|
||||
one lighter "KV-like") whose outputs are combined with an add.
|
||||
|
||||
The transform should:
|
||||
1. Detect the fork point (shared input with 2+ linear users).
|
||||
2. Identify the lighter KV-like linear (no downstream linear within
|
||||
a few hops) vs. the heavier Q-like chain (has a downstream linear).
|
||||
3. Move the KV linear onto the auxiliary CUDA stream.
|
||||
4. Preserve numerical correctness.
|
||||
5. Be compatible with CUDA graph capture & replay.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.multi_stream_attn import (
|
||||
_execute_kv_proj_in_aux_stream,
|
||||
_find_kv_proj_linears,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils.multi_stream_utils import cuda_stream_manager
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers -- mock MLA-like module
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockMLABlock(nn.Module):
|
||||
"""Simplified MLA-like attention block with Q and KV projection chains.
|
||||
|
||||
Q chain (heavier): q_a_proj -> relu (stand-in for rms_norm) -> q_b_proj
|
||||
KV chain (lighter): kv_a_proj
|
||||
Merge: add(q_b_proj_output, kv_a_proj_output)
|
||||
|
||||
The layernorm at the output simulates the inter-layer distance in a real
|
||||
transformer (output projection, residual add, layernorm) so that the
|
||||
next layer's fork point is beyond the BFS max_depth from this layer's
|
||||
KV linear.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int, q_inner_dim: int, kv_out_dim: int):
|
||||
super().__init__()
|
||||
# Q chain: two linears with a non-linearity in between
|
||||
self.q_a_proj = nn.Linear(hidden_dim, q_inner_dim, bias=False)
|
||||
self.q_b_proj = nn.Linear(q_inner_dim, kv_out_dim, bias=False)
|
||||
# KV chain: single linear
|
||||
self.kv_a_proj = nn.Linear(hidden_dim, kv_out_dim, bias=False)
|
||||
# Inter-layer distance (layernorm + relu simulate residual + norm)
|
||||
self.layernorm = nn.LayerNorm(kv_out_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Q chain: q_a_proj -> relu -> q_b_proj
|
||||
q = self.q_a_proj(x)
|
||||
q = torch.nn.functional.relu(q)
|
||||
q = self.q_b_proj(q)
|
||||
# KV chain: kv_a_proj
|
||||
kv = self.kv_a_proj(x)
|
||||
out = q + kv
|
||||
# Inter-layer distance to push next layer's linears beyond BFS depth
|
||||
return self.layernorm(torch.nn.functional.relu(out))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_gm(model, example_input):
|
||||
"""Export *model* to an FX GraphModule."""
|
||||
egm = torch.export.export(model, (example_input,))
|
||||
return egm.module()
|
||||
|
||||
|
||||
def test_pattern_matching_single_block():
|
||||
"""The pattern matcher should find exactly one pair for a single MLA block."""
|
||||
model = MockMLABlock(128, 64, 128).eval().to("cuda")
|
||||
example_input = torch.randn(4, 128, device="cuda")
|
||||
gm = _build_gm(model, example_input)
|
||||
|
||||
pairs = _find_kv_proj_linears(gm)
|
||||
assert len(pairs) == 1, f"Expected 1 fork-point pair, got {len(pairs)}"
|
||||
|
||||
|
||||
def test_pattern_matching_multi_block():
|
||||
"""Multiple layers with sufficient inter-layer distance should all be matched."""
|
||||
hidden_dim, q_inner_dim, kv_out_dim = 128, 64, 128
|
||||
model = (
|
||||
nn.Sequential(
|
||||
MockMLABlock(hidden_dim, q_inner_dim, kv_out_dim),
|
||||
MockMLABlock(kv_out_dim, q_inner_dim, kv_out_dim),
|
||||
)
|
||||
.eval()
|
||||
.to("cuda")
|
||||
)
|
||||
example_input = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example_input)
|
||||
|
||||
pairs = _find_kv_proj_linears(gm)
|
||||
assert len(pairs) == 2, f"Expected 2 fork-point pairs, got {len(pairs)}"
|
||||
|
||||
|
||||
def test_numerical_correctness():
|
||||
"""After the transform the GraphModule must produce the same output as the original model."""
|
||||
hidden_dim, q_inner_dim, kv_out_dim = 128, 64, 128
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = MockMLABlock(hidden_dim, q_inner_dim, kv_out_dim).eval().to("cuda")
|
||||
example_input = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example_input)
|
||||
|
||||
test_x = torch.randn(4, hidden_dim, device="cuda")
|
||||
ref_output = model(test_x)
|
||||
|
||||
gm, num_replaced = _execute_kv_proj_in_aux_stream(gm)
|
||||
|
||||
assert num_replaced == 1, f"Expected 1 replacement, got {num_replaced}"
|
||||
|
||||
y = gm(test_x)
|
||||
assert torch.allclose(y, ref_output, atol=1e-5), (
|
||||
f"Output mismatch: max diff = {(y - ref_output).abs().max().item()}"
|
||||
)
|
||||
|
||||
|
||||
def test_numerical_correctness_multi_block():
|
||||
"""Multi-block correctness test."""
|
||||
hidden_dim, q_inner_dim, kv_out_dim = 128, 64, 128
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = (
|
||||
nn.Sequential(
|
||||
MockMLABlock(hidden_dim, q_inner_dim, kv_out_dim),
|
||||
MockMLABlock(kv_out_dim, q_inner_dim, kv_out_dim),
|
||||
)
|
||||
.eval()
|
||||
.to("cuda")
|
||||
)
|
||||
example_input = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example_input)
|
||||
|
||||
test_x = torch.randn(4, hidden_dim, device="cuda")
|
||||
ref_output = model(test_x)
|
||||
|
||||
gm, num_replaced = _execute_kv_proj_in_aux_stream(gm)
|
||||
|
||||
assert num_replaced == 2, f"Expected 2 replacements, got {num_replaced}"
|
||||
|
||||
y = gm(test_x)
|
||||
assert torch.allclose(y, ref_output, atol=1e-5), (
|
||||
f"Output mismatch: max diff = {(y - ref_output).abs().max().item()}"
|
||||
)
|
||||
|
||||
|
||||
def test_cuda_graph_compatibility():
|
||||
"""The transformed GraphModule must work under CUDA graph capture and replay."""
|
||||
hidden_dim, q_inner_dim, kv_out_dim = 128, 64, 128
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = MockMLABlock(hidden_dim, q_inner_dim, kv_out_dim).eval().to("cuda")
|
||||
example_input = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example_input)
|
||||
|
||||
test_x = torch.randn(4, hidden_dim, device="cuda")
|
||||
ref_output = model(test_x)
|
||||
|
||||
gm, num_replaced = _execute_kv_proj_in_aux_stream(gm)
|
||||
assert num_replaced == 1
|
||||
|
||||
# Allocate static buffers for CUDA graph capture.
|
||||
static_x = torch.randn(4, hidden_dim, device="cuda")
|
||||
static_output = torch.randn(4, kv_out_dim, device="cuda")
|
||||
|
||||
# Warm up (required before capture).
|
||||
for _ in range(3):
|
||||
static_output.copy_(gm(static_x))
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
static_output.copy_(gm(static_x))
|
||||
|
||||
static_x.copy_(test_x)
|
||||
graph.replay()
|
||||
|
||||
assert torch.allclose(static_output, ref_output, atol=1e-5), (
|
||||
f"CUDA graph output mismatch: max diff = {(static_output - ref_output).abs().max().item()}"
|
||||
)
|
||||
|
||||
|
||||
def test_no_match_on_single_linear():
|
||||
"""A node with only one linear user should not be matched."""
|
||||
|
||||
class SingleLinear(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(dim, dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(x)
|
||||
|
||||
model = SingleLinear(64).eval().to("cuda")
|
||||
example_input = torch.randn(4, 64, device="cuda")
|
||||
gm = _build_gm(model, example_input)
|
||||
|
||||
pairs = _find_kv_proj_linears(gm)
|
||||
assert len(pairs) == 0, f"Expected 0 matches, got {len(pairs)}"
|
||||
|
||||
|
||||
def test_no_match_when_both_have_downstream_linear():
|
||||
"""When *both* branches have downstream linears the pattern should not match."""
|
||||
|
||||
class BothHeavy(nn.Module):
|
||||
def __init__(self, dim, inner):
|
||||
super().__init__()
|
||||
self.fc_a1 = nn.Linear(dim, inner, bias=False)
|
||||
self.fc_a2 = nn.Linear(inner, dim, bias=False)
|
||||
self.fc_b1 = nn.Linear(dim, inner, bias=False)
|
||||
self.fc_b2 = nn.Linear(inner, dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.fc_a2(torch.relu(self.fc_a1(x)))
|
||||
b = self.fc_b2(torch.relu(self.fc_b1(x)))
|
||||
return a + b
|
||||
|
||||
model = BothHeavy(64, 32).eval().to("cuda")
|
||||
example_input = torch.randn(4, 64, device="cuda")
|
||||
gm = _build_gm(model, example_input)
|
||||
|
||||
pairs = _find_kv_proj_linears(gm)
|
||||
assert len(pairs) == 0, f"Expected 0 matches, got {len(pairs)}"
|
||||
@ -0,0 +1,502 @@
|
||||
"""Tests for multi-stream MoE shared-expert transform across model architectures.
|
||||
|
||||
Verifies that ``_execute_shared_expert_in_aux_stream`` correctly identifies the
|
||||
shared-expert branch and moves it to the auxiliary CUDA stream for the MoE
|
||||
patterns used in DeepSeek V3, GLM4 MoE Lite, Mixtral, and Nemotron-H (with
|
||||
and without latent projections).
|
||||
|
||||
Architecture patterns tested:
|
||||
|
||||
**DeepSeek V3 / GLM4 MoE Lite** — Gated-MLP shared expert
|
||||
(gate_proj + up_proj → SiLU gate → down_proj).
|
||||
Routed MoE dispatched first; shared expert on ``identity``.
|
||||
Merge: ``moe_out + shared_out``.
|
||||
|
||||
**Mixtral** — Pure routed MoE, *no* shared expert.
|
||||
The transform must produce zero matches (no-op).
|
||||
|
||||
**Nemotron-H (no latent)** — Simple-MLP shared expert
|
||||
(up_proj → ReLU² → down_proj).
|
||||
Shared expert dispatched first; routed MoE second.
|
||||
Merge: ``shared_out + routed_out``.
|
||||
|
||||
**Nemotron-H (with latent projection)** — Same shared expert as above,
|
||||
but the routed path wraps the MoE op with
|
||||
``fc1_latent_proj → MoE → fc2_latent_proj``.
|
||||
Tests that the BFS from MoE to the merge ``add`` traverses the extra
|
||||
projection nodes correctly.
|
||||
|
||||
Each architecture is tested for:
|
||||
1. Pattern matching — correct number of replacements.
|
||||
2. Graph structure — ``begin_aux``, ``end_aux``, ``wait_aux`` nodes present.
|
||||
3. Numerical correctness — output matches eager reference within tolerance.
|
||||
4. CUDA graph compatibility — capture + replay produces correct output.
|
||||
5. Multi-layer stacking — multiple MoE layers handled independently.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.multi_stream_moe import (
|
||||
_execute_shared_expert_in_aux_stream,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils.multi_stream_utils import (
|
||||
begin_aux_stream_passthrough,
|
||||
cuda_stream_manager,
|
||||
end_aux_stream_passthrough,
|
||||
wait_aux_stream_passthrough,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock fused-MoE custom op (distinct name to avoid conflicts with other tests)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::mock_fused_moe_moe_test", mutates_args=())
|
||||
def mock_fused_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
expert_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Mock fused MoE: a simple linear transform standing in for the real kernel."""
|
||||
return torch.ops.aten.linear(x, expert_weight)
|
||||
|
||||
|
||||
@mock_fused_moe.register_fake
|
||||
def _mock_fused_moe_fake(x, selected_experts, routing_weights, expert_weight):
|
||||
return torch.ops.aten.linear(x, expert_weight)
|
||||
|
||||
|
||||
_MOE_OPS = [torch.ops.auto_deploy.mock_fused_moe_moe_test]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_gm(model, example_input):
|
||||
"""Export *model* to an FX ``GraphModule``."""
|
||||
return torch.export.export(model, (example_input,)).module()
|
||||
|
||||
|
||||
def _stream_targets(gm):
|
||||
"""Return the set of ``call_function`` targets present in *gm*."""
|
||||
return {n.target for n in gm.graph.nodes if n.op == "call_function"}
|
||||
|
||||
|
||||
def _assert_stream_nodes_present(gm):
|
||||
"""Assert that the three stream-management passthrough nodes are in the graph."""
|
||||
targets = _stream_targets(gm)
|
||||
assert begin_aux_stream_passthrough in targets, "begin_aux_stream_passthrough not in graph"
|
||||
assert end_aux_stream_passthrough in targets, "end_aux_stream_passthrough not in graph"
|
||||
assert wait_aux_stream_passthrough in targets, "wait_aux_stream_passthrough not in graph"
|
||||
|
||||
|
||||
def _assert_numerical_correctness(gm, model, test_x, *, atol=1e-5):
|
||||
"""Assert that *gm* and *model* produce the same output on *test_x*."""
|
||||
ref = model(test_x)
|
||||
out = gm(test_x)
|
||||
assert torch.allclose(out, ref, atol=atol), (
|
||||
f"Output mismatch: max diff = {(out - ref).abs().max().item()}"
|
||||
)
|
||||
|
||||
|
||||
def _assert_cuda_graph_correctness(gm, model, test_x, *, atol=1e-5):
|
||||
"""Assert correctness under CUDA graph capture + replay."""
|
||||
ref = model(test_x)
|
||||
out_shape = ref.shape
|
||||
|
||||
static_x = torch.randn_like(test_x)
|
||||
static_out = torch.empty(out_shape, device="cuda", dtype=ref.dtype)
|
||||
|
||||
# Warm-up (required before capture).
|
||||
for _ in range(3):
|
||||
static_out.copy_(gm(static_x))
|
||||
|
||||
cuda_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cuda_graph):
|
||||
static_out.copy_(gm(static_x))
|
||||
|
||||
static_x.copy_(test_x)
|
||||
cuda_graph.replay()
|
||||
|
||||
assert torch.allclose(static_out, ref, atol=atol), (
|
||||
f"CUDA graph output mismatch: max diff = {(static_out - ref).abs().max().item()}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock modules — shared expert variants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _GatedMLP(nn.Module):
|
||||
"""DeepSeek / GLM4 shared expert: ``down_proj(silu(gate_proj(x)) * up_proj(x))``."""
|
||||
|
||||
def __init__(self, hidden_dim: int, intermediate_dim: int):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
|
||||
self.down_proj = nn.Linear(intermediate_dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class _SimpleMLP(nn.Module):
|
||||
"""Nemotron-H shared expert: ``down_proj(relu(up_proj(x)) ** 2)``."""
|
||||
|
||||
def __init__(self, hidden_dim: int, intermediate_dim: int):
|
||||
super().__init__()
|
||||
self.up_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
|
||||
self.down_proj = nn.Linear(intermediate_dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.down_proj(torch.relu(self.up_proj(x)) ** 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock MoE layer modules — one per architecture pattern
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockDeepSeekGLM4MoELayer(nn.Module):
|
||||
"""DeepSeek V3 / GLM4 MoE Lite pattern.
|
||||
|
||||
Graph topology::
|
||||
|
||||
hidden_states ─┬─ gate ─ topk ─────────────────────────┐
|
||||
├─ shared_experts (gated MLP) ─ shared_out │
|
||||
└─ mock_fused_moe ──────────── moe_out ─┘
|
||||
moe_out + shared_out → layernorm → out
|
||||
|
||||
Routed MoE is dispatched *before* the shared expert in graph order.
|
||||
The ``add`` has the routed output on the left.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int, intermediate_dim: int, num_experts: int = 8):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
self.shared_experts = _GatedMLP(hidden_dim, intermediate_dim)
|
||||
self.expert_weight = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
|
||||
self.layernorm = nn.LayerNorm(hidden_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
identity = hidden_states
|
||||
logits = self.gate(hidden_states)
|
||||
routing_weights, selected_experts = torch.topk(logits, k=2, dim=-1)
|
||||
|
||||
# Routed path first (matches DeepSeek / GLM4 dispatch order).
|
||||
moe_out = torch.ops.auto_deploy.mock_fused_moe_moe_test(
|
||||
hidden_states, selected_experts, routing_weights, self.expert_weight
|
||||
)
|
||||
# Shared expert on original input.
|
||||
shared_out = self.shared_experts(identity)
|
||||
|
||||
return self.layernorm(moe_out + shared_out)
|
||||
|
||||
|
||||
class MockMixtralMoELayer(nn.Module):
|
||||
"""Mixtral pattern — pure routed MoE, **no** shared expert.
|
||||
|
||||
The transform must return 0 replacements for this topology.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int, num_experts: int = 8):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
self.expert_weight = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
|
||||
self.layernorm = nn.LayerNorm(hidden_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
logits = self.gate(hidden_states)
|
||||
routing_weights, selected_experts = torch.topk(logits, k=2, dim=-1)
|
||||
moe_out = torch.ops.auto_deploy.mock_fused_moe_moe_test(
|
||||
hidden_states, selected_experts, routing_weights, self.expert_weight
|
||||
)
|
||||
return self.layernorm(moe_out)
|
||||
|
||||
|
||||
class MockNemotronHMoELayer(nn.Module):
|
||||
"""Nemotron-H pattern *without* latent projections.
|
||||
|
||||
Graph topology::
|
||||
|
||||
hidden_states ─┬─ gate ─ topk ─────────────────────────┐
|
||||
├─ shared_experts (simple MLP) ─ shared_out │
|
||||
└─ mock_fused_moe ──────────── moe_out ─┘
|
||||
shared_out + moe_out → layernorm → out
|
||||
|
||||
Shared expert is dispatched *before* the MoE in graph order.
|
||||
The ``add`` has the shared output on the left.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int, intermediate_dim: int, num_experts: int = 8):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
self.shared_experts = _SimpleMLP(hidden_dim, intermediate_dim)
|
||||
self.expert_weight = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
|
||||
self.layernorm = nn.LayerNorm(hidden_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residuals = hidden_states
|
||||
logits = self.gate(hidden_states)
|
||||
routing_weights, selected_experts = torch.topk(logits, k=2, dim=-1)
|
||||
|
||||
# Shared expert dispatched first (matches Nemotron-H dispatch order).
|
||||
shared_out = self.shared_experts(residuals)
|
||||
# Routed path.
|
||||
moe_out = torch.ops.auto_deploy.mock_fused_moe_moe_test(
|
||||
hidden_states, selected_experts, routing_weights, self.expert_weight
|
||||
)
|
||||
|
||||
return self.layernorm(shared_out + moe_out)
|
||||
|
||||
|
||||
class MockNemotronHLatentMoELayer(nn.Module):
|
||||
"""Nemotron-H pattern *with* latent projections.
|
||||
|
||||
Graph topology::
|
||||
|
||||
hidden_states ─┬─ gate ─ topk ──────────────────────────────────────┐
|
||||
├─ shared_experts (simple MLP) ────────── shared_out │
|
||||
└─ fc1_latent ─ mock_fused_moe ─ fc2_latent ─ routed_out ┘
|
||||
shared_out + routed_out → ln → out
|
||||
|
||||
The latent projections add nodes between the MoE op and the merge ``add``,
|
||||
testing that the forward BFS from MoE correctly traverses extra projection
|
||||
nodes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
intermediate_dim: int,
|
||||
latent_dim: int,
|
||||
num_experts: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
self.shared_experts = _SimpleMLP(hidden_dim, intermediate_dim)
|
||||
self.fc1_latent_proj = nn.Linear(hidden_dim, latent_dim, bias=False)
|
||||
self.fc2_latent_proj = nn.Linear(latent_dim, hidden_dim, bias=False)
|
||||
# Expert weight operates in latent space.
|
||||
self.expert_weight = nn.Parameter(torch.randn(latent_dim, latent_dim))
|
||||
self.layernorm = nn.LayerNorm(hidden_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residuals = hidden_states
|
||||
logits = self.gate(hidden_states)
|
||||
routing_weights, selected_experts = torch.topk(logits, k=2, dim=-1)
|
||||
|
||||
# Shared expert dispatched first.
|
||||
shared_out = self.shared_experts(residuals)
|
||||
|
||||
# Latent projection → MoE → back-projection.
|
||||
x_latent = self.fc1_latent_proj(hidden_states)
|
||||
moe_out = torch.ops.auto_deploy.mock_fused_moe_moe_test(
|
||||
x_latent, selected_experts, routing_weights, self.expert_weight
|
||||
)
|
||||
routed_out = self.fc2_latent_proj(moe_out)
|
||||
|
||||
return self.layernorm(shared_out + routed_out)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Tests — DeepSeek V3 / GLM4 MoE Lite (gated-MLP shared expert)
|
||||
# ===================================================================
|
||||
|
||||
|
||||
def test_deepseek_glm4_pattern_and_correctness():
|
||||
"""Single-layer: pattern match + graph structure + numerical correctness."""
|
||||
hidden_dim, intermediate_dim = 128, 256
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = MockDeepSeekGLM4MoELayer(hidden_dim, intermediate_dim).eval().to("cuda")
|
||||
example = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example)
|
||||
|
||||
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
|
||||
|
||||
assert num == 1, f"Expected 1 replacement, got {num}"
|
||||
_assert_stream_nodes_present(gm)
|
||||
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
|
||||
|
||||
|
||||
def test_deepseek_glm4_cuda_graph():
|
||||
"""CUDA graph capture + replay for DeepSeek / GLM4 pattern."""
|
||||
hidden_dim, intermediate_dim = 128, 256
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = MockDeepSeekGLM4MoELayer(hidden_dim, intermediate_dim).eval().to("cuda")
|
||||
example = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example)
|
||||
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
|
||||
assert num == 1
|
||||
|
||||
_assert_cuda_graph_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
|
||||
|
||||
|
||||
def test_deepseek_glm4_multi_layer():
|
||||
"""Two stacked DeepSeek/GLM4 MoE layers — both should be transformed."""
|
||||
hidden_dim, intermediate_dim = 128, 256
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = (
|
||||
nn.Sequential(
|
||||
MockDeepSeekGLM4MoELayer(hidden_dim, intermediate_dim),
|
||||
MockDeepSeekGLM4MoELayer(hidden_dim, intermediate_dim),
|
||||
)
|
||||
.eval()
|
||||
.to("cuda")
|
||||
)
|
||||
example = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example)
|
||||
|
||||
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
|
||||
|
||||
assert num == 2, f"Expected 2 replacements, got {num}"
|
||||
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Tests — Mixtral (no shared expert → no-op)
|
||||
# ===================================================================
|
||||
|
||||
|
||||
def test_mixtral_no_shared_expert_no_match():
|
||||
"""Mixtral has no shared expert; the transform must produce zero matches."""
|
||||
hidden_dim = 128
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = MockMixtralMoELayer(hidden_dim).eval().to("cuda")
|
||||
example = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example)
|
||||
|
||||
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
|
||||
|
||||
assert num == 0, f"Expected 0 replacements for Mixtral (no shared expert), got {num}"
|
||||
|
||||
# Graph should NOT contain any stream-management nodes.
|
||||
targets = _stream_targets(gm)
|
||||
assert begin_aux_stream_passthrough not in targets
|
||||
assert end_aux_stream_passthrough not in targets
|
||||
assert wait_aux_stream_passthrough not in targets
|
||||
|
||||
# Numerical correctness should still hold (graph unchanged).
|
||||
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Tests — Nemotron-H without latent projections
|
||||
# ===================================================================
|
||||
|
||||
|
||||
def test_nemotron_h_pattern_and_correctness():
|
||||
"""Single-layer Nemotron-H (no latent): pattern + graph + correctness."""
|
||||
hidden_dim, intermediate_dim = 128, 256
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = MockNemotronHMoELayer(hidden_dim, intermediate_dim).eval().to("cuda")
|
||||
example = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example)
|
||||
|
||||
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
|
||||
|
||||
assert num == 1, f"Expected 1 replacement, got {num}"
|
||||
_assert_stream_nodes_present(gm)
|
||||
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
|
||||
|
||||
|
||||
def test_nemotron_h_cuda_graph():
|
||||
"""CUDA graph capture + replay for Nemotron-H (no latent) pattern."""
|
||||
hidden_dim, intermediate_dim = 128, 256
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = MockNemotronHMoELayer(hidden_dim, intermediate_dim).eval().to("cuda")
|
||||
example = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example)
|
||||
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
|
||||
assert num == 1
|
||||
|
||||
_assert_cuda_graph_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
|
||||
|
||||
|
||||
def test_nemotron_h_multi_layer():
|
||||
"""Two stacked Nemotron-H (no latent) layers — both should be transformed."""
|
||||
hidden_dim, intermediate_dim = 128, 256
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = (
|
||||
nn.Sequential(
|
||||
MockNemotronHMoELayer(hidden_dim, intermediate_dim),
|
||||
MockNemotronHMoELayer(hidden_dim, intermediate_dim),
|
||||
)
|
||||
.eval()
|
||||
.to("cuda")
|
||||
)
|
||||
example = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example)
|
||||
|
||||
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
|
||||
|
||||
assert num == 2, f"Expected 2 replacements, got {num}"
|
||||
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Tests — Nemotron-H with latent projections
|
||||
# ===================================================================
|
||||
|
||||
|
||||
def test_nemotron_h_latent_pattern_and_correctness():
|
||||
"""Single-layer Nemotron-H (with latent): pattern + graph + correctness."""
|
||||
hidden_dim, intermediate_dim, latent_dim = 128, 256, 64
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = MockNemotronHLatentMoELayer(hidden_dim, intermediate_dim, latent_dim).eval().to("cuda")
|
||||
example = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example)
|
||||
|
||||
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
|
||||
|
||||
assert num == 1, f"Expected 1 replacement, got {num}"
|
||||
_assert_stream_nodes_present(gm)
|
||||
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
|
||||
|
||||
|
||||
def test_nemotron_h_latent_cuda_graph():
|
||||
"""CUDA graph capture + replay for Nemotron-H (with latent) pattern."""
|
||||
hidden_dim, intermediate_dim, latent_dim = 128, 256, 64
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = MockNemotronHLatentMoELayer(hidden_dim, intermediate_dim, latent_dim).eval().to("cuda")
|
||||
example = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example)
|
||||
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
|
||||
assert num == 1
|
||||
|
||||
_assert_cuda_graph_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
|
||||
|
||||
|
||||
def test_nemotron_h_latent_multi_layer():
|
||||
"""Two stacked Nemotron-H (with latent) layers — both should be transformed."""
|
||||
hidden_dim, intermediate_dim, latent_dim = 128, 256, 64
|
||||
cuda_stream_manager.add_device(torch.cuda.current_device())
|
||||
|
||||
model = (
|
||||
nn.Sequential(
|
||||
MockNemotronHLatentMoELayer(hidden_dim, intermediate_dim, latent_dim),
|
||||
MockNemotronHLatentMoELayer(hidden_dim, intermediate_dim, latent_dim),
|
||||
)
|
||||
.eval()
|
||||
.to("cuda")
|
||||
)
|
||||
example = torch.randn(4, hidden_dim, device="cuda")
|
||||
gm = _build_gm(model, example)
|
||||
|
||||
gm, num = _execute_shared_expert_in_aux_stream(gm, _MOE_OPS)
|
||||
|
||||
assert num == 2, f"Expected 2 replacements, got {num}"
|
||||
_assert_numerical_correctness(gm, model, torch.randn(4, hidden_dim, device="cuda"))
|
||||
@ -0,0 +1,188 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.linear.swiglu import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
class SwiGLUMLP(torch.nn.Module):
|
||||
"""SwiGLU MLP module: silu(x @ gate.T) * (x @ up.T) @ down.T"""
|
||||
|
||||
def __init__(self, hidden_size: int, intermediate_size: int):
|
||||
super().__init__()
|
||||
self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class SwiGLUMLPWithBias(torch.nn.Module):
|
||||
"""SwiGLU MLP module with biases."""
|
||||
|
||||
def __init__(self, hidden_size: int, intermediate_size: int):
|
||||
super().__init__()
|
||||
self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=True)
|
||||
self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=True)
|
||||
self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class SwiGLUTestModel(torch.nn.Module):
|
||||
"""Test model with SwiGLU MLP sandwiched between linear layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 256,
|
||||
intermediate_size: int = 512,
|
||||
with_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(hidden_size, hidden_size, device="cuda", dtype=torch.float16)
|
||||
if with_bias:
|
||||
self.mlp = SwiGLUMLPWithBias(hidden_size, intermediate_size)
|
||||
else:
|
||||
self.mlp = SwiGLUMLP(hidden_size, intermediate_size)
|
||||
self.mlp = self.mlp.to(device="cuda", dtype=torch.float16)
|
||||
self.linear2 = torch.nn.Linear(hidden_size, hidden_size, device="cuda", dtype=torch.float16)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.mlp(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLUTestModelMultipleMLP(torch.nn.Module):
|
||||
"""Test model with multiple SwiGLU MLPs to test multiple pattern matches."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 256,
|
||||
intermediate_size: int = 512,
|
||||
num_layers: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
self.layers.append(
|
||||
torch.nn.ModuleDict(
|
||||
{
|
||||
"linear": torch.nn.Linear(
|
||||
hidden_size, hidden_size, device="cuda", dtype=torch.float16
|
||||
),
|
||||
"mlp": SwiGLUMLP(hidden_size, intermediate_size).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer["linear"](x)
|
||||
x = layer["mlp"](x)
|
||||
return x
|
||||
|
||||
|
||||
def _run_fusion_test(model, expected_op, expected_num_matches=1):
|
||||
"""Run the SwiGLU fusion test.
|
||||
|
||||
Args:
|
||||
model: The test model to transform.
|
||||
expected_op: The expected fused op to find in the transformed graph.
|
||||
expected_num_matches: Expected number of fused ops.
|
||||
"""
|
||||
x = torch.randn(2, 256, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = {0: Dim.DYNAMIC}
|
||||
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
|
||||
|
||||
# Apply transforms
|
||||
gm_transformed = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"match_swiglu_pattern": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
"fuse_swiglu": {
|
||||
"stage": "post_load_fusion",
|
||||
"enabled": True,
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
# Move to CUDA if needed
|
||||
gm_transformed = gm_transformed.to("cuda")
|
||||
|
||||
# Check that the expected op is present
|
||||
count = sum(1 for n in gm_transformed.graph.nodes if is_op(n, expected_op))
|
||||
assert count == expected_num_matches, (
|
||||
f"Expected {expected_num_matches} {expected_op} ops, got {count}"
|
||||
)
|
||||
|
||||
# Verify numerical correctness
|
||||
y_transformed = gm_transformed(x)
|
||||
y_model = model(x)
|
||||
torch.testing.assert_close(y_transformed, y_model, atol=1e-2, rtol=1e-2)
|
||||
|
||||
# Test with a different batch size
|
||||
new_input = torch.randn(4, 256, device="cuda", dtype=torch.float16)
|
||||
y_transformed_2 = gm_transformed(new_input)
|
||||
y_model_2 = model(new_input)
|
||||
torch.testing.assert_close(y_transformed_2, y_model_2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
def test_swiglu_fusion_basic():
|
||||
"""Test basic SwiGLU fusion without biases."""
|
||||
model = SwiGLUTestModel(with_bias=False)
|
||||
_run_fusion_test(model, torch.ops.auto_deploy.fused_swiglu_mlp.default)
|
||||
|
||||
|
||||
def test_swiglu_fusion_with_bias():
|
||||
"""Test SwiGLU fusion with biases."""
|
||||
model = SwiGLUTestModel(with_bias=True)
|
||||
_run_fusion_test(model, torch.ops.auto_deploy.fused_swiglu_mlp.default)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_layers", [2, 3])
|
||||
def test_swiglu_fusion_multiple_layers(num_layers):
|
||||
"""Test that multiple SwiGLU patterns are fused correctly."""
|
||||
model = SwiGLUTestModelMultipleMLP(num_layers=num_layers)
|
||||
_run_fusion_test(
|
||||
model, torch.ops.auto_deploy.fused_swiglu_mlp.default, expected_num_matches=num_layers
|
||||
)
|
||||
|
||||
|
||||
def test_swiglu_pattern_match_only():
|
||||
"""Test pattern matching stage only (without fusion)."""
|
||||
model = SwiGLUTestModel()
|
||||
x = torch.randn(2, 256, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = {0: Dim.DYNAMIC}
|
||||
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
|
||||
|
||||
# Only run pattern matching, not fusion
|
||||
gm_matched = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"match_swiglu_pattern": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
# Check that the intermediate op is present
|
||||
has_swiglu_op = any(
|
||||
is_op(n, torch.ops.auto_deploy.torch_swiglu_mlp.default) for n in gm_matched.graph.nodes
|
||||
)
|
||||
assert has_swiglu_op, "Pattern matcher should produce torch_swiglu_mlp op"
|
||||
|
||||
# Verify numerical correctness
|
||||
y_matched = gm_matched(x)
|
||||
y_model = model(x)
|
||||
torch.testing.assert_close(y_matched, y_model, atol=1e-3, rtol=1e-3)
|
||||
@ -1,14 +1,26 @@
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.flashinfer_fused_add_rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.flashinfer_fused_add_rms_norm import ( # noqa
|
||||
flashinfer_fused_add_rms_norm,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.interface import TransformConfig
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.fused_add_rms_norm import FuseAddRMSNorm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AddCastNormModel(torch.nn.Module):
|
||||
"""Pattern 1: add + cast(to.dtype) + rms_norm."""
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=128, eps=1e-5):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
@ -23,55 +35,263 @@ class TestModel(torch.nn.Module):
|
||||
return norm, added
|
||||
|
||||
|
||||
def _run_test(model):
|
||||
# The replacement uses flashinfer_fused_add_rms_norm python wrapper which calls the inplace op
|
||||
# auto_deploy::flashinfer_fused_add_rms_norm_inplace
|
||||
op = torch.ops.auto_deploy.flashinfer_fused_add_rms_norm_inplace
|
||||
class AddNormModel(torch.nn.Module):
|
||||
"""Pattern 2: add + rms_norm (no intermediate cast)."""
|
||||
|
||||
def checker(gm):
|
||||
return any(is_op(n, op) for n in gm.graph.nodes)
|
||||
def __init__(self, hidden_size=128, eps=1e-5):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.ones(hidden_size, device="cuda", dtype=torch.bfloat16)
|
||||
)
|
||||
self.eps = eps
|
||||
|
||||
bsz, seq_len, hidden = 2, 8, 128
|
||||
# Inputs should be bfloat16
|
||||
x = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
|
||||
residual = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
|
||||
def forward(self, x, residual):
|
||||
added = x + residual
|
||||
norm = torch.ops.auto_deploy.flashinfer_rms_norm(added, self.weight, self.eps)
|
||||
return norm, added
|
||||
|
||||
# Dynamic shapes
|
||||
dyn_batch_size = Dim.DYNAMIC
|
||||
ds_x = {0: dyn_batch_size}
|
||||
ds_res = {0: dyn_batch_size}
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x, residual), dynamic_shapes=(ds_x, ds_res), clone=True)
|
||||
class MultiUserModel(torch.nn.Module):
|
||||
"""Both add and rms_norm outputs have multiple users (DeepSeek V3 MoE pattern).
|
||||
|
||||
gm_transformed = InferenceOptimizer(
|
||||
add_result has 2 users: rms_norm + next residual add
|
||||
norm_result has 2 users: linear1 + linear2
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size=128, eps=1e-5):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.ones(hidden_size, device="cuda", dtype=torch.bfloat16)
|
||||
)
|
||||
self.linear1 = torch.nn.Linear(
|
||||
hidden_size, hidden_size, bias=False, device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
self.linear2 = torch.nn.Linear(
|
||||
hidden_size, hidden_size, bias=False, device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, residual, attn_output, moe_output):
|
||||
# add with 2 users (norm + next_add)
|
||||
add_result = residual + attn_output
|
||||
# rms_norm with 2 users (linear1, linear2)
|
||||
norm_result = torch.ops.auto_deploy.flashinfer_rms_norm(add_result, self.weight, self.eps)
|
||||
out1 = self.linear1(norm_result)
|
||||
out2 = self.linear2(norm_result)
|
||||
combined = out1 + out2
|
||||
# add_result also feeds into next residual add
|
||||
next_residual = add_result + moe_output
|
||||
return combined, next_residual
|
||||
|
||||
|
||||
class ChainedModel(torch.nn.Module):
|
||||
"""Two consecutive add+norm pairs sharing residual (like transformer layers).
|
||||
|
||||
Layer 1: add1 = embed + attn_out, norm1 = rms_norm(add1) -- add1 has 2 users
|
||||
Layer 2: add2 = add1 + mlp_out, norm2 = rms_norm(add2) -- add2 has 2 users
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size=128, eps=1e-5):
|
||||
super().__init__()
|
||||
self.weight1 = torch.nn.Parameter(
|
||||
torch.ones(hidden_size, device="cuda", dtype=torch.bfloat16)
|
||||
)
|
||||
self.weight2 = torch.nn.Parameter(
|
||||
torch.ones(hidden_size, device="cuda", dtype=torch.bfloat16)
|
||||
)
|
||||
self.linear = torch.nn.Linear(
|
||||
hidden_size, hidden_size, bias=False, device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, embed, attn_out, mlp_out):
|
||||
add1 = embed + attn_out
|
||||
norm1 = torch.ops.auto_deploy.flashinfer_rms_norm(add1, self.weight1, self.eps)
|
||||
branch1 = self.linear(norm1)
|
||||
|
||||
add2 = add1 + mlp_out
|
||||
norm2 = torch.ops.auto_deploy.flashinfer_rms_norm(add2, self.weight2, self.eps)
|
||||
branch2 = self.linear(norm2)
|
||||
|
||||
return branch1 + branch2, add2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _count_fused_ops(gm):
|
||||
"""Count flashinfer_fused_add_rms_norm wrapper calls in the graph."""
|
||||
return sum(
|
||||
1
|
||||
for n in gm.graph.nodes
|
||||
if n.op == "call_function" and n.target is flashinfer_fused_add_rms_norm
|
||||
)
|
||||
|
||||
|
||||
def _count_rms_norm_ops(gm):
|
||||
"""Count flashinfer_rms_norm calls in the graph."""
|
||||
return sum(1 for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.flashinfer_rms_norm))
|
||||
|
||||
|
||||
def _count_add_ops(gm):
|
||||
"""Count aten.add.Tensor calls in the graph."""
|
||||
return sum(1 for n in gm.graph.nodes if is_op(n, torch.ops.aten.add.Tensor))
|
||||
|
||||
|
||||
def _export_model(model, *inputs, dynamic_dim0=True):
|
||||
"""Export a model to a GraphModule, optionally with a dynamic batch dimension."""
|
||||
if dynamic_dim0:
|
||||
dyn = Dim.DYNAMIC
|
||||
ds = tuple({0: dyn} for _ in inputs)
|
||||
else:
|
||||
ds = None
|
||||
return torch_export_to_gm(model, args=inputs, dynamic_shapes=ds, clone=True)
|
||||
|
||||
|
||||
def _apply_transform(gm):
|
||||
"""Apply fuse_add_rms_norm via InferenceOptimizer (integration-style)."""
|
||||
return InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"fuse_add_rms_norm": {
|
||||
"stage": "post_load_fusion",
|
||||
},
|
||||
},
|
||||
{"fuse_add_rms_norm": {"stage": "post_load_fusion"}},
|
||||
)(None, gm)
|
||||
|
||||
# Check if transform happened
|
||||
if not checker(gm_transformed):
|
||||
raise AssertionError(
|
||||
"flashinfer_fused_add_rms_norm_inplace op not found in transformed graph"
|
||||
)
|
||||
|
||||
# Validation
|
||||
# Clone inputs because the fused op is inplace
|
||||
x_in = x.clone()
|
||||
res_in = residual.clone()
|
||||
|
||||
# The fused op is inplace, so inputs x_in and res_in will be modified.
|
||||
# gm_transformed returns (x_in, res_in) which are the modified tensors.
|
||||
y_transformed = gm_transformed(x_in, res_in)
|
||||
|
||||
y_model = model(x.clone(), residual.clone())
|
||||
torch.testing.assert_close(y_transformed[0], y_model[0], atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(y_transformed[1], y_model[1], atol=1e-2, rtol=1e-2)
|
||||
def _apply_transform_direct(gm):
|
||||
"""Apply the transform directly (unit-test style)."""
|
||||
config = TransformConfig(stage="post_load_fusion")
|
||||
transform = FuseAddRMSNorm(config=config)
|
||||
gm, info = transform._apply(gm, None, None, None)
|
||||
return gm, info
|
||||
|
||||
|
||||
def test_fuse_add_rms_norm():
|
||||
model = TestModel()
|
||||
_run_test(model)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_fuse_add_cast_rms_norm():
|
||||
"""Original test: add + cast(bf16) + rms_norm → fused op."""
|
||||
model = AddCastNormModel()
|
||||
bsz, seq_len, hidden = 2, 8, 128
|
||||
x = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
|
||||
residual = torch.randn_like(x)
|
||||
|
||||
gm = _export_model(model, x, residual)
|
||||
gm_t = _apply_transform(gm)
|
||||
|
||||
# Structure check
|
||||
assert _count_fused_ops(gm_t) >= 1, "fused op not found in graph"
|
||||
assert _count_rms_norm_ops(gm_t) == 0, "unfused rms_norm still in graph"
|
||||
|
||||
# Numerical check
|
||||
y_fused = gm_t(x.clone(), residual.clone())
|
||||
y_ref = model(x.clone(), residual.clone())
|
||||
torch.testing.assert_close(y_fused[0], y_ref[0], atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(y_fused[1], y_ref[1], atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
def test_fuse_add_rms_norm_no_cast():
|
||||
"""Pattern 2: add + rms_norm (no cast) → fused op."""
|
||||
model = AddNormModel()
|
||||
bsz, seq_len, hidden = 2, 8, 128
|
||||
x = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
|
||||
residual = torch.randn_like(x)
|
||||
|
||||
gm = _export_model(model, x, residual)
|
||||
gm_t = _apply_transform(gm)
|
||||
|
||||
# Structure check
|
||||
assert _count_fused_ops(gm_t) >= 1, "fused op not found in graph"
|
||||
assert _count_rms_norm_ops(gm_t) == 0, "unfused rms_norm still in graph"
|
||||
|
||||
# Numerical check
|
||||
y_fused = gm_t(x.clone(), residual.clone())
|
||||
y_ref = model(x.clone(), residual.clone())
|
||||
torch.testing.assert_close(y_fused[0], y_ref[0], atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(y_fused[1], y_ref[1], atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
def test_fuse_add_rms_norm_multi_user():
|
||||
"""Multi-user: both add (2 users) and rms_norm (2 users) → fused op.
|
||||
|
||||
This is the key pattern from the DeepSeek V3 / GLM4-MoE graph that failed
|
||||
with the old inductor-based pattern matcher due to num_users constraints.
|
||||
"""
|
||||
model = MultiUserModel()
|
||||
bsz, seq_len, hidden = 2, 8, 128
|
||||
residual = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
|
||||
attn_out = torch.randn_like(residual)
|
||||
moe_out = torch.randn_like(residual)
|
||||
|
||||
gm = _export_model(model, residual, attn_out, moe_out)
|
||||
|
||||
# Before: 1 add+norm fusible pair, add has 2 users, norm has 2 users
|
||||
assert _count_rms_norm_ops(gm) == 1
|
||||
|
||||
gm_t, info = _apply_transform_direct(gm)
|
||||
|
||||
# Structure check
|
||||
assert info.num_matches == 1, f"Expected 1 match, got {info.num_matches}"
|
||||
assert _count_fused_ops(gm_t) == 1, "fused op not found in graph"
|
||||
assert _count_rms_norm_ops(gm_t) == 0, "unfused rms_norm still in graph"
|
||||
|
||||
# Verify getitem nodes for both outputs
|
||||
getitems = [
|
||||
n
|
||||
for n in gm_t.graph.nodes
|
||||
if n.op == "call_function"
|
||||
and n.target is operator.getitem
|
||||
and isinstance(n.args[0], torch.fx.Node)
|
||||
and n.args[0].target is flashinfer_fused_add_rms_norm
|
||||
]
|
||||
assert len(getitems) == 2, f"Expected 2 getitem nodes, got {len(getitems)}"
|
||||
|
||||
# Numerical check
|
||||
y_fused = gm_t(residual.clone(), attn_out.clone(), moe_out.clone())
|
||||
y_ref = model(residual.clone(), attn_out.clone(), moe_out.clone())
|
||||
torch.testing.assert_close(y_fused[0], y_ref[0], atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(y_fused[1], y_ref[1], atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
def test_fuse_add_rms_norm_chained():
|
||||
"""Chained: two consecutive add+norm pairs across transformer layers."""
|
||||
model = ChainedModel()
|
||||
bsz, seq_len, hidden = 2, 8, 128
|
||||
embed = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
|
||||
attn_out = torch.randn_like(embed)
|
||||
mlp_out = torch.randn_like(embed)
|
||||
|
||||
gm = _export_model(model, embed, attn_out, mlp_out)
|
||||
|
||||
# Before: 2 add+norm fusible pairs
|
||||
assert _count_rms_norm_ops(gm) == 2
|
||||
|
||||
gm_t, info = _apply_transform_direct(gm)
|
||||
|
||||
# Structure check
|
||||
assert info.num_matches == 2, f"Expected 2 matches, got {info.num_matches}"
|
||||
assert _count_fused_ops(gm_t) == 2, "Expected 2 fused ops"
|
||||
assert _count_rms_norm_ops(gm_t) == 0, "unfused rms_norm still in graph"
|
||||
|
||||
# Verify second fused op receives add_out from first fused op (residual chain)
|
||||
fused_nodes = [
|
||||
n
|
||||
for n in gm_t.graph.nodes
|
||||
if n.op == "call_function" and n.target is flashinfer_fused_add_rms_norm
|
||||
]
|
||||
assert len(fused_nodes) == 2
|
||||
# The second fused op's residual arg should be a getitem from the first fused op
|
||||
second_residual_arg = fused_nodes[1].args[1]
|
||||
assert (
|
||||
second_residual_arg.op == "call_function"
|
||||
and second_residual_arg.target is operator.getitem
|
||||
and second_residual_arg.args[0] is fused_nodes[0]
|
||||
), "Second fused op's residual should come from first fused op's add_out"
|
||||
|
||||
# Numerical check
|
||||
y_fused = gm_t(embed.clone(), attn_out.clone(), mlp_out.clone())
|
||||
y_ref = model(embed.clone(), attn_out.clone(), mlp_out.clone())
|
||||
torch.testing.assert_close(y_fused[0], y_ref[0], atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(y_fused[1], y_ref[1], atol=1e-2, rtol=1e-2)
|
||||
|
||||
@ -0,0 +1,378 @@
|
||||
"""Tests for NVFP4 quantized SwiGLU pattern matching and fusion transforms.
|
||||
|
||||
Tests the parallel NVFP4 SwiGLU path:
|
||||
1. match_nvfp4_swiglu_pattern: Matches torch_fake_quant_nvfp4_linear SwiGLU -> torch_nvfp4_swiglu_mlp
|
||||
2. fuse_nvfp4_swiglu: Fuses gate+up FP4 weights -> fused_nvfp4_swiglu_mlp
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _torch_test_utils import fp4_compatible, trtllm_ops_available
|
||||
from torch.export import Dim
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale
|
||||
|
||||
_skip_reason = "Requires NVFP4 (Blackwell+) and TRT-LLM ops"
|
||||
_skip_condition = not (fp4_compatible() and trtllm_ops_available())
|
||||
|
||||
|
||||
class NVFP4SwiGLUMLP(nn.Module):
|
||||
"""SwiGLU MLP using NVFP4 quantized linear ops.
|
||||
|
||||
Mimics the graph structure produced by quantize_nvfp4_linear_from_config
|
||||
applied to a standard SwiGLU MLP: silu(gate(x)) * up(x) -> down(hidden).
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int = 128, intermediate_size: int = 128):
|
||||
super().__init__()
|
||||
device = torch.device("cuda")
|
||||
scaling_vector_size = 16
|
||||
|
||||
# Create random weights and quantize them to FP4
|
||||
gate_weight = (
|
||||
torch.randn(intermediate_size, hidden_size, dtype=torch.half, device=device) * 0.05
|
||||
)
|
||||
up_weight = (
|
||||
torch.randn(intermediate_size, hidden_size, dtype=torch.half, device=device) * 0.05
|
||||
)
|
||||
down_weight = (
|
||||
torch.randn(hidden_size, intermediate_size, dtype=torch.half, device=device) * 0.05
|
||||
)
|
||||
|
||||
# Quantize gate projection
|
||||
s_w_gate = fp4_global_scale(gate_weight)
|
||||
gate_fp4, gate_cutlass = torch.ops.trtllm.fp4_quantize(
|
||||
gate_weight, s_w_gate, scaling_vector_size, False
|
||||
)
|
||||
# Use a shared input scale for gate and up (same input x)
|
||||
s_in = fp4_global_scale(torch.randn(1, hidden_size, dtype=torch.half, device=device))
|
||||
gate_alpha = (1.0 / (s_in * s_w_gate)).to(torch.float32)
|
||||
|
||||
self.register_buffer("gate_weight", gate_fp4)
|
||||
self.register_buffer("gate_input_scale", s_in.to(torch.float32))
|
||||
self.register_buffer("gate_weight_scale", gate_cutlass)
|
||||
self.register_buffer("gate_alpha", gate_alpha)
|
||||
|
||||
# Quantize up projection (same input scale as gate)
|
||||
s_w_up = fp4_global_scale(up_weight)
|
||||
up_fp4, up_cutlass = torch.ops.trtllm.fp4_quantize(
|
||||
up_weight, s_w_up, scaling_vector_size, False
|
||||
)
|
||||
up_alpha = (1.0 / (s_in * s_w_up)).to(torch.float32)
|
||||
|
||||
self.register_buffer("up_weight", up_fp4)
|
||||
self.register_buffer("up_input_scale", s_in.to(torch.float32))
|
||||
self.register_buffer("up_weight_scale", up_cutlass)
|
||||
self.register_buffer("up_alpha", up_alpha)
|
||||
|
||||
# Quantize down projection (different input: the hidden state)
|
||||
s_in_down = fp4_global_scale(
|
||||
torch.randn(1, intermediate_size, dtype=torch.half, device=device)
|
||||
)
|
||||
s_w_down = fp4_global_scale(down_weight)
|
||||
down_fp4, down_cutlass = torch.ops.trtllm.fp4_quantize(
|
||||
down_weight, s_w_down, scaling_vector_size, False
|
||||
)
|
||||
down_alpha = (1.0 / (s_in_down * s_w_down)).to(torch.float32)
|
||||
|
||||
self.register_buffer("down_weight", down_fp4)
|
||||
self.register_buffer("down_input_scale", s_in_down.to(torch.float32))
|
||||
self.register_buffer("down_weight_scale", down_cutlass)
|
||||
self.register_buffer("down_alpha", down_alpha)
|
||||
|
||||
def forward(self, x):
|
||||
gate_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
|
||||
x,
|
||||
self.gate_weight,
|
||||
None,
|
||||
[self.gate_input_scale],
|
||||
[self.gate_weight_scale, self.gate_alpha],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
up_out = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
|
||||
x,
|
||||
self.up_weight,
|
||||
None,
|
||||
[self.up_input_scale],
|
||||
[self.up_weight_scale, self.up_alpha],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
hidden = torch.nn.functional.silu(gate_out) * up_out
|
||||
return torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
|
||||
hidden,
|
||||
self.down_weight,
|
||||
None,
|
||||
[self.down_input_scale],
|
||||
[self.down_weight_scale, self.down_alpha],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
|
||||
class NVFP4SwiGLUTestModel(nn.Module):
|
||||
"""Test model wrapping NVFP4 SwiGLU MLP between linear layers."""
|
||||
|
||||
def __init__(self, hidden_size: int = 128, intermediate_size: int = 128):
|
||||
super().__init__()
|
||||
device = torch.device("cuda")
|
||||
self.linear_in = nn.Linear(hidden_size, hidden_size, device=device, dtype=torch.float16)
|
||||
self.mlp = NVFP4SwiGLUMLP(hidden_size, intermediate_size)
|
||||
self.linear_out = nn.Linear(hidden_size, hidden_size, device=device, dtype=torch.float16)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_in(x)
|
||||
x = self.mlp(x)
|
||||
x = self.linear_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class NVFP4SwiGLUMultiLayerModel(nn.Module):
|
||||
"""Test model with multiple NVFP4 SwiGLU MLP layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 128,
|
||||
intermediate_size: int = 128,
|
||||
num_layers: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
device = torch.device("cuda")
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
self.layers.append(
|
||||
nn.ModuleDict(
|
||||
{
|
||||
"linear": nn.Linear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
"mlp": NVFP4SwiGLUMLP(hidden_size, intermediate_size),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer["linear"](x)
|
||||
x = layer["mlp"](x)
|
||||
return x
|
||||
|
||||
|
||||
# -- Test helpers --------------------------------------------------------------
|
||||
|
||||
|
||||
def _count_ops(gm, op):
|
||||
"""Count how many nodes in the graph match the given op."""
|
||||
return sum(1 for n in gm.graph.nodes if is_op(n, op))
|
||||
|
||||
|
||||
def _has_no_fake_quant_nvfp4(gm):
|
||||
"""Verify no torch_fake_quant_nvfp4_linear ops remain."""
|
||||
return _count_ops(gm, torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear) == 0
|
||||
|
||||
|
||||
# -- Tests ---------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(_skip_condition, reason=_skip_reason)
|
||||
def test_nvfp4_swiglu_pattern_match_only():
|
||||
"""Test that match_nvfp4_swiglu_pattern produces torch_nvfp4_swiglu_mlp op."""
|
||||
torch.manual_seed(0)
|
||||
model = NVFP4SwiGLUMLP().to("cuda")
|
||||
x = torch.randn(2, 128, device="cuda", dtype=torch.float16)
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
|
||||
# Verify the graph has torch_fake_quant_nvfp4_linear ops before transform
|
||||
assert _count_ops(gm, torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear) == 3, (
|
||||
"Expected 3 torch_fake_quant_nvfp4_linear ops (gate, up, down) before transform"
|
||||
)
|
||||
|
||||
# Apply only pattern matching
|
||||
gm_matched = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"match_nvfp4_swiglu_pattern": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
# Check the intermediate op is present
|
||||
nvfp4_swiglu_count = _count_ops(
|
||||
gm_matched, torch.ops.auto_deploy.torch_nvfp4_swiglu_mlp.default
|
||||
)
|
||||
assert nvfp4_swiglu_count == 1, (
|
||||
f"Expected 1 torch_nvfp4_swiglu_mlp op, got {nvfp4_swiglu_count}"
|
||||
)
|
||||
|
||||
# All 3 fake_quant_nvfp4 ops should be consumed
|
||||
assert _has_no_fake_quant_nvfp4(gm_matched), (
|
||||
"torch_fake_quant_nvfp4_linear ops should be consumed by pattern matcher"
|
||||
)
|
||||
|
||||
# Verify numerical correctness
|
||||
gm_matched = gm_matched.to("cuda")
|
||||
y_matched = gm_matched(x)
|
||||
y_model = model(x)
|
||||
torch.testing.assert_close(y_matched, y_model, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(_skip_condition, reason=_skip_reason)
|
||||
def test_nvfp4_swiglu_full_fusion():
|
||||
"""Test full pipeline: pattern match -> fuse -> fused_nvfp4_swiglu_mlp."""
|
||||
torch.manual_seed(0)
|
||||
model = NVFP4SwiGLUTestModel().to("cuda")
|
||||
x = torch.randn(2, 128, device="cuda", dtype=torch.float16)
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True, dynamic_shapes=({0: Dim.DYNAMIC},))
|
||||
|
||||
# Apply pattern matching + fusion
|
||||
gm_fused = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"match_nvfp4_swiglu_pattern": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
"fuse_nvfp4_swiglu": {
|
||||
"stage": "post_load_fusion",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
gm_fused = gm_fused.to("cuda")
|
||||
|
||||
# Check the fused op is present
|
||||
fused_count = _count_ops(gm_fused, torch.ops.auto_deploy.fused_nvfp4_swiglu_mlp.default)
|
||||
assert fused_count == 1, f"Expected 1 fused_nvfp4_swiglu_mlp op, got {fused_count}"
|
||||
|
||||
# No intermediate or unfused ops should remain
|
||||
assert _count_ops(gm_fused, torch.ops.auto_deploy.torch_nvfp4_swiglu_mlp.default) == 0, (
|
||||
"Intermediate torch_nvfp4_swiglu_mlp should be replaced by fused version"
|
||||
)
|
||||
assert _has_no_fake_quant_nvfp4(gm_fused), (
|
||||
"No torch_fake_quant_nvfp4_linear ops should remain after fusion"
|
||||
)
|
||||
|
||||
# Verify numerical correctness (fused uses TRT-LLM kernel, allow wider tolerance)
|
||||
y_fused = gm_fused(x)
|
||||
y_model = model(x)
|
||||
torch.testing.assert_close(y_fused, y_model, atol=0.15, rtol=0.05)
|
||||
|
||||
# Test with a different batch size to verify dynamic shapes work
|
||||
x2 = torch.randn(4, 128, device="cuda", dtype=torch.float16)
|
||||
y_fused_2 = gm_fused(x2)
|
||||
y_model_2 = model(x2)
|
||||
torch.testing.assert_close(y_fused_2, y_model_2, atol=0.15, rtol=0.05)
|
||||
|
||||
|
||||
@pytest.mark.skipif(_skip_condition, reason=_skip_reason)
|
||||
@pytest.mark.parametrize("num_layers", [2, 3])
|
||||
def test_nvfp4_swiglu_fusion_multiple_layers(num_layers):
|
||||
"""Test that multiple NVFP4 SwiGLU patterns are fused correctly."""
|
||||
torch.manual_seed(0)
|
||||
model = NVFP4SwiGLUMultiLayerModel(num_layers=num_layers).to("cuda")
|
||||
x = torch.randn(2, 128, device="cuda", dtype=torch.float16)
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
|
||||
# Apply pattern matching + fusion
|
||||
gm_fused = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"match_nvfp4_swiglu_pattern": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
"fuse_nvfp4_swiglu": {
|
||||
"stage": "post_load_fusion",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
gm_fused = gm_fused.to("cuda")
|
||||
|
||||
# Check that all layers are fused
|
||||
fused_count = _count_ops(gm_fused, torch.ops.auto_deploy.fused_nvfp4_swiglu_mlp.default)
|
||||
assert fused_count == num_layers, (
|
||||
f"Expected {num_layers} fused_nvfp4_swiglu_mlp ops, got {fused_count}"
|
||||
)
|
||||
|
||||
# Verify numerical correctness
|
||||
y_fused = gm_fused(x)
|
||||
y_model = model(x)
|
||||
torch.testing.assert_close(y_fused, y_model, atol=0.2, rtol=0.1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(_skip_condition, reason=_skip_reason)
|
||||
def test_nvfp4_swiglu_does_not_match_non_swiglu():
|
||||
"""Test that the NVFP4 SwiGLU matcher does not match non-SwiGLU NVFP4 linears."""
|
||||
torch.manual_seed(0)
|
||||
device = torch.device("cuda")
|
||||
hidden_size = 128
|
||||
|
||||
# Model with two sequential NVFP4 linears + relu (NOT a SwiGLU pattern)
|
||||
class NonSwiGLUModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
w1 = torch.randn(hidden_size, hidden_size, dtype=torch.half, device=device) * 0.05
|
||||
w2 = torch.randn(hidden_size, hidden_size, dtype=torch.half, device=device) * 0.05
|
||||
|
||||
s_in = fp4_global_scale(torch.randn(1, hidden_size, dtype=torch.half, device=device))
|
||||
s_w1 = fp4_global_scale(w1)
|
||||
s_w2 = fp4_global_scale(w2)
|
||||
|
||||
w1_fp4, w1_cutlass = torch.ops.trtllm.fp4_quantize(w1, s_w1, 16, False)
|
||||
w2_fp4, w2_cutlass = torch.ops.trtllm.fp4_quantize(w2, s_w2, 16, False)
|
||||
|
||||
self.register_buffer("w1", w1_fp4)
|
||||
self.register_buffer("w1_is", s_in.to(torch.float32))
|
||||
self.register_buffer("w1_ws", w1_cutlass)
|
||||
self.register_buffer("w1_a", (1.0 / (s_in * s_w1)).to(torch.float32))
|
||||
|
||||
self.register_buffer("w2", w2_fp4)
|
||||
self.register_buffer("w2_is", s_in.to(torch.float32))
|
||||
self.register_buffer("w2_ws", w2_cutlass)
|
||||
self.register_buffer("w2_a", (1.0 / (s_in * s_w2)).to(torch.float32))
|
||||
|
||||
def forward(self, x):
|
||||
# Sequential linears without SwiGLU pattern
|
||||
y = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
|
||||
x, self.w1, None, [self.w1_is], [self.w1_ws, self.w1_a], [], []
|
||||
)
|
||||
y = torch.nn.functional.relu(y)
|
||||
return torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear(
|
||||
y, self.w2, None, [self.w2_is], [self.w2_ws, self.w2_a], [], []
|
||||
)
|
||||
|
||||
model = NonSwiGLUModel().to("cuda")
|
||||
x = torch.randn(2, hidden_size, device="cuda", dtype=torch.float16)
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
|
||||
gm_result = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"match_nvfp4_swiglu_pattern": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
# No SwiGLU ops should be found
|
||||
assert _count_ops(gm_result, torch.ops.auto_deploy.torch_nvfp4_swiglu_mlp.default) == 0, (
|
||||
"Non-SwiGLU NVFP4 pattern should not match"
|
||||
)
|
||||
|
||||
# Original NVFP4 linear ops should still be present
|
||||
assert _count_ops(gm_result, torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear) == 2, (
|
||||
"Original NVFP4 linear ops should be unchanged"
|
||||
)
|
||||
@ -356,106 +356,98 @@ def test_moe_export_with_reduced_experts(
|
||||
# Real-model MOE export: GLM4 MoE Lite
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_glm4_moe_lite import (
|
||||
Glm4MoeLiteConfig,
|
||||
Glm4MoeLiteForCausalLM,
|
||||
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_glm4_moe_lite import ( # noqa: E402
|
||||
Glm4MoeLiteConfig,
|
||||
Glm4MoeLiteForCausalLM,
|
||||
)
|
||||
|
||||
|
||||
def _make_tiny_glm4_config(n_routed_experts: int = 8) -> Glm4MoeLiteConfig:
|
||||
"""Create a minimal ``Glm4MoeLiteConfig`` suitable for unit tests."""
|
||||
return Glm4MoeLiteConfig(
|
||||
vocab_size=256,
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
q_lora_rank=32,
|
||||
kv_lora_rank=32,
|
||||
qk_nope_head_dim=12,
|
||||
qk_rope_head_dim=4,
|
||||
v_head_dim=16,
|
||||
n_routed_experts=n_routed_experts,
|
||||
n_shared_experts=1,
|
||||
num_experts_per_tok=2,
|
||||
moe_intermediate_size=64,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
routed_scaling_factor=1.0,
|
||||
norm_topk_prob=True,
|
||||
first_k_dense_replace=1, # layer 0 = dense MLP, layer 1 = MoE
|
||||
max_position_embeddings=128,
|
||||
rope_scaling=None,
|
||||
pad_token_id=0,
|
||||
)
|
||||
|
||||
_HAS_GLM4 = True
|
||||
except ImportError:
|
||||
_HAS_GLM4 = False
|
||||
|
||||
def _count_moe_experts_in_graph(gm: GraphModule) -> int:
|
||||
"""Return the number of experts in the first ``torch_moe`` call in *gm*."""
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and "torch_moe" in str(node.target):
|
||||
return len(node.args[3]) # w1_weight list length
|
||||
return 0
|
||||
|
||||
|
||||
if _HAS_GLM4:
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="GLM4 MoE Lite requires CUDA (uses noaux_tc_op)"
|
||||
)
|
||||
@pytest.mark.parametrize("n_routed_experts", [8, 16])
|
||||
@pytest.mark.parametrize("num_moe_experts_for_export", [2])
|
||||
def test_glm4_moe_lite_export_with_reduced_experts(n_routed_experts, num_moe_experts_for_export):
|
||||
"""Export a tiny ``Glm4MoeLiteForCausalLM`` with reduced experts and verify
|
||||
that the expanded graph has the correct structure and accepts the original
|
||||
state dict.
|
||||
"""
|
||||
# GLM4 MoE Lite uses noaux_tc_op which is CUDA-only, so we must use CUDA device
|
||||
device = "cuda"
|
||||
config = _make_tiny_glm4_config(n_routed_experts=n_routed_experts)
|
||||
model = Glm4MoeLiteForCausalLM(config).to(device)
|
||||
model.eval()
|
||||
|
||||
def _make_tiny_glm4_config(n_routed_experts: int = 8) -> Glm4MoeLiteConfig:
|
||||
"""Create a minimal ``Glm4MoeLiteConfig`` suitable for unit tests."""
|
||||
return Glm4MoeLiteConfig(
|
||||
vocab_size=256,
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
q_lora_rank=32,
|
||||
kv_lora_rank=32,
|
||||
qk_nope_head_dim=12,
|
||||
qk_rope_head_dim=4,
|
||||
v_head_dim=16,
|
||||
n_routed_experts=n_routed_experts,
|
||||
n_shared_experts=1,
|
||||
num_experts_per_tok=2,
|
||||
moe_intermediate_size=64,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
routed_scaling_factor=1.0,
|
||||
norm_topk_prob=True,
|
||||
first_k_dense_replace=1, # layer 0 = dense MLP, layer 1 = MoE
|
||||
max_position_embeddings=128,
|
||||
rope_scaling=None,
|
||||
pad_token_id=0,
|
||||
)
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 8), device=device)
|
||||
position_ids = torch.arange(8, device=device).unsqueeze(0)
|
||||
sample_kwargs = {"input_ids": input_ids, "position_ids": position_ids}
|
||||
|
||||
def _count_moe_experts_in_graph(gm: GraphModule) -> int:
|
||||
"""Return the number of experts in the first ``torch_moe`` call in *gm*."""
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and "torch_moe" in str(node.target):
|
||||
return len(node.args[3]) # w1_weight list length
|
||||
return 0
|
||||
# --- full export (baseline) ---
|
||||
gm_full = torch_export_to_gm(model, kwargs=sample_kwargs)
|
||||
|
||||
@pytest.mark.skipif(not _HAS_GLM4, reason="GLM4 MoE Lite model not available on this branch")
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="GLM4 MoE Lite requires CUDA (uses noaux_tc_op)"
|
||||
# --- export with reduced experts ---
|
||||
gm_reduced = torch_export_to_gm(
|
||||
model,
|
||||
kwargs=sample_kwargs,
|
||||
num_moe_experts_for_export=num_moe_experts_for_export,
|
||||
)
|
||||
@pytest.mark.parametrize("n_routed_experts", [8, 16])
|
||||
@pytest.mark.parametrize("num_moe_experts_for_export", [2])
|
||||
def test_glm4_moe_lite_export_with_reduced_experts(
|
||||
n_routed_experts, num_moe_experts_for_export
|
||||
):
|
||||
"""Export a tiny ``Glm4MoeLiteForCausalLM`` with reduced experts and verify
|
||||
that the expanded graph has the correct structure and accepts the original
|
||||
state dict.
|
||||
"""
|
||||
# GLM4 MoE Lite uses noaux_tc_op which is CUDA-only, so we must use CUDA device
|
||||
device = "cuda"
|
||||
config = _make_tiny_glm4_config(n_routed_experts=n_routed_experts)
|
||||
model = Glm4MoeLiteForCausalLM(config).to(device)
|
||||
model.eval()
|
||||
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 8), device=device)
|
||||
position_ids = torch.arange(8, device=device).unsqueeze(0)
|
||||
sample_kwargs = {"input_ids": input_ids, "position_ids": position_ids}
|
||||
# Structural: both graphs must expose all experts
|
||||
assert _count_moe_experts_in_graph(gm_full) == n_routed_experts
|
||||
assert _count_moe_experts_in_graph(gm_reduced) == n_routed_experts
|
||||
|
||||
# --- full export (baseline) ---
|
||||
gm_full = torch_export_to_gm(model, kwargs=sample_kwargs)
|
||||
# State-dict keys must match between full and reduced exports
|
||||
full_keys = set(gm_full.state_dict().keys())
|
||||
reduced_keys = set(gm_reduced.state_dict().keys())
|
||||
assert full_keys == reduced_keys, (
|
||||
f"State-dict key mismatch.\n"
|
||||
f" Only in full: {full_keys - reduced_keys}\n"
|
||||
f" Only in reduced: {reduced_keys - full_keys}"
|
||||
)
|
||||
|
||||
# --- export with reduced experts ---
|
||||
gm_reduced = torch_export_to_gm(
|
||||
model,
|
||||
kwargs=sample_kwargs,
|
||||
num_moe_experts_for_export=num_moe_experts_for_export,
|
||||
)
|
||||
# Load the original model weights into the reduced export graph
|
||||
gm_reduced.load_state_dict(model.state_dict(), strict=False)
|
||||
|
||||
# Structural: both graphs must expose all experts
|
||||
assert _count_moe_experts_in_graph(gm_full) == n_routed_experts
|
||||
assert _count_moe_experts_in_graph(gm_reduced) == n_routed_experts
|
||||
|
||||
# State-dict keys must match between full and reduced exports
|
||||
full_keys = set(gm_full.state_dict().keys())
|
||||
reduced_keys = set(gm_reduced.state_dict().keys())
|
||||
assert full_keys == reduced_keys, (
|
||||
f"State-dict key mismatch.\n"
|
||||
f" Only in full: {full_keys - reduced_keys}\n"
|
||||
f" Only in reduced: {reduced_keys - full_keys}"
|
||||
)
|
||||
|
||||
# Load the original model weights into the reduced export graph
|
||||
gm_reduced.load_state_dict(model.state_dict(), strict=False)
|
||||
|
||||
# Source model must be fully restored
|
||||
for name, mod in model.named_modules():
|
||||
if hasattr(mod, "experts") and isinstance(mod.experts, nn.ModuleList):
|
||||
assert len(mod.experts) == n_routed_experts, (
|
||||
f"Expert list in '{name}' was not restored to {n_routed_experts}"
|
||||
)
|
||||
# Source model must be fully restored
|
||||
for name, mod in model.named_modules():
|
||||
if hasattr(mod, "experts") and isinstance(mod.experts, nn.ModuleList):
|
||||
assert len(mod.experts) == n_routed_experts, (
|
||||
f"Expert list in '{name}' was not restored to {n_routed_experts}"
|
||||
)
|
||||
|
||||
@ -0,0 +1,163 @@
|
||||
"""Tests for ``create_derived_custom_op`` in ``_graph.py``."""
|
||||
|
||||
import torch
|
||||
from torch._subclasses import FakeTensorMode
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import create_derived_custom_op
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers – tiny custom ops used as base ops for the tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch.library.custom_op("ad_test_derived::double", mutates_args=())
|
||||
def _double(x: torch.Tensor) -> torch.Tensor:
|
||||
return x * 2
|
||||
|
||||
|
||||
@_double.register_fake
|
||||
def _double_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@torch.library.custom_op("ad_test_derived::weighted_add", mutates_args=())
|
||||
def _weighted_add(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0) -> torch.Tensor:
|
||||
return x + alpha * y
|
||||
|
||||
|
||||
@_weighted_add.register_fake
|
||||
def _weighted_add_fake(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateDerivedCustomOp:
|
||||
"""Tests for the ``create_derived_custom_op`` utility."""
|
||||
|
||||
def test_basic_derived_op(self):
|
||||
"""A derived op should be callable and produce correct results."""
|
||||
|
||||
def make_impl(base_overload):
|
||||
# Wrapper that calls the base op then negates the result.
|
||||
def impl(*args, **kwargs):
|
||||
return -base_overload(*args, **kwargs)
|
||||
|
||||
return impl
|
||||
|
||||
base_op = torch.ops.ad_test_derived.double
|
||||
derived = create_derived_custom_op(base_op, "_neg", make_impl)
|
||||
|
||||
x = torch.tensor([1.0, 2.0, 3.0])
|
||||
result = derived(x)
|
||||
expected = -(x * 2)
|
||||
torch.testing.assert_close(result, expected)
|
||||
|
||||
def test_derived_op_is_registered(self):
|
||||
"""The derived op must be accessible via ``torch.ops``."""
|
||||
|
||||
def make_impl(base_overload):
|
||||
return lambda *a, **kw: base_overload(*a, **kw)
|
||||
|
||||
create_derived_custom_op(torch.ops.ad_test_derived.double, "_registered", make_impl)
|
||||
assert hasattr(torch.ops.ad_test_derived, "double_registered")
|
||||
|
||||
def test_caching(self):
|
||||
"""Repeated calls with the same base_op + suffix must return the same object."""
|
||||
|
||||
def make_impl(base_overload):
|
||||
return lambda *a, **kw: base_overload(*a, **kw)
|
||||
|
||||
op1 = create_derived_custom_op(torch.ops.ad_test_derived.double, "_cached", make_impl)
|
||||
op2 = create_derived_custom_op(torch.ops.ad_test_derived.double, "_cached", make_impl)
|
||||
assert op1 is op2
|
||||
|
||||
def test_different_suffix_produces_different_op(self):
|
||||
"""Different suffixes must create distinct ops."""
|
||||
|
||||
def make_impl(base_overload):
|
||||
return lambda *a, **kw: base_overload(*a, **kw)
|
||||
|
||||
op_a = create_derived_custom_op(torch.ops.ad_test_derived.double, "_sfx_a", make_impl)
|
||||
op_b = create_derived_custom_op(torch.ops.ad_test_derived.double, "_sfx_b", make_impl)
|
||||
assert op_a is not op_b
|
||||
|
||||
def test_default_fake_implementation(self):
|
||||
"""When *make_fake* is None the default (empty_like) must be used."""
|
||||
|
||||
def make_impl(base_overload):
|
||||
return lambda *a, **kw: base_overload(*a, **kw)
|
||||
|
||||
derived = create_derived_custom_op(
|
||||
torch.ops.ad_test_derived.double, "_dflt_fake", make_impl
|
||||
)
|
||||
# Calling the Meta implementation via FakeTensorMode
|
||||
with FakeTensorMode():
|
||||
x = torch.empty(4)
|
||||
out = derived(x)
|
||||
assert out.shape == x.shape
|
||||
|
||||
def test_custom_fake_implementation(self):
|
||||
"""A user-supplied *make_fake* must override the default."""
|
||||
|
||||
def make_impl(base_overload):
|
||||
return lambda *a, **kw: base_overload(*a, **kw)
|
||||
|
||||
# Fake that always returns shape (1,) regardless of input shape.
|
||||
def make_fake(base_overload):
|
||||
def fake(*args, **kwargs):
|
||||
return args[0].new_empty(1)
|
||||
|
||||
return fake
|
||||
|
||||
derived = create_derived_custom_op(
|
||||
torch.ops.ad_test_derived.double,
|
||||
"_custom_fake",
|
||||
make_impl,
|
||||
make_fake=make_fake,
|
||||
)
|
||||
|
||||
with FakeTensorMode():
|
||||
x = torch.empty(10)
|
||||
out = derived(x)
|
||||
assert out.shape == (1,)
|
||||
|
||||
def test_preserves_schema_with_defaults(self):
|
||||
"""Derived op must preserve the base op's argument defaults."""
|
||||
|
||||
def make_impl(base_overload):
|
||||
def impl(*args, **kwargs):
|
||||
return base_overload(*args, **kwargs) * 10
|
||||
|
||||
return impl
|
||||
|
||||
base_op = torch.ops.ad_test_derived.weighted_add
|
||||
derived = create_derived_custom_op(base_op, "_x10", make_impl)
|
||||
|
||||
x = torch.ones(3)
|
||||
y = torch.ones(3) * 2.0
|
||||
|
||||
# With default alpha=1.0 → (x + 1.0*y) * 10 = 30
|
||||
result_default = derived(x, y)
|
||||
torch.testing.assert_close(result_default, torch.full((3,), 30.0))
|
||||
|
||||
# With explicit alpha=0.5 → (x + 0.5*y) * 10 = 20
|
||||
result_alpha = derived(x, y, alpha=0.5)
|
||||
torch.testing.assert_close(result_alpha, torch.full((3,), 20.0))
|
||||
|
||||
def test_accepts_op_overload(self):
|
||||
"""The function should accept an OpOverload (e.g. ``.default``) as well."""
|
||||
|
||||
def make_impl(base_overload):
|
||||
return lambda *a, **kw: base_overload(*a, **kw) + 1
|
||||
|
||||
derived = create_derived_custom_op(
|
||||
torch.ops.ad_test_derived.double.default, "_from_overload", make_impl
|
||||
)
|
||||
|
||||
x = torch.tensor([5.0])
|
||||
# double → 10, +1 → 11
|
||||
torch.testing.assert_close(derived(x), torch.tensor([11.0]))
|
||||
Loading…
Reference in New Issue
Block a user