Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 85 additions & 37 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def get_all_files(
inst_id: str,
sst_generated_value: bool | None,
sess: Session,
storage_control,
storage_control: Any,
) -> list[DataInfo]:
"""Retrieve all files."""
# Update from bucket
Expand Down Expand Up @@ -191,7 +191,7 @@ def get_all_files(
"uploaded_date": elem.created_at,
}
)
return result_files
return result_files # type: ignore


def get_all_batches(
Expand Down Expand Up @@ -225,19 +225,19 @@ def get_all_batches(
"updated_at": elem.updated_at,
}
)
return result_batches
return result_batches # type: ignore


def uuids_to_strs(files) -> set[str]:
def uuids_to_strs(files: Any) -> set[str]:
"""Convert a set of uuids to strings.
The input is of type sqlalchemy.orm.collections.InstrumentedSet.
"""
return [uuid_to_str(x.id) for x in files]
return [uuid_to_str(x.id) for x in files] # type: ignore


def strs_to_uuids(files) -> set[uuid.UUID]:
def strs_to_uuids(files: Any) -> set[uuid.UUID]:
"""Convert a set of strs to uuids."""
return [str_to_uuid(x) for x in files]
return [str_to_uuid(x) for x in files] # type: ignore


@router.get("/{inst_id}/input", response_model=DataOverview)
Expand Down Expand Up @@ -463,7 +463,7 @@ def create_batch(
batch = BatchTable(
name=req.name,
inst_id=str_to_uuid(inst_id),
created_by=str_to_uuid(current_user.user_id),
created_by=str_to_uuid(current_user.user_id), # type: ignore
)
f_names = [] if not req.file_names else req.file_names
f_ids = [] if not req.file_ids else strs_to_uuids(req.file_ids)
Expand Down Expand Up @@ -647,7 +647,7 @@ def update_batch(
existing_batch.name = update_data_req["name"]
if "completed" in update_data_req:
existing_batch.completed = update_data_req["completed"]
existing_batch.updated_by = str_to_uuid(current_user.user_id)
existing_batch.updated_by = str_to_uuid(current_user.user_id) # type: ignore
local_session.get().commit()
res = (
local_session.get()
Expand Down Expand Up @@ -931,7 +931,7 @@ def validation_helper(
new_file_record = FileTable(
name=file_name,
inst_id=str_to_uuid(inst_id),
uploader=str_to_uuid(current_user.user_id),
uploader=str_to_uuid(current_user.user_id), # type: ignore
source=source_str,
sst_generated=False,
schemas=list(allowed_schemas),
Expand Down Expand Up @@ -974,7 +974,7 @@ def validate_file_sftp(
) -> Any:
"""Validate a given file pulled from SFTP. The file_name should be url encoded."""
file_name = decode_url_piece(file_name)
if not current_user.is_datakinder:
if not current_user.is_datakinder: # type: ignore
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="SFTP validation needs to be done by a datakinder.",
Expand Down Expand Up @@ -1022,9 +1022,12 @@ def get_upload_url(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


## FE Inference Tables


# Get SHAP Values for Inference
@router.get("/{inst_id}/inference/top-features/{run_id}")
def get_top_features(
def get_inference_top_features(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
Expand Down Expand Up @@ -1053,10 +1056,10 @@ def get_top_features(
try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_inference_{run_id}_features_with_most_impact",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
table_name=f"inference_{run_id}_features_with_most_impact",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
)

return rows
Expand All @@ -1067,7 +1070,7 @@ def get_top_features(

# Get SHAP Values for Inference
@router.get("/{inst_id}/inference/support-overview/{run_id}")
def get_support_overview(
def get_inference_support_overview(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
Expand Down Expand Up @@ -1096,10 +1099,10 @@ def get_support_overview(
try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_inference_{run_id}_support_overview",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
table_name=f"inference_{run_id}_support_overview",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
)

return rows
Expand All @@ -1108,8 +1111,8 @@ def get_support_overview(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


@router.get("/{inst_id}/training/support-overview/{run_id}")
def get_training_support_overview(
@router.get("/{inst_id}/inference/feature_importance/{run_id}")
def get_inference_feature_importance(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
Expand Down Expand Up @@ -1138,10 +1141,10 @@ def get_training_support_overview(
try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_training_{run_id}_support_overview",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
table_name=f"inference_{run_id}_shap_feature_importance",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
)

return rows
Expand All @@ -1150,8 +1153,11 @@ def get_training_support_overview(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


@router.get("/{inst_id}/inference/feature_value/{run_id}")
def get_feature_value(
## FE Training Tables


@router.get("/{inst_id}/training/feature_importance/{run_id}")
def get_training_feature_importance(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
Expand Down Expand Up @@ -1180,10 +1186,10 @@ def get_feature_value(
try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_inference_{run_id}_shap_feature_importance",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
table_name=f"training_{run_id}_shap_feature_importance",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
)

return rows
Expand All @@ -1193,7 +1199,7 @@ def get_feature_value(


@router.get("/{inst_id}/training/confusion_matrix/{run_id}")
def get_confusion_matrix(
def get_training_confusion_matrix(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
Expand Down Expand Up @@ -1222,10 +1228,10 @@ def get_confusion_matrix(
try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_training_{run_id}_confusion_matrix",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
table_name=f"training_{run_id}_confusion_matrix",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
)

return rows
Expand All @@ -1235,7 +1241,49 @@ def get_confusion_matrix(


@router.get("/{inst_id}/training/roc_curve/{run_id}")
def get_roc_curve(
def get_training_roc_curve(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)
if not query_result or len(query_result) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(query_result) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
inst_name=f"{query_result[0][0].name}",
table_name=f"training_{run_id}_roc_curve",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
)

return rows
except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


@router.get("/{inst_id}/training/support-overview/{run_id}")
def get_training_support_overview(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
Expand Down Expand Up @@ -1264,10 +1312,10 @@ def get_roc_curve(
try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"],
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
inst_name=f"{query_result[0][0].name}",
table_name=f"sample_training_{run_id}_roc_curve",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"],
table_name=f"training_{run_id}_support_overview",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
)

return rows
Expand Down
Loading