diff --git a/tinker_cookbook/checkpoint_utils.py b/tinker_cookbook/checkpoint_utils.py index 091e8f7b..e0cc7d8a 100644 --- a/tinker_cookbook/checkpoint_utils.py +++ b/tinker_cookbook/checkpoint_utils.py @@ -7,7 +7,7 @@ import tinker from tinker_cookbook.utils.file_utils import read_jsonl -from tinker_cookbook.utils.trace import get_scope_context, scope +from tinker_cookbook.utils.trace import scope, update_scope_context CHECKPOINTS_BASE_NAME = "checkpoints.jsonl" @@ -22,8 +22,7 @@ def load_checkpoints_file(log_dir: str) -> list[dict[str, Any]]: return [] logger.info(f"Reading checkpoints from {checkpoint_path}") - context = get_scope_context() - context.attributes["checkpoint_path"] = checkpoint_path + update_scope_context({"checkpoint_path": checkpoint_path}) return read_jsonl(checkpoint_path) @@ -78,8 +77,7 @@ async def save_checkpoint_async( results = {k: await v.result_async() for k, v in futures.items()} paths = {k + "_path": v.path for k, v in results.items()} - context = get_scope_context() - context.attributes.update(paths) + update_scope_context(paths) logger.info(f"Saved checkpoints: {paths}") full_dict = {"name": name, **loop_state, **paths} with open(os.path.join(log_path, "checkpoints.jsonl"), "a") as f: diff --git a/tinker_cookbook/distillation/train_on_policy.py b/tinker_cookbook/distillation/train_on_policy.py index 8e9943f6..4ae7d254 100644 --- a/tinker_cookbook/distillation/train_on_policy.py +++ b/tinker_cookbook/distillation/train_on_policy.py @@ -7,13 +7,18 @@ import logging import os import time -from typing import Any, List, Literal, Sequence, Dict, cast +from typing import Any, Dict, List, Literal, Sequence, cast import chz import tinker import torch + from tinker_cookbook import checkpoint_utils from tinker_cookbook.display import colorize_example +from tinker_cookbook.distillation.datasets import ( + CompositeDataset, + DistillationDatasetConfig, +) from tinker_cookbook.eval.evaluators import SamplingClientEvaluator, SamplingClientEvaluatorBuilder from tinker_cookbook.rl.data_processing import ( assemble_training_data, @@ -21,6 +26,12 @@ ) from tinker_cookbook.rl.metric_util import RLTestSetEvaluator, compute_trajectory_metrics from tinker_cookbook.rl.metrics import discounted_future_sum_vectorized +from tinker_cookbook.rl.train import ( + compute_full_batch_metrics_and_get_sampling_client, + do_group_rollout_and_filter_constant_reward, + save_checkpoint_and_get_sampling_client, + train_step, +) from tinker_cookbook.rl.types import ( EnvGroupBuilder, TrajectoryGroup, @@ -28,21 +39,7 @@ from tinker_cookbook.tokenizer_utils import Tokenizer from tinker_cookbook.utils import ml_log from tinker_cookbook.utils.misc_utils import safezip, timed -from tinker_cookbook.utils.trace import scope, get_scope_context, trace_init - -# Dataset configuration classes -from tinker_cookbook.distillation.datasets import ( - CompositeDataset, - DistillationDatasetConfig, -) - -# We re-use these methods from the RL training recipe -from tinker_cookbook.rl.train import ( - save_checkpoint_and_get_sampling_client, - train_step, - compute_full_batch_metrics_and_get_sampling_client, - do_group_rollout_and_filter_constant_reward, -) +from tinker_cookbook.utils.trace import scope, update_scope_context, trace_init logger = logging.getLogger(__name__) @@ -228,8 +225,7 @@ async def do_train_step_and_get_sampling_client( dataset_indices_P: List[int], teacher_clients: List[tinker.SamplingClient], ) -> tuple[tinker.SamplingClient, dict[str, Any]]: - context = get_scope_context() - context.attributes["step"] = i_batch + update_scope_context({"step": i_batch}) metrics = {} data_D, prepare_minibatch_metrics = await prepare_minibatch( diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 4e01fddc..d4e5163d 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -7,12 +7,14 @@ import logging import os import time -from typing import Any, Callable, List, Literal, Sequence, Iterator +from contextlib import contextmanager +from typing import Any, Callable, Iterator, List, Literal, Sequence import chz import numpy as np import tinker import torch + from tinker_cookbook import checkpoint_utils from tinker_cookbook.completers import TinkerTokenCompleter from tinker_cookbook.display import colorize_example @@ -39,9 +41,7 @@ from tinker_cookbook.tokenizer_utils import Tokenizer from tinker_cookbook.utils import logtree, ml_log from tinker_cookbook.utils.misc_utils import safezip, split_list, timed -from tinker_cookbook.utils.trace import scope, trace_init, get_scope_context -from contextlib import contextmanager - +from tinker_cookbook.utils.trace import scope, update_scope_context, trace_init logger = logging.getLogger(__name__) @@ -809,8 +809,7 @@ async def do_train_step_streaming_and_get_sampling_client( # Number of groups per minibatch in each optimizer substep groups_per_minibatch = groups_per_substep // cfg.stream_minibatch_config.num_minibatches - context = get_scope_context() - context.attributes["step"] = i_batch + update_scope_context({"step": i_batch}) metrics = {} @@ -904,8 +903,7 @@ async def do_train_step_and_get_sampling_client( env_group_builders_P: Sequence[EnvGroupBuilder], trajectory_groups_P: list[TrajectoryGroup], ) -> tuple[tinker.SamplingClient, dict[str, Any]]: - context = get_scope_context() - context.attributes["step"] = i_batch + update_scope_context({"step": i_batch}) metrics = {} data_D, prepare_minibatch_metrics = await prepare_minibatch( diff --git a/tinker_cookbook/supervised/train.py b/tinker_cookbook/supervised/train.py index f501ce4c..4dd758d3 100644 --- a/tinker_cookbook/supervised/train.py +++ b/tinker_cookbook/supervised/train.py @@ -32,7 +32,7 @@ from tinker_cookbook.utils import ml_log from tinker_cookbook.utils.lr_scheduling import compute_schedule_lr_multiplier from tinker_cookbook.utils.misc_utils import timed -from tinker_cookbook.utils.trace import get_scope_context, scope, trace_init +from tinker_cookbook.utils.trace import scope, update_scope_context, trace_init logger = logging.getLogger(__name__) @@ -107,22 +107,24 @@ async def run_evals( checkpoint. Returned metrics are prefixed with ``test/`` so they can be logged next to the same-step training metrics. """ - context = get_scope_context() - context.attributes["step"] = step + update_scope_context({"step": step}) metrics = {} sampling_client = None @scope async def run_evaluator(evaluator: Evaluator) -> dict[str, float]: - context = get_scope_context() - context.attributes["step"] = step - context.attributes["evaluator_name"] = type(evaluator).__name__ + update_scope_context( + { + "step": step, + "evaluator_name": type(evaluator).__name__, + } + ) if isinstance(evaluator, TrainingClientEvaluator): - context.attributes["evaluator_type"] = "TrainingClientEvaluator" + update_scope_context({"evaluator_type": "TrainingClientEvaluator"}) return await evaluator(training_client) elif isinstance(evaluator, SamplingClientEvaluator): - context.attributes["evaluator_type"] = "SamplingClientEvaluator" + update_scope_context({"evaluator_type": "SamplingClientEvaluator"}) # Create sampling client lazily, only when needed nonlocal sampling_client if sampling_client is None: @@ -225,8 +227,7 @@ async def main(config: Config): @scope async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch: step = epoch_idx * n_batches + batch_idx - context = get_scope_context() - context.attributes["step"] = step + update_scope_context({"step": step}) batch_start_time = time.time() metrics: dict[str, int | float | str] = {"epoch": epoch_idx} @@ -286,8 +287,7 @@ async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch: @scope async def finish_batch(submitted: SubmittedBatch): - context = get_scope_context() - context.attributes["step"] = submitted.step + update_scope_context({"step": submitted.step}) metrics = submitted.metrics metrics["progress"] = min((submitted.step + 1) / progress_denominator, 1.0) diff --git a/tinker_cookbook/tests/test_trace.py b/tinker_cookbook/tests/test_trace.py index 50afde4a..a0dd2f89 100644 --- a/tinker_cookbook/tests/test_trace.py +++ b/tinker_cookbook/tests/test_trace.py @@ -1,8 +1,15 @@ -import json -from tinker_cookbook.utils.trace import scope, trace_init, trace_shutdown, get_scope_context import asyncio -import threading +import json import tempfile +import threading + +from tinker_cookbook.utils.trace import ( + get_scope_context, + scope, + update_scope_context, + trace_init, + trace_shutdown, +) @scope @@ -30,8 +37,7 @@ def ced(): @scope async def baz(): await asyncio.sleep(0.02) - context = get_scope_context() - context.attributes["baz"] = "baz" + update_scope_context({"baz": "baz"}) ced() diff --git a/tinker_cookbook/utils/trace.py b/tinker_cookbook/utils/trace.py index ee9c896b..6e735e5b 100644 --- a/tinker_cookbook/utils/trace.py +++ b/tinker_cookbook/utils/trace.py @@ -1,4 +1,6 @@ +import argparse import asyncio +import atexit import functools import inspect import json @@ -6,12 +8,10 @@ import threading import time from contextvars import ContextVar -from typing import Any, Callable from dataclasses import dataclass, field from enum import Enum -import argparse from io import TextIOWrapper -import atexit +from typing import Any, Callable class EventType(str, Enum): @@ -407,6 +407,20 @@ async def foo(): return result +def update_scope_context(values: dict[str, Any]) -> None: + """Update the current scope's context. Example usage: + + @scope + async def foo(step: int): + update_scope_context({"step": step}) + await bar() + + """ + result = trace_context.get(ScopeContext()) + assert result is not None, "Trace context is not set" + result.attributes.update(values) + + def convert_jsonl_to_json_main(): """Helper script to convert the trace events format into a visualizable format""" parser = argparse.ArgumentParser(