Skip to content

Commit 13000d6

Browse files
authored
Fix: Introduce environment finalization (#764)
1 parent 5b1ce1b commit 13000d6

11 files changed

Lines changed: 127 additions & 42 deletions

File tree

sqlmesh/core/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def _model_tables(self) -> t.Dict[str, str]:
892892

893893
def _context_diff(
894894
self,
895-
environment: str | Environment,
895+
environment: str,
896896
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
897897
create_from: t.Optional[str] = None,
898898
) -> ContextDiff:

sqlmesh/core/context_diff.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import typing as t
1515

16-
from sqlmesh.core.environment import Environment
1716
from sqlmesh.core.snapshot import (
1817
Snapshot,
1918
SnapshotChangeCategory,
@@ -38,6 +37,8 @@ class ContextDiff(PydanticModel):
3837
"""The environment to diff."""
3938
is_new_environment: bool
4039
"""Whether the target environment is new."""
40+
is_unfinalized_environment: bool
41+
"""Whether the currently stored environment record is in unfinalized state."""
4142
create_from: str
4243
"""The name of the environment the target environment will be created from if new."""
4344
added: t.Set[str]
@@ -56,7 +57,7 @@ class ContextDiff(PydanticModel):
5657
@classmethod
5758
def create(
5859
cls,
59-
environment: str | Environment,
60+
environment: str,
6061
snapshots: t.Dict[str, Snapshot],
6162
create_from: str,
6263
state_reader: StateReader,
@@ -73,12 +74,8 @@ def create(
7374
Returns:
7475
The ContextDiff object.
7576
"""
76-
if isinstance(environment, str):
77-
environment = environment.lower()
78-
env = state_reader.get_environment(environment)
79-
else:
80-
env = environment
81-
environment = env.name.lower()
77+
environment = environment.lower()
78+
env = state_reader.get_environment(environment)
8279

8380
if env is None:
8481
env = state_reader.get_environment(create_from.lower())
@@ -151,6 +148,7 @@ def create(
151148
return ContextDiff(
152149
environment=environment,
153150
is_new_environment=is_new_environment,
151+
is_unfinalized_environment=bool(env and not env.finalized_ts),
154152
create_from=create_from,
155153
added=added,
156154
removed=removed,
@@ -162,7 +160,9 @@ def create(
162160

163161
@property
164162
def has_changes(self) -> bool:
165-
return self.has_snapshot_changes or self.is_new_environment
163+
return (
164+
self.has_snapshot_changes or self.is_new_environment or self.is_unfinalized_environment
165+
)
166166

167167
@property
168168
def has_snapshot_changes(self) -> bool:

sqlmesh/core/environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Environment(PydanticModel):
2424
plan_id: str
2525
previous_plan_id: t.Optional[str]
2626
expiration_ts: t.Optional[int]
27+
finalized_ts: t.Optional[int]
2728

2829
@validator("snapshots", pre=True)
2930
@classmethod

sqlmesh/core/plan/evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def on_complete(snapshot: SnapshotInfoLike) -> None:
137137
environment=environment.name,
138138
on_complete=on_complete,
139139
)
140+
self.state_sync.finalize(environment)
140141
completed = True
141142
finally:
142143
self.console.stop_promotion_progress(success=completed)

sqlmesh/core/state_sync/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,15 @@ def promote(
337337
A tuple of (added snapshot table infos, removed snapshot table infos)
338338
"""
339339

340+
@abc.abstractmethod
341+
def finalize(self, environment: Environment) -> None:
342+
"""Finalize the target environment, indicating that this environment has been
343+
fully promoted and is ready for use.
344+
345+
Args:
346+
environment: The target environment to finalize.
347+
"""
348+
340349
@abc.abstractmethod
341350
def delete_expired_environments(self) -> t.List[Environment]:
342351
"""Removes expired environments.

sqlmesh/core/state_sync/common.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
SnapshotTableInfo,
1818
)
1919
from sqlmesh.core.state_sync.base import StateSync
20-
from sqlmesh.utils.date import TimeLike, now, to_datetime
20+
from sqlmesh.utils.date import TimeLike, now, now_timestamp, to_datetime
2121
from sqlmesh.utils.errors import SQLMeshError
2222

2323
logger = logging.getLogger(__name__)
@@ -129,6 +129,26 @@ def promote(
129129
self._update_environment(environment)
130130
return table_infos, [existing_table_infos[name] for name in missing_models]
131131

132+
@transactional()
133+
def finalize(self, environment: Environment) -> None:
134+
"""Finalize the target environment, indicating that this environment has been
135+
fully promoted and is ready for use.
136+
137+
Args:
138+
environment: The target environment to finalize.
139+
"""
140+
logger.info("Finalizing environment '%s'", environment)
141+
142+
stored_environment = self._get_environment(environment.name, lock_for_update=True)
143+
if stored_environment and stored_environment.plan_id != environment.plan_id:
144+
raise SQLMeshError(
145+
f"Plan '{environment.plan_id}' is no longer valid for the target environment '{environment.name}'. "
146+
f"Stored plan ID: '{stored_environment.plan_id}'. Please recreate the plan and try again"
147+
)
148+
149+
environment.finalized_ts = now_timestamp()
150+
self._update_environment(environment)
151+
132152
@transactional()
133153
def delete_expired_snapshots(self) -> t.List[Snapshot]:
134154
current_time = now()

sqlmesh/core/state_sync/engine_adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def environment_columns_to_types(self) -> t.Dict[str, exp.DataType]:
9090
"plan_id": exp.DataType.build("text"),
9191
"previous_plan_id": exp.DataType.build("text"),
9292
"expiration_ts": exp.DataType.build("bigint"),
93+
"finalized_ts": exp.DataType.build("bigint"),
9394
}
9495

9596
@property
@@ -236,6 +237,7 @@ def _update_environment(self, environment: Environment) -> None:
236237
environment.plan_id,
237238
environment.previous_plan_id,
238239
environment.expiration_ts,
240+
environment.finalized_ts,
239241
)
240242
],
241243
columns_to_types=self.environment_columns_to_types,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Add support for environment finalization."""
2+
from sqlglot import exp
3+
4+
5+
def migrate(state_sync): # type: ignore
6+
engine_adapter = state_sync.engine_adapter
7+
environments_table = state_sync.environments_table
8+
9+
alter_table_exp = exp.AlterTable(
10+
this=exp.to_table(environments_table),
11+
actions=[
12+
exp.ColumnDef(
13+
this=exp.to_column("finalized_ts"),
14+
kind=exp.DataType.build("bigint"),
15+
)
16+
],
17+
)
18+
19+
engine_adapter.execute(alter_table_exp)

sqlmesh/schedulers/airflow/dag_generator.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -242,23 +242,34 @@ def _create_promotion_demotion_tasks(
242242
start_task = EmptyOperator(task_id="snapshot_promotion_start")
243243
end_task = EmptyOperator(task_id="snapshot_promotion_end")
244244

245+
environment = Environment(
246+
name=request.environment_name,
247+
snapshots=request.promoted_snapshots,
248+
start_at=request.start,
249+
end_at=request.end,
250+
plan_id=request.plan_id,
251+
previous_plan_id=request.previous_plan_id,
252+
expiration_ts=request.environment_expiration_ts,
253+
)
254+
245255
update_state_task = PythonOperator(
246256
task_id="snapshot_promotion__update_state",
247257
python_callable=promotion_update_state_task,
248258
op_kwargs={
249-
"snapshots": request.promoted_snapshots,
250-
"environment_name": request.environment_name,
251-
"start": request.start,
252-
"end": request.end,
259+
"environment": environment,
253260
"unpaused_dt": request.unpaused_dt,
254261
"no_gaps": request.no_gaps,
255-
"plan_id": request.plan_id,
256-
"previous_plan_id": request.previous_plan_id,
257-
"environment_expiration_ts": request.environment_expiration_ts,
258262
},
259263
)
260264

265+
finalize_task = PythonOperator(
266+
task_id="snapshot_promotion__finalize",
267+
python_callable=promotion_finalize_task,
268+
op_kwargs={"environment": environment},
269+
)
270+
261271
start_task >> update_state_task
272+
finalize_task >> end_task
262273

263274
if request.promoted_snapshots:
264275
create_views_task = self._create_snapshot_promotion_operator(
@@ -268,7 +279,7 @@ def _create_promotion_demotion_tasks(
268279
request.is_dev,
269280
"snapshot_promotion__create_views",
270281
)
271-
create_views_task >> end_task
282+
create_views_task >> finalize_task
272283

273284
if not request.is_dev and request.unpaused_dt:
274285
migrate_tables_task = self._create_snapshot_migrate_tables_operator(
@@ -289,10 +300,10 @@ def _create_promotion_demotion_tasks(
289300
"snapshot_promotion__delete_views",
290301
)
291302
update_state_task >> delete_views_task
292-
delete_views_task >> end_task
303+
delete_views_task >> finalize_task
293304

294305
if not request.promoted_snapshots and not request.demoted_snapshots:
295-
update_state_task >> end_task
306+
update_state_task >> finalize_task
296307

297308
return (start_task, end_task)
298309

@@ -479,26 +490,16 @@ def creation_update_state_task(new_snapshots: t.Iterable[Snapshot]) -> None:
479490

480491

481492
def promotion_update_state_task(
482-
snapshots: t.List[SnapshotTableInfo],
483-
environment_name: str,
484-
start: TimeLike,
485-
end: t.Optional[TimeLike],
493+
environment: Environment,
486494
unpaused_dt: t.Optional[TimeLike],
487495
no_gaps: bool,
488-
plan_id: str,
489-
previous_plan_id: t.Optional[str],
490-
environment_expiration_ts: t.Optional[int],
491496
) -> None:
492-
environment = Environment(
493-
name=environment_name,
494-
snapshots=snapshots,
495-
start_at=start,
496-
end_at=end,
497-
plan_id=plan_id,
498-
previous_plan_id=previous_plan_id,
499-
expiration_ts=environment_expiration_ts,
500-
)
501497
with util.scoped_state_sync() as state_sync:
502498
state_sync.promote(environment, no_gaps=no_gaps)
503-
if snapshots and not end and unpaused_dt:
504-
state_sync.unpause_snapshots(snapshots, unpaused_dt)
499+
if environment.snapshots and not environment.end_at and unpaused_dt:
500+
state_sync.unpause_snapshots(environment.snapshots, unpaused_dt)
501+
502+
503+
def promotion_finalize_task(environment: Environment) -> None:
504+
with util.scoped_state_sync() as state_sync:
505+
state_sync.finalize(environment)

tests/core/test_state_sync.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing as t
22

3+
import duckdb
34
import pandas as pd
45
import pytest
56
from pytest_mock.plugin import MockerFixture
@@ -445,6 +446,34 @@ def test_promote_snapshots_no_gaps(state_sync: EngineAdapterStateSync, make_snap
445446
promote_snapshots(state_sync, [new_snapshot_same_interval], "prod", no_gaps=True)
446447

447448

449+
def test_finalize(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable):
450+
snapshot_a = make_snapshot(
451+
SqlModel(
452+
name="a",
453+
query=parse_one("select 1, ds"),
454+
),
455+
)
456+
snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)
457+
458+
state_sync.push_snapshots([snapshot_a])
459+
promote_snapshots(state_sync, [snapshot_a], "prod")
460+
461+
env = state_sync.get_environment("prod")
462+
assert env
463+
state_sync.finalize(env)
464+
465+
env = state_sync.get_environment("prod")
466+
assert env
467+
assert env.finalized_ts is not None
468+
469+
env.plan_id = "different_plan_id"
470+
with pytest.raises(
471+
SQLMeshError,
472+
match=r"Plan 'different_plan_id' is no longer valid for the target environment 'prod'.*",
473+
):
474+
state_sync.finalize(env)
475+
476+
448477
def test_start_date_gap(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable):
449478
model = SqlModel(
450479
name="a",
@@ -569,8 +598,8 @@ def test_get_version(state_sync: EngineAdapterStateSync) -> None:
569598
schema_version=SCHEMA_VERSION, sqlglot_version=SQLGLOT_VERSION
570599
)
571600

572-
# old install does not have this table / row
573-
delete_versions(state_sync)
601+
# Start with a clean slate.
602+
state_sync = EngineAdapterStateSync(create_engine_adapter(duckdb.connect, "duckdb"))
574603

575604
with pytest.raises(
576605
SQLMeshError,
@@ -626,7 +655,9 @@ def test_migrate(state_sync: EngineAdapterStateSync, mocker: MockerFixture) -> N
626655
state_sync.migrate()
627656
mock.assert_not_called()
628657

629-
delete_versions(state_sync)
658+
# Start with a clean slate.
659+
state_sync = EngineAdapterStateSync(create_engine_adapter(duckdb.connect, "duckdb"))
660+
630661
state_sync.migrate()
631662
mock.assert_called_once()
632663
assert state_sync.get_versions() == Versions(

0 commit comments

Comments
 (0)