Skip to content

Commit cacb212

Browse files
committed
Load optimizer state when appropriate, part 2
This commit does the same as 7728f5a ("Load optimizer state when appropriate (#140)"), except this time for create_training_client_from_state(). Signed-off-by: Daniel Xu <[email protected]>
1 parent 2e8aaa4 commit cacb212

File tree

6 files changed

+44
-15
lines changed

6 files changed

+44
-15
lines changed

llms-full.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,27 @@ training_client.load_state(sft_checkpoint_path)
676676
- Multi-step training pipelines (e.g., starting DPO training from an SFT checkpoint)
677677
- Starting fresh training from pretrained weights with a new optimizer
678678

679+
### ServiceClient methods for loading checkpoints
680+
681+
The `ServiceClient` also provides methods to create a new `TrainingClient` directly from a saved checkpoint:
682+
683+
- `create_training_client_from_state(path)`: Creates a `TrainingClient` with weights loaded from the checkpoint (no optimizer state). Use this when starting a new training phase from saved weights.
684+
- `create_training_client_from_state_with_optimizer(path)`: Creates a `TrainingClient` with both weights and optimizer state loaded. Use this when resuming interrupted training.
685+
686+
```python
687+
# Resume training with optimizer state
688+
training_client = service_client.create_training_client_from_state_with_optimizer(
689+
"tinker://run-id/weights/checkpoint-001"
690+
)
691+
692+
# Start fresh training from a checkpoint (weights only)
693+
training_client = service_client.create_training_client_from_state(
694+
"tinker://run-id/weights/checkpoint-001"
695+
)
696+
```
697+
698+
Async versions are also available: `create_training_client_from_state_async()` and `create_training_client_from_state_with_optimizer_async()`.
699+
679700

680701
---
681702

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dependencies = [
1313
"numpy",
1414
"rich",
1515
"termcolor",
16-
"tinker>=0.6.0",
16+
"tinker>=0.6.1",
1717
"torch",
1818
"transformers",
1919
"blobfile",

tinker_cookbook/recipes/rl_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def main(config: Config):
8383

8484
resume_info = checkpoint_utils.get_last_checkpoint(config.log_path)
8585
if resume_info:
86-
training_client = service_client.create_training_client_from_state(
86+
training_client = service_client.create_training_client_from_state_with_optimizer(
8787
resume_info["state_path"]
8888
)
8989
start_batch = resume_info["batch"]

tinker_cookbook/recipes/sl_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def main(config: Config):
6363
# Check for resuming
6464
resume_info = checkpoint_utils.get_last_checkpoint(config.log_path)
6565
if resume_info:
66-
training_client = service_client.create_training_client_from_state(
66+
training_client = service_client.create_training_client_from_state_with_optimizer(
6767
resume_info["state_path"]
6868
)
6969
start_batch = resume_info["batch"]

tinker_cookbook/rl/train.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,14 +1058,18 @@ async def main(
10581058
start_batch = 0
10591059

10601060
service_client = tinker.ServiceClient(base_url=cfg.base_url)
1061-
load_state_path: str | None = (
1062-
resume_info["state_path"] if resume_info else cfg.load_checkpoint_path
1063-
)
1064-
if load_state_path:
1061+
if resume_info:
1062+
# Resuming interrupted training - load optimizer state for proper continuation
1063+
training_client = await service_client.create_training_client_from_state_with_optimizer_async(
1064+
resume_info["state_path"]
1065+
)
1066+
logger.info(f"Resumed training from {resume_info['state_path']}")
1067+
elif cfg.load_checkpoint_path:
1068+
# Starting fresh from a checkpoint - load weights only (fresh optimizer)
10651069
training_client = await service_client.create_training_client_from_state_async(
1066-
load_state_path
1070+
cfg.load_checkpoint_path
10671071
)
1068-
logger.info(f"Loaded state from {load_state_path}")
1072+
logger.info(f"Loaded weights from {cfg.load_checkpoint_path}")
10691073
else:
10701074
training_client = await service_client.create_lora_training_client_async(
10711075
cfg.model_name, rank=cfg.lora_rank

tinker_cookbook/supervised/train.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,23 @@ async def main(config: Config):
189189
trace_init(output_file=os.path.join(config.log_path, "trace_events.jsonl"))
190190

191191
service_client = tinker.ServiceClient(base_url=config.base_url)
192-
load_state_path: str | None = (
193-
resume_info["state_path"] if resume_info else config.load_checkpoint_path
194-
)
195192

196193
user_metadata: dict[str, str] = {}
197194
if wandb_link := ml_logger.get_logger_url():
198195
user_metadata["wandb_link"] = wandb_link
199196

200-
if load_state_path:
197+
if resume_info:
198+
# Resuming interrupted training - load optimizer state for proper continuation
199+
training_client = await service_client.create_training_client_from_state_with_optimizer_async(
200+
resume_info["state_path"], user_metadata
201+
)
202+
logger.info(f"Resumed training from {resume_info['state_path']}")
203+
elif config.load_checkpoint_path:
204+
# Starting fresh from a checkpoint - load weights only (fresh optimizer)
201205
training_client = await service_client.create_training_client_from_state_async(
202-
load_state_path, user_metadata
206+
config.load_checkpoint_path, user_metadata
203207
)
204-
logger.info(f"Loaded weights from {load_state_path}")
208+
logger.info(f"Loaded weights from {config.load_checkpoint_path}")
205209
else:
206210
training_client = await service_client.create_lora_training_client_async(
207211
base_model=config.model_name,

0 commit comments

Comments
 (0)