|
18 | 18 |
|
19 | 19 | class TestRuleRegistry: |
20 | 20 | def test_all_rules_loaded(self) -> None: |
21 | | - assert len(ALL_RULES) == 39 |
| 21 | + assert len(ALL_RULES) == 40 |
22 | 22 |
|
23 | | - def test_10_errors(self) -> None: |
24 | | - # 8 E-series + 2 T-series (T002 xp-cmdshell, T004 deprecated-outer-join). |
| 23 | + def test_11_errors(self) -> None: |
| 24 | + # 9 E-series + 2 T-series (T002 xp-cmdshell, T004 deprecated-outer-join). |
25 | 25 | errors = [r for r in ALL_RULES if r.severity == "error"] |
26 | | - assert len(errors) == 10 |
| 26 | + assert len(errors) == 11 |
27 | 27 |
|
28 | 28 | def test_29_warnings(self) -> None: |
29 | 29 | # 23 W-series + 3 S-series + 3 T-series (T001 with-nolock, |
@@ -91,6 +91,82 @@ def test_e006_update_with_where_ok(self, tmp_path) -> None: |
91 | 91 | e006 = [f for f in result.findings if f.rule_id == "E006"] |
92 | 92 | assert not e006 |
93 | 93 |
|
| 94 | + def test_e009_update_from_implicit_join(self) -> None: |
| 95 | + findings = check([str(FIXTURES / "errors.sql")]) |
| 96 | + e009 = [f for f in findings.findings if f.rule_id == "E009"] |
| 97 | + assert len(e009) >= 1 |
| 98 | + assert "UPDATE" in e009[0].message |
| 99 | + assert ( |
| 100 | + "cartesian" in e009[0].message.lower() or "comma-separated" in e009[0].message.lower() |
| 101 | + ) |
| 102 | + |
| 103 | + def test_e009_explicit_join_ok(self, tmp_path) -> None: |
| 104 | + # The recommended fix from the rule message must NOT trigger E009. |
| 105 | + sql = tmp_path / "safe_update_from.sql" |
| 106 | + sql.write_text( |
| 107 | + "UPDATE c SET c.status = o.status " |
| 108 | + "FROM customers c INNER JOIN orders o " |
| 109 | + "ON c.customer_id = o.customer_id;\n" |
| 110 | + ) |
| 111 | + result = check([str(sql)]) |
| 112 | + e009 = [f for f in result.findings if f.rule_id == "E009"] |
| 113 | + assert not e009 |
| 114 | + |
| 115 | + def test_e009_postgres_single_from_table_ok(self, tmp_path) -> None: |
| 116 | + # Postgres UPDATE ... FROM with a single table is the canonical |
| 117 | + # form and must not flag. |
| 118 | + sql = tmp_path / "postgres_update.sql" |
| 119 | + sql.write_text( |
| 120 | + "UPDATE customers SET status = o.status " |
| 121 | + "FROM orders o WHERE customers.id = o.customer_id;\n" |
| 122 | + ) |
| 123 | + result = check([str(sql)]) |
| 124 | + e009 = [f for f in result.findings if f.rule_id == "E009"] |
| 125 | + assert not e009 |
| 126 | + |
| 127 | + def test_e009_lateral_after_comma_ok(self, tmp_path) -> None: |
| 128 | + # `, LATERAL ...` is a real Snowflake / Postgres lateral join. |
| 129 | + sql = tmp_path / "lateral_update.sql" |
| 130 | + sql.write_text( |
| 131 | + "UPDATE c SET tag = sub.tag FROM customers c, " |
| 132 | + "LATERAL (SELECT tag FROM tags WHERE customer_id = c.id LIMIT 1) sub;\n" |
| 133 | + ) |
| 134 | + result = check([str(sql)]) |
| 135 | + e009 = [f for f in result.findings if f.rule_id == "E009"] |
| 136 | + assert not e009 |
| 137 | + |
| 138 | + def test_e009_update_without_from_ok(self, tmp_path) -> None: |
| 139 | + # No FROM clause at all is a plain single-table UPDATE. |
| 140 | + sql = tmp_path / "plain_update.sql" |
| 141 | + sql.write_text("UPDATE customers SET status = 'active' WHERE id = 1;\n") |
| 142 | + result = check([str(sql)]) |
| 143 | + e009 = [f for f in result.findings if f.rule_id == "E009"] |
| 144 | + assert not e009 |
| 145 | + |
| 146 | + def test_e009_three_table_comma_join_flagged(self, tmp_path) -> None: |
| 147 | + sql = tmp_path / "three_table.sql" |
| 148 | + sql.write_text( |
| 149 | + "UPDATE c SET c.label = p.label " |
| 150 | + "FROM customers c, orders o, products p " |
| 151 | + "WHERE c.id = o.customer_id AND o.product_id = p.id;\n" |
| 152 | + ) |
| 153 | + result = check([str(sql)]) |
| 154 | + e009 = [f for f in result.findings if f.rule_id == "E009"] |
| 155 | + assert len(e009) == 1 |
| 156 | + |
| 157 | + def test_e009_multiline_comma_join_flagged(self, tmp_path) -> None: |
| 158 | + sql = tmp_path / "multiline.sql" |
| 159 | + sql.write_text( |
| 160 | + "UPDATE customers\n" |
| 161 | + "SET status = o.status\n" |
| 162 | + "FROM customers c,\n" |
| 163 | + " orders o\n" |
| 164 | + "WHERE c.customer_id = o.customer_id;\n" |
| 165 | + ) |
| 166 | + result = check([str(sql)]) |
| 167 | + e009 = [f for f in result.findings if f.rule_id == "E009"] |
| 168 | + assert len(e009) == 1 |
| 169 | + |
94 | 170 |
|
95 | 171 | # --------------------------------------------------------------------------- |
96 | 172 | # Warning rules |
|
0 commit comments