Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ CQRS = {
'port': RABBITMQ_PORT,
'user': RABBITMQ_USERNAME,
'password': RABBITMQ_PASSWORD,
'virtual_host': RABBITMQ_VIRTUAL_HOST,
}

```
Expand Down Expand Up @@ -119,6 +120,7 @@ CQRS = {
'port': RABBITMQ_PORT,
'user': RABBITMQ_USERNAME,
'password': RABBITMQ_PASSWORD,
'virtual_host': RABBITMQ_VIRTUAL_HOST,
}
```
* Apply migrations on both services
Expand Down
7 changes: 4 additions & 3 deletions dj_cqrs/management/commands/cqrs_dead_letters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def get_common_settings(cls):
return cls._get_common_settings()

@classmethod
def create_connection(cls, host, port, creds, exchange):
return cls._create_connection(host, port, creds, exchange)
def create_connection(cls, host, port, creds, virtual_host, exchange):
return cls._create_connection(host, port, creds, virtual_host, exchange)

@classmethod
def declare_queue(cls, channel, queue_name):
Expand Down Expand Up @@ -75,11 +75,12 @@ def check_transport(self):
raise CommandError('Dead letters commands available only for RabbitMQTransport.')

def init_broker(self):
host, port, creds, exchange = RabbitMQTransportService.get_common_settings()
host, port, creds, virtual_host, exchange = RabbitMQTransportService.get_common_settings()
connection, channel = RabbitMQTransportService.create_connection(
host,
port,
creds,
virtual_host,
exchange,
)

Expand Down
20 changes: 14 additions & 6 deletions dj_cqrs/transport/rabbit_mq.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,15 @@ def _get_consumer_rmq_objects(
host,
port,
creds,
virtual_host,
exchange,
queue_name,
dead_letter_queue_name,
prefetch_count,
cqrs_ids=None,
):
connection = BlockingConnection(
ConnectionParameters(host=host, port=port, credentials=creds),
ConnectionParameters(host=host, port=port, credentials=creds, virtual_host=virtual_host),
)
channel = connection.channel()
channel.basic_qos(prefetch_count=prefetch_count)
Expand Down Expand Up @@ -333,29 +334,30 @@ def _get_consumer_rmq_objects(
return connection, channel, consumer_generator

@classmethod
def _get_producer_rmq_objects(cls, host, port, creds, exchange, signal_type=None):
def _get_producer_rmq_objects(cls, host, port, creds, virtual_host, exchange, signal_type=None):
"""
Use shared connection in case of sync mode, otherwise create new connection for each
message
"""
if signal_type == SignalType.SYNC:
if cls._producer_connection is None:
connection, channel = cls._create_connection(host, port, creds, exchange)
connection, channel = cls._create_connection(host, port, creds, virtual_host, exchange)

cls._producer_connection = connection
cls._producer_channel = channel

return cls._producer_connection, cls._producer_channel
else:
return cls._create_connection(host, port, creds, exchange)
return cls._create_connection(host, port, creds, virtual_host, exchange)

@classmethod
def _create_connection(cls, host, port, creds, exchange):
def _create_connection(cls, host, port, creds, virtual_host, exchange):
connection = BlockingConnection(
ConnectionParameters(
host=host,
port=port,
credentials=creds,
virtual_host=virtual_host,
blocked_connection_timeout=10,
),
)
Expand Down Expand Up @@ -386,22 +388,28 @@ def _parse_url(url):
parts.port or ConnectionParameters.DEFAULT_PORT,
unquote(parts.username or '') or ConnectionParameters.DEFAULT_USERNAME,
unquote(parts.password or '') or ConnectionParameters.DEFAULT_PASSWORD,
unquote(parts.path.lstrip('/')) or ConnectionParameters.DEFAULT_VIRTUAL_HOST,
)

@classmethod
def _get_common_settings(cls):
if 'url' in settings.CQRS:
host, port, user, password = cls._parse_url(settings.CQRS.get('url'))
host, port, user, password, virtual_host = cls._parse_url(settings.CQRS.get('url'))
else:
host = settings.CQRS.get('host', ConnectionParameters.DEFAULT_HOST)
port = settings.CQRS.get('port', ConnectionParameters.DEFAULT_PORT)
user = settings.CQRS.get('user', ConnectionParameters.DEFAULT_USERNAME)
password = settings.CQRS.get('password', ConnectionParameters.DEFAULT_PASSWORD)
virtual_host = settings.CQRS.get(
'virtual_host',
ConnectionParameters.DEFAULT_VIRTUAL_HOST,
)
exchange = settings.CQRS.get('exchange', 'cqrs')
return (
host,
port,
credentials.PlainCredentials(user, password, erase_on_connect=True),
virtual_host,
exchange,
)

Expand Down
23 changes: 23 additions & 0 deletions docs/transports.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@ CQRS = {
}
```

The virtual host is optional and defaults to `/` (root vhost). To specify a custom virtual host,
include it in the URL path:

``` py3
# Using root virtual host (default)
CQRS = {
'transport': 'dj_cqrs.transport.RabbitMQTransport',
'url': 'amqp://guest:guest@rabbit:5672/'
}

# Using a custom virtual host
CQRS = {
'transport': 'dj_cqrs.transport.RabbitMQTransport',
'url': 'amqp://guest:guest@rabbit:5672/my_vhost'
}

# Using nested virtual host paths
CQRS = {
'transport': 'dj_cqrs.transport.RabbitMQTransport',
'url': 'amqp://guest:guest@rabbit:5672/production/app1'
}
```

!!! warning

Previous versions of the `RabbitMQTransport` use the attributes `host`,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_commands/test_dead_letters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_dump(capsys, mocker):
mocker.patch.object(
RabbitMQTransport,
'_get_common_settings',
return_value=('host', 'port', mocker.MagicMock(), 'exchange'),
return_value=('host', 'port', mocker.MagicMock(), '/', 'exchange'),
)

queue = mocker.MagicMock()
Expand Down
140 changes: 135 additions & 5 deletions tests/test_transport/test_rabbit_mq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def test_default_settings():
assert s[0] == 'localhost'
assert s[1] == 5672
assert s[2].username == 'guest' and s[2].password == 'guest'
assert s[3] == 'cqrs'
assert s[3] == '/'
assert s[4] == 'cqrs'


def test_non_default_settings(settings, caplog):
Expand All @@ -78,14 +79,16 @@ def test_non_default_settings(settings, caplog):
'port': 8000,
'user': 'usr',
'password': 'pswd',
'virtual_host': 'test',
'exchange': 'exchange',
}

s = PublicRabbitMQTransport.get_common_settings()
assert s[0] == 'rabbit'
assert s[1] == 8000
assert s[2].username == 'usr' and s[2].password == 'pswd'
assert s[3] == 'exchange'
assert s[3] == 'test'
assert s[4] == 'exchange'


def test_default_url_settings(settings):
Expand All @@ -97,20 +100,50 @@ def test_default_url_settings(settings):
assert s[0] == 'localhost'
assert s[1] == 5672
assert s[2].username == 'guest' and s[2].password == 'guest'
assert s[3] == 'cqrs'
assert s[3] == '/'
assert s[4] == 'cqrs'


def test_non_default_url_settings(settings):
settings.CQRS = {
'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport',
'url': 'amqp://usr:pswd@rabbit:8000',
'url': 'amqp://usr:pswd@rabbit:8000/test',
'exchange': 'exchange',
}
s = PublicRabbitMQTransport.get_common_settings()
assert s[0] == 'rabbit'
assert s[1] == 8000
assert s[2].username == 'usr' and s[2].password == 'pswd'
assert s[3] == 'exchange'
assert s[3] == 'test'
assert s[4] == 'exchange'


def test_root_virtual_host_url_settings(settings):
settings.CQRS = {
'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport',
'url': 'amqp://usr:pswd@rabbit:8000/',
'exchange': 'exchange',
}
s = PublicRabbitMQTransport.get_common_settings()
assert s[0] == 'rabbit'
assert s[1] == 8000
assert s[2].username == 'usr' and s[2].password == 'pswd'
assert s[3] == '/'
assert s[4] == 'exchange'


def test_nested_virtual_host_url_settings(settings):
settings.CQRS = {
'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport',
'url': 'amqp://usr:pswd@rabbit:8000/foo/bar',
'exchange': 'exchange',
}
s = PublicRabbitMQTransport.get_common_settings()
assert s[0] == 'rabbit'
assert s[1] == 8000
assert s[2].username == 'usr' and s[2].password == 'pswd'
assert s[3] == 'foo/bar'
assert s[4] == 'exchange'


def test_invalid_url_settings(settings):
Expand Down Expand Up @@ -151,6 +184,103 @@ def test_consumer_non_default_settings(settings, caplog):
assert "The 'consumer_prefetch_count' setting is ignored for RabbitMQTransport." in caplog.text


def test_get_consumer_rmq_objects_passes_virtual_host(mocker):
connection = mocker.MagicMock()
channel = mocker.MagicMock()
consumer_generator = iter(())

connection.channel.return_value = channel
channel.consume.return_value = consumer_generator

connection_parameters = mocker.sentinel.connection_parameters
connection_parameters_cls = mocker.patch(
'dj_cqrs.transport.rabbit_mq.ConnectionParameters',
return_value=connection_parameters,
)
blocking_connection = mocker.patch(
'dj_cqrs.transport.rabbit_mq.BlockingConnection',
return_value=connection,
)

creds = mocker.MagicMock()
result = RabbitMQTransport._get_consumer_rmq_objects(
host='rabbit',
port=5672,
creds=creds,
virtual_host='custom-vhost',
exchange='cqrs',
queue_name='replica',
dead_letter_queue_name='dead_letter_replica',
prefetch_count=10,
)

connection_parameters_cls.assert_called_once_with(
host='rabbit',
port=5672,
credentials=creds,
virtual_host='custom-vhost',
)
blocking_connection.assert_called_once_with(connection_parameters)
assert result == (connection, channel, consumer_generator)


def test_get_producer_rmq_objects_passes_virtual_host_for_async(mocker):
creds = mocker.MagicMock()
create_connection = mocker.patch.object(
RabbitMQTransport,
'_create_connection',
return_value=(mocker.MagicMock(), mocker.MagicMock()),
)

RabbitMQTransport._get_producer_rmq_objects(
host='rabbit',
port=5672,
creds=creds,
virtual_host='custom-vhost',
exchange='cqrs',
)

create_connection.assert_called_once_with('rabbit', 5672, creds, 'custom-vhost', 'cqrs')


def test_get_producer_rmq_objects_passes_virtual_host_for_sync(mocker):
RabbitMQTransport._producer_connection = None
RabbitMQTransport._producer_channel = None

connection = mocker.MagicMock()
channel = mocker.MagicMock()
creds = mocker.MagicMock()
create_connection = mocker.patch.object(
RabbitMQTransport,
'_create_connection',
return_value=(connection, channel),
)

first_result = RabbitMQTransport._get_producer_rmq_objects(
host='rabbit',
port=5672,
creds=creds,
virtual_host='custom-vhost',
exchange='cqrs',
signal_type=SignalType.SYNC,
)
second_result = RabbitMQTransport._get_producer_rmq_objects(
host='rabbit',
port=5672,
creds=creds,
virtual_host='another-vhost',
exchange='cqrs',
signal_type=SignalType.SYNC,
)

create_connection.assert_called_once_with('rabbit', 5672, creds, 'custom-vhost', 'cqrs')
assert first_result == (connection, channel)
assert second_result == (connection, channel)

RabbitMQTransport._producer_connection = None
RabbitMQTransport._producer_channel = None


@pytest.fixture
def rabbit_transport(settings):
settings.CQRS = {
Expand Down