Skip to content

Commit e9f7e93

Browse files
authored
Merge pull request #3 from TaskarCenterAtUW/stabilize-proxy
Stabilize the OSM reverse proxy
2 parents a6f4c8f + 3d0c056 commit e9f7e93

5 files changed

Lines changed: 287 additions & 84 deletions

File tree

api/core/config.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
from pydantic_settings import BaseSettings, SettingsConfigDict
2+
3+
24
class Settings(BaseSettings):
35
"""Application settings."""
46

57
PROJECT_NAME: str = "Workspaces API"
68

9+
# JSON array of allowed CORS origins. For example:
10+
#
11+
# ["https://workspaces.example.com", "https://leaderboard.example.com"]
12+
#
13+
CORS_ORIGINS: list[str] = []
14+
715
TASK_DATABASE_URL: str = "postgresql+asyncpg://user:pass@localhost:5432/tasking_manager"
816
OSM_DATABASE_URL: str = "postgresql+asyncpg://user:pass@localhost:5432/tasking_manager"
917

@@ -18,8 +26,8 @@ class Settings(BaseSettings):
1826
"https://raw.githubusercontent.com/TaskarCenterAtUW/asr-quests/refs/heads/main/schema/schema.json"
1927
)
2028

21-
# proxy destination--"osm-rails" is a virtual docker network endpoint
22-
WS_OSM_HOST: str = "http://osm-rails:3000"
29+
# proxy destination--"osm-web" is a virtual docker network endpoint
30+
WS_OSM_HOST: str = "http://osm-web"
2331
#WS_OSM_HOST: str = "https://osm.workspaces-dev.sidewalks.washington.edu"
2432

2533
SENTRY_DSN: str = ""

api/core/jwt.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import jwt
2+
3+
from api.core.config import settings
4+
5+
# Singleton JWKS client reused to take advantage of internal cert/key caching:
6+
_jwks_client: jwt.PyJWKClient | None = None
7+
8+
9+
def _get_jwks_client() -> jwt.PyJWKClient:
10+
global _jwks_client
11+
12+
if _jwks_client is None:
13+
_jwks_client = jwt.PyJWKClient(
14+
f"{settings.TDEI_OIDC_URL.rstrip("/")}/realms/"
15+
f"{settings.TDEI_OIDC_REALM}/protocol/openid-connect/certs"
16+
)
17+
18+
return _jwks_client
19+
20+
21+
def validate_and_decode_token(token: str) -> dict:
22+
# TODO: use an async client like pyjwt-key-fetcher
23+
signing_key = _get_jwks_client().get_signing_key_from_jwt(token)
24+
25+
decoded = jwt.decode_complete(
26+
token,
27+
key=signing_key.key,
28+
algorithms=["RS256"],
29+
# OIDC server does not currently differentiate tokens by audience
30+
options={"verify_aud": False},
31+
)
32+
33+
return decoded.get("payload", {})

api/core/security.py

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
import json
21
from enum import StrEnum
32
from uuid import UUID
43

54
import cachetools
6-
import jwt
7-
import requests
5+
import httpx
86
from fastapi import Depends, HTTPException, status
97
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
108
from sqlalchemy import text
119
from sqlmodel.ext.asyncio.session import AsyncSession
1210

1311
from api.core.config import settings
1412
from api.core.database import get_osm_session, get_task_session
13+
from api.core.jwt import validate_and_decode_token
1514
from api.core.logging import get_logger
1615
from api.src.workspaces.schemas import WorkspaceUserRoleType
1716

@@ -23,6 +22,25 @@
2322
maxsize=1000, ttl=60 * 60
2423
)
2524

25+
# Shared HTTP client for TDEI backend calls. Initialized by main.py lifespan.
26+
_tdei_client: httpx.AsyncClient | None = None
27+
28+
29+
def init_tdei_client() -> None:
30+
global _tdei_client
31+
_tdei_client = httpx.AsyncClient(
32+
base_url=settings.TDEI_BACKEND_URL,
33+
timeout=httpx.Timeout(connect=10, read=30, write=30, pool=10),
34+
)
35+
36+
37+
async def close_tdei_client() -> None:
38+
global _tdei_client
39+
if _tdei_client is not None:
40+
await _tdei_client.aclose()
41+
_tdei_client = None
42+
43+
2644
security = HTTPBearer()
2745

2846

@@ -84,7 +102,9 @@ def isWorkspaceLead(self, workspaceId: int) -> bool:
84102

85103
for pg in self.projectGroups:
86104
if TdeiProjectGroupRole.POINT_OF_CONTACT in pg.tdeiRoles:
87-
if workspaceId in self.accessibleWorkspaceIds[pg.project_group_id]:
105+
if workspaceId in self.accessibleWorkspaceIds.get(
106+
pg.project_group_id, []
107+
):
88108
return True
89109

90110
return False
@@ -118,6 +138,7 @@ def get_task_db_session(
118138
) -> AsyncSession:
119139
return session
120140

141+
121142
async def validate_token(
122143
credentials: HTTPAuthorizationCredentials = Depends(security),
123144
osm_db_session: AsyncSession = Depends(get_osm_db_session),
@@ -129,19 +150,39 @@ async def validate_token(
129150
"""
130151
token = credentials.credentials
131152

153+
credentials_exception = HTTPException(
154+
status_code=status.HTTP_401_UNAUTHORIZED,
155+
detail="Invalid authentication credentials",
156+
headers={"WWW-Authenticate": "Bearer"},
157+
)
158+
159+
try:
160+
payload = validate_and_decode_token(token)
161+
except Exception:
162+
raise credentials_exception
163+
164+
user_id: str | None = payload.get("sub")
165+
if user_id is None:
166+
raise credentials_exception
167+
132168
# Check cache first
133169
if token in _token_cache:
134170
logger.info("Token validation cache hit")
135171
return _token_cache[token]
136172

137173
# Cache miss - perform full validation
138-
user_info = await _validate_token_uncached(token, osm_db_session, task_db_session)
174+
user_info = await _validate_token_uncached(
175+
token, user_id, payload, osm_db_session, task_db_session
176+
)
139177
_token_cache[token] = user_info
178+
140179
return user_info
141180

142181

143182
async def _validate_token_uncached(
144183
token: str,
184+
user_id: str,
185+
payload: dict,
145186
osm_db_session: AsyncSession,
146187
task_db_session: AsyncSession,
147188
) -> UserInfo:
@@ -153,66 +194,54 @@ async def _validate_token_uncached(
153194
headers={"WWW-Authenticate": "Bearer"},
154195
)
155196

156-
jwks_client = jwt.PyJWKClient(
157-
f"{settings.TDEI_OIDC_URL}realms/{settings.TDEI_OIDC_REALM}/protocol/openid-connect/certs"
158-
)
159-
160-
signing_key = jwks_client.get_signing_key_from_jwt(token)
161-
162-
jwtDecoded = jwt.decode_complete(
163-
token,
164-
key=signing_key.key,
165-
algorithms=["RS256"],
166-
# OIDC server does not currently differentiate tokens by audience
167-
options={"verify_aud": False}
168-
)
169-
payload = jwtDecoded.get("payload", {})
170-
171-
user_id: str | None = payload.get("sub")
172-
if user_id is None:
173-
raise credentials_exception
174-
175197
headers = {
176198
"Authorization": "Bearer " + token,
177199
"Content-Type": "application/json",
178200
}
179201

202+
r = UserInfo()
203+
204+
try:
205+
r.user_uuid = UUID(user_id)
206+
except ValueError:
207+
raise credentials_exception from None
208+
209+
r.credentials = token
210+
r.user_name = payload.get("preferred_username", "unknown")
211+
180212
# get user's project groups and roles from TDEI
181-
# TODO: fix if user has > 50 PGs
182-
authorizationUrl = (
183-
settings.TDEI_BACKEND_URL
184-
+ "/project-group-roles/"
185-
+ user_id
186-
+ "?page_no=1&page_size=50"
187-
)
213+
pgs = []
188214

189-
response = requests.get(authorizationUrl, headers=headers)
215+
try:
216+
response = await _tdei_client.get(
217+
f"project-group-roles/{user_id}",
218+
headers=headers,
219+
params={"page_no": 1, "page_size": 1000},
220+
)
221+
except httpx.RequestError:
222+
raise HTTPException(
223+
status_code=status.HTTP_502_BAD_GATEWAY,
224+
detail="Could not reach TDEI backend",
225+
) from None
190226

191227
# token is not valid or server unavailable
192228
if response.status_code != 200:
193229
raise credentials_exception
194230

195231
try:
196-
content = response.text
197-
j = json.loads(content)
198-
except json.JSONDecodeError:
232+
pg_data = response.json()
233+
except Exception:
199234
raise credentials_exception
200235

201-
r = UserInfo()
202-
r.credentials = token
203-
r.user_uuid = UUID(payload.get("sub", "unknown"))
204-
r.user_name = payload.get("preferred_username", "unknown")
205-
206-
# project groups and roles from TDEI KeyCloak
207-
pgs = []
208-
for i in j:
236+
for i in pg_data:
209237
pgs.append(
210238
UserInfoPGMembership(
211239
project_group_id=i["tdei_project_group_id"],
212240
project_group_name=i["project_group_name"],
213241
tdeiRoles=i["roles"],
214242
)
215243
)
244+
216245
r.projectGroups = pgs
217246

218247
# workspaces within our set of PGs from tasking manager DB
@@ -226,7 +255,7 @@ async def _validate_token_uncached(
226255
accessibleWorkspaces = list(result.mappings().all())
227256
r.accessibleWorkspaceIds = {}
228257
for i in accessibleWorkspaces:
229-
pgid = i["tdeiProjectGroupId"]
258+
pgid = str(i["tdeiProjectGroupId"]) # SQLAlchemy outputs UUID
230259
wsid = i["id"]
231260
if pgid not in r.accessibleWorkspaceIds:
232261
r.accessibleWorkspaceIds[pgid] = []

0 commit comments

Comments
 (0)