Skip to content

Commit 8c32375

Browse files
authored
feat: Allow to compare queries in compare_table function (#16)
1 parent 412c1b7 commit 8c32375

3 files changed

Lines changed: 84 additions & 33 deletions

File tree

sqlcompyre/analysis/table_comparison.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
# Copyright (c) QuantCo 2024-2024
1+
# Copyright (c) QuantCo 2024-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import functools
55
import logging
66
from functools import cached_property
7-
from typing import cast
87

98
import sqlalchemy as sa
109
import sqlalchemy.sql.functions as func
1110
from sqlalchemy.engine import Engine
12-
from sqlalchemy.sql import elements, expression, false, schema, select, selectable, true
11+
from sqlalchemy.sql import elements, expression, false, select, selectable, true
1312

1413
from sqlcompyre.report import Report
1514
from sqlcompyre.results import ColumnMatches, Counts, Names, RowMatches
@@ -26,8 +25,8 @@ class TableComparison:
2625
def __init__(
2726
self,
2827
engine: Engine,
29-
left_table: schema.Table,
30-
right_table: schema.Table,
28+
left_table: sa.FromClause,
29+
right_table: sa.FromClause,
3130
join_columns: list[str] | None,
3231
column_name_mapping: dict[str, str] | None,
3332
ignore_columns: list[str] | None,
@@ -54,8 +53,8 @@ def __init__(
5453
infer_primary_keys: Whether to infer primary keys if none are available.
5554
"""
5655
self.engine = engine
57-
self.left_table = cast(expression.Alias, left_table.alias("left"))
58-
self.right_table = cast(expression.Alias, right_table.alias("right"))
56+
self.left_table = left_table.alias("left")
57+
self.right_table = right_table.alias("right")
5958
self.column_name_mapping = _identity_column_mapping_if_needed(
6059
left_table,
6160
right_table,
@@ -79,8 +78,8 @@ def join_columns(self) -> list[str]:
7978
"""The columns used for joining the two tables."""
8079
pks = _join_columns_from_pk_if_needed(
8180
self.engine,
82-
cast(sa.Table, self.left_table.element),
83-
cast(sa.Table, self.right_table.element),
81+
self.left_table,
82+
self.right_table,
8483
self._user_join_columns,
8584
ignore_casing=self.ignore_casing,
8685
column_name_mapping=self.column_name_mapping,
@@ -323,9 +322,6 @@ def summary_report(self) -> Report:
323322
Returns:
324323
A report summarizing the comparison of the two tables.
325324
"""
326-
left_name = str(self.left_table.original)
327-
right_name = str(self.right_table.original)
328-
329325
description = None
330326
sections = {
331327
"Column Names": self.column_names,
@@ -349,17 +345,35 @@ def summary_report(self) -> Report:
349345
logging.warning(
350346
"'%s' and '%s' cannot be matched (%s): dropping row and column matches "
351347
"from the report",
352-
left_name,
353-
right_name,
348+
self._left_table_name,
349+
self._right_table_name,
354350
exc,
355351
)
356352

357-
return Report("tables", left_name, right_name, description, sections)
353+
return Report(
354+
"tables",
355+
self._left_table_name,
356+
self._right_table_name,
357+
description,
358+
sections,
359+
)
358360

359361
# ---------------------------------------------------------------------------------------------
360362
# UTILITY METHODS
361363
# ---------------------------------------------------------------------------------------------
362364

365+
@property
366+
def _left_table_name(self) -> str:
367+
if isinstance(self.left_table, sa.Alias):
368+
return str(self.left_table.element)
369+
return "<left query>"
370+
371+
@property
372+
def _right_table_name(self) -> str:
373+
if isinstance(self.right_table, sa.Alias):
374+
return str(self.right_table.element)
375+
return "<right query>"
376+
363377
def _is_equal(
364378
self, left_column: str, right_column: str
365379
) -> elements.ColumnElement[bool]:
@@ -477,8 +491,8 @@ def __repr__(self):
477491
def __str__(self):
478492
return (
479493
f"{self.__class__.__name__}("
480-
f'left_table="{self.left_table.original}", '
481-
f'right_table="{self.right_table.original}")'
494+
f'left_table="{self._left_table_name}", '
495+
f'right_table="{self._right_table_name}")'
482496
)
483497

484498

@@ -489,8 +503,8 @@ def __str__(self):
489503

490504
def _join_columns_from_pk_if_needed(
491505
engine: Engine,
492-
left: sa.Table,
493-
right: sa.Table,
506+
left: sa.FromClause,
507+
right: sa.FromClause,
494508
join_columns: list[str],
495509
ignore_casing: bool,
496510
column_name_mapping: dict[str, str],
@@ -503,8 +517,8 @@ def _join_columns_from_pk_if_needed(
503517
join_columns = [lowercase_map[c.lower()] for c in join_columns]
504518

505519
if not join_columns:
506-
left_pks = {pk.name for pk in sa.inspect(left).primary_key}
507-
right_pks = {pk.name for pk in sa.inspect(right).primary_key}
520+
left_pks = {col.name for col in left.columns if col.primary_key}
521+
right_pks = {col.name for col in right.columns if col.primary_key}
508522
reverse_mapping = {v: k for k, v in column_name_mapping.items()}
509523
if not (left_pks - set(column_name_mapping) | right_pks - set(reverse_mapping)):
510524
# All primary keys can be matched
@@ -551,8 +565,8 @@ def _join_columns_from_pk_if_needed(
551565

552566
def _is_valid_primary_key_column(
553567
engine: Engine,
554-
left_table: sa.Table,
555-
right_table: sa.Table,
568+
left_table: sa.FromClause,
569+
right_table: sa.FromClause,
556570
left_column: str,
557571
right_column: str,
558572
) -> bool:
@@ -578,7 +592,9 @@ def _is_valid_primary_key_column(
578592
return left_nulls == 0 and right_nulls == 0
579593

580594

581-
def _is_valid_primary_key(engine: Engine, table: sa.Table, columns: list[str]) -> bool:
595+
def _is_valid_primary_key(
596+
engine: Engine, table: sa.FromClause, columns: list[str]
597+
) -> bool:
582598
with engine.connect() as conn:
583599
result = conn.execute(
584600
sa.select(*[table.c[c] for c in columns])
@@ -590,8 +606,8 @@ def _is_valid_primary_key(engine: Engine, table: sa.Table, columns: list[str]) -
590606

591607

592608
def _identity_column_mapping_if_needed(
593-
left: sa.schema.Table,
594-
right: sa.schema.Table,
609+
left: sa.FromClause,
610+
right: sa.FromClause,
595611
mapping: dict[str, str],
596612
ignore_columns: list[str],
597613
ignore_casing: bool,

sqlcompyre/api.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) QuantCo 2024-2024
1+
# Copyright (c) QuantCo 2024-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import sys
@@ -74,8 +74,8 @@ def inspect_table(engine: sa.Engine, table: sa.Table | str) -> QueryInspection:
7474

7575
def compare_tables(
7676
engine: sa.Engine,
77-
left: sa.Table | str,
78-
right: sa.Table | str,
77+
left: sa.Select | sa.FromClause | str,
78+
right: sa.Select | sa.FromClause | str,
7979
join_columns: list[str] | None = None,
8080
ignore_columns: list[str] | None = None,
8181
column_name_mapping: dict[str, str] | None = None,
@@ -118,8 +118,8 @@ def compare_tables(
118118
A table comparison object that can be used to explore the differences in the tables.
119119
"""
120120
# Get the SQLAlchemy representation of the tables in the database
121-
left_table: sa.Table
122-
right_table: sa.Table
121+
left_table: sa.FromClause
122+
right_table: sa.FromClause
123123
if isinstance(left, str) or isinstance(right, str):
124124
meta = sa.MetaData()
125125

@@ -134,9 +134,9 @@ def compare_tables(
134134
right_table = meta.tables[right]
135135

136136
if not isinstance(left, str):
137-
left_table = left
137+
left_table = left.subquery() if isinstance(left, sa.Select) else left
138138
if not isinstance(right, str):
139-
right_table = right
139+
right_table = right.subquery() if isinstance(right, sa.Select) else right
140140

141141
# Create a table comparison object
142142
return TableComparison(
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) QuantCo 2024-2025
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import sqlalchemy as sa
5+
6+
import sqlcompyre as sc
7+
8+
9+
def test_compare_queries_join_columns_inferred(
10+
engine: sa.Engine, table_students: sa.Table
11+
):
12+
comparison = sc.compare_tables(
13+
engine, sa.select(table_students), sa.select(table_students)
14+
)
15+
assert comparison.join_columns == ["id"]
16+
17+
18+
def test_compare_queries_select(engine: sa.Engine, table_students: sa.Table):
19+
comparison = sc.compare_tables(
20+
engine,
21+
sa.select(table_students).where(table_students.c["age"] >= 30),
22+
sa.select(table_students).where(table_students.c["age"] >= 20),
23+
)
24+
assert comparison.row_counts.diff == 2
25+
assert comparison.row_matches.n_joined_total == 2
26+
27+
28+
def test_compare_queries_subquery(engine: sa.Engine, table_students: sa.Table):
29+
comparison = sc.compare_tables(
30+
engine,
31+
sa.select(table_students).where(table_students.c["age"] >= 30).subquery(),
32+
sa.select(table_students).where(table_students.c["age"] >= 20).subquery(),
33+
)
34+
assert comparison.row_counts.diff == 2
35+
assert comparison.row_matches.n_joined_total == 2

0 commit comments

Comments
 (0)