From acf7023b97565fb071b226f6c1c13d4411fbb606 Mon Sep 17 00:00:00 2001 From: Major Hayden Date: Tue, 17 Feb 2026 10:06:50 -0600 Subject: [PATCH] fix(rlsapi): improve exception handling and prevent sensitive data leakage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add missing exception handlers for RuntimeError (context_length → 413) and OpenAIAPIStatusError. Use handle_known_apistatus_errors() for smarter status code mapping instead of generic 500s. Replace raw str(e) in client-facing cause fields with safe generic messages while preserving full details in server-side logs. Extract common error bookkeeping into _record_inference_failure() helper to reduce duplication. Signed-off-by: Major Hayden --- src/app/endpoints/rlsapi_v1.py | 112 +++++++++++++-------- tests/unit/app/endpoints/test_rlsapi_v1.py | 55 ++++++++++ 2 files changed, 126 insertions(+), 41 deletions(-) diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 903ee1ac1..6ef545847 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -10,6 +10,7 @@ from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request from llama_stack_api.openai_responses import OpenAIResponseObject from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError +from openai._exceptions import APIStatusError as OpenAIAPIStatusError import constants import metrics @@ -23,6 +24,7 @@ from models.responses import ( ForbiddenResponse, InternalServerErrorResponse, + PromptTooLongResponse, QuotaExceededResponse, ServiceUnavailableResponse, UnauthorizedResponse, @@ -31,6 +33,7 @@ from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse from observability import InferenceEventData, build_inference_event, send_splunk_event +from utils.query import handle_known_apistatus_errors from utils.responses import extract_text_from_response_output_item, get_mcp_tools from utils.suid import get_suid from log import get_logger @@ -73,6 +76,7 @@ def _get_rh_identity_context(request: Request) -> tuple[str, str]: examples=["missing header", "missing token"] ), 403: ForbiddenResponse.openapi_response(examples=["endpoint"]), + 413: PromptTooLongResponse.openapi_response(), 422: UnprocessableEntityResponse.openapi_response(), 429: QuotaExceededResponse.openapi_response(), 500: InternalServerErrorResponse.openapi_response(examples=["generic"]), @@ -229,6 +233,41 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position background_tasks.add_task(send_splunk_event, event, sourcetype) +def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-positional-arguments + background_tasks: BackgroundTasks, + infer_request: RlsapiV1InferRequest, + request: Request, + request_id: str, + error: Exception, + start_time: float, +) -> float: + """Record metrics and queue Splunk event for an inference failure. + + Args: + background_tasks: FastAPI background tasks for async event sending. + infer_request: The original inference request. + request: The FastAPI request object. + request_id: Unique identifier for the request. + error: The exception that caused the failure. + start_time: Monotonic clock time when inference started. + + Returns: + The total inference time in seconds. + """ + inference_time = time.monotonic() - start_time + metrics.llm_calls_failures_total.inc() + _queue_splunk_event( + background_tasks, + infer_request, + request, + request_id, + str(error), + inference_time, + "infer_error", + ) + return inference_time + + @router.post("/infer", responses=infer_responses) @authorize(Action.RLSAPI_V1_INFER) async def infer_endpoint( @@ -265,6 +304,7 @@ async def infer_endpoint( input_source = infer_request.get_input_source() instructions = _build_instructions(infer_request.context.systeminfo) + model_id = _get_default_model_id() mcp_tools = get_mcp_tools(configuration.mcp_servers) logger.debug( "Request %s: Combined input source length: %d", request_id, len(input_source) @@ -276,58 +316,48 @@ async def infer_endpoint( input_source, instructions, tools=mcp_tools ) inference_time = time.monotonic() - start_time + except RuntimeError as e: + if "context_length" in str(e).lower(): + _record_inference_failure( + background_tasks, infer_request, request, request_id, e, start_time + ) + logger.error("Prompt too long for request %s: %s", request_id, e) + error_response = PromptTooLongResponse(model=model_id) + raise HTTPException(**error_response.model_dump()) from e + _record_inference_failure( + background_tasks, infer_request, request, request_id, e, start_time + ) + logger.error("Unexpected RuntimeError for request %s: %s", request_id, e) + raise except APIConnectionError as e: - inference_time = time.monotonic() - start_time - metrics.llm_calls_failures_total.inc() + _record_inference_failure( + background_tasks, infer_request, request, request_id, e, start_time + ) logger.error( "Unable to connect to Llama Stack for request %s: %s", request_id, e ) - _queue_splunk_event( - background_tasks, - infer_request, - request, - request_id, - str(e), - inference_time, - "infer_error", - ) - response = ServiceUnavailableResponse( + error_response = ServiceUnavailableResponse( backend_name="Llama Stack", - cause=str(e), + cause="Unable to connect to the inference backend", ) - raise HTTPException(**response.model_dump()) from e + raise HTTPException(**error_response.model_dump()) from e except RateLimitError as e: - inference_time = time.monotonic() - start_time - metrics.llm_calls_failures_total.inc() + _record_inference_failure( + background_tasks, infer_request, request, request_id, e, start_time + ) logger.error("Rate limit exceeded for request %s: %s", request_id, e) - _queue_splunk_event( - background_tasks, - infer_request, - request, - request_id, - str(e), - inference_time, - "infer_error", + error_response = QuotaExceededResponse( + response="The quota has been exceeded", + cause="Rate limit exceeded, please try again later", ) - response = QuotaExceededResponse( - response="The quota has been exceeded", cause=str(e) + raise HTTPException(**error_response.model_dump()) from e + except (APIStatusError, OpenAIAPIStatusError) as e: + _record_inference_failure( + background_tasks, infer_request, request, request_id, e, start_time ) - raise HTTPException(**response.model_dump()) from e - except APIStatusError as e: - inference_time = time.monotonic() - start_time - metrics.llm_calls_failures_total.inc() logger.exception("API error for request %s: %s", request_id, e) - _queue_splunk_event( - background_tasks, - infer_request, - request, - request_id, - str(e), - inference_time, - "infer_error", - ) - response = InternalServerErrorResponse.generic() - raise HTTPException(**response.model_dump()) from e + error_response = handle_known_apistatus_errors(e, model_id) + raise HTTPException(**error_response.model_dump()) from e if not response_text: logger.warning("Empty response from LLM for request %s", request_id) diff --git a/tests/unit/app/endpoints/test_rlsapi_v1.py b/tests/unit/app/endpoints/test_rlsapi_v1.py index 3a633a32a..1f23f7284 100644 --- a/tests/unit/app/endpoints/test_rlsapi_v1.py +++ b/tests/unit/app/endpoints/test_rlsapi_v1.py @@ -126,6 +126,15 @@ def mock_api_connection_error_fixture(mocker: MockerFixture) -> None: ) +@pytest.fixture(name="mock_generic_runtime_error") +def mock_generic_runtime_error_fixture(mocker: MockerFixture) -> None: + """Mock responses.create() to raise a non-context-length RuntimeError.""" + _setup_responses_mock( + mocker, + mocker.AsyncMock(side_effect=RuntimeError("something went wrong")), + ) + + # --- Test _build_instructions --- @@ -656,3 +665,49 @@ async def test_infer_endpoint_calls_get_mcp_tools( ) mock_get_mcp_tools.assert_called_once_with(mock_configuration.mcp_servers) + + +@pytest.mark.asyncio +async def test_infer_generic_runtime_error_reraises( + mocker: MockerFixture, + mock_configuration: AppConfig, + mock_generic_runtime_error: None, + mock_auth_resolvers: None, +) -> None: + """Test /infer endpoint re-raises non-context-length RuntimeErrors.""" + infer_request = RlsapiV1InferRequest(question="Test question") + mock_request = _create_mock_request(mocker) + mock_background_tasks = _create_mock_background_tasks(mocker) + + with pytest.raises(RuntimeError, match="something went wrong"): + await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) + + +@pytest.mark.asyncio +async def test_infer_generic_runtime_error_records_failure( + mocker: MockerFixture, + mock_configuration: AppConfig, + mock_generic_runtime_error: None, + mock_auth_resolvers: None, +) -> None: + """Test that non-context-length RuntimeErrors record inference failure metrics.""" + infer_request = RlsapiV1InferRequest(question="Test question") + mock_request = _create_mock_request(mocker) + mock_background_tasks = _create_mock_background_tasks(mocker) + + with pytest.raises(RuntimeError): + await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) + + mock_background_tasks.add_task.assert_called_once() + call_args = mock_background_tasks.add_task.call_args + assert call_args[0][2] == "infer_error"