22import logging
33import os
44import urllib .parse
5+ from typing import Literal
56
67import requests # TODO: make this async
78import xmltodict
8- from fastapi import APIRouter , BackgroundTasks , HTTPException , Request
9+ from fastapi import APIRouter , BackgroundTasks , Body , HTTPException , Request , Response
910from fastapi .responses import JSONResponse , PlainTextResponse , RedirectResponse
1011
1112import database
1213from auth import crud
13- from constants import FRONTEND_ROOT_URL
14+ from auth .models import LoginBodyModel
15+ from constants import IS_PROD
16+ from utils .shared_models import DetailModel
1417
1518_logger = logging .getLogger (__name__ )
1619
@@ -32,27 +35,34 @@ def generate_session_id_b64(num_bytes: int) -> str:
3235)
3336
3437
35- # NOTE: logging in a second time invaldiates the last session_id
36- @router .get (
38+ # NOTE: logging in a second time invalidates the last session_id
39+ @router .post (
3740 "/login" ,
38- description = "Login to the sfucsss.org. Must redirect to this endpoint from SFU's cas authentication service for correct parameters" ,
41+ description = "Create a login session." ,
42+ response_description = "Successfully validated with SFU's CAS" ,
43+ response_model = None ,
44+ responses = {
45+ 307 : { "description" : "Successful validation, with redirect" },
46+ 400 : { "description" : "Origin is missing." , "model" : DetailModel },
47+ 401 : { "description" : "Failed to validate ticket with SFU's CAS" , "model" : DetailModel }
48+ },
49+ operation_id = "login" ,
3950)
4051async def login_user (
41- redirect_path : str ,
42- redirect_fragment : str ,
43- ticket : str ,
52+ request : Request ,
4453 db_session : database .DBSession ,
4554 background_tasks : BackgroundTasks ,
55+ body : LoginBodyModel
4656):
4757 # verify the ticket is valid
48- service = urllib .parse .quote (f"{ FRONTEND_ROOT_URL } /api/auth/login?redirect_path={ redirect_path } &redirect_fragment={ redirect_fragment } " )
49- service_validate_url = f"https://cas.sfu.ca/cas/serviceValidate?service={ service } &ticket={ ticket } "
58+ service_url = body .service
59+ service = urllib .parse .quote (service_url )
60+ service_validate_url = f"https://cas.sfu.ca/cas/serviceValidate?service={ service } &ticket={ body .ticket } "
5061 cas_response = xmltodict .parse (requests .get (service_validate_url ).text )
5162
5263 if "cas:authenticationFailure" in cas_response ["cas:serviceResponse" ]:
5364 _logger .info (f"User failed to login, with response { cas_response } " )
54- raise HTTPException (status_code = 401 , detail = "authentication error, ticket likely invalid" )
55-
65+ raise HTTPException (status_code = 401 , detail = "authentication error" )
5666 else :
5767 session_id = generate_session_id_b64 (256 )
5868 computing_id = cas_response ["cas:serviceResponse" ]["cas:authenticationSuccess" ]["cas:user" ]
@@ -63,15 +73,29 @@ async def login_user(
6373 # clean old sessions after sending the response
6474 background_tasks .add_task (crud .task_clean_expired_user_sessions , db_session )
6575
66- response = RedirectResponse (FRONTEND_ROOT_URL + redirect_path + "#" + redirect_fragment )
76+ if body .redirect_url :
77+ origin = request .headers .get ("origin" )
78+ if origin :
79+ response = RedirectResponse (origin + body .redirect_url )
80+ else :
81+ raise HTTPException (status_code = 400 , detail = "bad origin" )
82+ else :
83+ response = Response ()
84+
6785 response .set_cookie (
68- key = "session_id" , value = session_id
86+ key = "session_id" ,
87+ value = session_id ,
88+ secure = IS_PROD ,
89+ httponly = True ,
90+ samesite = None if IS_PROD else "lax" ,
91+ domain = ".sfucsss.org" if IS_PROD else None
6992 ) # this overwrites any past, possibly invalid, session_id
7093 return response
7194
7295
7396@router .get (
7497 "/logout" ,
98+ operation_id = "logout" ,
7599 description = "Logs out the current user by invalidating the session_id cookie" ,
76100)
77101async def logout_user (
@@ -94,6 +118,7 @@ async def logout_user(
94118
95119@router .get (
96120 "/user" ,
121+ operation_id = "get_user" ,
97122 description = "Get info about the current user. Only accessible by that user" ,
98123)
99124async def get_user (
@@ -116,6 +141,7 @@ async def get_user(
116141
117142@router .patch (
118143 "/user" ,
144+ operation_id = "update_user" ,
119145 description = "Update information for the currently logged in user. Only accessible by that user" ,
120146)
121147async def update_user (
0 commit comments