Skip to content

Commit eb4d32e

Browse files
authored
Merge pull request #122 from datakind/AdjustedFETables
Added endpoint for front end table
2 parents d7d61dc + fd40d9e commit eb4d32e

1 file changed

Lines changed: 85 additions & 37 deletions

File tree

src/webapp/routers/data.py

Lines changed: 85 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def get_all_files(
150150
inst_id: str,
151151
sst_generated_value: bool | None,
152152
sess: Session,
153-
storage_control,
153+
storage_control: Any,
154154
) -> list[DataInfo]:
155155
"""Retrieve all files."""
156156
# Update from bucket
@@ -191,7 +191,7 @@ def get_all_files(
191191
"uploaded_date": elem.created_at,
192192
}
193193
)
194-
return result_files
194+
return result_files # type: ignore
195195

196196

197197
def get_all_batches(
@@ -225,19 +225,19 @@ def get_all_batches(
225225
"updated_at": elem.updated_at,
226226
}
227227
)
228-
return result_batches
228+
return result_batches # type: ignore
229229

230230

231-
def uuids_to_strs(files) -> set[str]:
231+
def uuids_to_strs(files: Any) -> set[str]:
232232
"""Convert a set of uuids to strings.
233233
The input is of type sqlalchemy.orm.collections.InstrumentedSet.
234234
"""
235-
return [uuid_to_str(x.id) for x in files]
235+
return [uuid_to_str(x.id) for x in files] # type: ignore
236236

237237

238-
def strs_to_uuids(files) -> set[uuid.UUID]:
238+
def strs_to_uuids(files: Any) -> set[uuid.UUID]:
239239
"""Convert a set of strs to uuids."""
240-
return [str_to_uuid(x) for x in files]
240+
return [str_to_uuid(x) for x in files] # type: ignore
241241

242242

243243
@router.get("/{inst_id}/input", response_model=DataOverview)
@@ -463,7 +463,7 @@ def create_batch(
463463
batch = BatchTable(
464464
name=req.name,
465465
inst_id=str_to_uuid(inst_id),
466-
created_by=str_to_uuid(current_user.user_id),
466+
created_by=str_to_uuid(current_user.user_id), # type: ignore
467467
)
468468
f_names = [] if not req.file_names else req.file_names
469469
f_ids = [] if not req.file_ids else strs_to_uuids(req.file_ids)
@@ -647,7 +647,7 @@ def update_batch(
647647
existing_batch.name = update_data_req["name"]
648648
if "completed" in update_data_req:
649649
existing_batch.completed = update_data_req["completed"]
650-
existing_batch.updated_by = str_to_uuid(current_user.user_id)
650+
existing_batch.updated_by = str_to_uuid(current_user.user_id) # type: ignore
651651
local_session.get().commit()
652652
res = (
653653
local_session.get()
@@ -931,7 +931,7 @@ def validation_helper(
931931
new_file_record = FileTable(
932932
name=file_name,
933933
inst_id=str_to_uuid(inst_id),
934-
uploader=str_to_uuid(current_user.user_id),
934+
uploader=str_to_uuid(current_user.user_id), # type: ignore
935935
source=source_str,
936936
sst_generated=False,
937937
schemas=list(allowed_schemas),
@@ -974,7 +974,7 @@ def validate_file_sftp(
974974
) -> Any:
975975
"""Validate a given file pulled from SFTP. The file_name should be url encoded."""
976976
file_name = decode_url_piece(file_name)
977-
if not current_user.is_datakinder:
977+
if not current_user.is_datakinder: # type: ignore
978978
raise HTTPException(
979979
status_code=status.HTTP_401_UNAUTHORIZED,
980980
detail="SFTP validation needs to be done by a datakinder.",
@@ -1022,9 +1022,12 @@ def get_upload_url(
10221022
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
10231023

10241024

1025+
## FE Inference Tables
1026+
1027+
10251028
# Get SHAP Values for Inference
10261029
@router.get("/{inst_id}/inference/top-features/{run_id}")
1027-
def get_top_features(
1030+
def get_inference_top_features(
10281031
inst_id: str,
10291032
run_id: str,
10301033
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
@@ -1053,10 +1056,10 @@ def get_top_features(
10531056
try:
10541057
dbc = DatabricksControl()
10551058
rows = dbc.fetch_table_data(
1056-
catalog_name=env_vars["CATALOG_NAME"],
1059+
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
10571060
inst_name=f"{query_result[0][0].name}",
1058-
table_name=f"sample_inference_{run_id}_features_with_most_impact",
1059-
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1061+
table_name=f"inference_{run_id}_features_with_most_impact",
1062+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
10601063
)
10611064

10621065
return rows
@@ -1067,7 +1070,7 @@ def get_top_features(
10671070

10681071
# Get SHAP Values for Inference
10691072
@router.get("/{inst_id}/inference/support-overview/{run_id}")
1070-
def get_support_overview(
1073+
def get_inference_support_overview(
10711074
inst_id: str,
10721075
run_id: str,
10731076
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
@@ -1096,10 +1099,10 @@ def get_support_overview(
10961099
try:
10971100
dbc = DatabricksControl()
10981101
rows = dbc.fetch_table_data(
1099-
catalog_name=env_vars["CATALOG_NAME"],
1102+
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
11001103
inst_name=f"{query_result[0][0].name}",
1101-
table_name=f"sample_inference_{run_id}_support_overview",
1102-
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1104+
table_name=f"inference_{run_id}_support_overview",
1105+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
11031106
)
11041107

11051108
return rows
@@ -1108,8 +1111,8 @@ def get_support_overview(
11081111
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
11091112

11101113

1111-
@router.get("/{inst_id}/training/support-overview/{run_id}")
1112-
def get_training_support_overview(
1114+
@router.get("/{inst_id}/inference/feature_importance/{run_id}")
1115+
def get_inference_feature_importance(
11131116
inst_id: str,
11141117
run_id: str,
11151118
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
@@ -1138,10 +1141,10 @@ def get_training_support_overview(
11381141
try:
11391142
dbc = DatabricksControl()
11401143
rows = dbc.fetch_table_data(
1141-
catalog_name=env_vars["CATALOG_NAME"],
1144+
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
11421145
inst_name=f"{query_result[0][0].name}",
1143-
table_name=f"sample_training_{run_id}_support_overview",
1144-
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1146+
table_name=f"inference_{run_id}_shap_feature_importance",
1147+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
11451148
)
11461149

11471150
return rows
@@ -1150,8 +1153,11 @@ def get_training_support_overview(
11501153
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
11511154

11521155

1153-
@router.get("/{inst_id}/inference/feature_value/{run_id}")
1154-
def get_feature_value(
1156+
## FE Training Tables
1157+
1158+
1159+
@router.get("/{inst_id}/training/feature_importance/{run_id}")
1160+
def get_training_feature_importance(
11551161
inst_id: str,
11561162
run_id: str,
11571163
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
@@ -1180,10 +1186,10 @@ def get_feature_value(
11801186
try:
11811187
dbc = DatabricksControl()
11821188
rows = dbc.fetch_table_data(
1183-
catalog_name=env_vars["CATALOG_NAME"],
1189+
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
11841190
inst_name=f"{query_result[0][0].name}",
1185-
table_name=f"sample_inference_{run_id}_shap_feature_importance",
1186-
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1191+
table_name=f"training_{run_id}_shap_feature_importance",
1192+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
11871193
)
11881194

11891195
return rows
@@ -1193,7 +1199,7 @@ def get_feature_value(
11931199

11941200

11951201
@router.get("/{inst_id}/training/confusion_matrix/{run_id}")
1196-
def get_confusion_matrix(
1202+
def get_training_confusion_matrix(
11971203
inst_id: str,
11981204
run_id: str,
11991205
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
@@ -1222,10 +1228,10 @@ def get_confusion_matrix(
12221228
try:
12231229
dbc = DatabricksControl()
12241230
rows = dbc.fetch_table_data(
1225-
catalog_name=env_vars["CATALOG_NAME"],
1231+
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
12261232
inst_name=f"{query_result[0][0].name}",
1227-
table_name=f"sample_training_{run_id}_confusion_matrix",
1228-
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1233+
table_name=f"training_{run_id}_confusion_matrix",
1234+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
12291235
)
12301236

12311237
return rows
@@ -1235,7 +1241,49 @@ def get_confusion_matrix(
12351241

12361242

12371243
@router.get("/{inst_id}/training/roc_curve/{run_id}")
1238-
def get_roc_curve(
1244+
def get_training_roc_curve(
1245+
inst_id: str,
1246+
run_id: str,
1247+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1248+
sql_session: Annotated[Session, Depends(get_session)],
1249+
) -> List[dict[str, Any]]:
1250+
"""Returns a signed URL for uploading data to a specific institution."""
1251+
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1252+
has_access_to_inst_or_err(inst_id, current_user)
1253+
local_session.set(sql_session)
1254+
query_result = (
1255+
local_session.get()
1256+
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
1257+
.all()
1258+
)
1259+
if not query_result or len(query_result) == 0:
1260+
raise HTTPException(
1261+
status_code=status.HTTP_404_NOT_FOUND,
1262+
detail="Institution not found.",
1263+
)
1264+
if len(query_result) > 1:
1265+
raise HTTPException(
1266+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1267+
detail="Institution duplicates found.",
1268+
)
1269+
1270+
try:
1271+
dbc = DatabricksControl()
1272+
rows = dbc.fetch_table_data(
1273+
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
1274+
inst_name=f"{query_result[0][0].name}",
1275+
table_name=f"training_{run_id}_roc_curve",
1276+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
1277+
)
1278+
1279+
return rows
1280+
except ValueError as ve:
1281+
# Return a 400 error with the specific message from ValueError
1282+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
1283+
1284+
1285+
@router.get("/{inst_id}/training/support-overview/{run_id}")
1286+
def get_training_support_overview(
12391287
inst_id: str,
12401288
run_id: str,
12411289
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
@@ -1264,10 +1312,10 @@ def get_roc_curve(
12641312
try:
12651313
dbc = DatabricksControl()
12661314
rows = dbc.fetch_table_data(
1267-
catalog_name=env_vars["CATALOG_NAME"],
1315+
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
12681316
inst_name=f"{query_result[0][0].name}",
1269-
table_name=f"sample_training_{run_id}_roc_curve",
1270-
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
1317+
table_name=f"training_{run_id}_support_overview",
1318+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
12711319
)
12721320

12731321
return rows

0 commit comments

Comments
 (0)