Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deebot_client/commands/json/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from deebot_client.message import (
HandlingResult,
MessageBodyDataDict,
MessageDict,
MessageDictOrJson,
)

from .common import JsonCommand, JsonCommandWithMessageHandling
Expand Down Expand Up @@ -42,7 +42,7 @@ def _handle_body_data_dict(
return HandlingResult.success()


class GetNetInfoLegacy(JsonCommand, CommandWithMessageHandling, MessageDict):
class GetNetInfoLegacy(JsonCommand, CommandWithMessageHandling, MessageDictOrJson):
"""Get network info command."""

NAME = "GetNetInfo"
Expand Down
21 changes: 9 additions & 12 deletions deebot_client/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from collections.abc import Callable, Coroutine
from contextlib import suppress
from datetime import datetime
import json
from typing import TYPE_CHECKING, Any, Final

from deebot_client.events.network import NetworkInfoEvent
Expand All @@ -19,6 +18,7 @@
AvailabilityEvent,
CleanLogEvent,
CustomCommandEvent,
FirmwareEvent,
LifeSpanEvent,
PositionsEvent,
StateEvent,
Expand All @@ -34,6 +34,7 @@
if TYPE_CHECKING:
from .authentication import Authenticator
from .command import DeviceCommandResult
from .message import MessagePayloadType

_LOGGER = get_logger(__name__)
_AVAILABLE_CHECK_INTERVAL = 60
Expand Down Expand Up @@ -113,6 +114,11 @@ async def on_network(event: NetworkInfoEvent) -> None:

self.events.subscribe(NetworkInfoEvent, on_network)

async def on_firmware(event: FirmwareEvent) -> None:
self.fw_version = event.version

self.events.subscribe(FirmwareEvent, on_firmware)

async def execute_command(self, command: Command) -> dict[str, Any]:
"""Execute given command.

Expand Down Expand Up @@ -191,7 +197,7 @@ def _set_available(self, *, available: bool) -> None:
self.events.notify(AvailabilityEvent(available=available))

def _handle_message(
self, message_name: str, message_data: str | bytes | bytearray | dict[str, Any]
self, message_name: str, message_data: MessagePayloadType
) -> None:
"""Handle the given message.

Expand All @@ -205,15 +211,6 @@ def _handle_message(
_LOGGER.debug("Try to handle message %s: %s", message_name, message_data)

if message := get_message(message_name, self._device_info.static.data_type):
if isinstance(message_data, dict):
data = message_data
else:
data = json.loads(message_data)

fw_version = data.get("header", {}).get("fwVer", None)
if fw_version:
self.fw_version = fw_version

message.handle(self.events, data)
message.handle(self.events, message_data)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("An exception occurred during handling message")
8 changes: 8 additions & 0 deletions deebot_client/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"Event",
"FanSpeedEvent",
"FanSpeedLevel",
"FirmwareEvent",
"MajorMapEvent",
"MapChangedEvent",
"MapSetEvent",
Expand Down Expand Up @@ -300,3 +301,10 @@ class CutDirectionEvent(Event):
"""Cut direction event representation."""

angle: int


@dataclass(frozen=True)
class FirmwareEvent(Event):
"""Firmware event."""

version: str
63 changes: 47 additions & 16 deletions deebot_client/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from dataclasses import dataclass
from enum import IntEnum, auto
import functools
import json
from typing import TYPE_CHECKING, Any, TypeVar, final

from deebot_client.events import FirmwareEvent
from deebot_client.util import verify_required_class_variables_exists

from .logging_filter import get_logger
Expand All @@ -19,6 +21,8 @@

_LOGGER = get_logger(__name__)

MessagePayloadType = str | bytes | bytearray | dict[str, Any]


class HandlingState(IntEnum):
"""Handling state enum."""
Expand Down Expand Up @@ -63,6 +67,15 @@ def wrapper(
) -> HandlingResult:
try:
response = func(cls, event_bus, data)
# This happens if for some reason someone calls super() of an ABC where handle is not implemented
if not response:
_LOGGER.error(
"Handler for message %s: %s returned no response. "
"This is a bug should not happen. Please report it.",
cls.NAME,
data,
)
return HandlingResult(HandlingState.ERROR)
if response.state == HandlingState.ANALYSE:
_LOGGER.debug("Could not handle %s message: %s", cls.NAME, data)
return HandlingResult(HandlingState.ANALYSE_LOGGED, response.args)
Expand All @@ -88,7 +101,7 @@ def __init_subclass__(cls) -> None:
@classmethod
@abstractmethod
def _handle(
cls, event_bus: EventBus, message: dict[str, Any] | str
cls, event_bus: EventBus, message: MessagePayloadType
) -> HandlingResult:
"""Handle message and notify the correct event subscribers.

Expand All @@ -98,9 +111,7 @@ def _handle(
@classmethod
@_handle_error_or_analyse
@final
def handle(
cls, event_bus: EventBus, message: dict[str, Any] | str
) -> HandlingResult:
def handle(cls, event_bus: EventBus, message: MessagePayloadType) -> HandlingResult:
"""Handle message and notify the correct event subscribers.

:return: A message response
Expand All @@ -120,28 +131,33 @@ def _handle_str(cls, event_bus: EventBus, message: str) -> HandlingResult:
"""

@classmethod
# @_handle_error_or_analyse @edenhaus will make the decorator to work again
@_handle_error_or_analyse
@final
def __handle_str(cls, event_bus: EventBus, message: str) -> HandlingResult:
return cls._handle_str(event_bus, message)

@classmethod
def _handle(
cls, event_bus: EventBus, message: dict[str, Any] | str
cls, event_bus: EventBus, message: MessagePayloadType
) -> HandlingResult:
"""Handle message and notify the correct event subscribers.

:return: A message response
"""
# This basically means an XML message
if isinstance(message, str):
return cls.__handle_str(event_bus, message)
if isinstance(message, bytearray):
data = bytes(message).decode()
elif isinstance(message, bytes):
data = message.decode()
elif isinstance(message, str):
data = message
else:
return super()._handle(event_bus, message)

return super()._handle(event_bus, message)
return cls.__handle_str(event_bus, data)


class MessageDict(Message, ABC):
"""Dict message."""
class MessageDictOrJson(Message, ABC):
"""Dict or json message."""

@classmethod
@abstractmethod
Expand All @@ -163,19 +179,34 @@ def __handle_dict(

@classmethod
def _handle(
cls, event_bus: EventBus, message: dict[str, Any] | str
cls, event_bus: EventBus, message: MessagePayloadType
) -> HandlingResult:
"""Handle message and notify the correct event subscribers.

:return: A message response
"""
if isinstance(message, dict):
return cls.__handle_dict(event_bus, message)
data = message
if not isinstance(message, dict):
try:
data = json.loads(message)
except Exception: # pylint: disable=broad-except
_LOGGER.debug(
"Could not decode message %s payload %s as JSON",
cls.NAME,
message,
)

if isinstance(data, dict):
fw_version = data.get("header", {}).get("fwVer", None)
if fw_version:
event_bus.notify(FirmwareEvent(fw_version))

return cls.__handle_dict(event_bus, data)

return super()._handle(event_bus, message)


class MessageBody(MessageDict, ABC):
class MessageBody(MessageDictOrJson, ABC):
"""Dict message with body attribute."""

@classmethod
Expand Down
31 changes: 18 additions & 13 deletions tests/commands/json/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@ async def assert_execute_command(
assert command._args == args

# success
json = get_request_json(get_success_body())
await assert_command(command, json, None)
json, firmware_event = get_request_json(get_success_body())
await assert_command(command, json, firmware_event)

# failed
with LogCapture() as log:
body = {"code": 500, "msg": "fail"}
json = get_request_json(body)
json, firmware_event = get_request_json(body)
await assert_command(
command, json, None, command_result=CommandResult(HandlingState.FAILED)
command,
json,
firmware_event,
command_result=CommandResult(HandlingState.FAILED),
)

log.check_present(
Expand All @@ -66,24 +69,26 @@ async def assert_set_command(
event_bus = Mock(spec_set=EventBus)

# Failed to set
json_data = get_message_json(
json_data, firmware_event = get_message_json(
{
"code": 500,
"msg": "fail",
}
)
command.handle_mqtt_p2p(event_bus, json.dumps(json_data))
event_bus.notify.assert_not_called()
event_bus.notify.assert_called_once_with(firmware_event)

event_bus.reset_mock()
# Success
command.handle_mqtt_p2p(event_bus, json.dumps(get_message_json(get_success_body())))
if isinstance(expected_get_command_events, Sequence):
event_bus.notify.assert_has_calls(
[call(x) for x in expected_get_command_events]
)
assert event_bus.notify.call_count == len(expected_get_command_events)
data, firmware_event = get_message_json(get_success_body())
command.handle_mqtt_p2p(event_bus, json.dumps(data))
if not isinstance(expected_get_command_events, Sequence):
expected_events = [firmware_event, expected_get_command_events]
else:
event_bus.notify.assert_called_once_with(expected_get_command_events)
expected_events = [firmware_event, *expected_get_command_events]

event_bus.notify.assert_has_calls([call(x) for x in expected_events])
assert event_bus.notify.call_count == len(expected_events)

payload = json.dumps({"body": {"data": args}})
mqtt_command = command.create_from_mqtt(payload)
Expand Down
8 changes: 6 additions & 2 deletions tests/commands/json/test_advanced_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@

@pytest.mark.parametrize("value", [False, True])
async def test_GetAdvancedMode(*, value: bool) -> None:
json = get_request_json(get_success_body({"enable": 1 if value else 0}))
await assert_command(GetAdvancedMode(), json, AdvancedModeEvent(value))
json, firmware_event = get_request_json(
get_success_body({"enable": 1 if value else 0})
)
await assert_command(
GetAdvancedMode(), json, (firmware_event, AdvancedModeEvent(value))
)


@pytest.mark.parametrize("value", [False, True])
Expand Down
4 changes: 2 additions & 2 deletions tests/commands/json/test_auto_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
)
async def test_GetAutoEmpty(json: dict[str, Any], expected: AutoEmptyEvent) -> None:
"""Test GetAutoEmpty."""
json = get_request_json(get_success_body(json))
await assert_command(GetAutoEmpty(), json, expected)
json, firmware_event = get_request_json(get_success_body(json))
await assert_command(GetAutoEmpty(), json, (firmware_event, expected))


@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions tests/commands/json/test_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.mark.parametrize("percentage", [0, 49, 100])
async def test_GetBattery(percentage: int) -> None:
json = get_request_json(
json, firmware_event = get_request_json(
get_success_body({"value": percentage, "isLow": 1 if percentage < 20 else 0})
)
await assert_command(GetBattery(), json, BatteryEvent(percentage))
await assert_command(GetBattery(), json, (firmware_event, BatteryEvent(percentage)))
8 changes: 6 additions & 2 deletions tests/commands/json/test_border_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
@pytest.mark.parametrize("value", [False, True])
async def test_GetBorderSwitch(*, value: bool) -> None:
"""Testing get border switch."""
json = get_request_json(get_success_body({"enable": 1 if value else 0}))
await assert_command(GetBorderSwitch(), json, BorderSwitchEvent(value))
json, firmware_event = get_request_json(
get_success_body({"enable": 1 if value else 0})
)
await assert_command(
GetBorderSwitch(), json, (firmware_event, BorderSwitchEvent(value))
)


@pytest.mark.parametrize("value", [False, True])
Expand Down
8 changes: 6 additions & 2 deletions tests/commands/json/test_carpet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@

@pytest.mark.parametrize("value", [False, True])
async def test_GetCarpetAutoFanBoost(*, value: bool) -> None:
json = get_request_json(get_success_body({"enable": 1 if value else 0}))
await assert_command(GetCarpetAutoFanBoost(), json, CarpetAutoFanBoostEvent(value))
json, firmware_event = get_request_json(
get_success_body({"enable": 1 if value else 0})
)
await assert_command(
GetCarpetAutoFanBoost(), json, (firmware_event, CarpetAutoFanBoostEvent(value))
)


@pytest.mark.parametrize("value", [False, True])
Expand Down
Loading