|
12 | 12 | import os |
13 | 13 | import logging |
14 | 14 | from sqlalchemy.exc import IntegrityError |
| 15 | +from ..config import env_vars |
15 | 16 |
|
16 | 17 | from ..utilities import ( |
17 | 18 | has_access_to_inst_or_err, |
|
31 | 32 | local_session, |
32 | 33 | BatchTable, |
33 | 34 | FileTable, |
| 35 | + InstTable, |
34 | 36 | ) |
35 | 37 |
|
| 38 | +from ..databricks import DatabricksControl |
36 | 39 | from ..gcsdbutils import update_db_from_bucket |
37 | 40 |
|
38 | 41 | from ..gcsutil import StorageControl |
@@ -1018,3 +1021,220 @@ def get_upload_url( |
1018 | 1021 | except ValueError as ve: |
1019 | 1022 | # Return a 400 error with the specific message from ValueError |
1020 | 1023 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) |
| 1024 | + |
| 1025 | + |
| 1026 | +# Get SHAP Values for Inference |
| 1027 | +@router.get("/{inst_id}/inference/top-features/{run_id}", response_model=str) |
| 1028 | +def get_top_features( |
| 1029 | + inst_id: str, |
| 1030 | + run_id: str, |
| 1031 | + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], |
| 1032 | + sql_session: Annotated[Session, Depends(get_session)], |
| 1033 | +) -> List[dict[str, Any]]: |
| 1034 | + """Returns a signed URL for uploading data to a specific institution.""" |
| 1035 | + # raise error at this level instead bc otherwise it's getting wrapped as a 200 |
| 1036 | + # has_access_to_inst_or_err(inst_id, current_user) |
| 1037 | + local_session.set(sql_session) |
| 1038 | + query_result = ( |
| 1039 | + local_session.get() |
| 1040 | + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) |
| 1041 | + .all() |
| 1042 | + ) |
| 1043 | + if not query_result or len(query_result) == 0: |
| 1044 | + raise HTTPException( |
| 1045 | + status_code=status.HTTP_404_NOT_FOUND, |
| 1046 | + detail="Institution not found.", |
| 1047 | + ) |
| 1048 | + if len(query_result) > 1: |
| 1049 | + raise HTTPException( |
| 1050 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 1051 | + detail="Institution duplicates found.", |
| 1052 | + ) |
| 1053 | + |
| 1054 | + try: |
| 1055 | + dbc = DatabricksControl() |
| 1056 | + rows = dbc.fetch_table_data( |
| 1057 | + catalog_name=env_vars["CATALOG_NAME"], |
| 1058 | + schema_name=f"{query_result[0][0].name}_silver", |
| 1059 | + table_name=f"sample_inference_{run_id}_features_with_most_impact", |
| 1060 | + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], |
| 1061 | + limit=500, |
| 1062 | + ) |
| 1063 | + |
| 1064 | + return rows |
| 1065 | + except ValueError as ve: |
| 1066 | + # Return a 400 error with the specific message from ValueError |
| 1067 | + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) |
| 1068 | + |
| 1069 | + |
| 1070 | +# Get SHAP Values for Inference |
| 1071 | +@router.get("/{inst_id}/inference/support-overview/{run_id}", response_model=str) |
| 1072 | +def get_support_overview( |
| 1073 | + inst_id: str, |
| 1074 | + run_id: str, |
| 1075 | + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], |
| 1076 | + sql_session: Annotated[Session, Depends(get_session)], |
| 1077 | +) -> List[dict[str, Any]]: |
| 1078 | + """Returns a signed URL for uploading data to a specific institution.""" |
| 1079 | + # raise error at this level instead bc otherwise it's getting wrapped as a 200 |
| 1080 | + # has_access_to_inst_or_err(inst_id, current_user) |
| 1081 | + local_session.set(sql_session) |
| 1082 | + query_result = ( |
| 1083 | + local_session.get() |
| 1084 | + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) |
| 1085 | + .all() |
| 1086 | + ) |
| 1087 | + if not query_result or len(query_result) == 0: |
| 1088 | + raise HTTPException( |
| 1089 | + status_code=status.HTTP_404_NOT_FOUND, |
| 1090 | + detail="Institution not found.", |
| 1091 | + ) |
| 1092 | + if len(query_result) > 1: |
| 1093 | + raise HTTPException( |
| 1094 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 1095 | + detail="Institution duplicates found.", |
| 1096 | + ) |
| 1097 | + |
| 1098 | + try: |
| 1099 | + dbc = DatabricksControl() |
| 1100 | + rows = dbc.fetch_table_data( |
| 1101 | + catalog_name=env_vars["CATALOG_NAME"], |
| 1102 | + schema_name=f"{query_result[0][0].name}_silver", |
| 1103 | + table_name=f"sample_inference_{run_id}_support_overview", |
| 1104 | + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], |
| 1105 | + limit=500, |
| 1106 | + ) |
| 1107 | + |
| 1108 | + return rows |
| 1109 | + except ValueError as ve: |
| 1110 | + # Return a 400 error with the specific message from ValueError |
| 1111 | + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) |
| 1112 | + |
| 1113 | + |
| 1114 | +@router.get("/{inst_id}/inference/feature_value/{run_id}", response_model=str) |
| 1115 | +def get_feature_value( |
| 1116 | + inst_id: str, |
| 1117 | + run_id: str, |
| 1118 | + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], |
| 1119 | + sql_session: Annotated[Session, Depends(get_session)], |
| 1120 | +) -> List[dict[str, Any]]: |
| 1121 | + """Returns a signed URL for uploading data to a specific institution.""" |
| 1122 | + # raise error at this level instead bc otherwise it's getting wrapped as a 200 |
| 1123 | + # has_access_to_inst_or_err(inst_id, current_user) |
| 1124 | + local_session.set(sql_session) |
| 1125 | + query_result = ( |
| 1126 | + local_session.get() |
| 1127 | + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) |
| 1128 | + .all() |
| 1129 | + ) |
| 1130 | + if not query_result or len(query_result) == 0: |
| 1131 | + raise HTTPException( |
| 1132 | + status_code=status.HTTP_404_NOT_FOUND, |
| 1133 | + detail="Institution not found.", |
| 1134 | + ) |
| 1135 | + if len(query_result) > 1: |
| 1136 | + raise HTTPException( |
| 1137 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 1138 | + detail="Institution duplicates found.", |
| 1139 | + ) |
| 1140 | + |
| 1141 | + try: |
| 1142 | + dbc = DatabricksControl() |
| 1143 | + rows = dbc.fetch_table_data( |
| 1144 | + catalog_name=env_vars["CATALOG_NAME"], |
| 1145 | + schema_name=f"{query_result[0][0].name}_silver", |
| 1146 | + table_name=f"sample_inference_{run_id}_shap_feature_importance", |
| 1147 | + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], |
| 1148 | + limit=500, |
| 1149 | + ) |
| 1150 | + |
| 1151 | + return rows |
| 1152 | + except ValueError as ve: |
| 1153 | + # Return a 400 error with the specific message from ValueError |
| 1154 | + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) |
| 1155 | + |
| 1156 | + |
| 1157 | +@router.get("/{inst_id}/training/confusion_matrix/{run_id}", response_model=str) |
| 1158 | +def get_confusion_matrix( |
| 1159 | + inst_id: str, |
| 1160 | + run_id: str, |
| 1161 | + ##current_user: Annotated[BaseUser, Depends(get_current_active_user)], |
| 1162 | + sql_session: Annotated[Session, Depends(get_session)], |
| 1163 | +) -> List[dict[str, Any]]: |
| 1164 | + """Returns a signed URL for uploading data to a specific institution.""" |
| 1165 | + # raise error at this level instead bc otherwise it's getting wrapped as a 200 |
| 1166 | + # has_access_to_inst_or_err(inst_id, current_user) |
| 1167 | + local_session.set(sql_session) |
| 1168 | + query_result = ( |
| 1169 | + local_session.get() |
| 1170 | + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) |
| 1171 | + .all() |
| 1172 | + ) |
| 1173 | + if not query_result or len(query_result) == 0: |
| 1174 | + raise HTTPException( |
| 1175 | + status_code=status.HTTP_404_NOT_FOUND, |
| 1176 | + detail="Institution not found.", |
| 1177 | + ) |
| 1178 | + if len(query_result) > 1: |
| 1179 | + raise HTTPException( |
| 1180 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 1181 | + detail="Institution duplicates found.", |
| 1182 | + ) |
| 1183 | + |
| 1184 | + try: |
| 1185 | + dbc = DatabricksControl() |
| 1186 | + rows = dbc.fetch_table_data( |
| 1187 | + catalog_name=env_vars["CATALOG_NAME"], |
| 1188 | + schema_name=f"{query_result[0][0].name}_silver", |
| 1189 | + table_name=f"sample_training_{run_id}_confusion_matrix", |
| 1190 | + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], |
| 1191 | + limit=500, |
| 1192 | + ) |
| 1193 | + |
| 1194 | + return rows |
| 1195 | + except ValueError as ve: |
| 1196 | + # Return a 400 error with the specific message from ValueError |
| 1197 | + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) |
| 1198 | + |
| 1199 | + |
| 1200 | +@router.get("/{inst_id}/training/roc_curve/{run_id}", response_model=str) |
| 1201 | +def get_roc_curve( |
| 1202 | + inst_id: str, |
| 1203 | + run_id: str, |
| 1204 | + # current_user: Annotated[BaseUser, Depends(get_current_active_user)], |
| 1205 | + sql_session: Annotated[Session, Depends(get_session)], |
| 1206 | +) -> List[dict[str, Any]]: |
| 1207 | + """Returns a signed URL for uploading data to a specific institution.""" |
| 1208 | + # raise error at this level instead bc otherwise it's getting wrapped as a 200 |
| 1209 | + # has_access_to_inst_or_err(inst_id, current_user) |
| 1210 | + local_session.set(sql_session) |
| 1211 | + query_result = ( |
| 1212 | + local_session.get() |
| 1213 | + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) |
| 1214 | + .all() |
| 1215 | + ) |
| 1216 | + if not query_result or len(query_result) == 0: |
| 1217 | + raise HTTPException( |
| 1218 | + status_code=status.HTTP_404_NOT_FOUND, |
| 1219 | + detail="Institution not found.", |
| 1220 | + ) |
| 1221 | + if len(query_result) > 1: |
| 1222 | + raise HTTPException( |
| 1223 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 1224 | + detail="Institution duplicates found.", |
| 1225 | + ) |
| 1226 | + |
| 1227 | + try: |
| 1228 | + dbc = DatabricksControl() |
| 1229 | + rows = dbc.fetch_table_data( |
| 1230 | + catalog_name=env_vars["CATALOG_NAME"], |
| 1231 | + schema_name=f"{query_result[0][0].name}_silver", |
| 1232 | + table_name=f"sample_training_{run_id}_roc_curve", |
| 1233 | + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], |
| 1234 | + limit=500, |
| 1235 | + ) |
| 1236 | + |
| 1237 | + return rows |
| 1238 | + except ValueError as ve: |
| 1239 | + # Return a 400 error with the specific message from ValueError |
| 1240 | + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) |
0 commit comments