-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathfeeds_operations_impl.py
More file actions
409 lines (376 loc) · 15.6 KB
/
feeds_operations_impl.py
File metadata and controls
409 lines (376 loc) · 15.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
#
# MobilityData 2024
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
import logging
from typing import Annotated, Optional
from deepdiff import DeepDiff
from fastapi import HTTPException
from pydantic import Field, StrictStr
from sqlalchemy import text
from sqlalchemy.orm import Session
from starlette.responses import Response
from feeds_gen.models.data_type import DataType
from feeds_gen.models.get_feeds200_response import GetFeeds200Response
from feeds_gen.models.operation_create_request_gtfs_feed import (
OperationCreateRequestGtfsFeed,
)
from feeds_gen.models.operation_create_request_gtfs_rt_feed import (
OperationCreateRequestGtfsRtFeed,
)
from feeds_gen.models.operation_gtfs_feed import OperationGtfsFeed
from feeds_gen.models.operation_gtfs_rt_feed import OperationGtfsRtFeed
from feeds_operations.impl.models.update_request_gtfs_feed_impl import (
UpdateRequestGtfsFeedImpl,
)
from feeds_operations.impl.models.update_request_gtfs_rt_feed_impl import (
UpdateRequestGtfsRtFeedImpl,
)
from feeds_gen.apis.operations_api_base import BaseOperationsApi
from feeds_gen.models.update_request_gtfs_feed import UpdateRequestGtfsFeed
from feeds_gen.models.update_request_gtfs_rt_feed import (
UpdateRequestGtfsRtFeed,
)
from middleware.request_context_oauth2 import get_request_context
from shared.database.database import with_db_session, refresh_materialized_view
from shared.database_gen.sqlacodegen_models import (
Gtfsfeed,
t_feedsearch,
Feed,
Gtfsrealtimefeed,
)
from shared.helpers.pub_sub import get_execution_id, trigger_dataset_download
from shared.helpers.query_helper import (
query_feed_by_stable_id,
get_feeds_query,
get_feed_by_normalized_url,
)
from .models.operation_create_request_gtfs_feed import (
OperationCreateRequestGtfsFeedImpl,
)
from .models.operation_create_request_gtfs_rt_feed import (
OperationCreateRequestGtfsRtFeedImpl,
)
from .models.operation_feed_impl import OperationFeedImpl
from .models.operation_gtfs_feed_impl import OperationGtfsFeedImpl
from .models.operation_gtfs_rt_feed_impl import OperationGtfsRtFeedImpl
from .request_validator import validate_request
class OperationsApiImpl(BaseOperationsApi):
"""Implementation of the operations API."""
@staticmethod
def assign_feed_id(new_feed: Gtfsfeed | Gtfsrealtimefeed):
client_provided_id = bool(getattr(new_feed, "id", None))
if not client_provided_id:
new_feed.id = new_feed.stable_id
@staticmethod
def assign_stable_id(new_feed: Gtfsfeed | Gtfsrealtimefeed, db_session: Session):
client_provided_stable_id = bool(getattr(new_feed, "stable_id", None))
if not client_provided_stable_id:
next_val = db_session.execute(
text("SELECT nextval('md_sequence')")
).scalar_one()
new_feed.stable_id = f"md-{next_val}"
@staticmethod
def assert_no_existing_feed_url(producer_url: str, db_session: Session):
existing_feed = get_feed_by_normalized_url(producer_url, db_session)
if existing_feed:
message = (
f"A published feed with url "
f"{producer_url} already exists."
f"Existing feed ID: {existing_feed.stable_id}, "
f"URL: {existing_feed.producer_url}"
)
logging.error(message)
raise HTTPException(
status_code=400,
detail=message,
)
@with_db_session
async def get_feeds(
self,
operation_status: Optional[str] = None,
data_type: Optional[str] = None,
offset: str = "0",
limit: str = "50",
db_session: Session = None,
) -> GetFeeds200Response:
"""Get a list of feeds with optional filtering and pagination."""
try:
limit_int = int(limit) if limit else 50
offset_int = int(offset) if offset else 0
query = get_feeds_query(
db_session=db_session,
operation_status=operation_status,
data_type=data_type,
limit=limit_int,
offset=offset_int,
model=Feed,
)
logging.info("Executing query with data_type: %s", data_type)
total = query.count()
feeds = query.all()
logging.info("Retrieved %d feeds from database", len(feeds))
feed_list = []
for feed in feeds:
feed_list.append(OperationFeedImpl.from_orm(feed))
response = GetFeeds200Response(
total=total, offset=offset_int, limit=limit_int, feeds=feed_list
)
logging.info("Returning response with %d feeds", len(feed_list))
return response
except Exception as e:
logging.error("Failed to get feeds. Error: %s", str(e))
raise HTTPException(
status_code=500, detail=f"Internal server error: {str(e)}"
)
@with_db_session
async def get_gtfs_feed(
self,
id: Annotated[
StrictStr, Field(description="The feed ID of the requested feed.")
],
db_session: Session = None,
) -> OperationGtfsFeed:
"""Get the specified GTFS feed from the Mobility Database."""
gtfs_feed = (
db_session.query(Gtfsfeed).filter(Gtfsfeed.stable_id == id).one_or_none()
)
if gtfs_feed is None:
raise HTTPException(status_code=404, detail="GTFS feed not found")
return OperationGtfsFeedImpl.from_orm(gtfs_feed)
@with_db_session
async def get_gtfs_rt_feed(
self,
id: Annotated[
StrictStr, Field(description="The feed ID of the requested feed.")
],
db_session: Session = None,
) -> OperationGtfsRtFeed:
"""Get the specified GTFS-RT feed from the Mobility Database."""
gtfs_rt_feed = (
db_session.query(Gtfsrealtimefeed)
.filter(Gtfsrealtimefeed.stable_id == id)
.one_or_none()
)
if gtfs_rt_feed is None:
raise HTTPException(status_code=404, detail="GTFS-RT feed not found")
return OperationGtfsRtFeedImpl.from_orm(gtfs_rt_feed)
@staticmethod
def detect_changes(
feed: Gtfsfeed,
update_request_feed: UpdateRequestGtfsFeed | UpdateRequestGtfsRtFeed,
impl_class: UpdateRequestGtfsFeedImpl | UpdateRequestGtfsRtFeedImpl,
) -> DeepDiff:
"""Detect changes between the feed and the update request."""
copy_feed = impl_class.from_orm(feed)
copy_feed.operational_status_action = (
update_request_feed.operational_status_action
)
diff = DeepDiff(
copy_feed.model_dump(),
update_request_feed.model_dump(),
ignore_order=True,
)
if diff.affected_paths:
logging.info(
"Detect update changes: affected paths: %s", diff.affected_paths
)
else:
logging.info("Detect update changes: no changes detected")
return diff
@validate_request(UpdateRequestGtfsFeed, "update_request_gtfs_feed")
async def update_gtfs_feed(
self,
update_request_gtfs_feed: Annotated[
UpdateRequestGtfsFeed,
Field(description="Payload to update the specified feed."),
],
) -> Response:
"""Update the specified feed in the Mobility Database.
returns:
- 200: Feed updated successfully.
- 204: No changes detected.
- 400: Feed ID not found.
- 500: Internal server error.
"""
return await self._update_feed(update_request_gtfs_feed, DataType.GTFS)
@validate_request(UpdateRequestGtfsRtFeed, "update_request_gtfs_rt_feed")
async def update_gtfs_rt_feed(
self,
update_request_gtfs_rt_feed: Annotated[
UpdateRequestGtfsRtFeed,
Field(description="Payload to update the specified GTFS-RT feed."),
],
) -> Response:
"""Update the specified GTFS-RT feed in the Mobility Database.
returns:
- 200: Feed updated successfully.
- 204: No changes detected.
- 400: Feed ID not found.
- 500: Internal server error.
"""
return await self._update_feed(update_request_gtfs_rt_feed, DataType.GTFS_RT)
@with_db_session
async def _update_feed(
self,
update_request_feed: UpdateRequestGtfsFeed | UpdateRequestGtfsRtFeed,
data_type: DataType,
db_session: Session,
) -> Response:
"""
Update the specified feed in the Mobility Database
"""
try:
feed_from_db = await OperationsApiImpl.fetch_feed(
data_type, db_session, update_request_feed
)
logging.info(
"Feed ID: %s attempting to update with the following request: %s",
update_request_feed.id,
update_request_feed,
)
impl_class = (
UpdateRequestGtfsFeedImpl
if data_type == DataType.GTFS
else UpdateRequestGtfsRtFeedImpl
)
diff = self.detect_changes(feed_from_db, update_request_feed, impl_class)
if len(diff.affected_paths) > 0 or (
update_request_feed.operational_status_action is not None
and update_request_feed.operational_status_action != "no_change"
):
await OperationsApiImpl._populate_feed_values(
feed_from_db, impl_class, db_session, update_request_feed
)
db_session.flush()
refreshed = refresh_materialized_view(db_session, t_feedsearch.name)
logging.info(
"Materialized view %s refreshed: %s", t_feedsearch.name, refreshed
)
db_session.commit()
logging.info(
"Feed ID: %s updated successfully with the following changes: %s",
update_request_feed.id,
diff.values(),
)
return Response(status_code=200)
else:
logging.info(
"No changes detected for feed ID: %s", update_request_feed.id
)
return Response(status_code=204)
except Exception as e:
logging.error(
"Failed to update feed ID: %s. Error: %s", update_request_feed.id, e
)
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
@staticmethod
async def _populate_feed_values(feed, impl_class, session, update_request_feed):
impl_class.to_orm(update_request_feed, feed, session)
action = update_request_feed.operational_status_action
# This is a temporary solution as the operational_status is not visible in the diff
if action is not None and not action.lower() == "no_change":
if action.lower() == "wip":
feed.operational_status = "wip"
elif action.lower() == "published":
feed.operational_status = "published"
elif action.lower() == "unpublished":
feed.operational_status = "unpublished"
session.add(feed)
@staticmethod
async def fetch_feed(data_type, session, update_request_feed):
"""Fetch a feed by its stable ID with eager loading.
Args:
data_type: The feed data type (gtfs or gtfs_rt)
session: SQLAlchemy session
update_request_feed: The update request containing the feed ID
Returns:
The feed object with relationships loaded
Raises:
HTTPException: If feed not found
"""
feed = query_feed_by_stable_id(session, update_request_feed.id, data_type.value)
if feed is None:
raise HTTPException(
status_code=400,
detail=f"Feed ID not found: {update_request_feed.id}",
)
return feed
@with_db_session
async def create_gtfs_feed(
self,
operation_create_request_gtfs_feed: Annotated[
OperationCreateRequestGtfsFeed,
Field(description="Payload to create the specified GTFS feed."),
],
db_session: Session = None,
) -> OperationGtfsFeed:
"""Create a GTFS feed in the Mobility Database."""
# Check if the provider_url already exists in an active feed
OperationsApiImpl.assert_no_existing_feed_url(
operation_create_request_gtfs_feed.source_info.producer_url,
db_session,
)
# Proceed with feed creation
new_feed = OperationCreateRequestGtfsFeedImpl.to_orm(
operation_create_request_gtfs_feed
)
new_feed.data_type = DataType.GTFS.value
OperationsApiImpl.assign_stable_id(new_feed, db_session)
OperationsApiImpl.assign_feed_id(new_feed)
db_session.add(new_feed)
db_session.commit()
created_feed = db_session.get(Gtfsfeed, new_feed.id)
trigger_dataset_download(
created_feed,
get_execution_id(get_request_context(), "feed-created-process"),
)
logging.info("Created new GTFS feed with ID: %s", new_feed.stable_id)
refreshed = refresh_materialized_view(db_session, t_feedsearch.name)
logging.info("Materialized view %s refreshed: %s", t_feedsearch.name, refreshed)
payload = OperationGtfsFeedImpl.from_orm(created_feed).model_dump()
return JSONResponse(status_code=201, content=jsonable_encoder(payload))
@with_db_session
async def create_gtfs_rt_feed(
self,
operation_create_request_gtfs_rt_feed: Annotated[
OperationCreateRequestGtfsRtFeed,
Field(description="Payload to create the specified GTF-RT feed."),
],
db_session: Session = None,
) -> OperationGtfsRtFeed:
"""Create a GTFS-RT feed in the Mobility Database."""
OperationsApiImpl.assert_no_existing_feed_url(
operation_create_request_gtfs_rt_feed.source_info.producer_url,
db_session,
)
# Proceed with feed creation
new_feed = OperationCreateRequestGtfsRtFeedImpl.to_orm(
operation_create_request_gtfs_rt_feed
)
new_feed.data_type = DataType.GTFS_RT.value
OperationsApiImpl.assign_stable_id(new_feed, db_session)
OperationsApiImpl.assign_feed_id(new_feed)
db_session.add(new_feed)
db_session.commit()
created_feed = db_session.get(Gtfsrealtimefeed, new_feed.id)
logging.info("Created new GTFS-RT feed with ID: %s", new_feed.stable_id)
refreshed = refresh_materialized_view(db_session, t_feedsearch.name)
logging.info("Materialized view %s refreshed: %s", t_feedsearch.name, refreshed)
payload = OperationGtfsRtFeedImpl.from_orm(created_feed).model_dump()
return JSONResponse(status_code=201, content=jsonable_encoder(payload))