Skip to content

Commit 460ea15

Browse files
authored
chore: Simplify type annotations for SQLAlchemy v2 (#17)
1 parent 8c32375 commit 460ea15

36 files changed

Lines changed: 270 additions & 294 deletions

sqlcompyre/analysis/dialects/_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
# Copyright (c) QuantCo 2024-2024
1+
# Copyright (c) QuantCo 2024-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
from datetime import datetime
55
from typing import Protocol
66

77
import sqlalchemy as sa
8-
from sqlalchemy.engine import Engine
98

109

1110
class DialectProtocol(Protocol):
@@ -40,7 +39,7 @@ class DialectProtocol(Protocol):
4039
views_support_notnull_columns: bool
4140

4241
def get_table_creation_timestamps(
43-
self, engine: Engine, tables: list[sa.Table]
42+
self, engine: sa.Engine, tables: list[sa.Table]
4443
) -> list[datetime]:
4544
"""Obtain the creation timestamps from a list of tables.
4645

sqlcompyre/analysis/dialects/mssql.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
# Copyright (c) QuantCo 2024-2024
1+
# Copyright (c) QuantCo 2024-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
from datetime import datetime
55

66
import sqlalchemy as sa
77
from sqlalchemy.dialects.mssql import dialect as SqlAlchemyMssqlDialect # noqa: N812
8-
from sqlalchemy.engine import Engine
98

109
from ._base import DialectProtocol
1110

@@ -22,7 +21,7 @@ class MssqlDialect(SqlAlchemyMssqlDialect, DialectProtocol): # type: ignore
2221
views_support_notnull_columns: bool = True
2322

2423
def get_table_creation_timestamps(
25-
self, engine: Engine, tables: list[sa.Table]
24+
self, engine: sa.Engine, tables: list[sa.Table]
2625
) -> list[datetime]:
2726
# Potentially, we need to get the database from the tables
2827
db: str | None = None

sqlcompyre/analysis/query_inspection.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
# Copyright (c) QuantCo 2024-2024
1+
# Copyright (c) QuantCo 2024-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

44

55
from functools import cached_property, lru_cache
66

77
import sqlalchemy as sa
8-
from sqlalchemy.engine import Engine
98

109

1110
class QueryInspection:
@@ -16,7 +15,7 @@ class QueryInspection:
1615
or :meth:`~sqlcompyre.api.inspect_table` functions instead.
1716
"""
1817

19-
def __init__(self, engine: Engine, selectable: sa.Select):
18+
def __init__(self, engine: sa.Engine, selectable: sa.Select):
2019
"""
2120
Args:
2221
engine: The engine to use for connecting to the database.

sqlcompyre/analysis/schema_comparison.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
# Copyright (c) QuantCo 2024-2024
1+
# Copyright (c) QuantCo 2024-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import logging
55
import re
66
from functools import cached_property
77
from typing import Literal, cast
88

9-
from sqlalchemy import schema
10-
from sqlalchemy.engine import Engine
9+
import sqlalchemy as sa
1110
from tqdm.auto import tqdm
1211

1312
from sqlcompyre.report import Report
@@ -27,11 +26,11 @@ class SchemaComparison:
2726

2827
def __init__(
2928
self,
30-
engine: Engine,
29+
engine: sa.Engine,
3130
left_schema: str,
3231
right_schema: str,
33-
left_tables: dict[str, schema.Table],
34-
right_tables: dict[str, schema.Table],
32+
left_tables: dict[str, sa.Table],
33+
right_tables: dict[str, sa.Table],
3534
float_precision: float,
3635
collation: str | None,
3736
ignore_casing: bool,

sqlcompyre/analysis/table_comparison.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
from functools import cached_property
77

88
import sqlalchemy as sa
9-
import sqlalchemy.sql.functions as func
10-
from sqlalchemy.engine import Engine
11-
from sqlalchemy.sql import elements, expression, false, select, selectable, true
129

1310
from sqlcompyre.report import Report
1411
from sqlcompyre.results import ColumnMatches, Counts, Names, RowMatches
@@ -24,7 +21,7 @@ class TableComparison:
2421

2522
def __init__(
2623
self,
27-
engine: Engine,
24+
engine: sa.Engine,
2825
left_table: sa.FromClause,
2926
right_table: sa.FromClause,
3027
join_columns: list[str] | None,
@@ -169,14 +166,14 @@ def row_matches(self) -> RowMatches:
169166
for colname_1, colname_2 in self.column_name_mapping.items()
170167
if colname_1 not in self.join_columns
171168
]
172-
inequality_conditions: list[elements.ColumnElement[bool]] = [
169+
inequality_conditions: list[sa.ColumnElement[bool]] = [
173170
sa.not_(c) for c in equality_conditions
174171
]
175172

176173
# If there are no conditions, equality is always true, inequality is always false
177174
if not equality_conditions:
178-
equality_conditions = [true()]
179-
inequality_conditions = [false()]
175+
equality_conditions = [sa.true()]
176+
inequality_conditions = [sa.false()]
180177

181178
# -- Create queries
182179
# Query for rows ONLY in left table
@@ -186,7 +183,7 @@ def row_matches(self) -> RowMatches:
186183
if c not in self.join_columns
187184
]
188185
unjoined_left = (
189-
select(*left_columns)
186+
sa.select(*left_columns)
190187
.select_from(self._outer_join(left=True))
191188
.where(
192189
self.right_table.c[self.column_name_mapping[self.join_columns[0]]].is_(
@@ -204,7 +201,7 @@ def row_matches(self) -> RowMatches:
204201
if k not in self.join_columns
205202
]
206203
unjoined_right = (
207-
select(*right_columns)
204+
sa.select(*right_columns)
208205
.select_from(self._outer_join(left=False))
209206
.where(self.left_table.c[self.join_columns[0]].is_(None))
210207
)
@@ -229,7 +226,7 @@ def row_matches(self) -> RowMatches:
229226
]
230227

231228
# The remaining queries
232-
joined_total = select(*columns).select_from(self._inner_join())
229+
joined_total = sa.select(*columns).select_from(self._inner_join())
233230
joined_unequal = joined_total.where(sa.or_(*inequality_conditions))
234231
joined_equal = joined_total.where(sa.and_(*equality_conditions))
235232
joined_row_count = self._count_rows(self._inner_join())
@@ -266,11 +263,11 @@ def column_matches(self) -> ColumnMatches:
266263
if len(cases) == 0:
267264
return ColumnMatches(fraction_same={}, mismatch_selects={})
268265

269-
case_stmt = select(*cases).select_from(inner_join).subquery()
266+
case_stmt = sa.select(*cases).select_from(inner_join).subquery()
270267

271268
# Compute fraction of matching values
272269
cols_to_avg = [col for col in case_stmt.c if f"_{MATCH_SUFFIX}" in col.name]
273-
avgs = select(
270+
avgs = sa.select(
274271
*[
275272
sa.func.avg(col).label(f"{col.name.replace(f'_{MATCH_SUFFIX}', '')}")
276273
for col in cols_to_avg
@@ -284,7 +281,7 @@ def column_matches(self) -> ColumnMatches:
284281

285282
# Find column mismatches
286283
mismatch_selects = {
287-
left_column: select(inner_join).where(
284+
left_column: sa.select(inner_join).where(
288285
sa.not_(self._is_equal(left_column, right_column))
289286
)
290287
for left_column, right_column in self.column_name_mapping.items()
@@ -374,9 +371,7 @@ def _right_table_name(self) -> str:
374371
return str(self.right_table.element)
375372
return "<right query>"
376373

377-
def _is_equal(
378-
self, left_column: str, right_column: str
379-
) -> elements.ColumnElement[bool]:
374+
def _is_equal(self, left_column: str, right_column: str) -> sa.ColumnElement[bool]:
380375
"""Forms a condition for comparing two columns.
381376
382377
Args:
@@ -403,13 +398,13 @@ def _is_equal(
403398
# and inverting this is still `unknown`). For more discussion, see
404399
# https://stackoverflow.com/questions/1075142/how-to-compare-values-which-may-both-be-null-in-t-sql
405400
# The following is a more robust formulation of `A = B OR (A IS NULL AND B IS NULL)`.
406-
return func.coalesce(
401+
return sa.func.coalesce(
407402
sa.case((condition, None), else_=lhs),
408403
sa.case((condition, None), else_=rhs),
409404
).is_(None)
410405

411406
@cached_property
412-
def _join_conditions(self) -> list[elements.ColumnElement[bool]]:
407+
def _join_conditions(self) -> list[sa.ColumnElement[bool]]:
413408
"""Forms a list of join conditions."""
414409
return [
415410
(
@@ -419,11 +414,11 @@ def _join_conditions(self) -> list[elements.ColumnElement[bool]]:
419414
for join_col in self.join_columns
420415
]
421416

422-
def _inner_join(self) -> expression.Join:
417+
def _inner_join(self) -> sa.Join:
423418
"""Specifies an inner join on the left and right tables."""
424419
return self.left_table.join(self.right_table, sa.and_(*self._join_conditions))
425420

426-
def _outer_join(self, left: bool) -> expression.Join:
421+
def _outer_join(self, left: bool) -> sa.Join:
427422
"""Specifies an outer join between the two tables.
428423
429424
Args:
@@ -438,7 +433,7 @@ def _outer_join(self, left: bool) -> expression.Join:
438433
return left_table.outerjoin(right_table, sa.and_(*self._join_conditions))
439434
return right_table.outerjoin(left_table, sa.and_(*self._join_conditions))
440435

441-
def _get_aggregate_changes(self, left_col_name: str) -> selectable.Select:
436+
def _get_aggregate_changes(self, left_col_name: str) -> sa.Select:
442437
"""Counts the number of different ways each column changes from one table to
443438
another.
444439
@@ -462,14 +457,14 @@ def _get_aggregate_changes(self, left_col_name: str) -> selectable.Select:
462457
)
463458

464459
return (
465-
select(change, sa.func.count())
460+
sa.select(change, sa.func.count())
466461
.select_from(self._inner_join())
467462
.where(sa.not_(self._is_equal(left_col_name, right_col_name)))
468463
.group_by(left_col, right_col)
469464
.order_by(sa.func.count().desc())
470465
)
471466

472-
def _count_rows(self, table: expression.FromClause) -> int:
467+
def _count_rows(self, table: sa.FromClause) -> int:
473468
"""Counts the number of rows in a table-like object.
474469
475470
Args:
@@ -479,7 +474,9 @@ def _count_rows(self, table: expression.FromClause) -> int:
479474
The number of rows.
480475
"""
481476
with self.engine.connect() as conn:
482-
return conn.execute(select(sa.func.count()).select_from(table)).scalar_one()
477+
return conn.execute(
478+
sa.select(sa.func.count()).select_from(table)
479+
).scalar_one()
483480

484481
# ---------------------------------------------------------------------------------------------
485482
# STRING REPRESENTATION
@@ -502,7 +499,7 @@ def __str__(self):
502499

503500

504501
def _join_columns_from_pk_if_needed(
505-
engine: Engine,
502+
engine: sa.Engine,
506503
left: sa.FromClause,
507504
right: sa.FromClause,
508505
join_columns: list[str],
@@ -564,7 +561,7 @@ def _join_columns_from_pk_if_needed(
564561

565562

566563
def _is_valid_primary_key_column(
567-
engine: Engine,
564+
engine: sa.Engine,
568565
left_table: sa.FromClause,
569566
right_table: sa.FromClause,
570567
left_column: str,
@@ -593,7 +590,7 @@ def _is_valid_primary_key_column(
593590

594591

595592
def _is_valid_primary_key(
596-
engine: Engine, table: sa.FromClause, columns: list[str]
593+
engine: sa.Engine, table: sa.FromClause, columns: list[str]
597594
) -> bool:
598595
with engine.connect() as conn:
599596
result = conn.execute(

sqlcompyre/api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import sys
55

66
import sqlalchemy as sa
7-
from sqlalchemy import schema
87

98
from .analysis import QueryInspection, SchemaComparison, TableComparison
109

@@ -248,7 +247,7 @@ def _get_tables_from_schema(
248247
schema: str,
249248
is_database: bool,
250249
include_views: bool,
251-
) -> list[schema.Table]:
250+
) -> list[sa.Table]:
252251
if is_database:
253252
engine = sa.create_engine(engine.url.set(database=schema))
254253
schemas = sa.inspect(engine).get_schema_names()

sqlcompyre/results/column_matches.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# Copyright (c) QuantCo 2024-2024
1+
# Copyright (c) QuantCo 2024-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
from dataclasses import dataclass
55

6-
from sqlalchemy.sql import selectable
6+
import sqlalchemy as sa
77

88

99
@dataclass
@@ -15,4 +15,4 @@ class ColumnMatches:
1515
fraction_same: dict[str, float]
1616
#: Dictionary mapping the name of the left-table column to a query of all joined rows for
1717
#: which the column does not have the same value in both tables.
18-
mismatch_selects: dict[str, selectable.Select]
18+
mismatch_selects: dict[str, sa.Select]

sqlcompyre/results/row_matches.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# Copyright (c) QuantCo 2024-2024
1+
# Copyright (c) QuantCo 2024-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
from dataclasses import dataclass
55

6-
from sqlalchemy.sql import selectable
6+
import sqlalchemy as sa
77

88

99
@dataclass
@@ -21,12 +21,12 @@ class RowMatches:
2121
#: Number of rows that could be joined
2222
n_joined_total: int
2323
#: Query for obtaining all rows from the left table that could not be joined.
24-
unjoined_left: selectable.Select
24+
unjoined_left: sa.Select
2525
#: Query for obtaining all rows from the right table that could not be joined.
26-
unjoined_right: selectable.Select
26+
unjoined_right: sa.Select
2727
#: Query for obtaining all rows that could be joined and were identical across the two tables.
28-
joined_equal: selectable.Select
28+
joined_equal: sa.Select
2929
#: Query for obtaining all rows that could be joined but were not identical.
30-
joined_unequal: selectable.Select
30+
joined_unequal: sa.Select
3131
#: Query for obtaining all rows that were joined, regardless of equality.
32-
joined_total: selectable.Select
32+
joined_total: sa.Select

tests/_shared/dialects.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# Copyright (c) QuantCo 2024-2024
1+
# Copyright (c) QuantCo 2024-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
from __future__ import annotations
55

66
import os
77

8-
from sqlalchemy.engine import url
8+
import sqlalchemy as sa
99

1010
from sqlcompyre.analysis.dialects import DialectProtocol
1111

@@ -17,15 +17,13 @@ def dialect_from_env() -> DialectProtocol:
1717
connection string. **Use this function with care!** Typically, it should only be
1818
used for ``pytest.mark.skipif`` statements.
1919
"""
20-
conn = url.make_url(os.environ["DB_CONNECTION_STRING"])
20+
conn = sa.make_url(os.environ["DB_CONNECTION_STRING"])
2121
return dialect_from_connection_url(conn)
2222

2323

24-
def dialect_from_connection_url(conn_url: url.URL) -> DialectProtocol:
24+
def dialect_from_connection_url(conn_url: sa.URL) -> DialectProtocol:
2525
"""Get dialect metadata from the connection URL."""
26-
from sqlalchemy import create_engine
27-
28-
dialect = create_engine(conn_url).dialect
26+
dialect = sa.create_engine(conn_url).dialect
2927
match dialect.name:
3028
case "mssql":
3129
from sqlcompyre.analysis.dialects import MssqlDialect

0 commit comments

Comments
 (0)