Skip to content

Commit 76fd0a3

Browse files
unix_socket_path in RedisSettings (#336)
* from rebase * update: testing requirements * Update tests/conftest.py Co-authored-by: Samuel Colvin <samcolvin@gmail.com> * update: testing requirements Co-authored-by: Samuel Colvin <samcolvin@gmail.com>
1 parent be6c763 commit 76fd0a3

6 files changed

Lines changed: 58 additions & 5 deletions

File tree

arq/connections.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from datetime import datetime, timedelta
66
from operator import attrgetter
77
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
8-
from urllib.parse import urlparse
8+
from urllib.parse import parse_qs, urlparse
99
from uuid import uuid4
1010

1111
from redis.asyncio import ConnectionPool, Redis
@@ -29,6 +29,7 @@ class RedisSettings:
2929

3030
host: Union[str, List[Tuple[str, int]]] = 'localhost'
3131
port: int = 6379
32+
unix_socket_path: Optional[str] = None
3233
database: int = 0
3334
username: Optional[str] = None
3435
password: Optional[str] = None
@@ -49,14 +50,21 @@ class RedisSettings:
4950
@classmethod
5051
def from_dsn(cls, dsn: str) -> 'RedisSettings':
5152
conf = urlparse(dsn)
52-
assert conf.scheme in {'redis', 'rediss'}, 'invalid DSN scheme'
53+
assert conf.scheme in {'redis', 'rediss', 'unix'}, 'invalid DSN scheme'
54+
query_db = parse_qs(conf.query).get('db')
55+
if query_db:
56+
# e.g. redis://localhost:6379?db=1
57+
database = int(query_db[0])
58+
else:
59+
database = int(conf.path.lstrip('/')) if conf.path else 0
5360
return RedisSettings(
5461
host=conf.hostname or 'localhost',
5562
port=conf.port or 6379,
5663
ssl=conf.scheme == 'rediss',
5764
username=conf.username,
5865
password=conf.password,
59-
database=int((conf.path or '0').strip('/')),
66+
database=database,
67+
unix_socket_path=conf.path if conf.scheme == 'unix' else None,
6068
)
6169

6270
def __repr__(self) -> str:
@@ -230,6 +238,7 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis:
230238
ArqRedis,
231239
host=settings.host,
232240
port=settings.port,
241+
unix_socket_path=settings.unix_socket_path,
233242
socket_connect_timeout=settings.conn_timeout,
234243
ssl=settings.ssl,
235244
ssl_keyfile=settings.ssl_keyfile,

requirements/testing.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ pytest-mock>=3,<4
88
pytest-sugar>=0.9,<1
99
pytest-timeout>=2,<3
1010
pytz
11+
redislite

requirements/testing.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
#
55
# pip-compile --output-file=requirements/testing.txt requirements/testing.in
66
#
7+
async-timeout==4.0.2
8+
# via redis
79
attrs==22.1.0
810
# via pytest
911
coverage[toml]==6.4.4
1012
# via -r requirements/testing.in
13+
deprecated==1.2.13
14+
# via redis
1115
dirty-equals==0.4
1216
# via -r requirements/testing.in
1317
iniconfig==1.1.1
@@ -18,8 +22,11 @@ packaging==21.3
1822
# via
1923
# pytest
2024
# pytest-sugar
25+
# redis
2126
pluggy==1.0.0
2227
# via pytest
28+
psutil==5.9.1
29+
# via redislite
2330
py==1.11.0
2431
# via pytest
2532
pydantic==1.9.2
@@ -45,6 +52,10 @@ pytz==2022.2.1
4552
# via
4653
# -r requirements/testing.in
4754
# dirty-equals
55+
redis==4.2.2
56+
# via redislite
57+
redislite==6.2.805324
58+
# via -r requirements/testing.in
4859
termcolor==1.1.0
4960
# via pytest-sugar
5061
tomli==2.0.1
@@ -53,3 +64,8 @@ tomli==2.0.1
5364
# pytest
5465
typing-extensions==4.3.0
5566
# via pydantic
67+
wrapt==1.14.1
68+
# via deprecated
69+
70+
# The following packages are considered to be unsafe in a requirements file:
71+
# setuptools

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import msgpack
77
import pytest
8+
from redislite import Redis
89

910
from arq.connections import ArqRedis, create_pool
1011
from arq.worker import Worker
@@ -30,6 +31,13 @@ async def arq_redis(loop):
3031
await redis_.close(close_connection_pool=True)
3132

3233

34+
@pytest.fixture
35+
async def unix_socket_path(loop, tmp_path):
36+
rdb = Redis(str(tmp_path / 'redis_test.db'))
37+
yield rdb.socket_file
38+
rdb.close()
39+
40+
3341
@pytest.fixture
3442
async def arq_redis_msgpack(loop):
3543
redis_ = ArqRedis(

tests/test_jobs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ async def foobar(ctx, *args, **kwargs):
7676
]
7777

7878

79+
async def test_enqueue_job_with_unix_socket(worker, unix_socket_path):
80+
"""Test initializing arq_redis using a unix socket connection, and the worker using it."""
81+
settings = RedisSettings(unix_socket_path=unix_socket_path)
82+
arq_redis = await create_pool(settings, default_queue_name='socket_queue')
83+
await test_enqueue_job(
84+
arq_redis,
85+
lambda functions, **_: worker(functions=functions, arq_redis=arq_redis, queue_name=None),
86+
queue_name=None,
87+
)
88+
89+
7990
async def test_enqueue_job_alt_queue(arq_redis: ArqRedis, worker):
8091
await test_enqueue_job(arq_redis, worker, queue_name='custom_queue')
8192

tests/test_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def test_settings_changed():
1818
settings = RedisSettings(port=123)
1919
assert settings.port == 123
2020
assert (
21-
"RedisSettings(host='localhost', port=123, database=0, username=None, password=None, ssl=False, "
22-
"ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs='required', ssl_ca_certs=None, "
21+
"RedisSettings(host='localhost', port=123, unix_socket_path=None, database=0, username=None, password=None, "
22+
"ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs='required', ssl_ca_certs=None, "
2323
'ssl_ca_data=None, ssl_check_hostname=False, conn_timeout=1, conn_retries=5, conn_retry_delay=1, '
2424
"sentinel=False, sentinel_master='mymaster')"
2525
) == str(settings)
@@ -141,6 +141,14 @@ def parse_redis_settings(cls, v):
141141
assert s4.redis_settings.username == 'user'
142142
assert s4.redis_settings.password == 'pass'
143143

144+
s5 = Settings(redis_settings={'unix_socket_path': '/tmp/redis.sock'})
145+
assert s5.redis_settings.unix_socket_path == '/tmp/redis.sock'
146+
assert s5.redis_settings.database == 0
147+
148+
s6 = Settings(redis_settings='unix:///tmp/redis.socket?db=6')
149+
assert s6.redis_settings.unix_socket_path == '/tmp/redis.socket'
150+
assert s6.redis_settings.database == 6
151+
144152

145153
def test_ms_to_datetime_tz(env: SetEnv):
146154
arq.utils.get_tz.cache_clear()

0 commit comments

Comments
 (0)