From 67b59d6f0cd2d642fa20ed4a49466e5ad7b70828 Mon Sep 17 00:00:00 2001 From: Vishakh Pillai Date: Wed, 4 Mar 2026 16:09:57 -0500 Subject: [PATCH 01/10] feat: custom school inference, but need to confirm if custom is the same as legacy --- src/webapp/databricks.py | 86 +++++++++++++++++++++++++++++++++-- src/webapp/routers/models.py | 88 ++++++++++++++++++++++++++++++++++-- 2 files changed, 166 insertions(+), 8 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 5820415e..4ba1fe16 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -36,10 +36,11 @@ # The name of the deployed pipeline in Databricks. Must match directly. PDP_INFERENCE_JOB_NAME = "edvise_github_sourced_pdp_inference_pipeline" +CUSTOM_INFERENCE_JOB_NAME = "edvise_github_sourced_custom_inference_pipeline" -class DatabricksInferenceRunRequest(BaseModel): - """Databricks parameters for an inference run.""" +class DatabricksPDPInferenceRunRequest(BaseModel): + """Databricks parameters for a PDP inference run.""" inst_name: str # Note that the following should be the filepath. @@ -50,6 +51,18 @@ class DatabricksInferenceRunRequest(BaseModel): gcp_external_bucket_name: str +class DatabricksCustomInferenceRunRequest(BaseModel): + """Databricks parameters for a custom schools inference run.""" + + inst_name: str + model_name: str + config_file_name: str + features_table_name: str + # The email where notifications will get sent. + email: str + gcp_external_bucket_name: str + + class DatabricksInferenceRunResponse(BaseModel): """Databricks parameters for an inference run.""" @@ -186,7 +199,7 @@ def setup_new_inst(self, inst_name: str) -> None: # E.g. there is one PDP inference pipeline, so one PDP inference function here. def run_pdp_inference( - self, req: DatabricksInferenceRunRequest + self, req: DatabricksPDPInferenceRunRequest ) -> DatabricksInferenceRunResponse: """Triggers PDP inference Databricks run.""" LOGGER.info(f"Running PDP inference for institution: {req.inst_name}") @@ -264,6 +277,73 @@ def run_pdp_inference( return DatabricksInferenceRunResponse(job_run_id=run_id) + def run_custom_inference( + self, req: DatabricksCustomInferenceRunRequest + ) -> DatabricksInferenceRunResponse: + """Triggers custom schools inference Databricks run.""" + LOGGER.info(f"Running custom inference for institution: {req.inst_name}") + try: + w = WorkspaceClient( + host=databricks_vars["DATABRICKS_HOST_URL"], + google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], + ) + LOGGER.info("Successfully created Databricks WorkspaceClient.") + except Exception as e: + LOGGER.exception( + "Failed to create Databricks WorkspaceClient with host: %s and service account: %s", + databricks_vars["DATABRICKS_HOST_URL"], + gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], + ) + raise ValueError( + f"run_custom_inference(): Workspace client initialization failed: {e}" + ) + + db_inst_name = databricksify_inst_name(req.inst_name) + pipeline_type = CUSTOM_INFERENCE_JOB_NAME + + try: + job = next(w.jobs.list(name=pipeline_type), None) + if not job or job.job_id is None: + raise ValueError( + f"run_custom_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'." + ) + job_id = job.job_id + LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}") + except Exception as e: + LOGGER.exception(f"Job lookup failed for '{pipeline_type}'.") + raise ValueError(f"run_custom_inference(): Failed to find job: {e}") + + try: + run_job: Any = w.jobs.run_now( + job_id, + job_parameters={ + "databricks_institution_name": db_inst_name, + "DB_workspace": databricks_vars[ + "DATABRICKS_WORKSPACE" + ], + "model_name": req.model_name, + "config_file_name": req.config_file_name, + "features_table_name": req.features_table_name, + "gcp_bucket_name": req.gcp_external_bucket_name, + "datakind_notification_email": req.email, + "DK_CC_EMAIL": req.email, + }, + ) + LOGGER.info( + f"Successfully triggered job run. Run ID: {run_job.response.run_id}" + ) + except Exception as e: + LOGGER.exception("Failed to run the custom inference job.") + raise ValueError(f"run_custom_inference(): Job could not be run: {e}") + + if not run_job.response or run_job.response.run_id is None: + raise ValueError("run_custom_inference(): Job did not return a valid run_id.") + + run_id = run_job.response.run_id + LOGGER.info(f"Successfully triggered job run. Run ID: {run_id}") + + return DatabricksInferenceRunResponse(job_run_id=run_id) + def delete_inst(self, inst_name: str) -> None: """Cleanup tasks required on the Databricks side to delete an institution.""" db_inst_name = databricksify_inst_name(inst_name) diff --git a/src/webapp/routers/models.py b/src/webapp/routers/models.py index 02be74ae..9aa8606f 100644 --- a/src/webapp/routers/models.py +++ b/src/webapp/routers/models.py @@ -8,7 +8,11 @@ from sqlalchemy import and_, update, or_ from sqlalchemy.orm import Session from sqlalchemy.future import select -from ..databricks import DatabricksControl, DatabricksInferenceRunRequest +from ..databricks import ( + DatabricksControl, + DatabricksPDPInferenceRunRequest, + DatabricksCustomInferenceRunRequest, +) from ..utilities import ( has_access_to_inst_or_err, has_full_data_access_or_err, @@ -138,6 +142,9 @@ class InferenceRunRequest(BaseModel): # Note: is_pdp is kept for backward compatibility but is ignored. # PDP status is derived from the institution's pdp_id field. is_pdp: bool = False + # Custom schools inference parameters (required for custom schools, ignored for PDP) + config_file_name: str | None = None + features_table_name: str | None = None # Model related operations. Or model specific data. @@ -524,11 +531,82 @@ def trigger_inference_run( + str(len(inst_result)), ) inst = inst_result[0][0] - # Check PDP status from institution's pdp_id (ignore req.is_pdp for backward compat) - if not inst.pdp_id: + # Determine institution type: PDP, Edvise, or Legacy/Custom + # There are only three options: PDP (pdp_id), Edvise (edvise_id), or Legacy/Custom (legacy_id or none) + # Follows the same pattern as validation_helper in data.py + pdp_id = getattr(inst, "pdp_id", None) + edvise_id = getattr(inst, "edvise_id", None) + legacy_id = getattr(inst, "legacy_id", None) + # Defensive check: ensure mutual exclusivity (should not happen if validation works correctly) + if sum(bool(x) for x in (pdp_id, edvise_id, legacy_id)) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution configuration error: cannot have more than one of pdp_id, edvise_id, or legacy_id set", + ) + is_pdp = bool(pdp_id) + is_edvise = bool(edvise_id) + # Legacy and custom are the same thing - both use custom inference pipeline + is_legacy_or_custom = not is_pdp and not is_edvise + + # Legacy/Custom schools inference + if is_legacy_or_custom: + if not req.config_file_name or not req.features_table_name: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Legacy/Custom schools inference requires config_file_name and features_table_name.", + ) + # For legacy/custom schools, we don't need batch validation (config and features table are used instead) + db_req = DatabricksCustomInferenceRunRequest( + inst_name=inst_result[0][0].name, + model_name=model_name, + config_file_name=req.config_file_name, + features_table_name=req.features_table_name, + gcp_external_bucket_name=get_external_bucket_name(inst_id), + email=cast(str, current_user.email), + ) + try: + res = databricks_control.run_custom_inference(db_req) + except Exception as e: + tb = traceback.format_exc() + logging.error(f"Databricks run failure:\n{tb}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Databricks run_custom_inference error. Error = {str(e)}", + ) from e + triggered_timestamp = datetime.now() + latest_model_version = databricks_control.fetch_model_version( + catalog_name=str(env_vars["CATALOG_NAME"]), + inst_name=inst_result[0][0].name, + model_name=model_name, + ) + job = JobTable( + id=res.job_run_id, + triggered_at=triggered_timestamp, + created_by=str_to_uuid(current_user.user_id), + batch_name=f"{model_name}_{triggered_timestamp}", # Custom schools don't use batches + model_id=query_result[0][0].id, + output_valid=False, + model_version=latest_model_version.version, + model_run_id=latest_model_version.run_id, + ) + local_session.get().add(job) + return { + "inst_id": inst_id, + "m_name": model_name, + "run_id": res.job_run_id, + "created_by": current_user.user_id, + "triggered_at": triggered_timestamp, + "batch_name": f"{model_name}_{triggered_timestamp}", + "output_valid": False, + "model_version": latest_model_version.version, + "model_run_id": latest_model_version.run_id, + } + + # PDP inference (existing logic) + if not is_pdp: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="Currently, only PDP inference is supported.", + detail="Currently, only PDP and Legacy/Custom schools inference are supported.", ) query_result = ( local_session.get() @@ -589,7 +667,7 @@ def trigger_inference_run( detail=f"The files in this batch don't conform to the schema configs allowed by this model. For debugging reference - file_schema={inst_file_schemas} and model_schema={schema_configs}", ) # Note to Datakind: In the long-term, this is where you would have a case block or something that would call different types of pipelines. - db_req = DatabricksInferenceRunRequest( + db_req = DatabricksPDPInferenceRunRequest( inst_name=inst_result[0][0].name, filepath_to_type=convert_files_to_dict(batch_result[0][0].files), model_name=model_name, From f5f5c057c87e46467230dbc4bb5f4e0d24fea875 Mon Sep 17 00:00:00 2001 From: Vishakh Pillai Date: Tue, 7 Apr 2026 15:02:22 -0400 Subject: [PATCH 02/10] fix: transitioning from 'custom' to 'legacy' --- src/webapp/databricks.py | 28 ++++++++++----------- src/webapp/routers/models.py | 49 ++++++++++++++++++++++++------------ 2 files changed, 47 insertions(+), 30 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 4ba1fe16..350dd39e 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -36,7 +36,7 @@ # The name of the deployed pipeline in Databricks. Must match directly. PDP_INFERENCE_JOB_NAME = "edvise_github_sourced_pdp_inference_pipeline" -CUSTOM_INFERENCE_JOB_NAME = "edvise_github_sourced_custom_inference_pipeline" +LEGACY_INFERENCE_JOB_NAME = "edvise_github_sourced_legacy_inference_pipeline" class DatabricksPDPInferenceRunRequest(BaseModel): @@ -51,8 +51,8 @@ class DatabricksPDPInferenceRunRequest(BaseModel): gcp_external_bucket_name: str -class DatabricksCustomInferenceRunRequest(BaseModel): - """Databricks parameters for a custom schools inference run.""" +class DatabricksLegacyInferenceRunRequest(BaseModel): + """Databricks parameters for a legacy schools inference run.""" inst_name: str model_name: str @@ -277,11 +277,11 @@ def run_pdp_inference( return DatabricksInferenceRunResponse(job_run_id=run_id) - def run_custom_inference( - self, req: DatabricksCustomInferenceRunRequest + def run_legacy_inference( + self, req: DatabricksLegacyInferenceRunRequest ) -> DatabricksInferenceRunResponse: - """Triggers custom schools inference Databricks run.""" - LOGGER.info(f"Running custom inference for institution: {req.inst_name}") + """Triggers legacy schools inference Databricks run.""" + LOGGER.info(f"Running legacy inference for institution: {req.inst_name}") try: w = WorkspaceClient( host=databricks_vars["DATABRICKS_HOST_URL"], @@ -295,23 +295,23 @@ def run_custom_inference( gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], ) raise ValueError( - f"run_custom_inference(): Workspace client initialization failed: {e}" + f"run_legacy_inference(): Workspace client initialization failed: {e}" ) db_inst_name = databricksify_inst_name(req.inst_name) - pipeline_type = CUSTOM_INFERENCE_JOB_NAME + pipeline_type = LEGACY_INFERENCE_JOB_NAME try: job = next(w.jobs.list(name=pipeline_type), None) if not job or job.job_id is None: raise ValueError( - f"run_custom_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'." + f"run_legacy_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'." ) job_id = job.job_id LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}") except Exception as e: LOGGER.exception(f"Job lookup failed for '{pipeline_type}'.") - raise ValueError(f"run_custom_inference(): Failed to find job: {e}") + raise ValueError(f"run_legacy_inference(): Failed to find job: {e}") try: run_job: Any = w.jobs.run_now( @@ -333,11 +333,11 @@ def run_custom_inference( f"Successfully triggered job run. Run ID: {run_job.response.run_id}" ) except Exception as e: - LOGGER.exception("Failed to run the custom inference job.") - raise ValueError(f"run_custom_inference(): Job could not be run: {e}") + LOGGER.exception("Failed to run the legacy inference job.") + raise ValueError(f"run_legacy_inference(): Job could not be run: {e}") if not run_job.response or run_job.response.run_id is None: - raise ValueError("run_custom_inference(): Job did not return a valid run_id.") + raise ValueError("run_legacy_inference(): Job did not return a valid run_id.") run_id = run_job.response.run_id LOGGER.info(f"Successfully triggered job run. Run ID: {run_id}") diff --git a/src/webapp/routers/models.py b/src/webapp/routers/models.py index 9aa8606f..3772d95d 100644 --- a/src/webapp/routers/models.py +++ b/src/webapp/routers/models.py @@ -11,7 +11,7 @@ from ..databricks import ( DatabricksControl, DatabricksPDPInferenceRunRequest, - DatabricksCustomInferenceRunRequest, + DatabricksLegacyInferenceRunRequest, ) from ..utilities import ( has_access_to_inst_or_err, @@ -142,7 +142,7 @@ class InferenceRunRequest(BaseModel): # Note: is_pdp is kept for backward compatibility but is ignored. # PDP status is derived from the institution's pdp_id field. is_pdp: bool = False - # Custom schools inference parameters (required for custom schools, ignored for PDP) + # Legacy schools inference parameters (required for legacy schools, ignored for PDP) config_file_name: str | None = None features_table_name: str | None = None @@ -531,8 +531,8 @@ def trigger_inference_run( + str(len(inst_result)), ) inst = inst_result[0][0] - # Determine institution type: PDP, Edvise, or Legacy/Custom - # There are only three options: PDP (pdp_id), Edvise (edvise_id), or Legacy/Custom (legacy_id or none) + # Determine institution type: PDP, Edvise, or Legacy + # There are only three options: PDP (pdp_id), Edvise (edvise_id), or Legacy (legacy_id or none) # Follows the same pattern as validation_helper in data.py pdp_id = getattr(inst, "pdp_id", None) edvise_id = getattr(inst, "edvise_id", None) @@ -545,18 +545,35 @@ def trigger_inference_run( ) is_pdp = bool(pdp_id) is_edvise = bool(edvise_id) - # Legacy and custom are the same thing - both use custom inference pipeline - is_legacy_or_custom = not is_pdp and not is_edvise + is_legacy = not is_pdp and not is_edvise - # Legacy/Custom schools inference - if is_legacy_or_custom: + # Legacy schools inference + if is_legacy: if not req.config_file_name or not req.features_table_name: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Legacy/Custom schools inference requires config_file_name and features_table_name.", + detail="Legacy schools inference requires config_file_name and features_table_name.", ) - # For legacy/custom schools, we don't need batch validation (config and features table are used instead) - db_req = DatabricksCustomInferenceRunRequest( + legacy_model_result = ( + local_session.get() + .execute( + select(ModelTable).where( + and_( + ModelTable.name == model_name, + ModelTable.inst_id == str_to_uuid(inst_id), + ) + ) + ) + .all() + ) + if len(legacy_model_result) != 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Unexpected number of models found: Expected 1, got " + + str(len(legacy_model_result)), + ) + # For legacy schools, we don't need batch validation (config and features table are used instead) + db_req = DatabricksLegacyInferenceRunRequest( inst_name=inst_result[0][0].name, model_name=model_name, config_file_name=req.config_file_name, @@ -565,13 +582,13 @@ def trigger_inference_run( email=cast(str, current_user.email), ) try: - res = databricks_control.run_custom_inference(db_req) + res = databricks_control.run_legacy_inference(db_req) except Exception as e: tb = traceback.format_exc() logging.error(f"Databricks run failure:\n{tb}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Databricks run_custom_inference error. Error = {str(e)}", + detail=f"Databricks run_legacy_inference error. Error = {str(e)}", ) from e triggered_timestamp = datetime.now() latest_model_version = databricks_control.fetch_model_version( @@ -583,8 +600,8 @@ def trigger_inference_run( id=res.job_run_id, triggered_at=triggered_timestamp, created_by=str_to_uuid(current_user.user_id), - batch_name=f"{model_name}_{triggered_timestamp}", # Custom schools don't use batches - model_id=query_result[0][0].id, + batch_name=f"{model_name}_{triggered_timestamp}", # Legacy schools don't use batches + model_id=legacy_model_result[0][0].id, output_valid=False, model_version=latest_model_version.version, model_run_id=latest_model_version.run_id, @@ -606,7 +623,7 @@ def trigger_inference_run( if not is_pdp: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="Currently, only PDP and Legacy/Custom schools inference are supported.", + detail="Currently, only PDP and Legacy schools inference are supported.", ) query_result = ( local_session.get() From 97065cc8fde3eddb581158043759628a823f5f5c Mon Sep 17 00:00:00 2001 From: Vishakh Pillai Date: Tue, 28 Apr 2026 12:04:47 -0400 Subject: [PATCH 03/10] fix: remove validation of job parameters, handled already through edvise --- src/webapp/routers/models.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/webapp/routers/models.py b/src/webapp/routers/models.py index 3772d95d..b13760d1 100644 --- a/src/webapp/routers/models.py +++ b/src/webapp/routers/models.py @@ -142,7 +142,7 @@ class InferenceRunRequest(BaseModel): # Note: is_pdp is kept for backward compatibility but is ignored. # PDP status is derived from the institution's pdp_id field. is_pdp: bool = False - # Legacy schools inference parameters (required for legacy schools, ignored for PDP) + # Legacy schools inference parameters (optional passthrough; ignored for PDP) config_file_name: str | None = None features_table_name: str | None = None @@ -545,15 +545,10 @@ def trigger_inference_run( ) is_pdp = bool(pdp_id) is_edvise = bool(edvise_id) - is_legacy = not is_pdp and not is_edvise + is_legacy = bool(legacy_id) # Legacy schools inference if is_legacy: - if not req.config_file_name or not req.features_table_name: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Legacy schools inference requires config_file_name and features_table_name.", - ) legacy_model_result = ( local_session.get() .execute( From acc294b4644f2f43b526925643b87dc6d4158bde Mon Sep 17 00:00:00 2001 From: Vishakh Pillai Date: Tue, 28 Apr 2026 12:27:31 -0400 Subject: [PATCH 04/10] fix: run request still requires str values, defaulting to empty string --- src/webapp/routers/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/webapp/routers/models.py b/src/webapp/routers/models.py index b13760d1..1b1b149c 100644 --- a/src/webapp/routers/models.py +++ b/src/webapp/routers/models.py @@ -568,11 +568,13 @@ def trigger_inference_run( + str(len(legacy_model_result)), ) # For legacy schools, we don't need batch validation (config and features table are used instead) + # Omitting names is allowed; pass empty strings so Pydantic accepts the request and the + # Edvise legacy_inference_inputs job can resolve artifacts under silver_volume (same as YAML defaults). db_req = DatabricksLegacyInferenceRunRequest( inst_name=inst_result[0][0].name, model_name=model_name, - config_file_name=req.config_file_name, - features_table_name=req.features_table_name, + config_file_name=req.config_file_name or "", + features_table_name=req.features_table_name or "", gcp_external_bucket_name=get_external_bucket_name(inst_id), email=cast(str, current_user.email), ) From c023b62c09e086a0639a6da9ea86d728bbadd5fb Mon Sep 17 00:00:00 2001 From: Vishakh Pillai Date: Tue, 28 Apr 2026 12:38:08 -0400 Subject: [PATCH 05/10] fix: still getting pydantic error --- src/webapp/databricks.py | 14 ++++++++++---- src/webapp/routers/models.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index e8ea110c..f5990d37 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -2,7 +2,7 @@ import os import logging -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from databricks.sdk import WorkspaceClient from databricks.sdk.service import catalog from databricks.sdk.service.sql import ( @@ -53,12 +53,18 @@ class DatabricksLegacyInferenceRunRequest(BaseModel): inst_name: str model_name: str - config_file_name: str - features_table_name: str + config_file_name: str = "" + features_table_name: str = "" # The email where notifications will get sent. - email: str + email: str = "" gcp_external_bucket_name: str + @field_validator("config_file_name", "features_table_name", "email", mode="before") + @classmethod + def _none_to_empty_str(cls, v: object) -> object: + """Allow callers to omit or pass null; Databricks job treats empty like YAML defaults.""" + return "" if v is None else v + class DatabricksInferenceRunResponse(BaseModel): """Databricks parameters for an inference run.""" diff --git a/src/webapp/routers/models.py b/src/webapp/routers/models.py index 1b1b149c..b3e75c4f 100644 --- a/src/webapp/routers/models.py +++ b/src/webapp/routers/models.py @@ -576,7 +576,7 @@ def trigger_inference_run( config_file_name=req.config_file_name or "", features_table_name=req.features_table_name or "", gcp_external_bucket_name=get_external_bucket_name(inst_id), - email=cast(str, current_user.email), + email=current_user.email or "", ) try: res = databricks_control.run_legacy_inference(db_req) From c35dfa8581800da71f4d03178300ca75c2b0cda5 Mon Sep 17 00:00:00 2001 From: Vishakh Pillai Date: Tue, 28 Apr 2026 14:45:24 -0400 Subject: [PATCH 06/10] feat: using substring matching to find legacy job since i have it deployed under my name because of target==dev --- src/webapp/databricks.py | 79 +++++++++++++++++++++++++++++------ src/webapp/databricks_test.py | 60 +++++++++++++++++++++++++- 2 files changed, 125 insertions(+), 14 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index f5990d37..85674975 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -31,11 +31,72 @@ # List of data medallion levels MEDALLION_LEVELS = ["silver", "gold", "bronze"] -# The name of the deployed pipeline in Databricks. Must match directly. +# The name of the deployed pipeline in Databricks. Must match the job's `name` in that workspace. +# Override with LEGACY_INFERENCE_JOB_NAME (and PDP_INFERENCE_JOB_NAME) when dev/staging deploy +# uses a different bundle target or a stub job that matches the same parameters. PDP_INFERENCE_JOB_NAME = "edvise_github_sourced_pdp_inference_pipeline" LEGACY_INFERENCE_JOB_NAME = "edvise_github_sourced_legacy_inference_pipeline" +def _pdp_inference_job_name() -> str: + name = os.environ.get("PDP_INFERENCE_JOB_NAME", "").strip() + return name or PDP_INFERENCE_JOB_NAME + + +def _legacy_inference_job_name() -> str: + name = os.environ.get("LEGACY_INFERENCE_JOB_NAME", "").strip() + return name or LEGACY_INFERENCE_JOB_NAME + + +def _resolve_pipeline_job(w: Any, pipeline_type: str, caller_label: str) -> Any: + """Find a job by exact name, else by unique substring match on display name. + + Development bundles often prefix job names (e.g. ``[dev vishakh] edvise_...``) while + the API passes the canonical base name. Optional env ``PDP_INFERENCE_JOB_NAME`` / + ``LEGACY_INFERENCE_JOB_NAME`` still wins when set to the full exact name. + """ + job = next(w.jobs.list(name=pipeline_type), None) + if job is not None and getattr(job, "job_id", None) is not None: + LOGGER.info("%s: resolved job by exact name %r (job_id=%s)", caller_label, pipeline_type, job.job_id) + return job + + matches: list[tuple[str, Any]] = [] + for j in w.jobs.list(): + settings = getattr(j, "settings", None) + name = getattr(settings, "name", None) if settings is not None else None + if not name: + continue + if pipeline_type in name: + matches.append((name, j)) + + if len(matches) == 1: + picked_name, picked = matches[0] + if getattr(picked, "job_id", None) is None: + raise ValueError( + f"{caller_label}: Job name {picked_name!r} matched substring {pipeline_type!r} but has no job_id." + ) + LOGGER.info( + "%s: resolved job by substring %r -> display name %r (job_id=%s)", + caller_label, + pipeline_type, + picked_name, + picked.job_id, + ) + return picked + + if len(matches) > 1: + names = [n for n, _ in matches] + raise ValueError( + f"{caller_label}: Multiple jobs match substring {pipeline_type!r}: {names}. " + "Set PDP_INFERENCE_JOB_NAME or LEGACY_INFERENCE_JOB_NAME to the full job name." + ) + + raise ValueError( + f"{caller_label}: Job {pipeline_type!r} was not found (exact name or unique substring of settings.name) " + f"for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'." + ) + + class DatabricksPDPInferenceRunRequest(BaseModel): """Databricks parameters for a PDP inference run.""" @@ -232,14 +293,10 @@ def run_pdp_inference( ) db_inst_name = databricksify_inst_name(req.inst_name) - pipeline_type = PDP_INFERENCE_JOB_NAME + pipeline_type = _pdp_inference_job_name() try: - job = next(w.jobs.list(name=pipeline_type), None) - if not job or job.job_id is None: - raise ValueError( - f"run_pdp_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'." - ) + job = _resolve_pipeline_job(w, pipeline_type, "run_pdp_inference") job_id = job.job_id LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}") except Exception as e: @@ -302,14 +359,10 @@ def run_legacy_inference( ) db_inst_name = databricksify_inst_name(req.inst_name) - pipeline_type = LEGACY_INFERENCE_JOB_NAME + pipeline_type = _legacy_inference_job_name() try: - job = next(w.jobs.list(name=pipeline_type), None) - if not job or job.job_id is None: - raise ValueError( - f"run_legacy_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'." - ) + job = _resolve_pipeline_job(w, pipeline_type, "run_legacy_inference") job_id = job.job_id LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}") except Exception as e: diff --git a/src/webapp/databricks_test.py b/src/webapp/databricks_test.py index 37fa2855..025e741e 100644 --- a/src/webapp/databricks_test.py +++ b/src/webapp/databricks_test.py @@ -1,6 +1,7 @@ import pytest +from unittest.mock import MagicMock -from .databricks import DatabricksControl +from .databricks import DatabricksControl, _resolve_pipeline_job @pytest.fixture @@ -51,3 +52,60 @@ def test_invalid_regex_is_ignored(ctrl): def test_returns_none_when_no_match(ctrl): mapping = {"student": "student.csv"} assert ctrl.get_key_for_file(mapping, "unknown.csv") is None + + +def _job_named(full_name: str, job_id: int = 42) -> MagicMock: + j = MagicMock() + j.job_id = job_id + j.settings = MagicMock() + j.settings.name = full_name + return j + + +def test_resolve_pipeline_job_exact_match_skips_scan(): + canonical = "edvise_github_sourced_pdp_inference_pipeline" + hit = _job_named(canonical, job_id=7) + + def list_jobs(name=None): + if name == canonical: + return iter([hit]) + return iter([]) + + w = MagicMock() + w.jobs.list.side_effect = list_jobs + + assert _resolve_pipeline_job(w, canonical, "test").job_id == 7 + w.jobs.list.assert_called_once() + + +def test_resolve_pipeline_job_substring_dev_prefix(): + canonical = "edvise_github_sourced_pdp_inference_pipeline" + hit = _job_named(f"[dev vishakh] {canonical}", job_id=11) + + def list_jobs(name=None): + if name is not None: + return iter([]) + return iter([hit]) + + w = MagicMock() + w.jobs.list.side_effect = list_jobs + + assert _resolve_pipeline_job(w, canonical, "test").job_id == 11 + assert w.jobs.list.call_count == 2 + + +def test_resolve_pipeline_job_ambiguous_substring_raises(): + canonical = "edvise_github_sourced_pdp_inference_pipeline" + a = _job_named(f"[dev a] {canonical}", job_id=1) + b = _job_named(f"[dev b] {canonical}", job_id=2) + + def list_jobs(name=None): + if name is not None: + return iter([]) + return iter([a, b]) + + w = MagicMock() + w.jobs.list.side_effect = list_jobs + + with pytest.raises(ValueError, match="Multiple jobs match substring"): + _resolve_pipeline_job(w, canonical, "test") From 7f05ad4fb87b0fa8b366920e158f957dfdb65fef Mon Sep 17 00:00:00 2001 From: Vishakh Pillai Date: Tue, 5 May 2026 13:43:24 -0400 Subject: [PATCH 07/10] fix: style --- src/webapp/routers/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/webapp/routers/models.py b/src/webapp/routers/models.py index b3e75c4f..52c2923c 100644 --- a/src/webapp/routers/models.py +++ b/src/webapp/routers/models.py @@ -681,7 +681,7 @@ def trigger_inference_run( detail=f"The files in this batch don't conform to the schema configs allowed by this model. For debugging reference - file_schema={inst_file_schemas} and model_schema={schema_configs}", ) # Note to Datakind: In the long-term, this is where you would have a case block or something that would call different types of pipelines. - db_req = DatabricksPDPInferenceRunRequest( + pdp_db_req = DatabricksPDPInferenceRunRequest( inst_name=inst_result[0][0].name, filepath_to_type=convert_files_to_dict(batch_result[0][0].files), model_name=model_name, @@ -690,7 +690,7 @@ def trigger_inference_run( email=cast(str, current_user.email), ) try: - res = databricks_control.run_pdp_inference(db_req) + res = databricks_control.run_pdp_inference(pdp_db_req) except Exception as e: tb = traceback.format_exc() logging.error(f"Databricks run failure:\n{tb}") From 0539497da6fdd9fdeef0025a8cced1a476d7ed94 Mon Sep 17 00:00:00 2001 From: Vishakh Pillai Date: Tue, 5 May 2026 13:47:36 -0400 Subject: [PATCH 08/10] fix: style --- src/webapp/routers/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/webapp/routers/models.py b/src/webapp/routers/models.py index 52c2923c..06eb0f78 100644 --- a/src/webapp/routers/models.py +++ b/src/webapp/routers/models.py @@ -544,7 +544,6 @@ def trigger_inference_run( detail="Institution configuration error: cannot have more than one of pdp_id, edvise_id, or legacy_id set", ) is_pdp = bool(pdp_id) - is_edvise = bool(edvise_id) is_legacy = bool(legacy_id) # Legacy schools inference From f64b6ddd45ec0cd97508b459c288e4dcfe6f1fd0 Mon Sep 17 00:00:00 2001 From: Vishakh Pillai Date: Tue, 5 May 2026 13:59:12 -0400 Subject: [PATCH 09/10] fix: style --- src/webapp/databricks.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 85674975..75f417c7 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -57,7 +57,12 @@ def _resolve_pipeline_job(w: Any, pipeline_type: str, caller_label: str) -> Any: """ job = next(w.jobs.list(name=pipeline_type), None) if job is not None and getattr(job, "job_id", None) is not None: - LOGGER.info("%s: resolved job by exact name %r (job_id=%s)", caller_label, pipeline_type, job.job_id) + LOGGER.info( + "%s: resolved job by exact name %r (job_id=%s)", + caller_label, + pipeline_type, + job.job_id, + ) return job matches: list[tuple[str, Any]] = [] @@ -374,9 +379,7 @@ def run_legacy_inference( job_id, job_parameters={ "databricks_institution_name": db_inst_name, - "DB_workspace": databricks_vars[ - "DATABRICKS_WORKSPACE" - ], + "DB_workspace": databricks_vars["DATABRICKS_WORKSPACE"], "model_name": req.model_name, "config_file_name": req.config_file_name, "features_table_name": req.features_table_name, @@ -393,7 +396,9 @@ def run_legacy_inference( raise ValueError(f"run_legacy_inference(): Job could not be run: {e}") if not run_job.response or run_job.response.run_id is None: - raise ValueError("run_legacy_inference(): Job did not return a valid run_id.") + raise ValueError( + "run_legacy_inference(): Job did not return a valid run_id." + ) run_id = run_job.response.run_id LOGGER.info(f"Successfully triggered job run. Run ID: {run_id}") From d2ee116c3f40b92bb267d74d76b8116c65d93505 Mon Sep 17 00:00:00 2001 From: Vishakh Pillai Date: Tue, 5 May 2026 15:30:56 -0400 Subject: [PATCH 10/10] fix: making batch file name more robust so we don't run into decoding issues --- src/webapp/routers/data.py | 25 +++++++++++++++++-------- src/webapp/utilities.py | 31 ++++++++++++++++++++++++++++--- src/webapp/utilities_test.py | 23 +++++++++++++++++++++++ 3 files changed, 68 insertions(+), 11 deletions(-) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index 88092ee2..2fc841b9 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -5,7 +5,7 @@ from typing import Annotated, Any, Dict, List, Optional, Tuple, Union, cast from pydantic import BaseModel, Field from fastapi import APIRouter, Depends, HTTPException, status, Response -from sqlalchemy import and_, or_ +from sqlalchemy import and_, false, or_ from sqlalchemy.orm import Session from sqlalchemy.future import select import os @@ -29,6 +29,8 @@ DataSource, get_external_bucket_name, decode_url_piece, + expand_batch_file_name_lookups, + file_name_variants_for_lookup, ) from ..database import ( @@ -749,9 +751,16 @@ def create_batch( inst_id=str_to_uuid(inst_id), created_by=str_to_uuid(current_user.user_id), # type: ignore ) - f_names = [] if not req.file_names else req.file_names + f_names = [] if not req.file_names else list(req.file_names) f_ids = [] if not req.file_ids else strs_to_uuids(req.file_ids) - print(f"File names: {f_names}, File Ids: {f_ids}") + file_match_parts: List[Any] = [] + if f_ids: + file_match_parts.append(FileTable.id.in_(f_ids)) + if f_names: + file_match_parts.append( + FileTable.name.in_(expand_batch_file_name_lookups(f_names)) + ) + file_clause = or_(*file_match_parts) if file_match_parts else false() # Check that the files requested for this batch exists. # Only valid non-sst generated files can be added to a batch at creation time. query_result_files = ( @@ -759,10 +768,7 @@ def create_batch( .execute( select(FileTable).where( and_( - or_( - FileTable.id.in_(f_ids), - FileTable.name.in_(f_names), - ), + file_clause, FileTable.inst_id == str_to_uuid(inst_id), FileTable.valid == True, FileTable.sst_generated == False, @@ -904,12 +910,15 @@ def update_batch( if "file_names" in update_data_req: for f in update_data_req["file_names"]: # Check that the files requested for this batch exists + name_variants = list(file_name_variants_for_lookup(f)) query_result_file = ( local_session.get() .execute( select(FileTable).where( and_( - FileTable.name == f, + FileTable.name.in_(name_variants) + if name_variants + else false(), FileTable.inst_id == str_to_uuid(inst_id), ) ) diff --git a/src/webapp/utilities.py b/src/webapp/utilities.py index 76dcff74..54cd2e14 100644 --- a/src/webapp/utilities.py +++ b/src/webapp/utilities.py @@ -3,7 +3,7 @@ import uuid import re from typing import Annotated, Final, Any, Optional, Tuple, Union -from urllib.parse import unquote +from urllib.parse import unquote_plus from strenum import StrEnum # needed for python pre 3.11 import jwt from fastapi import HTTPException, status, Depends @@ -23,8 +23,33 @@ def decode_url_piece(src: str) -> str: - """Decode encoded URL.""" - return unquote(src) + """Decode a URL path segment the way most clients encode names. + + Uses :func:`urllib.parse.unquote_plus` so ``+`` is treated as a space (common + when clients apply form-style encoding to paths). A literal ``+`` in a name + must be sent as ``%2B``. + """ + return unquote_plus(src) + + +def file_name_variants_for_lookup(name: str) -> set[str]: + """Return spellings to try when matching a stored ``file.name``. + + Accepts common mismatches between spaces and ``+`` (and ``decode_url_piece``-style + decoding) so batch endpoints stay usable even if the client and DB disagree. + """ + n = unquote_plus(name.strip()) + if not n: + return set() + return {n, n.replace("+", " "), n.replace(" ", "+")} + + +def expand_batch_file_name_lookups(names: list[str]) -> list[str]: + """Flatten :func:`file_name_variants_for_lookup` for a SQL ``IN`` clause.""" + expanded: set[str] = set() + for name in names: + expanded |= file_name_variants_for_lookup(name) + return list(expanded) class AccessType(StrEnum): diff --git a/src/webapp/utilities_test.py b/src/webapp/utilities_test.py index 29617c9d..6d7dba64 100644 --- a/src/webapp/utilities_test.py +++ b/src/webapp/utilities_test.py @@ -4,6 +4,9 @@ from fastapi import HTTPException from .utilities import ( + decode_url_piece, + expand_batch_file_name_lookups, + file_name_variants_for_lookup, has_access_to_inst_or_err, has_full_data_access_or_err, has_at_most_one_school_type, @@ -91,3 +94,23 @@ def test_databricksify_inst_name(): with pytest.raises(ValueError) as err: databricksify_inst_name("Northwest (invalid)") assert str(err.value) == "Unexpected character found in Databricks compatible name." + + +def test_decode_url_piece_treats_plus_as_space() -> None: + """Form-style + encoding in paths should decode to spaces.""" + assert decode_url_piece("a+b.csv") == "a b.csv" + assert decode_url_piece("foo%20bar.csv") == "foo bar.csv" + assert decode_url_piece("x%2By.csv") == "x+y.csv" + + +def test_file_name_variants_for_lookup() -> None: + """Batch lookups accept spaces vs. plus spellings.""" + v = file_name_variants_for_lookup("a b.csv") + assert v == {"a b.csv", "a+b.csv"} + assert file_name_variants_for_lookup(" a+b.csv ") == {"a b.csv", "a+b.csv"} + + +def test_expand_batch_file_name_lookups() -> None: + out = set(expand_batch_file_name_lookups(["x y.csv", "p+q.csv"])) + assert "x y.csv" in out and "x+y.csv" in out + assert "p q.csv" in out and "p+q.csv" in out