[#7222][autodeploy] Separate run_shape_prop as another graph utility (#7313)

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Frida Hou 2025-09-03 16:32:50 -07:00 committed by GitHub
parent bd9ba97d89
commit 51a2b8729e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 42 deletions

View File

@ -13,7 +13,7 @@ from torch.fx import GraphModule
from ..models.factory import ModelFactory
from ..shim.interface import CachedSequenceInterface
from ..transformations._graph import canonicalize_graph, lift_to_meta
from ..transformations._graph import canonicalize_graph, lift_to_meta, run_shape_prop
from ..utils.logger import ad_logger
from ..utils.sharding_utils import ShardingConfig
@ -328,7 +328,7 @@ class BaseTransform(ABC):
if self.config.requires_shape_prop and not has_valid_shapes:
canonicalize_graph(gm)
with lift_to_meta(gm):
canonicalize_graph(gm, shape_prop=True)
run_shape_prop(gm)
is_clean = True
has_valid_shapes = True
elif self.config.requires_clean_graph and not is_clean:
@ -354,7 +354,7 @@ class BaseTransform(ABC):
if self.config.run_shape_prop and not (info.is_clean and info.has_valid_shapes):
canonicalize_graph(gm)
with lift_to_meta(gm):
canonicalize_graph(gm, shape_prop=True)
run_shape_prop(gm)
elif self.config.run_graph_cleanup and not info.is_clean:
canonicalize_graph(gm)

View File

@ -129,7 +129,7 @@ def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None:
gm.recompile()
def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> fx.GraphModule:
def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> None:
"""Move the entire graph module and all sub-GraphModules to the specified device."""
# get device
device = torch.device(device)
@ -154,9 +154,7 @@ def _is_impure_node(node: Node) -> bool:
node.target._nondeterministic_seeded = True
def _canonicalize_single_gm(
gm: GraphModule, shape_prop: bool = False, args_static: Optional[Tuple[Any, ...]] = None
) -> GraphModule:
def _canonicalize_single_gm(gm: GraphModule) -> None:
# clean up graph (needs to be done repeatedly until no more dead code)
gm.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
@ -166,59 +164,81 @@ def _canonicalize_single_gm(
# clean up graph module
gm.delete_all_unused_submodules()
# NOTE: shape_prop can be a littly finicky & slow, so we only run it optionally...
if shape_prop:
fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm)
# get fake tensors from placeholder nodes
inps = [node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder"]
# check if we need to use args to create fake tensors
if any(inp is None for inp in inps):
if args_static is not None and fake_mode is not None and len(args_static) == len(inps):
inps = [
fake_t if fake_t is not None else fake_mode.from_tensor(arg, static_shapes=True)
for fake_t, arg in zip(inps, args_static)
]
# run shape propagation if we have all the fake tensors
if all(inp is not None for inp in inps):
FakeTensorProp(gm, fake_mode).propagate(*inps)
else:
ad_logger.warning("No fake tensors and no args available for shape propagation")
# lint the graph
gm.graph.lint()
def canonicalize_graph(
gm: GraphModule, shape_prop: bool = False, args_static: Optional[Tuple[Any, ...]] = None
) -> None:
def canonicalize_graph(gm: GraphModule) -> None:
"""Canonicalize the graph of the given GraphModule.
Args:
gm: The GraphModule to canonicalize.
shape_prop: Whether to run shape propagation. Shape propagation tends to be finicky and
slow, so we only run it optionally.
args_static: A tuple of static arguments to use for shape propagation. Shape propagation
requires all inputs to the graph ("placeholder" nodes) to have metadata with an
appropriate FakeTensor argument (``node.meta["val"]``). ``args_static`` can be used to
infer static FakeTensor information if some placeholder nodes do not have metadata.
When ``meta["val"]`` is available, it will take precedence over ``args_static``.
Returns:
The canonicalized (cleaned-up) GraphModule.
"""
ad_logger.debug(f"Before canonicalizing: {gm}")
for _, subgm in reversed(list(named_graphmodules(gm))):
_canonicalize_single_gm(
subgm, shape_prop=shape_prop, args_static=args_static if subgm is gm else None
)
_canonicalize_single_gm(subgm)
ad_logger.debug(f"After canonicalizing: {gm}")
def _run_shape_prop_single_gm(
gm: GraphModule,
args_static: Optional[Tuple[Any, ...]] = None,
) -> None:
fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm)
# get fake tensors from placeholder nodes
inps = [node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder"]
# check if we need to use args to create fake tensors
if any(inp is None for inp in inps):
if args_static is not None and fake_mode is not None and len(args_static) == len(inps):
inps = [
fake_t if fake_t is not None else fake_mode.from_tensor(arg, static_shapes=True)
for fake_t, arg in zip(inps, args_static)
]
# run shape propagation if we have all the fake tensors
if all(inp is not None for inp in inps):
FakeTensorProp(gm, fake_mode).propagate(*inps)
else:
ad_logger.warning("No fake tensors and no args available for shape propagation")
# lint the graph
gm.graph.lint()
def run_shape_prop(
gm: GraphModule,
args_static: Optional[Tuple[Any, ...]] = None,
) -> None:
"""Run FakeTensor-based shape propagation on the given GraphModule and its submodules.
This pass attempts to populate shape/type metadata for all nodes by propagating
FakeTensor inputs through the graph. If a placeholder node already has a
``node.meta["val"]`` entry, that FakeTensor will be used. Otherwise, if
``args_static`` is provided and a FakeTensorMode is detected, new FakeTensors
are synthesized from the static arguments.
Args:
gm: The top-level GraphModule on which to run shape propagation. All nested
GraphModules are processed in reverse topological order.
args_static: Optional tuple of concrete tensors used to create FakeTensors
when placeholder metadata is missing. Only applied to the top-level
GraphModule; submodules reuse their existing placeholder metadata.
"""
ad_logger.debug(f"Before running shape propagation: {gm}")
for _, subgm in reversed(list(named_graphmodules(gm))):
_run_shape_prop_single_gm(subgm, args_static=args_static if subgm is gm else None)
ad_logger.debug(f"After running shape propagation: {gm}")
def add_graph_input(
gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None
) -> Node: