diff --git a/README.md b/README.md index 373b225..61b73cb 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ CQRS = { 'port': RABBITMQ_PORT, 'user': RABBITMQ_USERNAME, 'password': RABBITMQ_PASSWORD, + 'virtual_host': RABBITMQ_VIRTUAL_HOST, } ``` @@ -119,6 +120,7 @@ CQRS = { 'port': RABBITMQ_PORT, 'user': RABBITMQ_USERNAME, 'password': RABBITMQ_PASSWORD, + 'virtual_host': RABBITMQ_VIRTUAL_HOST, } ``` * Apply migrations on both services diff --git a/dj_cqrs/management/commands/cqrs_dead_letters.py b/dj_cqrs/management/commands/cqrs_dead_letters.py index aa3f974..67a9846 100644 --- a/dj_cqrs/management/commands/cqrs_dead_letters.py +++ b/dj_cqrs/management/commands/cqrs_dead_letters.py @@ -21,8 +21,8 @@ def get_common_settings(cls): return cls._get_common_settings() @classmethod - def create_connection(cls, host, port, creds, exchange): - return cls._create_connection(host, port, creds, exchange) + def create_connection(cls, host, port, creds, virtual_host, exchange): + return cls._create_connection(host, port, creds, virtual_host, exchange) @classmethod def declare_queue(cls, channel, queue_name): @@ -75,11 +75,12 @@ def check_transport(self): raise CommandError('Dead letters commands available only for RabbitMQTransport.') def init_broker(self): - host, port, creds, exchange = RabbitMQTransportService.get_common_settings() + host, port, creds, virtual_host, exchange = RabbitMQTransportService.get_common_settings() connection, channel = RabbitMQTransportService.create_connection( host, port, creds, + virtual_host, exchange, ) diff --git a/dj_cqrs/transport/rabbit_mq.py b/dj_cqrs/transport/rabbit_mq.py index 8808fc3..a825dbb 100644 --- a/dj_cqrs/transport/rabbit_mq.py +++ b/dj_cqrs/transport/rabbit_mq.py @@ -287,6 +287,7 @@ def _get_consumer_rmq_objects( host, port, creds, + virtual_host, exchange, queue_name, dead_letter_queue_name, @@ -294,7 +295,7 @@ def _get_consumer_rmq_objects( cqrs_ids=None, ): connection = BlockingConnection( - ConnectionParameters(host=host, port=port, credentials=creds), + ConnectionParameters(host=host, port=port, credentials=creds, virtual_host=virtual_host), ) channel = connection.channel() channel.basic_qos(prefetch_count=prefetch_count) @@ -333,29 +334,30 @@ def _get_consumer_rmq_objects( return connection, channel, consumer_generator @classmethod - def _get_producer_rmq_objects(cls, host, port, creds, exchange, signal_type=None): + def _get_producer_rmq_objects(cls, host, port, creds, virtual_host, exchange, signal_type=None): """ Use shared connection in case of sync mode, otherwise create new connection for each message """ if signal_type == SignalType.SYNC: if cls._producer_connection is None: - connection, channel = cls._create_connection(host, port, creds, exchange) + connection, channel = cls._create_connection(host, port, creds, virtual_host, exchange) cls._producer_connection = connection cls._producer_channel = channel return cls._producer_connection, cls._producer_channel else: - return cls._create_connection(host, port, creds, exchange) + return cls._create_connection(host, port, creds, virtual_host, exchange) @classmethod - def _create_connection(cls, host, port, creds, exchange): + def _create_connection(cls, host, port, creds, virtual_host, exchange): connection = BlockingConnection( ConnectionParameters( host=host, port=port, credentials=creds, + virtual_host=virtual_host, blocked_connection_timeout=10, ), ) @@ -386,22 +388,28 @@ def _parse_url(url): parts.port or ConnectionParameters.DEFAULT_PORT, unquote(parts.username or '') or ConnectionParameters.DEFAULT_USERNAME, unquote(parts.password or '') or ConnectionParameters.DEFAULT_PASSWORD, + unquote(parts.path.lstrip('/')) or ConnectionParameters.DEFAULT_VIRTUAL_HOST, ) @classmethod def _get_common_settings(cls): if 'url' in settings.CQRS: - host, port, user, password = cls._parse_url(settings.CQRS.get('url')) + host, port, user, password, virtual_host = cls._parse_url(settings.CQRS.get('url')) else: host = settings.CQRS.get('host', ConnectionParameters.DEFAULT_HOST) port = settings.CQRS.get('port', ConnectionParameters.DEFAULT_PORT) user = settings.CQRS.get('user', ConnectionParameters.DEFAULT_USERNAME) password = settings.CQRS.get('password', ConnectionParameters.DEFAULT_PASSWORD) + virtual_host = settings.CQRS.get( + 'virtual_host', + ConnectionParameters.DEFAULT_VIRTUAL_HOST, + ) exchange = settings.CQRS.get('exchange', 'cqrs') return ( host, port, credentials.PlainCredentials(user, password, erase_on_connect=True), + virtual_host, exchange, ) diff --git a/docs/transports.md b/docs/transports.md index 41494c7..0e20362 100644 --- a/docs/transports.md +++ b/docs/transports.md @@ -19,6 +19,29 @@ CQRS = { } ``` +The virtual host is optional and defaults to `/` (root vhost). To specify a custom virtual host, +include it in the URL path: + +``` py3 +# Using root virtual host (default) +CQRS = { + 'transport': 'dj_cqrs.transport.RabbitMQTransport', + 'url': 'amqp://guest:guest@rabbit:5672/' +} + +# Using a custom virtual host +CQRS = { + 'transport': 'dj_cqrs.transport.RabbitMQTransport', + 'url': 'amqp://guest:guest@rabbit:5672/my_vhost' +} + +# Using nested virtual host paths +CQRS = { + 'transport': 'dj_cqrs.transport.RabbitMQTransport', + 'url': 'amqp://guest:guest@rabbit:5672/production/app1' +} +``` + !!! warning Previous versions of the `RabbitMQTransport` use the attributes `host`, diff --git a/tests/test_commands/test_dead_letters.py b/tests/test_commands/test_dead_letters.py index ff52fb6..9ac569c 100644 --- a/tests/test_commands/test_dead_letters.py +++ b/tests/test_commands/test_dead_letters.py @@ -23,7 +23,7 @@ def test_dump(capsys, mocker): mocker.patch.object( RabbitMQTransport, '_get_common_settings', - return_value=('host', 'port', mocker.MagicMock(), 'exchange'), + return_value=('host', 'port', mocker.MagicMock(), '/', 'exchange'), ) queue = mocker.MagicMock() diff --git a/tests/test_transport/test_rabbit_mq.py b/tests/test_transport/test_rabbit_mq.py index 3bbd81c..5363b77 100644 --- a/tests/test_transport/test_rabbit_mq.py +++ b/tests/test_transport/test_rabbit_mq.py @@ -68,7 +68,8 @@ def test_default_settings(): assert s[0] == 'localhost' assert s[1] == 5672 assert s[2].username == 'guest' and s[2].password == 'guest' - assert s[3] == 'cqrs' + assert s[3] == '/' + assert s[4] == 'cqrs' def test_non_default_settings(settings, caplog): @@ -78,6 +79,7 @@ def test_non_default_settings(settings, caplog): 'port': 8000, 'user': 'usr', 'password': 'pswd', + 'virtual_host': 'test', 'exchange': 'exchange', } @@ -85,7 +87,8 @@ def test_non_default_settings(settings, caplog): assert s[0] == 'rabbit' assert s[1] == 8000 assert s[2].username == 'usr' and s[2].password == 'pswd' - assert s[3] == 'exchange' + assert s[3] == 'test' + assert s[4] == 'exchange' def test_default_url_settings(settings): @@ -97,20 +100,50 @@ def test_default_url_settings(settings): assert s[0] == 'localhost' assert s[1] == 5672 assert s[2].username == 'guest' and s[2].password == 'guest' - assert s[3] == 'cqrs' + assert s[3] == '/' + assert s[4] == 'cqrs' def test_non_default_url_settings(settings): settings.CQRS = { 'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport', - 'url': 'amqp://usr:pswd@rabbit:8000', + 'url': 'amqp://usr:pswd@rabbit:8000/test', 'exchange': 'exchange', } s = PublicRabbitMQTransport.get_common_settings() assert s[0] == 'rabbit' assert s[1] == 8000 assert s[2].username == 'usr' and s[2].password == 'pswd' - assert s[3] == 'exchange' + assert s[3] == 'test' + assert s[4] == 'exchange' + + +def test_root_virtual_host_url_settings(settings): + settings.CQRS = { + 'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport', + 'url': 'amqp://usr:pswd@rabbit:8000/', + 'exchange': 'exchange', + } + s = PublicRabbitMQTransport.get_common_settings() + assert s[0] == 'rabbit' + assert s[1] == 8000 + assert s[2].username == 'usr' and s[2].password == 'pswd' + assert s[3] == '/' + assert s[4] == 'exchange' + + +def test_nested_virtual_host_url_settings(settings): + settings.CQRS = { + 'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport', + 'url': 'amqp://usr:pswd@rabbit:8000/foo/bar', + 'exchange': 'exchange', + } + s = PublicRabbitMQTransport.get_common_settings() + assert s[0] == 'rabbit' + assert s[1] == 8000 + assert s[2].username == 'usr' and s[2].password == 'pswd' + assert s[3] == 'foo/bar' + assert s[4] == 'exchange' def test_invalid_url_settings(settings): @@ -151,6 +184,103 @@ def test_consumer_non_default_settings(settings, caplog): assert "The 'consumer_prefetch_count' setting is ignored for RabbitMQTransport." in caplog.text +def test_get_consumer_rmq_objects_passes_virtual_host(mocker): + connection = mocker.MagicMock() + channel = mocker.MagicMock() + consumer_generator = iter(()) + + connection.channel.return_value = channel + channel.consume.return_value = consumer_generator + + connection_parameters = mocker.sentinel.connection_parameters + connection_parameters_cls = mocker.patch( + 'dj_cqrs.transport.rabbit_mq.ConnectionParameters', + return_value=connection_parameters, + ) + blocking_connection = mocker.patch( + 'dj_cqrs.transport.rabbit_mq.BlockingConnection', + return_value=connection, + ) + + creds = mocker.MagicMock() + result = RabbitMQTransport._get_consumer_rmq_objects( + host='rabbit', + port=5672, + creds=creds, + virtual_host='custom-vhost', + exchange='cqrs', + queue_name='replica', + dead_letter_queue_name='dead_letter_replica', + prefetch_count=10, + ) + + connection_parameters_cls.assert_called_once_with( + host='rabbit', + port=5672, + credentials=creds, + virtual_host='custom-vhost', + ) + blocking_connection.assert_called_once_with(connection_parameters) + assert result == (connection, channel, consumer_generator) + + +def test_get_producer_rmq_objects_passes_virtual_host_for_async(mocker): + creds = mocker.MagicMock() + create_connection = mocker.patch.object( + RabbitMQTransport, + '_create_connection', + return_value=(mocker.MagicMock(), mocker.MagicMock()), + ) + + RabbitMQTransport._get_producer_rmq_objects( + host='rabbit', + port=5672, + creds=creds, + virtual_host='custom-vhost', + exchange='cqrs', + ) + + create_connection.assert_called_once_with('rabbit', 5672, creds, 'custom-vhost', 'cqrs') + + +def test_get_producer_rmq_objects_passes_virtual_host_for_sync(mocker): + RabbitMQTransport._producer_connection = None + RabbitMQTransport._producer_channel = None + + connection = mocker.MagicMock() + channel = mocker.MagicMock() + creds = mocker.MagicMock() + create_connection = mocker.patch.object( + RabbitMQTransport, + '_create_connection', + return_value=(connection, channel), + ) + + first_result = RabbitMQTransport._get_producer_rmq_objects( + host='rabbit', + port=5672, + creds=creds, + virtual_host='custom-vhost', + exchange='cqrs', + signal_type=SignalType.SYNC, + ) + second_result = RabbitMQTransport._get_producer_rmq_objects( + host='rabbit', + port=5672, + creds=creds, + virtual_host='another-vhost', + exchange='cqrs', + signal_type=SignalType.SYNC, + ) + + create_connection.assert_called_once_with('rabbit', 5672, creds, 'custom-vhost', 'cqrs') + assert first_result == (connection, channel) + assert second_result == (connection, channel) + + RabbitMQTransport._producer_connection = None + RabbitMQTransport._producer_channel = None + + @pytest.fixture def rabbit_transport(settings): settings.CQRS = {