From a8093f442ecc9718713c107e3211722e6bdea75c Mon Sep 17 00:00:00 2001 From: currantw Date: Tue, 11 Mar 2025 16:26:12 -0700 Subject: [PATCH] feat: Add support for multiplexed sessions - part 1 (supporting classes and refactoring) - Adds SessionOptions and DatabaseSessionManager class support multiplexed session in future work. - Update Client, Instance, Database, Session and associated Checkout classes accordingly. - Add unit tests for new classes and update those for existing classes. Signed-off-by: currantw --- google/cloud/spanner_dbapi/connection.py | 11 +- google/cloud/spanner_v1/client.py | 16 + google/cloud/spanner_v1/database.py | 105 +++-- .../spanner_v1/database_sessions_manager.py | 116 ++++++ google/cloud/spanner_v1/pool.py | 49 +-- google/cloud/spanner_v1/session.py | 40 +- google/cloud/spanner_v1/session_options.py | 112 ++++++ tests/system/test_database_api.py | 4 +- tests/system/test_observability_options.py | 22 +- tests/unit/spanner_dbapi/test_connect.py | 4 +- tests/unit/spanner_dbapi/test_connection.py | 79 ++-- tests/unit/test_database.py | 359 +++++++++--------- tests/unit/test_database_session_manager.py | 140 +++++++ tests/unit/test_instance.py | 7 +- tests/unit/test_pool.py | 95 ----- tests/unit/test_session.py | 26 +- tests/unit/test_session_options.py | 101 +++++ 17 files changed, 876 insertions(+), 410 deletions(-) create mode 100644 google/cloud/spanner_v1/database_sessions_manager.py create mode 100644 google/cloud/spanner_v1/session_options.py create mode 100644 tests/unit/test_database_session_manager.py create mode 100644 tests/unit/test_session_options.py diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index c2aa385d2a..b0ccc8010e 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -330,8 +330,13 @@ def _session_checkout(self): """ if self.database is None: raise ValueError("Database needs to be passed for this operation") + if not self._session: - self._session = self.database._pool.get() + self._session = ( + self.database._session_manager.get_session_for_read_only() + if self.read_only + else self.database._session_manager.get_session_for_read_write() + ) return self._session @@ -344,7 +349,7 @@ def _release_session(self): return if self.database is None: raise ValueError("Database needs to be passed for this operation") - self.database._pool.put(self._session) + self.database._session_manager.put_session(self._session) self._session = None def transaction_checkout(self): @@ -400,7 +405,7 @@ def close(self): self._transaction.rollback() if self._own_pool and self.database: - self.database._pool.clear() + self.database._session_manager._pool.clear() self.is_closed = True diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index e201f93e9b..f84170c1f3 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -70,6 +70,7 @@ except ImportError: # pragma: NO COVER HAS_GOOGLE_CLOUD_MONITORING_INSTALLED = False +from google.cloud.spanner_v1.session_options import SessionOptions _CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST" @@ -171,6 +172,9 @@ class Client(ClientWithProject): or :class:`dict` :param default_transaction_options: (Optional) Default options to use for all transactions. + :type session_options: :class:`~google.cloud.spanner_v1.SessionOptions` + :param session_options: (Optional) Options for client sessions. + :raises: :class:`ValueError ` if both ``read_only`` and ``admin`` are :data:`True` """ @@ -193,6 +197,7 @@ def __init__( directed_read_options=None, observability_options=None, default_transaction_options: Optional[DefaultTransactionOptions] = None, + session_options=None, ): self._emulator_host = _get_spanner_emulator_host() @@ -262,6 +267,8 @@ def __init__( ) self._default_transaction_options = default_transaction_options + self._session_options = session_options or SessionOptions() + @property def credentials(self): """Getter for client's credentials. @@ -525,3 +532,12 @@ def default_transaction_options( ) self._default_transaction_options = default_transaction_options + + @property + def session_options(self): + """Returns the session options for the client. + + :rtype: :class:`~google.cloud.spanner_v1.SessionOptions` + :returns: The session options for the client. + """ + return self._session_options diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 03c6e5119f..a7caaa8a29 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -40,6 +40,7 @@ from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest from google.cloud.spanner_admin_database_v1.types import DatabaseDialect +from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager from google.cloud.spanner_v1.transaction import BatchTransactionId from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import Type @@ -59,7 +60,6 @@ from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.merged_result_set import MergedResultSet from google.cloud.spanner_v1.pool import BurstyPool -from google.cloud.spanner_v1.pool import SessionCheckout from google.cloud.spanner_v1.session import Session from google.cloud.spanner_v1.snapshot import _restart_on_unavailable from google.cloud.spanner_v1.snapshot import Snapshot @@ -70,7 +70,6 @@ from google.cloud.spanner_v1.table import Table from google.cloud.spanner_v1._opentelemetry_tracing import ( add_span_event, - get_current_span, trace_call, ) from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture @@ -191,10 +190,10 @@ def __init__( if pool is None: pool = BurstyPool(database_role=database_role) - - self._pool = pool pool.bind(self) + self._session_manager = DatabaseSessionsManager(database=self, pool=pool) + @classmethod def from_pb(cls, database_pb, instance, pool=None): """Creates an instance of this class from a protobuf. @@ -708,7 +707,7 @@ def execute_pdml(): "CloudSpanner.Database.execute_partitioned_pdml", observability_options=self.observability_options, ) as span, MetricsCapture(): - with SessionCheckout(self._pool) as session: + with SessionCheckout(self) as session: add_span_event(span, "Starting BeginTransaction") txn = api.begin_transaction( session=session.name, options=txn_options, metadata=metadata @@ -923,7 +922,7 @@ def run_in_transaction(self, func, *args, **kw): # Check out a session and run the function in a transaction; once # done, flip the sanity check bit back. try: - with SessionCheckout(self._pool) as session: + with SessionCheckout(self) as session: return session.run_in_transaction(func, *args, **kw) finally: self._local.transaction_running = False @@ -1160,6 +1159,35 @@ def observability_options(self): return opts +class SessionCheckout(object): + """Context manager for using a session from a database. + + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: database to use the session from + """ + + _session = None # Not checked out until '__enter__'. + + def __init__(self, database): + if not isinstance(database, Database): + raise TypeError( + "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format( + class_name=self.__class__.__name__, + expected_class_name=Database.__name__, + actual_class_name=database.__class__.__name__, + ) + ) + + self._database = database + + def __enter__(self): + self._session = self._database._session_manager.get_session_for_read_write() + return self._session + + def __exit__(self, *ignored): + self._database._session_manager.put_session(self._session) + + class BatchCheckout(object): """Context manager for using a batch from a database. @@ -1194,6 +1222,15 @@ def __init__( isolation_level=TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED, **kw, ): + if not isinstance(database, Database): + raise TypeError( + "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format( + class_name=self.__class__.__name__, + expected_class_name=Database.__name__, + actual_class_name=database.__class__.__name__, + ) + ) + self._database = database self._session = self._batch = None if request_options is None: @@ -1209,10 +1246,8 @@ def __init__( def __enter__(self): """Begin ``with`` block.""" - current_span = get_current_span() - session = self._session = self._database._pool.get() - add_span_event(current_span, "Using session", {"id": session.session_id}) - batch = self._batch = Batch(session) + self._session = self._database._session_manager.get_session_for_read_only() + batch = self._batch = Batch(self._session) if self._request_options.transaction_tag: batch.transaction_tag = self._request_options.transaction_tag return batch @@ -1235,13 +1270,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): "CommitStats: {}".format(self._batch.commit_stats), extra={"commit_stats": self._batch.commit_stats}, ) - self._database._pool.put(self._session) - current_span = get_current_span() - add_span_event( - current_span, - "Returned session to pool", - {"id": self._session.session_id}, - ) + self._database._session_manager.put_session(self._session) class MutationGroupsCheckout(object): @@ -1258,23 +1287,26 @@ class MutationGroupsCheckout(object): """ def __init__(self, database): + if not isinstance(database, Database): + raise TypeError( + "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format( + class_name=self.__class__.__name__, + expected_class_name=Database.__name__, + actual_class_name=database.__class__.__name__, + ) + ) + self._database = database self._session = None def __enter__(self): """Begin ``with`` block.""" - session = self._session = self._database._pool.get() - return MutationGroups(session) + self._session = self._database._session_manager.get_session_for_read_write() + return MutationGroups(self._session) def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" - if isinstance(exc_val, NotFound): - # If NotFound exception occurs inside the with block - # then we validate if the session still exists. - if not self._session.exists(): - self._session = self._database._pool._new_session() - self._session.create() - self._database._pool.put(self._session) + self._database._session_manager.put_session(self._session) class SnapshotCheckout(object): @@ -1296,24 +1328,27 @@ class SnapshotCheckout(object): """ def __init__(self, database, **kw): + if not isinstance(database, Database): + raise TypeError( + "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}".format( + class_name=self.__class__.__name__, + expected_class_name=Database.__name__, + actual_class_name=database.__class__.__name__, + ) + ) + self._database = database self._session = None self._kw = kw def __enter__(self): """Begin ``with`` block.""" - session = self._session = self._database._pool.get() - return Snapshot(session, **self._kw) + self._session = self._database._session_manager.get_session_for_read_only() + return Snapshot(self._session, **self._kw) def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" - if isinstance(exc_val, NotFound): - # If NotFound exception occurs inside the with block - # then we validate if the session still exists. - if not self._session.exists(): - self._session = self._database._pool._new_session() - self._session.create() - self._database._pool.put(self._session) + self._database._session_manager.put_session(self._session) class BatchSnapshot(object): diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py new file mode 100644 index 0000000000..722b4df769 --- /dev/null +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -0,0 +1,116 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from google.cloud.spanner_v1._opentelemetry_tracing import ( + get_current_span, + add_span_event, +) +from google.cloud.spanner_v1.session import Session + + +class DatabaseSessionsManager(object): + """Manages sessions for a Cloud Spanner database. + + Sessions can be checked out from the database session manager using :meth:`get_session_for_read_only`, + :meth:`get_session_for_partitioned`, and :meth:`get_session_for_read_write`, and returned to + the session manager using :meth:`put_session`. + + The sessions returned by the session manager depend on the client's session options (see + :class:`~google.cloud.spanner_v1.session_options.SessionOptions`) and the provided session + pool (see :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`). + + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: The database to manage sessions for. + + :type pool: :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` + :param pool: The pool to get non-multiplexed sessions from. + """ + + def __init__(self, database, pool): + self._database = database + self._pool = pool + + def get_session_for_read_only(self) -> Session: + """Returns a session for read-only transactions from the database session manager. + + :rtype: :class:`~google.cloud.spanner_v1.session.Session` + :returns: a session for read-only transactions. + """ + + if ( + self._database._instance._client.session_options.use_multiplexed_for_read_only() + ): + raise NotImplementedError( + "Multiplexed sessions are not yet supported for read-only transactions." + ) + + return self._get_pooled_session() + + def get_session_for_partitioned(self) -> Session: + """Returns a session for partitioned transactions from the database session manager. + + :rtype: :class:`~google.cloud.spanner_v1.session.Session` + :returns: a session for partitioned transactions. + """ + + if ( + self._database._instance._client.session_options.use_multiplexed_for_partitioned() + ): + raise NotImplementedError( + "Multiplexed sessions are not yet supported for partitioned transactions." + ) + + return self._get_pooled_session() + + def get_session_for_read_write(self) -> Session: + """Returns a session for read/write transactions from the database session manager. + + :rtype: :class:`~google.cloud.spanner_v1.session.Session` + :returns: a session for read/write transactions. + """ + + if ( + self._database._instance._client.session_options.use_multiplexed_for_read_write() + ): + raise NotImplementedError( + "Multiplexed sessions are not yet supported for read/write transactions." + ) + + return self._get_pooled_session() + + def put_session(self, session: Session) -> None: + """Returns the session to the database session manager.""" + + if session.is_multiplexed: + raise NotImplementedError("Multiplexed sessions are not yet supported.") + + self._pool.put(session) + + current_span = get_current_span() + add_span_event( + current_span, + "Returned session", + {"id": session.session_id, "multiplexed": session.is_multiplexed}, + ) + + def _get_pooled_session(self): + """Returns a non-multiplexed session from the session pool.""" + + session = self._pool.get() + add_span_event( + get_current_span(), + "Using session", + {"id": session.session_id, "multiplexed": session.is_multiplexed}, + ) + + return session diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 0c4dd5a63b..18c586e76e 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -134,18 +134,6 @@ def _new_session(self): labels=self.labels, database_role=self.database_role ) - def session(self, **kwargs): - """Check out a session from the pool. - - :param kwargs: (optional) keyword arguments, passed through to - the returned checkout. - - :rtype: :class:`~google.cloud.spanner_v1.session.SessionCheckout` - :returns: a checkout instance, to be used as a context manager for - accessing the session and returning it to the pool. - """ - return SessionCheckout(self, **kwargs) - class FixedSizePool(AbstractSessionPool): """Concrete session pool implementation: @@ -308,13 +296,12 @@ def get(self, timeout=None): age = _NOW() - session.last_use_time if age >= self._max_age and not session.exists(): - if not session.exists(): - add_span_event( - current_span, - "Session is not valid, recreating it", - span_event_attributes, - ) - session = self._database.session() + add_span_event( + current_span, + "Session is not valid, recreating it", + span_event_attributes, + ) + session = self._new_session() session.create() # Replacing with the updated session.id. span_event_attributes["session.id"] = session._session_id @@ -776,27 +763,3 @@ def begin_pending_transactions(self): while not self._pending_sessions.empty(): session = self._pending_sessions.get() super(TransactionPingingPool, self).put(session) - - -class SessionCheckout(object): - """Context manager: hold session checked out from a pool. - - :type pool: concrete subclass of - :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` - :param pool: Pool from which to check out a session. - - :param kwargs: extra keyword arguments to be passed to :meth:`pool.get`. - """ - - _session = None # Not checked out until '__enter__'. - - def __init__(self, pool, **kwargs): - self._pool = pool - self._kwargs = kwargs.copy() - - def __enter__(self): - self._session = self._pool.get(**self._kwargs) - return self._session - - def __exit__(self, *ignored): - self._pool.put(self._session) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index f18ba57582..e3e9aa6f66 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -41,6 +41,7 @@ from google.cloud.spanner_v1.transaction import Transaction from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture +from google.cloud.spanner_v1.types import Session as SessionPB DEFAULT_RETRY_TIMEOUT_SECS = 30 """Default timeout used by :meth:`Session.run_in_transaction`.""" @@ -69,12 +70,13 @@ class Session(object): _session_id = None _transaction = None - def __init__(self, database, labels=None, database_role=None): + def __init__(self, database, labels=None, database_role=None, is_multiplexed=False): self._database = database if labels is None: labels = {} self._labels = labels self._database_role = database_role + self._is_multiplexed = is_multiplexed self._last_use_time = datetime.utcnow() def __lt__(self, other): @@ -87,7 +89,7 @@ def session_id(self): @property def last_use_time(self): - """ "Approximate last use time of this session + """Approximate last use time of this session :rtype: datetime :returns: the approximate last use time of this session""" @@ -110,6 +112,15 @@ def labels(self): """ return self._labels + @property + def is_multiplexed(self): + """Whether this session is multiplexed. + + :rtype: bool + :returns: True if this session is multiplexed, False otherwise. + """ + return self._is_multiplexed + @property def name(self): """Session name used in requests. @@ -153,12 +164,14 @@ def create(self): ) ) - request = CreateSessionRequest(database=self._database.name) - if self._database.database_role is not None: - request.session.creator_role = self._database.database_role + session_pb = SessionPB(multiplexed=self.is_multiplexed) + if self._database.database_role: + session_pb.creator_role = self._database.database_role if self._labels: - request.session.labels = self._labels + session_pb.labels = self._labels + + request = CreateSessionRequest(database=self._database.name, session=session_pb) observability_options = getattr(self._database, "observability_options", None) with trace_call( @@ -287,6 +300,11 @@ def snapshot(self, **kw): if self._session_id is None: raise ValueError("Session has not been created.") + if self.is_multiplexed: + raise NotImplementedError( + "Multiplexed sessions do not yet support read-only transactions." + ) + return Snapshot(self, **kw) def read(self, table, columns, keyset, index="", limit=0, column_info=None): @@ -408,6 +426,11 @@ def batch(self): if self._session_id is None: raise ValueError("Session has not been created.") + if self.is_multiplexed: + raise NotImplementedError( + "Multiplexed sessions do not yet support read/write transactions." + ) + return Batch(self) def transaction(self): @@ -420,6 +443,11 @@ def transaction(self): if self._session_id is None: raise ValueError("Session has not been created.") + if self.is_multiplexed: + raise NotImplementedError( + "Multiplexed sessions do not yet support read/write transactions." + ) + if self._transaction is not None: self._transaction.rolled_back = True del self._transaction diff --git a/google/cloud/spanner_v1/session_options.py b/google/cloud/spanner_v1/session_options.py new file mode 100644 index 0000000000..6e35712a0b --- /dev/null +++ b/google/cloud/spanner_v1/session_options.py @@ -0,0 +1,112 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import os + + +class SessionOptions(object): + """Represents the session options for the Cloud Spanner Python client. + + We can use ::class::`SessionOptions` to determine whether multiplexed sessions should be used for: + * read-only transactions (:meth:`use_multiplexed_for_read_only`) + * partitioned transactions (:meth:`use_multiplexed_for_partitioned`) + * read/write transactions (:meth:`use_multiplexed_for_read_write`). + """ + + MULTIPLEXED_SESSIONS_REFRESH_INTERVAL = datetime.timedelta(days=7) + + # Environment variables for multiplexed sessions + ENV_VAR_ENABLE_MULTIPLEXED = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" + ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED = ( + "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" + ) + ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE = ( + "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" + ) + ENV_VAR_FORCE_DISABLE_MULTIPLEXED = ( + "GOOGLE_CLOUD_SPANNER_FORCE_DISABLE_MULTIPLEXED_SESSIONS" + ) + + def __init__(self): + # Internal overrides to disable the use of multiplexed + # sessions in case of runtime errors. + self._is_multiplexed_enabled_for_read_only = True + self._is_multiplexed_enabled_for_partitioned = True + self._is_multiplexed_enabled_for_read_write = True + + def use_multiplexed_for_read_only(self) -> bool: + """Returns whether to use multiplexed sessions for read-only transactions. + Multiplexed sessions are enabled for read-only transactions if: + * ENV_VAR_ENABLE_MULTIPLEXED is set to true; + * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * multiplexed sessions have not been disabled for read-only transactions (see 'disable_multiplexed_for_read_only'). + """ + + return ( + self._is_multiplexed_enabled_for_read_only + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + ) + + def disable_multiplexed_for_read_only(self) -> None: + """Disables the use of multiplexed sessions for read-only transactions.""" + self._is_multiplexed_enabled_for_read_only = False + + def use_multiplexed_for_partitioned(self) -> bool: + """Returns whether to use multiplexed sessions for partitioned transactions. + Multiplexed sessions are enabled for partitioned transactions if: + * ENV_VAR_ENABLE_MULTIPLEXED is set to true; + * ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED is set to true; + * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * multiplexed sessions have not been disabled for partitioned transactions (see 'disable_multiplexed_for_partitioned'). + """ + + return ( + self._is_multiplexed_enabled_for_partitioned + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + ) + + def disable_multiplexed_for_partitioned(self) -> None: + """Disables the use of multiplexed sessions for read-only transactions.""" + self._is_multiplexed_enabled_for_partitioned = False + + def use_multiplexed_for_read_write(self) -> bool: + """Returns whether to use multiplexed sessions for read/write transactions. + Multiplexed sessions are enabled for read/write transactions if: + * ENV_VAR_ENABLE_MULTIPLEXED is set to true; + * ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED is set to true; + * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * multiplexed sessions have not been disabled for read/write transactions (see 'disable_multiplexed_for_read_write'). + """ + + return ( + self._is_multiplexed_enabled_for_read_write + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + ) + + def disable_multiplexed_for_read_write(self) -> None: + """Disables the use of multiplexed sessions for read/write transactions.""" + self._is_multiplexed_enabled_for_read_write = False + + @staticmethod + def _getenv(name: str) -> bool: + """Returns the value of the given environment variable as a boolean. + True values are '1' and 'true' (case-insensitive); all other values are considered false. + """ + env_var = os.getenv(name, "").lower() + return env_var in ["1", "true"] diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 57ce49c8a2..4e44fb0b0c 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -121,7 +121,7 @@ def test_database_binding_of_fixed_size_pool( database_role="parent", ) database = shared_instance.database(temp_db_id, pool=pool) - assert database._pool.database_role == "parent" + assert database._session_manager._pool.database_role == "parent" def test_database_binding_of_pinging_pool( @@ -155,7 +155,7 @@ def test_database_binding_of_pinging_pool( database_role="parent", ) database = shared_instance.database(temp_db_id, pool=pool) - assert database._pool.database_role == "parent" + assert database._session_manager._pool.database_role == "parent" def test_create_database_pitr_invalid_retention_period( diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index d40b34f800..3f23cc8edc 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch, PropertyMock import pytest +from google.cloud.spanner_v1.session import Session from . import _helpers from google.cloud.spanner_v1 import Client from google.api_core.exceptions import Aborted @@ -195,9 +197,15 @@ def create_db_trace_exporter(): not HAS_OTEL_INSTALLED, reason="Tracing requires OpenTelemetry", ) -def test_transaction_abort_then_retry_spans(): +@patch.object(Session, "session_id", new_callable=PropertyMock) +@patch.object(Session, "is_multiplexed", new_callable=PropertyMock) +def test_transaction_abort_then_retry_spans(mock_session_multiplexed, mock_session_id): from opentelemetry.trace.status import StatusCode + # Mock session properties for testing. + mock_session_multiplexed.return_value = session_multiplexed = False + mock_session_id.return_value = session_id = "session-id" + db, trace_exporter = create_db_trace_exporter() counters = dict(aborted=0) @@ -224,6 +232,8 @@ def select_in_txn(txn): ("Waiting for a session to become available", {"kind": "BurstyPool"}), ("No sessions available in pool. Creating session", {"kind": "BurstyPool"}), ("Creating Session", {}), + ("Using session", {"id": session_id, "multiplexed": session_multiplexed}), + ("Returned session", {"id": session_id, "multiplexed": session_multiplexed}), ( "Transaction was aborted in user operation, retrying", {"delay_seconds": "EPHEMERAL", "cause": "EPHEMERAL", "attempt": 1}, @@ -391,9 +401,15 @@ def tx_update(txn): not HAS_OTEL_INSTALLED, reason="Tracing requires OpenTelemetry", ) -def test_database_partitioned_error(): +@patch.object(Session, "session_id", new_callable=PropertyMock) +@patch.object(Session, "is_multiplexed", new_callable=PropertyMock) +def test_database_partitioned_error(mock_session_multiplexed, mock_session_id): from opentelemetry.trace.status import StatusCode + # Mock session properties for testing. + mock_session_multiplexed.return_value = session_multiplexed = False + mock_session_id.return_value = session_id = "session-id" + db, trace_exporter = create_db_trace_exporter() try: @@ -408,7 +424,9 @@ def test_database_partitioned_error(): ("Waiting for a session to become available", {"kind": "BurstyPool"}), ("No sessions available in pool. Creating session", {"kind": "BurstyPool"}), ("Creating Session", {}), + ("Using session", {"id": session_id, "multiplexed": session_multiplexed}), ("Starting BeginTransaction", {}), + ("Returned session", {"id": session_id, "multiplexed": session_multiplexed}), ( "exception", { diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index 30ab3c7a8d..4021ee6083 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -60,8 +60,8 @@ def test_w_implicit(self, mock_client): self.assertIs(connection.database, database) instance.database.assert_called_once_with(DATABASE, pool=None) - # Datbase constructs its own pool - self.assertIsNotNone(connection.database._pool) + # Database constructs its own pool + self.assertIsNotNone(connection.database._session_manager._pool) self.assertTrue(connection.instance._client.route_to_leader_enabled) def test_w_explicit(self, mock_client): diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 4bee9e93c7..fc41a156db 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -37,6 +37,11 @@ AutocommitDmlMode, ) +from google.cloud.spanner_v1.client import Client +from google.cloud.spanner_v1.instance import Instance +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.session import Session + PROJECT = "test-project" INSTANCE = "test-instance" DATABASE = "test-database" @@ -64,9 +69,6 @@ def _get_client_info(self): def _make_connection( self, database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, **kwargs ): - from google.cloud.spanner_v1.instance import Instance - from google.cloud.spanner_v1.client import Client - # We don't need a real Client object to test the constructor client = Client() instance = Instance(INSTANCE, client=client) @@ -97,22 +99,16 @@ def test_autocommit_setter_transaction_started(self, mock_commit): self.assertTrue(connection._autocommit) def test_property_database(self): - from google.cloud.spanner_v1.database import Database - connection = self._make_connection() self.assertIsInstance(connection.database, Database) self.assertEqual(connection.database, connection._database) def test_property_instance(self): - from google.cloud.spanner_v1.instance import Instance - connection = self._make_connection() self.assertIsInstance(connection.instance, Instance) self.assertEqual(connection.instance, connection._instance) def test_property_current_schema_google_sql_dialect(self): - from google.cloud.spanner_v1.database import Database - connection = self._make_connection( database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL ) @@ -120,8 +116,6 @@ def test_property_current_schema_google_sql_dialect(self): self.assertEqual(connection.current_schema, "") def test_property_current_schema_postgres_sql_dialect(self): - from google.cloud.spanner_v1.database import Database - connection = self._make_connection(database_dialect=DatabaseDialect.POSTGRESQL) self.assertIsInstance(connection.database, Database) self.assertEqual(connection.current_schema, "public") @@ -146,25 +140,39 @@ def test_read_only_connection(self): connection.read_only = False self.assertFalse(connection.read_only) - @staticmethod - def _make_pool(): - from google.cloud.spanner_v1.pool import AbstractSessionPool + def test__session_checkout_read_only(self): + client = Client() + instance = Instance(instance_id="instance-id", client=client) + database = Database(database_id="database-id", instance=instance) + session_manager = database._session_manager - return mock.create_autospec(AbstractSessionPool) + session = Session(database=database) + session_manager.get_session_for_read_only = mock.Mock(return_value=session) - @mock.patch("google.cloud.spanner_v1.database.Database") - def test__session_checkout(self, mock_database): - pool = self._make_pool() - mock_database._pool = pool - connection = Connection(INSTANCE, mock_database) + read_only_connection = Connection( + instance="instance-id", database=database, read_only=True + ) + read_only_connection._session_checkout() - connection._session_checkout() - pool.get.assert_called_once_with() - self.assertEqual(connection._session, pool.get.return_value) + session_manager.get_session_for_read_only.assert_called_once_with() + self.assertEqual(read_only_connection._session, session) - connection._session = "db_session" - connection._session_checkout() - self.assertEqual(connection._session, "db_session") + def test__session_checkout_read_write(self): + client = Client() + instance = Instance(instance_id="instance-id", client=client) + database = Database(database_id="database-id", instance=instance) + session_manager = database._session_manager + + session = Session(database=database) + session_manager.get_session_for_read_write = mock.Mock(return_value=session) + + read_write_connection = Connection( + instance="instance-id", database=database, read_only=False + ) + read_write_connection._session_checkout() + + session_manager.get_session_for_read_write.assert_called_once_with() + self.assertEqual(read_write_connection._session, session) def test_session_checkout_database_error(self): connection = Connection(INSTANCE) @@ -172,15 +180,20 @@ def test_session_checkout_database_error(self): with pytest.raises(ValueError): connection._session_checkout() - @mock.patch("google.cloud.spanner_v1.database.Database") - def test__release_session(self, mock_database): - pool = self._make_pool() - mock_database._pool = pool - connection = Connection(INSTANCE, mock_database) - connection._session = "session" + def test__release_session(self): + client = Client() + instance = Instance(instance_id="instance-id", client=client) + database = Database(database_id="database-id", instance=instance) + connection = Connection(instance="instance-id", database=database) + + # Mock connection session and session manager. + session = Session(database=database) + connection._session = session + session_manager = database._session_manager + session_manager.put_session = mock.Mock() connection._release_session() - pool.put.assert_called_once_with("session") + session_manager.put_session.assert_called_once_with(session) self.assertIsNone(connection._session) def test_release_session_database_error(self): diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 1afda7f850..7bea6fc24a 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -14,6 +14,7 @@ import unittest +from logging import Logger import mock from google.api_core import gapic_v1 @@ -21,14 +22,15 @@ Database as DatabasePB, DatabaseDialect, ) +from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.param_types import INT64 from google.api_core.retry import Retry from google.protobuf.field_mask_pb2 import FieldMask from google.cloud.spanner_v1 import ( - RequestOptions, - DirectedReadOptions, DefaultTransactionOptions, + DirectedReadOptions, + RequestOptions, ) DML_WO_PARAM = """ @@ -101,8 +103,6 @@ def _make_duration(seconds=1, microseconds=0): class TestDatabase(_BaseTest): def _get_target_class(self): - from google.cloud.spanner_v1.database import Database - return Database @staticmethod @@ -127,11 +127,9 @@ def test_ctor_defaults(self): self.assertEqual(database.database_id, self.DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), []) - self.assertIsInstance(database._pool, BurstyPool) + self.assertIsInstance(database._session_manager._pool, BurstyPool) self.assertFalse(database.log_commit_stats) self.assertIsNone(database._logger) - # BurstyPool does not create sessions during 'bind()'. - self.assertTrue(database._pool._sessions.empty()) self.assertIsNone(database.database_role) self.assertTrue(database._route_to_leader_enabled, True) @@ -142,7 +140,7 @@ def test_ctor_w_explicit_pool(self): self.assertEqual(database.database_id, self.DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), []) - self.assertIs(database._pool, pool) + self.assertIs(database._session_manager._pool, pool) self.assertIs(pool._bound, database) def test_ctor_w_database_role(self): @@ -191,8 +189,6 @@ def test_ctor_w_ddl_statements_ok(self): self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS) def test_ctor_w_explicit_logger(self): - from logging import Logger - instance = _Instance(self.INSTANCE_NAME) logger = mock.create_autospec(Logger, instance=True) database = self._make_one(self.DATABASE_ID, instance, logger=logger) @@ -279,7 +275,7 @@ def test_from_pb_success_w_explicit_pool(self): self.assertIsInstance(database, klass) self.assertEqual(database._instance, instance) self.assertEqual(database.database_id, self.DATABASE_ID) - self.assertIs(database._pool, pool) + self.assertIs(database._session_manager._pool, pool) def test_from_pb_success_w_hyphen_w_default_pool(self): from google.cloud.spanner_admin_database_v1 import Database @@ -297,9 +293,9 @@ def test_from_pb_success_w_hyphen_w_default_pool(self): self.assertIsInstance(database, klass) self.assertEqual(database._instance, instance) self.assertEqual(database.database_id, DATABASE_ID_HYPHEN) - self.assertIsInstance(database._pool, BurstyPool) + self.assertIsInstance(database._session_manager._pool, BurstyPool) # BurstyPool does not create sessions during 'bind()'. - self.assertTrue(database._pool._sessions.empty()) + self.assertTrue(database._session_manager._pool._sessions.empty()) def test_name_property(self): instance = _Instance(self.INSTANCE_NAME) @@ -1801,7 +1797,9 @@ def _make_spanner_client(): return mock.create_autospec(SpannerClient) def test_ctor(self): - database = _Database(self.DATABASE_NAME) + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) checkout = self._make_one(database) self.assertIs(checkout._database, database) @@ -1817,22 +1815,27 @@ def test_context_mgr_success(self): now = datetime.datetime.utcnow().replace(tzinfo=UTC) now_pb = _datetime_to_pb_timestamp(now) response = CommitResponse(commit_timestamp=now_pb) - database = _Database(self.DATABASE_NAME) - api = database.spanner_api = self._make_spanner_client() + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + api = database._spanner_api = self._make_spanner_client() api.commit.return_value = response - pool = database._pool = _Pool() + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one( database, request_options={"transaction_tag": self.TRANSACTION_TAG} ) with checkout as batch: - self.assertIsNone(pool._session) + session_manager.get_session_for_read_only.assert_called_once() self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) self.assertEqual(batch.committed, now) self.assertEqual(batch.transaction_tag, self.TRANSACTION_TAG) @@ -1865,21 +1868,29 @@ def test_context_mgr_w_commit_stats_success(self): now_pb = _datetime_to_pb_timestamp(now) commit_stats = CommitResponse.CommitStats(mutation_count=4) response = CommitResponse(commit_timestamp=now_pb, commit_stats=commit_stats) - database = _Database(self.DATABASE_NAME) + logger = mock.create_autospec(Logger, instance=True) + database = Database( + database_id=self.DATABASE_ID, + instance=_Instance(self.INSTANCE_NAME), + logger=logger, + ) database.log_commit_stats = True - api = database.spanner_api = self._make_spanner_client() + api = database._spanner_api = self._make_spanner_client() api.commit.return_value = response - pool = database._pool = _Pool() + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) with checkout as batch: - self.assertIsNone(pool._session) + session_manager.get_session_for_read_only.assert_called_once() self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) self.assertEqual(batch.committed, now) expected_txn_options = TransactionOptions(read_write={}) @@ -1899,32 +1910,45 @@ def test_context_mgr_w_commit_stats_success(self): ], ) - database.logger.info.assert_called_once_with( + database._logger.info.assert_called_once_with( "CommitStats: mutation_count: 4\n", extra={"commit_stats": commit_stats} ) + def test_type_error(self): + with self.assertRaises(TypeError): + with self._make_one(None) as _: + pass + def test_context_mgr_w_aborted_commit_status(self): from google.api_core.exceptions import Aborted from google.cloud.spanner_v1 import CommitRequest from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1.batch import Batch - database = _Database(self.DATABASE_NAME) + logger = mock.create_autospec(Logger, instance=True) + database = Database( + database_id=self.DATABASE_ID, + instance=_Instance(self.INSTANCE_NAME), + logger=logger, + ) database.log_commit_stats = True - api = database.spanner_api = self._make_spanner_client() + api = database._spanner_api = self._make_spanner_client() api.commit.side_effect = Aborted("aborted exception", errors=("Aborted error")) - pool = database._pool = _Pool() + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) with self.assertRaises(Aborted): with checkout as batch: - self.assertIsNone(pool._session) + session_manager.get_session_for_read_only.assert_called_once() self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) expected_txn_options = TransactionOptions(read_write={}) @@ -1951,10 +1975,15 @@ def test_context_mgr_w_aborted_commit_status(self): def test_context_mgr_failure(self): from google.cloud.spanner_v1.batch import Batch - database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) class Testing(Exception): @@ -1962,12 +1991,12 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as batch: - self.assertIsNone(pool._session) + session_manager.get_session_for_read_only.assert_called_once() self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) raise Testing() - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) self.assertIsNone(batch.committed) @@ -1980,23 +2009,27 @@ def _get_target_class(self): def test_ctor_defaults(self): from google.cloud.spanner_v1.snapshot import Snapshot - database = _Database(self.DATABASE_NAME) + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + session = _Session(database) - pool = database._pool = _Pool() - pool.put(session) + session_manager = database._session_manager + session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database) self.assertIs(checkout._database, database) self.assertEqual(checkout._kw, {}) with checkout as snapshot: - self.assertIsNone(pool._session) + session_manager.get_session_for_read_only.assert_called_once() self.assertIsInstance(snapshot, Snapshot) self.assertIs(snapshot._session, session) self.assertTrue(snapshot._strong) self.assertFalse(snapshot._multi_use) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) def test_ctor_w_read_timestamp_and_multi_use(self): import datetime @@ -2004,31 +2037,40 @@ def test_ctor_w_read_timestamp_and_multi_use(self): from google.cloud.spanner_v1.snapshot import Snapshot now = datetime.datetime.utcnow().replace(tzinfo=UTC) - database = _Database(self.DATABASE_NAME) + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + session = _Session(database) - pool = database._pool = _Pool() - pool.put(session) + session_manager = database._session_manager + session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) checkout = self._make_one(database, read_timestamp=now, multi_use=True) self.assertIs(checkout._database, database) self.assertEqual(checkout._kw, {"read_timestamp": now, "multi_use": True}) with checkout as snapshot: - self.assertIsNone(pool._session) + session_manager.get_session_for_read_only.assert_called_once() self.assertIsInstance(snapshot, Snapshot) self.assertIs(snapshot._session, session) self.assertEqual(snapshot._read_timestamp, now) self.assertTrue(snapshot._multi_use) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) def test_context_mgr_failure(self): from google.cloud.spanner_v1.snapshot import Snapshot - database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session_for_read_only = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) class Testing(Exception): @@ -2036,72 +2078,17 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as snapshot: - self.assertIsNone(pool._session) + session_manager.get_session_for_read_only.assert_called_once() self.assertIsInstance(snapshot, Snapshot) self.assertIs(snapshot._session, session) raise Testing() - self.assertIs(pool._session, session) - - def test_context_mgr_session_not_found_error(self): - from google.cloud.exceptions import NotFound - - database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=False) - pool = database._pool = _Pool() - new_session = _Session(database, name="session-2") - new_session.create = mock.MagicMock(return_value=[]) - pool._new_session = mock.MagicMock(return_value=new_session) - - pool.put(session) - checkout = self._make_one(database) - - self.assertEqual(pool._session, session) - with self.assertRaises(NotFound): - with checkout as _: - raise NotFound("Session not found") - # Assert that session-1 was removed from pool and new session was added. - self.assertEqual(pool._session, new_session) - - def test_context_mgr_table_not_found_error(self): - from google.cloud.exceptions import NotFound - - database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=True) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) + session_manager.put_session.assert_called_once_with(session) - pool.put(session) - checkout = self._make_one(database) - - self.assertEqual(pool._session, session) - with self.assertRaises(NotFound): - with checkout as _: - raise NotFound("Table not found") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() - - def test_context_mgr_unknown_error(self): - database = _Database(self.DATABASE_NAME) - session = _Session(database) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) - pool.put(session) - checkout = self._make_one(database) - - class Testing(Exception): - pass - - self.assertEqual(pool._session, session) - with self.assertRaises(Testing): - with checkout as _: - raise Testing("Unknown error.") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() + def test_type_error(self): + with self.assertRaises(TypeError): + with self._make_one(None) as _: + pass class TestBatchSnapshot(_BaseTest): @@ -2117,8 +2104,6 @@ def _get_target_class(self): @staticmethod def _make_database(**kwargs): - from google.cloud.spanner_v1.database import Database - return mock.create_autospec(Database, instance=True, **kwargs) @staticmethod @@ -2944,19 +2929,24 @@ def _make_spanner_client(): def test_ctor(self): from google.cloud.spanner_v1.batch import MutationGroups - database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session_for_read_write = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) self.assertIs(checkout._database, database) with checkout as groups: - self.assertIsNone(pool._session) + session_manager.get_session_for_read_write.assert_called_once() self.assertIsInstance(groups, MutationGroups) self.assertIs(groups._session, session) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) def test_context_mgr_success(self): import datetime @@ -2975,12 +2965,17 @@ def test_context_mgr_success(self): response = BatchWriteResponse( commit_timestamp=now_pb, indexes=[0], status=status_pb ) - database = _Database(self.DATABASE_NAME) - api = database.spanner_api = self._make_spanner_client() + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + api = database._spanner_api = self._make_spanner_client() api.batch_write.return_value = [response] - pool = database._pool = _Pool() + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session_for_read_write = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) request_options = RequestOptions(transaction_tag=self.TRANSACTION_TAG) @@ -3002,7 +2997,7 @@ def test_context_mgr_success(self): request_options=request_options, ) with checkout as groups: - self.assertIsNone(pool._session) + session_manager.get_session_for_read_write.assert_called_once() self.assertIsInstance(groups, MutationGroups) self.assertIs(groups._session, session) group = groups.group() @@ -3010,7 +3005,7 @@ def test_context_mgr_success(self): groups.batch_write(request_options) self.assertEqual(groups.committed, True) - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) api.batch_write.assert_called_once_with( request=request, @@ -3020,13 +3015,23 @@ def test_context_mgr_success(self): ], ) + def test_type_error(self): + with self.assertRaises(TypeError): + with self._make_one(None) as _: + pass + def test_context_mgr_failure(self): from google.cloud.spanner_v1.batch import MutationGroups - database = _Database(self.DATABASE_NAME) - pool = database._pool = _Pool() + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + session = _Session(database) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session_for_read_write = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) class Testing(Exception): @@ -3034,72 +3039,74 @@ class Testing(Exception): with self.assertRaises(Testing): with checkout as groups: - self.assertIsNone(pool._session) self.assertIsInstance(groups, MutationGroups) + session_manager.get_session_for_read_write.assert_called_once() self.assertIs(groups._session, session) raise Testing() - self.assertIs(pool._session, session) + session_manager.put_session.assert_called_once_with(session) - def test_context_mgr_session_not_found_error(self): - from google.cloud.exceptions import NotFound - database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=False) - pool = database._pool = _Pool() - new_session = _Session(database, name="session-2") - new_session.create = mock.MagicMock(return_value=[]) - pool._new_session = mock.MagicMock(return_value=new_session) +class TestSessionCheckout(_BaseTest): + def _get_target_class(self): + from google.cloud.spanner_v1.database import SessionCheckout - pool.put(session) - checkout = self._make_one(database) + return SessionCheckout - self.assertEqual(pool._session, session) - with self.assertRaises(NotFound): - with checkout as _: - raise NotFound("Session not found") - # Assert that session-1 was removed from pool and new session was added. - self.assertEqual(pool._session, new_session) + def test_ctor(self): + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) - def test_context_mgr_table_not_found_error(self): - from google.cloud.exceptions import NotFound + checkout = self._make_one(database) + self.assertIs(checkout._database, database) + self.assertIsNone(checkout._session) - database = _Database(self.DATABASE_NAME) - session = _Session(database, name="session-1") - session.exists = mock.MagicMock(return_value=True) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) + def test_context_manager_success(self): + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + + session = _Session(database) + session_manager = database._session_manager + session_manager.get_session_for_read_write = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) - pool.put(session) checkout = self._make_one(database) - self.assertEqual(pool._session, session) - with self.assertRaises(NotFound): - with checkout as _: - raise NotFound("Table not found") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() - - def test_context_mgr_unknown_error(self): - database = _Database(self.DATABASE_NAME) + with checkout as borrowed: + session_manager.get_session_for_read_write.assert_called_once() + self.assertIs(borrowed, session) + + session_manager.put_session.assert_called_once_with(session) + + def test_context_manager_failure(self): + database = Database( + database_id=self.DATABASE_ID, instance=_Instance(self.INSTANCE_NAME) + ) + session = _Session(database) - pool = database._pool = _Pool() - pool._new_session = mock.MagicMock(return_value=[]) - pool.put(session) + session_manager = database._session_manager + session_manager.get_session_for_read_write = mock.Mock(return_value=session) + session_manager.put_session = mock.Mock(return_value=None) + checkout = self._make_one(database) class Testing(Exception): pass - self.assertEqual(pool._session, session) with self.assertRaises(Testing): - with checkout as _: - raise Testing("Unknown error.") - # Assert that session-1 was not removed from pool. - self.assertEqual(pool._session, session) - pool._new_session.assert_not_called() + with checkout as borrowed: + session_manager.get_session_for_read_write.assert_called_once() + self.assertIs(borrowed, session) + raise Testing() + + session_manager.put_session.assert_called_once_with(session) + + def test_type_error(self): + with self.assertRaises(TypeError): + with self._make_one(None) as _: + pass def _make_instance_api(): @@ -3123,6 +3130,7 @@ def __init__( default_transaction_options=DefaultTransactionOptions(), ): from google.cloud.spanner_v1 import ExecuteSqlRequest + from google.cloud.spanner_v1.session_options import SessionOptions self.project = project self.project_name = "projects/" + self.project @@ -3135,6 +3143,7 @@ def __init__( self.route_to_leader_enabled = route_to_leader_enabled self.directed_read_options = directed_read_options self.default_transaction_options = default_transaction_options + self.session_options = SessionOptions() class _Instance(object): @@ -3158,7 +3167,6 @@ def __init__(self, name, instance=None): self.name = name self.database_id = name.rsplit("/", 1)[1] self._instance = instance - from logging import Logger self.logger = mock.create_autospec(Logger, instance=True) self._directed_read_options = None @@ -3191,6 +3199,7 @@ def __init__( self._database = database self.name = name self._run_transaction_function = run_transaction_function + self.is_multiplexed = False def run_in_transaction(self, func, *args, **kw): if self._run_transaction_function: diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py new file mode 100644 index 0000000000..a6a7721614 --- /dev/null +++ b/tests/unit/test_database_session_manager.py @@ -0,0 +1,140 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import TestCase +from unittest.mock import Mock, patch + +from google.cloud.spanner_v1 import SpannerClient +from google.cloud.spanner_v1.client import Client +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.instance import Instance +from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.session_options import SessionOptions + + +class TestDatabaseSessionManager(TestCase): + def setUp(self): + self._original_env = dict(os.environ) + + self._mocks = { + "create_session": patch.object(SpannerClient, "create_session").start(), + "get_session": patch.object(SpannerClient, "get_session").start(), + "delete_session": patch.object(SpannerClient, "delete_session").start(), + } + + def tearDown(self): + os.environ.clear() + os.environ.update(self._original_env) + + patch.stopall() + + def test_read_only_pooled(self): + self._disable_multiplexed_env_vars() + session_manager = self._build_database_session_manager() + + # Get session from pool. + session = session_manager.get_session_for_read_only() + self.assertFalse(session.is_multiplexed) + session_manager._pool.get.assert_called_once() + + # Return session to pool. + session_manager.put_session(session) + session_manager._pool.put.assert_called_once_with(session) + + def test_read_only_multiplexed(self): + self._enable_multiplexed_env_vars() + session_manager = self._build_database_session_manager() + + with self.assertRaises(NotImplementedError): + session_manager.get_session_for_read_only() + + def test_partitioned_non_multiplexed(self): + self._disable_multiplexed_env_vars() + session_manager = self._build_database_session_manager() + + # Get session from pool. + session = session_manager.get_session_for_partitioned() + self.assertFalse(session.is_multiplexed) + session_manager._pool.get.assert_called_once() + + # Return session to pool. + session_manager.put_session(session) + session_manager._pool.put.assert_called_once_with(session) + + def test_partitioned_multiplexed(self): + self._enable_multiplexed_env_vars() + session_manager = self._build_database_session_manager() + + with self.assertRaises(NotImplementedError): + session_manager.get_session_for_partitioned() + + def test_read_write_non_multiplexed(self): + self._disable_multiplexed_env_vars() + session_manager = self._build_database_session_manager() + + # Get session from pool. + session = session_manager.get_session_for_read_write() + self.assertFalse(session.is_multiplexed) + session_manager._pool.get.assert_called_once() + + # Return session to pool. + session_manager.put_session(session) + session_manager._pool.put.assert_called_once_with(session) + + def test_read_write_multiplexed(self): + self._enable_multiplexed_env_vars() + session_manager = self._build_database_session_manager() + + with self.assertRaises(NotImplementedError): + session_manager.get_session_for_read_write() + + def test_put_multiplexed(self): + session_manager = self._build_database_session_manager() + + with self.assertRaises(NotImplementedError): + session_manager.put_session( + Session(database=session_manager._database, is_multiplexed=True) + ) + + @staticmethod + def _build_database_session_manager(): + """Builds and returns a new database session manager for testing.""" + + client = Client(project="project-id") + instance = Instance(instance_id="instance-id", client=client) + + database = Database(database_id="database-id", instance=instance) + session_manager = database._session_manager + + # Mock the session pool. + pool = session_manager._pool + pool.get = Mock(wraps=pool.get) + pool.put = Mock(wraps=pool.put) + + return session_manager + + @staticmethod + def _enable_multiplexed_env_vars(): + """Sets environment variables to enable multiplexed sessions.""" + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + + @staticmethod + def _disable_multiplexed_env_vars(): + """Sets environment variables to disable multiplexed sessions.""" + + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index e7ad729438..9ece105a3d 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -539,10 +539,9 @@ def test_database_factory_defaults(self): self.assertEqual(database.database_id, DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), []) - self.assertIsInstance(database._pool, BurstyPool) + self.assertIsInstance(database._session_manager._pool, BurstyPool) self.assertIsNone(database._logger) - pool = database._pool - self.assertIs(pool._database, database) + self.assertIs(database._session_manager._pool._database, database) self.assertIsNone(database.database_role) def test_database_factory_explicit(self): @@ -573,7 +572,7 @@ def test_database_factory_explicit(self): self.assertEqual(database.database_id, DATABASE_ID) self.assertIs(database._instance, instance) self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS) - self.assertIs(database._pool, pool) + self.assertIs(database._session_manager._pool, pool) self.assertIs(database._logger, logger) self.assertIs(pool._bound, database) self.assertIs(database._encryption_config, encryption_config) diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index a9593b3651..e6bc827f57 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -121,26 +121,6 @@ def test__new_session_w_database_role(self): self.assertIs(new_session, session) database.session.assert_called_once_with(labels={}, database_role=database_role) - def test_session_wo_kwargs(self): - from google.cloud.spanner_v1.pool import SessionCheckout - - pool = self._make_one() - checkout = pool.session() - self.assertIsInstance(checkout, SessionCheckout) - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual(checkout._kwargs, {}) - - def test_session_w_kwargs(self): - from google.cloud.spanner_v1.pool import SessionCheckout - - pool = self._make_one() - checkout = pool.session(foo="bar") - self.assertIsInstance(checkout, SessionCheckout) - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual(checkout._kwargs, {"foo": "bar"}) - class TestFixedSizePool(OpenTelemetryBase): BASE_ATTRIBUTES = { @@ -1073,73 +1053,6 @@ def test_spans_get_and_leave_empty_pool(self): self.assertSpanEvents("pool.Get", wantEventNames, span_list[-1]) -class TestSessionCheckout(unittest.TestCase): - def _getTargetClass(self): - from google.cloud.spanner_v1.pool import SessionCheckout - - return SessionCheckout - - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) - - def test_ctor_wo_kwargs(self): - pool = _Pool() - checkout = self._make_one(pool) - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual(checkout._kwargs, {}) - - def test_ctor_w_kwargs(self): - pool = _Pool() - checkout = self._make_one(pool, foo="bar", database_role="dummy-role") - self.assertIs(checkout._pool, pool) - self.assertIsNone(checkout._session) - self.assertEqual( - checkout._kwargs, {"foo": "bar", "database_role": "dummy-role"} - ) - - def test_context_manager_wo_kwargs(self): - session = object() - pool = _Pool(session) - checkout = self._make_one(pool) - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - - with checkout as borrowed: - self.assertIs(borrowed, session) - self.assertEqual(len(pool._items), 0) - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - self.assertEqual(pool._got, {}) - - def test_context_manager_w_kwargs(self): - session = object() - pool = _Pool(session) - checkout = self._make_one(pool, foo="bar") - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - - with checkout as borrowed: - self.assertIs(borrowed, session) - self.assertEqual(len(pool._items), 0) - - self.assertEqual(len(pool._items), 1) - self.assertIs(pool._items[0], session) - self.assertEqual(pool._got, {"foo": "bar"}) - - -def _make_transaction(*args, **kw): - from google.cloud.spanner_v1.transaction import Transaction - - txn = mock.create_autospec(Transaction)(*args, **kw) - txn.committed = None - txn.rolled_back = False - return txn - - @total_ordering class _Session(object): _transaction = None @@ -1183,14 +1096,6 @@ def delete(self): if not self._exists: raise NotFound("unknown session") - def transaction(self): - txn = self._transaction = _make_transaction(self) - return txn - - @property - def session_id(self): - return self._session_id - class _Database(object): def __init__(self, name): diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8f5f7039b9..a52abdac3d 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -99,8 +99,13 @@ def _make_database( return database @staticmethod - def _make_session_pb(name, labels=None, database_role=None): - return SessionRequestProto(name=name, labels=labels, creator_role=database_role) + def _make_session_pb(name=None, multiplexed=False, labels=None, database_role=None): + return SessionRequestProto( + name=name, + multiplexed=multiplexed, + labels=labels, + creator_role=database_role, + ) def _make_spanner_api(self): return mock.Mock(autospec=SpannerClient, instance=True) @@ -167,7 +172,7 @@ def test_create_w_session_id(self): def test_create_w_database_role(self): session_pb = self._make_session_pb( - self.SESSION_NAME, database_role=self.DATABASE_ROLE + name=self.SESSION_NAME, database_role=self.DATABASE_ROLE ) gax_api = self._make_spanner_api() gax_api.create_session.return_value = session_pb @@ -200,7 +205,7 @@ def test_create_w_database_role(self): def test_create_session_span_annotations(self): session_pb = self._make_session_pb( - self.SESSION_NAME, database_role=self.DATABASE_ROLE + name=self.SESSION_NAME, database_role=self.DATABASE_ROLE ) gax_api = self._make_spanner_api() @@ -233,7 +238,7 @@ def test_create_session_span_annotations(self): self.assertSpanEvents("TestSessionSpan", wantEventNames, span) def test_create_wo_database_role(self): - session_pb = self._make_session_pb(self.SESSION_NAME) + session_pb = self._make_session_pb(name=self.SESSION_NAME) gax_api = self._make_spanner_api() gax_api.create_session.return_value = session_pb database = self._make_database() @@ -245,7 +250,7 @@ def test_create_wo_database_role(self): self.assertIsNone(session.database_role) request = CreateSessionRequest( - database=database.name, + database=database.name, session=self._make_session_pb() ) gax_api.create_session.assert_called_once_with( @@ -261,7 +266,7 @@ def test_create_wo_database_role(self): ) def test_create_ok(self): - session_pb = self._make_session_pb(self.SESSION_NAME) + session_pb = self._make_session_pb(name=self.SESSION_NAME) gax_api = self._make_spanner_api() gax_api.create_session.return_value = session_pb database = self._make_database() @@ -274,6 +279,7 @@ def test_create_ok(self): request = CreateSessionRequest( database=database.name, + session=self._make_session_pb(), ) gax_api.create_session.assert_called_once_with( @@ -290,7 +296,7 @@ def test_create_ok(self): def test_create_w_labels(self): labels = {"foo": "bar"} - session_pb = self._make_session_pb(self.SESSION_NAME, labels=labels) + session_pb = self._make_session_pb(name=self.SESSION_NAME, labels=labels) gax_api = self._make_spanner_api() gax_api.create_session.return_value = session_pb database = self._make_database() @@ -343,7 +349,7 @@ def test_exists_wo_session_id(self): self.assertNoSpans() def test_exists_hit(self): - session_pb = self._make_session_pb(self.SESSION_NAME) + session_pb = self._make_session_pb(name=self.SESSION_NAME) gax_api = self._make_spanner_api() gax_api.get_session.return_value = session_pb database = self._make_database() @@ -371,7 +377,7 @@ def test_exists_hit(self): False, ) def test_exists_hit_wo_span(self): - session_pb = self._make_session_pb(self.SESSION_NAME) + session_pb = self._make_session_pb(name=self.SESSION_NAME) gax_api = self._make_spanner_api() gax_api.get_session.return_value = session_pb database = self._make_database() diff --git a/tests/unit/test_session_options.py b/tests/unit/test_session_options.py new file mode 100644 index 0000000000..7eac21f3cc --- /dev/null +++ b/tests/unit/test_session_options.py @@ -0,0 +1,101 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import TestCase +from google.cloud.spanner_v1.session_options import SessionOptions + + +class TestSessionOptions(TestCase): + @classmethod + def setUpClass(cls): + cls._original_env = dict(os.environ) + + @classmethod + def tearDownClass(cls): + os.environ.clear() + os.environ.update(cls._original_env) + + def test_use_multiplexed_for_read_only(self): + session_options = SessionOptions() + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" + self.assertFalse(session_options.use_multiplexed_for_read_only()) + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + self.assertFalse(session_options.use_multiplexed_for_read_only()) + + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + self.assertTrue(session_options.use_multiplexed_for_read_only()) + + session_options.disable_multiplexed_for_read_only() + self.assertFalse(session_options.use_multiplexed_for_read_only()) + + def test_use_multiplexed_for_partitioned(self): + session_options = SessionOptions() + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" + self.assertFalse(session_options.use_multiplexed_for_partitioned()) + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "false" + self.assertFalse(session_options.use_multiplexed_for_partitioned()) + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + self.assertFalse(session_options.use_multiplexed_for_partitioned()) + + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + self.assertTrue(session_options.use_multiplexed_for_partitioned()) + + session_options.disable_multiplexed_for_partitioned() + self.assertFalse(session_options.use_multiplexed_for_partitioned()) + + def test_use_multiplexed_for_read_write(self): + session_options = SessionOptions() + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" + self.assertFalse(session_options.use_multiplexed_for_read_write()) + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "false" + self.assertFalse(session_options.use_multiplexed_for_read_write()) + + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + self.assertFalse(session_options.use_multiplexed_for_read_write()) + + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + self.assertTrue(session_options.use_multiplexed_for_read_write()) + + session_options.disable_multiplexed_for_read_write() + self.assertFalse(session_options.use_multiplexed_for_read_write()) + + def test_supported_env_var_values(self): + session_options = SessionOptions() + + os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + + true_values = ["1", "true", "True", "TRUE"] + for value in true_values: + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value + self.assertTrue(session_options.use_multiplexed_for_read_only()) + + false_values = ["", "0", "false", "False", "FALSE"] + for value in false_values: + os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value + self.assertFalse(session_options.use_multiplexed_for_read_only()) + + del os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] + self.assertFalse(session_options.use_multiplexed_for_read_only())