Skip to content

Commit b7275d0

Browse files
authored
Fix: Initialize dev intervals with prod intervals for new forward-only snapshots (#1561)
1 parent 19c5033 commit b7275d0

16 files changed

Lines changed: 203 additions & 93 deletions

File tree

sqlmesh/core/console.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -609,9 +609,6 @@ def _prompt_effective_from(self, plan: Plan, auto_apply: bool) -> None:
609609
if effective_from:
610610
plan.effective_from = effective_from
611611

612-
if plan.is_dev and plan.effective_from:
613-
plan.set_start(plan.effective_from)
614-
615612
def _prompt_backfill(self, plan: Plan, auto_apply: bool) -> None:
616613
is_forward_only_dev = plan.is_dev and plan.forward_only
617614
backfill_or_preview = "preview" if is_forward_only_dev else "backfill"

sqlmesh/core/plan/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
MWAAPlanEvaluator,
1111
PlanEvaluator,
1212
can_evaluate_before_promote,
13+
update_intervals_for_new_snapshots,
1314
)

sqlmesh/core/plan/definition.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
self.__dag: t.Optional[DAG[str]] = None
113113

114114
self._start = start
115-
if not self._start and is_dev and (forward_only or self._has_paused_forward_only()):
115+
if not self._start and is_dev and forward_only:
116116
self._start = default_start or yesterday_ds()
117117

118118
self._end = end if end or not is_dev else (default_end or now())
@@ -167,11 +167,8 @@ def start(self) -> TimeLike:
167167
@start.setter
168168
def start(self, new_start: TimeLike) -> None:
169169
self._ensure_valid_date_range(new_start, self._end)
170-
self.set_start(new_start)
171-
self.override_start = True
172-
173-
def set_start(self, new_start: TimeLike) -> None:
174170
self._start = new_start
171+
self.override_start = True
175172
self.__missing_intervals = None
176173
self._refresh_dag_and_ignored_snapshots()
177174

@@ -371,6 +368,7 @@ def set_choice(self, snapshot: Snapshot, choice: SnapshotChangeCategory) -> None
371368
# Invalidate caches.
372369
self._categorized = None
373370
self._uncategorized = None
371+
self.__missing_intervals = None
374372

375373
@property
376374
def effective_from(self) -> t.Optional[TimeLike]:
@@ -393,6 +391,9 @@ def effective_from(self, effective_from: t.Optional[TimeLike]) -> None:
393391
effective_from: The effective date to set.
394392
"""
395393
self._set_effective_from(effective_from)
394+
if effective_from and self.is_dev and not self.override_start:
395+
self._start = effective_from
396+
self._refresh_dag_and_ignored_snapshots()
396397

397398
def _set_effective_from(self, effective_from: t.Optional[TimeLike]) -> None:
398399
if not self.forward_only:
@@ -410,32 +411,30 @@ def _set_effective_from(self, effective_from: t.Optional[TimeLike]) -> None:
410411
@property
411412
def _missing_intervals(self) -> t.Dict[t.Tuple[str, str], Intervals]:
412413
if self.__missing_intervals is None:
413-
# we need previous snapshots because this method is cached and users have the option
414-
# to choose non-breaking / forward only. this will change the version of the snapshot on the fly
415-
# thus changing the missing intervals. additionally we replace any snapshots with the old copies
416-
# because they have intervals and the ephemeral ones don't
417-
snapshots = {
418-
(snapshot.name, snapshot.version_get_or_generate()): snapshot
419-
for snapshot in self.snapshots
414+
old_snapshots = {
415+
(old.name, old.version_get_or_generate()): old
416+
for _, old in self.context_diff.modified_snapshots.values()
420417
}
421418

422-
for new, old in self.context_diff.modified_snapshots.values():
423-
# Never override forward-only snapshots to preserve the effect
424-
# of the effective_from setting. Instead re-merge the intervals.
425-
if not new.is_forward_only:
426-
snapshots[(old.name, old.version_get_or_generate())] = old
427-
else:
428-
new.intervals = []
429-
new.merge_intervals(old)
419+
for new in self.context_diff.new_snapshots.values():
420+
new.intervals = []
421+
new.dev_intervals = []
422+
old = old_snapshots.get((new.name, new.version_get_or_generate()))
423+
if not old:
424+
continue
425+
new.merge_intervals(old)
426+
if new.is_forward_only:
427+
new.dev_intervals = new.intervals.copy()
430428

431429
self.__missing_intervals = {
432430
(snapshot.name, snapshot.version_get_or_generate()): missing
433431
for snapshot, missing in missing_intervals(
434-
snapshots.values(),
432+
self.snapshots,
435433
start=self._start,
436434
end=self._end,
437435
execution_time=self._execution_time,
438436
restatements=self.restatements,
437+
is_dev=self.is_dev,
439438
ignore_cron=True,
440439
).items()
441440
}
@@ -728,12 +727,6 @@ def _build_snapshots_and_dag(
728727
ignored_snapshot_names,
729728
)
730729

731-
def _has_paused_forward_only(self) -> bool:
732-
for name, snapshot in self._snapshot_mapping.items():
733-
if snapshot.is_paused_forward_only or self._is_forward_only_model(name):
734-
return True
735-
return False
736-
737730
def _is_forward_only_model(self, model_name: str) -> bool:
738731
def _is_forward_only_expected(snapshot: Snapshot) -> bool:
739732
# Returns True if the snapshot is not categorized yet but is expected

sqlmesh/core/plan/evaluator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def evaluate(self, plan: Plan) -> None:
8686
}
8787
after_promote_snapshots = all_names - before_promote_snapshots
8888

89+
update_intervals_for_new_snapshots(plan.new_snapshots, self.state_sync)
90+
8991
self._push(plan)
9092
self._restate(plan)
9193
self._backfill(plan, before_promote_snapshots)
@@ -377,3 +379,13 @@ def can_evaluate_before_promote(
377379
return not snapshot.is_paused_forward_only and not any(
378380
snapshots[p_id].is_paused_forward_only for p_id in snapshot.parents
379381
)
382+
383+
384+
def update_intervals_for_new_snapshots(
385+
snapshots: t.Collection[Snapshot], state_sync: StateSync
386+
) -> None:
387+
for snapshot in state_sync.refresh_snapshot_intervals(snapshots):
388+
if snapshot.is_forward_only:
389+
snapshot.dev_intervals = snapshot.intervals.copy()
390+
for start, end in snapshot.dev_intervals:
391+
state_sync.add_interval(snapshot, start, end, is_dev=True)

sqlmesh/core/scheduler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from sqlmesh.core.snapshot import (
1616
Snapshot,
1717
SnapshotEvaluator,
18-
SnapshotIdLike,
1918
earliest_start_date,
2019
missing_intervals,
2120
)
@@ -44,7 +43,7 @@ class Scheduler:
4443
The scheduler comes equipped with a simple ThreadPoolExecutor based evaluation engine.
4544
4645
Args:
47-
snapshots: A collection of snapshots/ids.
46+
snapshots: A collection of snapshots.
4847
snapshot_evaluator: The snapshot evaluator to execute queries.
4948
state_sync: The state sync to pull saved snapshots.
5049
max_workers: The maximum number of parallel queries to run.
@@ -53,15 +52,15 @@ class Scheduler:
5352

5453
def __init__(
5554
self,
56-
snapshots: t.Iterable[SnapshotIdLike],
55+
snapshots: t.Iterable[Snapshot],
5756
snapshot_evaluator: SnapshotEvaluator,
5857
state_sync: StateSync,
5958
max_workers: int = 1,
6059
console: t.Optional[Console] = None,
6160
notification_target_manager: t.Optional[NotificationTargetManager] = None,
6261
):
6362
self.state_sync = state_sync
64-
self.snapshots = self.state_sync.get_snapshots(snapshots)
63+
self.snapshots = {s.snapshot_id: s for s in snapshots}
6564
self.snapshot_per_version = _resolve_one_snapshot_per_version(self.snapshots.values())
6665
self.snapshot_evaluator = snapshot_evaluator
6766
self.max_workers = max_workers
@@ -107,6 +106,8 @@ def batches(
107106
if selected_snapshots is not None:
108107
snapshots = [s for s in snapshots if s.name in selected_snapshots]
109108

109+
self.state_sync.refresh_snapshot_intervals(snapshots)
110+
110111
return compute_interval_params(
111112
snapshots,
112113
start=start or earliest_start_date(snapshots),

sqlmesh/core/snapshot/definition.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -691,11 +691,7 @@ def missing_intervals(
691691
Returns:
692692
A list of all the missing intervals as epoch timestamps.
693693
"""
694-
intervals = (
695-
self.dev_intervals
696-
if is_dev and self.is_forward_only and self.is_paused
697-
else self.intervals
698-
)
694+
intervals = self.dev_intervals if is_dev and self.is_paused_forward_only else self.intervals
699695

700696
if self.is_symbolic or (self.is_seed and intervals):
701697
return []
@@ -775,7 +771,6 @@ def categorize_as(self, category: SnapshotChangeCategory) -> None:
775771
Args:
776772
category: The change category to assign to this snapshot.
777773
"""
778-
assert not self.intervals # a snapshot that has been processed should not be recategorized
779774
is_forward_only = category in (
780775
SnapshotChangeCategory.FORWARD_ONLY,
781776
SnapshotChangeCategory.INDIRECT_NON_BREAKING,

sqlmesh/core/state_sync/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,17 @@ def remove_interval(
275275
all_snapshots: All snapshots can be passed in to skip fetching matching snapshot versions.
276276
"""
277277

278+
@abc.abstractmethod
279+
def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]:
280+
"""Updates given snapshots with latest intervals from the state.
281+
282+
Args:
283+
snapshots: The snapshots to refresh.
284+
285+
Returns:
286+
The updated snapshots.
287+
"""
288+
278289
@abc.abstractmethod
279290
def promote(self, environment: Environment, no_gaps: bool = False) -> PromotionResult:
280291
"""Update the environment to reflect the current state.

sqlmesh/core/state_sync/engine_adapter.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -537,11 +537,12 @@ def remove_interval(
537537
snapshot_ids = ", ".join(str(s.snapshot_id) for s, _ in snapshot_intervals)
538538
logger.info("Removing interval for snapshots: %s", snapshot_ids)
539539

540-
self.engine_adapter.insert_append(
541-
self.intervals_table,
542-
_intervals_to_df(snapshot_intervals, is_dev=False, is_removed=True),
543-
columns_to_types=self._interval_columns_to_types,
544-
)
540+
for is_dev in (True, False):
541+
self.engine_adapter.insert_append(
542+
self.intervals_table,
543+
_intervals_to_df(snapshot_intervals, is_dev=is_dev, is_removed=True),
544+
columns_to_types=self._interval_columns_to_types,
545+
)
545546

546547
@transactional()
547548
def compact_intervals(self) -> None:
@@ -558,6 +559,16 @@ def compact_intervals(self) -> None:
558559
self.intervals_table, exp.column("id").isin(*interval_ids)
559560
)
560561

562+
def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]:
563+
if not snapshots:
564+
return []
565+
566+
_, intervals = self._get_snapshot_intervals(snapshots)
567+
for s in snapshots:
568+
s.intervals = []
569+
s.dev_intervals = []
570+
return Snapshot.hydrate_with_intervals_by_version(snapshots, intervals)
571+
561572
def max_interval_end_for_environment(self, environment: str) -> t.Optional[int]:
562573
env = self._get_environment(environment)
563574
if not env:

sqlmesh/schedulers/airflow/plan.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from sqlmesh.core import scheduler
99
from sqlmesh.core.engine_adapter import EngineAdapter
1010
from sqlmesh.core.environment import Environment
11-
from sqlmesh.core.plan import can_evaluate_before_promote
11+
from sqlmesh.core.plan import (
12+
can_evaluate_before_promote,
13+
update_intervals_for_new_snapshots,
14+
)
1215
from sqlmesh.core.snapshot import SnapshotTableInfo
1316
from sqlmesh.core.state_sync import EngineAdapterStateSync, StateSync
1417
from sqlmesh.core.state_sync.base import DelegatingStateSync
@@ -94,6 +97,8 @@ def create_plan_dag_spec(
9497
"Make sure your code base is up to date and try re-creating the plan"
9598
)
9699

100+
update_intervals_for_new_snapshots(new_snapshots.values(), state_sync)
101+
97102
if request.environment.end_at:
98103
end = request.environment.end_at
99104
unpaused_dt = None
@@ -104,14 +109,17 @@ def create_plan_dag_spec(
104109
unpaused_dt = end
105110

106111
if request.restatements:
112+
intervals_to_remove = [
113+
(s, request.restatements[s.name])
114+
for s in all_snapshots.values()
115+
if s.name in request.restatements and s.snapshot_id not in new_snapshots
116+
]
107117
state_sync.remove_interval(
108-
[
109-
(s, request.restatements[s.name])
110-
for s in all_snapshots.values()
111-
if s.name in request.restatements and s.snapshot_id not in new_snapshots
112-
],
118+
intervals_to_remove,
113119
remove_shared_versions=not request.is_dev,
114120
)
121+
for s, interval in intervals_to_remove:
122+
all_snapshots[s.snapshot_id].remove_interval(interval)
115123

116124
if not request.skip_backfill:
117125
backfill_batches = scheduler.compute_interval_params(

sqlmesh/schedulers/airflow/state_sync.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,19 @@ def remove_interval(
208208
"""
209209
raise NotImplementedError("Removing intervals is not supported by the Airflow state sync.")
210210

211+
def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]:
212+
"""Updates given snapshots with latest intervals from the state.
213+
214+
Args:
215+
snapshots: The snapshots to refresh.
216+
217+
Returns:
218+
The updated snapshots.
219+
"""
220+
raise NotImplementedError(
221+
"Refreshing snapshot intervals is not supported by the Airflow state sync."
222+
)
223+
211224
def promote(self, environment: Environment, no_gaps: bool = False) -> PromotionResult:
212225
"""Update the environment to reflect the current state.
213226

0 commit comments

Comments
 (0)