Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/content/docs/tutorials/step-wise-training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ This page will also guide you how to implement step-wise training for your custo

When step-wise is enabled, a batch of T trajectories with an average of M turns per trajectory produces T×M training samples (sequences). This means:

- **Each mini-batch contains the same number of sequences** (`policy_mini_batch_size * n_samples`), but those sequences are now step-samples rather than full trajectories. The effective number of trajectories per mini-batch is reduced. The number of mini-batches (and hence optimizer steps) per training batch increases by the average number of turns — so if you have `train_batch_size=mini_batch_size=32` with an average of 3 turns, you get 3 optimizer steps instead of 1 for each training step. It is also possible that a mini-batch boundary falls mid-trajectory.
- **Each mini-batch contains the sequences for exactly `policy_mini_batch_size` prompts**, regardless of how many turns those prompts produced. This means the number of mini-batches (and hence optimizer steps) per training batch is always `train_batch_size / policy_mini_batch_size`, independent of the number of turns. This also means that the actual mini batch size (number of sequences) trained in each mini batch can vary. Each mini batch always leads to a single optimizer step.
- **Advantages are computed on last steps only**, then broadcast to all steps of the same trajectory. This is mathematically equivalent to non-step-wise advantage computation for GRPO.
- **Training time grows as O(T²) vs O(T)**, since each trajectory of T turns becomes T sequences to forward (each with a growing prompt prefix), as opposed to 1 sequence. SkyRL will support prefix-aware merging of per-step sequences when the prefix matches (WIP), which brings the cost back to O(T) in the common case.
- **Metrics** like `generate/avg_num_tokens` and `generate/avg_response_length` are per-turn rather than per-trajectory, since each training sample is a single turn.
Expand Down
44 changes: 22 additions & 22 deletions skyrl/backends/skyrl_train/distributed/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from skyrl.backends.skyrl_train.training_batch import (
TrainingInputBatch,
TrainingOutputBatch,
pad_training_input_batch,
)


Expand Down Expand Up @@ -170,37 +171,36 @@ def stage_chunks(
cls,
dp_size: int,
data: TrainingInputBatch,
mini_batch_size: int,
mini_batch_boundaries: List[Tuple[int, int]],
) -> List[List[ObjectRef]]:
"""
Pre-stage all mini-batch chunks into the object store.
"""Pre-stage mini-batch chunks into the object store.

Slices the full batch into mini-batches, chunks each by DP rank, and
``ray.put``s each chunk.
Each mini-batch is defined by a ``(start, end)`` index pair from mini_batch_boundaries.
Mini-batches are individually padded so that their size is divisible by dp_size, using dummy
entries with ``loss_mask=0`` that do not affect the loss.

Args:
dp_size: Number of data-parallel ranks
data: Full TrainingInputBatch to slice from
mini_batch_size: Size of each mini-batch (before DP chunking)
dp_size: Number of data-parallel ranks.
data: Full TrainingInputBatch to slice from.
mini_batch_boundaries: List of ``(start, end)`` index pairs. The i-th mini-batch is
data[mini_batch_boundaries[i][0]:mini_batch_boundaries[i][1]].

Returns:
List of per-mini-batch chunk ref lists. ``result[i][dp_rank]`` is
the ObjectRef for mini-batch *i*, DP rank *dp_rank*.
``result[i][dp_rank]`` - ObjectRef for mini-batch *i*, DP rank *dp_rank*.
"""
assert (
len(data) % mini_batch_size == 0
), f"data batch size must be divisible by mini_batch_size, got {len(data)} and {mini_batch_size}"
assert (
mini_batch_size % dp_size == 0
), f"mini_batch_size must be divisible by dp_size, got {mini_batch_size} and {dp_size}"
num_mini_batches = len(data) // mini_batch_size
chunk_size = mini_batch_size // dp_size

all_chunk_refs: List[List[ObjectRef]] = []
for step in range(num_mini_batches):
start = step * mini_batch_size
end = start + mini_batch_size
for start, end in mini_batch_boundaries:
mini_batch = data[start:end]
mb_size = end - start

# Pad to make divisible by dp_size. Will only be non-zero for step-wise training.
pad_size = (-mb_size) % dp_size
if pad_size > 0:
mini_batch = pad_training_input_batch(mini_batch, pad_size)

mini_batch_size = len(mini_batch)
assert mini_batch_size % dp_size == 0, f"mini_batch_size % dp_size != 0, got {mini_batch_size} and {dp_size}"
chunk_size = mini_batch_size // dp_size
chunks = mini_batch.chunk(chunk_size)
all_chunk_refs.append([ray.put(chunk) for chunk in chunks])
return all_chunk_refs
Expand Down
3 changes: 3 additions & 0 deletions skyrl/backends/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,7 @@ def apply_loss_reduction_to_advantages_minibatch(

Args:
advantages: Advantage tensor of shape (minibatch_size, seq_len).
For step-wise training, minibatch_size can be variable.
loss_mask: Mask of shape (minibatch_size, seq_len) indicating valid loss tokens.
loss_reduction: One of "token_mean", "token_mean_legacy", "sequence_mean", "seq_mean_token_sum_norm".
micro_batch_size: Number of sequences per micro-batch
Expand All @@ -1024,6 +1025,8 @@ def apply_loss_reduction_to_advantages_minibatch(

# Option 1b: legacy token-mean that normalizes per-microbatch then averages across microbatches.
elif loss_reduction == "token_mean_legacy":
# FIXME(Charlie): For step-wise training, mini_batch_size can be variable, so the number of
# microbatches depends on how we pad.
num_micro_batches = batch_size // micro_batch_size
for i in range(num_micro_batches):
start_idx = i * micro_batch_size
Expand Down
24 changes: 14 additions & 10 deletions skyrl/backends/skyrl_train/workers/worker_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import ray
from ray import ObjectRef
Expand Down Expand Up @@ -162,24 +162,28 @@ def forward(self, model: str, data: TrainingInputBatch) -> TrainingOutputBatch:
output = concatenate_outputs_after_mesh_dispatch(self._actor_groups[model].actor_infos, results)
return output

def stage_data(self, model: str, data: TrainingInputBatch, mini_batch_size: int) -> List[List[ObjectRef]]:
"""
Pre-stage all mini-batch chunks in the Ray object store.
def stage_data(
self,
model: str,
data: TrainingInputBatch,
mini_batch_boundaries: List[Tuple[int, int]],
) -> List[List[ObjectRef]]:
"""Pre-stage mini-batch chunks in the Ray object store.

Call this once before the training loop so that all serialization is
done upfront and GPUs stay saturated during training.

Args:
model: Model name (used to look up DP size)
data: Full training batch
mini_batch_size: Size of each mini-batch (before DP chunking)
model: Model name (used to look up DP size).
data: Full training batch.
mini_batch_boundaries: List of ``(start, end)`` index pairs.
The i-th mini-batch is data[mini_batch_boundaries[i][0]:mini_batch_boundaries[i][1]].

Returns:
``result[i][dp_rank]`` is the ObjectRef for mini-batch *i*,
DP rank *dp_rank*.
``result[i][dp_rank]`` - ObjectRef for mini-batch *i*, DP rank *dp_rank*.
"""
dp_size = self._actor_groups[model].actor_infos[0].rank.dp_size
return MeshDispatch.stage_chunks(dp_size, data, mini_batch_size)
return MeshDispatch.stage_chunks(dp_size, data, mini_batch_boundaries)

def forward_backward(
self,
Expand Down
75 changes: 75 additions & 0 deletions skyrl/train/dataset/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,78 @@ def convert_prompts_responses_to_batch_tensors(
logprobs_tensor,
rollout_expert_indices_tensor,
)


def compute_prompt_mini_batch_boundaries(
uids: List[str],
mini_batch_size: int,
train_batch_size: int,
is_stepwise: bool,
n_samples_per_prompt: int,
) -> List[Tuple[int, int]]:
"""Compute mini-batch ``(start, end)`` slices from a flat ``uids`` list.

Args:
uids: List of uids, representing which prompt each sequence belongs to.
mini_batch_size: Number of prompts to include in each mini-batch. Same as training config's
config.trainer.policy_mini_batch_size or config.trainer.critic_mini_batch_size.
train_batch_size: Number of prompts in a training batch. For sanity check.
is_stepwise: Whether the training is step-wise. For sanity check.
n_samples_per_prompt: how many samples per prompt. For sanity check.

Consecutive equal entries in ``uids`` belong to the same prompt. Each mini batch spans exactly
``mini_batch_size`` prompts (the last may be smaller if the total prompt count is not divisible
in step-wise training). Works for both step-wise (variable sequences per prompt) and non-step-wise
(fixed ``n_samples_per_prompt`` sequences per prompt) training.

We assume uids are contiguous, i.e. all n_samples_per_prompt trajectories for a prompt, or all
per-step sequences for a trajectory, are contiguous.

Example A: normal non-step-wise training, with n_samples_per_prompt=2 and train_batch_size=4.
uids = ["p0", "p0", "p1", "p1", "p2", "p2", "p3", "p3"]
mini_batch_size = 2
prompt_end_indices = [2, 4, 6, 8]
boundaries = [(0, 4), (4, 8)] # because each mini batch spans exactly 2 prompts, hence 4 sequences

Example B: step-wise training with n_samples_per_prompt = 2, and each trajectory can have 1-2 turns.
uids = ["p0", "p0", "p0", "p0", "p1", "p1", "p2", "p2", "p2", "p3", "p3"]
mini_batch_size = 2
prompt_end_indices = [4, 6, 9, 11]
boundaries = [(0, 6), (6, 11)]
"""
# First compute the end indices of each prompt.
prompt_end_indices: List[int] = []
seen_uids: set[str] = set()
seen_uids.add(uids[0])
for i in range(1, len(uids)):
if uids[i] != uids[i - 1]:
assert uids[i] not in seen_uids, f"uid {uids[i]!r} appears in non-contiguous positions at index {i}. Full uids: {uids}"
seen_uids.add(uids[i])
prompt_end_indices.append(i)
prompt_end_indices.append(len(uids))

# seen_uids should equal to the number of prompts and equal to `train_batch_size`
num_prompts = len(prompt_end_indices)
assert num_prompts == train_batch_size and len(seen_uids) == train_batch_size
assert train_batch_size % mini_batch_size == 0

# Compute boundaries.
boundaries: List[Tuple[int, int]] = []
start_seq = 0
for i in range(0, num_prompts, mini_batch_size):
end_prompt_idx = i + mini_batch_size - 1 # i + mini_batch_size is next mini-batch's first prompt's end index
end_seq = prompt_end_indices[end_prompt_idx]
boundaries.append((start_seq, end_seq))
start_seq = end_seq

# Assert that the mini-batch boundaries are uniform for non-step-wise training.
if not is_stepwise:
expected_num_seq_in_mini_batch = n_samples_per_prompt * mini_batch_size
assert len(boundaries) == train_batch_size // mini_batch_size
for i, (start, end) in enumerate(boundaries):
assert start == i * expected_num_seq_in_mini_batch
assert end - start == expected_num_seq_in_mini_batch
else:
assert len(boundaries) >= train_batch_size // mini_batch_size

return boundaries
58 changes: 40 additions & 18 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from skyrl.train.config import SkyRLTrainConfig
from skyrl.train.dataset import PromptDataset
from skyrl.train.dataset.preprocess import (
compute_prompt_mini_batch_boundaries,
convert_prompts_responses_to_batch_tensors,
)
from skyrl.train.evaluate import evaluate, evaluate_step_wise
Expand Down Expand Up @@ -605,7 +606,17 @@ def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]:
)

def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
"""Converts lists to a padded batch of tensors for training"""
"""Converts lists to a padded batch of tensors for training

Args:
generator_output (GeneratorOutput): Generated rollouts and associated data.
uids (List[str]): List of prompt-unique identifiers for each generator ouput in the same
order as `generator_output`. Used to identify which prompt each generated rollout belongs to.
Returns:
training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the
order of `generator_output` and hence `uids`.
"""
# 1. Extract generator output fields.
prompt_ids: List[List[int]] = generator_output["prompt_token_ids"]
response_ids: List[List[int]] = generator_output["response_ids"]
rewards: List[List[float]] = generator_output["rewards"]
Expand All @@ -628,6 +639,8 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
pixel_values = TensorList(pixel_values)
image_grid_thw = TensorList(image_grid_thw)


# 2. Convert to tensors.
(
sequences_tensor,
attention_masks_tensor,
Expand Down Expand Up @@ -657,6 +670,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
), "expected non-null rollout logprobs tensor when off_policy_correction is enabled"
assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"

# 3. Create training input batch.
training_input = TrainingInputBatch(
{
"sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses)
Expand All @@ -674,7 +688,20 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
if generator_output.get("is_last_step", None) is not None:
training_input.metadata["is_last_step"] = generator_output["is_last_step"]

# padded response length
# 4. Compute mini-batch boundaries for train_critic_and_policy(). It excludes the ones
# we will add in pad_training_input_batch().
train_batch_size = self.cfg.trainer.train_batch_size
n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
is_stepwise = self.cfg.generator.step_wise_trajectories
training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
)
if self.cfg.trainer.critic.model.path is not None:
training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
)

# 5. Record metadata and metrics.
training_input.metadata["response_length"] = response_masks_tensor.shape[1]
batch_num_seq, batch_padded_seq_len = sequences_tensor.shape
logger.info(f"batch_num_seq: {batch_num_seq}, batch_padded_seq_len: {batch_padded_seq_len}")
Expand All @@ -684,14 +711,11 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
"generate/batch_padded_seq_len": batch_padded_seq_len,
}
)
if self.cfg.generator.step_wise_trajectories:
assert (
"trajectory_ids" in generator_output
), "Expected `trajectory_ids` in generator output for step wise training"
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids) for sample_response_ids in response_ids
) / len(response_ids)

# 6. Pad the batch, only needed for step-wise training's `fwd_logprobs_values_reward()`.
logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
dp_size = self.dispatch.get_lcm_dp_size()
pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
Expand Down Expand Up @@ -1052,7 +1076,11 @@ def apply_reward_kl_penalty(
return data

@torch.no_grad()
def _normalize_advantages(self, data: TrainingInputBatch, mini_batch_size: int) -> TrainingInputBatch:
def _normalize_advantages(
self,
data: TrainingInputBatch,
mini_batch_boundaries: List[Tuple[int, int]],
) -> TrainingInputBatch:
advantages = data["advantages"]
response_mask = data["response_mask"]

Expand All @@ -1065,11 +1093,8 @@ def _normalize_advantages(self, data: TrainingInputBatch, mini_batch_size: int)
data["advantages"] = (advantages - mean) * rstd

# Step 2: Loss reduction normalization per mini-batch
num_mini_batches = len(data) // mini_batch_size
normalized_advantages = torch.zeros_like(advantages)
for local_step in range(num_mini_batches):
start_idx = local_step * mini_batch_size
end_idx = (local_step + 1) * mini_batch_size
for start_idx, end_idx in mini_batch_boundaries:
mini_batch = data[start_idx:end_idx]
normalized_advantages[start_idx:end_idx] = apply_loss_reduction_to_advantages_minibatch(
advantages=mini_batch["advantages"],
Expand Down Expand Up @@ -1099,20 +1124,17 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s
Returns:
Dict of reduced metrics from training
"""
# Compute mini batch size from config (algorithm-level concept)
n_samples = self.cfg.generator.n_samples_per_prompt
boundaries = data.metadata[f"{model}_mini_batch_boundaries"]
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.

if model == "policy":
mini_batch_size = self.cfg.trainer.policy_mini_batch_size * n_samples
# Normalize advantages for policy training; critic training does not need this
data = self._normalize_advantages(data, mini_batch_size)
else:
mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples
data = self._normalize_advantages(data, boundaries)

all_metrics: Dict[str, List[float]] = defaultdict(list)

# Pre-stage all per-DP mini-batch chunks in the object store so that
# serialization is fully off the critical path during training.
all_chunk_refs = self.dispatch.stage_data(model, data, mini_batch_size)
all_chunk_refs = self.dispatch.stage_data(model, data, boundaries)

# Training loop over epochs and mini-batches
for _epoch in range(self.cfg.trainer.update_epochs_per_batch):
Expand Down
9 changes: 9 additions & 0 deletions skyrl/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def validate_batch_sizes(cfg: SkyRLTrainConfig):
f"train_batch_size {cfg.trainer.train_batch_size} should be divisible by "
f"policy_mini_batch_size {cfg.trainer.policy_mini_batch_size}"
)

# TODO(Charlie): For step-wise training, the number of sequences per prompt is variable, and
# padded mini-batch may not be divisible by dp_size. Should check if we need these assertions.
policy_mini_batch_size_per_gpu = (
cfg.trainer.policy_mini_batch_size * cfg.generator.n_samples_per_prompt // policy_dp_size
)
Expand Down Expand Up @@ -294,6 +297,12 @@ def validate_cfg(cfg: SkyRLTrainConfig):
f"connection. Use an outcome-based estimator (grpo, rloo, maxrl) or disable "
f"step_wise_trajectories."
)
if cfg.generator.step_wise_trajectories and cfg.trainer.algorithm.loss_reduction == "token_mean_legacy":
# TODO(Charlie): this can be fixed, can revisit later.
raise ValueError(
"`token_mean_legacy` loss reduction is not supported with step-wise training. "
"Use `token_mean` instead."
)

assert cfg.trainer.algorithm.loss_reduction in (
"token_mean",
Expand Down
Loading
Loading