Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit a7ff9f9

Browse files
committed
feat: Multiplexed sessions - Update pools so they don't use deprecated database.session()
Signed-off-by: Taylor Curran <taylor.curran@improving.com>
1 parent 37553c9 commit a7ff9f9

File tree

2 files changed

+40
-42
lines changed

2 files changed

+40
-42
lines changed

google/cloud/spanner_v1/pool.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
from google.cloud.exceptions import NotFound
2222
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
23-
from google.cloud.spanner_v1 import Session
23+
from google.cloud.spanner_v1 import Session as SessionProto
24+
from google.cloud.spanner_v1.session import Session
2425
from google.cloud.spanner_v1._helpers import (
2526
_metadata_with_prefix,
2627
_metadata_with_leader_aware_routing,
@@ -130,9 +131,9 @@ def _new_session(self):
130131
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
131132
:returns: new session instance.
132133
"""
133-
return self._database.session(
134-
labels=self.labels, database_role=self.database_role
135-
)
134+
135+
role = self.database_role or self._database.database_role
136+
return Session(database=self._database, labels=self.labels, database_role=role)
136137

137138
def session(self, **kwargs):
138139
"""Check out a session from the pool.
@@ -240,7 +241,7 @@ def bind(self, database):
240241
request = BatchCreateSessionsRequest(
241242
database=database.name,
242243
session_count=requested_session_count,
243-
session_template=Session(creator_role=self.database_role),
244+
session_template=SessionProto(creator_role=self.database_role),
244245
)
245246

246247
observability_options = getattr(self._database, "observability_options", None)
@@ -322,7 +323,7 @@ def get(self, timeout=None):
322323
"Session is not valid, recreating it",
323324
span_event_attributes,
324325
)
325-
session = self._database.session()
326+
session = self._new_session()
326327
session.create()
327328
# Replacing with the updated session.id.
328329
span_event_attributes["session.id"] = session._session_id
@@ -540,7 +541,7 @@ def bind(self, database):
540541
request = BatchCreateSessionsRequest(
541542
database=database.name,
542543
session_count=self.size,
543-
session_template=Session(creator_role=self.database_role),
544+
session_template=SessionProto(creator_role=self.database_role),
544545
)
545546

546547
span_event_attributes = {"kind": type(self).__name__}

tests/unit/test_pool.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID
2727

2828
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
29+
from tests._builders import build_database
2930
from tests._helpers import (
3031
OpenTelemetryBase,
3132
LIB_VERSION,
@@ -94,38 +95,35 @@ def test_clear_abstract(self):
9495

9596
def test__new_session_wo_labels(self):
9697
pool = self._make_one()
97-
database = pool._database = _make_database("name")
98-
session = _make_session()
99-
database.session.return_value = session
98+
database = pool._database = build_database()
10099

101100
new_session = pool._new_session()
102101

103-
self.assertIs(new_session, session)
104-
database.session.assert_called_once_with(labels={}, database_role=None)
102+
self.assertEqual(new_session._database, database)
103+
self.assertEqual(new_session.labels, {})
104+
self.assertIsNone(new_session.database_role)
105105

106106
def test__new_session_w_labels(self):
107107
labels = {"foo": "bar"}
108108
pool = self._make_one(labels=labels)
109-
database = pool._database = _make_database("name")
110-
session = _make_session()
111-
database.session.return_value = session
109+
database = pool._database = build_database()
112110

113111
new_session = pool._new_session()
114112

115-
self.assertIs(new_session, session)
116-
database.session.assert_called_once_with(labels=labels, database_role=None)
113+
self.assertEqual(new_session._database, database)
114+
self.assertEqual(new_session.labels, labels)
115+
self.assertIsNone(new_session.database_role)
117116

118117
def test__new_session_w_database_role(self):
119118
database_role = "dummy-role"
120119
pool = self._make_one(database_role=database_role)
121-
database = pool._database = _make_database("name")
122-
session = _make_session()
123-
database.session.return_value = session
120+
database = pool._database = build_database()
124121

125122
new_session = pool._new_session()
126123

127-
self.assertIs(new_session, session)
128-
database.session.assert_called_once_with(labels={}, database_role=database_role)
124+
self.assertEqual(new_session._database, database)
125+
self.assertEqual(new_session.labels, {})
126+
self.assertEqual(new_session.database_role, database_role)
129127

130128
def test_session_wo_kwargs(self):
131129
from google.cloud.spanner_v1.pool import SessionCheckout
@@ -215,7 +213,7 @@ def test_get_active(self):
215213
pool = self._make_one(size=4)
216214
database = _Database("name")
217215
SESSIONS = sorted([_Session(database) for i in range(0, 4)])
218-
database._sessions.extend(SESSIONS)
216+
pool._new_session = mock.Mock(side_effect=SESSIONS)
219217
pool.bind(database)
220218

221219
# check if sessions returned in LIFO order
@@ -232,7 +230,7 @@ def test_get_non_expired(self):
232230
SESSIONS = sorted(
233231
[_Session(database, last_use_time=last_use_time) for i in range(0, 4)]
234232
)
235-
database._sessions.extend(SESSIONS)
233+
pool._new_session = mock.Mock(side_effect=SESSIONS)
236234
pool.bind(database)
237235

238236
# check if sessions returned in LIFO order
@@ -339,8 +337,7 @@ def test_spans_pool_bind(self):
339337
# you have an empty pool.
340338
pool = self._make_one(size=1)
341339
database = _Database("name")
342-
SESSIONS = []
343-
database._sessions.extend(SESSIONS)
340+
pool._new_session = mock.Mock(side_effect=Exception("test"))
344341
fauxSession = mock.Mock()
345342
setattr(fauxSession, "_database", database)
346343
try:
@@ -386,8 +383,8 @@ def test_spans_pool_bind(self):
386383
(
387384
"exception",
388385
{
389-
"exception.type": "IndexError",
390-
"exception.message": "pop from empty list",
386+
"exception.type": "Exception",
387+
"exception.message": "test",
391388
"exception.stacktrace": "EPHEMERAL",
392389
"exception.escaped": "False",
393390
},
@@ -397,8 +394,8 @@ def test_spans_pool_bind(self):
397394
(
398395
"exception",
399396
{
400-
"exception.type": "IndexError",
401-
"exception.message": "pop from empty list",
397+
"exception.type": "Exception",
398+
"exception.message": "test",
402399
"exception.stacktrace": "EPHEMERAL",
403400
"exception.escaped": "False",
404401
},
@@ -412,7 +409,7 @@ def test_get_expired(self):
412409
last_use_time = datetime.utcnow() - timedelta(minutes=65)
413410
SESSIONS = [_Session(database, last_use_time=last_use_time)] * 5
414411
SESSIONS[0]._exists = False
415-
database._sessions.extend(SESSIONS)
412+
pool._new_session = mock.Mock(side_effect=SESSIONS)
416413
pool.bind(database)
417414

418415
session = pool.get()
@@ -475,7 +472,7 @@ def test_clear(self):
475472
pool = self._make_one()
476473
database = _Database("name")
477474
SESSIONS = [_Session(database)] * 10
478-
database._sessions.extend(SESSIONS)
475+
pool._new_session = mock.Mock(side_effect=SESSIONS)
479476
pool.bind(database)
480477
self.assertTrue(pool._sessions.full())
481478

@@ -539,7 +536,7 @@ def test_ctor_explicit_w_database_role_in_db(self):
539536
def test_get_empty(self):
540537
pool = self._make_one()
541538
database = _Database("name")
542-
database._sessions.append(_Session(database))
539+
pool._new_session = mock.Mock(return_value=_Session(database))
543540
pool.bind(database)
544541

545542
session = pool.get()
@@ -559,7 +556,7 @@ def test_spans_get_empty_pool(self):
559556
pool = self._make_one()
560557
database = _Database("name")
561558
session1 = _Session(database)
562-
database._sessions.append(session1)
559+
pool._new_session = mock.Mock(return_value=session1)
563560
pool.bind(database)
564561

565562
with trace_call("pool.Get", session1):
@@ -630,7 +627,7 @@ def test_get_non_empty_session_expired(self):
630627
database = _Database("name")
631628
previous = _Session(database, exists=False)
632629
newborn = _Session(database)
633-
database._sessions.append(newborn)
630+
pool._new_session = mock.Mock(return_value=newborn)
634631
pool.bind(database)
635632
pool.put(previous)
636633

@@ -811,7 +808,7 @@ def test_get_hit_no_ping(self):
811808
pool = self._make_one(size=4)
812809
database = _Database("name")
813810
SESSIONS = [_Session(database)] * 4
814-
database._sessions.extend(SESSIONS)
811+
pool._new_session = mock.Mock(side_effect=SESSIONS)
815812
pool.bind(database)
816813
self.reset()
817814

@@ -830,7 +827,7 @@ def test_get_hit_w_ping(self):
830827
pool = self._make_one(size=4)
831828
database = _Database("name")
832829
SESSIONS = [_Session(database)] * 4
833-
database._sessions.extend(SESSIONS)
830+
pool._new_session = mock.Mock(side_effect=SESSIONS)
834831

835832
sessions_created = datetime.datetime.utcnow() - datetime.timedelta(seconds=4000)
836833

@@ -855,7 +852,7 @@ def test_get_hit_w_ping_expired(self):
855852
database = _Database("name")
856853
SESSIONS = [_Session(database)] * 5
857854
SESSIONS[0]._exists = False
858-
database._sessions.extend(SESSIONS)
855+
pool._new_session = mock.Mock(side_effect=SESSIONS)
859856

860857
sessions_created = datetime.datetime.utcnow() - datetime.timedelta(seconds=4000)
861858

@@ -974,7 +971,7 @@ def test_clear(self):
974971
pool = self._make_one()
975972
database = _Database("name")
976973
SESSIONS = [_Session(database)] * 10
977-
database._sessions.extend(SESSIONS)
974+
pool._new_session = mock.Mock(side_effect=SESSIONS)
978975
pool.bind(database)
979976
self.reset()
980977
self.assertTrue(pool._sessions.full())
@@ -1016,7 +1013,7 @@ def test_ping_oldest_stale_but_exists(self):
10161013
pool = self._make_one(size=1)
10171014
database = _Database("name")
10181015
SESSIONS = [_Session(database)] * 1
1019-
database._sessions.extend(SESSIONS)
1016+
pool._new_session = mock.Mock(side_effect=SESSIONS)
10201017
pool.bind(database)
10211018

10221019
later = datetime.datetime.utcnow() + datetime.timedelta(seconds=4000)
@@ -1034,7 +1031,7 @@ def test_ping_oldest_stale_and_not_exists(self):
10341031
database = _Database("name")
10351032
SESSIONS = [_Session(database)] * 2
10361033
SESSIONS[0]._exists = False
1037-
database._sessions.extend(SESSIONS)
1034+
pool._new_session = mock.Mock(side_effect=SESSIONS)
10381035
pool.bind(database)
10391036
self.reset()
10401037

@@ -1055,7 +1052,7 @@ def test_spans_get_and_leave_empty_pool(self):
10551052
pool = self._make_one()
10561053
database = _Database("name")
10571054
session1 = _Session(database)
1058-
database._sessions.append(session1)
1055+
pool._new_session = mock.Mock(side_effect=[session1, Exception])
10591056
try:
10601057
pool.bind(database)
10611058
except Exception:

0 commit comments

Comments
 (0)