Skip to content

Commit e0b09b7

Browse files
authored
Merge pull request #127 from datakind/ModelCardEndpoint
fixed endpoint issues
2 parents 355bce2 + 18b34a6 commit e0b09b7

1 file changed

Lines changed: 20 additions & 12 deletions

File tree

src/webapp/routers/data.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
import logging
1515
from sqlalchemy.exc import IntegrityError
1616
from ..config import databricks_vars, env_vars, gcs_vars
17-
import mlflow
1817
from mlflow.exceptions import MlflowException
1918
import tempfile
19+
import pathlib
2020

2121
from ..utilities import (
2222
has_access_to_inst_or_err,
@@ -1337,7 +1337,6 @@ def get_model_cards(
13371337
inst_id: str,
13381338
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
13391339
sql_session: Annotated[Session, Depends(get_session)],
1340-
artifact_path: str,
13411340
) -> FileResponse:
13421341
has_access_to_inst_or_err(inst_id, current_user)
13431342
local_session.set(sql_session)
@@ -1362,8 +1361,7 @@ def get_model_cards(
13621361
host=databricks_vars["DATABRICKS_HOST_URL"],
13631362
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
13641363
)
1365-
os.environ["DATABRICKS_HOST"] = w._config.host
1366-
os.environ["DATABRICKS_TOKEN"] = w._config.token
1364+
13671365
LOGGER.info("Successfully created Databricks WorkspaceClient.")
13681366
except Exception as e:
13691367
LOGGER.exception(
@@ -1376,18 +1374,28 @@ def get_model_cards(
13761374
)
13771375

13781376
try:
1379-
mlflow.set_tracking_uri("databricks")
1377+
run_resp = w.experiments.get_run(run_id=run_id)
1378+
1379+
assert run_resp.run is not None, "Expected non-None Run object"
1380+
assert run_resp.run.info is not None, "Expected non-None RunInfo object"
1381+
1382+
experiment_id = run_resp.run.info.experiment_id
1383+
1384+
dbfs_path = (
1385+
f"/databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/"
1386+
f"model_card/model-card-{model_name}.pdf"
1387+
)
13801388
with tempfile.TemporaryDirectory() as tmpdir:
1381-
artifact_path = f"model_card/model-card-{model_name}.pdf"
1382-
artifact_uri = f"runs:/{run_id}/{artifact_path}"
1383-
local_path = mlflow.artifacts.download_artifacts(
1384-
artifact_uri=artifact_uri, dst_path=tmpdir
1385-
)
1389+
local_file = pathlib.Path(tmpdir) / f"model-card-{model_name}.pdf"
1390+
with w.dbfs.download(
1391+
f"dbfs:{dbfs_path}"
1392+
) as stream: # DBFS API download() returns a bytes stream
1393+
local_file.write_bytes(stream.read())
13861394

13871395
LOGGER.debug("Artifact provisioned successfully")
13881396
return FileResponse(
1389-
path=local_path,
1390-
filename=os.path.basename(local_path),
1397+
path=str(local_file),
1398+
filename=local_file.name,
13911399
media_type="application/pdf",
13921400
)
13931401

0 commit comments

Comments
 (0)