Skip to content

Commit c35ad68

Browse files
committed
Make sure the @on_import is always at the end of the file
1 parent 8dacd34 commit c35ad68

5 files changed

Lines changed: 87 additions & 87 deletions

File tree

aikido_zen/sinks/asyncpg.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
from aikido_zen.sinks import patch_function, before, on_import
88

99

10+
@before
11+
def _execute(func, instance, args, kwargs):
12+
query = get_argument(args, kwargs, 0, "query")
13+
14+
op = f"asyncpg.connection.Connection.{func.__name__}"
15+
vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres"))
16+
17+
1018
@on_import("asyncpg.connection", "asyncpg", version_requirement="0.27.0")
1119
def patch(m):
1220
"""
@@ -19,11 +27,3 @@ def patch(m):
1927
patch_function(m, "Connection.execute", _execute)
2028
patch_function(m, "Connection.executemany", _execute)
2129
patch_function(m, "Connection._execute", _execute)
22-
23-
24-
@before
25-
def _execute(func, instance, args, kwargs):
26-
query = get_argument(args, kwargs, 0, "query")
27-
28-
op = f"asyncpg.connection.Connection.{func.__name__}"
29-
vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres"))

aikido_zen/sinks/mysqlclient.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,6 @@
77
from aikido_zen.sinks import patch_function, on_import, before
88

99

10-
@on_import("MySQLdb.cursors", "mysqlclient", version_requirement="1.5.0")
11-
def patch(m):
12-
"""
13-
patching MySQLdb.cursors (mysqlclient)
14-
- patches Cursor.execute(query, ...)
15-
- patches Cursor.executemany(query, ...)
16-
"""
17-
patch_function(m, "Cursor.execute", _execute)
18-
patch_function(m, "Cursor.executemany", _executemany)
19-
20-
2110
@before
2211
def _execute(func, instance, args, kwargs):
2312
query = get_argument(args, kwargs, 0, "query")
@@ -37,3 +26,14 @@ def _executemany(func, instance, args, kwargs):
3726
vulns.run_vulnerability_scan(
3827
kind="sql_injection", op="MySQLdb.Cursor.executemany", args=(query, "mysql")
3928
)
29+
30+
31+
@on_import("MySQLdb.cursors", "mysqlclient", version_requirement="1.5.0")
32+
def patch(m):
33+
"""
34+
patching MySQLdb.cursors (mysqlclient)
35+
- patches Cursor.execute(query, ...)
36+
- patches Cursor.executemany(query, ...)
37+
"""
38+
patch_function(m, "Cursor.execute", _execute)
39+
patch_function(m, "Cursor.executemany", _executemany)

aikido_zen/sinks/psycopg.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,6 @@
77
from aikido_zen.sinks import patch_function, on_import, before
88

99

10-
@on_import("psycopg.cursor", "psycopg", version_requirement="3.1.0")
11-
def patch(m):
12-
"""
13-
patching module psycopg.cursor
14-
- patches Cursor.copy
15-
- patches Cursor.execute
16-
- patches Cursor.executemany
17-
"""
18-
patch_function(m, "Cursor.copy", _copy)
19-
patch_function(m, "Cursor.execute", _execute)
20-
patch_function(m, "Cursor.executemany", _execute)
21-
22-
2310
@before
2411
def _copy(func, instance, args, kwargs):
2512
statement = get_argument(args, kwargs, 0, "statement")
@@ -33,3 +20,16 @@ def _execute(func, instance, args, kwargs):
3320
query = get_argument(args, kwargs, 0, "query")
3421
op = f"psycopg.Cursor.{func.__name__}"
3522
vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres"))
23+
24+
25+
@on_import("psycopg.cursor", "psycopg", version_requirement="3.1.0")
26+
def patch(m):
27+
"""
28+
patching module psycopg.cursor
29+
- patches Cursor.copy
30+
- patches Cursor.execute
31+
- patches Cursor.executemany
32+
"""
33+
patch_function(m, "Cursor.copy", _copy)
34+
patch_function(m, "Cursor.execute", _execute)
35+
patch_function(m, "Cursor.executemany", _execute)

aikido_zen/sinks/psycopg2.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,6 @@
88
from aikido_zen.sinks import on_import, before, patch_function, after
99

1010

11-
@on_import("psycopg2")
12-
def patch(m):
13-
"""
14-
patching module psycopg2
15-
- patches psycopg2.connect
16-
cannot set 'execute' attribute of immutable type 'psycopg2.extensions.cursor',
17-
so we create our own cursor factory to bypass this limitation.
18-
"""
19-
compatible = is_package_compatible(
20-
required_version="2.9.2", packages=["psycopg2", "psycopg2-binary"]
21-
)
22-
if not compatible:
23-
# Users can install either psycopg2 or psycopg2-binary, we need to check if at least
24-
# one is installed and if they meet version requirements
25-
return
26-
27-
patch_function(m, "connect", _connect)
28-
29-
3011
@after
3112
def _connect(func, instance, _args, _kwargs, rv):
3213
"""
@@ -56,3 +37,22 @@ def psycopg2_patch(func, instance, args, kwargs):
5637
query = get_argument(args, kwargs, 0, "query")
5738
op = f"psycopg2.Connection.Cursor.{func.__name__}"
5839
vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres"))
40+
41+
42+
@on_import("psycopg2")
43+
def patch(m):
44+
"""
45+
patching module psycopg2
46+
- patches psycopg2.connect
47+
cannot set 'execute' attribute of immutable type 'psycopg2.extensions.cursor',
48+
so we create our own cursor factory to bypass this limitation.
49+
"""
50+
compatible = is_package_compatible(
51+
required_version="2.9.2", packages=["psycopg2", "psycopg2-binary"]
52+
)
53+
if not compatible:
54+
# Users can install either psycopg2 or psycopg2-binary, we need to check if at least
55+
# one is installed and if they meet version requirements
56+
return
57+
58+
patch_function(m, "connect", _connect)

aikido_zen/sinks/pymongo.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,6 @@
77
from . import patch_function, on_import, before
88

99

10-
@on_import("pymongo.collection", "pymongo", version_requirement="3.10.0")
11-
def patch(m):
12-
"""
13-
patching pymongo.collection
14-
- patches Collection.*(filter, ...)
15-
- patches Collection.*(..., filter, ...)
16-
- patches Collection.*(pipeline, ...)
17-
- patches Collection.bulk_write
18-
src: https://github.com/mongodb/mongo-python-driver/blob/98658cfd1fea42680a178373333bf27f41153759/pymongo/synchronous/collection.py#L136
19-
"""
20-
# func(filter, ...)
21-
patch_function(m, "Collection.replace_one", _func_filter_first)
22-
patch_function(m, "Collection.update_one", _func_filter_first)
23-
patch_function(m, "Collection.update_many", _func_filter_first)
24-
patch_function(m, "Collection.delete_one", _func_filter_first)
25-
patch_function(m, "Collection.delete_many", _func_filter_first)
26-
patch_function(m, "Collection.count_documents", _func_filter_first)
27-
patch_function(m, "Collection.find_one_and_delete", _func_filter_first)
28-
patch_function(m, "Collection.find_one_and_replace", _func_filter_first)
29-
patch_function(m, "Collection.find_one_and_update", _func_filter_first)
30-
patch_function(m, "Collection.find", _func_filter_first)
31-
patch_function(m, "Collection.find_raw_batches", _func_filter_first)
32-
# find_one not present in list since find_one calls find function.
33-
34-
# func(..., filter, ...)
35-
patch_function(m, "Collection.distinct", _func_filter_second)
36-
37-
# func(pipeline, ...)
38-
patch_function(m, "Collection.watch", _func_pipeline)
39-
patch_function(m, "Collection.aggregate", _func_pipeline)
40-
patch_function(m, "Collection.aggregate_raw_batches", _func_pipeline)
41-
42-
# bulk_write
43-
patch_function(m, "Collection.bulk_write", _bulk_write)
44-
45-
4610
@before
4711
def _func_filter_first(func, instance, args, kwargs):
4812
"""Collection.func(filter, ...)"""
@@ -97,3 +61,39 @@ def _bulk_write(func, instance, args, kwargs):
9761
op="pymongo.collection.Collection.bulk_write",
9862
args=(request._filter,),
9963
)
64+
65+
66+
@on_import("pymongo.collection", "pymongo", version_requirement="3.10.0")
67+
def patch(m):
68+
"""
69+
patching pymongo.collection
70+
- patches Collection.*(filter, ...)
71+
- patches Collection.*(..., filter, ...)
72+
- patches Collection.*(pipeline, ...)
73+
- patches Collection.bulk_write
74+
src: https://github.com/mongodb/mongo-python-driver/blob/98658cfd1fea42680a178373333bf27f41153759/pymongo/synchronous/collection.py#L136
75+
"""
76+
# func(filter, ...)
77+
patch_function(m, "Collection.replace_one", _func_filter_first)
78+
patch_function(m, "Collection.update_one", _func_filter_first)
79+
patch_function(m, "Collection.update_many", _func_filter_first)
80+
patch_function(m, "Collection.delete_one", _func_filter_first)
81+
patch_function(m, "Collection.delete_many", _func_filter_first)
82+
patch_function(m, "Collection.count_documents", _func_filter_first)
83+
patch_function(m, "Collection.find_one_and_delete", _func_filter_first)
84+
patch_function(m, "Collection.find_one_and_replace", _func_filter_first)
85+
patch_function(m, "Collection.find_one_and_update", _func_filter_first)
86+
patch_function(m, "Collection.find", _func_filter_first)
87+
patch_function(m, "Collection.find_raw_batches", _func_filter_first)
88+
# find_one not present in list since find_one calls find function.
89+
90+
# func(..., filter, ...)
91+
patch_function(m, "Collection.distinct", _func_filter_second)
92+
93+
# func(pipeline, ...)
94+
patch_function(m, "Collection.watch", _func_pipeline)
95+
patch_function(m, "Collection.aggregate", _func_pipeline)
96+
patch_function(m, "Collection.aggregate_raw_batches", _func_pipeline)
97+
98+
# bulk_write
99+
patch_function(m, "Collection.bulk_write", _bulk_write)

0 commit comments

Comments
 (0)