Skip to content

Commit 2ec1bcb

Browse files
Xiuyu-Lijoschu
andauthored
Add configurable temperature parameter for RL rollout sampling (#86)
Co-authored-by: John Schulman <[email protected]>
1 parent 20e26a6 commit 2ec1bcb

File tree

4 files changed

+382
-2
lines changed

4 files changed

+382
-2
lines changed

tinker_cookbook/completers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class TinkerTokenCompleter(TokenCompleter):
5454

5555
sampling_client: tinker.SamplingClient
5656
max_tokens: int
57+
temperature: float = 1.0
5758

5859
async def __call__(
5960
self, model_input: tinker.ModelInput, stop: StopCondition
@@ -63,7 +64,11 @@ async def __call__(
6364
sample_result = await self.sampling_client.sample_async(
6465
prompt=model_input,
6566
num_samples=1,
66-
sampling_params=tinker.SamplingParams(stop=stop, max_tokens=self.max_tokens),
67+
sampling_params=tinker.SamplingParams(
68+
stop=stop,
69+
max_tokens=self.max_tokens,
70+
temperature=self.temperature,
71+
),
6772
)
6873

6974
# Extract tokens and logprobs from the first (and only) sample

tinker_cookbook/recipes/math_rl/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class CLIConfig:
3434
groups_per_batch: int = 100
3535
learning_rate: float = 1e-5
3636
max_tokens: int = 5
37+
temperature: float = 1.0
3738
kl_penalty_coef: float = 0.0
3839

3940
# Number of optimizer steps per training iteration.
@@ -124,6 +125,7 @@ async def cli_main(cli_config: CLIConfig):
124125
model_name=cli_config.model_name,
125126
lora_rank=cli_config.lora_rank,
126127
max_tokens=cli_config.max_tokens,
128+
temperature=cli_config.temperature,
127129
wandb_project=cli_config.wandb_project,
128130
wandb_name=wandb_name,
129131
log_path=log_path,

tinker_cookbook/rl/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ class Config:
229229
dataset_builder: RLDatasetBuilder # also determines batch size
230230
model_name: str
231231
max_tokens: int
232+
temperature: float = 1.0 # Changing sampling temperature is not generally recommended; does not currently play well with KL penalty
232233
compute_post_kl: bool = False
233234
evaluator_builders: list[SamplingClientEvaluatorBuilder] = chz.field(default_factory=list)
234235
lora_rank: int = 32
@@ -366,6 +367,7 @@ async def trajectory_group_worker_task(
366367
sampling_client,
367368
builder,
368369
max_tokens=cfg.max_tokens,
370+
temperature=cfg.temperature,
369371
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
370372
enable_logging=enable_logging,
371373
)
@@ -501,6 +503,7 @@ async def trajectory_group_worker_loop():
501503
sampling_client,
502504
env_group_builder,
503505
max_tokens=cfg.max_tokens,
506+
temperature=cfg.temperature,
504507
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
505508
)
506509
if trajectory_group is None:
@@ -659,10 +662,11 @@ async def do_group_rollout_and_filter_constant_reward(
659662
sampling_client: tinker.SamplingClient,
660663
env_group_builder: EnvGroupBuilder,
661664
max_tokens: int,
665+
temperature: float,
662666
do_remove_constant_reward_groups: bool,
663667
enable_logging: bool = True,
664668
) -> TrajectoryGroup | None:
665-
policy = TinkerTokenCompleter(sampling_client, max_tokens=max_tokens)
669+
policy = TinkerTokenCompleter(sampling_client, max_tokens=max_tokens, temperature=temperature)
666670

667671
with logtree.optional_enable_logging(enable_logging):
668672
trajectory_group = await do_group_rollout(env_group_builder, policy)
@@ -991,6 +995,7 @@ async def do_sync_training(
991995
sampling_client,
992996
builder,
993997
max_tokens=cfg.max_tokens,
998+
temperature=cfg.temperature,
994999
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
9951000
enable_logging=i < cfg.num_groups_to_log,
9961001
),

0 commit comments

Comments
 (0)