Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 71 additions & 41 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
from llama_stack_api.openai_responses import OpenAIResponseObject
from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError
from openai._exceptions import APIStatusError as OpenAIAPIStatusError

import constants
import metrics
Expand All @@ -23,6 +24,7 @@
from models.responses import (
ForbiddenResponse,
InternalServerErrorResponse,
PromptTooLongResponse,
QuotaExceededResponse,
ServiceUnavailableResponse,
UnauthorizedResponse,
Expand All @@ -31,6 +33,7 @@
from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
from observability import InferenceEventData, build_inference_event, send_splunk_event
from utils.query import handle_known_apistatus_errors
from utils.responses import extract_text_from_response_output_item, get_mcp_tools
from utils.suid import get_suid
from log import get_logger
Expand Down Expand Up @@ -73,6 +76,7 @@ def _get_rh_identity_context(request: Request) -> tuple[str, str]:
examples=["missing header", "missing token"]
),
403: ForbiddenResponse.openapi_response(examples=["endpoint"]),
413: PromptTooLongResponse.openapi_response(),
422: UnprocessableEntityResponse.openapi_response(),
429: QuotaExceededResponse.openapi_response(),
500: InternalServerErrorResponse.openapi_response(examples=["generic"]),
Expand Down Expand Up @@ -229,6 +233,41 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position
background_tasks.add_task(send_splunk_event, event, sourcetype)


def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-positional-arguments
background_tasks: BackgroundTasks,
infer_request: RlsapiV1InferRequest,
request: Request,
request_id: str,
error: Exception,
start_time: float,
) -> float:
"""Record metrics and queue Splunk event for an inference failure.

Args:
background_tasks: FastAPI background tasks for async event sending.
infer_request: The original inference request.
request: The FastAPI request object.
request_id: Unique identifier for the request.
error: The exception that caused the failure.
start_time: Monotonic clock time when inference started.

Returns:
The total inference time in seconds.
"""
inference_time = time.monotonic() - start_time
metrics.llm_calls_failures_total.inc()
_queue_splunk_event(
background_tasks,
infer_request,
request,
request_id,
str(error),
inference_time,
"infer_error",
)
return inference_time


@router.post("/infer", responses=infer_responses)
@authorize(Action.RLSAPI_V1_INFER)
async def infer_endpoint(
Expand Down Expand Up @@ -265,6 +304,7 @@ async def infer_endpoint(

input_source = infer_request.get_input_source()
instructions = _build_instructions(infer_request.context.systeminfo)
model_id = _get_default_model_id()
mcp_tools = get_mcp_tools(configuration.mcp_servers)
logger.debug(
"Request %s: Combined input source length: %d", request_id, len(input_source)
Expand All @@ -276,58 +316,48 @@ async def infer_endpoint(
input_source, instructions, tools=mcp_tools
)
inference_time = time.monotonic() - start_time
except RuntimeError as e:
if "context_length" in str(e).lower():
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
)
logger.error("Prompt too long for request %s: %s", request_id, e)
error_response = PromptTooLongResponse(model=model_id)
raise HTTPException(**error_response.model_dump()) from e
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
)
logger.error("Unexpected RuntimeError for request %s: %s", request_id, e)
raise
except APIConnectionError as e:
inference_time = time.monotonic() - start_time
metrics.llm_calls_failures_total.inc()
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
)
logger.error(
"Unable to connect to Llama Stack for request %s: %s", request_id, e
)
_queue_splunk_event(
background_tasks,
infer_request,
request,
request_id,
str(e),
inference_time,
"infer_error",
)
response = ServiceUnavailableResponse(
error_response = ServiceUnavailableResponse(
backend_name="Llama Stack",
cause=str(e),
cause="Unable to connect to the inference backend",
)
raise HTTPException(**response.model_dump()) from e
raise HTTPException(**error_response.model_dump()) from e
except RateLimitError as e:
inference_time = time.monotonic() - start_time
metrics.llm_calls_failures_total.inc()
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
)
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
_queue_splunk_event(
background_tasks,
infer_request,
request,
request_id,
str(e),
inference_time,
"infer_error",
error_response = QuotaExceededResponse(
response="The quota has been exceeded",
cause="Rate limit exceeded, please try again later",
)
response = QuotaExceededResponse(
response="The quota has been exceeded", cause=str(e)
raise HTTPException(**error_response.model_dump()) from e
except (APIStatusError, OpenAIAPIStatusError) as e:
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
)
raise HTTPException(**response.model_dump()) from e
except APIStatusError as e:
inference_time = time.monotonic() - start_time
metrics.llm_calls_failures_total.inc()
logger.exception("API error for request %s: %s", request_id, e)
_queue_splunk_event(
background_tasks,
infer_request,
request,
request_id,
str(e),
inference_time,
"infer_error",
)
response = InternalServerErrorResponse.generic()
raise HTTPException(**response.model_dump()) from e
error_response = handle_known_apistatus_errors(e, model_id)
raise HTTPException(**error_response.model_dump()) from e

if not response_text:
logger.warning("Empty response from LLM for request %s", request_id)
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/app/endpoints/test_rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
)


@pytest.fixture(name="mock_generic_runtime_error")
def mock_generic_runtime_error_fixture(mocker: MockerFixture) -> None:
"""Mock responses.create() to raise a non-context-length RuntimeError."""
_setup_responses_mock(
mocker,
mocker.AsyncMock(side_effect=RuntimeError("something went wrong")),
)


# --- Test _build_instructions ---


Expand Down Expand Up @@ -656,3 +665,49 @@ async def test_infer_endpoint_calls_get_mcp_tools(
)

mock_get_mcp_tools.assert_called_once_with(mock_configuration.mcp_servers)


@pytest.mark.asyncio
async def test_infer_generic_runtime_error_reraises(
mocker: MockerFixture,
mock_configuration: AppConfig,
mock_generic_runtime_error: None,
mock_auth_resolvers: None,
) -> None:
"""Test /infer endpoint re-raises non-context-length RuntimeErrors."""
infer_request = RlsapiV1InferRequest(question="Test question")
mock_request = _create_mock_request(mocker)
mock_background_tasks = _create_mock_background_tasks(mocker)

with pytest.raises(RuntimeError, match="something went wrong"):
await infer_endpoint(
infer_request=infer_request,
request=mock_request,
background_tasks=mock_background_tasks,
auth=MOCK_AUTH,
)


@pytest.mark.asyncio
async def test_infer_generic_runtime_error_records_failure(
mocker: MockerFixture,
mock_configuration: AppConfig,
mock_generic_runtime_error: None,
mock_auth_resolvers: None,
) -> None:
"""Test that non-context-length RuntimeErrors record inference failure metrics."""
infer_request = RlsapiV1InferRequest(question="Test question")
mock_request = _create_mock_request(mocker)
mock_background_tasks = _create_mock_background_tasks(mocker)

with pytest.raises(RuntimeError):
await infer_endpoint(
infer_request=infer_request,
request=mock_request,
background_tasks=mock_background_tasks,
auth=MOCK_AUTH,
)

mock_background_tasks.add_task.assert_called_once()
call_args = mock_background_tasks.add_task.call_args
assert call_args[0][2] == "infer_error"
Loading