Skip to content

Commit 72b4857

Browse files
authored
Merge branch 'main' into use-loss-fn-type
2 parents 62b6b01 + 6e6dbfe commit 72b4857

File tree

21 files changed

+2108
-39
lines changed

21 files changed

+2108
-39
lines changed

AGENTS.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Working notes for future agents hacking on `tinker-cookbook`. Additional docs ca
1919
- Launch scripts define a CLI-facing `CLIConfig` (parsed by `chz`) that instantiates the richer training `Config`. This gives every recipe a consistent `python -m ... key=value` interface.
2020
- Env builders compose like `RLDatasetBuilder → EnvGroupBuilder → Env`. Groups let us share metadata (tags, pairwise comparisons) and center rewards across related rollouts.
2121
- **Completers:** algorithms interact with the `TokenCompleter` interface. `TinkerTokenCompleter` (wrapping a `SamplingClient`) is the default implementation, but evaluators may accept any `TokenCompleter` or `MessageCompleter`.
22-
- **Renderers & tokenizer utils:** pick the renderer that matches your tokenizer/model pair (e.g., `role_colon`, `llama3`, `qwen3`). `TrainOnWhat` controls which tokens get weight=1 in SFT. Tokenizers are cached via `tokenizer_utils.get_tokenizer`, with Llama-3 names remapped to `baseten/Meta-Llama-3-tokenizer` to bypass HF gating.
22+
- **Renderers & tokenizer utils:** pick the renderer that matches your tokenizer/model pair (e.g., `role_colon`, `llama3`, `qwen3`). `TrainOnWhat` controls which tokens get weight=1 in SFT. Tokenizers are cached via `tokenizer_utils.get_tokenizer`, with Llama-3 names remapped to `thinkingmachineslabinc/meta-llama-3-tokenizer` to bypass HF gating.
2323
- **Loss plumbing:** every `tinker.Datum` bundles a `model_input` plus `loss_fn_inputs` (`TensorData`). Use helpers such as `conversation_to_datum`, `datum_from_tokens_weights`, and `_remove_mask` instead of constructing dicts manually. Built-in losses: `cross_entropy`, `importance_sampling`, `ppo`; `forward_backward_custom` covers bespoke differentiable objectives.
2424

2525
## Conventions & Notation (from CONTRIBUTING)
@@ -59,7 +59,7 @@ Working notes for future agents hacking on `tinker-cookbook`. Additional docs ca
5959

6060
### Evaluations & Sampling
6161
- Inline evaluators implement either `TrainingClientEvaluator` or `SamplingClientEvaluator`. Training loops accept builder lists (`evaluator_builders`, `infrequent_evaluator_builders`). Inspect AI integration is in `eval/inspect_evaluators.py` and `eval/run_inspect_evals.py`.
62-
- Sampling clients come from `training_client.save_weights_and_get_sampling_client(name=...)`. To export weights, use `RestClient.download_checkpoint_archive_from_tinker_path`.
62+
- Sampling clients come from `training_client.save_weights_and_get_sampling_client(name=...)`. To export weights, use `RestClient.get_checkpoint_archive_url_from_tinker_path`.
6363

6464
## Async & Performance
6565
- Worker pools advance in ~10s clock cycles. Submit `forward_backward_async` and `optim_step_async` back-to-back, then await both futures to keep them on the same cycle.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ See [tinker_cookbook/recipes/sl_loop.py](tinker_cookbook/recipes/sl_loop.py) and
4040
To download the weights of any model:
4141
```python
4242
rest_client = service_client.create_rest_client()
43-
future = rest_client.download_checkpoint_archive_from_tinker_path(sampling_client.model_path)
43+
future = rest_client.get_checkpoint_archive_url_from_tinker_path(sampling_client.model_path)
4444
with open(f"model-checkpoint.tar.gz", "wb") as f:
4545
f.write(future.result())
4646
```

llms-full.txt

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -607,11 +607,12 @@ We'll start with a couple of general pages that'll be relevant to almost all of
607607

608608
# Saving and loading weights and optimizer state
609609

610-
During training, you'll need to save checkpoints for two main purposes: *sampling* (to test your model) and *resuming training* (to continue from where you left off). The `TrainingClient` provides three methods to handle these cases:
610+
During training, you'll need to save checkpoints for two main purposes: *sampling* (to test your model) and *resuming training* (to continue from where you left off). The `TrainingClient` provides these methods to handle these cases:
611611

612612
1. `save_weights_for_sampler()`: saves a copy of the model weights that can be used for sampling.
613613
2. `save_state()`: saves the weights and the optimizer state. You can fully resume training from this checkpoint.
614-
3. `load_state()`: load the weights and the optimizer state. You can fully resume training from this checkpoint.
614+
3. `load_state()`: load the model weights only (without optimizer state). Use this when you want to start fresh training from a checkpoint, e.g., starting DPO training from an SFT checkpoint.
615+
4. `load_state_with_optimizer()`: load the model weights and optimizer state. Use this when resuming interrupted training, as it restores the full training state including optimizer momentum.
615616

616617
Note that (1) is faster and requires less storage space than (2).
617618

@@ -644,24 +645,58 @@ sampling_client = training_client.save_weights_and_get_sampling_client(name="000
644645

645646
### Example: Saving to resume training
646647

647-
Use `save_state()` and `load_state()` when you need to pause and continue training with full optimizer state preferred:
648+
Use `save_state()` and `load_state_with_optimizer()` when you need to pause and continue training with full optimizer state:
648649

649650
```python
650651
# Save a checkpoint that you can resume from
651652
resume_path = training_client.save_state(name="0010").result().path
652653

653-
# Load that checkpoint
654-
training_client.load_state(resume_path)
654+
# Load that checkpoint with optimizer state (for resuming training)
655+
training_client.load_state_with_optimizer(resume_path)
655656
```
656657

657-
### When to use `save_state()` and `load_state()`:
658+
Async versions are also available: `load_state_with_optimizer_async()`.
658659

660+
### Example: Starting fresh from a checkpoint
659661

660-
- Multi-step training pipelines (e.g. supervised learning followed by reinforcement learning)
661-
- Adjusting hyperparameters or data mid-run
662-
- Recovery from interruptions or failures
662+
Use `load_state()` when you want to start a new training phase from saved weights (e.g., starting DPO from an SFT checkpoint):
663+
664+
```python
665+
# Load weights only, starting with fresh optimizer state
666+
training_client.load_state(sft_checkpoint_path)
667+
```
668+
669+
### When to use `load_state_with_optimizer()`:
670+
671+
- Recovery from interruptions or failures (resume training exactly where you left off)
663672
- Any scenario where you need to preserve exact optimizer state (momentum, learning rate schedules, etc.)
664673

674+
### When to use `load_state()`:
675+
676+
- Multi-step training pipelines (e.g., starting DPO training from an SFT checkpoint)
677+
- Starting fresh training from pretrained weights with a new optimizer
678+
679+
### ServiceClient methods for loading checkpoints
680+
681+
The `ServiceClient` also provides methods to create a new `TrainingClient` directly from a saved checkpoint:
682+
683+
- `create_training_client_from_state(path)`: Creates a `TrainingClient` with weights loaded from the checkpoint (no optimizer state). Use this when starting a new training phase from saved weights.
684+
- `create_training_client_from_state_with_optimizer(path)`: Creates a `TrainingClient` with both weights and optimizer state loaded. Use this when resuming interrupted training.
685+
686+
```python
687+
# Resume training with optimizer state
688+
training_client = service_client.create_training_client_from_state_with_optimizer(
689+
"tinker://run-id/weights/checkpoint-001"
690+
)
691+
692+
# Start fresh training from a checkpoint (weights only)
693+
training_client = service_client.create_training_client_from_state(
694+
"tinker://run-id/weights/checkpoint-001"
695+
)
696+
```
697+
698+
Async versions are also available: `create_training_client_from_state_async()` and `create_training_client_from_state_with_optimizer_async()`.
699+
665700

666701
---
667702

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ authors = [
99
requires-python = ">=3.11"
1010
dependencies = [
1111
"chz",
12+
"cloudpickle",
1213
"datasets",
1314
"numpy",
1415
"rich",
1516
"termcolor",
16-
"tinker>=0.5.1",
17+
"tinker>=0.6.1",
1718
"torch",
1819
"transformers",
1920
"blobfile",

tinker_cookbook/distillation/train_on_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ async def main(
386386
resume_info["state_path"] if resume_info else cfg.load_checkpoint_path
387387
)
388388
if load_state_path:
389-
future = await training_client.load_state_async(load_state_path)
389+
future = await training_client.load_state_with_optimizer_async(load_state_path)
390390
_ = await future.result_async()
391391
logger.info(f"Loaded state from {load_state_path}")
392392

tinker_cookbook/preference/train_dpo.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,15 @@ def create_dpo_clients(
9191
base_model=config.model_name, rank=config.lora_rank
9292
)
9393

94-
# Load state first to get the SFT checkpoint path for the reference client
95-
load_state_path: str | None = (
96-
resume_info["state_path"] if resume_info else config.load_checkpoint_path
97-
)
98-
if load_state_path:
99-
# Load state into the training client
100-
training_client.load_state(load_state_path).result()
101-
logger.info(f"Loaded weights from {load_state_path}")
94+
# Load state - differentiate between resuming DPO training vs starting fresh from SFT
95+
if resume_info:
96+
# Resuming interrupted DPO training - load optimizer state for proper continuation
97+
training_client.load_state_with_optimizer(resume_info["state_path"]).result()
98+
logger.info(f"Resumed DPO training from {resume_info['state_path']}")
99+
elif config.load_checkpoint_path:
100+
# Starting fresh DPO from SFT checkpoint - load weights only (fresh optimizer)
101+
training_client.load_state(config.load_checkpoint_path).result()
102+
logger.info(f"Loaded weights from {config.load_checkpoint_path}")
102103
# Create a sampling client for the reference model from the training client
103104
reference_client = training_client.save_weights_and_get_sampling_client("reference")
104105
return training_client, reference_client

tinker_cookbook/recipes/rl_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def main(config: Config):
8383

8484
resume_info = checkpoint_utils.get_last_checkpoint(config.log_path)
8585
if resume_info:
86-
training_client = service_client.create_training_client_from_state(
86+
training_client = service_client.create_training_client_from_state_with_optimizer(
8787
resume_info["state_path"]
8888
)
8989
start_batch = resume_info["batch"]

tinker_cookbook/recipes/sl_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def main(config: Config):
6363
# Check for resuming
6464
resume_info = checkpoint_utils.get_last_checkpoint(config.log_path)
6565
if resume_info:
66-
training_client = service_client.create_training_client_from_state(
66+
training_client = service_client.create_training_client_from_state_with_optimizer(
6767
resume_info["state_path"]
6868
)
6969
start_batch = resume_info["batch"]

tinker_cookbook/renderers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -625,13 +625,18 @@ def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[i
625625
class DeepSeekV3Renderer(Renderer):
626626
"""
627627
Format like this (no newlines between messages):
628-
<|begin_of_sentence|><|User|>What can you help me with?<|Assistant|><think>Thinking...</think>I can help you with...<|end_of_centence|>
628+
<|begin_of_sentence|><|User|>What can you help me with?<|Assistant|><think>Thinking...</think>I can help you with...<|end_of_sentence|>
629629
For no-think, just use <|Assistant|></think>
630+
Deepseek renderer does not support the system role out of the box. You can set system_role_as_user to True to automatically convert the system role to the user role.
630631
"""
631632

633+
def __init__(self, tokenizer: Tokenizer, system_role_as_user: bool = False):
634+
super().__init__(tokenizer)
635+
self.system_role_as_user = system_role_as_user
636+
632637
def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]:
633638
assert message.get("thinking") is None, "TODO: support CoT in DsV3 renderer"
634-
if message["role"] == "user":
639+
if message["role"] == "user" or (self.system_role_as_user and message["role"] == "system"):
635640
role_token = self._get_special_token("User")
636641
elif message["role"] == "assistant":
637642
role_token = self._get_special_token("Assistant")

tinker_cookbook/rl/train.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,14 +1058,20 @@ async def main(
10581058
start_batch = 0
10591059

10601060
service_client = tinker.ServiceClient(base_url=cfg.base_url)
1061-
load_state_path: str | None = (
1062-
resume_info["state_path"] if resume_info else cfg.load_checkpoint_path
1063-
)
1064-
if load_state_path:
1061+
if resume_info:
1062+
# Resuming interrupted training - load optimizer state for proper continuation
1063+
training_client = (
1064+
await service_client.create_training_client_from_state_with_optimizer_async(
1065+
resume_info["state_path"]
1066+
)
1067+
)
1068+
logger.info(f"Resumed training from {resume_info['state_path']}")
1069+
elif cfg.load_checkpoint_path:
1070+
# Starting fresh from a checkpoint - load weights only (fresh optimizer)
10651071
training_client = await service_client.create_training_client_from_state_async(
1066-
load_state_path
1072+
cfg.load_checkpoint_path
10671073
)
1068-
logger.info(f"Loaded state from {load_state_path}")
1074+
logger.info(f"Loaded weights from {cfg.load_checkpoint_path}")
10691075
else:
10701076
training_client = await service_client.create_lora_training_client_async(
10711077
cfg.model_name, rank=cfg.lora_rank

0 commit comments

Comments
 (0)