|
| 1 | +""" |
| 2 | +Thread-safe singleton interface for the zen-internals Rust shared library. |
| 3 | +Provides a ZenInternal class whose instances are all the same object (singleton), |
| 4 | +so the DLL is searched for and dlopen'd exactly once. |
| 5 | +""" |
| 6 | + |
| 7 | +import ctypes |
| 8 | +import json |
| 9 | +import threading |
| 10 | +from .get_lib_path import get_binary_path |
| 11 | +from .map_dialect_to_rust_int import DIALECTS |
| 12 | +from aikido_zen.helpers.encode_safely import encode_safely |
| 13 | + |
| 14 | + |
| 15 | +class _Singleton(type): |
| 16 | + _instances = {} |
| 17 | + _lock = threading.Lock() |
| 18 | + |
| 19 | + def __call__(cls, *args, **kwargs): |
| 20 | + with cls._lock: |
| 21 | + if cls not in cls._instances: |
| 22 | + cls._instances[cls] = super().__call__(*args, **kwargs) |
| 23 | + return cls._instances[cls] |
| 24 | + |
| 25 | + |
| 26 | +class ZenInternal(metaclass=_Singleton): |
| 27 | + """ |
| 28 | + Thread-safe singleton wrapping the zen-internals Rust shared library. |
| 29 | + Sets up all FFI function signatures once on first instantiation. |
| 30 | + """ |
| 31 | + |
| 32 | + def __init__(self): |
| 33 | + lib = ctypes.CDLL(get_binary_path()) |
| 34 | + |
| 35 | + lib.detect_sql_injection.argtypes = [ |
| 36 | + ctypes.POINTER(ctypes.c_uint8), |
| 37 | + ctypes.c_size_t, |
| 38 | + ctypes.POINTER(ctypes.c_uint8), |
| 39 | + ctypes.c_size_t, |
| 40 | + ctypes.c_int, |
| 41 | + ] |
| 42 | + lib.detect_sql_injection.restype = ctypes.c_int |
| 43 | + |
| 44 | + lib.idor_analyze_sql_ffi.argtypes = [ |
| 45 | + ctypes.POINTER(ctypes.c_uint8), |
| 46 | + ctypes.c_size_t, |
| 47 | + ctypes.c_int, |
| 48 | + ] |
| 49 | + lib.idor_analyze_sql_ffi.restype = ctypes.c_void_p |
| 50 | + |
| 51 | + lib.free_string.argtypes = [ctypes.c_void_p] |
| 52 | + lib.free_string.restype = None |
| 53 | + |
| 54 | + self._lib = lib |
| 55 | + |
| 56 | + def detect_sql_injection(self, query, user_input, dialect): |
| 57 | + """Returns 1 (injection), 2 (error), 3 (tokenize fail), or 0 (clean).""" |
| 58 | + query_bytes = encode_safely(query) |
| 59 | + userinput_bytes = encode_safely(user_input) |
| 60 | + query_buffer = (ctypes.c_uint8 * len(query_bytes)).from_buffer_copy(query_bytes) |
| 61 | + userinput_buffer = (ctypes.c_uint8 * len(userinput_bytes)).from_buffer_copy( |
| 62 | + userinput_bytes |
| 63 | + ) |
| 64 | + dialect_int = DIALECTS[dialect] |
| 65 | + return self._lib.detect_sql_injection( |
| 66 | + query_buffer, |
| 67 | + len(query_bytes), |
| 68 | + userinput_buffer, |
| 69 | + len(userinput_bytes), |
| 70 | + dialect_int, |
| 71 | + ) |
| 72 | + |
| 73 | + def idor_analyze_sql(self, query, dialect): |
| 74 | + """ |
| 75 | + Parses a SQL query and returns a list of statement dicts, |
| 76 | + an error dict, or None if the pointer was null. |
| 77 | + """ |
| 78 | + query_bytes = encode_safely(query) |
| 79 | + query_buffer = (ctypes.c_uint8 * len(query_bytes)).from_buffer_copy(query_bytes) |
| 80 | + dialect_int = DIALECTS[dialect] |
| 81 | + |
| 82 | + result_ptr = self._lib.idor_analyze_sql_ffi( |
| 83 | + query_buffer, |
| 84 | + len(query_bytes), |
| 85 | + dialect_int, |
| 86 | + ) |
| 87 | + |
| 88 | + if not result_ptr: |
| 89 | + return None |
| 90 | + |
| 91 | + result_str = ctypes.string_at(result_ptr).decode("utf-8") |
| 92 | + self._lib.free_string(result_ptr) |
| 93 | + return json.loads(result_str) |
0 commit comments