[tinker] Support PPO loss with Tinker and add critic model in SkyRLTrainBackend#1389
[tinker] Support PPO loss with Tinker and add critic model in SkyRLTrainBackend#1389tamoghnokandar wants to merge 6 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request significantly refactors the SkyRL-Train backend to support both policy and critic models, primarily for PPO. It introduces a model_role concept, updating create_model methods, data structures (_model_ids, _model_metadata), and core functionalities like forward_backward, forward, optim_step, and checkpointing to handle distinct roles. The ppo_critic loss function is added, along with corresponding input fields (values, returns) in API types and batch preparation. The Jax backend is updated to enforce that it only supports 'policy' models and raises errors for 'ppo_critic' loss. Review comments suggest extracting duplicated validation logic for model_role and loss_fn into a helper method and refactoring the loss_fn_outputs construction for better efficiency and clarity.
I am having trouble creating individual review comments. Click here to see my feedback.
skyrl/backends/skyrl_train_backend.py (470-473)
This validation logic for role and loss_fn is duplicated in the forward method (lines 549-552). To improve maintainability and avoid code duplication, consider extracting this logic into a private helper method. For example:
def _validate_batch_role_and_loss(self, role: str, loss_fn: str):
if role == "critic" and loss_fn != "ppo_critic":
raise ValueError(f"Critic batches must use loss_fn='ppo_critic', got {loss_fn!r}")
if role != "critic" and loss_fn == "ppo_critic":
raise ValueError("loss_fn='ppo_critic' is only valid for critic models")You could then call self._validate_batch_role_and_loss(role, loss_fn) in both forward_backward and forward methods.
skyrl/backends/skyrl_train_backend.py (517-531)
The current implementation for constructing loss_fn_outputs is a bit inefficient and could be clearer. It initializes loss_fn_outputs on line 517, and then potentially re-initializes it as an empty list on line 519 if "loss_fn_outputs" is in data.
This can be refactored to be more direct and avoid the unnecessary list creation.
if "loss_fn_outputs" in data:
loss_fn_outputs = []
for i in range(start_idx, end_idx):
raw_output = data["loss_fn_outputs"][i]
formatted_output = {}
for key in ("elementwise_loss", "logprobs", "values"):
values = list(raw_output.get(key, []))
if values or key in raw_output:
formatted_output[key] = {
"data": values,
"dtype": "float32",
"shape": [len(values)],
}
loss_fn_outputs.append(formatted_output)
else:
loss_fn_outputs = [{} for _ in range(end_idx - start_idx)]| ray.get(critic_model.async_run_ray_method("pass_through", "_set_pad_token_id", self._tokenizer.pad_token_id)) | ||
| if colocate_all: | ||
| critic_model.offload_to_cpu() | ||
| self._dispatch.mark_all_offloaded() |
There was a problem hiding this comment.
instead of a mark_all_offloaded call here we should just have a mark_as_offloaded method that marks a specific model_id as offloaded.
| self.config = config | ||
| self._model_id: str | None = None | ||
| self._model_metadata: types.ModelMetadata | None = None | ||
| self._model_ids: dict[str, str] = {} |
There was a problem hiding this comment.
This should be renamed to _model_ids_to_role
| if len(roles) != 1: | ||
| raise ValueError(f"Mixed model roles in one batch are not supported: {sorted(roles)}") | ||
| if len(set(model_ids)) != 1: | ||
| raise ValueError(f"Mixed model_ids in one batch are not supported: {sorted(set(model_ids))}") |
There was a problem hiding this comment.
Aren't these checks redundant ? If len(set(model_ids) == 1, then len(roles) will also be 1.
| start = max(output_logprobs.shape[1] - valid_len, 0) | ||
| logprobs = output_logprobs[i, start:].tolist() | ||
| start = max(output_tensor.shape[1] - valid_len, 0) | ||
| values = output_tensor[i, start:].tolist() |
There was a problem hiding this comment.
Nit: This should just be outputs ? can be values or logprobs
| def ppo_critic_loss( | ||
| _target_logprobs: jax.Array, | ||
| _loss_mask: jax.Array, | ||
| _sampling_logprobs: jax.Array, | ||
| _advantages: jax.Array, | ||
| _loss_fn_config: LossFnConfig, | ||
| ) -> jax.Array: | ||
| return jnp.zeros_like(_loss_mask) | ||
|
|
||
|
|
There was a problem hiding this comment.
Yes, this is not implemented. Should I add critic model support for JAX backend too?
Fixes the first issue of #1380.
Summary
LossFnInputswith optionalvaluesandreturnsfields, and addppo_criticloss typeSkyrlTrainBackendfrom single-model to multi-model registry (model_id → role), supporting both policyand critic actor groups
create_modelwithmodel_role="critic", sharing the policy's base model withindependent training
register_actor_group()andset_algorithm_config()toWorkerDispatchfor dynamic critic registrationTest plan
LossFnInputsfields (values,returns) andDatum.to_types()conversionprepare_model_pass_batchwithppo_criticloss typeppo_criticloss