1- import json
21from enum import StrEnum
32from uuid import UUID
43
54import cachetools
6- import jwt
7- import requests
5+ import httpx
86from fastapi import Depends , HTTPException , status
97from fastapi .security import HTTPAuthorizationCredentials , HTTPBearer
108from sqlalchemy import text
119from sqlmodel .ext .asyncio .session import AsyncSession
1210
1311from api .core .config import settings
1412from api .core .database import get_osm_session , get_task_session
13+ from api .core .jwt import validate_and_decode_token
1514from api .core .logging import get_logger
1615from api .src .workspaces .schemas import WorkspaceUserRoleType
1716
2322 maxsize = 1000 , ttl = 60 * 60
2423)
2524
25+ # Shared HTTP client for TDEI backend calls. Initialized by main.py lifespan.
26+ _tdei_client : httpx .AsyncClient | None = None
27+
28+
29+ def init_tdei_client () -> None :
30+ global _tdei_client
31+ _tdei_client = httpx .AsyncClient (
32+ base_url = settings .TDEI_BACKEND_URL ,
33+ timeout = httpx .Timeout (connect = 10 , read = 30 , write = 30 , pool = 10 ),
34+ )
35+
36+
37+ async def close_tdei_client () -> None :
38+ global _tdei_client
39+ if _tdei_client is not None :
40+ await _tdei_client .aclose ()
41+ _tdei_client = None
42+
43+
2644security = HTTPBearer ()
2745
2846
@@ -84,7 +102,9 @@ def isWorkspaceLead(self, workspaceId: int) -> bool:
84102
85103 for pg in self .projectGroups :
86104 if TdeiProjectGroupRole .POINT_OF_CONTACT in pg .tdeiRoles :
87- if workspaceId in self .accessibleWorkspaceIds [pg .project_group_id ]:
105+ if workspaceId in self .accessibleWorkspaceIds .get (
106+ pg .project_group_id , []
107+ ):
88108 return True
89109
90110 return False
@@ -118,6 +138,7 @@ def get_task_db_session(
118138) -> AsyncSession :
119139 return session
120140
141+
121142async def validate_token (
122143 credentials : HTTPAuthorizationCredentials = Depends (security ),
123144 osm_db_session : AsyncSession = Depends (get_osm_db_session ),
@@ -129,19 +150,39 @@ async def validate_token(
129150 """
130151 token = credentials .credentials
131152
153+ credentials_exception = HTTPException (
154+ status_code = status .HTTP_401_UNAUTHORIZED ,
155+ detail = "Invalid authentication credentials" ,
156+ headers = {"WWW-Authenticate" : "Bearer" },
157+ )
158+
159+ try :
160+ payload = validate_and_decode_token (token )
161+ except Exception :
162+ raise credentials_exception
163+
164+ user_id : str | None = payload .get ("sub" )
165+ if user_id is None :
166+ raise credentials_exception
167+
132168 # Check cache first
133169 if token in _token_cache :
134170 logger .info ("Token validation cache hit" )
135171 return _token_cache [token ]
136172
137173 # Cache miss - perform full validation
138- user_info = await _validate_token_uncached (token , osm_db_session , task_db_session )
174+ user_info = await _validate_token_uncached (
175+ token , user_id , payload , osm_db_session , task_db_session
176+ )
139177 _token_cache [token ] = user_info
178+
140179 return user_info
141180
142181
143182async def _validate_token_uncached (
144183 token : str ,
184+ user_id : str ,
185+ payload : dict ,
145186 osm_db_session : AsyncSession ,
146187 task_db_session : AsyncSession ,
147188) -> UserInfo :
@@ -153,66 +194,54 @@ async def _validate_token_uncached(
153194 headers = {"WWW-Authenticate" : "Bearer" },
154195 )
155196
156- jwks_client = jwt .PyJWKClient (
157- f"{ settings .TDEI_OIDC_URL } realms/{ settings .TDEI_OIDC_REALM } /protocol/openid-connect/certs"
158- )
159-
160- signing_key = jwks_client .get_signing_key_from_jwt (token )
161-
162- jwtDecoded = jwt .decode_complete (
163- token ,
164- key = signing_key .key ,
165- algorithms = ["RS256" ],
166- # OIDC server does not currently differentiate tokens by audience
167- options = {"verify_aud" : False }
168- )
169- payload = jwtDecoded .get ("payload" , {})
170-
171- user_id : str | None = payload .get ("sub" )
172- if user_id is None :
173- raise credentials_exception
174-
175197 headers = {
176198 "Authorization" : "Bearer " + token ,
177199 "Content-Type" : "application/json" ,
178200 }
179201
202+ r = UserInfo ()
203+
204+ try :
205+ r .user_uuid = UUID (user_id )
206+ except ValueError :
207+ raise credentials_exception from None
208+
209+ r .credentials = token
210+ r .user_name = payload .get ("preferred_username" , "unknown" )
211+
180212 # get user's project groups and roles from TDEI
181- # TODO: fix if user has > 50 PGs
182- authorizationUrl = (
183- settings .TDEI_BACKEND_URL
184- + "/project-group-roles/"
185- + user_id
186- + "?page_no=1&page_size=50"
187- )
213+ pgs = []
188214
189- response = requests .get (authorizationUrl , headers = headers )
215+ try :
216+ response = await _tdei_client .get (
217+ f"project-group-roles/{ user_id } " ,
218+ headers = headers ,
219+ params = {"page_no" : 1 , "page_size" : 1000 },
220+ )
221+ except httpx .RequestError :
222+ raise HTTPException (
223+ status_code = status .HTTP_502_BAD_GATEWAY ,
224+ detail = "Could not reach TDEI backend" ,
225+ ) from None
190226
191227 # token is not valid or server unavailable
192228 if response .status_code != 200 :
193229 raise credentials_exception
194230
195231 try :
196- content = response .text
197- j = json .loads (content )
198- except json .JSONDecodeError :
232+ pg_data = response .json ()
233+ except Exception :
199234 raise credentials_exception
200235
201- r = UserInfo ()
202- r .credentials = token
203- r .user_uuid = UUID (payload .get ("sub" , "unknown" ))
204- r .user_name = payload .get ("preferred_username" , "unknown" )
205-
206- # project groups and roles from TDEI KeyCloak
207- pgs = []
208- for i in j :
236+ for i in pg_data :
209237 pgs .append (
210238 UserInfoPGMembership (
211239 project_group_id = i ["tdei_project_group_id" ],
212240 project_group_name = i ["project_group_name" ],
213241 tdeiRoles = i ["roles" ],
214242 )
215243 )
244+
216245 r .projectGroups = pgs
217246
218247 # workspaces within our set of PGs from tasking manager DB
@@ -226,7 +255,7 @@ async def _validate_token_uncached(
226255 accessibleWorkspaces = list (result .mappings ().all ())
227256 r .accessibleWorkspaceIds = {}
228257 for i in accessibleWorkspaces :
229- pgid = i ["tdeiProjectGroupId" ]
258+ pgid = str ( i ["tdeiProjectGroupId" ]) # SQLAlchemy outputs UUID
230259 wsid = i ["id" ]
231260 if pgid not in r .accessibleWorkspaceIds :
232261 r .accessibleWorkspaceIds [pgid ] = []
0 commit comments