diff --git a/docs/content/docs/tutorials/step-wise-training.mdx b/docs/content/docs/tutorials/step-wise-training.mdx index 9459beff6b..071f222f86 100644 --- a/docs/content/docs/tutorials/step-wise-training.mdx +++ b/docs/content/docs/tutorials/step-wise-training.mdx @@ -26,7 +26,7 @@ This page will also guide you how to implement step-wise training for your custo When step-wise is enabled, a batch of T trajectories with an average of M turns per trajectory produces T×M training samples (sequences). This means: -- **Each mini-batch contains the same number of sequences** (`policy_mini_batch_size * n_samples`), but those sequences are now step-samples rather than full trajectories. The effective number of trajectories per mini-batch is reduced. The number of mini-batches (and hence optimizer steps) per training batch increases by the average number of turns — so if you have `train_batch_size=mini_batch_size=32` with an average of 3 turns, you get 3 optimizer steps instead of 1 for each training step. It is also possible that a mini-batch boundary falls mid-trajectory. +- **Each mini-batch contains the sequences for exactly `policy_mini_batch_size` prompts**, regardless of how many turns those prompts produced. This means the number of mini-batches (and hence optimizer steps) per training batch is always `train_batch_size / policy_mini_batch_size`, independent of the number of turns. This also means that the actual mini batch size (number of sequences) trained in each mini batch can vary. Each mini batch always leads to a single optimizer step. - **Advantages are computed on last steps only**, then broadcast to all steps of the same trajectory. This is mathematically equivalent to non-step-wise advantage computation for GRPO. - **Training time grows as O(T²) vs O(T)**, since each trajectory of T turns becomes T sequences to forward (each with a growing prompt prefix), as opposed to 1 sequence. SkyRL will support prefix-aware merging of per-step sequences when the prefix matches (WIP), which brings the cost back to O(T) in the common case. - **Metrics** like `generate/avg_num_tokens` and `generate/avg_response_length` are per-turn rather than per-trajectory, since each training sample is a single turn. diff --git a/skyrl/backends/skyrl_train/distributed/dispatch.py b/skyrl/backends/skyrl_train/distributed/dispatch.py index 18e56666d1..f7278a0ac0 100644 --- a/skyrl/backends/skyrl_train/distributed/dispatch.py +++ b/skyrl/backends/skyrl_train/distributed/dispatch.py @@ -12,6 +12,7 @@ from skyrl.backends.skyrl_train.training_batch import ( TrainingInputBatch, TrainingOutputBatch, + pad_training_input_batch, ) @@ -170,37 +171,36 @@ def stage_chunks( cls, dp_size: int, data: TrainingInputBatch, - mini_batch_size: int, + mini_batch_boundaries: List[Tuple[int, int]], ) -> List[List[ObjectRef]]: - """ - Pre-stage all mini-batch chunks into the object store. + """Pre-stage mini-batch chunks into the object store. - Slices the full batch into mini-batches, chunks each by DP rank, and - ``ray.put``s each chunk. + Each mini-batch is defined by a ``(start, end)`` index pair from mini_batch_boundaries. + Mini-batches are individually padded so that their size is divisible by dp_size, using dummy + entries with ``loss_mask=0`` that do not affect the loss. Args: - dp_size: Number of data-parallel ranks - data: Full TrainingInputBatch to slice from - mini_batch_size: Size of each mini-batch (before DP chunking) + dp_size: Number of data-parallel ranks. + data: Full TrainingInputBatch to slice from. + mini_batch_boundaries: List of ``(start, end)`` index pairs. The i-th mini-batch is + data[mini_batch_boundaries[i][0]:mini_batch_boundaries[i][1]]. Returns: - List of per-mini-batch chunk ref lists. ``result[i][dp_rank]`` is - the ObjectRef for mini-batch *i*, DP rank *dp_rank*. + ``result[i][dp_rank]`` - ObjectRef for mini-batch *i*, DP rank *dp_rank*. """ - assert ( - len(data) % mini_batch_size == 0 - ), f"data batch size must be divisible by mini_batch_size, got {len(data)} and {mini_batch_size}" - assert ( - mini_batch_size % dp_size == 0 - ), f"mini_batch_size must be divisible by dp_size, got {mini_batch_size} and {dp_size}" - num_mini_batches = len(data) // mini_batch_size - chunk_size = mini_batch_size // dp_size - all_chunk_refs: List[List[ObjectRef]] = [] - for step in range(num_mini_batches): - start = step * mini_batch_size - end = start + mini_batch_size + for start, end in mini_batch_boundaries: mini_batch = data[start:end] + mb_size = end - start + + # Pad to make divisible by dp_size. Will only be non-zero for step-wise training. + pad_size = (-mb_size) % dp_size + if pad_size > 0: + mini_batch = pad_training_input_batch(mini_batch, pad_size) + + mini_batch_size = len(mini_batch) + assert mini_batch_size % dp_size == 0, f"mini_batch_size % dp_size != 0, got {mini_batch_size} and {dp_size}" + chunk_size = mini_batch_size // dp_size chunks = mini_batch.chunk(chunk_size) all_chunk_refs.append([ray.put(chunk) for chunk in chunks]) return all_chunk_refs diff --git a/skyrl/backends/skyrl_train/utils/ppo_utils.py b/skyrl/backends/skyrl_train/utils/ppo_utils.py index 05b7b0a0e3..64606b26a7 100644 --- a/skyrl/backends/skyrl_train/utils/ppo_utils.py +++ b/skyrl/backends/skyrl_train/utils/ppo_utils.py @@ -1007,6 +1007,7 @@ def apply_loss_reduction_to_advantages_minibatch( Args: advantages: Advantage tensor of shape (minibatch_size, seq_len). + For step-wise training, minibatch_size can be variable. loss_mask: Mask of shape (minibatch_size, seq_len) indicating valid loss tokens. loss_reduction: One of "token_mean", "token_mean_legacy", "sequence_mean", "seq_mean_token_sum_norm". micro_batch_size: Number of sequences per micro-batch @@ -1024,6 +1025,8 @@ def apply_loss_reduction_to_advantages_minibatch( # Option 1b: legacy token-mean that normalizes per-microbatch then averages across microbatches. elif loss_reduction == "token_mean_legacy": + # FIXME(Charlie): For step-wise training, mini_batch_size can be variable, so the number of + # microbatches depends on how we pad. num_micro_batches = batch_size // micro_batch_size for i in range(num_micro_batches): start_idx = i * micro_batch_size diff --git a/skyrl/backends/skyrl_train/workers/worker_dispatch.py b/skyrl/backends/skyrl_train/workers/worker_dispatch.py index a72223807e..91e14a356f 100644 --- a/skyrl/backends/skyrl_train/workers/worker_dispatch.py +++ b/skyrl/backends/skyrl_train/workers/worker_dispatch.py @@ -9,7 +9,7 @@ """ from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import ray from ray import ObjectRef @@ -162,24 +162,28 @@ def forward(self, model: str, data: TrainingInputBatch) -> TrainingOutputBatch: output = concatenate_outputs_after_mesh_dispatch(self._actor_groups[model].actor_infos, results) return output - def stage_data(self, model: str, data: TrainingInputBatch, mini_batch_size: int) -> List[List[ObjectRef]]: - """ - Pre-stage all mini-batch chunks in the Ray object store. + def stage_data( + self, + model: str, + data: TrainingInputBatch, + mini_batch_boundaries: List[Tuple[int, int]], + ) -> List[List[ObjectRef]]: + """Pre-stage mini-batch chunks in the Ray object store. Call this once before the training loop so that all serialization is done upfront and GPUs stay saturated during training. Args: - model: Model name (used to look up DP size) - data: Full training batch - mini_batch_size: Size of each mini-batch (before DP chunking) + model: Model name (used to look up DP size). + data: Full training batch. + mini_batch_boundaries: List of ``(start, end)`` index pairs. + The i-th mini-batch is data[mini_batch_boundaries[i][0]:mini_batch_boundaries[i][1]]. Returns: - ``result[i][dp_rank]`` is the ObjectRef for mini-batch *i*, - DP rank *dp_rank*. + ``result[i][dp_rank]`` - ObjectRef for mini-batch *i*, DP rank *dp_rank*. """ dp_size = self._actor_groups[model].actor_infos[0].rank.dp_size - return MeshDispatch.stage_chunks(dp_size, data, mini_batch_size) + return MeshDispatch.stage_chunks(dp_size, data, mini_batch_boundaries) def forward_backward( self, diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index 8ac49c8868..8e4e3af73e 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -188,3 +188,78 @@ def convert_prompts_responses_to_batch_tensors( logprobs_tensor, rollout_expert_indices_tensor, ) + + +def compute_prompt_mini_batch_boundaries( + uids: List[str], + mini_batch_size: int, + train_batch_size: int, + is_stepwise: bool, + n_samples_per_prompt: int, +) -> List[Tuple[int, int]]: + """Compute mini-batch ``(start, end)`` slices from a flat ``uids`` list. + + Args: + uids: List of uids, representing which prompt each sequence belongs to. + mini_batch_size: Number of prompts to include in each mini-batch. Same as training config's + config.trainer.policy_mini_batch_size or config.trainer.critic_mini_batch_size. + train_batch_size: Number of prompts in a training batch. For sanity check. + is_stepwise: Whether the training is step-wise. For sanity check. + n_samples_per_prompt: how many samples per prompt. For sanity check. + + Consecutive equal entries in ``uids`` belong to the same prompt. Each mini batch spans exactly + ``mini_batch_size`` prompts (the last may be smaller if the total prompt count is not divisible + in step-wise training). Works for both step-wise (variable sequences per prompt) and non-step-wise + (fixed ``n_samples_per_prompt`` sequences per prompt) training. + + We assume uids are contiguous, i.e. all n_samples_per_prompt trajectories for a prompt, or all + per-step sequences for a trajectory, are contiguous. + + Example A: normal non-step-wise training, with n_samples_per_prompt=2 and train_batch_size=4. + uids = ["p0", "p0", "p1", "p1", "p2", "p2", "p3", "p3"] + mini_batch_size = 2 + prompt_end_indices = [2, 4, 6, 8] + boundaries = [(0, 4), (4, 8)] # because each mini batch spans exactly 2 prompts, hence 4 sequences + + Example B: step-wise training with n_samples_per_prompt = 2, and each trajectory can have 1-2 turns. + uids = ["p0", "p0", "p0", "p0", "p1", "p1", "p2", "p2", "p2", "p3", "p3"] + mini_batch_size = 2 + prompt_end_indices = [4, 6, 9, 11] + boundaries = [(0, 6), (6, 11)] + """ + # First compute the end indices of each prompt. + prompt_end_indices: List[int] = [] + seen_uids: set[str] = set() + seen_uids.add(uids[0]) + for i in range(1, len(uids)): + if uids[i] != uids[i - 1]: + assert uids[i] not in seen_uids, f"uid {uids[i]!r} appears in non-contiguous positions at index {i}. Full uids: {uids}" + seen_uids.add(uids[i]) + prompt_end_indices.append(i) + prompt_end_indices.append(len(uids)) + + # seen_uids should equal to the number of prompts and equal to `train_batch_size` + num_prompts = len(prompt_end_indices) + assert num_prompts == train_batch_size and len(seen_uids) == train_batch_size + assert train_batch_size % mini_batch_size == 0 + + # Compute boundaries. + boundaries: List[Tuple[int, int]] = [] + start_seq = 0 + for i in range(0, num_prompts, mini_batch_size): + end_prompt_idx = i + mini_batch_size - 1 # i + mini_batch_size is next mini-batch's first prompt's end index + end_seq = prompt_end_indices[end_prompt_idx] + boundaries.append((start_seq, end_seq)) + start_seq = end_seq + + # Assert that the mini-batch boundaries are uniform for non-step-wise training. + if not is_stepwise: + expected_num_seq_in_mini_batch = n_samples_per_prompt * mini_batch_size + assert len(boundaries) == train_batch_size // mini_batch_size + for i, (start, end) in enumerate(boundaries): + assert start == i * expected_num_seq_in_mini_batch + assert end - start == expected_num_seq_in_mini_batch + else: + assert len(boundaries) >= train_batch_size // mini_batch_size + + return boundaries diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 82fdd78b11..d4062153bc 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -48,6 +48,7 @@ from skyrl.train.config import SkyRLTrainConfig from skyrl.train.dataset import PromptDataset from skyrl.train.dataset.preprocess import ( + compute_prompt_mini_batch_boundaries, convert_prompts_responses_to_batch_tensors, ) from skyrl.train.evaluate import evaluate, evaluate_step_wise @@ -605,7 +606,17 @@ def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]: ) def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch: - """Converts lists to a padded batch of tensors for training""" + """Converts lists to a padded batch of tensors for training + + Args: + generator_output (GeneratorOutput): Generated rollouts and associated data. + uids (List[str]): List of prompt-unique identifiers for each generator ouput in the same + order as `generator_output`. Used to identify which prompt each generated rollout belongs to. + Returns: + training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the + order of `generator_output` and hence `uids`. + """ + # 1. Extract generator output fields. prompt_ids: List[List[int]] = generator_output["prompt_token_ids"] response_ids: List[List[int]] = generator_output["response_ids"] rewards: List[List[float]] = generator_output["rewards"] @@ -628,6 +639,8 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis pixel_values = TensorList(pixel_values) image_grid_thw = TensorList(image_grid_thw) + + # 2. Convert to tensors. ( sequences_tensor, attention_masks_tensor, @@ -657,6 +670,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis ), "expected non-null rollout logprobs tensor when off_policy_correction is enabled" assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses" + # 3. Create training input batch. training_input = TrainingInputBatch( { "sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses) @@ -674,7 +688,20 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis if generator_output.get("is_last_step", None) is not None: training_input.metadata["is_last_step"] = generator_output["is_last_step"] - # padded response length + # 4. Compute mini-batch boundaries for train_critic_and_policy(). It excludes the ones + # we will add in pad_training_input_batch(). + train_batch_size = self.cfg.trainer.train_batch_size + n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt + is_stepwise = self.cfg.generator.step_wise_trajectories + training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries( + uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt + ) + if self.cfg.trainer.critic.model.path is not None: + training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries( + uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt + ) + + # 5. Record metadata and metrics. training_input.metadata["response_length"] = response_masks_tensor.shape[1] batch_num_seq, batch_padded_seq_len = sequences_tensor.shape logger.info(f"batch_num_seq: {batch_num_seq}, batch_padded_seq_len: {batch_padded_seq_len}") @@ -684,14 +711,11 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis "generate/batch_padded_seq_len": batch_padded_seq_len, } ) - if self.cfg.generator.step_wise_trajectories: - assert ( - "trajectory_ids" in generator_output - ), "Expected `trajectory_ids` in generator output for step wise training" training_input.metadata["avg_response_length"] = sum( len(sample_response_ids) for sample_response_ids in response_ids ) / len(response_ids) + # 6. Pad the batch, only needed for step-wise training's `fwd_logprobs_values_reward()`. logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}") dp_size = self.dispatch.get_lcm_dp_size() pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size @@ -1052,7 +1076,11 @@ def apply_reward_kl_penalty( return data @torch.no_grad() - def _normalize_advantages(self, data: TrainingInputBatch, mini_batch_size: int) -> TrainingInputBatch: + def _normalize_advantages( + self, + data: TrainingInputBatch, + mini_batch_boundaries: List[Tuple[int, int]], + ) -> TrainingInputBatch: advantages = data["advantages"] response_mask = data["response_mask"] @@ -1065,11 +1093,8 @@ def _normalize_advantages(self, data: TrainingInputBatch, mini_batch_size: int) data["advantages"] = (advantages - mean) * rstd # Step 2: Loss reduction normalization per mini-batch - num_mini_batches = len(data) // mini_batch_size normalized_advantages = torch.zeros_like(advantages) - for local_step in range(num_mini_batches): - start_idx = local_step * mini_batch_size - end_idx = (local_step + 1) * mini_batch_size + for start_idx, end_idx in mini_batch_boundaries: mini_batch = data[start_idx:end_idx] normalized_advantages[start_idx:end_idx] = apply_loss_reduction_to_advantages_minibatch( advantages=mini_batch["advantages"], @@ -1099,20 +1124,17 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s Returns: Dict of reduced metrics from training """ - # Compute mini batch size from config (algorithm-level concept) - n_samples = self.cfg.generator.n_samples_per_prompt + boundaries = data.metadata[f"{model}_mini_batch_boundaries"] + if model == "policy": - mini_batch_size = self.cfg.trainer.policy_mini_batch_size * n_samples # Normalize advantages for policy training; critic training does not need this - data = self._normalize_advantages(data, mini_batch_size) - else: - mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples + data = self._normalize_advantages(data, boundaries) all_metrics: Dict[str, List[float]] = defaultdict(list) # Pre-stage all per-DP mini-batch chunks in the object store so that # serialization is fully off the critical path during training. - all_chunk_refs = self.dispatch.stage_data(model, data, mini_batch_size) + all_chunk_refs = self.dispatch.stage_data(model, data, boundaries) # Training loop over epochs and mini-batches for _epoch in range(self.cfg.trainer.update_epochs_per_batch): diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 8344ec900f..014c97e697 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -97,6 +97,9 @@ def validate_batch_sizes(cfg: SkyRLTrainConfig): f"train_batch_size {cfg.trainer.train_batch_size} should be divisible by " f"policy_mini_batch_size {cfg.trainer.policy_mini_batch_size}" ) + + # TODO(Charlie): For step-wise training, the number of sequences per prompt is variable, and + # padded mini-batch may not be divisible by dp_size. Should check if we need these assertions. policy_mini_batch_size_per_gpu = ( cfg.trainer.policy_mini_batch_size * cfg.generator.n_samples_per_prompt // policy_dp_size ) @@ -294,6 +297,12 @@ def validate_cfg(cfg: SkyRLTrainConfig): f"connection. Use an outcome-based estimator (grpo, rloo, maxrl) or disable " f"step_wise_trajectories." ) + if cfg.generator.step_wise_trajectories and cfg.trainer.algorithm.loss_reduction == "token_mean_legacy": + # TODO(Charlie): this can be fixed, can revisit later. + raise ValueError( + "`token_mean_legacy` loss reduction is not supported with step-wise training. " + "Use `token_mean` instead." + ) assert cfg.trainer.algorithm.loss_reduction in ( "token_mean", diff --git a/tests/train/test_prompt_mini_batch.py b/tests/train/test_prompt_mini_batch.py new file mode 100644 index 0000000000..616a24e791 --- /dev/null +++ b/tests/train/test_prompt_mini_batch.py @@ -0,0 +1,323 @@ +""" +Tests for prompt-based mini-batching. + +uv run --isolated --extra dev pytest tests/train/test_prompt_mini_batch.py -v +""" + +from typing import List, Tuple +from unittest.mock import patch + +import pytest +import torch + +from skyrl.backends.skyrl_train.distributed.dispatch import MeshDispatch +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.train.dataset.preprocess import compute_prompt_mini_batch_boundaries + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_uids_stepwise( + prompts: List[Tuple[str, int, List[int]]], +) -> List[str]: + """Build uid list for a step-wise batch. + + Args: + prompts: List of (instance_id, spp, turns_per_sample) tuples. + ``turns_per_sample`` is a list of length ``spp`` giving + the number of turns for each trajectory of that prompt. + + Returns: + Flat uid list — same uid for all sequences of the same prompt. + """ + uids: List[str] = [] + for instance_id, _, turns_list in prompts: + for num_turns in turns_list: + for _ in range(num_turns): + uids.append(instance_id) + return uids + + +def _make_uids_fixed(train_batch_size: int, spp: int) -> List[str]: + """Build uid list for a non-step-wise batch (fixed spp per prompt). + + spp: samples per prompt. + + Example: + train_batch_size = 4, spp = 2 + uids = ["p0", "p0", "p1", "p1", "p2", "p2", "p3", "p3"] + """ + return [f"p{i}" for i in range(train_batch_size) for _ in range(spp)] + + +def _make_batch(num_sequences: int, seq_len: int = 8) -> TrainingInputBatch: + """Create a minimal TrainingInputBatch with the given number of sequences.""" + batch = TrainingInputBatch( + { + "sequences": torch.randint(0, 100, (num_sequences, seq_len)), + "attention_mask": torch.ones(num_sequences, seq_len, dtype=torch.long), + "response_mask": torch.ones(num_sequences, seq_len, dtype=torch.long), + "advantages": torch.randn(num_sequences, seq_len), + "loss_mask": torch.ones(num_sequences, seq_len, dtype=torch.float), + } + ) + batch.metadata = { + "is_last_step": [False] * num_sequences, + } + return batch + + +# --------------------------------------------------------------------------- +# Tests for compute_prompt_mini_batch_boundaries +# --------------------------------------------------------------------------- + + +class TestComputePromptMiniBatchBoundaries: + def test_nonstepwise_training(self): + """Test non-stepwise training with different mini batch sizes. + + For non-stepwise training, len(uids) % mini_batch_size should be 0. + """ + train_batch_size = 4 + spp = 2 + is_stepwise = False + uids = ["p0", "p0", "p1", "p1", "p2", "p2", "p3", "p3"] + + for mini_batch_size, expected_boundaries in [ + (1, [(0, 2), (2, 4), (4, 6), (6, 8)]), + (2, [(0, 4), (4, 8)]), + (4, [(0, 8)]), + ]: + boundaries = compute_prompt_mini_batch_boundaries(uids, mini_batch_size, train_batch_size, is_stepwise, spp) + assert boundaries == expected_boundaries + + def test_noncontiguous_uids_raise(self): + """Non-contiguous uids should raise an assertion error.""" + train_batch_size = 4 + spp = 2 + is_stepwise = False + uids = ["p0", "p0", "p1", "p0", "p2", "p2", "p3"] + with pytest.raises(AssertionError, match="uid 'p0' appears in non-contiguous positions at index 3."): + compute_prompt_mini_batch_boundaries(uids, 2, train_batch_size, is_stepwise, spp) + + def test_train_batch_size_not_equal_unique_uids_raise(self): + """When the number of prompts is not equal to the train batch size, raise an assertion error.""" + is_stepwise = False + train_batch_size = 4 + spp = 2 + mini_batch_size = 2 + uids = ["p0", "p0", "p1", "p1", "p2", "p2"] + with pytest.raises(AssertionError): + compute_prompt_mini_batch_boundaries(uids, mini_batch_size, train_batch_size, is_stepwise, spp) + + def test_stepwise_training(self): + """Step-wise: prompts have variable numbers of turns.""" + # Test 1: Each trajectory can have 1-4 turns, train_batch_size = 4, spp = 2. + mini_batch_size = 2 + train_batch_size = 4 + spp = 2 + is_stepwise = True + uids = _make_uids_stepwise( + [ + ("p0", 2, [3, 2]), # 5 seqs with a 3-turn trajectory and a 2-turn trajectory + ("p1", 2, [1, 4]), # 5 seqs + ("p2", 2, [2, 1]), # 3 seqs + ("p3", 2, [1, 1]), # 2 seqs + ] + ) + assert uids == ["p0", "p0", "p0", "p0", "p0", "p1", "p1", "p1", "p1", "p1", "p2", "p2", "p2", "p3", "p3"] + assert len(uids) == 15 + assert [(0, 10), (10, 15)] == compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp + ) + + # Test 2: Each mini batch only has 1 prompt. + mini_batch_size = 1 + train_batch_size = 2 + spp = 3 + is_stepwise = True + uids = _make_uids_stepwise( + [ + ("p0", 3, [2, 1, 3]), # 6 seqs + ("p1", 3, [1, 1, 1]), # 3 seqs + ] + ) + assert [(0, 6), (6, 9)] == compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp + ) + + @pytest.mark.parametrize( + "train_batch_size, spp, mini_batch_size", + [ + (8, 4, 4), + (256, 5, 128), + (16, 1, 4), + (32, 8, 32), + (128, 5, 64), + ], + ) + def test_non_stepwise_boundaries_are_uniform(self, train_batch_size, spp, mini_batch_size): + """ + For non-step-wise, every boundary must be [i*mb_size, (i+1)*mb_size). We run various + parametrization to make sure the assertion in `compute_prompt_mini_batch_boundaries()` passes. + """ + is_stepwise = False + uids = _make_uids_fixed(train_batch_size, spp) + boundaries = compute_prompt_mini_batch_boundaries(uids, mini_batch_size, train_batch_size, is_stepwise, spp) + + +# --------------------------------------------------------------------------- +# Tests for MeshDispatch.stage_chunks +# --------------------------------------------------------------------------- + + +class TestStageChunksVariable: + def test_uniform_minibatches_dp1(self): + """All mini-batches same size, dp_size=1 => no padding needed.""" + batch = _make_batch(10) + batch.metadata["pad_size"] = 0 + boundaries = [(0, 5), (5, 10)] + + with patch("skyrl.backends.skyrl_train.distributed.dispatch.ray") as mock_ray: + mock_ray.put.side_effect = lambda x: x + result = MeshDispatch.stage_chunks(dp_size=1, data=batch, mini_batch_boundaries=boundaries) + + assert len(result) == 2 + assert len(result[0]) == 1 + assert len(result[1]) == 1 + + def test_variable_minibatches_dp2_padding(self): + """Variable sizes with dp_size=2 => odd-sized mini-batches get padded.""" + batch = _make_batch(7) + batch.metadata["pad_size"] = 0 + boundaries = [(0, 3), (3, 7)] + + with patch("skyrl.backends.skyrl_train.distributed.dispatch.ray") as mock_ray: + chunks_put = [] + mock_ray.put.side_effect = lambda x: (chunks_put.append(x), len(chunks_put) - 1)[1] + result = MeshDispatch.stage_chunks(dp_size=2, data=batch, mini_batch_boundaries=boundaries) + + assert len(result) == 2 + assert len(result[0]) == 2 # 3->4, split into 2 + assert len(result[1]) == 2 # 4, split into 2 + assert len(chunks_put[0]) + len(chunks_put[1]) == 4 # padded from 3 + + def test_dp_size_4_heavy_padding(self): + """dp_size=4, mini-batch of 5 => padded to 8.""" + batch = _make_batch(5) + batch.metadata["pad_size"] = 0 + boundaries = [(0, 5)] + + with patch("skyrl.backends.skyrl_train.distributed.dispatch.ray") as mock_ray: + chunks_put = [] + mock_ray.put.side_effect = lambda x: (chunks_put.append(x), len(chunks_put) - 1)[1] + result = MeshDispatch.stage_chunks(dp_size=4, data=batch, mini_batch_boundaries=boundaries) + + assert len(result) == 1 + assert len(result[0]) == 4 + for chunk in chunks_put: + assert len(chunk) == 2 + + def test_loss_mask_zero_for_padding(self): + """Padding entries should have loss_mask=0.""" + batch = _make_batch(3, seq_len=4) + batch.metadata = {"pad_size": 0} + boundaries = [(0, 3)] + + with patch("skyrl.backends.skyrl_train.distributed.dispatch.ray") as mock_ray: + chunks_put = [] + mock_ray.put.side_effect = lambda x: (chunks_put.append(x), len(chunks_put) - 1)[1] + MeshDispatch.stage_chunks(dp_size=2, data=batch, mini_batch_boundaries=boundaries) + + all_loss_masks = torch.cat([c["loss_mask"] for c in chunks_put], dim=0) + assert all_loss_masks[:3].sum() == 3 * 4 + assert all_loss_masks[3].sum() == 0 + + def test_is_last_step_metadata_true_for_padding(self): + """Padding entries should have is_last_step=True in metadata.""" + batch = _make_batch(3) + boundaries = [(0, 3)] + + with patch("skyrl.backends.skyrl_train.distributed.dispatch.ray") as mock_ray: + chunks_put = [] + mock_ray.put.side_effect = lambda x: (chunks_put.append(x), len(chunks_put) - 1)[1] + MeshDispatch.stage_chunks(dp_size=2, data=batch, mini_batch_boundaries=boundaries) + + # The padded mini-batch (3->4) has is_last_step in metadata + # chunks share the same metadata reference from the padded mini-batch + assert chunks_put[0].metadata["is_last_step"] == [False, False, False, True] + + +# --------------------------------------------------------------------------- +# Tests for optimizer step count invariance +# --------------------------------------------------------------------------- + + +class TestOptimizerStepCount: + def test_num_minibatches_equals_train_over_policy(self): + """Number of mini-batches = train_batch_size / policy_mini_batch_size.""" + uids = _make_uids_stepwise( + [ + ("p0", 2, [3, 2]), + ("p1", 2, [1, 4]), + ("p2", 2, [2, 1]), + ("p3", 2, [1, 1]), + ] + ) + boundaries = compute_prompt_mini_batch_boundaries( + uids, mini_batch_size=2, train_batch_size=4, is_stepwise=True, n_samples_per_prompt=2 + ) + assert len(boundaries) == 4 // 2 + + def test_step_count_with_epochs(self): + """Total optimizer steps = num_mini_batches * update_epochs_per_batch.""" + uids = _make_uids_stepwise( + [ + ("p0", 5, [3, 2, 1, 4, 2]), + ("p1", 5, [1, 1, 1, 1, 1]), + ("p2", 5, [2, 3, 1, 1, 2]), + ("p3", 5, [1, 2, 3, 2, 1]), + ] + ) + boundaries = compute_prompt_mini_batch_boundaries( + uids, mini_batch_size=2, train_batch_size=4, is_stepwise=True, n_samples_per_prompt=5 + ) + update_epochs = 3 + assert len(boundaries) * update_epochs == (4 // 2) * update_epochs + + def test_same_step_count_as_non_stepwise(self): + """Step-wise and non-step-wise produce the same number of mini-batches.""" + train_batch_size = 256 + policy_mini_batch_size = 128 + n_samples = 5 + + # Non-step-wise + non_stepwise_uids = _make_uids_fixed(train_batch_size, n_samples) + non_stepwise_bounds = compute_prompt_mini_batch_boundaries( + non_stepwise_uids, + mini_batch_size=policy_mini_batch_size, + train_batch_size=train_batch_size, + is_stepwise=False, + n_samples_per_prompt=n_samples, + ) + + # Step-wise with variable turns + prompts = [] + for i in range(train_batch_size): + turns = [1 + (i * j) % 4 for j in range(n_samples)] + prompts.append((f"p{i}", n_samples, turns)) + stepwise_uids = _make_uids_stepwise(prompts) + stepwise_bounds = compute_prompt_mini_batch_boundaries( + stepwise_uids, + mini_batch_size=policy_mini_batch_size, + train_batch_size=train_batch_size, + is_stepwise=True, + n_samples_per_prompt=n_samples, + ) + + assert len(stepwise_bounds) == len(non_stepwise_bounds) == 2 + + # Non-step-wise boundaries should be uniform + assert non_stepwise_bounds == [(0, 640), (640, 1280)]