diff --git a/pyproject.toml b/pyproject.toml index 9a921ba86..18fd10dea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,8 @@ dependencies = [ "azure-core>=1.38.0", "azure-identity>=1.21.0", "pyasn1>=0.6.2", + # Used for system prompt template variable rendering + "jinja2>=3.1.0", ] diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 0333e04f7..886d04c12 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -4,10 +4,13 @@ from the RHEL Lightspeed Command Line Assistant (CLA). """ +import functools import time from datetime import datetime from typing import Annotated, Any, Optional, cast +import jinja2 +from jinja2.sandbox import SandboxedEnvironment from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request from llama_stack_api.openai_responses import OpenAIResponseObject from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError @@ -49,11 +52,17 @@ logger = get_logger(__name__) router = APIRouter(tags=["rlsapi-v1"]) + +class TemplateRenderError(Exception): + """Raised when the system prompt Jinja2 template cannot be compiled.""" + + # Default values when RH Identity auth is not configured AUTH_DISABLED = "auth_disabled" # Keep this tuple centralized so infer_endpoint can catch all expected backend # failures in one place while preserving a single telemetry/error-mapping path. _INFER_HANDLED_EXCEPTIONS = ( + TemplateRenderError, RuntimeError, APIConnectionError, RateLimitError, @@ -102,44 +111,70 @@ def _get_rh_identity_context(request: Request) -> tuple[str, str]: def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str: - """Build LLM instructions incorporating date and system context. + """Build LLM instructions by rendering the system prompt as a Jinja2 template. - Enhances the default system prompt with today's date and RHEL system - information to provide the LLM with relevant context about the user's - environment and current time. + The base prompt is rendered with the context variables ``date``, ``os``, + ``version``, and ``arch``. Prompts without template markers pass through + unchanged. The compiled template is cached after the first call. Args: systeminfo: System information from the client (OS, version, arch). Returns: - Instructions string for the LLM, with date and system context. + The rendered instructions string for the LLM. """ - base_prompt = _get_base_prompt() date_today = datetime.now().strftime("%B %d, %Y") - context_parts = [] - if systeminfo.os: - context_parts.append(f"OS: {systeminfo.os}") - if systeminfo.version: - context_parts.append(f"Version: {systeminfo.version}") - if systeminfo.arch: - context_parts.append(f"Architecture: {systeminfo.arch}") + return _get_prompt_template().render( + date=date_today, + os=systeminfo.os or "", + version=systeminfo.version or "", + arch=systeminfo.arch or "", + ) + + +@functools.lru_cache(maxsize=8) +def _compile_prompt_template(prompt: str) -> jinja2.Template: + """Compile a Jinja2 template string inside a SandboxedEnvironment. - if not context_parts: - return f"{base_prompt}\n\nToday's date: {date_today}" + Results are cached by prompt text so that a configuration reload with + a new system prompt produces a fresh compiled template. - system_context = ", ".join(context_parts) - return f"{base_prompt}\n\nToday's date: {date_today}\n\nUser's system: {system_context}" + Args: + prompt: The raw template source string. + + Returns: + The compiled Jinja2 Template. + Raises: + TemplateRenderError: If the template contains invalid Jinja2 syntax. + """ + env = SandboxedEnvironment() + try: + return env.from_string(prompt) + except jinja2.TemplateSyntaxError as exc: + raise TemplateRenderError( + f"System prompt contains invalid Jinja2 syntax: {exc}" + ) from exc -def _get_base_prompt() -> str: - """Get the base system prompt with configuration fallback.""" - if ( - configuration.customization is not None + +def _get_prompt_template() -> jinja2.Template: + """Resolve the system prompt from configuration and return the compiled template. + + Delegates to the cached ``_compile_prompt_template`` so that identical + prompt text is compiled only once, while configuration changes are + picked up automatically. + + Returns: + The compiled Jinja2 Template ready for rendering. + """ + prompt = ( + configuration.customization.system_prompt + if configuration.customization is not None and configuration.customization.system_prompt is not None - ): - return configuration.customization.system_prompt - return constants.DEFAULT_SYSTEM_PROMPT + else constants.DEFAULT_SYSTEM_PROMPT + ) + return _compile_prompt_template(prompt) async def _get_default_model_id() -> str: @@ -319,7 +354,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po return inference_time -def _map_inference_error_to_http_exception( +def _map_inference_error_to_http_exception( # pylint: disable=too-many-return-statements error: Exception, model_id: str, request_id: str ) -> Optional[HTTPException]: """Map known inference errors to HTTPException. @@ -328,6 +363,13 @@ def _map_inference_error_to_http_exception( so callers can preserve existing re-raise behavior for unknown runtime errors. """ + if isinstance(error, TemplateRenderError): + logger.error( + "Invalid system prompt template for request %s: %s", request_id, error + ) + error_response = InternalServerErrorResponse.generic() + return HTTPException(**error_response.model_dump()) + if isinstance(error, RuntimeError): error_message = str(error).lower() if "context_length" in error_message or "context length" in error_message: @@ -398,7 +440,6 @@ async def infer_endpoint( # pylint: disable=R0914 logger.info("Processing rlsapi v1 /infer request %s", request_id) input_source = infer_request.get_input_source() - instructions = _build_instructions(infer_request.context.systeminfo) model_id = await _get_default_model_id() provider, model = extract_provider_and_model_from_model_id(model_id) mcp_tools: list[Any] = await get_mcp_tools(request_headers=request.headers) @@ -408,6 +449,7 @@ async def infer_endpoint( # pylint: disable=R0914 start_time = time.monotonic() try: + instructions = _build_instructions(infer_request.context.systeminfo) response_text = await retrieve_simple_response( input_source, instructions, diff --git a/tests/unit/app/endpoints/test_rlsapi_v1.py b/tests/unit/app/endpoints/test_rlsapi_v1.py index 2b1cc22ef..43e9cb0f7 100644 --- a/tests/unit/app/endpoints/test_rlsapi_v1.py +++ b/tests/unit/app/endpoints/test_rlsapi_v1.py @@ -4,7 +4,8 @@ # pylint: disable=unused-argument import re -from typing import Any, Optional +from collections.abc import Callable +from typing import Any import pytest from fastapi import HTTPException, status @@ -15,7 +16,9 @@ import constants from app.endpoints.rlsapi_v1 import ( AUTH_DISABLED, + TemplateRenderError, _build_instructions, + _compile_prompt_template, _get_default_model_id, _get_rh_identity_context, infer_endpoint, @@ -39,6 +42,26 @@ MOCK_AUTH: AuthTuple = ("mock_user_id", "mock_username", False, "mock_token") +@pytest.fixture(autouse=True) +def _clear_prompt_template_cache() -> None: + """Clear the lru_cache on _compile_prompt_template between tests.""" + _compile_prompt_template.cache_clear() + + +@pytest.fixture(name="mock_custom_prompt") +def mock_custom_prompt_fixture(mocker: MockerFixture) -> Callable[[str], None]: + """Factory fixture that patches configuration with a custom system prompt.""" + + def _set(prompt: str) -> None: + mock_customization = mocker.Mock() + mock_customization.system_prompt = prompt + mock_config = mocker.Mock() + mock_config.customization = mock_customization + mocker.patch("app.endpoints.rlsapi_v1.configuration", mock_config) + + return _set + + def _create_mock_request(mocker: MockerFixture, rh_identity: Any = None) -> Any: """Create a mock FastAPI Request with optional RH Identity data.""" mock_request = mocker.Mock() @@ -140,71 +163,19 @@ def mock_generic_runtime_error_fixture(mocker: MockerFixture) -> None: # --- Test _build_instructions --- -@pytest.mark.parametrize( - ("systeminfo_kwargs", "expected_contains", "expected_not_contains"), - [ - pytest.param( - {"os": "RHEL", "version": "9.3", "arch": "x86_64"}, - ["OS: RHEL", "Version: 9.3", "Architecture: x86_64"], - [], - id="full_systeminfo", - ), - pytest.param( - {"os": "RHEL", "version": "", "arch": ""}, - ["OS: RHEL"], - ["Version:", "Architecture:"], - id="partial_systeminfo", - ), - pytest.param( - {}, - [constants.DEFAULT_SYSTEM_PROMPT], - ["OS:", "Version:", "Architecture:"], - id="empty_systeminfo", - ), - ], -) -def test_build_instructions( - systeminfo_kwargs: dict[str, str], - expected_contains: list[str], - expected_not_contains: list[str], -) -> None: - """Test _build_instructions includes date and system info.""" - systeminfo = RlsapiV1SystemInfo(**systeminfo_kwargs) +def test_build_instructions_default_prompt_passes_through() -> None: + """Test _build_instructions returns default prompt unchanged when no template vars.""" + systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64") result = _build_instructions(systeminfo) - assert re.search(r"Today's date: \w+ \d{2}, \d{4}", result) - for expected in expected_contains: - assert expected in result - for not_expected in expected_not_contains: - assert not_expected not in result - - -# --- Test _build_instructions with customization.system_prompt --- + assert result == constants.DEFAULT_SYSTEM_PROMPT -@pytest.mark.parametrize( - ("custom_prompt", "expected_prompt"), - [ - pytest.param( - "You are a RHEL expert.", - "You are a RHEL expert.", - id="customization_system_prompt_set", - ), - pytest.param( - None, - constants.DEFAULT_SYSTEM_PROMPT, - id="customization_system_prompt_none", - ), - ], -) -def test_build_instructions_with_customization( - mocker: MockerFixture, - custom_prompt: Optional[str], - expected_prompt: str, -) -> None: - """Test _build_instructions uses customization.system_prompt when set.""" +def test_build_instructions_with_customization(mocker: MockerFixture) -> None: + """Test _build_instructions uses customization.system_prompt with template vars.""" + template = "Expert assistant.\n\nDate: {{ date }}\nOS: {{ os }}" mock_customization = mocker.Mock() - mock_customization.system_prompt = custom_prompt + mock_customization.system_prompt = template mock_config = mocker.Mock() mock_config.customization = mock_customization mocker.patch("app.endpoints.rlsapi_v1.configuration", mock_config) @@ -212,12 +183,13 @@ def test_build_instructions_with_customization( systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64") result = _build_instructions(systeminfo) - assert expected_prompt in result + assert "Expert assistant." in result assert "OS: RHEL" in result + assert re.search(r"Date: \w+ \d{2}, \d{4}", result) def test_build_instructions_no_customization(mocker: MockerFixture) -> None: - """Test _build_instructions falls back when customization is None.""" + """Test _build_instructions falls back to DEFAULT_SYSTEM_PROMPT.""" mock_config = mocker.Mock() mock_config.customization = None mocker.patch("app.endpoints.rlsapi_v1.configuration", mock_config) @@ -225,8 +197,90 @@ def test_build_instructions_no_customization(mocker: MockerFixture) -> None: systeminfo = RlsapiV1SystemInfo() result = _build_instructions(systeminfo) - assert result.startswith(constants.DEFAULT_SYSTEM_PROMPT) - assert re.search(r"Today's date: \w+ \d{2}, \d{4}", result) + assert result == constants.DEFAULT_SYSTEM_PROMPT + + +# --- Test Jinja2 template rendering --- + + +def test_build_instructions_renders_jinja2_template( + mock_custom_prompt: Callable[[str], None], +) -> None: + """Test _build_instructions renders Jinja2 template variables instead of appending.""" + mock_custom_prompt( + "You are an assistant.\n\nDate: {{ date }}\nOS: {{ os }} {{ version }} ({{ arch }})" + ) + + systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64") + result = _build_instructions(systeminfo) + + assert "OS: RHEL 9.3 (x86_64)" in result + assert re.search(r"Date: \w+ \d{2}, \d{4}", result) + assert "Today's date:" not in result + assert "User's system:" not in result + + +def test_build_instructions_jinja2_none_values_render_empty( + mock_custom_prompt: Callable[[str], None], +) -> None: + """Test that None system info values render as empty strings, not 'None'.""" + mock_custom_prompt("Assistant.\nOS={{ os }} VER={{ version }} ARCH={{ arch }}") + + systeminfo = RlsapiV1SystemInfo() + result = _build_instructions(systeminfo) + + assert "None" not in result + assert "OS= VER= ARCH=" in result + + +def test_build_instructions_jinja2_conditionals( + mock_custom_prompt: Callable[[str], None], +) -> None: + """Test that Jinja2 conditionals work in system prompt templates.""" + mock_custom_prompt( + "Assistant.{% if os %} OS: {{ os }}{% endif %}" + "{% if version %} VER: {{ version }}{% endif %}" + ) + + systeminfo = RlsapiV1SystemInfo(os="RHEL") + result = _build_instructions(systeminfo) + + assert "OS: RHEL" in result + assert "VER:" not in result + + +def test_build_instructions_plain_prompt_passes_through( + mock_custom_prompt: Callable[[str], None], +) -> None: + """Test that prompts without Jinja2 syntax pass through unchanged.""" + plain_prompt = "You are an expert RHEL assistant." + mock_custom_prompt(plain_prompt) + + systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64") + result = _build_instructions(systeminfo) + + assert result == plain_prompt + + +@pytest.mark.parametrize( + "bad_template", + [ + pytest.param("Hello {{ unclosed", id="unclosed_variable"), + pytest.param("{% if %}", id="if_without_condition"), + pytest.param("{% endfor %}", id="endfor_without_for"), + ], +) +def test_build_instructions_malformed_template_raises_template_render_error( + mock_custom_prompt: Callable[[str], None], + bad_template: str, +) -> None: + """Test that invalid Jinja2 syntax in system prompt raises TemplateRenderError.""" + mock_custom_prompt(bad_template) + + systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64") + + with pytest.raises(TemplateRenderError, match="invalid Jinja2 syntax"): + _build_instructions(systeminfo) # --- Test _get_default_model_id --- @@ -540,6 +594,31 @@ async def test_infer_api_connection_error_returns_503( assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE +async def test_infer_malformed_template_returns_500( + mocker: MockerFixture, + mock_configuration: AppConfig, + mock_custom_prompt: Callable[[str], None], + mock_llm_response: None, + mock_auth_resolvers: None, +) -> None: + """Test /infer endpoint returns 500 when system prompt has invalid Jinja2 syntax.""" + mock_custom_prompt("Hello {{ unclosed") + + infer_request = RlsapiV1InferRequest(question="Test question") + mock_request = _create_mock_request(mocker) + mock_background_tasks = _create_mock_background_tasks(mocker) + + with pytest.raises(HTTPException) as exc_info: + await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + async def test_infer_empty_llm_response_returns_fallback( mocker: MockerFixture, mock_configuration: AppConfig, diff --git a/uv.lock b/uv.lock index c1a4a5d28..7292569e8 100644 --- a/uv.lock +++ b/uv.lock @@ -1519,6 +1519,7 @@ dependencies = [ { name = "einops" }, { name = "email-validator" }, { name = "fastapi" }, + { name = "jinja2" }, { name = "jsonpath-ng" }, { name = "kubernetes" }, { name = "litellm" }, @@ -1613,6 +1614,7 @@ requires-dist = [ { name = "einops", specifier = ">=0.8.1" }, { name = "email-validator", specifier = ">=2.2.0" }, { name = "fastapi", specifier = ">=0.115.12" }, + { name = "jinja2", specifier = ">=3.1.0" }, { name = "jsonpath-ng", specifier = ">=1.6.1" }, { name = "kubernetes", specifier = ">=30.1.0" }, { name = "litellm", specifier = ">=1.75.5.post1" },