Skip to content

Commit 53aaf8d

Browse files
kennyyujoschuclaude
authored
[tinker-cookbook] set_scope_context: add helper (#106)
Co-authored-by: John Schulman <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent 37370ca commit 53aaf8d

File tree

6 files changed

+63
-51
lines changed

6 files changed

+63
-51
lines changed

tinker_cookbook/checkpoint_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tinker
88

99
from tinker_cookbook.utils.file_utils import read_jsonl
10-
from tinker_cookbook.utils.trace import get_scope_context, scope
10+
from tinker_cookbook.utils.trace import scope, update_scope_context
1111

1212
CHECKPOINTS_BASE_NAME = "checkpoints.jsonl"
1313

@@ -22,8 +22,7 @@ def load_checkpoints_file(log_dir: str) -> list[dict[str, Any]]:
2222
return []
2323

2424
logger.info(f"Reading checkpoints from {checkpoint_path}")
25-
context = get_scope_context()
26-
context.attributes["checkpoint_path"] = checkpoint_path
25+
update_scope_context({"checkpoint_path": checkpoint_path})
2726
return read_jsonl(checkpoint_path)
2827

2928

@@ -78,8 +77,7 @@ async def save_checkpoint_async(
7877

7978
results = {k: await v.result_async() for k, v in futures.items()}
8079
paths = {k + "_path": v.path for k, v in results.items()}
81-
context = get_scope_context()
82-
context.attributes.update(paths)
80+
update_scope_context(paths)
8381
logger.info(f"Saved checkpoints: {paths}")
8482
full_dict = {"name": name, **loop_state, **paths}
8583
with open(os.path.join(log_path, "checkpoints.jsonl"), "a") as f:

tinker_cookbook/distillation/train_on_policy.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,39 @@
77
import logging
88
import os
99
import time
10-
from typing import Any, List, Literal, Sequence, Dict, cast
10+
from typing import Any, Dict, List, Literal, Sequence, cast
1111

1212
import chz
1313
import tinker
1414
import torch
15+
1516
from tinker_cookbook import checkpoint_utils
1617
from tinker_cookbook.display import colorize_example
18+
from tinker_cookbook.distillation.datasets import (
19+
CompositeDataset,
20+
DistillationDatasetConfig,
21+
)
1722
from tinker_cookbook.eval.evaluators import SamplingClientEvaluator, SamplingClientEvaluatorBuilder
1823
from tinker_cookbook.rl.data_processing import (
1924
assemble_training_data,
2025
compute_advantages,
2126
)
2227
from tinker_cookbook.rl.metric_util import RLTestSetEvaluator, compute_trajectory_metrics
2328
from tinker_cookbook.rl.metrics import discounted_future_sum_vectorized
29+
from tinker_cookbook.rl.train import (
30+
compute_full_batch_metrics_and_get_sampling_client,
31+
do_group_rollout_and_filter_constant_reward,
32+
save_checkpoint_and_get_sampling_client,
33+
train_step,
34+
)
2435
from tinker_cookbook.rl.types import (
2536
EnvGroupBuilder,
2637
TrajectoryGroup,
2738
)
2839
from tinker_cookbook.tokenizer_utils import Tokenizer
2940
from tinker_cookbook.utils import ml_log
3041
from tinker_cookbook.utils.misc_utils import safezip, timed
31-
from tinker_cookbook.utils.trace import scope, get_scope_context, trace_init
32-
33-
# Dataset configuration classes
34-
from tinker_cookbook.distillation.datasets import (
35-
CompositeDataset,
36-
DistillationDatasetConfig,
37-
)
38-
39-
# We re-use these methods from the RL training recipe
40-
from tinker_cookbook.rl.train import (
41-
save_checkpoint_and_get_sampling_client,
42-
train_step,
43-
compute_full_batch_metrics_and_get_sampling_client,
44-
do_group_rollout_and_filter_constant_reward,
45-
)
42+
from tinker_cookbook.utils.trace import scope, update_scope_context, trace_init
4643

4744
logger = logging.getLogger(__name__)
4845

@@ -228,8 +225,7 @@ async def do_train_step_and_get_sampling_client(
228225
dataset_indices_P: List[int],
229226
teacher_clients: List[tinker.SamplingClient],
230227
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
231-
context = get_scope_context()
232-
context.attributes["step"] = i_batch
228+
update_scope_context({"step": i_batch})
233229

234230
metrics = {}
235231
data_D, prepare_minibatch_metrics = await prepare_minibatch(

tinker_cookbook/rl/train.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import logging
88
import os
99
import time
10-
from typing import Any, Callable, List, Literal, Sequence, Iterator
10+
from contextlib import contextmanager
11+
from typing import Any, Callable, Iterator, List, Literal, Sequence
1112

1213
import chz
1314
import numpy as np
1415
import tinker
1516
import torch
17+
1618
from tinker_cookbook import checkpoint_utils
1719
from tinker_cookbook.completers import TinkerTokenCompleter
1820
from tinker_cookbook.display import colorize_example
@@ -39,9 +41,7 @@
3941
from tinker_cookbook.tokenizer_utils import Tokenizer
4042
from tinker_cookbook.utils import logtree, ml_log
4143
from tinker_cookbook.utils.misc_utils import safezip, split_list, timed
42-
from tinker_cookbook.utils.trace import scope, trace_init, get_scope_context
43-
from contextlib import contextmanager
44-
44+
from tinker_cookbook.utils.trace import scope, update_scope_context, trace_init
4545

4646
logger = logging.getLogger(__name__)
4747

@@ -809,8 +809,7 @@ async def do_train_step_streaming_and_get_sampling_client(
809809
# Number of groups per minibatch in each optimizer substep
810810
groups_per_minibatch = groups_per_substep // cfg.stream_minibatch_config.num_minibatches
811811

812-
context = get_scope_context()
813-
context.attributes["step"] = i_batch
812+
update_scope_context({"step": i_batch})
814813

815814
metrics = {}
816815

@@ -904,8 +903,7 @@ async def do_train_step_and_get_sampling_client(
904903
env_group_builders_P: Sequence[EnvGroupBuilder],
905904
trajectory_groups_P: list[TrajectoryGroup],
906905
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
907-
context = get_scope_context()
908-
context.attributes["step"] = i_batch
906+
update_scope_context({"step": i_batch})
909907

910908
metrics = {}
911909
data_D, prepare_minibatch_metrics = await prepare_minibatch(

tinker_cookbook/supervised/train.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from tinker_cookbook.utils import ml_log
3333
from tinker_cookbook.utils.lr_scheduling import compute_schedule_lr_multiplier
3434
from tinker_cookbook.utils.misc_utils import timed
35-
from tinker_cookbook.utils.trace import get_scope_context, scope, trace_init
35+
from tinker_cookbook.utils.trace import scope, update_scope_context, trace_init
3636

3737
logger = logging.getLogger(__name__)
3838

@@ -107,22 +107,24 @@ async def run_evals(
107107
checkpoint. Returned metrics are prefixed with ``test/`` so they can be logged next
108108
to the same-step training metrics.
109109
"""
110-
context = get_scope_context()
111-
context.attributes["step"] = step
110+
update_scope_context({"step": step})
112111

113112
metrics = {}
114113
sampling_client = None
115114

116115
@scope
117116
async def run_evaluator(evaluator: Evaluator) -> dict[str, float]:
118-
context = get_scope_context()
119-
context.attributes["step"] = step
120-
context.attributes["evaluator_name"] = type(evaluator).__name__
117+
update_scope_context(
118+
{
119+
"step": step,
120+
"evaluator_name": type(evaluator).__name__,
121+
}
122+
)
121123
if isinstance(evaluator, TrainingClientEvaluator):
122-
context.attributes["evaluator_type"] = "TrainingClientEvaluator"
124+
update_scope_context({"evaluator_type": "TrainingClientEvaluator"})
123125
return await evaluator(training_client)
124126
elif isinstance(evaluator, SamplingClientEvaluator):
125-
context.attributes["evaluator_type"] = "SamplingClientEvaluator"
127+
update_scope_context({"evaluator_type": "SamplingClientEvaluator"})
126128
# Create sampling client lazily, only when needed
127129
nonlocal sampling_client
128130
if sampling_client is None:
@@ -225,8 +227,7 @@ async def main(config: Config):
225227
@scope
226228
async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
227229
step = epoch_idx * n_batches + batch_idx
228-
context = get_scope_context()
229-
context.attributes["step"] = step
230+
update_scope_context({"step": step})
230231

231232
batch_start_time = time.time()
232233
metrics: dict[str, int | float | str] = {"epoch": epoch_idx}
@@ -286,8 +287,7 @@ async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
286287

287288
@scope
288289
async def finish_batch(submitted: SubmittedBatch):
289-
context = get_scope_context()
290-
context.attributes["step"] = submitted.step
290+
update_scope_context({"step": submitted.step})
291291

292292
metrics = submitted.metrics
293293
metrics["progress"] = min((submitted.step + 1) / progress_denominator, 1.0)

tinker_cookbook/tests/test_trace.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
import json
2-
from tinker_cookbook.utils.trace import scope, trace_init, trace_shutdown, get_scope_context
31
import asyncio
4-
import threading
2+
import json
53
import tempfile
4+
import threading
5+
6+
from tinker_cookbook.utils.trace import (
7+
get_scope_context,
8+
scope,
9+
update_scope_context,
10+
trace_init,
11+
trace_shutdown,
12+
)
613

714

815
@scope
@@ -30,8 +37,7 @@ def ced():
3037
@scope
3138
async def baz():
3239
await asyncio.sleep(0.02)
33-
context = get_scope_context()
34-
context.attributes["baz"] = "baz"
40+
update_scope_context({"baz": "baz"})
3541
ced()
3642

3743

tinker_cookbook/utils/trace.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1+
import argparse
12
import asyncio
3+
import atexit
24
import functools
35
import inspect
46
import json
57
import queue
68
import threading
79
import time
810
from contextvars import ContextVar
9-
from typing import Any, Callable
1011
from dataclasses import dataclass, field
1112
from enum import Enum
12-
import argparse
1313
from io import TextIOWrapper
14-
import atexit
14+
from typing import Any, Callable
1515

1616

1717
class EventType(str, Enum):
@@ -407,6 +407,20 @@ async def foo():
407407
return result
408408

409409

410+
def update_scope_context(values: dict[str, Any]) -> None:
411+
"""Update the current scope's context. Example usage:
412+
413+
@scope
414+
async def foo(step: int):
415+
update_scope_context({"step": step})
416+
await bar()
417+
418+
"""
419+
result = trace_context.get(ScopeContext())
420+
assert result is not None, "Trace context is not set"
421+
result.attributes.update(values)
422+
423+
410424
def convert_jsonl_to_json_main():
411425
"""Helper script to convert the trace events format into a visualizable format"""
412426
parser = argparse.ArgumentParser(

0 commit comments

Comments
 (0)