You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Avoid nn.Linear decomposition by replacing view -> mm -> view with einsum (#26)
* [WIP] Replace view -> mm -> view with matmul
This tries to support CP-style sharding, by overcoming a limitation of DTensor. Doesn't yet work as _mm_strategy is failing
* Fix matmul propagation rule
Somethings are starting to work, but we are not yet there
* Move function to graph_utils.py
* Pull improvements from #29
* Fix equation for einsum
* Cleanup code now that PyTorch has fixed _gen_einsum_strategies
Requires pytorch/pytorch#157593
* Generalize to more than 3d
* Generalize backward pass as well and make everything call into einsum
* Add note about future work
* Add einsum flops and generalize creation of sharded tensors
Before this, if we had a list of tensors we wouldn't shard the tensors inside the list
* Disable erroneous sdpa rule from backward
* Account for compute cost in collectives as well
This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement
* Account for compute cost in collectives as well
This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement
* Support getitem as well
* Improve comments and suppose 80% efficiency
* Suppose 70% efficiency for comms
* Add comment and set it to false by default
* Revert changes from another PR
* Add spaces back
0 commit comments