1+ import base64
12import requests
23import logging
34import json
78
89from fastapi import HTTPException , status
910from authlib .integrations .starlette_client import OAuth
10- from httpx import Timeout
11+ from httpx import AsyncClient , Timeout
1112
1213from . import settings
1314
3233)
3334
3435
36+ def _decode_jwt_payload (token_str : str ) -> dict :
37+ try :
38+ payload_b64 = token_str .split ("." )[1 ]
39+ payload_b64 += "=" * (4 - len (payload_b64 ) % 4 )
40+ return json .loads (base64 .urlsafe_b64decode (payload_b64 ))
41+ except Exception as err :
42+ raise HTTPException (
43+ status_code = status .HTTP_401_UNAUTHORIZED ,
44+ detail = f"Could not decode token: { err } "
45+ )
46+
47+
3548def get_kg_client_for_service_account ():
3649 global kg_client_for_service_account
3750 if kg_client_for_service_account is None :
@@ -71,41 +84,41 @@ def __init__(self, token, allow_anonymous=False):
7184 detail = "You need to provide a bearer token to access this resource"
7285 )
7386 self .token = token
74- self ._user_info = None
87+ self ._identity = None
88+ self ._teams = None
7589 self ._collab_info = {}
7690 self ._connection_error = False
7791
7892 @property
7993 def is_anonymous (self ):
8094 return self .token is None or self .token .credentials == "undefined"
8195
82- async def get_user_info (self ):
83- if self ._user_info is None :
84- user_info = await oauth .ebrains .userinfo (
85- token = {"access_token" : self .token .credentials , "token_type" : "bearer" }
86- )
87- if "error" in user_info :
88- raise HTTPException (
89- status_code = status .HTTP_401_UNAUTHORIZED , detail = user_info ["error_description" ]
90- )
91- elif user_info .get ("statusCode" , None ) == 401 :
92- raise HTTPException (
93- status_code = status .HTTP_401_UNAUTHORIZED , detail = user_info ["message" ]
94- )
95- elif user_info .get ("statusCode" , None ) == 500 :
96- raise HTTPException (
97- status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
98- detail = f'Problem getting user_info: { user_info ["message" ]} '
99- )
100- logger .debug (user_info )
101- try :
102- # make this compatible with the v1 json
103- user_info ["id" ] = user_info ["sub" ]
104- user_info ["username" ] = user_info .get ("preferred_username" , "unknown" )
105- except KeyError :
106- raise Exception (user_info )
107- self ._user_info = user_info
108- return self ._user_info
96+ async def get_identity (self ):
97+ if self ._identity is None :
98+ payload = _decode_jwt_payload (self .token .credentials )
99+ username = payload .get ("preferred_username" , "unknown" )
100+ self ._identity = {
101+ "sub" : payload ["sub" ],
102+ "id" : payload ["sub" ],
103+ "preferred_username" : username ,
104+ "username" : username ,
105+ "given_name" : payload .get ("given_name" , "" ),
106+ "family_name" : payload .get ("family_name" , "" ),
107+ }
108+ return self ._identity
109+
110+ async def get_teams (self ):
111+ if self ._teams is None :
112+ identity = await self .get_identity ()
113+ url = f"{ settings .EBRAINS_IDM_API_URL } /teams"
114+ headers = {"Authorization" : f"Bearer { self .token .credentials } " }
115+ params = {"username" : identity ["username" ]}
116+ async with AsyncClient () as client :
117+ res = await client .get (url , headers = headers , params = params ,
118+ timeout = settings .AUTHENTICATION_TIMEOUT )
119+ res .raise_for_status ()
120+ self ._teams = [t ["name" ] for t in res .json () if isinstance (t , dict ) and "name" in t ]
121+ return self ._teams
109122
110123 async def get_collab_info (self , collab_id ):
111124 if collab_id not in self ._collab_info :
@@ -122,9 +135,9 @@ async def get_collab_info(self, collab_id):
122135 return self ._collab_info [collab_id ]
123136
124137 async def get_person (self , kg_client ):
125- user_info = await self .get_user_info ()
126- family_name = user_info ["family_name" ]
127- given_name = user_info ["given_name" ]
138+ identity = await self .get_identity ()
139+ family_name = identity ["family_name" ]
140+ given_name = identity ["given_name" ]
128141 person = omcore .Person .list (kg_client , family_name = family_name , given_name = given_name , scope = "any" )
129142 if person :
130143 if isinstance (person , list ):
@@ -139,14 +152,14 @@ async def get_person(self, kg_client):
139152 return None
140153
141154 async def get_collab_permissions (self , collab_id ):
142- user_info = await self .get_user_info ()
155+ teams = await self .get_teams ()
143156
144157 target_team_names = {role : f"collab-{ collab_id } -{ role } "
145158 for role in ("viewer" , "editor" , "administrator" )}
146159
147160 highest_collab_role = None
148161 for role , team_name in target_team_names .items ():
149- if team_name in user_info [ "roles" ][ "team" ] :
162+ if team_name in teams :
150163 highest_collab_role = role
151164 if highest_collab_role == "viewer" :
152165 permissions = {"VIEW" : True , "UPDATE" : False }
@@ -184,9 +197,9 @@ async def is_admin(self):
184197 # todo: replace this check with a group membership check
185198
186199 async def get_editable_collabs (self ):
187- user_info = await self .get_user_info ()
200+ teams = await self .get_teams ()
188201 editable_collab_ids = set ()
189- for team_name in user_info [ "roles" ][ "team" ] :
202+ for team_name in teams :
190203 if team_name .endswith ("-editor" ) or team_name .endswith ("-administrator" ):
191204 collab_id = "-" .join (team_name .split ("-" )[1 :- 1 ])
192205 editable_collab_ids .add (collab_id )
0 commit comments