Skip to content

[fix][train] Prompt-based mini-batching for step-wise training#1483

Open
CharlieFRuan wants to merge 1 commit intomainfrom
prompt-based-mini-batching
Open

[fix][train] Prompt-based mini-batching for step-wise training#1483
CharlieFRuan wants to merge 1 commit intomainfrom
prompt-based-mini-batching

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Apr 9, 2026

Summary

Step-wise training decomposes each multi-turn trajectory into one training sequence per LLM turn. This produces a variable number of sequences per prompt, which broke the old fixed-size mini-batching (policy_mini_batch_size * n_samples_per_prompt):

  1. Crash: AssertionError: data batch size must be divisible by mini_batch_size, got 3509 and 640 — the total sequence count isn't divisible by the fixed mini-batch size.
  2. Variable optimizer steps: more turns → more mini-batches → more optimizer steps per training batch, causing the LR schedule to advance faster than intended.

This PR makes mini-batching operate in prompt units. Each mini-batch contains the sequences for exactly policy_mini_batch_size prompts, regardless of how many turns those prompts produced. The number of optimizer steps is always train_batch_size / policy_mini_batch_size * update_epochs_per_batch, independent of turn counts.

Approach

  • Two helpers (compute_prompt_end_indices, compute_prompt_mini_batch_boundaries) derive prompt boundaries from the uid list and group them into (start, end) mini-batch boundaries. Computed once in convert_to_training_input and stored in metadata.
  • MeshDispatch.stage_chunks now accepts boundary pairs instead of a fixed mini_batch_size. Each mini-batch is individually padded to dp_size with loss_mask=0 dummy entries.
  • _normalize_advantages iterates over boundary pairs instead of fixed-size slices.
  • validate_batch_sizes skips per-GPU sequence-level divisibility checks when step_wise_trajectories=True (variable sizes can't be validated statically; per-mini-batch padding handles it at runtime).

Non-step-wise backward compatibility

For non-step-wise training, every prompt has exactly n_samples_per_prompt sequences, so the boundaries are uniform — identical to the old fixed-size slicing. This is verified by a parametrized unit test across multiple batch/sample configs.

There is one unified code path for both step-wise and non-step-wise.

Test plan

  • 27 unit tests covering boundary computation, variable DP staging/padding, optimizer step count invariance, and non-step-wise uniformity (tests/train/test_prompt_mini_batch.py)
  • All 6 existing trainer tests pass unchanged (tests/train/test_trainer.py)
  • Integration test: bash examples/train/turn_level_rewards/run_gsm8k_multi_turn.sh with step_wise_trajectories=true, max_turns=5 — previously crashed, now completes training steps successfully. Timing: convert_to_training_input = 1.92s (negligible overhead from uid scan), policy_train = 50.22s.

🤖 Generated with Claude Code


Open with Devin

Step-wise training decomposes each trajectory into one sequence per turn,
producing a variable number of sequences per prompt. The old code computed
mini-batch size as `policy_mini_batch_size * n_samples_per_prompt` (a fixed
number of sequences), which broke in two ways:

1. The total sequence count was not divisible by this fixed mini-batch size,
   causing an assertion failure in `stage_chunks`.
2. More sequences meant more mini-batches and more optimizer steps, making
   the LR schedule advance faster than intended.

This PR makes mini-batching operate in prompt units. Each mini-batch
contains the sequences for exactly `policy_mini_batch_size` prompts,
regardless of how many sequences that is. The number of optimizer steps is
always `train_batch_size / policy_mini_batch_size * update_epochs_per_batch`,
independent of turn counts.

Key changes:

- `compute_prompt_end_indices` + `compute_prompt_mini_batch_boundaries`:
  two small helpers that derive prompt boundaries from the uid list and
  group them into mini-batches. Computed once in `convert_to_training_input`
  and stored in metadata.

- `MeshDispatch.stage_chunks` now accepts `(start, end)` boundary pairs
  instead of a fixed `mini_batch_size`. Each mini-batch is individually
  padded to `dp_size` with `loss_mask=0` dummy entries.

- `_normalize_advantages` iterates over boundary pairs instead of
  fixed-size slices.

- `validate_batch_sizes` skips per-GPU sequence-level divisibility checks
  when `step_wise_trajectories=True` (variable sizes can't be validated
  statically; padding handles it at runtime).

For non-step-wise training, every prompt has exactly `n_samples_per_prompt`
sequences, so the boundaries are uniform and identical to the old fixed-size
slicing. This is verified by a parametrized unit test.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements prompt-aligned mini-batching for step-wise training, ensuring that mini-batch boundaries are aligned with prompt boundaries and trajectories are never split. It introduces logic to compute these boundaries at runtime based on prompt UIDs and adds padding to MeshDispatch to maintain compatibility with data-parallel requirements. Feedback identifies a bug in the padding logic where slicing the original tensor for padding can fail if the required padding exceeds the mini-batch size, and notes that assigning global metadata to mini-batches may cause length mismatches in worker processes.

elif key == "loss_mask":
padding = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
else:
padding = tensor[:pad_size].clone()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The padding logic tensor[:pad_size].clone() is fragile and will fail if the mini-batch size (mb_size) is smaller than the required pad_size. For example, if dp_size=8 and mb_size=1, pad_size would be 7, but tensor[:7] would only return 1 element, resulting in a total mini-batch size of 2 instead of 8. This will cause the subsequent chunk operation to produce incorrect sizes or fail. Since these are dummy entries with loss_mask=0, it is safer to pad with zeros of the correct shape.

Suggested change
padding = tensor[:pad_size].clone()
padding = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)

padding = tensor[:pad_size].clone()
new_tensors[key] = torch.cat([tensor, padding], dim=0)
mini_batch = TrainingInputBatch(new_tensors)
mini_batch.metadata = data.metadata
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Assigning the full batch metadata (data.metadata) to a mini-batch is problematic. Metadata often contains fields like uids or trajectory_ids that are expected to be aligned with the number of sequences in the batch. By assigning the global metadata here, you are providing the worker with metadata for the entire training batch rather than just the current mini-batch. Additionally, unlike the pad_batch implementation in trainer.py, this logic does not pad the metadata fields, which will lead to length mismatches if the worker code attempts to use them.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 potential issue.

View 5 additional findings in Devin Review.

Open in Devin Review

elif key == "loss_mask":
padding = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
else:
padding = tensor[:pad_size].clone()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Padding in stage_chunks produces insufficient entries when mini-batch is smaller than dp_size

When a step-wise mini-batch has fewer sequences than dp_size, the padding logic at dispatch.py:208 uses tensor[:pad_size].clone() for general tensors (e.g., sequences, rewards, advantages), but Python/PyTorch slicing clamps to the tensor length — so if pad_size > mb_size, only mb_size entries are produced instead of pad_size. Meanwhile, is_last_step (line 204) and loss_mask (line 206) correctly create exactly pad_size entries via torch.ones/torch.zeros. This causes the tensors in new_tensors to have inconsistent batch dimensions, crashing at TrainingInputBatch(new_tensors) (line 210) with a batch-size-mismatch ValueError.

The bug triggers when policy_mini_batch_size * n_samples_per_prompt < dp_size (all trajectories having 1 turn), which is specifically enabled by step-wise training since the per-GPU divisibility validations are intentionally skipped (skyrl/train/utils/utils.py:107).

Suggested change
padding = tensor[:pad_size].clone()
padding = tensor[torch.arange(pad_size) % len(tensor)].clone()
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant