Skip to content

Commit 3e0cb4b

Browse files
authored
Merge pull request #181 from datakind/ModelDeletionEndpoint
feat: added model deletion endpoint
2 parents dbb00ff + b0f69a9 commit 3e0cb4b

4 files changed

Lines changed: 1467 additions & 1401 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"pandera~=0.13",
3131
"mlflow~=2.15.0",
3232
"cachetools",
33+
"types-cachetools",
3334
]
3435

3536
[project.urls]

src/webapp/databricks.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def _sha256_json(obj: Any) -> str:
9090

9191
L1_RESP_CACHE_TTL = int("600") # seconds
9292
L1_VER_CACHE_TTL = int("3600") # seconds
93-
L1_RESP_CACHE = TTLCache(maxsize=128, ttl=L1_RESP_CACHE_TTL)
94-
L1_VER_CACHE = TTLCache(maxsize=256, ttl=L1_VER_CACHE_TTL)
93+
L1_RESP_CACHE: Any = TTLCache(maxsize=128, ttl=L1_RESP_CACHE_TTL)
94+
L1_VER_CACHE: Any = TTLCache(maxsize=256, ttl=L1_VER_CACHE_TTL)
9595
_L1_LOCK = threading.RLock()
9696

9797

@@ -251,7 +251,6 @@ def run_pdp_inference(
251251
], # is this value the same PER environ? dev/staging/prod
252252
"gcp_bucket_name": req.gcp_external_bucket_name,
253253
"model_name": req.model_name,
254-
"model_type": req.model_type,
255254
"notification_email": req.email,
256255
},
257256
)
@@ -333,7 +332,7 @@ def fetch_table_data(
333332
inst_name: str,
334333
table_name: str,
335334
warehouse_id: str,
336-
) -> List[Dict[str, Any]]:
335+
) -> Any:
337336
"""
338337
Execute SELECT * via Databricks SQL Statement Execution API using EXTERNAL_LINKS.
339338
Blocks server-side for up to 30s; if not SUCCEEDED, raises. Downloads presigned
@@ -366,9 +365,9 @@ def fetch_table_data(
366365

367366
if not ver_resp.status or ver_resp.status.state != StatementState.SUCCEEDED:
368367
raise TimeoutError("DESCRIBE HISTORY did not finish within 30s")
369-
cols = [c.name for c in ver_resp.manifest.schema.columns]
368+
cols = [c.name for c in ver_resp.manifest.schema.columns] # type: ignore
370369
idx = {n: i for i, n in enumerate(cols)}
371-
rows = ver_resp.result.data_array or []
370+
rows = ver_resp.result.data_array or [] # type: ignore
372371
if not rows or "version" not in idx:
373372
raise ValueError("DESCRIBE HISTORY returned no version")
374373
table_version = str(rows[0][idx["version"]])
@@ -432,13 +431,13 @@ def fetch_table_data(
432431
resp.manifest and resp.manifest.schema and resp.manifest.schema.columns
433432
):
434433
raise ValueError("Schema/columns missing (EXTERNAL_LINKS).")
435-
cols: List[str] = []
434+
cols: List[str] = [] # type: ignore
436435
for c in resp.manifest.schema.columns:
437436
if c.name is None:
438437
raise ValueError("Encountered a column without a name.")
439438
cols.append(c.name)
440439

441-
records: List[Dict[str, Any]] = []
440+
records: Any = []
442441

443442
# Helper: consume one chunk-like object (first result or subsequent chunk)
444443
def _consume_chunk(chunk_obj: Any) -> int | None:
@@ -504,7 +503,9 @@ def _consume_chunk(chunk_obj: Any) -> int | None:
504503
pass
505504
return records
506505

507-
def fetch_model_version(self, catalog_name: str, inst_name: str, model_name: str):
506+
def fetch_model_version(
507+
self, catalog_name: str, inst_name: str, model_name: str
508+
) -> Any:
508509
schema = databricksify_inst_name(inst_name)
509510
model_name_path = f"{catalog_name}.{schema}_gold.{model_name}"
510511

@@ -521,7 +522,7 @@ def fetch_model_version(self, catalog_name: str, inst_name: str, model_name: str
521522
)
522523
raise ValueError(f"setup_new_inst(): Workspace client creation failed: {e}")
523524

524-
model_versions = list(
525+
model_versions: Any = list(
525526
w.model_versions.list(
526527
full_name=model_name_path,
527528
)
@@ -534,6 +535,30 @@ def fetch_model_version(self, catalog_name: str, inst_name: str, model_name: str
534535

535536
return latest_version
536537

538+
def delete_model(self, catalog_name: str, inst_name: str, model_name: str) -> None:
539+
schema = databricksify_inst_name(inst_name)
540+
model_name_path = f"{catalog_name}.{schema}_gold.{model_name}"
541+
542+
try:
543+
w = WorkspaceClient(
544+
host=databricks_vars["DATABRICKS_HOST_URL"],
545+
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
546+
)
547+
except Exception as e:
548+
LOGGER.exception(
549+
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
550+
databricks_vars["DATABRICKS_HOST_URL"],
551+
gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
552+
)
553+
raise ValueError(f"setup_new_inst(): Workspace client creation failed: {e}")
554+
555+
try:
556+
w.registered_models.delete(full_name=model_name_path)
557+
LOGGER.info("Deleted registration model: %s", model_name_path)
558+
except Exception:
559+
LOGGER.exception("Failed to delete registered model: %s", model_name_path)
560+
raise
561+
537562
def get_key_for_file(
538563
self, mapping: Dict[str, Any], file_name: str
539564
) -> Optional[str]:

src/webapp/routers/models.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,54 @@ def read_inst_model(
311311
}
312312

313313

314+
@router.delete("/{inst_id}/models/{model_name}")
315+
def delete_model(
316+
inst_id: str,
317+
model_name: str,
318+
delete_from_databricks: bool,
319+
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
320+
sql_session: Annotated[Session, Depends(get_session)],
321+
databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)],
322+
) -> Any:
323+
transformed_model_name = str(decode_url_piece(model_name)).strip()
324+
has_access_to_inst_or_err(inst_id, current_user)
325+
model_owner_and_higher_or_err(current_user, "modify batch")
326+
327+
local_session.set(sql_session)
328+
sess = local_session.get()
329+
330+
query_result = sess.execute(
331+
select(InstTable).where(InstTable.id == str_to_uuid(inst_id))
332+
).all()
333+
334+
model_list = sess.execute(
335+
select(ModelTable).where(
336+
ModelTable.name == str_to_uuid(model_name),
337+
ModelTable.inst_id == str_to_uuid(inst_id),
338+
)
339+
).scalar_one_or_none()
340+
if model_list is None:
341+
raise HTTPException(
342+
status_code=status.HTTP_404_NOT_FOUND, detail="Model not found."
343+
)
344+
345+
if delete_from_databricks:
346+
# 2) Optionally Delete models from databricks itself
347+
databricks_control.delete_model(
348+
catalog_name=str(env_vars["CATALOG_NAME"]),
349+
inst_name=f"{query_result[0][0].name}",
350+
model_name=transformed_model_name,
351+
)
352+
353+
sess.delete(model_list)
354+
sess.commit()
355+
return {
356+
"inst_id": inst_id,
357+
"model_name": transformed_model_name,
358+
"deleted_from_databricks": delete_from_databricks,
359+
}
360+
361+
314362
@router.get("/{inst_id}/models/{model_name}/runs", response_model=list[RunInfo])
315363
def read_inst_model_outputs(
316364
inst_id: str,
@@ -710,7 +758,7 @@ def backfill_model_runs(
710758
.values(model_run_id=mv_run_id, model_version=mv_version)
711759
)
712760
result = local_session.get().execute(stmt)
713-
updated_count = result.rowcount or 0
761+
updated_count = result.rowcount or 0 # type: ignore
714762
local_session.get().commit()
715763

716764
return {

0 commit comments

Comments
 (0)