Skip to content

Commit 8ceb40f

Browse files
authored
Feat!: Extend SQL parser with designated blocks for jinja code (#963)
1 parent b114e1c commit 8ceb40f

14 files changed

Lines changed: 324 additions & 79 deletions

File tree

docs/concepts/macros/jinja_macros.md

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Jinja macros
22

3-
SQLMesh supports macros from the [Jinja](https://jinja.palletsprojects.com/en/3.1.x/) templating system.
3+
SQLMesh supports macros from the [Jinja](https://jinja.palletsprojects.com/en/3.1.x/) templating system.
44

5-
Jinja's macro approach is pure string substitution. Unlike SQLMesh macros, they assemble SQL query text without building a semantic representation.
5+
Jinja's macro approach is pure string substitution. Unlike SQLMesh macros, they assemble SQL query text without building a semantic representation.
66

77
**NOTE:** SQLMesh projects support the standard Jinja function library only - they do **not** support dbt-specific jinja functions like `{{ ref() }}`. dbt-specific functions are allowed in dbt projects being run with the [SQLMesh adapter](../../integrations/dbt.md).
88

@@ -16,6 +16,41 @@ The three curly brace symbols are:
1616
- `{%...%}` creates Jinja statements. Statements give instructions to Jinja, such as setting variable values, control flow with `if`, `for` loops, and defining macro functions.
1717
- `{#...#}` creates Jinja comments. These comments will not be included in the rendered SQL query.
1818

19+
Since Jinja strings are not syntactically valid SQL expressions and cannot be parsed as such, the model query must be wrapped in a special `JINJA_QUERY_BEGIN; ...; JINJA_END;` block in order for SQLMesh to detect it:
20+
21+
```sql linenums="1" hl_lines="5 9"
22+
MODEL (
23+
name sqlmesh_example.full_model
24+
);
25+
26+
JINJA_QUERY_BEGIN;
27+
28+
SELECT {{ 1 + 1 }};
29+
30+
JINJA_END;
31+
```
32+
33+
Similarly, to use Jinja expressions as part of statements that should be evaluated before or after the model query, the `JINJA_STATEMENT_BEGIN; ...; JINJA_END;` block should be used:
34+
35+
```sql linenums="1"
36+
MODEL (
37+
name sqlmesh_example.full_model
38+
);
39+
40+
JINJA_STATEMENT_BEGIN;
41+
{{ pre_hook() }}
42+
JINJA_END;
43+
44+
JINJA_QUERY_BEGIN;
45+
SELECT {{ 1 + 1 }};
46+
JINJA_END;
47+
48+
JINJA_STATEMENT_BEGIN;
49+
{{ post_hook() }}
50+
JINJA_END;
51+
```
52+
53+
1954
## User-defined variables
2055

2156
Define your own variables with the Jinja statement `{% set ... %}`. For example, we could specify the name of the `num_orders` column in the `sqlmesh_example.full_model` like this:
@@ -28,6 +63,8 @@ MODEL (
2863
audits [assert_positive_order_ids],
2964
);
3065

66+
JINJA_QUERY_BEGIN;
67+
3168
{% set my_col = 'num_orders' %} -- Jinja definition of variable `my_col`
3269

3370
SELECT
@@ -36,6 +73,8 @@ SELECT
3673
FROM
3774
sqlmesh_example.incremental_model
3875
GROUP BY item_id
76+
77+
JINJA_END;
3978
```
4079

4180
Note that the Jinja set statement is written after the `MODEL` statement and before the SQL query.
@@ -48,7 +87,7 @@ Jinja variables can be string, integer, or float data types. They can also be an
4887

4988
#### for loops
5089

51-
For loops let you iterate over a collection of items to condense repetitive code and easily change the values used by the code.
90+
For loops let you iterate over a collection of items to condense repetitive code and easily change the values used by the code.
5291

5392
Jinja for loops begin with `{% for ... %}` and end with `{% endfor %}`. This example demonstrates creating indicator variables with `CASE WHEN` using a Jinja for loop:
5493

@@ -88,9 +127,9 @@ FROM table
88127

89128
The rendered query would be the same as before.
90129

91-
#### if
130+
#### if
92131

93-
if statements allow you to take an action (or not) based on some condition.
132+
if statements allow you to take an action (or not) based on some condition.
94133

95134
Jinja if statements begin with `{% if ... %}` and end with `{% endif %}`. The starting `if` statement must contain code that evaluates to `True` or `False`. For example, all of `True`, `1 + 1 == 2`, and `'a' in ['a', 'b']` evaluate to `True`.
96135

@@ -118,11 +157,11 @@ FROM table
118157

119158
## User-defined macro functions
120159

121-
User-defined macro functions allow the same macro code to be used in multiple models.
160+
User-defined macro functions allow the same macro code to be used in multiple models.
122161

123162
Jinja macro functions should be placed in `.sql` files in the SQLMesh project's `macros` directory. Multiple functions can be defined in one `.sql` file, or they can be distributed across multiple files.
124163

125-
Jinja macro functions are defined with the `{% macro %}` and `{% endmacro %}` statements. The macro function name and arguments are specified in the `{% macro %}` statement.
164+
Jinja macro functions are defined with the `{% macro %}` and `{% endmacro %}` statements. The macro function name and arguments are specified in the `{% macro %}` statement.
126165

127166
For example, a macro function named `print_text` that takes no arguments could be defined with:
128167

@@ -186,6 +225,6 @@ Some SQL dialects interpret double and single quotes differently. We could repla
186225

187226
## Mixing macro systems
188227

189-
SQLMesh supports both the Jinja and [SQLMesh](./sqlmesh_macros.md) macro systems. We strongly recommend using only one system in a single model - if both are present, they may fail or behave in unintuitive ways.
228+
SQLMesh supports both the Jinja and [SQLMesh](./sqlmesh_macros.md) macro systems. We strongly recommend using only one system in a single model - if both are present, they may fail or behave in unintuitive ways.
190229

191230
[Predefined SQLMesh macro variables](./macro_variables.md) can be used in a query containing user-defined Jinja variables and functions. However, predefined variables passed as arguments to a user-defined Jinja macro function must use the Jinja curly brace syntax `{{ start_ds }}` instead of the SQLMesh macro `@` prefix syntax `@start_ds`. Note that curly brace syntax may require quoting to generate the equivalent of the `@` syntax.

examples/sushi/models/waiter_as_customer_by_day.sql

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ MODEL (
1212
)
1313
);
1414

15+
JINJA_QUERY_BEGIN;
16+
1517
{% set x = 1 %}
1618

1719
SELECT
@@ -21,4 +23,6 @@ SELECT
2123
{{ alias(identity(x), 'flag') }}
2224
FROM sushi.waiters AS w
2325
JOIN sushi.customers as c ON w.waiter_id = c.customer_id
24-
JOIN sushi.waiter_names as wn ON w.waiter_id = wn.id
26+
JOIN sushi.waiter_names as wn ON w.waiter_id = wn.id;
27+
28+
JINJA_END;

sqlmesh/core/audit/definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class Audit(AuditMeta, frozen=True):
5353
An audit is a SQL query that returns bad records.
5454
"""
5555

56-
query: t.Union[exp.Subqueryable, d.Jinja]
56+
query: t.Union[exp.Subqueryable, d.JinjaQuery]
5757
expressions_: t.Optional[t.List[exp.Expression]] = Field(default=None, alias="expressions")
5858
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
5959

sqlmesh/core/dialect.py

Lines changed: 85 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import re
55
import typing as t
66
from difflib import unified_diff
7+
from enum import Enum, auto
78

89
import pandas as pd
910
from jinja2.meta import find_undeclared_variables
1011
from sqlglot import Dialect, Generator, Parser, TokenType, exp
1112
from sqlglot.dialects.dialect import DialectType
1213
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
14+
from sqlglot.tokens import Token
1315

1416
from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE
1517
from sqlmesh.utils.jinja import ENVIRONMENT
@@ -29,6 +31,14 @@ class Jinja(exp.Func):
2931
is_var_len_args = True
3032

3133

34+
class JinjaQuery(Jinja):
35+
pass
36+
37+
38+
class JinjaStatement(Jinja):
39+
pass
40+
41+
3242
class ModelKind(exp.Expression):
3343
arg_types = {"this": True, "expressions": False}
3444

@@ -387,6 +397,47 @@ def text_diff(
387397
)
388398

389399

400+
def _is_command_statement(command: str, tokens: t.List[Token], pos: int) -> bool:
401+
try:
402+
return (
403+
tokens[pos].text.upper() == command.upper()
404+
and tokens[pos + 1].token_type == TokenType.SEMICOLON
405+
)
406+
except IndexError:
407+
return False
408+
409+
410+
JINJA_QUERY_BEGIN = "JINJA_QUERY_BEGIN"
411+
JINJA_STATEMENT_BEGIN = "JINJA_STATEMENT_BEGIN"
412+
JINJA_END = "JINJA_END"
413+
414+
415+
def _is_jinja_statement_begin(tokens: t.List[Token], pos: int) -> bool:
416+
return _is_command_statement(JINJA_STATEMENT_BEGIN, tokens, pos)
417+
418+
419+
def _is_jinja_query_begin(tokens: t.List[Token], pos: int) -> bool:
420+
return _is_command_statement(JINJA_QUERY_BEGIN, tokens, pos)
421+
422+
423+
def _is_jinja_end(tokens: t.List[Token], pos: int) -> bool:
424+
return _is_command_statement(JINJA_END, tokens, pos)
425+
426+
427+
def jinja_query(query: str) -> JinjaQuery:
428+
return JinjaQuery(this=exp.Literal.string(query))
429+
430+
431+
def jinja_statement(statement: str) -> JinjaStatement:
432+
return JinjaStatement(this=exp.Literal.string(statement))
433+
434+
435+
class ChunkType(Enum):
436+
JINJA_QUERY = auto()
437+
JINJA_STATEMENT = auto()
438+
SQL = auto()
439+
440+
390441
def parse(sql: str, default_dialect: t.Optional[str] = None) -> t.List[exp.Expression]:
391442
"""Parse a sql string.
392443
@@ -404,37 +455,51 @@ def parse(sql: str, default_dialect: t.Optional[str] = None) -> t.List[exp.Expre
404455
dialect = Dialect.get_or_raise(match.group(2) if match else default_dialect)()
405456

406457
tokens = dialect.tokenizer.tokenize(sql)
407-
chunks: t.List[t.Tuple[t.List, bool]] = [([], False)]
458+
chunks: t.List[t.Tuple[t.List[Token], ChunkType]] = [([], ChunkType.SQL)]
408459
total = len(tokens)
409460

410-
for i, token in enumerate(tokens):
411-
if token.token_type == TokenType.SEMICOLON:
412-
if i < total - 1:
413-
chunks.append(([], False))
461+
pos = 0
462+
while pos < total:
463+
token = tokens[pos]
464+
if _is_jinja_end(tokens, pos) or (
465+
chunks[-1][1] == ChunkType.SQL
466+
and token.token_type == TokenType.SEMICOLON
467+
and pos < total - 1
468+
):
469+
chunks.append(([], ChunkType.SQL))
470+
if token.token_type == TokenType.SEMICOLON:
471+
pos += 1
472+
else:
473+
# Jinja end statement
474+
pos += 2
475+
elif _is_jinja_query_begin(tokens, pos):
476+
chunks.append(([], ChunkType.JINJA_QUERY))
477+
pos += 2
478+
elif _is_jinja_statement_begin(tokens, pos):
479+
chunks.append(([], ChunkType.JINJA_STATEMENT))
480+
pos += 2
414481
else:
415-
if token.token_type == TokenType.BLOCK_START or (
416-
i < total - 1
417-
and token.token_type == TokenType.L_BRACE
418-
and tokens[i + 1].token_type == TokenType.L_BRACE
419-
):
420-
chunks[-1] = (chunks[-1][0], True)
421482
chunks[-1][0].append(token)
483+
pos += 1
422484

423485
expressions: t.List[exp.Expression] = []
424486

425-
for chunk, is_jinja in chunks:
426-
if is_jinja:
487+
for chunk, chunk_type in chunks:
488+
if chunk_type == ChunkType.SQL:
489+
for expression in dialect.parser().parse(chunk, sql):
490+
if expression:
491+
expressions.append(expression)
492+
else:
427493
start, *_, end = chunk
428494
segment = sql[start.start : end.end + 2]
429495
variables = [
430496
exp.Literal.string(var)
431497
for var in find_undeclared_variables(ENVIRONMENT.parse(segment))
432498
]
433-
expressions.append(Jinja(this=exp.Literal.string(segment), expressions=variables))
434-
else:
435-
for expression in dialect.parser().parse(chunk, sql):
436-
if expression:
437-
expressions.append(expression)
499+
klass = JinjaQuery if chunk_type == ChunkType.JINJA_QUERY else JinjaStatement
500+
expressions.append(
501+
klass(this=exp.Literal.string(segment.strip()), expressions=variables)
502+
)
438503

439504
return expressions
440505

@@ -462,6 +527,8 @@ def extend_sqlglot() -> None:
462527
MacroVar: lambda self, e: f"@{e.name}",
463528
Model: _model_sql,
464529
Jinja: lambda self, e: e.name,
530+
JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};",
531+
JinjaStatement: lambda self, e: f"{JINJA_STATEMENT_BEGIN};\n{e.name.strip()}\n{JINJA_END};",
465532
ModelKind: _model_kind_sql,
466533
PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False),
467534
}

sqlmesh/core/model/definition.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from astor import to_source
1414
from pydantic import Field
15-
from sqlglot import diff, exp, parse_one
15+
from sqlglot import diff, exp
1616
from sqlglot.diff import Insert, Keep
1717
from sqlglot.helper import ensure_list
1818
from sqlglot.optimizer.scope import traverse_scope
@@ -26,7 +26,7 @@
2626
from sqlmesh.core.model.meta import ModelMeta
2727
from sqlmesh.core.model.seed import Seed, create_seed
2828
from sqlmesh.core.renderer import ExpressionRenderer, QueryRenderer
29-
from sqlmesh.utils.date import TimeLike, date_dict, make_inclusive, to_datetime
29+
from sqlmesh.utils.date import TimeLike, make_inclusive, to_datetime
3030
from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error
3131
from sqlmesh.utils.jinja import JinjaMacroRegistry, extract_macro_references
3232
from sqlmesh.utils.metaprogramming import (
@@ -630,7 +630,7 @@ class SqlModel(_SqlBasedModel):
630630
post_statements: The list of SQL statements that follow after the model's query.
631631
"""
632632

633-
query: t.Union[exp.Subqueryable, d.Jinja]
633+
query: t.Union[exp.Subqueryable, d.JinjaQuery]
634634
source_type: Literal["sql"] = "sql"
635635

636636
_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None
@@ -1042,7 +1042,7 @@ def load_model(
10421042

10431043
# Extract the query and any pre/post statements
10441044
query_or_seed_insert, pre_statements, post_statements = _split_sql_model_statements(
1045-
expressions[1:], jinja_macros or JinjaMacroRegistry(), dialect, path
1045+
expressions[1:], path
10461046
)
10471047

10481048
meta_fields: t.Dict[str, t.Any] = {
@@ -1067,7 +1067,9 @@ def load_model(
10671067
for r in references
10681068
}
10691069

1070-
if query_or_seed_insert is not None and isinstance(query_or_seed_insert, exp.Subqueryable):
1070+
if query_or_seed_insert is not None and isinstance(
1071+
query_or_seed_insert, (exp.Subqueryable, d.JinjaQuery)
1072+
):
10711073
macro_references.update(extract_macro_references(query_or_seed_insert.sql(dialect=dialect)))
10721074
return create_sql_model(
10731075
name,
@@ -1103,7 +1105,7 @@ def load_model(
11031105
)
11041106
except Exception:
11051107
raise_config_error(
1106-
"The model definition must either have a SELECT query or a valid Seed kind",
1108+
"The model definition must either have a SELECT query, a JINJA_QUERY block, or a valid Seed kind",
11071109
path,
11081110
)
11091111
raise
@@ -1142,9 +1144,9 @@ def create_sql_model(
11421144
dialect: The default dialect if no model dialect is configured.
11431145
The format must adhere to Python's strftime codes.
11441146
"""
1145-
if not isinstance(query, (exp.Subqueryable, d.Jinja)):
1147+
if not isinstance(query, (exp.Subqueryable, d.JinjaQuery)):
11461148
raise_config_error(
1147-
"A query is required and must be a SELECT or UNION statement",
1149+
"A query is required and must be a SELECT statement, a UNION statement, or a JINJA_QUERY block",
11481150
path,
11491151
)
11501152

@@ -1337,7 +1339,7 @@ def _create_model(
13371339

13381340

13391341
def _split_sql_model_statements(
1340-
expressions: t.List[exp.Expression], jinja_macros: JinjaMacroRegistry, dialect: str, path: Path
1342+
expressions: t.List[exp.Expression], path: Path
13411343
) -> t.Tuple[t.Optional[exp.Expression], t.List[exp.Expression], t.List[exp.Expression]]:
13421344
"""Extracts the SELECT query from a sequence of expressions.
13431345
@@ -1353,15 +1355,10 @@ def _split_sql_model_statements(
13531355
"""
13541356
query_positions = []
13551357
for idx, expression in enumerate(expressions):
1356-
# Render the expression using the macro registry if the expression is jinja.
1357-
if isinstance(expression, d.Jinja):
1358-
jinja_env = jinja_macros.build_environment(**date_dict(c.EPOCH, c.EPOCH, c.EPOCH))
1359-
rendered_expression = jinja_env.from_string(expression.name).render()
1360-
if rendered_expression is None:
1361-
continue
1362-
expression = parse_one(rendered_expression, read=dialect)
1363-
1364-
if isinstance(expression, exp.Subqueryable) or expression == INSERT_SEED_MACRO_CALL:
1358+
if (
1359+
isinstance(expression, (exp.Subqueryable, d.JinjaQuery))
1360+
or expression == INSERT_SEED_MACRO_CALL
1361+
):
13651362
query_positions.append((expression, idx))
13661363

13671364
if not query_positions:

0 commit comments

Comments
 (0)