diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index c3040fc1a2..973458d2fe 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -118,6 +118,11 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: """ raise NotImplementedError + @abstractmethod + async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: + """Handle Anthropic Messages API (/v1/messages). Payload: {"json": , "headers": }.""" + raise NotImplementedError + @abstractmethod async def wake_up(self, *args: Any, **kwargs: Any): raise NotImplementedError diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py index a443142a78..362bfb8927 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py @@ -214,6 +214,17 @@ async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, An return await self.engines[engine_idx].chat_completion(request_payload) + async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: + """Route Anthropic Messages API requests to engines. + Similar to chat_completion, routes based on session_id for sticky routing. + """ + session_id = request_payload["json"].pop("session_id", None) + if session_id is not None: + assert isinstance(session_id, (str, int)), "Session ID must be an integer or string for `/v1/messages`" + engine_idx = self._select_engine_idx(session_id) + logger.info(f"[InferenceEngineClient] Routing /v1/messages to engine {engine_idx}/{len(self.engines)}") + return await self.engines[engine_idx].anthropic_messages(request_payload) + async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: """ Handles an OpenAI /completions request. diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client_http_endpoint.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client_http_endpoint.py index ca1259ef98..85812d72e8 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client_http_endpoint.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client_http_endpoint.py @@ -300,6 +300,66 @@ async def completions(raw_request: Request): """ return await handle_openai_request(raw_request, endpoint="/completions") + @app.post("/v1/messages") + async def anthropic_messages(raw_request: Request): + """Anthropic-compatible Messages API endpoint.""" + try: + request_json = await raw_request.json() + + if _global_inference_engine_client is None: + return JSONResponse( + content={"error": {"message": "Inference engine client not initialized", "type": "internal_error"}}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + if "model" not in request_json: + return JSONResponse( + content={"error": {"message": "The field `model` is required", "type": "invalid_request_error"}}, + status_code=HTTPStatus.BAD_REQUEST.value, + ) + messages = request_json.get("messages") + if not isinstance(messages, list) or not messages: + return JSONResponse( + content={"error": {"message": "The field `messages` is required, must be a non-empty list", "type": "invalid_request_error"}}, + status_code=HTTPStatus.BAD_REQUEST.value, + ) + + payload = { + "json": request_json, + "headers": dict(raw_request.headers) if hasattr(raw_request, "headers") else {}, + } + anthropic_response = await _global_inference_engine_client.anthropic_messages(payload) + + if "error" in anthropic_response or anthropic_response.get("object", "") == "error": + if "error" in anthropic_response: + error = anthropic_response["error"] + error_code = error.get("code") + error_type = error.get("type", "internal_error") + else: + error_code = anthropic_response.get("code") + error_type = anthropic_response.get("type", "internal_error") + # Prefer numeric error code if available, fall back to type-based mapping + if isinstance(error_code, int): + status_code = error_code + elif isinstance(error_code, str) and error_code.isdigit(): + status_code = int(error_code) + else: + status_code = HTTPStatus.BAD_REQUEST.value if error_type == "invalid_request_error" else HTTPStatus.INTERNAL_SERVER_ERROR.value + return JSONResponse(content=anthropic_response, status_code=status_code) + + return JSONResponse(content=anthropic_response) + + except json.JSONDecodeError as e: + return JSONResponse( + content={"error": {"message": f"Invalid JSON: {str(e)}", "type": "invalid_request_error"}}, + status_code=HTTPStatus.BAD_REQUEST.value, + ) + except Exception as e: + logger.error(f"Error in /v1/messages: {e}\n{traceback.format_exc()}") + return JSONResponse( + content={"error": {"message": f"Internal error: {str(e)}", "type": "internal_error"}}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + # Health check endpoint # All inference engine replicas are initialized before creating `InferenceEngineClient`, and thus # we can start receiving requests as soon as the FastAPI server starts diff --git a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index 3276612a01..66f5cff071 100644 --- a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -76,6 +76,9 @@ async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, An async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: return await self.inference_engine_actor.completion.remote(request_payload) + async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: + return await self.inference_engine_actor.anthropic_messages.remote(request_payload) + async def pause_generation(self) -> None: return await self.inference_engine_actor.pause_generation.remote() diff --git a/skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py b/skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py index b1e61ce2be..8cf8a66c99 100644 --- a/skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py @@ -228,6 +228,17 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: return response + async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: + """Call Anthropic Messages API endpoint (/v1/messages).""" + body = request_payload.get("json", {}) + headers = {"Content-Type": "application/json"} + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session: + request_url = f"{self.url}/v1/messages" + async with session.post(request_url, json=body, headers=headers) as resp: + response = await resp.json() + + return response + async def wake_up(self, *args: Any, **kwargs: Any): async with aiohttp.ClientSession() as session: resp = await session.post(f"{self.url}/wake_up", json={"tags": kwargs.get("tags", 1)}) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index cfef56ace2..10c10f05d5 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -275,6 +275,10 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: """Only supported in AsyncVLLMInferenceEngine.""" raise NotImplementedError("`completion` is only supported in AsyncVLLMInferenceEngine.") + async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: + """Only supported in AsyncVLLMInferenceEngine.""" + raise NotImplementedError("`anthropic_messages` is only supported in AsyncVLLMInferenceEngine.") + async def wake_up(self, *args: Any, **kwargs: Any): await asyncio.to_thread(self.llm.wake_up, tags=kwargs.get("tags", None)) @@ -642,6 +646,108 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: """ return await self._handle_openai_request(request_payload, endpoint="/completions") + async def anthropic_messages(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: + """Convert Anthropic Messages format to OpenAI chat completions and back.""" + request_json = request_payload.get("json", {}) + headers = request_payload.get("headers", {}) + + if "model" not in request_json: + return {"error": {"message": "The field `model` is required", "type": "invalid_request_error"}} + messages = request_json.get("messages") + if not isinstance(messages, list) or not messages: + return {"error": {"message": "The field `messages` is required, must be a non-empty list", "type": "invalid_request_error"}} + + try: + openai_request = { + "model": request_json["model"], + "messages": [], + "stream": False, + } + + if "system" in request_json and request_json["system"]: + system_content = request_json["system"] + if isinstance(system_content, list): + text_parts = [] + for block in system_content: + if block.get("type") == "text": + text_parts.append(block["text"]) + system_content = "\n".join(text_parts) + openai_request["messages"].append({"role": "system", "content": system_content}) + + for msg in request_json["messages"]: + content = msg["content"] + if isinstance(content, list): + text_parts = [] + for block in content: + if block.get("type") == "text": + text_parts.append(block["text"]) + content = "\n".join(text_parts) + openai_msg = {"role": msg["role"], "content": content} + openai_request["messages"].append(openai_msg) + + if "max_tokens" in request_json: + openai_request["max_tokens"] = request_json["max_tokens"] + if "temperature" in request_json: + openai_request["temperature"] = request_json["temperature"] + if "top_p" in request_json: + openai_request["top_p"] = request_json["top_p"] + if "stop_sequences" in request_json: + openai_request["stop"] = request_json["stop_sequences"] + + payload = { + "json": openai_request, + "headers": headers, + } + openai_response = await self.chat_completion(payload) + + if "error" in openai_response or openai_response.get("object") == "error": + # Normalize to Anthropic error schema for consistent HTTP handler detection + if "error" not in openai_response: + openai_response = {"error": { + "message": openai_response.get("message", "Unknown upstream error"), + "type": openai_response.get("type", "internal_error"), + "code": openai_response.get("code"), + }} + return openai_response + + finish_reason = openai_response["choices"][0].get("finish_reason", "stop") + stop_reason_map = { + "stop": "end_turn", + "length": "max_tokens", + "content_filter": "end_turn", + "tool_calls": "tool_use", + "function_call": "tool_use", + } + stop_reason = stop_reason_map.get(finish_reason, "end_turn") + + message_content = openai_response["choices"][0]["message"]["content"] + if message_content is None: + message_content = "" + + anthropic_response = { + "id": openai_response.get("id") or f"msg-{uuid4()}", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": message_content}], + "model": openai_response.get("model", request_json["model"]), + "stop_reason": stop_reason, + "usage": { + "input_tokens": openai_response.get("usage", {}).get("prompt_tokens", 0), + "output_tokens": openai_response.get("usage", {}).get("completion_tokens", 0), + }, + } + + return anthropic_response + + except Exception as e: + logger.error(f"anthropic_messages error: {e}") + return { + "error": { + "message": f"Error converting response: {str(e)}", + "type": "internal_error", + } + } + async def pause_generation(self, clear_cache: bool = False) -> None: """Pause generation using vLLM's native keep mode, freezing in-flight requests.""" engine = self._get_engine()