2626from google .cloud .spanner_v1 .request_id_header import REQ_RAND_PROCESS_ID
2727
2828from google .cloud .spanner_v1 ._opentelemetry_tracing import trace_call
29+ from tests ._builders import build_database
2930from tests ._helpers import (
3031 OpenTelemetryBase ,
3132 LIB_VERSION ,
@@ -94,38 +95,35 @@ def test_clear_abstract(self):
9495
9596 def test__new_session_wo_labels (self ):
9697 pool = self ._make_one ()
97- database = pool ._database = _make_database ("name" )
98- session = _make_session ()
99- database .session .return_value = session
98+ database = pool ._database = build_database ()
10099
101100 new_session = pool ._new_session ()
102101
103- self .assertIs (new_session , session )
104- database .session .assert_called_once_with (labels = {}, database_role = None )
102+ self .assertEqual (new_session ._database , database )
103+ self .assertEqual (new_session .labels , {})
104+ self .assertIsNone (new_session .database_role )
105105
106106 def test__new_session_w_labels (self ):
107107 labels = {"foo" : "bar" }
108108 pool = self ._make_one (labels = labels )
109- database = pool ._database = _make_database ("name" )
110- session = _make_session ()
111- database .session .return_value = session
109+ database = pool ._database = build_database ()
112110
113111 new_session = pool ._new_session ()
114112
115- self .assertIs (new_session , session )
116- database .session .assert_called_once_with (labels = labels , database_role = None )
113+ self .assertEqual (new_session ._database , database )
114+ self .assertEqual (new_session .labels , labels )
115+ self .assertIsNone (new_session .database_role )
117116
118117 def test__new_session_w_database_role (self ):
119118 database_role = "dummy-role"
120119 pool = self ._make_one (database_role = database_role )
121- database = pool ._database = _make_database ("name" )
122- session = _make_session ()
123- database .session .return_value = session
120+ database = pool ._database = build_database ()
124121
125122 new_session = pool ._new_session ()
126123
127- self .assertIs (new_session , session )
128- database .session .assert_called_once_with (labels = {}, database_role = database_role )
124+ self .assertEqual (new_session ._database , database )
125+ self .assertEqual (new_session .labels , {})
126+ self .assertEqual (new_session .database_role , database_role )
129127
130128 def test_session_wo_kwargs (self ):
131129 from google .cloud .spanner_v1 .pool import SessionCheckout
@@ -215,7 +213,7 @@ def test_get_active(self):
215213 pool = self ._make_one (size = 4 )
216214 database = _Database ("name" )
217215 SESSIONS = sorted ([_Session (database ) for i in range (0 , 4 )])
218- database . _sessions . extend ( SESSIONS )
216+ pool . _new_session = mock . Mock ( side_effect = SESSIONS )
219217 pool .bind (database )
220218
221219 # check if sessions returned in LIFO order
@@ -232,7 +230,7 @@ def test_get_non_expired(self):
232230 SESSIONS = sorted (
233231 [_Session (database , last_use_time = last_use_time ) for i in range (0 , 4 )]
234232 )
235- database . _sessions . extend ( SESSIONS )
233+ pool . _new_session = mock . Mock ( side_effect = SESSIONS )
236234 pool .bind (database )
237235
238236 # check if sessions returned in LIFO order
@@ -339,8 +337,7 @@ def test_spans_pool_bind(self):
339337 # you have an empty pool.
340338 pool = self ._make_one (size = 1 )
341339 database = _Database ("name" )
342- SESSIONS = []
343- database ._sessions .extend (SESSIONS )
340+ pool ._new_session = mock .Mock (side_effect = Exception ("test" ))
344341 fauxSession = mock .Mock ()
345342 setattr (fauxSession , "_database" , database )
346343 try :
@@ -386,8 +383,8 @@ def test_spans_pool_bind(self):
386383 (
387384 "exception" ,
388385 {
389- "exception.type" : "IndexError " ,
390- "exception.message" : "pop from empty list " ,
386+ "exception.type" : "Exception " ,
387+ "exception.message" : "test " ,
391388 "exception.stacktrace" : "EPHEMERAL" ,
392389 "exception.escaped" : "False" ,
393390 },
@@ -397,8 +394,8 @@ def test_spans_pool_bind(self):
397394 (
398395 "exception" ,
399396 {
400- "exception.type" : "IndexError " ,
401- "exception.message" : "pop from empty list " ,
397+ "exception.type" : "Exception " ,
398+ "exception.message" : "test " ,
402399 "exception.stacktrace" : "EPHEMERAL" ,
403400 "exception.escaped" : "False" ,
404401 },
@@ -412,7 +409,7 @@ def test_get_expired(self):
412409 last_use_time = datetime .utcnow () - timedelta (minutes = 65 )
413410 SESSIONS = [_Session (database , last_use_time = last_use_time )] * 5
414411 SESSIONS [0 ]._exists = False
415- database . _sessions . extend ( SESSIONS )
412+ pool . _new_session = mock . Mock ( side_effect = SESSIONS )
416413 pool .bind (database )
417414
418415 session = pool .get ()
@@ -475,7 +472,7 @@ def test_clear(self):
475472 pool = self ._make_one ()
476473 database = _Database ("name" )
477474 SESSIONS = [_Session (database )] * 10
478- database . _sessions . extend ( SESSIONS )
475+ pool . _new_session = mock . Mock ( side_effect = SESSIONS )
479476 pool .bind (database )
480477 self .assertTrue (pool ._sessions .full ())
481478
@@ -539,7 +536,7 @@ def test_ctor_explicit_w_database_role_in_db(self):
539536 def test_get_empty (self ):
540537 pool = self ._make_one ()
541538 database = _Database ("name" )
542- database . _sessions . append ( _Session (database ))
539+ pool . _new_session = mock . Mock ( return_value = _Session (database ))
543540 pool .bind (database )
544541
545542 session = pool .get ()
@@ -559,7 +556,7 @@ def test_spans_get_empty_pool(self):
559556 pool = self ._make_one ()
560557 database = _Database ("name" )
561558 session1 = _Session (database )
562- database . _sessions . append ( session1 )
559+ pool . _new_session = mock . Mock ( return_value = session1 )
563560 pool .bind (database )
564561
565562 with trace_call ("pool.Get" , session1 ):
@@ -630,7 +627,7 @@ def test_get_non_empty_session_expired(self):
630627 database = _Database ("name" )
631628 previous = _Session (database , exists = False )
632629 newborn = _Session (database )
633- database . _sessions . append ( newborn )
630+ pool . _new_session = mock . Mock ( return_value = newborn )
634631 pool .bind (database )
635632 pool .put (previous )
636633
@@ -811,7 +808,7 @@ def test_get_hit_no_ping(self):
811808 pool = self ._make_one (size = 4 )
812809 database = _Database ("name" )
813810 SESSIONS = [_Session (database )] * 4
814- database . _sessions . extend ( SESSIONS )
811+ pool . _new_session = mock . Mock ( side_effect = SESSIONS )
815812 pool .bind (database )
816813 self .reset ()
817814
@@ -830,7 +827,7 @@ def test_get_hit_w_ping(self):
830827 pool = self ._make_one (size = 4 )
831828 database = _Database ("name" )
832829 SESSIONS = [_Session (database )] * 4
833- database . _sessions . extend ( SESSIONS )
830+ pool . _new_session = mock . Mock ( side_effect = SESSIONS )
834831
835832 sessions_created = datetime .datetime .utcnow () - datetime .timedelta (seconds = 4000 )
836833
@@ -855,7 +852,7 @@ def test_get_hit_w_ping_expired(self):
855852 database = _Database ("name" )
856853 SESSIONS = [_Session (database )] * 5
857854 SESSIONS [0 ]._exists = False
858- database . _sessions . extend ( SESSIONS )
855+ pool . _new_session = mock . Mock ( side_effect = SESSIONS )
859856
860857 sessions_created = datetime .datetime .utcnow () - datetime .timedelta (seconds = 4000 )
861858
@@ -974,7 +971,7 @@ def test_clear(self):
974971 pool = self ._make_one ()
975972 database = _Database ("name" )
976973 SESSIONS = [_Session (database )] * 10
977- database . _sessions . extend ( SESSIONS )
974+ pool . _new_session = mock . Mock ( side_effect = SESSIONS )
978975 pool .bind (database )
979976 self .reset ()
980977 self .assertTrue (pool ._sessions .full ())
@@ -1016,7 +1013,7 @@ def test_ping_oldest_stale_but_exists(self):
10161013 pool = self ._make_one (size = 1 )
10171014 database = _Database ("name" )
10181015 SESSIONS = [_Session (database )] * 1
1019- database . _sessions . extend ( SESSIONS )
1016+ pool . _new_session = mock . Mock ( side_effect = SESSIONS )
10201017 pool .bind (database )
10211018
10221019 later = datetime .datetime .utcnow () + datetime .timedelta (seconds = 4000 )
@@ -1034,7 +1031,7 @@ def test_ping_oldest_stale_and_not_exists(self):
10341031 database = _Database ("name" )
10351032 SESSIONS = [_Session (database )] * 2
10361033 SESSIONS [0 ]._exists = False
1037- database . _sessions . extend ( SESSIONS )
1034+ pool . _new_session = mock . Mock ( side_effect = SESSIONS )
10381035 pool .bind (database )
10391036 self .reset ()
10401037
@@ -1055,7 +1052,7 @@ def test_spans_get_and_leave_empty_pool(self):
10551052 pool = self ._make_one ()
10561053 database = _Database ("name" )
10571054 session1 = _Session (database )
1058- database . _sessions . append ( session1 )
1055+ pool . _new_session = mock . Mock ( side_effect = [ session1 , Exception ] )
10591056 try :
10601057 pool .bind (database )
10611058 except Exception :
0 commit comments