1414from decouple import config
1515from fastapi import APIRouter , Depends , FastAPI , Form , HTTPException , Request , status
1616from fastapi .middleware .cors import CORSMiddleware
17- from fastapi .responses import HTMLResponse , RedirectResponse
17+ from fastapi .responses import HTMLResponse , JSONResponse , RedirectResponse
1818from fastapi .security import OAuth2PasswordBearer , OAuth2PasswordRequestForm
1919from fastapi .templating import Jinja2Templates
2020from icecream import ic
2727from schedule import check_and_revert_snooze , get_current_schedule_time , get_schedule , snooze_schedule
2828from sign_jwt import main as gen_token
2929from slackbot import *
30- from typing import List , Union
3130
3231# verbose icecream
3332ic .configureOutput (includeContext = True )
4645bypass_schedule = config ("OVERRIDE" , default = False , cast = bool )
4746DEV = config ("DEV" , default = False , cast = bool )
4847
49- # time
50- current_time_local = arrow .now (tz )
51- current_time_utc = arrow .utcnow ()
52- current_day = current_time_local .format ("dddd" ) # Monday, Tuesday, etc.
5348time .tzset ()
5449
5550# pandas don't truncate output
6661PORT = config ("PORT" , default = 3000 , cast = int )
6762SECRET_KEY = config ("SECRET_KEY" )
6863ALGORITHM = config ("ALGORITHM" , default = "HS256" )
69- TOKEN_EXPIRE = config ("TOKEN_EXPIRE" , default = 30 , cast = int )
64+ TOKEN_EXPIRE = config ("TOKEN_EXPIRE" , default = 480 , cast = int )
7065
7166try :
7267 DB_USER = config ("DB_USER" )
8782DISABLE_IP_WHITELIST = config ("DISABLE_IP_WHITELIST" , default = False , cast = bool )
8883
8984
85+ def _parse_public_ips () -> list [str ]:
86+ raw = config ("PUBLIC_IPS" , default = "" )
87+ return [ip .strip () for ip in raw .split ("," ) if ip .strip ()]
88+
89+
9090class IPConfig (BaseModel ):
9191 whitelist : list [str ] = ["localhost" , "127.0.0.1" ]
9292 public_ips : list [str ] = []
9393
9494
95- ip_config = IPConfig ()
95+ ip_config = IPConfig (public_ips = _parse_public_ips () )
9696
9797
9898def is_ip_allowed (request : Request ):
@@ -120,7 +120,12 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
120120
121121
122122# main web app
123- app = FastAPI (title = "meetup_bot API" , openapi_url = "/meetup_bot.json" , lifespan = lifespan )
123+ app = FastAPI (
124+ title = "meetup_bot API" ,
125+ openapi_url = "/meetup_bot.json" ,
126+ lifespan = lifespan ,
127+ swagger_ui_parameters = {"persistAuthorization" : True },
128+ )
124129
125130# add `/api` route in front of all other endpoints
126131api_router = APIRouter (prefix = "/api" )
@@ -202,18 +207,17 @@ def authenticate_user(username: str, password: str):
202207def create_access_token (data : dict , expires_delta : timedelta | None = None ):
203208 """Create access token"""
204209 to_encode = data .copy ()
205- if expires_delta :
206- expire = datetime .utcnow () + expires_delta
207- else :
208- expire = datetime .utcnow () + timedelta (minutes = TOKEN_EXPIRE )
210+ expire = datetime .utcnow () + expires_delta if expires_delta else datetime .utcnow () + timedelta (minutes = TOKEN_EXPIRE )
209211 to_encode .update ({"exp" : expire })
210212 encoded_jwt = jwt .encode (to_encode , SECRET_KEY , algorithm = ALGORITHM )
211213
212214 return encoded_jwt
213215
214216
215- async def get_current_user (token : str | None = Depends (oauth2_scheme )):
216- """Get current user"""
217+ async def get_current_user (request : Request , token : str | None = Depends (oauth2_scheme )):
218+ """Get current user from Bearer token or session_token cookie."""
219+ if token is None :
220+ token = request .cookies .get ("session_token" )
217221 if token is None :
218222 return None
219223 credentials_exception = HTTPException (
@@ -227,8 +231,8 @@ async def get_current_user(token: str | None = Depends(oauth2_scheme)):
227231 if username is None :
228232 raise credentials_exception
229233 token_data = TokenData (username = username )
230- except JWTError :
231- raise credentials_exception
234+ except JWTError as err :
235+ raise credentials_exception from err
232236 user = get_user (username = token_data .username )
233237 if user is None :
234238 raise credentials_exception
@@ -284,7 +288,16 @@ async def login_for_oauth_token(form_data: OAuth2PasswordRequestForm = Depends()
284288 oauth_token_expires = timedelta (minutes = TOKEN_EXPIRE )
285289 oauth_token = create_access_token (data = {"sub" : user .username }, expires_delta = oauth_token_expires )
286290
287- return {"access_token" : oauth_token , "token_type" : "bearer" }
291+ response = JSONResponse (content = {"access_token" : oauth_token , "token_type" : "bearer" })
292+ response .set_cookie (
293+ key = "session_token" ,
294+ value = oauth_token ,
295+ httponly = True ,
296+ secure = not DEV ,
297+ samesite = "lax" ,
298+ max_age = TOKEN_EXPIRE * 60 ,
299+ )
300+ return response
288301
289302
290303"""
@@ -322,7 +335,18 @@ def index(request: Request):
322335def login (request : Request , username : str = Form (...), password : str = Form (...)):
323336 """Redirect to "/docs" from index page if user successfully logs in with HTML form"""
324337 if load_user (username ) and verify_password (password , load_user (username ).hashed_password ):
325- return RedirectResponse (url = "/docs" , status_code = 303 )
338+ oauth_token_expires = timedelta (minutes = TOKEN_EXPIRE )
339+ oauth_token = create_access_token (data = {"sub" : username }, expires_delta = oauth_token_expires )
340+ response = RedirectResponse (url = "/docs" , status_code = 303 )
341+ response .set_cookie (
342+ key = "session_token" ,
343+ value = oauth_token ,
344+ httponly = True ,
345+ secure = not DEV ,
346+ samesite = "lax" ,
347+ max_age = TOKEN_EXPIRE * 60 ,
348+ )
349+ return response
326350
327351
328352@api_router .get ("/token" )
@@ -347,7 +371,7 @@ def generate_token(current_user: User = Depends(get_current_active_user)):
347371 refresh_token = tokens ["refresh_token" ]
348372 except KeyError as e :
349373 print (f"{ Fore .RED } { error :<10} { Fore .RESET } KeyError: { e } " )
350- raise HTTPException (status_code = 500 , detail = "Internal Server Error" )
374+ raise HTTPException (status_code = 500 , detail = "Internal Server Error" ) from e
351375
352376 return access_token , refresh_token
353377
@@ -384,28 +408,24 @@ def get_events(
384408 exclusion_list = exclusion_list + exclusions
385409
386410 response = send_request (access_token , query , vars )
387-
388- export_to_file (response , format , exclusions = exclusion_list )
411+ frames = [format_response (response , exclusions = exclusion_list )]
389412
390413 # third-party query (batched)
391414 responses = send_batched_group_request (access_token , url_vars )
392- output = []
393415 for i , response in enumerate (responses ):
394- if len (format_response (response , exclusions = exclusion_list )) > 0 :
395- output .append (response )
416+ df = format_response (response , exclusions = exclusion_list )
417+ if len (df ) > 0 :
418+ frames .append (df )
396419 else :
397420 print (f"{ Fore .GREEN } { info :<10} { Fore .RESET } No upcoming events for { url_vars [i ]} found" )
398- for resp in output :
399- export_to_file (resp , format )
400421
401- # cleanup output file
402- sort_json ( json_fn )
422+ combined = pd . concat ( frames , ignore_index = True )
423+ events = prepare_events ( combined )
403424
404- # check if file exists after sorting
405- if not os .path .exists (json_fn ) or os .stat (json_fn ).st_size == 0 :
425+ if not events :
406426 return {"message" : "No events found" , "events" : []}
407427
408- return pd . read_json ( json_fn ). to_dict ( 'records' )
428+ return events
409429
410430
411431@api_router .get ("/check-schedule" )
@@ -414,6 +434,9 @@ def should_post_to_slack(auth: dict = Depends(ip_whitelist_or_auth), request: Re
414434 Check if it's time to post to Slack based on the schedule
415435 """
416436
437+ current_time_local = arrow .now (tz )
438+ current_day = current_time_local .format ("dddd" )
439+
417440 with db_session :
418441 check_and_revert_snooze () # Check and revert any expired snoozes
419442 schedule = get_schedule (current_day )
@@ -465,19 +488,18 @@ def post_slack(
465488
466489 check_auth (auth )
467490
468- get_events (auth = auth , location = location , exclusions = exclusions )
491+ events = get_events (auth = auth , location = location , exclusions = exclusions )
492+
493+ # handle "no events found" response
494+ if isinstance (events , dict ):
495+ events = events .get ("events" , [])
469496
470- # open json file and convert to list of strings
471- msg = fmt_json (json_fn )
497+ msg = fmt_events (events )
472498
473- # if channel_name is not None, post to channel as one concatenated string
474499 if channel_name is not None :
475- # get channel id chan_dict key value pair
476500 channel_id = chan_dict [channel_name ]
477- # post to single channel
478501 send_message ("\n " .join (msg ), channel_id )
479502 else :
480- # post to all channels
481503 for name , id in channels .items ():
482504 send_message ("\n " .join (msg ), id )
483505
@@ -503,7 +525,7 @@ def snooze_slack_post(
503525 snooze_schedule (duration )
504526 return {"message" : f"Slack post snoozed for { duration } " }
505527 except ValueError as e :
506- raise HTTPException (status_code = 400 , detail = str (e ))
528+ raise HTTPException (status_code = 400 , detail = str (e )) from e
507529
508530
509531@api_router .get ("/schedule" )
0 commit comments