Skip to content

Commit 06d5f49

Browse files
Enable DualPipeV by adding a multiplxed graph (#258)
1 parent d17cb4c commit 06d5f49

File tree

5 files changed

+613
-278
lines changed

5 files changed

+613
-278
lines changed

autoparallel/_passes/graph_multiplex.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,8 @@ def multiplex_fw_bw_graph(
3939
multiplexed_gm = copy.deepcopy(fw_gm)
4040

4141
# Collect all placeholder nodes from the backward graph
42-
bw_placeholders = []
43-
for n in bw_gm.graph.nodes:
44-
if n.op == "placeholder":
45-
bw_placeholders.append(n)
42+
bw_placeholders = bw_gm.graph.find_nodes(op="placeholder")
43+
fw_placeholders = fw_gm.graph.find_nodes(op="placeholder")
4644

4745
# Insert backward placeholders at the beginning of the multiplexed graph
4846
# Reversed order ensures correct execution sequence
@@ -54,21 +52,21 @@ def multiplex_fw_bw_graph(
5452
old_node_to_new_node[n] = new_placeholder
5553

5654
# Find the last placeholder and the output node in the multiplexed graph
57-
insert_point = None
58-
multiplexed_graph_op_node = None
59-
for n in multiplexed_gm.graph.nodes:
60-
if n.op == "placeholder":
61-
insert_point = n
62-
if n.op == "output":
63-
multiplexed_graph_op_node = n
55+
multiplxed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder")
56+
assert len(multiplxed_gm_placeholders) == (
57+
len(fw_placeholders) + len(bw_placeholders)
58+
)
59+
insert_point = multiplxed_gm_placeholders[-1]
6460

6561
# Copy all computation nodes from backward graph into multiplexed graph
66-
bw_graph_op_node = None
62+
fw_outputs = fw_gm.graph.find_nodes(op="output")
63+
bw_outputs = bw_gm.graph.find_nodes(op="output")
64+
assert len(bw_outputs) == 1 and len(fw_outputs) == 1
65+
bw_graph_op_node = bw_outputs[0]
6766
for n in bw_gm.graph.nodes:
6867
if n.op == "placeholder":
6968
continue
7069
if n.op == "output":
71-
bw_graph_op_node = n
7270
continue
7371
with multiplexed_gm.graph.inserting_after(insert_point):
7472
# Copy node and remap its arguments using the node mapping
@@ -79,25 +77,20 @@ def multiplex_fw_bw_graph(
7977
old_node_to_new_node[n] = new_node
8078
insert_point = new_node
8179

82-
assert bw_graph_op_node is not None
83-
assert multiplexed_graph_op_node is not None
84-
8580
# Collect output arguments from backward graph, remapping to new nodes
8681
bw_op_node_args = [
8782
old_node_to_new_node[n] if n is not None else None
8883
for n in bw_graph_op_node.args[0]
8984
]
9085

91-
# Collect output arguments from forward graph
86+
# Collect output arguments from multiplexed graph (will contain only fwd_outs)
87+
multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output")
88+
assert len(multiplexed_graph_outputs) == 1
89+
multiplexed_graph_op_node = multiplexed_graph_outputs[0]
9290
fw_op_node_args = list(multiplexed_graph_op_node.args[0])
9391

94-
# Remove the old output node and create new combined output
95-
insert_point = multiplexed_graph_op_node.prev
96-
multiplexed_gm.graph.erase_node(multiplexed_graph_op_node)
97-
98-
# Create combined output with backward outputs first, then forward outputs
99-
with multiplexed_gm.graph.inserting_after(insert_point):
100-
multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args)
92+
# Update output node args to prepend backward outputs before forward outputs
93+
multiplexed_graph_op_node.args = (tuple(bw_op_node_args + fw_op_node_args),)
10194

10295
multiplexed_gm.graph.eliminate_dead_code()
10396
multiplexed_gm.graph.lint()

autoparallel/api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,10 +667,14 @@ def apply_placement_pp(
667667
assert num_params_buffers == (
668668
num_params + num_buffers
669669
), f"num_params_buffers: {num_params_buffers}, num_params: {num_params}, num_buffers: {num_buffers}"
670+
num_input_grads = (
671+
len(bw_module.graph.find_nodes(op="output")[0].args[0]) - num_params_buffers
672+
)
670673
print(
671674
f"num_params_buffers: {num_params_buffers}\n"
672675
f"num_user_outputs: {num_user_outputs}\n"
673676
f"num_mutate_inputs: {num_mutate_inputs}\n"
677+
f"num_input_grads: {num_input_grads}\n"
674678
f"num_fw_outs_saved_for_bw: {num_fw_outs_saved_for_bw}\n"
675679
f"num_symints_saved_for_bw: {num_symints_saved_for_bw}"
676680
)
@@ -753,7 +757,6 @@ def apply_placement_pp(
753757

754758
bw_dI_module: Optional[torch.fx.GraphModule] = None
755759
bw_dW_module: Optional[torch.fx.GraphModule] = None
756-
num_input_grads = 0
757760
if "split_dI_dW" in graph_passes:
758761
from autoparallel._passes.split_di_dw_graph import split_di_dw_graph
759762

0 commit comments

Comments
 (0)