Skip to content

Commit fa4597d

Browse files
authored
feat: GTFS validator task sync (#1650)
1 parent a488c78 commit fa4597d

File tree

29 files changed

+2234
-762
lines changed

29 files changed

+2234
-762
lines changed

functions-python/update_validation_report/src/__init__.py renamed to functions-python/helpers/task_execution/__init__.py

File renamed without changes.

functions-python/helpers/task_execution/task_execution_tracker.py

Lines changed: 458 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
#
2+
# MobilityData 2026
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import unittest
18+
import uuid
19+
from datetime import datetime, timezone
20+
from unittest.mock import MagicMock
21+
22+
from task_execution.task_execution_tracker import (
23+
TaskExecutionTracker,
24+
STATUS_IN_PROGRESS,
25+
STATUS_TRIGGERED,
26+
STATUS_COMPLETED,
27+
STATUS_FAILED,
28+
)
29+
30+
31+
def _make_tracker(task_name="test_task", run_id="v1.0"):
32+
"""Return a tracker with a mock DB session."""
33+
session = MagicMock()
34+
tracker = TaskExecutionTracker(
35+
task_name=task_name, run_id=run_id, db_session=session
36+
)
37+
return tracker, session
38+
39+
40+
class TestTaskExecutionTrackerStartRun(unittest.TestCase):
41+
def test_start_run_upserts_task_run(self):
42+
tracker, session = _make_tracker()
43+
run_uuid = uuid.uuid4()
44+
execute_result = MagicMock()
45+
execute_result.scalar_one.return_value = run_uuid
46+
session.execute.return_value = execute_result
47+
48+
result = tracker.start_run(total_count=100, params={"env": "staging"})
49+
50+
self.assertEqual(result, run_uuid)
51+
self.assertEqual(tracker.task_run_id, run_uuid)
52+
session.execute.assert_called_once()
53+
session.flush.assert_called_once()
54+
55+
def test_start_run_caches_task_run_id(self):
56+
tracker, session = _make_tracker()
57+
run_uuid = uuid.uuid4()
58+
execute_result = MagicMock()
59+
execute_result.scalar_one.return_value = run_uuid
60+
session.execute.return_value = execute_result
61+
62+
tracker.start_run(total_count=10)
63+
tracker.start_run(total_count=20) # second call
64+
65+
self.assertEqual(tracker.task_run_id, run_uuid)
66+
67+
def test_start_run_resets_status_to_in_progress_on_rerun(self):
68+
"""Re-running the same task_name/run_id must reset status and completed_at on conflict."""
69+
tracker, session = _make_tracker()
70+
run_uuid = uuid.uuid4()
71+
execute_result = MagicMock()
72+
execute_result.scalar_one.return_value = run_uuid
73+
session.execute.return_value = execute_result
74+
75+
tracker.start_run(total_count=5)
76+
77+
stmt_compiled = str(session.execute.call_args[0][0])
78+
# The ON CONFLICT DO UPDATE clause must include status and completed_at
79+
self.assertIn("DO UPDATE SET", stmt_compiled)
80+
self.assertIn("status", stmt_compiled)
81+
self.assertIn("completed_at", stmt_compiled)
82+
83+
84+
class TestTaskExecutionTrackerIsTriggered(unittest.TestCase):
85+
def test_returns_true_when_triggered_row_exists(self):
86+
tracker, session = _make_tracker()
87+
existing_row = MagicMock()
88+
session.query.return_value.filter.return_value.filter.return_value.first.return_value = (
89+
existing_row
90+
)
91+
92+
result = tracker.is_triggered("ds-123")
93+
self.assertTrue(result)
94+
95+
def test_returns_false_when_no_row(self):
96+
tracker, session = _make_tracker()
97+
session.query.return_value.filter.return_value.filter.return_value.first.return_value = (
98+
None
99+
)
100+
101+
result = tracker.is_triggered("ds-999")
102+
self.assertFalse(result)
103+
104+
def test_handles_none_entity_id(self):
105+
tracker, session = _make_tracker()
106+
session.query.return_value.filter.return_value.filter.return_value.first.return_value = (
107+
None
108+
)
109+
110+
result = tracker.is_triggered(None)
111+
self.assertFalse(result)
112+
113+
114+
class TestTaskExecutionTrackerMarkTriggered(unittest.TestCase):
115+
def test_mark_triggered_inserts_execution_log(self):
116+
tracker, session = _make_tracker()
117+
tracker.task_run_id = uuid.uuid4()
118+
119+
tracker.mark_triggered("ds-1", execution_ref="projects/x/executions/abc")
120+
121+
session.execute.assert_called_once()
122+
session.flush.assert_called_once()
123+
124+
def test_mark_triggered_with_metadata(self):
125+
tracker, session = _make_tracker()
126+
tracker.task_run_id = uuid.uuid4()
127+
128+
tracker.mark_triggered("ds-1", metadata={"feed_id": "f-1"})
129+
130+
session.execute.assert_called_once()
131+
132+
133+
class TestTaskExecutionTrackerMarkCompleted(unittest.TestCase):
134+
def test_mark_completed_updates_status(self):
135+
tracker, session = _make_tracker()
136+
query_mock = MagicMock()
137+
session.query.return_value.filter.return_value.filter.return_value = query_mock
138+
139+
tracker.mark_completed("ds-1")
140+
141+
query_mock.update.assert_called_once()
142+
update_args = query_mock.update.call_args[0][0]
143+
self.assertEqual(update_args["status"], STATUS_COMPLETED)
144+
self.assertIn("completed_at", update_args)
145+
146+
147+
class TestTaskExecutionTrackerMarkFailed(unittest.TestCase):
148+
def test_mark_failed_sets_error_message(self):
149+
tracker, session = _make_tracker()
150+
query_mock = MagicMock()
151+
session.query.return_value.filter.return_value.filter.return_value = query_mock
152+
153+
tracker.mark_failed("ds-1", error_message="Workflow timed out")
154+
155+
query_mock.update.assert_called_once()
156+
update_args = query_mock.update.call_args[0][0]
157+
self.assertEqual(update_args["status"], STATUS_FAILED)
158+
self.assertEqual(update_args["error_message"], "Workflow timed out")
159+
160+
161+
class TestTaskExecutionTrackerGetSummary(unittest.TestCase):
162+
def _make_task_run(self, status=STATUS_IN_PROGRESS, total_count=10):
163+
run = MagicMock()
164+
run.status = status
165+
run.total_count = total_count
166+
run.created_at = datetime.now(timezone.utc)
167+
return run
168+
169+
def test_returns_none_summary_when_no_run(self):
170+
tracker, session = _make_tracker()
171+
session.query.return_value.filter.return_value.first.return_value = None
172+
session.query.return_value.filter.return_value.all.return_value = []
173+
174+
summary = tracker.get_summary()
175+
176+
self.assertIsNone(summary["run_status"])
177+
self.assertEqual(summary["triggered"], 0)
178+
self.assertEqual(summary["completed"], 0)
179+
180+
def test_counts_by_status(self):
181+
tracker, session = _make_tracker()
182+
task_run = self._make_task_run(total_count=5)
183+
184+
rows = [
185+
MagicMock(status=STATUS_TRIGGERED),
186+
MagicMock(status=STATUS_TRIGGERED),
187+
MagicMock(status=STATUS_COMPLETED),
188+
MagicMock(status=STATUS_FAILED),
189+
]
190+
191+
def query_side_effect(*args):
192+
m = MagicMock()
193+
m.filter.return_value.first.return_value = task_run
194+
m.filter.return_value.all.return_value = rows
195+
return m
196+
197+
session.query.side_effect = query_side_effect
198+
199+
summary = tracker.get_summary()
200+
self.assertEqual(summary["triggered"], 2)
201+
self.assertEqual(summary["completed"], 1)
202+
self.assertEqual(summary["failed"], 1)
203+
self.assertEqual(summary["pending"], 1) # 5 total - 4 processed

functions-python/helpers/validation_report/validation_report_update.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,18 @@ def execute_workflows(
4747
validator_endpoint=None,
4848
bypass_db_update=False,
4949
reports_bucket_name=None,
50+
tracker=None,
5051
):
5152
"""
52-
Execute the workflow for the latest datasets that need their validation report to be updated
53+
Execute the workflow for the latest datasets that need their validation report to be updated.
54+
5355
:param latest_datasets: List of tuples containing the feed stable id and dataset stable id
5456
:param validator_endpoint: The URL of the validator
5557
:param bypass_db_update: Whether to bypass the database update
5658
:param reports_bucket_name: The name of the bucket where the reports are stored
59+
:param tracker: Optional TaskExecutionTracker for idempotent execution tracking.
60+
When provided, datasets already in triggered/completed state are skipped
61+
and newly triggered datasets are recorded.
5762
:return: List of dataset stable ids for which the workflow was executed
5863
"""
5964
project_id = f"mobility-feeds-{env}"
@@ -64,6 +69,9 @@ def execute_workflows(
6469
count = 0
6570
logging.info(f"Executing workflow for {len(latest_datasets)} datasets")
6671
for feed_id, dataset_id in latest_datasets:
72+
if tracker and tracker.is_triggered(dataset_id):
73+
logging.info(f"Skipping already triggered dataset {feed_id}/{dataset_id}")
74+
continue
6775
try:
6876
input_data = {
6977
"data": {
@@ -83,12 +91,20 @@ def execute_workflows(
8391
if reports_bucket_name:
8492
input_data["data"]["reports_bucket_name"] = reports_bucket_name
8593
logging.info(f"Executing workflow for {feed_id}/{dataset_id}")
86-
execute_workflow(project_id, input_data=input_data)
94+
execution = execute_workflow(project_id, input_data=input_data)
8795
execution_triggered_datasets.append(dataset_id)
96+
if tracker:
97+
tracker.mark_triggered(
98+
entity_id=dataset_id,
99+
execution_ref=execution.name,
100+
metadata={"feed_id": feed_id},
101+
)
88102
except Exception as e:
89103
logging.error(
90104
f"Error while executing workflow for {feed_id}/{dataset_id}: {e}"
91105
)
106+
if tracker:
107+
tracker.mark_failed(entity_id=dataset_id, error_message=str(e))
92108
count += 1
93109
logging.info(f"Triggered workflow execution for {count} datasets")
94110
if count % batch_size == 0:

functions-python/process_validation_report/src/main.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from shared.helpers.logger import init_logger
3636
from shared.helpers.transform import get_nested_value
3737
from shared.helpers.feed_status import update_feed_statuses_query
38+
from shared.helpers.task_execution.task_execution_tracker import TaskExecutionTracker
3839
from shared.common.gcp_utils import create_web_revalidation_task
3940

4041
init_logger()
@@ -288,6 +289,20 @@ def create_validation_report_entities(
288289

289290
update_feed_statuses_query(db_session, [feed_stable_id])
290291

292+
# Update execution tracker regardless of bypass_db_update, so monitoring
293+
# works for both pre-release and post-release validation runs.
294+
try:
295+
tracker = TaskExecutionTracker(
296+
task_name="gtfs_validation",
297+
run_id=version,
298+
db_session=db_session,
299+
)
300+
tracker.mark_completed(dataset_stable_id)
301+
db_session.commit()
302+
except Exception as tracker_error:
303+
logging.warning(
304+
"Could not update task execution tracker: %s", tracker_error
305+
)
291306
# Trigger web app cache revalidation for the feed
292307
try:
293308
create_web_revalidation_task([feed_stable_id])

functions-python/tasks_executor/README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,26 @@ Examples:
2020
"task": "rebuild_missing_validation_reports",
2121
"payload": {
2222
"dry_run": true,
23-
"filter_after_in_days": 14,
23+
"bypass_db_update": true,
24+
"filter_after_in_days": null,
25+
"force_update": false,
26+
"validator_endpoint": "https://stg-gtfs-validator-web-mbzoxaljzq-ue.a.run.app",
27+
"limit": 1,
2428
"filter_statuses": ["active", "inactive", "future"]
2529
}
2630
}
2731
```
2832

33+
```json
34+
{
35+
"task": "get_validation_run_status",
36+
"payload": {
37+
"task_name": "gtfs_validation",
38+
"run_id": "7.1.1-SNAPSHOT"
39+
}
40+
}
41+
```
42+
2943
```json
3044
{
3145
"task": "rebuild_missing_bounding_boxes",

functions-python/tasks_executor/src/main.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import functions_framework
2222

2323
from shared.helpers.logger import init_logger
24+
from shared.helpers.task_execution.task_execution_tracker import TaskInProgressError
2425
from tasks.data_import.transportdatagouv.import_tdg_feeds import import_tdg_handler
2526
from tasks.data_import.transportdatagouv.update_tdg_redirects import (
2627
update_tdg_redirects_handler,
@@ -38,6 +39,12 @@
3839
from tasks.validation_reports.rebuild_missing_validation_reports import (
3940
rebuild_missing_validation_reports_handler,
4041
)
42+
from tasks.sync_task_run_status import (
43+
sync_task_run_status_handler,
44+
)
45+
from tasks.get_task_run_status import (
46+
get_task_run_status_handler,
47+
)
4148
from tasks.visualization_files.rebuild_missing_visualization_files import (
4249
rebuild_missing_visualization_files_handler,
4350
)
@@ -71,6 +78,25 @@
7178
"description": "Rebuilds missing validation reports for GTFS datasets.",
7279
"handler": rebuild_missing_validation_reports_handler,
7380
},
81+
"get_task_run_status": {
82+
"description": (
83+
"Read-only snapshot of a task_run tracked by TaskExecutionTracker. "
84+
"Returns current DB state (triggered/completed/failed/pending counts) "
85+
"without triggering any GCP Workflows polling or status transitions. "
86+
"Required: task_name, run_id."
87+
),
88+
"handler": get_task_run_status_handler,
89+
},
90+
"sync_task_run_status": {
91+
"description": (
92+
"Generic self-scheduling monitor for any task_run. "
93+
"Polls GCP Workflows for triggered entries, updates statuses, "
94+
"marks the task_run completed when all done, and re-schedules "
95+
"itself every 10 minutes until complete. "
96+
"Required: task_name, run_id."
97+
),
98+
"handler": sync_task_run_status_handler,
99+
},
74100
"rebuild_missing_bounding_boxes": {
75101
"description": "Rebuilds missing bounding boxes for GTFS datasets that contain valid stops.txt files.",
76102
"handler": rebuild_missing_bounding_boxes_handler,
@@ -195,5 +221,10 @@ def tasks_executor(request: flask.Request) -> flask.Response:
195221

196222
# Default JSON response
197223
return flask.make_response(flask.jsonify(result), 200)
224+
except TaskInProgressError as error:
225+
# Signal Cloud Tasks to retry — the run is not yet complete
226+
return flask.make_response(
227+
flask.jsonify({"status": "in_progress", "detail": str(error)}), 503
228+
)
198229
except Exception as error:
199230
return flask.make_response(flask.jsonify({"error": str(error)}), 500)

0 commit comments

Comments
 (0)