Skip to content

Commit e220475

Browse files
authored
Fix!: quote all identifiers in execute except those coming from fetchdf (#1128)
* Fix!: normalize exprs in execute except those coming from fetchdf * Fix tests * Set default to false * Simplify snowflake _insert_append_pandas_df * Redshift fixup * Cleanup * Fix test * Only quote identifiers in _to_sql * Snowflake fixups * Formatting * Quote identifiers in _to_sql * PR feedback * Comment fixup * Add empty dialect to seed model as well * Remove default dialects in model level from wursthall project * Add comment * Fixup
1 parent bbaae27 commit e220475

36 files changed

Lines changed: 365 additions & 259 deletions

examples/wursthall/models/db/customer_d.sql

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ MODEL (
44
time_column (register_ds, '%Y-%m-%d'),
55
batch_size 200,
66
),
7-
dialect "",
87
cron '@daily',
98
owner jen,
109
start '2022-06-01 00:00:00+00:00',

examples/wursthall/models/db/item_d.sql

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
MODEL (
22
name db.item_d,
33
kind VIEW,
4-
dialect "",
54
cron '@daily',
65
owner jen,
76
start '2022-06-01 00:00:00+00:00',

examples/wursthall/models/db/order_f.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
from models.src.shared import DATA_START_DATE_STR, set_seed # type: ignore
8+
from sqlglot import parse_one
89

910
from sqlmesh import ExecutionContext, model
1011
from sqlmesh.core.model import IncrementalByTimeRangeKind, TimeColumn
@@ -43,27 +44,29 @@ def execute(
4344
item_d_table_name = context.table("db.item_d")
4445
order_item_f_table_name = context.table("db.order_item_f")
4546

47+
# We use parse_one here instead of a raw string because this is a multi-dialect
48+
# project and we want to ensure that the resulting query is properly quoted in
49+
# the target dialect before executing it
4650
df_item_d = context.fetchdf(
47-
f"""
48-
SELECT
49-
item_id,
50-
item_price
51-
FROM {item_d_table_name}
52-
"""
51+
parse_one(f"SELECT item_id, item_price FROM {item_d_table_name}"),
52+
quote_identifiers=True,
5353
)
5454

5555
df_order_item_f = context.fetchdf(
56-
f"""
57-
SELECT
58-
order_id,
59-
customer_id,
60-
item_id,
61-
quantity,
62-
order_ds
63-
FROM {order_item_f_table_name}
64-
WHERE
65-
order_ds BETWEEN '{to_ds(start)}' AND '{to_ds(end)}'
66-
"""
56+
parse_one(
57+
f"""
58+
SELECT
59+
order_id,
60+
customer_id,
61+
item_id,
62+
quantity,
63+
order_ds
64+
FROM {order_item_f_table_name}
65+
WHERE
66+
order_ds BETWEEN '{to_ds(start)}' AND '{to_ds(end)}'
67+
"""
68+
),
69+
quote_identifiers=True,
6770
)
6871

6972
df_order_item_f = df_order_item_f.merge(df_item_d, how="inner", on="item_id")

examples/wursthall/models/db/order_item_f.sql

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ MODEL (
44
time_column (order_ds, '%Y-%m-%d'),
55
batch_size 200,
66
),
7-
dialect "",
87
cron '@daily',
98
owner jen,
109
start '2022-06-01 00:00:00+00:00',

examples/wursthall/models/src/menu_item_details.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ MODEL (
44
path '../../seeds/src/menu_item_details.csv',
55
),
66
owner jen
7-
)
7+
)

examples/wursthall/models/src/order_item_details.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88
from faker import Faker
99
from models.src.shared import DATA_START_DATE_STR, iter_dates, set_seed # type: ignore
10+
from sqlglot import parse_one
1011

1112
from sqlmesh import ExecutionContext, model
1213
from sqlmesh.core.model import IncrementalByTimeRangeKind, TimeColumn
@@ -49,23 +50,26 @@ def execute(
4950
customer_details_table_name = context.table("src.customer_details")
5051
menu_item_details_table_name = context.table("src.menu_item_details")
5152

53+
# We use parse_one here instead of a raw string because this is a multi-dialect
54+
# project and we want to ensure that the resulting query is properly quoted in
55+
# the target dialect before executing it
5256
df_customers = context.fetchdf(
53-
f"""
54-
SELECT
55-
id AS customer_id,
56-
register_ds
57-
FROM {customer_details_table_name}
58-
WHERE
59-
register_ds <= '{to_ds(end)}'
60-
"""
57+
parse_one(
58+
f"""
59+
SELECT
60+
id AS customer_id,
61+
register_ds
62+
FROM {customer_details_table_name}
63+
WHERE
64+
register_ds <= '{to_ds(end)}'
65+
"""
66+
),
67+
quote_identifiers=True,
6168
)
6269

6370
df_menu_items = context.fetchdf(
64-
f"""
65-
SELECT
66-
id AS item_id
67-
FROM {menu_item_details_table_name}
68-
"""
71+
parse_one(f"SELECT id AS item_id FROM {menu_item_details_table_name}"),
72+
quote_identifiers=True,
6973
)
7074

7175
num_menu_items = len(df_menu_items.index)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
],
116116
"snowflake": [
117117
"snowflake-connector-python[pandas,secure-local-storage]",
118+
"pyarrow>=10.0.1,<10.1.0",
118119
],
119120
"web": [
120121
"fastapi==0.100.0",

sqlmesh/core/context.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,27 +130,33 @@ def table(self, model_name: str) -> str:
130130
"""
131131
return self._model_tables[model_name]
132132

133-
def fetchdf(self, query: t.Union[exp.Expression, str]) -> pd.DataFrame:
133+
def fetchdf(
134+
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
135+
) -> pd.DataFrame:
134136
"""Fetches a dataframe given a sql string or sqlglot expression.
135137
136138
Args:
137139
query: SQL string or sqlglot expression.
140+
quote_identifiers: Whether to quote all identifiers in the query.
138141
139142
Returns:
140143
The default dataframe is Pandas, but for Spark a PySpark dataframe is returned.
141144
"""
142-
return self.engine_adapter.fetchdf(query)
145+
return self.engine_adapter.fetchdf(query, quote_identifiers=quote_identifiers)
143146

144-
def fetch_pyspark_df(self, query: t.Union[exp.Expression, str]) -> PySparkDataFrame:
147+
def fetch_pyspark_df(
148+
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
149+
) -> PySparkDataFrame:
145150
"""Fetches a PySpark dataframe given a sql string or sqlglot expression.
146151
147152
Args:
148153
query: SQL string or sqlglot expression.
154+
quote_identifiers: Whether to quote all identifiers in the query.
149155
150156
Returns:
151157
A PySpark dataframe.
152158
"""
153-
return self.engine_adapter.fetch_pyspark_df(query)
159+
return self.engine_adapter.fetch_pyspark_df(query, quote_identifiers=quote_identifiers)
154160

155161

156162
class ExecutionContext(BaseContext):

sqlmesh/core/engine_adapter/base.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sqlglot import Dialect, exp
2020
from sqlglot.errors import ErrorLevel
2121
from sqlglot.helper import ensure_list
22+
from sqlglot.optimizer.qualify_columns import quote_identifiers
2223

2324
from sqlmesh.core.dialect import pandas_to_sql
2425
from sqlmesh.core.engine_adapter.shared import DataObject, TransactionType
@@ -698,7 +699,9 @@ def _merge(
698699
match_expressions: t.List[exp.When],
699700
) -> None:
700701
this = exp.alias_(exp.to_table(target_table), alias=MERGE_TARGET_ALIAS, table=True)
701-
using = exp.Subquery(this=source_table, alias=MERGE_SOURCE_ALIAS)
702+
using = exp.alias_(
703+
exp.Subquery(this=source_table), alias=MERGE_SOURCE_ALIAS, copy=False, table=True
704+
)
702705
self.execute(
703706
exp.Merge(
704707
this=this,
@@ -778,33 +781,49 @@ def fetchone(
778781
self,
779782
query: t.Union[exp.Expression, str],
780783
ignore_unsupported_errors: bool = False,
784+
quote_identifiers: bool = False,
781785
) -> t.Tuple:
782-
self.execute(query, ignore_unsupported_errors=ignore_unsupported_errors)
786+
self.execute(
787+
query,
788+
ignore_unsupported_errors=ignore_unsupported_errors,
789+
quote_identifiers=quote_identifiers,
790+
)
783791
return self.cursor.fetchone()
784792

785793
def fetchall(
786794
self,
787795
query: t.Union[exp.Expression, str],
788796
ignore_unsupported_errors: bool = False,
797+
quote_identifiers: bool = False,
789798
) -> t.List[t.Tuple]:
790-
self.execute(query, ignore_unsupported_errors=ignore_unsupported_errors)
799+
self.execute(
800+
query,
801+
ignore_unsupported_errors=ignore_unsupported_errors,
802+
quote_identifiers=quote_identifiers,
803+
)
791804
return self.cursor.fetchall()
792805

793-
def _fetch_native_df(self, query: t.Union[exp.Expression, str]) -> DF:
806+
def _fetch_native_df(
807+
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
808+
) -> DF:
794809
"""Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
795-
self.execute(query)
810+
self.execute(query, quote_identifiers=quote_identifiers)
796811
return self.cursor.fetchdf()
797812

798-
def fetchdf(self, query: t.Union[exp.Expression, str]) -> pd.DataFrame:
813+
def fetchdf(
814+
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
815+
) -> pd.DataFrame:
799816
"""Fetches a Pandas DataFrame from the cursor"""
800-
df = self._fetch_native_df(query)
817+
df = self._fetch_native_df(query, quote_identifiers=quote_identifiers)
801818
if not isinstance(df, pd.DataFrame):
802819
raise NotImplementedError(
803820
"The cursor's `fetch_native_df` method is not returning a pandas DataFrame. Need to update `fetchdf` so a Pandas DataFrame is returned"
804821
)
805822
return df
806823

807-
def fetch_pyspark_df(self, query: t.Union[exp.Expression, str]) -> PySparkDataFrame:
824+
def fetch_pyspark_df(
825+
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
826+
) -> PySparkDataFrame:
808827
"""Fetches a PySpark DataFrame from the cursor"""
809828
raise NotImplementedError(f"Engine does not support PySpark DataFrames: {type(self)}")
810829

@@ -835,6 +854,7 @@ def execute(
835854
self,
836855
expressions: t.Union[str, exp.Expression, t.Sequence[exp.Expression]],
837856
ignore_unsupported_errors: bool = False,
857+
quote_identifiers: bool = True,
838858
**kwargs: t.Any,
839859
) -> None:
840860
"""Execute a sql query."""
@@ -843,7 +863,11 @@ def execute(
843863
)
844864

845865
for e in ensure_list(expressions):
846-
sql = self._to_sql(e, **to_sql_kwargs) if isinstance(e, exp.Expression) else e
866+
sql = (
867+
self._to_sql(e, quote=quote_identifiers, **to_sql_kwargs)
868+
if isinstance(e, exp.Expression)
869+
else e
870+
)
847871
logger.debug(f"Executing SQL:\n{sql}")
848872
self.cursor.execute(sql, **kwargs)
849873

@@ -882,7 +906,7 @@ def _create_table_properties(
882906
"""Creates a SQLGlot table properties expression for ddl."""
883907
return None
884908

885-
def _to_sql(self, e: exp.Expression, **kwargs: t.Any) -> str:
909+
def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str:
886910
"""
887911
Converts an expression to a SQL string. Has a set of default kwargs to apply, and then default
888912
kwargs defined for the given dialect, and then kwargs provided by the user when defining the engine
@@ -896,7 +920,11 @@ def _to_sql(self, e: exp.Expression, **kwargs: t.Any) -> str:
896920
**self.sql_gen_kwargs,
897921
**kwargs,
898922
}
899-
return e.sql(**sql_gen_kwargs) # type: ignore
923+
924+
if quote:
925+
quote_identifiers(expression)
926+
927+
return expression.sql(**sql_gen_kwargs) # type: ignore
900928

901929
def _get_data_objects(
902930
self, schema_name: str, catalog_name: t.Optional[str] = None
@@ -917,9 +945,11 @@ def _get_temp_table(
917945
"""
918946
table = t.cast(exp.Table, exp.to_table(table).copy())
919947
table.set("this", exp.to_identifier(f"__temp_{table.name}_{uuid.uuid4().hex}"))
948+
920949
if table_only:
921950
table.set("db", None)
922951
table.set("catalog", None)
952+
923953
return table
924954

925955
def _add_where_to_query(self, query: Query, where: t.Optional[exp.Expression]) -> Query:
@@ -933,8 +963,10 @@ def _add_where_to_query(self, query: Query, where: t.Optional[exp.Expression]) -
933963
.from_(query.subquery("_subquery", copy=False), copy=False)
934964
.where(where, copy=False)
935965
)
966+
936967
if with_:
937968
query.set("with", with_)
969+
938970
return query
939971

940972

sqlmesh/core/engine_adapter/base_postgres.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def execute(
100100
self,
101101
expressions: t.Union[str, exp.Expression, t.Sequence[exp.Expression]],
102102
ignore_unsupported_errors: bool = False,
103+
quote_identifiers: bool = True,
103104
**kwargs: t.Any,
104105
) -> None:
105106
"""
@@ -108,7 +109,12 @@ def execute(
108109
109110
Reference: https://www.psycopg.org/psycopg3/docs/basic/transactions.html
110111
"""
111-
super().execute(expressions, ignore_unsupported_errors=ignore_unsupported_errors, **kwargs)
112+
super().execute(
113+
expressions,
114+
ignore_unsupported_errors=ignore_unsupported_errors,
115+
quote_identifiers=quote_identifiers,
116+
**kwargs,
117+
)
112118
if not self._connection_pool.is_transaction_active:
113119
self._connection_pool.commit()
114120

0 commit comments

Comments
 (0)