Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ dependencies = [
"azure-core>=1.38.0",
"azure-identity>=1.21.0",
"pyasn1>=0.6.2",
# Used for system prompt template variable rendering
"jinja2>=3.1.0",
]


Expand Down
94 changes: 68 additions & 26 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from the RHEL Lightspeed Command Line Assistant (CLA).
"""

import functools
import time
from datetime import datetime
from typing import Annotated, Any, Optional, cast

import jinja2
from jinja2.sandbox import SandboxedEnvironment
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
from llama_stack_api.openai_responses import OpenAIResponseObject
from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError
Expand Down Expand Up @@ -49,11 +52,17 @@
logger = get_logger(__name__)
router = APIRouter(tags=["rlsapi-v1"])


class TemplateRenderError(Exception):
"""Raised when the system prompt Jinja2 template cannot be compiled."""


# Default values when RH Identity auth is not configured
AUTH_DISABLED = "auth_disabled"
# Keep this tuple centralized so infer_endpoint can catch all expected backend
# failures in one place while preserving a single telemetry/error-mapping path.
_INFER_HANDLED_EXCEPTIONS = (
TemplateRenderError,
RuntimeError,
APIConnectionError,
RateLimitError,
Expand Down Expand Up @@ -102,44 +111,70 @@ def _get_rh_identity_context(request: Request) -> tuple[str, str]:


def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
"""Build LLM instructions incorporating date and system context.
"""Build LLM instructions by rendering the system prompt as a Jinja2 template.

Enhances the default system prompt with today's date and RHEL system
information to provide the LLM with relevant context about the user's
environment and current time.
The base prompt is rendered with the context variables ``date``, ``os``,
``version``, and ``arch``. Prompts without template markers pass through
unchanged. The compiled template is cached after the first call.

Args:
systeminfo: System information from the client (OS, version, arch).

Returns:
Instructions string for the LLM, with date and system context.
The rendered instructions string for the LLM.
"""
base_prompt = _get_base_prompt()
date_today = datetime.now().strftime("%B %d, %Y")

context_parts = []
if systeminfo.os:
context_parts.append(f"OS: {systeminfo.os}")
if systeminfo.version:
context_parts.append(f"Version: {systeminfo.version}")
if systeminfo.arch:
context_parts.append(f"Architecture: {systeminfo.arch}")
return _get_prompt_template().render(
date=date_today,
os=systeminfo.os or "",
version=systeminfo.version or "",
arch=systeminfo.arch or "",
)


@functools.lru_cache(maxsize=8)
def _compile_prompt_template(prompt: str) -> jinja2.Template:
"""Compile a Jinja2 template string inside a SandboxedEnvironment.

if not context_parts:
return f"{base_prompt}\n\nToday's date: {date_today}"
Results are cached by prompt text so that a configuration reload with
a new system prompt produces a fresh compiled template.

system_context = ", ".join(context_parts)
return f"{base_prompt}\n\nToday's date: {date_today}\n\nUser's system: {system_context}"
Args:
prompt: The raw template source string.

Returns:
The compiled Jinja2 Template.

Raises:
TemplateRenderError: If the template contains invalid Jinja2 syntax.
"""
env = SandboxedEnvironment()
try:
return env.from_string(prompt)
except jinja2.TemplateSyntaxError as exc:
raise TemplateRenderError(
f"System prompt contains invalid Jinja2 syntax: {exc}"
) from exc

def _get_base_prompt() -> str:
"""Get the base system prompt with configuration fallback."""
if (
configuration.customization is not None

def _get_prompt_template() -> jinja2.Template:
"""Resolve the system prompt from configuration and return the compiled template.

Delegates to the cached ``_compile_prompt_template`` so that identical
prompt text is compiled only once, while configuration changes are
picked up automatically.

Returns:
The compiled Jinja2 Template ready for rendering.
"""
prompt = (
configuration.customization.system_prompt
if configuration.customization is not None
and configuration.customization.system_prompt is not None
):
return configuration.customization.system_prompt
return constants.DEFAULT_SYSTEM_PROMPT
else constants.DEFAULT_SYSTEM_PROMPT
)
return _compile_prompt_template(prompt)


async def _get_default_model_id() -> str:
Expand Down Expand Up @@ -319,7 +354,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
return inference_time


def _map_inference_error_to_http_exception(
def _map_inference_error_to_http_exception( # pylint: disable=too-many-return-statements
error: Exception, model_id: str, request_id: str
) -> Optional[HTTPException]:
"""Map known inference errors to HTTPException.
Expand All @@ -328,6 +363,13 @@ def _map_inference_error_to_http_exception(
so callers can preserve existing re-raise behavior for unknown runtime
errors.
"""
if isinstance(error, TemplateRenderError):
logger.error(
"Invalid system prompt template for request %s: %s", request_id, error
)
error_response = InternalServerErrorResponse.generic()
return HTTPException(**error_response.model_dump())

if isinstance(error, RuntimeError):
error_message = str(error).lower()
if "context_length" in error_message or "context length" in error_message:
Expand Down Expand Up @@ -398,7 +440,6 @@ async def infer_endpoint( # pylint: disable=R0914
logger.info("Processing rlsapi v1 /infer request %s", request_id)

input_source = infer_request.get_input_source()
instructions = _build_instructions(infer_request.context.systeminfo)
model_id = await _get_default_model_id()
provider, model = extract_provider_and_model_from_model_id(model_id)
mcp_tools: list[Any] = await get_mcp_tools(request_headers=request.headers)
Expand All @@ -408,6 +449,7 @@ async def infer_endpoint( # pylint: disable=R0914

start_time = time.monotonic()
try:
instructions = _build_instructions(infer_request.context.systeminfo)
response_text = await retrieve_simple_response(
input_source,
instructions,
Expand Down
Loading
Loading