Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,13 @@ def _session_checkout(self):
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")

if not self._session:
self._session = self.database._pool.get()
self._session = (
self.database._session_manager.get_session_for_read_only()
if self.read_only
else self.database._session_manager.get_session_for_read_write()
)

return self._session

Expand All @@ -344,7 +349,7 @@ def _release_session(self):
return
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
self.database._pool.put(self._session)
self.database._session_manager.put_session(self._session)
self._session = None

def transaction_checkout(self):
Expand Down Expand Up @@ -400,7 +405,7 @@ def close(self):
self._transaction.rollback()

if self._own_pool and self.database:
self.database._pool.clear()
self.database._session_manager._pool.clear()

self.is_closed = True

Expand Down
16 changes: 16 additions & 0 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
except ImportError: # pragma: NO COVER
HAS_GOOGLE_CLOUD_MONITORING_INSTALLED = False

from google.cloud.spanner_v1.session_options import SessionOptions

_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__)
EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST"
Expand Down Expand Up @@ -171,6 +172,9 @@ class Client(ClientWithProject):
or :class:`dict`
:param default_transaction_options: (Optional) Default options to use for all transactions.

:type session_options: :class:`~google.cloud.spanner_v1.SessionOptions`
:param session_options: (Optional) Options for client sessions.

:raises: :class:`ValueError <exceptions.ValueError>` if both ``read_only``
and ``admin`` are :data:`True`
"""
Expand All @@ -193,6 +197,7 @@ def __init__(
directed_read_options=None,
observability_options=None,
default_transaction_options: Optional[DefaultTransactionOptions] = None,
session_options=None,
):
self._emulator_host = _get_spanner_emulator_host()

Expand Down Expand Up @@ -262,6 +267,8 @@ def __init__(
)
self._default_transaction_options = default_transaction_options

self._session_options = session_options or SessionOptions()

@property
def credentials(self):
"""Getter for client's credentials.
Expand Down Expand Up @@ -525,3 +532,12 @@ def default_transaction_options(
)

self._default_transaction_options = default_transaction_options

@property
def session_options(self):
"""Returns the session options for the client.

:rtype: :class:`~google.cloud.spanner_v1.SessionOptions`
:returns: The session options for the client.
"""
return self._session_options
105 changes: 70 additions & 35 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager
from google.cloud.spanner_v1.transaction import BatchTransactionId
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import Type
Expand All @@ -59,7 +60,6 @@
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.merged_result_set import MergedResultSet
from google.cloud.spanner_v1.pool import BurstyPool
from google.cloud.spanner_v1.pool import SessionCheckout
from google.cloud.spanner_v1.session import Session
from google.cloud.spanner_v1.snapshot import _restart_on_unavailable
from google.cloud.spanner_v1.snapshot import Snapshot
Expand All @@ -70,7 +70,6 @@
from google.cloud.spanner_v1.table import Table
from google.cloud.spanner_v1._opentelemetry_tracing import (
add_span_event,
get_current_span,
trace_call,
)
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
Expand Down Expand Up @@ -191,10 +190,10 @@ def __init__(

if pool is None:
pool = BurstyPool(database_role=database_role)

self._pool = pool
pool.bind(self)

self._session_manager = DatabaseSessionsManager(database=self, pool=pool)

@classmethod
def from_pb(cls, database_pb, instance, pool=None):
"""Creates an instance of this class from a protobuf.
Expand Down Expand Up @@ -708,7 +707,7 @@ def execute_pdml():
"CloudSpanner.Database.execute_partitioned_pdml",
observability_options=self.observability_options,
) as span, MetricsCapture():
with SessionCheckout(self._pool) as session:
with SessionCheckout(self) as session:
add_span_event(span, "Starting BeginTransaction")
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
Expand Down Expand Up @@ -923,7 +922,7 @@ def run_in_transaction(self, func, *args, **kw):
# Check out a session and run the function in a transaction; once
# done, flip the sanity check bit back.
try:
with SessionCheckout(self._pool) as session:
with SessionCheckout(self) as session:
return session.run_in_transaction(func, *args, **kw)
finally:
self._local.transaction_running = False
Expand Down Expand Up @@ -1160,6 +1159,35 @@ def observability_options(self):
return opts


class SessionCheckout(object):
"""Context manager for using a session from a database.

:type database: :class:`~google.cloud.spanner_v1.database.Database`
:param database: database to use the session from
"""

_session = None # Not checked out until '__enter__'.

def __init__(self, database):
if not isinstance(database, Database):
raise TypeError(
"{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format(
class_name=self.__class__.__name__,
expected_class_name=Database.__name__,
actual_class_name=database.__class__.__name__,
)
)

self._database = database

def __enter__(self):
self._session = self._database._session_manager.get_session_for_read_write()
return self._session

def __exit__(self, *ignored):
self._database._session_manager.put_session(self._session)
Comment on lines +1162 to +1188
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SessionCheckout has been moved from pool.py to database.py, and updated so that the session is checked out from the database rather than from the pool.

Furthermore, SessionCheckout's constructor previous accepted arbitrary keyword arguments, which were passed to the pool's get method. While AbstractSessionPool.get does not describe any keyword arguments, both FixedSizePool.get and PingingPool.get accept an optional timeout keyword argument, which can be used to override the default timeout.

Because DatabaseSessionManager now manages which session is returned, and it not guaranteed to return a session from the pool, we are suggesting removing the ability to provide arbitrary keyword arguments to the pool via SessionCheckout.__init__.

No keyword arguments are supported by the default pool implementation (BurstyPool), and users are already able to specify their own default timeout for FixedSizePool and PingingPool. Moreover, it is not clear to us where this functionality is even used: remaining uses of SessionCheckout do not support these keyword arguments.

Please let us know:

  • Is it acceptable to remove this functionality?
  • Would you also like to remove support for the timeout arguments to FixedSizePool and PingingPool, for consistency?



class BatchCheckout(object):
"""Context manager for using a batch from a database.

Expand Down Expand Up @@ -1194,6 +1222,15 @@ def __init__(
isolation_level=TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED,
**kw,
):
if not isinstance(database, Database):
raise TypeError(
"{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format(
class_name=self.__class__.__name__,
expected_class_name=Database.__name__,
actual_class_name=database.__class__.__name__,
)
)
Comment on lines +1225 to +1232
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better error handling, added a type check to all *Checkout classes, along with corresponding unit tests.


self._database = database
self._session = self._batch = None
if request_options is None:
Expand All @@ -1209,10 +1246,8 @@ def __init__(

def __enter__(self):
"""Begin ``with`` block."""
current_span = get_current_span()
session = self._session = self._database._pool.get()
add_span_event(current_span, "Using session", {"id": session.session_id})
batch = self._batch = Batch(session)
self._session = self._database._session_manager.get_session_for_read_only()
batch = self._batch = Batch(self._session)
if self._request_options.transaction_tag:
batch.transaction_tag = self._request_options.transaction_tag
return batch
Expand All @@ -1235,13 +1270,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
"CommitStats: {}".format(self._batch.commit_stats),
extra={"commit_stats": self._batch.commit_stats},
)
self._database._pool.put(self._session)
current_span = get_current_span()
add_span_event(
current_span,
"Returned session to pool",
{"id": self._session.session_id},
)
self._database._session_manager.put_session(self._session)


class MutationGroupsCheckout(object):
Expand All @@ -1258,23 +1287,26 @@ class MutationGroupsCheckout(object):
"""

def __init__(self, database):
if not isinstance(database, Database):
raise TypeError(
"{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format(
class_name=self.__class__.__name__,
expected_class_name=Database.__name__,
actual_class_name=database.__class__.__name__,
)
)

self._database = database
self._session = None

def __enter__(self):
"""Begin ``with`` block."""
session = self._session = self._database._pool.get()
return MutationGroups(session)
self._session = self._database._session_manager.get_session_for_read_write()
return MutationGroups(self._session)

def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
if isinstance(exc_val, NotFound):
# If NotFound exception occurs inside the with block
# then we validate if the session still exists.
if not self._session.exists():
self._session = self._database._pool._new_session()
self._session.create()
Comment on lines -1271 to -1276
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check has been removed from MutationGroupsCheckout and SnapshotCheckout because:

  • exists only checks whether the session exists on the server (not whether the local session object exists)
  • This checkout wasn't for all context managers (not in BatchCheckout or SessionCheckout)
  • This seems to only be done as a slight optimization to potentially reduce overhead of session retrieval by adding it to the tail of a request instead of the head of another.
  • In reality, the optimization only plays a part if the session expires while a call was happening, and is unlikely to make a big difference.
  • exists checks are still performed by the session pool (for non-multiplexed session) and the database session manager (for multiplexed session).

We recommend removing this code, and instead depend on DatabaseSessionManager to take care of checking session existence and re-creating them if necessary, rather than (sometimes) handling this in the context manager.

self._database._pool.put(self._session)
self._database._session_manager.put_session(self._session)


class SnapshotCheckout(object):
Expand All @@ -1296,24 +1328,27 @@ class SnapshotCheckout(object):
"""

def __init__(self, database, **kw):
if not isinstance(database, Database):
raise TypeError(
"{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format(
class_name=self.__class__.__name__,
expected_class_name=Database.__name__,
actual_class_name=database.__class__.__name__,
)
)

self._database = database
self._session = None
self._kw = kw

def __enter__(self):
"""Begin ``with`` block."""
session = self._session = self._database._pool.get()
return Snapshot(session, **self._kw)
self._session = self._database._session_manager.get_session_for_read_only()
return Snapshot(self._session, **self._kw)

def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
if isinstance(exc_val, NotFound):
# If NotFound exception occurs inside the with block
# then we validate if the session still exists.
if not self._session.exists():
self._session = self._database._pool._new_session()
self._session.create()
self._database._pool.put(self._session)
self._database._session_manager.put_session(self._session)


class BatchSnapshot(object):
Expand Down
Loading
Loading