diff --git a/postgres/changelog.d/23728.fixed b/postgres/changelog.d/23728.fixed new file mode 100644 index 0000000000000..aec07d0c61155 --- /dev/null +++ b/postgres/changelog.d/23728.fixed @@ -0,0 +1 @@ +Fix a crash caused by cancel closing database connections while the check is still running. \ No newline at end of file diff --git a/postgres/datadog_checks/postgres/postgres.py b/postgres/datadog_checks/postgres/postgres.py index dac424e3f9c5f..dbd4bd067bc7c 100644 --- a/postgres/datadog_checks/postgres/postgres.py +++ b/postgres/datadog_checks/postgres/postgres.py @@ -5,6 +5,7 @@ import copy import functools import os +import threading from string import Template from time import time @@ -17,7 +18,6 @@ from datadog_checks.base.utils.db.core import QueryManager from datadog_checks.base.utils.db.health import HealthEvent, HealthStatus from datadog_checks.base.utils.db.utils import ( - DBMAsyncJob, default_json_event_encoding, tracked_query, ) @@ -194,6 +194,10 @@ def __init__(self, name, init_config, instances): self.diagnosis.register(functools.partial(run_diagnostics, self)) + self._cancel_lock = threading.Lock() + self._is_running = False + self._cancelled = False + def database_monitoring_column_statistics(self, raw_event: str): self.event_platform_event(raw_event, "dbm-column-statistics") @@ -476,38 +480,87 @@ def dynamic_queries(self): return self._dynamic_queries - @staticmethod - def _cancel_async_job(job: DBMAsyncJob): - job.cancel() - if job._job_loop_future: - job._job_loop_future.result() - job._job_loop_future = None - job._shutdown() + def run(self): + # TODO: move this lock into the base class + with self._cancel_lock: + if self._cancelled: + self.log.debug("run() skipped, check already cancelled") + return '' + self._is_running = True + try: + return super().run() + finally: + needs_finalize = False + with self._cancel_lock: + self._is_running = False + if self._cancelled: + needs_finalize = True + if needs_finalize: + self.log.debug("Check cancel has been signaled, finalizing now that run() is complete") + self._finalize() def cancel(self): + """Signal that the check is being unscheduled. + + This method can be called while check() is running on another thread + (the GIL is released during psycopg I/O). It must not perform any + destructive operations — closing connections or nulling attributes that + check() depends on — because that causes a SIGSEGV in libpq when + check() resumes. + + Destructive cleanup is deferred to _finalize(), which is called either + here (if the check is idle) or by run()'s finally block (if the check + is in-flight). The Agent guarantees it will not call run() again after + cancel(). """ - Cancels and sends cancel signal to all threads. - """ + self.log.debug("Marking check as cancelled") + self._cancel_async_jobs() + needs_finalize = False + with self._cancel_lock: + self._cancelled = True + if not self._is_running: + needs_finalize = True + if needs_finalize: + self.log.debug("cancel() finalizing immediately, check is idle") + self._finalize() + else: + self.log.debug("cancel() deferred finalize, check is still running") + + @property + def _async_jobs(self): + """Return the async jobs active for this check's configuration.""" + jobs = [] if self._config.dbm: - self._cancel_async_job(self.statement_metrics) - self._cancel_async_job(self.statement_samples) - self._cancel_async_job(self.metadata_samples) + jobs.extend([self.statement_metrics, self.statement_samples, self.metadata_samples]) elif self._config.data_observability.enabled: - self._cancel_async_job(self.metadata_samples) + jobs.append(self.metadata_samples) if self._config.data_observability.enabled: - self._cancel_async_job(self.data_observability) + jobs.append(self.data_observability) + return jobs + + def _cancel_async_jobs(self): + """Signal async jobs to stop. Safe to call while check() is running.""" + for job in self._async_jobs: + job.cancel() + + def _finalize(self): + """Tear down check state. Must not run while check() is executing.""" + self.log.debug("Finalizing check: closing connections and clearing state") + for job in self._async_jobs: + if job._job_loop_future: + job._job_loop_future.result() + job._job_loop_future = None + job._shutdown() self._clean_state() - self._query_manager = None - self.health = None self.check_initializations.clear() # TODO: move diagnosis cleanup into AgentCheck.cancel() in the base class self._diagnosis = None + self.log.check = None + self._query_manager = None + self.health = None self._close_db() self._close_db_pool() - # CheckLoggingAdapter holds self.check until check_id is resolved via - # process(), which only happens after the agent scheduler calls run(). - # If cancel() is called before that, the back-reference is never cleared. - self.log.check = None + self.log.debug("Check cleanup complete") def _clean_state(self): self.log.debug("Cleaning state") @@ -1191,14 +1244,15 @@ def check(self, _): if not self._config.only_custom_queries: self._collect_stats(tags) - if self._config.dbm: - self.statement_metrics.run_job_loop(tags) - self.statement_samples.run_job_loop(tags) - self.metadata_samples.run_job_loop(tags) - elif self._config.data_observability.enabled: - self.metadata_samples.run_job_loop(tags) - if self._config.data_observability.enabled: - self.data_observability.run_job_loop(tags) + if not self._cancelled: + if self._config.dbm: + self.statement_metrics.run_job_loop(tags) + self.statement_samples.run_job_loop(tags) + self.metadata_samples.run_job_loop(tags) + elif self._config.data_observability.enabled: + self.metadata_samples.run_job_loop(tags) + if self._config.data_observability.enabled: + self.data_observability.run_job_loop(tags) if self._config.collect_wal_metrics is True: # collect wal metrics for pg < 10 only when explicitly enabled # (requires local filesystem access to the WAL directory) diff --git a/postgres/tests/test_unit.py b/postgres/tests/test_unit.py index 7efdb71479139..5505131e71ce1 100644 --- a/postgres/tests/test_unit.py +++ b/postgres/tests/test_unit.py @@ -443,6 +443,80 @@ def test_check_gc_after_cancel(pg_instance): gc.enable() +def test_cancel_during_running_check_defers_finalize(pg_instance): + """Verify that cancel() during an in-flight check() does not close connections. + + Destructive cleanup (_finalize) must be deferred until run() completes so + that check() never accesses a closed psycopg connection, which would cause + a SIGSEGV in libpq. + """ + import threading + + check = PostgreSql('postgres', {}, [pg_instance]) + conn = mock.MagicMock() + check._db = conn + + check_started = threading.Event() + cancel_done = threading.Event() + + def slow_run(self_arg): + check_started.set() + cancel_done.wait(timeout=5) + return '' + + run_result = [None] + + def run_check(): + with mock.patch.object(type(check).__mro__[1], 'run', slow_run): + run_result[0] = check.run() + + run_thread = threading.Thread(target=run_check) + run_thread.start() + + check_started.wait(timeout=5) + + check.cancel() + # cancel() should have signaled but NOT finalized since run() is in-flight + assert not conn.close.called, "_close_db() ran while check() was still executing" + assert check._cancelled is True + + cancel_done.set() + run_thread.join(timeout=5) + + # After run() completes, _finalize() should have been called + conn.close.assert_called_once() + assert check._db is None + assert check._query_manager is None + assert check.health is None + + +def test_cancel_on_idle_check_finalizes_immediately(pg_instance): + """Verify that cancel() on an idle check runs _finalize() inline.""" + check = PostgreSql('postgres', {}, [pg_instance]) + conn = mock.MagicMock() + check._db = conn + + assert not check._is_running + + check.cancel() + + conn.close.assert_called_once() + assert check._db is None + assert check._query_manager is None + assert check.health is None + + +def test_run_after_cancel_returns_immediately(pg_instance): + """Verify that run() returns '' without executing check() if already cancelled.""" + check = PostgreSql('postgres', {}, [pg_instance]) + check.cancel() + + with mock.patch.object(check, 'check', side_effect=AssertionError("check() should not be called")): + result = check.run() + + assert result == '' + + def test_collect_column_statistics_updates_timestamp_on_failure(pg_instance): pg_instance['dbm'] = True pg_instance['collect_column_statistics'] = {'enabled': True, 'collection_interval': 60}