[#11203][feat] AutoDeploy: Refactor node caching and improve engine build time (#11250)

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
This commit is contained in:
Taylor Yeonbok Lee 2026-02-10 13:35:44 -08:00 committed by GitHub
parent f320bc8a9c
commit 860054c859
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 249 additions and 293 deletions

View File

@ -17,8 +17,9 @@ from ...custom_ops.quantization.quant import (
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import (
WeightBiasInfoCache,
extract_weight_nodes,
get_quantization_params_from_linear_node,
get_weight_info,
is_bmm_op,
is_linear_op,
)
@ -109,27 +110,28 @@ class Quantization(BaseTransform):
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
qcfg = factory.get_quant_config()
if not qcfg or (
qcfg.get("quant_algo", "").upper() != self.algo_name
and qcfg.get("quant_method", "").upper() != self.algo_name
):
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
excluded = qcfg.get("exclude_modules", [])
cnt = 0
for n in gm.graph.nodes:
if not is_linear_op(n):
continue
if should_skip_quantization(n, excluded):
continue
self._insert_quantized_linear(gm, n, is_quantized_graph=False)
cnt += 1
with WeightBiasInfoCache():
qcfg = factory.get_quant_config()
if not qcfg or (
qcfg.get("quant_algo", "").upper() != self.algo_name
and qcfg.get("quant_method", "").upper() != self.algo_name
):
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
excluded = qcfg.get("exclude_modules", [])
cnt = 0
for n in gm.graph.nodes:
if not is_linear_op(n):
continue
if should_skip_quantization(n, excluded):
continue
self._insert_quantized_linear(gm, n, is_quantized_graph=False)
cnt += 1
return gm, TransformInfo(
skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=(cnt == 0)
)
return gm, TransformInfo(
skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=(cnt == 0)
)
def _insert_quantized_linear(
self,
@ -141,9 +143,10 @@ class Quantization(BaseTransform):
The state_dict is also updated to contain the sharded weights.
"""
lin_weight = get_weight_info(node)
if lin_weight is None:
weight_nodes = extract_weight_nodes(node)
if len(weight_nodes.weights) == 0:
raise ValueError(f"Linear node {node.name} has no weight")
lin_weight = weight_nodes.weights[0]
new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False)
modname, _, attrname = lin_weight.node_key.rpartition(".")
@ -564,13 +567,15 @@ class FP8BMMQuantizationFromConfig(Quantization):
excluded = qcfg.get("exclude_modules", [])
cnt = 0
for n in gm.graph.nodes:
if not is_bmm_op(n):
continue
if should_skip_quantization(n, excluded):
continue
if self._insert_quantized_bmm(gm, n, is_quantized_graph=False):
cnt += 1
with WeightBiasInfoCache():
for n in gm.graph.nodes:
if not is_bmm_op(n):
continue
if should_skip_quantization(n, excluded):
continue
if self._insert_quantized_bmm(gm, n, is_quantized_graph=False):
cnt += 1
return gm, TransformInfo(
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=True

View File

@ -38,11 +38,12 @@ from ...utils.logger import ad_logger
from ...utils.node_utils import (
LayerSubgraph,
LayerType,
WeightBiasInfoCache,
bfs,
extract_weight_name,
extract_weight_nodes,
filtered_nodes,
get_all_layer_subgraphs,
get_all_weight_infos,
get_all_weights_in_subgraph,
is_any_attention_op,
is_any_lin_op,
@ -821,35 +822,44 @@ class Sharding(BaseTransform):
)
info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True)
for source in config.sharding_source:
if source == ShardingSource.FACTORY:
if len(config.factory_config) == 0:
ad_logger.debug(
"No factory config found. Skipping sharding from factory config"
with WeightBiasInfoCache():
for source in config.sharding_source:
if source == ShardingSource.FACTORY:
if len(config.factory_config) == 0:
ad_logger.debug(
"No factory config found. Skipping sharding from factory config"
)
continue
ad_logger.info("Applying sharding from factory config")
info += detect_sharding_from_config(
gm, transform_container, ShardingSource.FACTORY
)
elif source == ShardingSource.MANUAL:
if len(config.manual_config) == 0:
ad_logger.debug(
"No manual config found. Skipping sharding from manual config"
)
continue
ad_logger.info("Applying sharding from manual config")
info += detect_sharding_from_config(
gm, transform_container, ShardingSource.MANUAL
)
continue
ad_logger.info("Applying sharding from factory config")
info += detect_sharding_from_config(gm, transform_container, ShardingSource.FACTORY)
elif source == ShardingSource.MANUAL:
if len(config.manual_config) == 0:
ad_logger.debug("No manual config found. Skipping sharding from manual config")
continue
ad_logger.info("Applying sharding from manual config")
info += detect_sharding_from_config(gm, transform_container, ShardingSource.MANUAL)
elif source == ShardingSource.HEURISTIC:
ad_logger.info(f"Running autodeploy sharding heuristics: {config.sharding_dims}")
# run TP sharding across ranks
if ShardingDim.TP in config.sharding_dims:
info += detect_column_row_shard(gm, transform_container)
elif source == ShardingSource.HEURISTIC:
ad_logger.info(
f"Running autodeploy sharding heuristics: {config.sharding_dims}"
)
# run TP sharding across ranks
if ShardingDim.TP in config.sharding_dims:
info += detect_column_row_shard(gm, transform_container)
# run EP sharding across ranks
if ShardingDim.EP in config.sharding_dims:
info += detect_ep_shard(gm, transform_container)
# run EP sharding across ranks
if ShardingDim.EP in config.sharding_dims:
info += detect_ep_shard(gm, transform_container)
# run BMM sharding across ranks
if ShardingDim.BMM in config.sharding_dims:
info += detect_dp_bmm_shard(gm, transform_container)
# run BMM sharding across ranks
if ShardingDim.BMM in config.sharding_dims:
info += detect_dp_bmm_shard(gm, transform_container)
return gm, info
@ -889,18 +899,19 @@ class ShardingTransformExecutor(BaseTransform):
num_matches = 0
transforms = shared_config.sharding_transform_container
for tp_transform in transforms.weight_sharding_transforms:
if check_and_apply(tp_transform):
num_matches += 1
for bmm_transform in transforms.bmm_transforms:
if check_and_apply(bmm_transform):
num_matches += 1
for ep_transform in transforms.ep_transforms:
if check_and_apply(ep_transform):
num_matches += 1
for rmsnorm_transform in transforms.rmsnorm_transforms:
if check_and_apply(rmsnorm_transform):
num_matches += 1
with WeightBiasInfoCache():
for tp_transform in transforms.weight_sharding_transforms:
if check_and_apply(tp_transform):
num_matches += 1
for bmm_transform in transforms.bmm_transforms:
if check_and_apply(bmm_transform):
num_matches += 1
for ep_transform in transforms.ep_transforms:
if check_and_apply(ep_transform):
num_matches += 1
for rmsnorm_transform in transforms.rmsnorm_transforms:
if check_and_apply(rmsnorm_transform):
num_matches += 1
# post-sharding cleanup transformations
for update_transform in transforms.parameter_update_transforms:
@ -1308,17 +1319,17 @@ def _shard_parameter_node(
return
# Shard weight using the unified function (also updates the parameter)
all_weight_infos = get_all_weight_infos(node)
weight_nodes = extract_weight_nodes(node)
# Parametrized nodes must have at least one weight (for debugging)
assert len(all_weight_infos.weights) > 0, (
assert len(weight_nodes.weights) > 0, (
f"Node {node.name} has no weights - weight mapping may be incorrect"
)
for weight_info in all_weight_infos.weights:
for weight_node in weight_nodes.weights:
_, weight_new_shape = shard_weight_tensor(
gm=gm,
weight_tensor=weight_info.tensor,
param_key=weight_info.node_key,
weight_tensor=weight_node.tensor,
param_key=weight_node.node_key,
dim=dim,
rank=rank,
world_size=world_size,
@ -1328,22 +1339,22 @@ def _shard_parameter_node(
if quantization_cb is not None:
quantization_cb(
gm=gm,
submod=weight_info.submod,
submod=weight_node.submod,
node=node,
weight_key=weight_info.node_key,
weight_key=weight_node.node_key,
weight_new_shape=weight_new_shape,
dim=dim,
rank=rank,
world_size=world_size,
)
for bias_info in all_weight_infos.biases:
for bias_node in weight_nodes.biases:
if dim == 0:
# update bias for dim 0 --> we can handle it like the weight
shard_weight_tensor(
gm=gm,
weight_tensor=bias_info.tensor,
param_key=bias_info.node_key,
weight_tensor=bias_node.tensor,
param_key=bias_node.node_key,
dim=dim,
rank=rank,
world_size=world_size,
@ -1358,10 +1369,10 @@ def _shard_parameter_node(
args[2] = None
node.args = tuple(args)
gm.graph.erase_node(node_bias)
bias_param_name = bias_info.node_key.rpartition(".")[-1]
setattr(bias_info.submod, bias_param_name, None)
bias_param_name = bias_node.node_key.rpartition(".")[-1]
setattr(bias_node.submod, bias_param_name, None)
gm._register_load_state_dict_pre_hook(
partial(_load_hook_remove, param_key=bias_info.node_key)
partial(_load_hook_remove, param_key=bias_node.node_key)
)
# # # column shard with no gather: the output is sharded

View File

@ -152,36 +152,133 @@ def get_all_weights_in_subgraph(
def extract_weight_name(node: Node) -> Union[str, bool]:
"""
Extract the weight parameter name for a compute node.
Args:
node: Compute node (linear, MoE, SSM, etc.)
Returns:
Weight parameter name (str), or False if no weight exists.
"""
weight_node = get_weight_node(node)
if weight_node is None:
try:
weight_nodes = extract_weight_nodes(node)
except Exception:
return False
return weight_node.target
if len(weight_nodes.weights) == 0:
return False
return weight_nodes.weights[0].node_key
def get_param_or_buffer(tensor_name: str, gm: GraphModule) -> torch.Tensor:
if tensor_name in dict(gm.named_parameters()):
return gm.get_parameter(tensor_name)
elif tensor_name in dict(gm.named_buffers()):
return gm.get_buffer(tensor_name)
else:
raise KeyError(f"Tensor {tensor_name} not found in the graph")
param_dict = WeightBiasInfoCache.get_param_dict(gm)
if tensor_name in param_dict:
return param_dict[tensor_name]
buffer_dict = WeightBiasInfoCache.get_buffer_dict(gm)
if tensor_name in buffer_dict:
return buffer_dict[tensor_name]
raise KeyError(f"Tensor {tensor_name} not found in the graph")
class WeightBiasInfoCache:
"""Cache for weight and bias information to avoid repeated expensive operations.
This class manages caches for parameter names and weight shapes that are used
during graph transformation operations. Use it as a context manager to scope
the cache lifetime.
Example:
with WeightBiasInfoCache() as cache:
# All calls to get_weight_shape and extract_weight_nodes
# within this block use caching
layer_subgraphs, _ = get_all_layer_subgraphs(gm)
# Caches are cleared here
"""
# Class-level reference to the currently active cache instance
_active_instance: "WeightBiasInfoCache" = None
def __init__(self):
# Cache for param/buffer dicts to avoid repeated expensive named_parameters/named_buffers calls
self._param_dict_cache = {}
self._buffer_dict_cache = {}
# Cache for get_weight_shape to avoid repeated expensive extract_weight_nodes calls
self._weight_shape_cache = {}
# Activate this cache instance
WeightBiasInfoCache._active_instance = self
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
def close(self):
"""Explicitly deactivate and clear the cache."""
if WeightBiasInfoCache._active_instance is self:
WeightBiasInfoCache._active_instance = None
self._param_dict_cache.clear()
self._buffer_dict_cache.clear()
self._weight_shape_cache.clear()
def __del__(self):
"""Cleanup when the cache is garbage collected."""
self.close()
@classmethod
def is_active(cls) -> bool:
"""Check if caching is currently enabled."""
return cls._active_instance is not None
@classmethod
def get_param_dict(cls, gm: GraphModule) -> dict:
"""Get cached parameters dict for a GraphModule, or compute and cache it."""
if cls._active_instance is None:
return dict(gm.named_parameters())
cache = cls._active_instance._param_dict_cache
if gm not in cache:
cache[gm] = dict(gm.named_parameters())
return cache[gm]
@classmethod
def get_buffer_dict(cls, gm: GraphModule) -> dict:
"""Get cached buffers dict for a GraphModule, or compute and cache it."""
if cls._active_instance is None:
return dict(gm.named_buffers())
cache = cls._active_instance._buffer_dict_cache
if gm not in cache:
cache[gm] = dict(gm.named_buffers())
return cache[gm]
@classmethod
def get_param_names(cls, gm: GraphModule) -> set:
"""Get cached parameter and buffer names for a GraphModule."""
param_dict = cls.get_param_dict(gm)
buffer_dict = cls.get_buffer_dict(gm)
return set(param_dict.keys()).union(buffer_dict.keys())
@classmethod
def get_weight_shape(cls, node: Node) -> Tuple[bool, Optional[List[int]]]:
"""Get cached weight shape for a node.
Returns:
Tuple of (found, value). If found is False, value should be ignored.
"""
if cls._active_instance is None:
return False, None
cache = cls._active_instance._weight_shape_cache
if node in cache:
return True, cache[node]
return False, None
@classmethod
def set_weight_shape(cls, node: Node, shape: Optional[List[int]]):
"""Store weight shape in cache."""
if cls._active_instance is not None:
cls._active_instance._weight_shape_cache[node] = shape
def extract_weight_nodes(node: Node) -> WeightNodes:
"""Extracts the list of weight node and optional bias node from the given parametrized node"""
gm = node.graph.owning_module
param_names = {name for name, _ in gm.named_parameters()}.union(
{name for name, _ in gm.named_buffers()}
)
# Use cached param_names to avoid repeated expensive named_parameters/named_buffers calls
param_names = WeightBiasInfoCache.get_param_names(gm)
def find_get_attr_node(weight_node: Node) -> Node:
"""Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op."""
@ -255,18 +352,21 @@ def extract_weight_nodes(node: Node) -> WeightNodes:
return WeightNodes(weights=weight_nodes, biases=bias_nodes)
def get_weight_node(node: Node) -> Node:
"""Get the primary weight node for a compute node."""
weight_nodes = extract_weight_nodes(node)
if len(weight_nodes.weights) == 0:
raise ValueError(f"Node {node.name} has no weight")
return weight_nodes.weights[0].node
def num_users_of_weight_node(node: Node) -> int:
"""
Get the number of users of the weight node.
Args:
node: Compute node (linear, MoE, SSM, etc.)
Returns:
Number of users of the primary weight node, or 0 if no weight exists.
"""
weight_node = get_weight_node(node)
return len(weight_node.users) if weight_node is not None else 0
"""Returns the number of users of the weight node of the given parametrized node."""
try:
weight_node = get_weight_node(node)
except ValueError:
return 0
return len(weight_node.users)
def get_op_overload_packet(node: Union[OpOverloadPacket, OpOverload]) -> OpOverloadPacket:
@ -469,167 +569,6 @@ def is_weight_node(node: Node) -> bool:
return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0
# Auxiliary ops that may appear between a weight node and its consumer compute node
_WEIGHT_AUX_OPS = frozenset(
{
torch.ops.aten.to.dtype,
torch.ops.aten.view.default,
}
)
def precompute_weight_node_mapping(gm: GraphModule) -> None:
"""
Pre-compute weight-to-consumer mapping for all weight nodes in the graph.
For each weight node (get_attr), finds the consumer compute node by traversing
through auxiliary ops (to.dtype, view.default). Stores the mapping in consumer
node's metadata:
- node.meta["weight_nodes"]: list of weight nodes (non-bias)
- node.meta["bias_nodes"]: list of bias nodes
This enables O(1) weight node lookup instead of O(depth) backward traversal.
Called automatically on first weight lookup via lazy initialization.
GUARANTEES (verified by assertions for debugging):
- Called exactly once per GraphModule
- No duplicate weight/bias nodes in any consumer's lists
- Each weight node mapped to exactly one consumer
"""
# Early return if already computed
if "_weight_mapping_computed" in gm.meta:
return
gm.meta["_weight_mapping_computed"] = True
for node in gm.graph.nodes:
if not is_weight_node(node):
continue
is_bias = node.target.endswith("bias")
# Find the consumer compute node by traversing through auxiliary ops
current = node
visited = {current}
while True:
# Get users of current node
users = list(current.users.keys())
if not users:
break
# Check if any user is a compute node (not an auxiliary op)
consumer_found = None
aux_node = None
for user in users:
if is_bias:
if "bias_nodes" not in user.meta:
user.meta["bias_nodes"] = []
# ASSERTION: Each weight node should be mapped exactly once
assert node not in user.meta["bias_nodes"], (
f"Duplicate bias node {node.name} found for consumer {user.name}"
)
user.meta["bias_nodes"].append(node)
else:
if "weight_nodes" not in user.meta:
user.meta["weight_nodes"] = []
# ASSERTION: Each weight node should be mapped exactly once
assert node not in user.meta["weight_nodes"], (
f"Duplicate weight node {node.name} found for consumer {user.name}"
)
user.meta["weight_nodes"].append(node)
if user.target in _WEIGHT_AUX_OPS:
# This is an auxiliary op, continue traversing
aux_node = user
else:
# This is a potential consumer compute node
consumer_found = user
break
if consumer_found is not None:
# Found the consumer, return
break
elif aux_node is not None and aux_node not in visited:
# Continue through auxiliary op
current = aux_node
visited.add(current)
else:
# No more nodes to traverse
break
def _ensure_weight_mapping(node: Node) -> None:
"""Ensure weight node mapping is computed. Lazily calls precompute if needed."""
gm = node.graph.owning_module
if "_weight_mapping_computed" not in gm.meta or not gm.meta["_weight_mapping_computed"]:
precompute_weight_node_mapping(gm)
def get_weight_node(node: Node) -> Optional[Node]:
"""Get the primary weight node for a compute node"""
_ensure_weight_mapping(node)
weight_nodes = node.meta.get("weight_nodes", [])
return weight_nodes[0] if weight_nodes else None
def get_weight_nodes(node: Node) -> List[Node]:
"""Get all weight nodes for a compute node"""
_ensure_weight_mapping(node)
return node.meta.get("weight_nodes", [])
def get_bias_nodes(node: Node) -> List[Node]:
"""Get all bias nodes for a compute node"""
_ensure_weight_mapping(node)
return node.meta.get("bias_nodes", [])
@dataclass
class WeightInfo:
"""Lightweight weight info extracted from a weight node."""
node: Node
node_key: str
tensor: torch.Tensor
submod: nn.Module
def _weight_node_to_info(weight_node: Node, gm: GraphModule) -> WeightInfo:
"""Convert a weight node to WeightInfo."""
node_key = weight_node.target
tensor = get_param_or_buffer(node_key, gm)
submod = gm.get_submodule(node_key.rpartition(".")[0])
return WeightInfo(node=weight_node, node_key=node_key, tensor=tensor, submod=submod)
def get_weight_info(node: Node) -> Optional[WeightInfo]:
"""Extract weight info for the primary weight of a compute node."""
weight_node = get_weight_node(node)
if weight_node is None:
return None
return _weight_node_to_info(weight_node, node.graph.owning_module)
@dataclass
class AllWeightInfos:
"""Container for all weight and bias infos of a compute node."""
weights: List[WeightInfo]
biases: List[WeightInfo]
def get_all_weight_infos(node: Node) -> AllWeightInfos:
"""Extract all weight and bias infos for a compute node."""
gm = node.graph.owning_module
weight_nodes = get_weight_nodes(node)
bias_nodes = get_bias_nodes(node)
return AllWeightInfos(
weights=[_weight_node_to_info(wn, gm) for wn in weight_nodes],
biases=[_weight_node_to_info(bn, gm) for bn in bias_nodes],
)
def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0):
"""Get a user from a node if the node matches a given op set and num of users."""
if node is None:
@ -703,7 +642,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> tuple[List[LayerSubgraph], set[N
"""
Get subgraphs for all consecutive layers (attention, MLP, SSM, MoE) in the graph.
Pre-computes weight mappings and caches weight shapes for all linear nodes.
Caches weight shapes for all linear nodes using WeightBiasInfoCache.
Each layer is contained between opening linear layers and a single closing linear layer.
Assumptions:
@ -734,9 +673,6 @@ def get_all_layer_subgraphs(gm: GraphModule) -> tuple[List[LayerSubgraph], set[N
layer_subgraphs = []
linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op))
# Pre-compute weight-to-consumer mapping for O(1) weight node lookup
precompute_weight_node_mapping(gm)
# Cache weight shapes for all linear nodes
for lin_node in linear_nodes:
if "lin_node_shape" not in lin_node.meta:
@ -1006,19 +942,25 @@ def subgraph(
def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]:
"""Get weight shape for a linear operation node. Returns None if no weight."""
"""Get the shape of the weight node."""
if not is_any_lin_op(node):
return None
weight_node = get_weight_node(node)
if weight_node is None:
# Try to get from cache first
found, s = WeightBiasInfoCache.get_weight_shape(node)
if not found:
# Not in cache or caching not enabled - compute the shape
s = list(shape(extract_weight_nodes(node).weights[0].node))
if len(s) == 0:
s = None
elif is_fp4_op(node):
# FP4 weights are packed as uint8 type with 2 FP4 values per element
s[-1] *= 2
# Store in cache if caching is enabled
WeightBiasInfoCache.set_weight_shape(node, s)
if s is None:
return None
s = list(shape(weight_node))
if is_fp4_op(node):
# FP4 weights are packed as uint8 type with 2 FP4 values per element
s[-1] *= 2
if dim is None:
return s
else:
@ -1254,13 +1196,11 @@ def shape(node: Node) -> Tuple[int, ...]:
def get_weight_tensor(node: Node) -> torch.Tensor:
"""Extract the weight tensor from a compute node."""
weight_node = get_weight_node(node)
if weight_node is None:
"""Extract the weight tensor from a node within a GraphModule."""
weight_nodes = extract_weight_nodes(node)
if len(weight_nodes.weights) == 0:
raise ValueError(f"Node {node.name} has no weight")
gm = node.graph.owning_module
return get_param_or_buffer(weight_node.target, gm)
return weight_nodes.weights[0].tensor
def draw_graph(gm: GraphModule, filename: str):