diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py
index c3040fc1a2..973458d2fe 100644
--- a/skyrl/backends/skyrl_train/inference_engines/base.py
+++ b/skyrl/backends/skyrl_train/inference_engines/base.py
@@ -118,6 +118,11 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
raise NotImplementedError
+ @abstractmethod
+ async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
+ """Handle Anthropic Messages API (/v1/messages). Payload: {"json":
, "headers": }."""
+ raise NotImplementedError
+
@abstractmethod
async def wake_up(self, *args: Any, **kwargs: Any):
raise NotImplementedError
diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py
index a443142a78..362bfb8927 100644
--- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py
+++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py
@@ -214,6 +214,17 @@ async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, An
return await self.engines[engine_idx].chat_completion(request_payload)
+ async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
+ """Route Anthropic Messages API requests to engines.
+ Similar to chat_completion, routes based on session_id for sticky routing.
+ """
+ session_id = request_payload["json"].pop("session_id", None)
+ if session_id is not None:
+ assert isinstance(session_id, (str, int)), "Session ID must be an integer or string for `/v1/messages`"
+ engine_idx = self._select_engine_idx(session_id)
+ logger.info(f"[InferenceEngineClient] Routing /v1/messages to engine {engine_idx}/{len(self.engines)}")
+ return await self.engines[engine_idx].anthropic_messages(request_payload)
+
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Handles an OpenAI /completions request.
diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client_http_endpoint.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client_http_endpoint.py
index ca1259ef98..85812d72e8 100644
--- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client_http_endpoint.py
+++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client_http_endpoint.py
@@ -300,6 +300,66 @@ async def completions(raw_request: Request):
"""
return await handle_openai_request(raw_request, endpoint="/completions")
+ @app.post("/v1/messages")
+ async def anthropic_messages(raw_request: Request):
+ """Anthropic-compatible Messages API endpoint."""
+ try:
+ request_json = await raw_request.json()
+
+ if _global_inference_engine_client is None:
+ return JSONResponse(
+ content={"error": {"message": "Inference engine client not initialized", "type": "internal_error"}},
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
+ )
+ if "model" not in request_json:
+ return JSONResponse(
+ content={"error": {"message": "The field `model` is required", "type": "invalid_request_error"}},
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ )
+ messages = request_json.get("messages")
+ if not isinstance(messages, list) or not messages:
+ return JSONResponse(
+ content={"error": {"message": "The field `messages` is required, must be a non-empty list", "type": "invalid_request_error"}},
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ )
+
+ payload = {
+ "json": request_json,
+ "headers": dict(raw_request.headers) if hasattr(raw_request, "headers") else {},
+ }
+ anthropic_response = await _global_inference_engine_client.anthropic_messages(payload)
+
+ if "error" in anthropic_response or anthropic_response.get("object", "") == "error":
+ if "error" in anthropic_response:
+ error = anthropic_response["error"]
+ error_code = error.get("code")
+ error_type = error.get("type", "internal_error")
+ else:
+ error_code = anthropic_response.get("code")
+ error_type = anthropic_response.get("type", "internal_error")
+ # Prefer numeric error code if available, fall back to type-based mapping
+ if isinstance(error_code, int):
+ status_code = error_code
+ elif isinstance(error_code, str) and error_code.isdigit():
+ status_code = int(error_code)
+ else:
+ status_code = HTTPStatus.BAD_REQUEST.value if error_type == "invalid_request_error" else HTTPStatus.INTERNAL_SERVER_ERROR.value
+ return JSONResponse(content=anthropic_response, status_code=status_code)
+
+ return JSONResponse(content=anthropic_response)
+
+ except json.JSONDecodeError as e:
+ return JSONResponse(
+ content={"error": {"message": f"Invalid JSON: {str(e)}", "type": "invalid_request_error"}},
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ )
+ except Exception as e:
+ logger.error(f"Error in /v1/messages: {e}\n{traceback.format_exc()}")
+ return JSONResponse(
+ content={"error": {"message": f"Internal error: {str(e)}", "type": "internal_error"}},
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
+ )
+
# Health check endpoint
# All inference engine replicas are initialized before creating `InferenceEngineClient`, and thus
# we can start receiving requests as soon as the FastAPI server starts
diff --git a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py
index 3276612a01..66f5cff071 100644
--- a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py
+++ b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py
@@ -76,6 +76,9 @@ async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, An
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
return await self.inference_engine_actor.completion.remote(request_payload)
+ async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
+ return await self.inference_engine_actor.anthropic_messages.remote(request_payload)
+
async def pause_generation(self) -> None:
return await self.inference_engine_actor.pause_generation.remote()
diff --git a/skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py b/skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py
index b1e61ce2be..8cf8a66c99 100644
--- a/skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py
+++ b/skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py
@@ -228,6 +228,17 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
return response
+ async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
+ """Call Anthropic Messages API endpoint (/v1/messages)."""
+ body = request_payload.get("json", {})
+ headers = {"Content-Type": "application/json"}
+ async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
+ request_url = f"{self.url}/v1/messages"
+ async with session.post(request_url, json=body, headers=headers) as resp:
+ response = await resp.json()
+
+ return response
+
async def wake_up(self, *args: Any, **kwargs: Any):
async with aiohttp.ClientSession() as session:
resp = await session.post(f"{self.url}/wake_up", json={"tags": kwargs.get("tags", 1)})
diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
index cfef56ace2..10c10f05d5 100644
--- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
+++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
@@ -275,6 +275,10 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Only supported in AsyncVLLMInferenceEngine."""
raise NotImplementedError("`completion` is only supported in AsyncVLLMInferenceEngine.")
+ async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
+ """Only supported in AsyncVLLMInferenceEngine."""
+ raise NotImplementedError("`anthropic_messages` is only supported in AsyncVLLMInferenceEngine.")
+
async def wake_up(self, *args: Any, **kwargs: Any):
await asyncio.to_thread(self.llm.wake_up, tags=kwargs.get("tags", None))
@@ -642,6 +646,108 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
return await self._handle_openai_request(request_payload, endpoint="/completions")
+ async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
+ """Convert Anthropic Messages format to OpenAI chat completions and back."""
+ request_json = request_payload.get("json", {})
+ headers = request_payload.get("headers", {})
+
+ if "model" not in request_json:
+ return {"error": {"message": "The field `model` is required", "type": "invalid_request_error"}}
+ messages = request_json.get("messages")
+ if not isinstance(messages, list) or not messages:
+ return {"error": {"message": "The field `messages` is required, must be a non-empty list", "type": "invalid_request_error"}}
+
+ try:
+ openai_request = {
+ "model": request_json["model"],
+ "messages": [],
+ "stream": False,
+ }
+
+ if "system" in request_json and request_json["system"]:
+ system_content = request_json["system"]
+ if isinstance(system_content, list):
+ text_parts = []
+ for block in system_content:
+ if block.get("type") == "text":
+ text_parts.append(block["text"])
+ system_content = "\n".join(text_parts)
+ openai_request["messages"].append({"role": "system", "content": system_content})
+
+ for msg in request_json["messages"]:
+ content = msg["content"]
+ if isinstance(content, list):
+ text_parts = []
+ for block in content:
+ if block.get("type") == "text":
+ text_parts.append(block["text"])
+ content = "\n".join(text_parts)
+ openai_msg = {"role": msg["role"], "content": content}
+ openai_request["messages"].append(openai_msg)
+
+ if "max_tokens" in request_json:
+ openai_request["max_tokens"] = request_json["max_tokens"]
+ if "temperature" in request_json:
+ openai_request["temperature"] = request_json["temperature"]
+ if "top_p" in request_json:
+ openai_request["top_p"] = request_json["top_p"]
+ if "stop_sequences" in request_json:
+ openai_request["stop"] = request_json["stop_sequences"]
+
+ payload = {
+ "json": openai_request,
+ "headers": headers,
+ }
+ openai_response = await self.chat_completion(payload)
+
+ if "error" in openai_response or openai_response.get("object") == "error":
+ # Normalize to Anthropic error schema for consistent HTTP handler detection
+ if "error" not in openai_response:
+ openai_response = {"error": {
+ "message": openai_response.get("message", "Unknown upstream error"),
+ "type": openai_response.get("type", "internal_error"),
+ "code": openai_response.get("code"),
+ }}
+ return openai_response
+
+ finish_reason = openai_response["choices"][0].get("finish_reason", "stop")
+ stop_reason_map = {
+ "stop": "end_turn",
+ "length": "max_tokens",
+ "content_filter": "end_turn",
+ "tool_calls": "tool_use",
+ "function_call": "tool_use",
+ }
+ stop_reason = stop_reason_map.get(finish_reason, "end_turn")
+
+ message_content = openai_response["choices"][0]["message"]["content"]
+ if message_content is None:
+ message_content = ""
+
+ anthropic_response = {
+ "id": openai_response.get("id") or f"msg-{uuid4()}",
+ "type": "message",
+ "role": "assistant",
+ "content": [{"type": "text", "text": message_content}],
+ "model": openai_response.get("model", request_json["model"]),
+ "stop_reason": stop_reason,
+ "usage": {
+ "input_tokens": openai_response.get("usage", {}).get("prompt_tokens", 0),
+ "output_tokens": openai_response.get("usage", {}).get("completion_tokens", 0),
+ },
+ }
+
+ return anthropic_response
+
+ except Exception as e:
+ logger.error(f"anthropic_messages error: {e}")
+ return {
+ "error": {
+ "message": f"Error converting response: {str(e)}",
+ "type": "internal_error",
+ }
+ }
+
async def pause_generation(self, clear_cache: bool = False) -> None:
"""Pause generation using vLLM's native keep mode, freezing in-flight requests."""
engine = self._get_engine()