Skip to content

Commit d54a6d4

Browse files
committed
temporarily extend hacky optimizer stuff to make dsv3 ap 1d run again
1 parent 6cc8caa commit d54a6d4

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

torchtitan/components/optimizer.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,30 @@ def build_optimizers_with_moe_load_balancing(
340340
ft_manager=ft_manager,
341341
)
342342

343+
# AP friendly methods
344+
def is_moe_block(block):
345+
moe_enabled = getattr(block, "moe_enabled", False)
346+
has_moe_submod = hasattr(block, "moe") # AP
347+
return moe_enabled or has_moe_submod
348+
349+
def should_manual_allreduce(tokens_per_expert_by_layer):
350+
return not isinstance(
351+
tokens_per_expert_by_layer, torch.distributed.tensor.DTensor
352+
)
353+
354+
def get_transformer_blocks(model_part):
355+
if isinstance(model_part.layers, nn.ModuleDict):
356+
# regular torchtitan
357+
blocks = model_part.layers.values()
358+
else:
359+
# TODO: fix autoparallel to preserve the module dict
360+
blocks = model_part.layers.children()
361+
return blocks
362+
343363
def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool:
344364
for model_part in model_parts:
345-
for transformer_block in model_part.layers.values():
346-
if transformer_block.moe_enabled:
365+
for transformer_block in get_transformer_blocks(model_part):
366+
if is_moe_block(transformer_block):
347367
# Assumption: load_balance_coeff is set universally on all moe blocks.
348368
return bool(transformer_block.moe.load_balance_coeff)
349369
return False
@@ -360,26 +380,6 @@ def _update_expert_bias(
360380
parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None
361381
)
362382

363-
# AP friendly methods
364-
def is_moe_block(block):
365-
moe_enabled = getattr(block, "moe_enabled", False)
366-
has_moe_submod = hasattr(block, "moe") # AP
367-
return moe_enabled or has_moe_submod
368-
369-
def get_transformer_blocks(model_part):
370-
if isinstance(model_part.layers, nn.ModuleDict):
371-
# regular torchtitan
372-
blocks = model_part.layers.values()
373-
else:
374-
# TODO: fix autoparallel to preserve the module dict
375-
blocks = model_part.layers.children()
376-
return blocks
377-
378-
def should_manual_allreduce(tokens_per_expert_by_layer):
379-
return not isinstance(
380-
tokens_per_expert_by_layer, torch.distributed.tensor.DTensor
381-
)
382-
383383
# TODO: Currently this sync is blocking (thus exposed) and happens on the
384384
# default compute stream. Need to assess if this is OK performance-wise.
385385
tokens_per_expert_list = []

0 commit comments

Comments
 (0)