@@ -702,8 +702,8 @@ def local_mapped_region(
702702 out : torch .Tensor ,
703703 top_k : int ,
704704 num_experts : int ,
705+ axis_name : str ,
705706) -> tuple [torch .Tensor , torch .Tensor ]:
706- axis_name = "ep"
707707 # assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
708708
709709 dim = x .shape [- 1 ]
@@ -878,9 +878,10 @@ def _moe_forward(
878878 shared_w1 : torch .Tensor ,
879879 shared_w3 : torch .Tensor ,
880880 shared_w2 : torch .Tensor ,
881- router : TokenChoiceTopKRouter , # None
882- reorderer : TokenReorderer , # None
883- mesh : Optional [DeviceMesh ], # None
881+ router : TokenChoiceTopKRouter ,
882+ reorderer : TokenReorderer ,
883+ mesh : Optional [DeviceMesh ],
884+ axis_name : str ,
884885):
885886 # x: 64, 2048, 256
886887 bs , slen , dim = x .shape
@@ -947,6 +948,7 @@ def _moe_forward(
947948 (Shard (0 ), Shard (0 )),
948949 None ,
949950 None ,
951+ None ,
950952 )
951953
952954 # assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
@@ -968,6 +970,7 @@ def _moe_forward(
968970 out ,
969971 router .top_k ,
970972 router .num_experts ,
973+ axis_name ,
971974 )
972975 # assert False, f"there: {out.shape}, {num_tokens_per_expert.shape}"
973976
@@ -1013,6 +1016,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
10131016
10141017 num_experts = moe_args .num_experts
10151018 self .mesh = moe_args .mesh
1019+ self .axis_name = "ep"
10161020 self .experts = GroupedExperts (
10171021 dim = dim ,
10181022 hidden_dim = hidden_dim ,
@@ -1075,9 +1079,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10751079 shared_w1 ,
10761080 shared_w3 ,
10771081 shared_w2 ,
1078- self .router , # None
1079- self .reorderer , # None
1080- self .mesh , # None
1082+ self .router ,
1083+ self .reorderer ,
1084+ self .mesh ,
1085+ self .axis_name ,
10811086 )
10821087
10831088 # HOPs don't support buffer mutations, keep this outside
0 commit comments