@@ -1028,12 +1028,12 @@ def get_upload_url(
10281028def get_top_features (
10291029 inst_id : str ,
10301030 run_id : str ,
1031- # current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1031+ current_user : Annotated [BaseUser , Depends (get_current_active_user )],
10321032 sql_session : Annotated [Session , Depends (get_session )],
10331033) -> List [dict [str , Any ]]:
10341034 """Returns a signed URL for uploading data to a specific institution."""
10351035 # raise error at this level instead bc otherwise it's getting wrapped as a 200
1036- # has_access_to_inst_or_err(inst_id, current_user)
1036+ has_access_to_inst_or_err (inst_id , current_user )
10371037 local_session .set (sql_session )
10381038 query_result = (
10391039 local_session .get ()
@@ -1072,12 +1072,12 @@ def get_top_features(
10721072def get_support_overview (
10731073 inst_id : str ,
10741074 run_id : str ,
1075- # current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1075+ current_user : Annotated [BaseUser , Depends (get_current_active_user )],
10761076 sql_session : Annotated [Session , Depends (get_session )],
10771077) -> List [dict [str , Any ]]:
10781078 """Returns a signed URL for uploading data to a specific institution."""
10791079 # raise error at this level instead bc otherwise it's getting wrapped as a 200
1080- # has_access_to_inst_or_err(inst_id, current_user)
1080+ has_access_to_inst_or_err (inst_id , current_user )
10811081 local_session .set (sql_session )
10821082 query_result = (
10831083 local_session .get ()
@@ -1111,16 +1111,59 @@ def get_support_overview(
11111111 raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = str (ve ))
11121112
11131113
1114+ @router .get ("/{inst_id}/training/support-overview/{run_id}" )
1115+ def get_training_support_overview (
1116+ inst_id : str ,
1117+ run_id : str ,
1118+ current_user : Annotated [BaseUser , Depends (get_current_active_user )],
1119+ sql_session : Annotated [Session , Depends (get_session )],
1120+ ) -> List [dict [str , Any ]]:
1121+ """Returns a signed URL for uploading data to a specific institution."""
1122+ # raise error at this level instead bc otherwise it's getting wrapped as a 200
1123+ has_access_to_inst_or_err (inst_id , current_user )
1124+ local_session .set (sql_session )
1125+ query_result = (
1126+ local_session .get ()
1127+ .execute (select (InstTable ).where (InstTable .id == str_to_uuid (inst_id )))
1128+ .all ()
1129+ )
1130+ if not query_result or len (query_result ) == 0 :
1131+ raise HTTPException (
1132+ status_code = status .HTTP_404_NOT_FOUND ,
1133+ detail = "Institution not found." ,
1134+ )
1135+ if len (query_result ) > 1 :
1136+ raise HTTPException (
1137+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
1138+ detail = "Institution duplicates found." ,
1139+ )
1140+
1141+ try :
1142+ dbc = DatabricksControl ()
1143+ rows = dbc .fetch_table_data (
1144+ catalog_name = env_vars ["CATALOG_NAME" ],
1145+ inst_name = f"{ query_result [0 ][0 ].name } " ,
1146+ table_name = f"sample_training_{ run_id } _support_overview" ,
1147+ warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ],
1148+ limit = 500 ,
1149+ )
1150+
1151+ return rows
1152+ except ValueError as ve :
1153+ # Return a 400 error with the specific message from ValueError
1154+ raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = str (ve ))
1155+
1156+
11141157@router .get ("/{inst_id}/inference/feature_value/{run_id}" )
11151158def get_feature_value (
11161159 inst_id : str ,
11171160 run_id : str ,
1118- # current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1161+ current_user : Annotated [BaseUser , Depends (get_current_active_user )],
11191162 sql_session : Annotated [Session , Depends (get_session )],
11201163) -> List [dict [str , Any ]]:
11211164 """Returns a signed URL for uploading data to a specific institution."""
11221165 # raise error at this level instead bc otherwise it's getting wrapped as a 200
1123- # has_access_to_inst_or_err(inst_id, current_user)
1166+ has_access_to_inst_or_err (inst_id , current_user )
11241167 local_session .set (sql_session )
11251168 query_result = (
11261169 local_session .get ()
@@ -1158,12 +1201,12 @@ def get_feature_value(
11581201def get_confusion_matrix (
11591202 inst_id : str ,
11601203 run_id : str ,
1161- ## current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1204+ current_user : Annotated [BaseUser , Depends (get_current_active_user )],
11621205 sql_session : Annotated [Session , Depends (get_session )],
11631206) -> List [dict [str , Any ]]:
11641207 """Returns a signed URL for uploading data to a specific institution."""
11651208 # raise error at this level instead bc otherwise it's getting wrapped as a 200
1166- # has_access_to_inst_or_err(inst_id, current_user)
1209+ has_access_to_inst_or_err (inst_id , current_user )
11671210 local_session .set (sql_session )
11681211 query_result = (
11691212 local_session .get ()
@@ -1201,12 +1244,12 @@ def get_confusion_matrix(
12011244def get_roc_curve (
12021245 inst_id : str ,
12031246 run_id : str ,
1204- # current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1247+ current_user : Annotated [BaseUser , Depends (get_current_active_user )],
12051248 sql_session : Annotated [Session , Depends (get_session )],
12061249) -> List [dict [str , Any ]]:
12071250 """Returns a signed URL for uploading data to a specific institution."""
12081251 # raise error at this level instead bc otherwise it's getting wrapped as a 200
1209- # has_access_to_inst_or_err(inst_id, current_user)
1252+ has_access_to_inst_or_err (inst_id , current_user )
12101253 local_session .set (sql_session )
12111254 query_result = (
12121255 local_session .get ()
0 commit comments