Skip to content

Commit e0801c5

Browse files
committed
Make MoE axis_name patcheable
stack-info: PR: #275, branch: xmfan/stack/26
1 parent 397b7a6 commit e0801c5

File tree

1 file changed

+12
-7
lines changed
  • autoparallel/_testing/models

1 file changed

+12
-7
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -699,8 +699,8 @@ def local_mapped_region(
699699
out: torch.Tensor,
700700
top_k: int,
701701
num_experts: int,
702+
axis_name: str,
702703
) -> tuple[torch.Tensor, torch.Tensor]:
703-
axis_name = "ep"
704704
# assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
705705

706706
dim = x.shape[-1]
@@ -875,9 +875,10 @@ def _moe_forward(
875875
shared_w1: torch.Tensor,
876876
shared_w3: torch.Tensor,
877877
shared_w2: torch.Tensor,
878-
router: TokenChoiceTopKRouter, # None
879-
reorderer: TokenReorderer, # None
880-
mesh: Optional[DeviceMesh], # None
878+
router: TokenChoiceTopKRouter,
879+
reorderer: TokenReorderer,
880+
mesh: Optional[DeviceMesh],
881+
axis_name: str,
881882
):
882883
# x: 64, 2048, 256
883884
bs, slen, dim = x.shape
@@ -944,6 +945,7 @@ def _moe_forward(
944945
(Shard(0), Shard(0)),
945946
None,
946947
None,
948+
None,
947949
)
948950

949951
# assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
@@ -965,6 +967,7 @@ def _moe_forward(
965967
out,
966968
router.top_k,
967969
router.num_experts,
970+
axis_name,
968971
)
969972
# assert False, f"there: {out.shape}, {num_tokens_per_expert.shape}"
970973

@@ -1010,6 +1013,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
10101013

10111014
num_experts = moe_args.num_experts
10121015
self.mesh = moe_args.mesh
1016+
self.axis_name = "ep"
10131017
self.experts = GroupedExperts(
10141018
dim=dim,
10151019
hidden_dim=hidden_dim,
@@ -1072,9 +1076,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10721076
shared_w1,
10731077
shared_w3,
10741078
shared_w2,
1075-
self.router, # None
1076-
self.reorderer, # None
1077-
self.mesh, # None
1079+
self.router,
1080+
self.reorderer,
1081+
self.mesh,
1082+
self.axis_name,
10781083
)
10791084

10801085
# HOPs don't support buffer mutations, keep this outside

0 commit comments

Comments
 (0)