@@ -87,6 +87,21 @@ def _job_params(self) -> t.Dict[str, t.Any]:
8787 params ["maximum_bytes_billed" ] = self ._extra_config .get ("maximum_bytes_billed" )
8888 return params
8989
90+ def _begin_session (self ) -> None :
91+ from google .cloud .bigquery import QueryJobConfig
92+
93+ job = self .client .query ("SELECT 1;" , job_config = QueryJobConfig (create_session = True ))
94+ session_info = job .session_info
95+ session_id = session_info .session_id if session_info else None
96+ self ._session_id = session_id
97+ job .result ()
98+
99+ def _end_session (self ) -> None :
100+ self ._session_id = None
101+
102+ def _is_session_active (self ) -> bool :
103+ return self ._session_id is not None
104+
90105 def create_schema (self , schema_name : str , ignore_if_exists : bool = True ) -> None :
91106 """Create a schema from a name or qualified table name."""
92107 from google .api_core .exceptions import Conflict
@@ -153,7 +168,7 @@ def fetchone(
153168 quote_identifiers = quote_identifiers ,
154169 )
155170 try :
156- return next (self .cursor . _query_data )
171+ return next (self ._query_data )
157172 except StopIteration :
158173 return ()
159174
@@ -172,7 +187,7 @@ def fetchall(
172187 ignore_unsupported_errors = ignore_unsupported_errors ,
173188 quote_identifiers = quote_identifiers ,
174189 )
175- return list (self .cursor . _query_data )
190+ return list (self ._query_data )
176191
177192 def _create_table_from_df (
178193 self ,
@@ -410,7 +425,7 @@ def _fetch_native_df(
410425 self , query : t .Union [exp .Expression , str ], quote_identifiers : bool = False
411426 ) -> DF :
412427 self .execute (query , quote_identifiers = quote_identifiers )
413- return self .cursor . _query_job .to_dataframe ()
428+ return self ._query_job .to_dataframe ()
414429
415430 def _create_table_properties (
416431 self ,
@@ -487,6 +502,7 @@ def execute(
487502 ) -> None :
488503 """Execute a sql query."""
489504 from google .cloud .bigquery import QueryJobConfig
505+ from google .cloud .bigquery .query import ConnectionProperty
490506
491507 to_sql_kwargs = (
492508 {"unsupported_level" : ErrorLevel .IGNORE } if ignore_unsupported_errors else {}
@@ -503,19 +519,30 @@ def execute(
503519 # BigQuery's Python DB API implementation does not support retries, so we have to implement them ourselves.
504520 # So we update the cursor's query job and query data with the results of the new query job. This makes sure
505521 # that other cursor based operations execute correctly.
506- job_config = QueryJobConfig (** self ._job_params )
507- self .cursor ._query_job = self ._db_call (
522+ session_id = self ._session_id
523+ connection_properties = (
524+ [
525+ ConnectionProperty (key = "session_id" , value = session_id ),
526+ ]
527+ if session_id
528+ else []
529+ )
530+
531+ job_config = QueryJobConfig (
532+ ** self ._job_params , connection_properties = connection_properties
533+ )
534+ self ._query_job = self ._db_call (
508535 self .client .query ,
509536 query = sql ,
510537 job_config = job_config ,
511538 timeout = self ._extra_config .get ("job_creation_timeout_seconds" ),
512539 )
513540 results = self ._db_call (
514- self .cursor . _query_job .result ,
541+ self ._query_job .result ,
515542 timeout = self ._extra_config .get ("job_execution_timeout_seconds" ), # type: ignore
516543 )
517- self .cursor . _query_data = iter (results ) if results .total_rows else iter ([])
518- query_results = self .cursor . _query_job ._query_results
544+ self ._query_data = iter (results ) if results .total_rows else iter ([])
545+ query_results = self ._query_job ._query_results
519546 self .cursor ._set_rowcount (query_results )
520547 self .cursor ._set_description (query_results .schema )
521548
@@ -541,6 +568,30 @@ def _get_data_objects(
541568 for table in all_tables
542569 ]
543570
571+ @property
572+ def _query_data (self ) -> t .Any :
573+ return self ._connection_pool .get_attribute ("query_data" )
574+
575+ @_query_data .setter
576+ def _query_data (self , value : t .Any ) -> None :
577+ return self ._connection_pool .set_attribute ("query_data" , value )
578+
579+ @property
580+ def _query_job (self ) -> t .Any :
581+ return self ._connection_pool .get_attribute ("query_job" )
582+
583+ @_query_job .setter
584+ def _query_job (self , value : t .Any ) -> None :
585+ return self ._connection_pool .set_attribute ("query_job" , value )
586+
587+ @property
588+ def _session_id (self ) -> t .Any :
589+ return self ._connection_pool .get_attribute ("session_id" )
590+
591+ @_session_id .setter
592+ def _session_id (self , value : t .Any ) -> None :
593+ return self ._connection_pool .set_attribute ("session_id" , value )
594+
544595
545596class _ErrorCounter :
546597 """
0 commit comments