Skip to content

Commit 56e47e9

Browse files
authored
Fix: don't treat the PARAMETER token as a MacroVar start if it's not @ (#1505)
* Fix: don't treat the PARAMETER token as a MacroVar start if it's not @ * Factor out '@' into a constant
1 parent 8a8b135 commit 56e47e9

3 files changed

Lines changed: 14 additions & 5 deletions

File tree

sqlmesh/core/dialect.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from sqlmesh.utils.errors import SQLMeshError
2020
from sqlmesh.utils.pandas import columns_to_types_from_df
2121

22+
SQLMESH_MACRO_PREFIX = "@"
23+
2224
JSON_TYPE = exp.DataType.build("json")
2325

2426

@@ -147,6 +149,9 @@ def _parse_lambda(self: Parser, alias: bool = False) -> t.Optional[exp.Expressio
147149

148150

149151
def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expression]:
152+
if self._prev.text != SQLMESH_MACRO_PREFIX:
153+
return self._parse_parameter()
154+
150155
index = self._index
151156
field = self._parse_primary() or self._parse_function(functions={}) or self._parse_id_var()
152157

@@ -330,7 +335,7 @@ def _parse_table_parts(self: Parser, schema: bool = False) -> exp.Table:
330335
table = self.__parse_table_parts(schema=schema) # type: ignore
331336
table_arg = table.this
332337

333-
if isinstance(table_arg, exp.Var) and table_arg.name.startswith("@"):
338+
if isinstance(table_arg, exp.Var) and table_arg.name.startswith(SQLMESH_MACRO_PREFIX):
334339
return StagedFilePath(this=MacroVar(this=table_arg.name[1:]))
335340

336341
return table
@@ -640,7 +645,7 @@ def extend_sqlglot() -> None:
640645
generators.add(dialect.Generator)
641646

642647
for tokenizer in tokenizers:
643-
tokenizer.VAR_SINGLE_TOKENS.update("@")
648+
tokenizer.VAR_SINGLE_TOKENS.update(SQLMESH_MACRO_PREFIX)
644649

645650
for parser in parsers:
646651
parser.FUNCTIONS.update({"JINJA": Jinja.from_arg_list, "METRIC": MetricAgg.from_arg_list})

sqlmesh/core/macros.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sqlglot.helper import csv, ensure_collection
1313

1414
from sqlmesh.core.dialect import (
15+
SQLMESH_MACRO_PREFIX,
1516
MacroDef,
1617
MacroFunc,
1718
MacroSQL,
@@ -26,7 +27,7 @@
2627

2728

2829
class MacroStrTemplate(Template):
29-
delimiter = "@"
30+
delimiter = SQLMESH_MACRO_PREFIX
3031

3132

3233
EXPRESSIONS_NAME_MAP = {}
@@ -313,7 +314,7 @@ def substitute(
313314
if isinstance(node, (exp.Identifier, exp.Var)):
314315
if node.name in args and not isinstance(node.parent, exp.Column):
315316
return args[node.name].copy()
316-
if "@" in node.name:
317+
if SQLMESH_MACRO_PREFIX in node.name:
317318
return node.__class__(
318319
this=evaluator.template(node.name, {k: v.name for k, v in args.items()})
319320
)

tests/core/test_macros.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from sqlglot import exp, parse_one
77

8-
from sqlmesh.core.dialect import StagedFilePath
8+
from sqlmesh.core.dialect import MacroVar, StagedFilePath
99
from sqlmesh.core.macros import MacroEvaluator, macro
1010
from sqlmesh.utils.errors import SQLMeshError
1111
from sqlmesh.utils.metaprogramming import Executable, ExecutableKind
@@ -93,6 +93,9 @@ def test_macro_var(macro_evaluator):
9393

9494
assert "Macro variable 'y' is undefined." in str(ex.value)
9595

96+
# Parsing a "parameter" like Snowflake's $1 should not produce a MacroVar expression
97+
assert parse_one("select $1", read="snowflake").find(MacroVar) is None
98+
9699

97100
def test_macro_str_replace(macro_evaluator):
98101
expression = parse_one("""@'@val1, @val2'""")

0 commit comments

Comments
 (0)