diff --git a/tests/app/dao/test_fact_billing_dao.py b/tests/app/dao/test_fact_billing_dao.py index af444394a8..515ba33fe3 100644 --- a/tests/app/dao/test_fact_billing_dao.py +++ b/tests/app/dao/test_fact_billing_dao.py @@ -42,6 +42,7 @@ create_template, set_up_usage_data, ) +from tests.utils import QueryRecorder @pytest.fixture @@ -100,36 +101,42 @@ def sample_service_billing_fy_2018_variable_rates(sample_service): @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_today_includes_data_with_the_right_key_type(notify_db_session, session): +def test_fetch_billing_data_for_today_includes_data_with_the_right_key_type( + notify_db_session, session, expected_bind_key +): service = create_service() template = create_template(service=service, template_type="email") for key_type in ["normal", "test", "team"]: create_notification(template=template, status="delivered", key_type=key_type) today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 1 assert results[0].notifications_sent == 2 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) @pytest.mark.parametrize("notification_type", ["email", "sms", "letter"]) -def test_fetch_billing_data_for_day_only_calls_query_for_permission_type(notify_db_session, notification_type, session): +def test_fetch_billing_data_for_day_only_calls_query_for_permission_type( + notify_db_session, notification_type, session, expected_bind_key +): service = create_service(service_permissions=[notification_type]) email_template = create_template(service=service, template_type="email") sms_template = create_template(service=service, template_type="sms") @@ -138,21 +145,25 @@ def test_fetch_billing_data_for_day_only_calls_query_for_permission_type(notify_ create_notification(template=sms_template, status="delivered") create_notification(template=letter_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(process_day=today.date(), check_permissions=True, session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(process_day=today.date(), check_permissions=True, session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 1 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) @pytest.mark.parametrize("notification_type", ["email", "sms", "letter"]) -def test_fetch_billing_data_for_day_only_calls_query_for_all_channels(notify_db_session, notification_type, session): +def test_fetch_billing_data_for_day_only_calls_query_for_all_channels( + notify_db_session, notification_type, session, expected_bind_key +): service = create_service(service_permissions=[notification_type]) email_template = create_template(service=service, template_type="email") sms_template = create_template(service=service, template_type="sms") @@ -161,21 +172,23 @@ def test_fetch_billing_data_for_day_only_calls_query_for_all_channels(notify_db_ create_notification(template=sms_template, status="delivered") create_notification(template=letter_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(process_day=today.date(), check_permissions=False, session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(process_day=today.date(), check_permissions=False, session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 3 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) @freeze_time("2018-04-02 01:20:00") -def test_fetch_billing_data_for_today_includes_data_with_the_right_date(notify_db_session, session): +def test_fetch_billing_data_for_today_includes_data_with_the_right_date(notify_db_session, session, expected_bind_key): process_day = datetime(2018, 4, 1, 13, 30, 0) service = create_service() template = create_template(service=service, template_type="email") @@ -187,20 +200,25 @@ def test_fetch_billing_data_for_today_includes_data_with_the_right_date(notify_d day_under_test = convert_utc_to_bst(process_day) results = fetch_billing_data_for_day(day_under_test.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(day_under_test.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 1 assert results[0].notifications_sent == 2 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_template_and_notification_type(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_template_and_notification_type( + notify_db_session, session, expected_bind_key +): service = create_service() email_template = create_template(service=service, template_type="email") sms_template = create_template(service=service, template_type="sms") @@ -208,22 +226,24 @@ def test_fetch_billing_data_for_day_is_grouped_by_template_and_notification_type create_notification(template=sms_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].notifications_sent == 1 assert results[1].notifications_sent == 1 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_service(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_service(notify_db_session, session, expected_bind_key): service_1 = create_service() service_2 = create_service(service_name="Service 2") email_template = create_template(service=service_1) @@ -233,65 +253,72 @@ def test_fetch_billing_data_for_day_is_grouped_by_service(notify_db_session, ses today = convert_utc_to_bst(datetime.utcnow()) results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].notifications_sent == 1 assert results[1].notifications_sent == 1 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_provider(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_provider(notify_db_session, session, expected_bind_key): service = create_service() template = create_template(service=service) create_notification(template=template, status="delivered", sent_by="mmg") create_notification(template=template, status="delivered", sent_by="firetext") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].notifications_sent == 1 assert results[1].notifications_sent == 1 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_rate_mulitplier(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_rate_mulitplier(notify_db_session, session, expected_bind_key): service = create_service() template = create_template(service=service) create_notification(template=template, status="delivered", rate_multiplier=1) create_notification(template=template, status="delivered", rate_multiplier=2) today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].notifications_sent == 1 assert results[1].notifications_sent == 1 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_international(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_international(notify_db_session, session, expected_bind_key): service = create_service() sms_template = create_template(service=service) letter_template = create_template(template_type="letter", service=service) @@ -301,21 +328,23 @@ def test_fetch_billing_data_for_day_is_grouped_by_international(notify_db_sessio create_notification(template=letter_template, status="delivered", international=False) today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 4 assert all(result.notifications_sent == 1 for result in results) @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_is_grouped_by_notification_type(notify_db_session, session): +def test_fetch_billing_data_for_day_is_grouped_by_notification_type(notify_db_session, session, expected_bind_key): service = create_service() sms_template = create_template(service=service, template_type="sms") email_template = create_template(service=service, template_type="email") @@ -328,22 +357,24 @@ def test_fetch_billing_data_for_day_is_grouped_by_notification_type(notify_db_se create_notification(template=letter_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 3 notification_types = [x.notification_type for x in results] assert len(notification_types) == 3 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_groups_by_postage(notify_db_session, session): +def test_fetch_billing_data_for_day_groups_by_postage(notify_db_session, session, expected_bind_key): service = create_service() letter_template = create_template(service=service, template_type="letter") email_template = create_template(service=service, template_type="email") @@ -356,20 +387,22 @@ def test_fetch_billing_data_for_day_groups_by_postage(notify_db_session, session create_notification(template=email_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 6 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_groups_by_sent_by(notify_db_session, session): +def test_fetch_billing_data_for_day_groups_by_sent_by(notify_db_session, session, expected_bind_key): service = create_service() letter_template = create_template(service=service, template_type="letter") email_template = create_template(service=service, template_type="email") @@ -379,20 +412,22 @@ def test_fetch_billing_data_for_day_groups_by_sent_by(notify_db_session, session create_notification(template=email_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_groups_by_page_count(notify_db_session, session): +def test_fetch_billing_data_for_day_groups_by_page_count(notify_db_session, session, expected_bind_key): service = create_service() letter_template = create_template(service=service, template_type="letter") email_template = create_template(service=service, template_type="email") @@ -402,20 +437,24 @@ def test_fetch_billing_data_for_day_groups_by_page_count(notify_db_session, sess create_notification(template=email_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 3 @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_sets_postage_for_emails_and_sms_to_none(notify_db_session, session): +def test_fetch_billing_data_for_day_sets_postage_for_emails_and_sms_to_none( + notify_db_session, session, expected_bind_key +): service = create_service() sms_template = create_template(service=service, template_type="sms") email_template = create_template(service=service, template_type="email") @@ -423,37 +462,41 @@ def test_fetch_billing_data_for_day_sets_postage_for_emails_and_sms_to_none(noti create_notification(template=email_template, status="delivered") today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].postage == "none" assert results[1].postage == "none" @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_returns_empty_list(notify_db_session, session): +def test_fetch_billing_data_for_day_returns_empty_list(notify_db_session, session, expected_bind_key): today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(today.date(), session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(today.date(), session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert results == [] @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_uses_correct_table(notify_db_session, session): +def test_fetch_billing_data_for_day_uses_correct_table(notify_db_session, session, expected_bind_key): service = create_service() create_service_data_retention(service, notification_type="email", days_of_retention=3) sms_template = create_template(service=service, template_type="sms") @@ -464,8 +507,12 @@ def test_fetch_billing_data_for_day_uses_correct_table(notify_db_session, sessio create_notification_history(template=email_template, status="delivered", created_at=five_days_ago) service_id = service.id - results = fetch_billing_data_for_day(process_day=five_days_ago.date(), service_ids=[service_id], session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day( + process_day=five_days_ago.date(), service_ids=[service_id], session=session + ) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 2 assert results[0].notification_type == "sms" assert results[0].notifications_sent == 1 @@ -474,14 +521,14 @@ def test_fetch_billing_data_for_day_uses_correct_table(notify_db_session, sessio @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_returns_list_for_given_service(notify_db_session, session): +def test_fetch_billing_data_for_day_returns_list_for_given_service(notify_db_session, session, expected_bind_key): service = create_service() service_2 = create_service(service_name="Service 2") template = create_template(service=service) @@ -491,21 +538,23 @@ def test_fetch_billing_data_for_day_returns_list_for_given_service(notify_db_ses service_id = service.id today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(process_day=today.date(), service_ids=[service_id], session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(process_day=today.date(), service_ids=[service_id], session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 1 assert results[0].service_id == service_id @pytest.mark.parametrize( - "session", + "session,expected_bind_key", ( - db.session, - db.session_bulk, + (db.session, None), + (db.session_bulk, "bulk"), ), ids=("default", "bulk"), ) -def test_fetch_billing_data_for_day_bills_correctly_for_status(notify_db_session, session): +def test_fetch_billing_data_for_day_bills_correctly_for_status(notify_db_session, session, expected_bind_key): service = create_service() sms_template = create_template(service=service, template_type="sms") email_template = create_template(service=service, template_type="email") @@ -518,8 +567,10 @@ def test_fetch_billing_data_for_day_bills_correctly_for_status(notify_db_session service_id = service.id today = convert_utc_to_bst(datetime.utcnow()) - results = fetch_billing_data_for_day(process_day=today.date(), service_ids=[service_id], session=session) + with QueryRecorder() as query_recorder: + results = fetch_billing_data_for_day(process_day=today.date(), service_ids=[service_id], session=session) + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} sms_results = [x for x in results if x.notification_type == "sms"] email_results = [x for x in results if x.notification_type == "email"] letter_results = [x for x in results if x.notification_type == "letter"] diff --git a/tests/utils.py b/tests/utils.py index 914f274001..14c4de8e54 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,14 +1,46 @@ -import flask_sqlalchemy +from dataclasses import dataclass + +from sqlalchemy import event + +from app import db + + +@dataclass +class QueryInfo: + statement: str + parameters: tuple | dict | None + bind_key: str | None class QueryRecorder: def __init__(self): - self.queries = [] - self._count_on_enter = None + self.queries: list[QueryInfo] = [] + self._listeners = [] def __enter__(self): - self._count_on_enter = len(flask_sqlalchemy.record_queries.get_recorded_queries()) + # Register listeners for all engines to capture bind_key + for bind_key, engine in db.engines.items(): + listener = self._listener(bind_key) + event.listen(engine, "before_cursor_execute", listener) + self._listeners.append((engine, listener)) return self def __exit__(self, exc_type, exc_val, exc_tb): - self.queries = flask_sqlalchemy.record_queries.get_recorded_queries()[self._count_on_enter :] + # Remove all listeners + for engine, listener in self._listeners: + event.remove(engine, "before_cursor_execute", listener) + self._listeners.clear() + + def _listener(self, bind_key): + """Create a listener function that captures the bind_key in its closure.""" + + def listener(conn, cursor, statement, parameters, context, executemany): + self.queries.append( + QueryInfo( + statement=statement, + parameters=parameters, + bind_key=bind_key, + ) + ) + + return listener