Skip to content

Commit 4bbc8a1

Browse files
z3z1matobymao
andauthored
Feat: Greatly simplify table diff CLI invocation (#977)
* feat: greatly simplify table diff cli api * fix: use metavar * feat: add padding to cli and improve diff output * fix: ensure change is exclusive to cli interface and arbitrary comp is supported * fix: ensure model name is not required * fix: address cli on cond usage to use multiple and actually return a seq * chore: use row diff verbiage for brevity and semantic correctness * ci: run linter * ci: fix test * chore: make aliases optional, intent more obvious * nit: make shorthand on flag lowercase --------- Co-authored-by: Toby Mao <toby.mao@gmail.com>
1 parent e47bee2 commit 4bbc8a1

File tree

4 files changed

+54
-45
lines changed

4 files changed

+54
-45
lines changed

sqlmesh/cli/main.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import typing as t
77

88
import click
9-
from sqlglot import exp
10-
from sqlglot.helper import ensure_list
119

1210
from sqlmesh import debug_mode_enabled, enable_logging
1311
from sqlmesh.cli import error_handler
@@ -388,35 +386,14 @@ def create_external_models(obj: Context) -> None:
388386

389387

390388
@cli.command("table_diff")
389+
@click.argument("source_to_target", required=True, metavar="SOURCE:TARGET")
390+
@click.argument("model", required=False)
391391
@click.option(
392-
"--source",
393-
"-s",
394-
type=str,
395-
required=True,
396-
help="The source environment or table.",
397-
)
398-
@click.option(
399-
"--target",
400-
"-t",
401-
type=str,
402-
required=True,
403-
help="The target environment or table.",
404-
)
405-
@click.option(
406-
"--grain",
407-
type=str,
408-
multiple=True,
409-
help="The list of columns to use as keys.",
410-
)
411-
@click.option(
392+
"-o",
412393
"--on",
413394
type=str,
414-
help='The SQL join condition or list of columns to use as keys. Table aliases must be "s" and "t" for source and target.',
415-
)
416-
@click.option(
417-
"--model",
418-
type=str,
419-
help="The model to diff against when source and target are environments and not tables.",
395+
multiple=True,
396+
help="The column to join on. Can be specified multiple times. The model grain will be used if not specified.",
420397
)
421398
@click.option(
422399
"--where",
@@ -426,20 +403,22 @@ def create_external_models(obj: Context) -> None:
426403
@click.option(
427404
"--limit",
428405
type=int,
406+
default=20,
429407
help="The limit of the sample dataframe.",
430408
)
431409
@click.pass_obj
432410
@error_handler
433-
def table_diff(obj: Context, **kwargs: t.Any) -> None:
434-
"""Show the diff between two tables.
435-
436-
Can either be two tables or two environments and a model.
437-
"""
438-
kwargs["model_or_snapshot"] = kwargs.pop("model", None)
439-
on = kwargs.pop("on", None)
440-
grain = ensure_list(kwargs.pop("grain", None))
441-
kwargs["on"] = exp.condition(on) if on else grain
442-
obj.table_diff(**kwargs)
411+
def table_diff(
412+
obj: Context, source_to_target: str, model: t.Optional[str], **kwargs: t.Any
413+
) -> None:
414+
"""Show the diff between two tables."""
415+
source, target = source_to_target.split(":")
416+
obj.table_diff(
417+
source=source,
418+
target=target,
419+
model_or_snapshot=model,
420+
**kwargs,
421+
)
443422

444423

445424
@cli.command("prompt")

sqlmesh/core/console.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,9 @@ def loading_stop(self, id: uuid.UUID) -> None:
459459
del self.loading_status[id]
460460

461461
def show_schema_diff(self, schema_diff: SchemaDiff) -> None:
462-
tree = Tree(f"[bold]Schema Diff Between '{schema_diff.source}' and '{schema_diff.target}':")
462+
tree = Tree(
463+
f"\n[b]Schema Diff Between '[yellow]{schema_diff.source}[/yellow]' and '[green]{schema_diff.target}[/green]':"
464+
)
463465

464466
if schema_diff.added:
465467
added = Tree("[green]Added Columns:")
@@ -482,10 +484,19 @@ def show_schema_diff(self, schema_diff: SchemaDiff) -> None:
482484
self.console.print(tree)
483485

484486
def show_row_diff(self, row_diff: RowDiff) -> None:
485-
self.console.print(
486-
f"[bold]Row Count:[/bold] {row_diff.source}: {row_diff.source_count}, {row_diff.target}: {row_diff.target_count} -- {row_diff.count_pct_change}%"
487-
)
488-
self.console.print(row_diff.sample.to_string(index=False))
487+
source_name = row_diff.source
488+
if row_diff.source_alias:
489+
source_name = row_diff.source_alias.upper()
490+
target_name = row_diff.target
491+
if row_diff.target_alias:
492+
target_name = row_diff.target_alias.upper()
493+
494+
self.console.print("\n[b]Row Count:[/b]")
495+
self.console.print(f" [yellow]{source_name}[/yellow]: {row_diff.source_count} rows")
496+
self.console.print(f" [green]{target_name}[/green]: {row_diff.target_count} rows")
497+
self.console.print(f"\n[b]Row Diff[b]: {row_diff.count_pct_change:.1f}%")
498+
self.console.print("\n[b]Sample Rows:[/b]")
499+
self.console.print(row_diff.sample.to_string(index=False), end="\n\n")
489500

490501
def _get_snapshot_change_category(
491502
self, snapshot: Snapshot, plan: Plan, auto_apply: bool

sqlmesh/core/context.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,7 @@ def table_diff(
777777
Returns:
778778
The TableDiff object containing schema and summary differences.
779779
"""
780+
source_alias, target_alias = source, target
780781
if model_or_snapshot:
781782
model = self.get_model(model_or_snapshot, raise_if_missing=True)
782783
source_env = self.state_reader.get_environment(source)
@@ -793,6 +794,8 @@ def table_diff(
793794
target = next(
794795
snapshot for snapshot in target_env.snapshots if snapshot.name == model.name
795796
).table_name()
797+
source_alias = source_env.name
798+
target_alias = target_env.name
796799

797800
if not on and model.grain:
798801
on = model.grain
@@ -801,7 +804,14 @@ def table_diff(
801804
raise SQLMeshError("Missing join condition 'on'")
802805

803806
table_diff = TableDiff(
804-
adapter=self._engine_adapter, source=source, target=target, on=on, where=where
807+
adapter=self._engine_adapter,
808+
source=source,
809+
target=target,
810+
on=on,
811+
where=where,
812+
source_alias=source_alias,
813+
target_alias=target_alias,
814+
limit=limit,
805815
)
806816
if show:
807817
self.console.show_schema_diff(table_diff.schema_diff())

sqlmesh/core/table_diff.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class RowDiff(PydanticModel, frozen=True):
5252
target: str
5353
stats: t.Dict[str, float]
5454
sample: pd.DataFrame
55+
source_alias: t.Optional[str] = None
56+
target_alias: t.Optional[str] = None
5557

5658
@property
5759
def source_count(self) -> int:
@@ -83,14 +85,19 @@ def __init__(
8385
where: t.Optional[str | exp.Condition] = None,
8486
dialect: DialectType = None,
8587
limit: int = 20,
88+
source_alias: t.Optional[str] = None,
89+
target_alias: t.Optional[str] = None,
8690
):
8791
self.adapter = adapter
8892
self.source = source
8993
self.target = target
9094
self.where = exp.condition(where, dialect=dialect) if where else None
9195
self.limit = limit
96+
# Support environment aliases for diff output improvement in certain cases
97+
self.source_alias = source_alias
98+
self.target_alias = target_alias
9299

93-
if isinstance(on, list):
100+
if isinstance(on, (list, tuple)):
94101
self.on: exp.Condition = exp.and_(
95102
*(
96103
exp.column(c, "s").eq(exp.column(c, "t"))
@@ -194,5 +201,7 @@ def row_diff(self) -> RowDiff:
194201
target=self.target,
195202
stats=self.adapter.fetchdf(summary_query).iloc[0].to_dict(),
196203
sample=self.adapter.fetchdf(sample_query),
204+
source_alias=self.source_alias,
205+
target_alias=self.target_alias,
197206
)
198207
return self._row_diff

0 commit comments

Comments
 (0)