Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions api/src/feeds/impl/search_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,17 @@ def get_parsed_search_tsquery(search_query: str) -> str:

@staticmethod
def add_search_query_filters(
query, search_query, data_type, feed_id, status, is_official, features, version, license_ids, license_is_spdx
query,
search_query,
data_type,
feed_id,
status,
is_official,
features,
version,
license_ids,
license_is_spdx,
license_tags,
) -> Query:
"""
Add filters to the search query.
Expand Down Expand Up @@ -81,6 +91,17 @@ def add_search_query_filters(
or_(t_feedsearch.c.license_is_spdx.is_(False), t_feedsearch.c.license_is_spdx.is_(None))
)

if license_tags:
tag_ids_list = [tid.strip() for tid in license_tags.split(",") if len(tid.strip()) > 0]
if len(tag_ids_list) > 0:
# license_tags is a text[] column – use the @> (contains) operator
# so that ALL requested tags must be present (AND semantics).
query = query.where(
t_feedsearch.c.license_tags.op("@>")(
func.cast(array(tag_ids_list), t_feedsearch.c.license_tags.type)
)
)

# Add feature filter with OR logic
if features:
features_list = [s.strip() for s in features[0].split(",") if s]
Expand All @@ -101,6 +122,7 @@ def create_count_search_query(
search_query: str,
license_ids: str,
license_is_spdx: bool,
license_tags: str,
) -> Query:
"""
Create a search query for the database.
Expand All @@ -117,6 +139,7 @@ def create_count_search_query(
version,
license_ids,
license_is_spdx,
license_tags,
)

@staticmethod
Expand All @@ -130,6 +153,7 @@ def create_search_query(
version: str,
license_ids: str,
license_is_spdx: bool,
license_tags: str,
) -> Query:
"""
Create a search query for the database.
Expand All @@ -153,6 +177,7 @@ def create_search_query(
version,
license_ids,
license_is_spdx,
license_tags,
)
# If search query is provided, use it as secondary sort after timestamp
if search_query and len(search_query.strip()) > 0:
Expand All @@ -177,11 +202,21 @@ def search_feeds(
feature: List[str],
license_ids: str,
license_is_spdx: bool,
license_tags: str,
db_session: "Session",
) -> SearchFeeds200Response:
"""Search feeds using full-text search on feed, location and provider's information."""
query = self.create_search_query(
status, feed_id, data_type, is_official, search_query, feature, version, license_ids, license_is_spdx
status,
feed_id,
data_type,
is_official,
search_query,
feature,
version,
license_ids,
license_is_spdx,
license_tags,
)
feed_rows = Database().select(
session=db_session,
Expand All @@ -192,7 +227,16 @@ def search_feeds(
feed_total_count = Database().select(
session=db_session,
query=self.create_count_search_query(
status, feed_id, data_type, is_official, feature, version, search_query, license_ids, license_is_spdx
status,
feed_id,
data_type,
is_official,
feature,
version,
search_query,
license_ids,
license_is_spdx,
license_tags,
),
)
if feed_rows is None or feed_total_count is None:
Expand Down
3 changes: 3 additions & 0 deletions api/src/shared/db_models/search_feed_item_result_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def from_orm_gtfs(cls, feed_search_row: t_feedsearch):
license_url=feed_search_row.license_url,
license_id=feed_search_row.license_id,
license_is_spdx=feed_search_row.license_is_spdx,
license_tags=(sorted(feed_search_row.license_tags) if feed_search_row.license_tags else []),
),
redirects=feed_search_row.redirect_ids,
locations=cls.resolve_locations(feed_search_row.locations),
Expand Down Expand Up @@ -95,6 +96,7 @@ def from_orm_gbfs(cls, feed_search_row):
license_url=feed_search_row.license_url,
license_id=feed_search_row.license_id,
license_is_spdx=feed_search_row.license_is_spdx,
license_tags=(sorted(feed_search_row.license_tags) if feed_search_row.license_tags else []),
),
redirects=feed_search_row.redirect_ids,
locations=cls.resolve_locations(feed_search_row.locations),
Expand Down Expand Up @@ -124,6 +126,7 @@ def from_orm_gtfs_rt(cls, feed_search_row):
license_url=feed_search_row.license_url,
license_id=feed_search_row.license_id,
license_is_spdx=feed_search_row.license_is_spdx,
license_tags=(sorted(feed_search_row.license_tags) if feed_search_row.license_tags else []),
),
redirects=feed_search_row.redirect_ids,
locations=cls.resolve_locations(feed_search_row.locations),
Expand Down
69 changes: 69 additions & 0 deletions api/tests/integration/test_search_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,72 @@ def test_search_filter_by_feature(client: TestClient, values: dict):
assert requested_features.intersection(features), (
f"Feed {result.id} with features {features} does not match " f"requested features {requested_features}"
)


@pytest.mark.parametrize(
"values",
[
{"license_tags": "family:ODC", "expected_count": 1},
{"license_tags": "license:open-data-commons", "expected_count": 1},
# AND semantics: feed must contain ALL requested tags
{"license_tags": "family:ODC,license:open-data-commons", "expected_count": 1},
{"license_tags": "nonexistent:tag", "expected_count": 0},
{"license_tags": "license:open-data-commons,nonexistent:tag", "expected_count": 0},
{"license_tags": "", "expected_count": 16},
],
ids=[
"Filter by single tag family:ODC",
"Filter by single tag license:open-data-commons",
"Filter by multiple tags (AND semantics)",
"No feed matches nonexistent tag",
"Mixed existing and nonexistent tag (AND semantics)",
"No filter returns all feeds",
],
)
def test_search_filter_by_license_tags(client: TestClient, values: dict):
"""Retrieve feeds that have licenses associated with specific license tag IDs.

The ``license_tags`` parameter accepts a comma-separated list of tag IDs.
The filter uses AND semantics: the feed's ``license_tags`` array must
contain **all** of the requested tags for the feed to be returned.
"""

params = None
if values["license_tags"]:
params = [("license_tags", values["license_tags"])]

headers = {"Authentication": "special-key"}
response = client.request("GET", "/v1/search", headers=headers, params=params)

assert response.status_code == 200

response_body = SearchFeeds200Response.model_validate(response.json())
expected_count = values["expected_count"]
assert (
response_body.total == expected_count
), f"There should be {expected_count} feeds for license_tags={values['license_tags']}"


def test_search_result_contains_license_tags(client: TestClient):
"""
Verify that the search results include license_tags for feeds with license tags.
"""
params = [
("feed_id", "mdb-70"),
]
headers = {
"Authentication": "special-key",
}
response = client.request(
"GET",
"/v1/search",
headers=headers,
params=params,
)
assert response.status_code == 200
response_body = SearchFeeds200Response.parse_obj(response.json())
assert response_body.total == 1
result = response_body.results[0]
assert result.source_info.license_tags is not None
assert "family:ODC" in result.source_info.license_tags[0]
assert "license:open-data-commons" in result.source_info.license_tags[1]
27 changes: 27 additions & 0 deletions api/tests/unittest/models/test_search_feed_item_result_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, **kwargs):
municipality_translations=[],
license_id=None,
license_is_spdx=None,
license_tags=None,
)


Expand All @@ -86,6 +87,7 @@ def test_from_orm_gtfs(self):
authentication_info_url=item.authentication_info_url,
api_key_parameter_name=item.api_key_parameter_name,
license_url=item.license_url,
license_tags=[],
),
redirects=item.redirect_ids,
locations=item.locations,
Expand Down Expand Up @@ -131,6 +133,7 @@ def test_from_orm_gtfs_rt(self):
authentication_info_url=item.authentication_info_url,
api_key_parameter_name=item.api_key_parameter_name,
license_url=item.license_url,
license_tags=[],
),
redirects=item.redirect_ids,
locations=item.locations,
Expand Down Expand Up @@ -218,3 +221,27 @@ def test_from_orm_locations_country_invalid_code(self):
assert result.locations == [
Location(country_code="XY", country="", subdivision_name="subdivision_name", municipality="municipality")
]

def test_from_orm_license_tags(self):
"""Test that license_tags are correctly populated from the feed_search_row."""
item = copy.deepcopy(search_item)
item.data_type = "gtfs"
item.license_tags = ["family:ODC", "license:open-data-commons"]
result = SearchFeedItemResultImpl.from_orm(item)
assert result.source_info.license_tags == ["family:ODC", "license:open-data-commons"]

def test_from_orm_license_tags_none(self):
"""Test that license_tags defaults to empty list when None."""
item = copy.deepcopy(search_item)
item.data_type = "gtfs"
item.license_tags = None
result = SearchFeedItemResultImpl.from_orm(item)
assert result.source_info.license_tags == []

def test_from_orm_gtfs_rt_license_tags(self):
"""Test that license_tags are correctly populated for GTFS-RT feeds."""
item = copy.deepcopy(search_item)
item.data_type = "gtfs_rt"
item.license_tags = ["family:ODC"]
result = SearchFeedItemResultImpl.from_orm(item)
assert result.source_info.license_tags == ["family:ODC"]
9 changes: 9 additions & 0 deletions docs/DatabaseCatalogAPI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ paths:
- $ref: "#/components/parameters/feature"
- $ref: "#/components/parameters/license_ids"
- $ref: "#/components/parameters/license_is_spdx"
- $ref: "#/components/parameters/license_tags"
security:
- Authentication: []
responses:
Expand Down Expand Up @@ -1516,6 +1517,14 @@ components:
required: false
schema:
type: boolean
license_tags:
name: license_tags
in: query
description: Comma separated list of tags to filter feeds by their license tags.
required: false
schema:
type: string
example: family:ODC,license:open-data-commons
provider:
name: provider
in: query
Expand Down
10 changes: 5 additions & 5 deletions liquibase/changelog.xml
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@
<include file="changes/fix_operation_status_constraint.sql" relativeToChangelogFile="true"/>
<!-- Adding alt_name to the geopolygon table -->
<include file="changes/feat_1479.sql" relativeToChangelogFile="true"/>
<!-- Centralized materialized view definitions.
Views are rebuilt from source SQL files using runOnChange. -->
<!-- Keep this at the end to ensure all table and schema changes
are applied before materialized views are rebuilt. -->
<include file="materialized_views/materialized_views.xml" relativeToChangelogFile="true"/>
<!-- Remove filename as the primary key on databasechangelog table and add composite(id, author, filename) -->
<include file="changes/update_liquibase_changelog.sql" relativeToChangelogFile="true"/>
<!-- Change gtfsfile table hosted_url column type to text to accomodate longer URLs. -->
<include file="changes/feat_1542.sql" relativeToChangelogFile="true"/>
<!-- Add license_tag table and license_license_tags join table for tag classification of licenses. -->
<include file="changes/feat_1565.sql" relativeToChangelogFile="true"/>
<!-- Centralized materialized view definitions.
Views are rebuilt from source SQL files using runOnChange. -->
<!-- Keep this at the very end to ensure all table and schema changes
are applied before materialized views are (re)created. -->
<include file="materialized_views/materialized_views.xml" relativeToChangelogFile="true"/>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

</databaseChangeLog>
11 changes: 11 additions & 0 deletions liquibase/materialized_views/feed_search.sql
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ SELECT
Feed.license_id AS license_id,
License.is_spdx AS license_is_spdx,
License.name AS license_name,
LicenseTagsJoin.license_tags AS license_tags,

-- latest_dataset
Latest_dataset.stable_id AS latest_dataset_id,
Expand Down Expand Up @@ -91,6 +92,15 @@ FROM Feed
-- license join
LEFT JOIN License ON License.id = Feed.license_id

-- license tags join
LEFT JOIN (
SELECT
llt.license_id,
array_agg(llt.tag_id ORDER BY llt.tag_id) AS license_tags
FROM license_license_tags llt
GROUP BY llt.license_id
) AS LicenseTagsJoin ON LicenseTagsJoin.license_id = Feed.license_id

-- Latest dataset
LEFT JOIN gtfsfeed gtf ON gtf.id = Feed.id AND Feed.data_type = 'gtfs'
LEFT JOIN gtfsdataset Latest_dataset ON Latest_dataset.id = gtf.latest_dataset_id
Expand Down Expand Up @@ -270,4 +280,5 @@ CREATE INDEX feedsearch_data_type ON FeedSearch(data_type);
CREATE INDEX feedsearch_status ON FeedSearch(status);
CREATE INDEX feedsearch_license_id ON FeedSearch(license_id);
CREATE INDEX feedsearch_license_is_spdx ON FeedSearch(license_is_spdx);
CREATE INDEX feedsearch_license_tags ON FeedSearch USING GIN(license_tags);

3 changes: 3 additions & 0 deletions web-app/src/app/services/feeds/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,8 @@ export interface components {
license_ids?: string;
/** @description Filter feeds by whether their license is an SPDX license. */
license_is_spdx?: boolean;
/** @description Comma separated list of tags to filter feeds by their license tags. */
license_tags?: string;
/** @description List only feeds with the specified value. Can be a partial match. Case insensitive. */
provider?: string;
/** @description List only feeds with the specified value. Can be a partial match. Case insensitive. */
Expand Down Expand Up @@ -1334,6 +1336,7 @@ export interface operations {
feature?: components['parameters']['feature'];
license_ids?: components['parameters']['license_ids'];
license_is_spdx?: components['parameters']['license_is_spdx'];
license_tags?: components['parameters']['license_tags'];
};
};
responses: {
Expand Down
Loading