Skip to content
Merged
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dependencies = [
"types-six",
"fuzzywuzzy",
"databricks-sql-connector",
"pandera~=0.13"
"pandera~=0.13",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eventually we may need to specify versions for the rest like you've done here for pandera

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, to be honest, I wasn’t too sure what the other versions should be but I totally agree.

"mlflow"
]

[project.urls]
Expand Down
77 changes: 75 additions & 2 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@

import uuid
from datetime import datetime, date

from databricks.sdk import WorkspaceClient
from typing import Annotated, Any, Dict, List
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, status, Response
from fastapi.responses import FileResponse
from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from sqlalchemy.future import select
import os
import logging
from sqlalchemy.exc import IntegrityError
from ..config import env_vars
from ..config import databricks_vars, env_vars, gcs_vars
import mlflow
from mlflow.exceptions import MlflowException
import tempfile

from ..utilities import (
has_access_to_inst_or_err,
Expand Down Expand Up @@ -50,6 +54,8 @@
tags=["data"],
)

LOGGER = logging.getLogger(__name__)


class BatchCreationRequest(BaseModel):
"""The Batch creation request."""
Expand Down Expand Up @@ -1322,3 +1328,70 @@ def get_training_support_overview(
except ValueError as ve:
# Return a 400 error with the specific message from ValueError
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))


@router.get("/{inst_id}/training/model-cards/{run_id}/{model_name}")
def get_model_cards(
run_id: str,
model_name: str,
inst_id: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
artifact_path: str,
) -> FileResponse:
has_access_to_inst_or_err(inst_id, current_user)
local_session.set(sql_session)
query_result = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)
if not query_result or len(query_result) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(query_result) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

try:
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
os.environ["DATABRICKS_HOST"] = w._config.host
os.environ["DATABRICKS_TOKEN"] = w._config.token
LOGGER.info("Successfully created Databricks WorkspaceClient.")
except Exception as e:
LOGGER.exception(
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
databricks_vars["DATABRICKS_HOST_URL"],
gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
raise ValueError(
f"get_model_cards(): Workspace client initialization failed: {e}"
)

try:
mlflow.set_tracking_uri("databricks")
with tempfile.TemporaryDirectory() as tmpdir:
artifact_path = f"model_card/model-card-{model_name}.pdf"
artifact_uri = f"runs:/{run_id}/{artifact_path}"
local_path = mlflow.artifacts.download_artifacts(
artifact_uri=artifact_uri, dst_path=tmpdir
)

LOGGER.debug("Artifact provisioned successfully")
return FileResponse(
path=local_path,
filename=os.path.basename(local_path),
media_type="application/pdf",
)

except MlflowException as e:
# 6. Handle errors gracefully
LOGGER.debug(f"Artifact download failed: {e}")
raise HTTPException(status_code=500, detail=f"Artifact download failed: {e}")
2 changes: 1 addition & 1 deletion src/webapp/validation_schemas/base_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"checks": []
},
"first_generation_student": {
"dtype": "string",
"dtype": "category",
"coerce": true,
"nullable": true,
"required": false,
Expand Down
6 changes: 3 additions & 3 deletions src/webapp/validation_schemas/pdp_schema_extension.json
Original file line number Diff line number Diff line change
Expand Up @@ -515,23 +515,23 @@
"checks": []
},
"years_to_associates_or_certificate_at_cohort_inst": {
"dtype": "category",
"dtype": "string",
"coerce": true,
"nullable": true,
"required": true,
"aliases": [],
"checks": []
},
"years_to_associates_or_certificate_at_other_inst": {
"dtype": "category",
"dtype": "string",
"coerce": true,
"nullable": true,
"required": true,
"aliases": [],
"checks": []
},
"years_to_bachelor_at_other_inst": {
"dtype": "category",
"dtype": "string",
"coerce": true,
"nullable": true,
"required": true,
Expand Down
Loading
Loading