From 082f5554eaf8ea1668992e00450598561d6de981 Mon Sep 17 00:00:00 2001 From: spatel033 <177532680+spatel033@users.noreply.github.com> Date: Thu, 4 Dec 2025 12:33:38 +0000 Subject: [PATCH 1/2] enhance QueryRecorder to track which database bind is used for each query Replace flask_sqlalchemy's get_recorded_queries with SQLAlchemy event listeners that capture the bind_key for each engine. Add QueryInfo dataclass with statement, parameters, and bind_key fields. This enables tests to verify that queries are being routed to the correct database when using db.session_bulk. This is an alternative to the existing PR https://github.com/pallets-eco/flask-sqlalchemy/pull/1403 for flask-sqlalchemy. --- tests/utils.py | 42 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 914f274001..14c4de8e54 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,14 +1,46 @@ -import flask_sqlalchemy +from dataclasses import dataclass + +from sqlalchemy import event + +from app import db + + +@dataclass +class QueryInfo: + statement: str + parameters: tuple | dict | None + bind_key: str | None class QueryRecorder: def __init__(self): - self.queries = [] - self._count_on_enter = None + self.queries: list[QueryInfo] = [] + self._listeners = [] def __enter__(self): - self._count_on_enter = len(flask_sqlalchemy.record_queries.get_recorded_queries()) + # Register listeners for all engines to capture bind_key + for bind_key, engine in db.engines.items(): + listener = self._listener(bind_key) + event.listen(engine, "before_cursor_execute", listener) + self._listeners.append((engine, listener)) return self def __exit__(self, exc_type, exc_val, exc_tb): - self.queries = flask_sqlalchemy.record_queries.get_recorded_queries()[self._count_on_enter :] + # Remove all listeners + for engine, listener in self._listeners: + event.remove(engine, "before_cursor_execute", listener) + self._listeners.clear() + + def _listener(self, bind_key): + """Create a listener function that captures the bind_key in its closure.""" + + def listener(conn, cursor, statement, parameters, context, executemany): + self.queries.append( + QueryInfo( + statement=statement, + parameters=parameters, + bind_key=bind_key, + ) + ) + + return listener From af6d5e8386cbe1558b94f21ede089c272c79d9db Mon Sep 17 00:00:00 2001 From: spatel033 <177532680+spatel033@users.noreply.github.com> Date: Fri, 5 Dec 2025 14:28:48 +0000 Subject: [PATCH 2/2] update test_fact_billing_dao to assert bind_key --- tests/app/dao/test_fact_billing_dao.py | 227 +++++++++++++++---------- 1 file changed, 139 insertions(+), 88 deletions(-) diff --git a/tests/app/dao/test_fact_billing_dao.py b/tests/app/dao/test_fact_billing_dao.py index af444394a8..515ba33fe3 100644 --- a/tests/app/dao/test_fact_billing_dao.py +++ b/tests/app/dao/test_fact_billing_dao.py @@ -42,6 +42,7 @@ create_template, set_up_usage_data, ) +from tests.utils import QueryRecorder @pytest.fixture @@ -100,36 +101,42 @@ def sample_service_billing_fy_2018_variable_rates(sample_service): @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_today_includes_data_with_the_right_key_type(notify_db_session, session): +def test_fetch_billing_data_for_today_includes_data_with_the_right_key_type( + notify_db_session, session, expected_bind_key +): service = create_service() template = create_template(service=service, template_type="email") for key_type in ["normal", "test", "team"]: create_notification(template=template, status="delivered", key_type=key_type) today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 1 assert results[0].notifications_sent == 2 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) @pytest.mark.parametrize("notification_type", ["email", "sms", "letter"]) -def test_fetch_billing_data_for_day_only_calls_query_for_permission_type(notify_db_session, notification_type, session): +def test_fetch_billing_data_for_day_only_calls_query_for_permission_type( + notify_db_session, notification_type, session, expected_bind_key +): service = create_service(service_permissions=[notification_type]) email_template = create_template(service=service, template_type="email") sms_template = create_template(service=service, template_type="sms") @@ -138,21 +145,25 @@ def test_fetch_billing_data_for_day_only_calls_query_for_permission_type(notify_ create_notification(template=sms_template, status="delivered") create_notification(template=letter_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(process_day=today.date(), check_permissions=True, session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(process_day=today.date(), check_permissions=True, session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 1 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) @pytest.mark.parametrize("notification_type", ["email", "sms", "letter"]) -def test_fetch_billing_data_for_day_only_calls_query_for_all_channels(notify_db_session, notification_type, session): +def test_fetch_billing_data_for_day_only_calls_query_for_all_channels( + notify_db_session, notification_type, session, expected_bind_key +): service = create_service(service_permissions=[notification_type]) email_template = create_template(service=service, template_type="email") sms_template = create_template(service=service, template_type="sms") @@ -161,21 +172,23 @@ def test_fetch_billing_data_for_day_only_calls_query_for_all_channels(notify_db_ create_notification(template=sms_template, status="delivered") create_notification(template=letter_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(process_day=today.date(), check_permissions=False, session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(process_day=today.date(), check_permissions=False, session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 3 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) @freeze_time("2018-04-02 01:20:00") -def test_fetch_billing_data_for_today_includes_data_with_the_right_date(notify_db_session, session): +def test_fetch_billing_data_for_today_includes_data_with_the_right_date(notify_db_session, session, expected_bind_key): process_day = datetime(2018, 4, 1, 13, 30, 0) service = create_service() template = create_template(service=service, template_type="email") @@ -187,20 +200,25 @@ def test_fetch_billing_data_for_today_includes_data_with_the_right_date(notify_d day_under_test = convert_utc_to_bst(process_day) results = fetch_billing_data_for_day(day_under_test.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(day_under_test.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 1 assert results[0].notifications_sent == 2 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_template_and_notification_type(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_template_and_notification_type( + notify_db_session, session, expected_bind_key +): service = create_service() email_template = create_template(service=service, template_type="email") sms_template = create_template(service=service, template_type="sms") @@ -208,22 +226,24 @@ def test_fetch_billing_data_for_day_is_grouped_by_template_and_notification_type create_notification(template=sms_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].notifications_sent == 1 assert results[1].notifications_sent == 1 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_service(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_service(notify_db_session, session, expected_bind_key): service_1 = create_service() service_2 = create_service(service_name="Service 2") email_template = create_template(service=service_1) @@ -233,65 +253,72 @@ def test_fetch_billing_data_for_day_is_grouped_by_service(notify_db_session, ses today = convert_utc_to_bst(datetime.utcnow()) results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].notifications_sent == 1 assert results[1].notifications_sent == 1 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_provider(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_provider(notify_db_session, session, expected_bind_key): service = create_service() template = create_template(service=service) create_notification(template=template, status="delivered", sent_by="mmg") create_notification(template=template, status="delivered", sent_by="firetext") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].notifications_sent == 1 assert results[1].notifications_sent == 1 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_rate_mulitplier(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_rate_mulitplier(notify_db_session, session, expected_bind_key): service = create_service() template = create_template(service=service) create_notification(template=template, status="delivered", rate_multiplier=1) create_notification(template=template, status="delivered", rate_multiplier=2) today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].notifications_sent == 1 assert results[1].notifications_sent == 1 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_international(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_international(notify_db_session, session, expected_bind_key): service = create_service() sms_template = create_template(service=service) letter_template = create_template(template_type="letter", service=service) @@ -301,21 +328,23 @@ def test_fetch_billing_data_for_day_is_grouped_by_international(notify_db_sessio create_notification(template=letter_template, status="delivered", international=False) today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 4 assert all(result.notifications_sent == 1 for result in results) @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_notification_type(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_notification_type(notify_db_session, session, expected_bind_key): service = create_service() sms_template = create_template(service=service, template_type="sms") email_template = create_template(service=service, template_type="email") @@ -328,22 +357,24 @@ def test_fetch_billing_data_for_day_is_grouped_by_notification_type(notify_db_se create_notification(template=letter_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 3 notification_types = [x.notification_type for x in results] assert len(notification_types) == 3 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_groups_by_postage(notify_db_session, session): +def test_fetch_billing_data_for_day_groups_by_postage(notify_db_session, session, expected_bind_key): service = create_service() letter_template = create_template(service=service, template_type="letter") email_template = create_template(service=service, template_type="email") @@ -356,20 +387,22 @@ def test_fetch_billing_data_for_day_groups_by_postage(notify_db_session, session create_notification(template=email_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 6 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_groups_by_sent_by(notify_db_session, session): +def test_fetch_billing_data_for_day_groups_by_sent_by(notify_db_session, session, expected_bind_key): service = create_service() letter_template = create_template(service=service, template_type="letter") email_template = create_template(service=service, template_type="email") @@ -379,20 +412,22 @@ def test_fetch_billing_data_for_day_groups_by_sent_by(notify_db_session, session create_notification(template=email_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_groups_by_page_count(notify_db_session, session): +def test_fetch_billing_data_for_day_groups_by_page_count(notify_db_session, session, expected_bind_key): service = create_service() letter_template = create_template(service=service, template_type="letter") email_template = create_template(service=service, template_type="email") @@ -402,20 +437,24 @@ def test_fetch_billing_data_for_day_groups_by_page_count(notify_db_session, sess create_notification(template=email_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 3 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_sets_postage_for_emails_and_sms_to_none(notify_db_session, session): +def test_fetch_billing_data_for_day_sets_postage_for_emails_and_sms_to_none( + notify_db_session, session, expected_bind_key +): service = create_service() sms_template = create_template(service=service, template_type="sms") email_template = create_template(service=service, template_type="email") @@ -423,37 +462,41 @@ def test_fetch_billing_data_for_day_sets_postage_for_emails_and_sms_to_none(noti create_notification(template=email_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].postage == "none" assert results[1].postage == "none" @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_returns_empty_list(notify_db_session, session): +def test_fetch_billing_data_for_day_returns_empty_list(notify_db_session, session, expected_bind_key): today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert results == [] @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_uses_correct_table(notify_db_session, session): +def test_fetch_billing_data_for_day_uses_correct_table(notify_db_session, session, expected_bind_key): service = create_service() create_service_data_retention(service, notification_type="email", days_of_retention=3) sms_template = create_template(service=service, template_type="sms") @@ -464,8 +507,12 @@ def test_fetch_billing_data_for_day_uses_correct_table(notify_db_session, sessio create_notification_history(template=email_template, status="delivered", created_at=five_days_ago) service_id = service.id - results = fetch_billing_data_for_day(process_day=five_days_ago.date(), service_ids=[service_id], session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day( + process_day=five_days_ago.date(), service_ids=[service_id], session=session + ) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].notification_type == "sms" assert results[0].notifications_sent == 1 @@ -474,14 +521,14 @@ def test_fetch_billing_data_for_day_uses_correct_table(notify_db_session, sessio @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_returns_list_for_given_service(notify_db_session, session): +def test_fetch_billing_data_for_day_returns_list_for_given_service(notify_db_session, session, expected_bind_key): service = create_service() service_2 = create_service(service_name="Service 2") template = create_template(service=service) @@ -491,21 +538,23 @@ def test_fetch_billing_data_for_day_returns_list_for_given_service(notify_db_ses service_id = service.id today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(process_day=today.date(), service_ids=[service_id], session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(process_day=today.date(), service_ids=[service_id], session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 1 assert results[0].service_id == service_id @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_bills_correctly_for_status(notify_db_session, session): +def test_fetch_billing_data_for_day_bills_correctly_for_status(notify_db_session, session, expected_bind_key): service = create_service() sms_template = create_template(service=service, template_type="sms") email_template = create_template(service=service, template_type="email") @@ -518,8 +567,10 @@ def test_fetch_billing_data_for_day_bills_correctly_for_status(notify_db_session service_id = service.id today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(process_day=today.date(), service_ids=[service_id], session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(process_day=today.date(), service_ids=[service_id], session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} sms_results = [x for x in results if x.notification_type == "sms"] email_results = [x for x in results if x.notification_type == "email"] letter_results = [x for x in results if x.notification_type == "letter"]