Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit b5b65d8

Browse files
committed
feat: Add support for multiplexed sessions - part 3 (partitioned transactions)
- Enable use of multiplexed sessions for partitioned transactions. - Add handling for partitioned transaction exceptions. - Add transaction type enumeration. Signed-off-by: currantw <taylor.curran@improving.com>
1 parent 9c91bbe commit b5b65d8

18 files changed

+1101
-909
lines changed

google/cloud/spanner_dbapi/connection.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper
2727
from google.cloud.spanner_dbapi.cursor import Cursor
2828
from google.cloud.spanner_v1 import RequestOptions, TransactionOptions
29+
from google.cloud.spanner_v1.session_options import TransactionType
2930
from google.cloud.spanner_v1.snapshot import Snapshot
3031

3132
from google.cloud.spanner_dbapi.exceptions import (
@@ -356,11 +357,12 @@ def _session_checkout(self):
356357
raise ValueError("Database needs to be passed for this operation")
357358

358359
if not self._session:
359-
self._session = (
360-
self.database._session_manager.get_session_for_read_only()
360+
transaction_type = (
361+
TransactionType.READ_ONLY
361362
if self.read_only
362-
else self.database._session_manager.get_session_for_read_write()
363+
else TransactionType.READ_WRITE
363364
)
365+
self._session = self.database._session_manager.get_session(transaction_type)
364366

365367
return self._session
366368

@@ -628,7 +630,6 @@ def partition_query(
628630
self._partitioned_query_validation(partitioned_query, statement)
629631

630632
batch_snapshot = self._database.batch_snapshot()
631-
partition_ids = []
632633
partitions = list(
633634
batch_snapshot.generate_query_batches(
634635
partitioned_query,
@@ -639,6 +640,8 @@ def partition_query(
639640
)
640641

641642
batch_transaction_id = batch_snapshot.get_batch_transaction_id()
643+
644+
partition_ids = []
642645
for partition in partitions:
643646
partition_ids.append(
644647
partition_helper.encode_to_string(batch_transaction_id, partition)

google/cloud/spanner_v1/database.py

Lines changed: 109 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
4242
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
4343
from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager
44+
from google.cloud.spanner_v1.session import Session
45+
from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType
4446
from google.cloud.spanner_v1.transaction import BatchTransactionId
4547
from google.cloud.spanner_v1 import ExecuteSqlRequest
4648
from google.cloud.spanner_v1 import Type
@@ -60,7 +62,6 @@
6062
from google.cloud.spanner_v1.keyset import KeySet
6163
from google.cloud.spanner_v1.merged_result_set import MergedResultSet
6264
from google.cloud.spanner_v1.pool import BurstyPool
63-
from google.cloud.spanner_v1.session import Session
6465
from google.cloud.spanner_v1.snapshot import _restart_on_unavailable
6566
from google.cloud.spanner_v1.snapshot import Snapshot
6667
from google.cloud.spanner_v1.streamed import StreamedResultSet
@@ -74,7 +75,6 @@
7475
)
7576
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
7677

77-
7878
SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data"
7979

8080

@@ -192,9 +192,7 @@ def __init__(
192192
pool = BurstyPool(database_role=database_role)
193193
pool.bind(self)
194194

195-
self._session_manager = DatabaseSessionsManager(
196-
database=self, pool=pool, logger=self.logger
197-
)
195+
self._session_manager = DatabaseSessionsManager(database=self, pool=pool)
198196

199197
@classmethod
200198
def from_pb(cls, database_pb, instance, pool=None):
@@ -449,6 +447,15 @@ def spanner_api(self):
449447
)
450448
return self._spanner_api
451449

450+
@property
451+
def session_options(self) -> SessionOptions:
452+
"""Session options for the database.
453+
454+
:rtype: :class:`~google.cloud.spanner_v1.session_options.SessionOptions`
455+
:returns: the session options
456+
"""
457+
return self._instance._client.session_options
458+
452459
def __eq__(self, other):
453460
if not isinstance(other, self.__class__):
454461
return NotImplemented
@@ -709,11 +716,27 @@ def execute_pdml():
709716
"CloudSpanner.Database.execute_partitioned_pdml",
710717
observability_options=self.observability_options,
711718
) as span, MetricsCapture():
712-
with SessionCheckout(self) as session:
719+
transaction_type = TransactionType.PARTITIONED
720+
with SessionCheckout(self, transaction_type) as session:
713721
add_span_event(span, "Starting BeginTransaction")
714-
txn = api.begin_transaction(
715-
session=session.name, options=txn_options, metadata=metadata
716-
)
722+
723+
try:
724+
txn = api.begin_transaction(
725+
session=session.name, options=txn_options, metadata=metadata
726+
)
727+
728+
# If partitioned DML is not supported with multiplexed sessions,
729+
# disable multiplexed sessions for partitioned transactions before
730+
# re-raising the error.
731+
except NotImplementedError as exc:
732+
if (
733+
"Transaction type partitioned_dml not supported with multiplexed sessions"
734+
in str(exc)
735+
):
736+
self.session_options.disable_multiplexed(
737+
self.logger, transaction_type
738+
)
739+
raise exc
717740

718741
txn_selector = TransactionSelector(id=txn.id)
719742

@@ -732,8 +755,9 @@ def execute_pdml():
732755

733756
iterator = _restart_on_unavailable(
734757
method=method,
735-
trace_name="CloudSpanner.ExecuteStreamingSql",
736758
request=request,
759+
session=session,
760+
trace_name="CloudSpanner.ExecuteStreamingSql",
737761
metadata=metadata,
738762
transaction_selector=txn_selector,
739763
observability_options=self.observability_options,
@@ -746,23 +770,6 @@ def execute_pdml():
746770

747771
return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()
748772

749-
def session(self, labels=None, database_role=None):
750-
"""Factory to create a session for this database.
751-
752-
:type labels: dict (str -> str) or None
753-
:param labels: (Optional) user-assigned labels for the session.
754-
755-
:type database_role: str
756-
:param database_role: (Optional) user-assigned database_role for the session.
757-
758-
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
759-
:returns: a session bound to this database.
760-
"""
761-
# If role is specified in param, then that role is used
762-
# instead.
763-
role = database_role or self._database_role
764-
return Session(self, labels=labels, database_role=role)
765-
766773
def snapshot(self, **kw):
767774
"""Return an object which wraps a snapshot.
768775
@@ -1170,7 +1177,11 @@ class SessionCheckout(object):
11701177

11711178
_session = None # Not checked out until '__enter__'.
11721179

1173-
def __init__(self, database):
1180+
def __init__(
1181+
self,
1182+
database: Database,
1183+
transaction_type: TransactionType = TransactionType.READ_WRITE,
1184+
):
11741185
if not isinstance(database, Database):
11751186
raise TypeError(
11761187
"{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format(
@@ -1180,10 +1191,21 @@ def __init__(self, database):
11801191
)
11811192
)
11821193

1194+
if not isinstance(transaction_type, TransactionType):
1195+
raise TypeError(
1196+
"{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format(
1197+
class_name=self.__class__.__name__,
1198+
expected_class_name=TransactionType.__name__,
1199+
actual_class_name=transaction_type.__class__.__name__,
1200+
)
1201+
)
1202+
11831203
self._database = database
1204+
self._transaction_type = transaction_type
11841205

11851206
def __enter__(self):
1186-
self._session = self._database._session_manager.get_session_for_read_write()
1207+
session_manager = self._database._session_manager
1208+
self._session = session_manager.get_session(self._transaction_type)
11871209
return self._session
11881210

11891211
def __exit__(self, *ignored):
@@ -1248,7 +1270,13 @@ def __init__(
12481270

12491271
def __enter__(self):
12501272
"""Begin ``with`` block."""
1251-
self._session = self._database._session_manager.get_session_for_read_only()
1273+
1274+
# Batch transactions are performed as blind writes,
1275+
# which are treated as read-only transactions.
1276+
self._session = self._database._session_manager.get_session(
1277+
TransactionType.READ_ONLY
1278+
)
1279+
12521280
batch = self._batch = Batch(self._session)
12531281
if self._request_options.transaction_tag:
12541282
batch.transaction_tag = self._request_options.transaction_tag
@@ -1303,7 +1331,9 @@ def __init__(self, database):
13031331

13041332
def __enter__(self):
13051333
"""Begin ``with`` block."""
1306-
self._session = self._database._session_manager.get_session_for_read_write()
1334+
self._session = self._database._session_manager.get_session(
1335+
TransactionType.READ_WRITE
1336+
)
13071337
return MutationGroups(self._session)
13081338

13091339
def __exit__(self, exc_type, exc_val, exc_tb):
@@ -1345,7 +1375,9 @@ def __init__(self, database, **kw):
13451375

13461376
def __enter__(self):
13471377
"""Begin ``with`` block."""
1348-
self._session = self._database._session_manager.get_session_for_read_only()
1378+
self._session = self._database._session_manager.get_session(
1379+
TransactionType.READ_ONLY
1380+
)
13491381
return Snapshot(self._session, **self._kw)
13501382

13511383
def __exit__(self, exc_type, exc_val, exc_tb):
@@ -1395,11 +1427,15 @@ def from_dict(cls, database, mapping):
13951427
13961428
:rtype: :class:`BatchSnapshot`
13971429
"""
1430+
13981431
instance = cls(database)
1399-
session = instance._session = database.session()
1400-
session._session_id = mapping["session_id"]
1432+
1433+
session = instance._session = Session(database=database)
1434+
instance._session_id = session._session_id = mapping["session_id"]
1435+
14011436
snapshot = instance._snapshot = session.snapshot()
1402-
snapshot._transaction_id = mapping["transaction_id"]
1437+
instance._transaction_id = snapshot._transaction_id = mapping["transaction_id"]
1438+
14031439
return instance
14041440

14051441
def to_dict(self):
@@ -1408,10 +1444,15 @@ def to_dict(self):
14081444
Result can be used to serialize the instance and reconstitute
14091445
it later using :meth:`from_dict`.
14101446
1447+
When called, the underlying session is cleaned up, so
1448+
the batch snapshot is no longer valid.
1449+
14111450
:rtype: dict
14121451
"""
1452+
14131453
session = self._get_session()
14141454
snapshot = self._get_snapshot()
1455+
14151456
return {
14161457
"session_id": session._session_id,
14171458
"transaction_id": snapshot._transaction_id,
@@ -1429,25 +1470,48 @@ def _get_session(self):
14291470
Caller is responsible for cleaning up the session after
14301471
all partitions have been processed.
14311472
"""
1473+
14321474
if self._session is None:
1433-
session = self._session = self._database.session()
1475+
database = self._database
1476+
1477+
# If the session ID is not specified, check out a new session from
1478+
# the database session manager; otherwise, the session has already
1479+
# been checked out, so just create a session object to represent it.
14341480
if self._session_id is None:
1435-
session.create()
1481+
transaction_type = TransactionType.READ_ONLY
1482+
session = database._session_manager.get_session(transaction_type)
1483+
self._session_id = session.session_id
1484+
14361485
else:
1486+
session = Session(database=database)
14371487
session._session_id = self._session_id
1488+
1489+
self._session = session
1490+
14381491
return self._session
14391492

14401493
def _get_snapshot(self):
14411494
"""Create snapshot if needed."""
1495+
14421496
if self._snapshot is None:
1443-
self._snapshot = self._get_session().snapshot(
1444-
read_timestamp=self._read_timestamp,
1445-
exact_staleness=self._exact_staleness,
1446-
multi_use=True,
1447-
transaction_id=self._transaction_id,
1448-
)
1497+
snapshot_args = {
1498+
"session": self._get_session(),
1499+
"read_timestamp": self._read_timestamp,
1500+
"exact_staleness": self._exact_staleness,
1501+
"multi_use": True,
1502+
}
1503+
1504+
# If the transaction ID is not specified, create a new snapshot
1505+
# and begin a transaction; otherwise, the transaction is already
1506+
# in progress, so just create a snapshot object to represent it.
14491507
if self._transaction_id is None:
1450-
self._snapshot.begin()
1508+
self._snapshot = Snapshot(**snapshot_args)
1509+
self._transaction_id = self._snapshot.begin()
1510+
1511+
else:
1512+
snapshot_args["transaction_id"] = self._transaction_id
1513+
self._snapshot = Snapshot(**snapshot_args)
1514+
14511515
return self._snapshot
14521516

14531517
def get_batch_transaction_id(self):
@@ -1844,7 +1908,7 @@ def close(self):
18441908
from all the partitions.
18451909
"""
18461910
if self._session is not None:
1847-
self._session.delete()
1911+
self._database._session_manager.put_session(self._session)
18481912

18491913

18501914
def _check_ddl_statements(value):

0 commit comments

Comments
 (0)