Skip to content

Commit f5d4d45

Browse files
authored
Fix!: normalize identifiers and model names (#930)
* Fix!: normalize identifiers and model names during optimization * Revert query to query_ rename * Typo * Test fixups * PR feedback * Fix fingerprints * Quote identifiers at the very end * Ensure we use get_model, get_snapshot consistently, do some cleanups * Make depends_on validator more robust by using the dialect * Update test to reflect depends_on validator changes * PR feedback * Bump sqlglot * Fix table diff test
1 parent eecd772 commit f5d4d45

File tree

14 files changed

+139
-60
lines changed

14 files changed

+139
-60
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"requests",
4545
"rich",
4646
"ruamel.yaml",
47-
"sqlglot~=14.1.0",
47+
"sqlglot~=15.0.0",
4848
"fsspec",
4949
],
5050
extras_require={

sqlmesh/cli/main.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,8 @@ def render(
109109
dialect: t.Optional[str] = None,
110110
) -> None:
111111
"""Renders a model's query, optionally expanding referenced models."""
112-
snapshot = ctx.obj.snapshots.get(model)
113-
114-
if not snapshot:
115-
raise click.ClickException(f"Model `{model}` not found.")
116-
117112
rendered = ctx.obj.render(
118-
snapshot,
113+
model,
119114
start=start,
120115
end=end,
121116
latest=latest,

sqlmesh/core/context.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@
4949
from sqlmesh.core.config import Config, load_config_from_paths
5050
from sqlmesh.core.console import Console, get_console
5151
from sqlmesh.core.context_diff import ContextDiff
52-
from sqlmesh.core.dialect import format_model_expressions, pandas_to_sql, parse
52+
from sqlmesh.core.dialect import (
53+
format_model_expressions,
54+
normalize_model_name,
55+
pandas_to_sql,
56+
parse,
57+
)
5358
from sqlmesh.core.engine_adapter import EngineAdapter
5459
from sqlmesh.core.environment import Environment
5560
from sqlmesh.core.loader import Loader, SqlMeshLoader, update_model_schemas
@@ -277,13 +282,13 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model:
277282
Returns:
278283
A new instance of the updated or inserted model.
279284
"""
280-
if isinstance(model, str):
281-
model = self._models[model]
285+
model = self.get_model(model, raise_if_missing=True)
286+
path = model._path
282287

283-
path = model._path # type: ignore
284288
# model.copy() can't be used here due to a cached state that can be a part of a model instance.
285289
model = t.cast(Model, type(model)(**{**t.cast(Model, model).dict(), **kwargs}))
286290
model._path = path
291+
287292
self._models.update({model.name: model})
288293

289294
self._add_model_to_dag(model)
@@ -409,14 +414,16 @@ def get_model(
409414
The expected model.
410415
"""
411416
if isinstance(model_or_snapshot, str):
412-
model = self._models.get(model_or_snapshot)
417+
normalized_name = normalize_model_name(model_or_snapshot, dialect=self.config.dialect)
418+
model = self._models.get(normalized_name)
413419
elif isinstance(model_or_snapshot, Snapshot):
414420
model = model_or_snapshot.model
415421
else:
416422
model = model_or_snapshot
417423

418424
if raise_if_missing and not model:
419425
raise SQLMeshError(f"Cannot find model for '{model_or_snapshot}'")
426+
420427
return model
421428

422429
@t.overload
@@ -444,7 +451,8 @@ def get_snapshot(
444451
The expected snapshot.
445452
"""
446453
if isinstance(model_or_snapshot, str):
447-
snapshot = self.snapshots.get(model_or_snapshot)
454+
normalized_name = normalize_model_name(model_or_snapshot, dialect=self.config.dialect)
455+
snapshot = self.snapshots.get(normalized_name)
448456
elif isinstance(model_or_snapshot, Snapshot):
449457
snapshot = model_or_snapshot
450458
else:
@@ -908,7 +916,9 @@ def audit(
908916
"""
909917

910918
snapshots = (
911-
[self.snapshots[model] for model in models] if models else self.snapshots.values()
919+
[self.get_snapshot(model, raise_if_missing=True) for model in models]
920+
if models
921+
else self.snapshots.values()
912922
)
913923

914924
num_audits = sum(len(snapshot.model.audits) for snapshot in snapshots)
@@ -942,6 +952,7 @@ def audit(
942952
)
943953
self.console.log_status_update(f"Got {error.count} results, expected 0.")
944954
self.console.show_sql(f"{error.query}")
955+
945956
self.console.log_status_update("Done.")
946957

947958
def migrate(self) -> None:

sqlmesh/core/dialect.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import pandas as pd
99
from jinja2.meta import find_undeclared_variables
1010
from sqlglot import Dialect, Generator, Parser, TokenType, exp
11+
from sqlglot.dialects.dialect import DialectType
12+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1113

1214
from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE
1315
from sqlmesh.utils.jinja import ENVIRONMENT
@@ -544,3 +546,9 @@ def pandas_to_sql(
544546
batch_size=batch_size,
545547
alias=alias,
546548
)
549+
550+
551+
def normalize_model_name(table: str | exp.Table, dialect: DialectType = None) -> str:
552+
return exp.table_name(
553+
normalize_identifiers(exp.to_table(table, dialect=dialect), dialect=dialect)
554+
)

sqlmesh/core/loader.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pathlib import Path
1313

1414
from ruamel.yaml import YAML
15+
from sqlglot import exp
1516
from sqlglot.errors import SqlglotError
1617
from sqlglot.optimizer.qualify_columns import validate_qualify_columns
1718
from sqlglot.schema import MappingSchema, nested_set
@@ -39,7 +40,7 @@
3940

4041

4142
def update_model_schemas(dag: DAG[str], models: UniqueKeyDict[str, Model]) -> None:
42-
schema = MappingSchema()
43+
schema = MappingSchema(normalize=False)
4344

4445
for name in dag.sorted():
4546
model = models.get(name)
@@ -52,9 +53,7 @@ def update_model_schemas(dag: DAG[str], models: UniqueKeyDict[str, Model]) -> No
5253

5354
for dep in model.depends_on:
5455
external = external or dep not in models
55-
table = schema._normalize_table(
56-
schema._ensure_table(dep, dialect=model.dialect), dialect=model.dialect
57-
)
56+
table = exp.to_table(dep, dialect=model.dialect)
5857
mapping_schema = schema.find(table)
5958

6059
if mapping_schema:

sqlmesh/core/model/definition.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def columns_to_types(self) -> t.Dict[str, exp.DataType]:
654654
if self._columns_to_types is None:
655655
self._columns_to_types = {
656656
expression.output_name: expression.type or exp.DataType.build("unknown")
657-
for expression in self._query_renderer.render().expressions
657+
for expression in self._query_renderer.render().selects
658658
}
659659

660660
return self._columns_to_types
@@ -667,7 +667,7 @@ def column_descriptions(self) -> t.Dict[str, str]:
667667
if self._column_descriptions is None:
668668
self._column_descriptions = {
669669
select.alias: "\n".join(comment.strip() for comment in select.comments)
670-
for select in self.render_query().expressions
670+
for select in self.render_query().selects
671671
if select.comments
672672
}
673673
return self._column_descriptions
@@ -678,9 +678,7 @@ def validate_definition(self) -> None:
678678
if not isinstance(query, exp.Subqueryable):
679679
raise_config_error("Missing SELECT query in the model definition", self._path)
680680

681-
projection_list = (
682-
query.expressions if not isinstance(query, exp.Union) else query.this.expressions
683-
)
681+
projection_list = query.selects
684682
if not projection_list:
685683
raise_config_error("Query missing select statements", self._path)
686684

sqlmesh/core/model/meta.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ class IntervalUnit(str, Enum):
4141
class ModelMeta(PydanticModel):
4242
"""Metadata for models which can be defined in SQL."""
4343

44+
dialect: str = ""
4445
name: str
4546
kind: ModelKind = ModelKind(name=ModelKindName.VIEW)
46-
dialect: str = ""
4747
cron: str = "@daily"
4848
owner: t.Optional[str]
4949
description: t.Optional[str]
@@ -65,6 +65,10 @@ class ModelMeta(PydanticModel):
6565

6666
_model_kind_validator = ModelKind.field_validator()
6767

68+
@validator("name", pre=True)
69+
def _name_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> str:
70+
return d.normalize_model_name(v, dialect=values.get("dialect"))
71+
6872
@validator("audits", pre=True)
6973
def _audits_validator(cls, v: t.Any) -> t.Any:
7074
def extract(v: exp.Expression) -> t.Tuple[str, t.Dict[str, str]]:
@@ -144,14 +148,19 @@ def _columns_validator(
144148
return v
145149

146150
@validator("depends_on_", pre=True)
147-
def _depends_on_validator(cls, v: t.Any) -> t.Optional[t.Set[str]]:
151+
def _depends_on_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[t.Set[str]]:
152+
dialect = values.get("dialect")
153+
148154
if isinstance(v, (exp.Array, exp.Tuple)):
149155
return {
150-
exp.table_name(table.name if table.is_string else table.sql())
156+
d.normalize_model_name(
157+
table.name if table.is_string else table.sql(dialect=dialect), dialect=dialect
158+
)
151159
for table in v.expressions
152160
}
153161
if isinstance(v, exp.Expression):
154-
return {exp.table_name(v.sql())}
162+
return {d.normalize_model_name(v.sql(dialect=dialect), dialect=dialect)}
163+
155164
return v
156165

157166
@validator("start", pre=True)

sqlmesh/core/renderer.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
from sqlglot import exp, parse_one
88
from sqlglot.errors import SqlglotError
99
from sqlglot.optimizer.annotate_types import annotate_types
10-
from sqlglot.optimizer.lower_identities import lower_identities
11-
from sqlglot.optimizer.qualify_columns import qualify_columns
12-
from sqlglot.optimizer.qualify_tables import qualify_tables
10+
from sqlglot.optimizer.qualify import qualify
11+
from sqlglot.optimizer.qualify_columns import quote_identifiers
1312
from sqlglot.optimizer.simplify import simplify
1413
from sqlglot.schema import ensure_schema
1514

@@ -259,7 +258,7 @@ def render(
259258
if not isinstance(query, exp.Subqueryable):
260259
raise_config_error(f"Query needs to be a SELECT or a UNION {query}.", self._path)
261260

262-
return t.cast(exp.Subqueryable, query)
261+
return t.cast(exp.Subqueryable, quote_identifiers(query, dialect=self._dialect))
263262

264263
def time_column_filter(self, start: TimeLike, end: TimeLike) -> exp.Between:
265264
"""Returns a between statement with the properly formatted time column."""
@@ -273,19 +272,25 @@ def time_column_filter(self, start: TimeLike, end: TimeLike) -> exp.Between:
273272
)
274273

275274
def _optimize_query(self, query: exp.Expression) -> exp.Expression:
276-
schema = ensure_schema(self.schema, dialect=self._dialect)
277-
275+
# We don't want to normalize names in the schema because that's handled by the optimizer
276+
schema = ensure_schema(self.schema, dialect=self._dialect, normalize=False)
278277
query = t.cast(exp.Subqueryable, query.copy())
279-
lower_identities(query)
280-
qualify_tables(query)
281278

282279
try:
280+
qualify(
281+
query,
282+
dialect=self._dialect,
283+
schema=schema,
284+
infer_schema=False,
285+
validate_qualify_columns=False,
286+
qualify_columns=not schema.empty,
287+
quote_identifiers=False,
288+
)
289+
283290
if schema.empty:
284291
for select in query.selects:
285292
if not isinstance(select, exp.Alias) and select.output_name not in ("*", ""):
286293
select.replace(exp.alias_(select, select.output_name))
287-
else:
288-
qualify_columns(query, schema=schema, infer_schema=False)
289294
except SqlglotError as ex:
290295
raise_config_error(
291296
f"Error qualifying columns, the column may not exist or is ambiguous. {ex}",

sqlmesh/magics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sqlmesh.core.dialect import format_model_expressions, parse
1515
from sqlmesh.core.model import load_model
1616
from sqlmesh.core.test import ModelTestMetadata, get_all_model_tests
17-
from sqlmesh.utils.errors import MagicError, MissingContextException, SQLMeshError
17+
from sqlmesh.utils.errors import MagicError, MissingContextException
1818
from sqlmesh.utils.yaml import dumps as yaml_dumps
1919
from sqlmesh.utils.yaml import load as yaml_load
2020

@@ -66,10 +66,7 @@ def context(self, line: str) -> None:
6666
def model(self, line: str, sql: t.Optional[str] = None) -> None:
6767
"""Renders the model and automatically fills in an editable cell with the model definition."""
6868
args = parse_argstring(self.model, line)
69-
model = self._context.get_model(t.cast(str, args.model))
70-
71-
if not model:
72-
raise SQLMeshError(f"Cannot find {model}")
69+
model = self._context.get_model(args.model, raise_if_missing=True)
7370

7471
if sql:
7572
config = self._context.config_for_model(model)
@@ -142,13 +139,16 @@ def test(self, line: str, test_def_raw: t.Optional[str] = None) -> None:
142139
f"Test found that does not have `model` defined: {model_test_metadata.path}"
143140
)
144141
tests[model][model_test_metadata.test_name] = model_test_metadata
142+
143+
model = self._context.get_model(args.model, raise_if_missing=True)
144+
145145
if args.ls:
146146
# TODO: Provide better UI for displaying tests
147-
for test_name in tests[args.model]:
147+
for test_name in tests[model.name]:
148148
self._context.console.log_status_update(test_name)
149149
return
150150

151-
test = tests[args.model][args.test_name]
151+
test = tests[model.name][args.test_name]
152152
test_def = yaml_load(test_def_raw) if test_def_raw else test.body
153153
test_def_output = yaml_dumps(test_def)
154154

tests/core/test_audit.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def test_no_query():
144144

145145

146146
def test_macro(model: Model):
147-
expected_query = "SELECT * FROM (SELECT * FROM db.test_model AS test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') AS _q_0 WHERE a IS NULL"
147+
expected_query = """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" <= '1970-01-01' AND "ds" >= '1970-01-01') AS "_q_0" WHERE "a" IS NULL"""
148148

149149
audit = Audit(
150150
name="test_audit",
@@ -167,7 +167,7 @@ def test_not_null_audit(model: Model):
167167
)
168168
assert (
169169
rendered_query_a.sql()
170-
== "SELECT * FROM (SELECT * FROM db.test_model AS test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') AS _q_0 WHERE a IS NULL"
170+
== """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" <= '1970-01-01' AND "ds" >= '1970-01-01') AS "_q_0" WHERE "a" IS NULL"""
171171
)
172172

173173
rendered_query_a_and_b = builtin.not_null_audit.render_query(
@@ -176,23 +176,23 @@ def test_not_null_audit(model: Model):
176176
)
177177
assert (
178178
rendered_query_a_and_b.sql()
179-
== "SELECT * FROM (SELECT * FROM db.test_model AS test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') AS _q_0 WHERE a IS NULL OR b IS NULL"
179+
== """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" <= '1970-01-01' AND "ds" >= '1970-01-01') AS "_q_0" WHERE "a" IS NULL OR "b" IS NULL"""
180180
)
181181

182182

183183
def test_unique_values_audit(model: Model):
184184
rendered_query_a = builtin.unique_values_audit.render_query(model, columns=[exp.to_column("a")])
185185
assert (
186186
rendered_query_a.sql()
187-
== "SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY 1) AS a_rank FROM (SELECT * FROM db.test_model AS test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') AS _q_0) AS _q_1 WHERE a_rank > 1"
187+
== """SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY 1) AS "a_rank" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" <= '1970-01-01' AND "ds" >= '1970-01-01') AS "_q_0") AS "_q_1" WHERE "a_rank" > 1"""
188188
)
189189

190190
rendered_query_a_and_b = builtin.unique_values_audit.render_query(
191191
model, columns=[exp.to_column("a"), exp.to_column("b")]
192192
)
193193
assert (
194194
rendered_query_a_and_b.sql()
195-
== "SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY 1) AS a_rank, ROW_NUMBER() OVER (PARTITION BY b ORDER BY 1) AS b_rank FROM (SELECT * FROM db.test_model AS test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') AS _q_0) AS _q_1 WHERE a_rank > 1 OR b_rank > 1"
195+
== """SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY 1) AS "a_rank", ROW_NUMBER() OVER (PARTITION BY "b" ORDER BY 1) AS "b_rank" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" <= '1970-01-01' AND "ds" >= '1970-01-01') AS "_q_0") AS "_q_1" WHERE "a_rank" > 1 OR "b_rank" > 1"""
196196
)
197197

198198

@@ -204,7 +204,7 @@ def test_accepted_values_audit(model: Model):
204204
)
205205
assert (
206206
rendered_query.sql()
207-
== "SELECT * FROM (SELECT * FROM db.test_model AS test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') AS _q_0 WHERE NOT a IN ('value_a', 'value_b')"
207+
== """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" <= '1970-01-01' AND "ds" >= '1970-01-01') AS "_q_0" WHERE NOT "a" IN ('value_a', 'value_b')"""
208208
)
209209

210210

@@ -215,7 +215,7 @@ def test_number_of_rows_audit(model: Model):
215215
)
216216
assert (
217217
rendered_query.sql()
218-
== """SELECT 1 AS "1" FROM (SELECT * FROM db.test_model AS test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') AS _q_0 HAVING COUNT(*) <= 0 LIMIT 1"""
218+
== """SELECT 1 AS "1" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" <= '1970-01-01' AND "ds" >= '1970-01-01') AS "_q_0" HAVING COUNT(*) <= 0 LIMIT 1"""
219219
)
220220

221221

@@ -226,7 +226,7 @@ def test_forall_audit(model: Model):
226226
)
227227
assert (
228228
rendered_query_a.sql()
229-
== "SELECT * FROM (SELECT * FROM db.test_model AS test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') AS _q_0 WHERE NOT a >= b"
229+
== '''SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" <= '1970-01-01' AND "ds" >= '1970-01-01') AS "_q_0" WHERE NOT "a" >= "b"'''
230230
)
231231

232232
rendered_query_a = builtin.forall_audit.render_query(
@@ -235,5 +235,5 @@ def test_forall_audit(model: Model):
235235
)
236236
assert (
237237
rendered_query_a.sql()
238-
== "SELECT * FROM (SELECT * FROM db.test_model AS test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') AS _q_0 WHERE NOT a >= b OR NOT c + d - e < 1.0"
238+
== """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" <= '1970-01-01' AND "ds" >= '1970-01-01') AS "_q_0" WHERE NOT "a" >= "b" OR NOT "c" + "d" - "e" < 1.0"""
239239
)

0 commit comments

Comments
 (0)