Skip to content

Commit 50b2ac7

Browse files
authored
Merge pull request #126 from datakind/develop
Model Cards Endpoint
2 parents 0410e60 + 355bce2 commit 50b2ac7

5 files changed

Lines changed: 2126 additions & 1368 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ dependencies = [
2929
"types-six",
3030
"fuzzywuzzy",
3131
"databricks-sql-connector",
32-
"pandera~=0.13"
32+
"pandera~=0.13",
33+
"mlflow"
3334
]
3435

3536
[project.urls]

src/webapp/routers/data.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,21 @@
22

33
import uuid
44
from datetime import datetime, date
5-
5+
from databricks.sdk import WorkspaceClient
66
from typing import Annotated, Any, Dict, List
77
from pydantic import BaseModel
88
from fastapi import APIRouter, Depends, HTTPException, status, Response
9+
from fastapi.responses import FileResponse
910
from sqlalchemy import and_, or_
1011
from sqlalchemy.orm import Session
1112
from sqlalchemy.future import select
1213
import os
1314
import logging
1415
from sqlalchemy.exc import IntegrityError
15-
from ..config import env_vars
16+
from ..config import databricks_vars, env_vars, gcs_vars
17+
import mlflow
18+
from mlflow.exceptions import MlflowException
19+
import tempfile
1620

1721
from ..utilities import (
1822
has_access_to_inst_or_err,
@@ -50,6 +54,8 @@
5054
tags=["data"],
5155
)
5256

57+
LOGGER = logging.getLogger(__name__)
58+
5359

5460
class BatchCreationRequest(BaseModel):
5561
"""The Batch creation request."""
@@ -1322,3 +1328,70 @@ def get_training_support_overview(
13221328
except ValueError as ve:
13231329
# Return a 400 error with the specific message from ValueError
13241330
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
1331+
1332+
1333+
@router.get("/{inst_id}/training/model-cards/{run_id}/{model_name}")
1334+
def get_model_cards(
1335+
run_id: str,
1336+
model_name: str,
1337+
inst_id: str,
1338+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
1339+
sql_session: Annotated[Session, Depends(get_session)],
1340+
artifact_path: str,
1341+
) -> FileResponse:
1342+
has_access_to_inst_or_err(inst_id, current_user)
1343+
local_session.set(sql_session)
1344+
query_result = (
1345+
local_session.get()
1346+
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
1347+
.all()
1348+
)
1349+
if not query_result or len(query_result) == 0:
1350+
raise HTTPException(
1351+
status_code=status.HTTP_404_NOT_FOUND,
1352+
detail="Institution not found.",
1353+
)
1354+
if len(query_result) > 1:
1355+
raise HTTPException(
1356+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1357+
detail="Institution duplicates found.",
1358+
)
1359+
1360+
try:
1361+
w = WorkspaceClient(
1362+
host=databricks_vars["DATABRICKS_HOST_URL"],
1363+
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
1364+
)
1365+
os.environ["DATABRICKS_HOST"] = w._config.host
1366+
os.environ["DATABRICKS_TOKEN"] = w._config.token
1367+
LOGGER.info("Successfully created Databricks WorkspaceClient.")
1368+
except Exception as e:
1369+
LOGGER.exception(
1370+
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
1371+
databricks_vars["DATABRICKS_HOST_URL"],
1372+
gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
1373+
)
1374+
raise ValueError(
1375+
f"get_model_cards(): Workspace client initialization failed: {e}"
1376+
)
1377+
1378+
try:
1379+
mlflow.set_tracking_uri("databricks")
1380+
with tempfile.TemporaryDirectory() as tmpdir:
1381+
artifact_path = f"model_card/model-card-{model_name}.pdf"
1382+
artifact_uri = f"runs:/{run_id}/{artifact_path}"
1383+
local_path = mlflow.artifacts.download_artifacts(
1384+
artifact_uri=artifact_uri, dst_path=tmpdir
1385+
)
1386+
1387+
LOGGER.debug("Artifact provisioned successfully")
1388+
return FileResponse(
1389+
path=local_path,
1390+
filename=os.path.basename(local_path),
1391+
media_type="application/pdf",
1392+
)
1393+
1394+
except MlflowException as e:
1395+
# 6. Handle errors gracefully
1396+
LOGGER.debug(f"Artifact download failed: {e}")
1397+
raise HTTPException(status_code=500, detail=f"Artifact download failed: {e}")

src/webapp/validation_schemas/base_schema.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"checks": []
3333
},
3434
"first_generation_student": {
35-
"dtype": "string",
35+
"dtype": "category",
3636
"coerce": true,
3737
"nullable": true,
3838
"required": false,

src/webapp/validation_schemas/pdp_schema_extension.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -515,23 +515,23 @@
515515
"checks": []
516516
},
517517
"years_to_associates_or_certificate_at_cohort_inst": {
518-
"dtype": "category",
518+
"dtype": "string",
519519
"coerce": true,
520520
"nullable": true,
521521
"required": true,
522522
"aliases": [],
523523
"checks": []
524524
},
525525
"years_to_associates_or_certificate_at_other_inst": {
526-
"dtype": "category",
526+
"dtype": "string",
527527
"coerce": true,
528528
"nullable": true,
529529
"required": true,
530530
"aliases": [],
531531
"checks": []
532532
},
533533
"years_to_bachelor_at_other_inst": {
534-
"dtype": "category",
534+
"dtype": "string",
535535
"coerce": true,
536536
"nullable": true,
537537
"required": true,

0 commit comments

Comments
 (0)