Skip to content

Commit 66a3760

Browse files
committed
RabbitMQ back-end
1 parent df6241f commit 66a3760

4 files changed

Lines changed: 556 additions & 1 deletion

File tree

pyfuse/worker/backends/rabbitmq.py

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
"""RabbitMQ backend for multi-machine task distribution.
2+
3+
Uses ``aio-pika`` (async AMQP 0-9-1 client) for task dispatch, result
4+
routing, heartbeats, cancellation, and progress.
5+
6+
Tasks are dispatched via a durable queue. Per-task results use dedicated
7+
queues with message TTL. Heartbeats, cancellation flags, and progress
8+
data are stored in single-message queues (``x-max-length: 1``) that
9+
behave like key-value slots. Result notifications use a fanout exchange.
10+
11+
URL scheme: ``amqp://`` or ``amqps://`` (e.g. ``amqp://guest:guest@localhost/``)
12+
"""
13+
from __future__ import annotations
14+
15+
import asyncio
16+
import contextlib
17+
import time
18+
from collections.abc import AsyncIterator
19+
from typing import Any
20+
21+
from pyfuse.worker.backends.base import Backend
22+
23+
24+
class RabbitMQBackend(Backend):
25+
"""RabbitMQ-backed transport using ``aio-pika``.
26+
27+
Parameters
28+
----------
29+
url
30+
AMQP connection URL (e.g. ``amqp://guest:guest@localhost/``).
31+
task_queue
32+
Name of the durable task queue.
33+
result_ttl
34+
Seconds before result messages expire.
35+
"""
36+
37+
TASK_QUEUE = "pyfuse.tasks"
38+
RESULT_PREFIX = "pyfuse.result."
39+
HEARTBEAT_PREFIX = "pyfuse.hb."
40+
CANCEL_PREFIX = "pyfuse.cancel."
41+
PROGRESS_PREFIX = "pyfuse.progress."
42+
NOTIFY_EXCHANGE = "pyfuse.notify"
43+
44+
DEFAULT_RESULT_TTL = 300 # seconds
45+
HEARTBEAT_TTL = 30
46+
CANCEL_TTL = 3600
47+
PROGRESS_TTL = 300
48+
49+
def __init__(
50+
self,
51+
url: str = "amqp://localhost",
52+
*,
53+
task_queue: str | None = None,
54+
result_ttl: int | None = None,
55+
) -> None:
56+
try:
57+
import aio_pika as _ # type: ignore[import-not-found] # noqa: F401
58+
except ImportError:
59+
raise ImportError(
60+
"aio-pika package is required for RabbitMQBackend. "
61+
"Install it with: pip install aio-pika"
62+
) from None
63+
self._url = url
64+
self._task_queue_name = task_queue or self.TASK_QUEUE
65+
self._result_ttl = result_ttl or self.DEFAULT_RESULT_TTL
66+
self._connection: Any = None
67+
self._channel: Any = None
68+
self._lock = asyncio.Lock()
69+
70+
# -- connection management --------------------------------------------------
71+
72+
async def _get_channel(self) -> Any:
73+
"""Return the shared channel, creating connection if needed."""
74+
async with self._lock:
75+
if self._connection is None or self._connection.is_closed:
76+
import aio_pika
77+
78+
self._connection = await aio_pika.connect_robust(self._url)
79+
self._channel = None
80+
if self._channel is None or self._channel.is_closed:
81+
self._channel = await self._connection.channel()
82+
return self._channel
83+
84+
async def _new_channel(self) -> Any:
85+
"""Create a dedicated channel for long-running operations."""
86+
async with self._lock:
87+
if self._connection is None or self._connection.is_closed:
88+
import aio_pika
89+
90+
self._connection = await aio_pika.connect_robust(self._url)
91+
self._channel = None
92+
return await self._connection.channel()
93+
94+
# -- internal helpers -------------------------------------------------------
95+
96+
@staticmethod
97+
def _kv_args(ttl_s: int) -> dict[str, int]:
98+
"""Queue arguments for a single-message key-value slot."""
99+
return {
100+
"x-message-ttl": ttl_s * 1000,
101+
"x-max-length": 1,
102+
"x-expires": ttl_s * 2 * 1000,
103+
}
104+
105+
def _result_args(self) -> dict[str, int]:
106+
"""Queue arguments for a per-task result queue."""
107+
return {
108+
"x-message-ttl": self._result_ttl * 1000,
109+
"x-expires": self._result_ttl * 2 * 1000,
110+
}
111+
112+
async def _kv_put(
113+
self, prefix: str, task_id: str, value: str, ttl_s: int,
114+
) -> None:
115+
"""Write to a per-task KV queue (``x-max-length: 1`` overwrites)."""
116+
import aio_pika
117+
118+
channel = await self._get_channel()
119+
name = f"{prefix}{task_id}"
120+
await channel.declare_queue(name, arguments=self._kv_args(ttl_s))
121+
await channel.default_exchange.publish(
122+
aio_pika.Message(value.encode()),
123+
routing_key=name,
124+
)
125+
126+
async def _kv_get(
127+
self, prefix: str, task_id: str, ttl_s: int, *, peek: bool = False,
128+
) -> str | None:
129+
"""Read from a per-task KV queue.
130+
131+
When *peek* is ``True`` the message is nack'd back so future
132+
reads still see it (used for cancellation flags). Otherwise the
133+
message is consumed.
134+
"""
135+
channel = await self._get_channel()
136+
name = f"{prefix}{task_id}"
137+
queue = await channel.declare_queue(
138+
name, arguments=self._kv_args(ttl_s),
139+
)
140+
msg = await queue.get(fail=False, no_ack=not peek)
141+
if msg is None:
142+
return None
143+
if peek:
144+
await msg.nack(requeue=True)
145+
raw: str = msg.body.decode()
146+
return raw
147+
148+
# -- Backend interface: tasks -----------------------------------------------
149+
150+
async def submit(self, task_json: str) -> None:
151+
import aio_pika
152+
153+
channel = await self._get_channel()
154+
await channel.declare_queue(self._task_queue_name, durable=True)
155+
await channel.default_exchange.publish(
156+
aio_pika.Message(
157+
task_json.encode(),
158+
delivery_mode=aio_pika.DeliveryMode.PERSISTENT,
159+
),
160+
routing_key=self._task_queue_name,
161+
)
162+
163+
async def listen(self) -> AsyncIterator[str]:
164+
channel = await self._new_channel()
165+
try:
166+
await channel.set_qos(prefetch_count=1)
167+
queue = await channel.declare_queue(
168+
self._task_queue_name, durable=True,
169+
)
170+
async with queue.iterator() as qi:
171+
async for message in qi:
172+
async with message.process():
173+
yield message.body.decode()
174+
finally:
175+
with contextlib.suppress(Exception):
176+
await channel.close()
177+
178+
# -- Backend interface: results ---------------------------------------------
179+
180+
async def send_result(self, task_id: str, result_json: str) -> None:
181+
import aio_pika
182+
183+
channel = await self._get_channel()
184+
name = f"{self.RESULT_PREFIX}{task_id}"
185+
await channel.declare_queue(name, arguments=self._result_args())
186+
await channel.default_exchange.publish(
187+
aio_pika.Message(result_json.encode()),
188+
routing_key=name,
189+
)
190+
191+
async def get_result(self, task_id: str, timeout: float | None = None) -> str:
192+
channel = await self._new_channel()
193+
try:
194+
name = f"{self.RESULT_PREFIX}{task_id}"
195+
queue = await channel.declare_queue(
196+
name, arguments=self._result_args(),
197+
)
198+
future: asyncio.Future[str] = asyncio.get_running_loop().create_future()
199+
200+
async def _on_message(msg: Any) -> None:
201+
await msg.ack()
202+
if not future.done():
203+
future.set_result(msg.body.decode())
204+
205+
tag = await queue.consume(_on_message)
206+
try:
207+
if timeout is not None:
208+
try:
209+
return await asyncio.wait_for(future, timeout=timeout)
210+
except asyncio.TimeoutError:
211+
raise TimeoutError(
212+
f"Timed out waiting for result of task {task_id}"
213+
) from None
214+
return await future
215+
finally:
216+
with contextlib.suppress(Exception):
217+
await queue.cancel(tag)
218+
finally:
219+
with contextlib.suppress(Exception):
220+
await channel.close()
221+
222+
async def try_get_result(self, task_id: str) -> str | None:
223+
channel = await self._get_channel()
224+
name = f"{self.RESULT_PREFIX}{task_id}"
225+
queue = await channel.declare_queue(
226+
name, arguments=self._result_args(),
227+
)
228+
msg = await queue.get(fail=False)
229+
if msg is None:
230+
return None
231+
await msg.ack()
232+
raw: str = msg.body.decode()
233+
return raw
234+
235+
# -- Heartbeat -------------------------------------------------------------
236+
237+
async def send_heartbeat(self, task_id: str) -> None:
238+
await self._kv_put(
239+
self.HEARTBEAT_PREFIX, task_id,
240+
str(time.time()), self.HEARTBEAT_TTL,
241+
)
242+
243+
async def get_heartbeat(self, task_id: str) -> float | None:
244+
raw = await self._kv_get(
245+
self.HEARTBEAT_PREFIX, task_id, self.HEARTBEAT_TTL,
246+
)
247+
return float(raw) if raw is not None else None
248+
249+
# -- Cancellation ----------------------------------------------------------
250+
251+
async def cancel_task(self, task_id: str) -> None:
252+
await self._kv_put(
253+
self.CANCEL_PREFIX, task_id, "1", self.CANCEL_TTL,
254+
)
255+
256+
async def is_cancelled(self, task_id: str) -> bool:
257+
raw = await self._kv_get(
258+
self.CANCEL_PREFIX, task_id, self.CANCEL_TTL, peek=True,
259+
)
260+
return raw is not None
261+
262+
# -- Progress --------------------------------------------------------------
263+
264+
async def send_progress(self, task_id: str, progress_json: str) -> None:
265+
await self._kv_put(
266+
self.PROGRESS_PREFIX, task_id, progress_json, self.PROGRESS_TTL,
267+
)
268+
269+
async def get_progress(self, task_id: str) -> str | None:
270+
return await self._kv_get(
271+
self.PROGRESS_PREFIX, task_id, self.PROGRESS_TTL,
272+
)
273+
274+
# -- Result notifications --------------------------------------------------
275+
276+
async def notify_result(self, task_id: str) -> None:
277+
import aio_pika
278+
279+
channel = await self._get_channel()
280+
exchange = await channel.declare_exchange(
281+
self.NOTIFY_EXCHANGE, aio_pika.ExchangeType.FANOUT,
282+
)
283+
await exchange.publish(
284+
aio_pika.Message(task_id.encode()),
285+
routing_key="",
286+
)
287+
288+
async def subscribe_results(self) -> AsyncIterator[str]:
289+
import aio_pika
290+
291+
channel = await self._new_channel()
292+
try:
293+
exchange = await channel.declare_exchange(
294+
self.NOTIFY_EXCHANGE, aio_pika.ExchangeType.FANOUT,
295+
)
296+
queue = await channel.declare_queue(exclusive=True)
297+
await queue.bind(exchange)
298+
async with queue.iterator() as qi:
299+
async for message in qi:
300+
async with message.process():
301+
yield message.body.decode()
302+
finally:
303+
with contextlib.suppress(Exception):
304+
await channel.close()
305+
306+
# -- Lifecycle -------------------------------------------------------------
307+
308+
async def close(self) -> None:
309+
if self._channel is not None:
310+
with contextlib.suppress(Exception):
311+
await self._channel.close()
312+
self._channel = None
313+
if self._connection is not None:
314+
with contextlib.suppress(Exception):
315+
await self._connection.close()
316+
self._connection = None

pyfuse/worker/remote.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,13 @@ def _create_backend(url: str, **kwargs: Any) -> Backend:
5353
from pyfuse.worker.backends.local import LocalBackend
5454

5555
return LocalBackend(url, **kwargs)
56+
if scheme in ("amqp", "amqps"):
57+
from pyfuse.worker.backends.rabbitmq import RabbitMQBackend
58+
59+
return RabbitMQBackend(url, **kwargs)
5660
raise ValueError(
5761
f"Unknown backend scheme: {scheme!r}. "
58-
f"Supported: redis://, rediss://, local://"
62+
f"Supported: redis://, rediss://, local://, amqp://, amqps://"
5963
)
6064

6165

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pyfuse = "pyfuse.__main__:main"
2929

3030
[project.optional-dependencies]
3131
redis = ["redis>=5.0"]
32+
rabbitmq = ["aio-pika>=9.0"]
3233

3334
[dependency-groups]
3435
dev = [

0 commit comments

Comments
 (0)