Skip to content

Commit 14c388f

Browse files
committed
Assert specific return values in SQL injection tests
Instead of assertNotEquals(1, result), assert the exact expected return value: 0 for safe queries, 1 for injections, 3 for tokenize errors. Unterminated string cases now use isSqlTokenizeError.
1 parent 89b1adf commit 14c388f

1 file changed

Lines changed: 23 additions & 20 deletions

File tree

agent_api/src/test/java/vulnerabilities/SqlInjectionTest.java

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,39 @@
77
import static dev.aikido.agent_api.vulnerabilities.sql_injection.SqlDetector.detectSqlInjection;
88
import static org.junit.jupiter.api.Assertions.assertEquals;
99
import static org.junit.jupiter.api.Assertions.assertFalse;
10-
import static org.junit.jupiter.api.Assertions.assertNotEquals;
1110
import static org.junit.jupiter.api.Assertions.assertTrue;
1211

1312
public class SqlInjectionTest {
1413
private void isNotSqlInjection(String sql, String input, String dialect) {
15-
int result;
1614
if ("mysql".equals(dialect) || "all".equals(dialect)) {
17-
result = detectSqlInjection(sql, input, new Dialect("mysql"));
18-
assertNotEquals(1, result, String.format("Expected no SQL injection for SQL: %s and input: %s", sql, input));
15+
int result = detectSqlInjection(sql, input, new Dialect("mysql"));
16+
assertEquals(0, result, String.format("Expected no SQL injection for SQL: %s and input: %s", sql, input));
1917
}
2018
if ("postgresql".equals(dialect) || "all".equals(dialect)) {
21-
result = detectSqlInjection(sql, input, new Dialect("postgresql"));
22-
assertNotEquals(1, result, String.format("Expected no SQL injection for SQL: %s and input: %s", sql, input));
19+
int result = detectSqlInjection(sql, input, new Dialect("postgresql"));
20+
assertEquals(0, result, String.format("Expected no SQL injection for SQL: %s and input: %s", sql, input));
2321
}
2422
}
2523
private void isSqlInjection(String sql, String input, String dialect) {
26-
int result;
2724
if ("mysql".equals(dialect) || "all".equals(dialect)) {
28-
result = detectSqlInjection(sql, input, new Dialect("mysql"));
25+
int result = detectSqlInjection(sql, input, new Dialect("mysql"));
2926
assertEquals(1, result, String.format("Expected SQL injection for SQL: %s and input: %s", sql, input));
3027
}
3128
if ("postgresql".equals(dialect) || "all".equals(dialect)) {
32-
result = detectSqlInjection(sql, input, new Dialect("postgresql"));
29+
int result = detectSqlInjection(sql, input, new Dialect("postgresql"));
3330
assertEquals(1, result, String.format("Expected SQL injection for SQL: %s and input: %s", sql, input));
3431
}
3532
}
33+
private void isSqlTokenizeError(String sql, String input, String dialect) {
34+
if ("mysql".equals(dialect) || "all".equals(dialect)) {
35+
int result = detectSqlInjection(sql, input, new Dialect("mysql"));
36+
assertEquals(3, result, String.format("Expected SQL tokenize error for SQL: %s and input: %s", sql, input));
37+
}
38+
if ("postgresql".equals(dialect) || "all".equals(dialect)) {
39+
int result = detectSqlInjection(sql, input, new Dialect("postgresql"));
40+
assertEquals(3, result, String.format("Expected SQL tokenize error for SQL: %s and input: %s", sql, input));
41+
}
42+
}
3643

3744

3845
/**
@@ -166,17 +173,8 @@ public void testShouldReturnEarly() {
166173
assertFalse(SqlDetector.shouldReturnEarly("SELECT * FROM users; DROP TABLE", "users; DROP TABLE"));
167174
}
168175

169-
/**
170-
* Moved :
171-
* is_sql_injection("SELECT * FROM users WHERE id = 'users\\'", "users\\")
172-
* is_sql_injection("SELECT * FROM users WHERE id = 'users\\\\'", "users\\\\")
173-
* to is_not_sql_injection. Reason : Invalid SQL.
174-
*/
175176
@Test
176177
public void testAllowEscapeSequences() {
177-
// Invalid queries:
178-
isNotSqlInjection("SELECT * FROM users WHERE id = 'users\\'", "users\\", "all");
179-
isNotSqlInjection("SELECT * FROM users WHERE id = 'users\\\\'", "users\\\\", "all");
180178
isNotSqlInjection("SELECT * FROM users WHERE id = '\nusers'", "\nusers", "all");
181179
isNotSqlInjection("SELECT * FROM users WHERE id = '\rusers'", "\rusers", "all");
182180
isNotSqlInjection("SELECT * FROM users WHERE id = '\tusers'", "\tusers", "all");
@@ -220,8 +218,7 @@ public void testCheckStringSafelyEscaped() {
220218
"SELECT * FROM comments WHERE comment = \"I\\`m writing you\"", "I`m writing you", "all"
221219
);
222220

223-
// Invalid query (strings don't terminate)
224-
isNotSqlInjection(
221+
isSqlTokenizeError(
225222
"SELECT * FROM comments WHERE comment = 'I'm writing you'", "I'm writing you", "all"
226223
);
227224

@@ -375,6 +372,12 @@ public void testLowercasedInputSqlInjection() {
375372
isSqlInjection(sql, expectedSqlInjection, "all");
376373
}
377374

375+
@Test
376+
public void testUnterminatedStrings() {
377+
isSqlTokenizeError("SELECT * FROM users WHERE id = 'users\\'", "users\\", "all");
378+
isSqlTokenizeError("SELECT * FROM users WHERE id = 'users\\\\'", "users\\\\", "all");
379+
}
380+
378381
/**
379382
* Marked the following as SQL injection since this would result in 2 or more tokens becoming one :
380383
* is_not_sql_injection("foobar)", "foobar)")

0 commit comments

Comments
 (0)