Skip to content

Commit 5f52024

Browse files
committed
Update on "Prototype to run AutoParallel PP with Local Tensor"
[ghstack-poisoned]
1 parent 24ddcdd commit 5f52024

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

examples/example_ds3_pp_local_tensor.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
UNSHARD,
6363
)
6464
from torch.distributed.pipelining.stage import InputInfo, PipelineStage
65+
from torch.distributed.tensor import DTensor
6566
from torch.distributed.tensor.placement_types import Shard
6667
from torch.export._unlift import _assign_attr
6768
from torch.export.unflatten import _AttrKind
@@ -116,13 +117,16 @@ def multi_isend(tensor, dst=None, group=None, tag=0, group_src=None):
116117
assert group_src is not None, "Expected group rank"
117118
peer = get_pp_peer(pp_rank, group_src)
118119
print(f"PP peer {group_src} {ctx} multi_isend {peer=}")
120+
if not isinstance(tensor, LocalTensor):
121+
tensor = maybe_make_tensor_local(tensor)
119122
works = local_p2p_op(peer, tensor, dist.isend)
120123
return FakeWork()
121124

122125
def multi_irecv(tensor, src=None, group=None, tag=0, group_src=None):
123126
assert group_src is not None, "Expected group rank"
124127
peer = get_pp_peer(pp_rank, group_src)
125128
print(f"PP peer {group_src} {ctx} multi_irecv {peer=}")
129+
assert isinstance(tensor, LocalTensor), "Expected LocalTensor"
126130
works = local_p2p_op(peer, tensor, dist.irecv)
127131
return combine_works(works, f"PP peer {group_src} {ctx} multi_irecv {peer=}")
128132

@@ -421,8 +425,6 @@ def shape_inference_output_fn_last_stage():
421425
if run_local:
422426
global _pp_groups
423427
_pp_groups = enumerate_pp_groups(world_mesh["pp"])
424-
# for pp_group_ranks in pp_groups:
425-
# _pp_groups.append(default_pg.split_group(pp_group_ranks))
426428

427429
def run_pp_rank(pp_rank: int):
428430
maybe_local_context = (
@@ -511,6 +513,7 @@ def run_pp_rank(pp_rank: int):
511513
if debug_numerics:
512514
print_rank_by_rank("\n".join(numerics_logs))
513515

516+
# breakpoint()
514517
if run_local:
515518
with LocalRunnerMode(
516519
world_size,
@@ -550,6 +553,13 @@ def maybe_make_tensor_local(
550553
if ltm is None:
551554
return tensor
552555

556+
if isinstance(tensor, LocalTensor):
557+
return tensor
558+
559+
if isinstance(tensor, DTensor):
560+
tensor._local_tensor = maybe_make_tensor_local(tensor._local_tensor, ltm)
561+
return tensor
562+
553563
local_tensor = ltm.rank_map(lambda r: tensor.clone().detach())
554564
local_tensor.requires_grad = tensor.requires_grad
555565
return local_tensor

0 commit comments

Comments
 (0)