mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
This commit is contained in:
parent
744a955cbb
commit
393c3d259e
@ -333,9 +333,9 @@ def flashinfer_mha_with_cache(
|
||||
q_shape_og = q.shape
|
||||
b, s = q_shape_og[:2]
|
||||
|
||||
q = q.reshape(b * s, -1, head_dim)
|
||||
k = k.reshape(b * s, -1, head_dim)
|
||||
v = v.reshape(b * s, -1, head_dim)
|
||||
q = q.reshape(b * s, -1, head_dim).contiguous()
|
||||
k = k.reshape(b * s, -1, head_dim).contiguous()
|
||||
v = v.reshape(b * s, -1, head_dim).contiguous()
|
||||
|
||||
# convert to flashinfer-style metadata
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
@ -252,3 +253,65 @@ def gated_rms_norm_ref(
|
||||
if z is not None and norm_before_gate:
|
||||
out *= F.silu(z)
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sharded RMSNorm (for sharded activations)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::sharded_rmsnorm", mutates_args=())
|
||||
def sharded_rmsnorm(
|
||||
input: torch.Tensor, weight: torch.Tensor, eps: float, world_size: int
|
||||
) -> torch.Tensor:
|
||||
"""RMSNorm for sharded activations that need global reduction.
|
||||
|
||||
When activations are sharded (split along the last dimension across devices),
|
||||
standard RMSNorm computes an incorrect local mean. This op uses all_reduce to compute
|
||||
the global mean of squared values across all shards, ensuring correct normalization.
|
||||
|
||||
The computation is:
|
||||
1. Compute local sum of squares: sum(input^2) over local features
|
||||
2. All-reduce to get global sum of squares across all shards
|
||||
3. Compute global mean: global_sum / (local_dim * world_size)
|
||||
4. Normalize: input * rsqrt(global_mean + eps)
|
||||
5. Scale with local weight (weight is also column-sharded)
|
||||
|
||||
Args:
|
||||
input: Input tensor, shape [..., local_hidden_size] where local_hidden_size
|
||||
is the shard of the full hidden dimension on this device.
|
||||
weight: Scaling weights, shape [local_hidden_size] (column-sharded).
|
||||
eps: Small constant for numerical stability.
|
||||
world_size: Number of devices across which the activation is sharded.
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor with same shape as input.
|
||||
"""
|
||||
local_dim = input.shape[-1]
|
||||
|
||||
# Cast to float32 for precision
|
||||
input_fp32 = input.to(torch.float32)
|
||||
|
||||
# Compute local sum of squares (NOT mean - we need sum for all_reduce)
|
||||
local_sum_sq = input_fp32.pow(2).sum(-1, keepdim=True)
|
||||
|
||||
# All-reduce to get global sum of squares
|
||||
global_sum_sq = local_sum_sq.clone()
|
||||
dist.all_reduce(global_sum_sq, op=dist.ReduceOp.SUM)
|
||||
|
||||
# Compute global mean: global_sum / total_elements
|
||||
global_count = local_dim * world_size
|
||||
global_mean_sq = global_sum_sq / global_count
|
||||
|
||||
# Normalize
|
||||
input_normalized = input_fp32 * torch.rsqrt(global_mean_sq + eps)
|
||||
|
||||
# Apply weight (local weight since it's also column-sharded)
|
||||
out = weight * input_normalized.to(input.dtype)
|
||||
return out
|
||||
|
||||
|
||||
@sharded_rmsnorm.register_fake
|
||||
def _(input: torch.Tensor, weight: torch.Tensor, eps: float, world_size: int) -> torch.Tensor:
|
||||
"""Fake implementation for tracing."""
|
||||
return torch.empty_like(input)
|
||||
|
||||
78
tensorrt_llm/_torch/auto_deploy/models/patches/minimax_m2.py
Normal file
78
tensorrt_llm/_torch/auto_deploy/models/patches/minimax_m2.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""A patch for MiniMax-M2 MoE to make it compatible with torch.export.
|
||||
|
||||
MiniMax-M2 is loaded from HuggingFace hub (trust_remote_code), so we cannot
|
||||
import MiniMaxM2SparseMoeBlock directly. Instead, we use the same pattern as
|
||||
DeepSeek: patching AutoModelForCausalLM.from_config to iterate over modules
|
||||
and patch by class name.
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def minimax_m2_moe(self, hidden_states: torch.Tensor):
|
||||
"""MiniMaxM2SparseMoeBlock forward function rewritten to enable torch.export.
|
||||
|
||||
Key differences from Mixtral:
|
||||
- Uses sigmoid activation for routing (not softmax)
|
||||
- Has e_score_correction_bias added for expert selection only
|
||||
- Gathers original sigmoid weights after topk selection
|
||||
"""
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
if self.training and self.jitter_noise > 0:
|
||||
hidden_states *= torch.empty_like(hidden_states).uniform_(
|
||||
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
|
||||
)
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
# MiniMax-M2 specific routing:
|
||||
# Step 1: Sigmoid activation (not softmax like Mixtral)
|
||||
routing_weights = torch.sigmoid(router_logits.float())
|
||||
|
||||
# Step 2: Add bias for expert selection only
|
||||
scores_for_choice = routing_weights + self.e_score_correction_bias
|
||||
|
||||
# Step 3: Select top-k experts based on biased scores
|
||||
_, selected_experts = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
|
||||
|
||||
# Step 4: Gather ORIGINAL sigmoid weights (not biased scores)
|
||||
top_k_weights = routing_weights.gather(1, selected_experts)
|
||||
|
||||
# Step 5: Normalize so weights sum to 1
|
||||
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
|
||||
top_k_weights = top_k_weights.to(hidden_states.dtype)
|
||||
|
||||
final_hidden_states = torch.ops.auto_deploy.torch_moe(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
top_k_weights,
|
||||
w1_weight=[expert.w1.weight for expert in self.experts], # gate projection
|
||||
w2_weight=[expert.w2.weight for expert in self.experts], # down projection
|
||||
w3_weight=[expert.w3.weight for expert in self.experts], # up projection
|
||||
)
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
_from_config_previous = AutoModelForCausalLM.from_config
|
||||
|
||||
CUSTOM_MODULE_PATCHES: Dict[str, callable] = {"MiniMaxM2SparseMoeBlock": minimax_m2_moe}
|
||||
|
||||
|
||||
def get_model_from_config_patched(config, **kwargs):
|
||||
model = _from_config_previous(config, **kwargs)
|
||||
# Patch modules by class name
|
||||
for _, module in model.named_modules():
|
||||
if type(module).__name__ in CUSTOM_MODULE_PATCHES:
|
||||
module.forward = types.MethodType(CUSTOM_MODULE_PATCHES[type(module).__name__], module)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# Patch AutoModelForCausalLM.from_config
|
||||
AutoModelForCausalLM.from_config = get_model_from_config_patched
|
||||
@ -691,6 +691,77 @@ def _resolve_ep_cls_from_node(node: Node) -> type[EPShardingInfo]:
|
||||
return EPShardingInfo
|
||||
|
||||
|
||||
class RMSNormShardingInfo(ShardingTransformInfo):
|
||||
"""Configuration for replacing RMSNorm with sharded version.
|
||||
|
||||
When RMSNorm (torch_rmsnorm) operates on sharded activations with
|
||||
weight shape [num_heads * head_dim] (full hidden size), it needs to be
|
||||
replaced with sharded_rmsnorm which uses all_reduce to compute
|
||||
the correct global mean across shards.
|
||||
|
||||
The detection of whether an RMSNorm needs this treatment is done in
|
||||
_shard_qk_norm based on weight shape matching q/k projection
|
||||
output dimensions.
|
||||
"""
|
||||
|
||||
world_size: int
|
||||
|
||||
@classmethod
|
||||
def from_node(cls, node: Node, **kwargs) -> "RMSNormShardingInfo":
|
||||
"""Create a RMSNormShardingInfo from a node."""
|
||||
return cls(target_node=node.name, **kwargs)
|
||||
|
||||
def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
|
||||
"""Validate that the node is a torch_rmsnorm op."""
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_rmsnorm):
|
||||
ad_logger.debug(
|
||||
f"RMSNormShardingInfo only applies to torch_rmsnorm ops. "
|
||||
f"Got {node.target}. Skipping."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def apply(self, gm: GraphModule, node: Node) -> None:
|
||||
"""Replace torch_rmsnorm with sharded_rmsnorm and shard the weight.
|
||||
|
||||
This handles both:
|
||||
1. Weight sharding: Shard the RMSNorm weight across ranks
|
||||
2. Op replacement: Replace with sharded_rmsnorm which uses all_reduce
|
||||
for computing the correct global mean across shards.
|
||||
"""
|
||||
# Get original arguments: (input, weight, eps)
|
||||
input_node = node.args[0]
|
||||
weight_node = node.args[1]
|
||||
eps = node.args[2]
|
||||
|
||||
# Shard the weight parameter (column-wise for 1D RMSNorm weight)
|
||||
weight_key = weight_node.target
|
||||
weight_tensor = gm.get_parameter(weight_key)
|
||||
shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=weight_tensor,
|
||||
param_key=weight_key,
|
||||
dim=0, # Column shard for 1D weight
|
||||
rank=self.config.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
ad_logger.debug(f"Sharded RMSNorm weight: {weight_key}")
|
||||
|
||||
# Insert the new node with world_size parameter
|
||||
with gm.graph.inserting_after(node):
|
||||
new_node = gm.graph.call_function(
|
||||
torch.ops.auto_deploy.sharded_rmsnorm.default,
|
||||
args=(input_node, weight_node, eps, self.world_size),
|
||||
)
|
||||
|
||||
# Replace all uses with the new node and remove the old node
|
||||
node.replace_all_uses_with(new_node)
|
||||
gm.graph.erase_node(node)
|
||||
ad_logger.debug(
|
||||
f"Replaced torch_rmsnorm with sharded_rmsnorm (world_size={self.world_size})"
|
||||
)
|
||||
|
||||
|
||||
########################################################
|
||||
# Transform API classes
|
||||
########################################################
|
||||
@ -828,6 +899,9 @@ class ShardingTransformExecutor(BaseTransform):
|
||||
for ep_transform in transforms.ep_transforms:
|
||||
if check_and_apply(ep_transform):
|
||||
num_matches += 1
|
||||
for rmsnorm_transform in transforms.rmsnorm_transforms:
|
||||
if check_and_apply(rmsnorm_transform):
|
||||
num_matches += 1
|
||||
|
||||
# post-sharding cleanup transformations
|
||||
for update_transform in transforms.parameter_update_transforms:
|
||||
@ -851,6 +925,7 @@ class ShardingTransformContainer(BaseModel):
|
||||
parameter_update_transforms: List[ParameterUpdateInfo] = Field(default_factory=list)
|
||||
bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list)
|
||||
ep_transforms: List[EPShardingInfo] = Field(default_factory=list)
|
||||
rmsnorm_transforms: List[RMSNormShardingInfo] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@ -859,6 +934,7 @@ class ShardingTransformContainer(BaseModel):
|
||||
BMMShardingInfo: self.bmm_transforms,
|
||||
EPShardingInfo: self.ep_transforms,
|
||||
ParameterUpdateInfo: self.parameter_update_transforms,
|
||||
RMSNormShardingInfo: self.rmsnorm_transforms,
|
||||
}
|
||||
|
||||
def add(self, transform: ShardingTransformInfo) -> bool:
|
||||
@ -1299,14 +1375,30 @@ def _shard_parameter_node(
|
||||
|
||||
|
||||
def _update_node_args(node: Node, args: tuple) -> None:
|
||||
"""Update the node's arguments with the new sharded arguments."""
|
||||
"""Update the node's arguments with the new sharded arguments.
|
||||
|
||||
For Node args: preserve the current value (may have been updated by other transforms).
|
||||
For non-Node args (shapes, sizes, indices): use the stored sharded value.
|
||||
|
||||
This prevents ParameterUpdateInfo from reverting Node references that were
|
||||
intentionally updated by other transforms.
|
||||
"""
|
||||
if "sharded" in node.meta and node.meta["sharded"]:
|
||||
return
|
||||
node.args = args
|
||||
|
||||
# Build new args: preserve current Node refs, apply stored non-Node values
|
||||
new_args = []
|
||||
for i, stored_arg in enumerate(args):
|
||||
if isinstance(stored_arg, Node):
|
||||
# Node args: preserve current value (may have been updated by other transforms)
|
||||
new_args.append(node.args[i])
|
||||
else:
|
||||
# Non-Node args (shapes, sizes, indices): use stored sharded value
|
||||
new_args.append(stored_arg)
|
||||
|
||||
node.args = tuple(new_args)
|
||||
node.meta["sharded"] = True
|
||||
ad_logger.debug(
|
||||
f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
|
||||
)
|
||||
ad_logger.debug(f"Updated node {node}: sharded arguments are now {node.args}.")
|
||||
|
||||
|
||||
def _insert_sharded_moe(
|
||||
@ -1838,43 +1930,242 @@ def _determine_fused_weight_dims(
|
||||
return None
|
||||
linear_node = linear_nodes[0]
|
||||
fused_weight_dims = None
|
||||
# check if there are split nodes in the subgraph. They may indicate fused weights (e.g., QKV)
|
||||
split_nodes = list(filtered_nodes(linear_node.users, ops=[torch.ops.aten.split_with_sizes]))
|
||||
if len(split_nodes) > 0:
|
||||
assert len(linear_nodes) == 1
|
||||
if len(linear_nodes) == 1:
|
||||
linear_node = linear_nodes[0]
|
||||
assert len(split_nodes) == 1, "Expecting exactly one split node for fused weights"
|
||||
fused_weight_dims = split_nodes[0].args[1]
|
||||
# check if there are split nodes in the subgraph. They may indicate fused weights (e.g., QKV)
|
||||
linear_split_users = list(
|
||||
filtered_nodes(linear_node.users, ops=torch.ops.aten.split_with_sizes)
|
||||
)
|
||||
linear_slice_users = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.slice))
|
||||
linear_chunk_users = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk))
|
||||
if len(linear_split_users) > 0:
|
||||
assert len(linear_split_users) == 1, (
|
||||
"Expecting exactly one split node for fused weights"
|
||||
)
|
||||
fused_weight_dims = linear_split_users[0].args[1]
|
||||
|
||||
slice_nodes = list(filtered_nodes(linear_node.users, ops=[torch.ops.aten.slice]))
|
||||
if len(slice_nodes) > 0:
|
||||
# we are probably in fused QKV case with single linear node and 3 slice nodes
|
||||
assert len(linear_nodes) == 1
|
||||
linear_node = linear_nodes[0]
|
||||
assert all(
|
||||
s.args[1] == 2 for s in filtered_nodes(linear_node.users, ops=torch.ops.aten.slice)
|
||||
), "Expecting slice nodes to slice tensor over dim=2"
|
||||
fused_weight_dims = [s.args[3] - s.args[2] for s in linear_node.users]
|
||||
weight_dim = shape(linear_node)[2]
|
||||
if sum(fused_weight_dims) != weight_dim:
|
||||
if fused_weight_dims[-1] > weight_dim:
|
||||
fused_weight_dims[-1] = weight_dim - sum(fused_weight_dims[:-1])
|
||||
else:
|
||||
ad_logger.warning(
|
||||
f"Fused weight dims {fused_weight_dims} do not sum to weight dim {weight_dim}. Skipping."
|
||||
)
|
||||
return None
|
||||
chunk_nodes = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk))
|
||||
if len(chunk_nodes) > 0:
|
||||
assert len(linear_nodes) == 1
|
||||
linear_node = linear_nodes[0]
|
||||
assert len(chunk_nodes) == 1, "Expecting exactly one chunk node for fused weights"
|
||||
num_chunks = chunk_nodes[0].args[1]
|
||||
weight_dim = shape(linear_node)[2]
|
||||
fused_weight_dims = [weight_dim // num_chunks] * num_chunks
|
||||
if fused_weight_dims is not None:
|
||||
fused_weight_dims = tuple(fused_weight_dims)
|
||||
return fused_weight_dims
|
||||
elif len(linear_slice_users) > 0:
|
||||
# we are probably in fused QKV case with single linear node and 3 slice nodes
|
||||
assert all(s.args[1] == 2 for s in linear_slice_users), (
|
||||
"Expecting slice nodes to slice tensor over dim=2"
|
||||
)
|
||||
fused_weight_dims = [s.args[3] - s.args[2] for s in linear_slice_users]
|
||||
assert fused_weight_dims, "fused weight dims cannot be empty"
|
||||
weight_dim = linear_node.meta["val"].shape[2]
|
||||
if sum(fused_weight_dims) != weight_dim:
|
||||
if fused_weight_dims[-1] > weight_dim:
|
||||
fused_weight_dims[-1] = weight_dim - sum(fused_weight_dims[:-1])
|
||||
else:
|
||||
ad_logger.warning(
|
||||
f"Fused weight dims {fused_weight_dims} do not sum to weight dim {weight_dim}. Skipping."
|
||||
)
|
||||
return
|
||||
|
||||
elif len(linear_chunk_users) > 0:
|
||||
assert len(linear_chunk_users) == 1, (
|
||||
"Expecting exactly one chunk node for fused weights"
|
||||
)
|
||||
num_chunks = linear_chunk_users[0].args[1]
|
||||
weight_dim = linear_node.meta["val"].shape[2]
|
||||
fused_weight_dims = [weight_dim // num_chunks] * num_chunks
|
||||
|
||||
|
||||
def _find_upstream_qk_proj(node: Node, gm: GraphModule) -> Optional[str]:
|
||||
"""
|
||||
Find the upstream q/k projection linear node from an RMSNorm input.
|
||||
|
||||
Traverses backwards through pass-through tensor operations (view, reshape,
|
||||
type conversions, quantize/dequantize, etc.) to find the immediate producer.
|
||||
If that producer is a q_proj or k_proj linear, returns the weight name.
|
||||
|
||||
Args:
|
||||
node: The input node to start traversing from (typically RMSNorm's activation input)
|
||||
gm: The graph module containing the nodes
|
||||
|
||||
Returns:
|
||||
The weight name if a q_proj or k_proj is found as immediate upstream, None otherwise.
|
||||
|
||||
TODO: Is there a more efficient way to do this?
|
||||
"""
|
||||
# Pass-through ops that we traverse through (these don't change the semantic meaning)
|
||||
passthrough_ops = [
|
||||
torch.ops.aten.view,
|
||||
torch.ops.aten.reshape,
|
||||
torch.ops.aten.contiguous,
|
||||
torch.ops.aten.clone,
|
||||
torch.ops.aten.to,
|
||||
torch.ops.aten._to_copy,
|
||||
torch.ops.aten.slice,
|
||||
torch.ops.aten.transpose,
|
||||
torch.ops.aten.permute,
|
||||
]
|
||||
|
||||
visited = set()
|
||||
current = node
|
||||
|
||||
# Traverse through pass-through ops to find the actual producer
|
||||
while current is not None and current not in visited:
|
||||
visited.add(current)
|
||||
|
||||
# Check if this is a linear operation (the producer we're looking for)
|
||||
if is_any_lin_op(current):
|
||||
try:
|
||||
weight_name = extract_weight_name(current)
|
||||
if weight_name and ("q_proj" in weight_name or "k_proj" in weight_name):
|
||||
return weight_name
|
||||
except (AttributeError, AssertionError):
|
||||
pass
|
||||
# Found a linear but not q/k proj - this is not a QK norm
|
||||
return None
|
||||
|
||||
# If this is a pass-through op, continue to its first input
|
||||
if is_op(current, passthrough_ops):
|
||||
# Get the first tensor input (skip non-tensor args like dims)
|
||||
tensor_inputs = [arg for arg in current.all_input_nodes if isinstance(arg, Node)]
|
||||
if tensor_inputs:
|
||||
current = tensor_inputs[0]
|
||||
continue
|
||||
|
||||
# Hit a non-passthrough, non-linear op - stop searching
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _shard_qk_norm(
|
||||
layer_subgraph: LayerSubgraph,
|
||||
linear_nodes: List[Node],
|
||||
transform_container: ShardingTransformContainer,
|
||||
) -> int:
|
||||
"""
|
||||
Shard RMSNorm ops in attention layers that operate on full hidden size.
|
||||
|
||||
This function detects torch_rmsnorm ops that are true QK norms - i.e., norms that
|
||||
operate on the output of q_proj or k_proj. These global QK norms operate on
|
||||
flattened Q/K output [batch, seq, hidden_size] before reshape and need special
|
||||
handling for tensor parallelism:
|
||||
1. Weight sharding: The norm weight is column-sharded across ranks
|
||||
2. Global mean: Replaced with sharded_rmsnorm which uses all_reduce
|
||||
|
||||
Detection criteria:
|
||||
1. The RMSNorm's input must trace back to a q_proj or k_proj linear operation
|
||||
2. The weight shape must match q/k projection output dimensions
|
||||
|
||||
Example1: - Global Norm on all heads directly on flattened Q/K output (e.g. MiniMax):
|
||||
self.q_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_attention_heads, eps=config.rms_norm_eps)
|
||||
weight shape: [num_heads * head_dim] -> matches valid_sizes -> will be sharded
|
||||
Status: Needs sharding (weight matches q projection output dim)
|
||||
|
||||
Example2: - Norm per head after reshaping to [batch, seq, num_heads, head_dim] (e.g. GLM 4.7):
|
||||
self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
weight shape: [head_dim] -> does NOT match valid_sizes -> skipped
|
||||
Status: No sharding needed (weight doesn't match q/k output dims)
|
||||
|
||||
Args:
|
||||
layer_subgraph: The attention layer subgraph
|
||||
linear_nodes: The linear nodes (q/k/v projections)
|
||||
transform_container: Container for sharding transformations
|
||||
|
||||
Returns:
|
||||
Number of nodes added for sharding
|
||||
"""
|
||||
if layer_subgraph.layer_type != LayerType.ATTENTION or layer_subgraph.terminating_node is None:
|
||||
return 0
|
||||
|
||||
config = transform_container.config
|
||||
world_size = config.world_size
|
||||
gm = linear_nodes[0].graph.owning_module
|
||||
added_nodes = 0
|
||||
|
||||
# Collect valid output dimensions from q/k/v projections
|
||||
# These are the only valid sizes for intermediate weights (e.g., q_norm, k_norm)
|
||||
valid_sizes = set()
|
||||
for lin_node in linear_nodes:
|
||||
try:
|
||||
wkey = extract_weight_name(lin_node)
|
||||
w = gm.get_parameter(wkey)
|
||||
valid_sizes.add(w.shape[0]) # q: num_heads*head_dim, k/v: num_kv_heads*head_dim
|
||||
except (AttributeError, AssertionError):
|
||||
pass
|
||||
|
||||
# Find all intermediate weight nodes between q/k/v projections and o_proj.
|
||||
intermediate_weight_nodes = subgraph(
|
||||
sources=linear_nodes,
|
||||
sinks=[layer_subgraph.terminating_node],
|
||||
include=lambda n: n.op == "get_attr",
|
||||
)
|
||||
|
||||
for weight_node in intermediate_weight_nodes:
|
||||
weight_key = weight_node.target
|
||||
|
||||
# First check: is this weight consumed by a torch_rmsnorm op
|
||||
if len(list(weight_node.users)) == 0:
|
||||
continue
|
||||
|
||||
user_node = list(weight_node.users)[0]
|
||||
if not is_op(user_node, torch.ops.auto_deploy.torch_rmsnorm):
|
||||
continue
|
||||
|
||||
# Verify this is a true QK norm by checking its input traces back to q_proj or k_proj
|
||||
# This filters out input_layernorm which feeds INTO q/k/v projections (not after them)
|
||||
rmsnorm_input = user_node.args[0] # activation input (not the weight)
|
||||
upstream_proj = _find_upstream_qk_proj(rmsnorm_input, gm)
|
||||
if upstream_proj is None:
|
||||
ad_logger.debug(
|
||||
f"Skipping {user_node.name} - input does not trace back to q_proj or k_proj "
|
||||
f"(likely input_layernorm or other non-QK norm)"
|
||||
)
|
||||
continue
|
||||
|
||||
ad_logger.debug(f"Found QK norm {user_node.name} with upstream projection: {upstream_proj}")
|
||||
|
||||
# Try to get the parameter
|
||||
try:
|
||||
param = gm.get_parameter(weight_key)
|
||||
except AttributeError:
|
||||
ad_logger.debug(f"Could not get parameter for {weight_key}, skipping")
|
||||
continue
|
||||
|
||||
# Only shard 1D weights (RMSNorm weights are always 1D)
|
||||
if param.dim() != 1:
|
||||
ad_logger.debug(
|
||||
f"Skipping intermediate weight {weight_key} with dim={param.dim()} (not 1D)"
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if the weight size is divisible by world_size
|
||||
if param.shape[0] % world_size != 0:
|
||||
ad_logger.debug(
|
||||
f"Skipping intermediate weight {weight_key} with shape {param.shape} "
|
||||
f"(not divisible by world_size={world_size})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if weight size matches one of the q/k/v projection output dimensions
|
||||
# This filters out per-head norms and unrelated weights like inv_freq
|
||||
if valid_sizes and param.shape[0] not in valid_sizes:
|
||||
ad_logger.debug(
|
||||
f"Skipping intermediate weight {weight_key} with shape {param.shape} "
|
||||
f"(not in valid_sizes={valid_sizes}, likely per-head norm)"
|
||||
)
|
||||
continue
|
||||
|
||||
# Add RMSNormShardingInfo to replace with sharded_rmsnorm
|
||||
# This handles both weight sharding and op replacement in one transform
|
||||
if transform_container.add(
|
||||
RMSNormShardingInfo.from_node(
|
||||
user_node,
|
||||
config=config,
|
||||
world_size=world_size,
|
||||
)
|
||||
):
|
||||
ad_logger.debug(
|
||||
f"Added RMSNormShardingInfo for {user_node.name} "
|
||||
f"(will replace with sharded_rmsnorm for global mean)"
|
||||
)
|
||||
added_nodes += 1
|
||||
|
||||
return added_nodes
|
||||
|
||||
|
||||
def _process_column_sharding(
|
||||
@ -1955,6 +2246,10 @@ def _process_column_sharding(
|
||||
)
|
||||
)
|
||||
# chunk nodes do not need to be updated
|
||||
|
||||
# Shard intermediate weights (e.g. q/k/v -> q_norm, k_norm ... -> o_proj) for attention layers
|
||||
added_nodes += _shard_qk_norm(layer_subgraph, linear_nodes, transform_container)
|
||||
|
||||
return added_nodes
|
||||
|
||||
|
||||
|
||||
@ -967,7 +967,9 @@ def get_layer_after_linear_node(
|
||||
min_local_shape=min_local_shape,
|
||||
)
|
||||
assert linear_nodes[start_lin_index] in opening_linear_nodes, (
|
||||
"Linear node not found in opening linear nodes"
|
||||
f"Linear node not found in opening linear nodes - "
|
||||
f"terminating_linear_node:{terminating_linear_node.name}, "
|
||||
f"opening_linear_nodes: {[n.name for n in opening_linear_nodes]}"
|
||||
)
|
||||
|
||||
# return the index of the terminating linear node
|
||||
|
||||
@ -0,0 +1,171 @@
|
||||
"""Functional tests for sharded_rmsnorm custom op.
|
||||
|
||||
This module tests that the sharded_rmsnorm op produces correct
|
||||
numerical results when run across multiple GPUs with column-sharded inputs.
|
||||
|
||||
The op computes RMSNorm on column-sharded activations by using all_reduce to
|
||||
compute the global mean of squared values across all shards, ensuring correct
|
||||
normalization equivalent to non-sharded global RMSNorm.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mpi4py import MPI
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.distributed.common import initialize, is_initialized
|
||||
|
||||
# Register this module for cloudpickle serialization for MPI workers
|
||||
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
||||
MPI.pickle.__init__(
|
||||
cloudpickle.dumps,
|
||||
cloudpickle.loads,
|
||||
pickle.HIGHEST_PROTOCOL,
|
||||
)
|
||||
|
||||
# needed since we reuse the mpi executor pool, first test running will leak a thread
|
||||
pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
|
||||
def _reference_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Reference RMSNorm implementation for comparison."""
|
||||
input_fp32 = input.to(torch.float32)
|
||||
variance = input_fp32.pow(2).mean(-1, keepdim=True)
|
||||
input_normalized = input_fp32 * torch.rsqrt(variance + eps)
|
||||
return weight * input_normalized.to(input.dtype)
|
||||
|
||||
|
||||
def _run_sharded_rmsnorm_test(
|
||||
tensor_parallel_size: int,
|
||||
batch_size: int = 2,
|
||||
seq_len: int = 8,
|
||||
hidden_size: int = 64,
|
||||
eps: float = 1e-6,
|
||||
dtype_str: str = "float16",
|
||||
):
|
||||
"""Test that sharded_rmsnorm matches non-sharded global RMSNorm.
|
||||
|
||||
Each rank:
|
||||
1. Creates identical full input and weight tensors (same seed)
|
||||
2. Computes reference result using non-sharded RMSNorm
|
||||
3. Column-shards input and weight for this rank
|
||||
4. Calls sharded_rmsnorm on local shard
|
||||
5. All-gathers results and compares with reference
|
||||
"""
|
||||
# Import inside worker to avoid cloudpickle serialization issues with torch.ops
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
|
||||
|
||||
rank = tensorrt_llm.mpi_rank()
|
||||
world_size = tensor_parallel_size
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
if not is_initialized():
|
||||
initialize(rank, port=29500)
|
||||
|
||||
# Map string to torch.dtype inside worker to avoid cloudpickle serialization issues
|
||||
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16}
|
||||
dtype = dtype_map[dtype_str]
|
||||
|
||||
try:
|
||||
torch.manual_seed(42) # Same seed for reproducibility across ranks
|
||||
|
||||
# Full tensors (same on all ranks due to seed)
|
||||
full_input = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype)
|
||||
full_weight = torch.randn(hidden_size, device="cuda", dtype=dtype)
|
||||
|
||||
# Reference: non-sharded global RMSNorm
|
||||
reference_output = _reference_rmsnorm(full_input, full_weight, eps)
|
||||
|
||||
# Column shard: split hidden_size across ranks
|
||||
local_hidden_size = hidden_size // world_size
|
||||
start_idx = rank * local_hidden_size
|
||||
end_idx = start_idx + local_hidden_size
|
||||
|
||||
local_input = full_input[..., start_idx:end_idx].contiguous()
|
||||
local_weight = full_weight[start_idx:end_idx].contiguous()
|
||||
|
||||
# Call sharded_rmsnorm using dynamic getattr to avoid cloudpickle capturing torch.ops
|
||||
sharded_rmsnorm_op = getattr(getattr(torch, "ops"), "auto_deploy").sharded_rmsnorm
|
||||
local_output = sharded_rmsnorm_op(local_input, local_weight, eps, world_size)
|
||||
|
||||
# All-gather to reconstruct full output
|
||||
gathered_outputs = [torch.zeros_like(local_output) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_outputs, local_output)
|
||||
reconstructed_output = torch.cat(gathered_outputs, dim=-1)
|
||||
|
||||
# Compare with reference
|
||||
torch.testing.assert_close(
|
||||
reconstructed_output,
|
||||
reference_output,
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
msg=f"sharded_rmsnorm result doesn't match reference (rank={rank})",
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs for this test")
|
||||
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
|
||||
def test_sharded_rmsnorm_functional(mpi_pool_executor):
|
||||
"""Functional test: verify sharded_rmsnorm produces correct numerical results."""
|
||||
torch.manual_seed(0)
|
||||
tensor_parallel_size = mpi_pool_executor.num_workers
|
||||
|
||||
results = mpi_pool_executor.map(
|
||||
_run_sharded_rmsnorm_test,
|
||||
*zip(*[(tensor_parallel_size,)] * tensor_parallel_size),
|
||||
)
|
||||
for r in results:
|
||||
assert r is True
|
||||
|
||||
|
||||
def _run_sharded_rmsnorm_hidden_size_test(tensor_parallel_size: int, hidden_size: int):
|
||||
"""Worker function for hidden size parametrized test."""
|
||||
return _run_sharded_rmsnorm_test(tensor_parallel_size, hidden_size=hidden_size)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs for this test")
|
||||
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
|
||||
@pytest.mark.parametrize("hidden_size", [64, 128, 256])
|
||||
def test_sharded_rmsnorm_different_hidden_sizes(mpi_pool_executor, hidden_size):
|
||||
"""Test sharded_rmsnorm with different hidden sizes."""
|
||||
torch.manual_seed(0)
|
||||
tensor_parallel_size = mpi_pool_executor.num_workers
|
||||
|
||||
results = mpi_pool_executor.map(
|
||||
_run_sharded_rmsnorm_hidden_size_test,
|
||||
*zip(*[(tensor_parallel_size, hidden_size)] * tensor_parallel_size),
|
||||
)
|
||||
for r in results:
|
||||
assert r is True
|
||||
|
||||
|
||||
def _run_sharded_rmsnorm_dtype_test(tensor_parallel_size: int, dtype_str: str):
|
||||
"""Worker function for dtype parametrized test."""
|
||||
return _run_sharded_rmsnorm_test(tensor_parallel_size, dtype_str=dtype_str)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs for this test")
|
||||
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
|
||||
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
|
||||
def test_sharded_rmsnorm_different_dtypes(mpi_pool_executor, dtype_str):
|
||||
"""Test sharded_rmsnorm with different dtypes."""
|
||||
torch.manual_seed(0)
|
||||
tensor_parallel_size = mpi_pool_executor.num_workers
|
||||
|
||||
results = mpi_pool_executor.map(
|
||||
_run_sharded_rmsnorm_dtype_test,
|
||||
*zip(*[(tensor_parallel_size, dtype_str)] * tensor_parallel_size),
|
||||
)
|
||||
for r in results:
|
||||
assert r is True
|
||||
@ -0,0 +1,325 @@
|
||||
"""Tests for RMSNorm sharding detection and transformation.
|
||||
|
||||
This module tests that the sharding transform correctly detects and transforms
|
||||
torch_rmsnorm ops based on weight shape and position in the graph:
|
||||
|
||||
1. Full hidden norm (weight shape = [num_heads * head_dim], AFTER q/k projection):
|
||||
- Detected as QK norm needing sharding → replaced with sharded_rmsnorm
|
||||
- Weight is sharded
|
||||
|
||||
2. Per-head norm (weight shape = [head_dim], like GLM):
|
||||
- NOT detected as needing sharding → stays as local torch_rmsnorm
|
||||
- No transformation needed
|
||||
|
||||
3. Input layernorm (feeds INTO q/k/v projections, not after):
|
||||
- NOT detected as QK norm → stays as local torch_rmsnorm
|
||||
- Even though weight shape matches, it's not a QK norm
|
||||
|
||||
These are graph-level unit tests that verify the transform logic.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Ensure custom ops are registered
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops import rms_norm # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
class AttentionWithInputLayernorm(nn.Module):
|
||||
"""Attention module with input_layernorm BEFORE q/k/v projections.
|
||||
|
||||
This simulates the standard Llama pattern where input_layernorm feeds
|
||||
INTO the attention projections. This should NOT be detected as a QK norm.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int = 64, num_heads: int = 4):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
|
||||
# Input layernorm - BEFORE q/k/v projections (like Llama)
|
||||
self.input_layernorm_weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.eps = 1e-6
|
||||
|
||||
# Q/K/V/O projections
|
||||
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, s, _ = x.shape
|
||||
|
||||
# Input layernorm BEFORE projections (Llama pattern)
|
||||
x_normed = torch.ops.auto_deploy.torch_rmsnorm(x, self.input_layernorm_weight, self.eps)
|
||||
|
||||
q = self.q_proj(x_normed)
|
||||
k = self.k_proj(x_normed)
|
||||
v = self.v_proj(x_normed)
|
||||
|
||||
# Reshape for attention
|
||||
q = q.view(b, s, self.num_heads, self.head_dim)
|
||||
k = k.view(b, s, self.num_heads, self.head_dim)
|
||||
v = v.view(b, s, self.num_heads, self.head_dim)
|
||||
|
||||
y = torch.ops.auto_deploy.torch_attention(q, k, v, is_causal=True, layout="bsnd")
|
||||
y = y.contiguous().view(b, s, -1)
|
||||
|
||||
return self.o_proj(y)
|
||||
|
||||
|
||||
class SimpleAttentionWithQKNorm(nn.Module):
|
||||
"""Attention module with configurable QK normalization.
|
||||
|
||||
Args:
|
||||
hidden_size: Total hidden dimension
|
||||
num_heads: Number of attention heads
|
||||
use_full_hidden_norm: If True, use full hidden norm (like MiniMax)
|
||||
- Weight shape = [hidden_size], norm before reshape
|
||||
- Should be transformed to sharded_rmsnorm
|
||||
If False, use per-head norm (like GLM)
|
||||
- Weight shape = [head_dim], norm after reshape
|
||||
- Should NOT be transformed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int = 64, num_heads: int = 4, use_full_hidden_norm: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.use_full_hidden_norm = use_full_hidden_norm
|
||||
|
||||
# Q/K/V/O projections
|
||||
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
|
||||
# QK norm weights: shape depends on norm type
|
||||
norm_dim = hidden_size if use_full_hidden_norm else self.head_dim
|
||||
self.q_norm_weight = nn.Parameter(torch.ones(norm_dim))
|
||||
self.k_norm_weight = nn.Parameter(torch.ones(norm_dim))
|
||||
self.eps = 1e-6
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, s, _ = x.shape
|
||||
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
if self.use_full_hidden_norm:
|
||||
# Full hidden norm: apply before reshape
|
||||
q = torch.ops.auto_deploy.torch_rmsnorm(q, self.q_norm_weight, self.eps)
|
||||
k = torch.ops.auto_deploy.torch_rmsnorm(k, self.k_norm_weight, self.eps)
|
||||
# Reshape for attention
|
||||
q = q.view(b, s, self.num_heads, self.head_dim)
|
||||
k = k.view(b, s, self.num_heads, self.head_dim)
|
||||
v = v.view(b, s, self.num_heads, self.head_dim)
|
||||
else:
|
||||
# Reshape first for per-head norm
|
||||
q = q.view(b, s, self.num_heads, self.head_dim)
|
||||
k = k.view(b, s, self.num_heads, self.head_dim)
|
||||
v = v.view(b, s, self.num_heads, self.head_dim)
|
||||
# Per-head norm: apply after reshape (broadcasts over heads)
|
||||
q = torch.ops.auto_deploy.torch_rmsnorm(q, self.q_norm_weight, self.eps)
|
||||
k = torch.ops.auto_deploy.torch_rmsnorm(k, self.k_norm_weight, self.eps)
|
||||
|
||||
y = torch.ops.auto_deploy.torch_attention(q, k, v, is_causal=True, layout="bsnd")
|
||||
y = y.contiguous().view(b, s, -1)
|
||||
|
||||
return self.o_proj(y)
|
||||
|
||||
|
||||
def count_ops(gm, op) -> int:
|
||||
"""Count the number of nodes with a specific op in the graph."""
|
||||
count = 0
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and is_op(node, op):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRMSNormShardingTransform:
|
||||
"""Tests for the sharding transform on RMSNorm ops."""
|
||||
|
||||
def test_full_hidden_norm_transformed_to_sharded(self):
|
||||
"""Test that full hidden norm RMSNorm ops are replaced with sharded_rmsnorm.
|
||||
|
||||
When weight shape = [hidden_size] (matches q_proj output dim):
|
||||
- torch_rmsnorm should be replaced with sharded_rmsnorm
|
||||
- Weight should be sharded
|
||||
"""
|
||||
model = SimpleAttentionWithQKNorm(use_full_hidden_norm=True).to("cuda", dtype=torch.float16)
|
||||
x = torch.randn(1, 8, 64, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Export to graph
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
|
||||
# Check before transform
|
||||
before_rmsnorm = count_ops(gm, torch.ops.auto_deploy.torch_rmsnorm.default)
|
||||
before_sharded = count_ops(gm, torch.ops.auto_deploy.sharded_rmsnorm.default)
|
||||
assert before_rmsnorm == 2, f"Expected 2 torch_rmsnorm before, got {before_rmsnorm}"
|
||||
assert before_sharded == 0, f"Expected 0 sharded_rmsnorm before, got {before_sharded}"
|
||||
|
||||
# Apply sharding transform with world_size=2
|
||||
optimizer = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"detect_sharding": {
|
||||
"stage": "sharding",
|
||||
"simple_shard_only": False,
|
||||
"sharding_source": ["manual", "factory", "heuristic"],
|
||||
"support_partial_config": True,
|
||||
"sharding_dims": ["tp", "ep", "bmm"],
|
||||
"shard_all_unprocessed": True,
|
||||
"allreduce_strategy": "NCCL",
|
||||
"dist_backend": "auto",
|
||||
"requires_shape_prop": True,
|
||||
},
|
||||
"sharding_transform_executor": {
|
||||
"stage": "sharding",
|
||||
"run_shape_prop": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
optimizer.shared_config.local_rank = 0
|
||||
optimizer.shared_config.world_size = 2
|
||||
gm_transformed = optimizer(None, gm)
|
||||
|
||||
# Check after transform
|
||||
after_rmsnorm = count_ops(gm_transformed, torch.ops.auto_deploy.torch_rmsnorm.default)
|
||||
after_sharded = count_ops(gm_transformed, torch.ops.auto_deploy.sharded_rmsnorm.default)
|
||||
|
||||
# The QK norms (weight matching q/k output dims) should be transformed
|
||||
assert after_sharded == 2, (
|
||||
f"Expected 2 sharded_rmsnorm after transform, got {after_sharded}. "
|
||||
f"Remaining torch_rmsnorm: {after_rmsnorm}"
|
||||
)
|
||||
|
||||
def test_per_head_norm_not_transformed(self):
|
||||
"""Test that per-head norm RMSNorm ops are NOT replaced.
|
||||
|
||||
When weight shape = [head_dim] (doesn't match q_proj output dim):
|
||||
- torch_rmsnorm should stay as torch_rmsnorm
|
||||
- No sharded_rmsnorm should be added
|
||||
"""
|
||||
model = SimpleAttentionWithQKNorm(use_full_hidden_norm=False).to(
|
||||
"cuda", dtype=torch.float16
|
||||
)
|
||||
x = torch.randn(1, 8, 64, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Export to graph
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
|
||||
# Check before transform
|
||||
before_rmsnorm = count_ops(gm, torch.ops.auto_deploy.torch_rmsnorm.default)
|
||||
assert before_rmsnorm == 2, f"Expected 2 torch_rmsnorm before, got {before_rmsnorm}"
|
||||
|
||||
# Apply sharding transform with world_size=2
|
||||
# Using the same config as default.yaml for detect_sharding and sharding_transform_executor
|
||||
optimizer = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"detect_sharding": {
|
||||
"stage": "sharding",
|
||||
"simple_shard_only": False,
|
||||
"sharding_source": ["manual", "factory", "heuristic"],
|
||||
"support_partial_config": True,
|
||||
"sharding_dims": ["tp", "ep", "bmm"],
|
||||
"shard_all_unprocessed": True,
|
||||
"allreduce_strategy": "NCCL",
|
||||
"dist_backend": "auto",
|
||||
"requires_shape_prop": True,
|
||||
},
|
||||
"sharding_transform_executor": {
|
||||
"stage": "sharding",
|
||||
"run_shape_prop": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
# Set world_size and rank on shared_config
|
||||
optimizer.shared_config.local_rank = 0
|
||||
optimizer.shared_config.world_size = 2
|
||||
gm_transformed = optimizer(None, gm)
|
||||
|
||||
# Check after transform
|
||||
after_rmsnorm = count_ops(gm_transformed, torch.ops.auto_deploy.torch_rmsnorm.default)
|
||||
after_sharded = count_ops(gm_transformed, torch.ops.auto_deploy.sharded_rmsnorm.default)
|
||||
|
||||
# Per-head norms should NOT be transformed to sharded
|
||||
assert after_sharded == 0, (
|
||||
f"Expected 0 sharded_rmsnorm for per-head norm, got {after_sharded}"
|
||||
)
|
||||
# The original rmsnorm ops should still be present (or fewer if some were removed)
|
||||
# Note: Some rmsnorm ops may be removed/transformed for other reasons, but none should become sharded
|
||||
print(f"After transform: {after_rmsnorm} torch_rmsnorm, {after_sharded} sharded_rmsnorm")
|
||||
|
||||
def test_input_layernorm_not_transformed(self):
|
||||
"""Test that input_layernorm (before q/k/v projections) is NOT replaced.
|
||||
|
||||
When RMSNorm feeds INTO q/k/v projections (like Llama's input_layernorm):
|
||||
- torch_rmsnorm should stay as torch_rmsnorm
|
||||
- No sharded_rmsnorm should be added
|
||||
- Even though weight shape matches [hidden_size], it's not a QK norm
|
||||
|
||||
This tests the fix for the bug where input_layernorm was incorrectly
|
||||
detected as a QK norm because its weight shape matched q_proj output dim.
|
||||
"""
|
||||
model = AttentionWithInputLayernorm().to("cuda", dtype=torch.float16)
|
||||
x = torch.randn(1, 8, 64, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Export to graph
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
|
||||
# Check before transform
|
||||
before_rmsnorm = count_ops(gm, torch.ops.auto_deploy.torch_rmsnorm.default)
|
||||
assert before_rmsnorm == 1, f"Expected 1 torch_rmsnorm before, got {before_rmsnorm}"
|
||||
|
||||
# Apply sharding transform with world_size=4 (like the failing Llama test)
|
||||
optimizer = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"detect_sharding": {
|
||||
"stage": "sharding",
|
||||
"simple_shard_only": False,
|
||||
"sharding_source": ["manual", "factory", "heuristic"],
|
||||
"support_partial_config": True,
|
||||
"sharding_dims": ["tp", "ep", "bmm"],
|
||||
"shard_all_unprocessed": True,
|
||||
"allreduce_strategy": "NCCL",
|
||||
"dist_backend": "auto",
|
||||
"requires_shape_prop": True,
|
||||
},
|
||||
"sharding_transform_executor": {
|
||||
"stage": "sharding",
|
||||
"run_shape_prop": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
optimizer.shared_config.local_rank = 0
|
||||
optimizer.shared_config.world_size = 4
|
||||
gm_transformed = optimizer(None, gm)
|
||||
|
||||
# Check after transform
|
||||
after_rmsnorm = count_ops(gm_transformed, torch.ops.auto_deploy.torch_rmsnorm.default)
|
||||
after_sharded = count_ops(gm_transformed, torch.ops.auto_deploy.sharded_rmsnorm.default)
|
||||
|
||||
# Input layernorm should NOT be transformed to sharded_rmsnorm
|
||||
# because it feeds INTO q/k/v projections, not after them
|
||||
assert after_sharded == 0, (
|
||||
f"Expected 0 sharded_rmsnorm for input_layernorm, got {after_sharded}. "
|
||||
f"input_layernorm should not be detected as a QK norm."
|
||||
)
|
||||
print(f"After transform: {after_rmsnorm} torch_rmsnorm, {after_sharded} sharded_rmsnorm")
|
||||
@ -0,0 +1,138 @@
|
||||
"""Testing module patches that enable export of MiniMax-M2 model.
|
||||
|
||||
This test verifies that the patched MiniMaxM2SparseMoeBlock forward function
|
||||
produces identical outputs to the original HuggingFace implementation.
|
||||
"""
|
||||
|
||||
import types
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from test_common.llm_data import hf_id_to_local_model_dir
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
# Import custom_ops to register torch.ops.auto_deploy.torch_moe
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.models.patches.minimax_m2 import minimax_m2_moe
|
||||
|
||||
|
||||
def _load_minimax_m2_moe_layer(model_name_or_path):
|
||||
"""
|
||||
Loads the MoE layer from MiniMax-M2 model with a minimal configuration.
|
||||
|
||||
We create a small model to keep tests fast while still exercising the
|
||||
MoE routing and computation logic.
|
||||
|
||||
Parameters:
|
||||
model_name_or_path (str): Path or name of the pretrained model.
|
||||
|
||||
Returns:
|
||||
module: The MiniMaxM2SparseMoeBlock layer.
|
||||
"""
|
||||
try:
|
||||
# Load only the model configuration
|
||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
|
||||
# Configure minimal model for fast testing
|
||||
config.num_hidden_layers = 1
|
||||
config.use_cache = False
|
||||
config.hidden_size = 16 # Small hidden size
|
||||
config.intermediate_size = 8 # For MLP within experts
|
||||
config.mlp_intermediate_size = 32
|
||||
config.num_local_experts = 4 # Small number of experts
|
||||
config.num_experts_per_tok = 2 # Top-k experts
|
||||
config.num_attention_heads = 2
|
||||
config.num_key_value_heads = 2
|
||||
config.router_jitter_noise = 0.0 # Disable jitter for deterministic tests
|
||||
|
||||
# Build the model architecture (no weights loaded)
|
||||
# Note: Importing minimax_m2 module auto-patches from_config, so the
|
||||
# instance's forward is already patched. But the CLASS method is still
|
||||
# the original HF implementation, which we use as reference.
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
model.eval()
|
||||
|
||||
# Access the MoE layer
|
||||
layer_name = "model.layers.0.block_sparse_moe"
|
||||
module = dict(model.named_modules()).get(layer_name)
|
||||
if module is None:
|
||||
print(f"Layer '{layer_name}' not found in the model.")
|
||||
else:
|
||||
print(f"Successfully extracted layer '{layer_name}'.")
|
||||
return module
|
||||
except Exception as e:
|
||||
print(f"Error extracting layer: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
pytest.param(
|
||||
hf_id_to_local_model_dir("MiniMaxAI/MiniMax-M2"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_minimax_m2_moe_patch(model_name):
|
||||
"""
|
||||
Test that the patched MiniMaxM2SparseMoeBlock forward produces the same
|
||||
output as the original HuggingFace implementation.
|
||||
|
||||
The patch rewrites the forward to use torch.ops.auto_deploy.torch_moe
|
||||
for torch.export compatibility while maintaining numerical equivalence.
|
||||
|
||||
Since importing minimax_m2.py auto-patches module instances, we use the
|
||||
CLASS method (type(module).forward) as the original HF reference.
|
||||
"""
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Get MoE module (instance is already patched by import side-effect)
|
||||
module = _load_minimax_m2_moe_layer(model_name)
|
||||
assert module is not None, "Failed to load MiniMax-M2 MoE layer"
|
||||
|
||||
# Convert module to bfloat16 to match input dtype
|
||||
module = module.to(torch.bfloat16)
|
||||
|
||||
# Create test input - same input will be used for both original and patched
|
||||
# hidden_size=16 matches the config in _load_minimax_m2_moe_layer
|
||||
hidden_size = 16
|
||||
inputs = torch.randn(2, 6, hidden_size, dtype=torch.bfloat16)
|
||||
|
||||
# The CLASS method is still the original HuggingFace implementation
|
||||
# (the auto-patch only patches instance methods, not the class)
|
||||
original_class_forward = type(module).forward
|
||||
|
||||
# Generate reference output using original HF class method
|
||||
# Uses: same module weights, same input tensor
|
||||
with torch.no_grad():
|
||||
ref_output, ref_router_logits = original_class_forward(module, inputs)
|
||||
|
||||
# The instance forward is already patched by the import side-effect,
|
||||
# but let's be explicit and apply our patch function directly
|
||||
module.forward = types.MethodType(minimax_m2_moe, module)
|
||||
|
||||
# Generate test output using patched implementation
|
||||
# Uses: same module weights, same input tensor
|
||||
with torch.no_grad():
|
||||
test_output, test_router_logits = module(inputs)
|
||||
|
||||
# Verify outputs match
|
||||
# Router logits should be identical (same computation path)
|
||||
torch.testing.assert_close(
|
||||
ref_router_logits,
|
||||
test_router_logits,
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
msg="Router logits mismatch between original and patched MoE",
|
||||
)
|
||||
|
||||
# Final hidden states should be very close
|
||||
# (small tolerance for different computation order in torch_moe)
|
||||
torch.testing.assert_close(
|
||||
ref_output,
|
||||
test_output,
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
msg="Output mismatch between original and patched MoE",
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user