Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from databricks.sdk import WorkspaceClient
from typing import Annotated, Any, Dict, List, cast, IO, Optional
from pydantic import BaseModel, Field
from fastapi import APIRouter, Depends, HTTPException, status, Response
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query
from fastapi.responses import FileResponse
from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -1283,7 +1283,7 @@ def get_inference_top_features(
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> List[dict[str, Any]]:
"""Returns a signed URL for uploading data to a specific institution."""
"""Returns data for a specific institution."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
Expand Down Expand Up @@ -1318,6 +1318,79 @@ def get_inference_top_features(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


# Get Box plot values
@router.get("/{inst_id}/inference/features-boxplot-stat/{run_id}")
def get_inference_feature_boxstats(
inst_id: str,
run_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
feature_name: Optional[str] = Query(
None, description="If provided, filter by this feature name"
),
) -> List[dict[str, Any]]:
"""Returns box-plot stats for an institution/run. If `feature_name` is supplied,
only rows for that feature are returned."""
# raise error at this level instead bc otherwise it's getting wrapped as a 200
has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)
if not query_result or len(query_result) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(query_result) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

try:
dbc = DatabricksControl()
rows = dbc.fetch_table_data(
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
inst_name=f"{query_result[0][0].name}",
table_name=f"inference_{run_id}_box_plot_table",
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
)
if not feature_name:
return rows

# Helper: extract feature_name from various shapes (top-level or JSON column)
def row_feature_name(row: dict[str, Any]) -> Optional[str]:
# common case: it's a top-level column
if "feature_name" in row and row["feature_name"] is not None:
return str(row["feature_name"])
# fallback: search any dict-valued column for a 'feature_name' key
for v in row.values():
if (
isinstance(v, dict)
and "feature_name" in v
and v["feature_name"] is not None
):
return str(v["feature_name"])
return None

filtered = [r for r in rows if row_feature_name(r) == feature_name]

if not filtered:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Feature '{feature_name}' not found for run_id '{run_id}'.",
)

return filtered

except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


# Get SHAP Values for Inference
@router.get("/{inst_id}/inference/support-overview/{run_id}")
def get_inference_support_overview(
Expand Down Expand Up @@ -1576,8 +1649,8 @@ def get_training_support_overview(

@router.get("/{inst_id}/training/model-cards/{model_name}")
def get_model_cards(
model_name: str,
inst_id: str,
model_name: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> FileResponse:
Expand Down
Loading