Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def get_gtfs_feed_datasets(
).filter(DatasetsApiImpl.create_dataset_query().filter(FeedOrm.stable_id == gtfs_feed_id))

if latest:
query = query.filter(Gtfsdataset.latest)
query = query.join(Gtfsdataset.feed).filter(Gtfsdataset.id == FeedOrm.latest_dataset_id)

return DatasetsApiImpl.get_datasets_gtfs(query, session=db_session, limit=limit, offset=offset)

Expand Down
5 changes: 1 addition & 4 deletions api/src/feeds/impl/models/gtfs_feed_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ def from_orm(cls, feed: GtfsfeedOrm | None) -> GtfsFeed | None:
if not gtfs_feed:
return None
gtfs_feed.locations = [LocationImpl.from_orm(item) for item in feed.locations]
latest_dataset = next(
(dataset for dataset in feed.gtfsdatasets if dataset is not None and dataset.latest), None
)
gtfs_feed.latest_dataset = LatestDatasetImpl.from_orm(latest_dataset)
gtfs_feed.latest_dataset = LatestDatasetImpl.from_orm(feed.latest_dataset)
gtfs_feed.bounding_box = BoundingBoxImpl.from_orm(feed.bounding_box)
gtfs_feed.visualization_dataset_id = (
feed.visualization_dataset.stable_id if feed.visualization_dataset else None
Expand Down
4 changes: 3 additions & 1 deletion api/src/scripts/populate_db_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def populate_test_datasets(self, filepath, db_session: "Session"):
id=dataset["id"],
feed_id=gtfsfeed[0].id,
stable_id=dataset["id"],
latest=dataset["latest"],
hosted_url=dataset["hosted_url"],
hash=dataset["hash"],
downloaded_at=dataset["downloaded_at"],
Expand All @@ -82,6 +81,9 @@ def populate_test_datasets(self, filepath, db_session: "Session"):
),
validation_reports=[],
)
if dataset["latest"]:
gtfsfeed[0].latest_dataset = gtfs_dataset

dataset_dict[dataset["id"]] = gtfs_dataset
db_session.add(gtfs_dataset)
db_session.commit()
Expand Down
22 changes: 6 additions & 16 deletions api/src/shared/common/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,7 @@ def get_gtfs_feeds_query(
subquery = apply_bounding_filtering(
subquery, dataset_latitudes, dataset_longitudes, bounding_filter_method
).subquery()
feed_query = (
db_session.query(Gtfsfeed)
.outerjoin(Gtfsfeed.gtfsdatasets)
.filter(Gtfsfeed.id.in_(subquery))
.filter(or_(Gtfsdataset.latest, Gtfsdataset.id == None)) # noqa: E711
)
feed_query = db_session.query(Gtfsfeed).filter(Gtfsfeed.id.in_(subquery))

if country_code or subdivision_name or municipality:
location_filter = LocationFilter(
Expand All @@ -84,7 +79,7 @@ def get_gtfs_feeds_query(

if include_options_for_joinedload:
feed_query = feed_query.options(
contains_eager(Gtfsfeed.gtfsdatasets)
joinedload(Gtfsfeed.latest_dataset)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.features),
joinedload(Gtfsfeed.visualization_dataset),
Expand Down Expand Up @@ -172,14 +167,10 @@ def get_all_gtfs_feeds(
for batch in batched(batch_query, batch_size):
stable_ids = (f.stable_id for f in batch)
if w_extracted_locations_only:
feed_query = apply_most_common_location_filter(
db_session.query(Gtfsfeed).outerjoin(Gtfsfeed.gtfsdatasets), db_session
)
feed_query = apply_most_common_location_filter(db_session.query(Gtfsfeed), db_session)
yield from (
feed_query.filter(Gtfsfeed.stable_id.in_(stable_ids))
.filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711
.options(
contains_eager(Gtfsfeed.gtfsdatasets)
feed_query.filter(Gtfsfeed.stable_id.in_(stable_ids)).options(
joinedload(Gtfsfeed.latest_dataset)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.features),
*get_joinedload_options(include_extracted_location_entities=True),
Expand All @@ -190,9 +181,8 @@ def get_all_gtfs_feeds(
db_session.query(Gtfsfeed)
.outerjoin(Gtfsfeed.gtfsdatasets)
.filter(Gtfsfeed.stable_id.in_(stable_ids))
.filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711
.options(
contains_eager(Gtfsfeed.gtfsdatasets)
joinedload(Gtfsfeed.latest_dataset)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.features),
*get_joinedload_options(include_extracted_location_entities=False),
Expand Down
141 changes: 81 additions & 60 deletions api/tests/unittest/models/test_gtfs_feed_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import unittest
from datetime import datetime
from zoneinfo import ZoneInfo
Expand Down Expand Up @@ -49,13 +48,47 @@ def create_test_notice(notice_code: str, total_notices: int, severity: str):
)


gtfs_dataset_orm = Gtfsdataset(
id="id",
stable_id="dataset_stable_id",
feed_id="feed_id",
hosted_url="hosted_url",
note="note",
downloaded_at=datetime(year=2022, month=12, day=31, hour=13, minute=45, second=56),
hash="hash",
service_date_range_start=datetime(2024, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("Canada/Atlantic")),
service_date_range_end=datetime(2025, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("Canada/Atlantic")),
agency_timezone="Canada/Atlantic",
bounding_box=WKTElement(POLYGON, srid=4326),
validation_reports=[
Validationreport(
id="id",
validator_version="validator_version",
validated_at=datetime(year=2022, month=12, day=31, hour=13, minute=45, second=56),
html_report="html_report",
json_report="json_report",
features=[Feature(name="feature")],
notices=[
create_test_notice("notice_code1", 1, "INFO"),
create_test_notice("notice_code2", 3, "INFO"),
create_test_notice("notice_code3", 7, "ERROR"),
create_test_notice("notice_code4", 9, "ERROR"),
create_test_notice("notice_code5", 11, "ERROR"),
create_test_notice("notice_code6", 13, "WARNING"),
create_test_notice("notice_code7", 15, "WARNING"),
create_test_notice("notice_code8", 17, "WARNING"),
create_test_notice("notice_code9", 19, "WARNING"),
],
)
],
)
gtfs_feed_orm = Gtfsfeed(
id="id",
data_type="gtfs",
feed_name="feed_name",
note="note",
producer_url="producer_url",
authentication_type=1,
authentication_type="1",
authentication_info_url="authentication_info_url",
api_key_parameter_name="api_key_parameter_name",
license_url="license_url",
Expand All @@ -79,43 +112,8 @@ def create_test_notice(notice_code: str, total_notices: int, severity: str):
source="source",
)
],
gtfsdatasets=[
Gtfsdataset(
id="id",
stable_id="dataset_stable_id",
feed_id="feed_id",
hosted_url="hosted_url",
note="note",
downloaded_at=datetime(year=2022, month=12, day=31, hour=13, minute=45, second=56),
hash="hash",
service_date_range_start=datetime(2024, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("Canada/Atlantic")),
service_date_range_end=datetime(2025, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("Canada/Atlantic")),
agency_timezone="Canada/Atlantic",
bounding_box=WKTElement(POLYGON, srid=4326),
latest=True,
validation_reports=[
Validationreport(
id="id",
validator_version="validator_version",
validated_at=datetime(year=2022, month=12, day=31, hour=13, minute=45, second=56),
html_report="html_report",
json_report="json_report",
features=[Feature(name="feature")],
notices=[
create_test_notice("notice_code1", 1, "INFO"),
create_test_notice("notice_code2", 3, "INFO"),
create_test_notice("notice_code3", 7, "ERROR"),
create_test_notice("notice_code4", 9, "ERROR"),
create_test_notice("notice_code5", 11, "ERROR"),
create_test_notice("notice_code6", 13, "WARNING"),
create_test_notice("notice_code7", 15, "WARNING"),
create_test_notice("notice_code8", 17, "WARNING"),
create_test_notice("notice_code9", 19, "WARNING"),
],
)
],
)
],
latest_dataset=gtfs_dataset_orm,
gtfsdatasets=[gtfs_dataset_orm],
redirectingids=[
Redirectingid(source_id="source_id", target_id="id1", redirect_comment="redirect_comment", target=targetFeed)
],
Expand Down Expand Up @@ -198,24 +196,47 @@ def test_from_orm_all_fields(self):

def test_from_orm_empty_fields(self):
"""Test the `from_orm` method with not provided fields."""
# Test with empty fields and None values
# No error should be raised
# Target is set to None as deep copy is failing for unknown reasons
# At the end of the test, the target is set back to the original value
gtfs_feed_orm.redirectingids[0].target = None
target_feed_orm = copy.deepcopy(gtfs_feed_orm)
target_feed_orm.feed_name = ""
target_feed_orm.provider = None
target_feed_orm.externalids = []
target_feed_orm.redirectingids = []

target_expected_gtfs_feed_result = copy.deepcopy(expected_gtfs_feed_result)
target_expected_gtfs_feed_result.feed_name = ""
target_expected_gtfs_feed_result.provider = None
target_expected_gtfs_feed_result.external_ids = []
target_expected_gtfs_feed_result.redirects = []

result = GtfsFeedImpl.from_orm(target_feed_orm)
assert result == target_expected_gtfs_feed_result
# Set the target back to the original value
gtfs_feed_orm.redirectingids[0].target = targetFeed
# Manually construct a minimal Gtfsfeed ORM object with empty/None fields
minimal_feed_orm = Gtfsfeed(
id="id",
data_type="gtfs",
feed_name="",
note=None,
producer_url=None,
authentication_type=None,
authentication_info_url=None,
api_key_parameter_name=None,
license_url=None,
stable_id="stable_id",
status=None,
feed_contact_email=None,
provider=None,
locations=[],
externalids=[],
latest_dataset=None,
gtfsdatasets=[],
redirectingids=[],
gtfs_rt_feeds=[],
)
minimal_expected_result = GtfsFeedImpl(
id="stable_id",
data_type="gtfs",
status=None,
external_ids=[],
provider=None,
feed_name="",
note=None,
feed_contact_email=None,
source_info=SourceInfo(
producer_url=None,
authentication_type=None,
authentication_info_url=None,
api_key_parameter_name=None,
license_url=None,
),
redirects=[],
locations=[],
latest_dataset=None,
)
result = GtfsFeedImpl.from_orm(minimal_feed_orm)
assert result == minimal_expected_result
13 changes: 6 additions & 7 deletions api/tests/unittest/test_feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_gtfs_feeds_get_no_bounding_box(client: TestClient, mocker):
"""
mock_select = mocker.patch.object(Database(), "select")
mock_feed = Feed(stable_id="test_gtfs_id")
mock_latest_datasets = Gtfsdataset(stable_id="test_latest_dataset_id", hosted_url="test_hosted_url", latest=True)
mock_latest_datasets = Gtfsdataset(stable_id="test_latest_dataset_id", hosted_url="test_hosted_url")

mock_select.return_value = [
[
Expand Down Expand Up @@ -296,18 +296,17 @@ def assert_gtfs(gtfs_feed, response_gtfs_feed):
), f'Response feed municipality was {response_gtfs_feed["locations"][0]["municipality"]} \
instead of {gtfs_feed.locations[0].municipality}'
# It seems the resulting are not always in the same order, so find the latest instead of using a hardcoded index
latest_dataset = next((dataset for dataset in gtfs_feed.gtfsdatasets if dataset.latest), None)
if latest_dataset is not None:
# latest_dataset = next((dataset for dataset in gtfs_feed.gtfsdatasets if dataset.latest), None)
if gtfs_feed.latest_dataset is not None:
assert (
response_gtfs_feed["latest_dataset"]["id"] == latest_dataset.stable_id
response_gtfs_feed["latest_dataset"]["id"] == gtfs_feed.latest_dataset.stable_id
), f'Response feed latest dataset id was {response_gtfs_feed["latest_dataset"]["id"]} \
instead of {latest_dataset.stable_id}'
instead of {gtfs_feed.latest_dataset.stable_id}'
else:
raise Exception("No latest dataset found")

latest_dataset = next(filter(lambda x: x.latest, gtfs_feed.gtfsdatasets))
assert (
response_gtfs_feed["latest_dataset"]["hosted_url"] == latest_dataset.hosted_url
response_gtfs_feed["latest_dataset"]["hosted_url"] == gtfs_feed.latest_dataset.hosted_url
), f'Response feed hosted url was {response_gtfs_feed["latest_dataset"]["hosted_url"]} \
instead of test_hosted_url'
assert response_gtfs_feed["latest_dataset"]["bounding_box"] is not None, "Response feed bounding_box was None"
Expand Down
4 changes: 1 addition & 3 deletions functions-python/batch_datasets/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from google.cloud import pubsub_v1
from google.cloud.pubsub_v1 import PublisherClient
from google.cloud.pubsub_v1.futures import Future
from sqlalchemy import or_
from sqlalchemy.orm import Session

from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsdataset
Expand Down Expand Up @@ -87,9 +86,8 @@ def get_non_deprecated_feeds(
Gtfsdataset.hash.label("dataset_hash"),
)
.select_from(Gtfsfeed)
.outerjoin(Gtfsdataset, (Gtfsdataset.feed_id == Gtfsfeed.id))
.outerjoin(Gtfsdataset, (Gtfsfeed.latest_dataset_id == Gtfsdataset.id))
.filter(Gtfsfeed.status != "deprecated")
.filter(or_(Gtfsdataset.id.is_(None), Gtfsdataset.latest.is_(True)))
)
if feed_stable_ids:
# If feed_stable_ids are provided, filter the query by stable IDs
Expand Down
6 changes: 4 additions & 2 deletions functions-python/batch_datasets/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def populate_database(db_session: Session | None = None):
# GTFS datasets leaving one active feed without a dataset
active_gtfs_feeds = db_session.query(Gtfsfeed).all()
for i in range(1, 9):
id = fake.uuid4()
gtfs_dataset = Gtfsdataset(
id=fake.uuid4(),
id=id,
feed_id=active_gtfs_feeds[i].id,
latest=True,
bounding_box="POLYGON((-180 -90, -180 90, 180 90, 180 -90, -180 -90))",
hosted_url=fake.url(),
note=fake.sentence(),
Expand All @@ -96,6 +96,8 @@ def populate_database(db_session: Session | None = None):
stable_id=fake.uuid4(),
)
db_session.add(gtfs_dataset)
db_session.flush()
active_gtfs_feeds[i].latest_dataset_id = id

db_session.flush()
# GTFS Realtime feeds
Expand Down
Loading
Loading