Skip to content

Commit 74bc701

Browse files
committed
fix: type errors in officers crud
1 parent f571eb3 commit 74bc701

1 file changed

Lines changed: 39 additions & 50 deletions

File tree

src/officers/crud.py

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from collections.abc import Sequence
22
from datetime import date
33

4-
import sqlalchemy
54
from fastapi import HTTPException
5+
from sqlalchemy import Row, delete, select, update
66
from sqlalchemy.ext.asyncio import AsyncSession
77

88
import auth.crud
@@ -12,7 +12,7 @@
1212
from data import semesters
1313
from officers.constants import OfficerPosition
1414
from officers.models import OfficerInfoResponse
15-
from officers.tables import OfficerInfo, OfficerTerm
15+
from officers.tables import OfficerInfo, OfficerTermDB
1616

1717
# NOTE: this module should not do any data validation; that should be done in the urls.py or higher layer
1818

@@ -27,13 +27,13 @@ async def current_officers(
2727
"""
2828
curr_time = date.today()
2929
query = (
30-
sqlalchemy.select(OfficerTerm, OfficerInfo)
31-
.join(OfficerInfo, OfficerTerm.computing_id == OfficerInfo.computing_id)
32-
.where((OfficerTerm.start_date <= curr_time) & (OfficerTerm.end_date >= curr_time))
33-
.order_by(OfficerTerm.start_date.desc())
30+
select(OfficerTermDB, OfficerInfo)
31+
.join(OfficerInfo, OfficerTermDB.computing_id == OfficerInfo.computing_id)
32+
.where((OfficerTermDB.start_date <= curr_time) & (OfficerTermDB.end_date >= curr_time))
33+
.order_by(OfficerTermDB.start_date.desc())
3434
)
3535

36-
result: Sequence[sqlalchemy.Row[tuple[OfficerTerm, OfficerInfo]]] = (await db_session.execute(query)).all()
36+
result: Sequence[Row[tuple[OfficerTermDB, OfficerInfo]]] = (await db_session.execute(query)).all()
3737
officer_list = []
3838
for term, officer in result:
3939
officer_list.append(
@@ -59,27 +59,23 @@ async def current_officers(
5959
return officer_list
6060

6161

62-
async def get_current_terms_by_position(db_session: database.DBSession, position: str) -> list[OfficerInfoResponse]:
62+
async def get_current_terms_by_position(db_session: database.DBSession, position: str) -> list[OfficerTermDB]:
6363
"""
6464
Get current officer that holds a position
6565
"""
6666
curr_time = date.today()
6767
query = (
68-
sqlalchemy.select(OfficerTerm)
69-
.join(OfficerInfo, OfficerTerm.computing_id)
68+
select(OfficerTermDB)
7069
.where(
71-
(OfficerTerm.start_date <= curr_time) & (OfficerTerm.end_date >= curr_time) & OfficerTerm.position
70+
(OfficerTermDB.start_date <= curr_time) & (OfficerTermDB.end_date >= curr_time) & OfficerTermDB.position
7271
== position
7372
)
74-
.order_by(OfficerTerm.start_date.desc())
73+
.order_by(OfficerTermDB.start_date.desc())
7574
)
7675

77-
result = (await db_session.execute(query)).all()
78-
officer_list = []
79-
for term in result:
80-
officer_list.append(OfficerTerm(*term))
76+
result = list((await db_session.scalars(query)).all())
8177

82-
return officer_list
78+
return result
8379

8480

8581
async def all_officers(db_session: AsyncSession, include_future_terms: bool) -> list[OfficerInfoResponse]:
@@ -88,14 +84,14 @@ async def all_officers(db_session: AsyncSession, include_future_terms: bool) ->
8884
"""
8985
# NOTE: paginate data if needed
9086
query = (
91-
sqlalchemy.select(OfficerTerm, OfficerInfo)
92-
.join(OfficerInfo, OfficerTerm.computing_id == OfficerInfo.computing_id)
93-
.order_by(OfficerTerm.start_date.desc())
87+
select(OfficerTermDB, OfficerInfo)
88+
.join(OfficerInfo, OfficerTermDB.computing_id == OfficerInfo.computing_id)
89+
.order_by(OfficerTermDB.start_date.desc())
9490
)
9591

9692
if not include_future_terms:
9793
query = utils.has_started_term(query)
98-
result: Sequence[sqlalchemy.Row[tuple[OfficerTerm, OfficerInfo]]] = (await db_session.execute(query)).all()
94+
result: Sequence[Row[tuple[OfficerTermDB, OfficerInfo]]] = (await db_session.execute(query)).all()
9995
officer_list = []
10096
for term, officer in result:
10197
officer_list.append(
@@ -122,9 +118,7 @@ async def all_officers(db_session: AsyncSession, include_future_terms: bool) ->
122118

123119

124120
async def get_officer_info_or_raise(db_session: database.DBSession, computing_id: str) -> OfficerInfo:
125-
officer_term = await db_session.scalar(
126-
sqlalchemy.select(OfficerInfo).where(OfficerInfo.computing_id == computing_id)
127-
)
121+
officer_term = await db_session.scalar(select(OfficerInfo).where(OfficerInfo.computing_id == computing_id))
128122
if officer_term is None:
129123
raise HTTPException(status_code=404, detail=f"officer_info for computing_id={computing_id} does not exist yet")
130124
return officer_term
@@ -134,9 +128,7 @@ async def get_new_officer_info_or_raise(db_session: database.DBSession, computin
134128
"""
135129
This check is for after a create/update
136130
"""
137-
officer_term = await db_session.scalar(
138-
sqlalchemy.select(OfficerInfo).where(OfficerInfo.computing_id == computing_id)
139-
)
131+
officer_term = await db_session.scalar(select(OfficerInfo).where(OfficerInfo.computing_id == computing_id))
140132
if officer_term is None:
141133
raise HTTPException(status_code=500, detail=f"failed to fetch {computing_id} after update")
142134
return officer_term
@@ -146,34 +138,33 @@ async def get_officer_terms(
146138
db_session: database.DBSession,
147139
computing_id: str,
148140
include_future_terms: bool,
149-
) -> list[OfficerTerm]:
141+
) -> list[OfficerTermDB]:
150142
query = (
151-
sqlalchemy.select(OfficerTerm)
152-
.where(OfficerTerm.computing_id == computing_id)
143+
select(OfficerTermDB)
144+
.where(OfficerTermDB.computing_id == computing_id)
153145
# In order of most recent start date first
154-
.order_by(OfficerTerm.start_date.desc())
146+
.order_by(OfficerTermDB.start_date.desc())
155147
)
156148
if not include_future_terms:
157149
query = utils.has_started_term(query)
158150

159-
return (await db_session.scalars(query)).all()
151+
return list((await db_session.scalars(query)).all())
160152

161153

162-
async def get_active_officer_terms(db_session: database.DBSession, computing_id: str) -> list[OfficerTerm]:
154+
async def get_active_officer_terms(db_session: database.DBSession, computing_id: str) -> list[OfficerTermDB]:
163155
"""
164156
Returns the list of active officer terms for a user. Returns [] if the user is not currently an officer.
165157
An officer can have multiple positions at once, such as Webmaster, Frosh chair, and DoEE.
166158
"""
167159
query = (
168-
sqlalchemy.select(OfficerTerm)
169-
.where(OfficerTerm.computing_id == computing_id)
160+
select(OfficerTermDB)
161+
.where(OfficerTermDB.computing_id == computing_id)
170162
# In order of most recent start date first
171-
.order_by(OfficerTerm.start_date.desc())
163+
.order_by(OfficerTermDB.start_date.desc())
172164
)
173165
query = utils.is_active_officer(query)
174166

175-
officer_term_list = (await db_session.scalars(query)).all()
176-
return officer_term_list
167+
return list((await db_session.scalars(query)).all())
177168

178169

179170
async def current_officer_positions(db_session: database.DBSession, computing_id: str) -> list[str]:
@@ -186,8 +177,8 @@ async def current_officer_positions(db_session: database.DBSession, computing_id
186177

187178
async def get_officer_term_by_id_or_raise(
188179
db_session: database.DBSession, term_id: int, is_new: bool = False
189-
) -> OfficerTerm:
190-
officer_term = await db_session.scalar(sqlalchemy.select(OfficerTerm).where(OfficerTerm.id == term_id))
180+
) -> OfficerTermDB:
181+
officer_term = await db_session.scalar(select(OfficerTermDB).where(OfficerTermDB.id == term_id))
191182
if officer_term is None:
192183
if is_new:
193184
raise HTTPException(status_code=500, detail=f"could not find new officer_term with id={term_id}")
@@ -205,7 +196,7 @@ async def create_new_officer_info(db_session: database.DBSession, new_officer_in
205196
)
206197

207198
existing_officer_info = await db_session.scalar(
208-
sqlalchemy.select(OfficerInfo).where(OfficerInfo.computing_id == new_officer_info.computing_id)
199+
select(OfficerInfo).where(OfficerInfo.computing_id == new_officer_info.computing_id)
209200
)
210201
if existing_officer_info is not None:
211202
return False
@@ -214,7 +205,7 @@ async def create_new_officer_info(db_session: database.DBSession, new_officer_in
214205
return True
215206

216207

217-
async def create_new_officer_term(db_session: database.DBSession, new_officer_term: OfficerTerm):
208+
async def create_new_officer_term(db_session: database.DBSession, new_officer_term: OfficerTermDB):
218209
position_length = OfficerPosition.length_in_semesters(new_officer_term.position)
219210
if position_length is not None:
220211
# when creating a new position, assign a default end date if one exists
@@ -230,15 +221,15 @@ async def update_officer_info(db_session: database.DBSession, new_officer_info:
230221
Return False if the officer doesn't exist yet
231222
"""
232223
officer_info = await db_session.scalar(
233-
sqlalchemy.select(OfficerInfo).where(OfficerInfo.computing_id == new_officer_info.computing_id)
224+
select(OfficerInfo).where(OfficerInfo.computing_id == new_officer_info.computing_id)
234225
)
235226
if officer_info is None:
236227
return False
237228

238229
# NOTE: if there's ever an insert entry error, it will raise SQLAlchemyError
239230
# see: https://stackoverflow.com/questions/2136739/how-to-check-and-handle-errors-in-sqlalchemy
240231
await db_session.execute(
241-
sqlalchemy.update(OfficerInfo)
232+
update(OfficerInfo)
242233
.where(OfficerInfo.computing_id == officer_info.computing_id)
243234
.values(new_officer_info.to_update_dict())
244235
)
@@ -247,23 +238,21 @@ async def update_officer_info(db_session: database.DBSession, new_officer_info:
247238

248239
async def update_officer_term(
249240
db_session: database.DBSession,
250-
new_officer_term: OfficerTerm,
241+
new_officer_term: OfficerTermDB,
251242
) -> bool:
252243
"""
253244
Update all officer term data in `new_officer_term` based on the term id.
254245
Returns false if the above entry does not exist.
255246
"""
256-
officer_term = await db_session.scalar(sqlalchemy.select(OfficerTerm).where(OfficerTerm.id == new_officer_term.id))
247+
officer_term = await db_session.scalar(select(OfficerTermDB).where(OfficerTermDB.id == new_officer_term.id))
257248
if officer_term is None:
258249
return False
259250

260251
await db_session.execute(
261-
sqlalchemy.update(OfficerTerm)
262-
.where(OfficerTerm.id == new_officer_term.id)
263-
.values(new_officer_term.to_update_dict())
252+
update(OfficerTermDB).where(OfficerTermDB.id == new_officer_term.id).values(new_officer_term.to_update_dict())
264253
)
265254
return True
266255

267256

268257
async def delete_officer_term_by_id(db_session: database.DBSession, term_id: int):
269-
await db_session.execute(sqlalchemy.delete(OfficerTerm).where(OfficerTerm.id == term_id))
258+
await db_session.execute(delete(OfficerTermDB).where(OfficerTermDB.id == term_id))

0 commit comments

Comments
 (0)