Skip to content

Commit 34782b8

Browse files
authored
Feat: Support sessions in the BigQuery adapter (#1185)
1 parent d45e070 commit 34782b8

6 files changed

Lines changed: 164 additions & 10 deletions

File tree

sqlmesh/core/engine_adapter/base.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def fetch_pyspark_df(
840840
@contextlib.contextmanager
841841
def transaction(
842842
self, transaction_type: TransactionType = TransactionType.DML
843-
) -> t.Generator[None, None, None]:
843+
) -> t.Iterator[None]:
844844
"""A transaction context manager."""
845845
if self._connection_pool.is_transaction_active or not self.supports_transactions(
846846
transaction_type
@@ -860,6 +860,29 @@ def supports_transactions(self, transaction_type: TransactionType) -> bool:
860860
"""Whether or not the engine adapter supports transactions for the given transaction type."""
861861
return True
862862

863+
@contextlib.contextmanager
864+
def session(self) -> t.Iterator[None]:
865+
"""A session context manager."""
866+
if self._is_session_active():
867+
yield
868+
return
869+
870+
self._begin_session()
871+
try:
872+
yield
873+
finally:
874+
self._end_session()
875+
876+
def _begin_session(self) -> None:
877+
"""Begin a new session."""
878+
879+
def _end_session(self) -> None:
880+
"""End the existing session."""
881+
882+
def _is_session_active(self) -> bool:
883+
"""Indicates whether or not a session is active."""
884+
return False
885+
863886
def execute(
864887
self,
865888
expressions: t.Union[str, exp.Expression, t.Sequence[exp.Expression]],

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,21 @@ def _job_params(self) -> t.Dict[str, t.Any]:
8787
params["maximum_bytes_billed"] = self._extra_config.get("maximum_bytes_billed")
8888
return params
8989

90+
def _begin_session(self) -> None:
91+
from google.cloud.bigquery import QueryJobConfig
92+
93+
job = self.client.query("SELECT 1;", job_config=QueryJobConfig(create_session=True))
94+
session_info = job.session_info
95+
session_id = session_info.session_id if session_info else None
96+
self._session_id = session_id
97+
job.result()
98+
99+
def _end_session(self) -> None:
100+
self._session_id = None
101+
102+
def _is_session_active(self) -> bool:
103+
return self._session_id is not None
104+
90105
def create_schema(self, schema_name: str, ignore_if_exists: bool = True) -> None:
91106
"""Create a schema from a name or qualified table name."""
92107
from google.api_core.exceptions import Conflict
@@ -153,7 +168,7 @@ def fetchone(
153168
quote_identifiers=quote_identifiers,
154169
)
155170
try:
156-
return next(self.cursor._query_data)
171+
return next(self._query_data)
157172
except StopIteration:
158173
return ()
159174

@@ -172,7 +187,7 @@ def fetchall(
172187
ignore_unsupported_errors=ignore_unsupported_errors,
173188
quote_identifiers=quote_identifiers,
174189
)
175-
return list(self.cursor._query_data)
190+
return list(self._query_data)
176191

177192
def _create_table_from_df(
178193
self,
@@ -410,7 +425,7 @@ def _fetch_native_df(
410425
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
411426
) -> DF:
412427
self.execute(query, quote_identifiers=quote_identifiers)
413-
return self.cursor._query_job.to_dataframe()
428+
return self._query_job.to_dataframe()
414429

415430
def _create_table_properties(
416431
self,
@@ -487,6 +502,7 @@ def execute(
487502
) -> None:
488503
"""Execute a sql query."""
489504
from google.cloud.bigquery import QueryJobConfig
505+
from google.cloud.bigquery.query import ConnectionProperty
490506

491507
to_sql_kwargs = (
492508
{"unsupported_level": ErrorLevel.IGNORE} if ignore_unsupported_errors else {}
@@ -503,19 +519,30 @@ def execute(
503519
# BigQuery's Python DB API implementation does not support retries, so we have to implement them ourselves.
504520
# So we update the cursor's query job and query data with the results of the new query job. This makes sure
505521
# that other cursor based operations execute correctly.
506-
job_config = QueryJobConfig(**self._job_params)
507-
self.cursor._query_job = self._db_call(
522+
session_id = self._session_id
523+
connection_properties = (
524+
[
525+
ConnectionProperty(key="session_id", value=session_id),
526+
]
527+
if session_id
528+
else []
529+
)
530+
531+
job_config = QueryJobConfig(
532+
**self._job_params, connection_properties=connection_properties
533+
)
534+
self._query_job = self._db_call(
508535
self.client.query,
509536
query=sql,
510537
job_config=job_config,
511538
timeout=self._extra_config.get("job_creation_timeout_seconds"),
512539
)
513540
results = self._db_call(
514-
self.cursor._query_job.result,
541+
self._query_job.result,
515542
timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore
516543
)
517-
self.cursor._query_data = iter(results) if results.total_rows else iter([])
518-
query_results = self.cursor._query_job._query_results
544+
self._query_data = iter(results) if results.total_rows else iter([])
545+
query_results = self._query_job._query_results
519546
self.cursor._set_rowcount(query_results)
520547
self.cursor._set_description(query_results.schema)
521548

@@ -541,6 +568,30 @@ def _get_data_objects(
541568
for table in all_tables
542569
]
543570

571+
@property
572+
def _query_data(self) -> t.Any:
573+
return self._connection_pool.get_attribute("query_data")
574+
575+
@_query_data.setter
576+
def _query_data(self, value: t.Any) -> None:
577+
return self._connection_pool.set_attribute("query_data", value)
578+
579+
@property
580+
def _query_job(self) -> t.Any:
581+
return self._connection_pool.get_attribute("query_job")
582+
583+
@_query_job.setter
584+
def _query_job(self, value: t.Any) -> None:
585+
return self._connection_pool.set_attribute("query_job", value)
586+
587+
@property
588+
def _session_id(self) -> t.Any:
589+
return self._connection_pool.get_attribute("session_id")
590+
591+
@_session_id.setter
592+
def _session_id(self, value: t.Any) -> None:
593+
return self._connection_pool.set_attribute("session_id", value)
594+
544595

545596
class _ErrorCounter:
546597
"""

sqlmesh/core/snapshot/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
145145
transaction_type=TransactionType.DDL
146146
if model.kind.is_view or model.kind.is_full
147147
else TransactionType.DML
148-
):
148+
), self.adapter.session():
149149
if not limit:
150150
self.adapter.execute(model.render_pre_statements(**render_statements_kwargs))
151151

sqlmesh/utils/connection_pool.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
import logging
33
import typing as t
4+
from collections import defaultdict
45
from threading import Lock, get_ident
56

67
logger = logging.getLogger(__name__)
@@ -27,6 +28,26 @@ def get(self) -> t.Any:
2728
A connection instance.
2829
"""
2930

31+
@abc.abstractmethod
32+
def get_attribute(self, key: str) -> t.Optional[t.Any]:
33+
"""Returns an attribute associated with the connection.
34+
35+
Args:
36+
key: Attribute key.
37+
38+
Returns:
39+
Attribute value or None if not found.
40+
"""
41+
42+
@abc.abstractmethod
43+
def set_attribute(self, key: str, value: t.Any) -> None:
44+
"""Sets an attribute associated with the connection.
45+
46+
Args:
47+
key: Attribute key.
48+
value: Attribute value.
49+
"""
50+
3051
@abc.abstractmethod
3152
def begin(self) -> None:
3253
"""Starts a new transaction."""
@@ -96,6 +117,7 @@ def __init__(self, connection_factory: t.Callable[[], t.Any]):
96117
self._thread_connections: t.Dict[t.Hashable, t.Any] = {}
97118
self._thread_cursors: t.Dict[t.Hashable, t.Any] = {}
98119
self._thread_transactions: t.Set[t.Hashable] = set()
120+
self._thread_attributes: t.Dict[t.Hashable, t.Dict[str, t.Any]] = defaultdict(dict)
99121
self._thread_connections_lock = Lock()
100122
self._thread_cursors_lock = Lock()
101123
self._thread_transactions_lock = Lock()
@@ -114,6 +136,14 @@ def get(self) -> t.Any:
114136
self._thread_connections[thread_id] = self._connection_factory()
115137
return self._thread_connections[thread_id]
116138

139+
def get_attribute(self, key: str) -> t.Optional[t.Any]:
140+
thread_id = get_ident()
141+
return self._thread_attributes[thread_id].get(key)
142+
143+
def set_attribute(self, key: str, value: t.Any) -> None:
144+
thread_id = get_ident()
145+
self._thread_attributes[thread_id][key] = value
146+
117147
def begin(self) -> None:
118148
self._do_begin()
119149
with self._thread_transactions_lock:
@@ -147,6 +177,7 @@ def close(self) -> None:
147177
self._thread_connections.pop(thread_id)
148178
self._thread_cursors.pop(thread_id, None)
149179
self._discard_transaction(thread_id)
180+
self._thread_attributes.pop(thread_id, None)
150181

151182
def close_all(self, exclude_calling_thread: bool = False) -> None:
152183
calling_thread_id = get_ident()
@@ -158,6 +189,7 @@ def close_all(self, exclude_calling_thread: bool = False) -> None:
158189
self._thread_connections.pop(thread_id)
159190
self._thread_cursors.pop(thread_id, None)
160191
self._discard_transaction(thread_id)
192+
self._thread_attributes.pop(thread_id, None)
161193

162194
def _discard_transaction(self, thread_id: t.Hashable) -> None:
163195
with self._thread_transactions_lock:
@@ -169,6 +201,7 @@ def __init__(self, connection_factory: t.Callable[[], t.Any]):
169201
self._connection_factory = connection_factory
170202
self._connection: t.Optional[t.Any] = None
171203
self._cursor: t.Optional[t.Any] = None
204+
self._attributes: t.Dict[str, t.Any] = {}
172205
self._is_transaction_active: bool = False
173206

174207
def get_cursor(self) -> t.Any:
@@ -181,6 +214,12 @@ def get(self) -> t.Any:
181214
self._connection = self._connection_factory()
182215
return self._connection
183216

217+
def get_attribute(self, key: str) -> t.Optional[t.Any]:
218+
return self._attributes.get(key)
219+
220+
def set_attribute(self, key: str, value: t.Any) -> None:
221+
self._attributes[key] = value
222+
184223
def begin(self) -> None:
185224
self._do_begin()
186225
self._is_transaction_active = True
@@ -206,6 +245,7 @@ def close(self) -> None:
206245
self._connection = None
207246
self._cursor = None
208247
self._is_transaction_active = False
248+
self._attributes.clear()
209249

210250
def close_all(self, exclude_calling_thread: bool = False) -> None:
211251
if not exclude_calling_thread:

tests/core/engine_adapter/test_bigquery.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,3 +364,38 @@ def test_merge(mocker: MockerFixture):
364364
"job_config",
365365
"table",
366366
]
367+
368+
369+
def test_begin_end_session(mocker: MockerFixture):
370+
connection_mock = mocker.NonCallableMock()
371+
cursor_mock = mocker.Mock()
372+
cursor_mock.connection = connection_mock
373+
connection_mock.cursor.return_value = cursor_mock
374+
375+
query_result_mock = mocker.Mock()
376+
query_result_mock.total_rows = 0
377+
job_mock = mocker.Mock()
378+
job_mock.result.return_value = query_result_mock
379+
connection_mock._client.query.return_value = job_mock
380+
381+
adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0)
382+
383+
with adapter.session():
384+
assert adapter._connection_pool.get_attribute("session_id") is not None
385+
adapter.execute("SELECT 2;")
386+
387+
assert adapter._connection_pool.get_attribute("session_id") is None
388+
adapter.execute("SELECT 3;")
389+
390+
begin_session_call = connection_mock._client.query.call_args_list[0]
391+
assert begin_session_call[0][0] == "SELECT 1;"
392+
393+
execute_a_call = connection_mock._client.query.call_args_list[1]
394+
assert execute_a_call[1]["query"] == "SELECT 2;"
395+
assert len(execute_a_call[1]["job_config"].connection_properties) == 1
396+
assert execute_a_call[1]["job_config"].connection_properties[0].key == "session_id"
397+
assert execute_a_call[1]["job_config"].connection_properties[0].value
398+
399+
execute_b_call = connection_mock._client.query.call_args_list[2]
400+
assert execute_b_call[1]["query"] == "SELECT 3;"
401+
assert not execute_b_call[1]["job_config"].connection_properties

tests/core/test_snapshot_evaluator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,13 @@ def adapter_mock(mocker: MockerFixture):
6262
transaction_mock.__enter__ = mocker.Mock()
6363
transaction_mock.__exit__ = mocker.Mock()
6464

65+
session_mock = mocker.Mock()
66+
session_mock.__enter__ = mocker.Mock()
67+
session_mock.__exit__ = mocker.Mock()
68+
6569
adapter_mock = mocker.Mock()
6670
adapter_mock.transaction.return_value = transaction_mock
71+
adapter_mock.session.return_value = session_mock
6772
adapter_mock.dialect = "duckdb"
6873
return adapter_mock
6974

0 commit comments

Comments
 (0)