@@ -150,7 +150,7 @@ def get_all_files(
150150 inst_id : str ,
151151 sst_generated_value : bool | None ,
152152 sess : Session ,
153- storage_control ,
153+ storage_control : Any ,
154154) -> list [DataInfo ]:
155155 """Retrieve all files."""
156156 # Update from bucket
@@ -191,7 +191,7 @@ def get_all_files(
191191 "uploaded_date" : elem .created_at ,
192192 }
193193 )
194- return result_files
194+ return result_files # type: ignore
195195
196196
197197def get_all_batches (
@@ -225,19 +225,19 @@ def get_all_batches(
225225 "updated_at" : elem .updated_at ,
226226 }
227227 )
228- return result_batches
228+ return result_batches # type: ignore
229229
230230
231- def uuids_to_strs (files ) -> set [str ]:
231+ def uuids_to_strs (files : Any ) -> set [str ]:
232232 """Convert a set of uuids to strings.
233233 The input is of type sqlalchemy.orm.collections.InstrumentedSet.
234234 """
235- return [uuid_to_str (x .id ) for x in files ]
235+ return [uuid_to_str (x .id ) for x in files ] # type: ignore
236236
237237
238- def strs_to_uuids (files ) -> set [uuid .UUID ]:
238+ def strs_to_uuids (files : Any ) -> set [uuid .UUID ]:
239239 """Convert a set of strs to uuids."""
240- return [str_to_uuid (x ) for x in files ]
240+ return [str_to_uuid (x ) for x in files ] # type: ignore
241241
242242
243243@router .get ("/{inst_id}/input" , response_model = DataOverview )
@@ -463,7 +463,7 @@ def create_batch(
463463 batch = BatchTable (
464464 name = req .name ,
465465 inst_id = str_to_uuid (inst_id ),
466- created_by = str_to_uuid (current_user .user_id ),
466+ created_by = str_to_uuid (current_user .user_id ), # type: ignore
467467 )
468468 f_names = [] if not req .file_names else req .file_names
469469 f_ids = [] if not req .file_ids else strs_to_uuids (req .file_ids )
@@ -647,7 +647,7 @@ def update_batch(
647647 existing_batch .name = update_data_req ["name" ]
648648 if "completed" in update_data_req :
649649 existing_batch .completed = update_data_req ["completed" ]
650- existing_batch .updated_by = str_to_uuid (current_user .user_id )
650+ existing_batch .updated_by = str_to_uuid (current_user .user_id ) # type: ignore
651651 local_session .get ().commit ()
652652 res = (
653653 local_session .get ()
@@ -931,7 +931,7 @@ def validation_helper(
931931 new_file_record = FileTable (
932932 name = file_name ,
933933 inst_id = str_to_uuid (inst_id ),
934- uploader = str_to_uuid (current_user .user_id ),
934+ uploader = str_to_uuid (current_user .user_id ), # type: ignore
935935 source = source_str ,
936936 sst_generated = False ,
937937 schemas = list (allowed_schemas ),
@@ -974,7 +974,7 @@ def validate_file_sftp(
974974) -> Any :
975975 """Validate a given file pulled from SFTP. The file_name should be url encoded."""
976976 file_name = decode_url_piece (file_name )
977- if not current_user .is_datakinder :
977+ if not current_user .is_datakinder : # type: ignore
978978 raise HTTPException (
979979 status_code = status .HTTP_401_UNAUTHORIZED ,
980980 detail = "SFTP validation needs to be done by a datakinder." ,
@@ -1022,9 +1022,12 @@ def get_upload_url(
10221022 raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = str (ve ))
10231023
10241024
1025+ ## FE Inference Tables
1026+
1027+
10251028# Get SHAP Values for Inference
10261029@router .get ("/{inst_id}/inference/top-features/{run_id}" )
1027- def get_top_features (
1030+ def get_inference_top_features (
10281031 inst_id : str ,
10291032 run_id : str ,
10301033 current_user : Annotated [BaseUser , Depends (get_current_active_user )],
@@ -1053,10 +1056,10 @@ def get_top_features(
10531056 try :
10541057 dbc = DatabricksControl ()
10551058 rows = dbc .fetch_table_data (
1056- catalog_name = env_vars ["CATALOG_NAME" ],
1059+ catalog_name = env_vars ["CATALOG_NAME" ], # type: ignore
10571060 inst_name = f"{ query_result [0 ][0 ].name } " ,
1058- table_name = f"sample_inference_ { run_id } _features_with_most_impact" ,
1059- warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ],
1061+ table_name = f"inference_ { run_id } _features_with_most_impact" ,
1062+ warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ], # type: ignore
10601063 )
10611064
10621065 return rows
@@ -1067,7 +1070,7 @@ def get_top_features(
10671070
10681071# Get SHAP Values for Inference
10691072@router .get ("/{inst_id}/inference/support-overview/{run_id}" )
1070- def get_support_overview (
1073+ def get_inference_support_overview (
10711074 inst_id : str ,
10721075 run_id : str ,
10731076 current_user : Annotated [BaseUser , Depends (get_current_active_user )],
@@ -1096,10 +1099,10 @@ def get_support_overview(
10961099 try :
10971100 dbc = DatabricksControl ()
10981101 rows = dbc .fetch_table_data (
1099- catalog_name = env_vars ["CATALOG_NAME" ],
1102+ catalog_name = env_vars ["CATALOG_NAME" ], # type: ignore
11001103 inst_name = f"{ query_result [0 ][0 ].name } " ,
1101- table_name = f"sample_inference_ { run_id } _support_overview" ,
1102- warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ],
1104+ table_name = f"inference_ { run_id } _support_overview" ,
1105+ warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ], # type: ignore
11031106 )
11041107
11051108 return rows
@@ -1108,8 +1111,8 @@ def get_support_overview(
11081111 raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = str (ve ))
11091112
11101113
1111- @router .get ("/{inst_id}/training/support-overview /{run_id}" )
1112- def get_training_support_overview (
1114+ @router .get ("/{inst_id}/inference/feature_importance /{run_id}" )
1115+ def get_inference_feature_importance (
11131116 inst_id : str ,
11141117 run_id : str ,
11151118 current_user : Annotated [BaseUser , Depends (get_current_active_user )],
@@ -1138,10 +1141,10 @@ def get_training_support_overview(
11381141 try :
11391142 dbc = DatabricksControl ()
11401143 rows = dbc .fetch_table_data (
1141- catalog_name = env_vars ["CATALOG_NAME" ],
1144+ catalog_name = env_vars ["CATALOG_NAME" ], # type: ignore
11421145 inst_name = f"{ query_result [0 ][0 ].name } " ,
1143- table_name = f"sample_training_ { run_id } _support_overview " ,
1144- warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ],
1146+ table_name = f"inference_ { run_id } _shap_feature_importance " ,
1147+ warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ], # type: ignore
11451148 )
11461149
11471150 return rows
@@ -1150,8 +1153,11 @@ def get_training_support_overview(
11501153 raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = str (ve ))
11511154
11521155
1153- @router .get ("/{inst_id}/inference/feature_value/{run_id}" )
1154- def get_feature_value (
1156+ ## FE Training Tables
1157+
1158+
1159+ @router .get ("/{inst_id}/training/feature_importance/{run_id}" )
1160+ def get_training_feature_importance (
11551161 inst_id : str ,
11561162 run_id : str ,
11571163 current_user : Annotated [BaseUser , Depends (get_current_active_user )],
@@ -1180,10 +1186,10 @@ def get_feature_value(
11801186 try :
11811187 dbc = DatabricksControl ()
11821188 rows = dbc .fetch_table_data (
1183- catalog_name = env_vars ["CATALOG_NAME" ],
1189+ catalog_name = env_vars ["CATALOG_NAME" ], # type: ignore
11841190 inst_name = f"{ query_result [0 ][0 ].name } " ,
1185- table_name = f"sample_inference_ { run_id } _shap_feature_importance" ,
1186- warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ],
1191+ table_name = f"training_ { run_id } _shap_feature_importance" ,
1192+ warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ], # type: ignore
11871193 )
11881194
11891195 return rows
@@ -1193,7 +1199,7 @@ def get_feature_value(
11931199
11941200
11951201@router .get ("/{inst_id}/training/confusion_matrix/{run_id}" )
1196- def get_confusion_matrix (
1202+ def get_training_confusion_matrix (
11971203 inst_id : str ,
11981204 run_id : str ,
11991205 current_user : Annotated [BaseUser , Depends (get_current_active_user )],
@@ -1222,10 +1228,10 @@ def get_confusion_matrix(
12221228 try :
12231229 dbc = DatabricksControl ()
12241230 rows = dbc .fetch_table_data (
1225- catalog_name = env_vars ["CATALOG_NAME" ],
1231+ catalog_name = env_vars ["CATALOG_NAME" ], # type: ignore
12261232 inst_name = f"{ query_result [0 ][0 ].name } " ,
1227- table_name = f"sample_training_ { run_id } _confusion_matrix" ,
1228- warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ],
1233+ table_name = f"training_ { run_id } _confusion_matrix" ,
1234+ warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ], # type: ignore
12291235 )
12301236
12311237 return rows
@@ -1235,7 +1241,49 @@ def get_confusion_matrix(
12351241
12361242
12371243@router .get ("/{inst_id}/training/roc_curve/{run_id}" )
1238- def get_roc_curve (
1244+ def get_training_roc_curve (
1245+ inst_id : str ,
1246+ run_id : str ,
1247+ current_user : Annotated [BaseUser , Depends (get_current_active_user )],
1248+ sql_session : Annotated [Session , Depends (get_session )],
1249+ ) -> List [dict [str , Any ]]:
1250+ """Returns a signed URL for uploading data to a specific institution."""
1251+ # raise error at this level instead bc otherwise it's getting wrapped as a 200
1252+ has_access_to_inst_or_err (inst_id , current_user )
1253+ local_session .set (sql_session )
1254+ query_result = (
1255+ local_session .get ()
1256+ .execute (select (InstTable ).where (InstTable .id == str_to_uuid (inst_id )))
1257+ .all ()
1258+ )
1259+ if not query_result or len (query_result ) == 0 :
1260+ raise HTTPException (
1261+ status_code = status .HTTP_404_NOT_FOUND ,
1262+ detail = "Institution not found." ,
1263+ )
1264+ if len (query_result ) > 1 :
1265+ raise HTTPException (
1266+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
1267+ detail = "Institution duplicates found." ,
1268+ )
1269+
1270+ try :
1271+ dbc = DatabricksControl ()
1272+ rows = dbc .fetch_table_data (
1273+ catalog_name = env_vars ["CATALOG_NAME" ], # type: ignore
1274+ inst_name = f"{ query_result [0 ][0 ].name } " ,
1275+ table_name = f"training_{ run_id } _roc_curve" ,
1276+ warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ], # type: ignore
1277+ )
1278+
1279+ return rows
1280+ except ValueError as ve :
1281+ # Return a 400 error with the specific message from ValueError
1282+ raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = str (ve ))
1283+
1284+
1285+ @router .get ("/{inst_id}/training/support-overview/{run_id}" )
1286+ def get_training_support_overview (
12391287 inst_id : str ,
12401288 run_id : str ,
12411289 current_user : Annotated [BaseUser , Depends (get_current_active_user )],
@@ -1264,10 +1312,10 @@ def get_roc_curve(
12641312 try :
12651313 dbc = DatabricksControl ()
12661314 rows = dbc .fetch_table_data (
1267- catalog_name = env_vars ["CATALOG_NAME" ],
1315+ catalog_name = env_vars ["CATALOG_NAME" ], # type: ignore
12681316 inst_name = f"{ query_result [0 ][0 ].name } " ,
1269- table_name = f"sample_training_ { run_id } _roc_curve " ,
1270- warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ],
1317+ table_name = f"training_ { run_id } _support_overview " ,
1318+ warehouse_id = env_vars ["SQL_WAREHOUSE_ID" ], # type: ignore
12711319 )
12721320
12731321 return rows
0 commit comments