Skip to content

Commit 45900ce

Browse files
authored
Merge pull request #149 from datakind/develop
Features Box Plot Stat
2 parents 8156997 + d6329ab commit 45900ce

1 file changed

Lines changed: 76 additions & 3 deletions

File tree

src/webapp/routers/data.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from databricks.sdk import WorkspaceClient
66
from typing import Annotated, Any, Dict, List, cast, IO, Optional
77
from pydantic import BaseModel, Field
8-
from fastapi import APIRouter, Depends, HTTPException, status, Response
8+
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query
99
from fastapi.responses import FileResponse
1010
from sqlalchemy import and_, or_
1111
from sqlalchemy.orm import Session
@@ -1283,7 +1283,7 @@ def get_inference_top_features(
12831283
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
12841284
sql_session: Annotated[Session, Depends(get_session)],
12851285
) -> List[dict[str, Any]]:
1286-
"""Returns a signed URL for uploading data to a specific institution."""
1286+
"""Returns data for a specific institution."""
12871287
# raise error at this level instead bc otherwise it's getting wrapped as a 200
12881288
has_access_to_inst_or_err(inst_id, current_user)
12891289
local_session.set(sql_session)
@@ -1318,6 +1318,79 @@ def get_inference_top_features(
13181318
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
13191319

13201320

1321+
# Get Box plot values
1322+
@router.get("/{inst_id}/inference/features-boxplot-stat/{run_id}")
1323+
def get_inference_feature_boxstats(
1324+
inst_id: str,
1325+
run_id: str,
1326+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1327+
sql_session: Annotated[Session, Depends(get_session)],
1328+
feature_name: Optional[str] = Query(
1329+
None, description="If provided, filter by this feature name"
1330+
),
1331+
) -> List[dict[str, Any]]:
1332+
"""Returns box-plot stats for an institution/run. If `feature_name` is supplied,
1333+
only rows for that feature are returned."""
1334+
# raise error at this level instead bc otherwise it's getting wrapped as a 200
1335+
has_access_to_inst_or_err(inst_id, current_user)
1336+
local_session.set(sql_session)
1337+
query_result = (
1338+
local_session.get()
1339+
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
1340+
.all()
1341+
)
1342+
if not query_result or len(query_result) == 0:
1343+
raise HTTPException(
1344+
status_code=status.HTTP_404_NOT_FOUND,
1345+
detail="Institution not found.",
1346+
)
1347+
if len(query_result) > 1:
1348+
raise HTTPException(
1349+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1350+
detail="Institution duplicates found.",
1351+
)
1352+
1353+
try:
1354+
dbc = DatabricksControl()
1355+
rows = dbc.fetch_table_data(
1356+
catalog_name=env_vars["CATALOG_NAME"], # type: ignore
1357+
inst_name=f"{query_result[0][0].name}",
1358+
table_name=f"inference_{run_id}_box_plot_table",
1359+
warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore
1360+
)
1361+
if not feature_name:
1362+
return rows
1363+
1364+
# Helper: extract feature_name from various shapes (top-level or JSON column)
1365+
def row_feature_name(row: dict[str, Any]) -> Optional[str]:
1366+
# common case: it's a top-level column
1367+
if "feature_name" in row and row["feature_name"] is not None:
1368+
return str(row["feature_name"])
1369+
# fallback: search any dict-valued column for a 'feature_name' key
1370+
for v in row.values():
1371+
if (
1372+
isinstance(v, dict)
1373+
and "feature_name" in v
1374+
and v["feature_name"] is not None
1375+
):
1376+
return str(v["feature_name"])
1377+
return None
1378+
1379+
filtered = [r for r in rows if row_feature_name(r) == feature_name]
1380+
1381+
if not filtered:
1382+
raise HTTPException(
1383+
status_code=status.HTTP_404_NOT_FOUND,
1384+
detail=f"Feature '{feature_name}' not found for run_id '{run_id}'.",
1385+
)
1386+
1387+
return filtered
1388+
1389+
except ValueError as ve:
1390+
# Return a 400 error with the specific message from ValueError
1391+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
1392+
1393+
13211394
# Get SHAP Values for Inference
13221395
@router.get("/{inst_id}/inference/support-overview/{run_id}")
13231396
def get_inference_support_overview(
@@ -1576,8 +1649,8 @@ def get_training_support_overview(
15761649

15771650
@router.get("/{inst_id}/training/model-cards/{model_name}")
15781651
def get_model_cards(
1579-
model_name: str,
15801652
inst_id: str,
1653+
model_name: str,
15811654
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
15821655
sql_session: Annotated[Session, Depends(get_session)],
15831656
) -> FileResponse:

0 commit comments

Comments
 (0)