From a9b3d022a13807ec660702abb27ff98567be36de Mon Sep 17 00:00:00 2001 From: currantw Date: Thu, 27 Mar 2025 11:14:56 -0700 Subject: [PATCH] feat: Add support for multiplexed sessions - part 2 (read-only transactions) - Enabled use of multiplexed sessions for read-only transactions - Add maintenance threads for refreshing multiplexed sessions Signed-off-by: currantw --- google/cloud/spanner_v1/database.py | 4 +- .../spanner_v1/database_sessions_manager.py | 238 ++++++++++++++++-- google/cloud/spanner_v1/session_options.py | 36 ++- tests/unit/test_database_session_manager.py | 166 +++++++++++- tests/unit/test_session_options.py | 2 +- 5 files changed, 411 insertions(+), 35 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index a7caaa8a29..098bdc0730 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -192,7 +192,9 @@ def __init__( pool = BurstyPool(database_role=database_role) pool.bind(self) - self._session_manager = DatabaseSessionsManager(database=self, pool=pool) + self._session_manager = DatabaseSessionsManager( + database=self, pool=pool, logger=self.logger + ) @classmethod def from_pb(cls, database_pb, instance, pool=None): diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py index 722b4df769..293192304a 100644 --- a/google/cloud/spanner_v1/database_sessions_manager.py +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -11,6 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime +import threading +import time +import weakref + +from google.api_core.exceptions import MethodNotImplemented + from google.cloud.spanner_v1._opentelemetry_tracing import ( get_current_span, add_span_event, @@ -34,12 +41,35 @@ class DatabaseSessionsManager(object): :type pool: :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` :param pool: The pool to get non-multiplexed sessions from. + + :type logger: :class:`logging.Logger` + :param logger: Logger for the database session manager. """ - def __init__(self, database, pool): + # Intervals for the maintenance thread to check and refresh the multiplexed session. + _MAINTENANCE_THREAD_POLLING_INTERVAL = datetime.timedelta(hours=1) + _MAINTENANCE_THREAD_REFRESH_INTERVAL = datetime.timedelta(days=7) + + def __init__(self, database, pool, logger): self._database = database + self._logger = logger + + # The session pool manages non-multiplexed sessions, and + # will only be used if multiplexed sessions are not enabled. self._pool = pool + # Declare multiplexed session attributes. When a multiplexed session for the + # database session manager is created, a maintenance thread is initialized to + # periodically delete and recreate the multiplexed session so that it remains + # valid. Because of this concurrency, we need to use a lock whenever we access + # the multiplexed session to avoid any race conditions. We also create an event + # so that the thread can terminate if the use of multiplexed session has been + # disabled for all transactions. + self._multiplexed_session = None + self._multiplexed_session_maintenance_thread = None + self._multiplexed_session_lock = threading.Lock() + self._is_multiplexed_sessions_disabled_event = threading.Event() + def get_session_for_read_only(self) -> Session: """Returns a session for read-only transactions from the database session manager. @@ -47,14 +77,9 @@ def get_session_for_read_only(self) -> Session: :returns: a session for read-only transactions. """ - if ( - self._database._instance._client.session_options.use_multiplexed_for_read_only() - ): - raise NotImplementedError( - "Multiplexed sessions are not yet supported for read-only transactions." - ) - - return self._get_pooled_session() + return self._get_session( + use_multiplexed=self._database._instance._client.session_options.use_multiplexed_for_read_only() + ) def get_session_for_partitioned(self) -> Session: """Returns a session for partitioned transactions from the database session manager. @@ -70,7 +95,7 @@ def get_session_for_partitioned(self) -> Session: "Multiplexed sessions are not yet supported for partitioned transactions." ) - return self._get_pooled_session() + return self._get_session(use_multiplexed=False) def get_session_for_read_write(self) -> Session: """Returns a session for read/write transactions from the database session manager. @@ -86,15 +111,20 @@ def get_session_for_read_write(self) -> Session: "Multiplexed sessions are not yet supported for read/write transactions." ) - return self._get_pooled_session() + return self._get_session(use_multiplexed=False) def put_session(self, session: Session) -> None: - """Returns the session to the database session manager.""" + """Returns the session to the database session manager. - if session.is_multiplexed: - raise NotImplementedError("Multiplexed sessions are not yet supported.") + :type session: :class:`~google.cloud.spanner_v1.session.Session` + :param session: The session to return to the database session manager. + """ - self._pool.put(session) + # No action is needed for multiplexed sessions: the session + # pool is only used for managing non-multiplexed sessions, + # since they can only process one transaction at a time. + if not session.is_multiplexed: + self._pool.put(session) current_span = get_current_span() add_span_event( @@ -103,10 +133,34 @@ def put_session(self, session: Session) -> None: {"id": session.session_id, "multiplexed": session.is_multiplexed}, ) - def _get_pooled_session(self): - """Returns a non-multiplexed session from the session pool.""" + def _get_session(self, use_multiplexed: bool) -> Session: + """Returns a session from the database session manager. + + If use_multiplexed is True, returns a multiplexed session if + multiplexed sessions are supported. If multiplexed sessions are + not supported or if use_multiplexed is False, returns a non- + multiplexed session from the session pool. + + :type use_multiplexed: bool + :param use_multiplexed: Whether to try to get a multiplexed session. + + :rtype: :class:`~google.cloud.spanner_v1.session.Session` + :returns: a session for the database session manager. + """ + + if use_multiplexed: + try: + session = self._get_multiplexed_session() + + # If multiplexed sessions are not supported, disable + # them for all transactions and return a non-multiplexed session. + except MethodNotImplemented: + self._disable_multiplexed_sessions() + session = self._pool.get() + + else: + session = self._pool.get() - session = self._pool.get() add_span_event( get_current_span(), "Using session", @@ -114,3 +168,151 @@ def _get_pooled_session(self): ) return session + + def _get_multiplexed_session(self) -> Session: + """Returns a multiplexed session from the database session manager. + + If the multiplexed session is not defined, creates a new multiplexed + session and starts a maintenance thread to periodically delete and + recreate it so that it remains valid. Otherwise, simply returns the + current multiplexed session. + + :raises MethodNotImplemented: + if multiplexed sessions are not supported. + + :rtype: :class:`~google.cloud.spanner_v1.session.Session` + :returns: a multiplexed session. + """ + + with self._multiplexed_session_lock: + if self._multiplexed_session is None: + self._multiplexed_session = self._build_multiplexed_session() + + # Build and start a thread to maintain the multiplexed session. + self._multiplexed_session_maintenance_thread = ( + self._build_maintenance_thread() + ) + self._multiplexed_session_maintenance_thread.start() + + add_span_event( + get_current_span(), + "Using session", + {"id": self._multiplexed_session.session_id, "multiplexed": True}, + ) + + return self._multiplexed_session + + def _build_multiplexed_session(self) -> Session: + """Builds and returns a new multiplexed session for the database session manager. + + :raises MethodNotImplemented: + if multiplexed sessions are not supported. + + :rtype: :class:`~google.cloud.spanner_v1.session.Session` + :returns: a new multiplexed session. + """ + + session = Session( + database=self._database, + database_role=self._database.database_role, + is_multiplexed=True, + ) + + session.create() + + self._logger.info("Created multiplexed session.") + + return session + + def _disable_multiplexed_sessions(self) -> None: + """Disables multiplexed sessions for all transactions.""" + + self._logger.warning( + "Multiplexed session creation failed. Disabling multiplexed sessions." + ) + + session_options = self._database._instance._client.session_options + session_options.disable_multiplexed_for_read_only() + session_options.disable_multiplexed_for_partitioned() + session_options.disable_multiplexed_for_read_write() + + self._multiplexed_session = None + self._is_multiplexed_sessions_disabled_event.set() + + def _build_maintenance_thread(self) -> threading.Thread: + """Builds and returns a multiplexed session maintenance thread for + the database session manager. This thread will periodically delete + and recreate the multiplexed session to ensure that it is always valid. + + :rtype: :class:`threading.Thread` + :returns: a multiplexed session maintenance thread. + """ + + # Use a weak reference to the database session manager to avoid + # creating a circular reference that would prevent the database + # session manager from being garbage collected. + session_manager_ref = weakref.ref(self) + + return threading.Thread( + target=self._maintain_multiplexed_session, + name=f"maintenance-multiplexed-session-{self._multiplexed_session.name}", + args=[session_manager_ref], + daemon=True, + ) + + @staticmethod + def _maintain_multiplexed_session(session_manager_ref) -> None: + """Maintains the multiplexed session for the database session manager. + + This method will delete and recreate the referenced database session manager's + multiplexed session to ensure that it is always valid. The method will run until + the database session manager is deleted, the multiplexed session is deleted, or + building a multiplexed session fails. + + :type session_manager_ref: :class:`_weakref.ReferenceType` + :param session_manager_ref: A weak reference to the database session manager. + """ + + session_manager = session_manager_ref() + if session_manager is None: + return + + polling_interval_seconds = ( + session_manager._MAINTENANCE_THREAD_POLLING_INTERVAL.total_seconds() + ) + refresh_interval_seconds = ( + session_manager._MAINTENANCE_THREAD_REFRESH_INTERVAL.total_seconds() + ) + + session_created_time = time.time() + + while True: + # Terminate the thread is the database session manager has been deleted. + session_manager = session_manager_ref() + if session_manager is None: + return + + # Terminate the thread if the use of multiplexed sessions has been disabled. + if session_manager._is_multiplexed_sessions_disabled_event.is_set(): + return + + # Wait for until the refresh interval has elapsed. + if time.time() - session_created_time < refresh_interval_seconds: + time.sleep(polling_interval_seconds) + continue + + with session_manager._multiplexed_session_lock: + session_manager._multiplexed_session.delete() + + try: + session_manager._multiplexed_session = ( + session_manager._build_multiplexed_session() + ) + + # Disable multiplexed sessions for all transactions and terminate + # the thread if building a multiplexed session fails. + except MethodNotImplemented: + session_manager._disable_multiplexed_sessions() + return + + session_created_time = time.time() diff --git a/google/cloud/spanner_v1/session_options.py b/google/cloud/spanner_v1/session_options.py index 6e35712a0b..9675ca90db 100644 --- a/google/cloud/spanner_v1/session_options.py +++ b/google/cloud/spanner_v1/session_options.py @@ -11,9 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import datetime import os +from google.cloud.spanner_v1._opentelemetry_tracing import ( + get_current_span, + add_span_event, +) + class SessionOptions(object): """Represents the session options for the Cloud Spanner Python client. @@ -22,9 +26,12 @@ class SessionOptions(object): * read-only transactions (:meth:`use_multiplexed_for_read_only`) * partitioned transactions (:meth:`use_multiplexed_for_partitioned`) * read/write transactions (:meth:`use_multiplexed_for_read_write`). - """ - MULTIPLEXED_SESSIONS_REFRESH_INTERVAL = datetime.timedelta(days=7) + The use of multiplexed session can be disabled for corresponding transaction types by calling: + * :meth:`disable_multiplexed_for_read_only` + * :meth:`disable_multiplexed_for_partitioned` + * :meth:`disable_multiplexed_for_read_write`. + """ # Environment variables for multiplexed sessions ENV_VAR_ENABLE_MULTIPLEXED = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" @@ -61,6 +68,13 @@ def use_multiplexed_for_read_only(self) -> bool: def disable_multiplexed_for_read_only(self) -> None: """Disables the use of multiplexed sessions for read-only transactions.""" + + current_span = get_current_span() + add_span_event( + current_span, + "Disabling use of multiplexed session for read-only transactions", + ) + self._is_multiplexed_enabled_for_read_only = False def use_multiplexed_for_partitioned(self) -> bool: @@ -81,6 +95,13 @@ def use_multiplexed_for_partitioned(self) -> bool: def disable_multiplexed_for_partitioned(self) -> None: """Disables the use of multiplexed sessions for read-only transactions.""" + + current_span = get_current_span() + add_span_event( + current_span, + "Disabling use of multiplexed session for partitioned transactions", + ) + self._is_multiplexed_enabled_for_partitioned = False def use_multiplexed_for_read_write(self) -> bool: @@ -101,6 +122,13 @@ def use_multiplexed_for_read_write(self) -> bool: def disable_multiplexed_for_read_write(self) -> None: """Disables the use of multiplexed sessions for read/write transactions.""" + + current_span = get_current_span() + add_span_event( + current_span, + "Disabling use of multiplexed session for read/write transactions", + ) + self._is_multiplexed_enabled_for_read_write = False @staticmethod @@ -108,5 +136,5 @@ def _getenv(name: str) -> bool: """Returns the value of the given environment variable as a boolean. True values are '1' and 'true' (case-insensitive); all other values are considered false. """ - env_var = os.getenv(name, "").lower() + env_var = os.getenv(name, "").lower().strip() return env_var in ["1", "true"] diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py index a6a7721614..c1346da163 100644 --- a/tests/unit/test_database_session_manager.py +++ b/tests/unit/test_database_session_manager.py @@ -12,14 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import datetime +import time from unittest import TestCase -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, DEFAULT, PropertyMock + +from google.api_core.exceptions import ( + MethodNotImplemented, + BadRequest, + FailedPrecondition, +) from google.cloud.spanner_v1 import SpannerClient from google.cloud.spanner_v1.client import Client from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager from google.cloud.spanner_v1.instance import Instance -from google.cloud.spanner_v1.session import Session from google.cloud.spanner_v1.session_options import SessionOptions @@ -29,8 +37,20 @@ def setUp(self): self._mocks = { "create_session": patch.object(SpannerClient, "create_session").start(), - "get_session": patch.object(SpannerClient, "get_session").start(), "delete_session": patch.object(SpannerClient, "delete_session").start(), + # Mock faster polling and refresh intervals for tests. + "polling_interval": patch.object( + DatabaseSessionsManager, + "_MAINTENANCE_THREAD_POLLING_INTERVAL", + new_callable=PropertyMock, + return_value=datetime.timedelta(seconds=1), + ).start(), + "refresh_interval": patch.object( + DatabaseSessionsManager, + "_MAINTENANCE_THREAD_REFRESH_INTERVAL", + new_callable=PropertyMock, + return_value=datetime.timedelta(seconds=2), + ).start(), } def tearDown(self): @@ -56,10 +76,21 @@ def test_read_only_multiplexed(self): self._enable_multiplexed_env_vars() session_manager = self._build_database_session_manager() - with self.assertRaises(NotImplementedError): - session_manager.get_session_for_read_only() + # Session is created. + session_1 = session_manager.get_session_for_read_only() + self.assertTrue(session_1.is_multiplexed) + session_manager.put_session(session_1) + + # Session is re-used. + session_2 = session_manager.get_session_for_read_only() + self.assertEqual(session_1, session_2) + session_manager.put_session(session_2) - def test_partitioned_non_multiplexed(self): + # Verify that pool was not used. + session_manager._pool.get.assert_not_called() + session_manager._pool.put.assert_not_called() + + def test_partitioned_pooled(self): self._disable_multiplexed_env_vars() session_manager = self._build_database_session_manager() @@ -79,7 +110,7 @@ def test_partitioned_multiplexed(self): with self.assertRaises(NotImplementedError): session_manager.get_session_for_partitioned() - def test_read_write_non_multiplexed(self): + def test_read_write_pooled(self): self._disable_multiplexed_env_vars() session_manager = self._build_database_session_manager() @@ -99,13 +130,113 @@ def test_read_write_multiplexed(self): with self.assertRaises(NotImplementedError): session_manager.get_session_for_read_write() - def test_put_multiplexed(self): + def test_multiplexed_maintenance(self): + self._enable_multiplexed_env_vars() session_manager = self._build_database_session_manager() - with self.assertRaises(NotImplementedError): - session_manager.put_session( - Session(database=session_manager._database, is_multiplexed=True) + # Maintenance thread is started. + session_1 = session_manager.get_session_for_read_only() + self.assertTrue(session_1.is_multiplexed) + + # Wait for maintenance thread to execute. + def create_session_condition(): + return self._mocks["create_session"].call_count > 1 + + self.assert_true_with_timeout(create_session_condition) + + # Verify that maintenance thread created new multiplexed session. + session_2 = session_manager.get_session_for_read_only() + self.assertTrue(session_2.is_multiplexed) + self.assertNotEqual(session_1, session_2) + + def test_multiplexed_maintenance_terminates_not_implemented(self): + self._enable_multiplexed_env_vars() + session_manager = self._build_database_session_manager() + + # Maintenance thread is started. + session_1 = session_manager.get_session_for_read_only() + self.assertTrue(session_1.is_multiplexed) + + # Multiplexed sessions not implemented. + create_session_mock = self._mocks["create_session"] + create_session_mock.side_effect = MethodNotImplemented( + "Multiplexed sessions not implemented" + ) + + # Wait for maintenance thread to terminate. + thread = session_manager._multiplexed_session_maintenance_thread + + def thread_terminated_condition(): + return not thread.is_alive() + + self.assert_true_with_timeout(thread_terminated_condition) + + # Verify that multiplexed sessions are disabled. + session_options = session_manager._database._instance._client.session_options + self.assertFalse(session_options.use_multiplexed_for_read_only()) + self.assertFalse(session_options.use_multiplexed_for_partitioned()) + self.assertFalse(session_options.use_multiplexed_for_read_write()) + + def test_multiplexed_maintenance_terminates_disabled(self): + self._enable_multiplexed_env_vars() + session_manager = self._build_database_session_manager() + + # Maintenance thread is started. + session_1 = session_manager.get_session_for_read_only() + self.assertTrue(session_1.is_multiplexed) + + session_manager._is_multiplexed_sessions_disabled_event.set() + + # Wait for maintenance thread to terminate. + thread = session_manager._multiplexed_session_maintenance_thread + + def thread_terminated_condition(): + return not thread.is_alive() + + self.assert_true_with_timeout(thread_terminated_condition) + + def test_multiplexed_exception_method_not_implemented(self): + self._enable_multiplexed_env_vars() + session_manager = self._build_database_session_manager() + + # Multiplexed sessions not implemented. + self._mocks["create_session"].side_effect = [ + MethodNotImplemented("Test MethodNotImplemented"), + DEFAULT, + ] + + # Get session from pool. + session = session_manager.get_session_for_read_only() + self.assertFalse(session.is_multiplexed) + session_manager._pool.get.assert_called_once() + + # Return session to pool. + session_manager.put_session(session) + session_manager._pool.put.assert_called_once_with(session) + + # Verify that multiplexed session are disabled. + session_options = session_manager._database._instance._client.session_options + self.assertFalse(session_options.use_multiplexed_for_read_only()) + self.assertFalse(session_options.use_multiplexed_for_partitioned()) + self.assertFalse(session_options.use_multiplexed_for_read_write()) + + def test_exception_bad_request(self): + session_manager = self._build_database_session_manager() + + # Verify that BadRequest is not caught. + with self.assertRaises(BadRequest): + self._mocks["create_session"].side_effect = BadRequest("Test BadRequest") + session_manager.get_session_for_read_only() + + def test_exception_failed_precondition(self): + session_manager = self._build_database_session_manager() + + # Verify that FailedPrecondition is not caught. + with self.assertRaises(FailedPrecondition): + self._mocks["create_session"].side_effect = FailedPrecondition( + "Test FailedPrecondition" ) + session_manager.get_session_for_read_only() @staticmethod def _build_database_session_manager(): @@ -138,3 +269,16 @@ def _disable_multiplexed_env_vars(): """Sets environment variables to disable multiplexed sessions.""" os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + + @staticmethod + def assert_true_with_timeout(condition): + """Asserts that the given condition is met within a timeout period.""" + + sleep_seconds = 0.1 + timeout_seconds = 10 + + start_time = time.time() + while not condition() and time.time() - start_time < timeout_seconds: + time.sleep(sleep_seconds) + + assert condition() diff --git a/tests/unit/test_session_options.py b/tests/unit/test_session_options.py index 7eac21f3cc..324cb3cd43 100644 --- a/tests/unit/test_session_options.py +++ b/tests/unit/test_session_options.py @@ -87,7 +87,7 @@ def test_supported_env_var_values(self): os.environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" - true_values = ["1", "true", "True", "TRUE"] + true_values = ["1", " 1", " 1", "true", "True", "TRUE", " true "] for value in true_values: os.environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value self.assertTrue(session_options.use_multiplexed_for_read_only())