mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-22 02:35:21 +08:00
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
bd9ba97d89
commit
51a2b8729e
@ -13,7 +13,7 @@ from torch.fx import GraphModule
|
||||
|
||||
from ..models.factory import ModelFactory
|
||||
from ..shim.interface import CachedSequenceInterface
|
||||
from ..transformations._graph import canonicalize_graph, lift_to_meta
|
||||
from ..transformations._graph import canonicalize_graph, lift_to_meta, run_shape_prop
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.sharding_utils import ShardingConfig
|
||||
|
||||
@ -328,7 +328,7 @@ class BaseTransform(ABC):
|
||||
if self.config.requires_shape_prop and not has_valid_shapes:
|
||||
canonicalize_graph(gm)
|
||||
with lift_to_meta(gm):
|
||||
canonicalize_graph(gm, shape_prop=True)
|
||||
run_shape_prop(gm)
|
||||
is_clean = True
|
||||
has_valid_shapes = True
|
||||
elif self.config.requires_clean_graph and not is_clean:
|
||||
@ -354,7 +354,7 @@ class BaseTransform(ABC):
|
||||
if self.config.run_shape_prop and not (info.is_clean and info.has_valid_shapes):
|
||||
canonicalize_graph(gm)
|
||||
with lift_to_meta(gm):
|
||||
canonicalize_graph(gm, shape_prop=True)
|
||||
run_shape_prop(gm)
|
||||
elif self.config.run_graph_cleanup and not info.is_clean:
|
||||
canonicalize_graph(gm)
|
||||
|
||||
|
||||
@ -129,7 +129,7 @@ def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None:
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> fx.GraphModule:
|
||||
def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> None:
|
||||
"""Move the entire graph module and all sub-GraphModules to the specified device."""
|
||||
# get device
|
||||
device = torch.device(device)
|
||||
@ -154,9 +154,7 @@ def _is_impure_node(node: Node) -> bool:
|
||||
node.target._nondeterministic_seeded = True
|
||||
|
||||
|
||||
def _canonicalize_single_gm(
|
||||
gm: GraphModule, shape_prop: bool = False, args_static: Optional[Tuple[Any, ...]] = None
|
||||
) -> GraphModule:
|
||||
def _canonicalize_single_gm(gm: GraphModule) -> None:
|
||||
# clean up graph (needs to be done repeatedly until no more dead code)
|
||||
gm.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
|
||||
|
||||
@ -166,59 +164,81 @@ def _canonicalize_single_gm(
|
||||
# clean up graph module
|
||||
gm.delete_all_unused_submodules()
|
||||
|
||||
# NOTE: shape_prop can be a littly finicky & slow, so we only run it optionally...
|
||||
if shape_prop:
|
||||
fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm)
|
||||
|
||||
# get fake tensors from placeholder nodes
|
||||
inps = [node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder"]
|
||||
|
||||
# check if we need to use args to create fake tensors
|
||||
if any(inp is None for inp in inps):
|
||||
if args_static is not None and fake_mode is not None and len(args_static) == len(inps):
|
||||
inps = [
|
||||
fake_t if fake_t is not None else fake_mode.from_tensor(arg, static_shapes=True)
|
||||
for fake_t, arg in zip(inps, args_static)
|
||||
]
|
||||
|
||||
# run shape propagation if we have all the fake tensors
|
||||
if all(inp is not None for inp in inps):
|
||||
FakeTensorProp(gm, fake_mode).propagate(*inps)
|
||||
else:
|
||||
ad_logger.warning("No fake tensors and no args available for shape propagation")
|
||||
|
||||
# lint the graph
|
||||
gm.graph.lint()
|
||||
|
||||
|
||||
def canonicalize_graph(
|
||||
gm: GraphModule, shape_prop: bool = False, args_static: Optional[Tuple[Any, ...]] = None
|
||||
) -> None:
|
||||
def canonicalize_graph(gm: GraphModule) -> None:
|
||||
"""Canonicalize the graph of the given GraphModule.
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to canonicalize.
|
||||
shape_prop: Whether to run shape propagation. Shape propagation tends to be finicky and
|
||||
slow, so we only run it optionally.
|
||||
args_static: A tuple of static arguments to use for shape propagation. Shape propagation
|
||||
requires all inputs to the graph ("placeholder" nodes) to have metadata with an
|
||||
appropriate FakeTensor argument (``node.meta["val"]``). ``args_static`` can be used to
|
||||
infer static FakeTensor information if some placeholder nodes do not have metadata.
|
||||
When ``meta["val"]`` is available, it will take precedence over ``args_static``.
|
||||
|
||||
Returns:
|
||||
The canonicalized (cleaned-up) GraphModule.
|
||||
"""
|
||||
ad_logger.debug(f"Before canonicalizing: {gm}")
|
||||
|
||||
for _, subgm in reversed(list(named_graphmodules(gm))):
|
||||
_canonicalize_single_gm(
|
||||
subgm, shape_prop=shape_prop, args_static=args_static if subgm is gm else None
|
||||
)
|
||||
_canonicalize_single_gm(subgm)
|
||||
|
||||
ad_logger.debug(f"After canonicalizing: {gm}")
|
||||
|
||||
|
||||
def _run_shape_prop_single_gm(
|
||||
gm: GraphModule,
|
||||
args_static: Optional[Tuple[Any, ...]] = None,
|
||||
) -> None:
|
||||
fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm)
|
||||
|
||||
# get fake tensors from placeholder nodes
|
||||
inps = [node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder"]
|
||||
|
||||
# check if we need to use args to create fake tensors
|
||||
if any(inp is None for inp in inps):
|
||||
if args_static is not None and fake_mode is not None and len(args_static) == len(inps):
|
||||
inps = [
|
||||
fake_t if fake_t is not None else fake_mode.from_tensor(arg, static_shapes=True)
|
||||
for fake_t, arg in zip(inps, args_static)
|
||||
]
|
||||
|
||||
# run shape propagation if we have all the fake tensors
|
||||
if all(inp is not None for inp in inps):
|
||||
FakeTensorProp(gm, fake_mode).propagate(*inps)
|
||||
else:
|
||||
ad_logger.warning("No fake tensors and no args available for shape propagation")
|
||||
|
||||
# lint the graph
|
||||
gm.graph.lint()
|
||||
|
||||
|
||||
def run_shape_prop(
|
||||
gm: GraphModule,
|
||||
args_static: Optional[Tuple[Any, ...]] = None,
|
||||
) -> None:
|
||||
"""Run FakeTensor-based shape propagation on the given GraphModule and its submodules.
|
||||
|
||||
This pass attempts to populate shape/type metadata for all nodes by propagating
|
||||
FakeTensor inputs through the graph. If a placeholder node already has a
|
||||
``node.meta["val"]`` entry, that FakeTensor will be used. Otherwise, if
|
||||
``args_static`` is provided and a FakeTensorMode is detected, new FakeTensors
|
||||
are synthesized from the static arguments.
|
||||
|
||||
Args:
|
||||
gm: The top-level GraphModule on which to run shape propagation. All nested
|
||||
GraphModules are processed in reverse topological order.
|
||||
args_static: Optional tuple of concrete tensors used to create FakeTensors
|
||||
when placeholder metadata is missing. Only applied to the top-level
|
||||
GraphModule; submodules reuse their existing placeholder metadata.
|
||||
|
||||
"""
|
||||
ad_logger.debug(f"Before running shape propagation: {gm}")
|
||||
|
||||
for _, subgm in reversed(list(named_graphmodules(gm))):
|
||||
_run_shape_prop_single_gm(subgm, args_static=args_static if subgm is gm else None)
|
||||
|
||||
ad_logger.debug(f"After running shape propagation: {gm}")
|
||||
|
||||
|
||||
def add_graph_input(
|
||||
gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None
|
||||
) -> Node:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user