diff --git a/functions-python/tasks_executor/src/tasks/dataset_files/README.md b/functions-python/tasks_executor/src/tasks/dataset_files/README.md index 95c0aafc9..9402b8e52 100644 --- a/functions-python/tasks_executor/src/tasks/dataset_files/README.md +++ b/functions-python/tasks_executor/src/tasks/dataset_files/README.md @@ -19,7 +19,8 @@ The function accepts the following payload: { "dry_run": true, // [optional] If true, do not upload or modify the database (default: true) "after_date": "YYYY-MM-DD", // [optional] Only include datasets downloaded after this ISO date - "latest_only": true // [optional] If true, only process the latest version of each dataset (default: true) + "latest_only": true, // [optional] If true, only process the latest version of each dataset (default: true) + "dataset_id": id // [optional] If provided, only process the specified dataset. It will supersede the after_date and latest_only parameters. } ``` @@ -32,7 +33,13 @@ The function accepts the following payload: "latest_only": true } ``` - +or +```json +{ + "dry_run": false, + "dataset_id": "mdb-1147-202407031702" +} +``` --- ## What It Does @@ -52,6 +59,9 @@ For each GTFS dataset with missing file information (missing zipped/unzipped siz 7. Computes SHA256 hashes for each file 8. Stores metadata in the `Gtfsfile` table for later use +If the `dataset_id` parameter is provided, the process is a bit simplified. It does not download the dataset as it is +assumed the dataset is already present in the bucket. The rest of the processing is the same. + --- ## GCP Environment Variables diff --git a/functions-python/tasks_executor/src/tasks/dataset_files/rebuild_missing_dataset_files.py b/functions-python/tasks_executor/src/tasks/dataset_files/rebuild_missing_dataset_files.py index 00c1b07db..33bc656ee 100644 --- a/functions-python/tasks_executor/src/tasks/dataset_files/rebuild_missing_dataset_files.py +++ b/functions-python/tasks_executor/src/tasks/dataset_files/rebuild_missing_dataset_files.py @@ -23,9 +23,13 @@ def rebuild_missing_dataset_files_handler(payload) -> dict: dry_run = payload.get("dry_run", True) after_date = payload.get("after_date", None) latest_only = payload.get("latest_only", True) + dataset_id = payload.get("dataset_id", None) return rebuild_missing_dataset_files( - dry_run=dry_run, after_date=after_date, latest_only=latest_only + dry_run=dry_run, + after_date=after_date, + latest_only=latest_only, + dataset_id=dataset_id, ) @@ -67,6 +71,7 @@ def rebuild_missing_dataset_files( dry_run: bool = True, after_date: str = None, latest_only: bool = True, + dataset_id: str = None, ) -> dict: """ Processes GTFS datasets missing extracted files and updates database. @@ -76,13 +81,22 @@ def rebuild_missing_dataset_files( dry_run (bool): If True, only logs how many would be processed. after_date (str): Only consider datasets downloaded after this ISO date. latest_only (bool): Whether to include only latest datasets. + dataset_id (str | None): If provided, only process the dataset with this stable id. Returns: dict: Result summary. """ - datasets = get_datasets_with_missing_files_query( - db_session, after_date=after_date, latest_only=latest_only - ) + + if dataset_id: + datasets = ( + db_session.query(Gtfsdataset) + .filter(Gtfsdataset.stable_id == dataset_id) + .options(joinedload(Gtfsdataset.feed)) + ) + else: + datasets = get_datasets_with_missing_files_query( + db_session, after_date=after_date, latest_only=latest_only + ) if dry_run: total = datasets.count() @@ -102,6 +116,9 @@ def rebuild_missing_dataset_files( logging.info("Starting to process datasets with missing files...") execution_id = f"task-executor-uuid-{uuid.uuid4()}" messages = [] + all_datasets_count = datasets.count() + topic = (os.getenv("DATASET_PROCESSING_TOPIC_NAME"),) + for dataset in datasets.all(): try: message = { @@ -124,16 +141,13 @@ def rebuild_missing_dataset_files( count += 1 total_processed += 1 - if count % batch_count == 0: - publish_messages( - messages, - os.getenv("PROJECT_ID"), - os.getenv("DATASET_PROCESSING_TOPIC_NAME"), - ) + if count % batch_count == 0 or all_datasets_count == count: + publish_messages(messages, os.getenv("PROJECT_ID"), topic) messages = [] logging.info( - "Published message for %d datasets. Total processed: %d", - batch_count, + "Published message to topic %s for %d datasets. Total processed: %d", + topic, + batch_count if count % batch_count == 0 else all_datasets_count - count, total_processed, ) @@ -147,6 +161,7 @@ def rebuild_missing_dataset_files( "after_date": after_date, "latest_only": latest_only, "datasets_bucket_name": os.environ.get("DATASETS_BUCKET_NAME"), + "dataset_id": dataset_id, }, } logging.info("Task summary: %s", result) diff --git a/functions-python/tasks_executor/tests/tasks/dataset_files/test_rebuild_missing_dataset_files.py b/functions-python/tasks_executor/tests/tasks/dataset_files/test_rebuild_missing_dataset_files.py index b02694e44..b386d02dc 100644 --- a/functions-python/tasks_executor/tests/tasks/dataset_files/test_rebuild_missing_dataset_files.py +++ b/functions-python/tasks_executor/tests/tasks/dataset_files/test_rebuild_missing_dataset_files.py @@ -18,7 +18,8 @@ import os import unittest from datetime import datetime -from unittest.mock import patch +from types import SimpleNamespace +from unittest.mock import patch, MagicMock from sqlalchemy.orm import Session @@ -46,7 +47,7 @@ def test_handler_calls_main_function(self, mock_rebuild_func): self.assertEqual(response["message"], "test") mock_rebuild_func.assert_called_once_with( - dry_run=True, after_date="2024-01-01", latest_only=False + dry_run=True, after_date="2024-01-01", latest_only=False, dataset_id=None ) @with_db_session(db_url=default_db_url) @@ -93,3 +94,101 @@ def test_rebuild_missing_dataset_files_processing( self.assertIn("completed", response["message"]) self.assertGreaterEqual(response["total_processed"], 0) self.assertTrue(publish_mock.called or response["total_processed"] == 0) + + +class TestRebuildSpecificDatasetFiles(unittest.TestCase): + @patch( + "tasks.dataset_files.rebuild_missing_dataset_files.rebuild_missing_dataset_files" + ) + def test_handler_calls_main_function(self, mock_rebuild_func): + mock_rebuild_func.return_value = {"message": "test", "total_processed": 0} + payload = {"dry_run": True, "after_date": "2024-01-01", "latest_only": False} + + response = rebuild_missing_dataset_files_handler(payload) + + self.assertEqual(response["message"], "test") + mock_rebuild_func.assert_called_once_with( + dry_run=True, after_date="2024-01-01", latest_only=False, dataset_id=None + ) + + @patch( + "tasks.dataset_files.rebuild_missing_dataset_files.rebuild_missing_dataset_files" + ) + def test_handler_forwards_dataset_id(self, mock_rebuild_func): + payload = { + "dry_run": False, + "after_date": None, + "latest_only": True, + "dataset_id": "ds-123", + } + + rebuild_missing_dataset_files_handler(payload) + + mock_rebuild_func.assert_called_once_with( + dry_run=False, after_date=None, latest_only=True, dataset_id="ds-123" + ) + + def test_rebuild_with_specific_dataset_id_publishes_one_message(self): + dataset_stable_id = "ds-123" + fake_feed = SimpleNamespace( + producer_url="https://example.com", + stable_id="feed-stable", + id=42, + authentication_type=None, + authentication_info_url=None, + api_key_parameter_name=None, + ) + fake_dataset = SimpleNamespace( + stable_id=dataset_stable_id, hash="abc123", feed=fake_feed + ) + + # Mock the chained SQLAlchemy calls: + # db_session.query(Gtfsdataset).filter(...).options(...).count()/all() + db_session = MagicMock() + query_mock = MagicMock() + filter_mock = MagicMock() + options_mock = MagicMock() + + db_session.query.return_value = query_mock + query_mock.filter.return_value = filter_mock + filter_mock.options.return_value = options_mock + + options_mock.count.return_value = 1 + options_mock.all.return_value = [fake_dataset] + + with patch.dict( + os.environ, + {"PROJECT_ID": "test-project", "DATASET_PROCESSING_TOPIC_NAME": "topic"}, + clear=False, + ), patch( + "tasks.dataset_files.rebuild_missing_dataset_files.get_datasets_with_missing_files_query" + ) as get_query_mock, patch( + "tasks.dataset_files.rebuild_missing_dataset_files.publish_messages" + ) as mock_publish: + from tasks.dataset_files.rebuild_missing_dataset_files import ( + rebuild_missing_dataset_files, + Gtfsdataset, + ) + + result = rebuild_missing_dataset_files( + db_session=db_session, + dry_run=False, + after_date=None, + latest_only=True, # ignored when dataset_id is provided + dataset_id=dataset_stable_id, + ) + + # Asserts + get_query_mock.assert_not_called() # bypasses generic query when dataset_id is set + db_session.query.assert_called_once_with(Gtfsdataset) + query_mock.filter.assert_called_once() # filtered by stable_id + options_mock.count.assert_called_once() + options_mock.all.assert_called_once() + + self.assertEqual(result["total_processed"], 1) + mock_publish.assert_called_once() + + messages_arg, project_id_arg, _topic_arg = mock_publish.call_args[0] + self.assertEqual(project_id_arg, "test-project") + self.assertEqual(len(messages_arg), 1) + self.assertEqual(messages_arg[0]["dataset_stable_id"], dataset_stable_id)