Skip to content

Commit 8546ee4

Browse files
JenniferWangfacebook-github-bot
authored andcommitted
Bug fix for dropping episodes in the GRPO
Summary: ## Bug Description: meta-pytorch#580 had incorrect indentation cuasing the input_ids, episodes varibles to be deleted inside the episodes building loop, causing program to hang. Next diff shall make background thread crashes to be surfaced to the main thread so that we know what thread crashed for what reason. Reviewed By: daniellepintz Differential Revision: D87554570
1 parent d8f420b commit 8546ee4

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

apps/grpo/main.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -424,23 +424,23 @@ async def continuous_rollouts():
424424
input_ids[i, :max_req_tokens] = episode.request_tensor
425425
input_ids[i, max_req_tokens:] = episode.response_tensor
426426

427-
# drop episodes if
428-
# 1> reward std-dev is very small (including all 0s and all 1s)
429-
# 2> response is potentially truncated (response_len >= max_res_tokens)
430-
rewards = [e.reward for e in episodes]
431-
rewards_std = torch.std(torch.tensor(rewards))
432-
max_response_len = max(
433-
e.completion.token_ids.shape[0] for e in episodes
434-
)
435-
drop = rewards_std < 1e-3 or max_response_len >= max_res_tokens
436-
record_metric(
437-
"main/continuous_rollouts/dropped_episodes",
438-
1 if drop else 0,
439-
Reduce.SUM,
440-
)
441-
if drop:
442-
del input_ids, episodes
443-
continue
427+
# drop episodes if
428+
# 1> reward std-dev is very small (including all 0s and all 1s)
429+
# 2> response is potentially truncated (response_len >= max_res_tokens)
430+
rewards = [e.reward for e in episodes]
431+
rewards_std = torch.std(torch.tensor(rewards))
432+
max_response_len = max(
433+
e.completion.token_ids.shape[0] for e in episodes
434+
)
435+
drop = rewards_std < 1e-3 or max_response_len >= max_res_tokens
436+
record_metric(
437+
"main/continuous_rollouts/dropped_episodes",
438+
1 if drop else 0,
439+
Reduce.SUM,
440+
)
441+
if drop:
442+
del input_ids, episodes
443+
continue
444444

445445
t.step("reward_evaluation")
446446

0 commit comments

Comments
 (0)