Skip to content

Commit 1746a29

Browse files
authored
Fix(macros): incrementally resolve macro var definitions (#1253)
1 parent 8cc8a33 commit 1746a29

2 files changed

Lines changed: 13 additions & 8 deletions

File tree

sqlmesh/core/macros.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def send(
124124
print_exception(e, self.python_env)
125125
raise MacroEvalError(f"Error trying to eval macro.") from e
126126

127-
def transform(self, query: exp.Expression) -> exp.Expression | t.List[exp.Expression] | None:
127+
def transform(
128+
self, expression: exp.Expression
129+
) -> exp.Expression | t.List[exp.Expression] | None:
128130
changed = False
129131

130132
def _transform_node(node: exp.Expression) -> exp.Expression:
@@ -141,7 +143,7 @@ def _transform_node(node: exp.Expression) -> exp.Expression:
141143
return node
142144
return node
143145

144-
query = query.transform(_transform_node)
146+
expression = expression.transform(_transform_node)
145147

146148
def evaluate_macros(
147149
node: exp.Expression,
@@ -156,7 +158,7 @@ def evaluate_macros(
156158
return self.evaluate(node)
157159
return node
158160

159-
transformed = evaluate_macros(query)
161+
transformed = evaluate_macros(expression)
160162

161163
if changed:
162164
# the transformations could have corrupted the ast, turning this into sql and reparsing ensures
@@ -182,7 +184,7 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
182184

183185
def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None:
184186
if isinstance(node, MacroDef):
185-
self.locals[node.name] = node.expression
187+
self.locals[node.name] = self.transform(node.expression)
186188
return node
187189

188190
if isinstance(node, (MacroSQL, MacroStrReplace)):

tests/core/test_model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_load(assert_exp_eq):
5555
);
5656
5757
@DEF(x, 1);
58+
@DEF(y, @x + 1);
5859
CACHE TABLE x AS SELECT 1;
5960
ADD JAR 's3://my_jar.jar';
6061
@@ -65,6 +66,7 @@ def test_load(assert_exp_eq):
6566
1::int AS d, -- d
6667
CAST(2 AS double) AS e, --e
6768
f::bool, --f
69+
@y::int AS g,
6870
FROM
6971
db.other_table t1
7072
LEFT JOIN
@@ -90,13 +92,13 @@ def test_load(assert_exp_eq):
9092
"d": exp.DataType.build("int"),
9193
"e": exp.DataType.build("double"),
9294
"f": exp.DataType.build("boolean"),
95+
"g": exp.DataType.build("int"),
9396
}
9497
assert model.view_name == "table"
95-
assert model.macro_definitions == [
96-
d.parse_one("@DEF(x, 1)"),
97-
]
98+
assert model.macro_definitions == [d.parse_one("@DEF(x, 1)"), d.parse_one("@DEF(y, @x + 1)")]
9899
assert list(model.pre_statements) == [
99100
d.parse_one("@DEF(x, 1)"),
101+
d.parse_one("@DEF(y, @x + 1)"),
100102
d.parse_one("CACHE TABLE x AS SELECT 1"),
101103
d.parse_one("ADD JAR 's3://my_jar.jar'", dialect="spark"),
102104
]
@@ -113,7 +115,8 @@ def test_load(assert_exp_eq):
113115
TRY_CAST("c" AS BOOLEAN) AS "c",
114116
TRY_CAST(1 AS INT) AS "d", /* d */
115117
TRY_CAST(2 AS DOUBLE) AS "e", /* e */
116-
TRY_CAST("f" AS BOOLEAN) /* f */ AS "f"
118+
TRY_CAST("f" AS BOOLEAN) AS "f", /* f */
119+
TRY_CAST(1 + 1 AS INT) AS "g",
117120
FROM "db"."other_table" AS "t1"
118121
LEFT JOIN "db"."table" AS "t2"
119122
ON "t1"."a" = "t2"."a"

0 commit comments

Comments
 (0)