Skip to content

Commit e32ec02

Browse files
authored
Merge pull request #108 from datakind/develop
feat: adjusted UI api auth
2 parents 3f40549 + 22f7fd1 commit e32ec02

3 files changed

Lines changed: 80 additions & 29 deletions

File tree

src/webapp/authn.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,23 @@
66
import jwt
77
from fastapi import Security, HTTPException, status
88
from fastapi.security import (
9-
OAuth2PasswordBearer,
9+
# OAuth2PasswordBearer,
1010
APIKeyHeader,
11+
HTTPBearer,
1112
)
1213
from passlib.context import CryptContext
1314
from pydantic import BaseModel
1415
from .config import env_vars
15-
16+
from typing import Any
1617

1718
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
1819

19-
oauth2_apikey_scheme = OAuth2PasswordBearer(
20-
scheme_name="api_key_scheme",
21-
tokenUrl="token-from-api-key",
20+
# oauth2_apikey_scheme = OAuth2PasswordBearer(
21+
# scheme_name="api_key_scheme",
22+
# tokenUrl="token-from-api-key",
23+
# )
24+
oauth2_apikey_scheme = HTTPBearer(
25+
auto_error=True, scheme_name="Bearer token (get from /token-from-api-key)"
2226
)
2327

2428
api_key_header = APIKeyHeader(name="X-API-KEY", scheme_name="api-key", auto_error=False)
@@ -57,24 +61,24 @@ def check_creds(username: str, password: str) -> bool:
5761
return username == env_vars["USERNAME"] and password == env_vars["PASSWORD"]
5862

5963

60-
def verify_password(plain_password: str, hashed_password: str) -> bool:
64+
def verify_password(plain_password: str, hashed_password: str) -> Any:
6165
"""Verify a plain password against a hash. Includes a 2y/2b replacement since Laravel
6266
Generates hashes that start with 2y. The hashing scheme recognizes both."""
6367
revert_hash = hashed_password.replace("$2y", "$2b", 1)
6468
return pwd_context.verify(plain_password, revert_hash)
6569

6670

67-
def verify_api_key(plain_api_key: str, hashed_key: str) -> bool:
71+
def verify_api_key(plain_api_key: str, hashed_key: str) -> Any:
6872
"""Verify a plain API Key against a hash."""
6973
return pwd_context.verify(plain_api_key, hashed_key)
7074

7175

72-
def get_api_key_hash(api_key: str):
76+
def get_api_key_hash(api_key: str) -> Any:
7377
"""Hash a given api key."""
7478
return pwd_context.hash(api_key)
7579

7680

77-
def get_password_hash(password: str):
81+
def get_password_hash(password: str) -> Any:
7882
"""Hash a password. To align with the password hashing used by Laravel, we have to replace the 2b
7983
generated by pwd_context with 2y and that should be the version we store.
8084
They should be functionally the same: https://stackoverflow.com/a/36225192/28478909
@@ -83,7 +87,7 @@ def get_password_hash(password: str):
8387
return initial_hash.replace("$2b", "$2y", 1)
8488

8589

86-
def create_access_token(data: dict, expires_delta: timedelta | None = None):
90+
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
8791
"""Create a JWT."""
8892
to_encode = data.copy()
8993
if expires_delta:

src/webapp/routers/data.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,12 +1028,12 @@ def get_upload_url(
10281028
def get_top_features(
10291029
inst_id: str,
10301030
run_id: str,
1031-
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1031+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
10321032
sql_session: Annotated[Session, Depends(get_session)],
10331033
) -> List[dict[str, Any]]:
10341034
"""Returns a signed URL for uploading data to a specific institution."""
10351035
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1036-
# has_access_to_inst_or_err(inst_id, current_user)
1036+
has_access_to_inst_or_err(inst_id, current_user)
10371037
local_session.set(sql_session)
10381038
query_result = (
10391039
local_session.get()
@@ -1072,12 +1072,12 @@ def get_top_features(
10721072
def get_support_overview(
10731073
inst_id: str,
10741074
run_id: str,
1075-
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1075+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
10761076
sql_session: Annotated[Session, Depends(get_session)],
10771077
) -> List[dict[str, Any]]:
10781078
"""Returns a signed URL for uploading data to a specific institution."""
10791079
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1080-
# has_access_to_inst_or_err(inst_id, current_user)
1080+
has_access_to_inst_or_err(inst_id, current_user)
10811081
local_session.set(sql_session)
10821082
query_result = (
10831083
local_session.get()
@@ -1111,16 +1111,59 @@ def get_support_overview(
11111111
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
11121112

11131113

1114+
@router.get("/{inst_id}/training/support-overview/{run_id}")
1115+
def get_training_support_overview(
1116+
inst_id: str,
1117+
run_id: str,
1118+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1119+
sql_session: Annotated[Session, Depends(get_session)],
1120+
) -> List[dict[str, Any]]:
1121+
"""Returns a signed URL for uploading data to a specific institution."""
1122+
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1123+
has_access_to_inst_or_err(inst_id, current_user)
1124+
local_session.set(sql_session)
1125+
query_result = (
1126+
local_session.get()
1127+
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
1128+
.all()
1129+
)
1130+
if not query_result or len(query_result) == 0:
1131+
raise HTTPException(
1132+
status_code=status.HTTP_404_NOT_FOUND,
1133+
detail="Institution not found.",
1134+
)
1135+
if len(query_result) > 1:
1136+
raise HTTPException(
1137+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1138+
detail="Institution duplicates found.",
1139+
)
1140+
1141+
try:
1142+
dbc = DatabricksControl()
1143+
rows = dbc.fetch_table_data(
1144+
catalog_name=env_vars["CATALOG_NAME"],
1145+
inst_name=f"{query_result[0][0].name}",
1146+
table_name=f"sample_training_{run_id}_support_overview",
1147+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1148+
limit=500,
1149+
)
1150+
1151+
return rows
1152+
except ValueError as ve:
1153+
# Return a 400 error with the specific message from ValueError
1154+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
1155+
1156+
11141157
@router.get("/{inst_id}/inference/feature_value/{run_id}")
11151158
def get_feature_value(
11161159
inst_id: str,
11171160
run_id: str,
1118-
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1161+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
11191162
sql_session: Annotated[Session, Depends(get_session)],
11201163
) -> List[dict[str, Any]]:
11211164
"""Returns a signed URL for uploading data to a specific institution."""
11221165
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1123-
# has_access_to_inst_or_err(inst_id, current_user)
1166+
has_access_to_inst_or_err(inst_id, current_user)
11241167
local_session.set(sql_session)
11251168
query_result = (
11261169
local_session.get()
@@ -1158,12 +1201,12 @@ def get_feature_value(
11581201
def get_confusion_matrix(
11591202
inst_id: str,
11601203
run_id: str,
1161-
##current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1204+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
11621205
sql_session: Annotated[Session, Depends(get_session)],
11631206
) -> List[dict[str, Any]]:
11641207
"""Returns a signed URL for uploading data to a specific institution."""
11651208
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1166-
# has_access_to_inst_or_err(inst_id, current_user)
1209+
has_access_to_inst_or_err(inst_id, current_user)
11671210
local_session.set(sql_session)
11681211
query_result = (
11691212
local_session.get()
@@ -1201,12 +1244,12 @@ def get_confusion_matrix(
12011244
def get_roc_curve(
12021245
inst_id: str,
12031246
run_id: str,
1204-
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1247+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
12051248
sql_session: Annotated[Session, Depends(get_session)],
12061249
) -> List[dict[str, Any]]:
12071250
"""Returns a signed URL for uploading data to a specific institution."""
12081251
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1209-
# has_access_to_inst_or_err(inst_id, current_user)
1252+
has_access_to_inst_or_err(inst_id, current_user)
12101253
local_session.set(sql_session)
12111254
query_result = (
12121255
local_session.get()

src/webapp/utilities.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import uuid
44
import re
5-
from typing import Annotated, Final
5+
from typing import Annotated, Final, Any
66
from urllib.parse import unquote
77
from strenum import StrEnum # needed for python pre 3.11
88
import jwt
@@ -12,6 +12,7 @@
1212
from sqlalchemy.orm import Session
1313
from sqlalchemy.future import select
1414
from sqlalchemy import and_
15+
from fastapi.security import HTTPAuthorizationCredentials
1516

1617
from .authn import (
1718
verify_api_key,
@@ -21,7 +22,7 @@
2122
from .config import env_vars
2223

2324

24-
def decode_url_piece(src: str):
25+
def decode_url_piece(src: str) -> str:
2526
"""Decode encoded URL."""
2627
return unquote(src)
2728

@@ -164,29 +165,29 @@ class BaseUser(BaseModel):
164165
def __init__(self, usr: str | None, inst: str, access: str, email: str) -> None:
165166
super().__init__(user_id=usr, institution=inst, access_type=access, email=email)
166167

167-
def is_datakinder(self) -> bool:
168+
def is_datakinder(self) -> Any:
168169
"""Whether a given user is a Datakinder."""
169170
return self.access_type and self.access_type == AccessType.DATAKINDER
170171

171-
def is_model_owner(self) -> bool:
172+
def is_model_owner(self) -> Any:
172173
"""Whether a given user is a model owner."""
173174
return self.access_type and self.access_type == AccessType.MODEL_OWNER
174175

175-
def is_data_owner(self) -> bool:
176+
def is_data_owner(self) -> Any:
176177
"""Whether a given user is a data owner."""
177178
return self.access_type and self.access_type == AccessType.DATA_OWNER
178179

179-
def is_viewer(self) -> bool:
180+
def is_viewer(self) -> Any:
180181
"""Whether a given user is a viewer."""
181182
return self.access_type and self.access_type == AccessType.VIEWER
182183

183-
def has_access_to_inst(self, inst: str) -> bool:
184+
def has_access_to_inst(self, inst: str) -> Any:
184185
"""Whether a given user has access to a given institution."""
185186
return self.access_type and (
186187
self.access_type == AccessType.DATAKINDER or self.institution == inst
187188
)
188189

189-
def has_full_data_access(self) -> bool:
190+
def has_full_data_access(self) -> Any:
190191
"""Datakinders, model_owners, data_owners, all have full data access."""
191192
return self.access_type and self.access_type in (
192193
AccessType.DATAKINDER,
@@ -312,7 +313,9 @@ def authenticate_api_key(api_key_enduser_tuple: str, sess: Session) -> BaseUser:
312313

313314
async def get_current_user(
314315
sess: Annotated[Session, Depends(get_session)],
315-
token_from_key: Annotated[str, Depends(oauth2_apikey_scheme)],
316+
token_from_key: Annotated[
317+
HTTPAuthorizationCredentials, Depends(oauth2_apikey_scheme)
318+
],
316319
) -> BaseUser:
317320
"""Get the user from a given token."""
318321
credentials_exception = HTTPException(
@@ -321,6 +324,7 @@ async def get_current_user(
321324
headers={"WWW-Authenticate": "Bearer"},
322325
)
323326
usrname = None
327+
token_from_key = token_from_key.credentials
324328
try:
325329
if not token_from_key:
326330
raise credentials_exception

0 commit comments

Comments
 (0)