Skip to content

Commit 814fc01

Browse files
authored
Fix: Support insert overwrite with dynamic partitions in the BigQuery adapter (#1188)
1 parent 151210b commit 814fc01

6 files changed

Lines changed: 118 additions & 55 deletions

File tree

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
"requests",
4646
"rich",
4747
"ruamel.yaml",
48-
"sqlglot~=17.4.0",
48+
"sqlglot~=17.6.0",
4949
"fsspec",
5050
],
5151
extras_require={

sqlmesh/core/engine_adapter/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,11 @@ def _insert_append_pandas_df(
592592
):
593593
self.execute(exp.insert(expression, table_name, columns=column_names))
594594

595-
def insert_overwrite(
595+
def insert_overwrite_by_partition(
596596
self,
597597
table_name: TableName,
598598
query_or_df: QueryOrDF,
599+
partitioned_by: t.List[exp.Expression],
599600
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
600601
) -> None:
601602
self._insert_overwrite_by_condition(
@@ -905,7 +906,7 @@ def execute(
905906
self.cursor.execute(sql, **kwargs)
906907

907908
@contextlib.contextmanager
908-
def temp_table(self, query_or_df: QueryOrDF, name: str = "diff") -> t.Iterator[exp.Table]:
909+
def temp_table(self, query_or_df: QueryOrDF, name: TableName = "diff") -> t.Iterator[exp.Table]:
909910
"""A context manager for working a temp table.
910911
911912
The table will be created with a random guid and cleaned up after the block.

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,46 @@ def merge(
354354
) as source_table:
355355
return super().merge(target_table, source_table, columns_to_types, unique_key)
356356

357+
def insert_overwrite_by_partition(
358+
self,
359+
table_name: TableName,
360+
query_or_df: QueryOrDF,
361+
partitioned_by: t.List[exp.Expression],
362+
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
363+
) -> None:
364+
if len(partitioned_by) != 1:
365+
raise SQLMeshError(
366+
f"Bigquery only supports partitioning by one column, {len(partitioned_by)} were provided."
367+
)
368+
369+
partition_exp = partitioned_by[0]
370+
partition_sql = partition_exp.sql(dialect=self.dialect)
371+
partition_column = partition_exp.find(exp.Column)
372+
373+
if not partition_column:
374+
raise SQLMeshError(
375+
f"The partition expression '{partition_sql}' doesn't contain a column."
376+
)
377+
378+
with self.session(), self.temp_table(query_or_df, name=table_name) as temp_table_name:
379+
if columns_to_types is None:
380+
columns_to_types = self.columns(temp_table_name)
381+
382+
partition_type_sql = columns_to_types[partition_column.name].sql(dialect=self.dialect)
383+
temp_table_name_sql = temp_table_name.sql(dialect=self.dialect)
384+
self.execute(
385+
f"DECLARE target_partitions ARRAY<{partition_type_sql}> DEFAULT (SELECT ARRAY_AGG(DISTINCT {partition_sql}) FROM {temp_table_name_sql});"
386+
)
387+
388+
where = t.cast(exp.Condition, partition_exp).isin(unnest="target_partitions")
389+
390+
self._insert_overwrite_by_condition(
391+
table_name,
392+
exp.select("*").from_(temp_table_name),
393+
where=where,
394+
columns_to_types=columns_to_types,
395+
)
396+
357397
def _insert_overwrite_by_condition(
358398
self,
359399
table_name: TableName,

sqlmesh/core/snapshot/evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -800,8 +800,8 @@ def insert(
800800
**kwargs: t.Any,
801801
) -> None:
802802
if isinstance(model.kind, IncrementalUnmanagedKind) and model.kind.insert_overwrite:
803-
self.adapter.insert_overwrite(
804-
name, query_or_df, columns_to_types=model.columns_to_types
803+
self.adapter.insert_overwrite_by_partition(
804+
name, query_or_df, model.partitioned_by, columns_to_types=model.columns_to_types
805805
)
806806
else:
807807
self.append(model, name, query_or_df, snapshots, is_dev, **kwargs)

tests/core/engine_adapter/test_bigquery.py

Lines changed: 58 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# type: ignore
22
import sys
3+
import typing as t
4+
import uuid
35

46
import pandas as pd
57
import pytest
@@ -35,18 +37,48 @@ def test_insert_overwrite_by_time_partition_query(mocker: MockerFixture):
3537
"ds": exp.DataType.build("string"),
3638
},
3739
)
38-
sql_calls = [
39-
# Python 3.7 support
40-
call[0][0].sql(dialect="bigquery", identify=True)
41-
if isinstance(call[0], tuple)
42-
else call[0].sql(dialect="bigquery", identify=True)
43-
for call in execute_mock.call_args_list
44-
]
40+
sql_calls = _to_sql_calls(execute_mock)
4541
assert sql_calls == [
4642
"MERGE INTO `test_table` AS `__MERGE_TARGET__` USING (SELECT * FROM (SELECT `a`, `ds` FROM `tbl`) AS `_subquery` WHERE `ds` BETWEEN '2022-01-01' AND '2022-01-05') AS `__MERGE_SOURCE__` ON FALSE WHEN NOT MATCHED BY SOURCE AND `ds` BETWEEN '2022-01-01' AND '2022-01-05' THEN DELETE WHEN NOT MATCHED THEN INSERT (`a`, `ds`) VALUES (`a`, `ds`)"
4743
]
4844

4945

46+
def test_insert_overwrite_by_partition_query(mocker: MockerFixture):
47+
connection_mock = mocker.NonCallableMock()
48+
cursor_mock = mocker.Mock()
49+
connection_mock.cursor.return_value = cursor_mock
50+
51+
adapter = BigQueryEngineAdapter(lambda: connection_mock)
52+
execute_mock = mocker.patch(
53+
"sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter.execute"
54+
)
55+
56+
temp_table_uuid = uuid.uuid4()
57+
uuid4_mock = mocker.patch("uuid.uuid4")
58+
uuid4_mock.return_value = temp_table_uuid
59+
60+
adapter.insert_overwrite_by_partition(
61+
"test_schema.test_table",
62+
parse_one("SELECT a, ds FROM tbl"),
63+
partitioned_by=[
64+
d.parse_one("DATETIME_TRUNC(ds, MONTH)"),
65+
],
66+
columns_to_types={
67+
"a": exp.DataType.build("int"),
68+
"ds": exp.DataType.build("DATETIME"),
69+
},
70+
)
71+
72+
sql_calls = _to_sql_calls(execute_mock)
73+
assert sql_calls == [
74+
"CREATE SCHEMA IF NOT EXISTS `test_schema`",
75+
f"CREATE TABLE IF NOT EXISTS `test_schema`.`__temp_test_table_{temp_table_uuid.hex}` AS SELECT `a`, `ds` FROM `tbl`",
76+
f"DECLARE target_partitions ARRAY<DATETIME> DEFAULT (SELECT ARRAY_AGG(DISTINCT DATETIME_TRUNC(ds, MONTH)) FROM test_schema.__temp_test_table_{temp_table_uuid.hex});",
77+
f"MERGE INTO `test_schema`.`test_table` AS `__MERGE_TARGET__` USING (SELECT * FROM (SELECT * FROM `test_schema`.`__temp_test_table_{temp_table_uuid.hex}`) AS `_subquery` WHERE DATETIME_TRUNC(`ds`, MONTH) IN UNNEST(`target_partitions`)) AS `__MERGE_SOURCE__` ON FALSE WHEN NOT MATCHED BY SOURCE AND DATETIME_TRUNC(`ds`, MONTH) IN UNNEST(`target_partitions`) THEN DELETE WHEN NOT MATCHED THEN INSERT (`a`, `ds`) VALUES (`a`, `ds`)",
78+
f"DROP TABLE IF EXISTS `test_schema`.`__temp_test_table_{temp_table_uuid.hex}`",
79+
]
80+
81+
5082
def test_insert_overwrite_by_time_partition_pandas(mocker: MockerFixture):
5183
connection_mock = mocker.NonCallableMock()
5284
cursor_mock = mocker.Mock()
@@ -130,13 +162,7 @@ def test_replace_query(mocker: MockerFixture):
130162
)
131163
adapter.replace_query("test_table", parse_one("SELECT a FROM tbl"), {"a": "int"})
132164

133-
sql_calls = [
134-
# Python 3.7 support
135-
call[0][0].sql(dialect="bigquery", identify=True)
136-
if isinstance(call[0], tuple)
137-
else call[0].sql(dialect="bigquery", identify=True)
138-
for call in execute_mock.call_args_list
139-
]
165+
sql_calls = _to_sql_calls(execute_mock)
140166
assert sql_calls == ["CREATE OR REPLACE TABLE `test_table` AS SELECT `a` FROM `tbl`"]
141167

142168

@@ -221,13 +247,7 @@ def test_create_table_date_partition(
221247
clustered_by=["b"],
222248
)
223249

224-
sql_calls = [
225-
# Python 3.7 support
226-
call[0][0].sql(dialect="bigquery", identify=True)
227-
if isinstance(call[0], tuple)
228-
else call[0].sql(dialect="bigquery", identify=True)
229-
for call in execute_mock.call_args_list
230-
]
250+
sql_calls = _to_sql_calls(execute_mock)
231251
assert sql_calls == [
232252
f"CREATE TABLE IF NOT EXISTS `test_table` (`a` int, `b` int) PARTITION BY {partition_by_statement} CLUSTER BY `b`"
233253
]
@@ -261,13 +281,7 @@ def test_create_table_time_partition(
261281
partition_interval_unit=IntervalUnit.HOUR,
262282
)
263283

264-
sql_calls = [
265-
# Python 3.7 support
266-
call[0][0].sql(dialect="bigquery", identify=True)
267-
if isinstance(call[0], tuple)
268-
else call[0].sql(dialect="bigquery", identify=True)
269-
for call in execute_mock.call_args_list
270-
]
284+
sql_calls = _to_sql_calls(execute_mock)
271285
assert sql_calls == [
272286
f"CREATE TABLE IF NOT EXISTS `test_table` (`a` int, `b` int) PARTITION BY {partition_by_statement}"
273287
]
@@ -292,13 +306,7 @@ def test_merge(mocker: MockerFixture):
292306
},
293307
unique_key=["id"],
294308
)
295-
sql_calls = [
296-
# Python 3.7 support
297-
call[0][0].sql(dialect="bigquery")
298-
if isinstance(call[0], tuple)
299-
else call[0].sql(dialect="bigquery")
300-
for call in execute_mock.call_args_list
301-
]
309+
sql_calls = _to_sql_calls(execute_mock, identify=False)
302310
assert sql_calls == [
303311
"MERGE INTO target AS __MERGE_TARGET__ USING (SELECT id, ts, val FROM source) AS __MERGE_SOURCE__ ON __MERGE_TARGET__.id = __MERGE_SOURCE__.id "
304312
"WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.id = __MERGE_SOURCE__.id, __MERGE_TARGET__.ts = __MERGE_SOURCE__.ts, __MERGE_TARGET__.val = __MERGE_SOURCE__.val "
@@ -335,13 +343,7 @@ def test_merge(mocker: MockerFixture):
335343
unique_key=["id"],
336344
)
337345

338-
sql_calls = [
339-
# Python 3.7 support
340-
call[0][0].sql(dialect="bigquery")
341-
if isinstance(call[0], tuple)
342-
else call[0].sql(dialect="bigquery")
343-
for call in execute_mock.call_args_list
344-
]
346+
sql_calls = _to_sql_calls(execute_mock, identify=False)
345347
assert sql_calls == [
346348
"MERGE INTO target AS __MERGE_TARGET__ USING (SELECT id, ts, val FROM project.dataset.temp_table) AS __MERGE_SOURCE__ ON __MERGE_TARGET__.id = __MERGE_SOURCE__.id "
347349
"WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.id = __MERGE_SOURCE__.id, __MERGE_TARGET__.ts = __MERGE_SOURCE__.ts, __MERGE_TARGET__.val = __MERGE_SOURCE__.val "
@@ -399,3 +401,17 @@ def test_begin_end_session(mocker: MockerFixture):
399401
execute_b_call = connection_mock._client.query.call_args_list[2]
400402
assert execute_b_call[1]["query"] == "SELECT 3;"
401403
assert not execute_b_call[1]["job_config"].connection_properties
404+
405+
406+
def _to_sql_calls(execute_mock: t.Any, identify: bool = True) -> t.List[str]:
407+
output = []
408+
for call in execute_mock.call_args_list:
409+
# Python 3.7 support
410+
value = call[0][0] if isinstance(call[0], tuple) else call[0]
411+
sql = (
412+
value.sql(dialect="bigquery", identify=identify)
413+
if isinstance(value, exp.Expression)
414+
else str(value)
415+
)
416+
output.append(sql)
417+
return output

tests/core/test_snapshot_evaluator.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def test_evaluate_incremental_unmanaged(
277277
name="test_schema.test_model",
278278
query=parse_one("SELECT 1, ds FROM tbl_a"),
279279
kind=IncrementalUnmanagedKind(insert_overwrite=insert_overwrite),
280+
partitioned_by=["ds"],
280281
)
281282
snapshot = make_snapshot(model)
282283
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
@@ -290,14 +291,19 @@ def test_evaluate_incremental_unmanaged(
290291
snapshots={},
291292
)
292293

293-
expected_call = (
294-
adapter_mock.insert_overwrite if insert_overwrite else adapter_mock.insert_append
295-
)
296-
expected_call.assert_called_once_with(
297-
snapshot.table_name(),
298-
model.render_query(),
299-
columns_to_types=model.columns_to_types,
300-
)
294+
if insert_overwrite:
295+
adapter_mock.insert_overwrite_by_partition.assert_called_once_with(
296+
snapshot.table_name(),
297+
model.render_query(),
298+
[exp.to_column("ds")],
299+
columns_to_types=model.columns_to_types,
300+
)
301+
else:
302+
adapter_mock.insert_append.assert_called_once_with(
303+
snapshot.table_name(),
304+
model.render_query(),
305+
columns_to_types=model.columns_to_types,
306+
)
301307

302308

303309
def test_create_materialized_view(mocker: MockerFixture, adapter_mock, make_snapshot):

0 commit comments

Comments
 (0)