mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[TRTLLM-10673][feat] Improved layer classification for sharding (#10718)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
parent
925d911fc0
commit
d90a8e5700
@ -17,8 +17,8 @@ from ...custom_ops.quant import (
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.node_utils import (
|
||||
extract_weight_nodes,
|
||||
get_quantization_params_from_linear_node,
|
||||
get_weight_info,
|
||||
is_bmm_op,
|
||||
is_linear_op,
|
||||
)
|
||||
@ -141,9 +141,10 @@ class Quantization(BaseTransform):
|
||||
|
||||
The state_dict is also updated to contain the sharded weights.
|
||||
"""
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
assert len(weight_nodes.weights) == 1, "Expected exactly one weight node"
|
||||
lin_weight = weight_nodes.weights[0]
|
||||
lin_weight = get_weight_info(node)
|
||||
if lin_weight is None:
|
||||
raise ValueError(f"Linear node {node.name} has no weight")
|
||||
|
||||
new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False)
|
||||
modname, _, attrname = lin_weight.node_key.rpartition(".")
|
||||
|
||||
|
||||
@ -40,11 +40,10 @@ from ...utils.node_utils import (
|
||||
LayerType,
|
||||
bfs,
|
||||
extract_weight_name,
|
||||
extract_weight_nodes,
|
||||
filtered_nodes,
|
||||
get_all_layer_subgraphs,
|
||||
get_all_weight_infos,
|
||||
get_all_weights_in_subgraph,
|
||||
get_layer_after_linear_node,
|
||||
is_any_attention_op,
|
||||
is_any_lin_op,
|
||||
is_any_moe_op,
|
||||
@ -1296,6 +1295,11 @@ def _shard_parameter_node(
|
||||
rank, world_size = config.rank, config.world_size
|
||||
allreduce_strategy = config.allreduce_strategy.name
|
||||
|
||||
if "sharded" in node.meta and node.meta["sharded"]:
|
||||
# Node was already sharded, skip
|
||||
return
|
||||
node.meta["sharded"] = True
|
||||
|
||||
num_users = num_users_of_weight_node(node)
|
||||
if num_users > 1 or num_users == 0:
|
||||
ad_logger.warning(
|
||||
@ -1304,12 +1308,17 @@ def _shard_parameter_node(
|
||||
return
|
||||
|
||||
# Shard weight using the unified function (also updates the parameter)
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
for weight_node in weight_nodes.weights:
|
||||
all_weight_infos = get_all_weight_infos(node)
|
||||
# Parametrized nodes must have at least one weight (for debugging)
|
||||
assert len(all_weight_infos.weights) > 0, (
|
||||
f"Node {node.name} has no weights - weight mapping may be incorrect"
|
||||
)
|
||||
|
||||
for weight_info in all_weight_infos.weights:
|
||||
_, weight_new_shape = shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=weight_node.tensor,
|
||||
param_key=weight_node.node_key,
|
||||
weight_tensor=weight_info.tensor,
|
||||
param_key=weight_info.node_key,
|
||||
dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
@ -1319,29 +1328,29 @@ def _shard_parameter_node(
|
||||
if quantization_cb is not None:
|
||||
quantization_cb(
|
||||
gm=gm,
|
||||
submod=weight_node.submod,
|
||||
submod=weight_info.submod,
|
||||
node=node,
|
||||
weight_key=weight_node.node_key,
|
||||
weight_key=weight_info.node_key,
|
||||
weight_new_shape=weight_new_shape,
|
||||
dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
for bias_node in weight_nodes.biases:
|
||||
for bias_info in all_weight_infos.biases:
|
||||
if dim == 0:
|
||||
# update bias for dim 0 --> we can handle it like the weight
|
||||
shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=bias_node.tensor,
|
||||
param_key=bias_node.node_key,
|
||||
weight_tensor=bias_info.tensor,
|
||||
param_key=bias_info.node_key,
|
||||
dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
min_local_shape=min_local_shape,
|
||||
fused_weight_dims=fused_weight_dims,
|
||||
)
|
||||
elif bias_node is not None and rank != world_size - 1:
|
||||
elif rank != world_size - 1:
|
||||
# update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
|
||||
# double counting it. For all other we will delete the bias.
|
||||
args = list(node.args)
|
||||
@ -1349,10 +1358,10 @@ def _shard_parameter_node(
|
||||
args[2] = None
|
||||
node.args = tuple(args)
|
||||
gm.graph.erase_node(node_bias)
|
||||
bias_param_name = bias_node.node_key.rpartition(".")[-1]
|
||||
setattr(bias_node.submod, bias_param_name, None)
|
||||
bias_param_name = bias_info.node_key.rpartition(".")[-1]
|
||||
setattr(bias_info.submod, bias_param_name, None)
|
||||
gm._register_load_state_dict_pre_hook(
|
||||
partial(_load_hook_remove, param_key=bias_node.node_key)
|
||||
partial(_load_hook_remove, param_key=bias_info.node_key)
|
||||
)
|
||||
|
||||
# # # column shard with no gather: the output is sharded
|
||||
@ -2295,47 +2304,37 @@ def detect_sharding_from_config(
|
||||
raise ValueError(f"Unsupported sharding source: {source}")
|
||||
tp_plan = config["tp_plan"]
|
||||
|
||||
# If the node is inside the attention module, we need to set min_local_shape to the
|
||||
# head_dim - otherwise, we would risk splitting the heads into smaller shards.
|
||||
# TODO: is there a better way to check if we are in attention module?
|
||||
attn_names = [
|
||||
"attention",
|
||||
"Attention",
|
||||
"attn",
|
||||
"Attn",
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
]
|
||||
|
||||
num_shards = 0
|
||||
num_simple_shards = 0
|
||||
num_row_col_shards = 0
|
||||
num_attention_shards = 0
|
||||
num_ssm_shards = 0
|
||||
head_dim = -1
|
||||
linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op))
|
||||
|
||||
# use layer_subgraphs to determine the layer_type
|
||||
# and check the validity of the sharding transform
|
||||
layer_subgraphs, unprocessed_linear_nodes = get_all_layer_subgraphs(gm)
|
||||
|
||||
for lin_node in linear_nodes:
|
||||
# use node's weight name to get the module name
|
||||
weight_name = extract_weight_name(lin_node)
|
||||
|
||||
if any(attn_name in weight_name for attn_name in attn_names):
|
||||
# find the next attention node and infer the head_dim
|
||||
next_attention_node, _ = bfs(
|
||||
lin_node, is_any_attention_op, attr_next="users", include_root=False
|
||||
)
|
||||
if next_attention_node is None:
|
||||
# this is the last attention node in the graph. Take the previously found head_dim
|
||||
assert head_dim != -1, "Head dim not found for the last attention node"
|
||||
else:
|
||||
head_dim = shape(next_attention_node)[-1]
|
||||
min_local_shape = head_dim
|
||||
layer_type = LayerType.ATTENTION
|
||||
# get the parent layer_subgraph
|
||||
layer_subgraph = [
|
||||
layer
|
||||
for layer in layer_subgraphs
|
||||
if lin_node in layer.opening_nodes or lin_node == layer.terminating_node
|
||||
]
|
||||
if len(layer_subgraph) == 1:
|
||||
layer_subgraph = layer_subgraph[0]
|
||||
layer_type = layer_subgraph.layer_type
|
||||
else:
|
||||
min_local_shape = 1
|
||||
layer_type = LayerType.MLP
|
||||
if lin_node in unprocessed_linear_nodes:
|
||||
layer_type = LayerType.UNKNOWN
|
||||
else:
|
||||
ad_logger.warning(
|
||||
f"Failed to find the parent layer_subgraph for linear node {lin_node}. "
|
||||
f"May result in incorrect sharding."
|
||||
)
|
||||
|
||||
# use regex to find if module_name matches any of the keys in sharding_config
|
||||
for key in tp_plan.keys():
|
||||
@ -2349,11 +2348,6 @@ def detect_sharding_from_config(
|
||||
# we have a match. Get the config for this layer
|
||||
config = tp_plan[key]
|
||||
|
||||
if config in ["colwise", "mamba"]:
|
||||
cur_node_index = linear_nodes.index(lin_node)
|
||||
layer_subgraph = get_layer_after_linear_node(
|
||||
linear_nodes, [cur_node_index - 1], enforce_strict_linear_history=False
|
||||
)
|
||||
if config == "colwise":
|
||||
_process_column_sharding(
|
||||
layer_subgraph=layer_subgraph,
|
||||
@ -2366,7 +2360,6 @@ def detect_sharding_from_config(
|
||||
split_dim=SplitDimension.ROW,
|
||||
config=transform_container.config,
|
||||
dist_op="all_reduce",
|
||||
min_local_shape=min_local_shape,
|
||||
layer_type=layer_type,
|
||||
)
|
||||
):
|
||||
@ -2393,7 +2386,6 @@ def detect_sharding_from_config(
|
||||
split_dim=SplitDimension.COLUMN,
|
||||
config=transform_container.config,
|
||||
dist_op=None,
|
||||
min_local_shape=min_local_shape,
|
||||
layer_type=layer_type,
|
||||
)
|
||||
)
|
||||
@ -2404,7 +2396,6 @@ def detect_sharding_from_config(
|
||||
split_dim=SplitDimension.ROW,
|
||||
config=transform_container.config,
|
||||
dist_op="all_reduce",
|
||||
min_local_shape=min_local_shape,
|
||||
layer_type=layer_type,
|
||||
)
|
||||
):
|
||||
@ -2423,7 +2414,6 @@ def detect_sharding_from_config(
|
||||
split_dim=SplitDimension.COLUMN,
|
||||
config=transform_container.config,
|
||||
dist_op="all_gather",
|
||||
min_local_shape=1,
|
||||
layer_type=layer_type,
|
||||
)
|
||||
):
|
||||
@ -2536,7 +2526,7 @@ def detect_column_row_shard(
|
||||
attention_nodes = list(filtered_nodes(layer_subgraph, is_any_attention_op))
|
||||
min_local_shape = 1
|
||||
|
||||
if config.simple_shard_only:
|
||||
if config.simple_shard_only or layer.layer_type == LayerType.UNKNOWN:
|
||||
ad_logger.debug(
|
||||
f"Forcing Simple Shard on nodes: {nodes_linear} with layer type: {layer.layer_type}"
|
||||
)
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
import operator
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -153,10 +152,19 @@ def get_all_weights_in_subgraph(
|
||||
|
||||
|
||||
def extract_weight_name(node: Node) -> Union[str, bool]:
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
if len(weight_nodes.weights) == 0:
|
||||
"""
|
||||
Extract the weight parameter name for a compute node.
|
||||
|
||||
Args:
|
||||
node: Compute node (linear, MoE, SSM, etc.)
|
||||
|
||||
Returns:
|
||||
Weight parameter name (str), or False if no weight exists.
|
||||
"""
|
||||
weight_node = get_weight_node(node)
|
||||
if weight_node is None:
|
||||
return False
|
||||
return weight_nodes.weights[0].node_key
|
||||
return weight_node.target
|
||||
|
||||
|
||||
def get_param_or_buffer(tensor_name: str, gm: GraphModule) -> torch.Tensor:
|
||||
@ -248,8 +256,16 @@ def extract_weight_nodes(node: Node) -> WeightNodes:
|
||||
|
||||
|
||||
def num_users_of_weight_node(node: Node) -> int:
|
||||
"""Returns the number of users of the weight node of the given parametrized node."""
|
||||
weight_node = extract_weight_nodes(node).weights[0].node
|
||||
"""
|
||||
Get the number of users of the weight node.
|
||||
|
||||
Args:
|
||||
node: Compute node (linear, MoE, SSM, etc.)
|
||||
|
||||
Returns:
|
||||
Number of users of the primary weight node, or 0 if no weight exists.
|
||||
"""
|
||||
weight_node = get_weight_node(node)
|
||||
return len(weight_node.users) if weight_node is not None else 0
|
||||
|
||||
|
||||
@ -373,6 +389,13 @@ def is_any_moe_op(node: Node) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_residual_add(node: Node) -> bool:
|
||||
if is_op(node, torch.ops.aten.add):
|
||||
if len(list(filtered_nodes(node.args, is_any_lin_op))) == 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_any_ssm_op(node: Node) -> bool:
|
||||
return is_op(
|
||||
node,
|
||||
@ -446,6 +469,167 @@ def is_weight_node(node: Node) -> bool:
|
||||
return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0
|
||||
|
||||
|
||||
# Auxiliary ops that may appear between a weight node and its consumer compute node
|
||||
_WEIGHT_AUX_OPS = frozenset(
|
||||
{
|
||||
torch.ops.aten.to.dtype,
|
||||
torch.ops.aten.view.default,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def precompute_weight_node_mapping(gm: GraphModule) -> None:
|
||||
"""
|
||||
Pre-compute weight-to-consumer mapping for all weight nodes in the graph.
|
||||
|
||||
For each weight node (get_attr), finds the consumer compute node by traversing
|
||||
through auxiliary ops (to.dtype, view.default). Stores the mapping in consumer
|
||||
node's metadata:
|
||||
- node.meta["weight_nodes"]: list of weight nodes (non-bias)
|
||||
- node.meta["bias_nodes"]: list of bias nodes
|
||||
|
||||
This enables O(1) weight node lookup instead of O(depth) backward traversal.
|
||||
Called automatically on first weight lookup via lazy initialization.
|
||||
|
||||
GUARANTEES (verified by assertions for debugging):
|
||||
- Called exactly once per GraphModule
|
||||
- No duplicate weight/bias nodes in any consumer's lists
|
||||
- Each weight node mapped to exactly one consumer
|
||||
"""
|
||||
# Early return if already computed
|
||||
if "_weight_mapping_computed" in gm.meta:
|
||||
return
|
||||
gm.meta["_weight_mapping_computed"] = True
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if not is_weight_node(node):
|
||||
continue
|
||||
|
||||
is_bias = node.target.endswith("bias")
|
||||
|
||||
# Find the consumer compute node by traversing through auxiliary ops
|
||||
current = node
|
||||
visited = {current}
|
||||
|
||||
while True:
|
||||
# Get users of current node
|
||||
users = list(current.users.keys())
|
||||
if not users:
|
||||
break
|
||||
|
||||
# Check if any user is a compute node (not an auxiliary op)
|
||||
consumer_found = None
|
||||
aux_node = None
|
||||
|
||||
for user in users:
|
||||
if is_bias:
|
||||
if "bias_nodes" not in user.meta:
|
||||
user.meta["bias_nodes"] = []
|
||||
# ASSERTION: Each weight node should be mapped exactly once
|
||||
assert node not in user.meta["bias_nodes"], (
|
||||
f"Duplicate bias node {node.name} found for consumer {user.name}"
|
||||
)
|
||||
user.meta["bias_nodes"].append(node)
|
||||
else:
|
||||
if "weight_nodes" not in user.meta:
|
||||
user.meta["weight_nodes"] = []
|
||||
# ASSERTION: Each weight node should be mapped exactly once
|
||||
assert node not in user.meta["weight_nodes"], (
|
||||
f"Duplicate weight node {node.name} found for consumer {user.name}"
|
||||
)
|
||||
user.meta["weight_nodes"].append(node)
|
||||
if user.target in _WEIGHT_AUX_OPS:
|
||||
# This is an auxiliary op, continue traversing
|
||||
aux_node = user
|
||||
else:
|
||||
# This is a potential consumer compute node
|
||||
consumer_found = user
|
||||
break
|
||||
|
||||
if consumer_found is not None:
|
||||
# Found the consumer, return
|
||||
break
|
||||
elif aux_node is not None and aux_node not in visited:
|
||||
# Continue through auxiliary op
|
||||
current = aux_node
|
||||
visited.add(current)
|
||||
else:
|
||||
# No more nodes to traverse
|
||||
break
|
||||
|
||||
|
||||
def _ensure_weight_mapping(node: Node) -> None:
|
||||
"""Ensure weight node mapping is computed. Lazily calls precompute if needed."""
|
||||
gm = node.graph.owning_module
|
||||
if "_weight_mapping_computed" not in gm.meta or not gm.meta["_weight_mapping_computed"]:
|
||||
precompute_weight_node_mapping(gm)
|
||||
|
||||
|
||||
def get_weight_node(node: Node) -> Optional[Node]:
|
||||
"""Get the primary weight node for a compute node"""
|
||||
_ensure_weight_mapping(node)
|
||||
weight_nodes = node.meta.get("weight_nodes", [])
|
||||
return weight_nodes[0] if weight_nodes else None
|
||||
|
||||
|
||||
def get_weight_nodes(node: Node) -> List[Node]:
|
||||
"""Get all weight nodes for a compute node"""
|
||||
_ensure_weight_mapping(node)
|
||||
return node.meta.get("weight_nodes", [])
|
||||
|
||||
|
||||
def get_bias_nodes(node: Node) -> List[Node]:
|
||||
"""Get all bias nodes for a compute node"""
|
||||
_ensure_weight_mapping(node)
|
||||
return node.meta.get("bias_nodes", [])
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightInfo:
|
||||
"""Lightweight weight info extracted from a weight node."""
|
||||
|
||||
node: Node
|
||||
node_key: str
|
||||
tensor: torch.Tensor
|
||||
submod: nn.Module
|
||||
|
||||
|
||||
def _weight_node_to_info(weight_node: Node, gm: GraphModule) -> WeightInfo:
|
||||
"""Convert a weight node to WeightInfo."""
|
||||
node_key = weight_node.target
|
||||
tensor = get_param_or_buffer(node_key, gm)
|
||||
submod = gm.get_submodule(node_key.rpartition(".")[0])
|
||||
return WeightInfo(node=weight_node, node_key=node_key, tensor=tensor, submod=submod)
|
||||
|
||||
|
||||
def get_weight_info(node: Node) -> Optional[WeightInfo]:
|
||||
"""Extract weight info for the primary weight of a compute node."""
|
||||
weight_node = get_weight_node(node)
|
||||
if weight_node is None:
|
||||
return None
|
||||
return _weight_node_to_info(weight_node, node.graph.owning_module)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllWeightInfos:
|
||||
"""Container for all weight and bias infos of a compute node."""
|
||||
|
||||
weights: List[WeightInfo]
|
||||
biases: List[WeightInfo]
|
||||
|
||||
|
||||
def get_all_weight_infos(node: Node) -> AllWeightInfos:
|
||||
"""Extract all weight and bias infos for a compute node."""
|
||||
gm = node.graph.owning_module
|
||||
weight_nodes = get_weight_nodes(node)
|
||||
bias_nodes = get_bias_nodes(node)
|
||||
|
||||
return AllWeightInfos(
|
||||
weights=[_weight_node_to_info(wn, gm) for wn in weight_nodes],
|
||||
biases=[_weight_node_to_info(bn, gm) for bn in bias_nodes],
|
||||
)
|
||||
|
||||
|
||||
def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0):
|
||||
"""Get a user from a node if the node matches a given op set and num of users."""
|
||||
if node is None:
|
||||
@ -515,9 +699,12 @@ def identify_regions_between_residuals(gm: GraphModule) -> List[Node]:
|
||||
return boundary_nodes
|
||||
|
||||
|
||||
def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
|
||||
def get_all_layer_subgraphs(gm: GraphModule) -> tuple[List[LayerSubgraph], set[Node]]:
|
||||
"""
|
||||
Get subgraphs corresponding to all consecutive layers (attention, MLP, SSM, MoE) in the graph.
|
||||
Get subgraphs for all consecutive layers (attention, MLP, SSM, MoE) in the graph.
|
||||
|
||||
Pre-computes weight mappings and caches weight shapes for all linear nodes.
|
||||
Each layer is contained between opening linear layers and a single closing linear layer.
|
||||
|
||||
Assumptions:
|
||||
1. each layer (each subgraph) is contained between a list of opening
|
||||
@ -546,18 +733,32 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
|
||||
assert gm.graph.nodes, "Graph is empty"
|
||||
layer_subgraphs = []
|
||||
linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op))
|
||||
|
||||
# Pre-compute weight-to-consumer mapping for O(1) weight node lookup
|
||||
precompute_weight_node_mapping(gm)
|
||||
|
||||
# Cache weight shapes for all linear nodes
|
||||
for lin_node in linear_nodes:
|
||||
if "lin_node_shape" not in lin_node.meta:
|
||||
shape = get_weight_shape(lin_node)
|
||||
if shape is not None:
|
||||
lin_node.meta["lin_node_shape"] = shape
|
||||
|
||||
# Find the embedding size from the first linear node
|
||||
embd = get_weight_shape(linear_nodes[0], dim=-1)
|
||||
if embd is None:
|
||||
raise ValueError("Failed to extract embedding size from first linear node")
|
||||
|
||||
unprocessed_linear_nodes = set(linear_nodes)
|
||||
assert len(linear_nodes) > 0, "Could not find any linear nodes in the graph"
|
||||
|
||||
terminating_indices = [-1]
|
||||
last_lin_index = terminating_indices[-1] + 1
|
||||
|
||||
# for each linear node, find its layer subgraph defined as regions between consecutive linear nodes
|
||||
# For each linear node, find its layer subgraph defined as regions between consecutive linear nodes.
|
||||
while last_lin_index < len(linear_nodes):
|
||||
# opening is the list of linear nodes
|
||||
# layer_subgraph is the list of nodes between the opening and closing linear nodes
|
||||
# closing is the last linear node in the layer
|
||||
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices)
|
||||
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices, embd=embd)
|
||||
|
||||
if layer_subgraph.opening_nodes is not None and len(layer_subgraph.opening_nodes) > 0:
|
||||
unprocessed_linear_nodes -= (
|
||||
set(layer_subgraph.opening_nodes)
|
||||
@ -567,7 +768,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
|
||||
layer_subgraphs.append(layer_subgraph)
|
||||
last_lin_index = terminating_indices[-1] + 1
|
||||
|
||||
# unprocessed linear nodes can be "simple sharded".
|
||||
# Unprocessed linear nodes can be "simple sharded".
|
||||
return layer_subgraphs, unprocessed_linear_nodes
|
||||
|
||||
|
||||
@ -805,12 +1006,16 @@ def subgraph(
|
||||
|
||||
|
||||
def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]:
|
||||
"""Get the shape of the weight node."""
|
||||
"""Get weight shape for a linear operation node. Returns None if no weight."""
|
||||
if not is_any_lin_op(node):
|
||||
return None
|
||||
s = list(shape(extract_weight_nodes(node).weights[0].node))
|
||||
if len(s) == 0:
|
||||
|
||||
weight_node = get_weight_node(node)
|
||||
if weight_node is None:
|
||||
return None
|
||||
|
||||
s = list(shape(weight_node))
|
||||
|
||||
if is_fp4_op(node):
|
||||
# FP4 weights are packed as uint8 type with 2 FP4 values per element
|
||||
s[-1] *= 2
|
||||
@ -823,6 +1028,7 @@ def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[in
|
||||
def get_layer_after_linear_node(
|
||||
linear_nodes: List[Node],
|
||||
terminating_indices: List[int],
|
||||
embd: int,
|
||||
match_on_shapes: bool = True,
|
||||
enforce_strict_linear_history: bool = True,
|
||||
) -> LayerSubgraph:
|
||||
@ -856,37 +1062,42 @@ def get_layer_after_linear_node(
|
||||
Args:
|
||||
linear_nodes: List of linear nodes in the graph.
|
||||
terminating_indices: List of indices of terminating linear nodes.
|
||||
match_on_shapes: If True, the layer is matched on shapes of the nodes.
|
||||
If False, the layer is matched on the nodes themselves.
|
||||
embd: Embedding size for shape matching.
|
||||
match_on_shapes: If True, match layers on embedding shapes.
|
||||
enforce_strict_linear_history: If True, enforce strict ordering constraints.
|
||||
|
||||
Returns:
|
||||
LayerSubgraph: The layer subgraph.
|
||||
LayerSubgraph containing opening nodes, subgraph nodes, and terminating node.
|
||||
"""
|
||||
|
||||
def boundary_condition(
|
||||
node: Node, embd: Optional[int] = None, dim: Optional[int] = None
|
||||
) -> bool:
|
||||
if embd is not None and dim is not None:
|
||||
def boundary_condition(node: Node, dim: int) -> bool:
|
||||
if match_on_shapes:
|
||||
if is_any_lin_op(node):
|
||||
return node.meta["lin_node_shape"][dim] == embd
|
||||
return (
|
||||
# match on embedding size
|
||||
(is_any_lin_op(node) and get_weight_shape(node, dim=dim) == embd)
|
||||
or is_any_moe_op(node)
|
||||
is_any_moe_op(node)
|
||||
or is_op(node, ops=[torch.ops.aten.sym_size, torch.ops.aten.bmm])
|
||||
or is_residual_add(node)
|
||||
)
|
||||
else:
|
||||
return (
|
||||
is_any_lin_op(node)
|
||||
or is_any_moe_op(node)
|
||||
or is_op(node, ops=[torch.ops.aten.sym_size, torch.ops.aten.bmm])
|
||||
or is_residual_add(node)
|
||||
)
|
||||
|
||||
def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int] = None) -> bool:
|
||||
if embd is not None and dim is not None:
|
||||
return is_any_lin_op(node) and get_weight_shape(node, dim=dim) == embd
|
||||
def filter_condition(node: Node, dim: int) -> bool:
|
||||
if match_on_shapes:
|
||||
if is_any_lin_op(node):
|
||||
return node.meta["lin_node_shape"][dim] == embd
|
||||
return False
|
||||
else:
|
||||
return is_any_lin_op(node)
|
||||
|
||||
lin_nodes_in_subgraph = []
|
||||
start_lin_index = terminating_indices[-1] + 1
|
||||
|
||||
while len(lin_nodes_in_subgraph) != 1:
|
||||
if start_lin_index >= len(linear_nodes):
|
||||
terminating_indices.append(len(linear_nodes))
|
||||
@ -896,30 +1107,39 @@ def get_layer_after_linear_node(
|
||||
terminating_node=None,
|
||||
layer_type=LayerType.UNKNOWN,
|
||||
)
|
||||
if match_on_shapes:
|
||||
# get embedding size of the opening linear node
|
||||
embd = get_weight_shape(linear_nodes[start_lin_index], dim=-1)
|
||||
# partial init boundary_condition and filter_condition
|
||||
boundary_condition = partial(boundary_condition, embd=embd, dim=0)
|
||||
filter_condition = partial(filter_condition, embd=embd, dim=0)
|
||||
|
||||
forward_subgraph = subgraph(
|
||||
sources=[linear_nodes[start_lin_index]], boundary_condition=boundary_condition
|
||||
sources=[linear_nodes[start_lin_index]],
|
||||
boundary_condition=lambda n: boundary_condition(n, dim=0),
|
||||
)
|
||||
lin_nodes_in_subgraph = list(filtered_nodes(forward_subgraph, filter_condition))
|
||||
lin_nodes_in_subgraph = list(
|
||||
filtered_nodes(forward_subgraph, lambda n: filter_condition(n, dim=0))
|
||||
)
|
||||
if len(lin_nodes_in_subgraph) > 1:
|
||||
# it means that probably we went over the boundary of the layer.
|
||||
# It may happen e.g., with MoLE (latent MoE), with the closing latent fc2 projection,
|
||||
# when the subgraph spanned over fc2 "spills" over consecutive layers.
|
||||
# Then, wrap this single linear node in LayerType.UNKNOWN and return.
|
||||
terminating_indices.append(start_lin_index)
|
||||
return LayerSubgraph(
|
||||
opening_nodes=[linear_nodes[start_lin_index]],
|
||||
subgraph_nodes=[],
|
||||
terminating_node=linear_nodes[start_lin_index],
|
||||
layer_type=LayerType.UNKNOWN,
|
||||
)
|
||||
start_lin_index += 1
|
||||
start_lin_index -= 1
|
||||
terminating_linear_node = lin_nodes_in_subgraph[0]
|
||||
|
||||
# for backward pass, match embedding on the dim=0
|
||||
if match_on_shapes:
|
||||
boundary_condition = partial(boundary_condition, embd=embd, dim=-1)
|
||||
filter_condition = partial(filter_condition, embd=embd, dim=-1)
|
||||
# For backward pass, match embedding on dim=-1
|
||||
backward_subgraph = subgraph(
|
||||
sinks=[terminating_linear_node], boundary_condition=boundary_condition
|
||||
sinks=[terminating_linear_node], boundary_condition=lambda n: boundary_condition(n, dim=-1)
|
||||
)
|
||||
|
||||
# Get all opening linear nodes
|
||||
opening_linear_nodes = list(
|
||||
filtered_nodes(backward_subgraph, lambda n: filter_condition(n, dim=-1))
|
||||
)
|
||||
# get all opening linear nodes
|
||||
opening_linear_nodes = list(filtered_nodes(backward_subgraph, filter_condition))
|
||||
|
||||
if enforce_strict_linear_history:
|
||||
# opening nodes must succeed last terminating node
|
||||
@ -939,32 +1159,65 @@ def get_layer_after_linear_node(
|
||||
ssm_nodes = list(filtered_nodes(interior_nodes, is_any_ssm_op))
|
||||
attention_nodes = list(filtered_nodes(interior_nodes, is_any_attention_op))
|
||||
intermediate_lin_nodes = list(filtered_nodes(interior_nodes, is_any_lin_op))
|
||||
|
||||
layer_type = LayerType.MLP
|
||||
min_local_shape = 1
|
||||
if len(ssm_nodes) > 0:
|
||||
assert len(ssm_nodes) == 1, "SSM layer must have exactly one SSM node"
|
||||
layer_type = LayerType.SSM
|
||||
# determine head size
|
||||
min_local_shape = shape(ssm_nodes[0])[-1]
|
||||
if len(attention_nodes) > 0:
|
||||
assert len(attention_nodes) == 1, "Attention layer must have exactly one attention node"
|
||||
layer_type = LayerType.ATTENTION
|
||||
# determine head size
|
||||
min_local_shape = shape(attention_nodes[0])[-1]
|
||||
if len(intermediate_lin_nodes) > 0:
|
||||
assert len(intermediate_lin_nodes) == 2, (
|
||||
"MLA layer must have exactly two intermediate linear nodes"
|
||||
intermediate_weight_nodes = list(
|
||||
filtered_nodes(
|
||||
interior_nodes, lambda n: is_weight_node(n) and not is_any_lin_op(list(n.users)[0])
|
||||
)
|
||||
assert len(attention_nodes) == 1, "MLA layer must have exactly one attention node"
|
||||
layer_type = LayerType.MLA
|
||||
)
|
||||
|
||||
####################################################
|
||||
########## LAYER TYPE CLASSIFICATION ###############
|
||||
####################################################
|
||||
|
||||
def classify_layer_type() -> [LayerType, int]:
|
||||
if len(ssm_nodes) + len(attention_nodes) > 1:
|
||||
return LayerType.UNKNOWN, 1
|
||||
|
||||
if len(attention_nodes) == 1:
|
||||
head_size = shape(attention_nodes[0])[-1]
|
||||
# check if this is MLA:
|
||||
# these two intermediate linear nodes are the latent q and kv projections.
|
||||
if len(intermediate_lin_nodes) == 2:
|
||||
# MLA has a RMS norm inside, so it should have one (or two, couning biaas)
|
||||
# intermediate weight nodes
|
||||
if len(intermediate_weight_nodes) not in [1, 2]:
|
||||
return LayerType.UNKNOWN, 1
|
||||
return LayerType.MLA, head_size
|
||||
else:
|
||||
if len(intermediate_lin_nodes) != 0:
|
||||
return LayerType.UNKNOWN, 1
|
||||
return LayerType.ATTENTION, head_size
|
||||
|
||||
if len(ssm_nodes) == 1:
|
||||
head_size = shape(ssm_nodes[0])[-1]
|
||||
# Mamba layers should not have any intermediate linear nodes.
|
||||
if len(intermediate_lin_nodes) > 0:
|
||||
return LayerType.UNKNOWN, 1
|
||||
# Mamba layer should have 3 to 6 intermediate weight nodes:
|
||||
# - conv1d weight
|
||||
# - A (A_log)
|
||||
# - D
|
||||
# - conv1d bias [optional]
|
||||
# - dt_bias [optional]
|
||||
# - RMS norm [optional]
|
||||
if len(intermediate_weight_nodes) not in list(range(3, 7)):
|
||||
return LayerType.UNKNOWN, 1
|
||||
return LayerType.SSM, head_size
|
||||
|
||||
# if we reach here, it means the layer is a MLP.
|
||||
# MLP should not have any intermediate linear or weight nodes.
|
||||
if len(intermediate_lin_nodes) > 0 or len(intermediate_weight_nodes) > 0:
|
||||
return LayerType.UNKNOWN, 1
|
||||
return LayerType.MLP, 1
|
||||
|
||||
layer_type, head_size = classify_layer_type()
|
||||
|
||||
layer_subgraph = LayerSubgraph(
|
||||
opening_nodes=opening_linear_nodes,
|
||||
subgraph_nodes=interior_nodes,
|
||||
terminating_node=terminating_linear_node,
|
||||
layer_type=layer_type,
|
||||
min_local_shape=min_local_shape,
|
||||
min_local_shape=head_size,
|
||||
)
|
||||
assert linear_nodes[start_lin_index] in opening_linear_nodes, (
|
||||
f"Linear node not found in opening linear nodes - "
|
||||
@ -986,7 +1239,7 @@ def get_layer_after_linear_node(
|
||||
"ill-formed layer subgraph"
|
||||
)
|
||||
terminating_indices.append(terminating_index)
|
||||
# otherwise, we are done. We processed the last linear node.
|
||||
|
||||
return layer_subgraph
|
||||
|
||||
|
||||
@ -1001,9 +1254,13 @@ def shape(node: Node) -> Tuple[int, ...]:
|
||||
|
||||
|
||||
def get_weight_tensor(node: Node) -> torch.Tensor:
|
||||
"""Extract the weight tensor from a node within a GraphModule."""
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
return weight_nodes.weights[0].tensor
|
||||
"""Extract the weight tensor from a compute node."""
|
||||
weight_node = get_weight_node(node)
|
||||
if weight_node is None:
|
||||
raise ValueError(f"Node {node.name} has no weight")
|
||||
|
||||
gm = node.graph.owning_module
|
||||
return get_param_or_buffer(weight_node.target, gm)
|
||||
|
||||
|
||||
def draw_graph(gm: GraphModule, filename: str):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user