File tree Expand file tree Collapse file tree 6 files changed +48
-15
lines changed
Expand file tree Collapse file tree 6 files changed +48
-15
lines changed Original file line number Diff line number Diff line change @@ -676,6 +676,27 @@ training_client.load_state(sft_checkpoint_path)
676676- Multi-step training pipelines (e.g., starting DPO training from an SFT checkpoint)
677677- Starting fresh training from pretrained weights with a new optimizer
678678
679+ ### ServiceClient methods for loading checkpoints
680+
681+ The `ServiceClient` also provides methods to create a new `TrainingClient` directly from a saved checkpoint:
682+
683+ - `create_training_client_from_state(path)`: Creates a `TrainingClient` with weights loaded from the checkpoint (no optimizer state). Use this when starting a new training phase from saved weights.
684+ - `create_training_client_from_state_with_optimizer(path)`: Creates a `TrainingClient` with both weights and optimizer state loaded. Use this when resuming interrupted training.
685+
686+ ```python
687+ # Resume training with optimizer state
688+ training_client = service_client.create_training_client_from_state_with_optimizer(
689+ "tinker://run-id/weights/checkpoint-001"
690+ )
691+
692+ # Start fresh training from a checkpoint (weights only)
693+ training_client = service_client.create_training_client_from_state(
694+ "tinker://run-id/weights/checkpoint-001"
695+ )
696+ ```
697+
698+ Async versions are also available: `create_training_client_from_state_async()` and `create_training_client_from_state_with_optimizer_async()`.
699+
679700
680701---
681702
Original file line number Diff line number Diff line change @@ -13,7 +13,7 @@ dependencies = [
1313 " numpy" ,
1414 " rich" ,
1515 " termcolor" ,
16- " tinker>=0.6.0 " ,
16+ " tinker>=0.6.1 " ,
1717 " torch" ,
1818 " transformers" ,
1919 " blobfile" ,
Original file line number Diff line number Diff line change @@ -83,7 +83,7 @@ def main(config: Config):
8383
8484 resume_info = checkpoint_utils .get_last_checkpoint (config .log_path )
8585 if resume_info :
86- training_client = service_client .create_training_client_from_state (
86+ training_client = service_client .create_training_client_from_state_with_optimizer (
8787 resume_info ["state_path" ]
8888 )
8989 start_batch = resume_info ["batch" ]
Original file line number Diff line number Diff line change @@ -63,7 +63,7 @@ def main(config: Config):
6363 # Check for resuming
6464 resume_info = checkpoint_utils .get_last_checkpoint (config .log_path )
6565 if resume_info :
66- training_client = service_client .create_training_client_from_state (
66+ training_client = service_client .create_training_client_from_state_with_optimizer (
6767 resume_info ["state_path" ]
6868 )
6969 start_batch = resume_info ["batch" ]
Original file line number Diff line number Diff line change @@ -1058,14 +1058,20 @@ async def main(
10581058 start_batch = 0
10591059
10601060 service_client = tinker .ServiceClient (base_url = cfg .base_url )
1061- load_state_path : str | None = (
1062- resume_info ["state_path" ] if resume_info else cfg .load_checkpoint_path
1063- )
1064- if load_state_path :
1061+ if resume_info :
1062+ # Resuming interrupted training - load optimizer state for proper continuation
1063+ training_client = (
1064+ await service_client .create_training_client_from_state_with_optimizer_async (
1065+ resume_info ["state_path" ]
1066+ )
1067+ )
1068+ logger .info (f"Resumed training from { resume_info ['state_path' ]} " )
1069+ elif cfg .load_checkpoint_path :
1070+ # Starting fresh from a checkpoint - load weights only (fresh optimizer)
10651071 training_client = await service_client .create_training_client_from_state_async (
1066- load_state_path
1072+ cfg . load_checkpoint_path
10671073 )
1068- logger .info (f"Loaded state from { load_state_path } " )
1074+ logger .info (f"Loaded weights from { cfg . load_checkpoint_path } " )
10691075 else :
10701076 training_client = await service_client .create_lora_training_client_async (
10711077 cfg .model_name , rank = cfg .lora_rank
Original file line number Diff line number Diff line change @@ -189,19 +189,25 @@ async def main(config: Config):
189189 trace_init (output_file = os .path .join (config .log_path , "trace_events.jsonl" ))
190190
191191 service_client = tinker .ServiceClient (base_url = config .base_url )
192- load_state_path : str | None = (
193- resume_info ["state_path" ] if resume_info else config .load_checkpoint_path
194- )
195192
196193 user_metadata : dict [str , str ] = {}
197194 if wandb_link := ml_logger .get_logger_url ():
198195 user_metadata ["wandb_link" ] = wandb_link
199196
200- if load_state_path :
197+ if resume_info :
198+ # Resuming interrupted training - load optimizer state for proper continuation
199+ training_client = (
200+ await service_client .create_training_client_from_state_with_optimizer_async (
201+ resume_info ["state_path" ], user_metadata
202+ )
203+ )
204+ logger .info (f"Resumed training from { resume_info ['state_path' ]} " )
205+ elif config .load_checkpoint_path :
206+ # Starting fresh from a checkpoint - load weights only (fresh optimizer)
201207 training_client = await service_client .create_training_client_from_state_async (
202- load_state_path , user_metadata
208+ config . load_checkpoint_path , user_metadata
203209 )
204- logger .info (f"Loaded weights from { load_state_path } " )
210+ logger .info (f"Loaded weights from { config . load_checkpoint_path } " )
205211 else :
206212 training_client = await service_client .create_lora_training_client_async (
207213 base_model = config .model_name ,
You can’t perform that action at this time.
0 commit comments