[TRTLLM-9767][feat] Fixed recursive node traversals (#10379)

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
Grzegorz Kwasniewski 2026-01-05 17:42:06 +01:00 committed by GitHub
parent db2614ef10
commit ea380ff45c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -617,17 +617,16 @@ def predecessors(
"""
preds = []
seen = set()
for arg in node.args:
if isinstance(arg, Node):
if ((not include) or (include and include(arg))) and (not exclude or not exclude(arg)):
if arg not in seen:
preds.append(arg)
seen.add(arg)
if depth > 1:
for p in predecessors(arg, depth - 1, include, exclude):
if p not in seen:
preds.append(p)
seen.add(p)
for arg in node.all_input_nodes:
if ((not include) or (include and include(arg))) and (not exclude or not exclude(arg)):
if arg not in seen:
preds.append(arg)
seen.add(arg)
if depth > 1:
for p in predecessors(arg, depth - 1, include, exclude):
if p not in seen:
preds.append(p)
seen.add(p)
return preds
@ -698,7 +697,7 @@ def subgraph(
def boundary_condition(n):
return n in sources_set
attr_next = "args"
attr_next = "all_input_nodes"
elif sources is not None:
# case 2
assert boundary_condition is not None, "boundary_condition must be provided for case 2"
@ -712,7 +711,7 @@ def subgraph(
# Initialize queue with sinks and mark them as seen
queue = list(sinks)
start_nodes = set(sinks)
attr_next = "args"
attr_next = "all_input_nodes"
else:
raise ValueError("Either sinks or sources must be provided")