Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_v1 import RequestOptions, TransactionOptions
from google.cloud.spanner_v1.session_options import TransactionType
from google.cloud.spanner_v1.snapshot import Snapshot

from google.cloud.spanner_dbapi.exceptions import (
Expand Down Expand Up @@ -356,11 +357,12 @@ def _session_checkout(self):
raise ValueError("Database needs to be passed for this operation")

if not self._session:
self._session = (
self.database._session_manager.get_session_for_read_only()
transaction_type = (
TransactionType.READ_ONLY
if self.read_only
else self.database._session_manager.get_session_for_read_write()
else TransactionType.READ_WRITE
)
self._session = self.database._session_manager.get_session(transaction_type)

return self._session

Expand Down Expand Up @@ -628,7 +630,6 @@ def partition_query(
self._partitioned_query_validation(partitioned_query, statement)

batch_snapshot = self._database.batch_snapshot()
partition_ids = []
partitions = list(
batch_snapshot.generate_query_batches(
partitioned_query,
Expand All @@ -639,6 +640,8 @@ def partition_query(
)

batch_transaction_id = batch_snapshot.get_batch_transaction_id()

partition_ids = []
for partition in partitions:
partition_ids.append(
partition_helper.encode_to_string(batch_transaction_id, partition)
Expand Down
154 changes: 109 additions & 45 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager
from google.cloud.spanner_v1.session import Session
from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType
from google.cloud.spanner_v1.transaction import BatchTransactionId
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import Type
Expand All @@ -60,7 +62,6 @@
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.merged_result_set import MergedResultSet
from google.cloud.spanner_v1.pool import BurstyPool
from google.cloud.spanner_v1.session import Session
from google.cloud.spanner_v1.snapshot import _restart_on_unavailable
from google.cloud.spanner_v1.snapshot import Snapshot
from google.cloud.spanner_v1.streamed import StreamedResultSet
Expand All @@ -74,7 +75,6 @@
)
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture


SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data"


Expand Down Expand Up @@ -192,9 +192,7 @@ def __init__(
pool = BurstyPool(database_role=database_role)
pool.bind(self)

self._session_manager = DatabaseSessionsManager(
database=self, pool=pool, logger=self.logger
)
self._session_manager = DatabaseSessionsManager(database=self, pool=pool)

@classmethod
def from_pb(cls, database_pb, instance, pool=None):
Expand Down Expand Up @@ -449,6 +447,15 @@ def spanner_api(self):
)
return self._spanner_api

@property
def session_options(self) -> SessionOptions:
"""Session options for the database.

:rtype: :class:`~google.cloud.spanner_v1.session_options.SessionOptions`
:returns: the session options
"""
return self._instance._client.session_options

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
Expand Down Expand Up @@ -709,11 +716,27 @@ def execute_pdml():
"CloudSpanner.Database.execute_partitioned_pdml",
observability_options=self.observability_options,
) as span, MetricsCapture():
with SessionCheckout(self) as session:
transaction_type = TransactionType.PARTITIONED
with SessionCheckout(self, transaction_type) as session:
add_span_event(span, "Starting BeginTransaction")
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)

try:
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)

# If partitioned DML is not supported with multiplexed sessions,
# disable multiplexed sessions for partitioned transactions before
# re-raising the error.
except NotImplementedError as exc:
if (
"Transaction type partitioned_dml not supported with multiplexed sessions"
in str(exc)
):
self.session_options.disable_multiplexed(
self.logger, transaction_type
)
raise exc

txn_selector = TransactionSelector(id=txn.id)

Expand All @@ -732,8 +755,9 @@ def execute_pdml():

iterator = _restart_on_unavailable(
method=method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
session=session,
trace_name="CloudSpanner.ExecuteStreamingSql",
metadata=metadata,
transaction_selector=txn_selector,
observability_options=self.observability_options,
Expand All @@ -746,23 +770,6 @@ def execute_pdml():

return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()

def session(self, labels=None, database_role=None):
"""Factory to create a session for this database.

:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for the session.

:type database_role: str
:param database_role: (Optional) user-assigned database_role for the session.

:rtype: :class:`~google.cloud.spanner_v1.session.Session`
:returns: a session bound to this database.
"""
# If role is specified in param, then that role is used
# instead.
role = database_role or self._database_role
return Session(self, labels=labels, database_role=role)

def snapshot(self, **kw):
"""Return an object which wraps a snapshot.

Expand Down Expand Up @@ -1170,7 +1177,11 @@ class SessionCheckout(object):

_session = None # Not checked out until '__enter__'.

def __init__(self, database):
def __init__(
self,
database: Database,
transaction_type: TransactionType = TransactionType.READ_WRITE,
):
if not isinstance(database, Database):
raise TypeError(
"{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format(
Expand All @@ -1180,10 +1191,21 @@ def __init__(self, database):
)
)

if not isinstance(transaction_type, TransactionType):
raise TypeError(
"{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format(
class_name=self.__class__.__name__,
expected_class_name=TransactionType.__name__,
actual_class_name=transaction_type.__class__.__name__,
)
)

self._database = database
self._transaction_type = transaction_type

def __enter__(self):
self._session = self._database._session_manager.get_session_for_read_write()
session_manager = self._database._session_manager
self._session = session_manager.get_session(self._transaction_type)
return self._session

def __exit__(self, *ignored):
Expand Down Expand Up @@ -1248,7 +1270,13 @@ def __init__(

def __enter__(self):
"""Begin ``with`` block."""
self._session = self._database._session_manager.get_session_for_read_only()

# Batch transactions are performed as blind writes,
# which are treated as read-only transactions.
self._session = self._database._session_manager.get_session(
TransactionType.READ_ONLY
)

batch = self._batch = Batch(self._session)
if self._request_options.transaction_tag:
batch.transaction_tag = self._request_options.transaction_tag
Expand Down Expand Up @@ -1303,7 +1331,9 @@ def __init__(self, database):

def __enter__(self):
"""Begin ``with`` block."""
self._session = self._database._session_manager.get_session_for_read_write()
self._session = self._database._session_manager.get_session(
TransactionType.READ_WRITE
)
return MutationGroups(self._session)

def __exit__(self, exc_type, exc_val, exc_tb):
Expand Down Expand Up @@ -1345,7 +1375,9 @@ def __init__(self, database, **kw):

def __enter__(self):
"""Begin ``with`` block."""
self._session = self._database._session_manager.get_session_for_read_only()
self._session = self._database._session_manager.get_session(
TransactionType.READ_ONLY
)
return Snapshot(self._session, **self._kw)

def __exit__(self, exc_type, exc_val, exc_tb):
Expand Down Expand Up @@ -1395,11 +1427,15 @@ def from_dict(cls, database, mapping):

:rtype: :class:`BatchSnapshot`
"""

instance = cls(database)
session = instance._session = database.session()
session._session_id = mapping["session_id"]

session = instance._session = Session(database=database)
instance._session_id = session._session_id = mapping["session_id"]

snapshot = instance._snapshot = session.snapshot()
snapshot._transaction_id = mapping["transaction_id"]
instance._transaction_id = snapshot._transaction_id = mapping["transaction_id"]

return instance

def to_dict(self):
Expand All @@ -1408,10 +1444,15 @@ def to_dict(self):
Result can be used to serialize the instance and reconstitute
it later using :meth:`from_dict`.

When called, the underlying session is cleaned up, so
the batch snapshot is no longer valid.

:rtype: dict
"""

session = self._get_session()
snapshot = self._get_snapshot()

return {
"session_id": session._session_id,
"transaction_id": snapshot._transaction_id,
Expand All @@ -1429,25 +1470,48 @@ def _get_session(self):
Caller is responsible for cleaning up the session after
all partitions have been processed.
"""

if self._session is None:
session = self._session = self._database.session()
database = self._database

# If the session ID is not specified, check out a new session from
# the database session manager; otherwise, the session has already
# been checked out, so just create a session object to represent it.
if self._session_id is None:
session.create()
transaction_type = TransactionType.READ_ONLY
session = database._session_manager.get_session(transaction_type)
self._session_id = session.session_id

else:
session = Session(database=database)
session._session_id = self._session_id

self._session = session

return self._session

def _get_snapshot(self):
"""Create snapshot if needed."""

if self._snapshot is None:
self._snapshot = self._get_session().snapshot(
read_timestamp=self._read_timestamp,
exact_staleness=self._exact_staleness,
multi_use=True,
transaction_id=self._transaction_id,
)
snapshot_args = {
"session": self._get_session(),
"read_timestamp": self._read_timestamp,
"exact_staleness": self._exact_staleness,
"multi_use": True,
}

# If the transaction ID is not specified, create a new snapshot
# and begin a transaction; otherwise, the transaction is already
# in progress, so just create a snapshot object to represent it.
if self._transaction_id is None:
self._snapshot.begin()
self._snapshot = Snapshot(**snapshot_args)
self._transaction_id = self._snapshot.begin()

else:
snapshot_args["transaction_id"] = self._transaction_id
self._snapshot = Snapshot(**snapshot_args)

return self._snapshot

def get_batch_transaction_id(self):
Expand Down Expand Up @@ -1844,7 +1908,7 @@ def close(self):
from all the partitions.
"""
if self._session is not None:
self._session.delete()
self._database._session_manager.put_session(self._session)


def _check_ddl_statements(value):
Expand Down
Loading