Skip to content

Commit bd44916

Browse files
committed
fix: formatting style
1 parent 167112a commit bd44916

3 files changed

Lines changed: 14 additions & 8 deletions

File tree

src/webapp/database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,8 @@ class JobTable(Base):
551551
String(VAR_CHAR_STANDARD_LENGTH), nullable=True
552552
)
553553
completed: Mapped[bool] = mapped_column(nullable=True)
554-
framework: Mapped[str | None] = mapped_column(
555-
String(VAR_CHAR_STANDARD_LENGTH), nullable=False, default="sklearn"
554+
model_run_id: Mapped[str | None] = mapped_column(
555+
String(VAR_CHAR_STANDARD_LENGTH), nullable=True
556556
)
557557

558558

src/webapp/databricks.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,9 +523,11 @@ def fetch_model_version(self, catalog_name: str, inst_name: str, model_name: str
523523
)
524524
raise ValueError(f"setup_new_inst(): Workspace client creation failed: {e}")
525525

526-
model_versions = list(w.model_versions.list(
527-
full_name=model_name_path,
528-
))
526+
model_versions = list(
527+
w.model_versions.list(
528+
full_name=model_name_path,
529+
)
530+
)
529531

530532
if not model_versions:
531533
raise ValueError(f"No versions found for model: {model_name_path}")

src/webapp/routers/models.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,6 @@ def trigger_inference_run(
556556
gcp_external_bucket_name=get_external_bucket_name(inst_id),
557557
# The institution email to which pipeline success/failure notifications will get sent.
558558
email=cast(str, current_user.email),
559-
model_type=query_result[0][0].framework,
560559
)
561560
try:
562561
res = databricks_control.run_pdp_inference(db_req)
@@ -568,14 +567,19 @@ def trigger_inference_run(
568567
detail=f"Databricks run_pdp_inference error. Error = {str(e)}",
569568
) from e
570569
triggered_timestamp = datetime.now()
570+
latest_model_version = databricks_control.fetch_model_version(
571+
catalog_name=env_vars["CATALOG_NAME"],
572+
inst_name=inst_result[0][0].name,
573+
model_name=model_name,
574+
)
571575
job = JobTable(
572576
id=res.job_run_id,
573577
triggered_at=triggered_timestamp,
574578
created_by=str_to_uuid(current_user.user_id),
575579
batch_name=req.batch_name,
576580
model_id=query_result[0][0].id,
577581
output_valid=False,
578-
framework=query_result[0][0].framework,
582+
model_run_id=latest_model_version.run_id,
579583
)
580584
local_session.get().add(job)
581585
return {
@@ -586,7 +590,7 @@ def trigger_inference_run(
586590
"triggered_at": triggered_timestamp,
587591
"batch_name": req.batch_name,
588592
"output_valid": False,
589-
"framework": query_result[0][0].framework,
593+
"model_run_id": latest_model_version.run_id,
590594
}
591595

592596

0 commit comments

Comments
 (0)