Skip to content

Commit 35ac68c

Browse files
authored
Merge pull request #343 from apdavison/refactor-userinfo
Refactor `get_user_info()` into two methods: `get_identity()` and `get_teams()`
2 parents cc450d7 + 06e4685 commit 35ac68c

6 files changed

Lines changed: 59 additions & 46 deletions

File tree

validation_service_api/validation_service/auth.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import requests
23
import logging
34
import json
@@ -7,7 +8,7 @@
78

89
from fastapi import HTTPException, status
910
from authlib.integrations.starlette_client import OAuth
10-
from httpx import Timeout
11+
from httpx import AsyncClient, Timeout
1112

1213
from . import settings
1314

@@ -32,6 +33,18 @@
3233
)
3334

3435

36+
def _decode_jwt_payload(token_str: str) -> dict:
37+
try:
38+
payload_b64 = token_str.split(".")[1]
39+
payload_b64 += "=" * (4 - len(payload_b64) % 4)
40+
return json.loads(base64.urlsafe_b64decode(payload_b64))
41+
except Exception as err:
42+
raise HTTPException(
43+
status_code=status.HTTP_401_UNAUTHORIZED,
44+
detail=f"Could not decode token: {err}"
45+
)
46+
47+
3548
def get_kg_client_for_service_account():
3649
global kg_client_for_service_account
3750
if kg_client_for_service_account is None:
@@ -71,41 +84,41 @@ def __init__(self, token, allow_anonymous=False):
7184
detail="You need to provide a bearer token to access this resource"
7285
)
7386
self.token = token
74-
self._user_info = None
87+
self._identity = None
88+
self._teams = None
7589
self._collab_info = {}
7690
self._connection_error = False
7791

7892
@property
7993
def is_anonymous(self):
8094
return self.token is None or self.token.credentials == "undefined"
8195

82-
async def get_user_info(self):
83-
if self._user_info is None:
84-
user_info = await oauth.ebrains.userinfo(
85-
token={"access_token": self.token.credentials, "token_type": "bearer"}
86-
)
87-
if "error" in user_info:
88-
raise HTTPException(
89-
status_code=status.HTTP_401_UNAUTHORIZED, detail=user_info["error_description"]
90-
)
91-
elif user_info.get("statusCode", None) == 401:
92-
raise HTTPException(
93-
status_code=status.HTTP_401_UNAUTHORIZED, detail=user_info["message"]
94-
)
95-
elif user_info.get("statusCode", None) == 500:
96-
raise HTTPException(
97-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
98-
detail=f'Problem getting user_info: {user_info["message"]}'
99-
)
100-
logger.debug(user_info)
101-
try:
102-
# make this compatible with the v1 json
103-
user_info["id"] = user_info["sub"]
104-
user_info["username"] = user_info.get("preferred_username", "unknown")
105-
except KeyError:
106-
raise Exception(user_info)
107-
self._user_info = user_info
108-
return self._user_info
96+
async def get_identity(self):
97+
if self._identity is None:
98+
payload = _decode_jwt_payload(self.token.credentials)
99+
username = payload.get("preferred_username", "unknown")
100+
self._identity = {
101+
"sub": payload["sub"],
102+
"id": payload["sub"],
103+
"preferred_username": username,
104+
"username": username,
105+
"given_name": payload.get("given_name", ""),
106+
"family_name": payload.get("family_name", ""),
107+
}
108+
return self._identity
109+
110+
async def get_teams(self):
111+
if self._teams is None:
112+
identity = await self.get_identity()
113+
url = f"{settings.EBRAINS_IDM_API_URL}/teams"
114+
headers = {"Authorization": f"Bearer {self.token.credentials}"}
115+
params = {"username": identity["username"]}
116+
async with AsyncClient() as client:
117+
res = await client.get(url, headers=headers, params=params,
118+
timeout=settings.AUTHENTICATION_TIMEOUT)
119+
res.raise_for_status()
120+
self._teams = [t["name"] for t in res.json() if isinstance(t, dict) and "name" in t]
121+
return self._teams
109122

110123
async def get_collab_info(self, collab_id):
111124
if collab_id not in self._collab_info:
@@ -122,9 +135,9 @@ async def get_collab_info(self, collab_id):
122135
return self._collab_info[collab_id]
123136

124137
async def get_person(self, kg_client):
125-
user_info = await self.get_user_info()
126-
family_name = user_info["family_name"]
127-
given_name = user_info["given_name"]
138+
identity = await self.get_identity()
139+
family_name = identity["family_name"]
140+
given_name = identity["given_name"]
128141
person = omcore.Person.list(kg_client, family_name=family_name, given_name=given_name, scope="any")
129142
if person:
130143
if isinstance(person, list):
@@ -139,14 +152,14 @@ async def get_person(self, kg_client):
139152
return None
140153

141154
async def get_collab_permissions(self, collab_id):
142-
user_info = await self.get_user_info()
155+
teams = await self.get_teams()
143156

144157
target_team_names = {role: f"collab-{collab_id}-{role}"
145158
for role in ("viewer", "editor", "administrator")}
146159

147160
highest_collab_role = None
148161
for role, team_name in target_team_names.items():
149-
if team_name in user_info["roles"]["team"]:
162+
if team_name in teams:
150163
highest_collab_role = role
151164
if highest_collab_role == "viewer":
152165
permissions = {"VIEW": True, "UPDATE": False}
@@ -184,9 +197,9 @@ async def is_admin(self):
184197
# todo: replace this check with a group membership check
185198

186199
async def get_editable_collabs(self):
187-
user_info = await self.get_user_info()
200+
teams = await self.get_teams()
188201
editable_collab_ids = set()
189-
for team_name in user_info["roles"]["team"]:
202+
for team_name in teams:
190203
if team_name.endswith("-editor") or team_name.endswith("-administrator"):
191204
collab_id = "-".join(team_name.split("-")[1:-1])
192205
editable_collab_ids.add(collab_id)

validation_service_api/validation_service/resources/auth.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def list_projects(
5959
return []
6060
else:
6161
try:
62-
user_info = await user.get_user_info()
62+
teams = await user.get_teams()
6363
except HTTPStatusError as err:
6464
if "401" in str(err):
6565
raise HTTPException(
@@ -68,9 +68,8 @@ async def list_projects(
6868
)
6969
else:
7070
raise
71-
roles = user_info.get("roles", {}).get("team", [])
7271
projects = {}
73-
for role in roles:
72+
for role in teams:
7473
if role.startswith("collab-"):
7574
project_id = "-".join(role.split("-")[1:-1])
7675
if project_id not in projects:

validation_service_api/validation_service/resources/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ async def api_status(token: HTTPAuthorizationCredentials = Depends(auth)):
5252
}
5353
}
5454
if token:
55-
user_info = await User(token).get_user_info()
56-
info["user"] = user_info["preferred_username"]
55+
identity = await User(token).get_identity()
56+
info["user"] = identity["preferred_username"]
5757
service_status = getattr(settings, "SERVICE_STATUS", "ok")
5858
return info
5959

validation_service_api/validation_service/resources/tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,10 @@ async def create_test(test: ValidationTest, token: HTTPAuthorizationCredentials
263263
try:
264264
test_definition.save(kg_user_client, recursive=True, space=kg_space, ignore_duplicates=True)
265265
except AuthenticationError as err:
266-
user_info = await user.get_user_info()
266+
identity = await user.get_identity()
267267
raise HTTPException(
268268
status_code=status.HTTP_403_FORBIDDEN,
269-
detail=f"User {user_info['username']} cannot access space {kg_space}. Error message: {err}"
269+
detail=f"User {identity['username']} cannot access space {kg_space}. Error message: {err}"
270270
)
271271
return ValidationTest.from_kg_object(test_definition, kg_user_client)
272272

validation_service_api/validation_service/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
SERVICE_STATUS = os.environ.get("VF_SERVICE_STATUS", "ok")
1818
# e.g. SERVICE_STATUS = "The site is undergoing maintenance, and is currently in read-only mode."
1919
AUTHENTICATION_TIMEOUT = 20
20+
EBRAINS_IDM_API_URL = os.environ.get("EBRAINS_IDM_API_URL", "https://idm.ebrains.eu")
2021

2122
this_dir = os.path.dirname(__file__)
2223
build_info_path = os.path.join(this_dir, "build_info.json")

validation_service_api/validation_service/tests/test_user.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ async def test_user__is_admin():
2525

2626

2727
@pytest.mark.asyncio
28-
async def test_user_info():
28+
async def test_user_teams():
2929
user = User(token, allow_anonymous=False)
30-
user_info = await user.get_user_info()
31-
assert "collab-model-validation-administrator" not in user_info["roles"]["team"]
30+
teams = await user.get_teams()
31+
assert "collab-model-validation-administrator" not in teams
3232

3333

3434
@pytest.mark.asyncio

0 commit comments

Comments
 (0)