Skip to content

[tinker] Support PPO loss with Tinker and add critic model in SkyRLTrainBackend#1389

Open
tamoghnokandar wants to merge 6 commits intoNovaSky-AI:mainfrom
tamoghnokandar:add_logprobs
Open

[tinker] Support PPO loss with Tinker and add critic model in SkyRLTrainBackend#1389
tamoghnokandar wants to merge 6 commits intoNovaSky-AI:mainfrom
tamoghnokandar:add_logprobs

Conversation

@tamoghnokandar
Copy link
Copy Markdown
Contributor

@tamoghnokandar tamoghnokandar commented Mar 25, 2026

Fixes the first issue of #1380.

Summary

  • Add critic model support to the Tinker backend, enabling actor-critic PPO training through the Tinker API
  • Extend LossFnInputs with optional values and returns fields, and add ppo_critic loss type
  • Refactor SkyrlTrainBackend from single-model to multi-model registry (model_id → role), supporting both policy
    and critic actor groups
  • Critic models are created via create_model with model_role="critic", sharing the policy's base model with
    independent training
  • Add register_actor_group() and set_algorithm_config() to WorkerDispatch for dynamic critic registration

Test plan

  • Unit tests for new LossFnInputs fields (values, returns) and Datum.to_types() conversion
  • Unit tests for prepare_model_pass_batch with ppo_critic loss type
  • Unit tests verifying JAX backend rejects critic role and ppo_critic loss
  • Verify existing policy-only workflows remain unaffected (values/returns default to empty)

Open with Devin

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly refactors the SkyRL-Train backend to support both policy and critic models, primarily for PPO. It introduces a model_role concept, updating create_model methods, data structures (_model_ids, _model_metadata), and core functionalities like forward_backward, forward, optim_step, and checkpointing to handle distinct roles. The ppo_critic loss function is added, along with corresponding input fields (values, returns) in API types and batch preparation. The Jax backend is updated to enforce that it only supports 'policy' models and raises errors for 'ppo_critic' loss. Review comments suggest extracting duplicated validation logic for model_role and loss_fn into a helper method and refactoring the loss_fn_outputs construction for better efficiency and clarity.

I am having trouble creating individual review comments. Click here to see my feedback.

skyrl/backends/skyrl_train_backend.py (470-473)

medium

This validation logic for role and loss_fn is duplicated in the forward method (lines 549-552). To improve maintainability and avoid code duplication, consider extracting this logic into a private helper method. For example:

def _validate_batch_role_and_loss(self, role: str, loss_fn: str):
    if role == "critic" and loss_fn != "ppo_critic":
        raise ValueError(f"Critic batches must use loss_fn='ppo_critic', got {loss_fn!r}")
    if role != "critic" and loss_fn == "ppo_critic":
        raise ValueError("loss_fn='ppo_critic' is only valid for critic models")

You could then call self._validate_batch_role_and_loss(role, loss_fn) in both forward_backward and forward methods.

skyrl/backends/skyrl_train_backend.py (517-531)

medium

The current implementation for constructing loss_fn_outputs is a bit inefficient and could be clearer. It initializes loss_fn_outputs on line 517, and then potentially re-initializes it as an empty list on line 519 if "loss_fn_outputs" is in data.

This can be refactored to be more direct and avoid the unnecessary list creation.

            if "loss_fn_outputs" in data:
                loss_fn_outputs = []
                for i in range(start_idx, end_idx):
                    raw_output = data["loss_fn_outputs"][i]
                    formatted_output = {}
                    for key in ("elementwise_loss", "logprobs", "values"):
                        values = list(raw_output.get(key, []))
                        if values or key in raw_output:
                            formatted_output[key] = {
                                "data": values,
                                "dtype": "float32",
                                "shape": [len(values)],
                            }
                    loss_fn_outputs.append(formatted_output)
            else:
                loss_fn_outputs = [{} for _ in range(end_idx - start_idx)]

devin-ai-integration[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

@SumanthRH SumanthRH self-assigned this Apr 3, 2026
ray.get(critic_model.async_run_ray_method("pass_through", "_set_pad_token_id", self._tokenizer.pad_token_id))
if colocate_all:
critic_model.offload_to_cpu()
self._dispatch.mark_all_offloaded()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of a mark_all_offloaded call here we should just have a mark_as_offloaded method that marks a specific model_id as offloaded.

self.config = config
self._model_id: str | None = None
self._model_metadata: types.ModelMetadata | None = None
self._model_ids: dict[str, str] = {}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be renamed to _model_ids_to_role

Comment on lines +138 to +141
if len(roles) != 1:
raise ValueError(f"Mixed model roles in one batch are not supported: {sorted(roles)}")
if len(set(model_ids)) != 1:
raise ValueError(f"Mixed model_ids in one batch are not supported: {sorted(set(model_ids))}")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these checks redundant ? If len(set(model_ids) == 1, then len(roles) will also be 1.

start = max(output_logprobs.shape[1] - valid_len, 0)
logprobs = output_logprobs[i, start:].tolist()
start = max(output_tensor.shape[1] - valid_len, 0)
values = output_tensor[i, start:].tolist()
Copy link
Copy Markdown
Member

@SumanthRH SumanthRH Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This should just be outputs ? can be values or logprobs

Comment on lines +81 to +90
def ppo_critic_loss(
_target_logprobs: jax.Array,
_loss_mask: jax.Array,
_sampling_logprobs: jax.Array,
_advantages: jax.Array,
_loss_fn_config: LossFnConfig,
) -> jax.Array:
return jnp.zeros_like(_loss_mask)


Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not implemented?

Copy link
Copy Markdown
Contributor Author

@tamoghnokandar tamoghnokandar Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is not implemented. Should I add critic model support for JAX backend too?

devin-ai-integration[bot]

This comment was marked as resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants