Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit b792e80

Browse files
committed
fix(sessions): resolve async deadlock in multiplexed session manager
1 parent 0aa1eda commit b792e80

File tree

11 files changed

+89
-44
lines changed

11 files changed

+89
-44
lines changed

google/cloud/spanner_v1/_async/database_sessions_manager.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from datetime import timedelta
2020
from enum import Enum
2121
from os import getenv
22-
import threading
2322
from threading import Thread
2423
from typing import Optional
2524
from weakref import ref
@@ -71,8 +70,7 @@ def __init__(self, database, pool):
7170
self._pool = pool
7271
self._multiplexed_session: Optional[Session] = None
7372
self._multiplexed_session_thread: Optional[CrossSync.Task] = None
74-
# Use threading.Lock because this is accessed in a synchronous maintenance thread
75-
self._multiplexed_session_lock: threading.Lock = threading.Lock()
73+
self._multiplexed_session_lock: CrossSync.Lock = CrossSync.Lock()
7674
self._multiplexed_session_terminate_event: CrossSync.Event = CrossSync.Event()
7775

7876
@CrossSync.convert
@@ -119,7 +117,7 @@ async def _get_multiplexed_session(self) -> Session:
119117
120118
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
121119
:returns: a multiplexed session."""
122-
with CrossSync.rm_aio(self._multiplexed_session_lock):
120+
async with self._multiplexed_session_lock:
123121
if self._multiplexed_session is None:
124122
self._multiplexed_session = await self._build_multiplexed_session()
125123
self._multiplexed_session_thread = self._build_maintenance_thread()
@@ -193,7 +191,7 @@ async def _maintain_multiplexed_session(session_manager_ref) -> None:
193191
if time() - session_created_time < refresh_interval_seconds:
194192
await CrossSync.sleep(polling_interval_seconds)
195193
continue
196-
with manager._multiplexed_session_lock:
194+
async with manager._multiplexed_session_lock:
197195
await CrossSync.run_if_async(manager._multiplexed_session.delete)
198196
manager._multiplexed_session = (
199197
await manager._build_multiplexed_session()

google/cloud/spanner_v1/batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def wrapped_method():
243243
max_commit_delay=max_commit_delay,
244244
request_options=request_options,
245245
)
246-
(call_metadata, error_augmenter) = database.with_error_augmentation(
246+
call_metadata, error_augmenter = database.with_error_augmentation(
247247
getattr(database, "_next_nth_request", 0), 1, metadata, span
248248
)
249249
commit_method = functools.partial(

google/cloud/spanner_v1/database.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@
8282
trace_call,
8383
)
8484
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
85-
8685
from google.cloud.spanner_v1.table import Table
8786

8887
SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data"
@@ -211,11 +210,9 @@ def __init__(
211210
def _resource_info(self):
212211
"""Resource information for metrics labels."""
213212
return {
214-
"project": (
215-
self._instance._client.project
216-
if self._instance and self._instance._client
217-
else None
218-
),
213+
"project": self._instance._client.project
214+
if self._instance and self._instance._client
215+
else None,
219216
"instance": self._instance.instance_id if self._instance else None,
220217
"database": self.database_id,
221218
}
@@ -533,7 +530,7 @@ def with_error_augmentation(
533530
tuple: (metadata_list, context_manager)"""
534531
if span is None:
535532
span = get_current_span()
536-
(metadata, request_id) = _metadata_with_request_id_and_req_id(
533+
metadata, request_id = _metadata_with_request_id_and_req_id(
537534
self._nth_client_id,
538535
self._channel_id,
539536
nth_request,
@@ -810,7 +807,7 @@ def execute_pdml():
810807
session = self._sessions_manager.get_session(transaction_type)
811808
try:
812809
add_span_event(span, "Starting BeginTransaction")
813-
(call_metadata, error_augmenter) = self.with_error_augmentation(
810+
call_metadata, error_augmenter = self.with_error_augmentation(
814811
self._next_nth_request, 1, metadata, span
815812
)
816813
with error_augmenter:

google/cloud/spanner_v1/database_sessions_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from datetime import timedelta
2020
from enum import Enum
2121
from os import getenv
22-
import threading
2322
from threading import Thread
2423
from typing import Optional
2524
from weakref import ref
@@ -69,7 +68,9 @@ def __init__(self, database, pool):
6968
self._pool = pool
7069
self._multiplexed_session: Optional[Session] = None
7170
self._multiplexed_session_thread: Optional[CrossSync._Sync_Impl.Task] = None
72-
self._multiplexed_session_lock: threading.Lock = threading.Lock()
71+
self._multiplexed_session_lock: CrossSync._Sync_Impl.Lock = (
72+
CrossSync._Sync_Impl.Lock()
73+
)
7374
self._multiplexed_session_terminate_event: CrossSync._Sync_Impl.Event = (
7475
CrossSync._Sync_Impl.Event()
7576
)

google/cloud/spanner_v1/instance.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,9 @@ def database(
479479
database_role=database_role,
480480
enable_drop_protection=enable_drop_protection,
481481
)
482-
db._pool.bind(db)
482+
res = db._pool.bind(db)
483+
if res is not None:
484+
res
483485
return db
484486

485487
def list_databases(self, page_size=None):

google/cloud/spanner_v1/pool.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def _fill_pool(self):
304304
f"Creating {request.session_count} sessions",
305305
span_event_attributes,
306306
)
307-
(call_metadata, error_augmenter) = database.with_error_augmentation(
307+
call_metadata, error_augmenter = database.with_error_augmentation(
308308
database._next_nth_request, 1, metadata, span
309309
)
310310
with error_augmenter:
@@ -612,7 +612,7 @@ def bind(self, database):
612612
) as span, MetricsCapture(self._resource_info):
613613
returned_session_count = 0
614614
while returned_session_count < self.size:
615-
(call_metadata, error_augmenter) = database.with_error_augmentation(
615+
call_metadata, error_augmenter = database.with_error_augmentation(
616616
database._next_nth_request, 1, metadata, span
617617
)
618618
with error_augmenter:
@@ -654,7 +654,7 @@ def get(self, timeout=None):
654654
ping_after = None
655655
session = None
656656
try:
657-
(ping_after, session) = CrossSync._Sync_Impl.queue_get(
657+
ping_after, session = CrossSync._Sync_Impl.queue_get(
658658
self._sessions, block=True, timeout=timeout
659659
)
660660
except CrossSync._Sync_Impl.QueueEmpty as e:
@@ -698,9 +698,7 @@ def clear(self):
698698
"""Delete all sessions in the pool."""
699699
while True:
700700
try:
701-
(_, session) = CrossSync._Sync_Impl.queue_get(
702-
self._sessions, block=False
703-
)
701+
_, session = CrossSync._Sync_Impl.queue_get(self._sessions, block=False)
704702
except CrossSync._Sync_Impl.QueueEmpty:
705703
break
706704
else:
@@ -713,7 +711,7 @@ def ping(self):
713711
or during the "idle" phase of an event loop."""
714712
while True:
715713
try:
716-
(ping_after, session) = CrossSync._Sync_Impl.queue_get(
714+
ping_after, session = CrossSync._Sync_Impl.queue_get(
717715
self._sessions, block=False
718716
)
719717
except CrossSync._Sync_Impl.QueueEmpty:

google/cloud/spanner_v1/session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def create(self):
188188
observability_options=observability_options,
189189
metadata=metadata,
190190
) as span, MetricsCapture(self._resource_info):
191-
(call_metadata, error_augmenter) = database.with_error_augmentation(
191+
call_metadata, error_augmenter = database.with_error_augmentation(
192192
nth_request, 1, metadata, span
193193
)
194194
with error_augmenter:
@@ -232,7 +232,7 @@ def exists(self):
232232
observability_options=observability_options,
233233
metadata=metadata,
234234
) as span, MetricsCapture(self._resource_info):
235-
(call_metadata, error_augmenter) = database.with_error_augmentation(
235+
call_metadata, error_augmenter = database.with_error_augmentation(
236236
nth_request, 1, metadata, span
237237
)
238238
with error_augmenter:
@@ -283,7 +283,7 @@ def delete(self):
283283
observability_options=observability_options,
284284
metadata=metadata,
285285
) as span, MetricsCapture(self._resource_info):
286-
(call_metadata, error_augmenter) = database.with_error_augmentation(
286+
call_metadata, error_augmenter = database.with_error_augmentation(
287287
nth_request, 1, metadata, span
288288
)
289289
with error_augmenter:
@@ -300,7 +300,7 @@ def ping(self):
300300
metadata = _metadata_with_prefix(database.name)
301301
nth_request = database._next_nth_request
302302
with trace_call("CloudSpanner.Session.ping", self) as span:
303-
(call_metadata, error_augmenter) = database.with_error_augmentation(
303+
call_metadata, error_augmenter = database.with_error_augmentation(
304304
nth_request, 1, metadata, span
305305
)
306306
with error_augmenter:

google/cloud/spanner_v1/snapshot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def execute_sql(
322322
raise ValueError("Transaction has not begun.")
323323
if params is not None:
324324
params_pb = Struct(
325-
fields={key: _make_value_pb(value) for (key, value) in params.items()}
325+
fields={key: _make_value_pb(value) for key, value in params.items()}
326326
)
327327
else:
328328
params_pb = {}
@@ -513,7 +513,7 @@ def partition_query(
513513
raise ValueError("Cannot partition a single-use transaction.")
514514
if params is not None:
515515
params_pb = Struct(
516-
fields={key: _make_value_pb(value) for (key, value) in params.items()}
516+
fields={key: _make_value_pb(value) for key, value in params.items()}
517517
)
518518
else:
519519
params_pb = Struct()
@@ -614,7 +614,7 @@ def wrapped_method():
614614
begin_transaction_request = BeginTransactionRequest(
615615
**begin_request_kwargs
616616
)
617-
(call_metadata, error_augmenter) = database.with_error_augmentation(
617+
call_metadata, error_augmenter = database.with_error_augmentation(
618618
nth_request, attempt.increment(), metadata, span
619619
)
620620
begin_transaction_method = functools.partial(

google/cloud/spanner_v1/streamed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _consume_next(self):
147147

148148
def __iter__(self):
149149
while True:
150-
(iter_rows, self._rows[:]) = (self._rows[:], ())
150+
iter_rows, self._rows[:] = (self._rows[:], ())
151151
while iter_rows:
152152
yield iter_rows.pop(0)
153153
if self._done:
@@ -230,7 +230,7 @@ def to_dict_list(self):
230230
rows.append(
231231
{
232232
column: value
233-
for (column, value) in zip(
233+
for column, value in zip(
234234
[column.name for column in self._metadata.row_type.fields], row
235235
)
236236
}
@@ -291,7 +291,7 @@ def _merge_array(lhs, rhs, type_):
291291
if element_type.code in _UNMERGEABLE_TYPES:
292292
lhs.list_value.values.extend(rhs.list_value.values)
293293
return lhs
294-
(lhs, rhs) = (list(lhs.list_value.values), list(rhs.list_value.values))
294+
lhs, rhs = (list(lhs.list_value.values), list(rhs.list_value.values))
295295
if not len(lhs) or not len(rhs):
296296
return Value(list_value=ListValue(values=lhs + rhs))
297297
first = rhs.pop(0)
@@ -316,7 +316,7 @@ def _merge_array(lhs, rhs, type_):
316316
def _merge_struct(lhs, rhs, type_):
317317
"""Helper for '_merge_by_type'."""
318318
fields = type_.struct_type.fields
319-
(lhs, rhs) = (list(lhs.list_value.values), list(rhs.list_value.values))
319+
lhs, rhs = (list(lhs.list_value.values), list(rhs.list_value.values))
320320
if not len(lhs) or not len(rhs):
321321
return Value(list_value=ListValue(values=lhs + rhs))
322322
candidate_type = fields[len(lhs) - 1].type_

google/cloud/spanner_v1/transaction.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def rollback(self) -> None:
162162

163163
def wrapped_method(*args, **kwargs):
164164
attempt.increment()
165-
(call_metadata, error_augmenter) = database.with_error_augmentation(
165+
call_metadata, error_augmenter = database.with_error_augmentation(
166166
nth_request, attempt.value, metadata, span
167167
)
168168
rollback_method = functools.partial(
@@ -269,7 +269,7 @@ def wrapped_method(*args, **kwargs):
269269
is_multiplexed = getattr(self._session, "is_multiplexed", False)
270270
if is_multiplexed and self._precommit_token is not None:
271271
commit_request_args["precommit_token"] = self._precommit_token
272-
(call_metadata, error_augmenter) = database.with_error_augmentation(
272+
call_metadata, error_augmenter = database.with_error_augmentation(
273273
nth_request, attempt.value, metadata, span
274274
)
275275
commit_method = functools.partial(
@@ -300,7 +300,7 @@ def before_next_retry(nth_retry, delay_in_seconds):
300300
if commit_response_pb._pb.HasField("precommit_token"):
301301
add_span_event(span, commit_retry_event_name)
302302
nth_request = database._next_nth_request
303-
(call_metadata, error_augmenter) = database.with_error_augmentation(
303+
call_metadata, error_augmenter = database.with_error_augmentation(
304304
nth_request, 1, metadata, span
305305
)
306306
with error_augmenter:
@@ -338,7 +338,7 @@ def _make_params_pb(params, param_types):
338338
If ``params`` is None but ``param_types`` is not None."""
339339
if params:
340340
return Struct(
341-
fields={key: _make_value_pb(value) for (key, value) in params.items()}
341+
fields={key: _make_value_pb(value) for key, value in params.items()}
342342
)
343343
return {}
344344

@@ -417,7 +417,7 @@ def execute_update(
417417
metadata.append(
418418
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
419419
)
420-
(seqno, self._execute_sql_request_count) = (
420+
seqno, self._execute_sql_request_count = (
421421
self._execute_sql_request_count,
422422
self._execute_sql_request_count + 1,
423423
)
@@ -454,7 +454,7 @@ def execute_update(
454454

455455
def wrapped_method(*args, **kwargs):
456456
attempt.increment()
457-
(call_metadata, error_augmenter) = database.with_error_augmentation(
457+
call_metadata, error_augmenter = database.with_error_augmentation(
458458
nth_request, attempt.value, metadata
459459
)
460460
execute_sql_method = functools.partial(
@@ -544,7 +544,7 @@ def batch_update(
544544
if isinstance(statement, str):
545545
parsed.append(ExecuteBatchDmlRequest.Statement(sql=statement))
546546
else:
547-
(dml, params, param_types) = statement
547+
dml, params, param_types = statement
548548
params_pb = self._make_params_pb(params, param_types)
549549
parsed.append(
550550
ExecuteBatchDmlRequest.Statement(
@@ -556,7 +556,7 @@ def batch_update(
556556
metadata.append(
557557
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
558558
)
559-
(seqno, self._execute_sql_request_count) = (
559+
seqno, self._execute_sql_request_count = (
560560
self._execute_sql_request_count,
561561
self._execute_sql_request_count + 1,
562562
)
@@ -590,7 +590,7 @@ def batch_update(
590590

591591
def wrapped_method(*args, **kwargs):
592592
attempt.increment()
593-
(call_metadata, error_augmenter) = database.with_error_augmentation(
593+
call_metadata, error_augmenter = database.with_error_augmentation(
594594
nth_request, attempt.value, metadata
595595
)
596596
execute_batch_dml_method = functools.partial(

0 commit comments

Comments
 (0)