diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index efb4352..e985b2c 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -702,8 +702,8 @@ def local_mapped_region( out: torch.Tensor, top_k: int, num_experts: int, + axis_name: str, ) -> tuple[torch.Tensor, torch.Tensor]: - axis_name = "ep" # assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}" dim = x.shape[-1] @@ -878,9 +878,10 @@ def _moe_forward( shared_w1: torch.Tensor, shared_w3: torch.Tensor, shared_w2: torch.Tensor, - router: TokenChoiceTopKRouter, # None - reorderer: TokenReorderer, # None - mesh: Optional[DeviceMesh], # None + router: TokenChoiceTopKRouter, + reorderer: TokenReorderer, + mesh: Optional[DeviceMesh], + axis_name: str, ): # x: 64, 2048, 256 bs, slen, dim = x.shape @@ -947,6 +948,7 @@ def _moe_forward( (Shard(0), Shard(0)), None, None, + None, ) # assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}" @@ -968,6 +970,7 @@ def _moe_forward( out, router.top_k, router.num_experts, + axis_name, ) # assert False, f"there: {out.shape}, {num_tokens_per_expert.shape}" @@ -1013,6 +1016,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): num_experts = moe_args.num_experts self.mesh = moe_args.mesh + self.axis_name = "ep" self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, @@ -1075,9 +1079,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: shared_w1, shared_w3, shared_w2, - self.router, # None - self.reorderer, # None - self.mesh, # None + self.router, + self.reorderer, + self.mesh, + self.axis_name, ) # HOPs don't support buffer mutations, keep this outside