Skip to content

Commit cbe8c46

Browse files
authored
Fix: Inference of plan start and end (#1519)
1 parent 5c4ce6c commit cbe8c46

4 files changed

Lines changed: 31 additions & 70 deletions

File tree

sqlmesh/core/plan/definition.py

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
from sqlmesh.utils.dag import DAG
3232
from sqlmesh.utils.date import (
3333
TimeLike,
34-
make_inclusive_end,
3534
now,
36-
to_date,
3735
to_datetime,
3836
to_timestamp,
3937
validate_date_range,
@@ -144,7 +142,13 @@ def uncategorized(self) -> t.List[Snapshot]:
144142
def start(self) -> TimeLike:
145143
"""Returns the start of the plan or the earliest date of all snapshots."""
146144
if not self.override_start and not self._missing_intervals:
147-
return earliest_start_date(self.snapshots)
145+
earliest_start = earliest_start_date(self.snapshots)
146+
earliest_interval_starts = [s.intervals[0][0] for s in self.snapshots if s.intervals]
147+
return (
148+
min(earliest_start, to_datetime(min(earliest_interval_starts)))
149+
if earliest_interval_starts
150+
else earliest_start
151+
)
148152
return self._start or (
149153
min(
150154
start
@@ -166,40 +170,10 @@ def set_start(self, new_start: TimeLike) -> None:
166170
self.__missing_intervals = None
167171
self._refresh_dag_and_ignored_snapshots()
168172

169-
@classmethod
170-
def _get_end_date(cls, end_and_units: t.List[t.Tuple[int, IntervalUnit]]) -> TimeLike:
171-
if end_and_units:
172-
end, unit = max(end_and_units)
173-
174-
if unit.is_date_granularity:
175-
return to_date(make_inclusive_end(end))
176-
return end
177-
return now()
178-
179173
@property
180174
def end(self) -> TimeLike:
181175
"""Returns the end of the plan or now."""
182-
if not self._end or not self.override_end:
183-
if self._missing_intervals:
184-
return self._get_end_date(
185-
[
186-
(end, snapshot.node.interval_unit)
187-
for snapshot in self.snapshots
188-
if (snapshot.name, snapshot.version_get_or_generate())
189-
in self._missing_intervals
190-
for _, end in self._missing_intervals[
191-
(snapshot.name, snapshot.version_get_or_generate())
192-
]
193-
]
194-
)
195-
return self._get_end_date(
196-
[
197-
(snapshot.intervals[-1][1], snapshot.node.interval_unit)
198-
for snapshot in self.snapshots
199-
if snapshot.intervals
200-
]
201-
)
202-
return self._end
176+
return self._end or now()
203177

204178
@end.setter
205179
def end(self, new_end: TimeLike) -> None:

sqlmesh/core/snapshot/definition.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,8 +1208,14 @@ def missing_intervals(
12081208
"""Returns all missing intervals given a collection of snapshots."""
12091209
missing = {}
12101210
cache: t.Dict[str, datetime] = {}
1211-
start_dt = to_datetime(start or earliest_start_date(snapshots, cache))
12121211
end_date = end or now()
1212+
start_dt = (
1213+
to_datetime(start)
1214+
if start
1215+
else earliest_start_date(
1216+
snapshots, cache, default_value=to_date(end_date) - timedelta(days=1)
1217+
)
1218+
)
12131219
restatements = restatements or {}
12141220

12151221
for snapshot in snapshots:
@@ -1239,7 +1245,9 @@ def missing_intervals(
12391245

12401246

12411247
def earliest_start_date(
1242-
snapshots: t.Iterable[Snapshot], cache: t.Optional[t.Dict[str, datetime]] = None
1248+
snapshots: t.Iterable[Snapshot],
1249+
cache: t.Optional[t.Dict[str, datetime]] = None,
1250+
default_value: t.Optional[TimeLike] = None,
12431251
) -> datetime:
12441252
"""Get the earliest start date from a collection of snapshots.
12451253
@@ -1251,7 +1259,7 @@ def earliest_start_date(
12511259
"""
12521260
cache = {} if cache is None else cache
12531261
snapshots = list(snapshots)
1254-
earliest = to_datetime(yesterday().date())
1262+
earliest = to_datetime(default_value or yesterday().date())
12551263
if snapshots:
12561264
return min(start_date(snapshot, snapshots, cache) or earliest for snapshot in snapshots)
12571265
return earliest

tests/core/test_context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sqlmesh.core.environment import Environment
2121
from sqlmesh.core.model import load_sql_based_model
2222
from sqlmesh.core.plan import BuiltInPlanEvaluator, Plan
23-
from sqlmesh.utils.date import now, to_date, yesterday_ds
23+
from sqlmesh.utils.date import make_inclusive_end, now, to_date, yesterday_ds
2424
from sqlmesh.utils.errors import ConfigError
2525
from tests.utils.test_filesystem import create_temp_file
2626

@@ -438,12 +438,12 @@ def test_plan_default_end(sushi_context_pre_scheduling: Context):
438438
dev_plan = sushi_context_pre_scheduling.plan(
439439
"test_env", no_prompts=True, include_unmodified=True, skip_backfill=True, auto_apply=True
440440
)
441-
assert dev_plan.end == plan_end
441+
assert to_date(make_inclusive_end(dev_plan.end)) == plan_end
442442

443443
forward_only_dev_plan = sushi_context_pre_scheduling.plan(
444444
"test_env_forward_only", no_prompts=True, include_unmodified=True, forward_only=True
445445
)
446-
assert forward_only_dev_plan.end == plan_end
446+
assert to_date(make_inclusive_end(forward_only_dev_plan.end)) == plan_end
447447
assert forward_only_dev_plan._start == plan_end
448448

449449

tests/core/test_plan.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
SnapshotFingerprint,
1616
)
1717
from sqlmesh.utils.dag import DAG
18-
from sqlmesh.utils.date import now, to_date, to_ds, to_timestamp
18+
from sqlmesh.utils.date import now, to_date, to_datetime, to_timestamp
1919
from sqlmesh.utils.errors import PlanError
2020

2121

@@ -63,8 +63,8 @@ def test_forward_only_dev(make_snapshot, mocker: MockerFixture):
6363
)
6464

6565
expected_start = to_date("2022-01-01")
66-
expected_end = to_date("2022-01-02")
67-
expected_interval_end = to_timestamp(to_date("2022-01-03"))
66+
expected_end = to_date("2022-01-03")
67+
expected_interval_end = to_timestamp(to_date("2022-01-04"))
6868

6969
context_diff_mock = mocker.Mock()
7070
context_diff_mock.snapshots = {"a": snapshot_a}
@@ -342,6 +342,12 @@ def test_start_inference(make_snapshot, mocker: MockerFixture):
342342
assert plan.missing_intervals[0].snapshot_name == snapshot_a.name
343343
assert plan.start == to_timestamp("2022-01-01")
344344

345+
# Test inference from existing intervals
346+
context_diff_mock.snapshots = {"b": snapshot_b}
347+
plan = Plan(context_diff_mock)
348+
assert not plan.missing_intervals
349+
assert plan.start == to_datetime("2022-01-01")
350+
345351

346352
def test_auto_categorization(make_snapshot, mocker: MockerFixture):
347353
snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 1, ds")))
@@ -394,33 +400,6 @@ def test_auto_categorization_missing_schema_downstream(make_snapshot, mocker: Mo
394400
assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING
395401

396402

397-
def test_end_from_missing_instead_of_now(make_snapshot, mocker: MockerFixture):
398-
snapshot_a = make_snapshot(
399-
SqlModel(
400-
name="a",
401-
query=parse_one("select 1, ds"),
402-
kind=IncrementalByTimeRangeKind(time_column="ds"),
403-
)
404-
)
405-
406-
context_diff_mock = mocker.Mock()
407-
context_diff_mock.snapshots = {"a": snapshot_a}
408-
context_diff_mock.added = set()
409-
context_diff_mock.removed_snapshots = set()
410-
context_diff_mock.modified_snapshots = {}
411-
context_diff_mock.new_snapshots = {snapshot_a.snapshot_id: snapshot_a}
412-
413-
start_mock = mocker.patch("sqlmesh.core.snapshot.definition.earliest_start_date")
414-
start_mock.return_value = to_ds("2022-01-01")
415-
now_mock = mocker.patch("sqlmesh.core.plan.definition.now")
416-
now_mock.return_value = to_ds("2022-01-30")
417-
snapshot_a.add_interval("2022-01-01", "2022-01-05")
418-
419-
plan = Plan(context_diff_mock, is_dev=True)
420-
assert plan.start == to_timestamp("2022-01-06")
421-
assert plan.end == to_date("2022-01-29")
422-
423-
424403
def test_broken_references(make_snapshot, mocker: MockerFixture):
425404
snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("select 2, ds FROM a")))
426405
snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING)

0 commit comments

Comments
 (0)