diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index aa4b6b02..75f417c7 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -2,7 +2,7 @@ import os import logging -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from databricks.sdk import WorkspaceClient from databricks.sdk.service import catalog from databricks.sdk.service.sql import ( @@ -31,12 +31,79 @@ # List of data medallion levels MEDALLION_LEVELS = ["silver", "gold", "bronze"] -# The name of the deployed pipeline in Databricks. Must match directly. +# The name of the deployed pipeline in Databricks. Must match the job's `name` in that workspace. +# Override with LEGACY_INFERENCE_JOB_NAME (and PDP_INFERENCE_JOB_NAME) when dev/staging deploy +# uses a different bundle target or a stub job that matches the same parameters. PDP_INFERENCE_JOB_NAME = "edvise_github_sourced_pdp_inference_pipeline" +LEGACY_INFERENCE_JOB_NAME = "edvise_github_sourced_legacy_inference_pipeline" -class DatabricksInferenceRunRequest(BaseModel): - """Databricks parameters for an inference run.""" +def _pdp_inference_job_name() -> str: + name = os.environ.get("PDP_INFERENCE_JOB_NAME", "").strip() + return name or PDP_INFERENCE_JOB_NAME + + +def _legacy_inference_job_name() -> str: + name = os.environ.get("LEGACY_INFERENCE_JOB_NAME", "").strip() + return name or LEGACY_INFERENCE_JOB_NAME + + +def _resolve_pipeline_job(w: Any, pipeline_type: str, caller_label: str) -> Any: + """Find a job by exact name, else by unique substring match on display name. + + Development bundles often prefix job names (e.g. ``[dev vishakh] edvise_...``) while + the API passes the canonical base name. Optional env ``PDP_INFERENCE_JOB_NAME`` / + ``LEGACY_INFERENCE_JOB_NAME`` still wins when set to the full exact name. + """ + job = next(w.jobs.list(name=pipeline_type), None) + if job is not None and getattr(job, "job_id", None) is not None: + LOGGER.info( + "%s: resolved job by exact name %r (job_id=%s)", + caller_label, + pipeline_type, + job.job_id, + ) + return job + + matches: list[tuple[str, Any]] = [] + for j in w.jobs.list(): + settings = getattr(j, "settings", None) + name = getattr(settings, "name", None) if settings is not None else None + if not name: + continue + if pipeline_type in name: + matches.append((name, j)) + + if len(matches) == 1: + picked_name, picked = matches[0] + if getattr(picked, "job_id", None) is None: + raise ValueError( + f"{caller_label}: Job name {picked_name!r} matched substring {pipeline_type!r} but has no job_id." + ) + LOGGER.info( + "%s: resolved job by substring %r -> display name %r (job_id=%s)", + caller_label, + pipeline_type, + picked_name, + picked.job_id, + ) + return picked + + if len(matches) > 1: + names = [n for n, _ in matches] + raise ValueError( + f"{caller_label}: Multiple jobs match substring {pipeline_type!r}: {names}. " + "Set PDP_INFERENCE_JOB_NAME or LEGACY_INFERENCE_JOB_NAME to the full job name." + ) + + raise ValueError( + f"{caller_label}: Job {pipeline_type!r} was not found (exact name or unique substring of settings.name) " + f"for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'." + ) + + +class DatabricksPDPInferenceRunRequest(BaseModel): + """Databricks parameters for a PDP inference run.""" inst_name: str # Note that the following should be the filepath. @@ -47,6 +114,24 @@ class DatabricksInferenceRunRequest(BaseModel): gcp_external_bucket_name: str +class DatabricksLegacyInferenceRunRequest(BaseModel): + """Databricks parameters for a legacy schools inference run.""" + + inst_name: str + model_name: str + config_file_name: str = "" + features_table_name: str = "" + # The email where notifications will get sent. + email: str = "" + gcp_external_bucket_name: str + + @field_validator("config_file_name", "features_table_name", "email", mode="before") + @classmethod + def _none_to_empty_str(cls, v: object) -> object: + """Allow callers to omit or pass null; Databricks job treats empty like YAML defaults.""" + return "" if v is None else v + + class DatabricksInferenceRunResponse(BaseModel): """Databricks parameters for an inference run.""" @@ -183,7 +268,7 @@ def setup_new_inst(self, inst_name: str) -> None: # E.g. there is one PDP inference pipeline, so one PDP inference function here. def run_pdp_inference( - self, req: DatabricksInferenceRunRequest + self, req: DatabricksPDPInferenceRunRequest ) -> DatabricksInferenceRunResponse: """Triggers PDP inference Databricks run.""" LOGGER.info(f"Running PDP inference for institution: {req.inst_name}") @@ -213,14 +298,10 @@ def run_pdp_inference( ) db_inst_name = databricksify_inst_name(req.inst_name) - pipeline_type = PDP_INFERENCE_JOB_NAME + pipeline_type = _pdp_inference_job_name() try: - job = next(w.jobs.list(name=pipeline_type), None) - if not job or job.job_id is None: - raise ValueError( - f"run_pdp_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'." - ) + job = _resolve_pipeline_job(w, pipeline_type, "run_pdp_inference") job_id = job.job_id LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}") except Exception as e: @@ -261,6 +342,69 @@ def run_pdp_inference( return DatabricksInferenceRunResponse(job_run_id=run_id) + def run_legacy_inference( + self, req: DatabricksLegacyInferenceRunRequest + ) -> DatabricksInferenceRunResponse: + """Triggers legacy schools inference Databricks run.""" + LOGGER.info(f"Running legacy inference for institution: {req.inst_name}") + try: + w = WorkspaceClient( + host=databricks_vars["DATABRICKS_HOST_URL"], + google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], + ) + LOGGER.info("Successfully created Databricks WorkspaceClient.") + except Exception as e: + LOGGER.exception( + "Failed to create Databricks WorkspaceClient with host: %s and service account: %s", + databricks_vars["DATABRICKS_HOST_URL"], + gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], + ) + raise ValueError( + f"run_legacy_inference(): Workspace client initialization failed: {e}" + ) + + db_inst_name = databricksify_inst_name(req.inst_name) + pipeline_type = _legacy_inference_job_name() + + try: + job = _resolve_pipeline_job(w, pipeline_type, "run_legacy_inference") + job_id = job.job_id + LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}") + except Exception as e: + LOGGER.exception(f"Job lookup failed for '{pipeline_type}'.") + raise ValueError(f"run_legacy_inference(): Failed to find job: {e}") + + try: + run_job: Any = w.jobs.run_now( + job_id, + job_parameters={ + "databricks_institution_name": db_inst_name, + "DB_workspace": databricks_vars["DATABRICKS_WORKSPACE"], + "model_name": req.model_name, + "config_file_name": req.config_file_name, + "features_table_name": req.features_table_name, + "gcp_bucket_name": req.gcp_external_bucket_name, + "datakind_notification_email": req.email, + "DK_CC_EMAIL": req.email, + }, + ) + LOGGER.info( + f"Successfully triggered job run. Run ID: {run_job.response.run_id}" + ) + except Exception as e: + LOGGER.exception("Failed to run the legacy inference job.") + raise ValueError(f"run_legacy_inference(): Job could not be run: {e}") + + if not run_job.response or run_job.response.run_id is None: + raise ValueError( + "run_legacy_inference(): Job did not return a valid run_id." + ) + + run_id = run_job.response.run_id + LOGGER.info(f"Successfully triggered job run. Run ID: {run_id}") + + return DatabricksInferenceRunResponse(job_run_id=run_id) + def delete_inst(self, inst_name: str) -> None: """Cleanup tasks required on the Databricks side to delete an institution.""" db_inst_name = databricksify_inst_name(inst_name) diff --git a/src/webapp/databricks_test.py b/src/webapp/databricks_test.py index 37fa2855..025e741e 100644 --- a/src/webapp/databricks_test.py +++ b/src/webapp/databricks_test.py @@ -1,6 +1,7 @@ import pytest +from unittest.mock import MagicMock -from .databricks import DatabricksControl +from .databricks import DatabricksControl, _resolve_pipeline_job @pytest.fixture @@ -51,3 +52,60 @@ def test_invalid_regex_is_ignored(ctrl): def test_returns_none_when_no_match(ctrl): mapping = {"student": "student.csv"} assert ctrl.get_key_for_file(mapping, "unknown.csv") is None + + +def _job_named(full_name: str, job_id: int = 42) -> MagicMock: + j = MagicMock() + j.job_id = job_id + j.settings = MagicMock() + j.settings.name = full_name + return j + + +def test_resolve_pipeline_job_exact_match_skips_scan(): + canonical = "edvise_github_sourced_pdp_inference_pipeline" + hit = _job_named(canonical, job_id=7) + + def list_jobs(name=None): + if name == canonical: + return iter([hit]) + return iter([]) + + w = MagicMock() + w.jobs.list.side_effect = list_jobs + + assert _resolve_pipeline_job(w, canonical, "test").job_id == 7 + w.jobs.list.assert_called_once() + + +def test_resolve_pipeline_job_substring_dev_prefix(): + canonical = "edvise_github_sourced_pdp_inference_pipeline" + hit = _job_named(f"[dev vishakh] {canonical}", job_id=11) + + def list_jobs(name=None): + if name is not None: + return iter([]) + return iter([hit]) + + w = MagicMock() + w.jobs.list.side_effect = list_jobs + + assert _resolve_pipeline_job(w, canonical, "test").job_id == 11 + assert w.jobs.list.call_count == 2 + + +def test_resolve_pipeline_job_ambiguous_substring_raises(): + canonical = "edvise_github_sourced_pdp_inference_pipeline" + a = _job_named(f"[dev a] {canonical}", job_id=1) + b = _job_named(f"[dev b] {canonical}", job_id=2) + + def list_jobs(name=None): + if name is not None: + return iter([]) + return iter([a, b]) + + w = MagicMock() + w.jobs.list.side_effect = list_jobs + + with pytest.raises(ValueError, match="Multiple jobs match substring"): + _resolve_pipeline_job(w, canonical, "test") diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 88092ee2..2fc841b9 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -5,7 +5,7 @@ from typing import Annotated, Any, Dict, List, Optional, Tuple, Union, cast from pydantic import BaseModel, Field from fastapi import APIRouter, Depends, HTTPException, status, Response -from sqlalchemy import and_, or_ +from sqlalchemy import and_, false, or_ from sqlalchemy.orm import Session from sqlalchemy.future import select import os @@ -29,6 +29,8 @@ DataSource, get_external_bucket_name, decode_url_piece, + expand_batch_file_name_lookups, + file_name_variants_for_lookup, ) from ..database import ( @@ -749,9 +751,16 @@ def create_batch( inst_id=str_to_uuid(inst_id), created_by=str_to_uuid(current_user.user_id), # type: ignore ) - f_names = [] if not req.file_names else req.file_names + f_names = [] if not req.file_names else list(req.file_names) f_ids = [] if not req.file_ids else strs_to_uuids(req.file_ids) - print(f"File names: {f_names}, File Ids: {f_ids}") + file_match_parts: List[Any] = [] + if f_ids: + file_match_parts.append(FileTable.id.in_(f_ids)) + if f_names: + file_match_parts.append( + FileTable.name.in_(expand_batch_file_name_lookups(f_names)) + ) + file_clause = or_(*file_match_parts) if file_match_parts else false() # Check that the files requested for this batch exists. # Only valid non-sst generated files can be added to a batch at creation time. query_result_files = ( @@ -759,10 +768,7 @@ def create_batch( .execute( select(FileTable).where( and_( - or_( - FileTable.id.in_(f_ids), - FileTable.name.in_(f_names), - ), + file_clause, FileTable.inst_id == str_to_uuid(inst_id), FileTable.valid == True, FileTable.sst_generated == False, @@ -904,12 +910,15 @@ def update_batch( if "file_names" in update_data_req: for f in update_data_req["file_names"]: # Check that the files requested for this batch exists + name_variants = list(file_name_variants_for_lookup(f)) query_result_file = ( local_session.get() .execute( select(FileTable).where( and_( - FileTable.name == f, + FileTable.name.in_(name_variants) + if name_variants + else false(), FileTable.inst_id == str_to_uuid(inst_id), ) ) diff --git a/src/webapp/routers/models.py b/src/webapp/routers/models.py index 02be74ae..06eb0f78 100644 --- a/src/webapp/routers/models.py +++ b/src/webapp/routers/models.py @@ -8,7 +8,11 @@ from sqlalchemy import and_, update, or_ from sqlalchemy.orm import Session from sqlalchemy.future import select -from ..databricks import DatabricksControl, DatabricksInferenceRunRequest +from ..databricks import ( + DatabricksControl, + DatabricksPDPInferenceRunRequest, + DatabricksLegacyInferenceRunRequest, +) from ..utilities import ( has_access_to_inst_or_err, has_full_data_access_or_err, @@ -138,6 +142,9 @@ class InferenceRunRequest(BaseModel): # Note: is_pdp is kept for backward compatibility but is ignored. # PDP status is derived from the institution's pdp_id field. is_pdp: bool = False + # Legacy schools inference parameters (optional passthrough; ignored for PDP) + config_file_name: str | None = None + features_table_name: str | None = None # Model related operations. Or model specific data. @@ -524,11 +531,95 @@ def trigger_inference_run( + str(len(inst_result)), ) inst = inst_result[0][0] - # Check PDP status from institution's pdp_id (ignore req.is_pdp for backward compat) - if not inst.pdp_id: + # Determine institution type: PDP, Edvise, or Legacy + # There are only three options: PDP (pdp_id), Edvise (edvise_id), or Legacy (legacy_id or none) + # Follows the same pattern as validation_helper in data.py + pdp_id = getattr(inst, "pdp_id", None) + edvise_id = getattr(inst, "edvise_id", None) + legacy_id = getattr(inst, "legacy_id", None) + # Defensive check: ensure mutual exclusivity (should not happen if validation works correctly) + if sum(bool(x) for x in (pdp_id, edvise_id, legacy_id)) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution configuration error: cannot have more than one of pdp_id, edvise_id, or legacy_id set", + ) + is_pdp = bool(pdp_id) + is_legacy = bool(legacy_id) + + # Legacy schools inference + if is_legacy: + legacy_model_result = ( + local_session.get() + .execute( + select(ModelTable).where( + and_( + ModelTable.name == model_name, + ModelTable.inst_id == str_to_uuid(inst_id), + ) + ) + ) + .all() + ) + if len(legacy_model_result) != 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Unexpected number of models found: Expected 1, got " + + str(len(legacy_model_result)), + ) + # For legacy schools, we don't need batch validation (config and features table are used instead) + # Omitting names is allowed; pass empty strings so Pydantic accepts the request and the + # Edvise legacy_inference_inputs job can resolve artifacts under silver_volume (same as YAML defaults). + db_req = DatabricksLegacyInferenceRunRequest( + inst_name=inst_result[0][0].name, + model_name=model_name, + config_file_name=req.config_file_name or "", + features_table_name=req.features_table_name or "", + gcp_external_bucket_name=get_external_bucket_name(inst_id), + email=current_user.email or "", + ) + try: + res = databricks_control.run_legacy_inference(db_req) + except Exception as e: + tb = traceback.format_exc() + logging.error(f"Databricks run failure:\n{tb}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Databricks run_legacy_inference error. Error = {str(e)}", + ) from e + triggered_timestamp = datetime.now() + latest_model_version = databricks_control.fetch_model_version( + catalog_name=str(env_vars["CATALOG_NAME"]), + inst_name=inst_result[0][0].name, + model_name=model_name, + ) + job = JobTable( + id=res.job_run_id, + triggered_at=triggered_timestamp, + created_by=str_to_uuid(current_user.user_id), + batch_name=f"{model_name}_{triggered_timestamp}", # Legacy schools don't use batches + model_id=legacy_model_result[0][0].id, + output_valid=False, + model_version=latest_model_version.version, + model_run_id=latest_model_version.run_id, + ) + local_session.get().add(job) + return { + "inst_id": inst_id, + "m_name": model_name, + "run_id": res.job_run_id, + "created_by": current_user.user_id, + "triggered_at": triggered_timestamp, + "batch_name": f"{model_name}_{triggered_timestamp}", + "output_valid": False, + "model_version": latest_model_version.version, + "model_run_id": latest_model_version.run_id, + } + + # PDP inference (existing logic) + if not is_pdp: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="Currently, only PDP inference is supported.", + detail="Currently, only PDP and Legacy schools inference are supported.", ) query_result = ( local_session.get() @@ -589,7 +680,7 @@ def trigger_inference_run( detail=f"The files in this batch don't conform to the schema configs allowed by this model. For debugging reference - file_schema={inst_file_schemas} and model_schema={schema_configs}", ) # Note to Datakind: In the long-term, this is where you would have a case block or something that would call different types of pipelines. - db_req = DatabricksInferenceRunRequest( + pdp_db_req = DatabricksPDPInferenceRunRequest( inst_name=inst_result[0][0].name, filepath_to_type=convert_files_to_dict(batch_result[0][0].files), model_name=model_name, @@ -598,7 +689,7 @@ def trigger_inference_run( email=cast(str, current_user.email), ) try: - res = databricks_control.run_pdp_inference(db_req) + res = databricks_control.run_pdp_inference(pdp_db_req) except Exception as e: tb = traceback.format_exc() logging.error(f"Databricks run failure:\n{tb}") diff --git a/src/webapp/utilities.py b/src/webapp/utilities.py index 76dcff74..54cd2e14 100644 --- a/src/webapp/utilities.py +++ b/src/webapp/utilities.py @@ -3,7 +3,7 @@ import uuid import re from typing import Annotated, Final, Any, Optional, Tuple, Union -from urllib.parse import unquote +from urllib.parse import unquote_plus from strenum import StrEnum # needed for python pre 3.11 import jwt from fastapi import HTTPException, status, Depends @@ -23,8 +23,33 @@ def decode_url_piece(src: str) -> str: - """Decode encoded URL.""" - return unquote(src) + """Decode a URL path segment the way most clients encode names. + + Uses :func:`urllib.parse.unquote_plus` so ``+`` is treated as a space (common + when clients apply form-style encoding to paths). A literal ``+`` in a name + must be sent as ``%2B``. + """ + return unquote_plus(src) + + +def file_name_variants_for_lookup(name: str) -> set[str]: + """Return spellings to try when matching a stored ``file.name``. + + Accepts common mismatches between spaces and ``+`` (and ``decode_url_piece``-style + decoding) so batch endpoints stay usable even if the client and DB disagree. + """ + n = unquote_plus(name.strip()) + if not n: + return set() + return {n, n.replace("+", " "), n.replace(" ", "+")} + + +def expand_batch_file_name_lookups(names: list[str]) -> list[str]: + """Flatten :func:`file_name_variants_for_lookup` for a SQL ``IN`` clause.""" + expanded: set[str] = set() + for name in names: + expanded |= file_name_variants_for_lookup(name) + return list(expanded) class AccessType(StrEnum): diff --git a/src/webapp/utilities_test.py b/src/webapp/utilities_test.py index 29617c9d..6d7dba64 100644 --- a/src/webapp/utilities_test.py +++ b/src/webapp/utilities_test.py @@ -4,6 +4,9 @@ from fastapi import HTTPException from .utilities import ( + decode_url_piece, + expand_batch_file_name_lookups, + file_name_variants_for_lookup, has_access_to_inst_or_err, has_full_data_access_or_err, has_at_most_one_school_type, @@ -91,3 +94,23 @@ def test_databricksify_inst_name(): with pytest.raises(ValueError) as err: databricksify_inst_name("Northwest (invalid)") assert str(err.value) == "Unexpected character found in Databricks compatible name." + + +def test_decode_url_piece_treats_plus_as_space() -> None: + """Form-style + encoding in paths should decode to spaces.""" + assert decode_url_piece("a+b.csv") == "a b.csv" + assert decode_url_piece("foo%20bar.csv") == "foo bar.csv" + assert decode_url_piece("x%2By.csv") == "x+y.csv" + + +def test_file_name_variants_for_lookup() -> None: + """Batch lookups accept spaces vs. plus spellings.""" + v = file_name_variants_for_lookup("a b.csv") + assert v == {"a b.csv", "a+b.csv"} + assert file_name_variants_for_lookup(" a+b.csv ") == {"a b.csv", "a+b.csv"} + + +def test_expand_batch_file_name_lookups() -> None: + out = set(expand_batch_file_name_lookups(["x y.csv", "p+q.csv"])) + assert "x y.csv" in out and "x+y.csv" in out + assert "p q.csv" in out and "p+q.csv" in out