[TRTLLM-10673][feat] Improved layer classification for sharding (#10718)

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
Grzegorz Kwasniewski 2026-02-04 18:06:10 +01:00 committed by GitHub
parent 925d911fc0
commit d90a8e5700
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 375 additions and 127 deletions

View File

@ -17,8 +17,8 @@ from ...custom_ops.quant import (
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import (
extract_weight_nodes,
get_quantization_params_from_linear_node,
get_weight_info,
is_bmm_op,
is_linear_op,
)
@ -141,9 +141,10 @@ class Quantization(BaseTransform):
The state_dict is also updated to contain the sharded weights.
"""
weight_nodes = extract_weight_nodes(node)
assert len(weight_nodes.weights) == 1, "Expected exactly one weight node"
lin_weight = weight_nodes.weights[0]
lin_weight = get_weight_info(node)
if lin_weight is None:
raise ValueError(f"Linear node {node.name} has no weight")
new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False)
modname, _, attrname = lin_weight.node_key.rpartition(".")

View File

@ -40,11 +40,10 @@ from ...utils.node_utils import (
LayerType,
bfs,
extract_weight_name,
extract_weight_nodes,
filtered_nodes,
get_all_layer_subgraphs,
get_all_weight_infos,
get_all_weights_in_subgraph,
get_layer_after_linear_node,
is_any_attention_op,
is_any_lin_op,
is_any_moe_op,
@ -1296,6 +1295,11 @@ def _shard_parameter_node(
rank, world_size = config.rank, config.world_size
allreduce_strategy = config.allreduce_strategy.name
if "sharded" in node.meta and node.meta["sharded"]:
# Node was already sharded, skip
return
node.meta["sharded"] = True
num_users = num_users_of_weight_node(node)
if num_users > 1 or num_users == 0:
ad_logger.warning(
@ -1304,12 +1308,17 @@ def _shard_parameter_node(
return
# Shard weight using the unified function (also updates the parameter)
weight_nodes = extract_weight_nodes(node)
for weight_node in weight_nodes.weights:
all_weight_infos = get_all_weight_infos(node)
# Parametrized nodes must have at least one weight (for debugging)
assert len(all_weight_infos.weights) > 0, (
f"Node {node.name} has no weights - weight mapping may be incorrect"
)
for weight_info in all_weight_infos.weights:
_, weight_new_shape = shard_weight_tensor(
gm=gm,
weight_tensor=weight_node.tensor,
param_key=weight_node.node_key,
weight_tensor=weight_info.tensor,
param_key=weight_info.node_key,
dim=dim,
rank=rank,
world_size=world_size,
@ -1319,29 +1328,29 @@ def _shard_parameter_node(
if quantization_cb is not None:
quantization_cb(
gm=gm,
submod=weight_node.submod,
submod=weight_info.submod,
node=node,
weight_key=weight_node.node_key,
weight_key=weight_info.node_key,
weight_new_shape=weight_new_shape,
dim=dim,
rank=rank,
world_size=world_size,
)
for bias_node in weight_nodes.biases:
for bias_info in all_weight_infos.biases:
if dim == 0:
# update bias for dim 0 --> we can handle it like the weight
shard_weight_tensor(
gm=gm,
weight_tensor=bias_node.tensor,
param_key=bias_node.node_key,
weight_tensor=bias_info.tensor,
param_key=bias_info.node_key,
dim=dim,
rank=rank,
world_size=world_size,
min_local_shape=min_local_shape,
fused_weight_dims=fused_weight_dims,
)
elif bias_node is not None and rank != world_size - 1:
elif rank != world_size - 1:
# update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
# double counting it. For all other we will delete the bias.
args = list(node.args)
@ -1349,10 +1358,10 @@ def _shard_parameter_node(
args[2] = None
node.args = tuple(args)
gm.graph.erase_node(node_bias)
bias_param_name = bias_node.node_key.rpartition(".")[-1]
setattr(bias_node.submod, bias_param_name, None)
bias_param_name = bias_info.node_key.rpartition(".")[-1]
setattr(bias_info.submod, bias_param_name, None)
gm._register_load_state_dict_pre_hook(
partial(_load_hook_remove, param_key=bias_node.node_key)
partial(_load_hook_remove, param_key=bias_info.node_key)
)
# # # column shard with no gather: the output is sharded
@ -2295,47 +2304,37 @@ def detect_sharding_from_config(
raise ValueError(f"Unsupported sharding source: {source}")
tp_plan = config["tp_plan"]
# If the node is inside the attention module, we need to set min_local_shape to the
# head_dim - otherwise, we would risk splitting the heads into smaller shards.
# TODO: is there a better way to check if we are in attention module?
attn_names = [
"attention",
"Attention",
"attn",
"Attn",
"q_proj",
"k_proj",
"v_proj",
"o_proj",
]
num_shards = 0
num_simple_shards = 0
num_row_col_shards = 0
num_attention_shards = 0
num_ssm_shards = 0
head_dim = -1
linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op))
# use layer_subgraphs to determine the layer_type
# and check the validity of the sharding transform
layer_subgraphs, unprocessed_linear_nodes = get_all_layer_subgraphs(gm)
for lin_node in linear_nodes:
# use node's weight name to get the module name
weight_name = extract_weight_name(lin_node)
if any(attn_name in weight_name for attn_name in attn_names):
# find the next attention node and infer the head_dim
next_attention_node, _ = bfs(
lin_node, is_any_attention_op, attr_next="users", include_root=False
)
if next_attention_node is None:
# this is the last attention node in the graph. Take the previously found head_dim
assert head_dim != -1, "Head dim not found for the last attention node"
else:
head_dim = shape(next_attention_node)[-1]
min_local_shape = head_dim
layer_type = LayerType.ATTENTION
# get the parent layer_subgraph
layer_subgraph = [
layer
for layer in layer_subgraphs
if lin_node in layer.opening_nodes or lin_node == layer.terminating_node
]
if len(layer_subgraph) == 1:
layer_subgraph = layer_subgraph[0]
layer_type = layer_subgraph.layer_type
else:
min_local_shape = 1
layer_type = LayerType.MLP
if lin_node in unprocessed_linear_nodes:
layer_type = LayerType.UNKNOWN
else:
ad_logger.warning(
f"Failed to find the parent layer_subgraph for linear node {lin_node}. "
f"May result in incorrect sharding."
)
# use regex to find if module_name matches any of the keys in sharding_config
for key in tp_plan.keys():
@ -2349,11 +2348,6 @@ def detect_sharding_from_config(
# we have a match. Get the config for this layer
config = tp_plan[key]
if config in ["colwise", "mamba"]:
cur_node_index = linear_nodes.index(lin_node)
layer_subgraph = get_layer_after_linear_node(
linear_nodes, [cur_node_index - 1], enforce_strict_linear_history=False
)
if config == "colwise":
_process_column_sharding(
layer_subgraph=layer_subgraph,
@ -2366,7 +2360,6 @@ def detect_sharding_from_config(
split_dim=SplitDimension.ROW,
config=transform_container.config,
dist_op="all_reduce",
min_local_shape=min_local_shape,
layer_type=layer_type,
)
):
@ -2393,7 +2386,6 @@ def detect_sharding_from_config(
split_dim=SplitDimension.COLUMN,
config=transform_container.config,
dist_op=None,
min_local_shape=min_local_shape,
layer_type=layer_type,
)
)
@ -2404,7 +2396,6 @@ def detect_sharding_from_config(
split_dim=SplitDimension.ROW,
config=transform_container.config,
dist_op="all_reduce",
min_local_shape=min_local_shape,
layer_type=layer_type,
)
):
@ -2423,7 +2414,6 @@ def detect_sharding_from_config(
split_dim=SplitDimension.COLUMN,
config=transform_container.config,
dist_op="all_gather",
min_local_shape=1,
layer_type=layer_type,
)
):
@ -2536,7 +2526,7 @@ def detect_column_row_shard(
attention_nodes = list(filtered_nodes(layer_subgraph, is_any_attention_op))
min_local_shape = 1
if config.simple_shard_only:
if config.simple_shard_only or layer.layer_type == LayerType.UNKNOWN:
ad_logger.debug(
f"Forcing Simple Shard on nodes: {nodes_linear} with layer type: {layer.layer_type}"
)

View File

@ -3,7 +3,6 @@
import operator
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Callable, Iterable, List, Optional, Tuple, Union
import torch
@ -153,10 +152,19 @@ def get_all_weights_in_subgraph(
def extract_weight_name(node: Node) -> Union[str, bool]:
weight_nodes = extract_weight_nodes(node)
if len(weight_nodes.weights) == 0:
"""
Extract the weight parameter name for a compute node.
Args:
node: Compute node (linear, MoE, SSM, etc.)
Returns:
Weight parameter name (str), or False if no weight exists.
"""
weight_node = get_weight_node(node)
if weight_node is None:
return False
return weight_nodes.weights[0].node_key
return weight_node.target
def get_param_or_buffer(tensor_name: str, gm: GraphModule) -> torch.Tensor:
@ -248,8 +256,16 @@ def extract_weight_nodes(node: Node) -> WeightNodes:
def num_users_of_weight_node(node: Node) -> int:
"""Returns the number of users of the weight node of the given parametrized node."""
weight_node = extract_weight_nodes(node).weights[0].node
"""
Get the number of users of the weight node.
Args:
node: Compute node (linear, MoE, SSM, etc.)
Returns:
Number of users of the primary weight node, or 0 if no weight exists.
"""
weight_node = get_weight_node(node)
return len(weight_node.users) if weight_node is not None else 0
@ -373,6 +389,13 @@ def is_any_moe_op(node: Node) -> bool:
)
def is_residual_add(node: Node) -> bool:
if is_op(node, torch.ops.aten.add):
if len(list(filtered_nodes(node.args, is_any_lin_op))) == 1:
return True
return False
def is_any_ssm_op(node: Node) -> bool:
return is_op(
node,
@ -446,6 +469,167 @@ def is_weight_node(node: Node) -> bool:
return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0
# Auxiliary ops that may appear between a weight node and its consumer compute node
_WEIGHT_AUX_OPS = frozenset(
{
torch.ops.aten.to.dtype,
torch.ops.aten.view.default,
}
)
def precompute_weight_node_mapping(gm: GraphModule) -> None:
"""
Pre-compute weight-to-consumer mapping for all weight nodes in the graph.
For each weight node (get_attr), finds the consumer compute node by traversing
through auxiliary ops (to.dtype, view.default). Stores the mapping in consumer
node's metadata:
- node.meta["weight_nodes"]: list of weight nodes (non-bias)
- node.meta["bias_nodes"]: list of bias nodes
This enables O(1) weight node lookup instead of O(depth) backward traversal.
Called automatically on first weight lookup via lazy initialization.
GUARANTEES (verified by assertions for debugging):
- Called exactly once per GraphModule
- No duplicate weight/bias nodes in any consumer's lists
- Each weight node mapped to exactly one consumer
"""
# Early return if already computed
if "_weight_mapping_computed" in gm.meta:
return
gm.meta["_weight_mapping_computed"] = True
for node in gm.graph.nodes:
if not is_weight_node(node):
continue
is_bias = node.target.endswith("bias")
# Find the consumer compute node by traversing through auxiliary ops
current = node
visited = {current}
while True:
# Get users of current node
users = list(current.users.keys())
if not users:
break
# Check if any user is a compute node (not an auxiliary op)
consumer_found = None
aux_node = None
for user in users:
if is_bias:
if "bias_nodes" not in user.meta:
user.meta["bias_nodes"] = []
# ASSERTION: Each weight node should be mapped exactly once
assert node not in user.meta["bias_nodes"], (
f"Duplicate bias node {node.name} found for consumer {user.name}"
)
user.meta["bias_nodes"].append(node)
else:
if "weight_nodes" not in user.meta:
user.meta["weight_nodes"] = []
# ASSERTION: Each weight node should be mapped exactly once
assert node not in user.meta["weight_nodes"], (
f"Duplicate weight node {node.name} found for consumer {user.name}"
)
user.meta["weight_nodes"].append(node)
if user.target in _WEIGHT_AUX_OPS:
# This is an auxiliary op, continue traversing
aux_node = user
else:
# This is a potential consumer compute node
consumer_found = user
break
if consumer_found is not None:
# Found the consumer, return
break
elif aux_node is not None and aux_node not in visited:
# Continue through auxiliary op
current = aux_node
visited.add(current)
else:
# No more nodes to traverse
break
def _ensure_weight_mapping(node: Node) -> None:
"""Ensure weight node mapping is computed. Lazily calls precompute if needed."""
gm = node.graph.owning_module
if "_weight_mapping_computed" not in gm.meta or not gm.meta["_weight_mapping_computed"]:
precompute_weight_node_mapping(gm)
def get_weight_node(node: Node) -> Optional[Node]:
"""Get the primary weight node for a compute node"""
_ensure_weight_mapping(node)
weight_nodes = node.meta.get("weight_nodes", [])
return weight_nodes[0] if weight_nodes else None
def get_weight_nodes(node: Node) -> List[Node]:
"""Get all weight nodes for a compute node"""
_ensure_weight_mapping(node)
return node.meta.get("weight_nodes", [])
def get_bias_nodes(node: Node) -> List[Node]:
"""Get all bias nodes for a compute node"""
_ensure_weight_mapping(node)
return node.meta.get("bias_nodes", [])
@dataclass
class WeightInfo:
"""Lightweight weight info extracted from a weight node."""
node: Node
node_key: str
tensor: torch.Tensor
submod: nn.Module
def _weight_node_to_info(weight_node: Node, gm: GraphModule) -> WeightInfo:
"""Convert a weight node to WeightInfo."""
node_key = weight_node.target
tensor = get_param_or_buffer(node_key, gm)
submod = gm.get_submodule(node_key.rpartition(".")[0])
return WeightInfo(node=weight_node, node_key=node_key, tensor=tensor, submod=submod)
def get_weight_info(node: Node) -> Optional[WeightInfo]:
"""Extract weight info for the primary weight of a compute node."""
weight_node = get_weight_node(node)
if weight_node is None:
return None
return _weight_node_to_info(weight_node, node.graph.owning_module)
@dataclass
class AllWeightInfos:
"""Container for all weight and bias infos of a compute node."""
weights: List[WeightInfo]
biases: List[WeightInfo]
def get_all_weight_infos(node: Node) -> AllWeightInfos:
"""Extract all weight and bias infos for a compute node."""
gm = node.graph.owning_module
weight_nodes = get_weight_nodes(node)
bias_nodes = get_bias_nodes(node)
return AllWeightInfos(
weights=[_weight_node_to_info(wn, gm) for wn in weight_nodes],
biases=[_weight_node_to_info(bn, gm) for bn in bias_nodes],
)
def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0):
"""Get a user from a node if the node matches a given op set and num of users."""
if node is None:
@ -515,9 +699,12 @@ def identify_regions_between_residuals(gm: GraphModule) -> List[Node]:
return boundary_nodes
def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
def get_all_layer_subgraphs(gm: GraphModule) -> tuple[List[LayerSubgraph], set[Node]]:
"""
Get subgraphs corresponding to all consecutive layers (attention, MLP, SSM, MoE) in the graph.
Get subgraphs for all consecutive layers (attention, MLP, SSM, MoE) in the graph.
Pre-computes weight mappings and caches weight shapes for all linear nodes.
Each layer is contained between opening linear layers and a single closing linear layer.
Assumptions:
1. each layer (each subgraph) is contained between a list of opening
@ -546,18 +733,32 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
assert gm.graph.nodes, "Graph is empty"
layer_subgraphs = []
linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op))
# Pre-compute weight-to-consumer mapping for O(1) weight node lookup
precompute_weight_node_mapping(gm)
# Cache weight shapes for all linear nodes
for lin_node in linear_nodes:
if "lin_node_shape" not in lin_node.meta:
shape = get_weight_shape(lin_node)
if shape is not None:
lin_node.meta["lin_node_shape"] = shape
# Find the embedding size from the first linear node
embd = get_weight_shape(linear_nodes[0], dim=-1)
if embd is None:
raise ValueError("Failed to extract embedding size from first linear node")
unprocessed_linear_nodes = set(linear_nodes)
assert len(linear_nodes) > 0, "Could not find any linear nodes in the graph"
terminating_indices = [-1]
last_lin_index = terminating_indices[-1] + 1
# for each linear node, find its layer subgraph defined as regions between consecutive linear nodes
# For each linear node, find its layer subgraph defined as regions between consecutive linear nodes.
while last_lin_index < len(linear_nodes):
# opening is the list of linear nodes
# layer_subgraph is the list of nodes between the opening and closing linear nodes
# closing is the last linear node in the layer
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices)
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices, embd=embd)
if layer_subgraph.opening_nodes is not None and len(layer_subgraph.opening_nodes) > 0:
unprocessed_linear_nodes -= (
set(layer_subgraph.opening_nodes)
@ -567,7 +768,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
layer_subgraphs.append(layer_subgraph)
last_lin_index = terminating_indices[-1] + 1
# unprocessed linear nodes can be "simple sharded".
# Unprocessed linear nodes can be "simple sharded".
return layer_subgraphs, unprocessed_linear_nodes
@ -805,12 +1006,16 @@ def subgraph(
def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]:
"""Get the shape of the weight node."""
"""Get weight shape for a linear operation node. Returns None if no weight."""
if not is_any_lin_op(node):
return None
s = list(shape(extract_weight_nodes(node).weights[0].node))
if len(s) == 0:
weight_node = get_weight_node(node)
if weight_node is None:
return None
s = list(shape(weight_node))
if is_fp4_op(node):
# FP4 weights are packed as uint8 type with 2 FP4 values per element
s[-1] *= 2
@ -823,6 +1028,7 @@ def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[in
def get_layer_after_linear_node(
linear_nodes: List[Node],
terminating_indices: List[int],
embd: int,
match_on_shapes: bool = True,
enforce_strict_linear_history: bool = True,
) -> LayerSubgraph:
@ -856,37 +1062,42 @@ def get_layer_after_linear_node(
Args:
linear_nodes: List of linear nodes in the graph.
terminating_indices: List of indices of terminating linear nodes.
match_on_shapes: If True, the layer is matched on shapes of the nodes.
If False, the layer is matched on the nodes themselves.
embd: Embedding size for shape matching.
match_on_shapes: If True, match layers on embedding shapes.
enforce_strict_linear_history: If True, enforce strict ordering constraints.
Returns:
LayerSubgraph: The layer subgraph.
LayerSubgraph containing opening nodes, subgraph nodes, and terminating node.
"""
def boundary_condition(
node: Node, embd: Optional[int] = None, dim: Optional[int] = None
) -> bool:
if embd is not None and dim is not None:
def boundary_condition(node: Node, dim: int) -> bool:
if match_on_shapes:
if is_any_lin_op(node):
return node.meta["lin_node_shape"][dim] == embd
return (
# match on embedding size
(is_any_lin_op(node) and get_weight_shape(node, dim=dim) == embd)
or is_any_moe_op(node)
is_any_moe_op(node)
or is_op(node, ops=[torch.ops.aten.sym_size, torch.ops.aten.bmm])
or is_residual_add(node)
)
else:
return (
is_any_lin_op(node)
or is_any_moe_op(node)
or is_op(node, ops=[torch.ops.aten.sym_size, torch.ops.aten.bmm])
or is_residual_add(node)
)
def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int] = None) -> bool:
if embd is not None and dim is not None:
return is_any_lin_op(node) and get_weight_shape(node, dim=dim) == embd
def filter_condition(node: Node, dim: int) -> bool:
if match_on_shapes:
if is_any_lin_op(node):
return node.meta["lin_node_shape"][dim] == embd
return False
else:
return is_any_lin_op(node)
lin_nodes_in_subgraph = []
start_lin_index = terminating_indices[-1] + 1
while len(lin_nodes_in_subgraph) != 1:
if start_lin_index >= len(linear_nodes):
terminating_indices.append(len(linear_nodes))
@ -896,30 +1107,39 @@ def get_layer_after_linear_node(
terminating_node=None,
layer_type=LayerType.UNKNOWN,
)
if match_on_shapes:
# get embedding size of the opening linear node
embd = get_weight_shape(linear_nodes[start_lin_index], dim=-1)
# partial init boundary_condition and filter_condition
boundary_condition = partial(boundary_condition, embd=embd, dim=0)
filter_condition = partial(filter_condition, embd=embd, dim=0)
forward_subgraph = subgraph(
sources=[linear_nodes[start_lin_index]], boundary_condition=boundary_condition
sources=[linear_nodes[start_lin_index]],
boundary_condition=lambda n: boundary_condition(n, dim=0),
)
lin_nodes_in_subgraph = list(filtered_nodes(forward_subgraph, filter_condition))
lin_nodes_in_subgraph = list(
filtered_nodes(forward_subgraph, lambda n: filter_condition(n, dim=0))
)
if len(lin_nodes_in_subgraph) > 1:
# it means that probably we went over the boundary of the layer.
# It may happen e.g., with MoLE (latent MoE), with the closing latent fc2 projection,
# when the subgraph spanned over fc2 "spills" over consecutive layers.
# Then, wrap this single linear node in LayerType.UNKNOWN and return.
terminating_indices.append(start_lin_index)
return LayerSubgraph(
opening_nodes=[linear_nodes[start_lin_index]],
subgraph_nodes=[],
terminating_node=linear_nodes[start_lin_index],
layer_type=LayerType.UNKNOWN,
)
start_lin_index += 1
start_lin_index -= 1
terminating_linear_node = lin_nodes_in_subgraph[0]
# for backward pass, match embedding on the dim=0
if match_on_shapes:
boundary_condition = partial(boundary_condition, embd=embd, dim=-1)
filter_condition = partial(filter_condition, embd=embd, dim=-1)
# For backward pass, match embedding on dim=-1
backward_subgraph = subgraph(
sinks=[terminating_linear_node], boundary_condition=boundary_condition
sinks=[terminating_linear_node], boundary_condition=lambda n: boundary_condition(n, dim=-1)
)
# Get all opening linear nodes
opening_linear_nodes = list(
filtered_nodes(backward_subgraph, lambda n: filter_condition(n, dim=-1))
)
# get all opening linear nodes
opening_linear_nodes = list(filtered_nodes(backward_subgraph, filter_condition))
if enforce_strict_linear_history:
# opening nodes must succeed last terminating node
@ -939,32 +1159,65 @@ def get_layer_after_linear_node(
ssm_nodes = list(filtered_nodes(interior_nodes, is_any_ssm_op))
attention_nodes = list(filtered_nodes(interior_nodes, is_any_attention_op))
intermediate_lin_nodes = list(filtered_nodes(interior_nodes, is_any_lin_op))
layer_type = LayerType.MLP
min_local_shape = 1
if len(ssm_nodes) > 0:
assert len(ssm_nodes) == 1, "SSM layer must have exactly one SSM node"
layer_type = LayerType.SSM
# determine head size
min_local_shape = shape(ssm_nodes[0])[-1]
if len(attention_nodes) > 0:
assert len(attention_nodes) == 1, "Attention layer must have exactly one attention node"
layer_type = LayerType.ATTENTION
# determine head size
min_local_shape = shape(attention_nodes[0])[-1]
if len(intermediate_lin_nodes) > 0:
assert len(intermediate_lin_nodes) == 2, (
"MLA layer must have exactly two intermediate linear nodes"
intermediate_weight_nodes = list(
filtered_nodes(
interior_nodes, lambda n: is_weight_node(n) and not is_any_lin_op(list(n.users)[0])
)
assert len(attention_nodes) == 1, "MLA layer must have exactly one attention node"
layer_type = LayerType.MLA
)
####################################################
########## LAYER TYPE CLASSIFICATION ###############
####################################################
def classify_layer_type() -> [LayerType, int]:
if len(ssm_nodes) + len(attention_nodes) > 1:
return LayerType.UNKNOWN, 1
if len(attention_nodes) == 1:
head_size = shape(attention_nodes[0])[-1]
# check if this is MLA:
# these two intermediate linear nodes are the latent q and kv projections.
if len(intermediate_lin_nodes) == 2:
# MLA has a RMS norm inside, so it should have one (or two, couning biaas)
# intermediate weight nodes
if len(intermediate_weight_nodes) not in [1, 2]:
return LayerType.UNKNOWN, 1
return LayerType.MLA, head_size
else:
if len(intermediate_lin_nodes) != 0:
return LayerType.UNKNOWN, 1
return LayerType.ATTENTION, head_size
if len(ssm_nodes) == 1:
head_size = shape(ssm_nodes[0])[-1]
# Mamba layers should not have any intermediate linear nodes.
if len(intermediate_lin_nodes) > 0:
return LayerType.UNKNOWN, 1
# Mamba layer should have 3 to 6 intermediate weight nodes:
# - conv1d weight
# - A (A_log)
# - D
# - conv1d bias [optional]
# - dt_bias [optional]
# - RMS norm [optional]
if len(intermediate_weight_nodes) not in list(range(3, 7)):
return LayerType.UNKNOWN, 1
return LayerType.SSM, head_size
# if we reach here, it means the layer is a MLP.
# MLP should not have any intermediate linear or weight nodes.
if len(intermediate_lin_nodes) > 0 or len(intermediate_weight_nodes) > 0:
return LayerType.UNKNOWN, 1
return LayerType.MLP, 1
layer_type, head_size = classify_layer_type()
layer_subgraph = LayerSubgraph(
opening_nodes=opening_linear_nodes,
subgraph_nodes=interior_nodes,
terminating_node=terminating_linear_node,
layer_type=layer_type,
min_local_shape=min_local_shape,
min_local_shape=head_size,
)
assert linear_nodes[start_lin_index] in opening_linear_nodes, (
f"Linear node not found in opening linear nodes - "
@ -986,7 +1239,7 @@ def get_layer_after_linear_node(
"ill-formed layer subgraph"
)
terminating_indices.append(terminating_index)
# otherwise, we are done. We processed the last linear node.
return layer_subgraph
@ -1001,9 +1254,13 @@ def shape(node: Node) -> Tuple[int, ...]:
def get_weight_tensor(node: Node) -> torch.Tensor:
"""Extract the weight tensor from a node within a GraphModule."""
weight_nodes = extract_weight_nodes(node)
return weight_nodes.weights[0].tensor
"""Extract the weight tensor from a compute node."""
weight_node = get_weight_node(node)
if weight_node is None:
raise ValueError(f"Node {node.name} has no weight")
gm = node.graph.owning_module
return get_param_or_buffer(weight_node.target, gm)
def draw_graph(gm: GraphModule, filename: str):