Skip to content

Commit 95b82c8

Browse files
Intaikfacebook-github-bot
authored andcommitted
drop episodes with 0 advantages or truncated (#580)
Summary: Episodes with all rewards = 0 or =1 does not help learning as advantage would be 0. also, episodes with generations that are tuncated due to max_res_tokens would mostly get 0 rewards unnecessary as most of answers are at the end. Dropping these episodes provides trainer better batches to learn from (at the cost of sampling efficiency) {F1983571844} {F1983571853} Reviewed By: casteryh Differential Revision: D87243621
1 parent 5daec1b commit 95b82c8

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

apps/grpo/main.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from forge.observability.metric_actors import get_or_create_metric_logger
3232
from forge.observability.metrics import record_metric, Reduce
3333
from forge.observability.perf_tracker import Tracer
34-
3534
from forge.types import LauncherConfig, ProvisionerConfig
3635
from forge.util.config import parse
3736
from forge.util.ops import compute_logprobs
@@ -250,6 +249,11 @@ async def sample(self) -> dict[str, str] | None:
250249
len(sample["request"]),
251250
Reduce.MEAN,
252251
)
252+
record_metric(
253+
"dataset/sample/max_sample_len",
254+
len(sample["request"]),
255+
Reduce.MAX,
256+
)
253257
record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX)
254258

255259
return sample
@@ -396,6 +400,24 @@ async def continuous_rollouts():
396400
input_ids[i, :max_req_tokens] = episode.request_tensor
397401
input_ids[i, max_req_tokens:] = episode.response_tensor
398402

403+
# drop episodes if
404+
# 1> reward std-dev is very small (including all 0s and all 1s)
405+
# 2> response is potentially truncated (response_len >= max_res_tokens)
406+
rewards = [e.reward for e in episodes]
407+
rewards_std = torch.std(torch.tensor(rewards))
408+
max_response_len = max(
409+
e.completion.token_ids.shape[0] for e in episodes
410+
)
411+
drop = rewards_std < 1e-3 or max_response_len >= max_res_tokens
412+
record_metric(
413+
"main/continuous_rollouts/dropped_episodes",
414+
1 if drop else 0,
415+
Reduce.SUM,
416+
)
417+
if drop:
418+
del input_ids, episodes
419+
continue
420+
399421
t.step("reward_evaluation")
400422

401423
ref_logprobs = await ref_model.forward.route(

0 commit comments

Comments
 (0)