Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ verifiers = [
"verifiers",
"openai",
]
all = [
"tinker_cookbook[dev,vector-search,wandb,neptune-scale,trackio,verifiers]",
]

[build-system]
requires = ["hatchling"]
Expand Down
16 changes: 8 additions & 8 deletions tinker_cookbook/recipes/verifiers_rl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,25 @@ def log_results(
print(f"Examples: {num_examples}")
print(f"Rollouts per example: {rollouts_per_example}")
print("--- Example ---")
printable_prompts = [messages_to_printable(p) for p in results.prompt]
printable_completions = [messages_to_printable(c) for c in results.completion]
printable_prompts = [messages_to_printable(p) for p in results["prompt"]]
printable_completions = [messages_to_printable(c) for c in results["completion"]]
vf.print_prompt_completions_sample(
printable_prompts, printable_completions, results.reward, step=0
printable_prompts, printable_completions, results["reward"], step=0
)
print("--- All ---")
print("Rewards:")
print(
f"reward: avg - {sum(results.reward) / len(results.reward):.3f}, std - {np.std(results.reward):.3f}"
f"reward: avg - {sum(results['reward']) / len(results['reward']):.3f}, std - {np.std(results['reward']):.3f}"
)
r = rollouts_per_example
n = len(results.reward) // r
n = len(results["reward"]) // r
for i in range(r):
# rounded to 3 decimal places
trials = [round(results.reward[(i * n) + j], 3) for j in range(n)]
trials = [round(results["reward"][(i * n) + j], 3) for j in range(n)]
out = f"r{i + 1}: {trials}"
print(out)
for k in results.metrics:
v = results.metrics[k]
for k in results["metrics"]:
v = results["metrics"][k]
print(f"{k}: avg - {sum(v) / len(v):.3f}, std - {np.std(v):.3f}")
for i in range(r):
# rounded to 3 decimal places
Expand Down
27 changes: 15 additions & 12 deletions tinker_cookbook/recipes/verifiers_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import tinker
import verifiers as vf
from verifiers.utils.async_utils import maybe_semaphore
from tinker_cookbook import cli_utils, model_info, renderers
from tinker_cookbook.completers import TokensWithLogprobs, TokenCompleter, TinkerTokenCompleter
from tinker_cookbook.recipes.verifiers_rl.tinker_openai import TinkerAsyncOpenAIClient
Expand Down Expand Up @@ -105,24 +106,26 @@ def hook(messages, model_input, tokens, logprobs):
)
local_client.set_generation_hook(hook)

completion, state = await builder.vf_env.rollout(
rollout_input: vf.RolloutInput = {
"prompt": builder.prompt,
"answer": builder.answer,
"task": builder.task,
"info": builder.info,
"example_id": 0,
}
state = await builder.vf_env.rollout(
input=rollout_input,
client=local_client,
model="tinker",
prompt=builder.prompt,
answer=builder.answer,
task=builder.task,
info=builder.info,
sampling_args={},
)

rs = await builder.vf_env.rubric.score_rollout(
prompt=builder.prompt,
completion=completion,
answer=builder.answer,
score_sem = await maybe_semaphore(None)
await builder.vf_env.rubric.score_rollout(
state=state,
task=builder.task,
info=builder.info,
score_sem=score_sem,
)
rs: vf.RolloutScore = {"reward": state["reward"], "metrics": state.get("metrics", {})}

transitions: List[Transition] = []
for _msgs, model_input, tokens, logprobs in recorded:
Expand All @@ -144,7 +147,7 @@ def hook(messages, model_input, tokens, logprobs):
metrics=transitions[-1].metrics,
)
traj = Trajectory(transitions=transitions, final_ob=tinker.ModelInput.empty())
return traj, float(rs.reward), dict(rs.metrics)
return traj, float(rs["reward"]), dict(rs["metrics"])

results = await asyncio.gather(*[run_one_rollout() for _ in range(cli_config.group_size)])
trajectories_G = [t for (t, _r, _m) in results]
Expand Down