Skip to content

Commit 8846ff4

Browse files
authored
Refactor!: Replace snapshot model with snapshot node (#1137)
1 parent 7e6af83 commit 8846ff4

31 files changed

Lines changed: 581 additions & 418 deletions

sqlmesh/core/audit/definition.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sqlmesh.core import constants as c
1111
from sqlmesh.core import dialect as d
1212
from sqlmesh.core.model.common import bool_validator, expression_validator
13-
from sqlmesh.core.model.definition import Model, _Model
13+
from sqlmesh.core.model.definition import _Model
1414
from sqlmesh.core.renderer import QueryRenderer
1515
from sqlmesh.utils.date import TimeLike
1616
from sqlmesh.utils.errors import AuditConfigError, SQLMeshError, raise_config_error
@@ -160,7 +160,7 @@ def load_multiple(
160160

161161
def render_query(
162162
self,
163-
snapshot_or_model: t.Union[Snapshot, Model],
163+
snapshot_or_model: t.Union[Snapshot, _Model],
164164
*,
165165
start: t.Optional[TimeLike] = None,
166166
end: t.Optional[TimeLike] = None,
@@ -238,7 +238,7 @@ def macro_definitions(self) -> t.List[d.MacroDef]:
238238
"""All macro definitions from the list of expressions."""
239239
return [s for s in self.expressions if isinstance(s, d.MacroDef)]
240240

241-
def _create_query_renderer(self, model: Model) -> QueryRenderer:
241+
def _create_query_renderer(self, model: _Model) -> QueryRenderer:
242242
return QueryRenderer(
243243
self.query,
244244
self.dialect or model.dialect,

sqlmesh/core/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ def snapshots(self) -> t.Dict[str, Snapshot]:
543543

544544
snapshot = Snapshot.from_model(
545545
model,
546-
models=models,
546+
nodes=models,
547547
audits=audits,
548548
cache=fingerprint_cache,
549549
ttl=ttl,
@@ -555,7 +555,7 @@ def snapshots(self) -> t.Dict[str, Snapshot]:
555555

556556
for snapshot in stored_snapshots.values():
557557
# Keep the original model instance to preserve the query cache.
558-
snapshot.model = snapshots[snapshot.name].model
558+
snapshot.node = snapshots[snapshot.name].node
559559

560560
return {name: stored_snapshots.get(s.snapshot_id, s) for name, s in snapshots.items()}
561561

sqlmesh/core/context_diff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def create(
127127

128128
if existing:
129129
# Keep the original model instance to preserve the query cache.
130-
existing.model = snapshot.model
130+
existing.node = snapshot.node
131131

132132
merged_snapshots[name] = existing.copy()
133133
if modified:

sqlmesh/core/engine_adapter/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
Query,
4040
QueryOrDF,
4141
)
42-
from sqlmesh.core.model.meta import IntervalUnit
42+
from sqlmesh.core.node import IntervalUnit
4343

4444
logger = logging.getLogger(__name__)
4545

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
DataObjectType,
1717
TransactionType,
1818
)
19-
from sqlmesh.core.model.meta import IntervalUnit
19+
from sqlmesh.core.node import IntervalUnit
2020
from sqlmesh.core.schema_diff import SchemaDiffer
2121
from sqlmesh.utils.date import to_datetime
2222
from sqlmesh.utils.errors import SQLMeshError

sqlmesh/core/engine_adapter/spark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
PySparkSession,
2323
QueryOrDF,
2424
)
25-
from sqlmesh.core.model.meta import IntervalUnit
25+
from sqlmesh.core.node import IntervalUnit
2626

2727

2828
class SparkEngineAdapter(EngineAdapter):

sqlmesh/core/model/definition.py

Lines changed: 155 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import json
45
import logging
56
import sys
67
import types
@@ -23,13 +24,20 @@
2324
from sqlmesh.core import dialect as d
2425
from sqlmesh.core.macros import MacroRegistry, macro
2526
from sqlmesh.core.model.common import expression_validator
26-
from sqlmesh.core.model.kind import ModelKindName, SeedKind, _Incremental
27+
from sqlmesh.core.model.kind import (
28+
IncrementalByTimeRangeKind,
29+
IncrementalByUniqueKeyKind,
30+
ModelKindName,
31+
SeedKind,
32+
_Incremental,
33+
)
2734
from sqlmesh.core.model.meta import ModelMeta
2835
from sqlmesh.core.model.seed import Seed, create_seed
2936
from sqlmesh.core.renderer import ExpressionRenderer, QueryRenderer
3037
from sqlmesh.utils import str_to_bool
3138
from sqlmesh.utils.date import TimeLike, make_inclusive, to_datetime
3239
from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error
40+
from sqlmesh.utils.hashing import hash_data
3341
from sqlmesh.utils.jinja import JinjaMacroRegistry, extract_macro_references
3442
from sqlmesh.utils.metaprogramming import (
3543
Executable,
@@ -48,9 +56,9 @@
4856
from sqlmesh.utils.jinja import MacroReference
4957

5058
if sys.version_info >= (3, 9):
51-
from typing import Annotated, Literal
59+
from typing import Literal
5260
else:
53-
from typing_extensions import Annotated, Literal
61+
from typing_extensions import Literal
5462

5563
logger = logging.getLogger(__name__)
5664

@@ -628,6 +636,103 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]:
628636
"""
629637
raise NotImplementedError
630638

639+
@property
640+
def data_hash(self) -> str:
641+
"""
642+
Computes the data hash for the node.
643+
644+
Returns:
645+
The data hash for the node.
646+
"""
647+
return hash_data(self._data_hash_fields)
648+
649+
@property
650+
def _data_hash_fields(self) -> t.List[str]:
651+
data = [
652+
str(self.sorted_python_env),
653+
self.kind.name,
654+
self.cron,
655+
self.storage_format,
656+
str(self.lookback),
657+
*(expr.sql() for expr in (self.partitioned_by or [])),
658+
*(self.clustered_by or []),
659+
self.stamp,
660+
]
661+
662+
for column_name, column_type in (self.columns_to_types_ or {}).items():
663+
data.append(column_name)
664+
data.append(column_type.sql())
665+
666+
if isinstance(self.kind, IncrementalByTimeRangeKind):
667+
data.append(self.kind.time_column.column)
668+
data.append(self.kind.time_column.format)
669+
elif isinstance(self.kind, IncrementalByUniqueKeyKind):
670+
data.extend(self.kind.unique_key)
671+
672+
return data # type: ignore
673+
674+
def metadata_hash(self, audits: t.Dict[str, Audit]) -> str:
675+
"""
676+
Computes the metadata hash for the node.
677+
678+
Args:
679+
audits: Available audits by name.
680+
681+
Returns:
682+
The metadata hash for the node.
683+
"""
684+
from sqlmesh.core.audit import BUILT_IN_AUDITS
685+
686+
metadata = [
687+
self.dialect,
688+
self.owner,
689+
self.description,
690+
str(self.start) if self.start else None,
691+
str(self.retention) if self.retention else None,
692+
str(self.batch_size) if self.batch_size is not None else None,
693+
json.dumps(self.mapping_schema, sort_keys=True),
694+
*sorted(self.tags),
695+
*sorted(self.grain),
696+
str(self.forward_only),
697+
str(self.disable_restatement),
698+
]
699+
700+
for audit_name, audit_args in sorted(self.audits, key=lambda a: a[0]):
701+
metadata.append(audit_name)
702+
703+
if audit_name in BUILT_IN_AUDITS:
704+
for arg_name, arg_value in audit_args.items():
705+
metadata.append(arg_name)
706+
metadata.append(arg_value.sql(comments=True))
707+
elif audit_name in audits:
708+
audit = audits[audit_name]
709+
query = (
710+
audit.query
711+
if self.hash_raw_query
712+
else audit.render_query(self, **t.cast(t.Dict[str, t.Any], audit_args))
713+
or audit.query
714+
)
715+
metadata.extend(
716+
[
717+
query.sql(comments=True),
718+
audit.dialect,
719+
str(audit.skip),
720+
str(audit.blocking),
721+
]
722+
)
723+
else:
724+
raise SQLMeshError(f"Unexpected audit name '{audit_name}'.")
725+
726+
# Add comments from the query.
727+
if self.is_sql:
728+
rendered_query = self.render_query()
729+
if rendered_query:
730+
for e, _, _ in rendered_query.walk():
731+
if e.comments:
732+
metadata.extend(e.comments)
733+
734+
return hash_data(metadata)
735+
631736

632737
class _SqlBasedModel(_Model):
633738
pre_statements_: t.Optional[t.List[exp.Expression]] = Field(
@@ -728,6 +833,20 @@ def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer:
728833
)
729834
return self.__statement_renderers[expression_key]
730835

836+
@property
837+
def _data_hash_fields(self) -> t.List[str]:
838+
pre_statements = (
839+
self.pre_statements if self.hash_raw_query else self.render_pre_statements()
840+
)
841+
post_statements = (
842+
self.post_statements if self.hash_raw_query else self.render_post_statements()
843+
)
844+
macro_defs = self.macro_definitions if self.hash_raw_query else []
845+
return [
846+
*super()._data_hash_fields,
847+
*[e.sql(comments=False) for e in (*pre_statements, *post_statements, *macro_defs)],
848+
]
849+
731850

732851
class SqlModel(_SqlBasedModel):
733852
"""The model definition which relies on a SQL query to fetch the data.
@@ -926,6 +1045,24 @@ def _query_renderer(self) -> QueryRenderer:
9261045
)
9271046
return self.__query_renderer
9281047

1048+
@property
1049+
def _data_hash_fields(self) -> t.List[str]:
1050+
data = super()._data_hash_fields
1051+
1052+
query = self.query if self.hash_raw_query else self.render_query() or self.query
1053+
data.append(query.sql(comments=False))
1054+
1055+
for macro_name, macro in sorted(self.jinja_macros.root_macros.items()):
1056+
data.append(macro_name)
1057+
data.append(macro.definition)
1058+
1059+
for _, package in sorted(self.jinja_macros.packages.items(), key=lambda x: x[0]):
1060+
for macro_name, macro in sorted(package.items(), key=lambda x: x[0]):
1061+
data.append(macro_name)
1062+
data.append(macro.definition)
1063+
1064+
return data
1065+
9291066
def __repr__(self) -> str:
9301067
return f"Model<name: {self.name}, query: {self.query.sql(dialect=self.dialect)[0:30]}>"
9311068

@@ -1087,6 +1224,14 @@ def _ensure_hydrated(self) -> None:
10871224
if not self.is_hydrated:
10881225
raise SQLMeshError(f"Seed model '{self.name}' is not hydrated.")
10891226

1227+
@property
1228+
def _data_hash_fields(self) -> t.List[str]:
1229+
data = super()._data_hash_fields
1230+
for column_name, column_hash in self.column_hashes.items():
1231+
data.append(column_name)
1232+
data.append(column_hash)
1233+
return data
1234+
10901235
def __repr__(self) -> str:
10911236
return f"Model<name: {self.name}, seed: {self.kind.path}>"
10921237

@@ -1139,6 +1284,12 @@ def is_python(self) -> bool:
11391284
def is_breaking_change(self, previous: Model) -> t.Optional[bool]:
11401285
return None
11411286

1287+
@property
1288+
def _data_hash_fields(self) -> t.List[str]:
1289+
data = super()._data_hash_fields
1290+
data.append(self.entrypoint)
1291+
return data
1292+
11421293
def __repr__(self) -> str:
11431294
return f"Model<name: {self.name}, entrypoint: {self.entrypoint}>"
11441295

@@ -1156,9 +1307,7 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]:
11561307
return None
11571308

11581309

1159-
Model = Annotated[
1160-
t.Union[SqlModel, SeedModel, PythonModel, ExternalModel], Field(discriminator="source_type")
1161-
]
1310+
Model = t.Union[SqlModel, SeedModel, PythonModel, ExternalModel]
11621311

11631312

11641313
def load_model(

0 commit comments

Comments
 (0)