Skip to content

Commit c680107

Browse files
authored
Avoid nn.Linear decomposition by replacing view -> mm -> view with einsum (#26)
* [WIP] Replace view -> mm -> view with matmul This tries to support CP-style sharding, by overcoming a limitation of DTensor. Doesn't yet work as _mm_strategy is failing * Fix matmul propagation rule Somethings are starting to work, but we are not yet there * Move function to graph_utils.py * Pull improvements from #29 * Fix equation for einsum * Cleanup code now that PyTorch has fixed _gen_einsum_strategies Requires pytorch/pytorch#157593 * Generalize to more than 3d * Generalize backward pass as well and make everything call into einsum * Add note about future work * Add einsum flops and generalize creation of sharded tensors Before this, if we had a list of tensors we wouldn't shard the tensors inside the list * Disable erroneous sdpa rule from backward * Account for compute cost in collectives as well This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement * Account for compute cost in collectives as well This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement * Support getitem as well * Improve comments and suppose 80% efficiency * Suppose 70% efficiency for comms * Add comment and set it to false by default * Revert changes from another PR * Add spaces back
1 parent 8737c5a commit c680107

File tree

4 files changed

+133
-1
lines changed

4 files changed

+133
-1
lines changed

autoparallel/api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
3030
from .graph_utils import (
3131
_add_alias,
32+
_replace_view_mm_view_with_einsum,
3233
assert_has_no_collectives,
3334
cleanup_graph,
3435
update_joint_with_descriptors,
@@ -37,6 +38,8 @@
3738
from .optimize_sharding import ShardingOptimizer
3839
from .utils import _get_device_from_mesh
3940

41+
_APPLY_VIEW_MM_VIEW_PATTERN = False
42+
4043

4144
def try_convert_fake_to_real(tensors):
4245
out = {}
@@ -230,6 +233,8 @@ def build_model_graph(self):
230233
assert_has_no_collectives(gm)
231234

232235
cleanup_graph(gm)
236+
if _APPLY_VIEW_MM_VIEW_PATTERN:
237+
_replace_view_mm_view_with_einsum(gm)
233238
# now add aliases nodes to the graph to
234239
# give more room for optimizations
235240
_add_alias(gm)

autoparallel/compute_estimation.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,34 @@
88

99
import torch
1010
from torch.utils._pytree import tree_flatten, tree_map_only
11-
from torch.utils.flop_counter import FlopCounterMode
11+
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
12+
13+
14+
@register_flop_formula(torch.ops.aten.einsum, get_raw=True)
15+
def einsum_flop(equation, tensors, out=None, **kwargs) -> int:
16+
# from torch.distributed.tensor._ops._einsum_strategy import EinsumDims
17+
assert len(tensors) == 2
18+
a_shape, b_shape = [x.shape for x in tensors]
19+
20+
# parse einop equation and extract dims
21+
# TODO: generalize
22+
# input_dims, output_dim = EinsumDims.parse_equation(equation)
23+
# edims = EinsumDims.parse_dims(input_dims, output_dim)
24+
25+
if len(a_shape) == 3 and len(b_shape) == 3:
26+
b, m, k = a_shape
27+
b1, n, k2 = b_shape
28+
assert b == b1
29+
assert m == n
30+
flop = (b * m) * k * k2 * 2
31+
elif len(a_shape) == 3 and len(b_shape) == 2:
32+
b, m, k = a_shape
33+
k2, n = b_shape
34+
assert k == k2
35+
flop = b * m * n * k * 2
36+
else:
37+
raise NotImplementedError(f"Unsupported einsum shapes: {a_shape} {b_shape}")
38+
return flop
1239

1340

1441
@dataclass

autoparallel/graph_utils.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,81 @@ def assert_has_no_collectives(gm: torch.fx.GraphModule):
153153
f"autoparallel.local_map_hop.apply_local_map, see "
154154
"examples/example_local_map.py for more information."
155155
)
156+
157+
158+
# NOTE: [nn.Linear decomposition]
159+
# PyTorch currently decomposes any 3d-input nn.Linear (and matmul) into a
160+
# sequence of view -> mm -> view operations.
161+
# This has as a consequence of breaking any type of sharding on both the
162+
# batch and the sequence dimension, because the flattening that happens doesn't
163+
# allow to preserve this sharding.
164+
# While we wait for PyTorch to avoid decomposing nn.Linear, we instead take
165+
# the route of pattern-matching the nn.Linear specific occurences, and we replace
166+
# them with an einsum operator.
167+
# We perform this pattern-matching replacement for both the forward as well as
168+
# the backward pass.
169+
# TODO: use graph_patterns to simplify writing this
170+
def _replace_view_mm_view_with_einsum(gm):
171+
mm_nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.aten.mm.default)
172+
for node in mm_nodes:
173+
first_input, second_input = node.all_input_nodes
174+
if first_input.target == torch.ops.aten.view.default:
175+
view_input = first_input.all_input_nodes[0]
176+
users = list(node.users)
177+
if (
178+
len(users) == 1
179+
and users[0].target == torch.ops.aten.view.default
180+
and view_input.meta["val"].shape[:-1] == users[0].meta["val"].shape[:-1]
181+
and second_input.meta["val"].ndim == 2
182+
):
183+
print(
184+
f"Found matmul node {node}, {view_input.meta['val'].shape, second_input.meta['val'].shape}"
185+
)
186+
ndim = view_input.meta["val"].ndim
187+
assert 1 < ndim <= 10, "Only support up to 10D for now"
188+
189+
# generate the leading dimensions as a, b, c, etc
190+
dims = "".join([chr(97 + i) for i in range(ndim - 1)])
191+
mm_equation = f"{dims}k,kn->{dims}n"
192+
with gm.graph.inserting_before(node):
193+
new_node = gm.graph.call_function(
194+
torch.ops.aten.einsum.default,
195+
args=(mm_equation, [view_input, second_input]),
196+
)
197+
new_node.meta.update(users[0].meta)
198+
users[0].replace_all_uses_with(new_node)
199+
200+
elif second_input.target == torch.ops.aten.view.default:
201+
if first_input.target != torch.ops.aten.permute.default:
202+
continue
203+
if first_input.all_input_nodes[0].target != torch.ops.aten.view.default:
204+
continue
205+
orig_first = first_input.all_input_nodes[0].all_input_nodes[0]
206+
orig_second = second_input.all_input_nodes[0]
207+
users = list(node.users)
208+
if (
209+
len(users) == 1
210+
and users[0].target == torch.ops.aten.permute.default
211+
and orig_first.meta["val"].shape[:-1]
212+
== orig_second.meta["val"].shape[:-1]
213+
and node.meta["val"].ndim == 2
214+
):
215+
print(
216+
f"Found matmul node {node} {orig_first.meta['val'].shape, orig_second.meta['val'].shape}"
217+
)
218+
219+
ndim = orig_first.meta["val"].ndim
220+
assert 1 < ndim <= 10, "Only support up to 10D for now"
221+
222+
# generate the leading dimensions as a, b, c, etc
223+
dims = "".join([chr(97 + i) for i in range(ndim - 1)])
224+
mm_equation = f"{dims}n,{dims}k->kn"
225+
with gm.graph.inserting_before(node):
226+
new_node = gm.graph.call_function(
227+
torch.ops.aten.einsum.default,
228+
args=(mm_equation, [orig_first, orig_second]),
229+
)
230+
new_node.meta.update(users[0].meta)
231+
users[0].replace_all_uses_with(new_node)
232+
gm.graph.eliminate_dead_code()
233+
gm.recompile()

autoparallel/propagation_rules.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,3 +761,25 @@ def expand_rule(mesh, op_schema_):
761761
for remov in to_remove:
762762
ss.redistribute_cost[0].insert(remov, math.inf)
763763
return out_strat
764+
765+
766+
@register_opschema_rule(torch.ops.aten.einsum.default)
767+
def einsum_rule(mesh, op_schema):
768+
from torch.distributed.tensor._op_schema import TupleStrategy
769+
from torch.distributed.tensor._ops._matrix_ops import _mm_like_strategy
770+
771+
mm_equation, mat_strategy = op_schema.args_schema
772+
assert isinstance(mm_equation, str)
773+
assert isinstance(mat_strategy, TupleStrategy)
774+
775+
assert len(mat_strategy.children) == 2, "Only two args to einsum supported for now"
776+
777+
self_strategy, mat2_strategy = mat_strategy.children
778+
779+
# dispatch to mm_like_strategy
780+
new_op_schema = OpSchema(
781+
torch.ops.aten.einsum.default,
782+
args_schema=(self_strategy, mat2_strategy),
783+
kwargs_schema={},
784+
)
785+
return _mm_like_strategy(mm_equation, mesh, new_op_schema)

0 commit comments

Comments
 (0)