From 0feb6e289ed7502d77515e8b7bb430d80aede267 Mon Sep 17 00:00:00 2001 From: Mesh Date: Mon, 23 Jun 2025 16:36:18 -0500 Subject: [PATCH 1/6] Added endpoint for front end table --- src/webapp/routers/data.py | 68 ++++++++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 10 deletions(-) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 95c90997..415c023a 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -1022,6 +1022,9 @@ def get_upload_url( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) +## FE Inference Tables + + # Get SHAP Values for Inference @router.get("/{inst_id}/inference/top-features/{run_id}") def get_top_features( @@ -1055,7 +1058,7 @@ def get_top_features( rows = dbc.fetch_table_data( catalog_name=env_vars["CATALOG_NAME"], inst_name=f"{query_result[0][0].name}", - table_name=f"sample_inference_{run_id}_features_with_most_impact", + table_name=f"inference_{run_id}_features_with_most_impact", warehouse_id=env_vars["SQL_WAREHOUSE_ID"], ) @@ -1098,7 +1101,7 @@ def get_support_overview( rows = dbc.fetch_table_data( catalog_name=env_vars["CATALOG_NAME"], inst_name=f"{query_result[0][0].name}", - table_name=f"sample_inference_{run_id}_support_overview", + table_name=f"inference_{run_id}_support_overview", warehouse_id=env_vars["SQL_WAREHOUSE_ID"], ) @@ -1108,8 +1111,8 @@ def get_support_overview( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) -@router.get("/{inst_id}/training/support-overview/{run_id}") -def get_training_support_overview( +@router.get("/{inst_id}/inference/feature_importance/{run_id}") +def get_feature_importance( inst_id: str, run_id: str, current_user: Annotated[BaseUser, Depends(get_current_active_user)], @@ -1140,7 +1143,7 @@ def get_training_support_overview( rows = dbc.fetch_table_data( catalog_name=env_vars["CATALOG_NAME"], inst_name=f"{query_result[0][0].name}", - table_name=f"sample_training_{run_id}_support_overview", + table_name=f"inference_{run_id}_shap_feature_importance", warehouse_id=env_vars["SQL_WAREHOUSE_ID"], ) @@ -1150,8 +1153,11 @@ def get_training_support_overview( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) -@router.get("/{inst_id}/inference/feature_value/{run_id}") -def get_feature_value( +## FE Training Tables + + +@router.get("/{inst_id}/training/feature_importance/{run_id}") +def get_feature_importance( inst_id: str, run_id: str, current_user: Annotated[BaseUser, Depends(get_current_active_user)], @@ -1182,7 +1188,7 @@ def get_feature_value( rows = dbc.fetch_table_data( catalog_name=env_vars["CATALOG_NAME"], inst_name=f"{query_result[0][0].name}", - table_name=f"sample_inference_{run_id}_shap_feature_importance", + table_name=f"training_{run_id}_shap_feature_importance", warehouse_id=env_vars["SQL_WAREHOUSE_ID"], ) @@ -1224,7 +1230,7 @@ def get_confusion_matrix( rows = dbc.fetch_table_data( catalog_name=env_vars["CATALOG_NAME"], inst_name=f"{query_result[0][0].name}", - table_name=f"sample_training_{run_id}_confusion_matrix", + table_name=f"training_{run_id}_confusion_matrix", warehouse_id=env_vars["SQL_WAREHOUSE_ID"], ) @@ -1266,7 +1272,49 @@ def get_roc_curve( rows = dbc.fetch_table_data( catalog_name=env_vars["CATALOG_NAME"], inst_name=f"{query_result[0][0].name}", - table_name=f"sample_training_{run_id}_roc_curve", + table_name=f"training_{run_id}_roc_curve", + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + ) + + return rows + except ValueError as ve: + # Return a 400 error with the specific message from ValueError + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + + +@router.get("/{inst_id}/training/support-overview/{run_id}") +def get_training_support_overview( + inst_id: str, + run_id: str, + current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], +) -> List[dict[str, Any]]: + """Returns a signed URL for uploading data to a specific institution.""" + # raise error at this level instead bc otherwise it's getting wrapped as a 200 + has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + query_result = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .all() + ) + if not query_result or len(query_result) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + if len(query_result) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution duplicates found.", + ) + + try: + dbc = DatabricksControl() + rows = dbc.fetch_table_data( + catalog_name=env_vars["CATALOG_NAME"], + inst_name=f"{query_result[0][0].name}", + table_name=f"training_{run_id}_support_overview", warehouse_id=env_vars["SQL_WAREHOUSE_ID"], ) From 3fc0a1d53d81e6302fa5066e51af284018cb5c8e Mon Sep 17 00:00:00 2001 From: Mesh Date: Mon, 23 Jun 2025 16:40:30 -0500 Subject: [PATCH 2/6] Added endpoint for front end table --- src/webapp/routers/data.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 415c023a..b847ced1 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -1027,7 +1027,7 @@ def get_upload_url( # Get SHAP Values for Inference @router.get("/{inst_id}/inference/top-features/{run_id}") -def get_top_features( +def get_inference_top_features( inst_id: str, run_id: str, current_user: Annotated[BaseUser, Depends(get_current_active_user)], @@ -1070,7 +1070,7 @@ def get_top_features( # Get SHAP Values for Inference @router.get("/{inst_id}/inference/support-overview/{run_id}") -def get_support_overview( +def get_inference_support_overview( inst_id: str, run_id: str, current_user: Annotated[BaseUser, Depends(get_current_active_user)], @@ -1112,7 +1112,7 @@ def get_support_overview( @router.get("/{inst_id}/inference/feature_importance/{run_id}") -def get_feature_importance( +def get_inference_feature_importance( inst_id: str, run_id: str, current_user: Annotated[BaseUser, Depends(get_current_active_user)], @@ -1157,7 +1157,7 @@ def get_feature_importance( @router.get("/{inst_id}/training/feature_importance/{run_id}") -def get_feature_importance( +def get_training_feature_importance( inst_id: str, run_id: str, current_user: Annotated[BaseUser, Depends(get_current_active_user)], @@ -1199,7 +1199,7 @@ def get_feature_importance( @router.get("/{inst_id}/training/confusion_matrix/{run_id}") -def get_confusion_matrix( +def get_training_confusion_matrix( inst_id: str, run_id: str, current_user: Annotated[BaseUser, Depends(get_current_active_user)], @@ -1241,7 +1241,7 @@ def get_confusion_matrix( @router.get("/{inst_id}/training/roc_curve/{run_id}") -def get_roc_curve( +def get_training_roc_curve( inst_id: str, run_id: str, current_user: Annotated[BaseUser, Depends(get_current_active_user)], From 326bd4002c845c66656977d148e1c9a4868029ad Mon Sep 17 00:00:00 2001 From: Mesh Date: Mon, 23 Jun 2025 16:48:54 -0500 Subject: [PATCH 3/6] Added endpoint for front end table --- src/webapp/routers/data.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index b847ced1..93d4b03d 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -191,7 +191,7 @@ def get_all_files( "uploaded_date": elem.created_at, } ) - return result_files + return result_files # type: ignore def get_all_batches( @@ -225,19 +225,19 @@ def get_all_batches( "updated_at": elem.updated_at, } ) - return result_batches + return result_batches # type: ignore -def uuids_to_strs(files) -> set[str]: +def uuids_to_strs(files: Any) -> set[str]: """Convert a set of uuids to strings. The input is of type sqlalchemy.orm.collections.InstrumentedSet. """ - return [uuid_to_str(x.id) for x in files] + return [uuid_to_str(x.id) for x in files] # type: ignore -def strs_to_uuids(files) -> set[uuid.UUID]: +def strs_to_uuids(files: Any) -> set[uuid.UUID]: """Convert a set of strs to uuids.""" - return [str_to_uuid(x) for x in files] + return [str_to_uuid(x) for x in files] # type: ignore @router.get("/{inst_id}/input", response_model=DataOverview) From fb913be6b8d61b524b2a250c34a7fa54514b6b69 Mon Sep 17 00:00:00 2001 From: Mesh Date: Mon, 23 Jun 2025 16:49:06 -0500 Subject: [PATCH 4/6] Added endpoint for front end table --- src/webapp/routers/data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 93d4b03d..a329a853 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -191,7 +191,7 @@ def get_all_files( "uploaded_date": elem.created_at, } ) - return result_files # type: ignore + return result_files # type: ignore def get_all_batches( @@ -225,19 +225,19 @@ def get_all_batches( "updated_at": elem.updated_at, } ) - return result_batches # type: ignore + return result_batches # type: ignore -def uuids_to_strs(files: Any) -> set[str]: +def uuids_to_strs(files: Any) -> set[str]: """Convert a set of uuids to strings. The input is of type sqlalchemy.orm.collections.InstrumentedSet. """ - return [uuid_to_str(x.id) for x in files] # type: ignore + return [uuid_to_str(x.id) for x in files] # type: ignore def strs_to_uuids(files: Any) -> set[uuid.UUID]: """Convert a set of strs to uuids.""" - return [str_to_uuid(x) for x in files] # type: ignore + return [str_to_uuid(x) for x in files] # type: ignore @router.get("/{inst_id}/input", response_model=DataOverview) From 2ce596b667ee6f6b24420a1874195e9666a5e088 Mon Sep 17 00:00:00 2001 From: Mesh Date: Mon, 23 Jun 2025 16:53:30 -0500 Subject: [PATCH 5/6] Added endpoint for front end table --- src/webapp/routers/data.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index a329a853..194d114f 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -463,7 +463,7 @@ def create_batch( batch = BatchTable( name=req.name, inst_id=str_to_uuid(inst_id), - created_by=str_to_uuid(current_user.user_id), + created_by=str_to_uuid(current_user.user_id), # type: ignore ) f_names = [] if not req.file_names else req.file_names f_ids = [] if not req.file_ids else strs_to_uuids(req.file_ids) @@ -647,7 +647,7 @@ def update_batch( existing_batch.name = update_data_req["name"] if "completed" in update_data_req: existing_batch.completed = update_data_req["completed"] - existing_batch.updated_by = str_to_uuid(current_user.user_id) + existing_batch.updated_by = str_to_uuid(current_user.user_id) # type: ignore local_session.get().commit() res = ( local_session.get() @@ -931,7 +931,7 @@ def validation_helper( new_file_record = FileTable( name=file_name, inst_id=str_to_uuid(inst_id), - uploader=str_to_uuid(current_user.user_id), + uploader=str_to_uuid(current_user.user_id), # type: ignore source=source_str, sst_generated=False, schemas=list(allowed_schemas), @@ -974,7 +974,7 @@ def validate_file_sftp( ) -> Any: """Validate a given file pulled from SFTP. The file_name should be url encoded.""" file_name = decode_url_piece(file_name) - if not current_user.is_datakinder: + if not current_user.is_datakinder: # type: ignore raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="SFTP validation needs to be done by a datakinder.", @@ -1056,10 +1056,10 @@ def get_inference_top_features( try: dbc = DatabricksControl() rows = dbc.fetch_table_data( - catalog_name=env_vars["CATALOG_NAME"], + catalog_name=env_vars["CATALOG_NAME"], # type: ignore inst_name=f"{query_result[0][0].name}", table_name=f"inference_{run_id}_features_with_most_impact", - warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore ) return rows @@ -1099,10 +1099,10 @@ def get_inference_support_overview( try: dbc = DatabricksControl() rows = dbc.fetch_table_data( - catalog_name=env_vars["CATALOG_NAME"], + catalog_name=env_vars["CATALOG_NAME"], # type: ignore inst_name=f"{query_result[0][0].name}", table_name=f"inference_{run_id}_support_overview", - warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore ) return rows @@ -1141,10 +1141,10 @@ def get_inference_feature_importance( try: dbc = DatabricksControl() rows = dbc.fetch_table_data( - catalog_name=env_vars["CATALOG_NAME"], + catalog_name=env_vars["CATALOG_NAME"], # type: ignore inst_name=f"{query_result[0][0].name}", table_name=f"inference_{run_id}_shap_feature_importance", - warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore ) return rows @@ -1186,10 +1186,10 @@ def get_training_feature_importance( try: dbc = DatabricksControl() rows = dbc.fetch_table_data( - catalog_name=env_vars["CATALOG_NAME"], + catalog_name=env_vars["CATALOG_NAME"], # type: ignore inst_name=f"{query_result[0][0].name}", table_name=f"training_{run_id}_shap_feature_importance", - warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore ) return rows @@ -1228,10 +1228,10 @@ def get_training_confusion_matrix( try: dbc = DatabricksControl() rows = dbc.fetch_table_data( - catalog_name=env_vars["CATALOG_NAME"], + catalog_name=env_vars["CATALOG_NAME"], # type: ignore inst_name=f"{query_result[0][0].name}", table_name=f"training_{run_id}_confusion_matrix", - warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore ) return rows @@ -1270,10 +1270,10 @@ def get_training_roc_curve( try: dbc = DatabricksControl() rows = dbc.fetch_table_data( - catalog_name=env_vars["CATALOG_NAME"], + catalog_name=env_vars["CATALOG_NAME"], # type: ignore inst_name=f"{query_result[0][0].name}", table_name=f"training_{run_id}_roc_curve", - warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore ) return rows @@ -1312,10 +1312,10 @@ def get_training_support_overview( try: dbc = DatabricksControl() rows = dbc.fetch_table_data( - catalog_name=env_vars["CATALOG_NAME"], + catalog_name=env_vars["CATALOG_NAME"], # type: ignore inst_name=f"{query_result[0][0].name}", table_name=f"training_{run_id}_support_overview", - warehouse_id=env_vars["SQL_WAREHOUSE_ID"], + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore ) return rows From fd40d9ed51a92e95ffbc3ada88c6bbf3c048b11d Mon Sep 17 00:00:00 2001 From: Mesh Date: Mon, 23 Jun 2025 16:55:26 -0500 Subject: [PATCH 6/6] Added endpoint for front end table --- src/webapp/routers/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 194d114f..72390267 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -150,7 +150,7 @@ def get_all_files( inst_id: str, sst_generated_value: bool | None, sess: Session, - storage_control, + storage_control: Any, ) -> list[DataInfo]: """Retrieve all files.""" # Update from bucket