Skip to content

Commit a7e021f

Browse files
authored
Fix!: use dialect when generating types for mapping schema (#1531)
* Fix: use dialect when generating types for mapping schema * Add migration script * PR feedback
1 parent b5f8f0b commit a7e021f

3 files changed

Lines changed: 74 additions & 4 deletions

File tree

sqlmesh/core/model/definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def update_schema(
504504
nested_set(
505505
self.mapping_schema,
506506
tuple(str(part) for part in table.parts),
507-
{k: str(v) for k, v in mapping_schema.items()},
507+
{k: v.sql(dialect=self.dialect) for k, v in mapping_schema.items()}, # type: ignore
508508
)
509509
else:
510510
# Reset the entire mapping if at least one upstream dependency is missing from the mapping
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Generate mapping schema data types using the corresponding model's dialect."""
2+
import json
3+
4+
import pandas as pd
5+
from sqlglot import exp, parse_one
6+
7+
from sqlmesh.utils.migration import index_text_type
8+
9+
10+
def migrate(state_sync): # type: ignore
11+
engine_adapter = state_sync.engine_adapter
12+
schema = state_sync.schema
13+
snapshots_table = "_snapshots"
14+
if schema:
15+
snapshots_table = f"{schema}.{snapshots_table}"
16+
17+
new_snapshots = []
18+
for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall(
19+
exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table),
20+
quote_identifiers=True,
21+
):
22+
parsed_snapshot = json.loads(snapshot)
23+
node = parsed_snapshot["node"]
24+
25+
mapping_schema = node.get("mapping_schema")
26+
if mapping_schema:
27+
node["mapping_schema"] = _convert_schema_types(mapping_schema, node["dialect"])
28+
29+
new_snapshots.append(
30+
{
31+
"name": name,
32+
"identifier": identifier,
33+
"version": version,
34+
"snapshot": json.dumps(parsed_snapshot),
35+
"kind_name": kind_name,
36+
}
37+
)
38+
39+
if new_snapshots:
40+
engine_adapter.delete_from(snapshots_table, "TRUE")
41+
42+
text_type = index_text_type(engine_adapter.dialect)
43+
44+
engine_adapter.insert_append(
45+
snapshots_table,
46+
pd.DataFrame(new_snapshots),
47+
columns_to_types={
48+
"name": exp.DataType.build(text_type),
49+
"identifier": exp.DataType.build(text_type),
50+
"version": exp.DataType.build(text_type),
51+
"snapshot": exp.DataType.build("text"),
52+
"kind_name": exp.DataType.build(text_type),
53+
},
54+
contains_json=True,
55+
)
56+
57+
58+
def _convert_schema_types(schema, dialect): # type: ignore
59+
if not schema:
60+
return schema
61+
62+
for k, v in schema.items():
63+
if isinstance(v, dict):
64+
_convert_schema_types(v, dialect)
65+
else:
66+
schema[k] = parse_one(v).sql(dialect=dialect)
67+
68+
return schema

tests/core/test_context.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77
from pytest_mock.plugin import MockerFixture
8-
from sqlglot import MappingSchema, parse_one
8+
from sqlglot import MappingSchema, exp, parse_one
99
from sqlglot.errors import SchemaError
1010

1111
import sqlmesh.core.constants
@@ -474,9 +474,11 @@ def test_default_schema_and_config(sushi_context_pre_scheduling) -> None:
474474
context.upsert_model(c)
475475

476476
c.update_schema(
477-
MappingSchema({"a": {"col": "int"}}), default_schema="schema", default_catalog="catalog"
477+
MappingSchema({"a": {"col": exp.DataType.build("int")}}),
478+
default_schema="schema",
479+
default_catalog="catalog",
478480
)
479-
assert c.mapping_schema == {"catalog": {"schema": {"a": {"col": "int"}}}}
481+
assert c.mapping_schema == {"catalog": {"schema": {"a": {"col": "INT"}}}}
480482

481483

482484
def test_gateway_macro(sushi_context: Context) -> None:

0 commit comments

Comments
 (0)