Skip to content

Commit c0551e8

Browse files
committed
Create zen_internals_lib helper, which also has a Singleton
1 parent d2fc13d commit c0551e8

4 files changed

Lines changed: 95 additions & 30 deletions

File tree

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Thread-safe singleton interface for the zen-internals Rust shared library.
3+
Provides a ZenInternal class whose instances are all the same object (singleton),
4+
so the DLL is searched for and dlopen'd exactly once.
5+
"""
6+
7+
import ctypes
8+
import json
9+
import threading
10+
from .get_lib_path import get_binary_path
11+
from .map_dialect_to_rust_int import DIALECTS
12+
from aikido_zen.helpers.encode_safely import encode_safely
13+
14+
15+
class _Singleton(type):
16+
_instances = {}
17+
_lock = threading.Lock()
18+
19+
def __call__(cls, *args, **kwargs):
20+
with cls._lock:
21+
if cls not in cls._instances:
22+
cls._instances[cls] = super().__call__(*args, **kwargs)
23+
return cls._instances[cls]
24+
25+
26+
class ZenInternal(metaclass=_Singleton):
27+
"""
28+
Thread-safe singleton wrapping the zen-internals Rust shared library.
29+
Sets up all FFI function signatures once on first instantiation.
30+
"""
31+
32+
def __init__(self):
33+
lib = ctypes.CDLL(get_binary_path())
34+
35+
lib.detect_sql_injection.argtypes = [
36+
ctypes.POINTER(ctypes.c_uint8),
37+
ctypes.c_size_t,
38+
ctypes.POINTER(ctypes.c_uint8),
39+
ctypes.c_size_t,
40+
ctypes.c_int,
41+
]
42+
lib.detect_sql_injection.restype = ctypes.c_int
43+
44+
lib.idor_analyze_sql_ffi.argtypes = [
45+
ctypes.POINTER(ctypes.c_uint8),
46+
ctypes.c_size_t,
47+
ctypes.c_int,
48+
]
49+
lib.idor_analyze_sql_ffi.restype = ctypes.c_void_p
50+
51+
lib.free_string.argtypes = [ctypes.c_void_p]
52+
lib.free_string.restype = None
53+
54+
self._lib = lib
55+
56+
def detect_sql_injection(self, query, user_input, dialect):
57+
"""Returns 1 (injection), 2 (error), 3 (tokenize fail), or 0 (clean)."""
58+
query_bytes = encode_safely(query)
59+
userinput_bytes = encode_safely(user_input)
60+
query_buffer = (ctypes.c_uint8 * len(query_bytes)).from_buffer_copy(query_bytes)
61+
userinput_buffer = (ctypes.c_uint8 * len(userinput_bytes)).from_buffer_copy(
62+
userinput_bytes
63+
)
64+
dialect_int = DIALECTS[dialect]
65+
return self._lib.detect_sql_injection(
66+
query_buffer,
67+
len(query_bytes),
68+
userinput_buffer,
69+
len(userinput_bytes),
70+
dialect_int,
71+
)
72+
73+
def idor_analyze_sql(self, query, dialect):
74+
"""
75+
Parses a SQL query and returns a list of statement dicts,
76+
an error dict, or None if the pointer was null.
77+
"""
78+
query_bytes = encode_safely(query)
79+
query_buffer = (ctypes.c_uint8 * len(query_bytes)).from_buffer_copy(query_bytes)
80+
dialect_int = DIALECTS[dialect]
81+
82+
result_ptr = self._lib.idor_analyze_sql_ffi(
83+
query_buffer,
84+
len(query_bytes),
85+
dialect_int,
86+
)
87+
88+
if not result_ptr:
89+
return None
90+
91+
result_str = ctypes.string_at(result_ptr).decode("utf-8")
92+
self._lib.free_string(result_ptr)
93+
return json.loads(result_str)

aikido_zen/vulnerabilities/sql_injection/get_lib_path.py renamed to aikido_zen/helpers/zen_internals_lib/get_lib_path.py

File renamed without changes.

aikido_zen/vulnerabilities/sql_injection/map_dialect_to_rust_int.py renamed to aikido_zen/helpers/zen_internals_lib/map_dialect_to_rust_int.py

File renamed without changes.

aikido_zen/vulnerabilities/sql_injection/__init__.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
"""
44

55
import re
6-
import ctypes
76
from aikido_zen.helpers.logging import logger
8-
from .map_dialect_to_rust_int import map_dialect_to_rust_int
9-
from .get_lib_path import get_binary_path
10-
from ...helpers.encode_safely import encode_safely
7+
from aikido_zen.helpers.zen_internals_lib import ZenInternal
118

129

1310
def detect_sql_injection(query, user_input, dialect):
@@ -20,32 +17,7 @@ def detect_sql_injection(query, user_input, dialect):
2017
if should_return_early(query_l, userinput_l):
2118
return False
2219

23-
internals_lib = ctypes.CDLL(get_binary_path())
24-
internals_lib.detect_sql_injection.argtypes = [
25-
ctypes.POINTER(ctypes.c_uint8),
26-
ctypes.c_size_t,
27-
ctypes.POINTER(ctypes.c_uint8),
28-
ctypes.c_size_t,
29-
ctypes.c_int,
30-
]
31-
internals_lib.detect_sql_injection.restype = ctypes.c_int
32-
33-
# Parse input variables for rust function
34-
query_bytes = encode_safely(query_l)
35-
userinput_bytes = encode_safely(userinput_l)
36-
query_buffer = (ctypes.c_uint8 * len(query_bytes)).from_buffer_copy(query_bytes)
37-
userinput_buffer = (ctypes.c_uint8 * len(userinput_bytes)).from_buffer_copy(
38-
userinput_bytes
39-
)
40-
dialect_int = map_dialect_to_rust_int(dialect)
41-
42-
c_int_res = internals_lib.detect_sql_injection(
43-
query_buffer,
44-
len(query_bytes),
45-
userinput_buffer,
46-
len(userinput_bytes),
47-
dialect_int,
48-
)
20+
c_int_res = ZenInternal().detect_sql_injection(query_l, userinput_l, dialect)
4921

5022
# This means that an error occurred in the library
5123
if c_int_res == 2:

0 commit comments

Comments
 (0)