|
7 | 7 | import logging |
8 | 8 | import os |
9 | 9 | import time |
10 | | -from typing import Any, List, Literal, Sequence, Dict, cast |
| 10 | +from typing import Any, Dict, List, Literal, Sequence, cast |
11 | 11 |
|
12 | 12 | import chz |
13 | 13 | import tinker |
14 | 14 | import torch |
| 15 | + |
15 | 16 | from tinker_cookbook import checkpoint_utils |
16 | 17 | from tinker_cookbook.display import colorize_example |
| 18 | +from tinker_cookbook.distillation.datasets import ( |
| 19 | + CompositeDataset, |
| 20 | + DistillationDatasetConfig, |
| 21 | +) |
17 | 22 | from tinker_cookbook.eval.evaluators import SamplingClientEvaluator, SamplingClientEvaluatorBuilder |
18 | 23 | from tinker_cookbook.rl.data_processing import ( |
19 | 24 | assemble_training_data, |
20 | 25 | compute_advantages, |
21 | 26 | ) |
22 | 27 | from tinker_cookbook.rl.metric_util import RLTestSetEvaluator, compute_trajectory_metrics |
23 | 28 | from tinker_cookbook.rl.metrics import discounted_future_sum_vectorized |
| 29 | +from tinker_cookbook.rl.train import ( |
| 30 | + compute_full_batch_metrics_and_get_sampling_client, |
| 31 | + do_group_rollout_and_filter_constant_reward, |
| 32 | + save_checkpoint_and_get_sampling_client, |
| 33 | + train_step, |
| 34 | +) |
24 | 35 | from tinker_cookbook.rl.types import ( |
25 | 36 | EnvGroupBuilder, |
26 | 37 | TrajectoryGroup, |
27 | 38 | ) |
28 | 39 | from tinker_cookbook.tokenizer_utils import Tokenizer |
29 | 40 | from tinker_cookbook.utils import ml_log |
30 | 41 | from tinker_cookbook.utils.misc_utils import safezip, timed |
31 | | -from tinker_cookbook.utils.trace import scope, get_scope_context, trace_init |
32 | | - |
33 | | -# Dataset configuration classes |
34 | | -from tinker_cookbook.distillation.datasets import ( |
35 | | - CompositeDataset, |
36 | | - DistillationDatasetConfig, |
37 | | -) |
38 | | - |
39 | | -# We re-use these methods from the RL training recipe |
40 | | -from tinker_cookbook.rl.train import ( |
41 | | - save_checkpoint_and_get_sampling_client, |
42 | | - train_step, |
43 | | - compute_full_batch_metrics_and_get_sampling_client, |
44 | | - do_group_rollout_and_filter_constant_reward, |
45 | | -) |
| 42 | +from tinker_cookbook.utils.trace import scope, update_scope_context, trace_init |
46 | 43 |
|
47 | 44 | logger = logging.getLogger(__name__) |
48 | 45 |
|
@@ -228,8 +225,7 @@ async def do_train_step_and_get_sampling_client( |
228 | 225 | dataset_indices_P: List[int], |
229 | 226 | teacher_clients: List[tinker.SamplingClient], |
230 | 227 | ) -> tuple[tinker.SamplingClient, dict[str, Any]]: |
231 | | - context = get_scope_context() |
232 | | - context.attributes["step"] = i_batch |
| 228 | + update_scope_context({"step": i_batch}) |
233 | 229 |
|
234 | 230 | metrics = {} |
235 | 231 | data_D, prepare_minibatch_metrics = await prepare_minibatch( |
|
0 commit comments