Skip to content

Commit b31c7fc

Browse files
committed
added test
1 parent d1ed5af commit b31c7fc

1 file changed

Lines changed: 43 additions & 0 deletions

File tree

api/tests/unittest/models/test_gtfs_rt_feed_impl.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import copy
3+
from unittest.mock import MagicMock
34

45
from shared.database_gen.sqlacodegen_models import (
56
Gtfsrealtimefeed,
@@ -8,6 +9,7 @@
89
Location,
910
Redirectingid,
1011
Feed,
12+
Gtfsfeed,
1113
)
1214
from feeds_gen.models.source_info import SourceInfo
1315
from shared.db_models.gtfs_rt_feed_impl import GtfsRTFeedImpl
@@ -100,6 +102,47 @@ def test_from_orm_all_fields(self):
100102
result = GtfsRTFeedImpl.from_orm(gtfs_rt_feed_orm)
101103
assert result == expected_gtfs_rt_feed_result
102104

105+
def test_from_orm_feed_references_location_filter(self):
106+
"""
107+
Test that feed_references are correctly filtered based on shared locations.
108+
"""
109+
# Define locations
110+
location_de = Location(id="loc_de", country_code="US", subdivision_name="Delaware")
111+
location_md = Location(id="loc_md", country_code="US", subdivision_name="Maryland")
112+
location_ia = Location(id="loc_ia", country_code="US", subdivision_name="Iowa")
113+
114+
# Define the GTFS-RT feed (e.g., mdb-1771)
115+
rt_feed = Gtfsrealtimefeed(
116+
stable_id="mdb-1771",
117+
provider="DART",
118+
locations=[location_de, location_md],
119+
entitytypes=[],
120+
)
121+
122+
# Define a correct related schedule feed (e.g., mdb-1235)
123+
correct_schedule_feed = Gtfsfeed(stable_id="mdb-1235", provider="DART", locations=[location_de, location_md])
124+
125+
# Define an incorrect schedule feed with a different location (e.g., mdb-193)
126+
incorrect_schedule_feed = Gtfsfeed(stable_id="mdb-193", provider="DART", locations=[location_ia])
127+
128+
# Mock the database session and its query
129+
mock_session = MagicMock()
130+
mock_query = MagicMock()
131+
mock_session.query.return_value = mock_query
132+
# The query inside from_orm should return both schedule feeds before filtering
133+
mock_query.filter.return_value.options.return_value.all.return_value = [
134+
correct_schedule_feed,
135+
incorrect_schedule_feed,
136+
]
137+
138+
# Execute the method
139+
result = GtfsRTFeedImpl.from_orm(rt_feed, db_session=mock_session)
140+
141+
# Assert that only the correct feed reference is included
142+
self.assertIn("mdb-1235", result.feed_references)
143+
self.assertNotIn("mdb-193", result.feed_references)
144+
self.assertEqual(len(result.feed_references), 1)
145+
103146
def test_from_orm_empty_fields(self):
104147
"""Test the `from_orm` method with not provided fields."""
105148
# Test with empty fields and None values

0 commit comments

Comments
 (0)