|
1 | 1 | import ast |
| 2 | +import logging |
| 3 | +import importlib.util |
| 4 | +import inspect |
| 5 | +import functools |
2 | 6 | import os |
3 | | -import traceback |
| 7 | +import sys |
4 | 8 |
|
| 9 | +from concurrent.futures import ThreadPoolExecutor |
5 | 10 | from fastapi import HTTPException |
6 | | -import inspect |
7 | | -import logging |
8 | 11 | from typing import List, Optional, Callable |
9 | 12 |
|
10 | | - |
11 | 13 | from base_detector_registry import BaseDetectorRegistry |
| 14 | +from detectors.common.app import METRIC_PREFIX |
12 | 15 | from detectors.common.scheme import ContentAnalysisResponse |
13 | 16 |
|
14 | 17 | logger = logging.getLogger(__name__) |
15 | 18 |
|
| 19 | +def use_instruments(instruments: List): |
| 20 | + """Use this decorator to register the provided Prometheus instruments with the main /metrics endpoint""" |
| 21 | + def inner_layer_1(func): |
| 22 | + @functools.wraps(func) |
| 23 | + def inner_layer_2(*args, **kwargs): |
| 24 | + return func(*args, **kwargs) |
| 25 | + |
| 26 | + # check to see if "func" is already decorated, and only add the prometheus instruments field into the original function |
| 27 | + target = get_underlying_function(func) |
| 28 | + setattr(target, "prometheus_instruments", instruments) |
| 29 | + return inner_layer_2 |
| 30 | + return inner_layer_1 |
| 31 | + |
| 32 | +def non_blocking(return_value): |
| 33 | + """ |
| 34 | + Use this decorator to run the guardrail as a non-blocking background thread. |
| 35 | +
|
| 36 | + The `return_value` is returned instantly to the caller of the /api/v1/text/contents, while |
| 37 | + the logic inside the function will run asynchronously in the background. |
| 38 | + """ |
| 39 | + def inner_layer_1(func): |
| 40 | + @functools.wraps(func) |
| 41 | + def inner_layer_2(*args, **kwargs): |
| 42 | + executor = getattr(non_blocking, "_executor", None) |
| 43 | + if executor is None: |
| 44 | + executor = ThreadPoolExecutor() |
| 45 | + non_blocking._executor = executor |
| 46 | + def runner(): |
| 47 | + try: |
| 48 | + func(*args, **kwargs) |
| 49 | + except Exception as e: |
| 50 | + logging.error(f"Exception in non-blocking guardrail {func.__name__}: {e}") |
| 51 | + executor.submit(runner) |
| 52 | + |
| 53 | + # check to see if "func" is already decorated by `use_instruments`, and grab the instruments if so |
| 54 | + target = get_underlying_function(func) |
| 55 | + if hasattr(target, "prometheus_instruments"): |
| 56 | + setattr(target, "prometheus_instruments", target.prometheus_instruments) |
| 57 | + return return_value |
| 58 | + return inner_layer_2 |
| 59 | + return inner_layer_1 |
| 60 | + |
| 61 | +forbidden_names = [use_instruments.__name__, non_blocking.__name__] |
| 62 | + |
| 63 | +def get_underlying_function(func): |
| 64 | + if hasattr(func, "__wrapped__"): |
| 65 | + return get_underlying_function(func.__wrapped__) |
| 66 | + return func |
| 67 | + |
| 68 | + |
16 | 69 | def custom_func_wrapper(func: Callable, func_name: str, s: str, headers: dict) -> Optional[ContentAnalysisResponse]: |
17 | 70 | """Convert a some f(text)->bool into a Detector response""" |
18 | 71 | sig = inspect.signature(func) |
@@ -92,17 +145,42 @@ class CustomDetectorRegistry(BaseDetectorRegistry): |
92 | 145 | def __init__(self): |
93 | 146 | super().__init__("custom") |
94 | 147 |
|
| 148 | + # check the imported code for potential security issues |
95 | 149 | issues = static_code_analysis(module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py")) |
96 | 150 | if issues: |
97 | 151 | logging.error(f"Detected {len(issues)} potential security issues inside the custom_detectors file: {issues}") |
98 | 152 | raise ImportError(f"Unsafe code detected in custom_detectors:\n" + "\n".join(issues)) |
99 | 153 |
|
100 | | - import custom_detectors.custom_detectors as custom_detectors |
| 154 | + # grab custom detectors module |
| 155 | + module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py") |
| 156 | + spec = importlib.util.spec_from_file_location("custom_detectors.custom_detectors", module_path) |
| 157 | + custom_detectors = importlib.util.module_from_spec(spec) |
| 158 | + |
| 159 | + # inject any user utility functions into the code automatically |
| 160 | + inject_imports = { |
| 161 | + "use_instruments": use_instruments, |
| 162 | + "non_blocking": non_blocking, |
| 163 | + } |
| 164 | + for name, mod in inject_imports.items(): |
| 165 | + setattr(custom_detectors, name, mod) |
| 166 | + |
| 167 | + # load the module |
| 168 | + sys.modules["custom_detectors.custom_detectors"] = custom_detectors |
| 169 | + spec.loader.exec_module(custom_detectors) |
101 | 170 |
|
102 | 171 | self.registry = {name: obj for name, obj |
103 | 172 | in inspect.getmembers(custom_detectors, inspect.isfunction) |
104 | | - if not name.startswith("_")} |
| 173 | + if not name.startswith("_") and name not in forbidden_names} |
105 | 174 | self.function_needs_headers = {name: "headers" in inspect.signature(obj).parameters for name, obj in self.registry.items() } |
| 175 | + |
| 176 | + # check if functions have requested user prometheus metrics |
| 177 | + for name, func in self.registry.items(): |
| 178 | + target = get_underlying_function(func) |
| 179 | + if getattr(target, "prometheus_instruments", False): |
| 180 | + instruments = target.prometheus_instruments |
| 181 | + for instrument in instruments: |
| 182 | + super().add_instrument(instrument) |
| 183 | + |
106 | 184 | logger.info(f"Registered the following custom detectors: {self.registry.keys()}") |
107 | 185 |
|
108 | 186 |
|
|
0 commit comments