|
62 | 62 | UNSHARD, |
63 | 63 | ) |
64 | 64 | from torch.distributed.pipelining.stage import InputInfo, PipelineStage |
| 65 | +from torch.distributed.tensor import DTensor |
65 | 66 | from torch.distributed.tensor.placement_types import Shard |
66 | 67 | from torch.export._unlift import _assign_attr |
67 | 68 | from torch.export.unflatten import _AttrKind |
@@ -116,13 +117,16 @@ def multi_isend(tensor, dst=None, group=None, tag=0, group_src=None): |
116 | 117 | assert group_src is not None, "Expected group rank" |
117 | 118 | peer = get_pp_peer(pp_rank, group_src) |
118 | 119 | print(f"PP peer {group_src} {ctx} multi_isend {peer=}") |
| 120 | + if not isinstance(tensor, LocalTensor): |
| 121 | + tensor = maybe_make_tensor_local(tensor) |
119 | 122 | works = local_p2p_op(peer, tensor, dist.isend) |
120 | 123 | return FakeWork() |
121 | 124 |
|
122 | 125 | def multi_irecv(tensor, src=None, group=None, tag=0, group_src=None): |
123 | 126 | assert group_src is not None, "Expected group rank" |
124 | 127 | peer = get_pp_peer(pp_rank, group_src) |
125 | 128 | print(f"PP peer {group_src} {ctx} multi_irecv {peer=}") |
| 129 | + assert isinstance(tensor, LocalTensor), "Expected LocalTensor" |
126 | 130 | works = local_p2p_op(peer, tensor, dist.irecv) |
127 | 131 | return combine_works(works, f"PP peer {group_src} {ctx} multi_irecv {peer=}") |
128 | 132 |
|
@@ -421,8 +425,6 @@ def shape_inference_output_fn_last_stage(): |
421 | 425 | if run_local: |
422 | 426 | global _pp_groups |
423 | 427 | _pp_groups = enumerate_pp_groups(world_mesh["pp"]) |
424 | | - # for pp_group_ranks in pp_groups: |
425 | | - # _pp_groups.append(default_pg.split_group(pp_group_ranks)) |
426 | 428 |
|
427 | 429 | def run_pp_rank(pp_rank: int): |
428 | 430 | maybe_local_context = ( |
@@ -511,6 +513,7 @@ def run_pp_rank(pp_rank: int): |
511 | 513 | if debug_numerics: |
512 | 514 | print_rank_by_rank("\n".join(numerics_logs)) |
513 | 515 |
|
| 516 | + # breakpoint() |
514 | 517 | if run_local: |
515 | 518 | with LocalRunnerMode( |
516 | 519 | world_size, |
@@ -550,6 +553,13 @@ def maybe_make_tensor_local( |
550 | 553 | if ltm is None: |
551 | 554 | return tensor |
552 | 555 |
|
| 556 | + if isinstance(tensor, LocalTensor): |
| 557 | + return tensor |
| 558 | + |
| 559 | + if isinstance(tensor, DTensor): |
| 560 | + tensor._local_tensor = maybe_make_tensor_local(tensor._local_tensor, ltm) |
| 561 | + return tensor |
| 562 | + |
553 | 563 | local_tensor = ltm.rank_map(lambda r: tensor.clone().detach()) |
554 | 564 | local_tensor.requires_grad = tensor.requires_grad |
555 | 565 | return local_tensor |
|
0 commit comments