|
1 | 1 | import unittest |
2 | 2 | import copy |
| 3 | +from unittest.mock import MagicMock |
3 | 4 |
|
4 | 5 | from shared.database_gen.sqlacodegen_models import ( |
5 | 6 | Gtfsrealtimefeed, |
|
8 | 9 | Location, |
9 | 10 | Redirectingid, |
10 | 11 | Feed, |
| 12 | + Gtfsfeed, |
11 | 13 | ) |
12 | 14 | from feeds_gen.models.source_info import SourceInfo |
13 | 15 | from shared.db_models.gtfs_rt_feed_impl import GtfsRTFeedImpl |
@@ -100,6 +102,47 @@ def test_from_orm_all_fields(self): |
100 | 102 | result = GtfsRTFeedImpl.from_orm(gtfs_rt_feed_orm) |
101 | 103 | assert result == expected_gtfs_rt_feed_result |
102 | 104 |
|
| 105 | + def test_from_orm_feed_references_location_filter(self): |
| 106 | + """ |
| 107 | + Test that feed_references are correctly filtered based on shared locations. |
| 108 | + """ |
| 109 | + # Define locations |
| 110 | + location_de = Location(id="loc_de", country_code="US", subdivision_name="Delaware") |
| 111 | + location_md = Location(id="loc_md", country_code="US", subdivision_name="Maryland") |
| 112 | + location_ia = Location(id="loc_ia", country_code="US", subdivision_name="Iowa") |
| 113 | + |
| 114 | + # Define the GTFS-RT feed (e.g., mdb-1771) |
| 115 | + rt_feed = Gtfsrealtimefeed( |
| 116 | + stable_id="mdb-1771", |
| 117 | + provider="DART", |
| 118 | + locations=[location_de, location_md], |
| 119 | + entitytypes=[], |
| 120 | + ) |
| 121 | + |
| 122 | + # Define a correct related schedule feed (e.g., mdb-1235) |
| 123 | + correct_schedule_feed = Gtfsfeed(stable_id="mdb-1235", provider="DART", locations=[location_de, location_md]) |
| 124 | + |
| 125 | + # Define an incorrect schedule feed with a different location (e.g., mdb-193) |
| 126 | + incorrect_schedule_feed = Gtfsfeed(stable_id="mdb-193", provider="DART", locations=[location_ia]) |
| 127 | + |
| 128 | + # Mock the database session and its query |
| 129 | + mock_session = MagicMock() |
| 130 | + mock_query = MagicMock() |
| 131 | + mock_session.query.return_value = mock_query |
| 132 | + # The query inside from_orm should return both schedule feeds before filtering |
| 133 | + mock_query.filter.return_value.options.return_value.all.return_value = [ |
| 134 | + correct_schedule_feed, |
| 135 | + incorrect_schedule_feed, |
| 136 | + ] |
| 137 | + |
| 138 | + # Execute the method |
| 139 | + result = GtfsRTFeedImpl.from_orm(rt_feed, db_session=mock_session) |
| 140 | + |
| 141 | + # Assert that only the correct feed reference is included |
| 142 | + self.assertIn("mdb-1235", result.feed_references) |
| 143 | + self.assertNotIn("mdb-193", result.feed_references) |
| 144 | + self.assertEqual(len(result.feed_references), 1) |
| 145 | + |
103 | 146 | def test_from_orm_empty_fields(self): |
104 | 147 | """Test the `from_orm` method with not provided fields.""" |
105 | 148 | # Test with empty fields and None values |
|
0 commit comments