diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 4796c2fc76..6cf225bbea 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -26,6 +26,7 @@ from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_v1 import RequestOptions, TransactionOptions +from google.cloud.spanner_v1.session_options import TransactionType from google.cloud.spanner_v1.snapshot import Snapshot from google.cloud.spanner_dbapi.exceptions import ( @@ -356,11 +357,12 @@ def _session_checkout(self): raise ValueError("Database needs to be passed for this operation") if not self._session: - self._session = ( - self.database._session_manager.get_session_for_read_only() + transaction_type = ( + TransactionType.READ_ONLY if self.read_only - else self.database._session_manager.get_session_for_read_write() + else TransactionType.READ_WRITE ) + self._session = self.database._session_manager.get_session(transaction_type) return self._session @@ -628,7 +630,6 @@ def partition_query( self._partitioned_query_validation(partitioned_query, statement) batch_snapshot = self._database.batch_snapshot() - partition_ids = [] partitions = list( batch_snapshot.generate_query_batches( partitioned_query, @@ -639,6 +640,8 @@ def partition_query( ) batch_transaction_id = batch_snapshot.get_batch_transaction_id() + + partition_ids = [] for partition in partitions: partition_ids.append( partition_helper.encode_to_string(batch_transaction_id, partition) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 098bdc0730..93d9c1a31c 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -41,6 +41,8 @@ from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager +from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType from google.cloud.spanner_v1.transaction import BatchTransactionId from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import Type @@ -60,7 +62,6 @@ from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.merged_result_set import MergedResultSet from google.cloud.spanner_v1.pool import BurstyPool -from google.cloud.spanner_v1.session import Session from google.cloud.spanner_v1.snapshot import _restart_on_unavailable from google.cloud.spanner_v1.snapshot import Snapshot from google.cloud.spanner_v1.streamed import StreamedResultSet @@ -74,7 +75,6 @@ ) from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture - SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data" @@ -192,9 +192,7 @@ def __init__( pool = BurstyPool(database_role=database_role) pool.bind(self) - self._session_manager = DatabaseSessionsManager( - database=self, pool=pool, logger=self.logger - ) + self._session_manager = DatabaseSessionsManager(database=self, pool=pool) @classmethod def from_pb(cls, database_pb, instance, pool=None): @@ -449,6 +447,15 @@ def spanner_api(self): ) return self._spanner_api + @property + def session_options(self) -> SessionOptions: + """Session options for the database. + + :rtype: :class:`~google.cloud.spanner_v1.session_options.SessionOptions` + :returns: the session options + """ + return self._instance._client.session_options + def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented @@ -709,11 +716,27 @@ def execute_pdml(): "CloudSpanner.Database.execute_partitioned_pdml", observability_options=self.observability_options, ) as span, MetricsCapture(): - with SessionCheckout(self) as session: + transaction_type = TransactionType.PARTITIONED + with SessionCheckout(self, transaction_type) as session: add_span_event(span, "Starting BeginTransaction") - txn = api.begin_transaction( - session=session.name, options=txn_options, metadata=metadata - ) + + try: + txn = api.begin_transaction( + session=session.name, options=txn_options, metadata=metadata + ) + + # If partitioned DML is not supported with multiplexed sessions, + # disable multiplexed sessions for partitioned transactions before + # re-raising the error. + except NotImplementedError as exc: + if ( + "Transaction type partitioned_dml not supported with multiplexed sessions" + in str(exc) + ): + self.session_options.disable_multiplexed( + self.logger, transaction_type + ) + raise exc txn_selector = TransactionSelector(id=txn.id) @@ -732,8 +755,9 @@ def execute_pdml(): iterator = _restart_on_unavailable( method=method, - trace_name="CloudSpanner.ExecuteStreamingSql", request=request, + session=session, + trace_name="CloudSpanner.ExecuteStreamingSql", metadata=metadata, transaction_selector=txn_selector, observability_options=self.observability_options, @@ -746,23 +770,6 @@ def execute_pdml(): return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)() - def session(self, labels=None, database_role=None): - """Factory to create a session for this database. - - :type labels: dict (str -> str) or None - :param labels: (Optional) user-assigned labels for the session. - - :type database_role: str - :param database_role: (Optional) user-assigned database_role for the session. - - :rtype: :class:`~google.cloud.spanner_v1.session.Session` - :returns: a session bound to this database. - """ - # If role is specified in param, then that role is used - # instead. - role = database_role or self._database_role - return Session(self, labels=labels, database_role=role) - def snapshot(self, **kw): """Return an object which wraps a snapshot. @@ -1170,7 +1177,11 @@ class SessionCheckout(object): _session = None # Not checked out until '__enter__'. - def __init__(self, database): + def __init__( + self, + database: Database, + transaction_type: TransactionType = TransactionType.READ_WRITE, + ): if not isinstance(database, Database): raise TypeError( "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format( @@ -1180,10 +1191,21 @@ def __init__(self, database): ) ) + if not isinstance(transaction_type, TransactionType): + raise TypeError( + "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format( + class_name=self.__class__.__name__, + expected_class_name=TransactionType.__name__, + actual_class_name=transaction_type.__class__.__name__, + ) + ) + self._database = database + self._transaction_type = transaction_type def __enter__(self): - self._session = self._database._session_manager.get_session_for_read_write() + session_manager = self._database._session_manager + self._session = session_manager.get_session(self._transaction_type) return self._session def __exit__(self, *ignored): @@ -1248,7 +1270,13 @@ def __init__( def __enter__(self): """Begin ``with`` block.""" - self._session = self._database._session_manager.get_session_for_read_only() + + # Batch transactions are performed as blind writes, + # which are treated as read-only transactions. + self._session = self._database._session_manager.get_session( + TransactionType.READ_ONLY + ) + batch = self._batch = Batch(self._session) if self._request_options.transaction_tag: batch.transaction_tag = self._request_options.transaction_tag @@ -1303,7 +1331,9 @@ def __init__(self, database): def __enter__(self): """Begin ``with`` block.""" - self._session = self._database._session_manager.get_session_for_read_write() + self._session = self._database._session_manager.get_session( + TransactionType.READ_WRITE + ) return MutationGroups(self._session) def __exit__(self, exc_type, exc_val, exc_tb): @@ -1345,7 +1375,9 @@ def __init__(self, database, **kw): def __enter__(self): """Begin ``with`` block.""" - self._session = self._database._session_manager.get_session_for_read_only() + self._session = self._database._session_manager.get_session( + TransactionType.READ_ONLY + ) return Snapshot(self._session, **self._kw) def __exit__(self, exc_type, exc_val, exc_tb): @@ -1395,11 +1427,15 @@ def from_dict(cls, database, mapping): :rtype: :class:`BatchSnapshot` """ + instance = cls(database) - session = instance._session = database.session() - session._session_id = mapping["session_id"] + + session = instance._session = Session(database=database) + instance._session_id = session._session_id = mapping["session_id"] + snapshot = instance._snapshot = session.snapshot() - snapshot._transaction_id = mapping["transaction_id"] + instance._transaction_id = snapshot._transaction_id = mapping["transaction_id"] + return instance def to_dict(self): @@ -1408,10 +1444,15 @@ def to_dict(self): Result can be used to serialize the instance and reconstitute it later using :meth:`from_dict`. + When called, the underlying session is cleaned up, so + the batch snapshot is no longer valid. + :rtype: dict """ + session = self._get_session() snapshot = self._get_snapshot() + return { "session_id": session._session_id, "transaction_id": snapshot._transaction_id, @@ -1429,25 +1470,48 @@ def _get_session(self): Caller is responsible for cleaning up the session after all partitions have been processed. """ + if self._session is None: - session = self._session = self._database.session() + database = self._database + + # If the session ID is not specified, check out a new session from + # the database session manager; otherwise, the session has already + # been checked out, so just create a session object to represent it. if self._session_id is None: - session.create() + transaction_type = TransactionType.READ_ONLY + session = database._session_manager.get_session(transaction_type) + self._session_id = session.session_id + else: + session = Session(database=database) session._session_id = self._session_id + + self._session = session + return self._session def _get_snapshot(self): """Create snapshot if needed.""" + if self._snapshot is None: - self._snapshot = self._get_session().snapshot( - read_timestamp=self._read_timestamp, - exact_staleness=self._exact_staleness, - multi_use=True, - transaction_id=self._transaction_id, - ) + snapshot_args = { + "session": self._get_session(), + "read_timestamp": self._read_timestamp, + "exact_staleness": self._exact_staleness, + "multi_use": True, + } + + # If the transaction ID is not specified, create a new snapshot + # and begin a transaction; otherwise, the transaction is already + # in progress, so just create a snapshot object to represent it. if self._transaction_id is None: - self._snapshot.begin() + self._snapshot = Snapshot(**snapshot_args) + self._transaction_id = self._snapshot.begin() + + else: + snapshot_args["transaction_id"] = self._transaction_id + self._snapshot = Snapshot(**snapshot_args) + return self._snapshot def get_batch_transaction_id(self): @@ -1844,7 +1908,7 @@ def close(self): from all the partitions. """ if self._session is not None: - self._session.delete() + self._database._session_manager.put_session(self._session) def _check_ddl_statements(value): diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py index 293192304a..a9837700ef 100644 --- a/google/cloud/spanner_v1/database_sessions_manager.py +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -23,14 +23,15 @@ add_span_event, ) from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.session_options import TransactionType class DatabaseSessionsManager(object): """Manages sessions for a Cloud Spanner database. - Sessions can be checked out from the database session manager using :meth:`get_session_for_read_only`, - :meth:`get_session_for_partitioned`, and :meth:`get_session_for_read_write`, and returned to - the session manager using :meth:`put_session`. + Sessions can be checked out from the database session manager for a specific + transaction type using :meth:`get_session`, and returned to the session manager + using :meth:`put_session`. The sessions returned by the session manager depend on the client's session options (see :class:`~google.cloud.spanner_v1.session_options.SessionOptions`) and the provided session @@ -41,18 +42,15 @@ class DatabaseSessionsManager(object): :type pool: :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` :param pool: The pool to get non-multiplexed sessions from. - - :type logger: :class:`logging.Logger` - :param logger: Logger for the database session manager. """ # Intervals for the maintenance thread to check and refresh the multiplexed session. - _MAINTENANCE_THREAD_POLLING_INTERVAL = datetime.timedelta(hours=1) + _MAINTENANCE_THREAD_POLLING_INTERVAL = datetime.timedelta(minutes=10) _MAINTENANCE_THREAD_REFRESH_INTERVAL = datetime.timedelta(days=7) - def __init__(self, database, pool, logger): + def __init__(self, database, pool): self._database = database - self._logger = logger + self._logger = database.logger # The session pool manages non-multiplexed sessions, and # will only be used if multiplexed sessions are not enabled. @@ -70,84 +68,21 @@ def __init__(self, database, pool, logger): self._multiplexed_session_lock = threading.Lock() self._is_multiplexed_sessions_disabled_event = threading.Event() - def get_session_for_read_only(self) -> Session: - """Returns a session for read-only transactions from the database session manager. - - :rtype: :class:`~google.cloud.spanner_v1.session.Session` - :returns: a session for read-only transactions. - """ - - return self._get_session( - use_multiplexed=self._database._instance._client.session_options.use_multiplexed_for_read_only() - ) - - def get_session_for_partitioned(self) -> Session: - """Returns a session for partitioned transactions from the database session manager. + def get_session(self, transaction_type: TransactionType) -> Session: + """Returns a session for the given transaction type from the database session manager. :rtype: :class:`~google.cloud.spanner_v1.session.Session` - :returns: a session for partitioned transactions. + :returns: a session for the given transaction type. """ - if ( - self._database._instance._client.session_options.use_multiplexed_for_partitioned() - ): - raise NotImplementedError( - "Multiplexed sessions are not yet supported for partitioned transactions." - ) - - return self._get_session(use_multiplexed=False) - - def get_session_for_read_write(self) -> Session: - """Returns a session for read/write transactions from the database session manager. - - :rtype: :class:`~google.cloud.spanner_v1.session.Session` - :returns: a session for read/write transactions. - """ + session_options = self._database.session_options + use_multiplexed = session_options.use_multiplexed(transaction_type) - if ( - self._database._instance._client.session_options.use_multiplexed_for_read_write() - ): + if use_multiplexed and transaction_type == TransactionType.READ_WRITE: raise NotImplementedError( - "Multiplexed sessions are not yet supported for read/write transactions." + f"Multiplexed sessions are not yet supported for {transaction_type} transactions." ) - return self._get_session(use_multiplexed=False) - - def put_session(self, session: Session) -> None: - """Returns the session to the database session manager. - - :type session: :class:`~google.cloud.spanner_v1.session.Session` - :param session: The session to return to the database session manager. - """ - - # No action is needed for multiplexed sessions: the session - # pool is only used for managing non-multiplexed sessions, - # since they can only process one transaction at a time. - if not session.is_multiplexed: - self._pool.put(session) - - current_span = get_current_span() - add_span_event( - current_span, - "Returned session", - {"id": session.session_id, "multiplexed": session.is_multiplexed}, - ) - - def _get_session(self, use_multiplexed: bool) -> Session: - """Returns a session from the database session manager. - - If use_multiplexed is True, returns a multiplexed session if - multiplexed sessions are supported. If multiplexed sessions are - not supported or if use_multiplexed is False, returns a non- - multiplexed session from the session pool. - - :type use_multiplexed: bool - :param use_multiplexed: Whether to try to get a multiplexed session. - - :rtype: :class:`~google.cloud.spanner_v1.session.Session` - :returns: a session for the database session manager. - """ - if use_multiplexed: try: session = self._get_multiplexed_session() @@ -169,6 +104,25 @@ def _get_session(self, use_multiplexed: bool) -> Session: return session + def put_session(self, session: Session) -> None: + """Returns the session to the database session manager. + + :type session: :class:`~google.cloud.spanner_v1.session.Session` + :param session: The session to return to the database session manager. + """ + + add_span_event( + get_current_span(), + "Returning session", + {"id": session.session_id, "multiplexed": session.is_multiplexed}, + ) + + # No action is needed for multiplexed sessions: the session + # pool is only used for managing non-multiplexed sessions, + # since they can only process one transaction at a time. + if not session.is_multiplexed: + self._pool.put(session) + def _get_multiplexed_session(self) -> Session: """Returns a multiplexed session from the database session manager. @@ -194,12 +148,6 @@ def _get_multiplexed_session(self) -> Session: ) self._multiplexed_session_maintenance_thread.start() - add_span_event( - get_current_span(), - "Using session", - {"id": self._multiplexed_session.session_id, "multiplexed": True}, - ) - return self._multiplexed_session def _build_multiplexed_session(self) -> Session: @@ -227,17 +175,9 @@ def _build_multiplexed_session(self) -> Session: def _disable_multiplexed_sessions(self) -> None: """Disables multiplexed sessions for all transactions.""" - self._logger.warning( - "Multiplexed session creation failed. Disabling multiplexed sessions." - ) - - session_options = self._database._instance._client.session_options - session_options.disable_multiplexed_for_read_only() - session_options.disable_multiplexed_for_partitioned() - session_options.disable_multiplexed_for_read_write() - self._multiplexed_session = None self._is_multiplexed_sessions_disabled_event.set() + self._database.session_options.disable_multiplexed(self._logger) def _build_maintenance_thread(self) -> threading.Thread: """Builds and returns a multiplexed session maintenance thread for diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 18c586e76e..65968b7649 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -20,7 +20,7 @@ from google.cloud.exceptions import NotFound from google.cloud.spanner_v1 import BatchCreateSessionsRequest -from google.cloud.spanner_v1 import Session +from google.cloud.spanner_v1 import Session as SessionPB from google.cloud.spanner_v1._helpers import ( _metadata_with_prefix, _metadata_with_leader_aware_routing, @@ -33,6 +33,7 @@ from warnings import warn from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture +from google.cloud.spanner_v1.session import Session _NOW = datetime.datetime.utcnow # unit tests may replace @@ -130,8 +131,9 @@ def _new_session(self): :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: new session instance. """ - return self._database.session( - labels=self.labels, database_role=self.database_role + database_role = self.database_role or self._database.database_role + return Session( + database=self._database, labels=self.labels, database_role=database_role ) @@ -225,7 +227,7 @@ def bind(self, database): request = BatchCreateSessionsRequest( database=database.name, session_count=requested_session_count, - session_template=Session(creator_role=self.database_role), + session_template=SessionPB(creator_role=self.database_role), ) observability_options = getattr(self._database, "observability_options", None) @@ -518,7 +520,7 @@ def bind(self, database): request = BatchCreateSessionsRequest( database=database.name, session_count=self.size, - session_template=Session(creator_role=self.database_role), + session_template=SessionPB(creator_role=self.database_role), ) span_event_attributes = {"kind": type(self).__name__} diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index e3e9aa6f66..644b00ba9c 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -300,11 +300,6 @@ def snapshot(self, **kw): if self._session_id is None: raise ValueError("Session has not been created.") - if self.is_multiplexed: - raise NotImplementedError( - "Multiplexed sessions do not yet support read-only transactions." - ) - return Snapshot(self, **kw) def read(self, table, columns, keyset, index="", limit=0, column_info=None): diff --git a/google/cloud/spanner_v1/session_options.py b/google/cloud/spanner_v1/session_options.py index 9675ca90db..eab16dc6de 100644 --- a/google/cloud/spanner_v1/session_options.py +++ b/google/cloud/spanner_v1/session_options.py @@ -12,25 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from enum import Enum +from logging import Logger -from google.cloud.spanner_v1._opentelemetry_tracing import ( - get_current_span, - add_span_event, -) + +class TransactionType(Enum): + """Transaction types for session options.""" + + READ_ONLY = "read-only" + PARTITIONED = "partitioned" + READ_WRITE = "read/write" class SessionOptions(object): """Represents the session options for the Cloud Spanner Python client. - We can use ::class::`SessionOptions` to determine whether multiplexed sessions should be used for: - * read-only transactions (:meth:`use_multiplexed_for_read_only`) - * partitioned transactions (:meth:`use_multiplexed_for_partitioned`) - * read/write transactions (:meth:`use_multiplexed_for_read_write`). - - The use of multiplexed session can be disabled for corresponding transaction types by calling: - * :meth:`disable_multiplexed_for_read_only` - * :meth:`disable_multiplexed_for_partitioned` - * :meth:`disable_multiplexed_for_read_write`. + We can use ::class::`SessionOptions` to determine whether multiplexed sessions + should be used for a specific transaction type with :meth:`use_multiplexed`. The use + of multiplexed session can be disabled for a specific transaction type or for all + transaction types with :meth:`disable_multiplexed`. """ # Environment variables for multiplexed sessions @@ -48,93 +48,106 @@ class SessionOptions(object): def __init__(self): # Internal overrides to disable the use of multiplexed # sessions in case of runtime errors. - self._is_multiplexed_enabled_for_read_only = True - self._is_multiplexed_enabled_for_partitioned = True - self._is_multiplexed_enabled_for_read_write = True + self._is_multiplexed_enabled = { + TransactionType.READ_ONLY: True, + TransactionType.PARTITIONED: True, + TransactionType.READ_WRITE: True, + } + + def use_multiplexed(self, transaction_type: TransactionType) -> bool: + """Returns whether to use multiplexed sessions for the given transaction type. - def use_multiplexed_for_read_only(self) -> bool: - """Returns whether to use multiplexed sessions for read-only transactions. Multiplexed sessions are enabled for read-only transactions if: * ENV_VAR_ENABLE_MULTIPLEXED is set to true; * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and - * multiplexed sessions have not been disabled for read-only transactions (see 'disable_multiplexed_for_read_only'). - """ - - return ( - self._is_multiplexed_enabled_for_read_only - and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) - and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) - ) - - def disable_multiplexed_for_read_only(self) -> None: - """Disables the use of multiplexed sessions for read-only transactions.""" - - current_span = get_current_span() - add_span_event( - current_span, - "Disabling use of multiplexed session for read-only transactions", - ) - - self._is_multiplexed_enabled_for_read_only = False + * multiplexed sessions have not been disabled for read-only transactions. - def use_multiplexed_for_partitioned(self) -> bool: - """Returns whether to use multiplexed sessions for partitioned transactions. Multiplexed sessions are enabled for partitioned transactions if: * ENV_VAR_ENABLE_MULTIPLEXED is set to true; * ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED is set to true; * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and - * multiplexed sessions have not been disabled for partitioned transactions (see 'disable_multiplexed_for_partitioned'). - """ - - return ( - self._is_multiplexed_enabled_for_partitioned - and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) - and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED) - and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) - ) - - def disable_multiplexed_for_partitioned(self) -> None: - """Disables the use of multiplexed sessions for read-only transactions.""" - - current_span = get_current_span() - add_span_event( - current_span, - "Disabling use of multiplexed session for partitioned transactions", - ) + * multiplexed sessions have not been disabled for partitioned transactions. - self._is_multiplexed_enabled_for_partitioned = False - - def use_multiplexed_for_read_write(self) -> bool: - """Returns whether to use multiplexed sessions for read/write transactions. Multiplexed sessions are enabled for read/write transactions if: * ENV_VAR_ENABLE_MULTIPLEXED is set to true; * ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED is set to true; * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and - * multiplexed sessions have not been disabled for read/write transactions (see 'disable_multiplexed_for_read_write'). - """ + * multiplexed sessions have not been disabled for read/write transactions. - return ( - self._is_multiplexed_enabled_for_read_write - and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) - and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE) - and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) - ) + :type transaction_type: :class:`TransactionType` + :param transaction_type: the type of transaction to check whether + multiplexed sessions should be used. + """ - def disable_multiplexed_for_read_write(self) -> None: - """Disables the use of multiplexed sessions for read/write transactions.""" + if transaction_type is TransactionType.READ_ONLY: + return ( + self._is_multiplexed_enabled[transaction_type] + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + ) + + elif transaction_type is TransactionType.PARTITIONED: + return ( + self._is_multiplexed_enabled[transaction_type] + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + ) + + elif transaction_type is TransactionType.READ_WRITE: + return ( + self._is_multiplexed_enabled[transaction_type] + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + ) + + raise ValueError(f"Transaction type {transaction_type} is not supported.") + + def disable_multiplexed( + self, logger: Logger = None, transaction_type: TransactionType = None + ) -> None: + """Disables the use of multiplexed sessions for the given transaction type. + If no transaction type is specified, disables the use of multiplexed sessions + for all transaction types. + + :type logger: :class:`Logger` + :param logger: logger to use for logging the disabling the use of multiplexed + sessions. + + :type transaction_type: :class:`TransactionType` + :param transaction_type: (Optional) the type of transaction for which to disable + the use of multiplexed sessions. + """ - current_span = get_current_span() - add_span_event( - current_span, - "Disabling use of multiplexed session for read/write transactions", + disable_multiplexed_log_msg_fstring = ( + "Disabling multiplexed sessions for {transaction_type_value} transactions" ) - self._is_multiplexed_enabled_for_read_write = False + if transaction_type is None: + logger.warning( + disable_multiplexed_log_msg_fstring.format(transaction_type_value="all") + ) + for transaction_type in TransactionType: + self._is_multiplexed_enabled[transaction_type] = False + return + + elif transaction_type in self._is_multiplexed_enabled.keys(): + logger.warning( + disable_multiplexed_log_msg_fstring.format( + transaction_type_value=transaction_type.value + ) + ) + self._is_multiplexed_enabled[transaction_type] = False + return + + raise ValueError(f"Transaction type '{transaction_type}' is not supported.") @staticmethod def _getenv(name: str) -> bool: """Returns the value of the given environment variable as a boolean. - True values are '1' and 'true' (case-insensitive); all other values are considered false. + True values are '1' and 'true' (case-insensitive); all other values are + considered false. """ env_var = os.getenv(name, "").lower().strip() return env_var in ["1", "true"] diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 3b18d2c855..22bbe0e103 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -14,7 +14,6 @@ """Model a set of read-only queries to a database as a snapshot.""" -from datetime import datetime import functools import threading from google.protobuf.struct_pb2 import Struct @@ -40,6 +39,7 @@ _SessionWrapper, ) from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1.session_options import TransactionType from google.cloud.spanner_v1.streamed import StreamedResultSet from google.cloud.spanner_v1 import RequestOptions @@ -54,9 +54,9 @@ def _restart_on_unavailable( method, request, + session, metadata=None, trace_name=None, - session=None, attributes=None, transaction=None, transaction_selector=None, @@ -70,6 +70,9 @@ def _restart_on_unavailable( :type request: proto :param request: request proto to call the method with + :type session: :class:`~google.cloud.spanner_v1.session.Session` + :param session: the session used to perform the operation + :type transaction: :class:`google.cloud.spanner_v1.snapshot._SnapshotBase` :param transaction: Snapshot or Transaction class object based on the type of transaction @@ -153,6 +156,9 @@ def _restart_on_unavailable( iterator = method(request=request) continue + except NotImplementedError as exc: + _handle_not_implemented_error(exc, session._database) + if len(item_buffer) == 0: break @@ -162,6 +168,30 @@ def _restart_on_unavailable( del item_buffer[:] +def _handle_not_implemented_error(exception, database) -> None: + """Handles NotImplementedError for the database. If the error is due to unsupported + partitioned operations with multiplexed sessions, disables multiplexed sessions for + read-only transactions and re-raises the error. Otherwise, re-raises the error + with no side effects. + + :type exception: :class:`NotImplementedError` + :param exception: The NotImplementedError to handle. + + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: The database instance associated with the error. + + :raises NotImplementedError: The original exception + """ + + if "Partitioned operations are not supported with multiplexed sessions" in str( + exception + ): + session_options = database.session_options + session_options.disable_multiplexed(database.logger, TransactionType.READ_ONLY) + + raise exception + + class _SnapshotBase(_SessionWrapper): """Base class for Snapshot. @@ -329,7 +359,8 @@ def read( data_boost_enabled=data_boost_enabled, directed_read_options=directed_read_options, ) - restart = functools.partial( + + streaming_read = functools.partial( api.streaming_read, request=request, metadata=metadata, @@ -340,54 +371,23 @@ def read( trace_attributes = {"table_id": table, "columns": columns} observability_options = getattr(database, "observability_options", None) + get_streamed_result_set_args = { + "method": streaming_read, + "request": request, + "metadata": metadata, + "trace_attributes": trace_attributes, + "column_info": column_info, + "observability_options": observability_options, + "lazy_decode": lazy_decode, + } + if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: - iterator = _restart_on_unavailable( - restart, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.read", - self._session, - trace_attributes, - transaction=self, - observability_options=observability_options, - ) - self._read_request_count += 1 - if self._multi_use: - return StreamedResultSet( - iterator, - source=self, - column_info=column_info, - lazy_decode=lazy_decode, - ) - else: - return StreamedResultSet( - iterator, column_info=column_info, lazy_decode=lazy_decode - ) - else: - iterator = _restart_on_unavailable( - restart, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.read", - self._session, - trace_attributes, - transaction=self, - observability_options=observability_options, - ) + return self._get_streamed_result_set(**get_streamed_result_set_args) - self._read_request_count += 1 - self._session._last_use_time = datetime.now() - - if self._multi_use: - return StreamedResultSet( - iterator, source=self, column_info=column_info, lazy_decode=lazy_decode - ) else: - return StreamedResultSet( - iterator, column_info=column_info, lazy_decode=lazy_decode - ) + return self._get_streamed_result_set(**get_streamed_result_set_args) def execute_sql( self, @@ -562,7 +562,7 @@ def execute_sql( data_boost_enabled=data_boost_enabled, directed_read_options=directed_read_options, ) - restart = functools.partial( + execute_streaming_sql_method = functools.partial( api.execute_streaming_sql, request=request, metadata=metadata, @@ -573,60 +573,22 @@ def execute_sql( trace_attributes = {"db.statement": sql} observability_options = getattr(database, "observability_options", None) + get_streamed_result_set_args = { + "method": execute_streaming_sql_method, + "request": request, + "metadata": metadata, + "trace_attributes": trace_attributes, + "column_info": column_info, + "observability_options": observability_options, + "lazy_decode": lazy_decode, + } + if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: - return self._get_streamed_result_set( - restart, - request, - metadata, - trace_attributes, - column_info, - observability_options, - lazy_decode=lazy_decode, - ) - else: - return self._get_streamed_result_set( - restart, - request, - metadata, - trace_attributes, - column_info, - observability_options, - lazy_decode=lazy_decode, - ) - - def _get_streamed_result_set( - self, - restart, - request, - metadata, - trace_attributes, - column_info, - observability_options=None, - lazy_decode=False, - ): - iterator = _restart_on_unavailable( - restart, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.execute_sql", - self._session, - trace_attributes, - transaction=self, - observability_options=observability_options, - ) - self._read_request_count += 1 - self._execute_sql_count += 1 - - if self._multi_use: - return StreamedResultSet( - iterator, source=self, column_info=column_info, lazy_decode=lazy_decode - ) + return self._get_streamed_result_set(**get_streamed_result_set_args) else: - return StreamedResultSet( - iterator, column_info=column_info, lazy_decode=lazy_decode - ) + return self._get_streamed_result_set(**get_streamed_result_set_args) def partition_read( self, @@ -725,10 +687,15 @@ def partition_read( retry=retry, timeout=timeout, ) - response = _retry( - method, - allowed_exceptions={InternalServerError: _check_rst_stream_error}, - ) + + try: + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) + + except NotImplementedError as exc: + _handle_not_implemented_error(exc, database) return [partition.partition_token for partition in response.partitions] @@ -829,13 +796,62 @@ def partition_query( retry=retry, timeout=timeout, ) - response = _retry( - method, - allowed_exceptions={InternalServerError: _check_rst_stream_error}, - ) + + try: + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) + + except NotImplementedError as exc: + _handle_not_implemented_error(exc, database) return [partition.partition_token for partition in response.partitions] + def _get_streamed_result_set( + self, + method, + request, + metadata, + trace_attributes, + column_info, + observability_options, + lazy_decode, + ): + """Returns the streamed result set for a read or execute SQL request with the given arguments.""" + + is_execute_sql_request = isinstance(request, ExecuteSqlRequest) + + trace_request_name = "execute_sql" if is_execute_sql_request else "read" + trace_name = f"CloudSpanner.{type(self).__name__}.{trace_request_name}" + + iterator = _restart_on_unavailable( + method=method, + request=request, + session=self._session, + metadata=metadata, + trace_name=trace_name, + attributes=trace_attributes, + transaction=self, + observability_options=observability_options, + ) + + self._read_request_count += 1 + + if is_execute_sql_request: + self._execute_sql_count += 1 + + streamed_result_set_args = { + "response_iterator": iterator, + "column_info": column_info, + "lazy_decode": lazy_decode, + } + + if self._multi_use: + streamed_result_set_args["source"] = self + + return StreamedResultSet(**streamed_result_set_args) + class Snapshot(_SnapshotBase): """Allow a set of reads / SQL statements with shared staleness. diff --git a/tests/_builders.py b/tests/_builders.py new file mode 100644 index 0000000000..c07d003c19 --- /dev/null +++ b/tests/_builders.py @@ -0,0 +1,71 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Default identifiers. +PROJECT_ID = "project-id" +INSTANCE_ID = "instance-id" +DATABASE_ID = "database-id" +SESSION_ID = "session-id" +TRANSACTION_ID = b"transaction-id" + +# Default names. +INSTANCE_NAME = "projects/" + PROJECT_ID + "/instances/" + INSTANCE_ID +DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID +SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID + + +def build_database(**database_kwargs): + """Builds and returns a database for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" + from google.cloud.spanner_v1 import Session as SessionProto + from google.cloud.spanner_v1 import SpannerClient + from google.cloud.spanner_v1 import Transaction as TransactionProto + from google.cloud.spanner_v1.database import Database + from mock.mock import create_autospec + + instance = _build_instance() + database = Database(database_id=DATABASE_ID, instance=instance) + + # Mock API calls. Callers can override this to test specific behaviours. + api = database._spanner_api = create_autospec(SpannerClient, instance=True) + api.create_session.return_value = SessionProto(name=SESSION_NAME) + api.begin_transaction.return_value = TransactionProto(id=TRANSACTION_ID) + + return database + + +def build_session(**session_kwargs): + """Builds and returns a session for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" + from google.cloud.spanner_v1.session import Session + + if not session_kwargs.get("database"): + session_kwargs["database"] = build_database() + + return Session(**session_kwargs) + + +def _build_client(): + """Builds and returns a client for testing.""" + from google.cloud.spanner_v1 import Client + + return Client(project=PROJECT_ID) + + +def _build_instance(**instance_kwargs): + """Builds and returns an instance for testing.""" + from google.cloud.spanner_v1.instance import Instance + + client = _build_client() + return Instance(instance_id=INSTANCE_ID, client=client) diff --git a/tests/_helpers.py b/tests/_helpers.py index 667f9f8be1..f004e44494 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,7 +1,9 @@ +import os import unittest import mock from google.cloud.spanner_v1 import gapic_version +from google.cloud.spanner_v1.session_options import SessionOptions LIB_VERSION = gapic_version.__version__ @@ -149,3 +151,20 @@ def finished_spans_events_statuses(self): got_all_events.append((event.name, evt_attributes)) return got_all_events + + +def enable_multiplexed_sessions() -> None: + """Sets environment variables to enable multiplexed sessions for all transaction types. + The caller is responsible for resetting the environment variables after use.""" + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + + +def disable_multiplexed_sessions() -> None: + """Sets environment variables to disable multiplexed sessions for all transactions types. + The caller is responsible for resetting the environment variables after use.""" + + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index 3f23cc8edc..c247a6d6a9 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -233,7 +233,7 @@ def select_in_txn(txn): ("No sessions available in pool. Creating session", {"kind": "BurstyPool"}), ("Creating Session", {}), ("Using session", {"id": session_id, "multiplexed": session_multiplexed}), - ("Returned session", {"id": session_id, "multiplexed": session_multiplexed}), + ("Returning session", {"id": session_id, "multiplexed": session_multiplexed}), ( "Transaction was aborted in user operation, retrying", {"delay_seconds": "EPHEMERAL", "cause": "EPHEMERAL", "attempt": 1}, @@ -426,7 +426,7 @@ def test_database_partitioned_error(mock_session_multiplexed, mock_session_id): ("Creating Session", {}), ("Using session", {"id": session_id, "multiplexed": session_multiplexed}), ("Starting BeginTransaction", {}), - ("Returned session", {"id": session_id, "multiplexed": session_multiplexed}), + ("Returning session", {"id": session_id, "multiplexed": session_multiplexed}), ( "exception", { diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 4de0e681f6..be688bf43a 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -33,7 +33,7 @@ from tests import _helpers as ot_helpers from . import _helpers from . import _sample_data - +from .._builders import build_session SOME_DATE = datetime.date(2011, 1, 17) SOME_TIME = datetime.datetime(1989, 1, 17, 17, 59, 12, 345612) @@ -415,7 +415,7 @@ def handle_abort(self, database): def test_session_crud(sessions_database): - session = sessions_database.session() + session = build_session(database=sessions_database) assert not session.exists() session.create() @@ -587,7 +587,7 @@ def test_transaction_read_and_insert_then_rollback( sd = _sample_data db_name = sessions_database.name - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -722,7 +722,7 @@ def test_transaction_read_and_insert_or_update_then_commit( # [START spanner_test_dml_read_your_writes] sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -780,7 +780,7 @@ def test_transaction_execute_sql_w_dml_read_rollback( # [START spanner_test_dml_rollback_txn_not_committed] sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -814,7 +814,7 @@ def test_transaction_execute_update_read_commit(sessions_database, sessions_to_d # [START spanner_test_dml_read_your_writes] sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -846,7 +846,7 @@ def test_transaction_execute_update_then_insert_commit( # [START spanner_test_dml_update] sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -879,7 +879,7 @@ def test_transaction_execute_sql_dml_returning( ): sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -910,7 +910,7 @@ def test_transaction_execute_update_dml_returning( ): sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -938,7 +938,7 @@ def test_transaction_batch_update_dml_returning( ): sd = _sample_data - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -970,7 +970,7 @@ def test_transaction_batch_update_success( sd = _sample_data param_types = spanner_v1.param_types - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -1027,7 +1027,7 @@ def test_transaction_batch_update_and_execute_dml( sd = _sample_data param_types = spanner_v1.param_types - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -1087,7 +1087,7 @@ def test_transaction_batch_update_w_syntax_error( sd = _sample_data param_types = spanner_v1.param_types - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -1132,7 +1132,7 @@ def unit_of_work(transaction): def test_transaction_batch_update_wo_statements(sessions_database, sessions_to_delete): - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) @@ -1155,7 +1155,7 @@ def test_transaction_batch_update_w_parent_span( param_types = spanner_v1.param_types tracer = trace.get_tracer(__name__) - session = sessions_database.session() + session = build_session(database=sessions_database) session.create() sessions_to_delete.append(session) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index fc41a156db..038ab558d2 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -41,6 +41,7 @@ from google.cloud.spanner_v1.instance import Instance from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.session_options import TransactionType PROJECT = "test-project" INSTANCE = "test-instance" @@ -147,14 +148,14 @@ def test__session_checkout_read_only(self): session_manager = database._session_manager session = Session(database=database) - session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) read_only_connection = Connection( instance="instance-id", database=database, read_only=True ) read_only_connection._session_checkout() - session_manager.get_session_for_read_only.assert_called_once_with() + session_manager.get_session.assert_called_once_with(TransactionType.READ_ONLY) self.assertEqual(read_only_connection._session, session) def test__session_checkout_read_write(self): @@ -164,14 +165,14 @@ def test__session_checkout_read_write(self): session_manager = database._session_manager session = Session(database=database) - session_manager.get_session_for_read_write = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) read_write_connection = Connection( instance="instance-id", database=database, read_only=False ) read_write_connection._session_checkout() - session_manager.get_session_for_read_write.assert_called_once_with() + session_manager.get_session.assert_called_once_with(TransactionType.READ_WRITE) self.assertEqual(read_write_connection._session, session) def test_session_checkout_database_error(self): diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 7bea6fc24a..56d6223b3e 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import os import unittest from logging import Logger @@ -22,7 +21,7 @@ Database as DatabasePB, DatabaseDialect, ) -from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.database import Database, BatchSnapshot, SessionCheckout from google.cloud.spanner_v1.param_types import INT64 from google.api_core.retry import Retry from google.protobuf.field_mask_pb2 import FieldMask @@ -32,6 +31,10 @@ DirectedReadOptions, RequestOptions, ) +from google.cloud.spanner_v1.session_options import TransactionType +from google.cloud.spanner_v1.snapshot import Snapshot +from tests._builders import build_database, build_session, TRANSACTION_ID, SESSION_ID +from tests._helpers import enable_multiplexed_sessions DML_WO_PARAM = """ DELETE FROM citizens @@ -77,13 +80,20 @@ class _BaseTest(unittest.TestCase): DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID SESSION_ID = "session_id" SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID - TRANSACTION_ID = b"transaction_id" - RETRY_TRANSACTION_ID = b"transaction_id_retry" BACKUP_ID = "backup_id" BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID TRANSACTION_TAG = "transaction-tag" DATABASE_ROLE = "dummy-role" + def setUp(self): + # Save the original environment variables. + self._original_env = dict(os.environ) + + def tearDown(self): + # Restore environment variables. + os.environ.clear() + os.environ.update(self._original_env) + def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -92,7 +102,7 @@ def _make_timestamp(): import datetime from google.cloud._helpers import UTC - return datetime.datetime.utcnow().replace(tzinfo=UTC) + return datetime.datetime.now(tz=UTC) @staticmethod def _make_duration(seconds=1, microseconds=0): @@ -129,7 +139,6 @@ def test_ctor_defaults(self): self.assertEqual(list(database.ddl_statements), []) self.assertIsInstance(database._session_manager._pool, BurstyPool) self.assertFalse(database.log_commit_stats) - self.assertIsNone(database._logger) self.assertIsNone(database.database_role) self.assertTrue(database._route_to_leader_enabled, True) @@ -1105,9 +1114,15 @@ def _execute_partitioned_dml_helper( import collections - MethodConfig = collections.namedtuple("MethodConfig", ["retry"]) + transaction_id = b"transaction-id" + transaction_pb = TransactionPB(id=transaction_id) + transaction_selector_pb = TransactionSelector(id=transaction_id) + + retry_transaction_id = b"retry-transaction-id" + retry_transaction_pb = TransactionPB(id=retry_transaction_id) + retry_transaction_selector_pb = TransactionSelector(id=retry_transaction_id) - transaction_pb = TransactionPB(id=self.TRANSACTION_ID) + MethodConfig = collections.namedtuple("MethodConfig", ["retry"]) stats_pb = ResultSetStats(row_count_lower_bound=2) result_sets = [PartialResultSet(stats=stats_pb)] @@ -1122,7 +1137,6 @@ def _execute_partitioned_dml_helper( api = database._spanner_api = self._make_spanner_api() api._method_configs = {"ExecuteStreamingSql": MethodConfig(retry=Retry())} if retried: - retry_transaction_pb = TransactionPB(id=self.RETRY_TRANSACTION_ID) api.begin_transaction.side_effect = [transaction_pb, retry_transaction_pb] api.execute_streaming_sql.side_effect = [Aborted("test"), iterator] else: @@ -1165,7 +1179,6 @@ def _execute_partitioned_dml_helper( else: expected_params = {} - expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) expected_query_options = client._query_options if query_options: expected_query_options = _merge_query_options( @@ -1180,7 +1193,7 @@ def _execute_partitioned_dml_helper( expected_request = ExecuteSqlRequest( session=self.SESSION_NAME, sql=dml, - transaction=expected_transaction, + transaction=transaction_selector_pb, params=expected_params, param_types=param_types, query_options=expected_query_options, @@ -1195,13 +1208,10 @@ def _execute_partitioned_dml_helper( ], ) if retried: - expected_retry_transaction = TransactionSelector( - id=self.RETRY_TRANSACTION_ID - ) expected_request = ExecuteSqlRequest( session=self.SESSION_NAME, sql=dml, - transaction=expected_retry_transaction, + transaction=retry_transaction_selector_pb, params=expected_params, param_types=param_types, query_options=expected_query_options, @@ -1262,36 +1272,19 @@ def test_execute_partitioned_dml_w_exclude_txn_from_change_streams(self): dml=DML_WO_PARAM, exclude_txn_from_change_streams=True ) - def test_session_factory_defaults(self): - from google.cloud.spanner_v1.session import Session - - client = _Client() - instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) - - session = database.session() - - self.assertIsInstance(session, Session) - self.assertIs(session.session_id, None) - self.assertIs(session._database, database) - self.assertEqual(session.labels, {}) + def test_execute_partitioned_dml_not_implemented_error_multiplexed(self): + enable_multiplexed_sessions() - def test_session_factory_w_labels(self): - from google.cloud.spanner_v1.session import Session - - client = _Client() - instance = _Instance(self.INSTANCE_NAME, client=client) - pool = _Pool() - labels = {"foo": "bar"} - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + database = build_database() + database.spanner_api.begin_transaction.side_effect = NotImplementedError( + "Transaction type partitioned_dml not supported with multiplexed sessions" + ) - session = database.session(labels=labels) + with self.assertRaises(NotImplementedError): + database.execute_partitioned_dml(dml=DML_WO_PARAM) - self.assertIsInstance(session, Session) - self.assertIs(session.session_id, None) - self.assertIs(session._database, database) - self.assertEqual(session.labels, labels) + session_options = database.session_options + self.assertFalse(session_options.use_multiplexed(TransactionType.PARTITIONED)) def test_snapshot_defaults(self): from google.cloud.spanner_v1.database import SnapshotCheckout @@ -1313,7 +1306,7 @@ def test_snapshot_w_read_timestamp_and_multi_use(self): from google.cloud._helpers import UTC from google.cloud.spanner_v1.database import SnapshotCheckout - now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now = datetime.datetime.now(tz=UTC) client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() @@ -1812,7 +1805,7 @@ def test_context_mgr_success(self): from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud.spanner_v1.batch import Batch - now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now = datetime.datetime.now(tz=UTC) now_pb = _datetime_to_pb_timestamp(now) response = CommitResponse(commit_timestamp=now_pb) database = Database( @@ -1823,7 +1816,7 @@ def test_context_mgr_success(self): session = _Session(database) session_manager = database._session_manager - session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one( @@ -1831,7 +1824,9 @@ def test_context_mgr_success(self): ) with checkout as batch: - session_manager.get_session_for_read_only.assert_called_once() + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) @@ -1864,7 +1859,7 @@ def test_context_mgr_w_commit_stats_success(self): from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud.spanner_v1.batch import Batch - now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now = datetime.datetime.now(tz=UTC) now_pb = _datetime_to_pb_timestamp(now) commit_stats = CommitResponse.CommitStats(mutation_count=4) response = CommitResponse(commit_timestamp=now_pb, commit_stats=commit_stats) @@ -1880,13 +1875,15 @@ def test_context_mgr_w_commit_stats_success(self): session = _Session(database) session_manager = database._session_manager - session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database) with checkout as batch: - session_manager.get_session_for_read_only.assert_called_once() + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) @@ -1937,14 +1934,16 @@ def test_context_mgr_w_aborted_commit_status(self): session = _Session(database) session_manager = database._session_manager - session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database) with self.assertRaises(Aborted): with checkout as batch: - session_manager.get_session_for_read_only.assert_called_once() + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) @@ -1981,7 +1980,7 @@ def test_context_mgr_failure(self): session = _Session(database) session_manager = database._session_manager - session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database) @@ -1991,7 +1990,9 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as batch: - session_manager.get_session_for_read_only.assert_called_once() + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) raise Testing() @@ -2007,15 +2008,13 @@ def _get_target_class(self): return SnapshotCheckout def test_ctor_defaults(self): - from google.cloud.spanner_v1.snapshot import Snapshot - database = Database( database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) ) session = _Session(database) session_manager = database._session_manager - session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database) @@ -2023,7 +2022,9 @@ def test_ctor_defaults(self): self.assertEqual(checkout._kw, {}) with checkout as snapshot: - session_manager.get_session_for_read_only.assert_called_once() + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(snapshot, Snapshot) self.assertIs(snapshot._session, session) self.assertTrue(snapshot._strong) @@ -2034,16 +2035,15 @@ def test_ctor_defaults(self): def test_ctor_w_read_timestamp_and_multi_use(self): import datetime from google.cloud._helpers import UTC - from google.cloud.spanner_v1.snapshot import Snapshot - now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now = datetime.datetime.now(tz=UTC) database = Database( database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) ) session = _Session(database) session_manager = database._session_manager - session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database, read_timestamp=now, multi_use=True) @@ -2051,7 +2051,9 @@ def test_ctor_w_read_timestamp_and_multi_use(self): self.assertEqual(checkout._kw, {"read_timestamp": now, "multi_use": True}) with checkout as snapshot: - session_manager.get_session_for_read_only.assert_called_once() + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(snapshot, Snapshot) self.assertIs(snapshot._session, session) self.assertEqual(snapshot._read_timestamp, now) @@ -2060,15 +2062,13 @@ def test_ctor_w_read_timestamp_and_multi_use(self): session_manager.put_session.assert_called_once_with(session) def test_context_mgr_failure(self): - from google.cloud.spanner_v1.snapshot import Snapshot - database = Database( database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) ) session = _Session(database) session_manager = database._session_manager - session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database) @@ -2078,7 +2078,9 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as snapshot: - session_manager.get_session_for_read_only.assert_called_once() + session_manager.get_session.assert_called_once_with( + TransactionType.READ_ONLY + ) self.assertIsInstance(snapshot, Snapshot) self.assertIs(snapshot._session, session) raise Testing() @@ -2097,25 +2099,8 @@ class TestBatchSnapshot(_BaseTest): TOKENS = [b"TOKEN1", b"TOKEN2"] INDEX = "index" - def _get_target_class(self): - from google.cloud.spanner_v1.database import BatchSnapshot - - return BatchSnapshot - - @staticmethod - def _make_database(**kwargs): - return mock.create_autospec(Database, instance=True, **kwargs) - - @staticmethod - def _make_session(**kwargs): - from google.cloud.spanner_v1.session import Session - - return mock.create_autospec(Session, instance=True, **kwargs) - @staticmethod def _make_snapshot(transaction_id=None, **kwargs): - from google.cloud.spanner_v1.snapshot import Snapshot - snapshot = mock.create_autospec(Snapshot, instance=True, **kwargs) if transaction_id is not None: snapshot._transaction_id = transaction_id @@ -2129,9 +2114,8 @@ def _make_keyset(): return KeySet(all_=True) def test_ctor_no_staleness(self): - database = self._make_database() - - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) self.assertIs(batch_txn._database, database) self.assertIsNone(batch_txn._session) @@ -2140,10 +2124,9 @@ def test_ctor_no_staleness(self): self.assertIsNone(batch_txn._exact_staleness) def test_ctor_w_read_timestamp(self): - database = self._make_database() + database = build_database() timestamp = self._make_timestamp() - - batch_txn = self._make_one(database, read_timestamp=timestamp) + batch_txn = BatchSnapshot(database, read_timestamp=timestamp) self.assertIs(batch_txn._database, database) self.assertIsNone(batch_txn._session) @@ -2152,10 +2135,9 @@ def test_ctor_w_read_timestamp(self): self.assertIsNone(batch_txn._exact_staleness) def test_ctor_w_exact_staleness(self): - database = self._make_database() + database = build_database() duration = self._make_duration() - - batch_txn = self._make_one(database, exact_staleness=duration) + batch_txn = BatchSnapshot(database, exact_staleness=duration) self.assertIs(batch_txn._database, database) self.assertIsNone(batch_txn._session) @@ -2164,103 +2146,87 @@ def test_ctor_w_exact_staleness(self): self.assertEqual(batch_txn._exact_staleness, duration) def test_from_dict(self): - klass = self._get_target_class() - database = self._make_database() - session = database.session.return_value = self._make_session() - snapshot = session.snapshot.return_value = self._make_snapshot() - api_repr = { - "session_id": self.SESSION_ID, - "transaction_id": self.TRANSACTION_ID, - } + database = build_database() + + batch_txn = BatchSnapshot.from_dict( + database, + { + "transaction_id": TRANSACTION_ID, + "session_id": SESSION_ID, + }, + ) - batch_txn = klass.from_dict(database, api_repr) self.assertIs(batch_txn._database, database) - self.assertIs(batch_txn._session, session) - self.assertEqual(session._session_id, self.SESSION_ID) - self.assertEqual(snapshot._transaction_id, self.TRANSACTION_ID) - snapshot.begin.assert_not_called() - self.assertIs(batch_txn._snapshot, snapshot) + self.assertIs(batch_txn._transaction_id, TRANSACTION_ID) + self.assertIs(batch_txn._session_id, SESSION_ID) - def test_to_dict(self): - database = self._make_database() - batch_txn = self._make_one(database) - batch_txn._session = self._make_session(_session_id=self.SESSION_ID) - batch_txn._snapshot = self._make_snapshot(transaction_id=self.TRANSACTION_ID) - - expected = { - "session_id": self.SESSION_ID, - "transaction_id": self.TRANSACTION_ID, - } - self.assertEqual(batch_txn.to_dict(), expected) + session = batch_txn._get_session() + self.assertEqual(session._session_id, SESSION_ID) + + snapshot = batch_txn._get_snapshot() + self.assertEqual(snapshot._transaction_id, TRANSACTION_ID) + database.spanner_api.begin_transaction.assert_not_called() - def test__get_session_already(self): - database = self._make_database() - batch_txn = self._make_one(database) - already = batch_txn._session = object() - self.assertIs(batch_txn._get_session(), already) + def test_to_dict(self): + database = build_database() + batch_txn = BatchSnapshot(database) - def test__get_session_new(self): - database = self._make_database() - session = database.session.return_value = self._make_session() - batch_txn = self._make_one(database) - self.assertIs(batch_txn._get_session(), session) - session.create.assert_called_once_with() + self.assertEqual( + batch_txn.to_dict(), + { + "transaction_id": TRANSACTION_ID, + "session_id": SESSION_ID, + }, + ) def test__get_snapshot_already(self): - database = self._make_database() - batch_txn = self._make_one(database) - already = batch_txn._snapshot = self._make_snapshot() - self.assertIs(batch_txn._get_snapshot(), already) - already.begin.assert_not_called() + database = build_database() + batch_txn = BatchSnapshot(database) + snapshot_1 = batch_txn._get_snapshot() + + snapshot_2 = batch_txn._get_snapshot() + self.assertEqual(snapshot_1, snapshot_2) + database.spanner_api.begin_transaction.assert_called_once() def test__get_snapshot_new_wo_staleness(self): - database = self._make_database() - batch_txn = self._make_one(database) - session = batch_txn._session = self._make_session() - snapshot = session.snapshot.return_value = self._make_snapshot() - self.assertIs(batch_txn._get_snapshot(), snapshot) - session.snapshot.assert_called_once_with( - read_timestamp=None, - exact_staleness=None, - multi_use=True, - transaction_id=None, - ) - snapshot.begin.assert_called_once_with() + database = build_database() + batch_txn = BatchSnapshot(database) + snapshot = batch_txn._get_snapshot() + + self.assertIsNone(snapshot._read_timestamp) + self.assertIsNone(snapshot._exact_staleness) + self.assertTrue(snapshot._multi_use) + self.assertEqual(TRANSACTION_ID, snapshot._transaction_id) + database.spanner_api.begin_transaction.assert_called_once() def test__get_snapshot_w_read_timestamp(self): - database = self._make_database() + database = build_database() timestamp = self._make_timestamp() - batch_txn = self._make_one(database, read_timestamp=timestamp) - session = batch_txn._session = self._make_session() - snapshot = session.snapshot.return_value = self._make_snapshot() - self.assertIs(batch_txn._get_snapshot(), snapshot) - session.snapshot.assert_called_once_with( - read_timestamp=timestamp, - exact_staleness=None, - multi_use=True, - transaction_id=None, - ) - snapshot.begin.assert_called_once_with() + batch_txn = BatchSnapshot(database, read_timestamp=timestamp) + snapshot = batch_txn._get_snapshot() + + self.assertEqual(timestamp, snapshot._read_timestamp) + self.assertIsNone(snapshot._exact_staleness) + self.assertTrue(snapshot._multi_use) + self.assertEqual(TRANSACTION_ID, snapshot._transaction_id) + database.spanner_api.begin_transaction.assert_called_once() def test__get_snapshot_w_exact_staleness(self): - database = self._make_database() + database = build_database() duration = self._make_duration() - batch_txn = self._make_one(database, exact_staleness=duration) - session = batch_txn._session = self._make_session() - snapshot = session.snapshot.return_value = self._make_snapshot() - self.assertIs(batch_txn._get_snapshot(), snapshot) - session.snapshot.assert_called_once_with( - read_timestamp=None, - exact_staleness=duration, - multi_use=True, - transaction_id=None, - ) - snapshot.begin.assert_called_once_with() + batch_txn = BatchSnapshot(database, exact_staleness=duration) + snapshot = batch_txn._get_snapshot() + + self.assertIsNone(snapshot._read_timestamp) + self.assertEqual(duration, snapshot._exact_staleness) + self.assertTrue(snapshot._multi_use) + self.assertEqual(TRANSACTION_ID, snapshot._transaction_id) + database.spanner_api.begin_transaction.assert_called_once() def test_read(self): keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() rows = batch_txn.read(self.TABLE, self.COLUMNS, keyset, self.INDEX) @@ -2276,8 +2242,8 @@ def test_execute_sql(self): ) params = {"max_age": 30} param_types = {"max_age": "INT64"} - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() rows = batch_txn.execute_sql(sql, params, param_types) @@ -2288,8 +2254,8 @@ def test_execute_sql(self): def test_generate_read_batches_w_max_partitions(self): max_partitions = len(self.TOKENS) keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_read.return_value = self.TOKENS @@ -2326,8 +2292,8 @@ def test_generate_read_batches_w_max_partitions(self): def test_generate_read_batches_w_retry_and_timeout_params(self): max_partitions = len(self.TOKENS) keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_read.return_value = self.TOKENS retry = Retry(deadline=60) @@ -2369,8 +2335,8 @@ def test_generate_read_batches_w_retry_and_timeout_params(self): def test_generate_read_batches_w_index_w_partition_size_bytes(self): size = 1 << 20 keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_read.return_value = self.TOKENS @@ -2411,8 +2377,8 @@ def test_generate_read_batches_w_index_w_partition_size_bytes(self): def test_generate_read_batches_w_data_boost_enabled(self): data_boost_enabled = True keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_read.return_value = self.TOKENS @@ -2452,8 +2418,8 @@ def test_generate_read_batches_w_data_boost_enabled(self): def test_generate_read_batches_w_directed_read_options(self): keyset = self._make_keyset() - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_read.return_value = self.TOKENS @@ -2503,8 +2469,8 @@ def test_process_read_batch(self): "index": self.INDEX, }, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.read.return_value = object() @@ -2534,8 +2500,8 @@ def test_process_read_batch_w_retry_timeout(self): "index": self.INDEX, }, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.read.return_value = object() retry = Retry(deadline=60) @@ -2559,7 +2525,7 @@ def test_generate_query_batches_w_max_partitions(self): client = _Client(self.PROJECT_ID) instance = _Instance(self.INSTANCE_NAME, client=client) database = _Database(self.DATABASE_NAME, instance=instance) - batch_txn = self._make_one(database) + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS @@ -2598,7 +2564,7 @@ def test_generate_query_batches_w_params_w_partition_size_bytes(self): client = _Client(self.PROJECT_ID) instance = _Instance(self.INSTANCE_NAME, client=client) database = _Database(self.DATABASE_NAME, instance=instance) - batch_txn = self._make_one(database) + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS @@ -2641,7 +2607,7 @@ def test_generate_query_batches_w_retry_and_timeout_params(self): client = _Client(self.PROJECT_ID) instance = _Instance(self.INSTANCE_NAME, client=client) database = _Database(self.DATABASE_NAME, instance=instance) - batch_txn = self._make_one(database) + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS retry = Retry(deadline=60) @@ -2684,7 +2650,7 @@ def test_generate_query_batches_w_data_boost_enabled(self): client = _Client(self.PROJECT_ID) instance = _Instance(self.INSTANCE_NAME, client=client) database = _Database(self.DATABASE_NAME, instance=instance) - batch_txn = self._make_one(database) + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS @@ -2716,7 +2682,7 @@ def test_generate_query_batches_w_directed_read_options(self): client = _Client(self.PROJECT_ID) instance = _Instance(self.INSTANCE_NAME, client=client) database = _Database(self.DATABASE_NAME, instance=instance) - batch_txn = self._make_one(database) + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() snapshot.partition_query.return_value = self.TOKENS @@ -2758,8 +2724,8 @@ def test_process_query_batch(self): "partition": token, "query": {"sql": sql, "params": params, "param_types": param_types}, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.execute_sql.return_value = object() @@ -2787,8 +2753,8 @@ def test_process_query_batch_w_retry_timeout(self): "partition": token, "query": {"sql": sql, "params": params, "param_types": param_types}, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.execute_sql.return_value = object() retry = Retry(deadline=60) @@ -2812,8 +2778,8 @@ def test_process_query_batch_w_directed_read_options(self): "partition": token, "query": {"sql": sql, "directed_read_options": DIRECTED_READ_OPTIONS}, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.execute_sql.return_value = object() @@ -2830,25 +2796,26 @@ def test_process_query_batch_w_directed_read_options(self): ) def test_close_wo_session(self): - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) batch_txn.close() # no raise def test_close_w_session(self): - database = self._make_database() - batch_txn = self._make_one(database) - session = batch_txn._session = self._make_session() + database = build_database() + database._session_manager.put_session = mock.Mock() + batch_txn = BatchSnapshot(database) + session = batch_txn._get_session() batch_txn.close() - session.delete.assert_called_once_with() + database._session_manager.put_session.assert_called_once_with(session) def test_process_w_invalid_batch(self): token = b"TOKEN" batch = {"partition": token, "bogus": b"BOGUS"} - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) with self.assertRaises(ValueError): batch_txn.process(batch) @@ -2865,8 +2832,8 @@ def test_process_w_read_batch(self): "index": self.INDEX, }, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.read.return_value = object() @@ -2895,8 +2862,8 @@ def test_process_w_query_batch(self): "partition": token, "query": {"sql": sql, "params": params, "param_types": param_types}, } - database = self._make_database() - batch_txn = self._make_one(database) + database = build_database() + batch_txn = BatchSnapshot(database) snapshot = batch_txn._snapshot = self._make_snapshot() expected = snapshot.execute_sql.return_value = object() @@ -2935,14 +2902,16 @@ def test_ctor(self): session = _Session(database) session_manager = database._session_manager - session_manager.get_session_for_read_write = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database) self.assertIs(checkout._database, database) with checkout as groups: - session_manager.get_session_for_read_write.assert_called_once() + session_manager.get_session.assert_called_once_with( + TransactionType.READ_WRITE + ) self.assertIsInstance(groups, MutationGroups) self.assertIs(groups._session, session) @@ -2959,7 +2928,7 @@ def test_context_mgr_success(self): from google.cloud.spanner_v1.batch import MutationGroups from google.rpc.status_pb2 import Status - now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now = datetime.datetime.now(tz=UTC) now_pb = _datetime_to_pb_timestamp(now) status_pb = Status(code=200) response = BatchWriteResponse( @@ -2973,7 +2942,7 @@ def test_context_mgr_success(self): session = _Session(database) session_manager = database._session_manager - session_manager.get_session_for_read_write = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database) @@ -2997,7 +2966,9 @@ def test_context_mgr_success(self): request_options=request_options, ) with checkout as groups: - session_manager.get_session_for_read_write.assert_called_once() + session_manager.get_session.assert_called_once_with( + TransactionType.READ_WRITE + ) self.assertIsInstance(groups, MutationGroups) self.assertIs(groups._session, session) group = groups.group() @@ -3029,7 +3000,7 @@ def test_context_mgr_failure(self): session = _Session(database) session_manager = database._session_manager - session_manager.get_session_for_read_write = mock.Mock(return_value=session) + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database) @@ -3040,7 +3011,9 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as groups: self.assertIsInstance(groups, MutationGroups) - session_manager.get_session_for_read_write.assert_called_once() + session_manager.get_session.assert_called_once_with( + TransactionType.READ_WRITE + ) self.assertIs(groups._session, session) raise Testing() @@ -3048,64 +3021,69 @@ class Testing(Exception): class TestSessionCheckout(_BaseTest): - def _get_target_class(self): - from google.cloud.spanner_v1.database import SessionCheckout - - return SessionCheckout - def test_ctor(self): - database = Database( - database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) - ) + database = build_database() - checkout = self._make_one(database) + # Default transaction type. + checkout = SessionCheckout(database) self.assertIs(checkout._database, database) + self.assertIs(checkout._transaction_type, TransactionType.READ_WRITE) + self.assertIsNone(checkout._session) + + # Specified transaction type. + transaction_type = TransactionType.READ_ONLY + checkout = SessionCheckout(database, transaction_type) + self.assertIs(checkout._database, database) + self.assertIs(checkout._transaction_type, transaction_type) self.assertIsNone(checkout._session) def test_context_manager_success(self): - database = Database( - database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) - ) + database = build_database() + transaction_type = TransactionType.READ_ONLY + checkout = SessionCheckout(database, transaction_type) - session = _Session(database) - session_manager = database._session_manager - session_manager.get_session_for_read_write = mock.Mock(return_value=session) + session = build_session(database=database) + session_manager = session._database._session_manager + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) - checkout = self._make_one(database) - with checkout as borrowed: - session_manager.get_session_for_read_write.assert_called_once() + session_manager.get_session.assert_called_once_with(transaction_type) self.assertIs(borrowed, session) session_manager.put_session.assert_called_once_with(session) def test_context_manager_failure(self): - database = Database( - database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) - ) + database = build_database() + transaction_type = TransactionType.READ_ONLY + checkout = SessionCheckout(database, transaction_type) - session = _Session(database) - session_manager = database._session_manager - session_manager.get_session_for_read_write = mock.Mock(return_value=session) + session = build_session(database=database) + session_manager = session._database._session_manager + session_manager.get_session = mock.Mock(return_value=session) session_manager.put_session = mock.Mock(return_value=None) - checkout = self._make_one(database) - class Testing(Exception): pass with self.assertRaises(Testing): with checkout as borrowed: - session_manager.get_session_for_read_write.assert_called_once() + session_manager.get_session.assert_called_once_with(transaction_type) self.assertIs(borrowed, session) raise Testing() session_manager.put_session.assert_called_once_with(session) - def test_type_error(self): + def test_type_errors(self): + database = build_database() + transaction_type = TransactionType.READ_ONLY + with self.assertRaises(TypeError): - with self._make_one(None) as _: + with SessionCheckout(None, transaction_type) as _: + pass + + with self.assertRaises(TypeError): + with SessionCheckout(database, None) as _: pass @@ -3137,6 +3115,7 @@ def __init__( self._endpoint_cache = {} self.database_admin_api = _make_database_admin_api() self.instance_admin_api = _make_instance_api() + self.credentials = mock.Mock() self._client_info = mock.Mock() self._client_options = mock.Mock() self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py index c1346da163..6019dccb28 100644 --- a/tests/unit/test_database_session_manager.py +++ b/tests/unit/test_database_session_manager.py @@ -14,8 +14,9 @@ import os import datetime import time +from threading import Thread from unittest import TestCase -from unittest.mock import Mock, patch, DEFAULT, PropertyMock +from unittest.mock import Mock, DEFAULT, patch from google.api_core.exceptions import ( MethodNotImplemented, @@ -23,48 +24,33 @@ FailedPrecondition, ) -from google.cloud.spanner_v1 import SpannerClient -from google.cloud.spanner_v1.client import Client -from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager -from google.cloud.spanner_v1.instance import Instance -from google.cloud.spanner_v1.session_options import SessionOptions +from google.cloud.spanner_v1.session_options import TransactionType +from tests._helpers import disable_multiplexed_sessions, enable_multiplexed_sessions +# Shorten polling and refresh intervals for testing. +@patch.multiple( + DatabaseSessionsManager, + _MAINTENANCE_THREAD_POLLING_INTERVAL=datetime.timedelta(seconds=1), + _MAINTENANCE_THREAD_REFRESH_INTERVAL=datetime.timedelta(seconds=2), +) class TestDatabaseSessionManager(TestCase): def setUp(self): self._original_env = dict(os.environ) - - self._mocks = { - "create_session": patch.object(SpannerClient, "create_session").start(), - "delete_session": patch.object(SpannerClient, "delete_session").start(), - # Mock faster polling and refresh intervals for tests. - "polling_interval": patch.object( - DatabaseSessionsManager, - "_MAINTENANCE_THREAD_POLLING_INTERVAL", - new_callable=PropertyMock, - return_value=datetime.timedelta(seconds=1), - ).start(), - "refresh_interval": patch.object( - DatabaseSessionsManager, - "_MAINTENANCE_THREAD_REFRESH_INTERVAL", - new_callable=PropertyMock, - return_value=datetime.timedelta(seconds=2), - ).start(), - } + self._build_session_manager() def tearDown(self): + self._cleanup_database_session_manager() os.environ.clear() os.environ.update(self._original_env) - patch.stopall() - def test_read_only_pooled(self): - self._disable_multiplexed_env_vars() - session_manager = self._build_database_session_manager() + disable_multiplexed_sessions() + session_manager = self._session_manager # Get session from pool. - session = session_manager.get_session_for_read_only() + session = session_manager.get_session(TransactionType.READ_ONLY) self.assertFalse(session.is_multiplexed) session_manager._pool.get.assert_called_once() @@ -73,16 +59,16 @@ def test_read_only_pooled(self): session_manager._pool.put.assert_called_once_with(session) def test_read_only_multiplexed(self): - self._enable_multiplexed_env_vars() - session_manager = self._build_database_session_manager() + enable_multiplexed_sessions() + session_manager = self._session_manager # Session is created. - session_1 = session_manager.get_session_for_read_only() + session_1 = session_manager.get_session(TransactionType.READ_ONLY) self.assertTrue(session_1.is_multiplexed) session_manager.put_session(session_1) # Session is re-used. - session_2 = session_manager.get_session_for_read_only() + session_2 = session_manager.get_session(TransactionType.READ_ONLY) self.assertEqual(session_1, session_2) session_manager.put_session(session_2) @@ -91,11 +77,11 @@ def test_read_only_multiplexed(self): session_manager._pool.put.assert_not_called() def test_partitioned_pooled(self): - self._disable_multiplexed_env_vars() - session_manager = self._build_database_session_manager() + disable_multiplexed_sessions() + session_manager = self._session_manager # Get session from pool. - session = session_manager.get_session_for_partitioned() + session = session_manager.get_session(TransactionType.PARTITIONED) self.assertFalse(session.is_multiplexed) session_manager._pool.get.assert_called_once() @@ -104,18 +90,29 @@ def test_partitioned_pooled(self): session_manager._pool.put.assert_called_once_with(session) def test_partitioned_multiplexed(self): - self._enable_multiplexed_env_vars() - session_manager = self._build_database_session_manager() + enable_multiplexed_sessions() + session_manager = self._session_manager - with self.assertRaises(NotImplementedError): - session_manager.get_session_for_partitioned() + # Session is created. + session_1 = session_manager.get_session(TransactionType.PARTITIONED) + self.assertTrue(session_1.is_multiplexed) + session_manager.put_session(session_1) + + # Session is re-used. + session_2 = session_manager.get_session(TransactionType.PARTITIONED) + self.assertEqual(session_1, session_2) + session_manager.put_session(session_2) + + # Verify that pool was not used. + session_manager._pool.get.assert_not_called() + session_manager._pool.put.assert_not_called() def test_read_write_pooled(self): - self._disable_multiplexed_env_vars() - session_manager = self._build_database_session_manager() + disable_multiplexed_sessions() + session_manager = self._session_manager # Get session from pool. - session = session_manager.get_session_for_read_write() + session = session_manager.get_session(TransactionType.READ_WRITE) self.assertFalse(session.is_multiplexed) session_manager._pool.get.assert_called_once() @@ -124,89 +121,81 @@ def test_read_write_pooled(self): session_manager._pool.put.assert_called_once_with(session) def test_read_write_multiplexed(self): - self._enable_multiplexed_env_vars() - session_manager = self._build_database_session_manager() + enable_multiplexed_sessions() + session_manager = self._session_manager with self.assertRaises(NotImplementedError): - session_manager.get_session_for_read_write() + session_manager.get_session(TransactionType.READ_WRITE) - def test_multiplexed_maintenance(self): - self._enable_multiplexed_env_vars() - session_manager = self._build_database_session_manager() + def test_multiplexed_maintenance(self, *_): + enable_multiplexed_sessions() + session_manager = self._session_manager # Maintenance thread is started. - session_1 = session_manager.get_session_for_read_only() + session_1 = session_manager.get_session(TransactionType.READ_ONLY) self.assertTrue(session_1.is_multiplexed) # Wait for maintenance thread to execute. + api = session_manager._database.spanner_api + def create_session_condition(): - return self._mocks["create_session"].call_count > 1 + return api.create_session.call_count > 1 - self.assert_true_with_timeout(create_session_condition) + self._assert_true_with_timeout(create_session_condition) # Verify that maintenance thread created new multiplexed session. - session_2 = session_manager.get_session_for_read_only() + session_2 = session_manager.get_session(TransactionType.READ_ONLY) self.assertTrue(session_2.is_multiplexed) self.assertNotEqual(session_1, session_2) def test_multiplexed_maintenance_terminates_not_implemented(self): - self._enable_multiplexed_env_vars() - session_manager = self._build_database_session_manager() + enable_multiplexed_sessions() + session_manager = self._session_manager # Maintenance thread is started. - session_1 = session_manager.get_session_for_read_only() + session_1 = session_manager.get_session(TransactionType.READ_ONLY) self.assertTrue(session_1.is_multiplexed) # Multiplexed sessions not implemented. - create_session_mock = self._mocks["create_session"] - create_session_mock.side_effect = MethodNotImplemented( - "Multiplexed sessions not implemented" - ) + api = session_manager._database.spanner_api + api.create_session.side_effect = MethodNotImplemented("test") - # Wait for maintenance thread to terminate. + # Verify that maintenance thread is terminated. thread = session_manager._multiplexed_session_maintenance_thread + self._assert_thread_terminated(thread) - def thread_terminated_condition(): - return not thread.is_alive() - - self.assert_true_with_timeout(thread_terminated_condition) - - # Verify that multiplexed sessions are disabled. - session_options = session_manager._database._instance._client.session_options - self.assertFalse(session_options.use_multiplexed_for_read_only()) - self.assertFalse(session_options.use_multiplexed_for_partitioned()) - self.assertFalse(session_options.use_multiplexed_for_read_write()) + # Verify that multiplexed session are disabled. + session_options = session_manager._database.session_options + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) + self.assertFalse(session_options.use_multiplexed(TransactionType.PARTITIONED)) + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_WRITE)) def test_multiplexed_maintenance_terminates_disabled(self): - self._enable_multiplexed_env_vars() - session_manager = self._build_database_session_manager() + enable_multiplexed_sessions() + session_manager = self._session_manager # Maintenance thread is started. - session_1 = session_manager.get_session_for_read_only() + session_1 = session_manager.get_session(TransactionType.READ_ONLY) self.assertTrue(session_1.is_multiplexed) session_manager._is_multiplexed_sessions_disabled_event.set() - # Wait for maintenance thread to terminate. thread = session_manager._multiplexed_session_maintenance_thread - - def thread_terminated_condition(): - return not thread.is_alive() - - self.assert_true_with_timeout(thread_terminated_condition) + self._assert_thread_terminated(thread) def test_multiplexed_exception_method_not_implemented(self): - self._enable_multiplexed_env_vars() - session_manager = self._build_database_session_manager() + enable_multiplexed_sessions() + session_manager = self._session_manager # Multiplexed sessions not implemented. - self._mocks["create_session"].side_effect = [ + api = session_manager._database.spanner_api + api.create_session.side_effect = [ MethodNotImplemented("Test MethodNotImplemented"), DEFAULT, ] # Get session from pool. - session = session_manager.get_session_for_read_only() + session = session_manager.get_session(TransactionType.READ_ONLY) self.assertFalse(session.is_multiplexed) session_manager._pool.get.assert_called_once() @@ -215,37 +204,36 @@ def test_multiplexed_exception_method_not_implemented(self): session_manager._pool.put.assert_called_once_with(session) # Verify that multiplexed session are disabled. - session_options = session_manager._database._instance._client.session_options - self.assertFalse(session_options.use_multiplexed_for_read_only()) - self.assertFalse(session_options.use_multiplexed_for_partitioned()) - self.assertFalse(session_options.use_multiplexed_for_read_write()) + session_options = session_manager._database.session_options + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) + self.assertFalse(session_options.use_multiplexed(TransactionType.PARTITIONED)) + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_WRITE)) def test_exception_bad_request(self): - session_manager = self._build_database_session_manager() + session_manager = self._session_manager + + api = session_manager._database.spanner_api + api.create_session.side_effect = BadRequest("test") # Verify that BadRequest is not caught. with self.assertRaises(BadRequest): - self._mocks["create_session"].side_effect = BadRequest("Test BadRequest") - session_manager.get_session_for_read_only() + session_manager.get_session(TransactionType.READ_ONLY) def test_exception_failed_precondition(self): - session_manager = self._build_database_session_manager() + session_manager = self._session_manager + + api = session_manager._database.spanner_api + api.create_session.side_effect = FailedPrecondition("test") # Verify that FailedPrecondition is not caught. with self.assertRaises(FailedPrecondition): - self._mocks["create_session"].side_effect = FailedPrecondition( - "Test FailedPrecondition" - ) - session_manager.get_session_for_read_only() - - @staticmethod - def _build_database_session_manager(): - """Builds and returns a new database session manager for testing.""" + session_manager.get_session(TransactionType.READ_ONLY) - client = Client(project="project-id") - instance = Instance(instance_id="instance-id", client=client) + def _build_session_manager(self) -> DatabaseSessionsManager: + """Builds a new database session manager for testing.""" + from tests._builders import build_database - database = Database(database_id="database-id", instance=instance) + database = build_database() session_manager = database._session_manager # Mock the session pool. @@ -253,25 +241,22 @@ def _build_database_session_manager(): pool.get = Mock(wraps=pool.get) pool.put = Mock(wraps=pool.put) - return session_manager + self._session_manager = session_manager - @staticmethod - def _enable_multiplexed_env_vars(): - """Sets environment variables to enable multiplexed sessions.""" + def _cleanup_database_session_manager(self) -> None: + """Cleans up the database session manager after testing.""" - os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" - os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" - os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" - os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" - - @staticmethod - def _disable_multiplexed_env_vars(): - """Sets environment variables to disable multiplexed sessions.""" + # If the maintenance thread is still alive, disable multiplexed sessions and + # wait for the thread to terminate. We need to do this to ensure that the + # thread is properly cleaned up and does not interfere with other tests. + session_manager = self._session_manager + thread = session_manager._multiplexed_session_maintenance_thread - os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + if thread and thread.is_alive(): + session_manager._is_multiplexed_sessions_disabled_event.set() + self._assert_thread_terminated(thread) - @staticmethod - def assert_true_with_timeout(condition): + def _assert_true_with_timeout(self, condition): """Asserts that the given condition is met within a timeout period.""" sleep_seconds = 0.1 @@ -281,4 +266,12 @@ def assert_true_with_timeout(condition): while not condition() and time.time() - start_time < timeout_seconds: time.sleep(sleep_seconds) - assert condition() + self.assertTrue(condition()) + + def _assert_thread_terminated(self, thread: Thread): + """Asserts that the maintenance thread is terminated.""" + + def _is_thread_terminated(): + return not thread.is_alive() + + self._assert_true_with_timeout(_is_thread_terminated) diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 9ece105a3d..5810478c54 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -540,7 +540,6 @@ def test_database_factory_defaults(self): self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), []) self.assertIsInstance(database._session_manager._pool, BurstyPool) - self.assertIsNone(database._logger) self.assertIs(database._session_manager._pool._database, database) self.assertIsNone(database.database_role) diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index e6bc827f57..67c2b338f9 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -20,6 +20,7 @@ import mock from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from tests._builders import build_database from tests._helpers import ( OpenTelemetryBase, LIB_VERSION, @@ -29,18 +30,6 @@ ) -def _make_database(name="name"): - from google.cloud.spanner_v1.database import Database - - return mock.create_autospec(Database, instance=True) - - -def _make_session(): - from google.cloud.spanner_v1.database import Session - - return mock.create_autospec(Session, instance=True) - - class TestAbstractSessionPool(unittest.TestCase): def _getTargetClass(self): from google.cloud.spanner_v1.pool import AbstractSessionPool @@ -66,7 +55,7 @@ def test_ctor_explicit(self): def test_bind_abstract(self): pool = self._make_one() - database = _make_database("name") + database = build_database() with self.assertRaises(NotImplementedError): pool.bind(database) @@ -88,38 +77,32 @@ def test_clear_abstract(self): def test__new_session_wo_labels(self): pool = self._make_one() - database = pool._database = _make_database("name") - session = _make_session() - database.session.return_value = session + pool._database = build_database() new_session = pool._new_session() - self.assertIs(new_session, session) - database.session.assert_called_once_with(labels={}, database_role=None) + self.assertEqual(new_session.labels, {}) + self.assertIsNone(new_session.database_role) def test__new_session_w_labels(self): labels = {"foo": "bar"} pool = self._make_one(labels=labels) - database = pool._database = _make_database("name") - session = _make_session() - database.session.return_value = session + pool._database = build_database() new_session = pool._new_session() - self.assertIs(new_session, session) - database.session.assert_called_once_with(labels=labels, database_role=None) + self.assertEqual(new_session.labels, labels) + self.assertIsNone(new_session.database_role) def test__new_session_w_database_role(self): database_role = "dummy-role" pool = self._make_one(database_role=database_role) - database = pool._database = _make_database("name") - session = _make_session() - database.session.return_value = session + pool._database = build_database() new_session = pool._new_session() - self.assertIs(new_session, session) - database.session.assert_called_once_with(labels={}, database_role=database_role) + self.assertEqual(new_session.labels, {}) + self.assertEqual(new_session.database_role, database_role) class TestFixedSizePool(OpenTelemetryBase): @@ -167,10 +150,10 @@ def test_ctor_explicit(self): def test_bind(self): database_role = "dummy-role" pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 + database = _Database() database._database_role = database_role - database._sessions.extend(SESSIONS) + sessions = [_Session(database)] * 10 + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) @@ -182,37 +165,35 @@ def test_bind(self): api = database.spanner_api self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: + for session in sessions: session.create.assert_not_called() def test_get_active(self): pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = sorted([_Session(database) for i in range(0, 4)]) - database._sessions.extend(SESSIONS) + database = _Database() + sessions = sorted([_Session(database)] * 4) + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) # check if sessions returned in LIFO order for i in (3, 2, 1, 0): session = pool.get() - self.assertIs(session, SESSIONS[i]) + self.assertIs(session, sessions[i]) self.assertFalse(session._exists_checked) self.assertFalse(pool._sessions.full()) def test_get_non_expired(self): pool = self._make_one(size=4) - database = _Database("name") + database = _Database() last_use_time = datetime.utcnow() - timedelta(minutes=56) - SESSIONS = sorted( - [_Session(database, last_use_time=last_use_time) for i in range(0, 4)] - ) - database._sessions.extend(SESSIONS) + sessions = sorted([_Session(database, last_use_time=last_use_time)] * 4) + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) # check if sessions returned in LIFO order for i in (3, 2, 1, 0): session = pool.get() - self.assertIs(session, SESSIONS[i]) + self.assertIs(session, sessions[i]) self.assertTrue(session._exists_checked) self.assertFalse(pool._sessions.full()) @@ -222,12 +203,12 @@ def test_spans_bind_get(self): # This tests retrieving 1 out of 4 sessions from the session pool. pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = sorted([_Session(database) for i in range(0, 4)]) - database._sessions.extend(SESSIONS) + database = _Database() + sessions = sorted([_Session(database) for i in range(0, 4)]) + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) - with trace_call("pool.Get", SESSIONS[0]): + with trace_call("pool.Get", sessions[0]): pool.get() span_list = self.get_finished_spans() @@ -264,7 +245,7 @@ def test_spans_bind_get_empty_pool(self): # Tests trying to invoke pool.get() from an empty pool. pool = self._make_one(size=0) - database = _Database("name") + database = _Database() session1 = _Session(database) with trace_call("pool.Get", session1): try: @@ -309,9 +290,8 @@ def test_spans_pool_bind(self): # Tests the exception generated from invoking pool.bind when # you have an empty pool. pool = self._make_one(size=1) - database = _Database("name") - SESSIONS = [] - database._sessions.extend(SESSIONS) + database = _Database() + pool._new_session = mock.Mock(side_effect=Exception("Test")) fauxSession = mock.Mock() setattr(fauxSession, "_database", database) try: @@ -357,8 +337,8 @@ def test_spans_pool_bind(self): ( "exception", { - "exception.type": "IndexError", - "exception.message": "pop from empty list", + "exception.type": "Exception", + "exception.message": "Test", "exception.stacktrace": "EPHEMERAL", "exception.escaped": "False", }, @@ -368,8 +348,8 @@ def test_spans_pool_bind(self): ( "exception", { - "exception.type": "IndexError", - "exception.message": "pop from empty list", + "exception.type": "Exception", + "exception.message": "Test", "exception.stacktrace": "EPHEMERAL", "exception.escaped": "False", }, @@ -379,18 +359,18 @@ def test_spans_pool_bind(self): def test_get_expired(self): pool = self._make_one(size=4) - database = _Database("name") + database = _Database() last_use_time = datetime.utcnow() - timedelta(minutes=65) - SESSIONS = [_Session(database, last_use_time=last_use_time)] * 5 - SESSIONS[0]._exists = False - database._sessions.extend(SESSIONS) + sessions = [_Session(database, last_use_time=last_use_time)] * 5 + sessions[0]._exists = False + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) session = pool.get() - self.assertIs(session, SESSIONS[4]) + self.assertIs(session, sessions[4]) session.create.assert_called() - self.assertTrue(SESSIONS[0]._exists_checked) + self.assertTrue(sessions[0]._exists_checked) self.assertFalse(pool._sessions.full()) def test_get_empty_default_timeout(self): @@ -419,9 +399,9 @@ def test_put_full(self): import queue pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + database._sessions.extend(sessions) pool.bind(database) self.reset() @@ -432,9 +412,9 @@ def test_put_full(self): def test_put_non_full(self): pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + database._sessions.extend(sessions) pool.bind(database) pool._sessions.get() @@ -444,20 +424,20 @@ def test_put_non_full(self): def test_clear(self): pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 10 + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) self.assertTrue(pool._sessions.full()) api = database.spanner_api self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: + for session in sessions: session.create.assert_not_called() pool.clear() - for session in SESSIONS: + for session in sessions: self.assertTrue(session._deleted) @@ -502,15 +482,15 @@ def test_ctor_explicit(self): def test_ctor_explicit_w_database_role_in_db(self): database_role = "dummy-role" pool = self._make_one() - database = pool._database = _Database("name") + database = pool._database = _Database() database._database_role = database_role pool.bind(database) self.assertEqual(pool.database_role, database_role) def test_get_empty(self): pool = self._make_one() - database = _Database("name") - database._sessions.append(_Session(database)) + database = _Database() + pool._new_session = mock.Mock(side_effect=[_Session(database)]) pool.bind(database) session = pool.get() @@ -528,9 +508,9 @@ def test_spans_get_empty_pool(self): # and pool.get() acquires from a pool, waiting for a session # to become available. pool = self._make_one() - database = _Database("name") + database = _Database() session1 = _Session(database) - database._sessions.append(session1) + pool._new_session = mock.Mock(side_effect=[session1]) pool.bind(database) with trace_call("pool.Get", session1): @@ -560,7 +540,7 @@ def test_spans_get_empty_pool(self): def test_get_non_empty_session_exists(self): pool = self._make_one() - database = _Database("name") + database = _Database() previous = _Session(database) pool.bind(database) pool.put(previous) @@ -576,7 +556,7 @@ def test_spans_get_non_empty_session_exists(self): # Tests the spans produces when you invoke pool.bind # and then insert a session into the pool. pool = self._make_one() - database = _Database("name") + database = _Database() previous = _Session(database) pool.bind(database) with trace_call("pool.Get", previous): @@ -598,10 +578,10 @@ def test_spans_get_non_empty_session_exists(self): def test_get_non_empty_session_expired(self): pool = self._make_one() - database = _Database("name") + database = _Database() previous = _Session(database, exists=False) newborn = _Session(database) - database._sessions.append(newborn) + pool._new_session = mock.Mock(side_effect=[newborn]) pool.bind(database) pool.put(previous) @@ -615,7 +595,7 @@ def test_get_non_empty_session_expired(self): def test_put_empty(self): pool = self._make_one() - database = _Database("name") + database = _Database() pool.bind(database) session = _Session(database) @@ -626,7 +606,7 @@ def test_put_empty(self): def test_spans_put_empty(self): # Tests the spans produced when you put sessions into an empty pool. pool = self._make_one() - database = _Database("name") + database = _Database() pool.bind(database) session = _Session(database) @@ -641,7 +621,7 @@ def test_spans_put_empty(self): def test_put_full(self): pool = self._make_one(target_size=1) - database = _Database("name") + database = _Database() pool.bind(database) older = _Session(database) pool.put(older) @@ -657,7 +637,7 @@ def test_spans_put_full(self): # This scenario tests the spans produced from putting an older # session into a pool that is already full. pool = self._make_one(target_size=1) - database = _Database("name") + database = _Database() pool.bind(database) older = _Session(database) with trace_call("pool.put", older): @@ -677,7 +657,7 @@ def test_spans_put_full(self): def test_put_full_expired(self): pool = self._make_one(target_size=1) - database = _Database("name") + database = _Database() pool.bind(database) older = _Session(database) pool.put(older) @@ -691,7 +671,7 @@ def test_put_full_expired(self): def test_clear(self): pool = self._make_one() - database = _Database("name") + database = _Database() pool.bind(database) previous = _Session(database) pool.put(previous) @@ -753,18 +733,18 @@ def test_ctor_explicit(self): def test_ctor_explicit_w_database_role_in_db(self): database_role = "dummy-role" pool = self._make_one() - database = pool._database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) + database = pool._database = _Database() + sessions = [_Session(database)] * 10 + database._sessions.extend(sessions) database._database_role = database_role pool.bind(database) self.assertEqual(pool.database_role, database_role) def test_bind(self): pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 10 + database._sessions.extend(sessions) pool.bind(database) self.assertIs(pool._database, database) @@ -775,20 +755,20 @@ def test_bind(self): api = database.spanner_api self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: + for session in sessions: session.create.assert_not_called() def test_get_hit_no_ping(self): pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) self.reset() session = pool.get() - self.assertIs(session, SESSIONS[0]) + self.assertIs(session, sessions[0]) self.assertFalse(session._exists_checked) self.assertFalse(pool._sessions.full()) self.assertNoSpans() @@ -799,9 +779,9 @@ def test_get_hit_w_ping(self): from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + pool._new_session = mock.Mock(side_effect=sessions) sessions_created = datetime.datetime.utcnow() - datetime.timedelta(seconds=4000) @@ -812,7 +792,7 @@ def test_get_hit_w_ping(self): session = pool.get() - self.assertIs(session, SESSIONS[0]) + self.assertIs(session, sessions[0]) self.assertTrue(session._exists_checked) self.assertFalse(pool._sessions.full()) self.assertNoSpans() @@ -823,10 +803,10 @@ def test_get_hit_w_ping_expired(self): from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 5 - SESSIONS[0]._exists = False - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 5 + sessions[0]._exists = False + pool._new_session = mock.Mock(side_effect=sessions) sessions_created = datetime.datetime.utcnow() - datetime.timedelta(seconds=4000) @@ -836,9 +816,9 @@ def test_get_hit_w_ping_expired(self): session = pool.get() - self.assertIs(session, SESSIONS[4]) + self.assertIs(session, sessions[4]) session.create.assert_called() - self.assertTrue(SESSIONS[0]._exists_checked) + self.assertTrue(sessions[0]._exists_checked) self.assertFalse(pool._sessions.full()) self.assertNoSpans() @@ -870,9 +850,9 @@ def test_put_full(self): import queue pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + database._sessions.extend(sessions) pool.bind(database) with self.assertRaises(queue.Full): @@ -887,9 +867,9 @@ def test_spans_put_full(self): import queue pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 4 + database._sessions.extend(sessions) pool.bind(database) with self.assertRaises(queue.Full): @@ -926,7 +906,7 @@ def test_put_non_full(self): session_queue = pool._sessions = _Queue() now = datetime.datetime.utcnow() - database = _Database("name") + database = _Database() session = _Session(database) with _Monkey(MUT, _NOW=lambda: now): @@ -940,21 +920,21 @@ def test_put_non_full(self): def test_clear(self): pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 10 + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) self.reset() self.assertTrue(pool._sessions.full()) api = database.spanner_api self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: + for session in sessions: session.create.assert_not_called() pool.clear() - for session in SESSIONS: + for session in sessions: self.assertTrue(session._deleted) self.assertNoSpans() @@ -965,15 +945,15 @@ def test_ping_empty(self): def test_ping_oldest_fresh(self): pool = self._make_one(size=1) - database = _Database("name") - SESSIONS = [_Session(database)] * 1 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 1 + database._sessions.extend(sessions) pool.bind(database) self.reset() pool.ping() - self.assertFalse(SESSIONS[0]._pinged) + self.assertFalse(sessions[0]._pinged) self.assertNoSpans() def test_ping_oldest_stale_but_exists(self): @@ -982,16 +962,16 @@ def test_ping_oldest_stale_but_exists(self): from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=1) - database = _Database("name") - SESSIONS = [_Session(database)] * 1 - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 1 + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) later = datetime.datetime.utcnow() + datetime.timedelta(seconds=4000) with _Monkey(MUT, _NOW=lambda: later): pool.ping() - self.assertTrue(SESSIONS[0]._pinged) + self.assertTrue(sessions[0]._pinged) def test_ping_oldest_stale_and_not_exists(self): import datetime @@ -999,10 +979,10 @@ def test_ping_oldest_stale_and_not_exists(self): from google.cloud.spanner_v1 import pool as MUT pool = self._make_one(size=1) - database = _Database("name") - SESSIONS = [_Session(database)] * 2 - SESSIONS[0]._exists = False - database._sessions.extend(SESSIONS) + database = _Database() + sessions = [_Session(database)] * 2 + sessions[0]._exists = False + pool._new_session = mock.Mock(side_effect=sessions) pool.bind(database) self.reset() @@ -1010,8 +990,8 @@ def test_ping_oldest_stale_and_not_exists(self): with _Monkey(MUT, _NOW=lambda: later): pool.ping() - self.assertTrue(SESSIONS[0]._pinged) - SESSIONS[1].create.assert_called() + self.assertTrue(sessions[0]._pinged) + sessions[1].create.assert_called() self.assertNoSpans() def test_spans_get_and_leave_empty_pool(self): @@ -1021,9 +1001,9 @@ def test_spans_get_and_leave_empty_pool(self): # This scenario tests the spans generated from pulling a span # out the pool and leaving it empty. pool = self._make_one() - database = _Database("name") + database = _Database() session1 = _Session(database) - database._sessions.append(session1) + pool._new_session = mock.Mock(side_effect=[session1]) try: pool.bind(database) except Exception: @@ -1098,11 +1078,11 @@ def delete(self): class _Database(object): - def __init__(self, name): - self.name = name + def __init__(self): + self.name = "name" self._sessions = [] self._database_role = None - self.database_id = name + self.database_id = self.name self._route_to_leader_enabled = True def mock_batch_create_sessions( @@ -1181,7 +1161,3 @@ def put(self, item, **kwargs): def put_nowait(self, item, **kwargs): self._put_nowait = kwargs self._items.append(item) - - -class _Pool(_Queue): - _database = None diff --git a/tests/unit/test_session_options.py b/tests/unit/test_session_options.py index 324cb3cd43..4e478a816d 100644 --- a/tests/unit/test_session_options.py +++ b/tests/unit/test_session_options.py @@ -12,77 +12,110 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from logging import Logger from unittest import TestCase -from google.cloud.spanner_v1.session_options import SessionOptions + +from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType class TestSessionOptions(TestCase): + _logger = Logger("test_session_options_logger") + @classmethod def setUpClass(cls): + # Save the original environment variables. cls._original_env = dict(os.environ) @classmethod def tearDownClass(cls): + # Restore environment variables. os.environ.clear() os.environ.update(cls._original_env) def test_use_multiplexed_for_read_only(self): session_options = SessionOptions() + transaction_type = TransactionType.READ_ONLY os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" - self.assertFalse(session_options.use_multiplexed_for_read_only()) + self.assertFalse(session_options.use_multiplexed(transaction_type)) os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" - self.assertFalse(session_options.use_multiplexed_for_read_only()) + self.assertFalse(session_options.use_multiplexed(transaction_type)) os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" - self.assertTrue(session_options.use_multiplexed_for_read_only()) + self.assertTrue(session_options.use_multiplexed(transaction_type)) - session_options.disable_multiplexed_for_read_only() - self.assertFalse(session_options.use_multiplexed_for_read_only()) + session_options.disable_multiplexed(self._logger, transaction_type) + self.assertFalse(session_options.use_multiplexed(transaction_type)) def test_use_multiplexed_for_partitioned(self): session_options = SessionOptions() + transaction_type = TransactionType.PARTITIONED os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" - self.assertFalse(session_options.use_multiplexed_for_partitioned()) + self.assertFalse(session_options.use_multiplexed(transaction_type)) os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "false" - self.assertFalse(session_options.use_multiplexed_for_partitioned()) + self.assertFalse(session_options.use_multiplexed(transaction_type)) os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" - self.assertFalse(session_options.use_multiplexed_for_partitioned()) + self.assertFalse(session_options.use_multiplexed(transaction_type)) os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" - self.assertTrue(session_options.use_multiplexed_for_partitioned()) + self.assertTrue(session_options.use_multiplexed(transaction_type)) - session_options.disable_multiplexed_for_partitioned() - self.assertFalse(session_options.use_multiplexed_for_partitioned()) + session_options.disable_multiplexed(self._logger, transaction_type) + self.assertFalse(session_options.use_multiplexed(transaction_type)) def test_use_multiplexed_for_read_write(self): session_options = SessionOptions() + transaction_type = TransactionType.READ_WRITE os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" - self.assertFalse(session_options.use_multiplexed_for_read_write()) + self.assertFalse(session_options.use_multiplexed(transaction_type)) os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "false" - self.assertFalse(session_options.use_multiplexed_for_read_write()) + self.assertFalse(session_options.use_multiplexed(transaction_type)) os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" - self.assertFalse(session_options.use_multiplexed_for_read_write()) + self.assertFalse(session_options.use_multiplexed(transaction_type)) os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" - self.assertTrue(session_options.use_multiplexed_for_read_write()) + self.assertTrue(session_options.use_multiplexed(transaction_type)) + + session_options.disable_multiplexed(self._logger, transaction_type) + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + def test_disable_multiplexed_all(self): + session_options = SessionOptions() + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + + session_options.disable_multiplexed(self._logger) + + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) + self.assertFalse(session_options.use_multiplexed(TransactionType.PARTITIONED)) + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_WRITE)) + + def test_unsupported_transaction_type(self): + session_options = SessionOptions() + unsupported_type = "UNSUPPORTED_TRANSACTION_TYPE" + + with self.assertRaises(ValueError): + session_options.use_multiplexed(unsupported_type) - session_options.disable_multiplexed_for_read_write() - self.assertFalse(session_options.use_multiplexed_for_read_write()) + with self.assertRaises(ValueError): + session_options.disable_multiplexed(self._logger, unsupported_type) - def test_supported_env_var_values(self): + def test_env_var_values(self): session_options = SessionOptions() os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" @@ -90,12 +123,12 @@ def test_supported_env_var_values(self): true_values = ["1", " 1", " 1", "true", "True", "TRUE", " true "] for value in true_values: os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value - self.assertTrue(session_options.use_multiplexed_for_read_only()) + self.assertTrue(session_options.use_multiplexed(TransactionType.READ_ONLY)) false_values = ["", "0", "false", "False", "FALSE"] for value in false_values: os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value - self.assertFalse(session_options.use_multiplexed_for_read_only()) + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) del os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] - self.assertFalse(session_options.use_multiplexed_for_read_only()) + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 11fc0135d1..0e900acc37 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -11,21 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from google.api_core import gapic_v1 import mock -from google.cloud.spanner_v1 import RequestOptions, DirectedReadOptions +from google.cloud.spanner_v1 import ( + RequestOptions, + DirectedReadOptions, + SpannerClient, + KeySet, +) +from google.cloud.spanner_v1.session_options import TransactionType from tests._helpers import ( OpenTelemetryBase, LIB_VERSION, StatusCode, HAS_OPENTELEMETRY_INSTALLED, enrich_with_otel_scope, + enable_multiplexed_sessions, ) from google.cloud.spanner_v1.param_types import INT64 from google.api_core.retry import Retry +from tests._builders import build_session TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] @@ -79,7 +85,7 @@ def _makeTimestamp(): import datetime from google.cloud._helpers import UTC - return datetime.datetime.utcnow().replace(tzinfo=UTC) + return datetime.datetime.now(tz=UTC) class Test_restart_on_unavailable(OpenTelemetryBase): @@ -111,8 +117,6 @@ def _make_txn_selector(self): return _Derived(session) def _make_spanner_api(self): - from google.cloud.spanner_v1 import SpannerClient - return mock.create_autospec(SpannerClient, instance=True) def _call_fut( @@ -128,12 +132,12 @@ def _call_fut( from google.cloud.spanner_v1.snapshot import _restart_on_unavailable return _restart_on_unavailable( - restart, - request, - metadata, - span_name, - session, - attributes, + method=restart, + request=request, + session=session, + metadata=metadata, + trace_name=span_name, + attributes=attributes, transaction=derived, ) @@ -594,8 +598,6 @@ def _make_txn_selector(self): return _Derived(session) def _make_spanner_api(self): - from google.cloud.spanner_v1 import SpannerClient - return mock.create_autospec(SpannerClient, instance=True) def test_ctor(self): @@ -612,9 +614,25 @@ def test__make_txn_selector_virtual(self): with self.assertRaises(NotImplementedError): base._make_txn_selector() - def test_read_other_error(self): - from google.cloud.spanner_v1.keyset import KeySet + def test_read_partitioned_not_implemented_for_multiplexed(self): + enable_multiplexed_sessions() + + database = ( + self._build_database_with_partitioned_not_implemented_for_multiplexed() + ) + + session = build_session(database=database) + session.create() + derived = self._makeDerived(session) + + with self.assertRaises(NotImplementedError): + list(derived.read(TABLE_NAME, COLUMNS, KeySet(all_=True))) + + self.assertFalse( + database.session_options.use_multiplexed(TransactionType.READ_ONLY) + ) + def test_read_other_error(self): keyset = KeySet(all_=True) database = _Database() database.spanner_api = self._make_spanner_api() @@ -658,7 +676,6 @@ def _read_helper( from google.cloud.spanner_v1 import ReadRequest from google.cloud.spanner_v1 import Type, StructType from google.cloud.spanner_v1 import TypeCode - from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1._helpers import _make_value_pb VALUES = [["bharney", 31], ["phred", 32]] @@ -865,6 +882,24 @@ def test_read_w_directed_read_options_override(self): directed_read_options_at_client_level=DIRECTED_READ_OPTIONS_FOR_CLIENT, ) + def test_execute_sql_partitioned_not_implemented_for_multiplexed(self): + enable_multiplexed_sessions() + + database = ( + self._build_database_with_partitioned_not_implemented_for_multiplexed() + ) + + session = build_session(database=database) + session.create() + derived = self._makeDerived(session) + + with self.assertRaises(NotImplementedError): + list(derived.execute_sql(SQL_QUERY)) + + self.assertFalse( + database.session_options.use_multiplexed(TransactionType.READ_ONLY) + ) + def test_execute_sql_other_error(self): database = _Database() database.spanner_api = self._make_spanner_api() @@ -1137,7 +1172,6 @@ def _partition_read_helper( retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, ): - from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1 import Partition from google.cloud.spanner_v1 import PartitionOptions from google.cloud.spanner_v1 import PartitionReadRequest @@ -1225,9 +1259,27 @@ def test_partition_read_wo_existing_transaction_raises(self): with self.assertRaises(ValueError): self._partition_read_helper(multi_use=True, w_txn=False) - def test_partition_read_other_error(self): - from google.cloud.spanner_v1.keyset import KeySet + def test_partition_read_multiplexed_not_implemented_error(self): + enable_multiplexed_sessions() + database = ( + self._build_database_with_partitioned_not_implemented_for_multiplexed() + ) + + session = build_session(database=database) + session.create() + derived = self._makeDerived(session) + derived._multi_use = True + derived._transaction_id = TXN_ID + + with self.assertRaises(NotImplementedError): + list(derived.partition_read(TABLE_NAME, COLUMNS, KeySet(all_=True))) + + self.assertFalse( + database.session_options.use_multiplexed(TransactionType.READ_ONLY) + ) + + def test_partition_read_other_error(self): keyset = KeySet(all_=True) database = _Database() database.spanner_api = self._make_spanner_api() @@ -1249,7 +1301,6 @@ def test_partition_read_other_error(self): ) def test_partition_read_w_retry(self): - from google.cloud.spanner_v1.keyset import KeySet from google.api_core.exceptions import InternalServerError from google.cloud.spanner_v1 import Partition from google.cloud.spanner_v1 import PartitionResponse @@ -1389,6 +1440,26 @@ def _partition_query_helper( attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), ) + def test_partition_query_partitioned_not_implemented_for_multiplexed(self): + enable_multiplexed_sessions() + + database = ( + self._build_database_with_partitioned_not_implemented_for_multiplexed() + ) + + session = build_session(database=database) + session.create() + derived = self._makeDerived(session) + derived._multi_use = True + derived._transaction_id = TXN_ID + + with self.assertRaises(NotImplementedError): + list(derived.partition_query(SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES)) + + self.assertFalse( + database.session_options.use_multiplexed(TransactionType.READ_ONLY) + ) + def test_partition_query_other_error(self): database = _Database() database.spanner_api = self._make_spanner_api() @@ -1437,6 +1508,25 @@ def test_partition_query_ok_w_timeout_and_retry_params(self): multi_use=True, w_txn=True, retry=Retry(deadline=60), timeout=2.0 ) + @staticmethod + def _build_database_with_partitioned_not_implemented_for_multiplexed(): + """Builds and returns a database for testing that raises errors for + partitioned operations not being supported for multiplexed sessions.""" + from tests._builders import build_database + + database = build_database() + + error_msg = "Partitioned operations are not supported with multiplexed sessions" + not_implemented_error = NotImplementedError(error_msg) + + api = database.spanner_api + api.streaming_read.side_effect = not_implemented_error + api.execute_streaming_sql.side_effect = not_implemented_error + api.partition_read.side_effect = not_implemented_error + api.partition_query.side_effect = not_implemented_error + + return database + class TestSnapshot(OpenTelemetryBase): PROJECT_ID = "project-id" @@ -1456,8 +1546,6 @@ def _make_one(self, *args, **kwargs): return self._getTargetClass()(*args, **kwargs) def _make_spanner_api(self): - from google.cloud.spanner_v1 import SpannerClient - return mock.create_autospec(SpannerClient, instance=True) def _makeDuration(self, seconds=1, microseconds=0):