Skip to content

Commit 2cee272

Browse files
committed
added automated backend auth/session tests and verified they pass.
1 parent 667a3e1 commit 2cee272

5 files changed

Lines changed: 207 additions & 25 deletions

File tree

.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ APP_ENV=development
22
SECRET_KEY=replace-with-a-long-random-string
33
FRONTEND_URL=http://localhost:5173
44
SESSION_TTL_SECONDS=3600
5+
SESSIONS_DB_PATH=sessions.db
56
GITHUB_CLIENT_ID=your_github_client_id
67
GITHUB_CLIENT_SECRET=your_github_client_secret
78
GITHUB_TOKEN=your_github_token

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ venv/
1414
# Logs
1515
*.log
1616
sessions.json
17+
sessions.db
18+
sessions.db-shm
19+
sessions.db-wal

PRODUCTION_READINESS.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
## 2. Auth & Sessions
1212
- [x] Session cookie uses environment-based `secure` and `samesite`.
1313
- [x] Session TTL configurable (`SESSION_TTL_SECONDS`).
14-
- [ ] Replace `sessions.json` with Redis/DB-backed session store.
14+
- [x] Replace `sessions.json` with SQLite-backed session store (`sessions.db`).
1515
- [ ] Add CSRF protections for state-changing endpoints.
16+
- [ ] Optional: move sessions from SQLite to Redis for multi-instance horizontal scaling.
1617

1718
## 3. Reliability
1819
- [x] AI model fallback and status endpoint (`/api/ai-status`) added.
@@ -28,7 +29,8 @@
2829

2930
## 5. Quality Gates
3031
- [x] Static checks currently passing (`py_compile`, frontend lint).
31-
- [ ] Add backend API tests (auth, files, chat).
32+
- [x] Add backend API tests for auth/session lifecycle (`tests/test_auth_sessions.py`).
33+
- [ ] Extend backend API tests to files/chat endpoints.
3234
- [ ] Add frontend integration tests (repo tree + file preview + chat).
3335
- [ ] CI pipeline enforcing lint/tests before merge.
3436

main.py

Lines changed: 82 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import asyncio
44
import logging
5+
import sqlite3
56
from fastapi import FastAPI, Depends, HTTPException, Response, Request
67
from fastapi.responses import RedirectResponse, JSONResponse
78
from fastapi.middleware.cors import CORSMiddleware
@@ -40,7 +41,8 @@
4041
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
4142

4243
# ---- Session persistence ----
43-
SESSIONS_FILE = "sessions.json"
44+
SESSIONS_FILE = "sessions.json" # legacy migration source
45+
SESSIONS_DB_PATH = os.getenv("SESSIONS_DB_PATH", "sessions.db")
4446

4547
# ---- GitHub API URL ----
4648
GITHUB_API_URL = "https://api.github.com/users"
@@ -59,21 +61,81 @@ class ChatResponse(BaseModel):
5961
meta: dict = {}
6062

6163

62-
def load_sessions():
63-
if os.path.exists(SESSIONS_FILE):
64-
with open(SESSIONS_FILE, "r") as f:
65-
try:
66-
return json.load(f)
67-
except json.JSONDecodeError:
68-
return {}
69-
return {}
64+
def _db_connect():
65+
conn = sqlite3.connect(SESSIONS_DB_PATH)
66+
conn.row_factory = sqlite3.Row
67+
return conn
68+
69+
70+
def init_session_store():
71+
with _db_connect() as conn:
72+
conn.execute(
73+
"""
74+
CREATE TABLE IF NOT EXISTS sessions (
75+
session_id TEXT PRIMARY KEY,
76+
data TEXT NOT NULL,
77+
expires REAL NOT NULL
78+
)
79+
"""
80+
)
81+
conn.execute("CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires)")
82+
83+
84+
def session_store_set(session_id: str, data: dict):
85+
expires = float(data.get("expires", 0))
86+
with _db_connect() as conn:
87+
conn.execute(
88+
"INSERT OR REPLACE INTO sessions (session_id, data, expires) VALUES (?, ?, ?)",
89+
(session_id, json.dumps(data), expires),
90+
)
91+
92+
93+
def session_store_get(session_id: str) -> Optional[dict]:
94+
with _db_connect() as conn:
95+
row = conn.execute(
96+
"SELECT data, expires FROM sessions WHERE session_id = ?",
97+
(session_id,),
98+
).fetchone()
99+
if not row:
100+
return None
101+
if float(row["expires"]) < time.time():
102+
conn.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,))
103+
return None
104+
try:
105+
return json.loads(row["data"])
106+
except Exception:
107+
conn.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,))
108+
return None
70109

71-
def save_sessions():
72-
with open(SESSIONS_FILE, "w") as f:
73-
json.dump(sessions, f)
74110

75-
sessions = load_sessions()
111+
def session_store_delete(session_id: str):
112+
with _db_connect() as conn:
113+
conn.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,))
76114

115+
116+
def session_store_cleanup_expired():
117+
with _db_connect() as conn:
118+
conn.execute("DELETE FROM sessions WHERE expires < ?", (time.time(),))
119+
120+
121+
def migrate_legacy_sessions():
122+
if not os.path.exists(SESSIONS_FILE):
123+
return
124+
try:
125+
with open(SESSIONS_FILE, "r") as f:
126+
legacy = json.load(f)
127+
except Exception:
128+
return
129+
if not isinstance(legacy, dict):
130+
return
131+
for sid, data in legacy.items():
132+
if isinstance(data, dict):
133+
session_store_set(sid, data)
134+
135+
136+
init_session_store()
137+
migrate_legacy_sessions()
138+
session_store_cleanup_expired()
77139
# ---- Cookie signing ----
78140
SECRET_KEY = os.getenv("SECRET_KEY")
79141
if not SECRET_KEY:
@@ -207,7 +269,7 @@ def get_current_user(request: Request):
207269
raise HTTPException(status_code=401, detail="Not authenticated")
208270
try:
209271
session_id = serializer.loads(cookie)["session_id"]
210-
session = sessions.get(session_id)
272+
session = session_store_get(session_id)
211273
if not session or session["expires"] < time.time():
212274
raise HTTPException(status_code=401, detail="Session expired")
213275
return session
@@ -251,9 +313,6 @@ async def github_login():
251313
async def github_callback(request: Request, code: str, state: Optional[str] = None):
252314
ensure_github_oauth_config()
253315
validate_oauth_state(request, state)
254-
ensure_github_oauth_config()
255-
if not state:
256-
raise HTTPException(status_code=400, detail="Missing state parameter")
257316

258317
async with httpx.AsyncClient() as client:
259318
token_response = await client.post(
@@ -272,13 +331,13 @@ async def github_callback(request: Request, code: str, state: Optional[str] = No
272331
user_data = await get_github_user(token_data["access_token"])
273332

274333
session_id = str(uuid4())
275-
sessions[session_id] = {
334+
session_data = {
276335
"access_token": token_data["access_token"],
277336
"user": user_data["login"],
278337
"user_id": user_data["id"],
279338
"expires": time.time() + SESSION_TTL_SECONDS
280339
}
281-
save_sessions()
340+
session_store_set(session_id, session_data)
282341

283342
signed_cookie = serializer.dumps({"session_id": session_id})
284343
response = RedirectResponse(url=FRONTEND_URL)
@@ -908,10 +967,8 @@ async def logout(request: Request, user=Depends(get_current_user)):
908967
except Exception:
909968
pass
910969

911-
# Remove from sessions.json if exists
912-
if session_id and session_id in sessions:
913-
sessions.pop(session_id, None)
914-
save_sessions()
970+
if session_id:
971+
session_store_delete(session_id)
915972

916973
# Clear cookie
917974
response = JSONResponse({"message": "Logged out"})
@@ -1053,3 +1110,5 @@ async def get_task_status(task_id: str):
10531110

10541111

10551112

1113+
1114+

tests/test_auth_sessions.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import os
2+
import tempfile
3+
import unittest
4+
from unittest.mock import patch
5+
6+
from fastapi.testclient import TestClient
7+
8+
import main
9+
10+
11+
class _MockHTTPResponse:
12+
def __init__(self, status_code=200, json_data=None):
13+
self.status_code = status_code
14+
self._json_data = json_data or {}
15+
self.headers = {}
16+
self.text = ""
17+
18+
def json(self):
19+
return self._json_data
20+
21+
def raise_for_status(self):
22+
if self.status_code >= 400:
23+
raise RuntimeError(f"HTTP {self.status_code}")
24+
25+
26+
class _MockAsyncClient:
27+
async def __aenter__(self):
28+
return self
29+
30+
async def __aexit__(self, exc_type, exc, tb):
31+
return False
32+
33+
async def post(self, url, params=None, headers=None):
34+
if "login/oauth/access_token" in url:
35+
return _MockHTTPResponse(
36+
200,
37+
{"access_token": "mock_access_token"},
38+
)
39+
return _MockHTTPResponse(404, {"message": "not found"})
40+
41+
async def get(self, url, headers=None):
42+
if url == "https://api.github.com/user":
43+
return _MockHTTPResponse(
44+
200,
45+
{"login": "mock-user", "id": 12345},
46+
)
47+
return _MockHTTPResponse(404, {"message": "not found"})
48+
49+
50+
class AuthSessionTests(unittest.TestCase):
51+
@classmethod
52+
def setUpClass(cls):
53+
cls._orig_db = main.SESSIONS_DB_PATH
54+
cls._tmp_dir = tempfile.TemporaryDirectory()
55+
main.SESSIONS_DB_PATH = os.path.join(cls._tmp_dir.name, "test_sessions.db")
56+
main.init_session_store()
57+
58+
@classmethod
59+
def tearDownClass(cls):
60+
main.SESSIONS_DB_PATH = cls._orig_db
61+
try:
62+
cls._tmp_dir.cleanup()
63+
except PermissionError:
64+
# On Windows, sqlite may release file handles slightly later.
65+
pass
66+
67+
def setUp(self):
68+
main.session_store_cleanup_expired()
69+
self.client = TestClient(main.app)
70+
71+
def tearDown(self):
72+
self.client.close()
73+
74+
def test_login_sets_oauth_state_cookie(self):
75+
res = self.client.get("/login/github", follow_redirects=False)
76+
self.assertEqual(res.status_code, 307)
77+
self.assertIn("oauth_state=", res.headers.get("set-cookie", ""))
78+
self.assertIn("github.com/login/oauth/authorize", res.headers.get("location", ""))
79+
80+
def test_callback_rejects_missing_state_cookie(self):
81+
res = self.client.get("/auth/github/callback?code=abc&state=s1", follow_redirects=False)
82+
self.assertEqual(res.status_code, 400)
83+
self.assertIn("OAuth state cookie", res.json().get("detail", ""))
84+
85+
def test_callback_rejects_state_mismatch(self):
86+
self.client.get("/login/github", follow_redirects=False)
87+
res = self.client.get("/auth/github/callback?code=abc&state=wrong", follow_redirects=False)
88+
self.assertEqual(res.status_code, 400)
89+
self.assertIn("state mismatch", res.json().get("detail", ""))
90+
91+
def test_callback_creates_session_and_auth_works_then_logout(self):
92+
login_res = self.client.get("/login/github", follow_redirects=False)
93+
self.assertEqual(login_res.status_code, 307)
94+
location = login_res.headers["location"]
95+
state = location.split("state=")[1].split("&")[0]
96+
97+
with patch("main.httpx.AsyncClient", _MockAsyncClient):
98+
callback_res = self.client.get(
99+
f"/auth/github/callback?code=abc123&state={state}",
100+
follow_redirects=False,
101+
)
102+
self.assertEqual(callback_res.status_code, 307)
103+
self.assertIn("session_id=", callback_res.headers.get("set-cookie", ""))
104+
105+
authed = self.client.get("/test-auth")
106+
self.assertEqual(authed.status_code, 200)
107+
self.assertIn("mock-user", authed.json().get("message", ""))
108+
109+
logout = self.client.post("/logout")
110+
self.assertEqual(logout.status_code, 200)
111+
112+
authed_after = self.client.get("/test-auth")
113+
self.assertEqual(authed_after.status_code, 401)
114+
115+
116+
if __name__ == "__main__":
117+
unittest.main()

0 commit comments

Comments
 (0)