Skip to content

implement message level rollout with linear trajectories#6250

Open
AmineDiro wants to merge 5 commits into
mainfrom
linear-trajectory
Open

implement message level rollout with linear trajectories#6250
AmineDiro wants to merge 5 commits into
mainfrom
linear-trajectory

Conversation

@AmineDiro

@AmineDiro AmineDiro commented Jul 2, 2026

Copy link
Copy Markdown
Member

AsyncGRPO: message-mode rollouts

Adds an opt-in way to build training rows from a multi-turn conversation.

Message mode keeps the conversation as messages and re-tokenizes the whole thing each turn, then checks whether the fresh tokens still start with the tokens held so far. If yes → append the new part (same as token mode). If no → a rewrite happened → close the row and open a new one that matches what the model actually read.

config

AsyncGRPOConfig(
    rollout_protocol="message",   # "token" (default) | "message"
    fork_threshold_tokens=1024,   # message mode only
)

How rows are built

Each turn is a TurnRecord(prompt_ids, output_ids, output_log_probs). At the end, _chain_to_sequences walks the turns and per turn classifies the drift vs. the tokens held so far:

  • CLEAN — new prompt starts with held tokens → append (prompt/tool = context, generated = trained).
  • REALIGN — only the last answer's tail wobbled and the new turn < fork_threshold_tokens → overwrite that tail as context, same row.
  • FORK — a real rewrite → start a new row.

Advantage: one per conversation, stamped on every row it produced, no split. Under the token-mean loss a fork is invisible (each generated token trained once, same advantage, same denominator).

Next work: Tree trajectories

The design already covers the two hard parts:

  • Scoring is branch-agnostic. Every TrainingSequence carries a rollout_id and _score_group groups by it, not by list position. N rows per conversation already works today; a tree that yields several rows per conversation needs no scoring change.
  • The reconciler is tree-agnostic. _common_prefix_len / _SampleBuilder / _chain_to_sequences reconcile a single linear chain of turns. A root→leaf path in a tree is such a chain.

So the only change a tree adds is on the rollout side: replace the flat turns: list[TurnRecord] (+ one _chain_to_sequences call) with a tree of turns, then run the reconciler once per root→leaf path, with all rows sharing the conversation's rollout_id. Recording would carry the turn's message context so each turn can be placed under its parent node (this is what lets shared prefixes branch and stay trained once). Nothing downstream (collator, loss, scoring) changes.


Note

Medium Risk
Touches multi-turn rollout → training-sample mapping and advantage stamping for GRPO; default rollout_protocol="token" limits blast radius, but message mode can change which tokens are trained when chat templates drift.

Overview
Adds an opt-in message rollout path for AsyncGRPO alongside the default token buffer mode, controlled by rollout_protocol and fork_threshold_tokens on AsyncGRPOConfig.

In message mode, MessageRolloutLoop re-tokenizes the full conversation each turn and runs _chain_to_sequences to turn turn records into one or more TrainingSequence rows: CLEAN appends on one row, REALIGN treats small last-answer wobble as context, FORK starts a new row when history rewrites. Token mode is unchanged in behavior but now emits a single TrainingSequence per conversation.

Rollout groups carry completions_sequences instead of flat logprob/mask lists; _score_group expands each conversation into multiple RolloutSamples while stamping the same conversation-level advantage on every forked row (metrics get a per-row copy). The trainer picks MessageRolloutLoop vs AsyncRolloutLoop and passes loop_cls into the spawned worker. Tests cover the reconciler, message loop, and scoring.

Reviewed by Cursor Bugbot for commit 7a354d6. Bugbot is set up for automated code reviews on this repo. Configure here.

@qgallouedec qgallouedec left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, thanks! discussed internally. I'm sharing the figure I made

Image

Comment on lines +59 to +67
rollout_protocol (`str`, *optional*, defaults to `"token"`):
How a multi-turn conversation is turned into training rows. `"token"` grows a token buffer, appending each
turn's generated tokens and tokenized tool results (fast; cannot represent a conversation rewrite).
`"message"` re-tokenizes the whole conversation every turn and reconciles the result against the tokens held
so far: a clean append stays one row, a rewrite (dropped reasoning, summarized history) forks a new row.
fork_threshold_tokens (`int`, *optional*, defaults to `1024`):
Message mode only. When a turn's re-tokenized prompt drifts inside the last generated answer, a drift with a
generated turn shorter than this many tokens is treated as a re-tokenization wobble (realigned as context)
rather than a rewrite (a new row). Ignored when `rollout_protocol="token"`.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to advocate for simplicity here: because it might become impossible to maintain if we continue to support every possible configuration. What do you think about supporting only the message protocol? This trainer is still experimental; we don't need to be backward compatible, so it's a good time to make bold and radical decisions.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the same thing. Only thing that bothers me is the performance penalty of supporting this messages mode :/ Need to measure this to be sure 🫡 .

But I agree with the idea 1000%

@qgallouedec qgallouedec Jul 2, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I understand the concern. With tokens, you can only get G sequences for 1 prompt. For messages, you end up with, worst case scenario G*max_num_turns

@bot-ci-comment

bot-ci-comment Bot commented Jul 2, 2026

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to follow the repo structure, we should have only one test_async_grpo_trainer.py

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 9dc7d79. Configure here.

tool_failure_count += n_failures
completion.extend(tool_messages)
messages.extend(tool_messages) # tool result goes back as a MESSAGE, re-tokenized next turn
iteration_num += 1

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty tool calls loop forever

High Severity

In MessageRolloutLoop, the exit check only treats missing tool_calls as terminal (tool_calls is None). An assistant message with tool_calls set to an empty list is treated as a tool turn: no messages are appended, iteration_num advances, and the loop re-tokenizes and generates again with identical input. With no iteration cap, this can spin indefinitely in the rollout worker.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 9dc7d79. Configure here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valid defensive concern, real hang path, but a genuine edge case.
Worth a one-line if not tool_calls: in both loops?

builders.append(builder)
else:
builders[-1].append_turn(turn, kind)
return [b.to_training_sequence(rollout_id) for b in builders if b.has_trained_token()]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filtered rows skew GRPO advantages

Medium Severity

_chain_to_sequences drops reconciled builders with no completion_mask ones, so message-mode rollouts can yield zero TrainingSequences for a conversation (e.g. an empty generation turn). _score_group still computes that conversation’s reward and folds it into group mean/std for advantages, but emits no RolloutSamples, unlike token mode which always enqueues one row per generation.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 9dc7d79. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants