mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9767][feat] Fixed recursive node traversals (#10379)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
parent
db2614ef10
commit
ea380ff45c
@ -617,17 +617,16 @@ def predecessors(
|
||||
"""
|
||||
preds = []
|
||||
seen = set()
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if ((not include) or (include and include(arg))) and (not exclude or not exclude(arg)):
|
||||
if arg not in seen:
|
||||
preds.append(arg)
|
||||
seen.add(arg)
|
||||
if depth > 1:
|
||||
for p in predecessors(arg, depth - 1, include, exclude):
|
||||
if p not in seen:
|
||||
preds.append(p)
|
||||
seen.add(p)
|
||||
for arg in node.all_input_nodes:
|
||||
if ((not include) or (include and include(arg))) and (not exclude or not exclude(arg)):
|
||||
if arg not in seen:
|
||||
preds.append(arg)
|
||||
seen.add(arg)
|
||||
if depth > 1:
|
||||
for p in predecessors(arg, depth - 1, include, exclude):
|
||||
if p not in seen:
|
||||
preds.append(p)
|
||||
seen.add(p)
|
||||
return preds
|
||||
|
||||
|
||||
@ -698,7 +697,7 @@ def subgraph(
|
||||
def boundary_condition(n):
|
||||
return n in sources_set
|
||||
|
||||
attr_next = "args"
|
||||
attr_next = "all_input_nodes"
|
||||
elif sources is not None:
|
||||
# case 2
|
||||
assert boundary_condition is not None, "boundary_condition must be provided for case 2"
|
||||
@ -712,7 +711,7 @@ def subgraph(
|
||||
# Initialize queue with sinks and mark them as seen
|
||||
queue = list(sinks)
|
||||
start_nodes = set(sinks)
|
||||
attr_next = "args"
|
||||
attr_next = "all_input_nodes"
|
||||
else:
|
||||
raise ValueError("Either sinks or sources must be provided")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user