diff --git a/api/src/feeds/impl/feeds_api_impl.py b/api/src/feeds/impl/feeds_api_impl.py index 67ce62098..52bcfe732 100644 --- a/api/src/feeds/impl/feeds_api_impl.py +++ b/api/src/feeds/impl/feeds_api_impl.py @@ -2,13 +2,11 @@ from typing import List, Union, TypeVar, Optional from sqlalchemy import or_ -from sqlalchemy import select -from sqlalchemy.orm import joinedload, contains_eager, selectinload, Session +from sqlalchemy.orm import contains_eager, selectinload, Session from sqlalchemy.orm.query import Query from feeds.impl.datasets_api_impl import DatasetsApiImpl from feeds.impl.error_handling import raise_http_error, raise_http_validation_error, convert_exception -from shared.db_models.entity_type_enum import EntityType from shared.db_models.feed_impl import FeedImpl from shared.db_models.gbfs_feed_impl import GbfsFeedImpl from shared.db_models.gtfs_feed_impl import GtfsFeedImpl @@ -23,7 +21,7 @@ from shared.common.db_utils import ( get_gtfs_feeds_query, get_gtfs_rt_feeds_query, - get_joinedload_options, + get_selectinload_options, add_official_filter, get_gbfs_feeds_query, ) @@ -41,13 +39,10 @@ Gtfsdataset, Gtfsfeed, Gtfsrealtimefeed, - Location, - Entitytype, ) from shared.feed_filters.feed_filter import FeedFilter from shared.feed_filters.gtfs_dataset_filter import GtfsDatasetFilter -from shared.feed_filters.gtfs_feed_filter import LocationFilter -from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter +from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter from utils.date_utils import valid_iso_date from utils.logger import get_logger @@ -120,7 +115,7 @@ def get_feeds( # Results are sorted by provider feed_query = feed_query.order_by(FeedOrm.provider, FeedOrm.stable_id) # Ensure license relationship is available to the model conversion without extra queries - feed_query = feed_query.options(*get_joinedload_options(), selectinload(FeedOrm.license)) + feed_query = feed_query.options(*get_selectinload_options(), selectinload(FeedOrm.license)) if limit is not None: feed_query = feed_query.limit(limit) if offset is not None: @@ -251,11 +246,10 @@ def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed: not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted ) ) - .outerjoin(Location, Gtfsrealtimefeed.locations) .options( - joinedload(Gtfsrealtimefeed.entitytypes), - joinedload(Gtfsrealtimefeed.gtfs_feeds), - *get_joinedload_options(), + selectinload(Gtfsrealtimefeed.entitytypes), + selectinload(Gtfsrealtimefeed.gtfs_feeds), + *get_selectinload_options(), ) ).all() @@ -299,61 +293,11 @@ def get_gtfs_rt_feeds( return self._get_response(feed_query, GtfsRTFeedImpl) - entity_types_list = entity_types.split(",") if entity_types else None - - # Validate entity types using the EntityType enum - if entity_types_list: - try: - entity_types_list = [EntityType(et.strip()).value for et in entity_types_list] - except ValueError: - raise_http_validation_error( - "Entity types must be the value 'vp,' 'sa,' or 'tu,'. " - "When provided a list values must be separated by commas." - ) - - gtfs_rt_feed_filter = GtfsRtFeedFilter( - stable_id=None, - provider__ilike=provider, - producer_url__ilike=producer_url, - entity_types=EntityTypeFilter(name__in=entity_types_list), - location=LocationFilter( - country_code=country_code, - subdivision_name__ilike=subdivision_name, - municipality__ilike=municipality, - ), - ) - subquery = gtfs_rt_feed_filter.filter( - select(Gtfsrealtimefeed.id) - .join(Location, Gtfsrealtimefeed.locations) - .join(Entitytype, Gtfsrealtimefeed.entitytypes) - ).subquery() - feed_query = ( - db_session.query(Gtfsrealtimefeed) - .filter(Gtfsrealtimefeed.id.in_(subquery)) - .filter( - or_( - Gtfsrealtimefeed.operational_status == "published", - not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted - ) - ) - .options( - joinedload(Gtfsrealtimefeed.entitytypes), - joinedload(Gtfsrealtimefeed.gtfs_feeds), - *get_joinedload_options(), - ) - .order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id) - ) - feed_query = add_official_filter(feed_query, is_official) - - feed_query = feed_query.limit(limit).offset(offset) - return self._get_response(feed_query, GtfsRTFeedImpl) - @staticmethod def _get_response(feed_query: Query, impl_cls: type[T]) -> List[T]: """Get the response for the feed query.""" results = feed_query.all() - response = [impl_cls.from_orm(feed) for feed in results] - return list({feed.id: feed for feed in response}.values()) + return [impl_cls.from_orm(feed) for feed in results] @with_db_session def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[GtfsRTFeed]: diff --git a/api/src/shared/common/db_utils.py b/api/src/shared/common/db_utils.py index bb7d3c61f..07bb58151 100644 --- a/api/src/shared/common/db_utils.py +++ b/api/src/shared/common/db_utils.py @@ -81,11 +81,15 @@ def get_gtfs_feeds_query( if include_options_for_joinedload: feed_query = feed_query.options( - joinedload(Gtfsfeed.latest_dataset) - .joinedload(Gtfsdataset.validation_reports) - .joinedload(Validationreport.features), - joinedload(Gtfsfeed.visualization_dataset), - *get_joinedload_options(), + # Use selectinload for all collection relationships to avoid a cartesian-product row + # explosion when multiple one-to-many associations are loaded simultaneously. + # joinedload on collections multiplies rows (N feeds × M locations × F features …); + # selectinload issues a separate IN-query per relationship, keeping rows at N per query. + selectinload(Gtfsfeed.latest_dataset) + .selectinload(Gtfsdataset.validation_reports) + .selectinload(Validationreport.features), + joinedload(Gtfsfeed.visualization_dataset), # scalar (many-to-one) — joinedload is safe + *get_selectinload_options(), ).order_by(Gtfsfeed.provider, Gtfsfeed.stable_id) feed_query = feed_query.limit(limit).offset(offset) @@ -274,9 +278,9 @@ def get_gtfs_rt_feeds_query( feed_query = feed_query.filter(Gtfsrealtimefeed.operational_status == "published") feed_query = feed_query.options( - joinedload(Gtfsrealtimefeed.entitytypes), - joinedload(Gtfsrealtimefeed.gtfs_feeds), - *get_joinedload_options(), + selectinload(Gtfsrealtimefeed.entitytypes), + selectinload(Gtfsrealtimefeed.gtfs_feeds), + *get_selectinload_options(), ) feed_query = add_official_filter(feed_query, is_official)