From ea380ff45cdffd890cdcb2f80a8db1d66b84ebc2 Mon Sep 17 00:00:00 2001 From: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Mon, 5 Jan 2026 17:42:06 +0100 Subject: [PATCH] [TRTLLM-9767][feat] Fixed recursive node traversals (#10379) Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../_torch/auto_deploy/utils/node_utils.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index ded746fe78..795bfa4f2d 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -617,17 +617,16 @@ def predecessors( """ preds = [] seen = set() - for arg in node.args: - if isinstance(arg, Node): - if ((not include) or (include and include(arg))) and (not exclude or not exclude(arg)): - if arg not in seen: - preds.append(arg) - seen.add(arg) - if depth > 1: - for p in predecessors(arg, depth - 1, include, exclude): - if p not in seen: - preds.append(p) - seen.add(p) + for arg in node.all_input_nodes: + if ((not include) or (include and include(arg))) and (not exclude or not exclude(arg)): + if arg not in seen: + preds.append(arg) + seen.add(arg) + if depth > 1: + for p in predecessors(arg, depth - 1, include, exclude): + if p not in seen: + preds.append(p) + seen.add(p) return preds @@ -698,7 +697,7 @@ def subgraph( def boundary_condition(n): return n in sources_set - attr_next = "args" + attr_next = "all_input_nodes" elif sources is not None: # case 2 assert boundary_condition is not None, "boundary_condition must be provided for case 2" @@ -712,7 +711,7 @@ def subgraph( # Initialize queue with sinks and mark them as seen queue = list(sinks) start_nodes = set(sinks) - attr_next = "args" + attr_next = "all_input_nodes" else: raise ValueError("Either sinks or sources must be provided")