1- # Copyright (c) QuantCo 2024-2024
1+ # Copyright (c) QuantCo 2024-2025
22# SPDX-License-Identifier: BSD-3-Clause
33
44import functools
55import logging
66from functools import cached_property
7- from typing import cast
87
98import sqlalchemy as sa
109import sqlalchemy .sql .functions as func
1110from sqlalchemy .engine import Engine
12- from sqlalchemy .sql import elements , expression , false , schema , select , selectable , true
11+ from sqlalchemy .sql import elements , expression , false , select , selectable , true
1312
1413from sqlcompyre .report import Report
1514from sqlcompyre .results import ColumnMatches , Counts , Names , RowMatches
@@ -26,8 +25,8 @@ class TableComparison:
2625 def __init__ (
2726 self ,
2827 engine : Engine ,
29- left_table : schema . Table ,
30- right_table : schema . Table ,
28+ left_table : sa . FromClause ,
29+ right_table : sa . FromClause ,
3130 join_columns : list [str ] | None ,
3231 column_name_mapping : dict [str , str ] | None ,
3332 ignore_columns : list [str ] | None ,
@@ -54,8 +53,8 @@ def __init__(
5453 infer_primary_keys: Whether to infer primary keys if none are available.
5554 """
5655 self .engine = engine
57- self .left_table = cast ( expression . Alias , left_table .alias ("left" ) )
58- self .right_table = cast ( expression . Alias , right_table .alias ("right" ) )
56+ self .left_table = left_table .alias ("left" )
57+ self .right_table = right_table .alias ("right" )
5958 self .column_name_mapping = _identity_column_mapping_if_needed (
6059 left_table ,
6160 right_table ,
@@ -79,8 +78,8 @@ def join_columns(self) -> list[str]:
7978 """The columns used for joining the two tables."""
8079 pks = _join_columns_from_pk_if_needed (
8180 self .engine ,
82- cast ( sa . Table , self .left_table . element ) ,
83- cast ( sa . Table , self .right_table . element ) ,
81+ self .left_table ,
82+ self .right_table ,
8483 self ._user_join_columns ,
8584 ignore_casing = self .ignore_casing ,
8685 column_name_mapping = self .column_name_mapping ,
@@ -323,9 +322,6 @@ def summary_report(self) -> Report:
323322 Returns:
324323 A report summarizing the comparison of the two tables.
325324 """
326- left_name = str (self .left_table .original )
327- right_name = str (self .right_table .original )
328-
329325 description = None
330326 sections = {
331327 "Column Names" : self .column_names ,
@@ -349,17 +345,35 @@ def summary_report(self) -> Report:
349345 logging .warning (
350346 "'%s' and '%s' cannot be matched (%s): dropping row and column matches "
351347 "from the report" ,
352- left_name ,
353- right_name ,
348+ self . _left_table_name ,
349+ self . _right_table_name ,
354350 exc ,
355351 )
356352
357- return Report ("tables" , left_name , right_name , description , sections )
353+ return Report (
354+ "tables" ,
355+ self ._left_table_name ,
356+ self ._right_table_name ,
357+ description ,
358+ sections ,
359+ )
358360
359361 # ---------------------------------------------------------------------------------------------
360362 # UTILITY METHODS
361363 # ---------------------------------------------------------------------------------------------
362364
365+ @property
366+ def _left_table_name (self ) -> str :
367+ if isinstance (self .left_table , sa .Alias ):
368+ return str (self .left_table .element )
369+ return "<left query>"
370+
371+ @property
372+ def _right_table_name (self ) -> str :
373+ if isinstance (self .right_table , sa .Alias ):
374+ return str (self .right_table .element )
375+ return "<right query>"
376+
363377 def _is_equal (
364378 self , left_column : str , right_column : str
365379 ) -> elements .ColumnElement [bool ]:
@@ -477,8 +491,8 @@ def __repr__(self):
477491 def __str__ (self ):
478492 return (
479493 f"{ self .__class__ .__name__ } ("
480- f'left_table="{ self .left_table . original } ", '
481- f'right_table="{ self .right_table . original } ")'
494+ f'left_table="{ self ._left_table_name } ", '
495+ f'right_table="{ self ._right_table_name } ")'
482496 )
483497
484498
@@ -489,8 +503,8 @@ def __str__(self):
489503
490504def _join_columns_from_pk_if_needed (
491505 engine : Engine ,
492- left : sa .Table ,
493- right : sa .Table ,
506+ left : sa .FromClause ,
507+ right : sa .FromClause ,
494508 join_columns : list [str ],
495509 ignore_casing : bool ,
496510 column_name_mapping : dict [str , str ],
@@ -503,8 +517,8 @@ def _join_columns_from_pk_if_needed(
503517 join_columns = [lowercase_map [c .lower ()] for c in join_columns ]
504518
505519 if not join_columns :
506- left_pks = {pk .name for pk in sa . inspect ( left ) .primary_key }
507- right_pks = {pk .name for pk in sa . inspect ( right ) .primary_key }
520+ left_pks = {col .name for col in left . columns if col .primary_key }
521+ right_pks = {col .name for col in right . columns if col .primary_key }
508522 reverse_mapping = {v : k for k , v in column_name_mapping .items ()}
509523 if not (left_pks - set (column_name_mapping ) | right_pks - set (reverse_mapping )):
510524 # All primary keys can be matched
@@ -551,8 +565,8 @@ def _join_columns_from_pk_if_needed(
551565
552566def _is_valid_primary_key_column (
553567 engine : Engine ,
554- left_table : sa .Table ,
555- right_table : sa .Table ,
568+ left_table : sa .FromClause ,
569+ right_table : sa .FromClause ,
556570 left_column : str ,
557571 right_column : str ,
558572) -> bool :
@@ -578,7 +592,9 @@ def _is_valid_primary_key_column(
578592 return left_nulls == 0 and right_nulls == 0
579593
580594
581- def _is_valid_primary_key (engine : Engine , table : sa .Table , columns : list [str ]) -> bool :
595+ def _is_valid_primary_key (
596+ engine : Engine , table : sa .FromClause , columns : list [str ]
597+ ) -> bool :
582598 with engine .connect () as conn :
583599 result = conn .execute (
584600 sa .select (* [table .c [c ] for c in columns ])
@@ -590,8 +606,8 @@ def _is_valid_primary_key(engine: Engine, table: sa.Table, columns: list[str]) -
590606
591607
592608def _identity_column_mapping_if_needed (
593- left : sa .schema . Table ,
594- right : sa .schema . Table ,
609+ left : sa .FromClause ,
610+ right : sa .FromClause ,
595611 mapping : dict [str , str ],
596612 ignore_columns : list [str ],
597613 ignore_casing : bool ,
0 commit comments