Skip to content

Commit 776c869

Browse files
authored
Fix: Allow SELECT * in models that reference external tables (#1007)
1 parent 4703f1e commit 776c869

File tree

5 files changed

+144
-49
lines changed

5 files changed

+144
-49
lines changed

sqlmesh/core/dialect.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sqlglot import Dialect, Generator, Parser, Tokenizer, TokenType, exp
1111
from sqlglot.dialects.dialect import DialectType
1212
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
13+
from sqlglot.optimizer.scope import traverse_scope
1314
from sqlglot.tokens import Token
1415

1516
from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE
@@ -645,3 +646,21 @@ def extract_columns_to_types(query: exp.Subqueryable) -> t.Dict[str, exp.DataTyp
645646
expression.output_name: expression.type or exp.DataType.build("unknown")
646647
for expression in query.selects
647648
}
649+
650+
651+
def find_tables(expression: exp.Expression, dialect: DialectType = None) -> t.Set[str]:
652+
"""Find all tables referenced in a query.
653+
654+
Args:
655+
expressions: The query to find the tables in.
656+
dialect: The dialect to use for normalization of table names.
657+
658+
Returns:
659+
A Set of all the table names.
660+
"""
661+
return {
662+
normalize_model_name(table, dialect=dialect)
663+
for scope in traverse_scope(expression)
664+
for table in scope.tables
665+
if isinstance(table.this, exp.Identifier) and exp.table_name(table) not in scope.cte_sources
666+
}

sqlmesh/core/loader.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,14 @@ def update_model_schemas(
5454
continue
5555

5656
model.update_schema(schema)
57+
5758
cache_hit = optimized_query_cache.with_optimized_query(model)
58-
schema.add_table(name, model.columns_to_types, dialect=model.dialect)
59-
60-
if any(dep not in models for dep in model.depends_on):
61-
if "*" in model.columns_to_types:
62-
raise ConfigError(
63-
f"Can't expand SELECT * expression for model '{name}' at '{model._path}'."
64-
" Either specify external source projections expliticly or"
65-
' add source tables as "external models" using the command'
66-
" 'sqlmesh create_external_models'."
67-
)
68-
elif isinstance(model, SqlModel) and model.mapping_schema and not cache_hit:
59+
60+
columns_to_types = model.columns_to_types
61+
if columns_to_types is not None:
62+
schema.add_table(name, columns_to_types, dialect=model.dialect)
63+
64+
if isinstance(model, SqlModel) and model.mapping_schema and not cache_hit:
6965
query = model.render_query()
7066
if query is not None:
7167
try:

sqlmesh/core/model/definition.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from sqlglot import diff, exp
1717
from sqlglot.diff import Insert, Keep
1818
from sqlglot.helper import ensure_list
19-
from sqlglot.optimizer.scope import traverse_scope
2019
from sqlglot.schema import MappingSchema, nested_set
2120
from sqlglot.time import format_time
2221

@@ -460,6 +459,18 @@ def update_schema(self, schema: MappingSchema) -> None:
460459
tuple(str(part) for part in table.parts),
461460
{k: str(v) for k, v in mapping_schema.items()},
462461
)
462+
else:
463+
# Reset the entire mapping if at least one upstream dependency is missing from the mapping
464+
# to prevent partial mappings from being used.
465+
logger.warning(
466+
"Missing schema for model '%s' referenced in model '%s'. Run `sqlmesh create_external_models` "
467+
"and / or make sure that the model '%s' can be rendered at parse time",
468+
dep,
469+
self.name,
470+
dep,
471+
)
472+
self.mapping_schema.clear()
473+
return
463474

464475
@property
465476
def depends_on(self) -> t.Set[str]:
@@ -468,16 +479,7 @@ def depends_on(self) -> t.Set[str]:
468479
Returns:
469480
A list of all the upstream table names.
470481
"""
471-
if self.depends_on_ is not None:
472-
return self.depends_on_ - {self.name}
473-
474-
if self._depends_on is None:
475-
query = self.render_query(optimize=False)
476-
if query is None:
477-
self._depends_on = set()
478-
else:
479-
self._depends_on = _find_tables(query) - {self.name}
480-
return self._depends_on
482+
return self.depends_on_ or set()
481483

482484
@property
483485
def columns_to_types(self) -> t.Optional[t.Dict[str, exp.DataType]]:
@@ -537,7 +539,7 @@ def depends_on_past(self) -> bool:
537539
query = self.render_query(optimize=False)
538540
if query is None:
539541
return False
540-
return self.name in _find_tables(query)
542+
return self.name in d.find_tables(query, dialect=self.dialect)
541543

542544
return self._depends_on_past
543545

@@ -559,7 +561,7 @@ def validate_definition(self) -> None:
559561
)
560562

561563
columns_to_types = self.columns_to_types
562-
if columns_to_types is not None and "*" not in columns_to_types:
564+
if columns_to_types is not None:
563565
column_names = {c.lower() for c in columns_to_types}
564566
missing_keys = unique_partition_keys - column_names
565567
if missing_keys:
@@ -742,6 +744,18 @@ def render_definition(self, include_python: bool = True) -> t.List[exp.Expressio
742744
def is_sql(self) -> bool:
743745
return True
744746

747+
@property
748+
def depends_on(self) -> t.Set[str]:
749+
if self._depends_on is None:
750+
self._depends_on = self.depends_on_ or set()
751+
752+
query = self.render_query(optimize=False)
753+
if query is not None:
754+
self._depends_on |= d.find_tables(query, dialect=self.dialect)
755+
756+
self._depends_on -= {self.name}
757+
return self._depends_on
758+
745759
@property
746760
def columns_to_types(self) -> t.Optional[t.Dict[str, exp.DataType]]:
747761
if self.columns_to_types_ is not None:
@@ -753,6 +767,9 @@ def columns_to_types(self) -> t.Optional[t.Dict[str, exp.DataType]]:
753767
return None
754768
self._columns_to_types = d.extract_columns_to_types(query)
755769

770+
if "*" in self._columns_to_types:
771+
return None
772+
756773
return self._columns_to_types
757774

758775
@property
@@ -1511,23 +1528,6 @@ def _validate_model_fields(klass: t.Type[_Model], provided_fields: t.Set[str], p
15111528
raise_config_error(f"Invalid extra fields {extra_fields} in the model definition", path)
15121529

15131530

1514-
def _find_tables(expression: exp.Expression) -> t.Set[str]:
1515-
"""Find all tables referenced in a query.
1516-
1517-
Args:
1518-
expressions: The list of expressions to find tables for.
1519-
1520-
Returns:
1521-
A Set of all the table names.
1522-
"""
1523-
return {
1524-
exp.table_name(table)
1525-
for scope in traverse_scope(expression)
1526-
for table in scope.tables
1527-
if isinstance(table.this, exp.Identifier) and exp.table_name(table) not in scope.cte_sources
1528-
}
1529-
1530-
15311531
def _python_env(
15321532
expressions: t.Union[exp.Expression, t.List[exp.Expression]],
15331533
jinja_macro_references: t.Optional[t.Set[MacroReference]],

sqlmesh/core/renderer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sqlglot.optimizer.qualify_columns import quote_identifiers
1515
from sqlglot.optimizer.qualify_tables import qualify_tables
1616
from sqlglot.optimizer.simplify import simplify
17-
from sqlglot.schema import ensure_schema
17+
from sqlglot.schema import MappingSchema
1818

1919
from sqlmesh.core import constants as c
2020
from sqlmesh.core import dialect as d
@@ -263,7 +263,6 @@ def render(
263263
),
264264
)
265265
except ParsetimeAdapterCallError:
266-
logger.debug("Failed to render query at parse time:\n%s", self._expression)
267266
return None
268267

269268
if not query:
@@ -325,10 +324,21 @@ def time_column_filter(self, start: TimeLike, end: TimeLike) -> exp.Between:
325324

326325
def _optimize_query(self, query: exp.Subqueryable) -> exp.Subqueryable:
327326
# We don't want to normalize names in the schema because that's handled by the optimizer
328-
schema = ensure_schema(self.schema, dialect=self._dialect, normalize=False)
327+
schema = MappingSchema(self.schema, dialect=self._dialect, normalize=False)
329328
original = query
330329
failure = False
331330

331+
if not schema.empty:
332+
for dependency in d.find_tables(query, dialect=self._dialect):
333+
if schema.find(exp.to_table(dependency, dialect=self._dialect)) is None:
334+
logger.warning(
335+
"Query cannot be optimized due to missing schema for model '%s'. "
336+
"Make sure that the model query can be rendered at parse time",
337+
dependency,
338+
)
339+
schema = MappingSchema(None, dialect=self._dialect, normalize=False)
340+
break
341+
332342
try:
333343
if not schema.empty:
334344
query = query.copy()

tests/core/test_model.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from datetime import date
22
from pathlib import Path
3+
from unittest.mock import patch
34

45
import pytest
56
from pytest_mock.plugin import MockerFixture
67
from sqlglot import exp, parse, parse_one
8+
from sqlglot.schema import MappingSchema
79

810
import sqlmesh.core.dialect as d
911
from sqlmesh.core.config import Config
@@ -1227,7 +1229,7 @@ def test_case_sensitivity(assert_exp_eq):
12271229
"""
12281230
MODEL (name example.source, kind EMBEDDED);
12291231
1230-
SELECT "id", "name", "payload" FROM db.schema."table"
1232+
SELECT 'id' AS "id", 'name' AS "name", 'payload' AS "payload"
12311233
"""
12321234
),
12331235
dialect="snowflake",
@@ -1258,10 +1260,9 @@ def test_case_sensitivity(assert_exp_eq):
12581260
"SOURCE"."payload" AS "payload"
12591261
FROM (
12601262
SELECT
1261-
"id" AS "id",
1262-
"name" AS "name",
1263-
"payload" AS "payload"
1264-
FROM "DB"."SCHEMA"."table" AS "table"
1263+
'id' AS "id",
1264+
'name' AS "name",
1265+
'payload' AS "payload"
12651266
) AS "SOURCE"
12661267
""",
12671268
)
@@ -1386,3 +1387,72 @@ def runtime_macro(**kwargs) -> None:
13861387
match=r"Dependencies must be provided explicitly for models that can be rendered only at runtime at.*",
13871388
):
13881389
load_model(expressions)
1390+
1391+
1392+
def test_update_schema():
1393+
expressions = d.parse(
1394+
"""
1395+
MODEL (name db.table);
1396+
1397+
SELECT a, b FROM table_a JOIN table_b
1398+
"""
1399+
)
1400+
1401+
model = load_model(expressions)
1402+
1403+
schema = MappingSchema(normalize=False)
1404+
schema.add_table("table_a", {"a": exp.DataType.build("int")})
1405+
1406+
# Make sure that the partial schema is not applied.
1407+
model.update_schema(schema)
1408+
assert not model.mapping_schema
1409+
1410+
schema.add_table("table_b", {"b": exp.DataType.build("int")})
1411+
1412+
model.update_schema(schema)
1413+
assert model.mapping_schema == {
1414+
"table_a": {"a": "INT"},
1415+
"table_b": {"b": "INT"},
1416+
}
1417+
1418+
1419+
def test_user_provided_depends_on():
1420+
expressions = d.parse(
1421+
"""
1422+
MODEL (name db.table, depends_on [table_b]);
1423+
1424+
SELECT a FROM table_a
1425+
"""
1426+
)
1427+
1428+
model = load_model(expressions)
1429+
1430+
assert model.depends_on == {"table_a", "table_b"}
1431+
1432+
1433+
def test_check_schema_mapping_when_rendering_at_runtime(assert_exp_eq):
1434+
expressions = d.parse(
1435+
"""
1436+
MODEL (name db.table, depends_on [table_b]);
1437+
1438+
SELECT * FROM table_a JOIN table_b
1439+
"""
1440+
)
1441+
1442+
model = load_model(expressions)
1443+
1444+
# Simulate a query that cannot be rendered at parse time.
1445+
with patch.object(SqlModel, "render_query", return_value=None) as render_query_mock:
1446+
schema = MappingSchema(normalize=False)
1447+
schema.add_table("table_b", {"b": exp.DataType.build("int")})
1448+
model.update_schema(schema)
1449+
1450+
assert "table_b" in model.mapping_schema
1451+
assert model.depends_on == {"table_b"}
1452+
1453+
render_query_mock.assert_called_once()
1454+
1455+
# Simulate rendering at runtime.
1456+
assert_exp_eq(
1457+
model.render_query(), """SELECT * FROM "table_a" AS "table_a", "table_b" AS "table_b" """
1458+
)

0 commit comments

Comments
 (0)