Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions skyrl/backends/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
raise NotImplementedError

@abstractmethod
async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Handle Anthropic Messages API (/v1/messages). Payload: {"json": <body>, "headers": <headers>}."""
raise NotImplementedError

@abstractmethod
async def wake_up(self, *args: Any, **kwargs: Any):
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,17 @@ async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, An

return await self.engines[engine_idx].chat_completion(request_payload)

async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Route Anthropic Messages API requests to engines.
Similar to chat_completion, routes based on session_id for sticky routing.
"""
session_id = request_payload["json"].pop("session_id", None)
if session_id is not None:
assert isinstance(session_id, (str, int)), "Session ID must be an integer or string for `/v1/messages`"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using assert for validating input data is risky because assertions can be disabled in production (with Python's -O flag), which would silently bypass this check. This could lead to unexpected behavior. For validating user-provided data like session_id, it's better to perform an explicit check and return an error that results in a 400 Bad Request. Following the error handling pattern seen elsewhere in this PR, you could return an error dictionary.

engine_idx = self._select_engine_idx(session_id)
logger.info(f"[InferenceEngineClient] Routing /v1/messages to engine {engine_idx}/{len(self.engines)}")
return await self.engines[engine_idx].anthropic_messages(request_payload)

async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Handles an OpenAI /completions request.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,66 @@ async def completions(raw_request: Request):
"""
return await handle_openai_request(raw_request, endpoint="/completions")

@app.post("/v1/messages")
async def anthropic_messages(raw_request: Request):
"""Anthropic-compatible Messages API endpoint."""
Comment on lines +303 to +305
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New /v1/messages behavior isn’t covered by tests. There are already tests for the HTTP endpoint and request routing; please add at least a basic /v1/messages request/response test (including error status mapping) to prevent regressions.

Copilot uses AI. Check for mistakes.
try:
request_json = await raw_request.json()

if _global_inference_engine_client is None:
return JSONResponse(
content={"error": {"message": "Inference engine client not initialized", "type": "internal_error"}},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)
if "model" not in request_json:
return JSONResponse(
content={"error": {"message": "The field `model` is required", "type": "invalid_request_error"}},
status_code=HTTPStatus.BAD_REQUEST.value,
)
messages = request_json.get("messages")
if not isinstance(messages, list) or not messages:
return JSONResponse(
content={"error": {"message": "The field `messages` is required, must be a non-empty list", "type": "invalid_request_error"}},
status_code=HTTPStatus.BAD_REQUEST.value,
)

payload = {
"json": request_json,
"headers": dict(raw_request.headers) if hasattr(raw_request, "headers") else {},
}
anthropic_response = await _global_inference_engine_client.anthropic_messages(payload)

if "error" in anthropic_response or anthropic_response.get("object", "") == "error":
if "error" in anthropic_response:
error = anthropic_response["error"]
error_code = error.get("code")
error_type = error.get("type", "internal_error")
else:
error_code = anthropic_response.get("code")
error_type = anthropic_response.get("type", "internal_error")
# Prefer numeric error code if available, fall back to type-based mapping
if isinstance(error_code, int):
status_code = error_code
elif isinstance(error_code, str) and error_code.isdigit():
status_code = int(error_code)
else:
status_code = HTTPStatus.BAD_REQUEST.value if error_type == "invalid_request_error" else HTTPStatus.INTERNAL_SERVER_ERROR.value
return JSONResponse(content=anthropic_response, status_code=status_code)

return JSONResponse(content=anthropic_response)

except json.JSONDecodeError as e:
return JSONResponse(
content={"error": {"message": f"Invalid JSON: {str(e)}", "type": "invalid_request_error"}},
status_code=HTTPStatus.BAD_REQUEST.value,
)
except Exception as e:
logger.error(f"Error in /v1/messages: {e}\n{traceback.format_exc()}")
return JSONResponse(
content={"error": {"message": f"Internal error: {str(e)}", "type": "internal_error"}},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)

# Health check endpoint
# All inference engine replicas are initialized before creating `InferenceEngineClient`, and thus
# we can start receiving requests as soon as the FastAPI server starts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, An
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
return await self.inference_engine_actor.completion.remote(request_payload)

async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
return await self.inference_engine_actor.anthropic_messages.remote(request_payload)

async def pause_generation(self) -> None:
return await self.inference_engine_actor.pause_generation.remote()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,17 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:

return response

async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Call Anthropic Messages API endpoint (/v1/messages)."""
body = request_payload.get("json", {})
headers = {"Content-Type": "application/json"}
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Disabling the timeout by setting total=None can be risky, as it might cause requests to hang indefinitely if the remote server is unresponsive. This could lead to resource exhaustion on the client. It's generally safer to set a long but finite timeout value (e.g., 300 seconds) to ensure that connections are eventually terminated.

Suggested change
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=300)) as session:

request_url = f"{self.url}/v1/messages"
async with session.post(request_url, json=body, headers=headers) as resp:
response = await resp.json()

return response

async def wake_up(self, *args: Any, **kwargs: Any):
async with aiohttp.ClientSession() as session:
resp = await session.post(f"{self.url}/wake_up", json={"tags": kwargs.get("tags", 1)})
Expand Down
106 changes: 106 additions & 0 deletions skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Only supported in AsyncVLLMInferenceEngine."""
raise NotImplementedError("`completion` is only supported in AsyncVLLMInferenceEngine.")

async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Only supported in AsyncVLLMInferenceEngine."""
raise NotImplementedError("`anthropic_messages` is only supported in AsyncVLLMInferenceEngine.")

async def wake_up(self, *args: Any, **kwargs: Any):
await asyncio.to_thread(self.llm.wake_up, tags=kwargs.get("tags", None))

Expand Down Expand Up @@ -642,6 +646,108 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
return await self._handle_openai_request(request_payload, endpoint="/completions")

async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Convert Anthropic Messages format to OpenAI chat completions and back."""
request_json = request_payload.get("json", {})
headers = request_payload.get("headers", {})

if "model" not in request_json:
return {"error": {"message": "The field `model` is required", "type": "invalid_request_error"}}
messages = request_json.get("messages")
if not isinstance(messages, list) or not messages:
return {"error": {"message": "The field `messages` is required, must be a non-empty list", "type": "invalid_request_error"}}

try:
openai_request = {
"model": request_json["model"],
"messages": [],
"stream": False,
}

if "system" in request_json and request_json["system"]:
system_content = request_json["system"]
if isinstance(system_content, list):
text_parts = []
for block in system_content:
if block.get("type") == "text":
text_parts.append(block["text"])
system_content = "\n".join(text_parts)
openai_request["messages"].append({"role": "system", "content": system_content})

for msg in request_json["messages"]:
content = msg["content"]
if isinstance(content, list):
text_parts = []
for block in content:
if block.get("type") == "text":
text_parts.append(block["text"])
content = "\n".join(text_parts)
Comment on lines +680 to +684
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Anthropic→OpenAI conversion only keeps type == "text" blocks and silently discards other content block types (e.g. tool_use, tool_result, images). That breaks tool-calling flows and conflicts with the PR description of “full … content blocks … mapping”. Consider translating non-text blocks (especially tool use/result) or explicitly rejecting unsupported block types with a clear invalid_request_error rather than dropping them.

Copilot uses AI. Check for mistakes.
openai_msg = {"role": msg["role"], "content": content}
openai_request["messages"].append(openai_msg)

if "max_tokens" in request_json:
openai_request["max_tokens"] = request_json["max_tokens"]
if "temperature" in request_json:
openai_request["temperature"] = request_json["temperature"]
if "top_p" in request_json:
openai_request["top_p"] = request_json["top_p"]
if "stop_sequences" in request_json:
openai_request["stop"] = request_json["stop_sequences"]

payload = {
"json": openai_request,
"headers": headers,
}
openai_response = await self.chat_completion(payload)

if "error" in openai_response or openai_response.get("object") == "error":
# Normalize to Anthropic error schema for consistent HTTP handler detection
if "error" not in openai_response:
openai_response = {"error": {
"message": openai_response.get("message", "Unknown upstream error"),
"type": openai_response.get("type", "internal_error"),
"code": openai_response.get("code"),
}}
return openai_response

Comment on lines +711 to +712
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On OpenAI error responses, this returns openai_response as-is. That payload shape/types don’t match the Anthropic Messages API error schema the HTTP layer expects, and can also lead to mis-mapped HTTP statuses. It would be safer to normalize OpenAI/vLLM ErrorResponse into a consistent { "error": { "message": ..., "type": ... } } (and include a status/code field) before returning from anthropic_messages.

Suggested change
return openai_response
# Normalize OpenAI/vLLM-style error into Anthropic Messages-style schema
error_obj = openai_response.get("error")
status: Optional[int] = None
code: Optional[Any] = None
err_type: Optional[str] = None
message: str = "Unknown error from upstream OpenAI/vLLM backend."
# Extract fields from nested "error" object if present
if isinstance(error_obj, dict):
message = error_obj.get("message", message)
err_type = error_obj.get("type") or error_obj.get("error_type") or err_type
code = error_obj.get("code") or error_obj.get("error_code")
status = (
error_obj.get("status")
or error_obj.get("status_code")
)
# Fallback: look for error fields at the top level
if isinstance(openai_response, dict):
if message == "Unknown error from upstream OpenAI/vLLM backend.":
message = openai_response.get("message", message)
err_type = (
err_type
or openai_response.get("type")
or openai_response.get("error_type")
)
code = code or openai_response.get("code") or openai_response.get("error_code")
status = (
status
or openai_response.get("status")
or openai_response.get("status_code")
)
if status is None:
status = int(HTTPStatus.INTERNAL_SERVER_ERROR)
if err_type is None:
err_type = "internal_error"
normalized_error: Dict[str, Any] = {
"error": {
"message": message,
"type": err_type,
},
"status": status,
}
if code is not None:
normalized_error["error"]["code"] = code
return normalized_error

Copilot uses AI. Check for mistakes.
finish_reason = openai_response["choices"][0].get("finish_reason", "stop")
stop_reason_map = {
"stop": "end_turn",
"length": "max_tokens",
"content_filter": "end_turn",
"tool_calls": "tool_use",
"function_call": "tool_use",
}
stop_reason = stop_reason_map.get(finish_reason, "end_turn")

message_content = openai_response["choices"][0]["message"]["content"]
if message_content is None:
message_content = ""

anthropic_response = {
"id": openai_response.get("id") or f"msg-{uuid4()}",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": message_content}],
"model": openai_response.get("model", request_json["model"]),
"stop_reason": stop_reason,
"usage": {
"input_tokens": openai_response.get("usage", {}).get("prompt_tokens", 0),
"output_tokens": openai_response.get("usage", {}).get("completion_tokens", 0),
},
}

return anthropic_response

except Exception as e:
logger.error(f"anthropic_messages error: {e}")
return {
"error": {
"message": f"Error converting response: {str(e)}",
"type": "internal_error",
}
}

async def pause_generation(self, clear_cache: bool = False) -> None:
"""Pause generation using vLLM's native keep mode, freezing in-flight requests."""
engine = self._get_engine()
Expand Down
Loading