diff --git a/api/src/feeds/impl/search_api_impl.py b/api/src/feeds/impl/search_api_impl.py index b01642c22..84c074620 100644 --- a/api/src/feeds/impl/search_api_impl.py +++ b/api/src/feeds/impl/search_api_impl.py @@ -32,7 +32,17 @@ def get_parsed_search_tsquery(search_query: str) -> str: @staticmethod def add_search_query_filters( - query, search_query, data_type, feed_id, status, is_official, features, version, license_ids, license_is_spdx + query, + search_query, + data_type, + feed_id, + status, + is_official, + features, + version, + license_ids, + license_is_spdx, + license_tags, ) -> Query: """ Add filters to the search query. @@ -81,6 +91,17 @@ def add_search_query_filters( or_(t_feedsearch.c.license_is_spdx.is_(False), t_feedsearch.c.license_is_spdx.is_(None)) ) + if license_tags: + tag_ids_list = [tid.strip() for tid in license_tags.split(",") if len(tid.strip()) > 0] + if len(tag_ids_list) > 0: + # license_tags is a text[] column – use the @> (contains) operator + # so that ALL requested tags must be present (AND semantics). + query = query.where( + t_feedsearch.c.license_tags.op("@>")( + func.cast(array(tag_ids_list), t_feedsearch.c.license_tags.type) + ) + ) + # Add feature filter with OR logic if features: features_list = [s.strip() for s in features[0].split(",") if s] @@ -101,6 +122,7 @@ def create_count_search_query( search_query: str, license_ids: str, license_is_spdx: bool, + license_tags: str, ) -> Query: """ Create a search query for the database. @@ -117,6 +139,7 @@ def create_count_search_query( version, license_ids, license_is_spdx, + license_tags, ) @staticmethod @@ -130,6 +153,7 @@ def create_search_query( version: str, license_ids: str, license_is_spdx: bool, + license_tags: str, ) -> Query: """ Create a search query for the database. @@ -153,6 +177,7 @@ def create_search_query( version, license_ids, license_is_spdx, + license_tags, ) # If search query is provided, use it as secondary sort after timestamp if search_query and len(search_query.strip()) > 0: @@ -177,11 +202,21 @@ def search_feeds( feature: List[str], license_ids: str, license_is_spdx: bool, + license_tags: str, db_session: "Session", ) -> SearchFeeds200Response: """Search feeds using full-text search on feed, location and provider's information.""" query = self.create_search_query( - status, feed_id, data_type, is_official, search_query, feature, version, license_ids, license_is_spdx + status, + feed_id, + data_type, + is_official, + search_query, + feature, + version, + license_ids, + license_is_spdx, + license_tags, ) feed_rows = Database().select( session=db_session, @@ -192,7 +227,16 @@ def search_feeds( feed_total_count = Database().select( session=db_session, query=self.create_count_search_query( - status, feed_id, data_type, is_official, feature, version, search_query, license_ids, license_is_spdx + status, + feed_id, + data_type, + is_official, + feature, + version, + search_query, + license_ids, + license_is_spdx, + license_tags, ), ) if feed_rows is None or feed_total_count is None: diff --git a/api/src/shared/db_models/search_feed_item_result_impl.py b/api/src/shared/db_models/search_feed_item_result_impl.py index 5906a143c..2b1ba1148 100644 --- a/api/src/shared/db_models/search_feed_item_result_impl.py +++ b/api/src/shared/db_models/search_feed_item_result_impl.py @@ -42,6 +42,7 @@ def from_orm_gtfs(cls, feed_search_row: t_feedsearch): license_url=feed_search_row.license_url, license_id=feed_search_row.license_id, license_is_spdx=feed_search_row.license_is_spdx, + license_tags=(sorted(feed_search_row.license_tags) if feed_search_row.license_tags else []), ), redirects=feed_search_row.redirect_ids, locations=cls.resolve_locations(feed_search_row.locations), @@ -95,6 +96,7 @@ def from_orm_gbfs(cls, feed_search_row): license_url=feed_search_row.license_url, license_id=feed_search_row.license_id, license_is_spdx=feed_search_row.license_is_spdx, + license_tags=(sorted(feed_search_row.license_tags) if feed_search_row.license_tags else []), ), redirects=feed_search_row.redirect_ids, locations=cls.resolve_locations(feed_search_row.locations), @@ -124,6 +126,7 @@ def from_orm_gtfs_rt(cls, feed_search_row): license_url=feed_search_row.license_url, license_id=feed_search_row.license_id, license_is_spdx=feed_search_row.license_is_spdx, + license_tags=(sorted(feed_search_row.license_tags) if feed_search_row.license_tags else []), ), redirects=feed_search_row.redirect_ids, locations=cls.resolve_locations(feed_search_row.locations), diff --git a/api/tests/integration/test_search_api.py b/api/tests/integration/test_search_api.py index 9765ed5e7..fb52699ad 100644 --- a/api/tests/integration/test_search_api.py +++ b/api/tests/integration/test_search_api.py @@ -564,3 +564,72 @@ def test_search_filter_by_feature(client: TestClient, values: dict): assert requested_features.intersection(features), ( f"Feed {result.id} with features {features} does not match " f"requested features {requested_features}" ) + + +@pytest.mark.parametrize( + "values", + [ + {"license_tags": "family:ODC", "expected_count": 1}, + {"license_tags": "license:open-data-commons", "expected_count": 1}, + # AND semantics: feed must contain ALL requested tags + {"license_tags": "family:ODC,license:open-data-commons", "expected_count": 1}, + {"license_tags": "nonexistent:tag", "expected_count": 0}, + {"license_tags": "license:open-data-commons,nonexistent:tag", "expected_count": 0}, + {"license_tags": "", "expected_count": 16}, + ], + ids=[ + "Filter by single tag family:ODC", + "Filter by single tag license:open-data-commons", + "Filter by multiple tags (AND semantics)", + "No feed matches nonexistent tag", + "Mixed existing and nonexistent tag (AND semantics)", + "No filter returns all feeds", + ], +) +def test_search_filter_by_license_tags(client: TestClient, values: dict): + """Retrieve feeds that have licenses associated with specific license tag IDs. + + The ``license_tags`` parameter accepts a comma-separated list of tag IDs. + The filter uses AND semantics: the feed's ``license_tags`` array must + contain **all** of the requested tags for the feed to be returned. + """ + + params = None + if values["license_tags"]: + params = [("license_tags", values["license_tags"])] + + headers = {"Authentication": "special-key"} + response = client.request("GET", "/v1/search", headers=headers, params=params) + + assert response.status_code == 200 + + response_body = SearchFeeds200Response.model_validate(response.json()) + expected_count = values["expected_count"] + assert ( + response_body.total == expected_count + ), f"There should be {expected_count} feeds for license_tags={values['license_tags']}" + + +def test_search_result_contains_license_tags(client: TestClient): + """ + Verify that the search results include license_tags for feeds with license tags. + """ + params = [ + ("feed_id", "mdb-70"), + ] + headers = { + "Authentication": "special-key", + } + response = client.request( + "GET", + "/v1/search", + headers=headers, + params=params, + ) + assert response.status_code == 200 + response_body = SearchFeeds200Response.parse_obj(response.json()) + assert response_body.total == 1 + result = response_body.results[0] + assert result.source_info.license_tags is not None + assert "family:ODC" in result.source_info.license_tags[0] + assert "license:open-data-commons" in result.source_info.license_tags[1] diff --git a/api/tests/unittest/models/test_search_feed_item_result_impl.py b/api/tests/unittest/models/test_search_feed_item_result_impl.py index 033167124..1d5630056 100644 --- a/api/tests/unittest/models/test_search_feed_item_result_impl.py +++ b/api/tests/unittest/models/test_search_feed_item_result_impl.py @@ -62,6 +62,7 @@ def __init__(self, **kwargs): municipality_translations=[], license_id=None, license_is_spdx=None, + license_tags=None, ) @@ -86,6 +87,7 @@ def test_from_orm_gtfs(self): authentication_info_url=item.authentication_info_url, api_key_parameter_name=item.api_key_parameter_name, license_url=item.license_url, + license_tags=[], ), redirects=item.redirect_ids, locations=item.locations, @@ -131,6 +133,7 @@ def test_from_orm_gtfs_rt(self): authentication_info_url=item.authentication_info_url, api_key_parameter_name=item.api_key_parameter_name, license_url=item.license_url, + license_tags=[], ), redirects=item.redirect_ids, locations=item.locations, @@ -218,3 +221,27 @@ def test_from_orm_locations_country_invalid_code(self): assert result.locations == [ Location(country_code="XY", country="", subdivision_name="subdivision_name", municipality="municipality") ] + + def test_from_orm_license_tags(self): + """Test that license_tags are correctly populated from the feed_search_row.""" + item = copy.deepcopy(search_item) + item.data_type = "gtfs" + item.license_tags = ["family:ODC", "license:open-data-commons"] + result = SearchFeedItemResultImpl.from_orm(item) + assert result.source_info.license_tags == ["family:ODC", "license:open-data-commons"] + + def test_from_orm_license_tags_none(self): + """Test that license_tags defaults to empty list when None.""" + item = copy.deepcopy(search_item) + item.data_type = "gtfs" + item.license_tags = None + result = SearchFeedItemResultImpl.from_orm(item) + assert result.source_info.license_tags == [] + + def test_from_orm_gtfs_rt_license_tags(self): + """Test that license_tags are correctly populated for GTFS-RT feeds.""" + item = copy.deepcopy(search_item) + item.data_type = "gtfs_rt" + item.license_tags = ["family:ODC"] + result = SearchFeedItemResultImpl.from_orm(item) + assert result.source_info.license_tags == ["family:ODC"] diff --git a/docs/DatabaseCatalogAPI.yaml b/docs/DatabaseCatalogAPI.yaml index 2c1a8943c..8fae3c9e5 100644 --- a/docs/DatabaseCatalogAPI.yaml +++ b/docs/DatabaseCatalogAPI.yaml @@ -338,6 +338,7 @@ paths: - $ref: "#/components/parameters/feature" - $ref: "#/components/parameters/license_ids" - $ref: "#/components/parameters/license_is_spdx" + - $ref: "#/components/parameters/license_tags" security: - Authentication: [] responses: @@ -1516,6 +1517,14 @@ components: required: false schema: type: boolean + license_tags: + name: license_tags + in: query + description: Comma separated list of tags to filter feeds by their license tags. + required: false + schema: + type: string + example: family:ODC,license:open-data-commons provider: name: provider in: query diff --git a/liquibase/changelog.xml b/liquibase/changelog.xml index cb639fbda..69d47cc33 100644 --- a/liquibase/changelog.xml +++ b/liquibase/changelog.xml @@ -95,15 +95,15 @@ - - - + + + diff --git a/liquibase/materialized_views/feed_search.sql b/liquibase/materialized_views/feed_search.sql index 30e2d1055..f59df5d18 100644 --- a/liquibase/materialized_views/feed_search.sql +++ b/liquibase/materialized_views/feed_search.sql @@ -28,6 +28,7 @@ SELECT Feed.license_id AS license_id, License.is_spdx AS license_is_spdx, License.name AS license_name, + LicenseTagsJoin.license_tags AS license_tags, -- latest_dataset Latest_dataset.stable_id AS latest_dataset_id, @@ -91,6 +92,15 @@ FROM Feed -- license join LEFT JOIN License ON License.id = Feed.license_id +-- license tags join +LEFT JOIN ( + SELECT + llt.license_id, + array_agg(llt.tag_id ORDER BY llt.tag_id) AS license_tags + FROM license_license_tags llt + GROUP BY llt.license_id +) AS LicenseTagsJoin ON LicenseTagsJoin.license_id = Feed.license_id + -- Latest dataset LEFT JOIN gtfsfeed gtf ON gtf.id = Feed.id AND Feed.data_type = 'gtfs' LEFT JOIN gtfsdataset Latest_dataset ON Latest_dataset.id = gtf.latest_dataset_id @@ -270,4 +280,5 @@ CREATE INDEX feedsearch_data_type ON FeedSearch(data_type); CREATE INDEX feedsearch_status ON FeedSearch(status); CREATE INDEX feedsearch_license_id ON FeedSearch(license_id); CREATE INDEX feedsearch_license_is_spdx ON FeedSearch(license_is_spdx); +CREATE INDEX feedsearch_license_tags ON FeedSearch USING GIN(license_tags); diff --git a/web-app/src/app/services/feeds/types.ts b/web-app/src/app/services/feeds/types.ts index 634ff1441..b4cfb2fb1 100644 --- a/web-app/src/app/services/feeds/types.ts +++ b/web-app/src/app/services/feeds/types.ts @@ -974,6 +974,8 @@ export interface components { license_ids?: string; /** @description Filter feeds by whether their license is an SPDX license. */ license_is_spdx?: boolean; + /** @description Comma separated list of tags to filter feeds by their license tags. */ + license_tags?: string; /** @description List only feeds with the specified value. Can be a partial match. Case insensitive. */ provider?: string; /** @description List only feeds with the specified value. Can be a partial match. Case insensitive. */ @@ -1334,6 +1336,7 @@ export interface operations { feature?: components['parameters']['feature']; license_ids?: components['parameters']['license_ids']; license_is_spdx?: components['parameters']['license_is_spdx']; + license_tags?: components['parameters']['license_tags']; }; }; responses: {