4141from google .cloud .spanner_admin_database_v1 import UpdateDatabaseDdlRequest
4242from google .cloud .spanner_admin_database_v1 .types import DatabaseDialect
4343from google .cloud .spanner_v1 .database_sessions_manager import DatabaseSessionsManager
44+ from google .cloud .spanner_v1 .session import Session
45+ from google .cloud .spanner_v1 .session_options import SessionOptions , TransactionType
4446from google .cloud .spanner_v1 .transaction import BatchTransactionId
4547from google .cloud .spanner_v1 import ExecuteSqlRequest
4648from google .cloud .spanner_v1 import Type
6062from google .cloud .spanner_v1 .keyset import KeySet
6163from google .cloud .spanner_v1 .merged_result_set import MergedResultSet
6264from google .cloud .spanner_v1 .pool import BurstyPool
63- from google .cloud .spanner_v1 .session import Session
6465from google .cloud .spanner_v1 .snapshot import _restart_on_unavailable
6566from google .cloud .spanner_v1 .snapshot import Snapshot
6667from google .cloud .spanner_v1 .streamed import StreamedResultSet
7475)
7576from google .cloud .spanner_v1 .metrics .metrics_capture import MetricsCapture
7677
77-
7878SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data"
7979
8080
@@ -192,9 +192,7 @@ def __init__(
192192 pool = BurstyPool (database_role = database_role )
193193 pool .bind (self )
194194
195- self ._session_manager = DatabaseSessionsManager (
196- database = self , pool = pool , logger = self .logger
197- )
195+ self ._session_manager = DatabaseSessionsManager (database = self , pool = pool )
198196
199197 @classmethod
200198 def from_pb (cls , database_pb , instance , pool = None ):
@@ -449,6 +447,15 @@ def spanner_api(self):
449447 )
450448 return self ._spanner_api
451449
450+ @property
451+ def session_options (self ) -> SessionOptions :
452+ """Session options for the database.
453+
454+ :rtype: :class:`~google.cloud.spanner_v1.session_options.SessionOptions`
455+ :returns: the session options
456+ """
457+ return self ._instance ._client .session_options
458+
452459 def __eq__ (self , other ):
453460 if not isinstance (other , self .__class__ ):
454461 return NotImplemented
@@ -709,11 +716,27 @@ def execute_pdml():
709716 "CloudSpanner.Database.execute_partitioned_pdml" ,
710717 observability_options = self .observability_options ,
711718 ) as span , MetricsCapture ():
712- with SessionCheckout (self ) as session :
719+ transaction_type = TransactionType .PARTITIONED
720+ with SessionCheckout (self , transaction_type ) as session :
713721 add_span_event (span , "Starting BeginTransaction" )
714- txn = api .begin_transaction (
715- session = session .name , options = txn_options , metadata = metadata
716- )
722+
723+ try :
724+ txn = api .begin_transaction (
725+ session = session .name , options = txn_options , metadata = metadata
726+ )
727+
728+ # If partitioned DML is not supported with multiplexed sessions,
729+ # disable multiplexed sessions for partitioned transactions before
730+ # re-raising the error.
731+ except NotImplementedError as exc :
732+ if (
733+ "Transaction type partitioned_dml not supported with multiplexed sessions"
734+ in str (exc )
735+ ):
736+ self .session_options .disable_multiplexed (
737+ self .logger , transaction_type
738+ )
739+ raise exc
717740
718741 txn_selector = TransactionSelector (id = txn .id )
719742
@@ -732,8 +755,9 @@ def execute_pdml():
732755
733756 iterator = _restart_on_unavailable (
734757 method = method ,
735- trace_name = "CloudSpanner.ExecuteStreamingSql" ,
736758 request = request ,
759+ session = session ,
760+ trace_name = "CloudSpanner.ExecuteStreamingSql" ,
737761 metadata = metadata ,
738762 transaction_selector = txn_selector ,
739763 observability_options = self .observability_options ,
@@ -746,23 +770,6 @@ def execute_pdml():
746770
747771 return _retry_on_aborted (execute_pdml , DEFAULT_RETRY_BACKOFF )()
748772
749- def session (self , labels = None , database_role = None ):
750- """Factory to create a session for this database.
751-
752- :type labels: dict (str -> str) or None
753- :param labels: (Optional) user-assigned labels for the session.
754-
755- :type database_role: str
756- :param database_role: (Optional) user-assigned database_role for the session.
757-
758- :rtype: :class:`~google.cloud.spanner_v1.session.Session`
759- :returns: a session bound to this database.
760- """
761- # If role is specified in param, then that role is used
762- # instead.
763- role = database_role or self ._database_role
764- return Session (self , labels = labels , database_role = role )
765-
766773 def snapshot (self , ** kw ):
767774 """Return an object which wraps a snapshot.
768775
@@ -1170,7 +1177,11 @@ class SessionCheckout(object):
11701177
11711178 _session = None # Not checked out until '__enter__'.
11721179
1173- def __init__ (self , database ):
1180+ def __init__ (
1181+ self ,
1182+ database : Database ,
1183+ transaction_type : TransactionType = TransactionType .READ_WRITE ,
1184+ ):
11741185 if not isinstance (database , Database ):
11751186 raise TypeError (
11761187 "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}" .format (
@@ -1180,10 +1191,21 @@ def __init__(self, database):
11801191 )
11811192 )
11821193
1194+ if not isinstance (transaction_type , TransactionType ):
1195+ raise TypeError (
1196+ "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}" .format (
1197+ class_name = self .__class__ .__name__ ,
1198+ expected_class_name = TransactionType .__name__ ,
1199+ actual_class_name = transaction_type .__class__ .__name__ ,
1200+ )
1201+ )
1202+
11831203 self ._database = database
1204+ self ._transaction_type = transaction_type
11841205
11851206 def __enter__ (self ):
1186- self ._session = self ._database ._session_manager .get_session_for_read_write ()
1207+ session_manager = self ._database ._session_manager
1208+ self ._session = session_manager .get_session (self ._transaction_type )
11871209 return self ._session
11881210
11891211 def __exit__ (self , * ignored ):
@@ -1248,7 +1270,13 @@ def __init__(
12481270
12491271 def __enter__ (self ):
12501272 """Begin ``with`` block."""
1251- self ._session = self ._database ._session_manager .get_session_for_read_only ()
1273+
1274+ # Batch transactions are performed as blind writes,
1275+ # which are treated as read-only transactions.
1276+ self ._session = self ._database ._session_manager .get_session (
1277+ TransactionType .READ_ONLY
1278+ )
1279+
12521280 batch = self ._batch = Batch (self ._session )
12531281 if self ._request_options .transaction_tag :
12541282 batch .transaction_tag = self ._request_options .transaction_tag
@@ -1303,7 +1331,9 @@ def __init__(self, database):
13031331
13041332 def __enter__ (self ):
13051333 """Begin ``with`` block."""
1306- self ._session = self ._database ._session_manager .get_session_for_read_write ()
1334+ self ._session = self ._database ._session_manager .get_session (
1335+ TransactionType .READ_WRITE
1336+ )
13071337 return MutationGroups (self ._session )
13081338
13091339 def __exit__ (self , exc_type , exc_val , exc_tb ):
@@ -1345,7 +1375,9 @@ def __init__(self, database, **kw):
13451375
13461376 def __enter__ (self ):
13471377 """Begin ``with`` block."""
1348- self ._session = self ._database ._session_manager .get_session_for_read_only ()
1378+ self ._session = self ._database ._session_manager .get_session (
1379+ TransactionType .READ_ONLY
1380+ )
13491381 return Snapshot (self ._session , ** self ._kw )
13501382
13511383 def __exit__ (self , exc_type , exc_val , exc_tb ):
@@ -1395,11 +1427,15 @@ def from_dict(cls, database, mapping):
13951427
13961428 :rtype: :class:`BatchSnapshot`
13971429 """
1430+
13981431 instance = cls (database )
1399- session = instance ._session = database .session ()
1400- session ._session_id = mapping ["session_id" ]
1432+
1433+ session = instance ._session = Session (database = database )
1434+ instance ._session_id = session ._session_id = mapping ["session_id" ]
1435+
14011436 snapshot = instance ._snapshot = session .snapshot ()
1402- snapshot ._transaction_id = mapping ["transaction_id" ]
1437+ instance ._transaction_id = snapshot ._transaction_id = mapping ["transaction_id" ]
1438+
14031439 return instance
14041440
14051441 def to_dict (self ):
@@ -1408,10 +1444,15 @@ def to_dict(self):
14081444 Result can be used to serialize the instance and reconstitute
14091445 it later using :meth:`from_dict`.
14101446
1447+ When called, the underlying session is cleaned up, so
1448+ the batch snapshot is no longer valid.
1449+
14111450 :rtype: dict
14121451 """
1452+
14131453 session = self ._get_session ()
14141454 snapshot = self ._get_snapshot ()
1455+
14151456 return {
14161457 "session_id" : session ._session_id ,
14171458 "transaction_id" : snapshot ._transaction_id ,
@@ -1429,25 +1470,48 @@ def _get_session(self):
14291470 Caller is responsible for cleaning up the session after
14301471 all partitions have been processed.
14311472 """
1473+
14321474 if self ._session is None :
1433- session = self ._session = self ._database .session ()
1475+ database = self ._database
1476+
1477+ # If the session ID is not specified, check out a new session from
1478+ # the database session manager; otherwise, the session has already
1479+ # been checked out, so just create a session object to represent it.
14341480 if self ._session_id is None :
1435- session .create ()
1481+ transaction_type = TransactionType .READ_ONLY
1482+ session = database ._session_manager .get_session (transaction_type )
1483+ self ._session_id = session .session_id
1484+
14361485 else :
1486+ session = Session (database = database )
14371487 session ._session_id = self ._session_id
1488+
1489+ self ._session = session
1490+
14381491 return self ._session
14391492
14401493 def _get_snapshot (self ):
14411494 """Create snapshot if needed."""
1495+
14421496 if self ._snapshot is None :
1443- self ._snapshot = self ._get_session ().snapshot (
1444- read_timestamp = self ._read_timestamp ,
1445- exact_staleness = self ._exact_staleness ,
1446- multi_use = True ,
1447- transaction_id = self ._transaction_id ,
1448- )
1497+ snapshot_args = {
1498+ "session" : self ._get_session (),
1499+ "read_timestamp" : self ._read_timestamp ,
1500+ "exact_staleness" : self ._exact_staleness ,
1501+ "multi_use" : True ,
1502+ }
1503+
1504+ # If the transaction ID is not specified, create a new snapshot
1505+ # and begin a transaction; otherwise, the transaction is already
1506+ # in progress, so just create a snapshot object to represent it.
14491507 if self ._transaction_id is None :
1450- self ._snapshot .begin ()
1508+ self ._snapshot = Snapshot (** snapshot_args )
1509+ self ._transaction_id = self ._snapshot .begin ()
1510+
1511+ else :
1512+ snapshot_args ["transaction_id" ] = self ._transaction_id
1513+ self ._snapshot = Snapshot (** snapshot_args )
1514+
14511515 return self ._snapshot
14521516
14531517 def get_batch_transaction_id (self ):
@@ -1844,7 +1908,7 @@ def close(self):
18441908 from all the partitions.
18451909 """
18461910 if self ._session is not None :
1847- self ._session . delete ( )
1911+ self ._database . _session_manager . put_session ( self . _session )
18481912
18491913
18501914def _check_ddl_statements (value ):
0 commit comments