-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathdatabase.py
More file actions
260 lines (221 loc) · 10.1 KB
/
database.py
File metadata and controls
260 lines (221 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
from contextlib import contextmanager
import itertools
import os
import threading
import uuid
from typing import Type, Callable
from dotenv import load_dotenv
from sqlalchemy import create_engine, text, event
from sqlalchemy.orm import load_only, Query, class_mapper, Session, mapper
from shared.database_gen.sqlacodegen_models import (
Base,
Feed,
Gtfsfeed,
Gtfsrealtimefeed,
Gbfsversion,
Gbfsfeed,
Gbfsvalidationreport,
)
from sqlalchemy.orm import sessionmaker
import logging
from shared.common.logging_utils import get_env_logging_level
def generate_unique_id() -> str:
"""
Generates a unique ID of 36 characters
:return: the ID
"""
return str(uuid.uuid4())
def configure_polymorphic_mappers():
"""
Configure the polymorphic mappers allowing polymorphic values on relationships.
"""
feed_mapper = class_mapper(Feed)
# Configure the polymorphic mapper using date_type as discriminator for the Feed class
feed_mapper.polymorphic_on = Feed.data_type
feed_mapper.polymorphic_identity = Feed.__tablename__.lower()
gtfsfeed_mapper = class_mapper(Gtfsfeed)
gtfsfeed_mapper.inherits = feed_mapper
gtfsfeed_mapper.polymorphic_identity = Gtfsfeed.__tablename__.lower()
gtfsrealtimefeed_mapper = class_mapper(Gtfsrealtimefeed)
gtfsrealtimefeed_mapper.inherits = feed_mapper
gtfsrealtimefeed_mapper.polymorphic_identity = Gtfsrealtimefeed.__tablename__.lower()
gbfsfeed_mapper = class_mapper(Gbfsfeed)
gbfsfeed_mapper.inherits = feed_mapper
gbfsfeed_mapper.polymorphic_identity = Gbfsfeed.__tablename__.lower()
cascade_entities = {
Gtfsfeed: [Gtfsfeed.redirectingids, Gtfsfeed.redirectingids_, Gtfsfeed.externalids],
Gbfsversion: [Gbfsversion.gbfsendpoints, Gbfsversion.gbfsvalidationreports],
Gbfsfeed: [Gbfsfeed.gbfsversions],
Gbfsvalidationreport: [Gbfsvalidationreport.gbfsnotices],
Feed: [Feed.feedosmlocationgroups],
}
def set_cascade(mapper, class_):
"""
Set cascade for relationships in Gtfsfeed.
This allows to delete/add the relationships when their respective relation array changes.
"""
mapper.confirm_deleted_rows = False # Disable confirm_deleted_rows to avoid warnings in logs with delete-orphan
if class_ in cascade_entities:
relationship_keys = {rel.prop.key for rel in cascade_entities[class_]}
for rel in class_.__mapper__.relationships:
if rel.key in relationship_keys:
rel.cascade = "all, delete-orphan"
rel.passive_deletes = True
def mapper_configure_listener(mapper, class_):
"""
Mapper configure listener
"""
set_cascade(mapper, class_)
configure_polymorphic_mappers()
# Add the mapper_configure_listener to the mapper_configured event
event.listen(mapper, "mapper_configured", mapper_configure_listener)
def refresh_materialized_view(session: "Session", view_name: str) -> bool:
"""
Refresh Materialized view by name.
@return: True if the view was refreshed successfully, False otherwise
"""
try:
session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {view_name}"))
return True
except Exception as error:
logging.error("Error raised while refreshing view: %s", error)
return False
def with_db_session(func=None, db_url: str | None = None):
"""
Decorator to handle the session management for the decorated function.
This decorator ensures that a database session is properly created, committed, rolled back in case of an exception,
and closed. It uses the @contextmanager decorator to manage the lifecycle of the session, providing a clean and
efficient way to handle database interactions.
How it works:
- The decorator checks if a 'db_session' keyword argument is provided to the decorated function.
- If 'db_session' is not provided, it creates a new Database instance and starts a new session using the
start_db_session context manager.
- The context manager ensures that the session is properly committed if no exceptions occur, rolled back if an
exception occurs, and closed in either case.
- The session is then passed to the decorated function as the 'db_session' keyword argument.
- If 'db_session' is already provided, it simply calls the decorated function with the existing session.
- The echoed SQL queries will be logged if the environment variable LOGGING_LEVEL is set to DEBUG.
"""
if func is None:
return lambda f: with_db_session(f, db_url=db_url)
def wrapper(*args, **kwargs):
db_session = kwargs.get("db_session")
if db_session is None:
db = Database(echo_sql=get_env_logging_level() == logging.getLevelName("DEBUG"), feeds_database_url=db_url)
with db.start_db_session() as session:
kwargs["db_session"] = session
return func(*args, **kwargs)
else:
return func(*args, **kwargs)
return wrapper
class Database:
"""
This class represents a database instance
"""
instance = None
initialized = False
lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not isinstance(cls.instance, cls):
with cls.lock:
if not isinstance(cls.instance, cls):
cls.instance = object.__new__(cls)
return cls.instance
def __init__(self, echo_sql=False, feeds_database_url: str | None = None):
"""
Initializes the database instance
:param echo_sql: whether to echo the SQL queries or not echo_sql.
False reduces the amount of information and noise going to the logs.
In case of errors, the exceptions will still contain relevant information about the failing queries.
:param feeds_database_url: The URL of the target database.
If it's None the URL will be assigned from the environment variable FEEDS_DATABASE_URL.
"""
# This init function is called each time we call Database(), but in the case of a singleton, we only want to
# initialize once, so we need to use a lock and a flag
with Database.lock:
if Database.initialized:
return
Database.initialized = True
load_dotenv()
self.logger = logging.getLogger(__name__)
self.connection_attempts = 0
database_url = feeds_database_url if feeds_database_url else os.getenv("FEEDS_DATABASE_URL")
if database_url is None:
raise Exception("Database URL not provided.")
self.pool_size = int(os.getenv("DB_POOL_SIZE", 10))
self.engine = create_engine(database_url, echo=echo_sql, pool_size=self.pool_size, max_overflow=0)
# creates a session factory
self.Session = sessionmaker(bind=self.engine, autoflush=False)
def is_connected(self):
"""
Checks the connection status
:return: True if the database is accessible False otherwise
"""
return self.engine is not None or self.session is not None
@contextmanager
def start_db_session(self):
"""
Context manager to start a database session with optional echo.
This method manages the lifecycle of a database session, ensuring that the session is properly created,
committed, rolled back in case of an exception, and closed. The @contextmanager decorator simplifies
resource management by handling the setup and cleanup logic within a single function.
"""
session = self.Session()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def select(
self,
session: "Session",
model: Type[Base] = None,
query: Query = None,
conditions: list = None,
attributes: list = None,
limit: int = None,
offset: int = None,
group_by: Callable = None,
):
"""
Executes a query on the database.
:param session: The SQLAlchemy session object used to interact with the database.
:param model: The SQLAlchemy model to query. If not provided, the query parameter must be given.
:param query: The SQLAlchemy ORM query to execute. If not provided, a query will be created using the model.
:param conditions: A list of conditions (filters) to apply to the query. Each condition should be a SQLAlchemy
expression.
:param attributes: A list of model's attribute names to fetch. If not provided, all attributes will be fetched.
:param limit: An optional integer to limit the number of rows returned by the query.
:param offset: An optional integer to offset the number of rows returned by the query.
:param group_by: An optional function to group the query results by the return value of the function. The query
needs to order the return values by the key being grouped by.
:return: None if the database is inaccessible, otherwise the results of the query.
"""
try:
if query is None:
query = session.query(model)
if conditions:
for condition in conditions:
query = query.filter(condition)
if attributes is not None:
query = query.options(load_only(*attributes))
if limit is not None:
query = query.limit(limit)
if offset is not None:
query = query.offset(offset)
results = session.execute(query).all()
if group_by:
return [list(group) for _, group in itertools.groupby(results, group_by)]
return results
except Exception as e:
self.logger.error(f"SELECT query failed with exception: \n{e}")
return None
def get_query_model(self, session: Session, model: Type[Base]) -> Query:
"""
:param model: the sqlalchemy model to query
:return: the query model
"""
return session.query(model)