Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 155 additions & 11 deletions src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import logging
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
from databricks.sdk.service.sql import (
Expand Down Expand Up @@ -31,12 +31,79 @@
# List of data medallion levels
MEDALLION_LEVELS = ["silver", "gold", "bronze"]

# The name of the deployed pipeline in Databricks. Must match directly.
# The name of the deployed pipeline in Databricks. Must match the job's `name` in that workspace.
# Override with LEGACY_INFERENCE_JOB_NAME (and PDP_INFERENCE_JOB_NAME) when dev/staging deploy
# uses a different bundle target or a stub job that matches the same parameters.
PDP_INFERENCE_JOB_NAME = "edvise_github_sourced_pdp_inference_pipeline"
LEGACY_INFERENCE_JOB_NAME = "edvise_github_sourced_legacy_inference_pipeline"


class DatabricksInferenceRunRequest(BaseModel):
"""Databricks parameters for an inference run."""
def _pdp_inference_job_name() -> str:
name = os.environ.get("PDP_INFERENCE_JOB_NAME", "").strip()
return name or PDP_INFERENCE_JOB_NAME


def _legacy_inference_job_name() -> str:
name = os.environ.get("LEGACY_INFERENCE_JOB_NAME", "").strip()
return name or LEGACY_INFERENCE_JOB_NAME


def _resolve_pipeline_job(w: Any, pipeline_type: str, caller_label: str) -> Any:
"""Find a job by exact name, else by unique substring match on display name.

Development bundles often prefix job names (e.g. ``[dev vishakh] edvise_...``) while
the API passes the canonical base name. Optional env ``PDP_INFERENCE_JOB_NAME`` /
``LEGACY_INFERENCE_JOB_NAME`` still wins when set to the full exact name.
"""
job = next(w.jobs.list(name=pipeline_type), None)
if job is not None and getattr(job, "job_id", None) is not None:
LOGGER.info(
"%s: resolved job by exact name %r (job_id=%s)",
caller_label,
pipeline_type,
job.job_id,
)
return job

matches: list[tuple[str, Any]] = []
for j in w.jobs.list():
settings = getattr(j, "settings", None)
name = getattr(settings, "name", None) if settings is not None else None
if not name:
continue
if pipeline_type in name:
matches.append((name, j))

if len(matches) == 1:
picked_name, picked = matches[0]
if getattr(picked, "job_id", None) is None:
raise ValueError(
f"{caller_label}: Job name {picked_name!r} matched substring {pipeline_type!r} but has no job_id."
)
LOGGER.info(
"%s: resolved job by substring %r -> display name %r (job_id=%s)",
caller_label,
pipeline_type,
picked_name,
picked.job_id,
)
return picked

if len(matches) > 1:
names = [n for n, _ in matches]
raise ValueError(
f"{caller_label}: Multiple jobs match substring {pipeline_type!r}: {names}. "
"Set PDP_INFERENCE_JOB_NAME or LEGACY_INFERENCE_JOB_NAME to the full job name."
)

raise ValueError(
f"{caller_label}: Job {pipeline_type!r} was not found (exact name or unique substring of settings.name) "
f"for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'."
)


class DatabricksPDPInferenceRunRequest(BaseModel):
"""Databricks parameters for a PDP inference run."""

inst_name: str
# Note that the following should be the filepath.
Expand All @@ -47,6 +114,24 @@ class DatabricksInferenceRunRequest(BaseModel):
gcp_external_bucket_name: str


class DatabricksLegacyInferenceRunRequest(BaseModel):
"""Databricks parameters for a legacy schools inference run."""

inst_name: str
model_name: str
config_file_name: str = ""
features_table_name: str = ""
# The email where notifications will get sent.
email: str = ""
gcp_external_bucket_name: str

@field_validator("config_file_name", "features_table_name", "email", mode="before")
@classmethod
def _none_to_empty_str(cls, v: object) -> object:
"""Allow callers to omit or pass null; Databricks job treats empty like YAML defaults."""
return "" if v is None else v


class DatabricksInferenceRunResponse(BaseModel):
"""Databricks parameters for an inference run."""

Expand Down Expand Up @@ -183,7 +268,7 @@ def setup_new_inst(self, inst_name: str) -> None:
# E.g. there is one PDP inference pipeline, so one PDP inference function here.

def run_pdp_inference(
self, req: DatabricksInferenceRunRequest
self, req: DatabricksPDPInferenceRunRequest
) -> DatabricksInferenceRunResponse:
"""Triggers PDP inference Databricks run."""
LOGGER.info(f"Running PDP inference for institution: {req.inst_name}")
Expand Down Expand Up @@ -213,14 +298,10 @@ def run_pdp_inference(
)

db_inst_name = databricksify_inst_name(req.inst_name)
pipeline_type = PDP_INFERENCE_JOB_NAME
pipeline_type = _pdp_inference_job_name()

try:
job = next(w.jobs.list(name=pipeline_type), None)
if not job or job.job_id is None:
raise ValueError(
f"run_pdp_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'."
)
job = _resolve_pipeline_job(w, pipeline_type, "run_pdp_inference")
job_id = job.job_id
LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}")
except Exception as e:
Expand Down Expand Up @@ -261,6 +342,69 @@ def run_pdp_inference(

return DatabricksInferenceRunResponse(job_run_id=run_id)

def run_legacy_inference(
self, req: DatabricksLegacyInferenceRunRequest
) -> DatabricksInferenceRunResponse:
"""Triggers legacy schools inference Databricks run."""
LOGGER.info(f"Running legacy inference for institution: {req.inst_name}")
try:
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
LOGGER.info("Successfully created Databricks WorkspaceClient.")
except Exception as e:
LOGGER.exception(
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
databricks_vars["DATABRICKS_HOST_URL"],
gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
raise ValueError(
f"run_legacy_inference(): Workspace client initialization failed: {e}"
)

db_inst_name = databricksify_inst_name(req.inst_name)
pipeline_type = _legacy_inference_job_name()

try:
job = _resolve_pipeline_job(w, pipeline_type, "run_legacy_inference")
job_id = job.job_id
LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}")
except Exception as e:
LOGGER.exception(f"Job lookup failed for '{pipeline_type}'.")
raise ValueError(f"run_legacy_inference(): Failed to find job: {e}")

try:
run_job: Any = w.jobs.run_now(
job_id,
job_parameters={
"databricks_institution_name": db_inst_name,
"DB_workspace": databricks_vars["DATABRICKS_WORKSPACE"],
"model_name": req.model_name,
"config_file_name": req.config_file_name,
"features_table_name": req.features_table_name,
"gcp_bucket_name": req.gcp_external_bucket_name,
"datakind_notification_email": req.email,
"DK_CC_EMAIL": req.email,
},
)
LOGGER.info(
f"Successfully triggered job run. Run ID: {run_job.response.run_id}"
)
except Exception as e:
LOGGER.exception("Failed to run the legacy inference job.")
raise ValueError(f"run_legacy_inference(): Job could not be run: {e}")

if not run_job.response or run_job.response.run_id is None:
raise ValueError(
"run_legacy_inference(): Job did not return a valid run_id."
)

run_id = run_job.response.run_id
LOGGER.info(f"Successfully triggered job run. Run ID: {run_id}")

return DatabricksInferenceRunResponse(job_run_id=run_id)

def delete_inst(self, inst_name: str) -> None:
"""Cleanup tasks required on the Databricks side to delete an institution."""
db_inst_name = databricksify_inst_name(inst_name)
Expand Down
60 changes: 59 additions & 1 deletion src/webapp/databricks_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from unittest.mock import MagicMock

from .databricks import DatabricksControl
from .databricks import DatabricksControl, _resolve_pipeline_job


@pytest.fixture
Expand Down Expand Up @@ -51,3 +52,60 @@ def test_invalid_regex_is_ignored(ctrl):
def test_returns_none_when_no_match(ctrl):
mapping = {"student": "student.csv"}
assert ctrl.get_key_for_file(mapping, "unknown.csv") is None


def _job_named(full_name: str, job_id: int = 42) -> MagicMock:
j = MagicMock()
j.job_id = job_id
j.settings = MagicMock()
j.settings.name = full_name
return j


def test_resolve_pipeline_job_exact_match_skips_scan():
canonical = "edvise_github_sourced_pdp_inference_pipeline"
hit = _job_named(canonical, job_id=7)

def list_jobs(name=None):
if name == canonical:
return iter([hit])
return iter([])

w = MagicMock()
w.jobs.list.side_effect = list_jobs

assert _resolve_pipeline_job(w, canonical, "test").job_id == 7
w.jobs.list.assert_called_once()


def test_resolve_pipeline_job_substring_dev_prefix():
canonical = "edvise_github_sourced_pdp_inference_pipeline"
hit = _job_named(f"[dev vishakh] {canonical}", job_id=11)

def list_jobs(name=None):
if name is not None:
return iter([])
return iter([hit])

w = MagicMock()
w.jobs.list.side_effect = list_jobs

assert _resolve_pipeline_job(w, canonical, "test").job_id == 11
assert w.jobs.list.call_count == 2


def test_resolve_pipeline_job_ambiguous_substring_raises():
canonical = "edvise_github_sourced_pdp_inference_pipeline"
a = _job_named(f"[dev a] {canonical}", job_id=1)
b = _job_named(f"[dev b] {canonical}", job_id=2)

def list_jobs(name=None):
if name is not None:
return iter([])
return iter([a, b])

w = MagicMock()
w.jobs.list.side_effect = list_jobs

with pytest.raises(ValueError, match="Multiple jobs match substring"):
_resolve_pipeline_job(w, canonical, "test")
25 changes: 17 additions & 8 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Annotated, Any, Dict, List, Optional, Tuple, Union, cast
from pydantic import BaseModel, Field
from fastapi import APIRouter, Depends, HTTPException, status, Response
from sqlalchemy import and_, or_
from sqlalchemy import and_, false, or_
from sqlalchemy.orm import Session
from sqlalchemy.future import select
import os
Expand All @@ -29,6 +29,8 @@
DataSource,
get_external_bucket_name,
decode_url_piece,
expand_batch_file_name_lookups,
file_name_variants_for_lookup,
)

from ..database import (
Expand Down Expand Up @@ -749,20 +751,24 @@ def create_batch(
inst_id=str_to_uuid(inst_id),
created_by=str_to_uuid(current_user.user_id), # type: ignore
)
f_names = [] if not req.file_names else req.file_names
f_names = [] if not req.file_names else list(req.file_names)
f_ids = [] if not req.file_ids else strs_to_uuids(req.file_ids)
print(f"File names: {f_names}, File Ids: {f_ids}")
file_match_parts: List[Any] = []
if f_ids:
file_match_parts.append(FileTable.id.in_(f_ids))
if f_names:
file_match_parts.append(
FileTable.name.in_(expand_batch_file_name_lookups(f_names))
)
file_clause = or_(*file_match_parts) if file_match_parts else false()
# Check that the files requested for this batch exists.
# Only valid non-sst generated files can be added to a batch at creation time.
query_result_files = (
local_session.get()
.execute(
select(FileTable).where(
and_(
or_(
FileTable.id.in_(f_ids),
FileTable.name.in_(f_names),
),
file_clause,
FileTable.inst_id == str_to_uuid(inst_id),
FileTable.valid == True,
FileTable.sst_generated == False,
Expand Down Expand Up @@ -904,12 +910,15 @@ def update_batch(
if "file_names" in update_data_req:
for f in update_data_req["file_names"]:
# Check that the files requested for this batch exists
name_variants = list(file_name_variants_for_lookup(f))
query_result_file = (
local_session.get()
.execute(
select(FileTable).where(
and_(
FileTable.name == f,
FileTable.name.in_(name_variants)
if name_variants
else false(),
FileTable.inst_id == str_to_uuid(inst_id),
)
)
Expand Down
Loading
Loading