Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions functions-python/tasks_executor/src/tasks/dataset_files/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ The function accepts the following payload:
{
"dry_run": true, // [optional] If true, do not upload or modify the database (default: true)
"after_date": "YYYY-MM-DD", // [optional] Only include datasets downloaded after this ISO date
"latest_only": true // [optional] If true, only process the latest version of each dataset (default: true)
"latest_only": true, // [optional] If true, only process the latest version of each dataset (default: true)
"dataset_id": id // [optional] If provided, only process the specified dataset. It will supersede the after_date and latest_only parameters.
}
```

Expand All @@ -32,7 +33,13 @@ The function accepts the following payload:
"latest_only": true
}
```

or
```json
{
"dry_run": false,
"dataset_id": "mdb-1147-202407031702"
}
```
---

## What It Does
Expand All @@ -52,6 +59,9 @@ For each GTFS dataset with missing file information (missing zipped/unzipped siz
7. Computes SHA256 hashes for each file
8. Stores metadata in the `Gtfsfile` table for later use

If the `dataset_id` parameter is provided, the process is a bit simplified. It does not download the dataset as it is
assumed the dataset is already present in the bucket. The rest of the processing is the same.

---

## GCP Environment Variables
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ def rebuild_missing_dataset_files_handler(payload) -> dict:
dry_run = payload.get("dry_run", True)
after_date = payload.get("after_date", None)
latest_only = payload.get("latest_only", True)
dataset_id = payload.get("dataset_id", None)

return rebuild_missing_dataset_files(
dry_run=dry_run, after_date=after_date, latest_only=latest_only
dry_run=dry_run,
after_date=after_date,
latest_only=latest_only,
dataset_id=dataset_id,
)


Expand Down Expand Up @@ -67,6 +71,7 @@ def rebuild_missing_dataset_files(
dry_run: bool = True,
after_date: str = None,
latest_only: bool = True,
dataset_id: str = None,
) -> dict:
"""
Processes GTFS datasets missing extracted files and updates database.
Expand All @@ -76,13 +81,22 @@ def rebuild_missing_dataset_files(
dry_run (bool): If True, only logs how many would be processed.
after_date (str): Only consider datasets downloaded after this ISO date.
latest_only (bool): Whether to include only latest datasets.
dataset_id (str | None): If provided, only process the dataset with this stable id.

Returns:
dict: Result summary.
"""
datasets = get_datasets_with_missing_files_query(
db_session, after_date=after_date, latest_only=latest_only
)

if dataset_id:
datasets = (
db_session.query(Gtfsdataset)
.filter(Gtfsdataset.stable_id == dataset_id)
.options(joinedload(Gtfsdataset.feed))
)
else:
datasets = get_datasets_with_missing_files_query(
db_session, after_date=after_date, latest_only=latest_only
)

if dry_run:
total = datasets.count()
Expand All @@ -102,6 +116,9 @@ def rebuild_missing_dataset_files(
logging.info("Starting to process datasets with missing files...")
execution_id = f"task-executor-uuid-{uuid.uuid4()}"
messages = []
all_datasets_count = datasets.count()
topic = (os.getenv("DATASET_PROCESSING_TOPIC_NAME"),)

for dataset in datasets.all():
try:
message = {
Expand All @@ -124,16 +141,13 @@ def rebuild_missing_dataset_files(
count += 1
total_processed += 1

if count % batch_count == 0:
publish_messages(
messages,
os.getenv("PROJECT_ID"),
os.getenv("DATASET_PROCESSING_TOPIC_NAME"),
)
if count % batch_count == 0 or all_datasets_count == count:
publish_messages(messages, os.getenv("PROJECT_ID"), topic)
messages = []
logging.info(
"Published message for %d datasets. Total processed: %d",
batch_count,
"Published message to topic %s for %d datasets. Total processed: %d",
topic,
batch_count if count % batch_count == 0 else all_datasets_count - count,
total_processed,
)

Expand All @@ -147,6 +161,7 @@ def rebuild_missing_dataset_files(
"after_date": after_date,
"latest_only": latest_only,
"datasets_bucket_name": os.environ.get("DATASETS_BUCKET_NAME"),
"dataset_id": dataset_id,
},
}
logging.info("Task summary: %s", result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import os
import unittest
from datetime import datetime
from unittest.mock import patch
from types import SimpleNamespace
from unittest.mock import patch, MagicMock

from sqlalchemy.orm import Session

Expand Down Expand Up @@ -46,7 +47,7 @@ def test_handler_calls_main_function(self, mock_rebuild_func):

self.assertEqual(response["message"], "test")
mock_rebuild_func.assert_called_once_with(
dry_run=True, after_date="2024-01-01", latest_only=False
dry_run=True, after_date="2024-01-01", latest_only=False, dataset_id=None
)

@with_db_session(db_url=default_db_url)
Expand Down Expand Up @@ -93,3 +94,101 @@ def test_rebuild_missing_dataset_files_processing(
self.assertIn("completed", response["message"])
self.assertGreaterEqual(response["total_processed"], 0)
self.assertTrue(publish_mock.called or response["total_processed"] == 0)


class TestRebuildSpecificDatasetFiles(unittest.TestCase):
@patch(
"tasks.dataset_files.rebuild_missing_dataset_files.rebuild_missing_dataset_files"
)
def test_handler_calls_main_function(self, mock_rebuild_func):
mock_rebuild_func.return_value = {"message": "test", "total_processed": 0}
payload = {"dry_run": True, "after_date": "2024-01-01", "latest_only": False}

response = rebuild_missing_dataset_files_handler(payload)

self.assertEqual(response["message"], "test")
mock_rebuild_func.assert_called_once_with(
dry_run=True, after_date="2024-01-01", latest_only=False, dataset_id=None
)

@patch(
"tasks.dataset_files.rebuild_missing_dataset_files.rebuild_missing_dataset_files"
)
def test_handler_forwards_dataset_id(self, mock_rebuild_func):
payload = {
"dry_run": False,
"after_date": None,
"latest_only": True,
"dataset_id": "ds-123",
}

rebuild_missing_dataset_files_handler(payload)

mock_rebuild_func.assert_called_once_with(
dry_run=False, after_date=None, latest_only=True, dataset_id="ds-123"
)

def test_rebuild_with_specific_dataset_id_publishes_one_message(self):
dataset_stable_id = "ds-123"
fake_feed = SimpleNamespace(
producer_url="https://example.com",
stable_id="feed-stable",
id=42,
authentication_type=None,
authentication_info_url=None,
api_key_parameter_name=None,
)
fake_dataset = SimpleNamespace(
stable_id=dataset_stable_id, hash="abc123", feed=fake_feed
)

# Mock the chained SQLAlchemy calls:
# db_session.query(Gtfsdataset).filter(...).options(...).count()/all()
db_session = MagicMock()
query_mock = MagicMock()
filter_mock = MagicMock()
options_mock = MagicMock()

db_session.query.return_value = query_mock
query_mock.filter.return_value = filter_mock
filter_mock.options.return_value = options_mock

options_mock.count.return_value = 1
options_mock.all.return_value = [fake_dataset]

with patch.dict(
os.environ,
{"PROJECT_ID": "test-project", "DATASET_PROCESSING_TOPIC_NAME": "topic"},
clear=False,
), patch(
"tasks.dataset_files.rebuild_missing_dataset_files.get_datasets_with_missing_files_query"
) as get_query_mock, patch(
"tasks.dataset_files.rebuild_missing_dataset_files.publish_messages"
) as mock_publish:
from tasks.dataset_files.rebuild_missing_dataset_files import (
rebuild_missing_dataset_files,
Gtfsdataset,
)

result = rebuild_missing_dataset_files(
db_session=db_session,
dry_run=False,
after_date=None,
latest_only=True, # ignored when dataset_id is provided
dataset_id=dataset_stable_id,
)

# Asserts
get_query_mock.assert_not_called() # bypasses generic query when dataset_id is set
db_session.query.assert_called_once_with(Gtfsdataset)
query_mock.filter.assert_called_once() # filtered by stable_id
options_mock.count.assert_called_once()
options_mock.all.assert_called_once()

self.assertEqual(result["total_processed"], 1)
mock_publish.assert_called_once()

messages_arg, project_id_arg, _topic_arg = mock_publish.call_args[0]
self.assertEqual(project_id_arg, "test-project")
self.assertEqual(len(messages_arg), 1)
self.assertEqual(messages_arg[0]["dataset_stable_id"], dataset_stable_id)