[tinker] Support prompt_logprobs in SkyRLTrainBackend sample() path#1461
[tinker] Support prompt_logprobs in SkyRLTrainBackend sample() path#1461pbokc wants to merge 7 commits intoNovaSky-AI:mainfrom
Conversation
|
Adresses the "Support logprobs with the |
There was a problem hiding this comment.
Code Review
This pull request implements support for importance sampling and KL-regularized training within the Tinker API integration. Key changes include the addition of a new example script for GRPO-style RL on GSM8K, the implementation of prompt_logprobs extraction across various inference engine backends (including vLLM), and the integration of reference model support for KL loss calculation in the training backend. I have no feedback to provide.
Add prompt_logprobs support to the SkyRLTrainBackend, achieving parity with the JAX backend. When a sample request includes prompt_logprobs=True, the backend now returns per-prompt-token log-probabilities under the current model. Changes: - InferenceEngineOutput: add prompt_logprobs field - InferenceEngineInterface.sample(): accept prompt_logprobs param, inject into vLLM SamplingParams - Thread prompt_logprobs through RayWrappedInferenceEngine, InferenceEngineClient, and SkyRLTrainBackend - vllm_engine._postprocess_outputs(): extract prompt logprobs from vLLM RequestOutput, skipping position 0 (no prior context) to match JAX backend length of prompt_len - 1 - RemoteInferenceEngine: returns prompt_logprobs=None (unsupported over HTTP) - Add TODO for training/rollout generate() path Includes examples/tinker/tis_example.py demonstrating importance sampling with real prompt logprobs on GSM8K. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
2c4e23d to
54a7893
Compare
|
/gemini review |
There was a problem hiding this comment.
Devin Review found 1 new potential issue.
⚠️ 1 issue in files not directly in the diff
⚠️ Incomplete transformation: prompt_logprobs field missing from RemoteInferenceClient.generate() output (skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py:370-376)
The PR adds prompt_logprobs to the InferenceEngineOutput TypedDict (base.py:34) and updates four of the five construction sites to include prompt_logprobs=None, but misses the one in remote_inference_client.py:370-376. This generate() method is called in the _SKYRL_USE_NEW_INFERENCE=True pathway (e.g., from skyrl/train/generators/skyrl_gym_generator.py:338 and :671). While all current consumers use .get() (so no KeyError at runtime), the dict is structurally incomplete relative to its TypedDict contract, and any future code that accesses output["prompt_logprobs"] directly would crash.
View 8 additional findings in Devin Review.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Summary
Support prompt_logprobs in the SkyRLTrainBackend sample() path
Details:
prompt_logprobs=True, the backend returns per-prompt-token log-probabilities under the current modelprompt_logprobsthrough the full inference engine stack:InferenceEngineInterface→RayWrappedInferenceEngine→InferenceEngineClient→SkyRLTrainBackendRequestOutputin_postprocess_outputs(), skipping position 0 (no prior context) to match JAX backend'sprompt_len - 1lengthexamples/tinker/tis_example.py: a full RL training loop on GSM8K demonstrating truncated importance sampling with real prompt logprobs (based on the tinker-cookbookrl_looprecipe)Testing
Ran
rl_loop.pyfrom tinker-cookbook and confirmed prompt_logprobs were coming through.wandb run: https://wandb.ai/pranb01-pranavgb/skyrl-prompt-logprobs
🤖 Generated with Claude Code