Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tinker_cookbook/recipes/tool_use/search/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from os import environ
from typing import Any

from google import genai
import google.genai as genai
from google.genai import types

logger = getLogger(__name__)
Expand All @@ -25,7 +25,7 @@ def get_gemini_client(
http_options: types.HttpOptions | None = None,
**kwargs: Any,
) -> genai.Client:
from google import genai
import google.genai as genai
from google.genai.types import HttpOptions

project = project or environ.get("GCP_VERTEXAI_PROJECT_NUMBER")
Expand Down
2 changes: 1 addition & 1 deletion tinker_cookbook/recipes/tool_use/search/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from chromadb.api import AsyncClientAPI
from chromadb.api.types import QueryResult
from chromadb.config import Settings
from google import genai
import google.genai as genai
from tinker_cookbook.recipes.tool_use.search.embedding import (
get_gemini_client,
get_gemini_embedding,
Expand Down
16 changes: 8 additions & 8 deletions tinker_cookbook/recipes/verifiers_rl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,25 @@ def log_results(
print(f"Examples: {num_examples}")
print(f"Rollouts per example: {rollouts_per_example}")
print("--- Example ---")
printable_prompts = [messages_to_printable(p) for p in results.prompt]
printable_completions = [messages_to_printable(c) for c in results.completion]
printable_prompts = [messages_to_printable(p) for p in results["prompt"]]
printable_completions = [messages_to_printable(c) for c in results["completion"]]
vf.print_prompt_completions_sample(
printable_prompts, printable_completions, results.reward, step=0
printable_prompts, printable_completions, results["reward"], step=0
)
print("--- All ---")
print("Rewards:")
print(
f"reward: avg - {sum(results.reward) / len(results.reward):.3f}, std - {np.std(results.reward):.3f}"
f"reward: avg - {sum(results['reward']) / len(results['reward']):.3f}, std - {np.std(results['reward']):.3f}"
)
r = rollouts_per_example
n = len(results.reward) // r
n = len(results["reward"]) // r
for i in range(r):
# rounded to 3 decimal places
trials = [round(results.reward[(i * n) + j], 3) for j in range(n)]
trials = [round(results["reward"][(i * n) + j], 3) for j in range(n)]
out = f"r{i + 1}: {trials}"
print(out)
for k in results.metrics:
v = results.metrics[k]
for k in results["metrics"]:
v = results["metrics"][k]
print(f"{k}: avg - {sum(v) / len(v):.3f}, std - {np.std(v):.3f}")
for i in range(r):
# rounded to 3 decimal places
Expand Down
27 changes: 15 additions & 12 deletions tinker_cookbook/recipes/verifiers_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import tinker
import verifiers as vf
from verifiers.utils.async_utils import maybe_semaphore
from tinker_cookbook import cli_utils, model_info, renderers
from tinker_cookbook.completers import TokensWithLogprobs, TokenCompleter, TinkerTokenCompleter
from tinker_cookbook.recipes.verifiers_rl.tinker_openai import TinkerAsyncOpenAIClient
Expand Down Expand Up @@ -105,24 +106,26 @@ def hook(messages, model_input, tokens, logprobs):
)
local_client.set_generation_hook(hook)

completion, state = await builder.vf_env.rollout(
rollout_input: vf.RolloutInput = {
"prompt": builder.prompt,
"answer": builder.answer,
"task": builder.task,
"info": builder.info,
"example_id": 0,
}
state = await builder.vf_env.rollout(
input=rollout_input,
client=local_client,
model="tinker",
prompt=builder.prompt,
answer=builder.answer,
task=builder.task,
info=builder.info,
sampling_args={},
)

rs = await builder.vf_env.rubric.score_rollout(
prompt=builder.prompt,
completion=completion,
answer=builder.answer,
score_sem = await maybe_semaphore(None)
await builder.vf_env.rubric.score_rollout(
state=state,
task=builder.task,
info=builder.info,
score_sem=score_sem,
)
rs: vf.RolloutScore = {"reward": state["reward"], "metrics": state.get("metrics", {})}

transitions: List[Transition] = []
for _msgs, model_input, tokens, logprobs in recorded:
Expand All @@ -144,7 +147,7 @@ def hook(messages, model_input, tokens, logprobs):
metrics=transitions[-1].metrics,
)
traj = Trajectory(transitions=transitions, final_ob=tinker.ModelInput.empty())
return traj, float(rs.reward), dict(rs.metrics)
return traj, float(rs["reward"]), dict(rs["metrics"])

results = await asyncio.gather(*[run_one_rollout() for _ in range(cli_config.group_size)])
trajectories_G = [t for (t, _r, _m) in results]
Expand Down