@@ -699,8 +699,8 @@ def local_mapped_region(
699699 out : torch .Tensor ,
700700 top_k : int ,
701701 num_experts : int ,
702+ axis_name : str ,
702703) -> tuple [torch .Tensor , torch .Tensor ]:
703- axis_name = "ep"
704704 # assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
705705
706706 dim = x .shape [- 1 ]
@@ -875,9 +875,10 @@ def _moe_forward(
875875 shared_w1 : torch .Tensor ,
876876 shared_w3 : torch .Tensor ,
877877 shared_w2 : torch .Tensor ,
878- router : TokenChoiceTopKRouter , # None
879- reorderer : TokenReorderer , # None
880- mesh : Optional [DeviceMesh ], # None
878+ router : TokenChoiceTopKRouter ,
879+ reorderer : TokenReorderer ,
880+ mesh : Optional [DeviceMesh ],
881+ axis_name : str ,
881882):
882883 # x: 64, 2048, 256
883884 bs , slen , dim = x .shape
@@ -944,6 +945,7 @@ def _moe_forward(
944945 (Shard (0 ), Shard (0 )),
945946 None ,
946947 None ,
948+ None ,
947949 )
948950
949951 # assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
@@ -965,6 +967,7 @@ def _moe_forward(
965967 out ,
966968 router .top_k ,
967969 router .num_experts ,
970+ axis_name ,
968971 )
969972 # assert False, f"there: {out.shape}, {num_tokens_per_expert.shape}"
970973
@@ -1010,6 +1013,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
10101013
10111014 num_experts = moe_args .num_experts
10121015 self .mesh = moe_args .mesh
1016+ self .axis_name = "ep"
10131017 self .experts = GroupedExperts (
10141018 dim = dim ,
10151019 hidden_dim = hidden_dim ,
@@ -1072,9 +1076,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10721076 shared_w1 ,
10731077 shared_w3 ,
10741078 shared_w2 ,
1075- self .router , # None
1076- self .reorderer , # None
1077- self .mesh , # None
1079+ self .router ,
1080+ self .reorderer ,
1081+ self .mesh ,
1082+ self .axis_name ,
10781083 )
10791084
10801085 # HOPs don't support buffer mutations, keep this outside
0 commit comments