Skip to content

Commit f4dec46

Browse files
authored
Fix!: Make sure internal models are not accidentally converted to external ones (#926)
1 parent 3e41902 commit f4dec46

File tree

14 files changed

+206
-46
lines changed

14 files changed

+206
-46
lines changed

sqlmesh/core/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,7 @@ def create_external_models(self) -> None:
853853
if self.config_for_model(model) is config
854854
},
855855
adapter=self._engine_adapter,
856+
state_reader=self.state_reader,
856857
dialect=config.model_defaults.dialect,
857858
max_workers=self.concurrent_tasks,
858859
)

sqlmesh/core/model/definition.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,12 @@ def text_diff(self, other: Model) -> str:
348348
Returns:
349349
A unified text diff showing additions and deletions.
350350
"""
351-
meta_a, *statements_a, query_a = self.render_definition()
352-
meta_b, *statements_b, query_b = other.render_definition()
351+
meta_a, *statements_a = self.render_definition()
352+
meta_b, *statements_b = other.render_definition()
353+
354+
query_a = statements_a.pop() if statements_a else None
355+
query_b = statements_b.pop() if statements_b else None
356+
353357
return "\n".join(
354358
(
355359
d.text_diff(meta_a, meta_b, self.dialect),

sqlmesh/core/plan/evaluator.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,20 @@ def _restate(self, plan: Plan) -> None:
159159
if not plan.restatements:
160160
return
161161

162-
all_snapshots = (
163-
[s for s in plan.snapshots if s.name in plan.restatements]
164-
if plan.is_dev
165-
else self.state_sync.get_snapshots_by_models(*plan.restatements)
166-
)
167-
self.state_sync.remove_interval(
168-
[],
169-
start=plan.start,
170-
end=plan.end,
171-
all_snapshots=all_snapshots,
172-
)
162+
target_snapshots = [s for s in plan.snapshots if s.name in plan.restatements]
163+
if plan.is_dev:
164+
self.state_sync.remove_interval(
165+
[],
166+
start=plan.start,
167+
end=plan.end,
168+
all_snapshots=target_snapshots,
169+
)
170+
else:
171+
self.state_sync.remove_interval(
172+
target_snapshots,
173+
start=plan.start,
174+
end=plan.end,
175+
)
173176

174177

175178
class AirflowPlanEvaluator(PlanEvaluator):

sqlmesh/core/schema_loader.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
import typing as t
45
from concurrent.futures import ThreadPoolExecutor
56
from pathlib import Path
@@ -9,12 +10,16 @@
910

1011
from sqlmesh.core.engine_adapter import EngineAdapter
1112
from sqlmesh.core.model import Model
13+
from sqlmesh.core.state_sync import StateReader
14+
15+
logger = logging.getLogger(__name__)
1216

1317

1418
def create_schema_file(
1519
path: Path,
1620
models: t.Dict[str, Model],
1721
adapter: EngineAdapter,
22+
state_reader: StateReader,
1823
dialect: DialectType,
1924
max_workers: int = 1,
2025
) -> None:
@@ -24,13 +29,24 @@ def create_schema_file(
2429
path: The path to store the YAML file.
2530
models: A dictionary of models to fetch columns from the db.
2631
adapter: The engine adapter.
32+
state_reader: The state reader.
2733
dialect: The dialect to serialize the schema as.
2834
max_workers: The max concurrent workers to fetch columns.
2935
"""
3036
external_tables = {
3137
dep for model in models.values() for dep in model.depends_on if dep not in models
3238
}
3339

40+
# Make sure we don't convert internal models into external ones.
41+
existing_models = state_reader.models_exist(external_tables, exclude_external=True)
42+
if existing_models:
43+
logger.warning(
44+
"The following models already exist and can't be converted to external: %s."
45+
"Perhaps these models have been removed, while downstream models that reference them weren't updated accordingly",
46+
", ".join(existing_models),
47+
)
48+
external_tables -= existing_models
49+
3450
with ThreadPoolExecutor(max_workers=max_workers) as pool:
3551
schemas = [
3652
{

sqlmesh/core/state_sync/base.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@ def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[Sna
7676
A set of all the existing snapshot ids.
7777
"""
7878

79+
@abc.abstractmethod
80+
def models_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]:
81+
"""Returns the model names that exist in the state sync.
82+
83+
Args:
84+
names: Iterable of model names to check.
85+
exclude_external: Whether to exclude external models from the output.
86+
87+
Returns:
88+
A set of all the existing model names.
89+
"""
90+
7991
@abc.abstractmethod
8092
def get_environment(self, environment: str) -> t.Optional[Environment]:
8193
"""Fetches the environment if it exists.
@@ -95,14 +107,6 @@ def get_environments(self) -> t.List[Environment]:
95107
A list of all environments.
96108
"""
97109

98-
@abc.abstractmethod
99-
def get_snapshots_by_models(self, *names: str) -> t.List[Snapshot]:
100-
"""Get all snapshots by model name.
101-
102-
Returns:
103-
The list of snapshots.
104-
"""
105-
106110
@abc.abstractmethod
107111
def get_snapshot_intervals(
108112
self, snapshots: t.Optional[t.Iterable[SnapshotNameVersionLike]]

sqlmesh/core/state_sync/common.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,6 @@ def get_snapshots(
4949
def get_environment(self, environment: str) -> t.Optional[Environment]:
5050
return self._get_environment(environment)
5151

52-
def get_snapshots_by_models(
53-
self, *names: str, lock_for_update: bool = False
54-
) -> t.List[Snapshot]:
55-
"""
56-
Get all snapshots by model name.
57-
58-
Returns:
59-
The list of snapshots.
60-
"""
61-
return [
62-
snapshot
63-
for snapshot in self._get_snapshots(lock_for_update=lock_for_update).values()
64-
if snapshot.name in names
65-
]
66-
6752
@transactional()
6853
def promote(
6954
self, environment: Environment, no_gaps: bool = False

sqlmesh/core/state_sync/engine_adapter.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from sqlmesh.core.audit import Audit
3131
from sqlmesh.core.engine_adapter import EngineAdapter, TransactionType
3232
from sqlmesh.core.environment import Environment
33-
from sqlmesh.core.model import Model, SeedModel
33+
from sqlmesh.core.model import Model, ModelKindName, SeedModel
3434
from sqlmesh.core.snapshot import (
3535
Intervals,
3636
Snapshot,
@@ -87,6 +87,7 @@ def __init__(
8787
"identifier": exp.DataType.build("text"),
8888
"version": exp.DataType.build("text"),
8989
"snapshot": exp.DataType.build("text"),
90+
"kind_name": exp.DataType.build("text"),
9091
}
9192

9293
self._environment_columns_to_types = {
@@ -239,6 +240,17 @@ def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[Sna
239240
)
240241
}
241242

243+
def models_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]:
244+
query = (
245+
exp.select("name")
246+
.from_(self.snapshots_table)
247+
.where(exp.column("name").isin(*names))
248+
.distinct()
249+
)
250+
if exclude_external:
251+
query = query.where(exp.column("kind_name").neq(ModelKindName.EXTERNAL.value))
252+
return {name for name, in self.engine_adapter.fetchall(query)}
253+
242254
def reset(self) -> None:
243255
"""Resets the state store to the state when it was first initialized."""
244256
self.engine_adapter.drop_table(self.snapshots_table)
@@ -821,6 +833,7 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
821833
"identifier": snapshot.identifier,
822834
"version": snapshot.version,
823835
"snapshot": snapshot.json(exclude={"intervals", "dev_intervals"}),
836+
"kind_name": snapshot.model_kind_name.value,
824837
}
825838
for snapshot in snapshots
826839
]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Add the kind_name column to the snapshots table."""
2+
import json
3+
4+
import pandas as pd
5+
from sqlglot import exp
6+
7+
8+
def migrate(state_sync): # type: ignore
9+
engine_adapter = state_sync.engine_adapter
10+
schema = state_sync.schema
11+
snapshots_table = f"{schema}._snapshots"
12+
13+
alter_table_exp = exp.AlterTable(
14+
this=exp.to_table(snapshots_table),
15+
actions=[
16+
exp.ColumnDef(
17+
this=exp.to_column("kind_name"),
18+
kind=exp.DataType.build("text"),
19+
)
20+
],
21+
)
22+
engine_adapter.execute(alter_table_exp)
23+
24+
new_snapshots = []
25+
26+
for name, identifier, version, snapshot in engine_adapter.fetchall(
27+
exp.select("name", "identifier", "version", "snapshot").from_(snapshots_table)
28+
):
29+
parsed_snapshot = json.loads(snapshot)
30+
new_snapshots.append(
31+
{
32+
"name": name,
33+
"identifier": identifier,
34+
"version": version,
35+
"snapshot": snapshot,
36+
"kind_name": parsed_snapshot["model"]["kind"]["name"],
37+
}
38+
)
39+
40+
if new_snapshots:
41+
engine_adapter.delete_from(snapshots_table, "TRUE")
42+
43+
engine_adapter.insert_append(
44+
snapshots_table,
45+
pd.DataFrame(new_snapshots),
46+
columns_to_types={
47+
"name": exp.DataType.build("text"),
48+
"identifier": exp.DataType.build("text"),
49+
"version": exp.DataType.build("text"),
50+
"snapshot": exp.DataType.build("text"),
51+
"kind_name": exp.DataType.build("text"),
52+
},
53+
contains_json=True,
54+
)

sqlmesh/schedulers/airflow/api.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,17 @@ def get_snapshot_intervals() -> Response:
101101
)
102102

103103

104+
@sqlmesh_api_v1.get("/models")
105+
@csrf.exempt
106+
@check_authentication
107+
def models_exist() -> Response:
108+
with util.scoped_state_sync() as state_sync:
109+
names = _csv_arg("names")
110+
exclude_external = "exclude_external" in request.args
111+
existing_models = state_sync.models_exist(names, exclude_external=exclude_external)
112+
return _success(common.ExistingModelsResponse(names=list(existing_models)))
113+
114+
104115
@sqlmesh_api_v1.get("/versions")
105116
@csrf.exempt
106117
@check_authentication
@@ -138,3 +149,9 @@ def _snapshot_name_versions_from_request() -> t.Optional[t.List[SnapshotNameVers
138149

139150
raw_versions = json.loads(request.args["versions"])
140151
return [SnapshotNameVersion.parse_obj(v) for v in raw_versions]
152+
153+
154+
def _csv_arg(arg: str) -> t.List[str]:
155+
if arg not in request.args:
156+
return []
157+
return [v.strip() for v in request.args[arg].split(",")]

sqlmesh/schedulers/airflow/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
SNAPSHOTS_PATH = f"{common.SQLMESH_API_BASE_PATH}/snapshots"
3636
SEEDS_PATH = f"{common.SQLMESH_API_BASE_PATH}/seeds"
3737
INTERVALS_PATH = f"{common.SQLMESH_API_BASE_PATH}/intervals"
38+
MODELS_PATH = f"{common.SQLMESH_API_BASE_PATH}/models"
3839
VERSIONS_PATH = f"{common.SQLMESH_API_BASE_PATH}/versions"
3940

4041

@@ -105,6 +106,14 @@ def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]
105106
).snapshot_ids
106107
)
107108

109+
def models_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]:
110+
flags = ["exclude_external"] if exclude_external else []
111+
return set(
112+
common.ExistingModelsResponse.parse_obj(
113+
self._get(MODELS_PATH, *flags, names=",".join(names))
114+
).names
115+
)
116+
108117
def get_snapshot_intervals(
109118
self, snapshot_name_versions: t.List[SnapshotNameVersion]
110119
) -> t.List[SnapshotIntervals]:

0 commit comments

Comments
 (0)