Skip to content

Commit d31c826

Browse files
authored
Merge pull request #130 from datakind/develop
model cards optimization
2 parents f71e9aa + 8f2f06a commit d31c826

1 file changed

Lines changed: 27 additions & 27 deletions

File tree

src/webapp/routers/data.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import logging
1515
from sqlalchemy.exc import IntegrityError
1616
from ..config import databricks_vars, env_vars, gcs_vars
17-
from mlflow.exceptions import MlflowException
1817
import tempfile
1918
import pathlib
2019

@@ -1373,33 +1372,34 @@ def get_model_cards(
13731372
f"get_model_cards(): Workspace client initialization failed: {e}"
13741373
)
13751374

1376-
try:
1377-
run_resp = w.experiments.get_run(run_id=run_id)
1378-
1379-
assert run_resp.run is not None, "Expected non-None Run object"
1380-
assert run_resp.run.info is not None, "Expected non-None RunInfo object"
1375+
host = w.config.host # e.g. "https://12345.gcp.databricks.com"
13811376

1382-
experiment_id = run_resp.run.info.experiment_id
1377+
# 2. Build the MLflow REST endpoint URL and params
1378+
download_endpoint = f"{host}/api/2.0/mlflow/artifacts/download"
1379+
artifact_path = f"model_card/model-card-{model_name}.pdf"
1380+
params = {"run_id": run_id, "path": artifact_path}
13831381

1384-
dbfs_path = (
1385-
f"/databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/"
1386-
f"model_card/model-card-{model_name}.pdf"
1382+
# 3. Let WorkspaceClient’s ApiClient perform the authenticated GET
1383+
try:
1384+
# perform_query will attach the same OAuth creds that WorkspaceClient uses
1385+
resp = w.api_client.perform_query( # type: ignore[attr-defined]
1386+
method="GET",
1387+
path=download_endpoint,
1388+
query_params=params,
1389+
) # type: ignore[attr-defined]
1390+
# resp here is the raw bytes of the PDF
1391+
except Exception as e:
1392+
raise HTTPException(
1393+
status_code=500,
1394+
detail=f"Could not download model card via MLflow REST API: {e}",
13871395
)
1388-
with tempfile.TemporaryDirectory() as tmpdir:
1389-
local_file = pathlib.Path(tmpdir) / f"model-card-{model_name}.pdf"
1390-
with w.dbfs.download(
1391-
f"dbfs:{dbfs_path}"
1392-
) as stream: # DBFS API download() returns a bytes stream
1393-
local_file.write_bytes(stream.read())
1394-
1395-
LOGGER.debug("Artifact provisioned successfully")
1396-
return FileResponse(
1397-
path=str(local_file),
1398-
filename=local_file.name,
1399-
media_type="application/pdf",
1400-
)
14011396

1402-
except MlflowException as e:
1403-
# 6. Handle errors gracefully
1404-
LOGGER.debug(f"Artifact download failed: {e}")
1405-
raise HTTPException(status_code=500, detail=f"Artifact download failed: {e}")
1397+
# 4. Write to a temp file and return it
1398+
with tempfile.TemporaryDirectory() as td:
1399+
out_path = pathlib.Path(td) / f"model-card-{model_name}.pdf"
1400+
out_path.write_bytes(resp)
1401+
return FileResponse(
1402+
path=str(out_path),
1403+
filename=out_path.name,
1404+
media_type="application/pdf",
1405+
)

0 commit comments

Comments
 (0)