Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions skyrl/backends/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class InferenceEngineOutput(TypedDict):
response_ids: List[List[int]]
stop_reasons: List[str]
response_logprobs: Optional[List[List[float]]]
prompt_logprobs: Optional[List[List[float]]] # per-prompt-token logprobs under the current model
Comment thread
pbokc marked this conversation as resolved.
rollout_expert_indices: Optional[List[List[List[int]]]] # [seq_len, layer_num, topk]


Expand All @@ -45,6 +46,7 @@ async def sample(
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
prompt_logprobs: bool = False,
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.

Expand All @@ -54,18 +56,24 @@ async def sample(
prompt_token_ids: Token IDs for a single prompt.
num_samples: Number of independent samples to generate.
sampling_params: Sampling parameters.
prompt_logprobs: If True, return per-token logprobs over the prompt.

Returns:
InferenceEngineOutput containing num_samples results:
- response_ids: List of num_samples token ID lists
- responses: List of num_samples decoded strings
- stop_reasons: List of num_samples stop reasons
- response_logprobs: Optional list of num_samples logprob lists
- prompt_logprobs: Optional list of num_samples prompt logprob lists
"""
if prompt_logprobs:
sampling_params = {**sampling_params, "prompt_logprobs": 1}

all_response_ids = []
all_responses = []
all_stop_reasons = []
all_response_logprobs = []
all_prompt_logprobs = []
all_rollout_expert_indices = []

for _ in range(num_samples):
Expand All @@ -83,6 +91,8 @@ async def sample(
all_stop_reasons.append(output["stop_reasons"][0])
if output.get("response_logprobs") is not None:
all_response_logprobs.append(output["response_logprobs"][0])
if output.get("prompt_logprobs") is not None:
all_prompt_logprobs.append(output["prompt_logprobs"][0])
if output.get("rollout_expert_indices") is not None:
all_rollout_expert_indices.append(output["rollout_expert_indices"][0])

Expand All @@ -91,6 +101,7 @@ async def sample(
"responses": all_responses,
"stop_reasons": all_stop_reasons,
"response_logprobs": all_response_logprobs if all_response_logprobs else None,
"prompt_logprobs": all_prompt_logprobs if all_prompt_logprobs else None,
"rollout_expert_indices": all_rollout_expert_indices if all_rollout_expert_indices else None,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,14 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
add_rollout_expert_indices = True
rollout_expert_indices[original_idx] = result["rollout_expert_indices"][local_idx]

# TODO: Should we support prompt_logprobs in the training/rollout generate() path?
# Currently only the sample() path supports prompt_logprobs.
return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs if add_resp_logprobs else None,
prompt_logprobs=None,
rollout_expert_indices=rollout_expert_indices if add_rollout_expert_indices else None,
)

Expand All @@ -183,6 +186,7 @@ async def sample(
num_samples: int,
sampling_params: Dict[str, Any],
session_id: Optional[Union[str, int]] = None,
prompt_logprobs: bool = False,
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.

Expand All @@ -196,6 +200,7 @@ async def sample(
session_id: Optional session ID for consistent engine routing (e.g., conversation ID).
If None, uses random load-balancing. Tinker API should pass None since
each sample() call is independent.
prompt_logprobs: If True, return per-token logprobs over the prompt.

Returns:
InferenceEngineOutput containing num_samples results.
Expand All @@ -208,6 +213,7 @@ async def sample(
prompt_token_ids=prompt_token_ids,
num_samples=num_samples,
sampling_params=sampling_params,
prompt_logprobs=prompt_logprobs,
)

async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ async def sample(
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
prompt_logprobs: bool = False,
) -> InferenceEngineOutput:
return await self.inference_engine_actor.sample.remote(
prompt_token_ids=prompt_token_ids,
num_samples=num_samples,
sampling_params=sampling_params,
prompt_logprobs=prompt_logprobs,
)

async def wake_up(self, *args: Any, **kwargs: Any):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
raise ValueError(f"Invalid engine backend: {self.engine_backend}")

return InferenceEngineOutput(
responses=outputs, stop_reasons=finish_reasons, response_ids=output_ids, response_logprobs=None
responses=outputs,
stop_reasons=finish_reasons,
response_ids=output_ids,
response_logprobs=None,
prompt_logprobs=None,
Comment thread
pbokc marked this conversation as resolved.
)

async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
28 changes: 28 additions & 0 deletions skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def _postprocess_outputs(self, outputs):
stop_reasons: List[str] = []
response_ids: List[List[int]] = []
response_logprobs: Optional[List[List[float]]] = []
prompt_logprobs_list: List[Optional[List[float]]] = []
rollout_expert_indices: Optional[List[List[List[List[int]]]]] = []

for output in outputs:
Expand All @@ -174,6 +175,29 @@ def _postprocess_outputs(self, outputs):
del token_logprobs
response_logprobs.append(_logprobs)

# Extract per-prompt-token logprobs (from RequestOutput, not CompletionOutput).
# Returns logprob of each prompt token given prior context, skipping position 0
# (which has no prior context). This matches the JAX backend which computes
# logits_to_logprobs(all_logits[:, :-1, :], input_ids[:, 1:]) → length prompt_len - 1.
_prompt_logprobs = None
if output.prompt_logprobs is not None:
_prompt_logprobs = []
for i, pos_logprobs in enumerate(output.prompt_logprobs):
if pos_logprobs is None:
# First position has no prior context; skip it (matching JAX backend).
# Only first position can be None
continue
else:
token_id = output.prompt_token_ids[i]
if token_id not in pos_logprobs:
raise RuntimeError(
f"vLLM prompt_logprobs missing actual token at position {i} "
f"(token_id={token_id}). This violates vLLM's contract that "
f"the actual prompt token is always returned regardless of rank."
)
_prompt_logprobs.append(pos_logprobs[token_id].logprob)
prompt_logprobs_list.append(_prompt_logprobs)
Comment thread
pbokc marked this conversation as resolved.

_routed_experts = None
if resp.routed_experts is not None:
if hasattr(resp.routed_experts, "tolist"):
Expand All @@ -185,6 +209,9 @@ def _postprocess_outputs(self, outputs):
if len(response_logprobs) and response_logprobs[0] is None:
response_logprobs = None # hack: assume uniform sampling params

if len(prompt_logprobs_list) and prompt_logprobs_list[0] is None:
prompt_logprobs_list = None # hack: assume uniform sampling params

if len(rollout_expert_indices) > 0 and rollout_expert_indices[0] is None:
rollout_expert_indices = None # hack: assume uniform sampling params

Expand All @@ -193,6 +220,7 @@ def _postprocess_outputs(self, outputs):
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs,
prompt_logprobs=prompt_logprobs_list,
rollout_expert_indices=rollout_expert_indices,
)

Expand Down
18 changes: 13 additions & 5 deletions skyrl/backends/skyrl_train_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,8 @@ def _sample_with_legacy_client(
"""Sample using legacy InferenceEngineClient with Ray-wrapped engines."""
all_input_ids = [r.prompt_ids for r in render_model_input(prepared_batch.all_model_inputs)]

needs_prompt_logprobs = prepared_batch.needs_prompt_logprobs

async def sample_all():
tasks = []
for i in range(len(all_input_ids)):
Expand All @@ -688,6 +690,7 @@ async def sample_all():
prompt_token_ids=prompt_token_ids,
num_samples=1, # Tinker batches multiple samples separately
sampling_params=params_dict,
prompt_logprobs=needs_prompt_logprobs,
)
)

Expand Down Expand Up @@ -748,7 +751,7 @@ def _extract_sequences(output):
)

results = {}
for request_id, model_id, start_idx, end_idx, needs_prompt_logprobs in prepared_batch.request_batch_slices:
for request_id, model_id, start_idx, end_idx, prompt_logprobs_requested in prepared_batch.request_batch_slices:
sequences = []
has_error = False
error_msg = None
Expand Down Expand Up @@ -792,13 +795,18 @@ def _extract_sequences(output):
status="error",
)
else:
# Note: prompt_logprobs not supported initially
if needs_prompt_logprobs:
logger.warning("Prompt logprobs requested but not yet supported")
# All samples for a request share the same prompt, so use the first sample's
# prompt logprobs (parity with JAX backend).
first_output = sample_outputs[start_idx]
prompt_logprobs = None
if prompt_logprobs_requested:
all_prompt_logprobs = first_output.get("prompt_logprobs")
if all_prompt_logprobs and len(all_prompt_logprobs) > 0:
prompt_logprobs = all_prompt_logprobs[0]

results[request_id] = types.SampleOutput(
sequences=sequences,
prompt_logprobs=None,
prompt_logprobs=prompt_logprobs,
)

return results
Expand Down
Loading