|
5 | 5 | from databricks.sdk import WorkspaceClient |
6 | 6 | from typing import Annotated, Any, Dict, List, cast, IO, Optional |
7 | 7 | from pydantic import BaseModel, Field |
8 | | -from fastapi import APIRouter, Depends, HTTPException, status, Response |
| 8 | +from fastapi import APIRouter, Depends, HTTPException, status, Response, Query |
9 | 9 | from fastapi.responses import FileResponse |
10 | 10 | from sqlalchemy import and_, or_ |
11 | 11 | from sqlalchemy.orm import Session |
@@ -1283,7 +1283,7 @@ def get_inference_top_features( |
1283 | 1283 | current_user: Annotated[BaseUser, Depends(get_current_active_user)], |
1284 | 1284 | sql_session: Annotated[Session, Depends(get_session)], |
1285 | 1285 | ) -> List[dict[str, Any]]: |
1286 | | - """Returns a signed URL for uploading data to a specific institution.""" |
| 1286 | + """Returns data for a specific institution.""" |
1287 | 1287 | # raise error at this level instead bc otherwise it's getting wrapped as a 200 |
1288 | 1288 | has_access_to_inst_or_err(inst_id, current_user) |
1289 | 1289 | local_session.set(sql_session) |
@@ -1318,6 +1318,79 @@ def get_inference_top_features( |
1318 | 1318 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) |
1319 | 1319 |
|
1320 | 1320 |
|
| 1321 | +# Get Box plot values |
| 1322 | +@router.get("/{inst_id}/inference/features-boxplot-stat/{run_id}") |
| 1323 | +def get_inference_feature_boxstats( |
| 1324 | + inst_id: str, |
| 1325 | + run_id: str, |
| 1326 | + current_user: Annotated[BaseUser, Depends(get_current_active_user)], |
| 1327 | + sql_session: Annotated[Session, Depends(get_session)], |
| 1328 | + feature_name: Optional[str] = Query( |
| 1329 | + None, description="If provided, filter by this feature name" |
| 1330 | + ), |
| 1331 | +) -> List[dict[str, Any]]: |
| 1332 | + """Returns box-plot stats for an institution/run. If `feature_name` is supplied, |
| 1333 | + only rows for that feature are returned.""" |
| 1334 | + # raise error at this level instead bc otherwise it's getting wrapped as a 200 |
| 1335 | + has_access_to_inst_or_err(inst_id, current_user) |
| 1336 | + local_session.set(sql_session) |
| 1337 | + query_result = ( |
| 1338 | + local_session.get() |
| 1339 | + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) |
| 1340 | + .all() |
| 1341 | + ) |
| 1342 | + if not query_result or len(query_result) == 0: |
| 1343 | + raise HTTPException( |
| 1344 | + status_code=status.HTTP_404_NOT_FOUND, |
| 1345 | + detail="Institution not found.", |
| 1346 | + ) |
| 1347 | + if len(query_result) > 1: |
| 1348 | + raise HTTPException( |
| 1349 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 1350 | + detail="Institution duplicates found.", |
| 1351 | + ) |
| 1352 | + |
| 1353 | + try: |
| 1354 | + dbc = DatabricksControl() |
| 1355 | + rows = dbc.fetch_table_data( |
| 1356 | + catalog_name=env_vars["CATALOG_NAME"], # type: ignore |
| 1357 | + inst_name=f"{query_result[0][0].name}", |
| 1358 | + table_name=f"inference_{run_id}_box_plot_table", |
| 1359 | + warehouse_id=env_vars["SQL_WAREHOUSE_ID"], # type: ignore |
| 1360 | + ) |
| 1361 | + if not feature_name: |
| 1362 | + return rows |
| 1363 | + |
| 1364 | + # Helper: extract feature_name from various shapes (top-level or JSON column) |
| 1365 | + def row_feature_name(row: dict[str, Any]) -> Optional[str]: |
| 1366 | + # common case: it's a top-level column |
| 1367 | + if "feature_name" in row and row["feature_name"] is not None: |
| 1368 | + return str(row["feature_name"]) |
| 1369 | + # fallback: search any dict-valued column for a 'feature_name' key |
| 1370 | + for v in row.values(): |
| 1371 | + if ( |
| 1372 | + isinstance(v, dict) |
| 1373 | + and "feature_name" in v |
| 1374 | + and v["feature_name"] is not None |
| 1375 | + ): |
| 1376 | + return str(v["feature_name"]) |
| 1377 | + return None |
| 1378 | + |
| 1379 | + filtered = [r for r in rows if row_feature_name(r) == feature_name] |
| 1380 | + |
| 1381 | + if not filtered: |
| 1382 | + raise HTTPException( |
| 1383 | + status_code=status.HTTP_404_NOT_FOUND, |
| 1384 | + detail=f"Feature '{feature_name}' not found for run_id '{run_id}'.", |
| 1385 | + ) |
| 1386 | + |
| 1387 | + return filtered |
| 1388 | + |
| 1389 | + except ValueError as ve: |
| 1390 | + # Return a 400 error with the specific message from ValueError |
| 1391 | + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) |
| 1392 | + |
| 1393 | + |
1321 | 1394 | # Get SHAP Values for Inference |
1322 | 1395 | @router.get("/{inst_id}/inference/support-overview/{run_id}") |
1323 | 1396 | def get_inference_support_overview( |
@@ -1576,8 +1649,8 @@ def get_training_support_overview( |
1576 | 1649 |
|
1577 | 1650 | @router.get("/{inst_id}/training/model-cards/{model_name}") |
1578 | 1651 | def get_model_cards( |
1579 | | - model_name: str, |
1580 | 1652 | inst_id: str, |
| 1653 | + model_name: str, |
1581 | 1654 | current_user: Annotated[BaseUser, Depends(get_current_active_user)], |
1582 | 1655 | sql_session: Annotated[Session, Depends(get_session)], |
1583 | 1656 | ) -> FileResponse: |
|
0 commit comments