forked from Pawansingh3889/sql-guard
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_rules.py
More file actions
303 lines (233 loc) · 11.9 KB
/
test_rules.py
File metadata and controls
303 lines (233 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
"""Tests for all rules in the sql-sop rule registry."""
from __future__ import annotations
from pathlib import Path
from sql_guard.checker import check
from sql_guard.rules import ALL_RULES, get_rules
FIXTURES = Path(__file__).parent / "fixtures"
# ---------------------------------------------------------------------------
# Rule registry
# ---------------------------------------------------------------------------
class TestRuleRegistry:
def test_all_rules_loaded(self) -> None:
assert len(ALL_RULES) == 37
def test_10_errors(self) -> None:
# 8 E-series + 2 T-series (T002 xp-cmdshell, T004 deprecated-outer-join).
errors = [r for r in ALL_RULES if r.severity == "error"]
assert len(errors) == 10
def test_27_warnings(self) -> None:
# 21 W-series + 3 S-series + 3 T-series (T001 with-nolock,
# T003 cursor-declaration, T005 create-index-without-online).
warnings = [r for r in ALL_RULES if r.severity == "warning"]
assert len(warnings) == 27
def test_unique_ids(self) -> None:
ids = [r.id for r in ALL_RULES]
assert len(ids) == len(set(ids))
def test_disable_rules(self) -> None:
rules = get_rules(disabled_ids={"E001", "W001"})
ids = {r.id for r in rules}
assert "E001" not in ids
assert "W001" not in ids
# ---------------------------------------------------------------------------
# Error rules
# ---------------------------------------------------------------------------
class TestErrorRules:
def test_e001_delete_without_where(self) -> None:
findings = check([str(FIXTURES / "errors.sql")])
e001 = [f for f in findings.findings if f.rule_id == "E001"]
assert len(e001) >= 1
assert "DELETE" in e001[0].message
def test_e002_drop_without_if_exists(self) -> None:
findings = check([str(FIXTURES / "errors.sql")])
e002 = [f for f in findings.findings if f.rule_id == "E002"]
assert len(e002) >= 1
def test_e003_grant_revoke(self) -> None:
findings = check([str(FIXTURES / "errors.sql")])
e003 = [f for f in findings.findings if f.rule_id == "E003"]
assert len(e003) >= 1
def test_e004_string_concat(self) -> None:
findings = check([str(FIXTURES / "errors.sql")])
e004 = [f for f in findings.findings if f.rule_id == "E004"]
assert len(e004) >= 1
assert "injection" in e004[0].message.lower()
def test_e005_insert_without_columns(self) -> None:
findings = check([str(FIXTURES / "errors.sql")])
e005 = [f for f in findings.findings if f.rule_id == "E005"]
assert len(e005) >= 1
def test_e006_update_without_where(self) -> None:
findings = check([str(FIXTURES / "errors.sql")])
e006 = [f for f in findings.findings if f.rule_id == "E006"]
assert len(e006) >= 1
assert "UPDATE" in e006[0].message
assert "overwrite" in e006[0].message.lower() or "every row" in e006[0].message.lower()
def test_e006_update_with_where_ok(self, tmp_path) -> None:
# UPDATE ... WHERE must NOT trigger the rule. Belt-and-braces
# assertion so a future regex tweak can't silently over-trigger.
sql = tmp_path / "safe_update.sql"
sql.write_text("UPDATE orders SET status = 'shipped' WHERE id = 42;\n")
result = check([str(sql)])
e006 = [f for f in result.findings if f.rule_id == "E006"]
assert not e006
# ---------------------------------------------------------------------------
# Warning rules
# ---------------------------------------------------------------------------
class TestWarningRules:
def test_w001_select_star(self) -> None:
findings = check([str(FIXTURES / "warnings.sql")])
w001 = [f for f in findings.findings if f.rule_id == "W001"]
assert len(w001) >= 1
def test_w003_function_on_column(self) -> None:
findings = check([str(FIXTURES / "warnings.sql")])
w003 = [f for f in findings.findings if f.rule_id == "W003"]
assert len(w003) >= 1
def test_w007_hardcoded_values(self) -> None:
findings = check([str(FIXTURES / "warnings.sql")])
w007 = [f for f in findings.findings if f.rule_id == "W007"]
assert len(w007) >= 1
def test_w010_commented_out_code(self) -> None:
findings = check([str(FIXTURES / "warnings.sql")])
w010 = [f for f in findings.findings if f.rule_id == "W010"]
assert len(w010) >= 1
def test_w011_union_without_all(self) -> None:
findings = check([str(FIXTURES / "warnings.sql")])
w011 = [f for f in findings.findings if f.rule_id == "W011"]
assert len(w011) >= 1
def test_w011_passes_on_union_all(self) -> None:
from sql_guard.rules.warnings import UnionWithoutAll
rule = UnionWithoutAll()
statement = "SELECT id FROM orders_2024\nUNION ALL\nSELECT id FROM orders_2025;"
assert rule.check_statement(statement, 1, "test.sql") is None
def test_w012_group_by_ordinal(self) -> None:
findings = check([str(FIXTURES / "warnings.sql")])
w012 = [f for f in findings.findings if f.rule_id == "W012"]
assert len(w012) >= 1
def test_w012_catches_single_ordinal(self) -> None:
from sql_guard.rules.warnings import GroupByOrdinal
rule = GroupByOrdinal()
statement = "SELECT region, COUNT(*) FROM orders GROUP BY 1;"
assert rule.check_statement(statement, 1, "test.sql") is not None
def test_w012_catches_multiple_ordinals(self) -> None:
from sql_guard.rules.warnings import GroupByOrdinal
rule = GroupByOrdinal()
statement = "SELECT a, b, c, COUNT(*) FROM t GROUP BY 1, 2, 3;"
assert rule.check_statement(statement, 1, "test.sql") is not None
def test_w012_passes_on_explicit_columns(self) -> None:
from sql_guard.rules.warnings import GroupByOrdinal
rule = GroupByOrdinal()
statement = "SELECT region, status, COUNT(*) FROM orders GROUP BY region, status;"
assert rule.check_statement(statement, 1, "test.sql") is None
def test_w012_passes_on_digit_prefixed_column_name(self) -> None:
from sql_guard.rules.warnings import GroupByOrdinal
# Column names that start with a digit ('1st_quarter') are valid
# identifiers in dialects that quote them, and the regex must not
# match them because they are not pure integer tokens.
rule = GroupByOrdinal()
statement = "SELECT 1st_quarter, COUNT(*) FROM sales GROUP BY 1st_quarter;"
assert rule.check_statement(statement, 1, "test.sql") is None
def test_w013_window_missing_partition(self) -> None:
findings = check([str(FIXTURES / "warnings.sql")])
w013 = [f for f in findings.findings if f.rule_id == "W013"]
assert len(w013) >= 1
def test_w013_passes_on_valid_over_clause(self) -> None:
from sql_guard.rules.warnings import WindowMissingPartition
rule = WindowMissingPartition()
statement = "SELECT ROW_NUMBER() OVER (PARTITION BY department_id ORDER BY id) AS rn FROM users;"
assert rule.check_statement(statement, 1, "test.sql") is None
def test_w016_not_in_with_subquery(self) -> None:
findings = check([str(FIXTURES / "warnings.sql")])
w016 = [f for f in findings.findings if f.rule_id == "W016"]
assert len(w016) >= 1
assert "NOT IN" in w016[0].message
def test_w016_not_in_value_list_ok(self, tmp_path) -> None:
# NOT IN (1, 2, 3) value list must NOT trigger -- only subqueries.
sql = tmp_path / "value_list.sql"
sql.write_text("SELECT id FROM users WHERE status_id NOT IN (1, 2, 3);\n")
result = check([str(sql)])
w016 = [f for f in result.findings if f.rule_id == "W016"]
assert not w016
def test_w014_case_without_else(self) -> None:
findings = check([str(FIXTURES / "warnings.sql")])
w014 = [f for f in findings.findings if f.rule_id == "W014"]
assert len(w014) >= 1
assert "CASE" in w014[0].message
def test_w014_case_with_else_ok(self, tmp_path) -> None:
sql = tmp_path / "case_with_else.sql"
sql.write_text(
"SELECT CASE\n"
" WHEN status = 'paid' THEN 1\n"
" WHEN status = 'pending' THEN 0\n"
" ELSE NULL\n"
"END AS paid_flag\n"
"FROM orders;\n"
)
result = check([str(sql)])
w014 = [f for f in result.findings if f.rule_id == "W014"]
assert not w014
def test_w014_outer_case_without_else_fires_when_inner_has_else(
self, tmp_path
) -> None:
# Issue #4 specifically called out the nested case: an outer
# CASE with no ELSE must still fire even when an inner CASE
# does have one.
from sql_guard.rules.warnings import CaseWithoutElse
rule = CaseWithoutElse()
nested = (
"SELECT CASE\n"
" WHEN x THEN CASE WHEN y THEN 1 ELSE 2 END\n"
" WHEN z THEN 3\n"
"END FROM t;"
)
finding = rule.check_statement(nested, 1, "test.sql")
assert finding is not None
assert finding.rule_id == "W014"
def test_w014_does_not_fire_on_begin_end_block(self) -> None:
# T-SQL BEGIN/END blocks should not trip the rule on their own.
from sql_guard.rules.warnings import CaseWithoutElse
rule = CaseWithoutElse()
proc = "BEGIN\n SELECT 1;\nEND;"
assert rule.check_statement(proc, 1, "test.sql") is None
# ---------------------------------------------------------------------------
# Clean file
# ---------------------------------------------------------------------------
class TestCleanFile:
def test_no_errors_on_clean(self) -> None:
findings = check([str(FIXTURES / "clean.sql")], severity="error")
assert findings.error_count == 0
def test_no_findings_on_clean(self) -> None:
findings = check([str(FIXTURES / "clean.sql")])
# Clean file should have zero or near-zero findings
errors = [f for f in findings.findings if f.severity == "error"]
assert len(errors) == 0
# ---------------------------------------------------------------------------
# Checker behavior
# ---------------------------------------------------------------------------
class TestChecker:
def test_files_checked_count(self) -> None:
result = check([str(FIXTURES)])
assert result.files_checked == 3 # errors.sql, warnings.sql, clean.sql
def test_duration_tracked(self) -> None:
result = check([str(FIXTURES)])
# time.perf_counter can return 0.0 on fast hardware when resolution is
# coarser than the measured duration. Track that it's a non-negative
# number rather than strictly positive.
assert result.duration_seconds >= 0
assert isinstance(result.duration_seconds, float)
def test_severity_filter(self) -> None:
all_findings = check([str(FIXTURES / "errors.sql")])
error_only = check([str(FIXTURES / "errors.sql")], severity="error")
assert error_only.warning_count == 0
assert len(all_findings.findings) >= error_only.error_count
def test_fail_fast_stops_early(self) -> None:
result = check([str(FIXTURES / "errors.sql")], fail_fast=True)
# Should have at least 1 error but potentially fewer than checking all
assert result.error_count >= 1
def test_nonexistent_path(self) -> None:
result = check(["nonexistent_dir/"])
assert result.files_checked == 0
def test_w015_join_function_on_column(self) -> None:
from sql_guard.rules import get_rules
from sql_guard.rules.warnings import JoinFunctionOnColumn
# Confirm registration
assert any(isinstance(r, JoinFunctionOnColumn) for r in get_rules())
findings = check([str(FIXTURES / "warnings.sql")])
w015 = [f for f in findings.findings if f.rule_id == "W015"]
assert len(w015) >= 1