Skip to content

Commit 8dd6737

Browse files
authored
feat: added license filtering to search endpoint (#1593)
1 parent a848501 commit 8dd6737

File tree

9 files changed

+191
-7
lines changed

9 files changed

+191
-7
lines changed

api/src/feeds/impl/search_api_impl.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_parsed_search_tsquery(search_query: str) -> str:
3232

3333
@staticmethod
3434
def add_search_query_filters(
35-
query, search_query, data_type, feed_id, status, is_official, features, version
35+
query, search_query, data_type, feed_id, status, is_official, features, version, license_ids, license_is_spdx
3636
) -> Query:
3737
"""
3838
Add filters to the search query.
@@ -68,6 +68,19 @@ def add_search_query_filters(
6868
query = query.filter(
6969
t_feedsearch.c.document.op("@@")(SearchApiImpl.get_parsed_search_tsquery(search_query))
7070
)
71+
if license_ids:
72+
license_ids_list = [lid.strip() for lid in license_ids.split(",") if len(lid.strip()) > 0]
73+
if len(license_ids_list) > 0:
74+
query = query.where(t_feedsearch.c.license_id.in_(license_ids_list))
75+
76+
if license_is_spdx is not None:
77+
if license_is_spdx:
78+
query = query.where(t_feedsearch.c.license_is_spdx.is_(True))
79+
else:
80+
query = query.where(
81+
or_(t_feedsearch.c.license_is_spdx.is_(False), t_feedsearch.c.license_is_spdx.is_(None))
82+
)
83+
7184
# Add feature filter with OR logic
7285
if features:
7386
features_list = [s.strip() for s in features[0].split(",") if s]
@@ -86,13 +99,24 @@ def create_count_search_query(
8699
features,
87100
version: str,
88101
search_query: str,
102+
license_ids: str,
103+
license_is_spdx: bool,
89104
) -> Query:
90105
"""
91106
Create a search query for the database.
92107
"""
93108
query = select(func.count(t_feedsearch.c.feed_id))
94109
return SearchApiImpl.add_search_query_filters(
95-
query, search_query, data_type, feed_id, status, is_official, features, version
110+
query,
111+
search_query,
112+
data_type,
113+
feed_id,
114+
status,
115+
is_official,
116+
features,
117+
version,
118+
license_ids,
119+
license_is_spdx,
96120
)
97121

98122
@staticmethod
@@ -104,6 +128,8 @@ def create_search_query(
104128
search_query: str,
105129
features: List[str],
106130
version: str,
131+
license_ids: str,
132+
license_is_spdx: bool,
107133
) -> Query:
108134
"""
109135
Create a search query for the database.
@@ -117,7 +143,16 @@ def create_search_query(
117143
*feed_search_columns,
118144
)
119145
query = SearchApiImpl.add_search_query_filters(
120-
query, search_query, data_type, feed_id, status, is_official, features, version
146+
query,
147+
search_query,
148+
data_type,
149+
feed_id,
150+
status,
151+
is_official,
152+
features,
153+
version,
154+
license_ids,
155+
license_is_spdx,
121156
)
122157
# If search query is provided, use it as secondary sort after timestamp
123158
if search_query and len(search_query.strip()) > 0:
@@ -140,10 +175,14 @@ def search_feeds(
140175
version: str,
141176
search_query: str,
142177
feature: List[str],
178+
license_ids: str,
179+
license_is_spdx: bool,
143180
db_session: "Session",
144181
) -> SearchFeeds200Response:
145182
"""Search feeds using full-text search on feed, location and provider's information."""
146-
query = self.create_search_query(status, feed_id, data_type, is_official, search_query, feature, version)
183+
query = self.create_search_query(
184+
status, feed_id, data_type, is_official, search_query, feature, version, license_ids, license_is_spdx
185+
)
147186
feed_rows = Database().select(
148187
session=db_session,
149188
query=query,
@@ -153,7 +192,7 @@ def search_feeds(
153192
feed_total_count = Database().select(
154193
session=db_session,
155194
query=self.create_count_search_query(
156-
status, feed_id, data_type, is_official, feature, version, search_query
195+
status, feed_id, data_type, is_official, feature, version, search_query, license_ids, license_is_spdx
157196
),
158197
)
159198
if feed_rows is None or feed_total_count is None:

api/src/scripts/populate_db_test_data.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Gbfsendpoint,
2121
Gbfsfeed,
2222
Rule,
23+
Feed,
2324
)
2425
from scripts.populate_db import set_up_configs, DatabasePopulateHelper
2526
from typing import TYPE_CHECKING
@@ -83,6 +84,29 @@ def populate_test_datasets(self, filepath, db_session: "Session"):
8384
db_session.add(license_obj)
8485
db_session.commit()
8586

87+
# Link licenses to feeds if specified
88+
if "feed_licenses" in data:
89+
for lf in data["feed_licenses"]:
90+
license_id = lf.get("license_id")
91+
feed_stable_id = lf.get("feed_stable_id")
92+
if not license_id or not feed_stable_id:
93+
continue
94+
license_obj = db_session.get(License, license_id)
95+
if not license_obj:
96+
self.logger.error(
97+
f"No license found with id: {license_id}; skipping license_feed for feed " f"{feed_stable_id}"
98+
)
99+
continue
100+
feed_obj = db_session.query(Feed).filter(Feed.stable_id == feed_stable_id).one_or_none()
101+
if not feed_obj:
102+
self.logger.error(
103+
f"No feed found with stable_id: {feed_stable_id}; skipping license_feed for"
104+
f" license {license_id}"
105+
)
106+
continue
107+
feed_obj.license = license_obj
108+
db_session.commit()
109+
86110
# Rules (optional section to seed rule metadata used by license_rules)
87111
if "rules" in data:
88112
for rule in data["rules"]:

api/src/shared/db_models/search_feed_item_result_impl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def from_orm_gtfs(cls, feed_search_row: t_feedsearch):
4040
authentication_info_url=feed_search_row.authentication_info_url,
4141
api_key_parameter_name=feed_search_row.api_key_parameter_name,
4242
license_url=feed_search_row.license_url,
43+
license_id=feed_search_row.license_id,
44+
license_is_spdx=feed_search_row.license_is_spdx,
4345
),
4446
redirects=feed_search_row.redirect_ids,
4547
locations=cls.resolve_locations(feed_search_row.locations),
@@ -91,6 +93,8 @@ def from_orm_gbfs(cls, feed_search_row):
9193
authentication_info_url=feed_search_row.authentication_info_url,
9294
api_key_parameter_name=feed_search_row.api_key_parameter_name,
9395
license_url=feed_search_row.license_url,
96+
license_id=feed_search_row.license_id,
97+
license_is_spdx=feed_search_row.license_is_spdx,
9498
),
9599
redirects=feed_search_row.redirect_ids,
96100
locations=cls.resolve_locations(feed_search_row.locations),
@@ -118,6 +122,8 @@ def from_orm_gtfs_rt(cls, feed_search_row):
118122
authentication_info_url=feed_search_row.authentication_info_url,
119123
api_key_parameter_name=feed_search_row.api_key_parameter_name,
120124
license_url=feed_search_row.license_url,
125+
license_id=feed_search_row.license_id,
126+
license_is_spdx=feed_search_row.license_is_spdx,
121127
),
122128
redirects=feed_search_row.redirect_ids,
123129
locations=cls.resolve_locations(feed_search_row.locations),

api/tests/integration/test_search_api.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,49 @@ def test_search_filter_by_versions(client: TestClient, values: dict):
470470
), f"There should be {expected_count} feeds for versions={values['versions']}"
471471

472472

473+
@pytest.mark.parametrize(
474+
"values",
475+
[
476+
{"license_ids": "CC BY 4.0", "expected_count": 1},
477+
{"license_ids": "ODbL-1.0", "expected_count": 1},
478+
{"license_ids": "ODbL-1.0,CC BY 4.0", "expected_count": 2},
479+
{"license_ids": "", "expected_count": 16},
480+
],
481+
ids=[
482+
"License ID CC BY 4.0",
483+
"License ID ODbL-1.0",
484+
"License IDs ODbL-1.0 and CC BY 4.0",
485+
"No license IDs specified",
486+
],
487+
)
488+
def test_search_filter_by_license_ids(client: TestClient, values: dict):
489+
"""
490+
Retrieve feeds that contain specific license IDs.
491+
"""
492+
params = None
493+
if values["license_ids"] is not None:
494+
params = [
495+
("license_ids", values["license_ids"]),
496+
]
497+
headers = {
498+
"Authentication": "special-key",
499+
}
500+
response = client.request(
501+
"GET",
502+
"/v1/search",
503+
headers=headers,
504+
params=params,
505+
)
506+
# Assert the status code of the HTTP response
507+
assert response.status_code == 200
508+
# Parse the response body into a Python object
509+
response_body = SearchFeeds200Response.parse_obj(response.json())
510+
expected_count = values["expected_count"]
511+
assert (
512+
response_body.total == expected_count
513+
), f"There should be {expected_count} feeds for license_ids={values['license_ids']}"
514+
515+
473516
@pytest.mark.parametrize(
474517
"values",
475518
[

api/tests/test_data/extra_test_data.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,5 +825,29 @@
825825
}
826826
]
827827
}
828+
],
829+
"licenses": [
830+
{
831+
"id": "CC BY 4.0",
832+
"is_spdx": true,
833+
"name": "Creative Commons Attribution 4.0 International",
834+
"url": "https://creativecommons.org/licenses/by/4.0/"
835+
},
836+
{
837+
"id": "ODbL-1.0",
838+
"is_spdx": true,
839+
"name": "Open Data Commons Open Database License (ODbL) v1.0",
840+
"url": "https://opendatacommons.org/licenses/odbl/1.0/"
841+
}
842+
],
843+
"feed_licenses": [
844+
{
845+
"feed_stable_id": "mdb-1",
846+
"license_id": "CC BY 4.0"
847+
},
848+
{
849+
"feed_stable_id": "gbfs-system_id_1",
850+
"license_id": "ODbL-1.0"
851+
}
828852
]
829853
}

api/tests/unittest/models/test_search_feed_item_result_impl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def __init__(self, **kwargs):
6060
country_translations=[],
6161
subdivision_name_translations=[],
6262
municipality_translations=[],
63+
license_id=None,
64+
license_is_spdx=None,
6365
)
6466

6567

docs/DatabaseCatalogAPI.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ paths:
338338
- $ref: "#/components/parameters/version_query_param"
339339
- $ref: "#/components/parameters/search_text_query_param"
340340
- $ref: "#/components/parameters/feature"
341+
- $ref: "#/components/parameters/license_ids"
342+
- $ref: "#/components/parameters/license_is_spdx"
341343
security:
342344
- Authentication: []
343345
responses:
@@ -1485,6 +1487,21 @@ components:
14851487
type: array
14861488
items:
14871489
type: string
1490+
license_ids:
1491+
name: license_ids
1492+
in: query
1493+
description: Comma separated list of license IDs to filter feeds by their license.
1494+
required: false
1495+
schema:
1496+
type: string
1497+
example: CC-BY-4.0,ODbL-1.0
1498+
license_is_spdx:
1499+
name: license_is_spdx
1500+
in: query
1501+
description: Filter feeds by whether their license is an SPDX license.
1502+
required: false
1503+
schema:
1504+
type: boolean
14881505
provider:
14891506
name: provider
14901507
in: query

liquibase/materialized_views/feed_search.sql

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ SELECT
88
Feed.feed_name,
99
Feed.note,
1010
Feed.feed_contact_email,
11+
1112
-- source
1213
Feed.producer_url,
1314
Feed.authentication_info_url,
@@ -16,10 +17,18 @@ SELECT
1617
Feed.license_url,
1718
Feed.provider,
1819
Feed.operational_status,
20+
1921
-- official status
2022
Feed.official AS official,
23+
2124
-- created_at
2225
Feed.created_at AS created_at,
26+
27+
-- license fields
28+
Feed.license_id AS license_id,
29+
License.is_spdx AS license_is_spdx,
30+
License.name AS license_name,
31+
2332
-- latest_dataset
2433
Latest_dataset.stable_id AS latest_dataset_id,
2534
Latest_dataset.hosted_url AS latest_dataset_hosted_url,
@@ -29,28 +38,37 @@ SELECT
2938
Latest_dataset.agency_timezone AS latest_dataset_agency_timezone,
3039
Latest_dataset.service_date_range_start AS latest_dataset_service_date_range_start,
3140
Latest_dataset.service_date_range_end AS latest_dataset_service_date_range_end,
41+
3242
-- Latest dataset features
3343
LatestDatasetFeatures AS latest_dataset_features,
44+
3445
-- Latest dataset validation totals
3546
COALESCE(LatestDatasetValidationReportJoin.total_error, 0) as latest_total_error,
3647
COALESCE(LatestDatasetValidationReportJoin.total_warning, 0) as latest_total_warning,
3748
COALESCE(LatestDatasetValidationReportJoin.total_info, 0) as latest_total_info,
3849
COALESCE(LatestDatasetValidationReportJoin.unique_error_count, 0) as latest_unique_error_count,
3950
COALESCE(LatestDatasetValidationReportJoin.unique_warning_count, 0) as latest_unique_warning_count,
4051
COALESCE(LatestDatasetValidationReportJoin.unique_info_count, 0) as latest_unique_info_count,
52+
4153
-- external_ids
4254
ExternalIdJoin.external_ids,
55+
4356
-- redirect_ids
4457
RedirectingIdJoin.redirect_ids,
58+
4559
-- feed gtfs_rt references
4660
FeedReferenceJoin.feed_reference_ids,
61+
4762
-- feed gtfs_rt entities
4863
EntityTypeFeedJoin.entities,
64+
4965
-- locations
5066
FeedLocationJoin.locations,
67+
5168
-- osm locations grouped
5269
OsmLocationJoin.osm_locations,
53-
-- gbfs versions
70+
71+
-- gbfs versions
5472
COALESCE(GbfsVersionsJoin.versions, '[]'::jsonb) AS versions,
5573

5674
-- full-text searchable document
@@ -70,6 +88,9 @@ SELECT
7088
AS document
7189
FROM Feed
7290

91+
-- license join
92+
LEFT JOIN License ON License.id = Feed.license_id
93+
7394
-- Latest dataset
7495
LEFT JOIN gtfsfeed gtf ON gtf.id = Feed.id AND Feed.data_type = 'gtfs'
7596
LEFT JOIN gtfsdataset Latest_dataset ON Latest_dataset.id = gtf.latest_dataset_id
@@ -149,7 +170,6 @@ LEFT JOIN (
149170
GROUP BY gtfs_rt_feed_id
150171
) AS FeedReferenceJoin ON FeedReferenceJoin.gtfs_rt_feed_id = Feed.id AND Feed.data_type = 'gtfs_rt'
151172

152-
-- Redirect ids
153173
-- Redirect ids
154174
LEFT JOIN (
155175
SELECT
@@ -159,6 +179,7 @@ LEFT JOIN (
159179
JOIN Feed f ON r.target_id = f.id
160180
GROUP BY r.target_id
161181
) AS RedirectingIdJoin ON RedirectingIdJoin.target_id = Feed.id
182+
162183
-- Feed locations
163184
LEFT JOIN (
164185
SELECT
@@ -247,4 +268,6 @@ CREATE INDEX feedsearch_document_idx ON FeedSearch USING GIN(document);
247268
CREATE INDEX feedsearch_feed_stable_id ON FeedSearch(feed_stable_id);
248269
CREATE INDEX feedsearch_data_type ON FeedSearch(data_type);
249270
CREATE INDEX feedsearch_status ON FeedSearch(status);
271+
CREATE INDEX feedsearch_license_id ON FeedSearch(license_id);
272+
CREATE INDEX feedsearch_license_is_spdx ON FeedSearch(license_is_spdx);
250273

0 commit comments

Comments
 (0)