Skip to content

Fix vLLM server-mode generation in OnlineDPOTrainer#6228

Merged
qgallouedec merged 5 commits into
mainfrom
fix-server-online-dpo
Jul 2, 2026
Merged

Fix vLLM server-mode generation in OnlineDPOTrainer#6228
qgallouedec merged 5 commits into
mainfrom
fix-server-online-dpo

Conversation

@qgallouedec

@qgallouedec qgallouedec commented Jun 30, 2026

Copy link
Copy Markdown
Member

Fixes #5514

Background

OnlineDPOTrainer generates 2 completions per prompt to form a preference pair, and everything downstream (reward computation, rewards.split(batch_size), the chosen/rejected split) assumes a block layout: the first half of the batch is one completion per prompt, the second half the other (prompt i at rows i and i+N). The colocate and transformers generation paths both produce exactly this layout.

Unlike GRPO, Online DPO does not duplicate prompts in the batch (no RepeatSampler).

The vLLM server path (_generate_vllm_server) was copy-pasted from GRPO and never adapted to these semantics. As a result it has been broken since the feature was introduced in #3783: it ran without crashing but trained on meaningless data.

The bugs

Three compounding issues in _generate_vllm_server:

  1. Double-flatten (OnlineDPOTrainer._generate_vllm_server() flattens vllm-serve completion_ids twice #5514). VLLMClient.generate(...)["completion_ids"] already returns list[list[int]] (one token-id list per completion), but the trainer re-flattened it, turning every token into its own single-token completion.

    server returns:  [[11, 12, 13], [21, 22]]          2 completions (3 and 2 tokens)
    old code made:   [[11],[12],[13],[21],[22]]         5 "completions" of 1 token each   ✗
    
  2. Wrong prompt subsampling. all_prompts[::self.num_generations] de-duplicates prompts (correct for GRPO, its RepeatSampler duplicates them), but it silently drops half the batch in Online DPO, where prompts are unique.

  3. Interleaved vs block ordering. The server returns completions interleaved (grouped by prompt), and the trainer built prompt_ids interleaved to match. But the loss splits the batch in half (rewards.split(batch_size)), which assumes block order. Preference pairs were therefore formed across different prompts.

    2 prompts A, B → 2 completions each.   "A0" = prompt A, completion 0.
    
    server returns (grouped by prompt):    [ A0  A1  B0  B1 ]
    loss/reward code expects (block):      [ A0  B0 | A1  B1 ]
                                             └── rewards.split(batch_size) pairs row i with row i+N ──┘
    
    old path kept interleaved order  →  split pairs:   A0 vs B0  ✗   A1 vs B1  ✗   (different prompts!)
    fixed to block order             →  split pairs:   A0 vs A1  ✓   B0 vs B1  ✓   (same prompt)
    

Bug 1 masked bug 2: flattening produced enough single-token entries to fill the per-process slice, so the count happened to line up and nothing crashed. Removing only the flatten (the minimal fix proposed in the issue) instead surfaces a crash at torch.cat((prompt_ids, completion_ids)).

Verification

Real OnlineDPOTrainer run in vLLM server mode on 2 GPUs, Qwen/Qwen2-0.5B-Instruct, a length reward, batch size 2:

from datasets import Dataset
from transformers import AutoTokenizer

from trl.experimental.online_dpo.online_dpo_config import OnlineDPOConfig
from trl.experimental.online_dpo.online_dpo_trainer import OnlineDPOTrainer

MODEL = "Qwen/Qwen2-0.5B-Instruct"


def reward_len(completions, **kwargs):
    # Trivial reward: prefer shorter completions. Return one float per completion.
    return [-abs(len(c) - 20) for c in completions]


prompts = [
    "What is the capital of France?",
    "Write a short greeting.",
    "Name a primary color.",
    "Say hello.",
]
dataset = Dataset.from_dict({"prompt": prompts})

config = OnlineDPOConfig(
    output_dir="/tmp/claude-150110/online_dpo_5514",
    per_device_train_batch_size=2,
    max_steps=2,
    max_new_tokens=32,
    use_vllm=True,
    vllm_mode="server",
    logging_steps=1,
    report_to=[],
)

tokenizer = AutoTokenizer.from_pretrained(MODEL)

trainer = OnlineDPOTrainer(
    model=MODEL,
    reward_funcs=reward_len,
    args=config,
    train_dataset=dataset,
    processing_class=tokenizer,
)

trainer.train()
print("DONE")

Note

The existing test_train_with_vllm_server never caught this: it is @pytest.mark.slow, requires an external running server, and only asserts "train_loss" in log_history (which passes even on garbage data). A stronger server-mode test that checks completion shape and pairing would be a good follow-up.


Note

Medium Risk
Touches the experimental Online DPO training loop and reward pairing semantics; incorrect layout would silently train on wrong preferences, but the change aligns server mode with existing colocate/Transformers behavior.

Overview
Fixes vLLM server-mode generation in OnlineDPOTrainer so preference pairs match the same block batch layout used by colocate and Transformers paths (rewards.split(batch_size) pairs row i with row i+N).

_generate_vllm_server no longer de-duplicates gathered prompts with all_prompts[::num_generations] (GRPO-style; Online DPO keeps unique prompts). It calls VLLMClient.generate with the full all_prompts list and drops the extra flatten that turned each completion token into a fake one-token completion.

After the per-process slice, completions are reordered from server interleaved order to block order via completion_ids[0::2] + completion_ids[1::2], and prompt_ids are built in the same layout (duplicate the tokenized prompt list once instead of interleaving per row).

Reviewed by Cursor Bugbot for commit 96d543d. Bugbot is set up for automated code reviews on this repo. Configure here.

@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif

kashif commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

thanks! checking

@kashif

kashif commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

great catch!

@qgallouedec qgallouedec merged commit f444f85 into main Jul 2, 2026
5 checks passed
@qgallouedec qgallouedec deleted the fix-server-online-dpo branch July 2, 2026 20:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

OnlineDPOTrainer._generate_vllm_server() flattens vllm-serve completion_ids twice

2 participants