Skip to content

Commit 41f93a8

Browse files
Andrewzh112joschu
andauthored
fix: on-policy distillation missing temperature parameter fix (#103)
Co-authored-by: John Schulman <[email protected]>
1 parent f996a59 commit 41f93a8

File tree

3 files changed

+4
-0
lines changed

3 files changed

+4
-0
lines changed

tinker_cookbook/distillation/train_on_policy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class Config:
132132
dataset_configs: List[DistillationDatasetConfig]
133133
model_name: str
134134
max_tokens: int
135+
temperature: float = 1.0
135136
compute_post_kl: bool = False
136137
evaluator_builders: list[SamplingClientEvaluatorBuilder] = chz.field(default_factory=list)
137138
lora_rank: int = 32
@@ -308,6 +309,7 @@ async def do_sync_training(
308309
do_group_rollout_and_filter_constant_reward(
309310
sampling_client,
310311
builder,
312+
temperature=cfg.temperature,
311313
max_tokens=cfg.max_tokens,
312314
do_remove_constant_reward_groups=False,
313315
),

tinker_cookbook/recipes/distillation/on_policy_distillation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class CLIConfig:
6464
groups_per_batch: int = 1024
6565
learning_rate: float = 1e-4
6666
max_tokens: int = 4096
67+
temperature: float = 1.0
6768
kl_penalty_coef: float = 1.0
6869
kl_discount_factor: float = 0.0
6970

tinker_cookbook/recipes/distillation/on_policy_multi_teacher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class CLIConfig:
5858
group_size: int = 4 # Number of rollouts per prompt
5959
learning_rate: float = 1e-4
6060
max_tokens: int = 4096
61+
temperature: float = 1.0
6162
kl_penalty_coef: float = 1.0
6263
kl_discount_factor: float = 0.0
6364

0 commit comments

Comments
 (0)