Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions llama3_8b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import nemo_run as run

from nemo.collections.llm.recipes import llama3_8b

if __name__ == "__main__":
pretrain = llama3_8b.pretrain_recipe(num_nodes=1, num_gpus_per_node=8, performance_mode=True)

pretrain.trainer.strategy.context_parallel_size = 1
pretrain.trainer.log_every_n_steps = 1
pretrain.data.global_batch_size = 16
pretrain.data.seq_length = 64
pretrain.trainer.max_steps = 10

pretrain.trainer.strategy.fsdp = 'megatron'
pretrain.trainer.strategy.ddp.average_in_collective = False
pretrain.trainer.strategy.ddp.use_megatron_fsdp = True
pretrain.trainer.strategy.save_ckpt_format = 'fsdp_dtensor'
# pretrain.trainer.strategy.gradient_accumulation_fusion=False

# # included in the performance mode but not normal mode

pretrain.trainer.strategy.ddp.grad_reduce_in_fp32 = False
pretrain.trainer.plugins.grad_reduce_in_fp32 = False
pretrain.optim.config.use_precision_aware_optimizer = False
pretrain.optim.config.use_megatron_fsdp = True
# pretrain.data.seq_length = 4096

run.run(pretrain)
2 changes: 1 addition & 1 deletion nemo/collections/llm/recipes/llama3_8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def model() -> run.Config[pl.LightningModule]:
>>> model_config = model()
>>> print(model_config)
"""
return run.Config(LlamaModel, config=run.Config(Llama3Config8B))
return run.Config(LlamaModel, config=run.Config(Llama3Config8B, gradient_accumulation_fusion=False))


def trainer(
Expand Down
9 changes: 4 additions & 5 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,8 @@ def __init__(
"Setting FSDP option to megatron"
)
fsdp = 'megatron'
if use_megatron_fsdp and self.save_ckpt_format != "fsdp_dtensor":
raise NotImplementedError(
f"Megatron-FSDP checkpointing is not supported with {self.save_ckpt_format}."
)
if use_megatron_fsdp and self.save_ckpt_format != "fsdp_dtensor":
raise NotImplementedError(f"Megatron-FSDP checkpointing is not supported with {self.save_ckpt_format}.")

if fsdp == "pytorch":
raise NotImplementedError("PyTorch FSDP2 is not supported with MegatronParallel.")
Expand Down Expand Up @@ -1052,7 +1050,8 @@ def should_restore_optimizer_states(self, selective_restore: bool = False) -> bo
def _save_fsdp_dtensor_common_state(self, state_dict, ckpt_dir):
state_dict = state_dict.copy()
del state_dict["model"]
del state_dict["optimizer_states"]
if "optimizer_states" in state_dict:
del state_dict["optimizer_states"]
torch.save(state_dict, os.path.join(ckpt_dir, "common.pt"))

def _load_fsdp_dtensor_common_state(self, ckpt_dir):
Expand Down
Loading