diff --git a/server_api/main.py b/server_api/main.py index a97153f..29305f1 100644 --- a/server_api/main.py +++ b/server_api/main.py @@ -21,6 +21,7 @@ from server_api.auth.database import get_db from server_api.auth.router import get_current_user from server_api.ehtool import router as ehtool_router +from server_api.workflow import router as workflow_router from fastapi.staticfiles import StaticFiles import os @@ -77,6 +78,7 @@ def _ensure_chatbot(): app.include_router(auth_router.router) app.include_router(ehtool_router.router, prefix="/eh", tags=["ehtool"]) +app.include_router(workflow_router) app.add_middleware( CORSMiddleware, diff --git a/server_api/workflow/__init__.py b/server_api/workflow/__init__.py new file mode 100644 index 0000000..5bc0c2e --- /dev/null +++ b/server_api/workflow/__init__.py @@ -0,0 +1,3 @@ +from .router import router + +__all__ = ["router"] diff --git a/server_api/workflow/models.py b/server_api/workflow/models.py new file mode 100644 index 0000000..bc4f89f --- /dev/null +++ b/server_api/workflow/models.py @@ -0,0 +1,35 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + + +class WorkflowEvent(BaseModel): + event_type: str + event_time: Optional[str] = None + payload: Dict[str, Any] + + +class ArtifactReference(BaseModel): + path: str + exists: bool + + +class WorkflowSessionSnapshot(BaseModel): + id: int + user_id: int + project_name: str + workflow_type: str + dataset_path: str + mask_path: Optional[str] = None + total_layers: int + created_at: Optional[str] = None + updated_at: Optional[str] = None + + +class WorkflowExportBundle(BaseModel): + schema_version: str + exported_at: str + workflow_session: WorkflowSessionSnapshot + events: List[WorkflowEvent] + artifacts: List[ArtifactReference] diff --git a/server_api/workflow/router.py b/server_api/workflow/router.py new file mode 100644 index 0000000..1ccb357 --- /dev/null +++ b/server_api/workflow/router.py @@ -0,0 +1,131 @@ +from datetime import datetime, timezone +import os +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from server_api.auth.database import get_db +from server_api.auth.models import User +from server_api.auth.router import get_current_user +from server_api.ehtool.db_models import EHToolLayer, EHToolSession + +from .models import ( + ArtifactReference, + WorkflowEvent, + WorkflowExportBundle, + WorkflowSessionSnapshot, +) + +router = APIRouter(prefix="/api/workflows", tags=["workflow"]) + +SCHEMA_VERSION = "1.0" + + +def _to_iso(ts): + if ts is None: + return None + if ts.tzinfo is None: + ts = ts.replace(tzinfo=timezone.utc) + return ts.astimezone(timezone.utc).isoformat() + + +@router.post("/{workflow_id}/export-bundle", response_model=WorkflowExportBundle) +async def export_workflow_bundle( + workflow_id: int, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + workflow = ( + db.query(EHToolSession) + .filter(EHToolSession.id == workflow_id, EHToolSession.user_id == current_user.id) + .first() + ) + if not workflow: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workflow not found") + + layers = ( + db.query(EHToolLayer) + .filter(EHToolLayer.session_id == workflow.id) + .order_by(EHToolLayer.layer_index.asc(), EHToolLayer.id.asc()) + .all() + ) + + events: List[WorkflowEvent] = [ + WorkflowEvent( + event_type="workflow.created", + event_time=_to_iso(workflow.created_at), + payload={ + "workflow_id": workflow.id, + "project_name": workflow.project_name, + "workflow_type": workflow.workflow_type, + }, + ) + ] + + for layer in layers: + events.append( + WorkflowEvent( + event_type="layer.indexed", + event_time=_to_iso(layer.created_at), + payload={ + "layer_id": layer.id, + "layer_index": layer.layer_index, + "layer_name": layer.layer_name, + "classification": layer.classification, + }, + ) + ) + if layer.updated_at is not None: + events.append( + WorkflowEvent( + event_type="layer.updated", + event_time=_to_iso(layer.updated_at), + payload={ + "layer_id": layer.id, + "classification": layer.classification, + }, + ) + ) + + events.sort( + key=lambda item: ( + item.event_time is None, + item.event_time or "", + item.event_type, + item.payload.get("layer_id", -1), + item.payload.get("layer_index", -1), + ) + ) + + artifact_paths = [] + for path in [workflow.dataset_path, workflow.mask_path]: + if path: + artifact_paths.append(path) + for layer in layers: + for path in [layer.image_path, layer.mask_path]: + if path: + artifact_paths.append(path) + + deduped_paths = sorted(set(artifact_paths)) + artifacts = [ + ArtifactReference(path=path, exists=os.path.exists(path)) for path in deduped_paths + ] + + return WorkflowExportBundle( + schema_version=SCHEMA_VERSION, + exported_at=datetime.now(timezone.utc).isoformat(), + workflow_session=WorkflowSessionSnapshot( + id=workflow.id, + user_id=workflow.user_id, + project_name=workflow.project_name, + workflow_type=workflow.workflow_type, + dataset_path=workflow.dataset_path, + mask_path=workflow.mask_path, + total_layers=workflow.total_layers, + created_at=_to_iso(workflow.created_at), + updated_at=_to_iso(workflow.updated_at), + ), + events=events, + artifacts=artifacts, + ) diff --git a/tests/test_workflow_export_bundle.py b/tests/test_workflow_export_bundle.py new file mode 100644 index 0000000..c3d973c --- /dev/null +++ b/tests/test_workflow_export_bundle.py @@ -0,0 +1,151 @@ +import pathlib +import tempfile +import unittest +from datetime import datetime, timedelta, timezone + +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from server_api.auth import database as auth_database +from server_api.auth import models as auth_models +from server_api.auth.router import get_current_user +from server_api.ehtool.db_models import EHToolLayer, EHToolSession +from fastapi import FastAPI + +from server_api.workflow.router import router as workflow_router + + +class WorkflowExportBundleTests(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.db_path = pathlib.Path(self.temp_dir.name) / "workflow-export.db" + self.engine = create_engine( + f"sqlite:///{self.db_path}", connect_args={"check_same_thread": False} + ) + self.SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=self.engine + ) + + auth_models.Base.metadata.create_all(bind=self.engine) + + def override_get_db(): + db = self.SessionLocal() + try: + yield db + finally: + db.close() + + self.app = FastAPI() + self.app.include_router(workflow_router) + self.app.dependency_overrides[auth_database.get_db] = override_get_db + + with self.SessionLocal() as db: + user = auth_models.User( + username="bundle_user", + email="bundle@example.com", + hashed_password="hashed", + ) + db.add(user) + db.commit() + db.refresh(user) + self.user = user + + self.app.dependency_overrides[get_current_user] = lambda: self.user + self.client = TestClient(self.app) + + def tearDown(self): + self.app.dependency_overrides.clear() + self.engine.dispose() + self.temp_dir.cleanup() + + def _create_session_with_layers(self): + dataset_path = pathlib.Path(self.temp_dir.name) / "dataset.tif" + mask_path = pathlib.Path(self.temp_dir.name) / "mask.tif" + image_path = pathlib.Path(self.temp_dir.name) / "layer_0.png" + + dataset_path.write_text("dataset", encoding="utf-8") + mask_path.write_text("mask", encoding="utf-8") + image_path.write_text("image", encoding="utf-8") + + t0 = datetime(2026, 1, 1, 10, 0, 0, tzinfo=timezone.utc) + t1 = t0 + timedelta(minutes=1) + t2 = t0 + timedelta(minutes=2) + + with self.SessionLocal() as db: + session = EHToolSession( + user_id=self.user.id, + project_name="Bundle Project", + workflow_type="detection", + dataset_path=str(dataset_path), + mask_path=str(mask_path), + total_layers=2, + created_at=t0, + updated_at=t2, + ) + db.add(session) + db.commit() + db.refresh(session) + + layer0 = EHToolLayer( + session_id=session.id, + layer_index=0, + layer_name="layer-0", + classification="correct", + image_path=str(image_path), + mask_path=str(mask_path), + created_at=t1, + updated_at=t2, + ) + layer1 = EHToolLayer( + session_id=session.id, + layer_index=1, + layer_name="layer-1", + classification="error", + image_path=str(pathlib.Path(self.temp_dir.name) / "missing-layer.png"), + mask_path=None, + created_at=t2, + updated_at=t2, + ) + db.add_all([layer0, layer1]) + db.commit() + return session.id, str(pathlib.Path(self.temp_dir.name) / "missing-layer.png") + + def test_export_bundle_happy_path_returns_deterministic_structure(self): + session_id, _ = self._create_session_with_layers() + + response = self.client.post(f"/api/workflows/{session_id}/export-bundle") + + self.assertEqual(response.status_code, 200) + payload = response.json() + self.assertEqual(payload["schema_version"], "1.0") + self.assertIn("exported_at", payload) + + snapshot = payload["workflow_session"] + self.assertEqual(snapshot["id"], session_id) + self.assertEqual(snapshot["project_name"], "Bundle Project") + self.assertEqual(snapshot["workflow_type"], "detection") + self.assertEqual(snapshot["total_layers"], 2) + + event_times = [event["event_time"] for event in payload["events"] if event["event_time"]] + self.assertEqual(event_times, sorted(event_times)) + + self.assertEqual( + sorted(payload.keys()), + ["artifacts", "events", "exported_at", "schema_version", "workflow_session"], + ) + + def test_export_bundle_marks_missing_artifact_paths_without_failing(self): + session_id, missing_path = self._create_session_with_layers() + + response = self.client.post(f"/api/workflows/{session_id}/export-bundle") + + self.assertEqual(response.status_code, 200) + artifacts = {artifact["path"]: artifact["exists"] for artifact in response.json()["artifacts"]} + + self.assertIn(missing_path, artifacts) + self.assertFalse(artifacts[missing_path]) + + +if __name__ == "__main__": + unittest.main()