|
| 1 | +import ast |
| 2 | +import os |
| 3 | + |
| 4 | +from fastapi import HTTPException |
| 5 | +import inspect |
| 6 | +import logging |
| 7 | +from typing import List, Optional, Callable |
| 8 | + |
| 9 | + |
| 10 | +from base_detector_registry import BaseDetectorRegistry |
| 11 | +from detectors.common.scheme import ContentAnalysisResponse |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | +def custom_func_wrapper(func: Callable, func_name: str, s: str) -> Optional[ContentAnalysisResponse]: |
| 16 | + """Convert a some f(text)->bool into a Detector response""" |
| 17 | + try: |
| 18 | + result = func(s) |
| 19 | + except Exception as e: |
| 20 | + logging.error(f"Error when computing custom detector function {func_name}: {e}") |
| 21 | + raise e |
| 22 | + if result: |
| 23 | + if isinstance(result, bool): |
| 24 | + return ContentAnalysisResponse( |
| 25 | + start=0, |
| 26 | + end=len(s), |
| 27 | + text=s, |
| 28 | + detection_type=func_name, |
| 29 | + detection=func_name, |
| 30 | + score=1.0) |
| 31 | + elif isinstance(result, dict): |
| 32 | + try: |
| 33 | + return ContentAnalysisResponse(**result) |
| 34 | + except Exception as e: |
| 35 | + logging.error(f"Error when trying to build ContentAnalysisResponse from {func_name} response: {e}") |
| 36 | + raise e |
| 37 | + else: |
| 38 | + msg = f"Unsupported result type for custom detector function {func_name}, must be bool or ContentAnalysisResponse, got: {type(result)}" |
| 39 | + logging.error(msg) |
| 40 | + raise TypeError(msg) |
| 41 | + else: |
| 42 | + return None |
| 43 | + |
| 44 | + |
| 45 | +def static_code_analysis(module_path, forbidden_imports=None, forbidden_calls=None): |
| 46 | + """ |
| 47 | + Perform static code analysis on a Python module to check for forbidden imports and function calls. |
| 48 | + Returns a list of issues found. |
| 49 | + """ |
| 50 | + if forbidden_imports is None: |
| 51 | + forbidden_imports = {"os", "subprocess", "sys", "shutil"} |
| 52 | + if forbidden_calls is None: |
| 53 | + forbidden_calls = {"eval", "exec", "open", "compile", "input"} |
| 54 | + |
| 55 | + issues = [] |
| 56 | + with open(module_path, "r") as f: |
| 57 | + source = f.read() |
| 58 | + try: |
| 59 | + tree = ast.parse(source, filename=module_path) |
| 60 | + except Exception as e: |
| 61 | + issues.append(f"Failed to parse {module_path}: {e}") |
| 62 | + return issues |
| 63 | + |
| 64 | + for node in ast.walk(tree): |
| 65 | + # Check for forbidden imports |
| 66 | + if isinstance(node, ast.Import): |
| 67 | + for alias in node.names: |
| 68 | + if alias.name.split(".")[0] in forbidden_imports: |
| 69 | + issues.append(f"- Forbidden import: {alias.name} (line {node.lineno})") |
| 70 | + if isinstance(node, ast.ImportFrom): |
| 71 | + if node.module and node.module.split(".")[0] in forbidden_imports: |
| 72 | + issues.append(f"- Forbidden import: {node.module} (line {node.lineno})") |
| 73 | + # Check for forbidden function calls |
| 74 | + if isinstance(node, ast.Call): |
| 75 | + func_name = "" |
| 76 | + if isinstance(node.func, ast.Name): |
| 77 | + func_name = node.func.id |
| 78 | + elif isinstance(node.func, ast.Attribute): |
| 79 | + func_name = f"{getattr(node.func.value, 'id', '')}.{node.func.attr}" |
| 80 | + if func_name in forbidden_calls: |
| 81 | + issues.append(f"- Forbidden function call: {func_name} (line {node.lineno})") |
| 82 | + return issues |
| 83 | + |
| 84 | + |
| 85 | +class CustomDetectorRegistry(BaseDetectorRegistry): |
| 86 | + def __init__(self): |
| 87 | + super().__init__() |
| 88 | + |
| 89 | + issues = static_code_analysis(module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py")) |
| 90 | + if issues: |
| 91 | + logging.error(f"Detected {len(issues)} potential security issues inside the custom_detectors file: {issues}") |
| 92 | + raise ImportError(f"Unsafe code detected in custom_detectors:\n" + "\n".join(issues)) |
| 93 | + |
| 94 | + import custom_detectors.custom_detectors as custom_detectors |
| 95 | + |
| 96 | + self.registry = {name: obj for name, obj |
| 97 | + in inspect.getmembers(custom_detectors, inspect.isfunction) |
| 98 | + if not name.startswith("_")} |
| 99 | + logger.info(f"Registered the following custom detectors: {self.registry.keys()}") |
| 100 | + |
| 101 | + def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]: |
| 102 | + detections = [] |
| 103 | + if "custom" in detector_params and isinstance(detector_params["custom"], (list, str)): |
| 104 | + custom_functions = detector_params["custom"] |
| 105 | + custom_functions = [custom_functions] if isinstance(custom_functions, str) else custom_functions |
| 106 | + for custom_function in custom_functions: |
| 107 | + if self.registry.get(custom_function): |
| 108 | + try: |
| 109 | + result = custom_func_wrapper(self.registry[custom_function], custom_function, content) |
| 110 | + if result is not None: |
| 111 | + detections.append(result) |
| 112 | + except Exception as e: |
| 113 | + logger.error(e) |
| 114 | + raise HTTPException(status_code=400, detail="Detection error, check detector logs") |
| 115 | + else: |
| 116 | + raise HTTPException(status_code=400, detail=f"Unrecognized custom function: {custom_function}") |
| 117 | + return detections |
| 118 | + |
0 commit comments