Skip to content

Commit a86f66a

Browse files
committed
Add custom detectors to the built in detector, with code scanning
1 parent 5258196 commit a86f66a

File tree

7 files changed

+240
-9
lines changed

7 files changed

+240
-9
lines changed

detectors/Dockerfile.builtIn

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ WORKDIR /app
2020
ARG CACHEBUST=1
2121
RUN echo "$CACHEBUST"
2222
COPY ./common /app/detectors/common
23-
COPY ./built_in/* /app
23+
COPY ./built_in/ /app
2424

2525
EXPOSE 8080
2626

detectors/built_in/app.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,47 @@
1+
import logging
2+
13
from fastapi import HTTPException
24
from contextlib import asynccontextmanager
35
from base_detector_registry import BaseDetectorRegistry
46
from regex_detectors import RegexDetectorRegistry
7+
from custom_detectors_wrapper import CustomDetectorRegistry
58
from file_type_detectors import FileTypeDetectorRegistry
69

710
from prometheus_fastapi_instrumentator import Instrumentator
811
from detectors.common.scheme import ContentAnalysisHttpRequest, ContentsAnalysisResponse
912
from detectors.common.app import DetectorBaseAPI as FastAPI
1013

14+
1115
@asynccontextmanager
1216
async def lifespan(app: FastAPI):
1317
app.set_detector(RegexDetectorRegistry(), "regex")
1418
app.set_detector(FileTypeDetectorRegistry(), "file_type")
19+
app.set_detector(CustomDetectorRegistry(), "custom")
1520
yield
1621

1722
app.cleanup_detector()
1823

1924

2025
app = FastAPI(lifespan=lifespan)
2126
Instrumentator().instrument(app).expose(app)
27+
logging.basicConfig(level=logging.INFO)
28+
logger = logging.getLogger(__name__)
2229

2330

24-
# registry : dict[str, BaseDetectorRegistry] = {
25-
# "regex": RegexDetectorRegistry(),
26-
# "file_type": FileTypeDetectorRegistry(),
27-
# }
28-
2931
@app.post("/api/v1/text/contents", response_model=ContentsAnalysisResponse)
3032
def detect_content(request: ContentAnalysisHttpRequest):
33+
logger.info(f"Request for {request.detector_params}")
34+
3135
detections = []
3236
for content in request.contents:
3337
message_detections = []
34-
for detector_kind, detector_registry in app.get_all_detectors().items():
38+
for detector_kind in request.detector_params:
39+
detector_registry = app.get_all_detectors().get(detector_kind)
40+
if detector_registry is None:
41+
raise HTTPException(status_code=400, detail=f"Detector {detector_kind} not found")
3542
if not isinstance(detector_registry, BaseDetectorRegistry):
3643
raise TypeError(f"Detector {detector_kind} is not a valid BaseDetectorRegistry")
37-
if detector_kind in request.detector_params:
44+
else:
3845
try:
3946
message_detections += detector_registry.handle_request(content, request.detector_params)
4047
except HTTPException as e:

detectors/built_in/custom_detectors/__init__.py

Whitespace-only changes.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
def over_100_characters(text: str) -> bool:
3+
return len(text)>100
4+
5+
def contains_word(text: str) -> bool:
6+
return "apple" in text.lower()
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import ast
2+
import os
3+
4+
from fastapi import HTTPException
5+
import inspect
6+
import logging
7+
from typing import List, Optional, Callable
8+
9+
10+
from base_detector_registry import BaseDetectorRegistry
11+
from detectors.common.scheme import ContentAnalysisResponse
12+
13+
logger = logging.getLogger(__name__)
14+
15+
def custom_func_wrapper(func: Callable, func_name: str, s: str) -> Optional[ContentAnalysisResponse]:
16+
"""Convert a some f(text)->bool into a Detector response"""
17+
try:
18+
result = func(s)
19+
except Exception as e:
20+
logging.error(f"Error when computing custom detector function {func_name}: {e}")
21+
raise e
22+
if result:
23+
if isinstance(result, bool):
24+
return ContentAnalysisResponse(
25+
start=0,
26+
end=len(s),
27+
text=s,
28+
detection_type=func_name,
29+
detection=func_name,
30+
score=1.0)
31+
elif isinstance(result, dict):
32+
try:
33+
return ContentAnalysisResponse(**result)
34+
except Exception as e:
35+
logging.error(f"Error when trying to build ContentAnalysisResponse from {func_name} response: {e}")
36+
raise e
37+
else:
38+
msg = f"Unsupported result type for custom detector function {func_name}, must be bool or ContentAnalysisResponse, got: {type(result)}"
39+
logging.error(msg)
40+
raise TypeError(msg)
41+
else:
42+
return None
43+
44+
45+
def static_code_analysis(module_path, forbidden_imports=None, forbidden_calls=None):
46+
"""
47+
Perform static code analysis on a Python module to check for forbidden imports and function calls.
48+
Returns a list of issues found.
49+
"""
50+
if forbidden_imports is None:
51+
forbidden_imports = {"os", "subprocess", "sys", "shutil"}
52+
if forbidden_calls is None:
53+
forbidden_calls = {"eval", "exec", "open", "compile", "input"}
54+
55+
issues = []
56+
with open(module_path, "r") as f:
57+
source = f.read()
58+
try:
59+
tree = ast.parse(source, filename=module_path)
60+
except Exception as e:
61+
issues.append(f"Failed to parse {module_path}: {e}")
62+
return issues
63+
64+
for node in ast.walk(tree):
65+
# Check for forbidden imports
66+
if isinstance(node, ast.Import):
67+
for alias in node.names:
68+
if alias.name.split(".")[0] in forbidden_imports:
69+
issues.append(f"- Forbidden import: {alias.name} (line {node.lineno})")
70+
if isinstance(node, ast.ImportFrom):
71+
if node.module and node.module.split(".")[0] in forbidden_imports:
72+
issues.append(f"- Forbidden import: {node.module} (line {node.lineno})")
73+
# Check for forbidden function calls
74+
if isinstance(node, ast.Call):
75+
func_name = ""
76+
if isinstance(node.func, ast.Name):
77+
func_name = node.func.id
78+
elif isinstance(node.func, ast.Attribute):
79+
func_name = f"{getattr(node.func.value, 'id', '')}.{node.func.attr}"
80+
if func_name in forbidden_calls:
81+
issues.append(f"- Forbidden function call: {func_name} (line {node.lineno})")
82+
return issues
83+
84+
85+
class CustomDetectorRegistry(BaseDetectorRegistry):
86+
def __init__(self):
87+
super().__init__()
88+
89+
issues = static_code_analysis(module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py"))
90+
if issues:
91+
logging.error(f"Detected {len(issues)} potential security issues inside the custom_detectors file: {issues}")
92+
raise ImportError(f"Unsafe code detected in custom_detectors:\n" + "\n".join(issues))
93+
94+
import custom_detectors.custom_detectors as custom_detectors
95+
96+
self.registry = {name: obj for name, obj
97+
in inspect.getmembers(custom_detectors, inspect.isfunction)
98+
if not name.startswith("_")}
99+
logger.info(f"Registered the following custom detectors: {self.registry.keys()}")
100+
101+
def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]:
102+
detections = []
103+
if "custom" in detector_params and isinstance(detector_params["custom"], (list, str)):
104+
custom_functions = detector_params["custom"]
105+
custom_functions = [custom_functions] if isinstance(custom_functions, str) else custom_functions
106+
for custom_function in custom_functions:
107+
if self.registry.get(custom_function):
108+
try:
109+
result = custom_func_wrapper(self.registry[custom_function], custom_function, content)
110+
if result is not None:
111+
detections.append(result)
112+
except Exception as e:
113+
logger.error(e)
114+
raise HTTPException(status_code=400, detail="Detection error, check detector logs")
115+
else:
116+
raise HTTPException(status_code=400, detail=f"Unrecognized custom function: {custom_function}")
117+
return detections
118+
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
markdown==3.8.2
22
jsonschema==4.24.0
3-
xmlschema==4.1.0
3+
xmlschema==4.1.0
4+
requests==2.32.5
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import importlib
2+
import sys
3+
from http.client import HTTPException
4+
5+
import pytest
6+
import os
7+
from fastapi.testclient import TestClient
8+
9+
10+
CUSTOM_DETECTORS_PATH = os.path.join(
11+
os.path.dirname(__file__),
12+
"../../../detectors/built_in/custom_detectors/custom_detectors.py"
13+
)
14+
15+
SAFE_CODE = """
16+
def over_100_characters(text: str) -> bool:
17+
return len(text)>100
18+
19+
def contains_word(text: str) -> bool:
20+
return "apple" in text.lower()
21+
"""
22+
23+
UNSAFE_CODE = '''
24+
import os
25+
def evil(text: str) -> bool:
26+
os.system("echo haha gottem")
27+
return True
28+
'''
29+
30+
31+
def write_code_to_custom_detectors(code: str):
32+
with open(CUSTOM_DETECTORS_PATH, "w") as f:
33+
f.write(code)
34+
35+
def restore_safe_code():
36+
write_code_to_custom_detectors(SAFE_CODE)
37+
38+
39+
class TestCustomDetectors:
40+
@pytest.fixture
41+
def client(self):
42+
from detectors.built_in.app import app
43+
from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry
44+
app.set_detector(CustomDetectorRegistry(), "custom")
45+
return TestClient(app)
46+
47+
@pytest.fixture(autouse=True)
48+
def cleanup_custom_detectors(self):
49+
# Always restore safe code after test
50+
yield
51+
restore_safe_code()
52+
53+
def test_missing_detector_type(self, client):
54+
payload = {
55+
"contents": ["What is an apple?"],
56+
"detector_params": {"custom1": ["contains_word"]}
57+
}
58+
resp = client.post("/api/v1/text/contents", json=payload)
59+
assert resp.status_code == 400 and "Detector custom1 not found" in resp.text
60+
61+
62+
def test_custom_detectors(self, client):
63+
payload = {
64+
"contents": ["What is an apple?"],
65+
"detector_params": {"custom": ["contains_word"]}
66+
}
67+
resp = client.post("/api/v1/text/contents", json=payload)
68+
assert resp.status_code == 200
69+
texts = [d["text"] for d in resp.json()[0]]
70+
assert "What is an apple?" in texts
71+
72+
def test_custom_detectors_not_match(self, client):
73+
msg = "What is an banana?"
74+
payload = {
75+
"contents": [msg],
76+
"detector_params": {"custom": ["contains_word"]}
77+
}
78+
resp = client.post("/api/v1/text/contents", json=payload)
79+
assert resp.status_code == 200
80+
texts = [d["text"] for d in resp.json()[0]]
81+
assert msg not in texts
82+
83+
def test_unsafe_code(self, client):
84+
write_code_to_custom_detectors(UNSAFE_CODE)
85+
from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry
86+
with pytest.raises(ImportError) as excinfo:
87+
CustomDetectorRegistry()
88+
assert "Unsafe code detected" in str(excinfo.value)
89+
assert "Forbidden import: os" in str(excinfo.value) or "os.system" in str(excinfo.value)
90+
91+
92+
def test_custom_detectors_func_doesnt_exist(self, client):
93+
payload = {
94+
"contents": ["What is an apple?"],
95+
"detector_params": {"custom": ["abc"]}
96+
}
97+
resp = client.post("/api/v1/text/contents", json=payload)
98+
assert resp.status_code == 400 and "Unrecognized custom function: abc" in resp.text
99+

0 commit comments

Comments
 (0)