Skip to content

add iw-opd distillation#6191

Open
kashif wants to merge 2 commits into
huggingface:mainfrom
kashif:iw-opd-distillation
Open

add iw-opd distillation#6191
kashif wants to merge 2 commits into
huggingface:mainfrom
kashif:iw-opd-distillation

Conversation

@kashif

@kashif kashif commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds IW-OPD as an optional objective for the experimental distillation trainer.

The new path uses sampled-token teacher and rollout logprobs to build the detached IW-OPD advantage, including cached vLLM rollout logprobs when use_vllm=True.

Tests

  • make precommit
  • pytest tests/experimental/test_distillation_trainer.py -q

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.


Note

Medium Risk
Changes on-policy loss computation and vLLM rollout buffering in experimental distillation; misconfiguration is mostly blocked by config checks, but wrong rollout log-probs could skew training.

Overview
Adds distillation_objective="iw_opd" as an alternative to the default JSD/KL path in experimental DistillationTrainer, implementing Importance-Weighted On-Policy Distillation from the position-bias paper.

DistillationConfig gains distillation_objective (default "jsd"), iw_opd_gamma, and iw_opd_epsilon, plus validation: IW-OPD requires lmbda=1.0, reverse_kl_top_1_mode="sampled", no Liger kernel, and vllm_sync_frequency=1 when use_vllm=True.

The trainer introduces _compute_iw_opd_loss: detached teacher-minus-rollout advantages on sampled tokens, prefix importance weights from accumulated teacher–student drift, and IW-OPD metrics. compute_loss branches to this path for local and teacher-server teachers. With vLLM, generation requests logprobs=0, extracts sampled-token rollout log-probs, and stores them as rollout_logprobs on buffered batches.

Docs cover usage in distillation_trainer.md and a paper-index entry; tests cover config guards, loss math, cached rollouts, training, and the server path.

Reviewed by Cursor Bugbot for commit d8c4b98. Bugbot is set up for automated code reviews on this repo. Configure here.

@kashif kashif requested a review from cmpatino June 27, 2026 21:00
@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Importance-Weighted On-Policy Distillation (IW-OPD) as an optional objective for the experimental DistillationTrainer, including support for using cached vLLM rollout logprobs to form a detached rollout-policy baseline for the IW-OPD advantage.

Changes:

  • Introduces distillation_objective="iw_opd" with new hyperparameters (iw_opd_gamma, iw_opd_epsilon) and config-level validation/guards for compatible settings.
  • Implements _compute_iw_opd_loss and wires it into both local-teacher and teacher-server loss paths.
  • Extends the vLLM on-policy buffering path to request sampled-token logprobs and cache them as rollout_logprobs for IW-OPD.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
trl/experimental/distillation/distillation_trainer.py Adds IW-OPD loss implementation and integrates cached rollout logprobs from vLLM generation into the distillation loss path.
trl/experimental/distillation/distillation_config.py Adds new IW-OPD configuration fields and enforces constraints required for on-policy IW-OPD usage.
tests/experimental/test_distillation_trainer.py Adds unit and integration tests covering IW-OPD config guards, weight/loss math, rollout-logprob usage, and a basic training run.
docs/source/paper_index.md Adds a paper index entry describing IW-OPD and mapping key paper settings to DistillationConfig.
docs/source/distillation_trainer.md Documents the new distillation_objective option and IW-OPD usage/constraints.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Want higher recall? High effort reviews run extra passes and find more bugs. A team admin can switch effort levels in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit d8c4b98. Configure here.

completion_tokens=completion_tokens,
labels=trimmed_labels,
teacher_actual_logprobs=teacher_result["actual_logprobs"],
rollout_logprobs=rollout_logprobs,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rollout logprobs column misalignment

Medium Severity

Cached vLLM rollout_logprobs are stored with column 0 as the first completion token, but IW-OPD combines them with teacher and student logprobs in the slice that starts at _compute_prompt_length. When prompt lengths differ within a batch, each row’s completion begins at a different column in that slice, so element-wise pairing uses rollout logprobs from the wrong token positions while teacher values (including the teacher-server offset layout) stay aligned.

Additional Locations (2)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit d8c4b98. Configure here.

):
sampled_idx = step_token_ids.index(sampled_token_id)
sampled_logprobs.append(step_lps[sampled_idx])
slice_rollout_logprobs[slice_idx].append(sampled_logprobs)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing logprob token ids guard

Medium Severity

When vLLM returns logprobs, _store_completions_in_buffer always indexes logprob_token_ids[i] without checking that logprob_token_ids is non-null. Server mode can supply logprobs while logprob_token_ids is missing from the payload, which raises a runtime error during IW-OPD on-policy generation instead of a clear validation failure.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit d8c4b98. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants