add iw-opd distillation#6191
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Pull request overview
Adds Importance-Weighted On-Policy Distillation (IW-OPD) as an optional objective for the experimental DistillationTrainer, including support for using cached vLLM rollout logprobs to form a detached rollout-policy baseline for the IW-OPD advantage.
Changes:
- Introduces
distillation_objective="iw_opd"with new hyperparameters (iw_opd_gamma,iw_opd_epsilon) and config-level validation/guards for compatible settings. - Implements
_compute_iw_opd_lossand wires it into both local-teacher and teacher-server loss paths. - Extends the vLLM on-policy buffering path to request sampled-token logprobs and cache them as
rollout_logprobsfor IW-OPD.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
trl/experimental/distillation/distillation_trainer.py |
Adds IW-OPD loss implementation and integrates cached rollout logprobs from vLLM generation into the distillation loss path. |
trl/experimental/distillation/distillation_config.py |
Adds new IW-OPD configuration fields and enforces constraints required for on-policy IW-OPD usage. |
tests/experimental/test_distillation_trainer.py |
Adds unit and integration tests covering IW-OPD config guards, weight/loss math, rollout-logprob usage, and a basic training run. |
docs/source/paper_index.md |
Adds a paper index entry describing IW-OPD and mapping key paper settings to DistillationConfig. |
docs/source/distillation_trainer.md |
Documents the new distillation_objective option and IW-OPD usage/constraints. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Want higher recall? High effort reviews run extra passes and find more bugs. A team admin can switch effort levels in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit d8c4b98. Configure here.
| completion_tokens=completion_tokens, | ||
| labels=trimmed_labels, | ||
| teacher_actual_logprobs=teacher_result["actual_logprobs"], | ||
| rollout_logprobs=rollout_logprobs, |
There was a problem hiding this comment.
Rollout logprobs column misalignment
Medium Severity
Cached vLLM rollout_logprobs are stored with column 0 as the first completion token, but IW-OPD combines them with teacher and student logprobs in the slice that starts at _compute_prompt_length. When prompt lengths differ within a batch, each row’s completion begins at a different column in that slice, so element-wise pairing uses rollout logprobs from the wrong token positions while teacher values (including the teacher-server offset layout) stay aligned.
Additional Locations (2)
Reviewed by Cursor Bugbot for commit d8c4b98. Configure here.
| ): | ||
| sampled_idx = step_token_ids.index(sampled_token_id) | ||
| sampled_logprobs.append(step_lps[sampled_idx]) | ||
| slice_rollout_logprobs[slice_idx].append(sampled_logprobs) |
There was a problem hiding this comment.
Missing logprob token ids guard
Medium Severity
When vLLM returns logprobs, _store_completions_in_buffer always indexes logprob_token_ids[i] without checking that logprob_token_ids is non-null. Server mode can supply logprobs while logprob_token_ids is missing from the payload, which raises a runtime error during IW-OPD on-policy generation instead of a clear validation failure.
Reviewed by Cursor Bugbot for commit d8c4b98. Configure here.


Summary
Adds IW-OPD as an optional objective for the experimental distillation trainer.
The new path uses sampled-token teacher and rollout logprobs to build the detached IW-OPD advantage, including cached vLLM rollout logprobs when
use_vllm=True.Tests
make precommitpytest tests/experimental/test_distillation_trainer.py -qFixes # (issue)
Before submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Note
Medium Risk
Changes on-policy loss computation and vLLM rollout buffering in experimental distillation; misconfiguration is mostly blocked by config checks, but wrong rollout log-probs could skew training.
Overview
Adds
distillation_objective="iw_opd"as an alternative to the default JSD/KL path in experimentalDistillationTrainer, implementing Importance-Weighted On-Policy Distillation from the position-bias paper.DistillationConfiggainsdistillation_objective(default"jsd"),iw_opd_gamma, andiw_opd_epsilon, plus validation: IW-OPD requireslmbda=1.0,reverse_kl_top_1_mode="sampled", no Liger kernel, andvllm_sync_frequency=1whenuse_vllm=True.The trainer introduces
_compute_iw_opd_loss: detached teacher-minus-rollout advantages on sampled tokens, prefix importance weights from accumulated teacher–student drift, and IW-OPD metrics.compute_lossbranches to this path for local and teacher-server teachers. With vLLM, generation requestslogprobs=0, extracts sampled-token rollout log-probs, and stores them asrollout_logprobson buffered batches.Docs cover usage in
distillation_trainer.mdand a paper-index entry; tests cover config guards, loss math, cached rollouts, training, and the server path.Reviewed by Cursor Bugbot for commit d8c4b98. Bugbot is set up for automated code reviews on this repo. Configure here.