Skip to content

Commit b234778

Browse files
authored
Feat: expose columns to types of parent models in MacroEvaluator (#1488)
* Feat: expose explicit model dependencies in the macro context * Temporarily add debug statement * Add another temporary debug statement * Add another temporary debug statement * Set 'columns' instead of 'columns_to_types_' * Formatting * mocked star * formatting * Start refactoring - backtrack to previous codebase state * Refactor * Fixups * Simplify * Fix lineage test * Ensure sqlglot's diff doesn't fail when when prev/new queries are identical * Don't actually need to call diff if queries are identical * Rename to excluded * Yet another refactor * Add render test for sushi.waiters * Remove unnecessary depends_on field
1 parent 0302d3c commit b234778

6 files changed

Lines changed: 67 additions & 20 deletions

File tree

examples/sushi/models/waiters.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,14 @@ def entrypoint(evaluator: MacroEvaluator) -> exp.Select:
3030
FROM sushi.orders AS o
3131
WHERE @incremental_by_ds(ds)
3232
"""
33+
excluded = {"id", "customer_id", "start_ts", "end_ts"}
34+
projections = []
35+
for column, dtype in evaluator.columns_to_types("sushi.orders").items():
36+
if column not in excluded:
37+
projections.append(f"{column}::{dtype}")
38+
3339
return (
34-
exp.select("waiter_id::int as waiter_id", "ds::text as ds")
40+
exp.select(*projections)
3541
.from_("sushi.orders AS o")
3642
.where(incremental_by_ds(evaluator, exp.to_column("ds")))
3743
.distinct()

sqlmesh/core/loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
logger = logging.getLogger(__name__)
4545

46+
4647
# TODO: consider moving this to context
4748
def update_model_schemas(
4849
dag: DAG[str],
@@ -300,6 +301,7 @@ def _load() -> Model:
300301
raise ConfigError(
301302
f"Failed to parse a model definition at '{path}': {ex}."
302303
)
304+
303305
return load_sql_based_model(
304306
expressions,
305307
defaults=config.model_defaults.dict(),

sqlmesh/core/macros.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sqlglot.executor.env import ENV
1111
from sqlglot.executor.python import Python
1212
from sqlglot.helper import csv, ensure_collection
13+
from sqlglot.schema import MappingSchema
1314

1415
from sqlmesh.core.dialect import (
1516
SQLMESH_MACRO_PREFIX,
@@ -25,6 +26,8 @@
2526
from sqlmesh.utils.jinja import JinjaMacroRegistry, has_jinja
2627
from sqlmesh.utils.metaprogramming import Executable, prepare_env, print_exception
2728

29+
SQLMESH_MOCKED_STAR = "__SQLMESH_MOCKED_STAR__"
30+
2831

2932
class MacroStrTemplate(Template):
3033
delimiter = SQLMESH_MACRO_PREFIX
@@ -99,6 +102,7 @@ def __init__(
99102
dialect: str = "",
100103
python_env: t.Optional[t.Dict[str, Executable]] = None,
101104
jinja_env: t.Optional[Environment] = None,
105+
schema: t.Optional[t.Dict[str, t.Any]] = None,
102106
):
103107
self.dialect = dialect
104108
self.generator = MacroDialect().generator()
@@ -107,6 +111,7 @@ def __init__(
107111
self.python_env = python_env or {}
108112
self._jinja_env: t.Optional[Environment] = jinja_env
109113
self.macros = {normalize_macro_name(k): v.func for k, v in macro.get_registry().items()}
114+
self._schema = MappingSchema(schema, dialect=dialect, normalize=False) if schema else {}
110115

111116
prepare_env(self.python_env, self.env)
112117
for k, v in self.python_env.items():
@@ -261,6 +266,17 @@ def jinja_env(self) -> Environment:
261266
self._jinja_env = JinjaMacroRegistry().build_environment(**jinja_env_methods)
262267
return self._jinja_env
263268

269+
def columns_to_types(self, model_name: str) -> t.Dict[str, exp.DataType]:
270+
"""Returns the columns-to-types mapping corresponding to the specified model."""
271+
if not isinstance(self._schema, MappingSchema):
272+
return {SQLMESH_MOCKED_STAR: exp.DataType.build("unknown")}
273+
274+
columns_to_types = self._schema.find(exp.to_table(model_name))
275+
if columns_to_types is None:
276+
raise SQLMeshError(f"Model '{model_name}' not found in the macro evaluator's context.")
277+
278+
return columns_to_types # type: ignore
279+
264280

265281
class macro(registry_decorator):
266282
"""Specifies a function is a macro and registers it the global MACROS registry.

sqlmesh/core/model/definition.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
from sqlglot import diff, exp
1818
from sqlglot.diff import Insert, Keep
1919
from sqlglot.helper import ensure_list
20+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
2021
from sqlglot.schema import MappingSchema, nested_set
2122
from sqlglot.time import format_time
2223

2324
from sqlmesh.core import constants as c
2425
from sqlmesh.core import dialect as d
25-
from sqlmesh.core.macros import MacroRegistry, macro
26+
from sqlmesh.core.macros import SQLMESH_MOCKED_STAR, MacroRegistry, macro
2627
from sqlmesh.core.model.common import expression_validator
2728
from sqlmesh.core.model.kind import (
2829
IncrementalByTimeRangeKind,
@@ -1026,6 +1027,14 @@ def update_schema(
10261027
super().update_schema(
10271028
schema, default_schema=default_schema, default_catalog=default_catalog
10281029
)
1030+
1031+
mocked_star = normalize_identifiers(SQLMESH_MOCKED_STAR, dialect=self.dialect)
1032+
if mocked_star.name in (self.columns_to_types or {}):
1033+
# We reset the unoptimized query cache here as well to allow the model's query
1034+
# to be re-rendered so that the MacroEvaluator can resolve columns_to_types calls
1035+
# and get rid of the mocked star column
1036+
self._query_renderer._cache = {}
1037+
10291038
self._columns_to_types = None
10301039
self._query_renderer._optimized_cache = {}
10311040

sqlmesh/core/renderer.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,6 @@
3838
logger = logging.getLogger(__name__)
3939

4040

41-
def _dates(
42-
start: t.Optional[TimeLike] = None,
43-
end: t.Optional[TimeLike] = None,
44-
execution_time: t.Optional[TimeLike] = None,
45-
) -> t.Tuple[datetime, datetime, datetime]:
46-
return (
47-
*make_inclusive(start or c.EPOCH, end or c.EPOCH),
48-
to_datetime(execution_time or c.EPOCH),
49-
)
50-
51-
5241
class BaseExpressionRenderer:
5342
def __init__(
5443
self,
@@ -59,6 +48,7 @@ def __init__(
5948
jinja_macro_registry: t.Optional[JinjaMacroRegistry] = None,
6049
python_env: t.Optional[t.Dict[str, Executable]] = None,
6150
only_execution_time: bool = False,
51+
schema: t.Optional[t.Dict[str, t.Any]] = None,
6252
):
6353
self._expression = expression
6454
self._dialect = dialect
@@ -67,6 +57,7 @@ def __init__(
6757
self._jinja_macro_registry = jinja_macro_registry or JinjaMacroRegistry()
6858
self._python_env = python_env or {}
6959
self._only_execution_time = only_execution_time
60+
self.schema = {} if schema is None else schema
7061

7162
self._cache: t.Dict[t.Tuple[datetime, datetime, datetime], t.List[exp.Expression]] = {}
7263

@@ -96,7 +87,7 @@ def _render(
9687
The rendered expressions.
9788
"""
9889

99-
cache_key = _dates(start, end, execution_time)
90+
cache_key = self._cache_key(start, end, execution_time)
10091
start_dt, end_dt, execution_dt = cache_key
10192
if cache_key not in self._cache:
10293
expressions = [self._expression]
@@ -140,6 +131,7 @@ def _render(
140131
self._dialect,
141132
python_env=self._python_env,
142133
jinja_env=jinja_env,
134+
schema=self.schema,
143135
)
144136

145137
for definition in self._macro_definitions:
@@ -174,7 +166,7 @@ def update_cache(
174166
execution_time: t.Optional[TimeLike] = None,
175167
**kwargs: t.Any,
176168
) -> None:
177-
self._cache[_dates(start, end, execution_time)] = [expression]
169+
self._cache[self._cache_key(start, end, execution_time)] = [expression]
178170

179171
def _resolve_tables(
180172
self,
@@ -206,6 +198,17 @@ def _resolve_tables(
206198
**kwargs,
207199
)
208200

201+
def _cache_key(
202+
self,
203+
start: t.Optional[TimeLike] = None,
204+
end: t.Optional[TimeLike] = None,
205+
execution_time: t.Optional[TimeLike] = None,
206+
) -> t.Tuple[datetime, datetime, datetime]:
207+
return (
208+
*make_inclusive(start or c.EPOCH, end or c.EPOCH),
209+
to_datetime(execution_time or c.EPOCH),
210+
)
211+
209212

210213
class ExpressionRenderer(BaseExpressionRenderer):
211214
def render(
@@ -265,14 +268,13 @@ def __init__(
265268
jinja_macro_registry=jinja_macro_registry,
266269
python_env=python_env,
267270
only_execution_time=only_execution_time,
271+
schema=schema,
268272
)
269273

270274
self._model_name = model_name
271275

272276
self._optimized_cache: t.Dict[t.Tuple[datetime, datetime, datetime], exp.Expression] = {}
273277

274-
self.schema = {} if schema is None else schema
275-
276278
def render(
277279
self,
278280
start: t.Optional[TimeLike] = None,
@@ -305,7 +307,7 @@ def render(
305307
Returns:
306308
The rendered expression.
307309
"""
308-
cache_key = _dates(start, end, execution_time)
310+
cache_key = self._cache_key(start, end, execution_time)
309311

310312
if not optimize or cache_key not in self._optimized_cache:
311313
try:
@@ -367,7 +369,7 @@ def update_cache(
367369
expression, start=start, end=end, execution_time=execution_time, **kwargs
368370
)
369371
else:
370-
self._optimized_cache[_dates(start, end, execution_time)] = expression
372+
self._optimized_cache[self._cache_key(start, end, execution_time)] = expression
371373

372374
def _optimize_query(self, query: exp.Subqueryable) -> exp.Subqueryable:
373375
# We don't want to normalize names in the schema because that's handled by the optimizer

tests/core/test_model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,7 @@ def test_lookback():
985985
assert to_timestamp(model.lookback_start("Jan 1 2020")) == to_timestamp("Jan 1 2018")
986986

987987

988-
def test_render_query(assert_exp_eq):
988+
def test_render_query(assert_exp_eq, sushi_context):
989989
model = SqlModel(
990990
name="test",
991991
cron="1 0 * * *",
@@ -1040,6 +1040,18 @@ def test_render_query(assert_exp_eq):
10401040
'SELECT COUNT(DISTINCT "a") FILTER (WHERE "b" > 0) AS "c" FROM "x" AS "x"',
10411041
)
10421042

1043+
assert_exp_eq(
1044+
sushi_context.models["sushi.waiters"].render_query().sql(),
1045+
"""
1046+
SELECT DISTINCT
1047+
CAST("o"."waiter_id" AS INT) AS "waiter_id",
1048+
CAST("o"."ds" AS TEXT) AS "ds"
1049+
FROM "sushi"."orders" AS "o"
1050+
WHERE
1051+
"o"."ds" <= '1970-01-01' AND "o"."ds" >= '1970-01-01'
1052+
""",
1053+
)
1054+
10431055

10441056
def test_time_column():
10451057
expressions = d.parse(

0 commit comments

Comments
 (0)