Skip to content

[tinker] Support prompt_logprobs in SkyRLTrainBackend sample() path#1461

Open
pbokc wants to merge 7 commits intoNovaSky-AI:mainfrom
pbokc:tinker_skyrl_support_prompt_logprobs
Open

[tinker] Support prompt_logprobs in SkyRLTrainBackend sample() path#1461
pbokc wants to merge 7 commits intoNovaSky-AI:mainfrom
pbokc:tinker_skyrl_support_prompt_logprobs

Conversation

@pbokc
Copy link
Copy Markdown
Contributor

@pbokc pbokc commented Apr 6, 2026

Summary

Support prompt_logprobs in the SkyRLTrainBackend sample() path

Details:

  • When a sample request includes prompt_logprobs=True, the backend returns per-prompt-token log-probabilities under the current model
  • Threads prompt_logprobs through the full inference engine stack: InferenceEngineInterfaceRayWrappedInferenceEngineInferenceEngineClientSkyRLTrainBackend
  • Extracts prompt logprobs from vLLM's RequestOutput in _postprocess_outputs(), skipping position 0 (no prior context) to match JAX backend's prompt_len - 1 length
  • Includes examples/tinker/tis_example.py: a full RL training loop on GSM8K demonstrating truncated importance sampling with real prompt logprobs (based on the tinker-cookbook rl_loop recipe)

Testing

Ran rl_loop.py from tinker-cookbook and confirmed prompt_logprobs were coming through.
wandb run: https://wandb.ai/pranb01-pranavgb/skyrl-prompt-logprobs

🤖 Generated with Claude Code


Open with Devin

@pbokc pbokc marked this pull request as draft April 6, 2026 01:26
gemini-code-assist[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

@pbokc
Copy link
Copy Markdown
Contributor Author

pbokc commented Apr 6, 2026

Adresses the "Support logprobs with the sample() API." item in #1380

@pbokc pbokc marked this pull request as ready for review April 8, 2026 03:05
gemini-code-assist[bot]

This comment was marked as resolved.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for importance sampling and KL-regularized training within the Tinker API integration. Key changes include the addition of a new example script for GRPO-style RL on GSM8K, the implementation of prompt_logprobs extraction across various inference engine backends (including vLLM), and the integration of reference model support for KL loss calculation in the training backend. I have no feedback to provide.

@pbokc pbokc marked this pull request as draft April 9, 2026 01:52
pbokc and others added 4 commits April 11, 2026 23:50
Add prompt_logprobs support to the SkyRLTrainBackend, achieving parity with
the JAX backend. When a sample request includes prompt_logprobs=True, the
backend now returns per-prompt-token log-probabilities under the current model.

Changes:
- InferenceEngineOutput: add prompt_logprobs field
- InferenceEngineInterface.sample(): accept prompt_logprobs param, inject into
  vLLM SamplingParams
- Thread prompt_logprobs through RayWrappedInferenceEngine, InferenceEngineClient,
  and SkyRLTrainBackend
- vllm_engine._postprocess_outputs(): extract prompt logprobs from vLLM
  RequestOutput, skipping position 0 (no prior context) to match JAX backend
  length of prompt_len - 1
- RemoteInferenceEngine: returns prompt_logprobs=None (unsupported over HTTP)
- Add TODO for training/rollout generate() path

Includes examples/tinker/tis_example.py demonstrating importance sampling
with real prompt logprobs on GSM8K.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
@pbokc pbokc force-pushed the tinker_skyrl_support_prompt_logprobs branch from 2c4e23d to 54a7893 Compare April 12, 2026 03:54
@pbokc pbokc marked this pull request as ready for review April 12, 2026 03:56
@pbokc
Copy link
Copy Markdown
Contributor Author

pbokc commented Apr 12, 2026

/gemini review

gemini-code-assist[bot]

This comment was marked as resolved.

gemini-code-assist[bot]

This comment was marked as resolved.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

⚠️ 1 issue in files not directly in the diff

⚠️ Incomplete transformation: prompt_logprobs field missing from RemoteInferenceClient.generate() output (skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py:370-376)

The PR adds prompt_logprobs to the InferenceEngineOutput TypedDict (base.py:34) and updates four of the five construction sites to include prompt_logprobs=None, but misses the one in remote_inference_client.py:370-376. This generate() method is called in the _SKYRL_USE_NEW_INFERENCE=True pathway (e.g., from skyrl/train/generators/skyrl_gym_generator.py:338 and :671). While all current consumers use .get() (so no KeyError at runtime), the dict is structurally incomplete relative to its TypedDict contract, and any future code that accesses output["prompt_logprobs"] directly would crash.

View 8 additional findings in Devin Review.

Open in Devin Review

pbokc and others added 2 commits April 14, 2026 21:12
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant