[fix][train] Prompt-based mini-batching for step-wise training#1483
[fix][train] Prompt-based mini-batching for step-wise training#1483CharlieFRuan wants to merge 1 commit intomainfrom
Conversation
Step-wise training decomposes each trajectory into one sequence per turn, producing a variable number of sequences per prompt. The old code computed mini-batch size as `policy_mini_batch_size * n_samples_per_prompt` (a fixed number of sequences), which broke in two ways: 1. The total sequence count was not divisible by this fixed mini-batch size, causing an assertion failure in `stage_chunks`. 2. More sequences meant more mini-batches and more optimizer steps, making the LR schedule advance faster than intended. This PR makes mini-batching operate in prompt units. Each mini-batch contains the sequences for exactly `policy_mini_batch_size` prompts, regardless of how many sequences that is. The number of optimizer steps is always `train_batch_size / policy_mini_batch_size * update_epochs_per_batch`, independent of turn counts. Key changes: - `compute_prompt_end_indices` + `compute_prompt_mini_batch_boundaries`: two small helpers that derive prompt boundaries from the uid list and group them into mini-batches. Computed once in `convert_to_training_input` and stored in metadata. - `MeshDispatch.stage_chunks` now accepts `(start, end)` boundary pairs instead of a fixed `mini_batch_size`. Each mini-batch is individually padded to `dp_size` with `loss_mask=0` dummy entries. - `_normalize_advantages` iterates over boundary pairs instead of fixed-size slices. - `validate_batch_sizes` skips per-GPU sequence-level divisibility checks when `step_wise_trajectories=True` (variable sizes can't be validated statically; padding handles it at runtime). For non-step-wise training, every prompt has exactly `n_samples_per_prompt` sequences, so the boundaries are uniform and identical to the old fixed-size slicing. This is verified by a parametrized unit test. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request implements prompt-aligned mini-batching for step-wise training, ensuring that mini-batch boundaries are aligned with prompt boundaries and trajectories are never split. It introduces logic to compute these boundaries at runtime based on prompt UIDs and adds padding to MeshDispatch to maintain compatibility with data-parallel requirements. Feedback identifies a bug in the padding logic where slicing the original tensor for padding can fail if the required padding exceeds the mini-batch size, and notes that assigning global metadata to mini-batches may cause length mismatches in worker processes.
| elif key == "loss_mask": | ||
| padding = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) | ||
| else: | ||
| padding = tensor[:pad_size].clone() |
There was a problem hiding this comment.
The padding logic tensor[:pad_size].clone() is fragile and will fail if the mini-batch size (mb_size) is smaller than the required pad_size. For example, if dp_size=8 and mb_size=1, pad_size would be 7, but tensor[:7] would only return 1 element, resulting in a total mini-batch size of 2 instead of 8. This will cause the subsequent chunk operation to produce incorrect sizes or fail. Since these are dummy entries with loss_mask=0, it is safer to pad with zeros of the correct shape.
| padding = tensor[:pad_size].clone() | |
| padding = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) |
| padding = tensor[:pad_size].clone() | ||
| new_tensors[key] = torch.cat([tensor, padding], dim=0) | ||
| mini_batch = TrainingInputBatch(new_tensors) | ||
| mini_batch.metadata = data.metadata |
There was a problem hiding this comment.
Assigning the full batch metadata (data.metadata) to a mini-batch is problematic. Metadata often contains fields like uids or trajectory_ids that are expected to be aligned with the number of sequences in the batch. By assigning the global metadata here, you are providing the worker with metadata for the entire training batch rather than just the current mini-batch. Additionally, unlike the pad_batch implementation in trainer.py, this logic does not pad the metadata fields, which will lead to length mismatches if the worker code attempts to use them.
| elif key == "loss_mask": | ||
| padding = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) | ||
| else: | ||
| padding = tensor[:pad_size].clone() |
There was a problem hiding this comment.
🔴 Padding in stage_chunks produces insufficient entries when mini-batch is smaller than dp_size
When a step-wise mini-batch has fewer sequences than dp_size, the padding logic at dispatch.py:208 uses tensor[:pad_size].clone() for general tensors (e.g., sequences, rewards, advantages), but Python/PyTorch slicing clamps to the tensor length — so if pad_size > mb_size, only mb_size entries are produced instead of pad_size. Meanwhile, is_last_step (line 204) and loss_mask (line 206) correctly create exactly pad_size entries via torch.ones/torch.zeros. This causes the tensors in new_tensors to have inconsistent batch dimensions, crashing at TrainingInputBatch(new_tensors) (line 210) with a batch-size-mismatch ValueError.
The bug triggers when policy_mini_batch_size * n_samples_per_prompt < dp_size (all trajectories having 1 turn), which is specifically enabled by step-wise training since the per-GPU divisibility validations are intentionally skipped (skyrl/train/utils/utils.py:107).
| padding = tensor[:pad_size].clone() | |
| padding = tensor[torch.arange(pad_size) % len(tensor)].clone() |
Was this helpful? React with 👍 or 👎 to provide feedback.
Summary
Step-wise training decomposes each multi-turn trajectory into one training sequence per LLM turn. This produces a variable number of sequences per prompt, which broke the old fixed-size mini-batching (
policy_mini_batch_size * n_samples_per_prompt):AssertionError: data batch size must be divisible by mini_batch_size, got 3509 and 640— the total sequence count isn't divisible by the fixed mini-batch size.This PR makes mini-batching operate in prompt units. Each mini-batch contains the sequences for exactly
policy_mini_batch_sizeprompts, regardless of how many turns those prompts produced. The number of optimizer steps is alwaystrain_batch_size / policy_mini_batch_size * update_epochs_per_batch, independent of turn counts.Approach
compute_prompt_end_indices,compute_prompt_mini_batch_boundaries) derive prompt boundaries from the uid list and group them into(start, end)mini-batch boundaries. Computed once inconvert_to_training_inputand stored in metadata.MeshDispatch.stage_chunksnow accepts boundary pairs instead of a fixedmini_batch_size. Each mini-batch is individually padded todp_sizewithloss_mask=0dummy entries._normalize_advantagesiterates over boundary pairs instead of fixed-size slices.validate_batch_sizesskips per-GPU sequence-level divisibility checks whenstep_wise_trajectories=True(variable sizes can't be validated statically; per-mini-batch padding handles it at runtime).Non-step-wise backward compatibility
For non-step-wise training, every prompt has exactly
n_samples_per_promptsequences, so the boundaries are uniform — identical to the old fixed-size slicing. This is verified by a parametrized unit test across multiple batch/sample configs.There is one unified code path for both step-wise and non-step-wise.
Test plan
tests/train/test_prompt_mini_batch.py)tests/train/test_trainer.py)bash examples/train/turn_level_rewards/run_gsm8k_multi_turn.shwithstep_wise_trajectories=true, max_turns=5— previously crashed, now completes training steps successfully. Timing:convert_to_training_input= 1.92s (negligible overhead from uid scan),policy_train= 50.22s.🤖 Generated with Claude Code