66from functools import cached_property
77
88import sqlalchemy as sa
9- import sqlalchemy .sql .functions as func
10- from sqlalchemy .engine import Engine
11- from sqlalchemy .sql import elements , expression , false , select , selectable , true
129
1310from sqlcompyre .report import Report
1411from sqlcompyre .results import ColumnMatches , Counts , Names , RowMatches
@@ -24,7 +21,7 @@ class TableComparison:
2421
2522 def __init__ (
2623 self ,
27- engine : Engine ,
24+ engine : sa . Engine ,
2825 left_table : sa .FromClause ,
2926 right_table : sa .FromClause ,
3027 join_columns : list [str ] | None ,
@@ -169,14 +166,14 @@ def row_matches(self) -> RowMatches:
169166 for colname_1 , colname_2 in self .column_name_mapping .items ()
170167 if colname_1 not in self .join_columns
171168 ]
172- inequality_conditions : list [elements .ColumnElement [bool ]] = [
169+ inequality_conditions : list [sa .ColumnElement [bool ]] = [
173170 sa .not_ (c ) for c in equality_conditions
174171 ]
175172
176173 # If there are no conditions, equality is always true, inequality is always false
177174 if not equality_conditions :
178- equality_conditions = [true ()]
179- inequality_conditions = [false ()]
175+ equality_conditions = [sa . true ()]
176+ inequality_conditions = [sa . false ()]
180177
181178 # -- Create queries
182179 # Query for rows ONLY in left table
@@ -186,7 +183,7 @@ def row_matches(self) -> RowMatches:
186183 if c not in self .join_columns
187184 ]
188185 unjoined_left = (
189- select (* left_columns )
186+ sa . select (* left_columns )
190187 .select_from (self ._outer_join (left = True ))
191188 .where (
192189 self .right_table .c [self .column_name_mapping [self .join_columns [0 ]]].is_ (
@@ -204,7 +201,7 @@ def row_matches(self) -> RowMatches:
204201 if k not in self .join_columns
205202 ]
206203 unjoined_right = (
207- select (* right_columns )
204+ sa . select (* right_columns )
208205 .select_from (self ._outer_join (left = False ))
209206 .where (self .left_table .c [self .join_columns [0 ]].is_ (None ))
210207 )
@@ -229,7 +226,7 @@ def row_matches(self) -> RowMatches:
229226 ]
230227
231228 # The remaining queries
232- joined_total = select (* columns ).select_from (self ._inner_join ())
229+ joined_total = sa . select (* columns ).select_from (self ._inner_join ())
233230 joined_unequal = joined_total .where (sa .or_ (* inequality_conditions ))
234231 joined_equal = joined_total .where (sa .and_ (* equality_conditions ))
235232 joined_row_count = self ._count_rows (self ._inner_join ())
@@ -266,11 +263,11 @@ def column_matches(self) -> ColumnMatches:
266263 if len (cases ) == 0 :
267264 return ColumnMatches (fraction_same = {}, mismatch_selects = {})
268265
269- case_stmt = select (* cases ).select_from (inner_join ).subquery ()
266+ case_stmt = sa . select (* cases ).select_from (inner_join ).subquery ()
270267
271268 # Compute fraction of matching values
272269 cols_to_avg = [col for col in case_stmt .c if f"_{ MATCH_SUFFIX } " in col .name ]
273- avgs = select (
270+ avgs = sa . select (
274271 * [
275272 sa .func .avg (col ).label (f"{ col .name .replace (f'_{ MATCH_SUFFIX } ' , '' )} " )
276273 for col in cols_to_avg
@@ -284,7 +281,7 @@ def column_matches(self) -> ColumnMatches:
284281
285282 # Find column mismatches
286283 mismatch_selects = {
287- left_column : select (inner_join ).where (
284+ left_column : sa . select (inner_join ).where (
288285 sa .not_ (self ._is_equal (left_column , right_column ))
289286 )
290287 for left_column , right_column in self .column_name_mapping .items ()
@@ -374,9 +371,7 @@ def _right_table_name(self) -> str:
374371 return str (self .right_table .element )
375372 return "<right query>"
376373
377- def _is_equal (
378- self , left_column : str , right_column : str
379- ) -> elements .ColumnElement [bool ]:
374+ def _is_equal (self , left_column : str , right_column : str ) -> sa .ColumnElement [bool ]:
380375 """Forms a condition for comparing two columns.
381376
382377 Args:
@@ -403,13 +398,13 @@ def _is_equal(
403398 # and inverting this is still `unknown`). For more discussion, see
404399 # https://stackoverflow.com/questions/1075142/how-to-compare-values-which-may-both-be-null-in-t-sql
405400 # The following is a more robust formulation of `A = B OR (A IS NULL AND B IS NULL)`.
406- return func .coalesce (
401+ return sa . func .coalesce (
407402 sa .case ((condition , None ), else_ = lhs ),
408403 sa .case ((condition , None ), else_ = rhs ),
409404 ).is_ (None )
410405
411406 @cached_property
412- def _join_conditions (self ) -> list [elements .ColumnElement [bool ]]:
407+ def _join_conditions (self ) -> list [sa .ColumnElement [bool ]]:
413408 """Forms a list of join conditions."""
414409 return [
415410 (
@@ -419,11 +414,11 @@ def _join_conditions(self) -> list[elements.ColumnElement[bool]]:
419414 for join_col in self .join_columns
420415 ]
421416
422- def _inner_join (self ) -> expression .Join :
417+ def _inner_join (self ) -> sa .Join :
423418 """Specifies an inner join on the left and right tables."""
424419 return self .left_table .join (self .right_table , sa .and_ (* self ._join_conditions ))
425420
426- def _outer_join (self , left : bool ) -> expression .Join :
421+ def _outer_join (self , left : bool ) -> sa .Join :
427422 """Specifies an outer join between the two tables.
428423
429424 Args:
@@ -438,7 +433,7 @@ def _outer_join(self, left: bool) -> expression.Join:
438433 return left_table .outerjoin (right_table , sa .and_ (* self ._join_conditions ))
439434 return right_table .outerjoin (left_table , sa .and_ (* self ._join_conditions ))
440435
441- def _get_aggregate_changes (self , left_col_name : str ) -> selectable .Select :
436+ def _get_aggregate_changes (self , left_col_name : str ) -> sa .Select :
442437 """Counts the number of different ways each column changes from one table to
443438 another.
444439
@@ -462,14 +457,14 @@ def _get_aggregate_changes(self, left_col_name: str) -> selectable.Select:
462457 )
463458
464459 return (
465- select (change , sa .func .count ())
460+ sa . select (change , sa .func .count ())
466461 .select_from (self ._inner_join ())
467462 .where (sa .not_ (self ._is_equal (left_col_name , right_col_name )))
468463 .group_by (left_col , right_col )
469464 .order_by (sa .func .count ().desc ())
470465 )
471466
472- def _count_rows (self , table : expression .FromClause ) -> int :
467+ def _count_rows (self , table : sa .FromClause ) -> int :
473468 """Counts the number of rows in a table-like object.
474469
475470 Args:
@@ -479,7 +474,9 @@ def _count_rows(self, table: expression.FromClause) -> int:
479474 The number of rows.
480475 """
481476 with self .engine .connect () as conn :
482- return conn .execute (select (sa .func .count ()).select_from (table )).scalar_one ()
477+ return conn .execute (
478+ sa .select (sa .func .count ()).select_from (table )
479+ ).scalar_one ()
483480
484481 # ---------------------------------------------------------------------------------------------
485482 # STRING REPRESENTATION
@@ -502,7 +499,7 @@ def __str__(self):
502499
503500
504501def _join_columns_from_pk_if_needed (
505- engine : Engine ,
502+ engine : sa . Engine ,
506503 left : sa .FromClause ,
507504 right : sa .FromClause ,
508505 join_columns : list [str ],
@@ -564,7 +561,7 @@ def _join_columns_from_pk_if_needed(
564561
565562
566563def _is_valid_primary_key_column (
567- engine : Engine ,
564+ engine : sa . Engine ,
568565 left_table : sa .FromClause ,
569566 right_table : sa .FromClause ,
570567 left_column : str ,
@@ -593,7 +590,7 @@ def _is_valid_primary_key_column(
593590
594591
595592def _is_valid_primary_key (
596- engine : Engine , table : sa .FromClause , columns : list [str ]
593+ engine : sa . Engine , table : sa .FromClause , columns : list [str ]
597594) -> bool :
598595 with engine .connect () as conn :
599596 result = conn .execute (
0 commit comments