22import os
33import asyncio
44import logging
5+ import sqlite3
56from fastapi import FastAPI , Depends , HTTPException , Response , Request
67from fastapi .responses import RedirectResponse , JSONResponse
78from fastapi .middleware .cors import CORSMiddleware
4041 logging .basicConfig (level = logging .INFO , format = "%(asctime)s %(levelname)s %(message)s" )
4142
4243# ---- Session persistence ----
43- SESSIONS_FILE = "sessions.json"
44+ SESSIONS_FILE = "sessions.json" # legacy migration source
45+ SESSIONS_DB_PATH = os .getenv ("SESSIONS_DB_PATH" , "sessions.db" )
4446
4547# ---- GitHub API URL ----
4648GITHUB_API_URL = "https://api.github.com/users"
@@ -59,21 +61,81 @@ class ChatResponse(BaseModel):
5961 meta : dict = {}
6062
6163
62- def load_sessions ():
63- if os .path .exists (SESSIONS_FILE ):
64- with open (SESSIONS_FILE , "r" ) as f :
65- try :
66- return json .load (f )
67- except json .JSONDecodeError :
68- return {}
69- return {}
64+ def _db_connect ():
65+ conn = sqlite3 .connect (SESSIONS_DB_PATH )
66+ conn .row_factory = sqlite3 .Row
67+ return conn
68+
69+
70+ def init_session_store ():
71+ with _db_connect () as conn :
72+ conn .execute (
73+ """
74+ CREATE TABLE IF NOT EXISTS sessions (
75+ session_id TEXT PRIMARY KEY,
76+ data TEXT NOT NULL,
77+ expires REAL NOT NULL
78+ )
79+ """
80+ )
81+ conn .execute ("CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires)" )
82+
83+
84+ def session_store_set (session_id : str , data : dict ):
85+ expires = float (data .get ("expires" , 0 ))
86+ with _db_connect () as conn :
87+ conn .execute (
88+ "INSERT OR REPLACE INTO sessions (session_id, data, expires) VALUES (?, ?, ?)" ,
89+ (session_id , json .dumps (data ), expires ),
90+ )
91+
92+
93+ def session_store_get (session_id : str ) -> Optional [dict ]:
94+ with _db_connect () as conn :
95+ row = conn .execute (
96+ "SELECT data, expires FROM sessions WHERE session_id = ?" ,
97+ (session_id ,),
98+ ).fetchone ()
99+ if not row :
100+ return None
101+ if float (row ["expires" ]) < time .time ():
102+ conn .execute ("DELETE FROM sessions WHERE session_id = ?" , (session_id ,))
103+ return None
104+ try :
105+ return json .loads (row ["data" ])
106+ except Exception :
107+ conn .execute ("DELETE FROM sessions WHERE session_id = ?" , (session_id ,))
108+ return None
70109
71- def save_sessions ():
72- with open (SESSIONS_FILE , "w" ) as f :
73- json .dump (sessions , f )
74110
75- sessions = load_sessions ()
111+ def session_store_delete (session_id : str ):
112+ with _db_connect () as conn :
113+ conn .execute ("DELETE FROM sessions WHERE session_id = ?" , (session_id ,))
76114
115+
116+ def session_store_cleanup_expired ():
117+ with _db_connect () as conn :
118+ conn .execute ("DELETE FROM sessions WHERE expires < ?" , (time .time (),))
119+
120+
121+ def migrate_legacy_sessions ():
122+ if not os .path .exists (SESSIONS_FILE ):
123+ return
124+ try :
125+ with open (SESSIONS_FILE , "r" ) as f :
126+ legacy = json .load (f )
127+ except Exception :
128+ return
129+ if not isinstance (legacy , dict ):
130+ return
131+ for sid , data in legacy .items ():
132+ if isinstance (data , dict ):
133+ session_store_set (sid , data )
134+
135+
136+ init_session_store ()
137+ migrate_legacy_sessions ()
138+ session_store_cleanup_expired ()
77139# ---- Cookie signing ----
78140SECRET_KEY = os .getenv ("SECRET_KEY" )
79141if not SECRET_KEY :
@@ -207,7 +269,7 @@ def get_current_user(request: Request):
207269 raise HTTPException (status_code = 401 , detail = "Not authenticated" )
208270 try :
209271 session_id = serializer .loads (cookie )["session_id" ]
210- session = sessions . get (session_id )
272+ session = session_store_get (session_id )
211273 if not session or session ["expires" ] < time .time ():
212274 raise HTTPException (status_code = 401 , detail = "Session expired" )
213275 return session
@@ -251,9 +313,6 @@ async def github_login():
251313async def github_callback (request : Request , code : str , state : Optional [str ] = None ):
252314 ensure_github_oauth_config ()
253315 validate_oauth_state (request , state )
254- ensure_github_oauth_config ()
255- if not state :
256- raise HTTPException (status_code = 400 , detail = "Missing state parameter" )
257316
258317 async with httpx .AsyncClient () as client :
259318 token_response = await client .post (
@@ -272,13 +331,13 @@ async def github_callback(request: Request, code: str, state: Optional[str] = No
272331 user_data = await get_github_user (token_data ["access_token" ])
273332
274333 session_id = str (uuid4 ())
275- sessions [ session_id ] = {
334+ session_data = {
276335 "access_token" : token_data ["access_token" ],
277336 "user" : user_data ["login" ],
278337 "user_id" : user_data ["id" ],
279338 "expires" : time .time () + SESSION_TTL_SECONDS
280339 }
281- save_sessions ( )
340+ session_store_set ( session_id , session_data )
282341
283342 signed_cookie = serializer .dumps ({"session_id" : session_id })
284343 response = RedirectResponse (url = FRONTEND_URL )
@@ -908,10 +967,8 @@ async def logout(request: Request, user=Depends(get_current_user)):
908967 except Exception :
909968 pass
910969
911- # Remove from sessions.json if exists
912- if session_id and session_id in sessions :
913- sessions .pop (session_id , None )
914- save_sessions ()
970+ if session_id :
971+ session_store_delete (session_id )
915972
916973 # Clear cookie
917974 response = JSONResponse ({"message" : "Logged out" })
@@ -1053,3 +1110,5 @@ async def get_task_status(task_id: str):
10531110
10541111
10551112
1113+
1114+
0 commit comments