Skip to content

Commit 8d09830

Browse files
committed
feat: improved in-memory task queue as celery alternative
1 parent 1b012eb commit 8d09830

4 files changed

Lines changed: 164 additions & 61 deletions

File tree

src/pypsa_app/backend/api/utils/task_utils.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
1-
"""Shared utilities for Celery task status responses"""
1+
"""Shared utilities for task status responses"""
22

33
import logging
4-
from celery.result import AsyncResult
54

6-
from pypsa_app.backend.celery_app import celery_app
5+
from pypsa_app.backend.task_queue import task_app
76
from pypsa_app.backend.settings import settings
87

98
logger = logging.getLogger(__name__)
109

1110

1211
def get_task_status_response(task_id: str) -> dict:
13-
"""Get standardized Celery task status response"""
14-
task = AsyncResult(task_id, app=celery_app)
12+
"""Get standardized task status response"""
13+
# Check if using Celery or in-memory queue
14+
try:
15+
# Try using Celery's AsyncResult if Celery is the backend
16+
from celery.result import AsyncResult
17+
18+
task = AsyncResult(task_id, app=task_app)
19+
except (ImportError, AttributeError):
20+
# Fall back to in-memory AsyncResult
21+
from pypsa_app.backend.task_queue import InMemoryAsyncResult
22+
23+
task = InMemoryAsyncResult(task_id)
24+
1525
response = {"task_id": task_id, "state": task.state}
1626

1727
match task.state:

src/pypsa_app/backend/celery_app.py

Lines changed: 0 additions & 47 deletions
This file was deleted.
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""Celery configuration and in-memory task queue fallback"""
2+
3+
import logging
4+
import threading
5+
import uuid
6+
from concurrent.futures import ThreadPoolExecutor
7+
from datetime import datetime, timedelta, timezone
8+
9+
from pypsa_app.backend.settings import settings
10+
11+
logger = logging.getLogger(__name__)
12+
13+
_tasks = {}
14+
_lock = threading.Lock()
15+
_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="task")
16+
17+
18+
class InMemoryAsyncResult:
19+
def __init__(self, task_id):
20+
self.id = task_id
21+
22+
@property
23+
def state(self):
24+
with _lock:
25+
return _tasks.get(self.id, {}).get("state", "PENDING")
26+
27+
@property
28+
def result(self):
29+
with _lock:
30+
t = _tasks.get(self.id, {})
31+
return t.get("result") if t.get("state") == "SUCCESS" else None
32+
33+
@property
34+
def info(self):
35+
with _lock:
36+
t = _tasks.get(self.id, {})
37+
if t.get("state") == "FAILURE":
38+
return t.get("exception")
39+
if t.get("state") == "PROGRESS":
40+
return t.get("meta", {})
41+
return None
42+
43+
44+
class InMemoryTaskQueue:
45+
def task(self, *args, **kwargs):
46+
bind = kwargs.get("bind", False)
47+
48+
def decorator(func):
49+
def apply_async(args=(), kwargs=None, **options):
50+
tid = str(uuid.uuid4())
51+
now = datetime.now(timezone.utc)
52+
53+
with _lock:
54+
_tasks[tid] = {"state": "PENDING", "created_at": now}
55+
cutoff = now - timedelta(hours=24)
56+
for k in list(_tasks.keys()):
57+
if _tasks[k].get("created_at", cutoff) < cutoff:
58+
del _tasks[k]
59+
60+
class Task:
61+
request = type("Request", (), {"id": tid})()
62+
63+
@staticmethod
64+
def update_state(state=None, meta=None):
65+
with _lock:
66+
if tid in _tasks and state:
67+
_tasks[tid]["state"] = state
68+
if tid in _tasks and meta:
69+
_tasks[tid]["meta"] = meta
70+
71+
def run():
72+
try:
73+
res = (
74+
func(Task(), *args, **(kwargs or {}))
75+
if bind
76+
else func(*args, **(kwargs or {}))
77+
)
78+
with _lock:
79+
_tasks[tid].update({"state": "SUCCESS", "result": res})
80+
except Exception as e:
81+
with _lock:
82+
_tasks[tid].update(
83+
{"state": "FAILURE", "exception": str(e)}
84+
)
85+
logger.error(
86+
"Task failed",
87+
extra={"task_id": tid, "error": str(e)},
88+
exc_info=True,
89+
)
90+
91+
_pool.submit(run)
92+
return InMemoryAsyncResult(tid)
93+
94+
func.apply_async = apply_async
95+
func.name = kwargs.get("name", func.__name__)
96+
return func
97+
98+
return decorator
99+
100+
101+
# Try to use Celery with Redis, fall back to in-memory task queue
102+
try:
103+
from celery import Celery
104+
105+
# Only use real Celery if Redis URL is configured
106+
if not settings.redis_url:
107+
logger.warning(
108+
"Redis URL not configured - using in-memory task queue",
109+
extra={"backend": "in-memory", "background_tasks_enabled": True},
110+
)
111+
task_app = InMemoryTaskQueue()
112+
else:
113+
task_app = Celery(
114+
"pypsa_app",
115+
broker=settings.redis_url,
116+
backend=settings.redis_url,
117+
include=["pypsa_app.backend.tasks"],
118+
)
119+
120+
task_app.conf.update(
121+
accept_content=["json"],
122+
result_expires=86400,
123+
worker_prefetch_multiplier=1,
124+
worker_max_tasks_per_child=10,
125+
task_soft_time_limit=3600,
126+
task_time_limit=7200,
127+
task_acks_late=True,
128+
)
129+
130+
logger.info(
131+
"Initialized Celery with Redis backend",
132+
extra={"redis_url": settings.redis_url, "backend": "celery"},
133+
)
134+
135+
except ImportError:
136+
logger.warning(
137+
"Celery not installed - using in-memory task queue",
138+
extra={"backend": "in-memory", "background_tasks_enabled": True},
139+
)
140+
task_app = InMemoryTaskQueue()

src/pypsa_app/backend/tasks.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, Callable, Dict
66

77
from pypsa_app.backend.cache import cache
8-
from pypsa_app.backend.celery_app import celery_app
8+
from pypsa_app.backend.task_queue import task_app
99
from pypsa_app.backend.schemas.task import TaskResult
1010
from pypsa_app.backend.services.map import extract_geographic_layer
1111
from pypsa_app.backend.services.network import scan_networks
@@ -46,28 +46,28 @@ def _execute_task(self, name: str, func: Callable, **kwargs) -> Dict[str, Any]:
4646
).model_dump()
4747

4848

49-
@celery_app.task(bind=True, name="tasks.get_statistics")
49+
@task_app.task(bind=True, name="tasks.get_statistics")
5050
def get_statistics_task(self, **kwargs):
51-
"""Celery task for statistics generation"""
51+
"""Background task for statistics generation"""
5252
func = cache("statistics", ttl=settings.plot_cache_ttl)(get_statistics_service)
5353
return _execute_task(self, "Statistics generation", func, **kwargs)
5454

5555

56-
@celery_app.task(bind=True, name="tasks.get_plot")
56+
@task_app.task(bind=True, name="tasks.get_plot")
5757
def get_plot_task(self, **kwargs):
58-
"""Celery task for plot generation"""
58+
"""Background task for plot generation"""
5959
func = cache("plot", ttl=settings.plot_cache_ttl)(get_plot_service)
6060
return _execute_task(self, "Plot generation", func, **kwargs)
6161

6262

63-
@celery_app.task(bind=True, name="tasks.extract_geographic_layer")
63+
@task_app.task(bind=True, name="tasks.extract_geographic_layer")
6464
def extract_geographic_layer_task(self, **kwargs):
65-
"""Celery task for geographic layer extraction"""
65+
"""Background task for geographic layer extraction"""
6666
func = cache("map_data", ttl=settings.map_cache_ttl)(extract_geographic_layer)
6767
return _execute_task(self, "Geographic layer extraction", func, **kwargs)
6868

6969

70-
@celery_app.task(bind=True, name="tasks.scan_networks")
70+
@task_app.task(bind=True, name="tasks.scan_networks")
7171
def scan_networks_task(self, **kwargs):
72-
"""Celery task for network scanning (no caching)"""
72+
"""Background task for network scanning (no caching)"""
7373
return _execute_task(self, "Network scan", scan_networks, **kwargs)

0 commit comments

Comments
 (0)