Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from server_api.auth.database import get_db
from server_api.auth.router import get_current_user
from server_api.ehtool import router as ehtool_router
from server_api.workflow import router as workflow_router

from fastapi.staticfiles import StaticFiles
import os
Expand Down Expand Up @@ -77,6 +78,7 @@ def _ensure_chatbot():

app.include_router(auth_router.router)
app.include_router(ehtool_router.router, prefix="/eh", tags=["ehtool"])
app.include_router(workflow_router)

app.add_middleware(
CORSMiddleware,
Expand Down
3 changes: 3 additions & 0 deletions server_api/workflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .router import router

__all__ = ["router"]
35 changes: 35 additions & 0 deletions server_api/workflow/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from datetime import datetime
from typing import Any, Dict, List, Optional

from pydantic import BaseModel


class WorkflowEvent(BaseModel):
event_type: str
event_time: Optional[str] = None
payload: Dict[str, Any]


class ArtifactReference(BaseModel):
path: str
exists: bool


class WorkflowSessionSnapshot(BaseModel):
id: int
user_id: int
project_name: str
workflow_type: str
dataset_path: str
mask_path: Optional[str] = None
total_layers: int
created_at: Optional[str] = None
updated_at: Optional[str] = None


class WorkflowExportBundle(BaseModel):
schema_version: str
exported_at: str
workflow_session: WorkflowSessionSnapshot
events: List[WorkflowEvent]
artifacts: List[ArtifactReference]
131 changes: 131 additions & 0 deletions server_api/workflow/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from datetime import datetime, timezone
import os
from typing import List

from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session

from server_api.auth.database import get_db
from server_api.auth.models import User
from server_api.auth.router import get_current_user
from server_api.ehtool.db_models import EHToolLayer, EHToolSession

from .models import (
ArtifactReference,
WorkflowEvent,
WorkflowExportBundle,
WorkflowSessionSnapshot,
)

router = APIRouter(prefix="/api/workflows", tags=["workflow"])

SCHEMA_VERSION = "1.0"


def _to_iso(ts):
if ts is None:
return None
if ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)
return ts.astimezone(timezone.utc).isoformat()


@router.post("/{workflow_id}/export-bundle", response_model=WorkflowExportBundle)
async def export_workflow_bundle(
workflow_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
workflow = (
db.query(EHToolSession)
.filter(EHToolSession.id == workflow_id, EHToolSession.user_id == current_user.id)
.first()
)
if not workflow:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workflow not found")

layers = (
db.query(EHToolLayer)
.filter(EHToolLayer.session_id == workflow.id)
.order_by(EHToolLayer.layer_index.asc(), EHToolLayer.id.asc())
.all()
)

events: List[WorkflowEvent] = [
WorkflowEvent(
event_type="workflow.created",
event_time=_to_iso(workflow.created_at),
payload={
"workflow_id": workflow.id,
"project_name": workflow.project_name,
"workflow_type": workflow.workflow_type,
},
)
]

for layer in layers:
events.append(
WorkflowEvent(
event_type="layer.indexed",
event_time=_to_iso(layer.created_at),
payload={
"layer_id": layer.id,
"layer_index": layer.layer_index,
"layer_name": layer.layer_name,
"classification": layer.classification,
},
)
)
if layer.updated_at is not None:
events.append(
WorkflowEvent(
event_type="layer.updated",
event_time=_to_iso(layer.updated_at),
payload={
"layer_id": layer.id,
"classification": layer.classification,
},
)
)

events.sort(
key=lambda item: (
item.event_time is None,
item.event_time or "",
item.event_type,
item.payload.get("layer_id", -1),
item.payload.get("layer_index", -1),
)
)

artifact_paths = []
for path in [workflow.dataset_path, workflow.mask_path]:
if path:
artifact_paths.append(path)
for layer in layers:
for path in [layer.image_path, layer.mask_path]:
if path:
artifact_paths.append(path)

deduped_paths = sorted(set(artifact_paths))
artifacts = [
ArtifactReference(path=path, exists=os.path.exists(path)) for path in deduped_paths
]

return WorkflowExportBundle(
schema_version=SCHEMA_VERSION,
exported_at=datetime.now(timezone.utc).isoformat(),
workflow_session=WorkflowSessionSnapshot(
id=workflow.id,
user_id=workflow.user_id,
project_name=workflow.project_name,
workflow_type=workflow.workflow_type,
dataset_path=workflow.dataset_path,
mask_path=workflow.mask_path,
total_layers=workflow.total_layers,
created_at=_to_iso(workflow.created_at),
updated_at=_to_iso(workflow.updated_at),
),
events=events,
artifacts=artifacts,
)
151 changes: 151 additions & 0 deletions tests/test_workflow_export_bundle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import pathlib
import tempfile
import unittest
from datetime import datetime, timedelta, timezone

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from server_api.auth import database as auth_database
from server_api.auth import models as auth_models
from server_api.auth.router import get_current_user
from server_api.ehtool.db_models import EHToolLayer, EHToolSession
from fastapi import FastAPI

from server_api.workflow.router import router as workflow_router


class WorkflowExportBundleTests(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.db_path = pathlib.Path(self.temp_dir.name) / "workflow-export.db"
self.engine = create_engine(
f"sqlite:///{self.db_path}", connect_args={"check_same_thread": False}
)
self.SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=self.engine
)

auth_models.Base.metadata.create_all(bind=self.engine)

def override_get_db():
db = self.SessionLocal()
try:
yield db
finally:
db.close()

self.app = FastAPI()
self.app.include_router(workflow_router)
self.app.dependency_overrides[auth_database.get_db] = override_get_db

with self.SessionLocal() as db:
user = auth_models.User(
username="bundle_user",
email="bundle@example.com",
hashed_password="hashed",
)
db.add(user)
db.commit()
db.refresh(user)
self.user = user

self.app.dependency_overrides[get_current_user] = lambda: self.user
self.client = TestClient(self.app)

def tearDown(self):
self.app.dependency_overrides.clear()
self.engine.dispose()
self.temp_dir.cleanup()

def _create_session_with_layers(self):
dataset_path = pathlib.Path(self.temp_dir.name) / "dataset.tif"
mask_path = pathlib.Path(self.temp_dir.name) / "mask.tif"
image_path = pathlib.Path(self.temp_dir.name) / "layer_0.png"

dataset_path.write_text("dataset", encoding="utf-8")
mask_path.write_text("mask", encoding="utf-8")
image_path.write_text("image", encoding="utf-8")

t0 = datetime(2026, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
t1 = t0 + timedelta(minutes=1)
t2 = t0 + timedelta(minutes=2)

with self.SessionLocal() as db:
session = EHToolSession(
user_id=self.user.id,
project_name="Bundle Project",
workflow_type="detection",
dataset_path=str(dataset_path),
mask_path=str(mask_path),
total_layers=2,
created_at=t0,
updated_at=t2,
)
db.add(session)
db.commit()
db.refresh(session)

layer0 = EHToolLayer(
session_id=session.id,
layer_index=0,
layer_name="layer-0",
classification="correct",
image_path=str(image_path),
mask_path=str(mask_path),
created_at=t1,
updated_at=t2,
)
layer1 = EHToolLayer(
session_id=session.id,
layer_index=1,
layer_name="layer-1",
classification="error",
image_path=str(pathlib.Path(self.temp_dir.name) / "missing-layer.png"),
mask_path=None,
created_at=t2,
updated_at=t2,
)
db.add_all([layer0, layer1])
db.commit()
return session.id, str(pathlib.Path(self.temp_dir.name) / "missing-layer.png")

def test_export_bundle_happy_path_returns_deterministic_structure(self):
session_id, _ = self._create_session_with_layers()

response = self.client.post(f"/api/workflows/{session_id}/export-bundle")

self.assertEqual(response.status_code, 200)
payload = response.json()
self.assertEqual(payload["schema_version"], "1.0")
self.assertIn("exported_at", payload)

snapshot = payload["workflow_session"]
self.assertEqual(snapshot["id"], session_id)
self.assertEqual(snapshot["project_name"], "Bundle Project")
self.assertEqual(snapshot["workflow_type"], "detection")
self.assertEqual(snapshot["total_layers"], 2)

event_times = [event["event_time"] for event in payload["events"] if event["event_time"]]
self.assertEqual(event_times, sorted(event_times))

self.assertEqual(
sorted(payload.keys()),
["artifacts", "events", "exported_at", "schema_version", "workflow_session"],
)

def test_export_bundle_marks_missing_artifact_paths_without_failing(self):
session_id, missing_path = self._create_session_with_layers()

response = self.client.post(f"/api/workflows/{session_id}/export-bundle")

self.assertEqual(response.status_code, 200)
artifacts = {artifact["path"]: artifact["exists"] for artifact in response.json()["artifacts"]}

self.assertIn(missing_path, artifacts)
self.assertFalse(artifacts[missing_path])


if __name__ == "__main__":
unittest.main()
Loading