mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
This commit is contained in:
parent
f320bc8a9c
commit
860054c859
@ -17,8 +17,9 @@ from ...custom_ops.quantization.quant import (
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.node_utils import (
|
||||
WeightBiasInfoCache,
|
||||
extract_weight_nodes,
|
||||
get_quantization_params_from_linear_node,
|
||||
get_weight_info,
|
||||
is_bmm_op,
|
||||
is_linear_op,
|
||||
)
|
||||
@ -109,27 +110,28 @@ class Quantization(BaseTransform):
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
qcfg = factory.get_quant_config()
|
||||
if not qcfg or (
|
||||
qcfg.get("quant_algo", "").upper() != self.algo_name
|
||||
and qcfg.get("quant_method", "").upper() != self.algo_name
|
||||
):
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
excluded = qcfg.get("exclude_modules", [])
|
||||
cnt = 0
|
||||
for n in gm.graph.nodes:
|
||||
if not is_linear_op(n):
|
||||
continue
|
||||
if should_skip_quantization(n, excluded):
|
||||
continue
|
||||
self._insert_quantized_linear(gm, n, is_quantized_graph=False)
|
||||
cnt += 1
|
||||
with WeightBiasInfoCache():
|
||||
qcfg = factory.get_quant_config()
|
||||
if not qcfg or (
|
||||
qcfg.get("quant_algo", "").upper() != self.algo_name
|
||||
and qcfg.get("quant_method", "").upper() != self.algo_name
|
||||
):
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
excluded = qcfg.get("exclude_modules", [])
|
||||
cnt = 0
|
||||
for n in gm.graph.nodes:
|
||||
if not is_linear_op(n):
|
||||
continue
|
||||
if should_skip_quantization(n, excluded):
|
||||
continue
|
||||
self._insert_quantized_linear(gm, n, is_quantized_graph=False)
|
||||
cnt += 1
|
||||
|
||||
return gm, TransformInfo(
|
||||
skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=(cnt == 0)
|
||||
)
|
||||
return gm, TransformInfo(
|
||||
skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=(cnt == 0)
|
||||
)
|
||||
|
||||
def _insert_quantized_linear(
|
||||
self,
|
||||
@ -141,9 +143,10 @@ class Quantization(BaseTransform):
|
||||
|
||||
The state_dict is also updated to contain the sharded weights.
|
||||
"""
|
||||
lin_weight = get_weight_info(node)
|
||||
if lin_weight is None:
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
if len(weight_nodes.weights) == 0:
|
||||
raise ValueError(f"Linear node {node.name} has no weight")
|
||||
lin_weight = weight_nodes.weights[0]
|
||||
|
||||
new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False)
|
||||
modname, _, attrname = lin_weight.node_key.rpartition(".")
|
||||
@ -564,13 +567,15 @@ class FP8BMMQuantizationFromConfig(Quantization):
|
||||
|
||||
excluded = qcfg.get("exclude_modules", [])
|
||||
cnt = 0
|
||||
for n in gm.graph.nodes:
|
||||
if not is_bmm_op(n):
|
||||
continue
|
||||
if should_skip_quantization(n, excluded):
|
||||
continue
|
||||
if self._insert_quantized_bmm(gm, n, is_quantized_graph=False):
|
||||
cnt += 1
|
||||
|
||||
with WeightBiasInfoCache():
|
||||
for n in gm.graph.nodes:
|
||||
if not is_bmm_op(n):
|
||||
continue
|
||||
if should_skip_quantization(n, excluded):
|
||||
continue
|
||||
if self._insert_quantized_bmm(gm, n, is_quantized_graph=False):
|
||||
cnt += 1
|
||||
|
||||
return gm, TransformInfo(
|
||||
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=True
|
||||
|
||||
@ -38,11 +38,12 @@ from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import (
|
||||
LayerSubgraph,
|
||||
LayerType,
|
||||
WeightBiasInfoCache,
|
||||
bfs,
|
||||
extract_weight_name,
|
||||
extract_weight_nodes,
|
||||
filtered_nodes,
|
||||
get_all_layer_subgraphs,
|
||||
get_all_weight_infos,
|
||||
get_all_weights_in_subgraph,
|
||||
is_any_attention_op,
|
||||
is_any_lin_op,
|
||||
@ -821,35 +822,44 @@ class Sharding(BaseTransform):
|
||||
)
|
||||
|
||||
info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True)
|
||||
for source in config.sharding_source:
|
||||
if source == ShardingSource.FACTORY:
|
||||
if len(config.factory_config) == 0:
|
||||
ad_logger.debug(
|
||||
"No factory config found. Skipping sharding from factory config"
|
||||
with WeightBiasInfoCache():
|
||||
for source in config.sharding_source:
|
||||
if source == ShardingSource.FACTORY:
|
||||
if len(config.factory_config) == 0:
|
||||
ad_logger.debug(
|
||||
"No factory config found. Skipping sharding from factory config"
|
||||
)
|
||||
continue
|
||||
ad_logger.info("Applying sharding from factory config")
|
||||
info += detect_sharding_from_config(
|
||||
gm, transform_container, ShardingSource.FACTORY
|
||||
)
|
||||
elif source == ShardingSource.MANUAL:
|
||||
if len(config.manual_config) == 0:
|
||||
ad_logger.debug(
|
||||
"No manual config found. Skipping sharding from manual config"
|
||||
)
|
||||
continue
|
||||
ad_logger.info("Applying sharding from manual config")
|
||||
info += detect_sharding_from_config(
|
||||
gm, transform_container, ShardingSource.MANUAL
|
||||
)
|
||||
continue
|
||||
ad_logger.info("Applying sharding from factory config")
|
||||
info += detect_sharding_from_config(gm, transform_container, ShardingSource.FACTORY)
|
||||
elif source == ShardingSource.MANUAL:
|
||||
if len(config.manual_config) == 0:
|
||||
ad_logger.debug("No manual config found. Skipping sharding from manual config")
|
||||
continue
|
||||
ad_logger.info("Applying sharding from manual config")
|
||||
info += detect_sharding_from_config(gm, transform_container, ShardingSource.MANUAL)
|
||||
|
||||
elif source == ShardingSource.HEURISTIC:
|
||||
ad_logger.info(f"Running autodeploy sharding heuristics: {config.sharding_dims}")
|
||||
# run TP sharding across ranks
|
||||
if ShardingDim.TP in config.sharding_dims:
|
||||
info += detect_column_row_shard(gm, transform_container)
|
||||
elif source == ShardingSource.HEURISTIC:
|
||||
ad_logger.info(
|
||||
f"Running autodeploy sharding heuristics: {config.sharding_dims}"
|
||||
)
|
||||
# run TP sharding across ranks
|
||||
if ShardingDim.TP in config.sharding_dims:
|
||||
info += detect_column_row_shard(gm, transform_container)
|
||||
|
||||
# run EP sharding across ranks
|
||||
if ShardingDim.EP in config.sharding_dims:
|
||||
info += detect_ep_shard(gm, transform_container)
|
||||
# run EP sharding across ranks
|
||||
if ShardingDim.EP in config.sharding_dims:
|
||||
info += detect_ep_shard(gm, transform_container)
|
||||
|
||||
# run BMM sharding across ranks
|
||||
if ShardingDim.BMM in config.sharding_dims:
|
||||
info += detect_dp_bmm_shard(gm, transform_container)
|
||||
# run BMM sharding across ranks
|
||||
if ShardingDim.BMM in config.sharding_dims:
|
||||
info += detect_dp_bmm_shard(gm, transform_container)
|
||||
|
||||
return gm, info
|
||||
|
||||
@ -889,18 +899,19 @@ class ShardingTransformExecutor(BaseTransform):
|
||||
|
||||
num_matches = 0
|
||||
transforms = shared_config.sharding_transform_container
|
||||
for tp_transform in transforms.weight_sharding_transforms:
|
||||
if check_and_apply(tp_transform):
|
||||
num_matches += 1
|
||||
for bmm_transform in transforms.bmm_transforms:
|
||||
if check_and_apply(bmm_transform):
|
||||
num_matches += 1
|
||||
for ep_transform in transforms.ep_transforms:
|
||||
if check_and_apply(ep_transform):
|
||||
num_matches += 1
|
||||
for rmsnorm_transform in transforms.rmsnorm_transforms:
|
||||
if check_and_apply(rmsnorm_transform):
|
||||
num_matches += 1
|
||||
with WeightBiasInfoCache():
|
||||
for tp_transform in transforms.weight_sharding_transforms:
|
||||
if check_and_apply(tp_transform):
|
||||
num_matches += 1
|
||||
for bmm_transform in transforms.bmm_transforms:
|
||||
if check_and_apply(bmm_transform):
|
||||
num_matches += 1
|
||||
for ep_transform in transforms.ep_transforms:
|
||||
if check_and_apply(ep_transform):
|
||||
num_matches += 1
|
||||
for rmsnorm_transform in transforms.rmsnorm_transforms:
|
||||
if check_and_apply(rmsnorm_transform):
|
||||
num_matches += 1
|
||||
|
||||
# post-sharding cleanup transformations
|
||||
for update_transform in transforms.parameter_update_transforms:
|
||||
@ -1308,17 +1319,17 @@ def _shard_parameter_node(
|
||||
return
|
||||
|
||||
# Shard weight using the unified function (also updates the parameter)
|
||||
all_weight_infos = get_all_weight_infos(node)
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
# Parametrized nodes must have at least one weight (for debugging)
|
||||
assert len(all_weight_infos.weights) > 0, (
|
||||
assert len(weight_nodes.weights) > 0, (
|
||||
f"Node {node.name} has no weights - weight mapping may be incorrect"
|
||||
)
|
||||
|
||||
for weight_info in all_weight_infos.weights:
|
||||
for weight_node in weight_nodes.weights:
|
||||
_, weight_new_shape = shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=weight_info.tensor,
|
||||
param_key=weight_info.node_key,
|
||||
weight_tensor=weight_node.tensor,
|
||||
param_key=weight_node.node_key,
|
||||
dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
@ -1328,22 +1339,22 @@ def _shard_parameter_node(
|
||||
if quantization_cb is not None:
|
||||
quantization_cb(
|
||||
gm=gm,
|
||||
submod=weight_info.submod,
|
||||
submod=weight_node.submod,
|
||||
node=node,
|
||||
weight_key=weight_info.node_key,
|
||||
weight_key=weight_node.node_key,
|
||||
weight_new_shape=weight_new_shape,
|
||||
dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
for bias_info in all_weight_infos.biases:
|
||||
for bias_node in weight_nodes.biases:
|
||||
if dim == 0:
|
||||
# update bias for dim 0 --> we can handle it like the weight
|
||||
shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=bias_info.tensor,
|
||||
param_key=bias_info.node_key,
|
||||
weight_tensor=bias_node.tensor,
|
||||
param_key=bias_node.node_key,
|
||||
dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
@ -1358,10 +1369,10 @@ def _shard_parameter_node(
|
||||
args[2] = None
|
||||
node.args = tuple(args)
|
||||
gm.graph.erase_node(node_bias)
|
||||
bias_param_name = bias_info.node_key.rpartition(".")[-1]
|
||||
setattr(bias_info.submod, bias_param_name, None)
|
||||
bias_param_name = bias_node.node_key.rpartition(".")[-1]
|
||||
setattr(bias_node.submod, bias_param_name, None)
|
||||
gm._register_load_state_dict_pre_hook(
|
||||
partial(_load_hook_remove, param_key=bias_info.node_key)
|
||||
partial(_load_hook_remove, param_key=bias_node.node_key)
|
||||
)
|
||||
|
||||
# # # column shard with no gather: the output is sharded
|
||||
|
||||
@ -152,36 +152,133 @@ def get_all_weights_in_subgraph(
|
||||
|
||||
|
||||
def extract_weight_name(node: Node) -> Union[str, bool]:
|
||||
"""
|
||||
Extract the weight parameter name for a compute node.
|
||||
|
||||
Args:
|
||||
node: Compute node (linear, MoE, SSM, etc.)
|
||||
|
||||
Returns:
|
||||
Weight parameter name (str), or False if no weight exists.
|
||||
"""
|
||||
weight_node = get_weight_node(node)
|
||||
if weight_node is None:
|
||||
try:
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
except Exception:
|
||||
return False
|
||||
return weight_node.target
|
||||
if len(weight_nodes.weights) == 0:
|
||||
return False
|
||||
return weight_nodes.weights[0].node_key
|
||||
|
||||
|
||||
def get_param_or_buffer(tensor_name: str, gm: GraphModule) -> torch.Tensor:
|
||||
if tensor_name in dict(gm.named_parameters()):
|
||||
return gm.get_parameter(tensor_name)
|
||||
elif tensor_name in dict(gm.named_buffers()):
|
||||
return gm.get_buffer(tensor_name)
|
||||
else:
|
||||
raise KeyError(f"Tensor {tensor_name} not found in the graph")
|
||||
param_dict = WeightBiasInfoCache.get_param_dict(gm)
|
||||
if tensor_name in param_dict:
|
||||
return param_dict[tensor_name]
|
||||
buffer_dict = WeightBiasInfoCache.get_buffer_dict(gm)
|
||||
if tensor_name in buffer_dict:
|
||||
return buffer_dict[tensor_name]
|
||||
raise KeyError(f"Tensor {tensor_name} not found in the graph")
|
||||
|
||||
|
||||
class WeightBiasInfoCache:
|
||||
"""Cache for weight and bias information to avoid repeated expensive operations.
|
||||
|
||||
This class manages caches for parameter names and weight shapes that are used
|
||||
during graph transformation operations. Use it as a context manager to scope
|
||||
the cache lifetime.
|
||||
|
||||
Example:
|
||||
with WeightBiasInfoCache() as cache:
|
||||
# All calls to get_weight_shape and extract_weight_nodes
|
||||
# within this block use caching
|
||||
layer_subgraphs, _ = get_all_layer_subgraphs(gm)
|
||||
# Caches are cleared here
|
||||
"""
|
||||
|
||||
# Class-level reference to the currently active cache instance
|
||||
_active_instance: "WeightBiasInfoCache" = None
|
||||
|
||||
def __init__(self):
|
||||
# Cache for param/buffer dicts to avoid repeated expensive named_parameters/named_buffers calls
|
||||
self._param_dict_cache = {}
|
||||
self._buffer_dict_cache = {}
|
||||
# Cache for get_weight_shape to avoid repeated expensive extract_weight_nodes calls
|
||||
self._weight_shape_cache = {}
|
||||
# Activate this cache instance
|
||||
WeightBiasInfoCache._active_instance = self
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
"""Explicitly deactivate and clear the cache."""
|
||||
if WeightBiasInfoCache._active_instance is self:
|
||||
WeightBiasInfoCache._active_instance = None
|
||||
self._param_dict_cache.clear()
|
||||
self._buffer_dict_cache.clear()
|
||||
self._weight_shape_cache.clear()
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup when the cache is garbage collected."""
|
||||
self.close()
|
||||
|
||||
@classmethod
|
||||
def is_active(cls) -> bool:
|
||||
"""Check if caching is currently enabled."""
|
||||
return cls._active_instance is not None
|
||||
|
||||
@classmethod
|
||||
def get_param_dict(cls, gm: GraphModule) -> dict:
|
||||
"""Get cached parameters dict for a GraphModule, or compute and cache it."""
|
||||
if cls._active_instance is None:
|
||||
return dict(gm.named_parameters())
|
||||
|
||||
cache = cls._active_instance._param_dict_cache
|
||||
if gm not in cache:
|
||||
cache[gm] = dict(gm.named_parameters())
|
||||
return cache[gm]
|
||||
|
||||
@classmethod
|
||||
def get_buffer_dict(cls, gm: GraphModule) -> dict:
|
||||
"""Get cached buffers dict for a GraphModule, or compute and cache it."""
|
||||
if cls._active_instance is None:
|
||||
return dict(gm.named_buffers())
|
||||
|
||||
cache = cls._active_instance._buffer_dict_cache
|
||||
if gm not in cache:
|
||||
cache[gm] = dict(gm.named_buffers())
|
||||
return cache[gm]
|
||||
|
||||
@classmethod
|
||||
def get_param_names(cls, gm: GraphModule) -> set:
|
||||
"""Get cached parameter and buffer names for a GraphModule."""
|
||||
param_dict = cls.get_param_dict(gm)
|
||||
buffer_dict = cls.get_buffer_dict(gm)
|
||||
return set(param_dict.keys()).union(buffer_dict.keys())
|
||||
|
||||
@classmethod
|
||||
def get_weight_shape(cls, node: Node) -> Tuple[bool, Optional[List[int]]]:
|
||||
"""Get cached weight shape for a node.
|
||||
|
||||
Returns:
|
||||
Tuple of (found, value). If found is False, value should be ignored.
|
||||
"""
|
||||
if cls._active_instance is None:
|
||||
return False, None
|
||||
|
||||
cache = cls._active_instance._weight_shape_cache
|
||||
if node in cache:
|
||||
return True, cache[node]
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
def set_weight_shape(cls, node: Node, shape: Optional[List[int]]):
|
||||
"""Store weight shape in cache."""
|
||||
if cls._active_instance is not None:
|
||||
cls._active_instance._weight_shape_cache[node] = shape
|
||||
|
||||
|
||||
def extract_weight_nodes(node: Node) -> WeightNodes:
|
||||
"""Extracts the list of weight node and optional bias node from the given parametrized node"""
|
||||
gm = node.graph.owning_module
|
||||
param_names = {name for name, _ in gm.named_parameters()}.union(
|
||||
{name for name, _ in gm.named_buffers()}
|
||||
)
|
||||
|
||||
# Use cached param_names to avoid repeated expensive named_parameters/named_buffers calls
|
||||
param_names = WeightBiasInfoCache.get_param_names(gm)
|
||||
|
||||
def find_get_attr_node(weight_node: Node) -> Node:
|
||||
"""Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op."""
|
||||
@ -255,18 +352,21 @@ def extract_weight_nodes(node: Node) -> WeightNodes:
|
||||
return WeightNodes(weights=weight_nodes, biases=bias_nodes)
|
||||
|
||||
|
||||
def get_weight_node(node: Node) -> Node:
|
||||
"""Get the primary weight node for a compute node."""
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
if len(weight_nodes.weights) == 0:
|
||||
raise ValueError(f"Node {node.name} has no weight")
|
||||
return weight_nodes.weights[0].node
|
||||
|
||||
|
||||
def num_users_of_weight_node(node: Node) -> int:
|
||||
"""
|
||||
Get the number of users of the weight node.
|
||||
|
||||
Args:
|
||||
node: Compute node (linear, MoE, SSM, etc.)
|
||||
|
||||
Returns:
|
||||
Number of users of the primary weight node, or 0 if no weight exists.
|
||||
"""
|
||||
weight_node = get_weight_node(node)
|
||||
return len(weight_node.users) if weight_node is not None else 0
|
||||
"""Returns the number of users of the weight node of the given parametrized node."""
|
||||
try:
|
||||
weight_node = get_weight_node(node)
|
||||
except ValueError:
|
||||
return 0
|
||||
return len(weight_node.users)
|
||||
|
||||
|
||||
def get_op_overload_packet(node: Union[OpOverloadPacket, OpOverload]) -> OpOverloadPacket:
|
||||
@ -469,167 +569,6 @@ def is_weight_node(node: Node) -> bool:
|
||||
return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0
|
||||
|
||||
|
||||
# Auxiliary ops that may appear between a weight node and its consumer compute node
|
||||
_WEIGHT_AUX_OPS = frozenset(
|
||||
{
|
||||
torch.ops.aten.to.dtype,
|
||||
torch.ops.aten.view.default,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def precompute_weight_node_mapping(gm: GraphModule) -> None:
|
||||
"""
|
||||
Pre-compute weight-to-consumer mapping for all weight nodes in the graph.
|
||||
|
||||
For each weight node (get_attr), finds the consumer compute node by traversing
|
||||
through auxiliary ops (to.dtype, view.default). Stores the mapping in consumer
|
||||
node's metadata:
|
||||
- node.meta["weight_nodes"]: list of weight nodes (non-bias)
|
||||
- node.meta["bias_nodes"]: list of bias nodes
|
||||
|
||||
This enables O(1) weight node lookup instead of O(depth) backward traversal.
|
||||
Called automatically on first weight lookup via lazy initialization.
|
||||
|
||||
GUARANTEES (verified by assertions for debugging):
|
||||
- Called exactly once per GraphModule
|
||||
- No duplicate weight/bias nodes in any consumer's lists
|
||||
- Each weight node mapped to exactly one consumer
|
||||
"""
|
||||
# Early return if already computed
|
||||
if "_weight_mapping_computed" in gm.meta:
|
||||
return
|
||||
gm.meta["_weight_mapping_computed"] = True
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if not is_weight_node(node):
|
||||
continue
|
||||
|
||||
is_bias = node.target.endswith("bias")
|
||||
|
||||
# Find the consumer compute node by traversing through auxiliary ops
|
||||
current = node
|
||||
visited = {current}
|
||||
|
||||
while True:
|
||||
# Get users of current node
|
||||
users = list(current.users.keys())
|
||||
if not users:
|
||||
break
|
||||
|
||||
# Check if any user is a compute node (not an auxiliary op)
|
||||
consumer_found = None
|
||||
aux_node = None
|
||||
|
||||
for user in users:
|
||||
if is_bias:
|
||||
if "bias_nodes" not in user.meta:
|
||||
user.meta["bias_nodes"] = []
|
||||
# ASSERTION: Each weight node should be mapped exactly once
|
||||
assert node not in user.meta["bias_nodes"], (
|
||||
f"Duplicate bias node {node.name} found for consumer {user.name}"
|
||||
)
|
||||
user.meta["bias_nodes"].append(node)
|
||||
else:
|
||||
if "weight_nodes" not in user.meta:
|
||||
user.meta["weight_nodes"] = []
|
||||
# ASSERTION: Each weight node should be mapped exactly once
|
||||
assert node not in user.meta["weight_nodes"], (
|
||||
f"Duplicate weight node {node.name} found for consumer {user.name}"
|
||||
)
|
||||
user.meta["weight_nodes"].append(node)
|
||||
if user.target in _WEIGHT_AUX_OPS:
|
||||
# This is an auxiliary op, continue traversing
|
||||
aux_node = user
|
||||
else:
|
||||
# This is a potential consumer compute node
|
||||
consumer_found = user
|
||||
break
|
||||
|
||||
if consumer_found is not None:
|
||||
# Found the consumer, return
|
||||
break
|
||||
elif aux_node is not None and aux_node not in visited:
|
||||
# Continue through auxiliary op
|
||||
current = aux_node
|
||||
visited.add(current)
|
||||
else:
|
||||
# No more nodes to traverse
|
||||
break
|
||||
|
||||
|
||||
def _ensure_weight_mapping(node: Node) -> None:
|
||||
"""Ensure weight node mapping is computed. Lazily calls precompute if needed."""
|
||||
gm = node.graph.owning_module
|
||||
if "_weight_mapping_computed" not in gm.meta or not gm.meta["_weight_mapping_computed"]:
|
||||
precompute_weight_node_mapping(gm)
|
||||
|
||||
|
||||
def get_weight_node(node: Node) -> Optional[Node]:
|
||||
"""Get the primary weight node for a compute node"""
|
||||
_ensure_weight_mapping(node)
|
||||
weight_nodes = node.meta.get("weight_nodes", [])
|
||||
return weight_nodes[0] if weight_nodes else None
|
||||
|
||||
|
||||
def get_weight_nodes(node: Node) -> List[Node]:
|
||||
"""Get all weight nodes for a compute node"""
|
||||
_ensure_weight_mapping(node)
|
||||
return node.meta.get("weight_nodes", [])
|
||||
|
||||
|
||||
def get_bias_nodes(node: Node) -> List[Node]:
|
||||
"""Get all bias nodes for a compute node"""
|
||||
_ensure_weight_mapping(node)
|
||||
return node.meta.get("bias_nodes", [])
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightInfo:
|
||||
"""Lightweight weight info extracted from a weight node."""
|
||||
|
||||
node: Node
|
||||
node_key: str
|
||||
tensor: torch.Tensor
|
||||
submod: nn.Module
|
||||
|
||||
|
||||
def _weight_node_to_info(weight_node: Node, gm: GraphModule) -> WeightInfo:
|
||||
"""Convert a weight node to WeightInfo."""
|
||||
node_key = weight_node.target
|
||||
tensor = get_param_or_buffer(node_key, gm)
|
||||
submod = gm.get_submodule(node_key.rpartition(".")[0])
|
||||
return WeightInfo(node=weight_node, node_key=node_key, tensor=tensor, submod=submod)
|
||||
|
||||
|
||||
def get_weight_info(node: Node) -> Optional[WeightInfo]:
|
||||
"""Extract weight info for the primary weight of a compute node."""
|
||||
weight_node = get_weight_node(node)
|
||||
if weight_node is None:
|
||||
return None
|
||||
return _weight_node_to_info(weight_node, node.graph.owning_module)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllWeightInfos:
|
||||
"""Container for all weight and bias infos of a compute node."""
|
||||
|
||||
weights: List[WeightInfo]
|
||||
biases: List[WeightInfo]
|
||||
|
||||
|
||||
def get_all_weight_infos(node: Node) -> AllWeightInfos:
|
||||
"""Extract all weight and bias infos for a compute node."""
|
||||
gm = node.graph.owning_module
|
||||
weight_nodes = get_weight_nodes(node)
|
||||
bias_nodes = get_bias_nodes(node)
|
||||
|
||||
return AllWeightInfos(
|
||||
weights=[_weight_node_to_info(wn, gm) for wn in weight_nodes],
|
||||
biases=[_weight_node_to_info(bn, gm) for bn in bias_nodes],
|
||||
)
|
||||
|
||||
|
||||
def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0):
|
||||
"""Get a user from a node if the node matches a given op set and num of users."""
|
||||
if node is None:
|
||||
@ -703,7 +642,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> tuple[List[LayerSubgraph], set[N
|
||||
"""
|
||||
Get subgraphs for all consecutive layers (attention, MLP, SSM, MoE) in the graph.
|
||||
|
||||
Pre-computes weight mappings and caches weight shapes for all linear nodes.
|
||||
Caches weight shapes for all linear nodes using WeightBiasInfoCache.
|
||||
Each layer is contained between opening linear layers and a single closing linear layer.
|
||||
|
||||
Assumptions:
|
||||
@ -734,9 +673,6 @@ def get_all_layer_subgraphs(gm: GraphModule) -> tuple[List[LayerSubgraph], set[N
|
||||
layer_subgraphs = []
|
||||
linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op))
|
||||
|
||||
# Pre-compute weight-to-consumer mapping for O(1) weight node lookup
|
||||
precompute_weight_node_mapping(gm)
|
||||
|
||||
# Cache weight shapes for all linear nodes
|
||||
for lin_node in linear_nodes:
|
||||
if "lin_node_shape" not in lin_node.meta:
|
||||
@ -1006,19 +942,25 @@ def subgraph(
|
||||
|
||||
|
||||
def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]:
|
||||
"""Get weight shape for a linear operation node. Returns None if no weight."""
|
||||
"""Get the shape of the weight node."""
|
||||
if not is_any_lin_op(node):
|
||||
return None
|
||||
|
||||
weight_node = get_weight_node(node)
|
||||
if weight_node is None:
|
||||
# Try to get from cache first
|
||||
found, s = WeightBiasInfoCache.get_weight_shape(node)
|
||||
if not found:
|
||||
# Not in cache or caching not enabled - compute the shape
|
||||
s = list(shape(extract_weight_nodes(node).weights[0].node))
|
||||
if len(s) == 0:
|
||||
s = None
|
||||
elif is_fp4_op(node):
|
||||
# FP4 weights are packed as uint8 type with 2 FP4 values per element
|
||||
s[-1] *= 2
|
||||
# Store in cache if caching is enabled
|
||||
WeightBiasInfoCache.set_weight_shape(node, s)
|
||||
|
||||
if s is None:
|
||||
return None
|
||||
|
||||
s = list(shape(weight_node))
|
||||
|
||||
if is_fp4_op(node):
|
||||
# FP4 weights are packed as uint8 type with 2 FP4 values per element
|
||||
s[-1] *= 2
|
||||
if dim is None:
|
||||
return s
|
||||
else:
|
||||
@ -1254,13 +1196,11 @@ def shape(node: Node) -> Tuple[int, ...]:
|
||||
|
||||
|
||||
def get_weight_tensor(node: Node) -> torch.Tensor:
|
||||
"""Extract the weight tensor from a compute node."""
|
||||
weight_node = get_weight_node(node)
|
||||
if weight_node is None:
|
||||
"""Extract the weight tensor from a node within a GraphModule."""
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
if len(weight_nodes.weights) == 0:
|
||||
raise ValueError(f"Node {node.name} has no weight")
|
||||
|
||||
gm = node.graph.owning_module
|
||||
return get_param_or_buffer(weight_node.target, gm)
|
||||
return weight_nodes.weights[0].tensor
|
||||
|
||||
|
||||
def draw_graph(gm: GraphModule, filename: str):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user