|
2 | 2 |
|
3 | 3 | import uuid |
4 | 4 | from datetime import datetime, date |
5 | | - |
| 5 | +from databricks.sdk import WorkspaceClient |
6 | 6 | from typing import Annotated, Any, Dict, List |
7 | 7 | from pydantic import BaseModel |
8 | 8 | from fastapi import APIRouter, Depends, HTTPException, status, Response |
| 9 | +from fastapi.responses import FileResponse |
9 | 10 | from sqlalchemy import and_, or_ |
10 | 11 | from sqlalchemy.orm import Session |
11 | 12 | from sqlalchemy.future import select |
12 | 13 | import os |
13 | 14 | import logging |
14 | 15 | from sqlalchemy.exc import IntegrityError |
15 | | -from ..config import env_vars |
| 16 | +from ..config import databricks_vars, env_vars, gcs_vars |
| 17 | +import mlflow |
| 18 | +from mlflow.exceptions import MlflowException |
| 19 | +import tempfile |
16 | 20 |
|
17 | 21 | from ..utilities import ( |
18 | 22 | has_access_to_inst_or_err, |
|
50 | 54 | tags=["data"], |
51 | 55 | ) |
52 | 56 |
|
| 57 | +LOGGER = logging.getLogger(__name__) |
| 58 | + |
53 | 59 |
|
54 | 60 | class BatchCreationRequest(BaseModel): |
55 | 61 | """The Batch creation request.""" |
@@ -1322,3 +1328,70 @@ def get_training_support_overview( |
1322 | 1328 | except ValueError as ve: |
1323 | 1329 | # Return a 400 error with the specific message from ValueError |
1324 | 1330 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) |
| 1331 | + |
| 1332 | + |
| 1333 | +@router.get("/{inst_id}/training/model-cards/{run_id}/{model_name}") |
| 1334 | +def get_model_cards( |
| 1335 | + run_id: str, |
| 1336 | + model_name: str, |
| 1337 | + inst_id: str, |
| 1338 | + current_user: Annotated[BaseUser, Depends(get_current_active_user)], |
| 1339 | + sql_session: Annotated[Session, Depends(get_session)], |
| 1340 | + artifact_path: str, |
| 1341 | +) -> FileResponse: |
| 1342 | + has_access_to_inst_or_err(inst_id, current_user) |
| 1343 | + local_session.set(sql_session) |
| 1344 | + query_result = ( |
| 1345 | + local_session.get() |
| 1346 | + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) |
| 1347 | + .all() |
| 1348 | + ) |
| 1349 | + if not query_result or len(query_result) == 0: |
| 1350 | + raise HTTPException( |
| 1351 | + status_code=status.HTTP_404_NOT_FOUND, |
| 1352 | + detail="Institution not found.", |
| 1353 | + ) |
| 1354 | + if len(query_result) > 1: |
| 1355 | + raise HTTPException( |
| 1356 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 1357 | + detail="Institution duplicates found.", |
| 1358 | + ) |
| 1359 | + |
| 1360 | + try: |
| 1361 | + w = WorkspaceClient( |
| 1362 | + host=databricks_vars["DATABRICKS_HOST_URL"], |
| 1363 | + google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], |
| 1364 | + ) |
| 1365 | + os.environ["DATABRICKS_HOST"] = w._config.host |
| 1366 | + os.environ["DATABRICKS_TOKEN"] = w._config.token |
| 1367 | + LOGGER.info("Successfully created Databricks WorkspaceClient.") |
| 1368 | + except Exception as e: |
| 1369 | + LOGGER.exception( |
| 1370 | + "Failed to create Databricks WorkspaceClient with host: %s and service account: %s", |
| 1371 | + databricks_vars["DATABRICKS_HOST_URL"], |
| 1372 | + gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], |
| 1373 | + ) |
| 1374 | + raise ValueError( |
| 1375 | + f"get_model_cards(): Workspace client initialization failed: {e}" |
| 1376 | + ) |
| 1377 | + |
| 1378 | + try: |
| 1379 | + mlflow.set_tracking_uri("databricks") |
| 1380 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 1381 | + artifact_path = f"model_card/model-card-{model_name}.pdf" |
| 1382 | + artifact_uri = f"runs:/{run_id}/{artifact_path}" |
| 1383 | + local_path = mlflow.artifacts.download_artifacts( |
| 1384 | + artifact_uri=artifact_uri, dst_path=tmpdir |
| 1385 | + ) |
| 1386 | + |
| 1387 | + LOGGER.debug("Artifact provisioned successfully") |
| 1388 | + return FileResponse( |
| 1389 | + path=local_path, |
| 1390 | + filename=os.path.basename(local_path), |
| 1391 | + media_type="application/pdf", |
| 1392 | + ) |
| 1393 | + |
| 1394 | + except MlflowException as e: |
| 1395 | + # 6. Handle errors gracefully |
| 1396 | + LOGGER.debug(f"Artifact download failed: {e}") |
| 1397 | + raise HTTPException(status_code=500, detail=f"Artifact download failed: {e}") |
0 commit comments