4040from google .cloud .spanner_admin_database_v1 import RestoreDatabaseRequest
4141from google .cloud .spanner_admin_database_v1 import UpdateDatabaseDdlRequest
4242from google .cloud .spanner_admin_database_v1 .types import DatabaseDialect
43+ from google .cloud .spanner_v1 .database_sessions_manager import DatabaseSessionsManager
4344from google .cloud .spanner_v1 .transaction import BatchTransactionId
4445from google .cloud .spanner_v1 import ExecuteSqlRequest
4546from google .cloud .spanner_v1 import Type
5960from google .cloud .spanner_v1 .keyset import KeySet
6061from google .cloud .spanner_v1 .merged_result_set import MergedResultSet
6162from google .cloud .spanner_v1 .pool import BurstyPool
62- from google .cloud .spanner_v1 .pool import SessionCheckout
6363from google .cloud .spanner_v1 .session import Session
6464from google .cloud .spanner_v1 .snapshot import _restart_on_unavailable
6565from google .cloud .spanner_v1 .snapshot import Snapshot
7070from google .cloud .spanner_v1 .table import Table
7171from google .cloud .spanner_v1 ._opentelemetry_tracing import (
7272 add_span_event ,
73- get_current_span ,
7473 trace_call ,
7574)
7675from google .cloud .spanner_v1 .metrics .metrics_capture import MetricsCapture
@@ -191,10 +190,10 @@ def __init__(
191190
192191 if pool is None :
193192 pool = BurstyPool (database_role = database_role )
194-
195- self ._pool = pool
196193 pool .bind (self )
197194
195+ self ._session_manager = DatabaseSessionsManager (database = self , pool = pool )
196+
198197 @classmethod
199198 def from_pb (cls , database_pb , instance , pool = None ):
200199 """Creates an instance of this class from a protobuf.
@@ -708,7 +707,7 @@ def execute_pdml():
708707 "CloudSpanner.Database.execute_partitioned_pdml" ,
709708 observability_options = self .observability_options ,
710709 ) as span , MetricsCapture ():
711- with SessionCheckout (self . _pool ) as session :
710+ with SessionCheckout (self ) as session :
712711 add_span_event (span , "Starting BeginTransaction" )
713712 txn = api .begin_transaction (
714713 session = session .name , options = txn_options , metadata = metadata
@@ -923,7 +922,7 @@ def run_in_transaction(self, func, *args, **kw):
923922 # Check out a session and run the function in a transaction; once
924923 # done, flip the sanity check bit back.
925924 try :
926- with SessionCheckout (self . _pool ) as session :
925+ with SessionCheckout (self ) as session :
927926 return session .run_in_transaction (func , * args , ** kw )
928927 finally :
929928 self ._local .transaction_running = False
@@ -1160,6 +1159,35 @@ def observability_options(self):
11601159 return opts
11611160
11621161
1162+ class SessionCheckout (object ):
1163+ """Context manager for using a session from a database.
1164+
1165+ :type database: :class:`~google.cloud.spanner_v1.database.Database`
1166+ :param database: database to use the session from
1167+ """
1168+
1169+ _session = None # Not checked out until '__enter__'.
1170+
1171+ def __init__ (self , database ):
1172+ if not isinstance (database , Database ):
1173+ raise TypeError (
1174+ "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}" .format (
1175+ class_name = self .__class__ .__name__ ,
1176+ expected_class_name = Database .__name__ ,
1177+ actual_class_name = database .__class__ .__name__ ,
1178+ )
1179+ )
1180+
1181+ self ._database = database
1182+
1183+ def __enter__ (self ):
1184+ self ._session = self ._database ._session_manager .get_session_for_read_write ()
1185+ return self ._session
1186+
1187+ def __exit__ (self , * ignored ):
1188+ self ._database ._session_manager .put_session (self ._session )
1189+
1190+
11631191class BatchCheckout (object ):
11641192 """Context manager for using a batch from a database.
11651193
@@ -1194,6 +1222,15 @@ def __init__(
11941222 isolation_level = TransactionOptions .IsolationLevel .ISOLATION_LEVEL_UNSPECIFIED ,
11951223 ** kw ,
11961224 ):
1225+ if not isinstance (database , Database ):
1226+ raise TypeError (
1227+ "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}" .format (
1228+ class_name = self .__class__ .__name__ ,
1229+ expected_class_name = Database .__name__ ,
1230+ actual_class_name = database .__class__ .__name__ ,
1231+ )
1232+ )
1233+
11971234 self ._database = database
11981235 self ._session = self ._batch = None
11991236 if request_options is None :
@@ -1209,10 +1246,8 @@ def __init__(
12091246
12101247 def __enter__ (self ):
12111248 """Begin ``with`` block."""
1212- current_span = get_current_span ()
1213- session = self ._session = self ._database ._pool .get ()
1214- add_span_event (current_span , "Using session" , {"id" : session .session_id })
1215- batch = self ._batch = Batch (session )
1249+ self ._session = self ._database ._session_manager .get_session_for_read_only ()
1250+ batch = self ._batch = Batch (self ._session )
12161251 if self ._request_options .transaction_tag :
12171252 batch .transaction_tag = self ._request_options .transaction_tag
12181253 return batch
@@ -1235,13 +1270,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
12351270 "CommitStats: {}" .format (self ._batch .commit_stats ),
12361271 extra = {"commit_stats" : self ._batch .commit_stats },
12371272 )
1238- self ._database ._pool .put (self ._session )
1239- current_span = get_current_span ()
1240- add_span_event (
1241- current_span ,
1242- "Returned session to pool" ,
1243- {"id" : self ._session .session_id },
1244- )
1273+ self ._database ._session_manager .put_session (self ._session )
12451274
12461275
12471276class MutationGroupsCheckout (object ):
@@ -1258,23 +1287,26 @@ class MutationGroupsCheckout(object):
12581287 """
12591288
12601289 def __init__ (self , database ):
1290+ if not isinstance (database , Database ):
1291+ raise TypeError (
1292+ "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}" .format (
1293+ class_name = self .__class__ .__name__ ,
1294+ expected_class_name = Database .__name__ ,
1295+ actual_class_name = database .__class__ .__name__ ,
1296+ )
1297+ )
1298+
12611299 self ._database = database
12621300 self ._session = None
12631301
12641302 def __enter__ (self ):
12651303 """Begin ``with`` block."""
1266- session = self ._session = self ._database ._pool . get ()
1267- return MutationGroups (session )
1304+ self ._session = self ._database ._session_manager . get_session_for_read_write ()
1305+ return MutationGroups (self . _session )
12681306
12691307 def __exit__ (self , exc_type , exc_val , exc_tb ):
12701308 """End ``with`` block."""
1271- if isinstance (exc_val , NotFound ):
1272- # If NotFound exception occurs inside the with block
1273- # then we validate if the session still exists.
1274- if not self ._session .exists ():
1275- self ._session = self ._database ._pool ._new_session ()
1276- self ._session .create ()
1277- self ._database ._pool .put (self ._session )
1309+ self ._database ._session_manager .put_session (self ._session )
12781310
12791311
12801312class SnapshotCheckout (object ):
@@ -1296,24 +1328,27 @@ class SnapshotCheckout(object):
12961328 """
12971329
12981330 def __init__ (self , database , ** kw ):
1331+ if not isinstance (database , Database ):
1332+ raise TypeError (
1333+ "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}" .format (
1334+ class_name = self .__class__ .__name__ ,
1335+ expected_class_name = Database .__name__ ,
1336+ actual_class_name = database .__class__ .__name__ ,
1337+ )
1338+ )
1339+
12991340 self ._database = database
13001341 self ._session = None
13011342 self ._kw = kw
13021343
13031344 def __enter__ (self ):
13041345 """Begin ``with`` block."""
1305- session = self ._session = self ._database ._pool . get ()
1306- return Snapshot (session , ** self ._kw )
1346+ self ._session = self ._database ._session_manager . get_session_for_read_only ()
1347+ return Snapshot (self . _session , ** self ._kw )
13071348
13081349 def __exit__ (self , exc_type , exc_val , exc_tb ):
13091350 """End ``with`` block."""
1310- if isinstance (exc_val , NotFound ):
1311- # If NotFound exception occurs inside the with block
1312- # then we validate if the session still exists.
1313- if not self ._session .exists ():
1314- self ._session = self ._database ._pool ._new_session ()
1315- self ._session .create ()
1316- self ._database ._pool .put (self ._session )
1351+ self ._database ._session_manager .put_session (self ._session )
13171352
13181353
13191354class BatchSnapshot (object ):
0 commit comments