diff --git a/elementary/messages/messaging_integrations/base_messaging_integration.py b/elementary/messages/messaging_integrations/base_messaging_integration.py index 43409a3bb..0bc118fd0 100644 --- a/elementary/messages/messaging_integrations/base_messaging_integration.py +++ b/elementary/messages/messaging_integrations/base_messaging_integration.py @@ -18,11 +18,12 @@ class MessageSendResult(BaseModel, Generic[T]): timestamp: datetime + message_format: str message_context: Optional[T] = None DestinationType = TypeVar("DestinationType") -MessageContextType = TypeVar("MessageContextType") +MessageContextType = TypeVar("MessageContextType", bound=BaseModel) class BaseMessagingIntegration(ABC, Generic[DestinationType, MessageContextType]): diff --git a/elementary/messages/messaging_integrations/empty_message_context.py b/elementary/messages/messaging_integrations/empty_message_context.py new file mode 100644 index 000000000..278615592 --- /dev/null +++ b/elementary/messages/messaging_integrations/empty_message_context.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class EmptyMessageContext(BaseModel): + pass diff --git a/elementary/messages/messaging_integrations/mapped.py b/elementary/messages/messaging_integrations/mapped.py new file mode 100644 index 000000000..22e3604b7 --- /dev/null +++ b/elementary/messages/messaging_integrations/mapped.py @@ -0,0 +1,42 @@ +from typing import Dict + +from elementary.messages.message_body import MessageBody +from elementary.messages.messaging_integrations.base_messaging_integration import ( + BaseMessagingIntegration, + MessageContextType, + MessageSendResult, +) +from elementary.messages.messaging_integrations.exceptions import ( + MessagingIntegrationError, +) + + +class MappedMessagingIntegration(BaseMessagingIntegration[str, MessageContextType]): + def __init__( + self, mapping: Dict[str, BaseMessagingIntegration[None, MessageContextType]] + ): + self._mapping = mapping + + def send_message( + self, destination: str, body: MessageBody + ) -> MessageSendResult[MessageContextType]: + if destination not in self._mapping: + raise MessagingIntegrationError(f"Invalid destination: {destination}") + return self._mapping[destination].send_message(None, body) + + def supports_reply(self) -> bool: + return all( + integration.supports_reply() for integration in self._mapping.values() + ) + + def supports_actions(self) -> bool: + return all( + integration.supports_actions() for integration in self._mapping.values() + ) + + def reply_to_message( + self, destination: str, message_context: MessageContextType, body: MessageBody + ) -> MessageSendResult[MessageContextType]: + if destination not in self._mapping: + raise MessagingIntegrationError(f"Invalid destination: {destination}") + return self._mapping[destination].reply_to_message(None, message_context, body) diff --git a/elementary/messages/messaging_integrations/slack_web.py b/elementary/messages/messaging_integrations/slack_web.py index b3496e3bb..721b43c2c 100644 --- a/elementary/messages/messaging_integrations/slack_web.py +++ b/elementary/messages/messaging_integrations/slack_web.py @@ -100,6 +100,7 @@ def _send_message( id=response["ts"], channel=response["channel"] ), timestamp=response["ts"], + message_format="block_kit", ) def _handle_send_err(self, err: SlackApiError, channel_name: str): diff --git a/elementary/messages/messaging_integrations/slack_webhook.py b/elementary/messages/messaging_integrations/slack_webhook.py index 4cc6bad43..fdcbd59b9 100644 --- a/elementary/messages/messaging_integrations/slack_webhook.py +++ b/elementary/messages/messaging_integrations/slack_webhook.py @@ -15,6 +15,9 @@ BaseMessagingIntegration, MessageSendResult, ) +from elementary.messages.messaging_integrations.empty_message_context import ( + EmptyMessageContext, +) from elementary.messages.messaging_integrations.exceptions import ( MessagingIntegrationError, ) @@ -23,7 +26,9 @@ ONE_SECOND = 1 -class SlackWebhookMessagingIntegration(BaseMessagingIntegration[None, None]): +class SlackWebhookMessagingIntegration( + BaseMessagingIntegration[None, EmptyMessageContext] +): def __init__( self, client: WebhookClient, tracking: Optional[Tracking] = None ) -> None: @@ -52,12 +57,13 @@ def _send_message(self, formatted_message: FormattedBlockKitMessage) -> None: def send_message( self, destination: None, body: MessageBody - ) -> MessageSendResult[None]: + ) -> MessageSendResult[EmptyMessageContext]: formatted_message = format_block_kit(body) self._send_message(formatted_message) return MessageSendResult( - message_context=destination, + message_context=EmptyMessageContext(), timestamp=datetime.utcnow(), + message_format="block_kit", ) def supports_reply(self) -> bool: diff --git a/elementary/messages/messaging_integrations/teams_webhook.py b/elementary/messages/messaging_integrations/teams_webhook.py index b15bdd7af..5a644c667 100644 --- a/elementary/messages/messaging_integrations/teams_webhook.py +++ b/elementary/messages/messaging_integrations/teams_webhook.py @@ -10,6 +10,9 @@ BaseMessagingIntegration, MessageSendResult, ) +from elementary.messages.messaging_integrations.empty_message_context import ( + EmptyMessageContext, +) from elementary.messages.messaging_integrations.exceptions import ( MessagingIntegrationError, ) @@ -44,21 +47,24 @@ def send_adaptive_card(webhook_url: str, card: dict) -> requests.Response: return response -class TeamsWebhookMessagingIntegration(BaseMessagingIntegration[Channel, Channel]): +class TeamsWebhookMessagingIntegration( + BaseMessagingIntegration[None, EmptyMessageContext] +): def __init__(self, url: str) -> None: self.url = url def send_message( self, - destination: Channel, + destination: None, body: MessageBody, - ) -> MessageSendResult[Channel]: + ) -> MessageSendResult[EmptyMessageContext]: card = format_adaptive_card(body) try: send_adaptive_card(self.url, card) return MessageSendResult( - message_context=destination, + message_context=EmptyMessageContext(), timestamp=datetime.utcnow(), + message_format="adaptive_cards", ) except requests.RequestException as e: raise MessagingIntegrationError( diff --git a/tests/unit/messages/messaging_integrations/test_mapped.py b/tests/unit/messages/messaging_integrations/test_mapped.py new file mode 100644 index 000000000..a4375aa21 --- /dev/null +++ b/tests/unit/messages/messaging_integrations/test_mapped.py @@ -0,0 +1,180 @@ +from datetime import datetime +from typing import Dict, List +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + +from elementary.messages.blocks import HeaderBlock +from elementary.messages.message_body import MessageBody +from elementary.messages.messaging_integrations.base_messaging_integration import ( + BaseMessagingIntegration, + MessageSendResult, +) +from elementary.messages.messaging_integrations.exceptions import ( + MessagingIntegrationError, +) +from elementary.messages.messaging_integrations.mapped import MappedMessagingIntegration + + +class MockMessageContext(BaseModel): + id: str + + +class MockMessagingIntegration(BaseMessagingIntegration[None, MockMessageContext]): + def __init__(self, supports_reply: bool = True, supports_actions: bool = False): + self.supports_reply_value = supports_reply + self.supports_actions_value = supports_actions + self.send_message_mock = MagicMock() + self.send_message_mock.return_value = MessageSendResult( + timestamp=datetime.now(), + message_format="test_format", + message_context=MockMessageContext(id="test_id"), + ) + self.reply_to_message_mock = MagicMock() + self.reply_to_message_mock.return_value = MessageSendResult( + timestamp=datetime.now(), + message_format="test_format", + message_context=MockMessageContext(id="test_id"), + ) + + def send_message( + self, destination: None, body: MessageBody + ) -> MessageSendResult[MockMessageContext]: + return self.send_message_mock(destination, body) + + def supports_reply(self) -> bool: + return self.supports_reply_value + + def supports_actions(self) -> bool: + return self.supports_actions_value + + def reply_to_message( + self, + destination: None, + message_context: MockMessageContext, + body: MessageBody, + ) -> MessageSendResult[MockMessageContext]: + return self.reply_to_message_mock(destination, message_context, body) + + +@pytest.fixture +def mock_integration() -> MockMessagingIntegration: + return MockMessagingIntegration() + + +@pytest.fixture +def mapped_integration( + mock_integration: MockMessagingIntegration, +) -> MappedMessagingIntegration: + return MappedMessagingIntegration({"test_destination": mock_integration}) + + +def test_send_message_success( + mapped_integration: MappedMessagingIntegration, + mock_integration: MockMessagingIntegration, +) -> None: + destination = "test_destination" + body = MessageBody(blocks=[HeaderBlock(text="test message")]) + expected_result: MessageSendResult[MockMessageContext] = MessageSendResult( + timestamp=datetime.now(), + message_format="test_format", + message_context=None, + ) + mock_integration.send_message_mock.return_value = expected_result + + result = mapped_integration.send_message(destination, body) + + assert result == expected_result + mock_integration.send_message_mock.assert_called_once_with(None, body) + + +def test_send_message_invalid_destination( + mapped_integration: MappedMessagingIntegration, +) -> None: + destination = "invalid_destination" + body = MessageBody(blocks=[HeaderBlock(text="test message")]) + + with pytest.raises(MessagingIntegrationError) as exc_info: + mapped_integration.send_message(destination, body) + assert str(exc_info.value) == "Invalid destination: invalid_destination" + + +@pytest.mark.parametrize( + "integrations_support_reply,expected_support", + [ + ([True, True], True), + ([True, False], False), + ([False, True], False), + ([False, False], False), + ], +) +def test_supports_reply( + integrations_support_reply: List[bool], expected_support: bool +) -> None: + integrations: Dict[str, BaseMessagingIntegration[None, MockMessageContext]] = { + f"dest_{i}": MockMessagingIntegration(supports_reply=supports_reply) + for i, supports_reply in enumerate(integrations_support_reply) + } + mapped_integration = MappedMessagingIntegration(integrations) + + result = mapped_integration.supports_reply() + + assert result == expected_support + + +@pytest.mark.parametrize( + "integrations_support_actions,expected_support", + [ + ([True, True], True), + ([True, False], False), + ([False, True], False), + ([False, False], False), + ], +) +def test_supports_actions( + integrations_support_actions: List[bool], expected_support: bool +) -> None: + integrations: Dict[str, BaseMessagingIntegration[None, MockMessageContext]] = { + f"dest_{i}": MockMessagingIntegration(supports_actions=supports_actions) + for i, supports_actions in enumerate(integrations_support_actions) + } + mapped_integration = MappedMessagingIntegration(integrations) + + result = mapped_integration.supports_actions() + + assert result == expected_support + + +def test_reply_to_message_success( + mapped_integration: MappedMessagingIntegration, + mock_integration: MockMessagingIntegration, +) -> None: + destination = "test_destination" + message_context = MagicMock() + body = MessageBody(blocks=[HeaderBlock(text="test reply")]) + expected_result: MessageSendResult[MockMessageContext] = MessageSendResult( + timestamp=datetime.now(), + message_format="test_format", + message_context=message_context, + ) + mock_integration.reply_to_message_mock.return_value = expected_result + + result = mapped_integration.reply_to_message(destination, message_context, body) + + assert result == expected_result + mock_integration.reply_to_message_mock.assert_called_once_with( + None, message_context, body + ) + + +def test_reply_to_message_invalid_destination( + mapped_integration: MappedMessagingIntegration, +) -> None: + destination = "invalid_destination" + message_context = MagicMock() + body = MessageBody(blocks=[HeaderBlock(text="test reply")]) + + with pytest.raises(MessagingIntegrationError) as exc_info: + mapped_integration.reply_to_message(destination, message_context, body) + assert str(exc_info.value) == "Invalid destination: invalid_destination"