Skip to content

Commit 19fcc0d

Browse files
authored
Fix: Redundant extraction of macro references in jinja expressions (#965)
1 parent 5fe0687 commit 19fcc0d

File tree

3 files changed

+71
-42
lines changed

3 files changed

+71
-42
lines changed

sqlmesh/core/dialect.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@
77
from enum import Enum, auto
88

99
import pandas as pd
10-
from jinja2.meta import find_undeclared_variables
1110
from sqlglot import Dialect, Generator, Parser, TokenType, exp
1211
from sqlglot.dialects.dialect import DialectType
1312
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1413
from sqlglot.tokens import Token
1514

1615
from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE
17-
from sqlmesh.utils.jinja import ENVIRONMENT
1816
from sqlmesh.utils.pandas import columns_to_types_from_df
1917

2018

@@ -27,8 +25,7 @@ class Audit(exp.Expression):
2725

2826

2927
class Jinja(exp.Func):
30-
arg_types = {"this": True, "expressions": False}
31-
is_var_len_args = True
28+
arg_types = {"this": True}
3229

3330

3431
class JinjaQuery(Jinja):
@@ -492,14 +489,8 @@ def parse(sql: str, default_dialect: t.Optional[str] = None) -> t.List[exp.Expre
492489
else:
493490
start, *_, end = chunk
494491
segment = sql[start.start : end.end + 2]
495-
variables = [
496-
exp.Literal.string(var)
497-
for var in find_undeclared_variables(ENVIRONMENT.parse(segment))
498-
]
499-
klass = JinjaQuery if chunk_type == ChunkType.JINJA_QUERY else JinjaStatement
500-
expressions.append(
501-
klass(this=exp.Literal.string(segment.strip()), expressions=variables)
502-
)
492+
factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement
493+
expressions.append(factory(segment.strip()))
503494

504495
return expressions
505496

sqlmesh/core/model/definition.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,7 @@ def load_model(
10581058
if not name:
10591059
raise_config_error("Model must have a name", path)
10601060

1061-
macro_references: t.Set[MacroReference] = {
1061+
jinja_macro_references: t.Set[MacroReference] = {
10621062
r
10631063
for references in [
10641064
*[extract_macro_references(e.sql(dialect=dialect)) for e in pre_statements],
@@ -1067,41 +1067,40 @@ def load_model(
10671067
for r in references
10681068
}
10691069

1070+
common_kwargs = dict(
1071+
pre_statements=pre_statements,
1072+
post_statements=post_statements,
1073+
defaults=defaults,
1074+
path=path,
1075+
module_path=module_path,
1076+
macros=macros,
1077+
python_env=python_env,
1078+
jinja_macros=jinja_macros,
1079+
jinja_macro_references=jinja_macro_references,
1080+
**meta_fields,
1081+
)
1082+
10701083
if query_or_seed_insert is not None and isinstance(
10711084
query_or_seed_insert, (exp.Subqueryable, d.JinjaQuery)
10721085
):
1073-
macro_references.update(extract_macro_references(query_or_seed_insert.sql(dialect=dialect)))
1086+
jinja_macro_references.update(
1087+
extract_macro_references(query_or_seed_insert.sql(dialect=dialect))
1088+
)
10741089
return create_sql_model(
10751090
name,
10761091
query_or_seed_insert,
1077-
pre_statements=pre_statements,
1078-
post_statements=post_statements,
1079-
defaults=defaults,
1080-
path=path,
1081-
module_path=module_path,
10821092
time_column_format=time_column_format,
1083-
macros=macros,
1084-
jinja_macros=(jinja_macros or JinjaMacroRegistry()).trim(macro_references),
1085-
python_env=python_env,
1086-
**meta_fields,
1093+
**common_kwargs,
10871094
)
10881095
else:
10891096
try:
10901097
seed_properties = {
1091-
p.name.lower(): p.args.get("value") for p in meta_fields.pop("kind").expressions
1098+
p.name.lower(): p.args.get("value") for p in common_kwargs.pop("kind").expressions
10921099
}
10931100
return create_seed_model(
10941101
name,
10951102
SeedKind(**seed_properties),
1096-
pre_statements=pre_statements,
1097-
post_statements=post_statements,
1098-
defaults=defaults,
1099-
path=path,
1100-
module_path=module_path,
1101-
macros=macros,
1102-
jinja_macros=(jinja_macros or JinjaMacroRegistry()).trim(macro_references),
1103-
python_env=python_env,
1104-
**meta_fields,
1103+
**common_kwargs,
11051104
)
11061105
except Exception:
11071106
raise_config_error(
@@ -1123,6 +1122,8 @@ def create_sql_model(
11231122
time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT,
11241123
macros: t.Optional[MacroRegistry] = None,
11251124
python_env: t.Optional[t.Dict[str, Executable]] = None,
1125+
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
1126+
jinja_macro_references: t.Optional[t.Set[MacroReference]] = None,
11261127
dialect: t.Optional[str] = None,
11271128
**kwargs: t.Any,
11281129
) -> Model:
@@ -1156,6 +1157,7 @@ def create_sql_model(
11561157
if not python_env:
11571158
python_env = _python_env(
11581159
[*pre_statements, query, *post_statements],
1160+
jinja_macro_references,
11591161
module_path,
11601162
macros or macro.get_registry(),
11611163
)
@@ -1167,6 +1169,8 @@ def create_sql_model(
11671169
path=path,
11681170
time_column_format=time_column_format,
11691171
python_env=python_env,
1172+
jinja_macros=jinja_macros,
1173+
jinja_macro_references=jinja_macro_references,
11701174
dialect=dialect,
11711175
query=query,
11721176
pre_statements=pre_statements,
@@ -1186,6 +1190,8 @@ def create_seed_model(
11861190
module_path: Path = Path(),
11871191
macros: t.Optional[MacroRegistry] = None,
11881192
python_env: t.Optional[t.Dict[str, Executable]] = None,
1193+
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
1194+
jinja_macro_references: t.Optional[t.Set[MacroReference]] = None,
11891195
**kwargs: t.Any,
11901196
) -> Model:
11911197
"""Creates a Seed model.
@@ -1213,6 +1219,7 @@ def create_seed_model(
12131219
if not python_env:
12141220
python_env = _python_env(
12151221
[*pre_statements, *post_statements],
1222+
jinja_macro_references,
12161223
module_path,
12171224
macros or macro.get_registry(),
12181225
)
@@ -1226,6 +1233,8 @@ def create_seed_model(
12261233
kind=seed_kind,
12271234
depends_on=kwargs.pop("depends_on", set()),
12281235
python_env=python_env,
1236+
jinja_macros=jinja_macros,
1237+
jinja_macro_references=jinja_macro_references,
12291238
pre_statements=pre_statements,
12301239
post_statements=post_statements,
12311240
**kwargs,
@@ -1306,6 +1315,8 @@ def _create_model(
13061315
defaults: t.Optional[t.Dict[str, t.Any]] = None,
13071316
path: Path = Path(),
13081317
time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT,
1318+
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
1319+
jinja_macro_references: t.Optional[t.Set[MacroReference]] = None,
13091320
depends_on: t.Optional[t.Set[str]] = None,
13101321
dialect: t.Optional[str] = None,
13111322
**kwargs: t.Any,
@@ -1314,11 +1325,16 @@ def _create_model(
13141325

13151326
dialect = dialect or ""
13161327

1328+
jinja_macros = jinja_macros or JinjaMacroRegistry()
1329+
if jinja_macro_references is not None:
1330+
jinja_macros = jinja_macros.trim(jinja_macro_references)
1331+
13171332
try:
13181333
model = klass(
13191334
name=name,
13201335
**{
13211336
**(defaults or {}),
1337+
"jinja_macros": jinja_macros,
13221338
"dialect": dialect,
13231339
"depends_on": depends_on,
13241340
**kwargs,
@@ -1403,27 +1419,25 @@ def _find_tables(expressions: t.List[exp.Expression]) -> t.Set[str]:
14031419

14041420
def _python_env(
14051421
expressions: t.Union[exp.Expression, t.List[exp.Expression]],
1422+
jinja_macro_references: t.Optional[t.Set[MacroReference]],
14061423
module_path: Path,
14071424
macros: MacroRegistry,
14081425
) -> t.Dict[str, Executable]:
14091426
python_env: t.Dict[str, Executable] = {}
14101427

14111428
used_macros = {}
14121429

1413-
def _capture_expression_macros(expression: exp.Expression) -> None:
1414-
if isinstance(expression, d.Jinja):
1415-
for var in expression.expressions:
1416-
if var in macros:
1417-
used_macros[var] = macros[var]
1418-
else:
1430+
expressions = ensure_list(expressions)
1431+
for expression in expressions:
1432+
if not isinstance(expression, d.Jinja):
14191433
for macro_func in expression.find_all(d.MacroFunc):
14201434
if macro_func.__class__ is d.MacroFunc:
14211435
name = macro_func.this.name.lower()
14221436
used_macros[name] = macros[name]
14231437

1424-
expressions = ensure_list(expressions)
1425-
for expression in expressions:
1426-
_capture_expression_macros(expression)
1438+
for macro_ref in jinja_macro_references or set():
1439+
if macro_ref.package is None and macro_ref.name in macros:
1440+
used_macros[macro_ref.name] = macros[macro_ref.name]
14271441

14281442
for name, macro in used_macros.items():
14291443
if not macro.func.__module__.startswith("sqlmesh."):

tests/core/test_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,30 @@ def test_column_descriptions(sushi_context, assert_exp_eq):
333333
)
334334

335335

336+
def test_model_jinja_macro_reference_extraction():
337+
@macro()
338+
def test_macro(**kwargs) -> None:
339+
pass
340+
341+
expressions = d.parse(
342+
"""
343+
MODEL (
344+
name db.table,
345+
dialect spark,
346+
owner owner_name,
347+
);
348+
349+
JINJA_STATEMENT_BEGIN;
350+
{{ test_macro() }}
351+
JINJA_END;
352+
353+
SELECT 1 AS x;
354+
"""
355+
)
356+
model = load_model(expressions)
357+
assert "test_macro" in model.python_env
358+
359+
336360
def test_model_pre_post_statements():
337361
@macro()
338362
def foo(**kwargs) -> None:

0 commit comments

Comments
 (0)