@@ -39,10 +39,8 @@ def multiplex_fw_bw_graph(
3939 multiplexed_gm = copy .deepcopy (fw_gm )
4040
4141 # Collect all placeholder nodes from the backward graph
42- bw_placeholders = []
43- for n in bw_gm .graph .nodes :
44- if n .op == "placeholder" :
45- bw_placeholders .append (n )
42+ bw_placeholders = bw_gm .graph .find_nodes (op = "placeholder" )
43+ fw_placeholders = fw_gm .graph .find_nodes (op = "placeholder" )
4644
4745 # Insert backward placeholders at the beginning of the multiplexed graph
4846 # Reversed order ensures correct execution sequence
@@ -54,21 +52,21 @@ def multiplex_fw_bw_graph(
5452 old_node_to_new_node [n ] = new_placeholder
5553
5654 # Find the last placeholder and the output node in the multiplexed graph
57- insert_point = None
58- multiplexed_graph_op_node = None
59- for n in multiplexed_gm .graph .nodes :
60- if n .op == "placeholder" :
61- insert_point = n
62- if n .op == "output" :
63- multiplexed_graph_op_node = n
55+ multiplxed_gm_placeholders = multiplexed_gm .graph .find_nodes (op = "placeholder" )
56+ assert len (multiplxed_gm_placeholders ) == (
57+ len (fw_placeholders ) + len (bw_placeholders )
58+ )
59+ insert_point = multiplxed_gm_placeholders [- 1 ]
6460
6561 # Copy all computation nodes from backward graph into multiplexed graph
66- bw_graph_op_node = None
62+ fw_outputs = fw_gm .graph .find_nodes (op = "output" )
63+ bw_outputs = bw_gm .graph .find_nodes (op = "output" )
64+ assert len (bw_outputs ) == 1 and len (fw_outputs ) == 1
65+ bw_graph_op_node = bw_outputs [0 ]
6766 for n in bw_gm .graph .nodes :
6867 if n .op == "placeholder" :
6968 continue
7069 if n .op == "output" :
71- bw_graph_op_node = n
7270 continue
7371 with multiplexed_gm .graph .inserting_after (insert_point ):
7472 # Copy node and remap its arguments using the node mapping
@@ -79,25 +77,20 @@ def multiplex_fw_bw_graph(
7977 old_node_to_new_node [n ] = new_node
8078 insert_point = new_node
8179
82- assert bw_graph_op_node is not None
83- assert multiplexed_graph_op_node is not None
84-
8580 # Collect output arguments from backward graph, remapping to new nodes
8681 bw_op_node_args = [
8782 old_node_to_new_node [n ] if n is not None else None
8883 for n in bw_graph_op_node .args [0 ]
8984 ]
9085
91- # Collect output arguments from forward graph
86+ # Collect output arguments from multiplexed graph (will contain only fwd_outs)
87+ multiplexed_graph_outputs = multiplexed_gm .graph .find_nodes (op = "output" )
88+ assert len (multiplexed_graph_outputs ) == 1
89+ multiplexed_graph_op_node = multiplexed_graph_outputs [0 ]
9290 fw_op_node_args = list (multiplexed_graph_op_node .args [0 ])
9391
94- # Remove the old output node and create new combined output
95- insert_point = multiplexed_graph_op_node .prev
96- multiplexed_gm .graph .erase_node (multiplexed_graph_op_node )
97-
98- # Create combined output with backward outputs first, then forward outputs
99- with multiplexed_gm .graph .inserting_after (insert_point ):
100- multiplexed_gm .graph .output (bw_op_node_args + fw_op_node_args )
92+ # Update output node args to prepend backward outputs before forward outputs
93+ multiplexed_graph_op_node .args = (tuple (bw_op_node_args + fw_op_node_args ),)
10194
10295 multiplexed_gm .graph .eliminate_dead_code ()
10396 multiplexed_gm .graph .lint ()
0 commit comments