diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index cb5dfcf1..2f230587 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -14,7 +14,6 @@ import logging from sqlalchemy.exc import IntegrityError from ..config import databricks_vars, env_vars, gcs_vars -from mlflow.exceptions import MlflowException import tempfile import pathlib @@ -1373,33 +1372,34 @@ def get_model_cards( f"get_model_cards(): Workspace client initialization failed: {e}" ) - try: - run_resp = w.experiments.get_run(run_id=run_id) - - assert run_resp.run is not None, "Expected non-None Run object" - assert run_resp.run.info is not None, "Expected non-None RunInfo object" + host = w.config.host # e.g. "https://12345.gcp.databricks.com" - experiment_id = run_resp.run.info.experiment_id + # 2. Build the MLflow REST endpoint URL and params + download_endpoint = f"{host}/api/2.0/mlflow/artifacts/download" + artifact_path = f"model_card/model-card-{model_name}.pdf" + params = {"run_id": run_id, "path": artifact_path} - dbfs_path = ( - f"/databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/" - f"model_card/model-card-{model_name}.pdf" + # 3. Let WorkspaceClient’s ApiClient perform the authenticated GET + try: + # perform_query will attach the same OAuth creds that WorkspaceClient uses + resp = w.api_client.perform_query( # type: ignore[attr-defined] + method="GET", + path=download_endpoint, + query_params=params, + ) # type: ignore[attr-defined] + # resp here is the raw bytes of the PDF + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Could not download model card via MLflow REST API: {e}", ) - with tempfile.TemporaryDirectory() as tmpdir: - local_file = pathlib.Path(tmpdir) / f"model-card-{model_name}.pdf" - with w.dbfs.download( - f"dbfs:{dbfs_path}" - ) as stream: # DBFS API download() returns a bytes stream - local_file.write_bytes(stream.read()) - - LOGGER.debug("Artifact provisioned successfully") - return FileResponse( - path=str(local_file), - filename=local_file.name, - media_type="application/pdf", - ) - except MlflowException as e: - # 6. Handle errors gracefully - LOGGER.debug(f"Artifact download failed: {e}") - raise HTTPException(status_code=500, detail=f"Artifact download failed: {e}") + # 4. Write to a temp file and return it + with tempfile.TemporaryDirectory() as td: + out_path = pathlib.Path(td) / f"model-card-{model_name}.pdf" + out_path.write_bytes(resp) + return FileResponse( + path=str(out_path), + filename=out_path.name, + media_type="application/pdf", + )