diff --git a/CHANGELOG.md b/CHANGELOG.md index ac0f5718..cfbf6b7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +* Added `retain` option to `publish()` method for transports that allow retaining messages for new subscribers. +* Added support for MQTT over secure WebSockets using `transport="websockets"` and `tls_set()`. + ### Changed ### Removed @@ -236,4 +239,3 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed ### Removed - diff --git a/pyproject.toml b/pyproject.toml index 4b1a2c87..f658d2c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,6 @@ select = ["E", "F", "I"] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["I001"] -"tests/*" = ["I001"] "tasks.py" = ["I001"] [tool.pytest.ini_options] diff --git a/src/compas_eve/core.py b/src/compas_eve/core.py index 47d6445f..20982015 100644 --- a/src/compas_eve/core.py +++ b/src/compas_eve/core.py @@ -5,7 +5,6 @@ from typing import Type from typing import Union - DEFAULT_TRANSPORT = None @@ -59,7 +58,7 @@ def id_counter(self) -> int: self._id_counter += 1 return self._id_counter - def publish(self, topic: "Topic", message: Union["Message", dict]) -> None: + def publish(self, topic: "Topic", message: Union["Message", dict], **options: Any) -> None: pass def subscribe(self, topic: "Topic", callback: Callable) -> Optional[str]: @@ -173,19 +172,22 @@ def message_published(self, message: Union[Message, dict]) -> None: """Handler called when a message has been published.""" pass - def publish(self, message: Union[Message, dict]) -> None: + def publish(self, message: Union[Message, dict], **options: Any) -> None: """Publish a message to the topic. Parameters ---------- message The message to publish. + **options + Transport-specific options passed through to the underlying transport. + For example, ``retain=True`` on MQTT and InMemory transports. """ # TODO: check if message type matches self.topic.message_type declared if not self.is_advertised: self.advertise() - self.transport.publish(self.topic, message) + self.transport.publish(self.topic, message, **options) self.message_published(message) def advertise(self) -> None: diff --git a/src/compas_eve/memory/__init__.py b/src/compas_eve/memory/__init__.py index 58b070c4..381c2e36 100644 --- a/src/compas_eve/memory/__init__.py +++ b/src/compas_eve/memory/__init__.py @@ -24,12 +24,13 @@ class InMemoryTransport(Transport, EventEmitterMixin): def __init__(self, codec: Optional[MessageCodec] = None, *args, **kwargs): super(InMemoryTransport, self).__init__(codec=codec, *args, **kwargs) self._local_callbacks = {} + self._retained = {} def on_ready(self, callback: Callable): """In-memory transport is always ready, it will immediately trigger the callback.""" callback() - def publish(self, topic: Topic, message: Message): + def publish(self, topic: Topic, message: Message, **options): """Publish a message to a topic. Parameters @@ -38,12 +39,20 @@ def publish(self, topic: Topic, message: Message): Instance of the topic to publish to. message Instance of the message to publish. + retain : bool, optional + If True, the last message on this topic is stored and delivered + immediately to any new subscriber. Defaults to False. """ + retain = options.pop("retain", False) + if options: + raise TypeError("publish() got unexpected options for InMemoryTransport: {}".format(", ".join(options))) event_key = "event:{}".format(topic.name) def _callback(**kwargs): encoded_message = self.codec.encode(message) encoded_message_bytes = encoded_message if isinstance(encoded_message, bytes) else encoded_message.encode("utf-8") + if retain: + self._retained[topic.name] = encoded_message_bytes self.emit(event_key, encoded_message_bytes) self.on_ready(_callback) @@ -75,6 +84,8 @@ def _local_callback(msg): def _callback(**kwargs): self.on(event_key, _local_callback) + if topic.name in self._retained: + _local_callback(self._retained[topic.name]) self._local_callbacks[subscribe_id] = _local_callback diff --git a/src/compas_eve/mqtt/mqtt_paho.py b/src/compas_eve/mqtt/mqtt_paho.py index 1aac3292..5bcbe708 100644 --- a/src/compas_eve/mqtt/mqtt_paho.py +++ b/src/compas_eve/mqtt/mqtt_paho.py @@ -1,5 +1,7 @@ import uuid +from typing import Any from typing import Callable +from typing import Dict from typing import Optional import paho.mqtt.client as mqtt @@ -30,12 +32,30 @@ class MqttTransport(Transport, EventEmitterMixin): MQTT broker port, defaults to `1883`. client_id Client ID for the MQTT connection. If not provided, a unique ID will be generated. + transport + Paho MQTT transport to use. Defaults to `"tcp"`. Use `"websockets"` for MQTT over WebSockets. + tls + If True, enables TLS by calling `client.tls_set()` before connecting. + tls_options + Optional keyword arguments for `client.tls_set()`, e.g. `ca_certs`, `certfile`, + `keyfile`, `cert_reqs`, `tls_version`, or `ciphers`. Providing this also enables TLS. codec The codec to use for encoding and decoding messages. If not provided, defaults to [JsonMessageCodec][compas_eve.codecs.JsonMessageCodec]. """ - def __init__(self, host: str, port: int = 1883, client_id: Optional[str] = None, codec: Optional[MessageCodec] = None, *args, **kwargs): + def __init__( + self, + host: str, + port: int = 1883, + client_id: Optional[str] = None, + codec: Optional[MessageCodec] = None, + transport: str = "tcp", + tls: bool = False, + tls_options: Optional[Dict[str, Any]] = None, + *args, + **kwargs, + ): super(MqttTransport, self).__init__(codec=codec, *args, **kwargs) self.host = host self.port = port @@ -45,10 +65,12 @@ def __init__(self, host: str, port: int = 1883, client_id: Optional[str] = None, if client_id is None: client_id = "compas_eve_{}".format(uuid.uuid4().hex[:8]) if PAHO_MQTT_V2_AVAILABLE: - self.client = mqtt.Client(client_id=client_id, callback_api_version=CallbackAPIVersion.VERSION1) + self.client = mqtt.Client(client_id=client_id, callback_api_version=CallbackAPIVersion.VERSION1, transport=transport) else: - self.client = mqtt.Client(client_id=client_id) + self.client = mqtt.Client(client_id=client_id, transport=transport) self.client.on_connect = self._on_connect + if tls or tls_options is not None: + self.client.tls_set(**(tls_options or {})) self.client.connect(self.host, self.port) self.client.loop_start() @@ -73,7 +95,7 @@ def on_ready(self, callback: Callable): else: self.once("ready", callback) - def publish(self, topic: Topic, message: Message): + def publish(self, topic: Topic, message: Message, **options): """Publish a message to a topic. Parameters @@ -82,11 +104,17 @@ def publish(self, topic: Topic, message: Message): Instance of the topic to publish to. message Instance of the message to publish. + retain : bool, optional + If True, the broker retains the last message on this topic and + delivers it immediately to any new subscriber. Defaults to False. """ + retain = options.pop("retain", False) + if options: + raise TypeError("publish() got unexpected options for MqttTransport: {}".format(", ".join(options))) def _callback(**kwargs): encoded_message = self.codec.encode(message) - self.client.publish(topic.name, encoded_message) + self.client.publish(topic.name, encoded_message, retain=retain) self.on_ready(_callback) diff --git a/src/compas_eve/zenoh/zenoh_transport.py b/src/compas_eve/zenoh/zenoh_transport.py index 288856e8..2de1d0ed 100644 --- a/src/compas_eve/zenoh/zenoh_transport.py +++ b/src/compas_eve/zenoh/zenoh_transport.py @@ -64,7 +64,7 @@ def on_ready(self, callback: Callable) -> None: else: self.once("ready", callback) - def publish(self, topic: Topic, message: Message) -> None: + def publish(self, topic: Topic, message: Message, **options: Any) -> None: """Publish a message to a topic. Parameters @@ -74,13 +74,16 @@ def publish(self, topic: Topic, message: Message) -> None: message Instance of the message to publish. """ + if options: + raise TypeError("publish() got unexpected options for ZenohTransport: {}".format(", ".join(options))) def _callback(**kwargs: Any) -> None: - if self._get_topic_name(topic) not in self._publishers: - self._publishers[self._get_topic_name(topic)] = self.session.declare_publisher(self._get_topic_name(topic)) + topic_name = self._get_topic_name(topic) + if topic_name not in self._publishers: + self._publishers[topic_name] = self.session.declare_publisher(topic_name) encoded_message = self.codec.encode(message) - self._publishers[self._get_topic_name(topic)].put(encoded_message) + self._publishers[topic_name].put(encoded_message) self.on_ready(_callback) @@ -111,8 +114,9 @@ def _zenoh_handler(sample: Any) -> None: self.emit(event_key, message_obj) def _subscribe_callback(**kwargs: Any) -> None: - if self._get_topic_name(topic) not in self._subscribers: - self._subscribers[self._get_topic_name(topic)] = self.session.declare_subscriber(self._get_topic_name(topic), _zenoh_handler) + topic_name = self._get_topic_name(topic) + if topic_name not in self._subscribers: + self._subscribers[topic_name] = self.session.declare_subscriber(topic_name, _zenoh_handler) self.on(event_key, _local_callback) diff --git a/tests/integration/test_transports.py b/tests/integration/test_transports.py index 4f574069..1b5fb0c7 100644 --- a/tests/integration/test_transports.py +++ b/tests/integration/test_transports.py @@ -20,6 +20,13 @@ HOST = "localhost" +@pytest.fixture +def mqtt_tx(): + tx = MqttTransport(HOST) + yield tx + tx.close() + + @pytest.fixture(params=["mqtt", "zenoh"]) def tx(request): if request.param == "mqtt": @@ -217,3 +224,32 @@ def callback(msg): assert received, "Message not received" assert result["value"].name == "Jazz" assert result["value"]["name"] == "Jazz", "Messages should be accessible as dict" + + +def test_mqtt_retain_delivers_to_late_subscriber(mqtt_tx): + topic = Topic("/messages_compas_eve_test/test_retain/", Message) + + pub = Publisher(topic, transport=mqtt_tx) + pub.publish(Message(value=42), retain=True) + time.sleep(0.2) + + result = dict(value=None, event=Event()) + + def callback(msg): + result["value"] = msg.value + result["event"].set() + + Subscriber(topic, callback, transport=mqtt_tx).subscribe() + + received = result["event"].wait(timeout=3) + assert received, "Retained message not delivered to late subscriber" + assert result["value"] == 42 + + # Clean up: publish empty retained message to clear broker state + pub.publish(Message(), retain=True) + + +def test_mqtt_unknown_option_raises(mqtt_tx): + topic = Topic("/messages_compas_eve_test/test_bad_option/", Message) + with pytest.raises(TypeError): + Publisher(topic, transport=mqtt_tx).publish(Message(value=1), unknown_flag=True) diff --git a/tests/unit/test_codecs.py b/tests/unit/test_codecs.py index 0faa8c8e..7c570f08 100644 --- a/tests/unit/test_codecs.py +++ b/tests/unit/test_codecs.py @@ -1,4 +1,5 @@ from compas.geometry import Frame + from compas_eve import Message from compas_eve.codecs import JsonMessageCodec from compas_eve.codecs import ProtobufMessageCodec diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index df87139c..d361a8ad 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -1,5 +1,7 @@ from threading import Event +import pytest + from compas_eve import InMemoryTransport from compas_eve import Message from compas_eve import Publisher @@ -104,3 +106,64 @@ def callback(msg): def test_message_str(): msg = Message(a=3) assert str(msg) == "{'a': 3}" + + +def test_retain_delivers_to_late_subscriber(): + tx = InMemoryTransport() + topic = Topic("/messages_compas_eve_test/retain/", Message) + + Publisher(topic, transport=tx).publish(Message(value=42), retain=True) + + result = dict(value=None, event=Event()) + + def callback(msg): + result["value"] = msg.value + result["event"].set() + + Subscriber(topic, callback, transport=tx).subscribe() + + received = result["event"].wait(timeout=1) + assert received, "Retained message not delivered to late subscriber" + assert result["value"] == 42 + + +def test_retain_last_message_wins(): + tx = InMemoryTransport() + topic = Topic("/messages_compas_eve_test/retain_last/", Message) + pub = Publisher(topic, transport=tx) + + pub.publish(Message(value=1), retain=True) + pub.publish(Message(value=2), retain=True) + + result = dict(value=None, event=Event()) + + def callback(msg): + result["value"] = msg.value + result["event"].set() + + Subscriber(topic, callback, transport=tx).subscribe() + + received = result["event"].wait(timeout=1) + assert received, "Retained message not delivered" + assert result["value"] == 2 + + +def test_no_retain_does_not_deliver_to_late_subscriber(): + tx = InMemoryTransport() + topic = Topic("/messages_compas_eve_test/no_retain/", Message) + + Publisher(topic, transport=tx).publish(Message(value=42)) + + event = Event() + Subscriber(topic, lambda m: event.set(), transport=tx).subscribe() + + received = event.wait(timeout=0.2) + assert not received, "Non-retained message should not be delivered to late subscriber" + + +def test_unknown_option_raises(): + tx = InMemoryTransport() + topic = Topic("/messages_compas_eve_test/bad_option/", Message) + + with pytest.raises(TypeError): + Publisher(topic, transport=tx).publish(Message(value=1), unknown_flag=True) diff --git a/tests/unit/test_mqtt_paho_compatibility.py b/tests/unit/test_mqtt_paho_compatibility.py index a6c8103e..e0591b69 100644 --- a/tests/unit/test_mqtt_paho_compatibility.py +++ b/tests/unit/test_mqtt_paho_compatibility.py @@ -1,6 +1,11 @@ +from unittest.mock import Mock +from unittest.mock import call +from unittest.mock import patch + import pytest -from unittest.mock import Mock, patch -from compas_eve.mqtt.mqtt_paho import MqttTransport, PAHO_MQTT_V2_AVAILABLE + +from compas_eve.mqtt import MqttTransport +from compas_eve.mqtt.mqtt_paho import PAHO_MQTT_V2_AVAILABLE def test_paho_mqtt_v1_compatibility(): @@ -11,11 +16,12 @@ def test_paho_mqtt_v1_compatibility(): # This should work as if paho-mqtt 1.x is installed transport = MqttTransport("localhost") - # Should have called mqtt.Client() with client_id parameter only (no callback_api_version) + # Should have called mqtt.Client() without callback_api_version mock_client_class.assert_called_once() call_args = mock_client_class.call_args assert "client_id" in call_args.kwargs assert call_args.kwargs["client_id"].startswith("compas_eve_") + assert call_args.kwargs["transport"] == "tcp" assert "callback_api_version" not in call_args.kwargs assert transport.client == mock_client @@ -38,6 +44,41 @@ def test_paho_mqtt_v2_compatibility(): call_args = mock_client_class.call_args assert "client_id" in call_args.kwargs assert call_args.kwargs["client_id"].startswith("compas_eve_") + assert call_args.kwargs["transport"] == "tcp" assert "callback_api_version" in call_args.kwargs assert call_args.kwargs["callback_api_version"] == CallbackAPIVersion.VERSION1 assert transport.client == mock_client + + +def test_mqtt_websockets_transport(): + with patch("compas_eve.mqtt.mqtt_paho.PAHO_MQTT_V2_AVAILABLE", False), patch("paho.mqtt.client.Client") as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + MqttTransport("localhost", port=443, transport="websockets") + + call_args = mock_client_class.call_args + assert call_args.kwargs["transport"] == "websockets" + mock_client.tls_set.assert_not_called() + mock_client.connect.assert_called_once_with("localhost", 443) + + +def test_mqtt_tls_enabled_before_connect(): + with patch("compas_eve.mqtt.mqtt_paho.PAHO_MQTT_V2_AVAILABLE", False), patch("paho.mqtt.client.Client") as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + MqttTransport("localhost", port=443, transport="websockets", tls=True) + + mock_client.tls_set.assert_called_once_with() + assert mock_client.method_calls[:2] == [call.tls_set(), call.connect("localhost", 443)] + + +def test_mqtt_tls_options_enable_tls(): + with patch("compas_eve.mqtt.mqtt_paho.PAHO_MQTT_V2_AVAILABLE", False), patch("paho.mqtt.client.Client") as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + MqttTransport("localhost", tls_options={"ca_certs": "/tmp/ca.pem"}) + + mock_client.tls_set.assert_called_once_with(ca_certs="/tmp/ca.pem")