From 7f40dc45043d919077afd90b0d2bb728ba8aac64 Mon Sep 17 00:00:00 2001 From: Mesh Date: Thu, 5 Jun 2025 15:24:01 -0500 Subject: [PATCH 1/7] feat: added option training support dist --- src/webapp/routers/data.py | 62 ++++++++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 0094aebb..bb37d5ab 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -1028,12 +1028,12 @@ def get_upload_url( def get_top_features( inst_id: str, run_id: str, - # current_user: Annotated[BaseUser, Depends(get_current_active_user)], + current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - # has_access_to_inst_or_err(inst_id, current_user) + has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1072,12 +1072,12 @@ def get_top_features( def get_support_overview( inst_id: str, run_id: str, - # current_user: Annotated[BaseUser, Depends(get_current_active_user)], + current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - # has_access_to_inst_or_err(inst_id, current_user) + has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1110,17 +1110,59 @@ def get_support_overview( # Return a 400 error with the specific message from ValueError raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) +@router.get("/{inst_id}/training/support-overview/{run_id}") +def get_training_support_overview( + inst_id: str, + run_id: str, + current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + inst_name=f"{query_result[0][0].name}", + table_name=f"sample_training_{run_id}_support_overview", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + limit=500, + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + @router.get("/{inst_id}/inference/feature_value/{run_id}") def get_feature_value( inst_id: str, run_id: str, - # current_user: Annotated[BaseUser, Depends(get_current_active_user)], + current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - # has_access_to_inst_or_err(inst_id, current_user) + has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1158,12 +1200,12 @@ def get_feature_value( def get_confusion_matrix( inst_id: str, run_id: str, - ##current_user: Annotated[BaseUser, Depends(get_current_active_user)], + current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - # has_access_to_inst_or_err(inst_id, current_user) + has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() @@ -1201,12 +1243,12 @@ def get_confusion_matrix( def get_roc_curve( inst_id: str, run_id: str, - # current_user: Annotated[BaseUser, Depends(get_current_active_user)], + current_user: Annotated[BaseUser, Depends(get_current_active_user)], sql_session: Annotated[Session, Depends(get_session)], ) -> List[dict[str, Any]]: """Returns a signed URL for uploading data to a specific institution.""" # raise error at this level instead bc otherwise it's getting wrapped as a 200 - # has_access_to_inst_or_err(inst_id, current_user) + has_access_to_inst_or_err(inst_id, current_user) local_session.set(sql_session) query_result = ( local_session.get() From 35c81548e807a5b4a8cd449f12ac4bfc08268e0b Mon Sep 17 00:00:00 2001 From: Mesh Date: Thu, 5 Jun 2025 15:24:25 -0500 Subject: [PATCH 2/7] feat: added option for api auth --- src/webapp/routers/data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index bb37d5ab..e415b1b5 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -1110,6 +1110,7 @@ def get_support_overview( # Return a 400 error with the specific message from ValueError raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + @router.get("/{inst_id}/training/support-overview/{run_id}") def get_training_support_overview( inst_id: str, From 57a256f4ebb437aee26f4991b916a271ea2361ce Mon Sep 17 00:00:00 2001 From: Mesh Date: Mon, 9 Jun 2025 09:25:32 -0500 Subject: [PATCH 3/7] adjusted api authentication --- src/webapp/authn.py | 12 +++++++----- src/webapp/utilities.py | 4 +++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/webapp/authn.py b/src/webapp/authn.py index 1f60c4ed..c45d6811 100644 --- a/src/webapp/authn.py +++ b/src/webapp/authn.py @@ -6,8 +6,9 @@ import jwt from fastapi import Security, HTTPException, status from fastapi.security import ( - OAuth2PasswordBearer, + #OAuth2PasswordBearer, APIKeyHeader, + HTTPBearer ) from passlib.context import CryptContext from pydantic import BaseModel @@ -16,10 +17,11 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -oauth2_apikey_scheme = OAuth2PasswordBearer( - scheme_name="api_key_scheme", - tokenUrl="token-from-api-key", -) +#oauth2_apikey_scheme = OAuth2PasswordBearer( + # scheme_name="api_key_scheme", + #tokenUrl="token-from-api-key", +#) +oauth2_apikey_scheme = HTTPBearer(auto_error=True) api_key_header = APIKeyHeader(name="X-API-KEY", scheme_name="api-key", auto_error=False) # The INST value may be empty for Datakinder or cross-institution access. diff --git a/src/webapp/utilities.py b/src/webapp/utilities.py index afc88fa7..eb126c7c 100644 --- a/src/webapp/utilities.py +++ b/src/webapp/utilities.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import Session from sqlalchemy.future import select from sqlalchemy import and_ +from fastapi.security import HTTPAuthorizationCredentials from .authn import ( verify_api_key, @@ -312,7 +313,7 @@ def authenticate_api_key(api_key_enduser_tuple: str, sess: Session) -> BaseUser: async def get_current_user( sess: Annotated[Session, Depends(get_session)], - token_from_key: Annotated[str, Depends(oauth2_apikey_scheme)], + token_from_key: Annotated[HTTPAuthorizationCredentials, Depends(oauth2_apikey_scheme)], ) -> BaseUser: """Get the user from a given token.""" credentials_exception = HTTPException( @@ -321,6 +322,7 @@ async def get_current_user( headers={"WWW-Authenticate": "Bearer"}, ) usrname = None + token_from_key = token_from_key.credentials try: if not token_from_key: raise credentials_exception From 9caee1417ca85872c4ddf69b03c538aaac0d7c36 Mon Sep 17 00:00:00 2001 From: Mesh Date: Mon, 9 Jun 2025 09:51:26 -0500 Subject: [PATCH 4/7] adjusted api authentication --- src/webapp/utilities.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/webapp/utilities.py b/src/webapp/utilities.py index eb126c7c..a5a4bb7e 100644 --- a/src/webapp/utilities.py +++ b/src/webapp/utilities.py @@ -322,6 +322,7 @@ async def get_current_user( headers={"WWW-Authenticate": "Bearer"}, ) usrname = None + print(token_from_key) token_from_key = token_from_key.credentials try: if not token_from_key: From 05bcd6d73faad1a376e1ee69e36210b07fd8edcf Mon Sep 17 00:00:00 2001 From: Mesh Date: Mon, 9 Jun 2025 10:25:10 -0500 Subject: [PATCH 5/7] adjusted api authentication --- src/webapp/authn.py | 16 +++++++++------- src/webapp/utilities.py | 5 +++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/webapp/authn.py b/src/webapp/authn.py index c45d6811..96301867 100644 --- a/src/webapp/authn.py +++ b/src/webapp/authn.py @@ -6,9 +6,9 @@ import jwt from fastapi import Security, HTTPException, status from fastapi.security import ( - #OAuth2PasswordBearer, + # OAuth2PasswordBearer, APIKeyHeader, - HTTPBearer + HTTPBearer, ) from passlib.context import CryptContext from pydantic import BaseModel @@ -17,11 +17,13 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -#oauth2_apikey_scheme = OAuth2PasswordBearer( - # scheme_name="api_key_scheme", - #tokenUrl="token-from-api-key", -#) -oauth2_apikey_scheme = HTTPBearer(auto_error=True) +# oauth2_apikey_scheme = OAuth2PasswordBearer( +# scheme_name="api_key_scheme", +# tokenUrl="token-from-api-key", +# ) +oauth2_apikey_scheme = HTTPBearer( + auto_error=True, scheme_name="Bearer token (get from /token-from-api-key)" +) api_key_header = APIKeyHeader(name="X-API-KEY", scheme_name="api-key", auto_error=False) # The INST value may be empty for Datakinder or cross-institution access. diff --git a/src/webapp/utilities.py b/src/webapp/utilities.py index a5a4bb7e..5d6f189f 100644 --- a/src/webapp/utilities.py +++ b/src/webapp/utilities.py @@ -313,7 +313,9 @@ def authenticate_api_key(api_key_enduser_tuple: str, sess: Session) -> BaseUser: async def get_current_user( sess: Annotated[Session, Depends(get_session)], - token_from_key: Annotated[HTTPAuthorizationCredentials, Depends(oauth2_apikey_scheme)], + token_from_key: Annotated[ + HTTPAuthorizationCredentials, Depends(oauth2_apikey_scheme) + ], ) -> BaseUser: """Get the user from a given token.""" credentials_exception = HTTPException( @@ -322,7 +324,6 @@ async def get_current_user( headers={"WWW-Authenticate": "Bearer"}, ) usrname = None - print(token_from_key) token_from_key = token_from_key.credentials try: if not token_from_key: From 6be550991fb1dc7ded4463de645127cfcd9d820e Mon Sep 17 00:00:00 2001 From: Mesh Date: Mon, 9 Jun 2025 10:33:25 -0500 Subject: [PATCH 6/7] adjusted api authentication --- src/webapp/authn.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/webapp/authn.py b/src/webapp/authn.py index 96301867..0de63d08 100644 --- a/src/webapp/authn.py +++ b/src/webapp/authn.py @@ -13,7 +13,7 @@ from passlib.context import CryptContext from pydantic import BaseModel from .config import env_vars - +from typing import Any pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -61,24 +61,24 @@ def check_creds(username: str, password: str) -> bool: return username == env_vars["USERNAME"] and password == env_vars["PASSWORD"] -def verify_password(plain_password: str, hashed_password: str) -> bool: +def verify_password(plain_password: str, hashed_password: str) -> Any: """Verify a plain password against a hash. Includes a 2y/2b replacement since Laravel Generates hashes that start with 2y. The hashing scheme recognizes both.""" revert_hash = hashed_password.replace("$2y", "$2b", 1) return pwd_context.verify(plain_password, revert_hash) -def verify_api_key(plain_api_key: str, hashed_key: str) -> bool: +def verify_api_key(plain_api_key: str, hashed_key: str) -> Any: """Verify a plain API Key against a hash.""" return pwd_context.verify(plain_api_key, hashed_key) -def get_api_key_hash(api_key: str): +def get_api_key_hash(api_key: str) -> Any: """Hash a given api key.""" return pwd_context.hash(api_key) -def get_password_hash(password: str): +def get_password_hash(password: str) -> Any: """Hash a password. To align with the password hashing used by Laravel, we have to replace the 2b generated by pwd_context with 2y and that should be the version we store. They should be functionally the same: https://stackoverflow.com/a/36225192/28478909 @@ -94,10 +94,10 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None): expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta( - minutes=env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"] + minutes=int(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"]) ) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode( - to_encode, env_vars["SECRET_KEY"], algorithm=env_vars["ALGORITHM"] + to_encode, str(env_vars["SECRET_KEY"]), algorithm=str(env_vars["ALGORITHM"]) ) return encoded_jwt From f12709fee1a57194098c03777d1e14747901c43e Mon Sep 17 00:00:00 2001 From: Mesh Date: Mon, 9 Jun 2025 10:44:41 -0500 Subject: [PATCH 7/7] adjusted api authentication --- src/webapp/authn.py | 6 +++--- src/webapp/utilities.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/webapp/authn.py b/src/webapp/authn.py index 0de63d08..09af3c9f 100644 --- a/src/webapp/authn.py +++ b/src/webapp/authn.py @@ -87,17 +87,17 @@ def get_password_hash(password: str) -> Any: return initial_hash.replace("$2b", "$2y", 1) -def create_access_token(data: dict, expires_delta: timedelta | None = None): +def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str: """Create a JWT.""" to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta( - minutes=int(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"]) + minutes=env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"] ) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode( - to_encode, str(env_vars["SECRET_KEY"]), algorithm=str(env_vars["ALGORITHM"]) + to_encode, env_vars["SECRET_KEY"], algorithm=env_vars["ALGORITHM"] ) return encoded_jwt diff --git a/src/webapp/utilities.py b/src/webapp/utilities.py index 5d6f189f..8d859301 100644 --- a/src/webapp/utilities.py +++ b/src/webapp/utilities.py @@ -2,7 +2,7 @@ import uuid import re -from typing import Annotated, Final +from typing import Annotated, Final, Any from urllib.parse import unquote from strenum import StrEnum # needed for python pre 3.11 import jwt @@ -22,7 +22,7 @@ from .config import env_vars -def decode_url_piece(src: str): +def decode_url_piece(src: str) -> str: """Decode encoded URL.""" return unquote(src) @@ -165,29 +165,29 @@ class BaseUser(BaseModel): def __init__(self, usr: str | None, inst: str, access: str, email: str) -> None: super().__init__(user_id=usr, institution=inst, access_type=access, email=email) - def is_datakinder(self) -> bool: + def is_datakinder(self) -> Any: """Whether a given user is a Datakinder.""" return self.access_type and self.access_type == AccessType.DATAKINDER - def is_model_owner(self) -> bool: + def is_model_owner(self) -> Any: """Whether a given user is a model owner.""" return self.access_type and self.access_type == AccessType.MODEL_OWNER - def is_data_owner(self) -> bool: + def is_data_owner(self) -> Any: """Whether a given user is a data owner.""" return self.access_type and self.access_type == AccessType.DATA_OWNER - def is_viewer(self) -> bool: + def is_viewer(self) -> Any: """Whether a given user is a viewer.""" return self.access_type and self.access_type == AccessType.VIEWER - def has_access_to_inst(self, inst: str) -> bool: + def has_access_to_inst(self, inst: str) -> Any: """Whether a given user has access to a given institution.""" return self.access_type and ( self.access_type == AccessType.DATAKINDER or self.institution == inst ) - def has_full_data_access(self) -> bool: + def has_full_data_access(self) -> Any: """Datakinders, model_owners, data_owners, all have full data access.""" return self.access_type and self.access_type in ( AccessType.DATAKINDER,