Skip to content

Commit c637908

Browse files
authored
Feat: expose a macro var to check if we're evaluating snapshots (#1562)
* Feat: expose a macro var to check if we're evaluating snapshots * PR feedback * Fixup * PR comments * Make evaluating an optional in _cache_key again * Refactor to use enum, address PR feedback, add unit test * Rename unit test
1 parent d753b59 commit c637908

4 files changed

Lines changed: 67 additions & 10 deletions

File tree

sqlmesh/core/macros.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import typing as t
4+
from enum import Enum
45
from functools import reduce
56
from string import Template
67

@@ -29,6 +30,12 @@
2930
SQLMESH_MOCKED_STAR = "__SQLMESH_MOCKED_STAR__"
3031

3132

33+
class RuntimeStage(Enum):
34+
LOADING = "loading"
35+
CREATING = "creating"
36+
EVALUATING = "evaluating"
37+
38+
3239
class MacroStrTemplate(Template):
3340
delimiter = SQLMESH_MACRO_PREFIX
3441

@@ -103,10 +110,11 @@ def __init__(
103110
python_env: t.Optional[t.Dict[str, Executable]] = None,
104111
jinja_env: t.Optional[Environment] = None,
105112
schema: t.Optional[t.Dict[str, t.Any]] = None,
113+
runtime_stage: RuntimeStage = RuntimeStage.LOADING,
106114
):
107115
self.dialect = dialect
108116
self.generator = MacroDialect().generator()
109-
self.locals: t.Dict[str, t.Any] = {}
117+
self.locals: t.Dict[str, t.Any] = {"runtime_stage": runtime_stage}
110118
self.env = {**ENV, "self": self}
111119
self.python_env = python_env or {}
112120
self._jinja_env: t.Optional[Environment] = jinja_env

sqlmesh/core/renderer.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from sqlmesh.core import constants as c
2020
from sqlmesh.core import dialect as d
21-
from sqlmesh.core.macros import MacroEvaluator
21+
from sqlmesh.core.macros import MacroEvaluator, RuntimeStage
2222
from sqlmesh.utils.date import TimeLike, date_dict, make_inclusive, to_datetime
2323
from sqlmesh.utils.errors import (
2424
ConfigError,
@@ -34,6 +34,8 @@
3434

3535
from sqlmesh.core.snapshot import Snapshot
3636

37+
CacheKey = t.Tuple[datetime, datetime, datetime, RuntimeStage]
38+
3739

3840
logger = logging.getLogger(__name__)
3941

@@ -59,7 +61,7 @@ def __init__(
5961
self._only_execution_time = only_execution_time
6062
self.schema = {} if schema is None else schema
6163

62-
self._cache: t.Dict[t.Tuple[datetime, datetime, datetime], t.List[exp.Expression]] = {}
64+
self._cache: t.Dict[CacheKey, t.List[exp.Expression]] = {}
6365

6466
def _render(
6567
self,
@@ -69,6 +71,7 @@ def _render(
6971
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
7072
table_mapping: t.Optional[t.Dict[str, str]] = None,
7173
is_dev: bool = False,
74+
runtime_stage: t.Optional[RuntimeStage] = None,
7275
**kwargs: t.Any,
7376
) -> t.List[exp.Expression]:
7477
"""Renders a expression, expanding macros with provided kwargs
@@ -77,18 +80,19 @@ def _render(
7780
start: The start datetime to render. Defaults to epoch start.
7881
end: The end datetime to render. Defaults to epoch start.
7982
execution_time: The date/time time reference to use for execution time.
80-
kwargs: Additional kwargs to pass to the renderer.
8183
snapshots: All upstream snapshots (by model name) to use for expansion and mapping of physical locations.
8284
table_mapping: Table mapping of physical locations. Takes precedence over snapshot mappings.
8385
is_dev: Indicates whether the rendering happens in the development mode and temporary
8486
tables / table clones should be used where applicable.
87+
runtime_stage: Indicates the current runtime stage, for example if we're still loading the project, etc.
88+
kwargs: Additional kwargs to pass to the renderer.
8589
8690
Returns:
8791
The rendered expressions.
8892
"""
8993

90-
cache_key = self._cache_key(start, end, execution_time)
91-
start_dt, end_dt, execution_dt = cache_key
94+
cache_key = self._cache_key(start, end, execution_time, runtime_stage)
95+
start_dt, end_dt, execution_dt, runtime_stage = cache_key
9296
if cache_key not in self._cache:
9397
expressions = [self._expression]
9498

@@ -132,6 +136,7 @@ def _render(
132136
python_env=self._python_env,
133137
jinja_env=jinja_env,
134138
schema=self.schema,
139+
runtime_stage=runtime_stage,
135140
)
136141

137142
for definition in self._macro_definitions:
@@ -203,10 +208,12 @@ def _cache_key(
203208
start: t.Optional[TimeLike] = None,
204209
end: t.Optional[TimeLike] = None,
205210
execution_time: t.Optional[TimeLike] = None,
206-
) -> t.Tuple[datetime, datetime, datetime]:
211+
runtime_stage: t.Optional[RuntimeStage] = None,
212+
) -> CacheKey:
207213
return (
208214
*make_inclusive(start or c.EPOCH, end or c.EPOCH),
209215
to_datetime(execution_time or c.EPOCH),
216+
runtime_stage or RuntimeStage.LOADING,
210217
)
211218

212219

@@ -273,7 +280,7 @@ def __init__(
273280

274281
self._model_name = model_name
275282

276-
self._optimized_cache: t.Dict[t.Tuple[datetime, datetime, datetime], exp.Expression] = {}
283+
self._optimized_cache: t.Dict[CacheKey, exp.Expression] = {}
277284

278285
def render(
279286
self,
@@ -285,6 +292,7 @@ def render(
285292
is_dev: bool = False,
286293
expand: t.Iterable[str] = tuple(),
287294
optimize: bool = True,
295+
runtime_stage: t.Optional[RuntimeStage] = None,
288296
**kwargs: t.Any,
289297
) -> t.Optional[exp.Subqueryable]:
290298
"""Renders a query, expanding macros with provided kwargs, and optionally expanding referenced models.
@@ -302,12 +310,13 @@ def render(
302310
that depend on materialized tables. Model definitions are inlined and can thus be run end to
303311
end on the fly.
304312
optimize: Whether to optimize the query.
313+
runtime_stage: Indicates the current runtime stage, for example if we're still loading the project, etc.
305314
kwargs: Additional kwargs to pass to the renderer.
306315
307316
Returns:
308317
The rendered expression.
309318
"""
310-
cache_key = self._cache_key(start, end, execution_time)
319+
cache_key = self._cache_key(start, end, execution_time, runtime_stage)
311320

312321
if not optimize or cache_key not in self._optimized_cache:
313322
try:

sqlmesh/core/snapshot/evaluator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from sqlmesh.core.dialect import schema_
3636
from sqlmesh.core.engine_adapter import EngineAdapter
3737
from sqlmesh.core.engine_adapter.base import InsertOverwriteStrategy
38+
from sqlmesh.core.macros import RuntimeStage
3839
from sqlmesh.core.model import IncrementalUnmanagedKind, Model, SCDType2Kind, ViewKind
3940
from sqlmesh.core.snapshot import (
4041
QualifiedViewName,
@@ -158,6 +159,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
158159
engine_adapter=self.adapter,
159160
snapshots=snapshots,
160161
is_dev=is_dev,
162+
runtime_stage=RuntimeStage.EVALUATING,
161163
**common_render_kwargs,
162164
)
163165

@@ -411,6 +413,7 @@ def _create_snapshot(
411413
engine_adapter=self.adapter,
412414
snapshots=parent_snapshots_by_name,
413415
is_dev=is_dev,
416+
runtime_stage=RuntimeStage.CREATING,
414417
)
415418

416419
evaluation_strategy = _evaluation_strategy(snapshot, self.adapter)

tests/core/test_snapshot_evaluator.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sqlmesh.core.engine_adapter import EngineAdapter, create_engine_adapter
1212
from sqlmesh.core.engine_adapter.base import InsertOverwriteStrategy
1313
from sqlmesh.core.environment import EnvironmentNamingInfo
14-
from sqlmesh.core.macros import macro
14+
from sqlmesh.core.macros import RuntimeStage, macro
1515
from sqlmesh.core.model import (
1616
FullKind,
1717
IncrementalByTimeRangeKind,
@@ -148,6 +148,43 @@ def x(evaluator, y=None) -> None:
148148
)
149149

150150

151+
def test_runtime_stages(capsys, mocker, adapter_mock, make_snapshot):
152+
evaluator = SnapshotEvaluator(adapter_mock)
153+
154+
@macro()
155+
def increment_stage_counter(evaluator) -> None:
156+
# Hack which allows us to intercept the different runtime stage values
157+
print(f"RuntimeStage value: {evaluator.locals['runtime_stage'].value}")
158+
159+
model = load_sql_based_model(
160+
parse( # type: ignore
161+
"""
162+
MODEL (
163+
name test_schema.test_model,
164+
kind FULL,
165+
);
166+
167+
@increment_stage_counter();
168+
169+
SELECT 1 AS a;
170+
"""
171+
),
172+
macros=macro.get_registry(),
173+
)
174+
175+
capsys.readouterr()
176+
177+
snapshot = make_snapshot(model)
178+
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
179+
assert f"RuntimeStage value: {RuntimeStage.LOADING.value}" in capsys.readouterr().out
180+
181+
evaluator.create([snapshot], {})
182+
assert f"RuntimeStage value: {RuntimeStage.CREATING.value}" in capsys.readouterr().out
183+
184+
evaluator.evaluate(snapshot, "2020-01-01", "2020-01-02", "2020-01-02", snapshots={})
185+
assert f"RuntimeStage value: {RuntimeStage.EVALUATING.value}" in capsys.readouterr().out
186+
187+
151188
def test_evaluate_paused_forward_only_upstream(mocker: MockerFixture, make_snapshot):
152189
model = SqlModel(name="test_schema.test_model", query=parse_one("SELECT a, ds"))
153190
snapshot = make_snapshot(model)

0 commit comments

Comments
 (0)