|
14 | 14 | import logging |
15 | 15 | from sqlalchemy.exc import IntegrityError |
16 | 16 | from ..config import databricks_vars, env_vars, gcs_vars |
17 | | -from mlflow.exceptions import MlflowException |
18 | 17 | import tempfile |
19 | 18 | import pathlib |
20 | 19 |
|
@@ -1373,33 +1372,34 @@ def get_model_cards( |
1373 | 1372 | f"get_model_cards(): Workspace client initialization failed: {e}" |
1374 | 1373 | ) |
1375 | 1374 |
|
1376 | | - try: |
1377 | | - run_resp = w.experiments.get_run(run_id=run_id) |
1378 | | - |
1379 | | - assert run_resp.run is not None, "Expected non-None Run object" |
1380 | | - assert run_resp.run.info is not None, "Expected non-None RunInfo object" |
| 1375 | + host = w.config.host # e.g. "https://12345.gcp.databricks.com" |
1381 | 1376 |
|
1382 | | - experiment_id = run_resp.run.info.experiment_id |
| 1377 | + # 2. Build the MLflow REST endpoint URL and params |
| 1378 | + download_endpoint = f"{host}/api/2.0/mlflow/artifacts/download" |
| 1379 | + artifact_path = f"model_card/model-card-{model_name}.pdf" |
| 1380 | + params = {"run_id": run_id, "path": artifact_path} |
1383 | 1381 |
|
1384 | | - dbfs_path = ( |
1385 | | - f"/databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/" |
1386 | | - f"model_card/model-card-{model_name}.pdf" |
| 1382 | + # 3. Let WorkspaceClient’s ApiClient perform the authenticated GET |
| 1383 | + try: |
| 1384 | + # perform_query will attach the same OAuth creds that WorkspaceClient uses |
| 1385 | + resp = w.api_client.perform_query( # type: ignore[attr-defined] |
| 1386 | + method="GET", |
| 1387 | + path=download_endpoint, |
| 1388 | + query_params=params, |
| 1389 | + ) # type: ignore[attr-defined] |
| 1390 | + # resp here is the raw bytes of the PDF |
| 1391 | + except Exception as e: |
| 1392 | + raise HTTPException( |
| 1393 | + status_code=500, |
| 1394 | + detail=f"Could not download model card via MLflow REST API: {e}", |
1387 | 1395 | ) |
1388 | | - with tempfile.TemporaryDirectory() as tmpdir: |
1389 | | - local_file = pathlib.Path(tmpdir) / f"model-card-{model_name}.pdf" |
1390 | | - with w.dbfs.download( |
1391 | | - f"dbfs:{dbfs_path}" |
1392 | | - ) as stream: # DBFS API download() returns a bytes stream |
1393 | | - local_file.write_bytes(stream.read()) |
1394 | | - |
1395 | | - LOGGER.debug("Artifact provisioned successfully") |
1396 | | - return FileResponse( |
1397 | | - path=str(local_file), |
1398 | | - filename=local_file.name, |
1399 | | - media_type="application/pdf", |
1400 | | - ) |
1401 | 1396 |
|
1402 | | - except MlflowException as e: |
1403 | | - # 6. Handle errors gracefully |
1404 | | - LOGGER.debug(f"Artifact download failed: {e}") |
1405 | | - raise HTTPException(status_code=500, detail=f"Artifact download failed: {e}") |
| 1397 | + # 4. Write to a temp file and return it |
| 1398 | + with tempfile.TemporaryDirectory() as td: |
| 1399 | + out_path = pathlib.Path(td) / f"model-card-{model_name}.pdf" |
| 1400 | + out_path.write_bytes(resp) |
| 1401 | + return FileResponse( |
| 1402 | + path=str(out_path), |
| 1403 | + filename=out_path.name, |
| 1404 | + media_type="application/pdf", |
| 1405 | + ) |
0 commit comments