Skip to content

Commit 7b9c85d

Browse files
committed
refactor: extract run sync service and use better list filters structure
1 parent 788d451 commit 7b9c85d

7 files changed

Lines changed: 282 additions & 112 deletions

File tree

src/pypsa_app/backend/api/routes/networks.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile
77
from fastapi import Path as PathParam
8+
from pydantic import BaseModel
89
from sqlalchemy import ColumnElement, or_
910
from sqlalchemy.orm import Session, joinedload
1011

@@ -87,14 +88,20 @@ def upload_network(
8788
tmp.unlink(missing_ok=True)
8889

8990

90-
@router.get("/", response_model=NetworkListResponse)
91-
def list_networks(
92-
skip: int = 0,
93-
limit: int = 100,
91+
class NetworkListFilters(BaseModel):
92+
"""Query parameters for filtering the networks list."""
93+
94+
skip: int = 0
95+
limit: int = 100
9496
owners: list[str] | None = Query(
9597
None,
9698
description="Filter by owner IDs. Use 'me' for current user.",
97-
),
99+
)
100+
101+
102+
@router.get("/", response_model=NetworkListResponse)
103+
def list_networks(
104+
filters: NetworkListFilters = Depends(),
98105
db: Session = Depends(get_db),
99106
user: User = Depends(require_permission(Permission.NETWORKS_VIEW)),
100107
) -> NetworkListResponse:
@@ -110,35 +117,43 @@ def list_networks(
110117
)
111118
query = query.filter(visibility_filter)
112119

113-
# Apply owner filter if specified
114-
if owners:
120+
# Apply owner filter if specified
121+
if filters.owners:
115122

116-
def owner_to_condition(owner_id: str) -> ColumnElement[bool]:
117-
if owner_id == "me":
118-
return Network.user_id == user.id
119-
return Network.user_id == owner_id
123+
def owner_to_condition(owner_id: str) -> ColumnElement[bool]:
124+
if owner_id == "me":
125+
return Network.user_id == user.id
126+
return Network.user_id == owner_id
120127

121-
query = query.filter(or_(*[owner_to_condition(oid) for oid in owners]))
128+
conditions = [owner_to_condition(oid) for oid in filters.owners]
129+
query = query.filter(or_(*conditions))
122130

123131
total = query.count()
124-
networks = query.order_by(Network.created_at.desc()).offset(skip).limit(limit).all()
132+
networks = (
133+
query.order_by(Network.created_at.desc())
134+
.offset(filters.skip)
135+
.limit(filters.limit)
136+
.all()
137+
)
125138

126139
# Get all unique owners for filter dropdown
127140
all_owners = []
141+
owners_query = db.query(Network.user_id)
128142
if not has_permission(user, Permission.NETWORKS_MANAGE_ALL):
129-
owners_query = db.query(Network.user_id)
130143
if visibility_filter is not None:
131144
owners_query = owners_query.filter(visibility_filter)
132-
owner_ids = [oid[0] for oid in owners_query.distinct().all()]
133-
if owner_ids:
134-
all_owners = db.query(User).filter(User.id.in_(owner_ids)).all()
145+
else:
146+
owners_query = owners_query.filter(Network.user_id == user.id)
147+
owner_ids = [oid[0] for oid in owners_query.distinct().all()]
148+
if owner_ids:
149+
all_owners = db.query(User).filter(User.id.in_(owner_ids)).all()
135150

136151
return NetworkListResponse(
137152
data=networks,
138153
meta={
139154
"total": total,
140-
"skip": skip,
141-
"limit": limit,
155+
"skip": filters.skip,
156+
"limit": filters.limit,
142157
"count": len(networks),
143158
"owners": all_owners,
144159
},

src/pypsa_app/backend/api/routes/runs.py

Lines changed: 96 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import re
55
import urllib.parse
66
import uuid
7-
from collections import defaultdict
87
from pathlib import PurePosixPath
8+
from typing import Any
99

1010
from fastapi import APIRouter, Depends, HTTPException, Query
1111
from fastapi import Path as PathParam
1212
from fastapi.responses import StreamingResponse
13+
from pydantic import BaseModel
1314
from sqlalchemy.orm import Session, joinedload
1415

1516
from pypsa_app.backend.api.deps import get_db, require_permission
@@ -33,21 +34,12 @@
3334
)
3435
from pypsa_app.backend.services.backend_registry import backend_registry
3536
from pypsa_app.backend.services.run import SnakedispatchClient, SnakedispatchError
37+
from pypsa_app.backend.services.sync import SYNCED_STATUSES, sync_run_from_job
3638
from pypsa_app.backend.settings import settings
37-
from pypsa_app.backend.tasks import import_run_outputs_task
3839

3940
router = APIRouter()
4041
logger = logging.getLogger(__name__)
4142

42-
# Statuses where the remote executor is done — no need to sync from Snakedispatch.
43-
SYNCED_STATUSES = {
44-
RunStatus.UPLOADING,
45-
RunStatus.COMPLETED,
46-
RunStatus.FAILED,
47-
RunStatus.ERROR,
48-
RunStatus.CANCELLED,
49-
}
50-
5143

5244
def _get_client_for_run(run: Run) -> SnakedispatchClient:
5345
"""Resolve a SnakedispatchClient from the run's backend_id."""
@@ -95,57 +87,6 @@ def _check_run(run_id: uuid.UUID, db: Session, user: User) -> Run:
9587
return run
9688

9789

98-
# Statuses that should not be resynced or trigger import again
99-
_IMPORT_DONE_STATUSES = {RunStatus.UPLOADING, RunStatus.COMPLETED, RunStatus.ERROR}
100-
101-
102-
def _sync_run_from_job(run: Run, job: dict, db: Session) -> None:
103-
"""Update a Run record from a Snakedispatch response dict."""
104-
changed = False
105-
field_map = {
106-
"workflow": "workflow",
107-
"configfile": "configfile",
108-
"git_ref": "git_ref",
109-
"git_sha": "git_sha",
110-
"exit_code": "exit_code",
111-
"started_at": "started_at",
112-
"completed_at": "completed_at",
113-
}
114-
for db_field, job_key in field_map.items():
115-
new_val = job.get(job_key)
116-
if new_val is not None and getattr(run, db_field) != new_val:
117-
setattr(run, db_field, new_val)
118-
changed = True
119-
120-
# Status needs enum conversion
121-
raw_status = job.get("status")
122-
if raw_status:
123-
try:
124-
new_status = RunStatus(raw_status)
125-
except ValueError:
126-
new_status = None
127-
128-
completed_with_import_pending = (
129-
new_status == RunStatus.COMPLETED
130-
and run.status not in _IMPORT_DONE_STATUSES
131-
)
132-
if completed_with_import_pending and run.import_networks:
133-
run.status = RunStatus.UPLOADING
134-
changed = True
135-
db.flush()
136-
import_run_outputs_task.apply_async(args=(str(run.job_id),))
137-
elif completed_with_import_pending:
138-
# Nothing to import
139-
run.status = RunStatus.COMPLETED
140-
changed = True
141-
elif new_status and run.status != new_status:
142-
run.status = new_status
143-
changed = True
144-
145-
if changed:
146-
db.flush()
147-
148-
14990
@router.get("/backends", response_model=list[BackendPublicResponse])
15091
def list_user_backends(
15192
db: Session = Depends(get_db),
@@ -213,46 +154,109 @@ def create_run(
213154
return RunResponse.model_validate(run)
214155

215156

157+
class RunListFilters(BaseModel):
158+
"""Query parameters for filtering the runs list."""
159+
160+
skip: int = 0
161+
limit: int = 100
162+
statuses: list[str] | None = Query(None, description="Filter by status values")
163+
workflows: list[str] | None = Query(None, description="Filter by workflow names")
164+
owners: list[str] | None = Query(
165+
None, description="Filter by owner IDs. Use 'me' for current user."
166+
)
167+
git_refs: list[str] | None = Query(None, description="Filter by git ref")
168+
configfiles: list[str] | None = Query(None, description="Filter by configfile")
169+
backends: list[str] | None = Query(
170+
None, description="Filter by backend IDs (UUIDs)"
171+
)
172+
173+
216174
@router.get("/", response_model=RunListResponse)
217175
def list_runs(
218-
skip: int = 0,
219-
limit: int = 100,
176+
filters: RunListFilters = Depends(),
220177
db: Session = Depends(get_db),
221178
user: User = Depends(require_permission(Permission.RUNS_VIEW)),
222179
) -> RunListResponse:
223180
"""List runs visible to the current user."""
181+
is_admin = has_permission(user, Permission.RUNS_MANAGE_ALL)
182+
user_filter = Run.user_id == user.id if not is_admin else None
183+
184+
# Collect distinct values per column for filter dropdowns
185+
def _distinct_vals(col: Any) -> list:
186+
q = db.query(col).filter(col.isnot(None))
187+
if user_filter is not None:
188+
q = q.filter(user_filter)
189+
return sorted(r[0] for r in q.distinct().all())
190+
191+
all_statuses = list(RunStatus)
192+
present_statuses = set(_distinct_vals(Run.status))
193+
filter_options: dict[str, Any] = {
194+
"statuses": [s for s in all_statuses if s in present_statuses],
195+
"workflows": _distinct_vals(Run.workflow),
196+
"git_refs": _distinct_vals(Run.git_ref),
197+
"configfiles": _distinct_vals(Run.configfile),
198+
}
199+
200+
backend_ids = _distinct_vals(Run.backend_id)
201+
filter_options["backends"] = (
202+
db.query(SnakedispatchBackend)
203+
.filter(SnakedispatchBackend.id.in_(backend_ids))
204+
.order_by(SnakedispatchBackend.name)
205+
.all()
206+
if backend_ids
207+
else None
208+
)
209+
210+
if is_admin:
211+
owner_ids = _distinct_vals(Run.user_id)
212+
filter_options["owners"] = (
213+
db.query(User).filter(User.id.in_(owner_ids)).all() if owner_ids else None
214+
)
215+
else:
216+
filter_options["owners"] = None
217+
224218
query = db.query(Run).options(joinedload(Run.owner), joinedload(Run.backend))
225-
if not has_permission(user, Permission.RUNS_MANAGE_ALL):
226-
query = query.filter(Run.user_id == user.id)
219+
if user_filter is not None:
220+
query = query.filter(user_filter)
221+
if filters.statuses:
222+
try:
223+
parsed_statuses = [RunStatus(s) for s in filters.statuses]
224+
except ValueError as e:
225+
raise HTTPException(422, f"Invalid status filter: {e}") from None
226+
query = query.filter(Run.status.in_(parsed_statuses))
227+
if filters.workflows:
228+
query = query.filter(Run.workflow.in_(filters.workflows))
229+
if filters.owners:
230+
resolved_ids = [user.id if o == "me" else o for o in filters.owners]
231+
query = query.filter(Run.user_id.in_(resolved_ids))
232+
if filters.git_refs:
233+
query = query.filter(Run.git_ref.in_(filters.git_refs))
234+
if filters.configfiles:
235+
query = query.filter(Run.configfile.in_(filters.configfiles))
236+
if filters.backends:
237+
try:
238+
parsed_backends = [uuid.UUID(b) for b in filters.backends]
239+
except ValueError as e:
240+
raise HTTPException(422, f"Invalid backend ID: {e}") from None
241+
query = query.filter(Run.backend_id.in_(parsed_backends))
227242

228243
total = query.count()
229-
runs = query.order_by(Run.created_at.desc()).offset(skip).limit(limit).all()
230-
231-
# One sync call per backend to avoid redundant API requests
232-
non_terminal = [r for r in runs if r.status not in SYNCED_STATUSES]
233-
if non_terminal:
234-
by_backend: dict[uuid.UUID, list[Run]] = defaultdict(list)
235-
for r in non_terminal:
236-
by_backend[r.backend_id].append(r)
237-
238-
for backend_id, backend_runs in by_backend.items():
239-
client = backend_registry.get_client(backend_id)
240-
if client is None:
241-
continue
242-
try:
243-
all_jobs = client.list_jobs()
244-
jobs_by_id = {j["job_id"]: j for j in all_jobs}
245-
for run in backend_runs:
246-
job = jobs_by_id.get(str(run.job_id))
247-
if job:
248-
_sync_run_from_job(run, job, db)
249-
db.commit()
250-
except SnakedispatchError:
251-
pass
244+
runs = (
245+
query.order_by(Run.created_at.desc())
246+
.offset(filters.skip)
247+
.limit(filters.limit)
248+
.all()
249+
)
252250

253251
return RunListResponse(
254252
data=[RunSummary.model_validate(r) for r in runs],
255-
meta={"total": total, "skip": skip, "limit": limit, "count": len(runs)},
253+
meta={
254+
"total": total,
255+
"skip": filters.skip,
256+
"limit": filters.limit,
257+
"count": len(runs),
258+
**filter_options,
259+
},
256260
)
257261

258262

@@ -271,7 +275,7 @@ def get_run(
271275
if client:
272276
try:
273277
job = client.get_job(str(run_id))
274-
_sync_run_from_job(run, job, db)
278+
sync_run_from_job(run, job, db)
275279
db.commit()
276280
except SnakedispatchError:
277281
pass
@@ -368,7 +372,7 @@ def cancel_run(
368372
sd_client = _get_client_for_run(run)
369373
try:
370374
result = sd_client.cancel_job(str(run_id))
371-
_sync_run_from_job(run, result, db)
375+
sync_run_from_job(run, result, db)
372376
db.commit()
373377
except SnakedispatchError as e:
374378
if e.status_code in (404, 409):

src/pypsa_app/backend/main.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import contextlib
13
import logging
24
from collections.abc import AsyncIterator
35
from contextlib import asynccontextmanager
@@ -28,6 +30,7 @@
2830
from pypsa_app.backend.models import SnakedispatchBackend, User, UserRole
2931
from pypsa_app.backend.services.backend_registry import backend_registry
3032
from pypsa_app.backend.services.run import SnakedispatchError
33+
from pypsa_app.backend.services.sync import run_sync_loop
3134
from pypsa_app.backend.settings import API_V1_PREFIX, settings
3235

3336
logging.basicConfig(
@@ -183,10 +186,25 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
183186
"No Snakedispatch backends configured (SNAKEDISPATCH_BACKENDS not set)"
184187
)
185188

189+
sync_task = None
190+
if settings.resolved_backends:
191+
sync_task = asyncio.create_task(
192+
run_sync_loop(interval=settings.snakedispatch_sync_interval)
193+
)
194+
logger.info(
195+
"Background run sync started",
196+
extra={"interval": settings.snakedispatch_sync_interval},
197+
)
198+
186199
yield
187200

188201
# Shutdown
189202
logger.info("Shutting down PyPSA Web App API")
203+
if sync_task is not None:
204+
sync_task.cancel()
205+
with contextlib.suppress(asyncio.CancelledError):
206+
await sync_task
207+
logger.info("Background run sync stopped")
190208
engine.dispose()
191209
logger.info("Shutdown complete")
192210

src/pypsa_app/backend/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,8 @@ class Run(Base):
327327
started_at = Column(TIMESTAMP, nullable=True)
328328
completed_at = Column(TIMESTAMP, nullable=True)
329329
import_networks = Column(JSON, nullable=True)
330+
total_job_count = Column(Integer, nullable=True)
331+
jobs_finished = Column(Integer, nullable=True)
330332

331333
networks = relationship(
332334
"Network", foreign_keys="Network.source_run_id", viewonly=True

0 commit comments

Comments
 (0)