Skip to content

Commit 4873cc3

Browse files
pthombreakoumpa
authored andcommitted
Integrate with StepScheduler and checkpointing code
Signed-off-by: Pranav Prashant Thombre <[email protected]>
1 parent e515e5c commit 4873cc3

File tree

3 files changed

+130
-238
lines changed

3 files changed

+130
-238
lines changed

examples/diffusion/finetune/wan2_1_t2v_flow.yaml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ data:
2222
num_nodes: 1
2323

2424
batch:
25-
batch_size_per_node: 1
25+
batch_size_per_node: 8
2626

2727
training:
2828
num_epochs: 20
@@ -43,13 +43,19 @@ flow_matching:
4343

4444
fsdp:
4545
cpu_offload: true
46+
tp_size: 1
47+
cp_size: 1
48+
pp_size: 1
4649

4750
logging:
4851
save_every: 50
4952
log_every: 2
5053

5154
checkpoint:
52-
output_dir: /opt/Automodel/wan_t2v_flow_outputs_updated/
53-
resume: null
55+
enabled: true
56+
checkpoint_dir: /opt/Automodel/wan_t2v_flow_outputs_base_recipe_checkpoint_NEW_new/
57+
model_save_format: torch_save
58+
save_consolidated: false
59+
restore_from: null
5460

5561

nemo_automodel/components/_diffusers/utils/validate_t2v.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main():
139139
# Try EMA checkpoint first (best quality)
140140
ema_path = os.path.join(args.checkpoint, "ema_shadow.pt")
141141
consolidated_path = os.path.join(args.checkpoint, "consolidated_model.bin")
142-
sharded_dir = os.path.join(args.checkpoint, "transformer_model")
142+
sharded_dir = os.path.join(args.checkpoint, "model")
143143

144144
if os.path.exists(ema_path):
145145
print(f"[INFO] Loading EMA checkpoint (best quality)...")
@@ -183,9 +183,9 @@ def main():
183183
)
184184

185185
# Load shards into the FSDP-wrapped model
186-
model_state = {"model": fsdp_transformer.state_dict()}
186+
model_state = fsdp_transformer.state_dict()
187187
dist_load(state_dict=model_state, storage_reader=FileSystemReader(sharded_dir))
188-
fsdp_transformer.load_state_dict(model_state["model"])
188+
fsdp_transformer.load_state_dict(model_state)
189189

190190
# Unwrap back to the original module for inference
191191
pipe.transformer = fsdp_transformer.module

0 commit comments

Comments
 (0)