Skip to content

Commit 4873317

Browse files
committed
refactor(permissions): restructured most endpoints to use Depends
1 parent 74bc701 commit 4873317

23 files changed

Lines changed: 663 additions & 506 deletions

File tree

src/auth/crud.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sqlalchemy
55
from sqlalchemy.ext.asyncio import AsyncSession
66

7-
from auth.tables import SiteUser, UserSession
7+
from auth.tables import SiteUserDB, UserSession
88

99
_logger = logging.getLogger(__name__)
1010

@@ -18,7 +18,9 @@ async def create_user_session(db_session: AsyncSession, session_id: str, computi
1818
existing_user_session = await db_session.scalar(
1919
sqlalchemy.select(UserSession).where(UserSession.computing_id == computing_id)
2020
)
21-
existing_user = await db_session.scalar(sqlalchemy.select(SiteUser).where(SiteUser.computing_id == computing_id))
21+
existing_user = await db_session.scalar(
22+
sqlalchemy.select(SiteUserDB).where(SiteUserDB.computing_id == computing_id)
23+
)
2224

2325
if existing_user is None:
2426
if existing_user_session is not None:
@@ -27,7 +29,7 @@ async def create_user_session(db_session: AsyncSession, session_id: str, computi
2729

2830
# add new user to User table if it's their first time logging in
2931
db_session.add(
30-
SiteUser(computing_id=computing_id, first_logged_in=datetime.now(), last_logged_in=datetime.now())
32+
SiteUserDB(computing_id=computing_id, first_logged_in=datetime.now(), last_logged_in=datetime.now())
3133
)
3234

3335
if existing_user_session is not None:
@@ -68,18 +70,18 @@ async def task_clean_expired_user_sessions(db_session: AsyncSession):
6870

6971

7072
# get the site user given a session ID; returns None when session is invalid
71-
async def get_site_user(db_session: AsyncSession, session_id: str) -> SiteUser | None:
73+
async def get_site_user(db_session: AsyncSession, session_id: str) -> SiteUserDB | None:
7274
query = sqlalchemy.select(UserSession).where(UserSession.session_id == session_id)
7375
user_session = await db_session.scalar(query)
7476
if user_session is None:
7577
return None
7678

77-
query = sqlalchemy.select(SiteUser).where(SiteUser.computing_id == user_session.computing_id)
79+
query = sqlalchemy.select(SiteUserDB).where(SiteUserDB.computing_id == user_session.computing_id)
7880
return await db_session.scalar(query)
7981

8082

8183
async def site_user_exists(db_session: AsyncSession, computing_id: str) -> bool:
82-
user = await db_session.scalar(sqlalchemy.select(SiteUser).where(SiteUser.computing_id == computing_id))
84+
user = await db_session.scalar(sqlalchemy.select(SiteUserDB).where(SiteUserDB.computing_id == computing_id))
8385
return user is not None
8486

8587

@@ -91,8 +93,8 @@ async def update_site_user(db_session: AsyncSession, session_id: str, profile_pi
9193
return False
9294

9395
query = (
94-
sqlalchemy.update(SiteUser)
95-
.where(SiteUser.computing_id == user_session.computing_id)
96+
sqlalchemy.update(SiteUserDB)
97+
.where(SiteUserDB.computing_id == user_session.computing_id)
9698
.values(profile_picture_url=profile_picture_url)
9799
)
98100
await db_session.execute(query)

src/auth/tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class UserSession(Base):
2626
) # the space needed to store 256 bytes in base64
2727

2828

29-
class SiteUser(Base):
29+
class SiteUserDB(Base):
3030
# user is a reserved word in postgres
3131
# see: https://stackoverflow.com/questions/22256124/cannot-create-a-database-table-named-user-in-postgresql
3232
__tablename__ = "site_user"

src/cron/daily.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import google_api
88
import utils
99
from database import get_db_session
10-
from officers.crud import all_officers, get_user_by_username
10+
from officers.crud import get_all_officers, get_user_by_username
1111

1212
_logger = logging.getLogger(__name__)
1313

@@ -19,7 +19,7 @@ async def update_google_permissions(db_session):
1919

2020
# TODO: for performance, only include officers with recent end-date (1 yr)
2121
# but measure performance first
22-
for term in await all_officers(db_session):
22+
for term in await get_all_officers(db_session):
2323
if utils.is_active(term):
2424
# TODO: if google drive permission is not active, update them
2525
pass
@@ -33,7 +33,7 @@ async def update_google_permissions(db_session):
3333
async def update_github_permissions(db_session):
3434
github_permissions, team_id_map = github.all_permissions()
3535

36-
for term in await all_officers(db_session):
36+
for term in await get_all_officers(db_session):
3737
new_teams = (
3838
# move all active officers to their respective teams
3939
github.officer_teams(term.position)

src/data/semesters.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from enum import Enum
33
from typing import assert_never
44

5+
JANUARY = 1
6+
MAY = 5
7+
SEPTEMBER = 9
8+
59

610
class Semester(Enum):
711
"""semester numbers are assigned by their order in the year"""
@@ -18,7 +22,7 @@ def __str__(self):
1822
elif self.value == 2:
1923
return "fall"
2024
else:
21-
assert_never()
25+
assert_never(self.value)
2226

2327

2428
def step_semesters(semester_start_date: date, num_semesters: int) -> date:
@@ -32,30 +36,32 @@ def step_semesters(semester_start_date: date, num_semesters: int) -> date:
3236

3337

3438
def current_semester_start(the_date: date) -> date:
35-
if the_date.month >= 9:
36-
return date(year=the_date.year, month=9, day=1)
37-
elif the_date.month >= 5:
38-
return date(year=the_date.year, month=5, day=1)
39-
elif the_date.month >= 1:
40-
return date(year=the_date.year, month=1, day=1)
39+
if the_date.month >= SEPTEMBER:
40+
return date(year=the_date.year, month=SEPTEMBER, day=1)
41+
elif the_date.month >= MAY:
42+
return date(year=the_date.year, month=MAY, day=1)
43+
elif the_date.month >= JANUARY:
44+
return date(year=the_date.year, month=JANUARY, day=1)
45+
else:
46+
raise AssertionError("unreachable")
4147

4248

4349
def current_semester(the_date: date) -> Semester:
44-
if the_date.month >= 9:
50+
if the_date.month >= SEPTEMBER:
4551
return Semester.Fall
46-
elif the_date.month >= 5:
52+
elif the_date.month >= MAY:
4753
return Semester.Summer
48-
elif the_date.month >= 1:
54+
elif the_date.month >= JANUARY:
4955
return Semester.Spring
5056
else:
51-
assert_never()
57+
raise AssertionError("unreachable")
5258

5359

5460
def get_semester_start(year: int, semester: Semester):
5561
match semester:
5662
case Semester.Fall:
57-
return date(year, month=9, day=1)
63+
return date(year, month=SEPTEMBER, day=1)
5864
case Semester.Summer:
59-
return date(year, month=5, day=1)
65+
return date(year, month=MAY, day=1)
6066
case Semester.Spring:
61-
return date(year, month=1, day=1)
67+
return date(year, month=JANUARY, day=1)

src/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def setup_database():
106106
# TODO: where is sys.stdout piped to? I want all these to go to a specific logs folder
107107
sessionmanager = DatabaseSessionManager(
108108
SQLALCHEMY_TEST_DATABASE_URL if os.environ.get("LOCAL") else SQLALCHEMY_DATABASE_URL,
109-
{"echo": False},
109+
{"echo": True},
110110
)
111111

112112

src/dependencies.py

Lines changed: 25 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,53 @@
1-
from enum import Enum
21
from typing import Annotated
32

4-
from fastapi import Depends, HTTPException, Request, status
3+
from fastapi import Cookie, Depends, HTTPException, status
54

65
import auth
76
import database
8-
import officers
9-
from officers.constants import OfficerPositionEnum
10-
from permission.types import WEBSITE_ADMIN_POSITIONS
7+
from utils.permissions import is_user_election_officer, is_user_website_admin
118

129

13-
# Permissions are granted if the Enum value >= the level needed
14-
class AdminTypeEnum(Enum):
15-
Election = 1
16-
Full = 2
17-
18-
19-
async def is_user_website_admin(computing_id: str, db_session: database.DBSession) -> bool:
20-
for position in await officers.crud.current_officer_positions(db_session, computing_id):
21-
if position in WEBSITE_ADMIN_POSITIONS:
22-
return True
10+
async def user(db_session: database.DBSession, session_id: Annotated[str | None, Cookie()] = None) -> str | None:
11+
if session_id is None:
12+
return None
2313

24-
return False
14+
session_computing_id = await auth.crud.get_computing_id(db_session, session_id)
2515

16+
return session_computing_id
2617

27-
# TODO: Add an election admin version that checks the election attempting to be modified as well
28-
async def is_user_election_officer(computing_id: str, db_session: database.DBSession) -> bool:
29-
"""
30-
An current election officer has access to all election, prior election officers have no access.
31-
"""
32-
officer_terms = await officers.crud.get_current_terms_by_position(db_session, OfficerPositionEnum.ELECTIONS_OFFICER)
33-
for officer in officer_terms:
34-
if computing_id == officer.computing_id:
35-
return True
3618

37-
return False
19+
SessionUser = Annotated[str, Depends(user)]
3820

3921

40-
async def get_user(request: Request, db_session: database.DBSession) -> tuple[str, str]:
41-
"""gets the user's computing_id, or raises an exception if the current request is not logged in"""
42-
session_id = request.cookies.get("session_id", None)
22+
async def logged_in_user(db_session: database.DBSession, session_id: Annotated[str | None, Cookie()] = None) -> str:
4323
if session_id is None:
4424
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="no session id")
4525

4626
session_computing_id = await auth.crud.get_computing_id(db_session, session_id)
4727
if session_computing_id is None:
4828
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="no computing id")
4929

50-
return session_id, session_computing_id
30+
return session_computing_id
31+
32+
33+
LoggedInUser = Annotated[str, Depends(logged_in_user)]
34+
35+
36+
async def perm_election(db_session: database.DBSession, computing_id: LoggedInUser) -> str:
37+
if not is_user_website_admin(computing_id, db_session) or is_user_election_officer(computing_id, db_session):
38+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="must be an election admin")
5139

40+
return computing_id
5241

53-
# Allows path functions to use this without having to add a bunch of checks
54-
SessionUser = Annotated[tuple[str, str], Depends(get_user)]
5542

43+
ElectionAdmin = Annotated[str, Depends(perm_election)]
5644

57-
async def get_admin(
58-
db_session: database.DBSession, session_user: SessionUser, admin_type: AdminTypeEnum
59-
) -> tuple[str, str]:
60-
session_id, computing_id = session_user
61-
# Website admins have full permissions
62-
if is_user_website_admin(computing_id, db_session):
63-
return (session_id, computing_id)
6445

65-
# Election officers have lower permissions
66-
if admin_type == AdminTypeEnum.Election and is_user_election_officer(computing_id, db_session):
67-
return (session_id, computing_id)
46+
async def perm_admin(db_session: database.DBSession, computing_id: LoggedInUser):
47+
if not is_user_website_admin(computing_id, db_session):
48+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="must be an admin")
6849

69-
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="must be an admin")
50+
return computing_id
7051

7152

72-
# Allows path functions to use this without having to add a bunch of checks
73-
SessionAdmin = Annotated[tuple[str, str], Depends(get_admin)]
53+
SiteAdmin = Annotated[str, Depends(perm_admin)]

src/elections/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ElectionResponse(BaseModel):
2929
available_positions: list[OfficerPositionEnum]
3030
status: ElectionStatusEnum
3131

32+
# Private fields
3233
survey_link: str | None = Field(None, description="Only available to admins")
3334
candidates: list[RegistrationModel] | None = Field(None, description="Only available to admins")
3435

0 commit comments

Comments
 (0)