@@ -340,10 +340,30 @@ def build_optimizers_with_moe_load_balancing(
340340 ft_manager = ft_manager ,
341341 )
342342
343+ # AP friendly methods
344+ def is_moe_block (block ):
345+ moe_enabled = getattr (block , "moe_enabled" , False )
346+ has_moe_submod = hasattr (block , "moe" ) # AP
347+ return moe_enabled or has_moe_submod
348+
349+ def should_manual_allreduce (tokens_per_expert_by_layer ):
350+ return not isinstance (
351+ tokens_per_expert_by_layer , torch .distributed .tensor .DTensor
352+ )
353+
354+ def get_transformer_blocks (model_part ):
355+ if isinstance (model_part .layers , nn .ModuleDict ):
356+ # regular torchtitan
357+ blocks = model_part .layers .values ()
358+ else :
359+ # TODO: fix autoparallel to preserve the module dict
360+ blocks = model_part .layers .children ()
361+ return blocks
362+
343363 def _should_register_moe_balancing_hook (model_parts : list [nn .Module ]) -> bool :
344364 for model_part in model_parts :
345- for transformer_block in model_part . layers . values ( ):
346- if transformer_block . moe_enabled :
365+ for transformer_block in get_transformer_blocks ( model_part ):
366+ if is_moe_block ( transformer_block ) :
347367 # Assumption: load_balance_coeff is set universally on all moe blocks.
348368 return bool (transformer_block .moe .load_balance_coeff )
349369 return False
@@ -360,26 +380,6 @@ def _update_expert_bias(
360380 parallel_dims .world_mesh ["dp_cp" ] if parallel_dims .dp_cp_enabled else None
361381 )
362382
363- # AP friendly methods
364- def is_moe_block (block ):
365- moe_enabled = getattr (block , "moe_enabled" , False )
366- has_moe_submod = hasattr (block , "moe" ) # AP
367- return moe_enabled or has_moe_submod
368-
369- def get_transformer_blocks (model_part ):
370- if isinstance (model_part .layers , nn .ModuleDict ):
371- # regular torchtitan
372- blocks = model_part .layers .values ()
373- else :
374- # TODO: fix autoparallel to preserve the module dict
375- blocks = model_part .layers .children ()
376- return blocks
377-
378- def should_manual_allreduce (tokens_per_expert_by_layer ):
379- return not isinstance (
380- tokens_per_expert_by_layer , torch .distributed .tensor .DTensor
381- )
382-
383383 # TODO: Currently this sync is blocking (thus exposed) and happens on the
384384 # default compute stream. Need to assess if this is OK performance-wise.
385385 tokens_per_expert_list = []
0 commit comments