Skip to content

Commit d9f8f55

Browse files
Feat: Allow macros in python model properties (#3740)
1 parent 65ac801 commit d9f8f55

4 files changed

Lines changed: 309 additions & 29 deletions

File tree

sqlmesh/core/model/decorator.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
create_python_model,
1818
create_sql_model,
1919
get_model_name,
20+
render_meta_fields,
2021
)
2122
from sqlmesh.core.model.kind import ModelKindName, _ModelKind
2223
from sqlmesh.utils import registry_decorator
@@ -118,6 +119,21 @@ def model(
118119

119120
build_env(self.func, env=env, name=entrypoint, path=module_path)
120121

122+
rendered_fields = render_meta_fields(
123+
fields={"name": self.name, **self.kwargs},
124+
module_path=module_path,
125+
macros=macros,
126+
jinja_macros=jinja_macros,
127+
variables=variables,
128+
path=path,
129+
dialect=dialect,
130+
default_catalog=default_catalog,
131+
)
132+
133+
rendered_name = rendered_fields["name"]
134+
if isinstance(rendered_name, exp.Expression):
135+
rendered_fields["name"] = rendered_name.sql(dialect=dialect)
136+
121137
common_kwargs = {
122138
"defaults": defaults,
123139
"path": path,
@@ -133,7 +149,7 @@ def model(
133149
"macros": macros,
134150
"jinja_macros": jinja_macros,
135151
"audit_definitions": audit_definitions,
136-
**self.kwargs,
152+
**rendered_fields,
137153
}
138154

139155
for key in ("pre_statements", "post_statements", "on_virtual_update"):
@@ -146,5 +162,5 @@ def model(
146162

147163
if self.is_sql:
148164
query = MacroFunc(this=exp.Anonymous(this=entrypoint))
149-
return create_sql_model(self.name, query, **common_kwargs)
150-
return create_python_model(self.name, entrypoint, **common_kwargs)
165+
return create_sql_model(query=query, **common_kwargs)
166+
return create_python_model(entrypoint=entrypoint, **common_kwargs)

sqlmesh/core/model/definition.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@
6262

6363
logger = logging.getLogger(__name__)
6464

65+
RUNTIME_RENDERED_MODEL_FIELDS = {
66+
"audits",
67+
"signals",
68+
"description",
69+
"cron",
70+
"physical_properties",
71+
"merge_filter",
72+
}
73+
6574

6675
class _Model(ModelMeta, frozen=True):
6776
"""Model is the core abstraction for user defined datasets.
@@ -1823,7 +1832,7 @@ def load_sql_based_model(
18231832
if kind_prop.name.lower() == "merge_filter":
18241833
unrendered_merge_filter = kind_prop
18251834

1826-
meta_renderer = _meta_renderer(
1835+
rendered_meta_exprs = render_expression(
18271836
expression=meta,
18281837
module_path=module_path,
18291838
macros=macros,
@@ -1834,7 +1843,6 @@ def load_sql_based_model(
18341843
default_catalog=default_catalog,
18351844
)
18361845

1837-
rendered_meta_exprs = meta_renderer.render()
18381846
if rendered_meta_exprs is None or len(rendered_meta_exprs) != 1:
18391847
raise_config_error(
18401848
f"Invalid MODEL statement:\n{meta.sql(dialect=dialect, pretty=True)}",
@@ -2024,21 +2032,6 @@ def create_python_model(
20242032
# Also remove self-references that are found
20252033

20262034
dialect = kwargs.get("dialect")
2027-
renderer_kwargs = {
2028-
"module_path": module_path,
2029-
"macros": macros,
2030-
"jinja_macros": jinja_macros,
2031-
"variables": variables,
2032-
"path": path,
2033-
"dialect": dialect,
2034-
"default_catalog": kwargs.get("default_catalog"),
2035-
}
2036-
2037-
name_renderer = _meta_renderer(
2038-
expression=d.parse_one(name, dialect=dialect),
2039-
**renderer_kwargs, # type: ignore
2040-
)
2041-
name = t.cast(t.List[exp.Expression], name_renderer.render())[0].sql(dialect=dialect)
20422035

20432036
dependencies_unspecified = depends_on is None
20442037

@@ -2050,15 +2043,21 @@ def create_python_model(
20502043
if dependencies_unspecified:
20512044
depends_on = parsed_depends_on - {name}
20522045
else:
2053-
depends_on_renderer = _meta_renderer(
2046+
depends_on_rendered = render_expression(
20542047
expression=exp.Array(
20552048
expressions=[d.parse_one(dep, dialect=dialect) for dep in depends_on or []]
20562049
),
2057-
**renderer_kwargs, # type: ignore
2050+
module_path=module_path,
2051+
macros=macros,
2052+
jinja_macros=jinja_macros,
2053+
variables=variables,
2054+
path=path,
2055+
dialect=dialect,
2056+
default_catalog=kwargs.get("default_catalog"),
20582057
)
20592058
depends_on = {
20602059
dep.sql(dialect=dialect)
2061-
for dep in t.cast(t.List[exp.Expression], depends_on_renderer.render())[0].expressions
2060+
for dep in t.cast(t.List[exp.Expression], depends_on_rendered)[0].expressions
20622061
}
20632062

20642063
variables = {k: v for k, v in (variables or {}).items() if k in referenced_variables}
@@ -2382,7 +2381,61 @@ def _refs_to_sql(values: t.Any) -> exp.Expression:
23822381
return exp.Tuple(expressions=values)
23832382

23842383

2385-
def _meta_renderer(
2384+
def render_meta_fields(
2385+
fields: t.Dict[str, t.Any],
2386+
module_path: Path,
2387+
path: Path,
2388+
jinja_macros: t.Optional[JinjaMacroRegistry],
2389+
macros: t.Optional[MacroRegistry],
2390+
dialect: DialectType,
2391+
variables: t.Optional[t.Dict[str, t.Any]],
2392+
default_catalog: t.Optional[str],
2393+
) -> t.Dict[str, t.Any]:
2394+
def render_field_value(value: t.Any) -> t.Any:
2395+
if isinstance(value, exp.Expression) or (
2396+
isinstance(value, str) and d.SQLMESH_MACRO_PREFIX in value
2397+
):
2398+
expression = exp.maybe_parse(value, dialect=dialect)
2399+
rendered_expr = render_expression(
2400+
expression=expression,
2401+
module_path=module_path,
2402+
macros=macros,
2403+
jinja_macros=jinja_macros,
2404+
variables=variables,
2405+
path=path,
2406+
dialect=dialect,
2407+
default_catalog=default_catalog,
2408+
)
2409+
if rendered_expr is None:
2410+
raise SQLMeshError(
2411+
f"Failed to render model attribute `{fields['name']}` at `{path}`\n"
2412+
f"'{expression.sql(dialect=dialect)}' must return an expression"
2413+
)
2414+
if len(rendered_expr) != 1:
2415+
raise SQLMeshError(
2416+
f"Failed to render model attribute `{fields['name']}` at `{path}`.\n"
2417+
f"`{expression.sql(dialect=dialect)}` must return one result, but got {len(rendered_expr)}"
2418+
)
2419+
return rendered_expr[0]
2420+
2421+
return value
2422+
2423+
for field_name, field_info in ModelMeta.all_field_infos().items():
2424+
field = field_info.alias or field_name
2425+
if field not in RUNTIME_RENDERED_MODEL_FIELDS and (field_value := fields.get(field)):
2426+
if isinstance(field_value, dict):
2427+
for key in list(field_value.keys()):
2428+
if key not in RUNTIME_RENDERED_MODEL_FIELDS:
2429+
fields[field][key] = render_field_value(field_value[key])
2430+
elif isinstance(field_value, list):
2431+
fields[field] = [render_field_value(value) for value in field_value]
2432+
else:
2433+
fields[field] = render_field_value(field_value)
2434+
2435+
return fields
2436+
2437+
2438+
def render_expression(
23862439
expression: exp.Expression,
23872440
module_path: Path,
23882441
path: Path,
@@ -2391,7 +2444,7 @@ def _meta_renderer(
23912444
dialect: DialectType = None,
23922445
variables: t.Optional[t.Dict[str, t.Any]] = None,
23932446
default_catalog: t.Optional[str] = None,
2394-
) -> ExpressionRenderer:
2447+
) -> t.Optional[t.List[exp.Expression]]:
23952448
meta_python_env = make_python_env(
23962449
expressions=expression,
23972450
jinja_macro_references=None,
@@ -2410,7 +2463,7 @@ def _meta_renderer(
24102463
default_catalog=default_catalog,
24112464
quote_identifiers=False,
24122465
normalize_identifiers=False,
2413-
)
2466+
).render()
24142467

24152468

24162469
META_FIELD_CONVERTER: t.Dict[str, t.Callable] = {

tests/core/test_context.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,27 @@ def test_disabled_model(copy_to_temp_path):
10541054
assert not context.get_model("sushi.disabled_py")
10551055

10561056

1057+
def test_disabled_model_python_macro(sushi_context):
1058+
@model(
1059+
"memory.sushi.disabled_model_2",
1060+
columns={"col": "int"},
1061+
enabled="@IF(@gateway = 'dev', True, False)",
1062+
)
1063+
def entrypoint(context, **kwargs):
1064+
yield pd.DataFrame({"col": []})
1065+
1066+
test_model = model.get_registry()["memory.sushi.disabled_model_2"].model(
1067+
module_path=Path("."), path=Path("."), variables={"gateway": "prod"}
1068+
)
1069+
assert not test_model.enabled
1070+
1071+
with pytest.raises(
1072+
SQLMeshError,
1073+
match="The disabled model 'memory.sushi.disabled_model_2' cannot be upserted",
1074+
):
1075+
sushi_context.upsert_model(test_model)
1076+
1077+
10571078
def test_get_model_mixed_dialects(copy_to_temp_path):
10581079
path = copy_to_temp_path("examples/sushi")
10591080

0 commit comments

Comments
 (0)