Skip to content

Commit 2790b15

Browse files
committed
Make MoE axis_name patcheable
stack-info: PR: #275, branch: xmfan/stack/26
1 parent 5b91ac1 commit 2790b15

File tree

1 file changed

+12
-7
lines changed
  • autoparallel/_testing/models

1 file changed

+12
-7
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,8 @@ def local_mapped_region(
702702
out: torch.Tensor,
703703
top_k: int,
704704
num_experts: int,
705+
axis_name: str,
705706
) -> tuple[torch.Tensor, torch.Tensor]:
706-
axis_name = "ep"
707707
# assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
708708

709709
dim = x.shape[-1]
@@ -878,9 +878,10 @@ def _moe_forward(
878878
shared_w1: torch.Tensor,
879879
shared_w3: torch.Tensor,
880880
shared_w2: torch.Tensor,
881-
router: TokenChoiceTopKRouter, # None
882-
reorderer: TokenReorderer, # None
883-
mesh: Optional[DeviceMesh], # None
881+
router: TokenChoiceTopKRouter,
882+
reorderer: TokenReorderer,
883+
mesh: Optional[DeviceMesh],
884+
axis_name: str,
884885
):
885886
# x: 64, 2048, 256
886887
bs, slen, dim = x.shape
@@ -947,6 +948,7 @@ def _moe_forward(
947948
(Shard(0), Shard(0)),
948949
None,
949950
None,
951+
None,
950952
)
951953

952954
# assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
@@ -968,6 +970,7 @@ def _moe_forward(
968970
out,
969971
router.top_k,
970972
router.num_experts,
973+
axis_name,
971974
)
972975
# assert False, f"there: {out.shape}, {num_tokens_per_expert.shape}"
973976

@@ -1013,6 +1016,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
10131016

10141017
num_experts = moe_args.num_experts
10151018
self.mesh = moe_args.mesh
1019+
self.axis_name = "ep"
10161020
self.experts = GroupedExperts(
10171021
dim=dim,
10181022
hidden_dim=hidden_dim,
@@ -1075,9 +1079,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10751079
shared_w1,
10761080
shared_w3,
10771081
shared_w2,
1078-
self.router, # None
1079-
self.reorderer, # None
1080-
self.mesh, # None
1082+
self.router,
1083+
self.reorderer,
1084+
self.mesh,
1085+
self.axis_name,
10811086
)
10821087

10831088
# HOPs don't support buffer mutations, keep this outside

0 commit comments

Comments
 (0)