4040from google .cloud .spanner_admin_database_v1 import RestoreDatabaseRequest
4141from google .cloud .spanner_admin_database_v1 import UpdateDatabaseDdlRequest
4242from google .cloud .spanner_admin_database_v1 .types import DatabaseDialect
43+ from google .cloud .spanner_v1 .database_sessions_manager import DatabaseSessionsManager
4344from google .cloud .spanner_v1 .transaction import BatchTransactionId
4445from google .cloud .spanner_v1 import ExecuteSqlRequest
4546from google .cloud .spanner_v1 import Type
5960from google .cloud .spanner_v1 .keyset import KeySet
6061from google .cloud .spanner_v1 .merged_result_set import MergedResultSet
6162from google .cloud .spanner_v1 .pool import BurstyPool
62- from google .cloud .spanner_v1 .pool import SessionCheckout
6363from google .cloud .spanner_v1 .session import Session
6464from google .cloud .spanner_v1 .snapshot import _restart_on_unavailable
6565from google .cloud .spanner_v1 .snapshot import Snapshot
7070from google .cloud .spanner_v1 .table import Table
7171from google .cloud .spanner_v1 ._opentelemetry_tracing import (
7272 add_span_event ,
73- get_current_span ,
7473 trace_call ,
7574)
7675from google .cloud .spanner_v1 .metrics .metrics_capture import MetricsCapture
@@ -191,10 +190,10 @@ def __init__(
191190
192191 if pool is None :
193192 pool = BurstyPool (database_role = database_role )
194-
195- self ._pool = pool
196193 pool .bind (self )
197194
195+ self ._session_manager = DatabaseSessionsManager (database = self , pool = pool )
196+
198197 @classmethod
199198 def from_pb (cls , database_pb , instance , pool = None ):
200199 """Creates an instance of this class from a protobuf.
@@ -708,7 +707,7 @@ def execute_pdml():
708707 "CloudSpanner.Database.execute_partitioned_pdml" ,
709708 observability_options = self .observability_options ,
710709 ) as span , MetricsCapture ():
711- with SessionCheckout (self . _pool ) as session :
710+ with SessionCheckout (self ) as session :
712711 add_span_event (span , "Starting BeginTransaction" )
713712 txn = api .begin_transaction (
714713 session = session .name , options = txn_options , metadata = metadata
@@ -759,6 +758,7 @@ def session(self, labels=None, database_role=None):
759758 """
760759 # If role is specified in param, then that role is used
761760 # instead.
761+
762762 role = database_role or self ._database_role
763763 return Session (self , labels = labels , database_role = role )
764764
@@ -923,7 +923,7 @@ def run_in_transaction(self, func, *args, **kw):
923923 # Check out a session and run the function in a transaction; once
924924 # done, flip the sanity check bit back.
925925 try :
926- with SessionCheckout (self . _pool ) as session :
926+ with SessionCheckout (self ) as session :
927927 return session .run_in_transaction (func , * args , ** kw )
928928 finally :
929929 self ._local .transaction_running = False
@@ -1160,6 +1160,29 @@ def observability_options(self):
11601160 return opts
11611161
11621162
1163+ class SessionCheckout (object ):
1164+ """Context manager for using a session from a database.
1165+
1166+ :type database: :class:`~google.cloud.spanner_v1.database.Database`
1167+ :param database: database to use the session from
1168+ """
1169+
1170+ _session = None # Not checked out until '__enter__'.
1171+
1172+ def __init__ (self , database ):
1173+ if not isinstance (database , Database ):
1174+ raise TypeError (f"database must be an instance of { Database .__name__ } " )
1175+
1176+ self ._database = database
1177+
1178+ def __enter__ (self ):
1179+ self ._session = self ._database ._session_manager .get_session_for_read_write ()
1180+ return self ._session
1181+
1182+ def __exit__ (self , * ignored ):
1183+ self ._database ._session_manager .put_session (self ._session )
1184+
1185+
11631186class BatchCheckout (object ):
11641187 """Context manager for using a batch from a database.
11651188
@@ -1194,6 +1217,9 @@ def __init__(
11941217 isolation_level = TransactionOptions .IsolationLevel .ISOLATION_LEVEL_UNSPECIFIED ,
11951218 ** kw ,
11961219 ):
1220+ if not isinstance (database , Database ):
1221+ raise TypeError (f"database must be an instance of { Database .__name__ } " )
1222+
11971223 self ._database = database
11981224 self ._session = self ._batch = None
11991225 if request_options is None :
@@ -1209,9 +1235,9 @@ def __init__(
12091235
12101236 def __enter__ (self ):
12111237 """Begin ``with`` block."""
1212- current_span = get_current_span ()
1213- session = self . _session = self ._database . _pool . get ()
1214- add_span_event ( current_span , "Using session" , { "id" : session . session_id } )
1238+ session = (
1239+ self ._session
1240+ ) = self . _database . _session_manager . get_session_for_read_only ( )
12151241 batch = self ._batch = Batch (session )
12161242 if self ._request_options .transaction_tag :
12171243 batch .transaction_tag = self ._request_options .transaction_tag
@@ -1235,13 +1261,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
12351261 "CommitStats: {}" .format (self ._batch .commit_stats ),
12361262 extra = {"commit_stats" : self ._batch .commit_stats },
12371263 )
1238- self ._database ._pool .put (self ._session )
1239- current_span = get_current_span ()
1240- add_span_event (
1241- current_span ,
1242- "Returned session to pool" ,
1243- {"id" : self ._session .session_id },
1244- )
1264+ self ._database ._session_manager .put_session (self ._session )
12451265
12461266
12471267class MutationGroupsCheckout (object ):
@@ -1258,23 +1278,22 @@ class MutationGroupsCheckout(object):
12581278 """
12591279
12601280 def __init__ (self , database ):
1281+ if not isinstance (database , Database ):
1282+ raise TypeError (f"database must be an instance of { Database .__name__ } " )
1283+
12611284 self ._database = database
12621285 self ._session = None
12631286
12641287 def __enter__ (self ):
12651288 """Begin ``with`` block."""
1266- session = self ._session = self ._database ._pool .get ()
1289+ session = (
1290+ self ._session
1291+ ) = self ._database ._session_manager .get_session_for_read_write ()
12671292 return MutationGroups (session )
12681293
12691294 def __exit__ (self , exc_type , exc_val , exc_tb ):
12701295 """End ``with`` block."""
1271- if isinstance (exc_val , NotFound ):
1272- # If NotFound exception occurs inside the with block
1273- # then we validate if the session still exists.
1274- if not self ._session .exists ():
1275- self ._session = self ._database ._pool ._new_session ()
1276- self ._session .create ()
1277- self ._database ._pool .put (self ._session )
1296+ self ._database ._session_manager .put_session (self ._session )
12781297
12791298
12801299class SnapshotCheckout (object ):
@@ -1296,24 +1315,23 @@ class SnapshotCheckout(object):
12961315 """
12971316
12981317 def __init__ (self , database , ** kw ):
1318+ if not isinstance (database , Database ):
1319+ raise TypeError (f"database must be an instance of { Database .__name__ } " )
1320+
12991321 self ._database = database
13001322 self ._session = None
13011323 self ._kw = kw
13021324
13031325 def __enter__ (self ):
13041326 """Begin ``with`` block."""
1305- session = self ._session = self ._database ._pool .get ()
1327+ session = (
1328+ self ._session
1329+ ) = self ._database ._session_manager .get_session_for_read_only ()
13061330 return Snapshot (session , ** self ._kw )
13071331
13081332 def __exit__ (self , exc_type , exc_val , exc_tb ):
13091333 """End ``with`` block."""
1310- if isinstance (exc_val , NotFound ):
1311- # If NotFound exception occurs inside the with block
1312- # then we validate if the session still exists.
1313- if not self ._session .exists ():
1314- self ._session = self ._database ._pool ._new_session ()
1315- self ._session .create ()
1316- self ._database ._pool .put (self ._session )
1334+ self ._database ._session_manager .put_session (self ._session )
13171335
13181336
13191337class BatchSnapshot (object ):
@@ -1392,6 +1410,7 @@ def _get_session(self):
13921410 Caller is responsible for cleaning up the session after
13931411 all partitions have been processed.
13941412 """
1413+
13951414 if self ._session is None :
13961415 session = self ._session = self ._database .session ()
13971416 if self ._session_id is None :
0 commit comments