diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 5820415e..f8580d40 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -37,6 +37,11 @@ # The name of the deployed pipeline in Databricks. Must match directly. PDP_INFERENCE_JOB_NAME = "edvise_github_sourced_pdp_inference_pipeline" +VALID_BRONZE_FILE_RE = re.compile( + r"^[a-z0-9]+pdp_[a-z0-9]+_(course_level_)?ar_.*\.csv$", + re.IGNORECASE, +) + class DatabricksInferenceRunRequest(BaseModel): """Databricks parameters for an inference run.""" @@ -181,6 +186,96 @@ def setup_new_inst(self, inst_name: str) -> None: exist_ok=True, ) + def list_bronze_volume_csvs(self, inst_name: str) -> list[str]: + """List `.csv` files directly under the institution's bronze volume root.""" + if not databricks_vars.get("DATABRICKS_HOST_URL") or not databricks_vars.get( + "CATALOG_NAME" + ): + raise ValueError("Databricks integration not configured.") + if not gcs_vars.get("GCP_SERVICE_ACCOUNT_EMAIL"): + raise ValueError("GCP service account email not configured.") + + try: + w = WorkspaceClient( + host=databricks_vars["DATABRICKS_HOST_URL"], + google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], + ) + except Exception as e: + LOGGER.exception( + "Failed to create Databricks WorkspaceClient with host: %s and service account: %s", + databricks_vars.get("DATABRICKS_HOST_URL"), + gcs_vars.get("GCP_SERVICE_ACCOUNT_EMAIL"), + ) + raise ValueError(f"Workspace client creation failed: {e}") + + db_inst_name = databricksify_inst_name(inst_name) + volume_root = ( + f"/Volumes/{databricks_vars['CATALOG_NAME']}/" + f"{db_inst_name}_bronze/bronze_volume" + ) + + try: + entries = list(w.dbfs.list(f"dbfs:{volume_root}") or []) + except Exception as e: + LOGGER.exception("Failed to list bronze volume directory: %s", volume_root) + raise ValueError(f"Failed to list bronze volume directory: {e}") + + csvs: list[str] = [] + for entry in entries: + entry_path = getattr(entry, "path", None) + is_dir = getattr(entry, "is_dir", False) + if not entry_path or is_dir: + continue + basename = os.path.basename(str(entry_path)) + if not VALID_BRONZE_FILE_RE.match(basename): + continue + csvs.append(basename) + csvs.sort() + return csvs + + def download_bronze_volume_file(self, inst_name: str, file_name: str) -> Any: + """Download a file from the institution's bronze volume root and return a byte stream.""" + if "/" in file_name: + raise ValueError("file_name must not contain '/'.") + if not VALID_BRONZE_FILE_RE.match(file_name): + raise ValueError("Invalid bronze dataset filename.") + if not databricks_vars.get("DATABRICKS_HOST_URL") or not databricks_vars.get( + "CATALOG_NAME" + ): + raise ValueError("Databricks integration not configured.") + if not gcs_vars.get("GCP_SERVICE_ACCOUNT_EMAIL"): + raise ValueError("GCP service account email not configured.") + + try: + w = WorkspaceClient( + host=databricks_vars["DATABRICKS_HOST_URL"], + google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], + ) + except Exception as e: + LOGGER.exception( + "Failed to create Databricks WorkspaceClient with host: %s and service account: %s", + databricks_vars.get("DATABRICKS_HOST_URL"), + gcs_vars.get("GCP_SERVICE_ACCOUNT_EMAIL"), + ) + raise ValueError(f"Workspace client creation failed: {e}") + + db_inst_name = databricksify_inst_name(inst_name) + volume_path = ( + f"/Volumes/{databricks_vars['CATALOG_NAME']}/" + f"{db_inst_name}_bronze/bronze_volume/{file_name}" + ) + + try: + response = w.files.download(volume_path) + except Exception as e: + LOGGER.exception("Failed to download from %s", volume_path) + raise ValueError(f"Failed to download bronze dataset: {e}") + + stream = getattr(response, "contents", None) + if stream is None: + raise ValueError("Databricks download returned no contents.") + return stream + # Note that for each unique PIPELINE, we'll need a new function, this is by nature of how unique pipelines # may have unique parameters and would have a unique name (i.e. the name field specified in w.jobs.list()). But any run of a given pipeline (even across institutions) can use the same function. # E.g. there is one PDP inference pipeline, so one PDP inference function here. diff --git a/src/webapp/gcsutil.py b/src/webapp/gcsutil.py index ea64f396..fdcf6c90 100644 --- a/src/webapp/gcsutil.py +++ b/src/webapp/gcsutil.py @@ -4,7 +4,7 @@ import logging import os import tempfile -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, IO import pandas as pd from pydantic import BaseModel @@ -156,6 +156,28 @@ def generate_download_signed_url(self, bucket_name: str, blob_name: str) -> Any: ) return url + def upload_unvalidated_csv_from_file( + self, bucket_name: str, file_name: str, file_obj: IO[bytes] + ) -> None: + """Upload a CSV into unvalidated/ while enforcing no-overwrite semantics.""" + if not file_name or not file_name.strip(): + raise ValueError("file_name is required and must be non-empty.") + if "/" in file_name: + raise ValueError("file_name must not contain '/'.") + + client = storage.Client() + bucket = client.bucket(bucket_name) + if not bucket.exists(): + raise ValueError("Storage bucket not found.") + + for prefix in ("unvalidated/", "validated/"): + blob = bucket.blob(prefix + file_name) + if blob.exists(): + raise ValueError("File already exists.") + + blob = bucket.blob("unvalidated/" + file_name) + blob.upload_from_file(file_obj, content_type="text/csv") + def delete_bucket(self, bucket_name: str) -> None: """Delete a given bucket.""" storage_client = storage.Client() diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 811a25ca..bcff4d13 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -12,6 +12,7 @@ import logging from sqlalchemy.exc import IntegrityError import re +import requests from ..validation import HardValidationError from ..validation_error_formatter import format_validation_error import pandas as pd @@ -180,6 +181,33 @@ class ValidationResult(BaseModel): source: str +class BronzeImportRequest(BaseModel): + """Request to import a dataset from the institution's bronze volume into GCS.""" + + name: str + + +class BronzeImportResponse(BaseModel): + """Response for bronze import request.""" + + file_name: str + message: str + + +def _upload_file_bytes_to_signed_url(file_bytes: bytes, upload_signed_url: str) -> None: + """Upload file bytes to a signed GCS URL using the same request shape as the worker path.""" + upload_response = requests.put( + upload_signed_url, + data=file_bytes, + headers={"Content-Type": "text/csv"}, + timeout=600, + ) + if upload_response.status_code != 200: + raise requests.RequestException( + f"{upload_response.status_code} {upload_response.text}" + ) + + class DataOverview(BaseModel): """All data for a given institution (batches and files).""" @@ -1812,6 +1840,127 @@ def get_upload_url( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) +@router.get("/{inst_id}/input/bronze-datasets", response_model=list[str]) +def list_bronze_datasets( + inst_id: str, + current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], + databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)], +) -> Any: + """List `.csv` files directly under the institution's Databricks bronze volume root.""" + has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + + inst = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .scalar_one_or_none() + ) + if inst is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + + try: + return databricks_control.list_bronze_volume_csvs(inst.name) + except ValueError as ve: + msg = str(ve) + if "not configured" in msg.lower(): + raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=msg) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=msg + ) + + +@router.post( + "/{inst_id}/input/upload-from-volume-to-gcs-bucket", + response_model=BronzeImportResponse, +) +def upload_from_volume_to_gcs_bucket( + inst_id: str, + req: BronzeImportRequest, + current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], + storage_control: Annotated[StorageControl, Depends(StorageControl)], + databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)], +) -> Any: + """Import a selected dataset from the institution's bronze volume into GCS unvalidated/.""" + has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + + inst = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .scalar_one_or_none() + ) + if inst is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Institution not found.", + ) + + requested_name = (req.name or "").strip() + if not requested_name: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Dataset name is required.", + ) + if "/" in requested_name: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Dataset name can't contain '/'.", + ) + + # Ensure this is actually present in the bronze root (and matches naming rules). + try: + available = databricks_control.list_bronze_volume_csvs(inst.name) + except ValueError as ve: + msg = str(ve) + if "not configured" in msg.lower(): + raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=msg) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=msg + ) + + available_map = {x.lower(): x for x in available} + file_name = available_map.get(requested_name.lower()) + if not file_name: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Bronze dataset not found.", + ) + + stream = None + try: + stream = databricks_control.download_bronze_volume_file(inst.name, file_name) + file_bytes = stream.read() + upload_url = storage_control.generate_upload_signed_url( + get_external_bucket_name(inst_id), file_name + ) + _upload_file_bytes_to_signed_url(file_bytes, upload_url) + except ValueError as ve: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) + except requests.RequestException as rexc: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to upload dataset to GCS: {rexc}", + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Unexpected error importing dataset: {e}", + ) + finally: + if stream is not None and hasattr(stream, "close"): + try: + stream.close() + except Exception: + pass + + return {"file_name": file_name, "message": "Upload successful."} + + @router.post("/{inst_id}/add-custom-school-job/{job_run_id}") def add_custom_school_job( inst_id: str, diff --git a/src/webapp/routers/data_test.py b/src/webapp/routers/data_test.py index 5b311e3a..715d3489 100644 --- a/src/webapp/routers/data_test.py +++ b/src/webapp/routers/data_test.py @@ -1,5 +1,6 @@ """Test file for the data.py file and constituent API functions.""" +import io import uuid import time from unittest import mock @@ -28,7 +29,12 @@ Base, get_session, ) -from ..utilities import uuid_to_str, get_current_active_user, SchemaType +from ..utilities import ( + uuid_to_str, + get_current_active_user, + SchemaType, + get_external_bucket_name, +) from .data import ( router, DataOverview, @@ -38,8 +44,10 @@ ) from fastapi import HTTPException from ..gcsutil import StorageControl +from ..databricks import DatabricksControl MOCK_STORAGE = mock.Mock() +MOCK_DATABRICKS = mock.Mock() UUID_2 = uuid.UUID("9bcbc782-2e71-4441-afa2-7a311024a5ec") FILE_UUID_1 = uuid.UUID("f0bb3a20-6d92-4254-afed-6a72f43c562a") @@ -212,10 +220,14 @@ def get_current_active_user_override(): def storage_control_override(): return MOCK_STORAGE + def databricks_control_override(): + return MOCK_DATABRICKS + app.include_router(router) app.dependency_overrides[get_session] = get_session_override app.dependency_overrides[get_current_active_user] = get_current_active_user_override app.dependency_overrides[StorageControl] = storage_control_override + app.dependency_overrides[DatabricksControl] = databricks_control_override client = TestClient(app) yield client @@ -289,6 +301,73 @@ def test_read_inst_all_input_files(client: TestClient) -> Any: ) +def test_list_bronze_datasets(client: TestClient) -> Any: + """Test GET /institutions//input/bronze-datasets.""" + MOCK_DATABRICKS.reset_mock() + MOCK_DATABRICKS.list_bronze_volume_csvs.return_value = ["a.csv", "b.csv"] + + response = client.get( + "/institutions/" + uuid_to_str(UUID_INVALID) + "/input/bronze-datasets" + ) + assert response.status_code == 401 + + response = client.get( + "/institutions/" + uuid_to_str(USER_VALID_INST_UUID) + "/input/bronze-datasets" + ) + assert response.status_code == 200 + assert response.json() == ["a.csv", "b.csv"] + MOCK_DATABRICKS.list_bronze_volume_csvs.assert_called_with("school_1") + + +def test_upload_from_volume_to_gcs_bucket(client: TestClient) -> Any: + """Test POST /institutions//input/upload-from-volume-to-gcs-bucket.""" + MOCK_DATABRICKS.reset_mock() + MOCK_STORAGE.reset_mock() + + response = client.post( + "/institutions/" + + uuid_to_str(UUID_INVALID) + + "/input/upload-from-volume-to-gcs-bucket", + json={"name": "file.csv"}, + ) + assert response.status_code == 401 + + MOCK_DATABRICKS.list_bronze_volume_csvs.return_value = ["file.csv"] + MOCK_DATABRICKS.download_bronze_volume_file.return_value = io.BytesIO( + b"col1,col2\n1,2\n" + ) + MOCK_STORAGE.generate_upload_signed_url.return_value = "https://signed.example" + + with mock.patch("src.webapp.routers.data.requests.put") as mock_put: + mock_put.return_value.status_code = 200 + mock_put.return_value.text = "" + response = client.post( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/input/upload-from-volume-to-gcs-bucket", + json={"name": "file.csv"}, + ) + assert response.status_code == 200 + assert response.json() == { + "file_name": "file.csv", + "message": "Upload successful.", + } + + MOCK_DATABRICKS.list_bronze_volume_csvs.assert_called_with("school_1") + MOCK_DATABRICKS.download_bronze_volume_file.assert_called_with( + "school_1", "file.csv" + ) + MOCK_STORAGE.generate_upload_signed_url.assert_called_with( + get_external_bucket_name(uuid_to_str(USER_VALID_INST_UUID)), "file.csv" + ) + mock_put.assert_called_with( + "https://signed.example", + data=b"col1,col2\n1,2\n", + headers={"Content-Type": "text/csv"}, + timeout=600, + ) + + def test_read_inst_all_output_files(client: TestClient) -> Any: """Test GET /institutions//output.""" MOCK_STORAGE.list_blobs_in_folder.return_value = []