Skip to content

Commit 23b293e

Browse files
refactor
1 parent 74332e8 commit 23b293e

1 file changed

Lines changed: 41 additions & 30 deletions

File tree

diffly/summary.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import dataclasses
77
import io
88
import json
9-
from dataclasses import dataclass
9+
from dataclasses import dataclass, field
1010
from datetime import date, datetime, timedelta
1111
from decimal import Decimal
1212
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -43,6 +43,10 @@ class SummaryDataSchemas:
4343
left_only: list[tuple[str, str]]
4444
in_common: list[tuple[str, str, str]]
4545
right_only: list[tuple[str, str]]
46+
_equal: bool = field(default=False, repr=False)
47+
_mismatching_dtypes: list[tuple[str, str, str]] = field(
48+
default_factory=list, repr=False
49+
)
4650

4751

4852
@dataclass
@@ -53,6 +57,8 @@ class SummaryDataRows:
5357
n_joined_equal: int | None
5458
n_joined_unequal: int | None
5559
n_right_only: int | None
60+
_equal_rows: bool = field(default=False, repr=False)
61+
_equal_num_rows: bool = field(default=False, repr=False)
5662

5763

5864
@dataclass
@@ -85,11 +91,13 @@ class SummaryData:
8591
columns: list[SummaryDataColumn] | None
8692
sample_rows_left_only: list[tuple[Any, ...]] | None
8793
sample_rows_right_only: list[tuple[Any, ...]] | None
94+
_truncated_left_name: str = field(default="", repr=False)
95+
_truncated_right_name: str = field(default="", repr=False)
8896

8997
def to_dict(self) -> dict[str, Any]:
9098
def _convert(obj: Any) -> Any:
9199
if isinstance(obj, dict):
92-
return {k: _convert(v) for k, v in obj.items()}
100+
return {k: _convert(v) for k, v in obj.items() if not k.startswith("_")}
93101
if isinstance(obj, (list, tuple)):
94102
return type(obj)(_convert(v) for v in obj)
95103
return _to_python(obj)
@@ -166,6 +174,9 @@ def _validate_primary_key_hidden_columns() -> None:
166174
is_equal = comp.equal()
167175
n_rows_left = comp.num_rows_left()
168176

177+
truncated_left = _truncate_name(left_name)
178+
truncated_right = _truncate_name(right_name)
179+
169180
if is_equal:
170181
return SummaryData(
171182
equal=True,
@@ -180,6 +191,8 @@ def _validate_primary_key_hidden_columns() -> None:
180191
columns=None,
181192
sample_rows_left_only=None,
182193
sample_rows_right_only=None,
194+
_truncated_left_name=truncated_left,
195+
_truncated_right_name=truncated_right,
183196
)
184197

185198
# --- Schemas ---
@@ -190,13 +203,19 @@ def _validate_primary_key_hidden_columns() -> None:
190203
left_only_cols = sorted(schemas_obj.left_only().items())
191204
right_only_cols = sorted(schemas_obj.right_only().items())
192205
in_common = sorted(schemas_obj.in_common().items())
206+
mismatching = sorted(schemas_obj.in_common().mismatching_dtypes().items())
193207
schemas = SummaryDataSchemas(
194208
left_only=[(name, str(dtype)) for name, dtype in left_only_cols],
195209
in_common=[
196210
(name, str(left_dtype), str(right_dtype))
197211
for name, (left_dtype, right_dtype) in in_common
198212
],
199213
right_only=[(name, str(dtype)) for name, dtype in right_only_cols],
214+
_equal=schemas_equal,
215+
_mismatching_dtypes=[
216+
(name, str(left_dtype), str(right_dtype))
217+
for name, (left_dtype, right_dtype) in mismatching
218+
],
200219
)
201220

202221
# --- Rows ---
@@ -215,6 +234,8 @@ def _validate_primary_key_hidden_columns() -> None:
215234
n_joined_equal=comp.num_rows_joined_equal(),
216235
n_joined_unequal=comp.num_rows_joined_unequal(),
217236
n_right_only=comp.num_rows_right_only(),
237+
_equal_rows=comp._equal_rows(),
238+
_equal_num_rows=comp.equal_num_rows(),
218239
)
219240
else:
220241
rows = SummaryDataRows(
@@ -224,6 +245,8 @@ def _validate_primary_key_hidden_columns() -> None:
224245
n_joined_equal=None,
225246
n_joined_unequal=None,
226247
n_right_only=None,
248+
_equal_rows=False,
249+
_equal_num_rows=comp.equal_num_rows(),
227250
)
228251

229252
# --- Columns ---
@@ -306,6 +329,8 @@ def _validate_primary_key_hidden_columns() -> None:
306329
columns=columns,
307330
sample_rows_left_only=sample_rows_left_only,
308331
sample_rows_right_only=sample_rows_right_only,
332+
_truncated_left_name=truncated_left,
333+
_truncated_right_name=truncated_right,
309334
)
310335

311336

@@ -421,8 +446,7 @@ def _print_diff(self, console: Console) -> None:
421446
# --------------------------------- PRIMARY KEY ---------------------------------- #
422447

423448
def _print_primary_key(self, console: Console) -> None:
424-
primary_key = self._data.primary_key
425-
if primary_key is not None:
449+
if (primary_key := self._data.primary_key) is not None:
426450
content = self._section_primary_key(primary_key)
427451
else:
428452
content = Text(
@@ -449,14 +473,8 @@ def _print_schemas(self, console: Console) -> None:
449473
return
450474

451475
schemas = self._data.schemas
452-
schemas_equal = (
453-
not schemas.left_only
454-
and not schemas.right_only
455-
and all(left == right for _, left, right in schemas.in_common)
456-
)
457-
458476
content: RenderableType
459-
if schemas_equal:
477+
if schemas._equal:
460478
num_cols = len(schemas.in_common)
461479
content = Text(
462480
f"Schemas match exactly (column count: {num_cols:,}).", style="italic"
@@ -488,7 +506,7 @@ def _print_num_columns(n: int) -> str:
488506

489507
# Left only
490508
if len(left_only_names) > 0:
491-
left_only_header = f"{capitalize_first(_truncate_name(self._data.left_name))} only \n{_print_num_columns(len(left_only_names))}"
509+
left_only_header = f"{capitalize_first(self._data._truncated_left_name)} only \n{_print_num_columns(len(left_only_names))}"
492510
table.add_column(
493511
left_only_header,
494512
header_style="red",
@@ -512,11 +530,7 @@ def _print_num_columns(n: int) -> str:
512530
)
513531
num_in_common = len(schemas.in_common)
514532
table_data[in_common_header] = []
515-
mismatching = [
516-
(name, left, right)
517-
for name, left, right in schemas.in_common
518-
if left != right
519-
]
533+
mismatching = schemas._mismatching_dtypes
520534
if len(mismatching) == 0:
521535
table_data[in_common_header] = ["..."]
522536
max_column_width = max(
@@ -542,7 +556,7 @@ def _print_num_columns(n: int) -> str:
542556

543557
# Right only
544558
if len(right_only_names) > 0:
545-
right_only_header = f"{capitalize_first(_truncate_name(self._data.right_name))} only\n{_print_num_columns(len(right_only_names))}"
559+
right_only_header = f"{capitalize_first(self._data._truncated_right_name)} only\n{_print_num_columns(len(right_only_names))}"
546560
table.add_column(
547561
right_only_header,
548562
header_style="green",
@@ -582,7 +596,7 @@ def _print_rows(self, console: Console) -> None:
582596

583597
def _render_rows_without_primary_key(self, rows: SummaryDataRows) -> RenderableType:
584598
content: RenderableType
585-
if rows.n_left == rows.n_right:
599+
if rows._equal_num_rows:
586600
content = Text(
587601
f"The number of rows matches exactly (row count: {rows.n_left:,}).",
588602
style="italic",
@@ -598,16 +612,15 @@ def _render_rows_with_primary_key(self, rows: SummaryDataRows) -> RenderableType
598612
assert rows.n_right_only is not None
599613

600614
content: RenderableType
601-
equal_rows = rows.n_joined_equal == rows.n_left == rows.n_right
602-
if equal_rows:
615+
if rows._equal_rows:
603616
content = Text(
604617
f"All rows match exactly (row count: {rows.n_left:,}).",
605618
style="italic",
606619
)
607620
else:
608621
# NOTE: In slim mode, we omit the row counts section and only show the
609622
# row matches section.
610-
if (rows.n_left == rows.n_right) and self._data.slim:
623+
if rows._equal_num_rows and self._data.slim:
611624
content = Group(self._section_row_matches(rows))
612625
else:
613626
content = Group(
@@ -632,10 +645,8 @@ def _section_row_counts(self, rows: SummaryDataRows) -> RenderableType:
632645
count_rows: list[RenderableType] = []
633646

634647
count_grid = Table(padding=0, box=None)
635-
left_header = f"{capitalize_first(_truncate_name(self._data.left_name))} count"
636-
right_header = (
637-
f"{capitalize_first(_truncate_name(self._data.right_name))} count"
638-
)
648+
left_header = f"{capitalize_first(self._data._truncated_left_name)} count"
649+
right_header = f"{capitalize_first(self._data._truncated_right_name)} count"
639650
count_grid.add_column(left_header, justify="center")
640651
count_grid.add_column("", justify="center")
641652
count_grid.add_column(right_header, justify="center")
@@ -741,7 +752,7 @@ def _section_row_matches(self, rows: SummaryDataRows) -> RenderableType:
741752
fraction_left_only = rows.n_left_only / rows.n_left
742753
grid.add_row(
743754
f"{rows.n_left_only:,}",
744-
f"{_truncate_name(self._data.left_name)} only",
755+
f"{self._data._truncated_left_name} only",
745756
f"({_format_fraction_as_percentage(fraction_left_only)})",
746757
)
747758
grid.add_section()
@@ -765,7 +776,7 @@ def _section_row_matches(self, rows: SummaryDataRows) -> RenderableType:
765776
fraction_right_only = rows.n_right_only / rows.n_right
766777
grid.add_row(
767778
f"{rows.n_right_only:,}",
768-
f"{_truncate_name(self._data.right_name)} only",
779+
f"{self._data._truncated_right_name} only",
769780
f"({_format_fraction_as_percentage(fraction_right_only)})",
770781
)
771782
columns.append(grid)
@@ -880,10 +891,10 @@ def _section_columns(self) -> RenderableType:
880891
def _print_sample_rows_only_one_side(self, console: Console, side: Side) -> None:
881892
if side == Side.LEFT:
882893
sample_rows = self._data.sample_rows_left_only
883-
name = _truncate_name(self._data.left_name)
894+
name = self._data._truncated_left_name
884895
else:
885896
sample_rows = self._data.sample_rows_right_only
886-
name = _truncate_name(self._data.right_name)
897+
name = self._data._truncated_right_name
887898

888899
primary_key = self._data.primary_key
889900
if primary_key is not None and sample_rows is not None and len(sample_rows) > 0:

0 commit comments

Comments
 (0)