diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index a615a282b5..6cf225bbea 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -26,6 +26,7 @@ from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_v1 import RequestOptions, TransactionOptions +from google.cloud.spanner_v1.session_options import TransactionType from google.cloud.spanner_v1.snapshot import Snapshot from google.cloud.spanner_dbapi.exceptions import ( @@ -354,8 +355,14 @@ def _session_checkout(self): """ if self.database is None: raise ValueError("Database needs to be passed for this operation") + if not self._session: - self._session = self.database._pool.get() + transaction_type = ( + TransactionType.READ_ONLY + if self.read_only + else TransactionType.READ_WRITE + ) + self._session = self.database._session_manager.get_session(transaction_type) return self._session @@ -368,7 +375,7 @@ def _release_session(self): return if self.database is None: raise ValueError("Database needs to be passed for this operation") - self.database._pool.put(self._session) + self.database._session_manager.put_session(self._session) self._session = None def transaction_checkout(self): @@ -430,7 +437,7 @@ def close(self): self._transaction.rollback() if self._own_pool and self.database: - self.database._pool.clear() + self.database._session_manager._pool.clear() self.is_closed = True @@ -623,7 +630,6 @@ def partition_query( self._partitioned_query_validation(partitioned_query, statement) batch_snapshot = self._database.batch_snapshot() - partition_ids = [] partitions = list( batch_snapshot.generate_query_batches( partitioned_query, @@ -634,6 +640,8 @@ def partition_query( ) batch_transaction_id = batch_snapshot.get_batch_transaction_id() + + partition_ids = [] for partition in partitions: partition_ids.append( partition_helper.encode_to_string(batch_transaction_id, partition) diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index e201f93e9b..f84170c1f3 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -70,6 +70,7 @@ except ImportError: # pragma: NO COVER HAS_GOOGLE_CLOUD_MONITORING_INSTALLED = False +from google.cloud.spanner_v1.session_options import SessionOptions _CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST" @@ -171,6 +172,9 @@ class Client(ClientWithProject): or :class:`dict` :param default_transaction_options: (Optional) Default options to use for all transactions. + :type session_options: :class:`~google.cloud.spanner_v1.SessionOptions` + :param session_options: (Optional) Options for client sessions. + :raises: :class:`ValueError ` if both ``read_only`` and ``admin`` are :data:`True` """ @@ -193,6 +197,7 @@ def __init__( directed_read_options=None, observability_options=None, default_transaction_options: Optional[DefaultTransactionOptions] = None, + session_options=None, ): self._emulator_host = _get_spanner_emulator_host() @@ -262,6 +267,8 @@ def __init__( ) self._default_transaction_options = default_transaction_options + self._session_options = session_options or SessionOptions() + @property def credentials(self): """Getter for client's credentials. @@ -525,3 +532,12 @@ def default_transaction_options( ) self._default_transaction_options = default_transaction_options + + @property + def session_options(self): + """Returns the session options for the client. + + :rtype: :class:`~google.cloud.spanner_v1.SessionOptions` + :returns: The session options for the client. + """ + return self._session_options diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 03c6e5119f..93d9c1a31c 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -40,6 +40,9 @@ from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest from google.cloud.spanner_admin_database_v1.types import DatabaseDialect +from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager +from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType from google.cloud.spanner_v1.transaction import BatchTransactionId from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import Type @@ -59,8 +62,6 @@ from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.merged_result_set import MergedResultSet from google.cloud.spanner_v1.pool import BurstyPool -from google.cloud.spanner_v1.pool import SessionCheckout -from google.cloud.spanner_v1.session import Session from google.cloud.spanner_v1.snapshot import _restart_on_unavailable from google.cloud.spanner_v1.snapshot import Snapshot from google.cloud.spanner_v1.streamed import StreamedResultSet @@ -70,12 +71,10 @@ from google.cloud.spanner_v1.table import Table from google.cloud.spanner_v1._opentelemetry_tracing import ( add_span_event, - get_current_span, trace_call, ) from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture - SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data" @@ -191,10 +190,10 @@ def __init__( if pool is None: pool = BurstyPool(database_role=database_role) - - self._pool = pool pool.bind(self) + self._session_manager = DatabaseSessionsManager(database=self, pool=pool) + @classmethod def from_pb(cls, database_pb, instance, pool=None): """Creates an instance of this class from a protobuf. @@ -448,6 +447,15 @@ def spanner_api(self): ) return self._spanner_api + @property + def session_options(self) -> SessionOptions: + """Session options for the database. + + :rtype: :class:`~google.cloud.spanner_v1.session_options.SessionOptions` + :returns: the session options + """ + return self._instance._client.session_options + def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented @@ -708,11 +716,27 @@ def execute_pdml(): "CloudSpanner.Database.execute_partitioned_pdml", observability_options=self.observability_options, ) as span, MetricsCapture(): - with SessionCheckout(self._pool) as session: + transaction_type = TransactionType.PARTITIONED + with SessionCheckout(self, transaction_type) as session: add_span_event(span, "Starting BeginTransaction") - txn = api.begin_transaction( - session=session.name, options=txn_options, metadata=metadata - ) + + try: + txn = api.begin_transaction( + session=session.name, options=txn_options, metadata=metadata + ) + + # If partitioned DML is not supported with multiplexed sessions, + # disable multiplexed sessions for partitioned transactions before + # re-raising the error. + except NotImplementedError as exc: + if ( + "Transaction type partitioned_dml not supported with multiplexed sessions" + in str(exc) + ): + self.session_options.disable_multiplexed( + self.logger, transaction_type + ) + raise exc txn_selector = TransactionSelector(id=txn.id) @@ -731,8 +755,9 @@ def execute_pdml(): iterator = _restart_on_unavailable( method=method, - trace_name="CloudSpanner.ExecuteStreamingSql", request=request, + session=session, + trace_name="CloudSpanner.ExecuteStreamingSql", metadata=metadata, transaction_selector=txn_selector, observability_options=self.observability_options, @@ -745,23 +770,6 @@ def execute_pdml(): return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)() - def session(self, labels=None, database_role=None): - """Factory to create a session for this database. - - :type labels: dict (str -> str) or None - :param labels: (Optional) user-assigned labels for the session. - - :type database_role: str - :param database_role: (Optional) user-assigned database_role for the session. - - :rtype: :class:`~google.cloud.spanner_v1.session.Session` - :returns: a session bound to this database. - """ - # If role is specified in param, then that role is used - # instead. - role = database_role or self._database_role - return Session(self, labels=labels, database_role=role) - def snapshot(self, **kw): """Return an object which wraps a snapshot. @@ -923,7 +931,7 @@ def run_in_transaction(self, func, *args, **kw): # Check out a session and run the function in a transaction; once # done, flip the sanity check bit back. try: - with SessionCheckout(self._pool) as session: + with SessionCheckout(self) as session: return session.run_in_transaction(func, *args, **kw) finally: self._local.transaction_running = False @@ -1160,6 +1168,50 @@ def observability_options(self): return opts +class SessionCheckout(object): + """Context manager for using a session from a database. + + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: database to use the session from + """ + + _session = None # Not checked out until '__enter__'. + + def __init__( + self, + database: Database, + transaction_type: TransactionType = TransactionType.READ_WRITE, + ): + if not isinstance(database, Database): + raise TypeError( + "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format( + class_name=self.__class__.__name__, + expected_class_name=Database.__name__, + actual_class_name=database.__class__.__name__, + ) + ) + + if not isinstance(transaction_type, TransactionType): + raise TypeError( + "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format( + class_name=self.__class__.__name__, + expected_class_name=TransactionType.__name__, + actual_class_name=transaction_type.__class__.__name__, + ) + ) + + self._database = database + self._transaction_type = transaction_type + + def __enter__(self): + session_manager = self._database._session_manager + self._session = session_manager.get_session(self._transaction_type) + return self._session + + def __exit__(self, *ignored): + self._database._session_manager.put_session(self._session) + + class BatchCheckout(object): """Context manager for using a batch from a database. @@ -1194,6 +1246,15 @@ def __init__( isolation_level=TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED, **kw, ): + if not isinstance(database, Database): + raise TypeError( + "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format( + class_name=self.__class__.__name__, + expected_class_name=Database.__name__, + actual_class_name=database.__class__.__name__, + ) + ) + self._database = database self._session = self._batch = None if request_options is None: @@ -1209,10 +1270,14 @@ def __init__( def __enter__(self): """Begin ``with`` block.""" - current_span = get_current_span() - session = self._session = self._database._pool.get() - add_span_event(current_span, "Using session", {"id": session.session_id}) - batch = self._batch = Batch(session) + + # Batch transactions are performed as blind writes, + # which are treated as read-only transactions. + self._session = self._database._session_manager.get_session( + TransactionType.READ_ONLY + ) + + batch = self._batch = Batch(self._session) if self._request_options.transaction_tag: batch.transaction_tag = self._request_options.transaction_tag return batch @@ -1235,13 +1300,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): "CommitStats: {}".format(self._batch.commit_stats), extra={"commit_stats": self._batch.commit_stats}, ) - self._database._pool.put(self._session) - current_span = get_current_span() - add_span_event( - current_span, - "Returned session to pool", - {"id": self._session.session_id}, - ) + self._database._session_manager.put_session(self._session) class MutationGroupsCheckout(object): @@ -1258,23 +1317,28 @@ class MutationGroupsCheckout(object): """ def __init__(self, database): + if not isinstance(database, Database): + raise TypeError( + "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format( + class_name=self.__class__.__name__, + expected_class_name=Database.__name__, + actual_class_name=database.__class__.__name__, + ) + ) + self._database = database self._session = None def __enter__(self): """Begin ``with`` block.""" - session = self._session = self._database._pool.get() - return MutationGroups(session) + self._session = self._database._session_manager.get_session( + TransactionType.READ_WRITE + ) + return MutationGroups(self._session) def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" - if isinstance(exc_val, NotFound): - # If NotFound exception occurs inside the with block - # then we validate if the session still exists. - if not self._session.exists(): - self._session = self._database._pool._new_session() - self._session.create() - self._database._pool.put(self._session) + self._database._session_manager.put_session(self._session) class SnapshotCheckout(object): @@ -1296,24 +1360,29 @@ class SnapshotCheckout(object): """ def __init__(self, database, **kw): + if not isinstance(database, Database): + raise TypeError( + "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format( + class_name=self.__class__.__name__, + expected_class_name=Database.__name__, + actual_class_name=database.__class__.__name__, + ) + ) + self._database = database self._session = None self._kw = kw def __enter__(self): """Begin ``with`` block.""" - session = self._session = self._database._pool.get() - return Snapshot(session, **self._kw) + self._session = self._database._session_manager.get_session( + TransactionType.READ_ONLY + ) + return Snapshot(self._session, **self._kw) def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" - if isinstance(exc_val, NotFound): - # If NotFound exception occurs inside the with block - # then we validate if the session still exists. - if not self._session.exists(): - self._session = self._database._pool._new_session() - self._session.create() - self._database._pool.put(self._session) + self._database._session_manager.put_session(self._session) class BatchSnapshot(object): @@ -1358,11 +1427,15 @@ def from_dict(cls, database, mapping): :rtype: :class:`BatchSnapshot` """ + instance = cls(database) - session = instance._session = database.session() - session._session_id = mapping["session_id"] + + session = instance._session = Session(database=database) + instance._session_id = session._session_id = mapping["session_id"] + snapshot = instance._snapshot = session.snapshot() - snapshot._transaction_id = mapping["transaction_id"] + instance._transaction_id = snapshot._transaction_id = mapping["transaction_id"] + return instance def to_dict(self): @@ -1371,10 +1444,15 @@ def to_dict(self): Result can be used to serialize the instance and reconstitute it later using :meth:`from_dict`. + When called, the underlying session is cleaned up, so + the batch snapshot is no longer valid. + :rtype: dict """ + session = self._get_session() snapshot = self._get_snapshot() + return { "session_id": session._session_id, "transaction_id": snapshot._transaction_id, @@ -1392,25 +1470,48 @@ def _get_session(self): Caller is responsible for cleaning up the session after all partitions have been processed. """ + if self._session is None: - session = self._session = self._database.session() + database = self._database + + # If the session ID is not specified, check out a new session from + # the database session manager; otherwise, the session has already + # been checked out, so just create a session object to represent it. if self._session_id is None: - session.create() + transaction_type = TransactionType.READ_ONLY + session = database._session_manager.get_session(transaction_type) + self._session_id = session.session_id + else: + session = Session(database=database) session._session_id = self._session_id + + self._session = session + return self._session def _get_snapshot(self): """Create snapshot if needed.""" + if self._snapshot is None: - self._snapshot = self._get_session().snapshot( - read_timestamp=self._read_timestamp, - exact_staleness=self._exact_staleness, - multi_use=True, - transaction_id=self._transaction_id, - ) + snapshot_args = { + "session": self._get_session(), + "read_timestamp": self._read_timestamp, + "exact_staleness": self._exact_staleness, + "multi_use": True, + } + + # If the transaction ID is not specified, create a new snapshot + # and begin a transaction; otherwise, the transaction is already + # in progress, so just create a snapshot object to represent it. if self._transaction_id is None: - self._snapshot.begin() + self._snapshot = Snapshot(**snapshot_args) + self._transaction_id = self._snapshot.begin() + + else: + snapshot_args["transaction_id"] = self._transaction_id + self._snapshot = Snapshot(**snapshot_args) + return self._snapshot def get_batch_transaction_id(self): @@ -1807,7 +1908,7 @@ def close(self): from all the partitions. """ if self._session is not None: - self._session.delete() + self._database._session_manager.put_session(self._session) def _check_ddl_statements(value): diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py new file mode 100644 index 0000000000..a9837700ef --- /dev/null +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -0,0 +1,258 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import threading +import time +import weakref + +from google.api_core.exceptions import MethodNotImplemented + +from google.cloud.spanner_v1._opentelemetry_tracing import ( + get_current_span, + add_span_event, +) +from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.session_options import TransactionType + + +class DatabaseSessionsManager(object): + """Manages sessions for a Cloud Spanner database. + + Sessions can be checked out from the database session manager for a specific + transaction type using :meth:`get_session`, and returned to the session manager + using :meth:`put_session`. + + The sessions returned by the session manager depend on the client's session options (see + :class:`~google.cloud.spanner_v1.session_options.SessionOptions`) and the provided session + pool (see :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`). + + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: The database to manage sessions for. + + :type pool: :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` + :param pool: The pool to get non-multiplexed sessions from. + """ + + # Intervals for the maintenance thread to check and refresh the multiplexed session. + _MAINTENANCE_THREAD_POLLING_INTERVAL = datetime.timedelta(minutes=10) + _MAINTENANCE_THREAD_REFRESH_INTERVAL = datetime.timedelta(days=7) + + def __init__(self, database, pool): + self._database = database + self._logger = database.logger + + # The session pool manages non-multiplexed sessions, and + # will only be used if multiplexed sessions are not enabled. + self._pool = pool + + # Declare multiplexed session attributes. When a multiplexed session for the + # database session manager is created, a maintenance thread is initialized to + # periodically delete and recreate the multiplexed session so that it remains + # valid. Because of this concurrency, we need to use a lock whenever we access + # the multiplexed session to avoid any race conditions. We also create an event + # so that the thread can terminate if the use of multiplexed session has been + # disabled for all transactions. + self._multiplexed_session = None + self._multiplexed_session_maintenance_thread = None + self._multiplexed_session_lock = threading.Lock() + self._is_multiplexed_sessions_disabled_event = threading.Event() + + def get_session(self, transaction_type: TransactionType) -> Session: + """Returns a session for the given transaction type from the database session manager. + + :rtype: :class:`~google.cloud.spanner_v1.session.Session` + :returns: a session for the given transaction type. + """ + + session_options = self._database.session_options + use_multiplexed = session_options.use_multiplexed(transaction_type) + + if use_multiplexed and transaction_type == TransactionType.READ_WRITE: + raise NotImplementedError( + f"Multiplexed sessions are not yet supported for {transaction_type} transactions." + ) + + if use_multiplexed: + try: + session = self._get_multiplexed_session() + + # If multiplexed sessions are not supported, disable + # them for all transactions and return a non-multiplexed session. + except MethodNotImplemented: + self._disable_multiplexed_sessions() + session = self._pool.get() + + else: + session = self._pool.get() + + add_span_event( + get_current_span(), + "Using session", + {"id": session.session_id, "multiplexed": session.is_multiplexed}, + ) + + return session + + def put_session(self, session: Session) -> None: + """Returns the session to the database session manager. + + :type session: :class:`~google.cloud.spanner_v1.session.Session` + :param session: The session to return to the database session manager. + """ + + add_span_event( + get_current_span(), + "Returning session", + {"id": session.session_id, "multiplexed": session.is_multiplexed}, + ) + + # No action is needed for multiplexed sessions: the session + # pool is only used for managing non-multiplexed sessions, + # since they can only process one transaction at a time. + if not session.is_multiplexed: + self._pool.put(session) + + def _get_multiplexed_session(self) -> Session: + """Returns a multiplexed session from the database session manager. + + If the multiplexed session is not defined, creates a new multiplexed + session and starts a maintenance thread to periodically delete and + recreate it so that it remains valid. Otherwise, simply returns the + current multiplexed session. + + :raises MethodNotImplemented: + if multiplexed sessions are not supported. + + :rtype: :class:`~google.cloud.spanner_v1.session.Session` + :returns: a multiplexed session. + """ + + with self._multiplexed_session_lock: + if self._multiplexed_session is None: + self._multiplexed_session = self._build_multiplexed_session() + + # Build and start a thread to maintain the multiplexed session. + self._multiplexed_session_maintenance_thread = ( + self._build_maintenance_thread() + ) + self._multiplexed_session_maintenance_thread.start() + + return self._multiplexed_session + + def _build_multiplexed_session(self) -> Session: + """Builds and returns a new multiplexed session for the database session manager. + + :raises MethodNotImplemented: + if multiplexed sessions are not supported. + + :rtype: :class:`~google.cloud.spanner_v1.session.Session` + :returns: a new multiplexed session. + """ + + session = Session( + database=self._database, + database_role=self._database.database_role, + is_multiplexed=True, + ) + + session.create() + + self._logger.info("Created multiplexed session.") + + return session + + def _disable_multiplexed_sessions(self) -> None: + """Disables multiplexed sessions for all transactions.""" + + self._multiplexed_session = None + self._is_multiplexed_sessions_disabled_event.set() + self._database.session_options.disable_multiplexed(self._logger) + + def _build_maintenance_thread(self) -> threading.Thread: + """Builds and returns a multiplexed session maintenance thread for + the database session manager. This thread will periodically delete + and recreate the multiplexed session to ensure that it is always valid. + + :rtype: :class:`threading.Thread` + :returns: a multiplexed session maintenance thread. + """ + + # Use a weak reference to the database session manager to avoid + # creating a circular reference that would prevent the database + # session manager from being garbage collected. + session_manager_ref = weakref.ref(self) + + return threading.Thread( + target=self._maintain_multiplexed_session, + name=f"maintenance-multiplexed-session-{self._multiplexed_session.name}", + args=[session_manager_ref], + daemon=True, + ) + + @staticmethod + def _maintain_multiplexed_session(session_manager_ref) -> None: + """Maintains the multiplexed session for the database session manager. + + This method will delete and recreate the referenced database session manager's + multiplexed session to ensure that it is always valid. The method will run until + the database session manager is deleted, the multiplexed session is deleted, or + building a multiplexed session fails. + + :type session_manager_ref: :class:`_weakref.ReferenceType` + :param session_manager_ref: A weak reference to the database session manager. + """ + + session_manager = session_manager_ref() + if session_manager is None: + return + + polling_interval_seconds = ( + session_manager._MAINTENANCE_THREAD_POLLING_INTERVAL.total_seconds() + ) + refresh_interval_seconds = ( + session_manager._MAINTENANCE_THREAD_REFRESH_INTERVAL.total_seconds() + ) + + session_created_time = time.time() + + while True: + # Terminate the thread is the database session manager has been deleted. + session_manager = session_manager_ref() + if session_manager is None: + return + + # Terminate the thread if the use of multiplexed sessions has been disabled. + if session_manager._is_multiplexed_sessions_disabled_event.is_set(): + return + + # Wait for until the refresh interval has elapsed. + if time.time() - session_created_time < refresh_interval_seconds: + time.sleep(polling_interval_seconds) + continue + + with session_manager._multiplexed_session_lock: + session_manager._multiplexed_session.delete() + + try: + session_manager._multiplexed_session = ( + session_manager._build_multiplexed_session() + ) + + # Disable multiplexed sessions for all transactions and terminate + # the thread if building a multiplexed session fails. + except MethodNotImplemented: + session_manager._disable_multiplexed_sessions() + return + + session_created_time = time.time() diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 0c4dd5a63b..65968b7649 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -20,7 +20,7 @@ from google.cloud.exceptions import NotFound from google.cloud.spanner_v1 import BatchCreateSessionsRequest -from google.cloud.spanner_v1 import Session +from google.cloud.spanner_v1 import Session as SessionPB from google.cloud.spanner_v1._helpers import ( _metadata_with_prefix, _metadata_with_leader_aware_routing, @@ -33,6 +33,7 @@ from warnings import warn from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture +from google.cloud.spanner_v1.session import Session _NOW = datetime.datetime.utcnow # unit tests may replace @@ -130,22 +131,11 @@ def _new_session(self): :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: new session instance. """ - return self._database.session( - labels=self.labels, database_role=self.database_role + database_role = self.database_role or self._database.database_role + return Session( + database=self._database, labels=self.labels, database_role=database_role ) - def session(self, **kwargs): - """Check out a session from the pool. - - :param kwargs: (optional) keyword arguments, passed through to - the returned checkout. - - :rtype: :class:`~google.cloud.spanner_v1.session.SessionCheckout` - :returns: a checkout instance, to be used as a context manager for - accessing the session and returning it to the pool. - """ - return SessionCheckout(self, **kwargs) - class FixedSizePool(AbstractSessionPool): """Concrete session pool implementation: @@ -237,7 +227,7 @@ def bind(self, database): request = BatchCreateSessionsRequest( database=database.name, session_count=requested_session_count, - session_template=Session(creator_role=self.database_role), + session_template=SessionPB(creator_role=self.database_role), ) observability_options = getattr(self._database, "observability_options", None) @@ -308,13 +298,12 @@ def get(self, timeout=None): age = _NOW() - session.last_use_time if age >= self._max_age and not session.exists(): - if not session.exists(): - add_span_event( - current_span, - "Session is not valid, recreating it", - span_event_attributes, - ) - session = self._database.session() + add_span_event( + current_span, + "Session is not valid, recreating it", + span_event_attributes, + ) + session = self._new_session() session.create() # Replacing with the updated session.id. span_event_attributes["session.id"] = session._session_id @@ -531,7 +520,7 @@ def bind(self, database): request = BatchCreateSessionsRequest( database=database.name, session_count=self.size, - session_template=Session(creator_role=self.database_role), + session_template=SessionPB(creator_role=self.database_role), ) span_event_attributes = {"kind": type(self).__name__} @@ -776,27 +765,3 @@ def begin_pending_transactions(self): while not self._pending_sessions.empty(): session = self._pending_sessions.get() super(TransactionPingingPool, self).put(session) - - -class SessionCheckout(object): - """Context manager: hold session checked out from a pool. - - :type pool: concrete subclass of - :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` - :param pool: Pool from which to check out a session. - - :param kwargs: extra keyword arguments to be passed to :meth:`pool.get`. - """ - - _session = None # Not checked out until '__enter__'. - - def __init__(self, pool, **kwargs): - self._pool = pool - self._kwargs = kwargs.copy() - - def __enter__(self): - self._session = self._pool.get(**self._kwargs) - return self._session - - def __exit__(self, *ignored): - self._pool.put(self._session) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index f18ba57582..644b00ba9c 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -41,6 +41,7 @@ from google.cloud.spanner_v1.transaction import Transaction from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture +from google.cloud.spanner_v1.types import Session as SessionPB DEFAULT_RETRY_TIMEOUT_SECS = 30 """Default timeout used by :meth:`Session.run_in_transaction`.""" @@ -69,12 +70,13 @@ class Session(object): _session_id = None _transaction = None - def __init__(self, database, labels=None, database_role=None): + def __init__(self, database, labels=None, database_role=None, is_multiplexed=False): self._database = database if labels is None: labels = {} self._labels = labels self._database_role = database_role + self._is_multiplexed = is_multiplexed self._last_use_time = datetime.utcnow() def __lt__(self, other): @@ -87,7 +89,7 @@ def session_id(self): @property def last_use_time(self): - """ "Approximate last use time of this session + """Approximate last use time of this session :rtype: datetime :returns: the approximate last use time of this session""" @@ -110,6 +112,15 @@ def labels(self): """ return self._labels + @property + def is_multiplexed(self): + """Whether this session is multiplexed. + + :rtype: bool + :returns: True if this session is multiplexed, False otherwise. + """ + return self._is_multiplexed + @property def name(self): """Session name used in requests. @@ -153,12 +164,14 @@ def create(self): ) ) - request = CreateSessionRequest(database=self._database.name) - if self._database.database_role is not None: - request.session.creator_role = self._database.database_role + session_pb = SessionPB(multiplexed=self.is_multiplexed) + if self._database.database_role: + session_pb.creator_role = self._database.database_role if self._labels: - request.session.labels = self._labels + session_pb.labels = self._labels + + request = CreateSessionRequest(database=self._database.name, session=session_pb) observability_options = getattr(self._database, "observability_options", None) with trace_call( @@ -408,6 +421,11 @@ def batch(self): if self._session_id is None: raise ValueError("Session has not been created.") + if self.is_multiplexed: + raise NotImplementedError( + "Multiplexed sessions do not yet support read/write transactions." + ) + return Batch(self) def transaction(self): @@ -420,6 +438,11 @@ def transaction(self): if self._session_id is None: raise ValueError("Session has not been created.") + if self.is_multiplexed: + raise NotImplementedError( + "Multiplexed sessions do not yet support read/write transactions." + ) + if self._transaction is not None: self._transaction.rolled_back = True del self._transaction diff --git a/google/cloud/spanner_v1/session_options.py b/google/cloud/spanner_v1/session_options.py new file mode 100644 index 0000000000..eab16dc6de --- /dev/null +++ b/google/cloud/spanner_v1/session_options.py @@ -0,0 +1,153 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from enum import Enum +from logging import Logger + + +class TransactionType(Enum): + """Transaction types for session options.""" + + READ_ONLY = "read-only" + PARTITIONED = "partitioned" + READ_WRITE = "read/write" + + +class SessionOptions(object): + """Represents the session options for the Cloud Spanner Python client. + + We can use ::class::`SessionOptions` to determine whether multiplexed sessions + should be used for a specific transaction type with :meth:`use_multiplexed`. The use + of multiplexed session can be disabled for a specific transaction type or for all + transaction types with :meth:`disable_multiplexed`. + """ + + # Environment variables for multiplexed sessions + ENV_VAR_ENABLE_MULTIPLEXED = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" + ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED = ( + "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" + ) + ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE = ( + "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" + ) + ENV_VAR_FORCE_DISABLE_MULTIPLEXED = ( + "GOOGLE_CLOUD_SPANNER_FORCE_DISABLE_MULTIPLEXED_SESSIONS" + ) + + def __init__(self): + # Internal overrides to disable the use of multiplexed + # sessions in case of runtime errors. + self._is_multiplexed_enabled = { + TransactionType.READ_ONLY: True, + TransactionType.PARTITIONED: True, + TransactionType.READ_WRITE: True, + } + + def use_multiplexed(self, transaction_type: TransactionType) -> bool: + """Returns whether to use multiplexed sessions for the given transaction type. + + Multiplexed sessions are enabled for read-only transactions if: + * ENV_VAR_ENABLE_MULTIPLEXED is set to true; + * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * multiplexed sessions have not been disabled for read-only transactions. + + Multiplexed sessions are enabled for partitioned transactions if: + * ENV_VAR_ENABLE_MULTIPLEXED is set to true; + * ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED is set to true; + * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * multiplexed sessions have not been disabled for partitioned transactions. + + Multiplexed sessions are enabled for read/write transactions if: + * ENV_VAR_ENABLE_MULTIPLEXED is set to true; + * ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED is set to true; + * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * multiplexed sessions have not been disabled for read/write transactions. + + :type transaction_type: :class:`TransactionType` + :param transaction_type: the type of transaction to check whether + multiplexed sessions should be used. + """ + + if transaction_type is TransactionType.READ_ONLY: + return ( + self._is_multiplexed_enabled[transaction_type] + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + ) + + elif transaction_type is TransactionType.PARTITIONED: + return ( + self._is_multiplexed_enabled[transaction_type] + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + ) + + elif transaction_type is TransactionType.READ_WRITE: + return ( + self._is_multiplexed_enabled[transaction_type] + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + ) + + raise ValueError(f"Transaction type {transaction_type} is not supported.") + + def disable_multiplexed( + self, logger: Logger = None, transaction_type: TransactionType = None + ) -> None: + """Disables the use of multiplexed sessions for the given transaction type. + If no transaction type is specified, disables the use of multiplexed sessions + for all transaction types. + + :type logger: :class:`Logger` + :param logger: logger to use for logging the disabling the use of multiplexed + sessions. + + :type transaction_type: :class:`TransactionType` + :param transaction_type: (Optional) the type of transaction for which to disable + the use of multiplexed sessions. + """ + + disable_multiplexed_log_msg_fstring = ( + "Disabling multiplexed sessions for {transaction_type_value} transactions" + ) + + if transaction_type is None: + logger.warning( + disable_multiplexed_log_msg_fstring.format(transaction_type_value="all") + ) + for transaction_type in TransactionType: + self._is_multiplexed_enabled[transaction_type] = False + return + + elif transaction_type in self._is_multiplexed_enabled.keys(): + logger.warning( + disable_multiplexed_log_msg_fstring.format( + transaction_type_value=transaction_type.value + ) + ) + self._is_multiplexed_enabled[transaction_type] = False + return + + raise ValueError(f"Transaction type '{transaction_type}' is not supported.") + + @staticmethod + def _getenv(name: str) -> bool: + """Returns the value of the given environment variable as a boolean. + True values are '1' and 'true' (case-insensitive); all other values are + considered false. + """ + env_var = os.getenv(name, "").lower().strip() + return env_var in ["1", "true"] diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 3b18d2c855..22bbe0e103 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -14,7 +14,6 @@ """Model a set of read-only queries to a database as a snapshot.""" -from datetime import datetime import functools import threading from google.protobuf.struct_pb2 import Struct @@ -40,6 +39,7 @@ _SessionWrapper, ) from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1.session_options import TransactionType from google.cloud.spanner_v1.streamed import StreamedResultSet from google.cloud.spanner_v1 import RequestOptions @@ -54,9 +54,9 @@ def _restart_on_unavailable( method, request, + session, metadata=None, trace_name=None, - session=None, attributes=None, transaction=None, transaction_selector=None, @@ -70,6 +70,9 @@ def _restart_on_unavailable( :type request: proto :param request: request proto to call the method with + :type session: :class:`~google.cloud.spanner_v1.session.Session` + :param session: the session used to perform the operation + :type transaction: :class:`google.cloud.spanner_v1.snapshot._SnapshotBase` :param transaction: Snapshot or Transaction class object based on the type of transaction @@ -153,6 +156,9 @@ def _restart_on_unavailable( iterator = method(request=request) continue + except NotImplementedError as exc: + _handle_not_implemented_error(exc, session._database) + if len(item_buffer) == 0: break @@ -162,6 +168,30 @@ def _restart_on_unavailable( del item_buffer[:] +def _handle_not_implemented_error(exception, database) -> None: + """Handles NotImplementedError for the database. If the error is due to unsupported + partitioned operations with multiplexed sessions, disables multiplexed sessions for + read-only transactions and re-raises the error. Otherwise, re-raises the error + with no side effects. + + :type exception: :class:`NotImplementedError` + :param exception: The NotImplementedError to handle. + + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: The database instance associated with the error. + + :raises NotImplementedError: The original exception + """ + + if "Partitioned operations are not supported with multiplexed sessions" in str( + exception + ): + session_options = database.session_options + session_options.disable_multiplexed(database.logger, TransactionType.READ_ONLY) + + raise exception + + class _SnapshotBase(_SessionWrapper): """Base class for Snapshot. @@ -329,7 +359,8 @@ def read( data_boost_enabled=data_boost_enabled, directed_read_options=directed_read_options, ) - restart = functools.partial( + + streaming_read = functools.partial( api.streaming_read, request=request, metadata=metadata, @@ -340,54 +371,23 @@ def read( trace_attributes = {"table_id": table, "columns": columns} observability_options = getattr(database, "observability_options", None) + get_streamed_result_set_args = { + "method": streaming_read, + "request": request, + "metadata": metadata, + "trace_attributes": trace_attributes, + "column_info": column_info, + "observability_options": observability_options, + "lazy_decode": lazy_decode, + } + if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: - iterator = _restart_on_unavailable( - restart, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.read", - self._session, - trace_attributes, - transaction=self, - observability_options=observability_options, - ) - self._read_request_count += 1 - if self._multi_use: - return StreamedResultSet( - iterator, - source=self, - column_info=column_info, - lazy_decode=lazy_decode, - ) - else: - return StreamedResultSet( - iterator, column_info=column_info, lazy_decode=lazy_decode - ) - else: - iterator = _restart_on_unavailable( - restart, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.read", - self._session, - trace_attributes, - transaction=self, - observability_options=observability_options, - ) + return self._get_streamed_result_set(**get_streamed_result_set_args) - self._read_request_count += 1 - self._session._last_use_time = datetime.now() - - if self._multi_use: - return StreamedResultSet( - iterator, source=self, column_info=column_info, lazy_decode=lazy_decode - ) else: - return StreamedResultSet( - iterator, column_info=column_info, lazy_decode=lazy_decode - ) + return self._get_streamed_result_set(**get_streamed_result_set_args) def execute_sql( self, @@ -562,7 +562,7 @@ def execute_sql( data_boost_enabled=data_boost_enabled, directed_read_options=directed_read_options, ) - restart = functools.partial( + execute_streaming_sql_method = functools.partial( api.execute_streaming_sql, request=request, metadata=metadata, @@ -573,60 +573,22 @@ def execute_sql( trace_attributes = {"db.statement": sql} observability_options = getattr(database, "observability_options", None) + get_streamed_result_set_args = { + "method": execute_streaming_sql_method, + "request": request, + "metadata": metadata, + "trace_attributes": trace_attributes, + "column_info": column_info, + "observability_options": observability_options, + "lazy_decode": lazy_decode, + } + if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: - return self._get_streamed_result_set( - restart, - request, - metadata, - trace_attributes, - column_info, - observability_options, - lazy_decode=lazy_decode, - ) - else: - return self._get_streamed_result_set( - restart, - request, - metadata, - trace_attributes, - column_info, - observability_options, - lazy_decode=lazy_decode, - ) - - def _get_streamed_result_set( - self, - restart, - request, - metadata, - trace_attributes, - column_info, - observability_options=None, - lazy_decode=False, - ): - iterator = _restart_on_unavailable( - restart, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.execute_sql", - self._session, - trace_attributes, - transaction=self, - observability_options=observability_options, - ) - self._read_request_count += 1 - self._execute_sql_count += 1 - - if self._multi_use: - return StreamedResultSet( - iterator, source=self, column_info=column_info, lazy_decode=lazy_decode - ) + return self._get_streamed_result_set(**get_streamed_result_set_args) else: - return StreamedResultSet( - iterator, column_info=column_info, lazy_decode=lazy_decode - ) + return self._get_streamed_result_set(**get_streamed_result_set_args) def partition_read( self, @@ -725,10 +687,15 @@ def partition_read( retry=retry, timeout=timeout, ) - response = _retry( - method, - allowed_exceptions={InternalServerError: _check_rst_stream_error}, - ) + + try: + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) + + except NotImplementedError as exc: + _handle_not_implemented_error(exc, database) return [partition.partition_token for partition in response.partitions] @@ -829,13 +796,62 @@ def partition_query( retry=retry, timeout=timeout, ) - response = _retry( - method, - allowed_exceptions={InternalServerError: _check_rst_stream_error}, - ) + + try: + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) + + except NotImplementedError as exc: + _handle_not_implemented_error(exc, database) return [partition.partition_token for partition in response.partitions] + def _get_streamed_result_set( + self, + method, + request, + metadata, + trace_attributes, + column_info, + observability_options, + lazy_decode, + ): + """Returns the streamed result set for a read or execute SQL request with the given arguments.""" + + is_execute_sql_request = isinstance(request, ExecuteSqlRequest) + + trace_request_name = "execute_sql" if is_execute_sql_request else "read" + trace_name = f"CloudSpanner.{type(self).__name__}.{trace_request_name}" + + iterator = _restart_on_unavailable( + method=method, + request=request, + session=self._session, + metadata=metadata, + trace_name=trace_name, + attributes=trace_attributes, + transaction=self, + observability_options=observability_options, + ) + + self._read_request_count += 1 + + if is_execute_sql_request: + self._execute_sql_count += 1 + + streamed_result_set_args = { + "response_iterator": iterator, + "column_info": column_info, + "lazy_decode": lazy_decode, + } + + if self._multi_use: + streamed_result_set_args["source"] = self + + return StreamedResultSet(**streamed_result_set_args) + class Snapshot(_SnapshotBase): """Allow a set of reads / SQL statements with shared staleness. diff --git a/noxfile.py b/noxfile.py index cb683afd7e..e76461514b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -34,7 +34,6 @@ DEFAULT_PYTHON_VERSION = "3.8" -DEFAULT_MOCK_SERVER_TESTS_PYTHON_VERSION = "3.12" UNIT_TEST_PYTHON_VERSIONS: List[str] = [ "3.7", "3.8", @@ -169,30 +168,6 @@ def install_unittest_dependencies(session, *constraints): else: session.install("-e", ".", *constraints) - # XXX Work around Kokoro image's older pip, which borks the OT install. - session.run("pip", "install", "--upgrade", "pip") - constraints_path = str( - CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" - ) - session.install("-e", ".[tracing]", "-c", constraints_path) - # XXX: Dump installed versions to debug OT issue - session.run("pip", "list") - - # Run py.test against the unit tests with OpenTelemetry. - session.run( - "py.test", - "--quiet", - "--cov=google.cloud.spanner", - "--cov=google.cloud", - "--cov=tests.unit", - "--cov-append", - "--cov-config=.coveragerc", - "--cov-report=", - "--cov-fail-under=0", - os.path.join("tests", "unit"), - *session.posargs, - ) - @nox.session(python=UNIT_TEST_PYTHON_VERSIONS) @nox.parametrize( @@ -235,34 +210,6 @@ def unit(session, protobuf_implementation): ) -@nox.session(python=DEFAULT_MOCK_SERVER_TESTS_PYTHON_VERSION) -def mockserver(session): - # Install all test dependencies, then install this package in-place. - - constraints_path = str( - CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" - ) - # install_unittest_dependencies(session, "-c", constraints_path) - standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES - session.install(*standard_deps, "-c", constraints_path) - session.install("-e", ".", "-c", constraints_path) - - # Run py.test against the mockserver tests. - session.run( - "py.test", - "--quiet", - f"--junitxml=unit_{session.python}_sponge_log.xml", - "--cov=google", - "--cov=tests/unit", - "--cov-append", - "--cov-config=.coveragerc", - "--cov-report=", - "--cov-fail-under=0", - os.path.join("tests", "mockserver_tests"), - *session.posargs, - ) - - def install_systemtest_dependencies(session, *constraints): # Use pre-release gRPC for system tests. # Exclude version 1.52.0rc1 which has a known issue. @@ -294,18 +241,7 @@ def install_systemtest_dependencies(session, *constraints): @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) -@nox.parametrize( - "protobuf_implementation,database_dialect", - [ - ("python", "GOOGLE_STANDARD_SQL"), - ("python", "POSTGRESQL"), - ("upb", "GOOGLE_STANDARD_SQL"), - ("upb", "POSTGRESQL"), - ("cpp", "GOOGLE_STANDARD_SQL"), - ("cpp", "POSTGRESQL"), - ], -) -def system(session, protobuf_implementation, database_dialect): +def system(session): """Run the system test suite.""" constraints_path = str( CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" @@ -316,17 +252,6 @@ def system(session, protobuf_implementation, database_dialect): # Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true. if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false": session.skip("RUN_SYSTEM_TESTS is set to false, skipping") - # Sanity check: Only run tests if the environment variable is set. - if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", "") and not os.environ.get( - "SPANNER_EMULATOR_HOST", "" - ): - session.skip( - "Credentials or emulator host must be set via environment variable" - ) - # If POSTGRESQL tests and Emulator, skip the tests - if os.environ.get("SPANNER_EMULATOR_HOST") and database_dialect == "POSTGRESQL": - session.skip("Postgresql is not supported by Emulator yet.") - # Install pyopenssl for mTLS testing. if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true": session.install("pyopenssl") @@ -339,12 +264,6 @@ def system(session, protobuf_implementation, database_dialect): install_systemtest_dependencies(session, "-c", constraints_path) - # TODO(https://github.com/googleapis/synthtool/issues/1976): - # Remove the 'cpp' implementation once support for Protobuf 3.x is dropped. - # The 'cpp' implementation requires Protobuf<4. - if protobuf_implementation == "cpp": - session.install("protobuf<4") - # Run py.test against the system tests. if system_test_exists: session.run( @@ -353,11 +272,6 @@ def system(session, protobuf_implementation, database_dialect): f"--junitxml=system_{session.python}_sponge_log.xml", system_test_path, *session.posargs, - env={ - "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, - "SPANNER_DATABASE_DIALECT": database_dialect, - "SKIP_BACKUP_TESTS": "true", - }, ) if system_test_folder_exists: session.run( @@ -366,11 +280,6 @@ def system(session, protobuf_implementation, database_dialect): f"--junitxml=system_{session.python}_sponge_log.xml", system_test_folder_path, *session.posargs, - env={ - "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, - "SPANNER_DATABASE_DIALECT": database_dialect, - "SKIP_BACKUP_TESTS": "true", - }, ) @@ -391,7 +300,7 @@ def cover(session): def docs(session): """Build the docs for this library.""" - session.install("-e", ".[tracing]") + session.install("-e", ".") session.install( # We need to pin to specific versions of the `sphinxcontrib-*` packages # which still support sphinx 4.x. @@ -426,7 +335,7 @@ def docs(session): def docfx(session): """Build the docfx yaml files for this library.""" - session.install("-e", ".[tracing]") + session.install("-e", ".") session.install( # We need to pin to specific versions of the `sphinxcontrib-*` packages # which still support sphinx 4.x. @@ -470,17 +379,10 @@ def docfx(session): @nox.session(python="3.13") @nox.parametrize( - "protobuf_implementation,database_dialect", - [ - ("python", "GOOGLE_STANDARD_SQL"), - ("python", "POSTGRESQL"), - ("upb", "GOOGLE_STANDARD_SQL"), - ("upb", "POSTGRESQL"), - ("cpp", "GOOGLE_STANDARD_SQL"), - ("cpp", "POSTGRESQL"), - ], + "protobuf_implementation", + ["python", "upb", "cpp"], ) -def prerelease_deps(session, protobuf_implementation, database_dialect): +def prerelease_deps(session, protobuf_implementation): """Run all tests with prerelease versions of dependencies installed.""" if protobuf_implementation == "cpp" and session.python in ("3.11", "3.12", "3.13"): @@ -553,8 +455,6 @@ def prerelease_deps(session, protobuf_implementation, database_dialect): "tests/unit", env={ "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, - "SPANNER_DATABASE_DIALECT": database_dialect, - "SKIP_BACKUP_TESTS": "true", }, ) @@ -571,8 +471,6 @@ def prerelease_deps(session, protobuf_implementation, database_dialect): *session.posargs, env={ "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, - "SPANNER_DATABASE_DIALECT": database_dialect, - "SKIP_BACKUP_TESTS": "true", }, ) if os.path.exists(system_test_folder_path): @@ -584,7 +482,5 @@ def prerelease_deps(session, protobuf_implementation, database_dialect): *session.posargs, env={ "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, - "SPANNER_DATABASE_DIALECT": database_dialect, - "SKIP_BACKUP_TESTS": "true", }, ) diff --git a/tests/_builders.py b/tests/_builders.py new file mode 100644 index 0000000000..c07d003c19 --- /dev/null +++ b/tests/_builders.py @@ -0,0 +1,71 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Default identifiers. +PROJECT_ID = "project-id" +INSTANCE_ID = "instance-id" +DATABASE_ID = "database-id" +SESSION_ID = "session-id" +TRANSACTION_ID = b"transaction-id" + +# Default names. +INSTANCE_NAME = "projects/" + PROJECT_ID + "/instances/" + INSTANCE_ID +DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID +SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID + + +def build_database(**database_kwargs): + """Builds and returns a database for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" + from google.cloud.spanner_v1 import Session as SessionProto + from google.cloud.spanner_v1 import SpannerClient + from google.cloud.spanner_v1 import Transaction as TransactionProto + from google.cloud.spanner_v1.database import Database + from mock.mock import create_autospec + + instance = _build_instance() + database = Database(database_id=DATABASE_ID, instance=instance) + + # Mock API calls. Callers can override this to test specific behaviours. + api = database._spanner_api = create_autospec(SpannerClient, instance=True) + api.create_session.return_value = SessionProto(name=SESSION_NAME) + api.begin_transaction.return_value = TransactionProto(id=TRANSACTION_ID) + + return database + + +def build_session(**session_kwargs): + """Builds and returns a session for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" + from google.cloud.spanner_v1.session import Session + + if not session_kwargs.get("database"): + session_kwargs["database"] = build_database() + + return Session(**session_kwargs) + + +def _build_client(): + """Builds and returns a client for testing.""" + from google.cloud.spanner_v1 import Client + + return Client(project=PROJECT_ID) + + +def _build_instance(**instance_kwargs): + """Builds and returns an instance for testing.""" + from google.cloud.spanner_v1.instance import Instance + + client = _build_client() + return Instance(instance_id=INSTANCE_ID, client=client) diff --git a/tests/_helpers.py b/tests/_helpers.py index 667f9f8be1..f004e44494 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,7 +1,9 @@ +import os import unittest import mock from google.cloud.spanner_v1 import gapic_version +from google.cloud.spanner_v1.session_options import SessionOptions LIB_VERSION = gapic_version.__version__ @@ -149,3 +151,20 @@ def finished_spans_events_statuses(self): got_all_events.append((event.name, evt_attributes)) return got_all_events + + +def enable_multiplexed_sessions() -> None: + """Sets environment variables to enable multiplexed sessions for all transaction types. + The caller is responsible for resetting the environment variables after use.""" + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + + +def disable_multiplexed_sessions() -> None: + """Sets environment variables to disable multiplexed sessions for all transactions types. + The caller is responsible for resetting the environment variables after use.""" + + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 57ce49c8a2..4e44fb0b0c 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -121,7 +121,7 @@ def test_database_binding_of_fixed_size_pool( database_role="parent", ) database = shared_instance.database(temp_db_id, pool=pool) - assert database._pool.database_role == "parent" + assert database._session_manager._pool.database_role == "parent" def test_database_binding_of_pinging_pool( @@ -155,7 +155,7 @@ def test_database_binding_of_pinging_pool( database_role="parent", ) database = shared_instance.database(temp_db_id, pool=pool) - assert database._pool.database_role == "parent" + assert database._session_manager._pool.database_role == "parent" def test_create_database_pitr_invalid_retention_period( diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index d40b34f800..c247a6d6a9 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch, PropertyMock import pytest +from google.cloud.spanner_v1.session import Session from . import _helpers from google.cloud.spanner_v1 import Client from google.api_core.exceptions import Aborted @@ -195,9 +197,15 @@ def create_db_trace_exporter(): not HAS_OTEL_INSTALLED, reason="Tracing requires OpenTelemetry", ) -def test_transaction_abort_then_retry_spans(): +@patch.object(Session, "session_id", new_callable=PropertyMock) +@patch.object(Session, "is_multiplexed", new_callable=PropertyMock) +def test_transaction_abort_then_retry_spans(mock_session_multiplexed, mock_session_id): from opentelemetry.trace.status import StatusCode + # Mock session properties for testing. + mock_session_multiplexed.return_value = session_multiplexed = False + mock_session_id.return_value = session_id = "session-id" + db, trace_exporter = create_db_trace_exporter() counters = dict(aborted=0) @@ -224,6 +232,8 @@ def select_in_txn(txn): ("Waiting for a session to become available", {"kind": "BurstyPool"}), ("No sessions available in pool. Creating session", {"kind": "BurstyPool"}), ("Creating Session", {}), + ("Using session", {"id": session_id, "multiplexed": session_multiplexed}), + ("Returning session", {"id": session_id, "multiplexed": session_multiplexed}), ( "Transaction was aborted in user operation, retrying", {"delay_seconds": "EPHEMERAL", "cause": "EPHEMERAL", "attempt": 1}, @@ -391,9 +401,15 @@ def tx_update(txn): not HAS_OTEL_INSTALLED, reason="Tracing requires OpenTelemetry", ) -def test_database_partitioned_error(): +@patch.object(Session, "session_id", new_callable=PropertyMock) +@patch.object(Session, "is_multiplexed", new_callable=PropertyMock) +def test_database_partitioned_error(mock_session_multiplexed, mock_session_id): from opentelemetry.trace.status import StatusCode + # Mock session properties for testing. + mock_session_multiplexed.return_value = session_multiplexed = False + mock_session_id.return_value = session_id = "session-id" + db, trace_exporter = create_db_trace_exporter() try: @@ -408,7 +424,9 @@ def test_database_partitioned_error(): ("Waiting for a session to become available", {"kind": "BurstyPool"}), ("No sessions available in pool. Creating session", {"kind": "BurstyPool"}), ("Creating Session", {}), + ("Using session", {"id": session_id, "multiplexed": session_multiplexed}), ("Starting BeginTransaction", {}), + ("Returning session", {"id": session_id, "multiplexed": session_multiplexed}), ( "exception", { diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 73b55b035d..f72c4d12ee 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -33,7 +33,7 @@ from tests import _helpers as ot_helpers from . import _helpers from . import _sample_data - +from .._builders import build_session SOME_DATE = datetime.date(2011, 1, 17) SOME_TIME = datetime.datetime(1989, 1, 17, 17, 59, 12, 345612) @@ -415,7 +415,7 @@ def handle_abort(self, database): def test_session_crud(sessions_database): - session = sessions_database.session() + session = build_session(database=sessions_database) assert not session.exists() session.create() @@ -587,7 +587,7 @@ def test_transaction_read_and_insert_then_rollback( sd = _sample_data db_name = sessions_database.name - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -722,7 +722,7 @@ def test_transaction_read_and_insert_or_update_then_commit( # [START spanner_test_dml_read_your_writes] sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -780,7 +780,7 @@ def test_transaction_execute_sql_w_dml_read_rollback( # [START spanner_test_dml_rollback_txn_not_committed] sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -814,7 +814,7 @@ def test_transaction_execute_update_read_commit(sessions_database, sessions_to_d # [START spanner_test_dml_read_your_writes] sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -846,7 +846,7 @@ def test_transaction_execute_update_then_insert_commit( # [START spanner_test_dml_update] sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -879,7 +879,7 @@ def test_transaction_execute_sql_dml_returning( ): sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -910,7 +910,7 @@ def test_transaction_execute_update_dml_returning( ): sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -938,7 +938,7 @@ def test_transaction_batch_update_dml_returning( ): sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -970,7 +970,7 @@ def test_transaction_batch_update_success( sd = _sample_data param_types = spanner_v1.param_types - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -1027,7 +1027,7 @@ def test_transaction_batch_update_and_execute_dml( sd = _sample_data param_types = spanner_v1.param_types - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -1087,7 +1087,7 @@ def test_transaction_batch_update_w_syntax_error( sd = _sample_data param_types = spanner_v1.param_types - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -1132,7 +1132,7 @@ def unit_of_work(transaction): def test_transaction_batch_update_wo_statements(sessions_database, sessions_to_delete): - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -1155,7 +1155,7 @@ def test_transaction_batch_update_w_parent_span( param_types = spanner_v1.param_types tracer = trace.get_tracer(__name__) - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index 30ab3c7a8d..4021ee6083 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -60,8 +60,8 @@ def test_w_implicit(self, mock_client): self.assertIs(connection.database, database) instance.database.assert_called_once_with(DATABASE, pool=None) - # Datbase constructs its own pool - self.assertIsNotNone(connection.database._pool) + # Database constructs its own pool + self.assertIsNotNone(connection.database._session_manager._pool) self.assertTrue(connection.instance._client.route_to_leader_enabled) def test_w_explicit(self, mock_client): diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 4bee9e93c7..038ab558d2 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -37,6 +37,12 @@ AutocommitDmlMode, ) +from google.cloud.spanner_v1.client import Client +from google.cloud.spanner_v1.instance import Instance +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.session_options import TransactionType + PROJECT = "test-project" INSTANCE = "test-instance" DATABASE = "test-database" @@ -64,9 +70,6 @@ def _get_client_info(self): def _make_connection( self, database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, **kwargs ): - from google.cloud.spanner_v1.instance import Instance - from google.cloud.spanner_v1.client import Client - # We don't need a real Client object to test the constructor client = Client() instance = Instance(INSTANCE, client=client) @@ -97,22 +100,16 @@ def test_autocommit_setter_transaction_started(self, mock_commit): self.assertTrue(connection._autocommit) def test_property_database(self): - from google.cloud.spanner_v1.database import Database - connection = self._make_connection() self.assertIsInstance(connection.database, Database) self.assertEqual(connection.database, connection._database) def test_property_instance(self): - from google.cloud.spanner_v1.instance import Instance - connection = self._make_connection() self.assertIsInstance(connection.instance, Instance) self.assertEqual(connection.instance, connection._instance) def test_property_current_schema_google_sql_dialect(self): - from google.cloud.spanner_v1.database import Database - connection = self._make_connection( database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL ) @@ -120,8 +117,6 @@ def test_property_current_schema_google_sql_dialect(self): self.assertEqual(connection.current_schema, "") def test_property_current_schema_postgres_sql_dialect(self): - from google.cloud.spanner_v1.database import Database - connection = self._make_connection(database_dialect=DatabaseDialect.POSTGRESQL) self.assertIsInstance(connection.database, Database) self.assertEqual(connection.current_schema, "public") @@ -146,25 +141,39 @@ def test_read_only_connection(self): connection.read_only = False self.assertFalse(connection.read_only) - @staticmethod - def _make_pool(): - from google.cloud.spanner_v1.pool import AbstractSessionPool + def test__session_checkout_read_only(self): + client = Client() + instance = Instance(instance_id="instance-id", client=client) + database = Database(database_id="database-id", instance=instance) + session_manager = database._session_manager - return mock.create_autospec(AbstractSessionPool) + session = Session(database=database) + session_manager.get_session = mock.Mock(return_value=session) - @mock.patch("google.cloud.spanner_v1.database.Database") - def test__session_checkout(self, mock_database): - pool = self._make_pool() - mock_database._pool = pool - connection = Connection(INSTANCE, mock_database) + read_only_connection = Connection( + instance="instance-id", database=database, read_only=True + ) + read_only_connection._session_checkout() - connection._session_checkout() - pool.get.assert_called_once_with() - self.assertEqual(connection._session, pool.get.return_value) + session_manager.get_session.assert_called_once_with(TransactionType.READ_ONLY) + self.assertEqual(read_only_connection._session, session) - connection._session = "db_session" - connection._session_checkout() - self.assertEqual(connection._session, "db_session") + def test__session_checkout_read_write(self): + client = Client() + instance = Instance(instance_id="instance-id", client=client) + database = Database(database_id="database-id", instance=instance) + session_manager = database._session_manager + + session = Session(database=database) + session_manager.get_session = mock.Mock(return_value=session) + + read_write_connection = Connection( + instance="instance-id", database=database, read_only=False + ) + read_write_connection._session_checkout() + + session_manager.get_session.assert_called_once_with(TransactionType.READ_WRITE) + self.assertEqual(read_write_connection._session, session) def test_session_checkout_database_error(self): connection = Connection(INSTANCE) @@ -172,15 +181,20 @@ def test_session_checkout_database_error(self): with pytest.raises(ValueError): connection._session_checkout() - @mock.patch("google.cloud.spanner_v1.database.Database") - def test__release_session(self, mock_database): - pool = self._make_pool() - mock_database._pool = pool - connection = Connection(INSTANCE, mock_database) - connection._session = "session" + def test__release_session(self): + client = Client() + instance = Instance(instance_id="instance-id", client=client) + database = Database(database_id="database-id", instance=instance) + connection = Connection(instance="instance-id", database=database) + + # Mock connection session and session manager. + session = Session(database=database) + connection._session = session + session_manager = database._session_manager + session_manager.put_session = mock.Mock() connection._release_session() - pool.put.assert_called_once_with("session") + session_manager.put_session.assert_called_once_with(session) self.assertIsNone(connection._session) def test_release_session_database_error(self): diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 1afda7f850..56d6223b3e 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import os import unittest +from logging import Logger import mock from google.api_core import gapic_v1 @@ -21,15 +21,20 @@ Database as DatabasePB, DatabaseDialect, ) +from google.cloud.spanner_v1.database import Database, BatchSnapshot, SessionCheckout from google.cloud.spanner_v1.param_types import INT64 from google.api_core.retry import Retry from google.protobuf.field_mask_pb2 import FieldMask from google.cloud.spanner_v1 import ( - RequestOptions, - DirectedReadOptions, DefaultTransactionOptions, + DirectedReadOptions, + RequestOptions, ) +from google.cloud.spanner_v1.session_options import TransactionType +from google.cloud.spanner_v1.snapshot import Snapshot +from tests._builders import build_database, build_session, TRANSACTION_ID, SESSION_ID +from tests._helpers import enable_multiplexed_sessions DML_WO_PARAM = """ DELETE FROM citizens @@ -75,13 +80,20 @@ class _BaseTest(unittest.TestCase): DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID SESSION_ID = "session_id" SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID - TRANSACTION_ID = b"transaction_id" - RETRY_TRANSACTION_ID = b"transaction_id_retry" BACKUP_ID = "backup_id" BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID TRANSACTION_TAG = "transaction-tag" DATABASE_ROLE = "dummy-role" + def setUp(self): + # Save the original environment variables. + self._original_env = dict(os.environ) + + def tearDown(self): + # Restore environment variables. + os.environ.clear() + os.environ.update(self._original_env) + def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -90,7 +102,7 @@ def _make_timestamp(): import datetime from google.cloud._helpers import UTC - return datetime.datetime.utcnow().replace(tzinfo=UTC) + return datetime.datetime.now(tz=UTC) @staticmethod def _make_duration(seconds=1, microseconds=0): @@ -101,8 +113,6 @@ def _make_duration(seconds=1, microseconds=0): class TestDatabase(_BaseTest): def _get_target_class(self): - from google.cloud.spanner_v1.database import Database - return Database @staticmethod @@ -127,11 +137,8 @@ def test_ctor_defaults(self): self.assertEqual(database.database_id, self.DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), []) - self.assertIsInstance(database._pool, BurstyPool) + self.assertIsInstance(database._session_manager._pool, BurstyPool) self.assertFalse(database.log_commit_stats) - self.assertIsNone(database._logger) - # BurstyPool does not create sessions during 'bind()'. - self.assertTrue(database._pool._sessions.empty()) self.assertIsNone(database.database_role) self.assertTrue(database._route_to_leader_enabled, True) @@ -142,7 +149,7 @@ def test_ctor_w_explicit_pool(self): self.assertEqual(database.database_id, self.DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), []) - self.assertIs(database._pool, pool) + self.assertIs(database._session_manager._pool, pool) self.assertIs(pool._bound, database) def test_ctor_w_database_role(self): @@ -191,8 +198,6 @@ def test_ctor_w_ddl_statements_ok(self): self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS) def test_ctor_w_explicit_logger(self): - from logging import Logger - instance = _Instance(self.INSTANCE_NAME) logger = mock.create_autospec(Logger, instance=True) database = self._make_one(self.DATABASE_ID, instance, logger=logger) @@ -279,7 +284,7 @@ def test_from_pb_success_w_explicit_pool(self): self.assertIsInstance(database, klass) self.assertEqual(database._instance, instance) self.assertEqual(database.database_id, self.DATABASE_ID) - self.assertIs(database._pool, pool) + self.assertIs(database._session_manager._pool, pool) def test_from_pb_success_w_hyphen_w_default_pool(self): from google.cloud.spanner_admin_database_v1 import Database @@ -297,9 +302,9 @@ def test_from_pb_success_w_hyphen_w_default_pool(self): self.assertIsInstance(database, klass) self.assertEqual(database._instance, instance) self.assertEqual(database.database_id, DATABASE_ID_HYPHEN) - self.assertIsInstance(database._pool, BurstyPool) + self.assertIsInstance(database._session_manager._pool, BurstyPool) # BurstyPool does not create sessions during 'bind()'. - self.assertTrue(database._pool._sessions.empty()) + self.assertTrue(database._session_manager._pool._sessions.empty()) def test_name_property(self): instance = _Instance(self.INSTANCE_NAME) @@ -1109,9 +1114,15 @@ def _execute_partitioned_dml_helper( import collections - MethodConfig = collections.namedtuple("MethodConfig", ["retry"]) + transaction_id = b"transaction-id" + transaction_pb = TransactionPB(id=transaction_id) + transaction_selector_pb = TransactionSelector(id=transaction_id) - transaction_pb = TransactionPB(id=self.TRANSACTION_ID) + retry_transaction_id = b"retry-transaction-id" + retry_transaction_pb = TransactionPB(id=retry_transaction_id) + retry_transaction_selector_pb = TransactionSelector(id=retry_transaction_id) + + MethodConfig = collections.namedtuple("MethodConfig", ["retry"]) stats_pb = ResultSetStats(row_count_lower_bound=2) result_sets = [PartialResultSet(stats=stats_pb)] @@ -1126,7 +1137,6 @@ def _execute_partitioned_dml_helper( api = database._spanner_api = self._make_spanner_api() api._method_configs = {"ExecuteStreamingSql": MethodConfig(retry=Retry())} if retried: - retry_transaction_pb = TransactionPB(id=self.RETRY_TRANSACTION_ID) api.begin_transaction.side_effect = [transaction_pb, retry_transaction_pb] api.execute_streaming_sql.side_effect = [Aborted("test"), iterator] else: @@ -1169,7 +1179,6 @@ def _execute_partitioned_dml_helper( else: expected_params = {} - expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) expected_query_options = client._query_options if query_options: expected_query_options = _merge_query_options( @@ -1184,7 +1193,7 @@ def _execute_partitioned_dml_helper( expected_request = ExecuteSqlRequest( session=self.SESSION_NAME, sql=dml, - transaction=expected_transaction, + transaction=transaction_selector_pb, params=expected_params, param_types=param_types, query_options=expected_query_options, @@ -1199,13 +1208,10 @@ def _execute_partitioned_dml_helper( ], ) if retried: - expected_retry_transaction = TransactionSelector( - id=self.RETRY_TRANSACTION_ID - ) expected_request = ExecuteSqlRequest( session=self.SESSION_NAME, sql=dml, - transaction=expected_retry_transaction, + transaction=retry_transaction_selector_pb, params=expected_params, param_types=param_types, query_options=expected_query_options, @@ -1266,36 +1272,19 @@ def test_execute_partitioned_dml_w_exclude_txn_from_change_streams(self): dml=DML_WO_PARAM, exclude_txn_from_change_streams=True ) - def test_session_factory_defaults(self): - from google.cloud.spanner_v1.session import Session - - client = _Client() - instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + def test_execute_partitioned_dml_not_implemented_error_multiplexed(self): + enable_multiplexed_sessions() - session = database.session() - - self.assertIsInstance(session, Session) - self.assertIs(session.session_id, None) - self.assertIs(session._database, database) - self.assertEqual(session.labels, {}) - - def test_session_factory_w_labels(self): - from google.cloud.spanner_v1.session import Session - - client = _Client() - instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - labels = {"foo": "bar"} - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = build_database() + database.spanner_api.begin_transaction.side_effect = NotImplementedError( + "Transaction type partitioned_dml not supported with multiplexed sessions" + ) - session = database.session(labels=labels) + with self.assertRaises(NotImplementedError): + database.execute_partitioned_dml(dml=DML_WO_PARAM) - self.assertIsInstance(session, Session) - self.assertIs(session.session_id, None) - self.assertIs(session._database, database) - self.assertEqual(session.labels, labels) + session_options = database.session_options + self.assertFalse(session_options.use_multiplexed(TransactionType.PARTITIONED)) def test_snapshot_defaults(self): from google.cloud.spanner_v1.database import SnapshotCheckout @@ -1317,7 +1306,7 @@ def test_snapshot_w_read_timestamp_and_multi_use(self): from google.cloud._helpers import UTC from google.cloud.spanner_v1.database import SnapshotCheckout - now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now = datetime.datetime.now(tz=UTC) client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() @@ -1801,7 +1790,9 @@ def _make_spanner_client(): return mock.create_autospec(SpannerClient) def test_ctor(self): - database = _Database(self.DATABASE_NAME) + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) checkout = self._make_one(database) self.assertIs(checkout._database, database) @@ -1814,25 +1805,32 @@ def test_context_mgr_success(self): from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud.spanner_v1.batch import Batch - now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now = datetime.datetime.now(tz=UTC) now_pb = _datetime_to_pb_timestamp(now) response = CommitResponse(commit_timestamp=now_pb) - database = _Database(self.DATABASE_NAME) - api = database.spanner_api = self._make_spanner_client() + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + api = database._spanner_api = self._make_spanner_client() api.commit.return_value = response - pool = database._pool = _Pool() + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one( database, request_options={"transaction_tag": self.TRANSACTION_TAG} ) with checkout as batch: - self.assertIsNone(pool._session) + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) self.assertEqual(batch.committed, now) self.assertEqual(batch.transaction_tag, self.TRANSACTION_TAG) @@ -1861,25 +1859,35 @@ def test_context_mgr_w_commit_stats_success(self): from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud.spanner_v1.batch import Batch - now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now = datetime.datetime.now(tz=UTC) now_pb = _datetime_to_pb_timestamp(now) commit_stats = CommitResponse.CommitStats(mutation_count=4) response = CommitResponse(commit_timestamp=now_pb, commit_stats=commit_stats) - database = _Database(self.DATABASE_NAME) + logger = mock.create_autospec(Logger, instance=True) + database = Database( + database_id=self.DATABASE_ID, + instance=_Instance(self.INSTANCE_NAME), + logger=logger, + ) database.log_commit_stats = True - api = database.spanner_api = self._make_spanner_client() + api = database._spanner_api = self._make_spanner_client() api.commit.return_value = response - pool = database._pool = _Pool() + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) with checkout as batch: - self.assertIsNone(pool._session) + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) self.assertEqual(batch.committed, now) expected_txn_options = TransactionOptions(read_write={}) @@ -1899,32 +1907,47 @@ def test_context_mgr_w_commit_stats_success(self): ], ) - database.logger.info.assert_called_once_with( + database._logger.info.assert_called_once_with( "CommitStats: mutation_count: 4\n", extra={"commit_stats": commit_stats} ) + def test_type_error(self): + with self.assertRaises(TypeError): + with self._make_one(None) as _: + pass + def test_context_mgr_w_aborted_commit_status(self): from google.api_core.exceptions import Aborted from google.cloud.spanner_v1 import CommitRequest from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1.batch import Batch - database = _Database(self.DATABASE_NAME) + logger = mock.create_autospec(Logger, instance=True) + database = Database( + database_id=self.DATABASE_ID, + instance=_Instance(self.INSTANCE_NAME), + logger=logger, + ) database.log_commit_stats = True - api = database.spanner_api = self._make_spanner_client() + api = database._spanner_api = self._make_spanner_client() api.commit.side_effect = Aborted("aborted exception", errors=("Aborted error")) - pool = database._pool = _Pool() + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) with self.assertRaises(Aborted): with checkout as batch: - self.assertIsNone(pool._session) + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) expected_txn_options = TransactionOptions(read_write={}) @@ -1951,10 +1974,15 @@ def test_context_mgr_w_aborted_commit_status(self): def test_context_mgr_failure(self): from google.cloud.spanner_v1.batch import Batch - database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) class Testing(Exception): @@ -1962,12 +1990,14 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as batch: - self.assertIsNone(pool._session) + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) raise Testing() - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) self.assertIsNone(batch.committed) @@ -1978,57 +2008,69 @@ def _get_target_class(self): return SnapshotCheckout def test_ctor_defaults(self): - from google.cloud.spanner_v1.snapshot import Snapshot + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) - database = _Database(self.DATABASE_NAME) session = _Session(database) - pool = database._pool = _Pool() - pool.put(session) + session_manager = database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database) self.assertIs(checkout._database, database) self.assertEqual(checkout._kw, {}) with checkout as snapshot: - self.assertIsNone(pool._session) + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(snapshot, Snapshot) self.assertIs(snapshot._session, session) self.assertTrue(snapshot._strong) self.assertFalse(snapshot._multi_use) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) def test_ctor_w_read_timestamp_and_multi_use(self): import datetime from google.cloud._helpers import UTC - from google.cloud.spanner_v1.snapshot import Snapshot - now = datetime.datetime.utcnow().replace(tzinfo=UTC) - database = _Database(self.DATABASE_NAME) + now = datetime.datetime.now(tz=UTC) + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + session = _Session(database) - pool = database._pool = _Pool() - pool.put(session) + session_manager = database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database, read_timestamp=now, multi_use=True) self.assertIs(checkout._database, database) self.assertEqual(checkout._kw, {"read_timestamp": now, "multi_use": True}) with checkout as snapshot: - self.assertIsNone(pool._session) + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(snapshot, Snapshot) self.assertIs(snapshot._session, session) self.assertEqual(snapshot._read_timestamp, now) self.assertTrue(snapshot._multi_use) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) def test_context_mgr_failure(self): - from google.cloud.spanner_v1.snapshot import Snapshot + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) - database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) class Testing(Exception): @@ -2036,72 +2078,19 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as snapshot: - self.assertIsNone(pool._session) + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(snapshot, Snapshot) self.assertIs(snapshot._session, session) raise Testing() - self.assertIs(pool._session, session) - - def test_context_mgr_session_not_found_error(self): - from google.cloud.exceptions import NotFound - - database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=False) - pool = database._pool = _Pool() - new_session = _Session(database, name="session-2") - new_session.create = mock.MagicMock(return_value=[]) - pool._new_session = mock.MagicMock(return_value=new_session) - - pool.put(session) - checkout = self._make_one(database) - - self.assertEqual(pool._session, session) - with self.assertRaises(NotFound): - with checkout as _: - raise NotFound("Session not found") - # Assert that session-1 was removed from pool and new session was added. - self.assertEqual(pool._session, new_session) - - def test_context_mgr_table_not_found_error(self): - from google.cloud.exceptions import NotFound - - database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=True) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) - - pool.put(session) - checkout = self._make_one(database) - - self.assertEqual(pool._session, session) - with self.assertRaises(NotFound): - with checkout as _: - raise NotFound("Table not found") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() - - def test_context_mgr_unknown_error(self): - database = _Database(self.DATABASE_NAME) - session = _Session(database) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) - pool.put(session) - checkout = self._make_one(database) - - class Testing(Exception): - pass + session_manager.put_session.assert_called_once_with(session) - self.assertEqual(pool._session, session) - with self.assertRaises(Testing): - with checkout as _: - raise Testing("Unknown error.") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() + def test_type_error(self): + with self.assertRaises(TypeError): + with self._make_one(None) as _: + pass class TestBatchSnapshot(_BaseTest): @@ -2110,27 +2099,8 @@ class TestBatchSnapshot(_BaseTest): TOKENS = [b"TOKEN1", b"TOKEN2"] INDEX = "index" - def _get_target_class(self): - from google.cloud.spanner_v1.database import BatchSnapshot - - return BatchSnapshot - - @staticmethod - def _make_database(**kwargs): - from google.cloud.spanner_v1.database import Database - - return mock.create_autospec(Database, instance=True, **kwargs) - - @staticmethod - def _make_session(**kwargs): - from google.cloud.spanner_v1.session import Session - - return mock.create_autospec(Session, instance=True, **kwargs) - @staticmethod def _make_snapshot(transaction_id=None, **kwargs): - from google.cloud.spanner_v1.snapshot import Snapshot - snapshot = mock.create_autospec(Snapshot, instance=True, **kwargs) if transaction_id is not None: snapshot._transaction_id = transaction_id @@ -2144,9 +2114,8 @@ def _make_keyset(): return KeySet(all_=True) def test_ctor_no_staleness(self): - database = self._make_database() - - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) self.assertIs(batch_txn._database, database) self.assertIsNone(batch_txn._session) @@ -2155,10 +2124,9 @@ def test_ctor_no_staleness(self): self.assertIsNone(batch_txn._exact_staleness) def test_ctor_w_read_timestamp(self): - database = self._make_database() + database = build_database() timestamp = self._make_timestamp() - - batch_txn = self._make_one(database, read_timestamp=timestamp) + batch_txn = BatchSnapshot(database, read_timestamp=timestamp) self.assertIs(batch_txn._database, database) self.assertIsNone(batch_txn._session) @@ -2167,10 +2135,9 @@ def test_ctor_w_read_timestamp(self): self.assertIsNone(batch_txn._exact_staleness) def test_ctor_w_exact_staleness(self): - database = self._make_database() + database = build_database() duration = self._make_duration() - - batch_txn = self._make_one(database, exact_staleness=duration) + batch_txn = BatchSnapshot(database, exact_staleness=duration) self.assertIs(batch_txn._database, database) self.assertIsNone(batch_txn._session) @@ -2179,103 +2146,87 @@ def test_ctor_w_exact_staleness(self): self.assertEqual(batch_txn._exact_staleness, duration) def test_from_dict(self): - klass = self._get_target_class() - database = self._make_database() - session = database.session.return_value = self._make_session() - snapshot = session.snapshot.return_value = self._make_snapshot() - api_repr = { - "session_id": self.SESSION_ID, - "transaction_id": self.TRANSACTION_ID, - } + database = build_database() + + batch_txn = BatchSnapshot.from_dict( + database, + { + "transaction_id": TRANSACTION_ID, + "session_id": SESSION_ID, + }, + ) - batch_txn = klass.from_dict(database, api_repr) self.assertIs(batch_txn._database, database) - self.assertIs(batch_txn._session, session) - self.assertEqual(session._session_id, self.SESSION_ID) - self.assertEqual(snapshot._transaction_id, self.TRANSACTION_ID) - snapshot.begin.assert_not_called() - self.assertIs(batch_txn._snapshot, snapshot) + self.assertIs(batch_txn._transaction_id, TRANSACTION_ID) + self.assertIs(batch_txn._session_id, SESSION_ID) - def test_to_dict(self): - database = self._make_database() - batch_txn = self._make_one(database) - batch_txn._session = self._make_session(_session_id=self.SESSION_ID) - batch_txn._snapshot = self._make_snapshot(transaction_id=self.TRANSACTION_ID) - - expected = { - "session_id": self.SESSION_ID, - "transaction_id": self.TRANSACTION_ID, - } - self.assertEqual(batch_txn.to_dict(), expected) + session = batch_txn._get_session() + self.assertEqual(session._session_id, SESSION_ID) - def test__get_session_already(self): - database = self._make_database() - batch_txn = self._make_one(database) - already = batch_txn._session = object() - self.assertIs(batch_txn._get_session(), already) + snapshot = batch_txn._get_snapshot() + self.assertEqual(snapshot._transaction_id, TRANSACTION_ID) + database.spanner_api.begin_transaction.assert_not_called() - def test__get_session_new(self): - database = self._make_database() - session = database.session.return_value = self._make_session() - batch_txn = self._make_one(database) - self.assertIs(batch_txn._get_session(), session) - session.create.assert_called_once_with() + def test_to_dict(self): + database = build_database() + batch_txn = BatchSnapshot(database) + + self.assertEqual( + batch_txn.to_dict(), + { + "transaction_id": TRANSACTION_ID, + "session_id": SESSION_ID, + }, + ) def test__get_snapshot_already(self): - database = self._make_database() - batch_txn = self._make_one(database) - already = batch_txn._snapshot = self._make_snapshot() - self.assertIs(batch_txn._get_snapshot(), already) - already.begin.assert_not_called() + database = build_database() + batch_txn = BatchSnapshot(database) + snapshot_1 = batch_txn._get_snapshot() + + snapshot_2 = batch_txn._get_snapshot() + self.assertEqual(snapshot_1, snapshot_2) + database.spanner_api.begin_transaction.assert_called_once() def test__get_snapshot_new_wo_staleness(self): - database = self._make_database() - batch_txn = self._make_one(database) - session = batch_txn._session = self._make_session() - snapshot = session.snapshot.return_value = self._make_snapshot() - self.assertIs(batch_txn._get_snapshot(), snapshot) - session.snapshot.assert_called_once_with( - read_timestamp=None, - exact_staleness=None, - multi_use=True, - transaction_id=None, - ) - snapshot.begin.assert_called_once_with() + database = build_database() + batch_txn = BatchSnapshot(database) + snapshot = batch_txn._get_snapshot() + + self.assertIsNone(snapshot._read_timestamp) + self.assertIsNone(snapshot._exact_staleness) + self.assertTrue(snapshot._multi_use) + self.assertEqual(TRANSACTION_ID, snapshot._transaction_id) + database.spanner_api.begin_transaction.assert_called_once() def test__get_snapshot_w_read_timestamp(self): - database = self._make_database() + database = build_database() timestamp = self._make_timestamp() - batch_txn = self._make_one(database, read_timestamp=timestamp) - session = batch_txn._session = self._make_session() - snapshot = session.snapshot.return_value = self._make_snapshot() - self.assertIs(batch_txn._get_snapshot(), snapshot) - session.snapshot.assert_called_once_with( - read_timestamp=timestamp, - exact_staleness=None, - multi_use=True, - transaction_id=None, - ) - snapshot.begin.assert_called_once_with() + batch_txn = BatchSnapshot(database, read_timestamp=timestamp) + snapshot = batch_txn._get_snapshot() + + self.assertEqual(timestamp, snapshot._read_timestamp) + self.assertIsNone(snapshot._exact_staleness) + self.assertTrue(snapshot._multi_use) + self.assertEqual(TRANSACTION_ID, snapshot._transaction_id) + database.spanner_api.begin_transaction.assert_called_once() def test__get_snapshot_w_exact_staleness(self): - database = self._make_database() + database = build_database() duration = self._make_duration() - batch_txn = self._make_one(database, exact_staleness=duration) - session = batch_txn._session = self._make_session() - snapshot = session.snapshot.return_value = self._make_snapshot() - self.assertIs(batch_txn._get_snapshot(), snapshot) - session.snapshot.assert_called_once_with( - read_timestamp=None, - exact_staleness=duration, - multi_use=True, - transaction_id=None, - ) - snapshot.begin.assert_called_once_with() + batch_txn = BatchSnapshot(database, exact_staleness=duration) + snapshot = batch_txn._get_snapshot() + + self.assertIsNone(snapshot._read_timestamp) + self.assertEqual(duration, snapshot._exact_staleness) + self.assertTrue(snapshot._multi_use) + self.assertEqual(TRANSACTION_ID, snapshot._transaction_id) + database.spanner_api.begin_transaction.assert_called_once() def test_read(self): keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() rows = batch_txn.read(self.TABLE, self.COLUMNS, keyset, self.INDEX) @@ -2291,8 +2242,8 @@ def test_execute_sql(self): ) params = {"max_age": 30} param_types = {"max_age": "INT64"} - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() rows = batch_txn.execute_sql(sql, params, param_types) @@ -2303,8 +2254,8 @@ def test_execute_sql(self): def test_generate_read_batches_w_max_partitions(self): max_partitions = len(self.TOKENS) keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_read.return_value = self.TOKENS @@ -2341,8 +2292,8 @@ def test_generate_read_batches_w_max_partitions(self): def test_generate_read_batches_w_retry_and_timeout_params(self): max_partitions = len(self.TOKENS) keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_read.return_value = self.TOKENS retry = Retry(deadline=60) @@ -2384,8 +2335,8 @@ def test_generate_read_batches_w_retry_and_timeout_params(self): def test_generate_read_batches_w_index_w_partition_size_bytes(self): size = 1 << 20 keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_read.return_value = self.TOKENS @@ -2426,8 +2377,8 @@ def test_generate_read_batches_w_index_w_partition_size_bytes(self): def test_generate_read_batches_w_data_boost_enabled(self): data_boost_enabled = True keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_read.return_value = self.TOKENS @@ -2467,8 +2418,8 @@ def test_generate_read_batches_w_data_boost_enabled(self): def test_generate_read_batches_w_directed_read_options(self): keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_read.return_value = self.TOKENS @@ -2518,8 +2469,8 @@ def test_process_read_batch(self): "index": self.INDEX, }, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.read.return_value = object() @@ -2549,8 +2500,8 @@ def test_process_read_batch_w_retry_timeout(self): "index": self.INDEX, }, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.read.return_value = object() retry = Retry(deadline=60) @@ -2574,7 +2525,7 @@ def test_generate_query_batches_w_max_partitions(self): client = _Client(self.PROJECT_ID) instance = _Instance(self.INSTANCE_NAME, client=client) database = _Database(self.DATABASE_NAME, instance=instance) - batch_txn = self._make_one(database) + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS @@ -2613,7 +2564,7 @@ def test_generate_query_batches_w_params_w_partition_size_bytes(self): client = _Client(self.PROJECT_ID) instance = _Instance(self.INSTANCE_NAME, client=client) database = _Database(self.DATABASE_NAME, instance=instance) - batch_txn = self._make_one(database) + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS @@ -2656,7 +2607,7 @@ def test_generate_query_batches_w_retry_and_timeout_params(self): client = _Client(self.PROJECT_ID) instance = _Instance(self.INSTANCE_NAME, client=client) database = _Database(self.DATABASE_NAME, instance=instance) - batch_txn = self._make_one(database) + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS retry = Retry(deadline=60) @@ -2699,7 +2650,7 @@ def test_generate_query_batches_w_data_boost_enabled(self): client = _Client(self.PROJECT_ID) instance = _Instance(self.INSTANCE_NAME, client=client) database = _Database(self.DATABASE_NAME, instance=instance) - batch_txn = self._make_one(database) + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS @@ -2731,7 +2682,7 @@ def test_generate_query_batches_w_directed_read_options(self): client = _Client(self.PROJECT_ID) instance = _Instance(self.INSTANCE_NAME, client=client) database = _Database(self.DATABASE_NAME, instance=instance) - batch_txn = self._make_one(database) + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS @@ -2773,8 +2724,8 @@ def test_process_query_batch(self): "partition": token, "query": {"sql": sql, "params": params, "param_types": param_types}, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.execute_sql.return_value = object() @@ -2802,8 +2753,8 @@ def test_process_query_batch_w_retry_timeout(self): "partition": token, "query": {"sql": sql, "params": params, "param_types": param_types}, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.execute_sql.return_value = object() retry = Retry(deadline=60) @@ -2827,8 +2778,8 @@ def test_process_query_batch_w_directed_read_options(self): "partition": token, "query": {"sql": sql, "directed_read_options": DIRECTED_READ_OPTIONS}, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.execute_sql.return_value = object() @@ -2845,25 +2796,26 @@ def test_process_query_batch_w_directed_read_options(self): ) def test_close_wo_session(self): - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) batch_txn.close() # no raise def test_close_w_session(self): - database = self._make_database() - batch_txn = self._make_one(database) - session = batch_txn._session = self._make_session() + database = build_database() + database._session_manager.put_session = mock.Mock() + batch_txn = BatchSnapshot(database) + session = batch_txn._get_session() batch_txn.close() - session.delete.assert_called_once_with() + database._session_manager.put_session.assert_called_once_with(session) def test_process_w_invalid_batch(self): token = b"TOKEN" batch = {"partition": token, "bogus": b"BOGUS"} - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) with self.assertRaises(ValueError): batch_txn.process(batch) @@ -2880,8 +2832,8 @@ def test_process_w_read_batch(self): "index": self.INDEX, }, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.read.return_value = object() @@ -2910,8 +2862,8 @@ def test_process_w_query_batch(self): "partition": token, "query": {"sql": sql, "params": params, "param_types": param_types}, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.execute_sql.return_value = object() @@ -2944,19 +2896,26 @@ def _make_spanner_client(): def test_ctor(self): from google.cloud.spanner_v1.batch import MutationGroups - database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) self.assertIs(checkout._database, database) with checkout as groups: - self.assertIsNone(pool._session) + session_manager.get_session.assert_called_once_with( + TransactionType.READ_WRITE + ) self.assertIsInstance(groups, MutationGroups) self.assertIs(groups._session, session) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) def test_context_mgr_success(self): import datetime @@ -2969,18 +2928,23 @@ def test_context_mgr_success(self): from google.cloud.spanner_v1.batch import MutationGroups from google.rpc.status_pb2 import Status - now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now = datetime.datetime.now(tz=UTC) now_pb = _datetime_to_pb_timestamp(now) status_pb = Status(code=200) response = BatchWriteResponse( commit_timestamp=now_pb, indexes=[0], status=status_pb ) - database = _Database(self.DATABASE_NAME) - api = database.spanner_api = self._make_spanner_client() + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + api = database._spanner_api = self._make_spanner_client() api.batch_write.return_value = [response] - pool = database._pool = _Pool() + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) request_options = RequestOptions(transaction_tag=self.TRANSACTION_TAG) @@ -3002,7 +2966,9 @@ def test_context_mgr_success(self): request_options=request_options, ) with checkout as groups: - self.assertIsNone(pool._session) + session_manager.get_session.assert_called_once_with( + TransactionType.READ_WRITE + ) self.assertIsInstance(groups, MutationGroups) self.assertIs(groups._session, session) group = groups.group() @@ -3010,7 +2976,7 @@ def test_context_mgr_success(self): groups.batch_write(request_options) self.assertEqual(groups.committed, True) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) api.batch_write.assert_called_once_with( request=request, @@ -3020,13 +2986,23 @@ def test_context_mgr_success(self): ], ) + def test_type_error(self): + with self.assertRaises(TypeError): + with self._make_one(None) as _: + pass + def test_context_mgr_failure(self): from google.cloud.spanner_v1.batch import MutationGroups - database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) class Testing(Exception): @@ -3034,72 +3010,81 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as groups: - self.assertIsNone(pool._session) self.assertIsInstance(groups, MutationGroups) + session_manager.get_session.assert_called_once_with( + TransactionType.READ_WRITE + ) self.assertIs(groups._session, session) raise Testing() - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) - def test_context_mgr_session_not_found_error(self): - from google.cloud.exceptions import NotFound - database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=False) - pool = database._pool = _Pool() - new_session = _Session(database, name="session-2") - new_session.create = mock.MagicMock(return_value=[]) - pool._new_session = mock.MagicMock(return_value=new_session) +class TestSessionCheckout(_BaseTest): + def test_ctor(self): + database = build_database() - pool.put(session) - checkout = self._make_one(database) + # Default transaction type. + checkout = SessionCheckout(database) + self.assertIs(checkout._database, database) + self.assertIs(checkout._transaction_type, TransactionType.READ_WRITE) + self.assertIsNone(checkout._session) - self.assertEqual(pool._session, session) - with self.assertRaises(NotFound): - with checkout as _: - raise NotFound("Session not found") - # Assert that session-1 was removed from pool and new session was added. - self.assertEqual(pool._session, new_session) + # Specified transaction type. + transaction_type = TransactionType.READ_ONLY + checkout = SessionCheckout(database, transaction_type) + self.assertIs(checkout._database, database) + self.assertIs(checkout._transaction_type, transaction_type) + self.assertIsNone(checkout._session) - def test_context_mgr_table_not_found_error(self): - from google.cloud.exceptions import NotFound + def test_context_manager_success(self): + database = build_database() + transaction_type = TransactionType.READ_ONLY + checkout = SessionCheckout(database, transaction_type) - database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=True) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) + session = build_session(database=database) + session_manager = session._database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) - pool.put(session) - checkout = self._make_one(database) + with checkout as borrowed: + session_manager.get_session.assert_called_once_with(transaction_type) + self.assertIs(borrowed, session) - self.assertEqual(pool._session, session) - with self.assertRaises(NotFound): - with checkout as _: - raise NotFound("Table not found") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() - - def test_context_mgr_unknown_error(self): - database = _Database(self.DATABASE_NAME) - session = _Session(database) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) - pool.put(session) - checkout = self._make_one(database) + session_manager.put_session.assert_called_once_with(session) + + def test_context_manager_failure(self): + database = build_database() + transaction_type = TransactionType.READ_ONLY + checkout = SessionCheckout(database, transaction_type) + + session = build_session(database=database) + session_manager = session._database._session_manager + session_manager.get_session = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) class Testing(Exception): pass - self.assertEqual(pool._session, session) with self.assertRaises(Testing): - with checkout as _: - raise Testing("Unknown error.") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() + with checkout as borrowed: + session_manager.get_session.assert_called_once_with(transaction_type) + self.assertIs(borrowed, session) + raise Testing() + + session_manager.put_session.assert_called_once_with(session) + + def test_type_errors(self): + database = build_database() + transaction_type = TransactionType.READ_ONLY + + with self.assertRaises(TypeError): + with SessionCheckout(None, transaction_type) as _: + pass + + with self.assertRaises(TypeError): + with SessionCheckout(database, None) as _: + pass def _make_instance_api(): @@ -3123,18 +3108,21 @@ def __init__( default_transaction_options=DefaultTransactionOptions(), ): from google.cloud.spanner_v1 import ExecuteSqlRequest + from google.cloud.spanner_v1.session_options import SessionOptions self.project = project self.project_name = "projects/" + self.project self._endpoint_cache = {} self.database_admin_api = _make_database_admin_api() self.instance_admin_api = _make_instance_api() + self.credentials = mock.Mock() self._client_info = mock.Mock() self._client_options = mock.Mock() self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") self.route_to_leader_enabled = route_to_leader_enabled self.directed_read_options = directed_read_options self.default_transaction_options = default_transaction_options + self.session_options = SessionOptions() class _Instance(object): @@ -3158,7 +3146,6 @@ def __init__(self, name, instance=None): self.name = name self.database_id = name.rsplit("/", 1)[1] self._instance = instance - from logging import Logger self.logger = mock.create_autospec(Logger, instance=True) self._directed_read_options = None @@ -3191,6 +3178,7 @@ def __init__( self._database = database self.name = name self._run_transaction_function = run_transaction_function + self.is_multiplexed = False def run_in_transaction(self, func, *args, **kw): if self._run_transaction_function: diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py new file mode 100644 index 0000000000..6019dccb28 --- /dev/null +++ b/tests/unit/test_database_session_manager.py @@ -0,0 +1,277 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import datetime +import time +from threading import Thread +from unittest import TestCase +from unittest.mock import Mock, DEFAULT, patch + +from google.api_core.exceptions import ( + MethodNotImplemented, + BadRequest, + FailedPrecondition, +) + +from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager +from google.cloud.spanner_v1.session_options import TransactionType +from tests._helpers import disable_multiplexed_sessions, enable_multiplexed_sessions + + +# Shorten polling and refresh intervals for testing. +@patch.multiple( + DatabaseSessionsManager, + _MAINTENANCE_THREAD_POLLING_INTERVAL=datetime.timedelta(seconds=1), + _MAINTENANCE_THREAD_REFRESH_INTERVAL=datetime.timedelta(seconds=2), +) +class TestDatabaseSessionManager(TestCase): + def setUp(self): + self._original_env = dict(os.environ) + self._build_session_manager() + + def tearDown(self): + self._cleanup_database_session_manager() + os.environ.clear() + os.environ.update(self._original_env) + + def test_read_only_pooled(self): + disable_multiplexed_sessions() + session_manager = self._session_manager + + # Get session from pool. + session = session_manager.get_session(TransactionType.READ_ONLY) + self.assertFalse(session.is_multiplexed) + session_manager._pool.get.assert_called_once() + + # Return session to pool. + session_manager.put_session(session) + session_manager._pool.put.assert_called_once_with(session) + + def test_read_only_multiplexed(self): + enable_multiplexed_sessions() + session_manager = self._session_manager + + # Session is created. + session_1 = session_manager.get_session(TransactionType.READ_ONLY) + self.assertTrue(session_1.is_multiplexed) + session_manager.put_session(session_1) + + # Session is re-used. + session_2 = session_manager.get_session(TransactionType.READ_ONLY) + self.assertEqual(session_1, session_2) + session_manager.put_session(session_2) + + # Verify that pool was not used. + session_manager._pool.get.assert_not_called() + session_manager._pool.put.assert_not_called() + + def test_partitioned_pooled(self): + disable_multiplexed_sessions() + session_manager = self._session_manager + + # Get session from pool. + session = session_manager.get_session(TransactionType.PARTITIONED) + self.assertFalse(session.is_multiplexed) + session_manager._pool.get.assert_called_once() + + # Return session to pool. + session_manager.put_session(session) + session_manager._pool.put.assert_called_once_with(session) + + def test_partitioned_multiplexed(self): + enable_multiplexed_sessions() + session_manager = self._session_manager + + # Session is created. + session_1 = session_manager.get_session(TransactionType.PARTITIONED) + self.assertTrue(session_1.is_multiplexed) + session_manager.put_session(session_1) + + # Session is re-used. + session_2 = session_manager.get_session(TransactionType.PARTITIONED) + self.assertEqual(session_1, session_2) + session_manager.put_session(session_2) + + # Verify that pool was not used. + session_manager._pool.get.assert_not_called() + session_manager._pool.put.assert_not_called() + + def test_read_write_pooled(self): + disable_multiplexed_sessions() + session_manager = self._session_manager + + # Get session from pool. + session = session_manager.get_session(TransactionType.READ_WRITE) + self.assertFalse(session.is_multiplexed) + session_manager._pool.get.assert_called_once() + + # Return session to pool. + session_manager.put_session(session) + session_manager._pool.put.assert_called_once_with(session) + + def test_read_write_multiplexed(self): + enable_multiplexed_sessions() + session_manager = self._session_manager + + with self.assertRaises(NotImplementedError): + session_manager.get_session(TransactionType.READ_WRITE) + + def test_multiplexed_maintenance(self, *_): + enable_multiplexed_sessions() + session_manager = self._session_manager + + # Maintenance thread is started. + session_1 = session_manager.get_session(TransactionType.READ_ONLY) + self.assertTrue(session_1.is_multiplexed) + + # Wait for maintenance thread to execute. + api = session_manager._database.spanner_api + + def create_session_condition(): + return api.create_session.call_count > 1 + + self._assert_true_with_timeout(create_session_condition) + + # Verify that maintenance thread created new multiplexed session. + session_2 = session_manager.get_session(TransactionType.READ_ONLY) + self.assertTrue(session_2.is_multiplexed) + self.assertNotEqual(session_1, session_2) + + def test_multiplexed_maintenance_terminates_not_implemented(self): + enable_multiplexed_sessions() + session_manager = self._session_manager + + # Maintenance thread is started. + session_1 = session_manager.get_session(TransactionType.READ_ONLY) + self.assertTrue(session_1.is_multiplexed) + + # Multiplexed sessions not implemented. + api = session_manager._database.spanner_api + api.create_session.side_effect = MethodNotImplemented("test") + + # Verify that maintenance thread is terminated. + thread = session_manager._multiplexed_session_maintenance_thread + self._assert_thread_terminated(thread) + + # Verify that multiplexed session are disabled. + session_options = session_manager._database.session_options + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) + self.assertFalse(session_options.use_multiplexed(TransactionType.PARTITIONED)) + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_WRITE)) + + def test_multiplexed_maintenance_terminates_disabled(self): + enable_multiplexed_sessions() + session_manager = self._session_manager + + # Maintenance thread is started. + session_1 = session_manager.get_session(TransactionType.READ_ONLY) + self.assertTrue(session_1.is_multiplexed) + + session_manager._is_multiplexed_sessions_disabled_event.set() + + thread = session_manager._multiplexed_session_maintenance_thread + self._assert_thread_terminated(thread) + + def test_multiplexed_exception_method_not_implemented(self): + enable_multiplexed_sessions() + session_manager = self._session_manager + + # Multiplexed sessions not implemented. + api = session_manager._database.spanner_api + api.create_session.side_effect = [ + MethodNotImplemented("Test MethodNotImplemented"), + DEFAULT, + ] + + # Get session from pool. + session = session_manager.get_session(TransactionType.READ_ONLY) + self.assertFalse(session.is_multiplexed) + session_manager._pool.get.assert_called_once() + + # Return session to pool. + session_manager.put_session(session) + session_manager._pool.put.assert_called_once_with(session) + + # Verify that multiplexed session are disabled. + session_options = session_manager._database.session_options + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) + self.assertFalse(session_options.use_multiplexed(TransactionType.PARTITIONED)) + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_WRITE)) + + def test_exception_bad_request(self): + session_manager = self._session_manager + + api = session_manager._database.spanner_api + api.create_session.side_effect = BadRequest("test") + + # Verify that BadRequest is not caught. + with self.assertRaises(BadRequest): + session_manager.get_session(TransactionType.READ_ONLY) + + def test_exception_failed_precondition(self): + session_manager = self._session_manager + + api = session_manager._database.spanner_api + api.create_session.side_effect = FailedPrecondition("test") + + # Verify that FailedPrecondition is not caught. + with self.assertRaises(FailedPrecondition): + session_manager.get_session(TransactionType.READ_ONLY) + + def _build_session_manager(self) -> DatabaseSessionsManager: + """Builds a new database session manager for testing.""" + from tests._builders import build_database + + database = build_database() + session_manager = database._session_manager + + # Mock the session pool. + pool = session_manager._pool + pool.get = Mock(wraps=pool.get) + pool.put = Mock(wraps=pool.put) + + self._session_manager = session_manager + + def _cleanup_database_session_manager(self) -> None: + """Cleans up the database session manager after testing.""" + + # If the maintenance thread is still alive, disable multiplexed sessions and + # wait for the thread to terminate. We need to do this to ensure that the + # thread is properly cleaned up and does not interfere with other tests. + session_manager = self._session_manager + thread = session_manager._multiplexed_session_maintenance_thread + + if thread and thread.is_alive(): + session_manager._is_multiplexed_sessions_disabled_event.set() + self._assert_thread_terminated(thread) + + def _assert_true_with_timeout(self, condition): + """Asserts that the given condition is met within a timeout period.""" + + sleep_seconds = 0.1 + timeout_seconds = 10 + + start_time = time.time() + while not condition() and time.time() - start_time < timeout_seconds: + time.sleep(sleep_seconds) + + self.assertTrue(condition()) + + def _assert_thread_terminated(self, thread: Thread): + """Asserts that the maintenance thread is terminated.""" + + def _is_thread_terminated(): + return not thread.is_alive() + + self._assert_true_with_timeout(_is_thread_terminated) diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index e7ad729438..5810478c54 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -539,10 +539,8 @@ def test_database_factory_defaults(self): self.assertEqual(database.database_id, DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), []) - self.assertIsInstance(database._pool, BurstyPool) - self.assertIsNone(database._logger) - pool = database._pool - self.assertIs(pool._database, database) + self.assertIsInstance(database._session_manager._pool, BurstyPool) + self.assertIs(database._session_manager._pool._database, database) self.assertIsNone(database.database_role) def test_database_factory_explicit(self): @@ -573,7 +571,7 @@ def test_database_factory_explicit(self): self.assertEqual(database.database_id, DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS) - self.assertIs(database._pool, pool) + self.assertIs(database._session_manager._pool, pool) self.assertIs(database._logger, logger) self.assertIs(pool._bound, database) self.assertIs(database._encryption_config, encryption_config) diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index a9593b3651..67c2b338f9 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -20,6 +20,7 @@ import mock from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from tests._builders import build_database from tests._helpers import ( OpenTelemetryBase, LIB_VERSION, @@ -29,18 +30,6 @@ ) -def _make_database(name="name"): - from google.cloud.spanner_v1.database import Database - - return mock.create_autospec(Database, instance=True) - - -def _make_session(): - from google.cloud.spanner_v1.database import Session - - return mock.create_autospec(Session, instance=True) - - class TestAbstractSessionPool(unittest.TestCase): def _getTargetClass(self): from google.cloud.spanner_v1.pool import AbstractSessionPool @@ -66,7 +55,7 @@ def test_ctor_explicit(self): def test_bind_abstract(self): pool = self._make_one() - database = _make_database("name") + database = build_database() with self.assertRaises(NotImplementedError): pool.bind(database) @@ -88,58 +77,32 @@ def test_clear_abstract(self): def test__new_session_wo_labels(self): pool = self._make_one() - database = pool._database = _make_database("name") - session = _make_session() - database.session.return_value = session + pool._database = build_database() new_session = pool._new_session() - self.assertIs(new_session, session) - database.session.assert_called_once_with(labels={}, database_role=None) + self.assertEqual(new_session.labels, {}) + self.assertIsNone(new_session.database_role) def test__new_session_w_labels(self): labels = {"foo": "bar"} pool = self._make_one(labels=labels) - database = pool._database = _make_database("name") - session = _make_session() - database.session.return_value = session + pool._database = build_database() new_session = pool._new_session() - self.assertIs(new_session, session) - database.session.assert_called_once_with(labels=labels, database_role=None) + self.assertEqual(new_session.labels, labels) + self.assertIsNone(new_session.database_role) def test__new_session_w_database_role(self): database_role = "dummy-role" pool = self._make_one(database_role=database_role) - database = pool._database = _make_database("name") - session = _make_session() - database.session.return_value = session + pool._database = build_database() new_session = pool._new_session() - self.assertIs(new_session, session) - database.session.assert_called_once_with(labels={}, database_role=database_role) - - def test_session_wo_kwargs(self): - from google.cloud.spanner_v1.pool import SessionCheckout - - pool = self._make_one() - checkout = pool.session() - self.assertIsInstance(checkout, SessionCheckout) - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual(checkout._kwargs, {}) - - def test_session_w_kwargs(self): - from google.cloud.spanner_v1.pool import SessionCheckout - - pool = self._make_one() - checkout = pool.session(foo="bar") - self.assertIsInstance(checkout, SessionCheckout) - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual(checkout._kwargs, {"foo": "bar"}) + self.assertEqual(new_session.labels, {}) + self.assertEqual(new_session.database_role, database_role) class TestFixedSizePool(OpenTelemetryBase): @@ -187,10 +150,10 @@ def test_ctor_explicit(self): def test_bind(self): database_role = "dummy-role" pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 + database = _Database() database._database_role = database_role - database._sessions.extend(SESSIONS) + sessions = [_Session(database)] * 10 + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) @@ -202,37 +165,35 @@ def test_bind(self): api = database.spanner_api self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: + for session in sessions: session.create.assert_not_called() def test_get_active(self): pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = sorted([_Session(database) for i in range(0, 4)]) - database._sessions.extend(SESSIONS) + database = _Database() + sessions = sorted([_Session(database)] * 4) + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) # check if sessions returned in LIFO order for i in (3, 2, 1, 0): session = pool.get() - self.assertIs(session, SESSIONS[i]) + self.assertIs(session, sessions[i]) self.assertFalse(session._exists_checked) self.assertFalse(pool._sessions.full()) def test_get_non_expired(self): pool = self._make_one(size=4) - database = _Database("name") + database = _Database() last_use_time = datetime.utcnow() - timedelta(minutes=56) - SESSIONS = sorted( - [_Session(database, last_use_time=last_use_time) for i in range(0, 4)] - ) - database._sessions.extend(SESSIONS) + sessions = sorted([_Session(database, last_use_time=last_use_time)] * 4) + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) # check if sessions returned in LIFO order for i in (3, 2, 1, 0): session = pool.get() - self.assertIs(session, SESSIONS[i]) + self.assertIs(session, sessions[i]) self.assertTrue(session._exists_checked) self.assertFalse(pool._sessions.full()) @@ -242,12 +203,12 @@ def test_spans_bind_get(self): # This tests retrieving 1 out of 4 sessions from the session pool. pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = sorted([_Session(database) for i in range(0, 4)]) - database._sessions.extend(SESSIONS) + database = _Database() + sessions = sorted([_Session(database) for i in range(0, 4)]) + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) - with trace_call("pool.Get", SESSIONS[0]): + with trace_call("pool.Get", sessions[0]): pool.get() span_list = self.get_finished_spans() @@ -284,7 +245,7 @@ def test_spans_bind_get_empty_pool(self): # Tests trying to invoke pool.get() from an empty pool. pool = self._make_one(size=0) - database = _Database("name") + database = _Database() session1 = _Session(database) with trace_call("pool.Get", session1): try: @@ -329,9 +290,8 @@ def test_spans_pool_bind(self): # Tests the exception generated from invoking pool.bind when # you have an empty pool. pool = self._make_one(size=1) - database = _Database("name") - SESSIONS = [] - database._sessions.extend(SESSIONS) + database = _Database() + pool._new_session = mock.Mock(side_effect=Exception("Test")) fauxSession = mock.Mock() setattr(fauxSession, "_database", database) try: @@ -377,8 +337,8 @@ def test_spans_pool_bind(self): ( "exception", { - "exception.type": "IndexError", - "exception.message": "pop from empty list", + "exception.type": "Exception", + "exception.message": "Test", "exception.stacktrace": "EPHEMERAL", "exception.escaped": "False", }, @@ -388,8 +348,8 @@ def test_spans_pool_bind(self): ( "exception", { - "exception.type": "IndexError", - "exception.message": "pop from empty list", + "exception.type": "Exception", + "exception.message": "Test", "exception.stacktrace": "EPHEMERAL", "exception.escaped": "False", }, @@ -399,18 +359,18 @@ def test_spans_pool_bind(self): def test_get_expired(self): pool = self._make_one(size=4) - database = _Database("name") + database = _Database() last_use_time = datetime.utcnow() - timedelta(minutes=65) - SESSIONS = [_Session(database, last_use_time=last_use_time)] * 5 - SESSIONS[0]._exists = False - database._sessions.extend(SESSIONS) + sessions = [_Session(database, last_use_time=last_use_time)] * 5 + sessions[0]._exists = False + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) session = pool.get() - self.assertIs(session, SESSIONS[4]) + self.assertIs(session, sessions[4]) session.create.assert_called() - self.assertTrue(SESSIONS[0]._exists_checked) + self.assertTrue(sessions[0]._exists_checked) self.assertFalse(pool._sessions.full()) def test_get_empty_default_timeout(self): @@ -439,9 +399,9 @@ def test_put_full(self): import queue pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + database._sessions.extend(sessions) pool.bind(database) self.reset() @@ -452,9 +412,9 @@ def test_put_full(self): def test_put_non_full(self): pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + database._sessions.extend(sessions) pool.bind(database) pool._sessions.get() @@ -464,20 +424,20 @@ def test_put_non_full(self): def test_clear(self): pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 10 + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) self.assertTrue(pool._sessions.full()) api = database.spanner_api self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: + for session in sessions: session.create.assert_not_called() pool.clear() - for session in SESSIONS: + for session in sessions: self.assertTrue(session._deleted) @@ -522,15 +482,15 @@ def test_ctor_explicit(self): def test_ctor_explicit_w_database_role_in_db(self): database_role = "dummy-role" pool = self._make_one() - database = pool._database = _Database("name") + database = pool._database = _Database() database._database_role = database_role pool.bind(database) self.assertEqual(pool.database_role, database_role) def test_get_empty(self): pool = self._make_one() - database = _Database("name") - database._sessions.append(_Session(database)) + database = _Database() + pool._new_session = mock.Mock(side_effect=[_Session(database)]) pool.bind(database) session = pool.get() @@ -548,9 +508,9 @@ def test_spans_get_empty_pool(self): # and pool.get() acquires from a pool, waiting for a session # to become available. pool = self._make_one() - database = _Database("name") + database = _Database() session1 = _Session(database) - database._sessions.append(session1) + pool._new_session = mock.Mock(side_effect=[session1]) pool.bind(database) with trace_call("pool.Get", session1): @@ -580,7 +540,7 @@ def test_spans_get_empty_pool(self): def test_get_non_empty_session_exists(self): pool = self._make_one() - database = _Database("name") + database = _Database() previous = _Session(database) pool.bind(database) pool.put(previous) @@ -596,7 +556,7 @@ def test_spans_get_non_empty_session_exists(self): # Tests the spans produces when you invoke pool.bind # and then insert a session into the pool. pool = self._make_one() - database = _Database("name") + database = _Database() previous = _Session(database) pool.bind(database) with trace_call("pool.Get", previous): @@ -618,10 +578,10 @@ def test_spans_get_non_empty_session_exists(self): def test_get_non_empty_session_expired(self): pool = self._make_one() - database = _Database("name") + database = _Database() previous = _Session(database, exists=False) newborn = _Session(database) - database._sessions.append(newborn) + pool._new_session = mock.Mock(side_effect=[newborn]) pool.bind(database) pool.put(previous) @@ -635,7 +595,7 @@ def test_get_non_empty_session_expired(self): def test_put_empty(self): pool = self._make_one() - database = _Database("name") + database = _Database() pool.bind(database) session = _Session(database) @@ -646,7 +606,7 @@ def test_put_empty(self): def test_spans_put_empty(self): # Tests the spans produced when you put sessions into an empty pool. pool = self._make_one() - database = _Database("name") + database = _Database() pool.bind(database) session = _Session(database) @@ -661,7 +621,7 @@ def test_spans_put_empty(self): def test_put_full(self): pool = self._make_one(target_size=1) - database = _Database("name") + database = _Database() pool.bind(database) older = _Session(database) pool.put(older) @@ -677,7 +637,7 @@ def test_spans_put_full(self): # This scenario tests the spans produced from putting an older # session into a pool that is already full. pool = self._make_one(target_size=1) - database = _Database("name") + database = _Database() pool.bind(database) older = _Session(database) with trace_call("pool.put", older): @@ -697,7 +657,7 @@ def test_spans_put_full(self): def test_put_full_expired(self): pool = self._make_one(target_size=1) - database = _Database("name") + database = _Database() pool.bind(database) older = _Session(database) pool.put(older) @@ -711,7 +671,7 @@ def test_put_full_expired(self): def test_clear(self): pool = self._make_one() - database = _Database("name") + database = _Database() pool.bind(database) previous = _Session(database) pool.put(previous) @@ -773,18 +733,18 @@ def test_ctor_explicit(self): def test_ctor_explicit_w_database_role_in_db(self): database_role = "dummy-role" pool = self._make_one() - database = pool._database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) + database = pool._database = _Database() + sessions = [_Session(database)] * 10 + database._sessions.extend(sessions) database._database_role = database_role pool.bind(database) self.assertEqual(pool.database_role, database_role) def test_bind(self): pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 10 + database._sessions.extend(sessions) pool.bind(database) self.assertIs(pool._database, database) @@ -795,20 +755,20 @@ def test_bind(self): api = database.spanner_api self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: + for session in sessions: session.create.assert_not_called() def test_get_hit_no_ping(self): pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) self.reset() session = pool.get() - self.assertIs(session, SESSIONS[0]) + self.assertIs(session, sessions[0]) self.assertFalse(session._exists_checked) self.assertFalse(pool._sessions.full()) self.assertNoSpans() @@ -819,9 +779,9 @@ def test_get_hit_w_ping(self): from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + pool._new_session = mock.Mock(side_effect=sessions) sessions_created = datetime.datetime.utcnow() - datetime.timedelta(seconds=4000) @@ -832,7 +792,7 @@ def test_get_hit_w_ping(self): session = pool.get() - self.assertIs(session, SESSIONS[0]) + self.assertIs(session, sessions[0]) self.assertTrue(session._exists_checked) self.assertFalse(pool._sessions.full()) self.assertNoSpans() @@ -843,10 +803,10 @@ def test_get_hit_w_ping_expired(self): from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 5 - SESSIONS[0]._exists = False - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 5 + sessions[0]._exists = False + pool._new_session = mock.Mock(side_effect=sessions) sessions_created = datetime.datetime.utcnow() - datetime.timedelta(seconds=4000) @@ -856,9 +816,9 @@ def test_get_hit_w_ping_expired(self): session = pool.get() - self.assertIs(session, SESSIONS[4]) + self.assertIs(session, sessions[4]) session.create.assert_called() - self.assertTrue(SESSIONS[0]._exists_checked) + self.assertTrue(sessions[0]._exists_checked) self.assertFalse(pool._sessions.full()) self.assertNoSpans() @@ -890,9 +850,9 @@ def test_put_full(self): import queue pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + database._sessions.extend(sessions) pool.bind(database) with self.assertRaises(queue.Full): @@ -907,9 +867,9 @@ def test_spans_put_full(self): import queue pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + database._sessions.extend(sessions) pool.bind(database) with self.assertRaises(queue.Full): @@ -946,7 +906,7 @@ def test_put_non_full(self): session_queue = pool._sessions = _Queue() now = datetime.datetime.utcnow() - database = _Database("name") + database = _Database() session = _Session(database) with _Monkey(MUT, _NOW=lambda: now): @@ -960,21 +920,21 @@ def test_put_non_full(self): def test_clear(self): pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 10 + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) self.reset() self.assertTrue(pool._sessions.full()) api = database.spanner_api self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: + for session in sessions: session.create.assert_not_called() pool.clear() - for session in SESSIONS: + for session in sessions: self.assertTrue(session._deleted) self.assertNoSpans() @@ -985,15 +945,15 @@ def test_ping_empty(self): def test_ping_oldest_fresh(self): pool = self._make_one(size=1) - database = _Database("name") - SESSIONS = [_Session(database)] * 1 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 1 + database._sessions.extend(sessions) pool.bind(database) self.reset() pool.ping() - self.assertFalse(SESSIONS[0]._pinged) + self.assertFalse(sessions[0]._pinged) self.assertNoSpans() def test_ping_oldest_stale_but_exists(self): @@ -1002,16 +962,16 @@ def test_ping_oldest_stale_but_exists(self): from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=1) - database = _Database("name") - SESSIONS = [_Session(database)] * 1 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 1 + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) later = datetime.datetime.utcnow() + datetime.timedelta(seconds=4000) with _Monkey(MUT, _NOW=lambda: later): pool.ping() - self.assertTrue(SESSIONS[0]._pinged) + self.assertTrue(sessions[0]._pinged) def test_ping_oldest_stale_and_not_exists(self): import datetime @@ -1019,10 +979,10 @@ def test_ping_oldest_stale_and_not_exists(self): from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=1) - database = _Database("name") - SESSIONS = [_Session(database)] * 2 - SESSIONS[0]._exists = False - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 2 + sessions[0]._exists = False + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) self.reset() @@ -1030,8 +990,8 @@ def test_ping_oldest_stale_and_not_exists(self): with _Monkey(MUT, _NOW=lambda: later): pool.ping() - self.assertTrue(SESSIONS[0]._pinged) - SESSIONS[1].create.assert_called() + self.assertTrue(sessions[0]._pinged) + sessions[1].create.assert_called() self.assertNoSpans() def test_spans_get_and_leave_empty_pool(self): @@ -1041,9 +1001,9 @@ def test_spans_get_and_leave_empty_pool(self): # This scenario tests the spans generated from pulling a span # out the pool and leaving it empty. pool = self._make_one() - database = _Database("name") + database = _Database() session1 = _Session(database) - database._sessions.append(session1) + pool._new_session = mock.Mock(side_effect=[session1]) try: pool.bind(database) except Exception: @@ -1073,73 +1033,6 @@ def test_spans_get_and_leave_empty_pool(self): self.assertSpanEvents("pool.Get", wantEventNames, span_list[-1]) -class TestSessionCheckout(unittest.TestCase): - def _getTargetClass(self): - from google.cloud.spanner_v1.pool import SessionCheckout - - return SessionCheckout - - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) - - def test_ctor_wo_kwargs(self): - pool = _Pool() - checkout = self._make_one(pool) - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual(checkout._kwargs, {}) - - def test_ctor_w_kwargs(self): - pool = _Pool() - checkout = self._make_one(pool, foo="bar", database_role="dummy-role") - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual( - checkout._kwargs, {"foo": "bar", "database_role": "dummy-role"} - ) - - def test_context_manager_wo_kwargs(self): - session = object() - pool = _Pool(session) - checkout = self._make_one(pool) - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - - with checkout as borrowed: - self.assertIs(borrowed, session) - self.assertEqual(len(pool._items), 0) - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - self.assertEqual(pool._got, {}) - - def test_context_manager_w_kwargs(self): - session = object() - pool = _Pool(session) - checkout = self._make_one(pool, foo="bar") - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - - with checkout as borrowed: - self.assertIs(borrowed, session) - self.assertEqual(len(pool._items), 0) - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - self.assertEqual(pool._got, {"foo": "bar"}) - - -def _make_transaction(*args, **kw): - from google.cloud.spanner_v1.transaction import Transaction - - txn = mock.create_autospec(Transaction)(*args, **kw) - txn.committed = None - txn.rolled_back = False - return txn - - @total_ordering class _Session(object): _transaction = None @@ -1183,21 +1076,13 @@ def delete(self): if not self._exists: raise NotFound("unknown session") - def transaction(self): - txn = self._transaction = _make_transaction(self) - return txn - - @property - def session_id(self): - return self._session_id - class _Database(object): - def __init__(self, name): - self.name = name + def __init__(self): + self.name = "name" self._sessions = [] self._database_role = None - self.database_id = name + self.database_id = self.name self._route_to_leader_enabled = True def mock_batch_create_sessions( @@ -1276,7 +1161,3 @@ def put(self, item, **kwargs): def put_nowait(self, item, **kwargs): self._put_nowait = kwargs self._items.append(item) - - -class _Pool(_Queue): - _database = None diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8f5f7039b9..a52abdac3d 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -99,8 +99,13 @@ def _make_database( return database @staticmethod - def _make_session_pb(name, labels=None, database_role=None): - return SessionRequestProto(name=name, labels=labels, creator_role=database_role) + def _make_session_pb(name=None, multiplexed=False, labels=None, database_role=None): + return SessionRequestProto( + name=name, + multiplexed=multiplexed, + labels=labels, + creator_role=database_role, + ) def _make_spanner_api(self): return mock.Mock(autospec=SpannerClient, instance=True) @@ -167,7 +172,7 @@ def test_create_w_session_id(self): def test_create_w_database_role(self): session_pb = self._make_session_pb( - self.SESSION_NAME, database_role=self.DATABASE_ROLE + name=self.SESSION_NAME, database_role=self.DATABASE_ROLE ) gax_api = self._make_spanner_api() gax_api.create_session.return_value = session_pb @@ -200,7 +205,7 @@ def test_create_w_database_role(self): def test_create_session_span_annotations(self): session_pb = self._make_session_pb( - self.SESSION_NAME, database_role=self.DATABASE_ROLE + name=self.SESSION_NAME, database_role=self.DATABASE_ROLE ) gax_api = self._make_spanner_api() @@ -233,7 +238,7 @@ def test_create_session_span_annotations(self): self.assertSpanEvents("TestSessionSpan", wantEventNames, span) def test_create_wo_database_role(self): - session_pb = self._make_session_pb(self.SESSION_NAME) + session_pb = self._make_session_pb(name=self.SESSION_NAME) gax_api = self._make_spanner_api() gax_api.create_session.return_value = session_pb database = self._make_database() @@ -245,7 +250,7 @@ def test_create_wo_database_role(self): self.assertIsNone(session.database_role) request = CreateSessionRequest( - database=database.name, + database=database.name, session=self._make_session_pb() ) gax_api.create_session.assert_called_once_with( @@ -261,7 +266,7 @@ def test_create_wo_database_role(self): ) def test_create_ok(self): - session_pb = self._make_session_pb(self.SESSION_NAME) + session_pb = self._make_session_pb(name=self.SESSION_NAME) gax_api = self._make_spanner_api() gax_api.create_session.return_value = session_pb database = self._make_database() @@ -274,6 +279,7 @@ def test_create_ok(self): request = CreateSessionRequest( database=database.name, + session=self._make_session_pb(), ) gax_api.create_session.assert_called_once_with( @@ -290,7 +296,7 @@ def test_create_ok(self): def test_create_w_labels(self): labels = {"foo": "bar"} - session_pb = self._make_session_pb(self.SESSION_NAME, labels=labels) + session_pb = self._make_session_pb(name=self.SESSION_NAME, labels=labels) gax_api = self._make_spanner_api() gax_api.create_session.return_value = session_pb database = self._make_database() @@ -343,7 +349,7 @@ def test_exists_wo_session_id(self): self.assertNoSpans() def test_exists_hit(self): - session_pb = self._make_session_pb(self.SESSION_NAME) + session_pb = self._make_session_pb(name=self.SESSION_NAME) gax_api = self._make_spanner_api() gax_api.get_session.return_value = session_pb database = self._make_database() @@ -371,7 +377,7 @@ def test_exists_hit(self): False, ) def test_exists_hit_wo_span(self): - session_pb = self._make_session_pb(self.SESSION_NAME) + session_pb = self._make_session_pb(name=self.SESSION_NAME) gax_api = self._make_spanner_api() gax_api.get_session.return_value = session_pb database = self._make_database() diff --git a/tests/unit/test_session_options.py b/tests/unit/test_session_options.py new file mode 100644 index 0000000000..4e478a816d --- /dev/null +++ b/tests/unit/test_session_options.py @@ -0,0 +1,134 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from logging import Logger +from unittest import TestCase + +from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType + + +class TestSessionOptions(TestCase): + _logger = Logger("test_session_options_logger") + + @classmethod + def setUpClass(cls): + # Save the original environment variables. + cls._original_env = dict(os.environ) + + @classmethod + def tearDownClass(cls): + # Restore environment variables. + os.environ.clear() + os.environ.update(cls._original_env) + + def test_use_multiplexed_for_read_only(self): + session_options = SessionOptions() + transaction_type = TransactionType.READ_ONLY + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + self.assertTrue(session_options.use_multiplexed(transaction_type)) + + session_options.disable_multiplexed(self._logger, transaction_type) + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + def test_use_multiplexed_for_partitioned(self): + session_options = SessionOptions() + transaction_type = TransactionType.PARTITIONED + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "false" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + self.assertTrue(session_options.use_multiplexed(transaction_type)) + + session_options.disable_multiplexed(self._logger, transaction_type) + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + def test_use_multiplexed_for_read_write(self): + session_options = SessionOptions() + transaction_type = TransactionType.READ_WRITE + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "false" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + self.assertTrue(session_options.use_multiplexed(transaction_type)) + + session_options.disable_multiplexed(self._logger, transaction_type) + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + def test_disable_multiplexed_all(self): + session_options = SessionOptions() + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + + session_options.disable_multiplexed(self._logger) + + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) + self.assertFalse(session_options.use_multiplexed(TransactionType.PARTITIONED)) + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_WRITE)) + + def test_unsupported_transaction_type(self): + session_options = SessionOptions() + unsupported_type = "UNSUPPORTED_TRANSACTION_TYPE" + + with self.assertRaises(ValueError): + session_options.use_multiplexed(unsupported_type) + + with self.assertRaises(ValueError): + session_options.disable_multiplexed(self._logger, unsupported_type) + + def test_env_var_values(self): + session_options = SessionOptions() + + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + + true_values = ["1", " 1", " 1", "true", "True", "TRUE", " true "] + for value in true_values: + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value + self.assertTrue(session_options.use_multiplexed(TransactionType.READ_ONLY)) + + false_values = ["", "0", "false", "False", "FALSE"] + for value in false_values: + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) + + del os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 11fc0135d1..0e900acc37 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -11,21 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from google.api_core import gapic_v1 import mock -from google.cloud.spanner_v1 import RequestOptions, DirectedReadOptions +from google.cloud.spanner_v1 import ( + RequestOptions, + DirectedReadOptions, + SpannerClient, + KeySet, +) +from google.cloud.spanner_v1.session_options import TransactionType from tests._helpers import ( OpenTelemetryBase, LIB_VERSION, StatusCode, HAS_OPENTELEMETRY_INSTALLED, enrich_with_otel_scope, + enable_multiplexed_sessions, ) from google.cloud.spanner_v1.param_types import INT64 from google.api_core.retry import Retry +from tests._builders import build_session TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] @@ -79,7 +85,7 @@ def _makeTimestamp(): import datetime from google.cloud._helpers import UTC - return datetime.datetime.utcnow().replace(tzinfo=UTC) + return datetime.datetime.now(tz=UTC) class Test_restart_on_unavailable(OpenTelemetryBase): @@ -111,8 +117,6 @@ def _make_txn_selector(self): return _Derived(session) def _make_spanner_api(self): - from google.cloud.spanner_v1 import SpannerClient - return mock.create_autospec(SpannerClient, instance=True) def _call_fut( @@ -128,12 +132,12 @@ def _call_fut( from google.cloud.spanner_v1.snapshot import _restart_on_unavailable return _restart_on_unavailable( - restart, - request, - metadata, - span_name, - session, - attributes, + method=restart, + request=request, + session=session, + metadata=metadata, + trace_name=span_name, + attributes=attributes, transaction=derived, ) @@ -594,8 +598,6 @@ def _make_txn_selector(self): return _Derived(session) def _make_spanner_api(self): - from google.cloud.spanner_v1 import SpannerClient - return mock.create_autospec(SpannerClient, instance=True) def test_ctor(self): @@ -612,9 +614,25 @@ def test__make_txn_selector_virtual(self): with self.assertRaises(NotImplementedError): base._make_txn_selector() - def test_read_other_error(self): - from google.cloud.spanner_v1.keyset import KeySet + def test_read_partitioned_not_implemented_for_multiplexed(self): + enable_multiplexed_sessions() + + database = ( + self._build_database_with_partitioned_not_implemented_for_multiplexed() + ) + + session = build_session(database=database) + session.create() + derived = self._makeDerived(session) + + with self.assertRaises(NotImplementedError): + list(derived.read(TABLE_NAME, COLUMNS, KeySet(all_=True))) + + self.assertFalse( + database.session_options.use_multiplexed(TransactionType.READ_ONLY) + ) + def test_read_other_error(self): keyset = KeySet(all_=True) database = _Database() database.spanner_api = self._make_spanner_api() @@ -658,7 +676,6 @@ def _read_helper( from google.cloud.spanner_v1 import ReadRequest from google.cloud.spanner_v1 import Type, StructType from google.cloud.spanner_v1 import TypeCode - from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1._helpers import _make_value_pb VALUES = [["bharney", 31], ["phred", 32]] @@ -865,6 +882,24 @@ def test_read_w_directed_read_options_override(self): directed_read_options_at_client_level=DIRECTED_READ_OPTIONS_FOR_CLIENT, ) + def test_execute_sql_partitioned_not_implemented_for_multiplexed(self): + enable_multiplexed_sessions() + + database = ( + self._build_database_with_partitioned_not_implemented_for_multiplexed() + ) + + session = build_session(database=database) + session.create() + derived = self._makeDerived(session) + + with self.assertRaises(NotImplementedError): + list(derived.execute_sql(SQL_QUERY)) + + self.assertFalse( + database.session_options.use_multiplexed(TransactionType.READ_ONLY) + ) + def test_execute_sql_other_error(self): database = _Database() database.spanner_api = self._make_spanner_api() @@ -1137,7 +1172,6 @@ def _partition_read_helper( retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, ): - from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1 import Partition from google.cloud.spanner_v1 import PartitionOptions from google.cloud.spanner_v1 import PartitionReadRequest @@ -1225,9 +1259,27 @@ def test_partition_read_wo_existing_transaction_raises(self): with self.assertRaises(ValueError): self._partition_read_helper(multi_use=True, w_txn=False) - def test_partition_read_other_error(self): - from google.cloud.spanner_v1.keyset import KeySet + def test_partition_read_multiplexed_not_implemented_error(self): + enable_multiplexed_sessions() + database = ( + self._build_database_with_partitioned_not_implemented_for_multiplexed() + ) + + session = build_session(database=database) + session.create() + derived = self._makeDerived(session) + derived._multi_use = True + derived._transaction_id = TXN_ID + + with self.assertRaises(NotImplementedError): + list(derived.partition_read(TABLE_NAME, COLUMNS, KeySet(all_=True))) + + self.assertFalse( + database.session_options.use_multiplexed(TransactionType.READ_ONLY) + ) + + def test_partition_read_other_error(self): keyset = KeySet(all_=True) database = _Database() database.spanner_api = self._make_spanner_api() @@ -1249,7 +1301,6 @@ def test_partition_read_other_error(self): ) def test_partition_read_w_retry(self): - from google.cloud.spanner_v1.keyset import KeySet from google.api_core.exceptions import InternalServerError from google.cloud.spanner_v1 import Partition from google.cloud.spanner_v1 import PartitionResponse @@ -1389,6 +1440,26 @@ def _partition_query_helper( attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), ) + def test_partition_query_partitioned_not_implemented_for_multiplexed(self): + enable_multiplexed_sessions() + + database = ( + self._build_database_with_partitioned_not_implemented_for_multiplexed() + ) + + session = build_session(database=database) + session.create() + derived = self._makeDerived(session) + derived._multi_use = True + derived._transaction_id = TXN_ID + + with self.assertRaises(NotImplementedError): + list(derived.partition_query(SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES)) + + self.assertFalse( + database.session_options.use_multiplexed(TransactionType.READ_ONLY) + ) + def test_partition_query_other_error(self): database = _Database() database.spanner_api = self._make_spanner_api() @@ -1437,6 +1508,25 @@ def test_partition_query_ok_w_timeout_and_retry_params(self): multi_use=True, w_txn=True, retry=Retry(deadline=60), timeout=2.0 ) + @staticmethod + def _build_database_with_partitioned_not_implemented_for_multiplexed(): + """Builds and returns a database for testing that raises errors for + partitioned operations not being supported for multiplexed sessions.""" + from tests._builders import build_database + + database = build_database() + + error_msg = "Partitioned operations are not supported with multiplexed sessions" + not_implemented_error = NotImplementedError(error_msg) + + api = database.spanner_api + api.streaming_read.side_effect = not_implemented_error + api.execute_streaming_sql.side_effect = not_implemented_error + api.partition_read.side_effect = not_implemented_error + api.partition_query.side_effect = not_implemented_error + + return database + class TestSnapshot(OpenTelemetryBase): PROJECT_ID = "project-id" @@ -1456,8 +1546,6 @@ def _make_one(self, *args, **kwargs): return self._getTargetClass()(*args, **kwargs) def _make_spanner_api(self): - from google.cloud.spanner_v1 import SpannerClient - return mock.create_autospec(SpannerClient, instance=True) def _makeDuration(self, seconds=1, microseconds=0):