diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 295d17aedc..a9254b1294 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -17,8 +17,8 @@ from ...custom_ops.quant import ( from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import ( - extract_weight_nodes, get_quantization_params_from_linear_node, + get_weight_info, is_bmm_op, is_linear_op, ) @@ -141,9 +141,10 @@ class Quantization(BaseTransform): The state_dict is also updated to contain the sharded weights. """ - weight_nodes = extract_weight_nodes(node) - assert len(weight_nodes.weights) == 1, "Expected exactly one weight node" - lin_weight = weight_nodes.weights[0] + lin_weight = get_weight_info(node) + if lin_weight is None: + raise ValueError(f"Linear node {node.name} has no weight") + new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False) modname, _, attrname = lin_weight.node_key.rpartition(".") diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 1ecf85aeda..f93b506201 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -40,11 +40,10 @@ from ...utils.node_utils import ( LayerType, bfs, extract_weight_name, - extract_weight_nodes, filtered_nodes, get_all_layer_subgraphs, + get_all_weight_infos, get_all_weights_in_subgraph, - get_layer_after_linear_node, is_any_attention_op, is_any_lin_op, is_any_moe_op, @@ -1296,6 +1295,11 @@ def _shard_parameter_node( rank, world_size = config.rank, config.world_size allreduce_strategy = config.allreduce_strategy.name + if "sharded" in node.meta and node.meta["sharded"]: + # Node was already sharded, skip + return + node.meta["sharded"] = True + num_users = num_users_of_weight_node(node) if num_users > 1 or num_users == 0: ad_logger.warning( @@ -1304,12 +1308,17 @@ def _shard_parameter_node( return # Shard weight using the unified function (also updates the parameter) - weight_nodes = extract_weight_nodes(node) - for weight_node in weight_nodes.weights: + all_weight_infos = get_all_weight_infos(node) + # Parametrized nodes must have at least one weight (for debugging) + assert len(all_weight_infos.weights) > 0, ( + f"Node {node.name} has no weights - weight mapping may be incorrect" + ) + + for weight_info in all_weight_infos.weights: _, weight_new_shape = shard_weight_tensor( gm=gm, - weight_tensor=weight_node.tensor, - param_key=weight_node.node_key, + weight_tensor=weight_info.tensor, + param_key=weight_info.node_key, dim=dim, rank=rank, world_size=world_size, @@ -1319,29 +1328,29 @@ def _shard_parameter_node( if quantization_cb is not None: quantization_cb( gm=gm, - submod=weight_node.submod, + submod=weight_info.submod, node=node, - weight_key=weight_node.node_key, + weight_key=weight_info.node_key, weight_new_shape=weight_new_shape, dim=dim, rank=rank, world_size=world_size, ) - for bias_node in weight_nodes.biases: + for bias_info in all_weight_infos.biases: if dim == 0: # update bias for dim 0 --> we can handle it like the weight shard_weight_tensor( gm=gm, - weight_tensor=bias_node.tensor, - param_key=bias_node.node_key, + weight_tensor=bias_info.tensor, + param_key=bias_info.node_key, dim=dim, rank=rank, world_size=world_size, min_local_shape=min_local_shape, fused_weight_dims=fused_weight_dims, ) - elif bias_node is not None and rank != world_size - 1: + elif rank != world_size - 1: # update the bias for dim 1 --> in this case only the last rank gets the bias to avoid # double counting it. For all other we will delete the bias. args = list(node.args) @@ -1349,10 +1358,10 @@ def _shard_parameter_node( args[2] = None node.args = tuple(args) gm.graph.erase_node(node_bias) - bias_param_name = bias_node.node_key.rpartition(".")[-1] - setattr(bias_node.submod, bias_param_name, None) + bias_param_name = bias_info.node_key.rpartition(".")[-1] + setattr(bias_info.submod, bias_param_name, None) gm._register_load_state_dict_pre_hook( - partial(_load_hook_remove, param_key=bias_node.node_key) + partial(_load_hook_remove, param_key=bias_info.node_key) ) # # # column shard with no gather: the output is sharded @@ -2295,47 +2304,37 @@ def detect_sharding_from_config( raise ValueError(f"Unsupported sharding source: {source}") tp_plan = config["tp_plan"] - # If the node is inside the attention module, we need to set min_local_shape to the - # head_dim - otherwise, we would risk splitting the heads into smaller shards. - # TODO: is there a better way to check if we are in attention module? - attn_names = [ - "attention", - "Attention", - "attn", - "Attn", - "q_proj", - "k_proj", - "v_proj", - "o_proj", - ] - num_shards = 0 num_simple_shards = 0 num_row_col_shards = 0 num_attention_shards = 0 num_ssm_shards = 0 - head_dim = -1 linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op)) + # use layer_subgraphs to determine the layer_type + # and check the validity of the sharding transform + layer_subgraphs, unprocessed_linear_nodes = get_all_layer_subgraphs(gm) + for lin_node in linear_nodes: # use node's weight name to get the module name weight_name = extract_weight_name(lin_node) - - if any(attn_name in weight_name for attn_name in attn_names): - # find the next attention node and infer the head_dim - next_attention_node, _ = bfs( - lin_node, is_any_attention_op, attr_next="users", include_root=False - ) - if next_attention_node is None: - # this is the last attention node in the graph. Take the previously found head_dim - assert head_dim != -1, "Head dim not found for the last attention node" - else: - head_dim = shape(next_attention_node)[-1] - min_local_shape = head_dim - layer_type = LayerType.ATTENTION + # get the parent layer_subgraph + layer_subgraph = [ + layer + for layer in layer_subgraphs + if lin_node in layer.opening_nodes or lin_node == layer.terminating_node + ] + if len(layer_subgraph) == 1: + layer_subgraph = layer_subgraph[0] + layer_type = layer_subgraph.layer_type else: - min_local_shape = 1 - layer_type = LayerType.MLP + if lin_node in unprocessed_linear_nodes: + layer_type = LayerType.UNKNOWN + else: + ad_logger.warning( + f"Failed to find the parent layer_subgraph for linear node {lin_node}. " + f"May result in incorrect sharding." + ) # use regex to find if module_name matches any of the keys in sharding_config for key in tp_plan.keys(): @@ -2349,11 +2348,6 @@ def detect_sharding_from_config( # we have a match. Get the config for this layer config = tp_plan[key] - if config in ["colwise", "mamba"]: - cur_node_index = linear_nodes.index(lin_node) - layer_subgraph = get_layer_after_linear_node( - linear_nodes, [cur_node_index - 1], enforce_strict_linear_history=False - ) if config == "colwise": _process_column_sharding( layer_subgraph=layer_subgraph, @@ -2366,7 +2360,6 @@ def detect_sharding_from_config( split_dim=SplitDimension.ROW, config=transform_container.config, dist_op="all_reduce", - min_local_shape=min_local_shape, layer_type=layer_type, ) ): @@ -2393,7 +2386,6 @@ def detect_sharding_from_config( split_dim=SplitDimension.COLUMN, config=transform_container.config, dist_op=None, - min_local_shape=min_local_shape, layer_type=layer_type, ) ) @@ -2404,7 +2396,6 @@ def detect_sharding_from_config( split_dim=SplitDimension.ROW, config=transform_container.config, dist_op="all_reduce", - min_local_shape=min_local_shape, layer_type=layer_type, ) ): @@ -2423,7 +2414,6 @@ def detect_sharding_from_config( split_dim=SplitDimension.COLUMN, config=transform_container.config, dist_op="all_gather", - min_local_shape=1, layer_type=layer_type, ) ): @@ -2536,7 +2526,7 @@ def detect_column_row_shard( attention_nodes = list(filtered_nodes(layer_subgraph, is_any_attention_op)) min_local_shape = 1 - if config.simple_shard_only: + if config.simple_shard_only or layer.layer_type == LayerType.UNKNOWN: ad_logger.debug( f"Forcing Simple Shard on nodes: {nodes_linear} with layer type: {layer.layer_type}" ) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 78755d353d..efba02306f 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -3,7 +3,6 @@ import operator from dataclasses import dataclass from enum import Enum -from functools import partial from typing import Callable, Iterable, List, Optional, Tuple, Union import torch @@ -153,10 +152,19 @@ def get_all_weights_in_subgraph( def extract_weight_name(node: Node) -> Union[str, bool]: - weight_nodes = extract_weight_nodes(node) - if len(weight_nodes.weights) == 0: + """ + Extract the weight parameter name for a compute node. + + Args: + node: Compute node (linear, MoE, SSM, etc.) + + Returns: + Weight parameter name (str), or False if no weight exists. + """ + weight_node = get_weight_node(node) + if weight_node is None: return False - return weight_nodes.weights[0].node_key + return weight_node.target def get_param_or_buffer(tensor_name: str, gm: GraphModule) -> torch.Tensor: @@ -248,8 +256,16 @@ def extract_weight_nodes(node: Node) -> WeightNodes: def num_users_of_weight_node(node: Node) -> int: - """Returns the number of users of the weight node of the given parametrized node.""" - weight_node = extract_weight_nodes(node).weights[0].node + """ + Get the number of users of the weight node. + + Args: + node: Compute node (linear, MoE, SSM, etc.) + + Returns: + Number of users of the primary weight node, or 0 if no weight exists. + """ + weight_node = get_weight_node(node) return len(weight_node.users) if weight_node is not None else 0 @@ -373,6 +389,13 @@ def is_any_moe_op(node: Node) -> bool: ) +def is_residual_add(node: Node) -> bool: + if is_op(node, torch.ops.aten.add): + if len(list(filtered_nodes(node.args, is_any_lin_op))) == 1: + return True + return False + + def is_any_ssm_op(node: Node) -> bool: return is_op( node, @@ -446,6 +469,167 @@ def is_weight_node(node: Node) -> bool: return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0 +# Auxiliary ops that may appear between a weight node and its consumer compute node +_WEIGHT_AUX_OPS = frozenset( + { + torch.ops.aten.to.dtype, + torch.ops.aten.view.default, + } +) + + +def precompute_weight_node_mapping(gm: GraphModule) -> None: + """ + Pre-compute weight-to-consumer mapping for all weight nodes in the graph. + + For each weight node (get_attr), finds the consumer compute node by traversing + through auxiliary ops (to.dtype, view.default). Stores the mapping in consumer + node's metadata: + - node.meta["weight_nodes"]: list of weight nodes (non-bias) + - node.meta["bias_nodes"]: list of bias nodes + + This enables O(1) weight node lookup instead of O(depth) backward traversal. + Called automatically on first weight lookup via lazy initialization. + + GUARANTEES (verified by assertions for debugging): + - Called exactly once per GraphModule + - No duplicate weight/bias nodes in any consumer's lists + - Each weight node mapped to exactly one consumer + """ + # Early return if already computed + if "_weight_mapping_computed" in gm.meta: + return + gm.meta["_weight_mapping_computed"] = True + + for node in gm.graph.nodes: + if not is_weight_node(node): + continue + + is_bias = node.target.endswith("bias") + + # Find the consumer compute node by traversing through auxiliary ops + current = node + visited = {current} + + while True: + # Get users of current node + users = list(current.users.keys()) + if not users: + break + + # Check if any user is a compute node (not an auxiliary op) + consumer_found = None + aux_node = None + + for user in users: + if is_bias: + if "bias_nodes" not in user.meta: + user.meta["bias_nodes"] = [] + # ASSERTION: Each weight node should be mapped exactly once + assert node not in user.meta["bias_nodes"], ( + f"Duplicate bias node {node.name} found for consumer {user.name}" + ) + user.meta["bias_nodes"].append(node) + else: + if "weight_nodes" not in user.meta: + user.meta["weight_nodes"] = [] + # ASSERTION: Each weight node should be mapped exactly once + assert node not in user.meta["weight_nodes"], ( + f"Duplicate weight node {node.name} found for consumer {user.name}" + ) + user.meta["weight_nodes"].append(node) + if user.target in _WEIGHT_AUX_OPS: + # This is an auxiliary op, continue traversing + aux_node = user + else: + # This is a potential consumer compute node + consumer_found = user + break + + if consumer_found is not None: + # Found the consumer, return + break + elif aux_node is not None and aux_node not in visited: + # Continue through auxiliary op + current = aux_node + visited.add(current) + else: + # No more nodes to traverse + break + + +def _ensure_weight_mapping(node: Node) -> None: + """Ensure weight node mapping is computed. Lazily calls precompute if needed.""" + gm = node.graph.owning_module + if "_weight_mapping_computed" not in gm.meta or not gm.meta["_weight_mapping_computed"]: + precompute_weight_node_mapping(gm) + + +def get_weight_node(node: Node) -> Optional[Node]: + """Get the primary weight node for a compute node""" + _ensure_weight_mapping(node) + weight_nodes = node.meta.get("weight_nodes", []) + return weight_nodes[0] if weight_nodes else None + + +def get_weight_nodes(node: Node) -> List[Node]: + """Get all weight nodes for a compute node""" + _ensure_weight_mapping(node) + return node.meta.get("weight_nodes", []) + + +def get_bias_nodes(node: Node) -> List[Node]: + """Get all bias nodes for a compute node""" + _ensure_weight_mapping(node) + return node.meta.get("bias_nodes", []) + + +@dataclass +class WeightInfo: + """Lightweight weight info extracted from a weight node.""" + + node: Node + node_key: str + tensor: torch.Tensor + submod: nn.Module + + +def _weight_node_to_info(weight_node: Node, gm: GraphModule) -> WeightInfo: + """Convert a weight node to WeightInfo.""" + node_key = weight_node.target + tensor = get_param_or_buffer(node_key, gm) + submod = gm.get_submodule(node_key.rpartition(".")[0]) + return WeightInfo(node=weight_node, node_key=node_key, tensor=tensor, submod=submod) + + +def get_weight_info(node: Node) -> Optional[WeightInfo]: + """Extract weight info for the primary weight of a compute node.""" + weight_node = get_weight_node(node) + if weight_node is None: + return None + return _weight_node_to_info(weight_node, node.graph.owning_module) + + +@dataclass +class AllWeightInfos: + """Container for all weight and bias infos of a compute node.""" + + weights: List[WeightInfo] + biases: List[WeightInfo] + + +def get_all_weight_infos(node: Node) -> AllWeightInfos: + """Extract all weight and bias infos for a compute node.""" + gm = node.graph.owning_module + weight_nodes = get_weight_nodes(node) + bias_nodes = get_bias_nodes(node) + + return AllWeightInfos( + weights=[_weight_node_to_info(wn, gm) for wn in weight_nodes], + biases=[_weight_node_to_info(bn, gm) for bn in bias_nodes], + ) + + def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0): """Get a user from a node if the node matches a given op set and num of users.""" if node is None: @@ -515,9 +699,12 @@ def identify_regions_between_residuals(gm: GraphModule) -> List[Node]: return boundary_nodes -def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]: +def get_all_layer_subgraphs(gm: GraphModule) -> tuple[List[LayerSubgraph], set[Node]]: """ - Get subgraphs corresponding to all consecutive layers (attention, MLP, SSM, MoE) in the graph. + Get subgraphs for all consecutive layers (attention, MLP, SSM, MoE) in the graph. + + Pre-computes weight mappings and caches weight shapes for all linear nodes. + Each layer is contained between opening linear layers and a single closing linear layer. Assumptions: 1. each layer (each subgraph) is contained between a list of opening @@ -546,18 +733,32 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]: assert gm.graph.nodes, "Graph is empty" layer_subgraphs = [] linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op)) + + # Pre-compute weight-to-consumer mapping for O(1) weight node lookup + precompute_weight_node_mapping(gm) + + # Cache weight shapes for all linear nodes + for lin_node in linear_nodes: + if "lin_node_shape" not in lin_node.meta: + shape = get_weight_shape(lin_node) + if shape is not None: + lin_node.meta["lin_node_shape"] = shape + + # Find the embedding size from the first linear node + embd = get_weight_shape(linear_nodes[0], dim=-1) + if embd is None: + raise ValueError("Failed to extract embedding size from first linear node") + unprocessed_linear_nodes = set(linear_nodes) assert len(linear_nodes) > 0, "Could not find any linear nodes in the graph" terminating_indices = [-1] last_lin_index = terminating_indices[-1] + 1 - # for each linear node, find its layer subgraph defined as regions between consecutive linear nodes + # For each linear node, find its layer subgraph defined as regions between consecutive linear nodes. while last_lin_index < len(linear_nodes): - # opening is the list of linear nodes - # layer_subgraph is the list of nodes between the opening and closing linear nodes - # closing is the last linear node in the layer - layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices) + layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices, embd=embd) + if layer_subgraph.opening_nodes is not None and len(layer_subgraph.opening_nodes) > 0: unprocessed_linear_nodes -= ( set(layer_subgraph.opening_nodes) @@ -567,7 +768,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]: layer_subgraphs.append(layer_subgraph) last_lin_index = terminating_indices[-1] + 1 - # unprocessed linear nodes can be "simple sharded". + # Unprocessed linear nodes can be "simple sharded". return layer_subgraphs, unprocessed_linear_nodes @@ -805,12 +1006,16 @@ def subgraph( def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]: - """Get the shape of the weight node.""" + """Get weight shape for a linear operation node. Returns None if no weight.""" if not is_any_lin_op(node): return None - s = list(shape(extract_weight_nodes(node).weights[0].node)) - if len(s) == 0: + + weight_node = get_weight_node(node) + if weight_node is None: return None + + s = list(shape(weight_node)) + if is_fp4_op(node): # FP4 weights are packed as uint8 type with 2 FP4 values per element s[-1] *= 2 @@ -823,6 +1028,7 @@ def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[in def get_layer_after_linear_node( linear_nodes: List[Node], terminating_indices: List[int], + embd: int, match_on_shapes: bool = True, enforce_strict_linear_history: bool = True, ) -> LayerSubgraph: @@ -856,37 +1062,42 @@ def get_layer_after_linear_node( Args: linear_nodes: List of linear nodes in the graph. terminating_indices: List of indices of terminating linear nodes. - match_on_shapes: If True, the layer is matched on shapes of the nodes. - If False, the layer is matched on the nodes themselves. + embd: Embedding size for shape matching. + match_on_shapes: If True, match layers on embedding shapes. + enforce_strict_linear_history: If True, enforce strict ordering constraints. + Returns: - LayerSubgraph: The layer subgraph. + LayerSubgraph containing opening nodes, subgraph nodes, and terminating node. """ - def boundary_condition( - node: Node, embd: Optional[int] = None, dim: Optional[int] = None - ) -> bool: - if embd is not None and dim is not None: + def boundary_condition(node: Node, dim: int) -> bool: + if match_on_shapes: + if is_any_lin_op(node): + return node.meta["lin_node_shape"][dim] == embd return ( - # match on embedding size - (is_any_lin_op(node) and get_weight_shape(node, dim=dim) == embd) - or is_any_moe_op(node) + is_any_moe_op(node) or is_op(node, ops=[torch.ops.aten.sym_size, torch.ops.aten.bmm]) + or is_residual_add(node) ) else: return ( is_any_lin_op(node) or is_any_moe_op(node) or is_op(node, ops=[torch.ops.aten.sym_size, torch.ops.aten.bmm]) + or is_residual_add(node) ) - def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int] = None) -> bool: - if embd is not None and dim is not None: - return is_any_lin_op(node) and get_weight_shape(node, dim=dim) == embd + def filter_condition(node: Node, dim: int) -> bool: + if match_on_shapes: + if is_any_lin_op(node): + return node.meta["lin_node_shape"][dim] == embd + return False else: return is_any_lin_op(node) lin_nodes_in_subgraph = [] start_lin_index = terminating_indices[-1] + 1 + while len(lin_nodes_in_subgraph) != 1: if start_lin_index >= len(linear_nodes): terminating_indices.append(len(linear_nodes)) @@ -896,30 +1107,39 @@ def get_layer_after_linear_node( terminating_node=None, layer_type=LayerType.UNKNOWN, ) - if match_on_shapes: - # get embedding size of the opening linear node - embd = get_weight_shape(linear_nodes[start_lin_index], dim=-1) - # partial init boundary_condition and filter_condition - boundary_condition = partial(boundary_condition, embd=embd, dim=0) - filter_condition = partial(filter_condition, embd=embd, dim=0) forward_subgraph = subgraph( - sources=[linear_nodes[start_lin_index]], boundary_condition=boundary_condition + sources=[linear_nodes[start_lin_index]], + boundary_condition=lambda n: boundary_condition(n, dim=0), ) - lin_nodes_in_subgraph = list(filtered_nodes(forward_subgraph, filter_condition)) + lin_nodes_in_subgraph = list( + filtered_nodes(forward_subgraph, lambda n: filter_condition(n, dim=0)) + ) + if len(lin_nodes_in_subgraph) > 1: + # it means that probably we went over the boundary of the layer. + # It may happen e.g., with MoLE (latent MoE), with the closing latent fc2 projection, + # when the subgraph spanned over fc2 "spills" over consecutive layers. + # Then, wrap this single linear node in LayerType.UNKNOWN and return. + terminating_indices.append(start_lin_index) + return LayerSubgraph( + opening_nodes=[linear_nodes[start_lin_index]], + subgraph_nodes=[], + terminating_node=linear_nodes[start_lin_index], + layer_type=LayerType.UNKNOWN, + ) start_lin_index += 1 start_lin_index -= 1 terminating_linear_node = lin_nodes_in_subgraph[0] - # for backward pass, match embedding on the dim=0 - if match_on_shapes: - boundary_condition = partial(boundary_condition, embd=embd, dim=-1) - filter_condition = partial(filter_condition, embd=embd, dim=-1) + # For backward pass, match embedding on dim=-1 backward_subgraph = subgraph( - sinks=[terminating_linear_node], boundary_condition=boundary_condition + sinks=[terminating_linear_node], boundary_condition=lambda n: boundary_condition(n, dim=-1) + ) + + # Get all opening linear nodes + opening_linear_nodes = list( + filtered_nodes(backward_subgraph, lambda n: filter_condition(n, dim=-1)) ) - # get all opening linear nodes - opening_linear_nodes = list(filtered_nodes(backward_subgraph, filter_condition)) if enforce_strict_linear_history: # opening nodes must succeed last terminating node @@ -939,32 +1159,65 @@ def get_layer_after_linear_node( ssm_nodes = list(filtered_nodes(interior_nodes, is_any_ssm_op)) attention_nodes = list(filtered_nodes(interior_nodes, is_any_attention_op)) intermediate_lin_nodes = list(filtered_nodes(interior_nodes, is_any_lin_op)) - - layer_type = LayerType.MLP - min_local_shape = 1 - if len(ssm_nodes) > 0: - assert len(ssm_nodes) == 1, "SSM layer must have exactly one SSM node" - layer_type = LayerType.SSM - # determine head size - min_local_shape = shape(ssm_nodes[0])[-1] - if len(attention_nodes) > 0: - assert len(attention_nodes) == 1, "Attention layer must have exactly one attention node" - layer_type = LayerType.ATTENTION - # determine head size - min_local_shape = shape(attention_nodes[0])[-1] - if len(intermediate_lin_nodes) > 0: - assert len(intermediate_lin_nodes) == 2, ( - "MLA layer must have exactly two intermediate linear nodes" + intermediate_weight_nodes = list( + filtered_nodes( + interior_nodes, lambda n: is_weight_node(n) and not is_any_lin_op(list(n.users)[0]) ) - assert len(attention_nodes) == 1, "MLA layer must have exactly one attention node" - layer_type = LayerType.MLA + ) + + #################################################### + ########## LAYER TYPE CLASSIFICATION ############### + #################################################### + + def classify_layer_type() -> [LayerType, int]: + if len(ssm_nodes) + len(attention_nodes) > 1: + return LayerType.UNKNOWN, 1 + + if len(attention_nodes) == 1: + head_size = shape(attention_nodes[0])[-1] + # check if this is MLA: + # these two intermediate linear nodes are the latent q and kv projections. + if len(intermediate_lin_nodes) == 2: + # MLA has a RMS norm inside, so it should have one (or two, couning biaas) + # intermediate weight nodes + if len(intermediate_weight_nodes) not in [1, 2]: + return LayerType.UNKNOWN, 1 + return LayerType.MLA, head_size + else: + if len(intermediate_lin_nodes) != 0: + return LayerType.UNKNOWN, 1 + return LayerType.ATTENTION, head_size + + if len(ssm_nodes) == 1: + head_size = shape(ssm_nodes[0])[-1] + # Mamba layers should not have any intermediate linear nodes. + if len(intermediate_lin_nodes) > 0: + return LayerType.UNKNOWN, 1 + # Mamba layer should have 3 to 6 intermediate weight nodes: + # - conv1d weight + # - A (A_log) + # - D + # - conv1d bias [optional] + # - dt_bias [optional] + # - RMS norm [optional] + if len(intermediate_weight_nodes) not in list(range(3, 7)): + return LayerType.UNKNOWN, 1 + return LayerType.SSM, head_size + + # if we reach here, it means the layer is a MLP. + # MLP should not have any intermediate linear or weight nodes. + if len(intermediate_lin_nodes) > 0 or len(intermediate_weight_nodes) > 0: + return LayerType.UNKNOWN, 1 + return LayerType.MLP, 1 + + layer_type, head_size = classify_layer_type() layer_subgraph = LayerSubgraph( opening_nodes=opening_linear_nodes, subgraph_nodes=interior_nodes, terminating_node=terminating_linear_node, layer_type=layer_type, - min_local_shape=min_local_shape, + min_local_shape=head_size, ) assert linear_nodes[start_lin_index] in opening_linear_nodes, ( f"Linear node not found in opening linear nodes - " @@ -986,7 +1239,7 @@ def get_layer_after_linear_node( "ill-formed layer subgraph" ) terminating_indices.append(terminating_index) - # otherwise, we are done. We processed the last linear node. + return layer_subgraph @@ -1001,9 +1254,13 @@ def shape(node: Node) -> Tuple[int, ...]: def get_weight_tensor(node: Node) -> torch.Tensor: - """Extract the weight tensor from a node within a GraphModule.""" - weight_nodes = extract_weight_nodes(node) - return weight_nodes.weights[0].tensor + """Extract the weight tensor from a compute node.""" + weight_node = get_weight_node(node) + if weight_node is None: + raise ValueError(f"Node {node.name} has no weight") + + gm = node.graph.owning_module + return get_param_or_buffer(weight_node.target, gm) def draw_graph(gm: GraphModule, filename: str):