Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/webapp/authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,23 @@
import jwt
from fastapi import Security, HTTPException, status
from fastapi.security import (
OAuth2PasswordBearer,
# OAuth2PasswordBearer,
APIKeyHeader,
HTTPBearer,
)
from passlib.context import CryptContext
from pydantic import BaseModel
from .config import env_vars

from typing import Any

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

oauth2_apikey_scheme = OAuth2PasswordBearer(
scheme_name="api_key_scheme",
tokenUrl="token-from-api-key",
# oauth2_apikey_scheme = OAuth2PasswordBearer(
# scheme_name="api_key_scheme",
# tokenUrl="token-from-api-key",
# )
oauth2_apikey_scheme = HTTPBearer(
auto_error=True, scheme_name="Bearer token (get from /token-from-api-key)"
)

api_key_header = APIKeyHeader(name="X-API-KEY", scheme_name="api-key", auto_error=False)
Expand Down Expand Up @@ -57,24 +61,24 @@ def check_creds(username: str, password: str) -> bool:
return username == env_vars["USERNAME"] and password == env_vars["PASSWORD"]


def verify_password(plain_password: str, hashed_password: str) -> bool:
def verify_password(plain_password: str, hashed_password: str) -> Any:
"""Verify a plain password against a hash. Includes a 2y/2b replacement since Laravel
Generates hashes that start with 2y. The hashing scheme recognizes both."""
revert_hash = hashed_password.replace("$2y", "$2b", 1)
return pwd_context.verify(plain_password, revert_hash)


def verify_api_key(plain_api_key: str, hashed_key: str) -> bool:
def verify_api_key(plain_api_key: str, hashed_key: str) -> Any:
"""Verify a plain API Key against a hash."""
return pwd_context.verify(plain_api_key, hashed_key)


def get_api_key_hash(api_key: str):
def get_api_key_hash(api_key: str) -> Any:
"""Hash a given api key."""
return pwd_context.hash(api_key)


def get_password_hash(password: str):
def get_password_hash(password: str) -> Any:
"""Hash a password. To align with the password hashing used by Laravel, we have to replace the 2b
generated by pwd_context with 2y and that should be the version we store.
They should be functionally the same: https://stackoverflow.com/a/36225192/28478909
Expand All @@ -83,7 +87,7 @@ def get_password_hash(password: str):
return initial_hash.replace("$2b", "$2y", 1)


def create_access_token(data: dict, expires_delta: timedelta | None = None):
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
"""Create a JWT."""
to_encode = data.copy()
if expires_delta:
Expand Down
63 changes: 53 additions & 10 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,12 +1028,12 @@ def get_upload_url(
def get_top_features(
inst_id: str,
run_id: str,
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
# has_access_to_inst_or_err(inst_id, current_user)
has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
Expand Down Expand Up @@ -1072,12 +1072,12 @@ def get_top_features(
def get_support_overview(
inst_id: str,
run_id: str,
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
# has_access_to_inst_or_err(inst_id, current_user)
has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
Expand Down Expand Up @@ -1111,16 +1111,59 @@ def get_support_overview(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


@router.get("/{inst_id}/training/support-overview/{run_id}")
def get_training_support_overview(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)
if not query_result or len(query_result) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(query_result) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_training_{run_id}_support_overview",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
limit=500,
)

return rows
except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


@router.get("/{inst_id}/inference/feature_value/{run_id}")
def get_feature_value(
inst_id: str,
run_id: str,
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
# has_access_to_inst_or_err(inst_id, current_user)
has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
Expand Down Expand Up @@ -1158,12 +1201,12 @@ def get_feature_value(
def get_confusion_matrix(
inst_id: str,
run_id: str,
##current_user: Annotated[BaseUser, Depends(get_current_active_user)],
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
# has_access_to_inst_or_err(inst_id, current_user)
has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
Expand Down Expand Up @@ -1201,12 +1244,12 @@ def get_confusion_matrix(
def get_roc_curve(
inst_id: str,
run_id: str,
# current_user: Annotated[BaseUser, Depends(get_current_active_user)],
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
# has_access_to_inst_or_err(inst_id, current_user)
has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
Expand Down
22 changes: 13 additions & 9 deletions src/webapp/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import uuid
import re
from typing import Annotated, Final
from typing import Annotated, Final, Any
from urllib.parse import unquote
from strenum import StrEnum # needed for python pre 3.11
import jwt
Expand All @@ -12,6 +12,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.future import select
from sqlalchemy import and_
from fastapi.security import HTTPAuthorizationCredentials

from .authn import (
verify_api_key,
Expand All @@ -21,7 +22,7 @@
from .config import env_vars


def decode_url_piece(src: str):
def decode_url_piece(src: str) -> str:
"""Decode encoded URL."""
return unquote(src)

Expand Down Expand Up @@ -164,29 +165,29 @@ class BaseUser(BaseModel):
def __init__(self, usr: str | None, inst: str, access: str, email: str) -> None:
super().__init__(user_id=usr, institution=inst, access_type=access, email=email)

def is_datakinder(self) -> bool:
def is_datakinder(self) -> Any:
"""Whether a given user is a Datakinder."""
return self.access_type and self.access_type == AccessType.DATAKINDER

def is_model_owner(self) -> bool:
def is_model_owner(self) -> Any:
"""Whether a given user is a model owner."""
return self.access_type and self.access_type == AccessType.MODEL_OWNER

def is_data_owner(self) -> bool:
def is_data_owner(self) -> Any:
"""Whether a given user is a data owner."""
return self.access_type and self.access_type == AccessType.DATA_OWNER

def is_viewer(self) -> bool:
def is_viewer(self) -> Any:
"""Whether a given user is a viewer."""
return self.access_type and self.access_type == AccessType.VIEWER

def has_access_to_inst(self, inst: str) -> bool:
def has_access_to_inst(self, inst: str) -> Any:
"""Whether a given user has access to a given institution."""
return self.access_type and (
self.access_type == AccessType.DATAKINDER or self.institution == inst
)

def has_full_data_access(self) -> bool:
def has_full_data_access(self) -> Any:
"""Datakinders, model_owners, data_owners, all have full data access."""
return self.access_type and self.access_type in (
AccessType.DATAKINDER,
Expand Down Expand Up @@ -312,7 +313,9 @@ def authenticate_api_key(api_key_enduser_tuple: str, sess: Session) -> BaseUser:

async def get_current_user(
sess: Annotated[Session, Depends(get_session)],
token_from_key: Annotated[str, Depends(oauth2_apikey_scheme)],
token_from_key: Annotated[
HTTPAuthorizationCredentials, Depends(oauth2_apikey_scheme)
],
) -> BaseUser:
"""Get the user from a given token."""
credentials_exception = HTTPException(
Expand All @@ -321,6 +324,7 @@ async def get_current_user(
headers={"WWW-Authenticate": "Bearer"},
)
usrname = None
token_from_key = token_from_key.credentials
try:
if not token_from_key:
raise credentials_exception
Expand Down
Loading