Skip to content

Commit 5dd4efa

Browse files
committed
Fix Redis E2E on GHA
1 parent 4c6e13d commit 5dd4efa

2 files changed

Lines changed: 104 additions & 10 deletions

File tree

offwork/worker/backends/redis.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
"""Redis-backed transport using ``RPUSH``/``BLPOP`` for tasks and results."""
22

3+
import math
34
import time
45
import asyncio
56
from typing import Any
7+
from urllib.parse import parse_qs, urlparse
68
from collections.abc import AsyncIterator
79

810
try:
911
import redis.asyncio as _redis
12+
from redis.exceptions import TimeoutError as RedisTimeoutError
1013
except ImportError:
1114
raise ImportError(
1215
"redis package is required for RedisBackend. "
@@ -50,7 +53,11 @@ def __init__(
5053
queue_key: str | None = None,
5154
result_ttl: int | None = None,
5255
) -> None:
53-
self._redis: Any = _redis.Redis.from_url(url)
56+
query = parse_qs(urlparse(url).query)
57+
connect_kwargs: dict[str, Any] = {}
58+
if "socket_timeout" not in query:
59+
connect_kwargs["socket_timeout"] = None
60+
self._redis: Any = _redis.Redis.from_url(url, **connect_kwargs)
5461
self._queue_key = queue_key or self.DEFAULT_QUEUE_KEY
5562
self._result_ttl = result_ttl or self.DEFAULT_RESULT_TTL
5663

@@ -60,7 +67,13 @@ async def submit(self, task_json: str) -> None:
6067
async def listen(self) -> AsyncIterator[str]:
6168
"""Block on ``BLPOP`` and yield task JSON strings as they arrive."""
6269
while True:
63-
result = await self._redis.blpop(self._queue_key)
70+
try:
71+
result = await self._redis.blpop(self._queue_key)
72+
except RedisTimeoutError:
73+
task = asyncio.current_task()
74+
if task is not None and task.cancelling():
75+
raise asyncio.CancelledError() from None
76+
continue
6477
if result is None:
6578
continue
6679
_, raw = result
@@ -73,14 +86,34 @@ async def send_result(self, task_id: str, result_json: str) -> None:
7386

7487
async def get_result(self, task_id: str, timeout: float | None = None) -> str:
7588
key = f"{self.RESULT_PREFIX}{task_id}"
76-
t = int(timeout) if timeout else 0
77-
result = await self._redis.blpop(key, timeout=t)
78-
if result is None:
79-
raise TimeoutError(
80-
f"Timed out waiting for result of task {task_id}"
81-
)
82-
_, raw = result
83-
return raw.decode() if isinstance(raw, bytes) else raw
89+
deadline = None if timeout is None else time.monotonic() + max(0.0, timeout)
90+
while True:
91+
if deadline is None:
92+
block_seconds = 0
93+
else:
94+
remaining = deadline - time.monotonic()
95+
if remaining <= 0:
96+
raise TimeoutError(
97+
f"Timed out waiting for result of task {task_id}"
98+
)
99+
block_seconds = max(1, math.ceil(remaining))
100+
try:
101+
result = await self._redis.blpop(key, timeout=block_seconds)
102+
except RedisTimeoutError:
103+
task = asyncio.current_task()
104+
if task is not None and task.cancelling():
105+
raise asyncio.CancelledError() from None
106+
if deadline is not None and time.monotonic() >= deadline:
107+
raise TimeoutError(
108+
f"Timed out waiting for result of task {task_id}"
109+
) from None
110+
continue
111+
if result is None:
112+
raise TimeoutError(
113+
f"Timed out waiting for result of task {task_id}"
114+
)
115+
_, raw = result
116+
return raw.decode() if isinstance(raw, bytes) else raw
84117

85118
async def try_get_result(self, task_id: str) -> str | None:
86119
"""Non-blocking ``LPOP``; returns ``None`` if not yet available."""

tests/test_redis_backend.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import asyncio
2+
3+
import pytest
4+
from redis.exceptions import TimeoutError as RedisTimeoutError
5+
6+
from offwork.worker.backends.redis import RedisBackend
7+
8+
9+
class _FakeRedis:
10+
def __init__(self) -> None:
11+
self.listen_calls = 0
12+
self.result_calls = 0
13+
14+
async def blpop(self, key: str, timeout: int | None = None) -> tuple[str, str] | None:
15+
if key == "offwork:tasks":
16+
self.listen_calls += 1
17+
if self.listen_calls == 1:
18+
raise RedisTimeoutError("transient read timeout")
19+
return (key, "task-json")
20+
self.result_calls += 1
21+
raise RedisTimeoutError("read timeout")
22+
23+
async def aclose(self) -> None:
24+
return
25+
26+
27+
@pytest.mark.asyncio
28+
async def test_listen_retries_after_redis_timeout() -> None:
29+
backend = RedisBackend()
30+
backend._redis = _FakeRedis()
31+
32+
stream = backend.listen()
33+
item = await anext(stream)
34+
35+
assert item == "task-json"
36+
37+
38+
@pytest.mark.asyncio
39+
async def test_get_result_translates_redis_timeout() -> None:
40+
backend = RedisBackend()
41+
backend._redis = _FakeRedis()
42+
43+
with pytest.raises(TimeoutError, match="Timed out waiting for result"):
44+
await backend.get_result("task-1", timeout=0.01)
45+
46+
47+
@pytest.mark.asyncio
48+
async def test_get_result_returns_payload() -> None:
49+
backend = RedisBackend()
50+
51+
class _ResultRedis:
52+
async def blpop(self, key: str, timeout: int | None = None) -> tuple[str, bytes]:
53+
return (key, b"ok")
54+
55+
async def aclose(self) -> None:
56+
return
57+
58+
backend._redis = _ResultRedis()
59+
60+
result = await backend.get_result("task-2", timeout=1)
61+
assert result == "ok"

0 commit comments

Comments
 (0)