Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion functions-python/helpers/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ google-cloud-bigquery

# Additional package
pycountry
shapely
shapely
pandas
96 changes: 95 additions & 1 deletion functions-python/helpers/tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from transform import to_boolean, get_nested_value
import unittest

import pandas as pd

from transform import (
to_boolean,
get_nested_value,
to_float,
get_safe_value,
get_safe_float,
)


def test_to_boolean():
Expand Down Expand Up @@ -60,3 +70,87 @@ def test_get_nested_value():
# Test case 9: Non-dictionary data
assert get_nested_value("not a dict", ["a", "b", "c"]) is None
assert get_nested_value("not a dict", ["a", "b", "c"], []) == []


class TestToFloat(unittest.TestCase):
def test_valid_float(self):
self.assertEqual(to_float("3.14"), 3.14)
self.assertEqual(to_float(2.5), 2.5)
self.assertEqual(to_float("0"), 0.0)
self.assertEqual(to_float(0), 0.0)

def test_invalid_float(self):
self.assertIsNone(to_float("abc"))
self.assertIsNone(to_float(None))
self.assertIsNone(to_float(""))

def test_default_value(self):
self.assertEqual(to_float("abc", default_value=1.23), 1.23)
self.assertEqual(to_float(None, default_value=4.56), 4.56)
self.assertEqual(to_float("", default_value=7.89), 7.89)


class TestGetSafeValue(unittest.TestCase):
def test_valid_value(self):
row = {"name": " Alice "}
self.assertEqual(get_safe_value(row, "name"), "Alice")

def test_missing_column(self):
row = {"age": 30}
self.assertIsNone(get_safe_value(row, "name"))

def test_empty_string(self):
row = {"name": " "}
self.assertIsNone(get_safe_value(row, "name"))

def test_nan_value(self):
row = {"name": pd.NA}
self.assertIsNone(get_safe_value(row, "name"))
row = {"name": float("nan")}
self.assertIsNone(get_safe_value(row, "name"))

def test_default_value(self):
row = {"name": ""}
self.assertEqual(
get_safe_value(row, "name", default_value="default"), "default"
)


class TestGetSafeFloat(unittest.TestCase):
def test_valid_float(self):
row = {"value": "3.14"}
self.assertEqual(get_safe_float(row, "value"), 3.14)
row = {"value": 2.5}
self.assertEqual(get_safe_float(row, "value"), 2.5)
row = {"value": "0"}
self.assertEqual(get_safe_float(row, "value"), 0.0)
row = {"value": 0}
self.assertEqual(get_safe_float(row, "value"), 0.0)

def test_missing_column(self):
row = {"other": 1.23}
self.assertIsNone(get_safe_float(row, "value"))

def test_empty_string(self):
row = {"value": " "}
self.assertIsNone(get_safe_float(row, "value"))

def test_nan_value(self):
row = {"value": pd.NA}
self.assertIsNone(get_safe_float(row, "value"))
row = {"value": float("nan")}
self.assertIsNone(get_safe_float(row, "value"))

def test_invalid_float(self):
row = {"value": "abc"}
self.assertIsNone(get_safe_float(row, "value"))
row = {"value": None}
self.assertIsNone(get_safe_float(row, "value"))

def test_default_value(self):
row = {"value": ""}
self.assertEqual(get_safe_float(row, "value", default_value=1.23), 1.23)
row = {"value": "abc"}
self.assertEqual(get_safe_float(row, "value", default_value=4.56), 4.56)
row = {"value": None}
self.assertEqual(get_safe_float(row, "value", default_value=7.89), 7.89)
37 changes: 37 additions & 0 deletions functions-python/helpers/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,40 @@ def to_enum(value, enum_class=None, default_value=None):
except (ValueError, TypeError) as e:
logging.warning("Failed to convert value to enum member: %s", e)
return default_value


def to_float(value, default_value: Optional[float] = None) -> Optional[float]:
"""
Convert a value to a float. If conversion fails, return the default value.
"""
try:
return float(value)
except (ValueError, TypeError):
return default_value


def get_safe_value(row, column_name, default_value=None) -> Optional[str]:
"""
Get a safe value from the row. If the value is missing or empty, return the default value.
"""
import pandas

value = row.get(column_name, None)
if (
value is None
or pandas.isna(value)
or (isinstance(value, str) and value.strip() == "")
):
return default_value
return f"{value}".strip()


def get_safe_float(row, column_name, default_value=None) -> Optional[float]:
"""
Get a safe float value from the row. If the value is missing or cannot be converted to float,
"""
safe_value = get_safe_value(row, column_name)
try:
return float(safe_value)
except (ValueError, TypeError):
return default_value
144 changes: 144 additions & 0 deletions functions-python/helpers/verifier_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import logging
import os
import socket
import subprocess
from typing import Dict
import uuid
from io import BytesIO

import requests

from shared.database.database import with_db_session
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gbfsfeed
from shared.helpers.runtime_metrics import track_metrics

from google.cloud import storage
from sqlalchemy.orm import Session


EMULATOR_STORAGE_BUCKET_NAME = "verifier"
EMULATOR_HOST = "localhost"
EMULATOR_STORAGE_PORT = 9023


@track_metrics(metrics=("time", "memory", "cpu"))
def download_to_local(
feed_stable_id: str, url: str, filename: str, force_download: bool = False
):
"""
Download a file from a URL and upload it to the Google Cloud Storage emulator.
If the file already exists, it will not be downloaded again.
Args:
url (str): The URL to download the file from.
filename (str): The name of the file to save in the emulator.
"""
if not url:
return
blob_path = f"{feed_stable_id}/{filename}"
client = storage.Client()
bucket = client.bucket(EMULATOR_STORAGE_BUCKET_NAME)
blob = bucket.blob(blob_path)

# Check if the blob already exists in the emulator
if not blob.exists() or force_download:
logging.info(f"Downloading and uploading: {blob_path}")
with requests.get(url, stream=True) as response:
response.raise_for_status()
blob.content_type = "application/json"
# The file is downloaded into memory before uploading to ensure it's seekable.
# Be careful with large files.
data = BytesIO(response.content)
blob.upload_from_file(data, rewind=True)
else:
logging.info(
f"Blob already exists: gs://{EMULATOR_STORAGE_BUCKET_NAME}/{blob_path}"
)


@with_db_session
def create_test_data(feed_stable_id: str, feed_dict: Dict, db_session: Session = None):
"""
Create test data in the database if it does not exist.
This function is used to ensure that the reverse geolocation process has the necessary data to work with.
"""
# Here you would typically interact with your database to create the necessary test data
# For this example, we will just log the action
logging.info(f"Creating test data for {feed_stable_id} with data: {feed_dict}")
model = Gtfsfeed if feed_dict["data_type"] == "gtfs" else Gbfsfeed
local_feed = (
db_session.query(model).filter(model.stable_id == feed_stable_id).one_or_none()
)
if not local_feed:
local_feed = model(
id=uuid.uuid4(),
stable_id=feed_stable_id,
data_type=feed_dict["data_type"],
feed_name="Test Feed",
note="This is a test feed created for reverse geolocation verification.",
producer_url="https://files.mobilitydatabase.org/mdb-2014/mdb-2014-202508120303/mdb-2014-202508120303.zip",
authentication_type="0",
status="active",
)
db_session.add(local_feed)
db_session.commit()


def setup_local_storage_emulator():
"""
Setup the Google Cloud Storage emulator by creating the necessary bucket.
"""
from gcp_storage_emulator.server import create_server

os.environ[
"STORAGE_EMULATOR_HOST"
] = f"http://{EMULATOR_HOST}:{EMULATOR_STORAGE_PORT}"
os.environ["DATASETS_BUCKET_NAME_GBFS"] = EMULATOR_STORAGE_BUCKET_NAME
os.environ["DATASETS_BUCKET_NAME_GTFS"] = EMULATOR_STORAGE_BUCKET_NAME
os.environ["DATASTORE_EMULATOR_HOST"] = "localhost:8081"
server = create_server(
host=EMULATOR_HOST,
port=EMULATOR_STORAGE_PORT,
in_memory=False,
default_bucket=EMULATOR_STORAGE_BUCKET_NAME,
)
server.start()
return server


def shutdown_local_storage_emulator(server):
"""Shutdown the Google Cloud Storage emulator."""
server.stop()


def is_datastore_emulator_running(host=EMULATOR_HOST, port=8081):
"""Check if the Google Cloud Datastore emulator is running."""
try:
with socket.create_connection((host, port), timeout=2):
return True
except OSError:
return False


def start_datastore_emulator(project_id="test-project"):
"""Start the Google Cloud Datastore emulator if it's not already running."""
if not is_datastore_emulator_running():
process = subprocess.Popen(
[
"gcloud",
"beta",
"emulators",
"datastore",
"start",
"--project={}".format(project_id),
"--host-port=localhost:8081",
]
)
return process
return None # Already running


def shutdown_datastore_emulator(process):
"""Shutdown the Google Cloud Datastore emulator."""
if process:
process.terminate()
process.wait()
1 change: 1 addition & 0 deletions functions-python/pmtiles_builder/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ google-cloud-storage
python-dotenv==1.0.0
tippecanoe
psutil
pandas

28 changes: 23 additions & 5 deletions functions-python/pmtiles_builder/src/csv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
#
import csv
import os
from shared.helpers.logger import get_logger

from gtfs import stop_txt_is_lat_log_required
from shared.helpers.logger import get_logger
from shared.helpers.transform import get_safe_value, get_safe_float

STOP_TIMES_FILE = "stop_times.txt"
SHAPES_FILE = "shapes.txt"
Expand Down Expand Up @@ -127,10 +129,26 @@ def get_stops_from_trip(self, trip_id):

def get_coordinates_for_stop(self, stop_id) -> tuple[float, float] | None:
if self.stop_to_coordinates is None:
self.stop_to_coordinates = {
s["stop_id"]: (float(s["stop_lon"]), float(s["stop_lat"]))
for s in self.get_file(STOPS_FILE)
}
self.stop_to_coordinates = {}
for s in self.get_file(STOPS_FILE):
self.stop_to_coordinates.get(stop_id, [])
row_stop_id = get_safe_value(s, "stop_id")
row_stop_lon = get_safe_float(s, "stop_lon")
row_stop_lat = get_safe_float(s, "stop_lat")
if row_stop_id is None:
self.logger.warning("Missing stop id: %s", s)
continue
if row_stop_lon is None or row_stop_lat is None:
if stop_txt_is_lat_log_required(s):
self.logger.warning(
"Missing stop latitude and longitude : %s", s
)
else:
self.logger.debug(
"Missing optional stop latitude and longitude : %s", s
)
continue
self.stop_to_coordinates[row_stop_id] = (row_stop_lon, row_stop_lat)
return self.stop_to_coordinates.get(stop_id, None)

def set_workdir(self, workdir):
Expand Down
20 changes: 20 additions & 0 deletions functions-python/pmtiles_builder/src/gtfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from shared.helpers.transform import get_safe_value

# TODO: Move this file to a shared folder


def stop_txt_is_lat_log_required(stop_row):
"""
Conditionally Required:
- Required for locations which are stops (location_type=0), stations (location_type=1)
or entrances/exits (location_type=2).
- Optional for locations which are generic nodes (location_type=3) or boarding areas (location_type=4).

Args:
row (dict): The data row to check.

Returns:
bool: True if both latitude and longitude is required, False otherwise.
"""
location_type = get_safe_value(stop_row, "location_type", "0")
return location_type in ("0", "1", "2")
11 changes: 8 additions & 3 deletions functions-python/pmtiles_builder/src/gtfs_stops_to_geojson.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from collections import defaultdict

from csv_cache import CsvCache, ROUTES_FILE, TRIPS_FILE, STOP_TIMES_FILE, STOPS_FILE
from gtfs import stop_txt_is_lat_log_required
from shared.helpers.runtime_metrics import track_metrics
from shared.helpers.transform import get_safe_float

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,10 +62,13 @@ def convert_stops_to_geojson(csv_cache: CsvCache, output_file):
if (
"stop_lat" not in row
or "stop_lon" not in row
or not row["stop_lat"]
or not row["stop_lon"]
or get_safe_float(row, "stop_lat") is None
or get_safe_float(row, "stop_lon") is None
):
logger.warning(f"Missing coordinates for stop_id {stop_id}, skipping.")
if stop_txt_is_lat_log_required(row):
logger.warning(
"Missing coordinates for stop_id {%s}, skipping.", stop_id
)
continue

# Routes serving this stop
Expand Down
Loading
Loading