Skip to content

Commit 7728f5a

Browse files
authored
Load optimizer state when appropriate (#140)
load_state() does not actually load optimizer state. The API has been fixed to make loading optimizer state explicit. This commit migrates load_state() callers to use load_state_with_optimizer() when appropriate. Next we'll fix callers of create_training_client_from_state(). Signed-off-by: Daniel Xu <[email protected]>
1 parent 53c6c38 commit 7728f5a

File tree

4 files changed

+34
-19
lines changed

4 files changed

+34
-19
lines changed

llms-full.txt

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -607,11 +607,12 @@ We'll start with a couple of general pages that'll be relevant to almost all of
607607

608608
# Saving and loading weights and optimizer state
609609

610-
During training, you'll need to save checkpoints for two main purposes: *sampling* (to test your model) and *resuming training* (to continue from where you left off). The `TrainingClient` provides three methods to handle these cases:
610+
During training, you'll need to save checkpoints for two main purposes: *sampling* (to test your model) and *resuming training* (to continue from where you left off). The `TrainingClient` provides these methods to handle these cases:
611611

612612
1. `save_weights_for_sampler()`: saves a copy of the model weights that can be used for sampling.
613613
2. `save_state()`: saves the weights and the optimizer state. You can fully resume training from this checkpoint.
614-
3. `load_state()`: load the weights and the optimizer state. You can fully resume training from this checkpoint.
614+
3. `load_state()`: load the model weights only (without optimizer state). Use this when you want to start fresh training from a checkpoint, e.g., starting DPO training from an SFT checkpoint.
615+
4. `load_state_with_optimizer()`: load the model weights and optimizer state. Use this when resuming interrupted training, as it restores the full training state including optimizer momentum.
615616

616617
Note that (1) is faster and requires less storage space than (2).
617618

@@ -644,24 +645,37 @@ sampling_client = training_client.save_weights_and_get_sampling_client(name="000
644645

645646
### Example: Saving to resume training
646647

647-
Use `save_state()` and `load_state()` when you need to pause and continue training with full optimizer state preferred:
648+
Use `save_state()` and `load_state_with_optimizer()` when you need to pause and continue training with full optimizer state:
648649

649650
```python
650651
# Save a checkpoint that you can resume from
651652
resume_path = training_client.save_state(name="0010").result().path
652653

653-
# Load that checkpoint
654-
training_client.load_state(resume_path)
654+
# Load that checkpoint with optimizer state (for resuming training)
655+
training_client.load_state_with_optimizer(resume_path)
655656
```
656657

657-
### When to use `save_state()` and `load_state()`:
658+
Async versions are also available: `load_state_with_optimizer_async()`.
658659

660+
### Example: Starting fresh from a checkpoint
659661

660-
- Multi-step training pipelines (e.g. supervised learning followed by reinforcement learning)
661-
- Adjusting hyperparameters or data mid-run
662-
- Recovery from interruptions or failures
662+
Use `load_state()` when you want to start a new training phase from saved weights (e.g., starting DPO from an SFT checkpoint):
663+
664+
```python
665+
# Load weights only, starting with fresh optimizer state
666+
training_client.load_state(sft_checkpoint_path)
667+
```
668+
669+
### When to use `load_state_with_optimizer()`:
670+
671+
- Recovery from interruptions or failures (resume training exactly where you left off)
663672
- Any scenario where you need to preserve exact optimizer state (momentum, learning rate schedules, etc.)
664673

674+
### When to use `load_state()`:
675+
676+
- Multi-step training pipelines (e.g., starting DPO training from an SFT checkpoint)
677+
- Starting fresh training from pretrained weights with a new optimizer
678+
665679

666680
---
667681

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dependencies = [
1313
"numpy",
1414
"rich",
1515
"termcolor",
16-
"tinker>=0.3.0",
16+
"tinker>=0.6.0",
1717
"torch",
1818
"transformers",
1919
"blobfile",

tinker_cookbook/distillation/train_on_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ async def main(
384384
resume_info["state_path"] if resume_info else cfg.load_checkpoint_path
385385
)
386386
if load_state_path:
387-
future = await training_client.load_state_async(load_state_path)
387+
future = await training_client.load_state_with_optimizer_async(load_state_path)
388388
_ = await future.result_async()
389389
logger.info(f"Loaded state from {load_state_path}")
390390

tinker_cookbook/preference/train_dpo.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,15 @@ def create_dpo_clients(
9191
base_model=config.model_name, rank=config.lora_rank
9292
)
9393

94-
# Load state first to get the SFT checkpoint path for the reference client
95-
load_state_path: str | None = (
96-
resume_info["state_path"] if resume_info else config.load_checkpoint_path
97-
)
98-
if load_state_path:
99-
# Load state into the training client
100-
training_client.load_state(load_state_path).result()
101-
logger.info(f"Loaded weights from {load_state_path}")
94+
# Load state - differentiate between resuming DPO training vs starting fresh from SFT
95+
if resume_info:
96+
# Resuming interrupted DPO training - load optimizer state for proper continuation
97+
training_client.load_state_with_optimizer(resume_info["state_path"]).result()
98+
logger.info(f"Resumed DPO training from {resume_info['state_path']}")
99+
elif config.load_checkpoint_path:
100+
# Starting fresh DPO from SFT checkpoint - load weights only (fresh optimizer)
101+
training_client.load_state(config.load_checkpoint_path).result()
102+
logger.info(f"Loaded weights from {config.load_checkpoint_path}")
102103
# Create a sampling client for the reference model from the training client
103104
reference_client = training_client.save_weights_and_get_sampling_client("reference")
104105
return training_client, reference_client

0 commit comments

Comments
 (0)