Skip to content

Commit cbe5b63

Browse files
authored
Merge pull request #57 from RobGeada/MetricsAndNonBlockingGuardrails
Feat: Custom user metrics and non-blocking guardrails
2 parents d3a34fd + dc2f360 commit cbe5b63

File tree

12 files changed

+307
-78
lines changed

12 files changed

+307
-78
lines changed

detectors/Dockerfile.builtIn

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ ARG CACHEBUST=1
2121
RUN echo "$CACHEBUST"
2222
COPY ./common /app/detectors/common
2323
COPY ./built_in/ /app
24+
ENV PROMETHEUS_MULTIPROC_DIR="/tmp/prometheus_multiproc_dir"
25+
RUN mkdir -p $PROMETHEUS_MULTIPROC_DIR && chmod 777 $PROMETHEUS_MULTIPROC_DIR
2426

2527
EXPOSE 8080
2628

2729
# for backwards compatibility with existing k8s deployment configs
2830
RUN mkdir /app/bin &&\
29-
echo '#!/bin/bash' > /app/bin/regex-detector &&\
31+
echo '#!/bin/bash' > /app/bin/regex-detector &&\
3032
echo "uvicorn app:app --workers 4 --host 0.0.0.0 --port 8080 --log-config /app/detectors/common/log_conf.yaml" >> /app/bin/regex-detector &&\
3133
chmod +x /app/bin/regex-detector
3234
CMD ["/app/bin/regex-detector"]

detectors/built_in/app.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from custom_detectors_wrapper import CustomDetectorRegistry
88
from file_type_detectors import FileTypeDetectorRegistry
99

10-
from prometheus_fastapi_instrumentator import Instrumentator
11-
from prometheus_client import Gauge
10+
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST, CollectorRegistry, multiprocess
11+
from starlette.responses import Response
1212
from detectors.common.scheme import ContentAnalysisHttpRequest, ContentsAnalysisResponse
1313
from detectors.common.app import DetectorBaseAPI as FastAPI
1414

@@ -21,17 +21,23 @@ async def lifespan(app: FastAPI):
2121
CustomDetectorRegistry()
2222
]:
2323
app.set_detector(detector_registry, detector_registry.registry_name)
24-
detector_registry.add_instruments(app.state.instruments)
24+
detector_registry.set_instruments(app.state.instruments)
2525
yield
2626
app.cleanup_detector()
2727

2828

2929
app = FastAPI(lifespan=lifespan)
30-
Instrumentator().instrument(app).expose(app)
3130
logging.basicConfig(level=logging.INFO)
3231
logger = logging.getLogger(__name__)
3332

3433

34+
@app.get("/metrics")
35+
def metrics():
36+
registry = CollectorRegistry()
37+
multiprocess.MultiProcessCollector(registry)
38+
data = generate_latest(registry)
39+
return Response(data, media_type=CONTENT_TYPE_LATEST)
40+
3541
@app.post("/api/v1/text/contents", response_model=ContentsAnalysisResponse)
3642
def detect_content(request: ContentAnalysisHttpRequest, raw_request: Request):
3743
logger.info(f"Request for {request.detector_params}")
Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,36 @@
11
"""
2-
This is an example custom_detectors.py file. Here, you can define any arbitrary Python code as a
3-
Guardrail detector.
2+
This is an example custom_detectors.py file. Overwrite this file to define custom guardrailing
3+
logic!
44
5-
The following rules apply:
6-
1) Each function defined in this file (except for those starting with "_") will be registered as a detector
7-
2) Functions that accept a parameter "headers" will receive the inbound request headers as a parameter
8-
3) Functions may either return a boolean or a dict:
9-
3a) Return values that evaluate to false (e.g., {}, "", None, etc) are treated as non-detections
10-
3b) Boolean responses of "true" are considered a detection
11-
3c) Dict response must be parseable as a ContentAnalysisResponse object (see example below)
12-
4) This code may not import "os", "subprocess", "sys", or "shutil" for security reasons
13-
5) This code may not call "eval", "exec", "open", "compile", or "input" for security reasons
5+
See [docs/custom_detectors.md](../../docs/custom_detectors.md) for more details.
146
"""
157

16-
# example boolean-returning function
17-
def over_100_characters(text: str) -> bool:
18-
return len(text)>100
19-
20-
# example dict-returning function
21-
def contains_word(text: str) -> dict:
22-
detection = "apple" in text.lower()
23-
if detection:
24-
detection_position = text.find("apple")
25-
return {
26-
"start":detection_position, # start position of detection in text
27-
"end": detection_position+5, # end position of detection in text
28-
"text": text, # "the flagged text, or some arbitrary message to return to the user"
29-
"detection_type": "content_check", #detection_type -> use these fields to define your detector taxonomy as you see fit
30-
"detection": "forbidden_word: apple", ##detection -> use these fields to define your detector taxonomy as you see fit
31-
"score": 1.0 # the score/severity/probability of the detection
32-
}
33-
else:
34-
return {}
35-
36-
def _this_function_will_not_be_exposed():
37-
pass
38-
39-
def function_that_needs_headers(text: str, headers: dict) -> bool:
40-
return headers['magic-key'] != "123"
8+
import time
9+
def slow_func(text: str) -> bool:
10+
time.sleep(.25)
11+
return False
12+
13+
from prometheus_client import Counter
4114

15+
prompt_rejection_counter = Counter(
16+
"trustyai_guardrails_system_prompt_rejections",
17+
"Number of rejections by the system prompt",
18+
)
19+
20+
@use_instruments(instruments=[prompt_rejection_counter])
21+
def has_metrics(text: str) -> bool:
22+
if "sorry" in text:
23+
prompt_rejection_counter.inc()
24+
return False
25+
26+
background_metric = Counter(
27+
"trustyai_guardrails_background_metric",
28+
"Runs some logic in the background without blocking the /detections call"
29+
)
30+
@use_instruments(instruments=[background_metric])
31+
@non_blocking(return_value=False)
32+
def background_function(text: str) -> bool:
33+
time.sleep(.25)
34+
if "sorry" in text:
35+
background_metric.inc()
36+
return False

detectors/built_in/custom_detectors_wrapper.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,71 @@
11
import ast
2+
import logging
3+
import importlib.util
4+
import inspect
5+
import functools
26
import os
3-
import traceback
7+
import sys
48

9+
from concurrent.futures import ThreadPoolExecutor
510
from fastapi import HTTPException
6-
import inspect
7-
import logging
811
from typing import List, Optional, Callable
912

10-
1113
from base_detector_registry import BaseDetectorRegistry
14+
from detectors.common.app import METRIC_PREFIX
1215
from detectors.common.scheme import ContentAnalysisResponse
1316

1417
logger = logging.getLogger(__name__)
1518

19+
def use_instruments(instruments: List):
20+
"""Use this decorator to register the provided Prometheus instruments with the main /metrics endpoint"""
21+
def inner_layer_1(func):
22+
@functools.wraps(func)
23+
def inner_layer_2(*args, **kwargs):
24+
return func(*args, **kwargs)
25+
26+
# check to see if "func" is already decorated, and only add the prometheus instruments field into the original function
27+
target = get_underlying_function(func)
28+
setattr(target, "prometheus_instruments", instruments)
29+
return inner_layer_2
30+
return inner_layer_1
31+
32+
def non_blocking(return_value):
33+
"""
34+
Use this decorator to run the guardrail as a non-blocking background thread.
35+
36+
The `return_value` is returned instantly to the caller of the /api/v1/text/contents, while
37+
the logic inside the function will run asynchronously in the background.
38+
"""
39+
def inner_layer_1(func):
40+
@functools.wraps(func)
41+
def inner_layer_2(*args, **kwargs):
42+
executor = getattr(non_blocking, "_executor", None)
43+
if executor is None:
44+
executor = ThreadPoolExecutor()
45+
non_blocking._executor = executor
46+
def runner():
47+
try:
48+
func(*args, **kwargs)
49+
except Exception as e:
50+
logging.error(f"Exception in non-blocking guardrail {func.__name__}: {e}")
51+
executor.submit(runner)
52+
53+
# check to see if "func" is already decorated by `use_instruments`, and grab the instruments if so
54+
target = get_underlying_function(func)
55+
if hasattr(target, "prometheus_instruments"):
56+
setattr(target, "prometheus_instruments", target.prometheus_instruments)
57+
return return_value
58+
return inner_layer_2
59+
return inner_layer_1
60+
61+
forbidden_names = [use_instruments.__name__, non_blocking.__name__]
62+
63+
def get_underlying_function(func):
64+
if hasattr(func, "__wrapped__"):
65+
return get_underlying_function(func.__wrapped__)
66+
return func
67+
68+
1669
def custom_func_wrapper(func: Callable, func_name: str, s: str, headers: dict) -> Optional[ContentAnalysisResponse]:
1770
"""Convert a some f(text)->bool into a Detector response"""
1871
sig = inspect.signature(func)
@@ -92,17 +145,42 @@ class CustomDetectorRegistry(BaseDetectorRegistry):
92145
def __init__(self):
93146
super().__init__("custom")
94147

148+
# check the imported code for potential security issues
95149
issues = static_code_analysis(module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py"))
96150
if issues:
97151
logging.error(f"Detected {len(issues)} potential security issues inside the custom_detectors file: {issues}")
98152
raise ImportError(f"Unsafe code detected in custom_detectors:\n" + "\n".join(issues))
99153

100-
import custom_detectors.custom_detectors as custom_detectors
154+
# grab custom detectors module
155+
module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py")
156+
spec = importlib.util.spec_from_file_location("custom_detectors.custom_detectors", module_path)
157+
custom_detectors = importlib.util.module_from_spec(spec)
158+
159+
# inject any user utility functions into the code automatically
160+
inject_imports = {
161+
"use_instruments": use_instruments,
162+
"non_blocking": non_blocking,
163+
}
164+
for name, mod in inject_imports.items():
165+
setattr(custom_detectors, name, mod)
166+
167+
# load the module
168+
sys.modules["custom_detectors.custom_detectors"] = custom_detectors
169+
spec.loader.exec_module(custom_detectors)
101170

102171
self.registry = {name: obj for name, obj
103172
in inspect.getmembers(custom_detectors, inspect.isfunction)
104-
if not name.startswith("_")}
173+
if not name.startswith("_") and name not in forbidden_names}
105174
self.function_needs_headers = {name: "headers" in inspect.signature(obj).parameters for name, obj in self.registry.items() }
175+
176+
# check if functions have requested user prometheus metrics
177+
for name, func in self.registry.items():
178+
target = get_underlying_function(func)
179+
if getattr(target, "prometheus_instruments", False):
180+
instruments = target.prometheus_instruments
181+
for instrument in instruments:
182+
super().add_instrument(instrument)
183+
106184
logger.info(f"Registered the following custom detectors: {self.registry.keys()}")
107185

108186

detectors/common/app.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import yaml
88
from fastapi.exceptions import RequestValidationError
99
from fastapi.responses import JSONResponse
10-
from prometheus_client import Counter
10+
from prometheus_client import Counter, CollectorRegistry
1111

1212
import logging
1313

@@ -28,29 +28,30 @@
2828
dependencies=[],
2929
)
3030

31+
METRIC_PREFIX = "trustyai_guardrails"
3132

3233
class DetectorBaseAPI(FastAPI):
3334
def __init__(self, *args, **kwargs):
3435
super().__init__(*args, **kwargs)
3536
self.state.detectors = {}
3637
self.state.instruments = {
3738
"detections": Counter(
38-
"trustyai_guardrails_detections",
39+
f"{METRIC_PREFIX}_detections",
3940
"Number of detections per detector function",
4041
["detector_kind", "detector_name"]
4142
),
4243
"requests": Counter(
43-
"trustyai_guardrails_requests",
44+
f"{METRIC_PREFIX}_requests",
4445
"Number of requests per detector function",
4546
["detector_kind", "detector_name"]
4647
),
4748
"errors": Counter(
48-
"trustyai_guardrails_errors",
49+
f"{METRIC_PREFIX}_errors",
4950
"Number of errors per detector function",
5051
["detector_kind", "detector_name"]
5152
),
5253
"runtime": Counter(
53-
"trustyai_guardrails_runtime",
54+
f"{METRIC_PREFIX}_runtime",
5455
"Total runtime of a detector function- this is the induced latency of this guardrail",
5556
["detector_kind", "detector_name"]
5657
)
@@ -62,7 +63,6 @@ def __init__(self, *args, **kwargs):
6263
self.add_exception_handler(StarletteHTTPException, self.http_exception_handler)
6364
self.add_api_route("/health", health, description="Check if server is alive")
6465

65-
6666
async def validation_exception_handler(self, request, exc):
6767
errors = exc.errors()
6868
if len(errors) > 0 and errors[0]["type"] == "missing":

detectors/common/instrumented_detector.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ def instrument_runtime(self, function_name: str):
1717
finally:
1818
pass
1919

20-
def add_instruments(self, gauges):
21-
self.instruments = gauges
20+
def set_instruments(self, instruments):
21+
self.instruments = instruments
22+
23+
def add_instrument(self, instrument):
24+
self.instruments[instrument._name] = instrument
2225

2326
def increment_detector_instruments(self, function_name: str, is_detection: bool):
2427
"""Increment the detection and request counters, automatically update rates"""

detectors/huggingface/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
async def lifespan(app: FastAPI):
1717
detector = Detector()
1818
app.set_detector(detector, detector.model_name)
19-
detector.add_instruments(app.state.instruments)
19+
detector.set_instruments(app.state.instruments)
2020
yield
2121
# Clean up the ML models and release the resources
2222
detector: Detector = app.get_detector()

docs/custom_detectors.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Custom Detectors
2+
You can overwrite [custom_detectors.py](detectors/built_in/custom_detectors/custom_detectors.py) to create
3+
custom detectors in the built-in server based off of arbitrary Python code. This lets you quickly and flexibly
4+
create your own detection logic!
5+
6+
The following rules apply:
7+
1) Each function defined in the `custom_detectors.py` file (except for those starting with "_") will be registered as a detector
8+
2) Functions that accept a parameter `headers` will receive the inbound request headers as a parameter
9+
* see the `function_that_needs_headers` example in [custom_detectors.py](detectors/built_in/custom_detectors/custom_detectors.py) for usage
10+
3) Functions that are intended to be used as detectors must either return a `bool` or a `dict`:
11+
1) Return values that evaluate to false (e.g., `{}`, `""`, `None`, etc) are treated as non-detections
12+
2) Boolean responses of `true` are considered a detection
13+
* see the `over_100_characters` example in [custom_detectors.py](detectors/built_in/custom_detectors/custom_detectors.py) for usage
14+
3) Dict response that are parseable as a `ContentAnalysisResponse` object are considered a detection
15+
* see the `contains_word` example in [custom_detectors.py](detectors/built_in/custom_detectors/custom_detectors.py) for usage
16+
4) This code may not import `os`, `subprocess`, `sys`, or `shutil` for security reasons
17+
5) This code may not call `eval`, `exec`, `open`, `compile`, or `input` for security reasons
18+
19+
20+
## Utility Decorators
21+
The following decorators are also available, and are automatically imported into the custom_detectors.py file:
22+
23+
### `@use_instruments(instruments=[$INSTRUMENT_1, ..., $INSTRUMENT_N])`
24+
Use this decorator to register your own Prometheus instruments with the server's main
25+
`/metrics` registry. See the `function_that_has_prometheus_metrics` example
26+
in [custom_detectors.py](detectors/built_in/custom_detectors/custom_detectors.py) for usage.
27+
28+
### `@non_blocking(return_value=$RETURN_VALUE)`
29+
Use this decorator to indicate that the logic inside this function should run in a non-blocking
30+
background thread. The guardrail function will immediately return $RETURN_VALUE while launching
31+
your function logic into a background thread.
32+
33+
This enables a number of use-cases, such as:
34+
* Producing some background analysis metric over the input/output, without adding latency to the system
35+
* Performing "silent" guardrailing, e.g., adding information to a server or alerting admins
36+
37+
See the `background_function` example in
38+
[custom_detectors.py](detectors/built_in/custom_detectors/custom_detectors.py) for usage.
39+
40+
## More Examples
41+
For a "real-world" example, check out the [TrustyAI custom detectors demo](https://github.com/trustyai-explainability/trustyai-llm-demo/blob/main/custom-detectors/custom_detectors.py)!

tests/conftest.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
2+
import shutil
23
import sys
34
import pytest
5+
import tempfile
46

57

68
@pytest.fixture(autouse=True)
@@ -24,3 +26,22 @@ def setup_imports():
2426
sys.path.insert(0, path)
2527
print(f"Added to sys.path: {path}")
2628

29+
@pytest.fixture(scope="session", autouse=True)
30+
def prometheus_multiproc_dir():
31+
"""
32+
Create a temporary directory for PROMETHEUS_MULTIPROC_DIR and set the environment variable.
33+
"""
34+
tmpdir = tempfile.mkdtemp(prefix="prometheus_multiproc_")
35+
os.environ["PROMETHEUS_MULTIPROC_DIR"] = tmpdir
36+
yield tmpdir
37+
# Cleanup will be handled by the next fixture
38+
39+
@pytest.fixture(scope="session", autouse=True)
40+
def cleanup_prometheus_multiproc_dir(request, prometheus_multiproc_dir):
41+
"""
42+
Cleanup the PROMETHEUS_MULTIPROC_DIR after the test session.
43+
"""
44+
yield
45+
shutil.rmtree(prometheus_multiproc_dir, ignore_errors=True)
46+
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
47+
del os.environ["PROMETHEUS_MULTIPROC_DIR"]

0 commit comments

Comments
 (0)