diff --git a/Cargo.toml b/Cargo.toml index 346e0d86c..c3be973bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ crc32fast = "1.4.2" image = { version = "0.25.5", default-features = false } liblzma = "0.4.0" log = "0.4.26" +lz4_flex = "0.13" once_cell = "1.20.3" png = "0.18.0" pyo3 = "0.28.0" diff --git a/deebot_client/authentication.py b/deebot_client/authentication.py index b8f8cb127..4601d240e 100644 --- a/deebot_client/authentication.py +++ b/deebot_client/authentication.py @@ -3,11 +3,12 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass +from collections.abc import Mapping +from dataclasses import dataclass, fields, is_dataclass from http import HTTPStatus import time from typing import TYPE_CHECKING, Any -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse from aiohttp import ClientResponseError, ClientSession, ClientTimeout, hdrs @@ -25,7 +26,11 @@ from .util.countries import get_ecovacs_country if TYPE_CHECKING: - from collections.abc import Callable, Coroutine, Mapping + from collections.abc import Callable, Coroutine + + from .models import ApiDeviceInfo, StaticDeviceInfo + from .ngiot_client import NgiotClient + from .sst_authentication import SstAuthenticator _LOGGER = get_logger(__name__) @@ -44,6 +49,8 @@ "deviceType": "1", } MAX_RETRIES = 3 +_NGIOT_BASE_URL_TEMPLATE = "https://api-base.dc-{region}.ww.ecouser.net" +_NGIOT_COMMAND_MODULE_PREFIX = "deebot_client.commands.ngiot" @dataclass(frozen=True, kw_only=True) @@ -58,6 +65,21 @@ class RestConfiguration: auth_code_url: str +@dataclass(frozen=True, kw_only=True) +class NgiotConfiguration: + """Optional overrides and defaults for NGIOT-backed devices.""" + + base_url: str | None = None + region: str | None = None + user_agent: str = "okhttp/4.9.1" + channel: str = "Android" + protocol_version: str = "0.0.22" + timezone_name: str = "UTC" + timezone_offset_minutes: int = 0 + requested_ttl: int = 600 + refresh_skew: int = 60 + + def create_rest_config( session: ClientSession, *, @@ -127,9 +149,6 @@ async def login(self) -> Credentials: user_id = login_token_resp["userId"] user_access_token = login_token_resp["token"] - # last is validity in milliseconds. Usually 7 days - # we set the expiry at 99% of the validity - # 604800 = 7 days expires_at = int( time.time() + int(login_token_resp.get("last", 604800)) / 1000 * 0.99 ) @@ -149,11 +168,9 @@ async def __do_auth_response( ) as res: res.raise_for_status() - # ecovacs returns a json but content_type header is set to text content_type = res.headers.get(hdrs.CONTENT_TYPE, "").lower() json = await res.json(content_type=content_type) _LOGGER.debug("got %s", json) - # TODO better error handling if json["code"] == "0000": data: dict[str, Any] = json["data"] return data @@ -246,7 +263,6 @@ async def __call_login_by_it_token( if resp["result"] == "ok": return resp if resp["result"] == "fail" and resp["error"] == "set token error.": - # If it is a set token error try again _LOGGER.warning("loginByItToken set token error, attempt %d/3", i + 2) continue @@ -348,6 +364,7 @@ def __init__( account_id: str, password_hash: str, ) -> None: + self._config = config self._auth_client = _AuthClient( config, account_id, @@ -361,6 +378,122 @@ def __init__( self._credentials: Credentials | None = None self._refresh_handle: asyncio.TimerHandle | None = None self._tasks: set[asyncio.Future[Any]] = set() + self._ngiot_config = NgiotConfiguration() + self._ngiot_base_url: str | None = None + self.sst_authenticator: SstAuthenticator | None = None + self.ngiot_client: NgiotClient | None = None + + def configure_ngiot( + self, + *, + base_url: str | None = None, + region: str | None = None, + user_agent: str = "okhttp/4.9.1", + channel: str = "Android", + protocol_version: str = "0.0.22", + timezone_name: str = "UTC", + timezone_offset_minutes: int = 0, + requested_ttl: int = 600, + refresh_skew: int = 60, + ) -> None: + """Store NGIOT defaults and optional region/base-url overrides. + + ``base_url`` wins over ``region``. If neither is configured, the + runtime derives the SST endpoint from the device ``service.mqs`` host. + """ + + normalized_base_url = ( + self._normalize_base_url(base_url) if base_url is not None else None + ) + normalized_region = ( + self._normalize_region(region) if region is not None else None + ) + self._ngiot_config = NgiotConfiguration( + base_url=normalized_base_url, + region=normalized_region, + user_agent=user_agent, + channel=channel, + protocol_version=protocol_version, + timezone_name=timezone_name, + timezone_offset_minutes=timezone_offset_minutes, + requested_ttl=requested_ttl, + refresh_skew=refresh_skew, + ) + + def attach_ngiot( + self, + *, + base_url: str | None = None, + region: str | None = None, + user_agent: str = "okhttp/4.9.1", + channel: str = "Android", + protocol_version: str = "0.0.22", + timezone_name: str = "UTC", + timezone_offset_minutes: int = 0, + requested_ttl: int = 600, + refresh_skew: int = 60, + ) -> None: + """Attach NGIOT helpers immediately using an explicit base URL or region.""" + + self.configure_ngiot( + base_url=base_url, + region=region, + user_agent=user_agent, + channel=channel, + protocol_version=protocol_version, + timezone_name=timezone_name, + timezone_offset_minutes=timezone_offset_minutes, + requested_ttl=requested_ttl, + refresh_skew=refresh_skew, + ) + resolved_base_url = self._resolve_configured_ngiot_base_url() + if resolved_base_url is None: + msg = ( + "attach_ngiot() requires base_url or region. " + "For automatic per-device attachment, call configure_ngiot() and let " + "ApiClient.get_devices() bootstrap NGIOT for matching hardware classes." + ) + raise ApiError(msg) + + if self.ngiot_client is not None: + if self._ngiot_base_url == resolved_base_url: + return + msg = ( + "NGIOT transport already attached with a different base URL. " + "Use configure_ngiot() plus automatic device bootstrap, or call teardown() first." + ) + raise ApiError(msg) + + self._create_ngiot_stack(resolved_base_url) + + async def ensure_ngiot_for_device( + self, + device_info: ApiDeviceInfo, + static_device_info: StaticDeviceInfo, + ) -> bool: + """Attach NGIOT transport if the hardware profile uses NGIOT commands.""" + + if not self._uses_ngiot(static_device_info): + return False + + desired_base_url = self._resolve_ngiot_base_url(device_info) + if self.ngiot_client is not None and self._ngiot_base_url == desired_base_url: + return True + + if self.sst_authenticator is not None: + if self._ngiot_base_url != desired_base_url: + _LOGGER.info( + "Re-attaching NGIOT transport with region/base URL %s for %s", + desired_base_url, + device_info["class"], + ) + await self.sst_authenticator.teardown() + + self.sst_authenticator = None + self.ngiot_client = None + self._ngiot_base_url = None + self._create_ngiot_stack(desired_base_url) + return True async def authenticate(self, *, force: bool = False) -> Credentials: """Authenticate on ecovacs servers.""" @@ -411,6 +544,11 @@ async def post_authenticated( async def teardown(self) -> None: """Teardown authenticator.""" self._cancel_refresh_task() + if self.sst_authenticator is not None: + await self.sst_authenticator.teardown() + self.sst_authenticator = None + self.ngiot_client = None + self._ngiot_base_url = None await cancel(self._tasks) def _cancel_refresh_task(self) -> None: @@ -418,7 +556,6 @@ def _cancel_refresh_task(self) -> None: self._refresh_handle.cancel() def _create_refresh_task(self, credentials: Credentials) -> None: - # refresh at 99% of validity def refresh() -> None: _LOGGER.debug("Refresh token") @@ -432,5 +569,120 @@ async def async_refresh() -> None: self._refresh_handle = None validity = (credentials.expires_at - time.time()) * 0.99 - self._refresh_handle = asyncio.get_event_loop().call_later(validity, refresh) + + def _create_ngiot_stack(self, base_url: str) -> None: + from .ngiot_client import NgiotClient + from .sst_authentication import SstAuthenticator + + normalized_base_url = self._normalize_base_url(base_url) + self.sst_authenticator = SstAuthenticator( + self._config.session, + self, + base_url=normalized_base_url, + requested_ttl=self._ngiot_config.requested_ttl, + refresh_skew=self._ngiot_config.refresh_skew, + ) + # SST tokens are minted against api-base, but endpoint-control requests + # must keep using the device service.mqs host (typically api-ngiot). + self.ngiot_client = NgiotClient( + self._config.session, + self.sst_authenticator, + user_agent=self._ngiot_config.user_agent, + channel=self._ngiot_config.channel, + protocol_version=self._ngiot_config.protocol_version, + timezone_name=self._ngiot_config.timezone_name, + timezone_offset_minutes=self._ngiot_config.timezone_offset_minutes, + ) + self._ngiot_base_url = normalized_base_url + + def _resolve_configured_ngiot_base_url(self) -> str | None: + if self._ngiot_config.base_url is not None: + return self._ngiot_config.base_url + if self._ngiot_config.region is not None: + return self._format_ngiot_base_url(self._ngiot_config.region) + return None + + def _resolve_ngiot_base_url(self, device_info: ApiDeviceInfo) -> str: + configured_base_url = self._resolve_configured_ngiot_base_url() + if configured_base_url is not None: + return configured_base_url + + service = device_info.get("service") + if isinstance(service, Mapping): + mqs_host = service.get("mqs") + if isinstance(mqs_host, str) and mqs_host: + return self._derive_ngiot_base_url_from_mqs(mqs_host) + + msg = ( + f'Could not resolve NGIOT base URL for device class "{device_info["class"]}". ' + "Configure an explicit region or base_url before device bootstrap." + ) + raise ApiError(msg) + + @classmethod + def _uses_ngiot(cls, static_device_info: StaticDeviceInfo) -> bool: + return cls._object_uses_ngiot(getattr(static_device_info, "capabilities", None)) + + @classmethod + def _object_uses_ngiot(cls, value: object) -> bool: + if value is None: + return False + if isinstance(value, type): + return cls._is_ngiot_module(value.__module__) + if cls._is_ngiot_module(value.__class__.__module__): + return True + if isinstance(value, Mapping): + return any( + cls._object_uses_ngiot(key) or cls._object_uses_ngiot(item) + for key, item in value.items() + ) + if isinstance(value, (list, tuple, set, frozenset)): + return any(cls._object_uses_ngiot(item) for item in value) + if is_dataclass(value): + return any( + cls._object_uses_ngiot(getattr(value, field.name)) + for field in fields(value) + ) + return False + + @staticmethod + def _is_ngiot_module(module_name: str) -> bool: + return module_name.startswith(_NGIOT_COMMAND_MODULE_PREFIX) + + @staticmethod + def _normalize_base_url(base_url: str) -> str: + parsed = urlparse(base_url) + if parsed.scheme and parsed.netloc: + host = parsed.netloc + else: + host = base_url + return f"https://{host.strip().rstrip('/')}" + + @staticmethod + def _normalize_region(region: str) -> str: + normalized = region.strip().lower() + if normalized.startswith("dc-"): + normalized = normalized[3:] + return normalized + + @classmethod + def _format_ngiot_base_url(cls, region: str) -> str: + return _NGIOT_BASE_URL_TEMPLATE.format(region=cls._normalize_region(region)) + + @classmethod + def _derive_ngiot_base_url_from_mqs(cls, mqs_host: str) -> str: + parsed = urlparse(mqs_host) + host = parsed.netloc or parsed.path + host = host.strip().rstrip("/") + if not host: + msg = f'Could not derive NGIOT base URL from mqs host "{mqs_host}"' + raise ApiError(msg) + if host.startswith("api-base."): + return cls._normalize_base_url(host) + if host.startswith("api-ngiot."): + return cls._normalize_base_url("api-base." + host.split(".", 1)[1]) + if "." in host: + return cls._normalize_base_url("api-base." + host.split(".", 1)[1]) + msg = f'Could not derive NGIOT base URL from mqs host "{mqs_host}"' + raise ApiError(msg) \ No newline at end of file diff --git a/deebot_client/commands/ngiot/__init__.py b/deebot_client/commands/ngiot/__init__.py new file mode 100644 index 000000000..64855bbad --- /dev/null +++ b/deebot_client/commands/ngiot/__init__.py @@ -0,0 +1,86 @@ +"""NGIOT commands module.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .battery import GetBattery +from .charge import Charge +from .clean import Clean, CleanArea, GetCleanInfo +from .custom import CustomCommand +from .error import GetError +from .fan_speed import GetFanSpeed, SetFanSpeed +from .life_span import GetLifeSpan, ResetLifeSpan +from .network import GetNetInfo +from .play_sound import PlaySound +from .stats import GetReportStats, GetStats, GetTotalStats +from .child_lock import GetChildLock, SetChildLock +from .volume import GetVolume, SetVolume +from .map import ( + GetMajorMap, + GetMapSet, + GetMapTrace, + GetMinorMap, +) +from .pos import GetPos + +if TYPE_CHECKING: + from deebot_client.command import Command + +__all__ = [ + "Charge", + "Clean", + "CleanArea", + "CustomCommand", + "GetBattery", + "GetCleanInfo", + "GetError", + "GetFanSpeed", + "GetLifeSpan", + "GetNetInfo", + "GetReportStats", + "GetStats", + "GetTotalStats", + "PlaySound", + "ResetLifeSpan", + "SetFanSpeed", + "GetChildLock", + "SetChildLock", + "GetMajorMap", + "GetMapSet", + "GetMapTrace", + "GetMinorMap", + "GetPos", + "GetVolume", + "SetVolume", +] + +_COMMANDS: list[type[Command]] = [ + GetBattery, + Charge, + Clean, + CleanArea, + GetCleanInfo, + CustomCommand, + GetError, + GetFanSpeed, + SetFanSpeed, + GetLifeSpan, + ResetLifeSpan, + GetNetInfo, + PlaySound, + GetReportStats, + GetStats, + GetTotalStats, + GetChildLock, + SetChildLock, + GetMajorMap, + GetMapSet, + GetMapTrace, + GetMinorMap, + GetPos, + GetVolume, + SetVolume, +] + +COMMANDS: dict[str, type[Command]] = {cmd.NAME: cmd for cmd in _COMMANDS} \ No newline at end of file diff --git a/deebot_client/commands/ngiot/battery.py b/deebot_client/commands/ngiot/battery.py new file mode 100644 index 000000000..1387a9323 --- /dev/null +++ b/deebot_client/commands/ngiot/battery.py @@ -0,0 +1,33 @@ +"""Battery commands.""" + +from __future__ import annotations + +from typing import Any + +from deebot_client.events import AvailabilityEvent, BatteryEvent +from deebot_client.message import HandlingResult + +from .common import RobotDetailGetCommand + + +class GetBattery(RobotDetailGetCommand): + """Get battery percentage.""" + + NAME = 'getBattery' + FIELDS = ('battery',) + + def __init__(self, *, is_available_check: bool = False) -> None: + super().__init__(is_available_check=is_available_check) + + @classmethod + def _handle_body_data_dict( + cls, + event_bus, + data: dict[str, Any], + ) -> HandlingResult: + battery = data.get('battery') + available = battery is not None + event_bus.notify(AvailabilityEvent(available=available)) + if available: + event_bus.notify(BatteryEvent(int(battery))) + return HandlingResult.success() diff --git a/deebot_client/commands/ngiot/charge.py b/deebot_client/commands/ngiot/charge.py new file mode 100644 index 000000000..04191f578 --- /dev/null +++ b/deebot_client/commands/ngiot/charge.py @@ -0,0 +1,44 @@ +"""Charge commands.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from deebot_client.events import StateEvent +from deebot_client.message import HandlingResult, HandlingState +from deebot_client.models import State +from deebot_client.ngiot_client import APN_RETURN_TO_DOCK + +from .common import NgiotExecuteCommand + +if TYPE_CHECKING: + from deebot_client.event_bus import EventBus + from deebot_client.models import ApiDeviceInfo + from deebot_client.ngiot_client import NgiotClient + + +class Charge(NgiotExecuteCommand): + """Return robot to charge dock.""" + + NAME = "charge" + + def __init__(self) -> None: + super().__init__({}) + + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + return await client.request( + device_info, + apn=APN_RETURN_TO_DOCK, + body_data={"chargeSwitch": True}, + ) + + @classmethod + def _handle_body(cls, event_bus: EventBus, body: dict[str, Any]) -> HandlingResult: + result = super()._handle_body(event_bus, body) + if result.state == HandlingState.SUCCESS: + event_bus.notify(StateEvent(State.RETURNING)) + return result \ No newline at end of file diff --git a/deebot_client/commands/ngiot/child_lock.py b/deebot_client/commands/ngiot/child_lock.py new file mode 100644 index 000000000..67f19b4c9 --- /dev/null +++ b/deebot_client/commands/ngiot/child_lock.py @@ -0,0 +1,64 @@ +"""Child lock commands.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from deebot_client.events import ChildLockEvent +from deebot_client.message import HandlingResult, HandlingState +from deebot_client.ngiot_client import APN_CHILD_LOCK + +from .common import NgiotExecuteCommand, RobotDetailGetCommand + +if TYPE_CHECKING: + from deebot_client.event_bus import EventBus + from deebot_client.models import ApiDeviceInfo + from deebot_client.ngiot_client import NgiotClient + + +class GetChildLock(RobotDetailGetCommand): + """Get child-lock state from the robot-detail surface.""" + + NAME = "getChildLock" + FIELDS = ("childLock",) + + @classmethod + def _handle_body_data_dict( + cls, + event_bus: EventBus, + data: dict[str, Any], + ) -> HandlingResult: + event_bus.notify(ChildLockEvent(bool(data.get("childLock")))) + return HandlingResult.success() + + +class SetChildLock(NgiotExecuteCommand): + """Set child-lock state using the confirmed NGIOT write APN.""" + + NAME = "setChildLock" + get_command = GetChildLock + + def __init__(self, enable: bool) -> None: + super().__init__({}) + self._enable = bool(enable) + + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + return await client.request( + device_info, + apn=APN_CHILD_LOCK, + body_data={"childLock": self._enable}, + ) + + def _handle_response( + self, + event_bus: EventBus, + response: dict[str, Any], + ) -> HandlingResult: + result = super()._handle_response(event_bus, response) + if result.state == HandlingState.SUCCESS: + event_bus.notify(ChildLockEvent(self._enable)) + return result \ No newline at end of file diff --git a/deebot_client/commands/ngiot/clean.py b/deebot_client/commands/ngiot/clean.py new file mode 100644 index 000000000..0ac74ef1c --- /dev/null +++ b/deebot_client/commands/ngiot/clean.py @@ -0,0 +1,196 @@ +"""Clean commands.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any + +from deebot_client.events import StateEvent +from deebot_client.exceptions import ApiError +from deebot_client.message import HandlingResult +from deebot_client.models import CleanAction, CleanMode, State +from deebot_client.ngiot_client import ( + APN_AREA_CLEAN, + APN_CLEAN_START, + APN_PAUSE, + APN_RESUME, + APN_RETURN_TO_DOCK, +) + +from .common import NgiotExecuteCommand, RobotDetailGetCommand + +if TYPE_CHECKING: + from deebot_client.authentication import Authenticator + from deebot_client.event_bus import EventBus + from deebot_client.models import ApiDeviceInfo + from deebot_client.ngiot_client import NgiotClient + + +_ACTIVE_WORK_MODES = { + "smart", + "smartclean", + "area", + "auto", + "customarea", + "custom_area", + "spotarea", + "spot_area", +} + + +def _coerce_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "on", "yes"} + return False + + +def map_snapshot_state(data: Mapping[str, Any]) -> State: + """Map robot-detail snapshot fields onto the generic state enum.""" + work_mode = str(data.get("workMode", "")).strip().lower() + pause_switch = data.get("pauseSwitch") + charge_status = _coerce_bool(data.get("chargeStatus")) + + if work_mode == "auto_pause" or (pause_switch is True and work_mode in _ACTIVE_WORK_MODES): + return State.PAUSED + if work_mode in {"gocharge", "go_charge"}: + return State.RETURNING + if charge_status and work_mode in {"stop", "idle", "", "none"}: + return State.DOCKED + if charge_status: + return State.RETURNING + if work_mode in _ACTIVE_WORK_MODES: + return State.CLEANING + return State.IDLE + + +def map_live_state(data: Mapping[str, Any], previous: State | None = None) -> State | None: + """Map live 10000 status events onto the generic state enum.""" + status = str(data.get("status", "")).strip().lower() + pause_switch = data.get("pauseSwitch") + + if status == "smartclean": + return State.PAUSED if pause_switch is True else State.CLEANING + if status in {"gocharge", "go_charge"}: + return State.RETURNING + if status == "idle": + return State.DOCKED if _coerce_bool(data.get("chargeStatus")) else State.IDLE + + if pause_switch is True and previous in {State.CLEANING, State.PAUSED}: + return State.PAUSED + if pause_switch is False and previous == State.PAUSED: + return State.CLEANING + + if any(key in data for key in ("workMode", "chargeStatus")): + return map_snapshot_state(data) + + return None + + +class Clean(NgiotExecuteCommand): + """Translate generic clean actions into captured NGIOT control payloads.""" + + NAME = "clean" + + def __init__(self, action: CleanAction) -> None: + super().__init__({}) + self._action = action + + async def _execute( + self, + authenticator: Authenticator, + device_info: ApiDeviceInfo, + event_bus: EventBus, + ) -> tuple[HandlingResult, dict[str, Any]]: + state = event_bus.get_last_event(StateEvent) + if state is not None: + if self._action is CleanAction.RESUME and state.state != State.PAUSED: + self._action = CleanAction.START + elif self._action is CleanAction.START and state.state == State.PAUSED: + self._action = CleanAction.RESUME + + return await super()._execute(authenticator, device_info, event_bus) + + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + apn, body_data = self._get_request() + return await client.request( + device_info, + apn=apn, + body_data=body_data, + ) + + def _get_request(self) -> tuple[str, dict[str, Any]]: + if self._action is CleanAction.START: + return APN_CLEAN_START, {"cleanSwitch": True, "cleanMode": "smart"} + if self._action is CleanAction.PAUSE: + return APN_PAUSE, {"pauseSwitch": True} + if self._action is CleanAction.RESUME: + return APN_RESUME, {"pauseSwitch": False} + if self._action is CleanAction.STOP: + return APN_RETURN_TO_DOCK, {"chargeSwitch": True} + raise ApiError(f"Unsupported clean action: {self._action}") + + +class CleanArea(NgiotExecuteCommand): + """Start room/area cleaning using room IDs.""" + + NAME = "clean" + + def __init__( + self, + mode: CleanMode, + area: list[int | float], + cleanings: int = 1, + ) -> None: + super().__init__({}) + self._mode = mode + self._room_ids = [int(value) for value in area] + self._cleanings = cleanings + + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + if self._mode is not CleanMode.SPOT_AREA: + raise ApiError( + "NGIOT area cleaning currently supports room-id cleaning only" + ) + + if self._cleanings != 1: + raise ApiError( + "NGIOT room cleaning repeat count has not been captured yet" + ) + + return await client.request( + device_info, + apn=APN_AREA_CLEAN, + body_data={ + "cleanSwitch": True, + "cleanMode": "area", + "cleanValues": self._room_ids, + }, + ) + + +class GetCleanInfo(RobotDetailGetCommand): + """Get high-level robot state.""" + + NAME = "getCleanInfo" + FIELDS = ("cleanValues", "workMode", "chargeStatus", "pauseSwitch") + + @classmethod + def _handle_body_data_dict( + cls, + event_bus, + data: dict[str, Any], + ) -> HandlingResult: + event_bus.notify(StateEvent(map_snapshot_state(data))) + return HandlingResult.success() \ No newline at end of file diff --git a/deebot_client/commands/ngiot/common.py b/deebot_client/commands/ngiot/common.py new file mode 100644 index 000000000..2fde0c70d --- /dev/null +++ b/deebot_client/commands/ngiot/common.py @@ -0,0 +1,137 @@ +"""Common NGIOT command helpers.""" + +from __future__ import annotations + +import inspect +from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any + +from deebot_client.commands.json.common import ExecuteCommand, JsonGetCommand +from deebot_client.exceptions import ApiError +from deebot_client.hardware import get_static_device_info +from deebot_client.ngiot_client import APN_ROBOT_DETAIL + +if TYPE_CHECKING: + from deebot_client.authentication import Authenticator + from deebot_client.models import ApiDeviceInfo, DeviceInfo, StaticDeviceInfo + from deebot_client.ngiot_client import NgiotClient + + +class NgiotJsonCommandMixin(ABC): + """Mixin that routes command execution through ``Authenticator.ngiot_client``.""" + + async def _get_ngiot_client( + self, + authenticator: Authenticator, + device_info: ApiDeviceInfo | DeviceInfo, + ) -> NgiotClient: + client = getattr(authenticator, "ngiot_client", None) + if client is not None: + return client + + ensure = getattr(authenticator, "ensure_ngiot_for_device", None) + if ensure is not None: + raw_device: ApiDeviceInfo = ( + device_info.api if hasattr(device_info, "api") else device_info + ) + + static_device_info: StaticDeviceInfo | None = getattr( + device_info, "static", None + ) + if static_device_info is None: + static_device_info = await get_static_device_info(raw_device["class"]) + + if static_device_info is None: + raise ApiError( + f'No static device info found for NGIOT device class "{raw_device["class"]}"' + ) + + result = ensure(raw_device, static_device_info) + if inspect.isawaitable(result): + await result + + client = getattr(authenticator, "ngiot_client", None) + if client is not None: + return client + + raise ApiError( + "NGIOT client not attached to authenticator after bootstrap attempt" + ) + + @staticmethod + def _wrap_response(response: Mapping[str, Any]) -> dict[str, Any]: + body = response.get("body", {}) + if not isinstance(body, Mapping): + body = {} + return {"ret": "ok", "resp": {"body": dict(body)}} + + async def _execute_api_request( + self, + authenticator: Authenticator, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + client = await self._get_ngiot_client(authenticator, device_info) + response = await self._request_ngiot(client, device_info) + return self._wrap_response(response) + + @abstractmethod + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + """Execute the NGIOT request and return the raw NGIOT envelope.""" + + +class NgiotJsonGetCommand(NgiotJsonCommandMixin, JsonGetCommand, ABC): + """Base class for NGIOT-backed get commands.""" + + def __init__( + self, + args: dict[str, Any] | list[Any] | None = None, + *, + is_available_check: bool = False, + ) -> None: + super().__init__(args) + self._is_available_check = is_available_check + + +class NgiotExecuteCommand(NgiotJsonCommandMixin, ExecuteCommand, ABC): + """Base class for NGIOT-backed execute commands.""" + + +class RobotDetailGetCommand(NgiotJsonGetCommand, ABC): + """Base class for APN 10001 field queries.""" + + FIELDS: tuple[str, ...] = () + + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + return await client.request( + device_info, + apn=APN_ROBOT_DETAIL, + body_data={"fields": list(self.FIELDS)}, + ) + + +class RobotDetailSetCommand(NgiotExecuteCommand, ABC): + """Base class for APN 10001 writes.""" + + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + return await client.request( + device_info, + apn=APN_ROBOT_DETAIL, + body_data=self._get_body_data(), + ) + + @abstractmethod + def _get_body_data(self) -> dict[str, Any] | list[Any]: + """Return the NGIOT request body payload.""" \ No newline at end of file diff --git a/deebot_client/commands/ngiot/custom.py b/deebot_client/commands/ngiot/custom.py new file mode 100644 index 000000000..d7e10807d --- /dev/null +++ b/deebot_client/commands/ngiot/custom.py @@ -0,0 +1,31 @@ +"""Custom commands.""" + +from __future__ import annotations + +from typing import Any + +from deebot_client.events import CustomCommandEvent +from deebot_client.message import HandlingResult, HandlingState + +from .common import RobotDetailSetCommand + + +class CustomCommand(RobotDetailSetCommand): + """Send an arbitrary key/value payload to APN 10001.""" + + NAME = 'customCommand' + + def __init__(self, name: str, value: Any) -> None: + super().__init__({name: value}) + self._name = name + self._value = value + + def _get_body_data(self) -> dict[str, Any]: + return dict(self._args) + + def _handle_response(self, event_bus, response: dict[str, Any]) -> HandlingResult: + result = super()._handle_response(event_bus, response) + if result.state == HandlingState.SUCCESS: + body = response.get('resp', {}).get('body', {}) + event_bus.notify(CustomCommandEvent(name=self._name, response=body)) + return result diff --git a/deebot_client/commands/ngiot/error.py b/deebot_client/commands/ngiot/error.py new file mode 100644 index 000000000..35a1aef43 --- /dev/null +++ b/deebot_client/commands/ngiot/error.py @@ -0,0 +1,35 @@ +"""Error commands.""" + +from __future__ import annotations + +from typing import Any + +from deebot_client.events import ErrorEvent +from deebot_client.message import HandlingResult + +from .common import RobotDetailGetCommand + + +class GetError(RobotDetailGetCommand): + """Get current robot error.""" + + NAME = 'getError' + FIELDS = ('error',) + + @classmethod + def _handle_body_data_dict(cls, event_bus: EventBus, data: dict[str, Any]) -> HandlingResult: + code = _extract_first_int(data.get("error")) + + if code == 0: + return HandlingResult.success() + + event_bus.notify(ErrorEvent(code, f"NGIOT error {code}")) + return HandlingResult.success() + +def _extract_first_int(value: Any) -> int: + if isinstance(value, list) and value: + value = value[0] + try: + return int(value) + except (TypeError, ValueError): + return 0 diff --git a/deebot_client/commands/ngiot/fan_speed.py b/deebot_client/commands/ngiot/fan_speed.py new file mode 100644 index 000000000..715f3b0ca --- /dev/null +++ b/deebot_client/commands/ngiot/fan_speed.py @@ -0,0 +1,90 @@ +"""NGIOT fan speed commands.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from deebot_client.events import FanSpeedEvent, FanSpeedLevel +from deebot_client.message import HandlingResult, HandlingState +from deebot_client.util import get_enum + +from .common import NgiotExecuteCommand, RobotDetailGetCommand + +if TYPE_CHECKING: + from deebot_client.event_bus import EventBus + from deebot_client.models import ApiDeviceInfo + from deebot_client.ngiot_client import NgiotClient + +APN_FAN_MODE = "50011" + +_WIRE_TO_LEVEL: dict[str, FanSpeedLevel] = { + "quiet": FanSpeedLevel.QUIET, + "auto": FanSpeedLevel.NORMAL, + "strong": FanSpeedLevel.MAX, + "max": FanSpeedLevel.MAX_PLUS, +} + +_LEVEL_TO_WIRE: dict[FanSpeedLevel, str] = { + FanSpeedLevel.QUIET: "quiet", + FanSpeedLevel.NORMAL: "auto", + FanSpeedLevel.MAX: "strong", + FanSpeedLevel.MAX_PLUS: "max", +} + + +class GetFanSpeed(RobotDetailGetCommand): + """Get current fan speed/mode from robot detail status.""" + + NAME = "getSpeed" + FIELDS = ("fanMode",) + + @classmethod + def _handle_body_data_dict( + cls, + event_bus: EventBus, + data: dict[str, Any], + ) -> HandlingResult: + fan_mode = str(data["fanMode"]).lower() + event_bus.notify(FanSpeedEvent(_WIRE_TO_LEVEL[fan_mode])) + return HandlingResult.success() + + +class SetFanSpeed(NgiotExecuteCommand): + """Set fan speed/mode for NGIOT devices.""" + + NAME = "setSpeed" + get_command = GetFanSpeed + + def __init__(self, speed: FanSpeedLevel | str) -> None: + super().__init__({}) + if isinstance(speed, str): + speed = get_enum(FanSpeedLevel, speed) + self._speed = speed + + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + try: + fan_mode = _LEVEL_TO_WIRE[self._speed] + except KeyError as ex: + raise ValueError( + f"Fan speed {self._speed!s} is not supported by this NGIOT ruleset" + ) from ex + + return await client.request( + device_info, + apn=APN_FAN_MODE, + body_data={"fanMode": fan_mode}, + ) + + def _handle_response( + self, + event_bus: EventBus, + response: dict[str, Any], + ) -> HandlingResult: + result = super()._handle_response(event_bus, response) + if result.state == HandlingState.SUCCESS: + event_bus.notify(FanSpeedEvent(self._speed)) + return result \ No newline at end of file diff --git a/deebot_client/commands/ngiot/life_span.py b/deebot_client/commands/ngiot/life_span.py new file mode 100644 index 000000000..a34977b8c --- /dev/null +++ b/deebot_client/commands/ngiot/life_span.py @@ -0,0 +1,95 @@ +"""Life span commands.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any + +from deebot_client.events import LifeSpan, LifeSpanEvent +from deebot_client.message import HandlingResult, HandlingState +from deebot_client.ngiot_client import APN_RESET_CONSUMABLE + +from .common import NgiotExecuteCommand, RobotDetailGetCommand + +if TYPE_CHECKING: + from deebot_client.event_bus import EventBus + from deebot_client.models import ApiDeviceInfo + from deebot_client.ngiot_client import NgiotClient + +_CONSUMABLE_TYPES: dict[str, LifeSpan] = { + "rollBrush": LifeSpan.BRUSH, + "filter": LifeSpan.FILTER, + "sideBrush": LifeSpan.SIDE_BRUSH, + "unitCare": LifeSpan.UNIT_CARE, +} + +_RESET_CONSUMABLE_TYPES: dict[LifeSpan, str] = { + LifeSpan.BRUSH: "rollBrush", + LifeSpan.FILTER: "filter", + LifeSpan.SIDE_BRUSH: "sideBrush", + LifeSpan.UNIT_CARE: "unitCare", +} + + +class GetLifeSpan(RobotDetailGetCommand): + """Get consumable life-span data.""" + + NAME = "getLifeSpan" + FIELDS = ("consumables",) + + @classmethod + def _handle_body_data_dict( + cls, + event_bus: EventBus, + data: dict[str, Any], + ) -> HandlingResult: + for item in data.get("consumables", []) or []: + if not isinstance(item, Mapping): + continue + consumable_type = _CONSUMABLE_TYPES.get(str(item.get("type"))) + if consumable_type is None: + continue + left = int(item.get("left", 0) or 0) + total = int(item.get("total", 0) or 0) + percent = 0.0 if total <= 0 else (left / total) * 100 + event_bus.notify(LifeSpanEvent(consumable_type, percent, left)) + return HandlingResult.success() + + +class ResetLifeSpan(NgiotExecuteCommand): + """Reset a consumable counter via the NGIOT reset-consumable surface.""" + + NAME = "resetLifeSpan" + get_command = GetLifeSpan + + def __init__(self, life_span: LifeSpan) -> None: + super().__init__({}) + self._life_span = life_span + + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + try: + reset_consumable = _RESET_CONSUMABLE_TYPES[self._life_span] + except KeyError as ex: + raise ValueError( + f"Life-span reset is not supported for NGIOT consumable {self._life_span!s}" + ) from ex + + return await client.request( + device_info, + apn=APN_RESET_CONSUMABLE, + body_data={"resetConsumable": reset_consumable}, + ) + + def _handle_response( + self, + event_bus: EventBus, + response: dict[str, Any], + ) -> HandlingResult: + result = super()._handle_response(event_bus, response) + if result.state == HandlingState.SUCCESS: + event_bus.request_refresh(LifeSpanEvent) + return result \ No newline at end of file diff --git a/deebot_client/commands/ngiot/locate.py b/deebot_client/commands/ngiot/locate.py new file mode 100644 index 000000000..f353cecd0 --- /dev/null +++ b/deebot_client/commands/ngiot/locate.py @@ -0,0 +1,23 @@ +"""NGIOT locate-device command.""" + +from __future__ import annotations + +from deebot_client.ngiot_client import APN_DEVICE_LOCATE + +from .common import NgiotExecuteCommand + + +class LocateDevice(NgiotExecuteCommand): + """Trigger the robot locator beep on NGIOT devices.""" + + NAME = "seek" + + def __init__(self) -> None: + super().__init__({}) + + async def _request_ngiot(self, client, device_info): + return await client.request( + device_info, + apn=APN_DEVICE_LOCATE, + body_data={"seek": True}, + ) \ No newline at end of file diff --git a/deebot_client/commands/ngiot/map.py b/deebot_client/commands/ngiot/map.py new file mode 100644 index 000000000..4488d14ab --- /dev/null +++ b/deebot_client/commands/ngiot/map.py @@ -0,0 +1,605 @@ +"""NGIOT map commands.""" + +from __future__ import annotations + +import binascii +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from deebot_client.events import Position, PositionsEvent, RoomsEvent +from deebot_client.events.map import ( + CachedMapInfoEvent, + MajorMapEvent, + Map, + MapSetEvent, + MapSetType, + MapSubsetEvent, + MapTraceEvent, + MinorMapEvent, +) +from deebot_client.exceptions import ApiError +from deebot_client.message import HandlingResult, HandlingState +from deebot_client.models import Room +from deebot_client.ngiot_client import APN_MAP_DETAILS +from deebot_client.ngiot_map_parser import ( + parse_areas, + parse_base_map, + parse_map_infos, + parse_overlays, + parse_pose, + parse_trace, + resolve_map_id, +) +from deebot_client.ngiot_map_state import NgiotMapStateStore +from deebot_client.rs.map import PositionType, RotationAngle + +from .common import NgiotJsonGetCommand + +if TYPE_CHECKING: + from collections.abc import Sequence + + from deebot_client.event_bus import EventBus + from deebot_client.models import ApiDeviceInfo + from deebot_client.ngiot_client import NgiotClient + + +def _coerce_int(value: Any, default: int = 0) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _build_position(raw: Any, type_name: str) -> Position | None: + if not isinstance(raw, dict): + return None + + if raw.get("x") is None or raw.get("y") is None: + return None + + return Position( + type=PositionType.from_str(type_name), + x=_coerce_int(raw.get("x")), + y=_coerce_int(raw.get("y")), + a=_coerce_int(raw.get("a")), + ) + + +def _build_position_from_point(raw_point: Any, type_name: str) -> Position | None: + if raw_point is None: + return None + + x = getattr(raw_point, "x", None) + y = getattr(raw_point, "y", None) + a = getattr(raw_point, "a", 0) + + if x is None or y is None: + return None + + return Position( + type=PositionType.from_str(type_name), + x=_coerce_int(x), + y=_coerce_int(y), + a=_coerce_int(a), + ) + + +def _get_ngiot_map_state_store(event_bus: EventBus) -> NgiotMapStateStore: + store = getattr(event_bus, "_ngiot_map_state_store", None) + if store is None: + store = NgiotMapStateStore() + setattr(event_bus, "_ngiot_map_state_store", store) + return store + + +def _resolve_effective_map_id( + event_bus: EventBus, + data: dict[str, Any], + explicit: str = "", +) -> str: + store = _get_ngiot_map_state_store(event_bus) + # For eyfj07, downstream mapping payloads do not expose a stable join key. + # Prefer an explicit/requested or already-active map context over payload IDs. + return explicit or store.active_map_id or resolve_map_id(data, fallback="") + + +def _polygon_to_coordinates(points: list[Any]) -> str: + return ",".join(f"{point.x},{point.y}" for point in points) + + +class NgiotMapGetCommand(NgiotJsonGetCommand, ABC): + """Base class for NGIOT APN 30001 field queries.""" + + def __init__(self, map_id: str = "") -> None: + super().__init__({}) + self._map_id = str(map_id) + + @property + @abstractmethod + def _fields(self) -> Sequence[str]: + """Return fields to request from APN 30001.""" + + async def _resolve_request_map_id( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> str: + if self._map_id: + return self._map_id + + # mapInfos itself must not recurse. + if tuple(self._fields) == ("mapInfos",): + return "" + + try: + response = await client.request( + device_info, + apn=APN_MAP_DETAILS, + body_data={"fields": ["mapInfos"]}, + ) + except ApiError: + return "" + + body = response.get("body", {}) + data = body.get("data", {}) + if not isinstance(data, dict): + return "" + + infos = parse_map_infos(data) + active = next((info for info in infos if info.using and info.map_id), None) + if active is not None: + return active.map_id + + first = next((info for info in infos if info.map_id), None) + return first.map_id if first is not None else "" + + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + request_map_id = await self._resolve_request_map_id(client, device_info) + + body_data: dict[str, Any] = {"fields": list(self._fields)} + if request_map_id: + body_data["mapId"] = request_map_id + + return await client.request( + device_info, + apn=APN_MAP_DETAILS, + body_data=body_data, + ) + + @classmethod + def _handle_body_data_dict( + cls, event_bus: EventBus, data: dict[str, Any] + ) -> HandlingResult: + infos = parse_map_infos(data) + if not infos: + return HandlingResult.analyse() + + store = _get_ngiot_map_state_store(event_bus) + store.update_map_infos(infos) + + maps = { + Map( + id=info.map_id, + name=info.name, + using=info.using, + built=True, + angle=RotationAngle.from_int(info.angle), + ) + for info in infos + } + event_bus.notify(CachedMapInfoEvent(maps=maps)) + + active_info = next((info for info in infos if info.using), None) + resolved_info = active_info or infos[0] + + charger_pos = _build_position_from_point( + resolved_info.charge_pos, "chargePos" + ) + if charger_pos is not None: + event_bus.notify(PositionsEvent(positions=[charger_pos])) + + return HandlingResult( + HandlingState.SUCCESS, + {"map_id": resolved_info.map_id}, + ) + + def _handle_response( + self, + event_bus: EventBus, + response: dict[str, Any], + ) -> HandlingResult: + result = super()._handle_response(event_bus, response) + return result + + +class GetMajorMap(NgiotMapGetCommand): + """Get the current NGIOT raster map.""" + + NAME = "getMajorMap" + + @property + def _fields(self) -> Sequence[str]: + return ("mapData", "areas", "pos") + + @classmethod + def _handle_body_data_dict( + cls, + event_bus: EventBus, + data: dict[str, Any], + ) -> HandlingResult: + store = _get_ngiot_map_state_store(event_bus) + map_id = _resolve_effective_map_id(event_bus, data) + handled = False + + map_data = data.get("mapData") + if isinstance(map_data, dict): + legacy_map_blob = map_data.get("map") + if isinstance(legacy_map_blob, str) and legacy_map_blob: + crc = binascii.crc32(legacy_map_blob.encode("utf-8")) & 0xFFFFFFFF + event_bus.notify( + MajorMapEvent(map_id=map_id, values=[crc], requested=False) + ) + handled = True + + base_map = parse_base_map(data, map_id) + if base_map is not None and base_map.map_id: + store.update_base_map(base_map) + map_id = base_map.map_id + handled = True + + areas = parse_areas(data) + if areas and map_id: + store.update_areas(map_id, areas) + rooms = [ + Room( + name=(area.name or f"Area { _coerce_int(area.area_id, index) }"), + id=_coerce_int(area.area_id, index), + coordinates=_polygon_to_coordinates(area.polygon), + ) + for index, area in enumerate(areas) + ] + event_bus.notify(RoomsEvent(map_id=map_id, rooms=rooms)) + handled = True + + positions: list[Position] = [] + + pose = parse_pose(data) + if pose is not None: + if map_id: + store.update_pose(map_id, pose) + positions.append( + Position( + type=PositionType.from_str("deebotPos"), + x=pose.x, + y=pose.y, + a=pose.a, + ) + ) + handled = True + + if isinstance(map_data, dict): + charger_pos = _build_position(map_data.get("chargePos"), "chargePos") + if charger_pos is not None: + positions.append(charger_pos) + handled = True + + legacy_charge_pos = data.get("chargePos") + if isinstance(legacy_charge_pos, list): + for entry in legacy_charge_pos: + charger_pos = _build_position(entry, "chargePos") + if charger_pos is not None: + positions.append(charger_pos) + handled = True + + if positions: + event_bus.notify(PositionsEvent(positions=positions)) + + return HandlingResult.success() if handled else HandlingResult.analyse() + + +class GetCachedMapInfo(NgiotMapGetCommand): + NAME = "getCachedMapInfo" + + @property + def _fields(self) -> Sequence[str]: + return ("mapInfos",) + + +class GetMinorMap(NgiotMapGetCommand): + """Compatibility command for NGIOT map tile fetches.""" + + NAME = "getMinorMap" + + def __init__(self, piece_index: int, map_id: str) -> None: + super().__init__(map_id) + self._piece_index = piece_index + + @property + def _fields(self) -> Sequence[str]: + return ("mapData",) + + @classmethod + def _handle_body_data_dict( + cls, + event_bus: EventBus, + data: dict[str, Any], + ) -> HandlingResult: + del event_bus, data + return HandlingResult.analyse() + + def _handle_response( + self, + event_bus: EventBus, + response: dict[str, Any], + ) -> HandlingResult: + if response.get("ret") != "ok": + return HandlingResult.analyse() + + body = response.get("resp", {}).get("body", {}) + data = body.get("data", {}) + if not isinstance(data, dict): + return HandlingResult.analyse() + + map_data = data.get("mapData") + if not isinstance(map_data, dict): + return HandlingResult.analyse() + + map_blob = map_data.get("map") + if not isinstance(map_blob, str) or not map_blob: + return HandlingResult.analyse() + + event_bus.notify(MinorMapEvent(index=self._piece_index, value=map_blob)) + return HandlingResult.success() + + +class GetMapTrace(NgiotMapGetCommand): + """Get the current NGIOT map trace.""" + + NAME = "getMapTrace" + + @property + def _fields(self) -> Sequence[str]: + return ("mapTraceData",) + + @classmethod + def _handle_body_data_dict( + cls, + event_bus: EventBus, + data: dict[str, Any], + ) -> HandlingResult: + trace = parse_trace(data) + if trace is None: + return HandlingResult.analyse() + + store = _get_ngiot_map_state_store(event_bus) + map_id = _resolve_effective_map_id(event_bus, data) + if map_id: + store.update_trace(map_id, trace) + + event_bus.notify( + MapTraceEvent( + start=trace.start, + total=trace.total_count, + data=trace.encoded, + lz4_len=trace.lz4_len, + ) + ) + return HandlingResult.success() + + +class GetPos(NgiotMapGetCommand): + """Get current robot and charger positions from NGIOT map data.""" + + NAME = "getPos" + + @property + def _fields(self) -> Sequence[str]: + return ("mapData", "pos") + + @classmethod + def _handle_body_data_dict( + cls, + event_bus: EventBus, + data: dict[str, Any], + ) -> HandlingResult: + store = _get_ngiot_map_state_store(event_bus) + map_id = _resolve_effective_map_id(event_bus, data) + + positions: list[Position] = [] + + pose = parse_pose(data) + if pose is not None: + if map_id: + store.update_pose(map_id, pose) + + positions.append( + Position( + type=PositionType.from_str("deebotPos"), + x=pose.x, + y=pose.y, + a=pose.a, + ) + ) + + map_data = data.get("mapData") + if isinstance(map_data, dict): + charger_pos = _build_position(map_data.get("chargePos"), "chargePos") + if charger_pos is not None: + positions.append(charger_pos) + + legacy_charge_pos = data.get("chargePos") + if isinstance(legacy_charge_pos, list): + for entry in legacy_charge_pos: + charger_pos = _build_position(entry, "chargePos") + if charger_pos is not None: + positions.append(charger_pos) + + if positions: + event_bus.notify(PositionsEvent(positions=positions)) + return HandlingResult.success() + + return HandlingResult.analyse() + + +class GetMapSet(NgiotMapGetCommand): + """Get room and barrier data from the NGIOT map surface.""" + + NAME = "getMapSubSet" + + def __init__( + self, + mid: str, + type: MapSetType | str = MapSetType.ROOMS, + ) -> None: + if isinstance(type, MapSetType): + type = type.value + + super().__init__(mid) + self._map_type = MapSetType(type) + + @property + def _fields(self) -> Sequence[str]: + if self._map_type == MapSetType.ROOMS: + return ("areas",) + return ("virtualWalls", "mopWalls", "carpets") + + @classmethod + def _handle_body_data_dict( + cls, + event_bus: EventBus, + data: dict[str, Any], + ) -> HandlingResult: + del event_bus, data + return HandlingResult.analyse() + + def _handle_response( + self, + event_bus: EventBus, + response: dict[str, Any], + ) -> HandlingResult: + if response.get("ret") != "ok": + return HandlingResult.analyse() + + body = response.get("resp", {}).get("body", {}) + data = body.get("data", {}) + if not isinstance(data, dict): + return HandlingResult.analyse() + + store = _get_ngiot_map_state_store(event_bus) + map_id = _resolve_effective_map_id(event_bus, data, self._map_id) + + if self._map_type == MapSetType.ROOMS: + areas = parse_areas(data) + if not areas: + return HandlingResult.analyse() + + if map_id: + store.update_areas(map_id, areas) + + drawable_subset_ids: list[int] = [] + rooms: list[Room] = [] + + for index, area in enumerate(areas): + subset_id = _coerce_int(area.area_id, index) + coordinates = _polygon_to_coordinates(area.polygon) + + if coordinates: + drawable_subset_ids.append(subset_id) + event_bus.notify( + MapSubsetEvent( + id=subset_id, + type=MapSetType.ROOMS, + coordinates=coordinates, + name=area.name, + ) + ) + + rooms.append( + Room( + name=(area.name or f"Area {subset_id}"), + id=subset_id, + coordinates=coordinates, + ) + ) + + if drawable_subset_ids: + event_bus.notify( + MapSetEvent(MapSetType.ROOMS, drawable_subset_ids, map_id) + ) + + event_bus.notify(RoomsEvent(map_id=map_id, rooms=rooms)) + return HandlingResult.success() + + overlays = parse_overlays(data) + if map_id and overlays: + store.update_overlays(map_id, overlays) + + overlay_type_map = { + MapSetType.VIRTUAL_WALLS: "virtual_walls", + MapSetType.NO_MOP_ZONES: "mop_walls", + MapSetType.CARPETS: "carpets", + } + target_overlay_type = overlay_type_map.get(self._map_type) + + if target_overlay_type: + parsed_for_type = [ + overlay + for overlay in overlays + if overlay.overlay_type == target_overlay_type + ] + + if parsed_for_type: + subset_ids: list[int] = [] + for index, overlay in enumerate(parsed_for_type): + subset_id = _coerce_int(overlay.overlay_id, index) + subset_ids.append(subset_id) + event_bus.notify( + MapSubsetEvent( + id=subset_id, + type=self._map_type, + coordinates=_polygon_to_coordinates(overlay.polygon), + ) + ) + + event_bus.notify(MapSetEvent(self._map_type, subset_ids, map_id)) + return HandlingResult.success() + + data_key = { + MapSetType.VIRTUAL_WALLS: "virtualWalls", + MapSetType.NO_MOP_ZONES: "mopWalls", + MapSetType.CARPETS: "carpets", + }[self._map_type] + + raw_value = str(data.get(data_key, "")).strip() + subset_ids: list[int] = [] + + if raw_value: + for entry in raw_value.split(";"): + parts = [part.strip() for part in entry.split(",") if part.strip()] + if len(parts) < 3: + continue + + subset_id = int(parts[0]) + coordinates = ",".join(parts[2:] if len(parts) % 2 == 0 else parts[1:]) + subset_ids.append(subset_id) + event_bus.notify( + MapSubsetEvent( + id=subset_id, + type=self._map_type, + coordinates=coordinates, + ) + ) + + event_bus.notify(MapSetEvent(self._map_type, subset_ids, map_id)) + return HandlingResult.success() + + event_bus.notify(MapSetEvent(self._map_type, subset_ids, map_id)) + return HandlingResult.success() + + +# Backward compatibility for older imports +GetMapSubSet = GetMapSet \ No newline at end of file diff --git a/deebot_client/commands/ngiot/network.py b/deebot_client/commands/ngiot/network.py new file mode 100644 index 000000000..6ccac7edf --- /dev/null +++ b/deebot_client/commands/ngiot/network.py @@ -0,0 +1,44 @@ +"""Network commands.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from deebot_client.events import NetworkInfoEvent +from deebot_client.message import HandlingResult + +from .common import RobotDetailGetCommand + + +class GetNetInfo(RobotDetailGetCommand): + """Get network info from the robot-detail surface.""" + + NAME = 'getNetInfo' + FIELDS = ('deviceInfo',) + + @classmethod + def _handle_body_data_dict( + cls, + event_bus, + data: dict[str, Any], + ) -> HandlingResult: + device_info = data.get('deviceInfo', {}) + if not isinstance(device_info, Mapping): + device_info = {} + event_bus.notify( + NetworkInfoEvent( + ip=str(device_info.get('ip', '')), + ssid=str(device_info.get('ssid', '')), + rssi=_coerce_rssi(device_info.get('rssi')), + mac=str(device_info.get('mac', '')), + ) + ) + return HandlingResult.success() + + +def _coerce_rssi(value: Any) -> int: + try: + return int(value) + except (TypeError, ValueError): + return 0 diff --git a/deebot_client/commands/ngiot/play_sound.py b/deebot_client/commands/ngiot/play_sound.py new file mode 100644 index 000000000..0651db013 --- /dev/null +++ b/deebot_client/commands/ngiot/play_sound.py @@ -0,0 +1,33 @@ +"""NGIOT play-sound commands.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from deebot_client.ngiot_client import APN_DEVICE_LOCATE + +from .common import NgiotExecuteCommand + +if TYPE_CHECKING: + from deebot_client.models import ApiDeviceInfo + from deebot_client.ngiot_client import NgiotClient + + +class PlaySound(NgiotExecuteCommand): + """Trigger device locate sound.""" + + NAME = "seek" + + def __init__(self) -> None: + super().__init__({}) + + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, object]: + return await client.request( + device_info, + apn=APN_DEVICE_LOCATE, + body_data={"seek": True}, + ) \ No newline at end of file diff --git a/deebot_client/commands/ngiot/pos.py b/deebot_client/commands/ngiot/pos.py new file mode 100644 index 000000000..46a2a8e9d --- /dev/null +++ b/deebot_client/commands/ngiot/pos.py @@ -0,0 +1,7 @@ +"""NGIOT position commands.""" + +from __future__ import annotations + +from .map import GetPos + +__all__ = ["GetPos"] \ No newline at end of file diff --git a/deebot_client/commands/ngiot/stats.py b/deebot_client/commands/ngiot/stats.py new file mode 100644 index 000000000..ad62b7815 --- /dev/null +++ b/deebot_client/commands/ngiot/stats.py @@ -0,0 +1,127 @@ +"""Stats commands.""" + +from __future__ import annotations + +from typing import Any + +from deebot_client.events import CleanJobStatus, ReportStatsEvent, StatsEvent, TotalStatsEvent +from deebot_client.message import HandlingResult + +from .common import RobotDetailGetCommand + + +class GetStats(RobotDetailGetCommand): + """Get current clean stats. + + eyfj07 reports cleanTime in minutes. Home Assistant's Ecovacs duration + sensors expect native seconds, so convert before emitting the event. + """ + + NAME = "getStats" + FIELDS = ("cleanArea", "cleanTime", "cleanCount", "workMode", "cleanLogReport") + + @classmethod + def _handle_body_data_dict( + cls, + event_bus, + data: dict[str, Any], + ) -> HandlingResult: + event_bus.notify( + StatsEvent( + area=_maybe_int(data.get("cleanArea")), + time=_minutes_to_seconds(data.get("cleanTime")), + type=_maybe_str(data.get("workMode")), + ) + ) + return HandlingResult.success() + + +class GetReportStats(RobotDetailGetCommand): + """Get current clean report stats from the robot detail snapshot.""" + + NAME = "getReportStats" + FIELDS = ("cleanArea", "cleanTime", "cleanCount", "workMode", "cleanLogReport") + + @classmethod + def _handle_body_data_dict( + cls, + event_bus, + data: dict[str, Any], + ) -> HandlingResult: + clean_log_report = data.get("cleanLogReport") + cleaning_id = "" + if isinstance(clean_log_report, dict): + cleaning_id = str(clean_log_report.get("cid") or "") + + event_bus.notify( + ReportStatsEvent( + area=_maybe_int(data.get("cleanArea")), + time=_minutes_to_seconds(data.get("cleanTime")), + type=_maybe_str(data.get("workMode")), + cleaning_id=cleaning_id, + status=CleanJobStatus.NO_STATUS, + content=[], + ) + ) + return HandlingResult.success() + + +class GetTotalStats(RobotDetailGetCommand): + """Get lifetime totals. + + eyfj07 exposes lifetime totals on the total-stats response surface as + ``cleanAreaTotal``, ``cleanTimeTotal``, and ``cleanCountTotal``. + + cleanTimeTotal is reported in minutes. Home Assistant expects native + duration values in seconds, so convert before emitting TotalStatsEvent. + """ + + NAME = "getTotalStats" + FIELDS = ("cleanAreaTotal", "cleanTimeTotal", "cleanCountTotal") + + @classmethod + def _handle_body_data_dict( + cls, + event_bus, + data: dict[str, Any], + ) -> HandlingResult: + event_bus.notify( + TotalStatsEvent( + area=_coerce_total(data, "cleanAreaTotal", "cleanArea"), + time=_minutes_to_seconds( + data.get("cleanTimeTotal", data.get("cleanTime", 0)) + ) + or 0, + cleanings=_coerce_total(data, "cleanCountTotal", "cleanCount"), + ) + ) + return HandlingResult.success() + + +def _maybe_int(value: Any) -> int | None: + try: + return int(value) if value is not None else None + except (TypeError, ValueError): + return None + + + +def _maybe_str(value: Any) -> str | None: + return None if value is None else str(value) + + + +def _coerce_total(data: dict[str, Any], primary: str, fallback: str) -> int: + value = data.get(primary, data.get(fallback, 0)) + try: + return int(value or 0) + except (TypeError, ValueError): + return 0 + + + +def _minutes_to_seconds(value: Any) -> int | None: + minutes = _maybe_int(value) + if minutes is None: + return None + return minutes * 60 \ No newline at end of file diff --git a/deebot_client/commands/ngiot/volume.py b/deebot_client/commands/ngiot/volume.py new file mode 100644 index 000000000..aa0f961bf --- /dev/null +++ b/deebot_client/commands/ngiot/volume.py @@ -0,0 +1,78 @@ +"""Volume commands.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from deebot_client.events import VolumeEvent +from deebot_client.message import HandlingResult, HandlingState + +from .common import NgiotExecuteCommand, RobotDetailGetCommand + +if TYPE_CHECKING: + from deebot_client.event_bus import EventBus + from deebot_client.models import ApiDeviceInfo + from deebot_client.ngiot_client import NgiotClient + + +class GetVolume(RobotDetailGetCommand): + """Get device voice volume from the robot-detail surface.""" + + NAME = "getVolume" + FIELDS = ("volume",) + MAX_VOLUME = 5 + + @classmethod + def _handle_body_data_dict( + cls, + event_bus: EventBus, + data: dict[str, Any], + ) -> HandlingResult: + volume = data.get("volume") + if volume is None: + return HandlingResult.analyse() + + event_bus.notify(VolumeEvent(volume=int(volume), maximum=cls.MAX_VOLUME)) + return HandlingResult.success() + + +class SetVolume(NgiotExecuteCommand): + """Set device voice volume.""" + + NAME = "setVolume" + APN = "50023" + MIN_VOLUME = 0 + MAX_VOLUME = 10 + get_command = GetVolume + + def __init__(self, volume: int) -> None: + super().__init__({}) + self._volume = int(volume) + + async def _request_ngiot( + self, + client: NgiotClient, + device_info: ApiDeviceInfo, + ) -> dict[str, Any]: + if not self.MIN_VOLUME <= self._volume <= self.MAX_VOLUME: + raise ValueError( + f"Volume must be between {self.MIN_VOLUME} and {self.MAX_VOLUME}" + ) + + return await client.request( + device_info, + apn=self.APN, + body_data={"volume": self._volume}, + ) + + def _handle_response( + self, + event_bus: EventBus, + response: dict[str, Any], + ) -> HandlingResult: + result = super()._handle_response(event_bus, response) + if result.state == HandlingState.SUCCESS: + event_bus.notify( + VolumeEvent(volume=self._volume, maximum=self.MAX_VOLUME) + ) + return result \ No newline at end of file diff --git a/deebot_client/device.py b/deebot_client/device.py index 117134aef..4459fa6c4 100644 --- a/deebot_client/device.py +++ b/deebot_client/device.py @@ -30,6 +30,7 @@ from .map import Map from .messages import get_message from .models import DeviceInfo, State +from .ngiot_map_state import NgiotMapStateStore from .rs.map import PositionType if TYPE_CHECKING: @@ -65,8 +66,16 @@ def __init__( self.fw_version: str | None = None self.mac: str | None = None + self.events: Final[EventBus] = EventBus(self.execute_command, self.capabilities) + # Shared NGIOT map aggregation store. + # Commands look for _ngiot_map_state_store on the event bus. + self.ngiot_map_state = NgiotMapStateStore() + setattr(self.events, "_ngiot_map_state_store", self.ngiot_map_state) + # Optional public alias for debugging/introspection. + self.events.ngiot_map_state = self.ngiot_map_state + self.map: Final[Map | None] = ( Map(self.execute_command, self.events, self.capabilities.map) if self.capabilities.map @@ -95,6 +104,7 @@ async def on_pos(event: PositionsEvent) -> None: self.events.subscribe(PositionsEvent, on_pos) async def on_state(event: StateEvent) -> None: + self._state = event if event.state == State.DOCKED: self.events.request_refresh(CleanLogEvent) self.events.request_refresh(TotalStatsEvent) @@ -239,4 +249,4 @@ def _handle_message( self._create_request_command_task(result.requested_commands) except Exception: - _LOGGER.exception("An exception occurred during handling message") + _LOGGER.exception("An exception occurred during handling message") \ No newline at end of file diff --git a/deebot_client/event_bus.py b/deebot_client/event_bus.py index 5d60fe831..c50ac32a5 100644 --- a/deebot_client/event_bus.py +++ b/deebot_client/event_bus.py @@ -243,7 +243,8 @@ def add_on_subscription_callback( def unsubscribe() -> None: data.unsubscribe() - event_processing_data.on_subscription_callbacks.remove(data) + if data in event_processing_data.on_subscription_callbacks: + event_processing_data.on_subscription_callbacks.remove(data) event_processing_data.on_subscription_callbacks.append(data) @@ -251,4 +252,4 @@ def unsubscribe() -> None: # There are already subscribers create_task(self._tasks, data.call()) - return unsubscribe + return unsubscribe \ No newline at end of file diff --git a/deebot_client/events/map.py b/deebot_client/events/map.py index d45fd7fe3..d01323e16 100644 --- a/deebot_client/events/map.py +++ b/deebot_client/events/map.py @@ -39,6 +39,21 @@ class GpsPositionEvent(Event): latitude: float +@unique +class MapSetType(StrEnum): + """Map set type enum.""" + + ROOMS = "ar" + VIRTUAL_WALLS = "vw" + NO_MOP_ZONES = "mw" + CARPETS = "cp" + + @classmethod + def has_value(cls, value: Any) -> bool: + """Check if value exists.""" + return value in cls._value2member_map_ + + @dataclass(frozen=True) class MapTraceEvent(Event): """Map trace event representation.""" @@ -46,6 +61,7 @@ class MapTraceEvent(Event): start: int total: int data: str + lz4_len: int | None = field(default=None, kw_only=True) @dataclass(frozen=True) @@ -73,20 +89,6 @@ class MinorMapEvent(Event): value: str -@unique -class MapSetType(StrEnum): - """Map set type enum.""" - - ROOMS = "ar" - VIRTUAL_WALLS = "vw" - NO_MOP_ZONES = "mw" - - @classmethod - def has_value(cls, value: Any) -> bool: - """Check if value exists.""" - return value in cls._value2member_map_ - - @dataclass(frozen=True) class MapSetEvent(Event): """Map set event.""" @@ -128,4 +130,4 @@ class CachedMapInfoEvent(Event): class MapChangedEvent(Event): """Map changed event.""" - when: datetime + when: datetime \ No newline at end of file diff --git a/deebot_client/hardware/eyfj07.py b/deebot_client/hardware/eyfj07.py new file mode 100644 index 000000000..7dc46dcd9 --- /dev/null +++ b/deebot_client/hardware/eyfj07.py @@ -0,0 +1,169 @@ +"""DEEBOT eyfj07 capabilities. + +This profile is scoped to loading cleanly against the current ``client.py`` +capability contract. Raster map support and additional write payloads can be +added later once their command surfaces are implemented. +""" + +from __future__ import annotations + +from deebot_client.capabilities import ( + Capabilities, + CapabilityClean, + CapabilityCleanAction, + CapabilityCustomCommand, + CapabilityEvent, + CapabilityExecute, + CapabilityLifeSpan, + CapabilityMap, + CapabilitySet, + CapabilitySetEnable, + CapabilitySettings, + CapabilitySetTypes, + CapabilityStats, + DeviceType, +) +from deebot_client.const import DataType +from deebot_client.events import ( + AvailabilityEvent, + BatteryEvent, + CustomCommandEvent, + ChildLockEvent, + ErrorEvent, + FanSpeedEvent, + FanSpeedLevel, + LifeSpan, + LifeSpanEvent, + NetworkInfoEvent, + ReportStatsEvent, + StateEvent, + StatsEvent, + TotalStatsEvent, + PositionsEvent, + RoomsEvent, + VolumeEvent, +) +from deebot_client.models import StaticDeviceInfo + +from deebot_client.commands.ngiot.battery import GetBattery +from deebot_client.commands.ngiot.charge import Charge +from deebot_client.commands.ngiot.clean import Clean, CleanArea, GetCleanInfo +from deebot_client.commands.ngiot.custom import CustomCommand +from deebot_client.commands.ngiot.error import GetError +from deebot_client.commands.ngiot.fan_speed import GetFanSpeed, SetFanSpeed +from deebot_client.commands.ngiot.life_span import GetLifeSpan, ResetLifeSpan +from deebot_client.commands.ngiot.network import GetNetInfo +from deebot_client.commands.ngiot.play_sound import PlaySound +from deebot_client.commands.ngiot.child_lock import GetChildLock, SetChildLock +from deebot_client.commands.ngiot.stats import GetReportStats, GetStats, GetTotalStats +from deebot_client.commands.ngiot.volume import GetVolume, SetVolume +from deebot_client.commands.ngiot.map import ( + GetCachedMapInfo, + GetMajorMap, + GetMapSet, + GetMapTrace, + GetMinorMap, +) +from deebot_client.events.map import ( + CachedMapInfoEvent, + MajorMapEvent, + MapChangedEvent, + MapTraceEvent, +) +from deebot_client.commands.ngiot.pos import GetPos + + +def get_device_info() -> StaticDeviceInfo: + """Get device info for this model.""" + + return StaticDeviceInfo( + DataType.JSON, + Capabilities( + device_type=DeviceType.VACUUM, + availability=CapabilityEvent( + AvailabilityEvent, + [GetBattery(is_available_check=True)], + ), + battery=CapabilityEvent( + BatteryEvent, + [GetBattery()], + ), + charge=CapabilityExecute(Charge), + clean=CapabilityClean( + action=CapabilityCleanAction( + command=Clean, + area=CleanArea, + ), + ), + custom=CapabilityCustomCommand( + event=CustomCommandEvent, + get=[], + set=CustomCommand, + ), + error=CapabilityEvent( + ErrorEvent, + [GetError()], + ), + fan_speed=CapabilitySetTypes( + event=FanSpeedEvent, + get=[GetFanSpeed()], + set=SetFanSpeed, + types=( + FanSpeedLevel.QUIET, + FanSpeedLevel.NORMAL, + FanSpeedLevel.MAX, + FanSpeedLevel.MAX_PLUS, + ), + ), + life_span=CapabilityLifeSpan( + types=( + LifeSpan.BRUSH, + LifeSpan.FILTER, + LifeSpan.SIDE_BRUSH, + LifeSpan.UNIT_CARE, + ), + event=LifeSpanEvent, + get=[GetLifeSpan()], + reset=ResetLifeSpan, + ), + map=CapabilityMap( + cached_info=CapabilityEvent(CachedMapInfoEvent, [GetCachedMapInfo()]), + changed=CapabilityEvent(MapChangedEvent, []), + info=None, + major=CapabilityEvent(MajorMapEvent, [GetMajorMap()]), + minor=CapabilityExecute(GetMinorMap), + multi_state=None, + position=CapabilityEvent(PositionsEvent, [GetPos()]), + rooms=CapabilityEvent(RoomsEvent, [GetMajorMap()]), + set=CapabilityExecute(GetMapSet), + trace=CapabilityEvent(MapTraceEvent, [GetMapTrace()]), + ), + network=CapabilityEvent( + NetworkInfoEvent, + [GetNetInfo()], + ), + play_sound=CapabilityExecute(PlaySound), + settings=CapabilitySettings( + child_lock=CapabilitySetEnable( + ChildLockEvent, + [GetChildLock()], + SetChildLock, + ), + volume=CapabilitySet( + VolumeEvent, + [GetVolume()], + SetVolume, + ), + ), + state=CapabilityEvent( + StateEvent, + [GetCleanInfo()], + ), + stats=CapabilityStats( + clean=CapabilityEvent(StatsEvent, [GetStats()]), + report=CapabilityEvent(ReportStatsEvent, [GetReportStats()]), + total=CapabilityEvent(TotalStatsEvent, [GetTotalStats()]), + ), + water=None, + ), + ) \ No newline at end of file diff --git a/deebot_client/map.py b/deebot_client/map.py index 7b099b2fe..3b6a3a877 100644 --- a/deebot_client/map.py +++ b/deebot_client/map.py @@ -24,9 +24,7 @@ from .logging_filter import get_logger from .models import Room from .rs.map import MapData as MapDataRs, RotationAngle -from .util import ( - OnChangedDict, -) +from .util import OnChangedDict if TYPE_CHECKING: from collections.abc import Callable @@ -59,18 +57,19 @@ async def on_map_set(event: MapSetEvent) -> None: if event.type == MapSetType.ROOMS: return - for subset_id, subset in self._map_data.map_subsets.copy().items(): - if subset.type == event.type and subset_id not in event.subsets: - self._map_data.map_subsets.pop(subset_id, None) + for subset_key, subset in self._map_data.map_subsets.copy().items(): + if subset.type == event.type and subset.id not in event.subsets: + self._map_data.map_subsets.pop(subset_key, None) self._unsubscribers.append(event_bus.subscribe(MapSetEvent, on_map_set)) async def on_map_subset(event: MapSubsetEvent) -> None: - if ( - event.type != MapSetType.ROOMS - and self._map_data.map_subsets.get(event.id, None) != event - ): - self._map_data.map_subsets[event.id] = event + if event.type == MapSetType.ROOMS: + return + + subset_key = (str(event.type), event.id) + if self._map_data.map_subsets.get(subset_key, None) != event: + self._map_data.map_subsets[subset_key] = event self._unsubscribers.append(event_bus.subscribe(MapSubsetEvent, on_map_subset)) @@ -85,21 +84,19 @@ async def on_map_info(event: MapInfoEvent) -> None: self._unsubscribers.append(event_bus.subscribe(MapInfoEvent, on_map_info)) - # ---------------------------- METHODS ---------------------------- - async def _subscribe_minor_major_map_events(self) -> list[Callable[[], None]]: async def on_major_map(event: MajorMapEvent) -> None: - async with asyncio.TaskGroup() as tg: - for idx, value in enumerate(event.values): - if ( - self._map_data.map_piece_crc32_indicates_update(idx, value) - and event.requested - ): - tg.create_task( - self._execute_command( - self._capabilities.minor.execute(idx, event.map_id) + if event.requested: + async with asyncio.TaskGroup() as tg: + for idx, value in enumerate(event.values): + if self._map_data.map_piece_crc32_indicates_update(idx, value): + tg.create_task( + self._execute_command( + self._capabilities.minor.execute(idx, event.map_id) + ) ) - ) + + self._sync_ngiot_background_from_store() async def on_minor_map(event: MinorMapEvent) -> None: self._map_data.update_map_piece(event.index, event.value) @@ -117,17 +114,18 @@ async def on_cached_info(event: CachedMapInfoEvent) -> None: used_map = next((m for m in event.maps if m.using), None) if used_map: self._map_data.set_rotation_angle(used_map.angle) + self._sync_ngiot_background_from_store() cached_map_subscribers = self._event_bus.has_subscribers(CachedMapInfoEvent) unsubscribers.append( self._event_bus.subscribe(CachedMapInfoEvent, on_cached_info) ) if cached_map_subscribers: - # Request update only if there was already a subscriber before self._event_bus.request_refresh(CachedMapInfoEvent) async def on_position(event: PositionsEvent) -> None: self._map_data.update_positions(event.positions) + self._sync_ngiot_background_from_store() unsubscribers.append(self._event_bus.subscribe(PositionsEvent, on_position)) @@ -135,14 +133,37 @@ async def on_map_trace(event: MapTraceEvent) -> None: if event.start == 0: self._map_data.clear_trace_points() - if data := event.data.strip(): - self._map_data.add_trace_points(data) + if not (data := event.data.strip()): + return + + try: + if self._map_data.has_ngiot_background(): + self._map_data.use_world_trace_scale() + else: + self._map_data.use_legacy_trace_scale() + + self._map_data.add_trace_points(data, event.lz4_len) + except ValueError as err: + _LOGGER.warning( + "Skipping invalid trace payload for geometry map " + "(start=%s total=%s lz4_len=%s): %s", + event.start, + event.total, + event.lz4_len, + err, + ) + except Exception: + _LOGGER.exception( + "Unexpected error while processing trace payload; continuing without trace" + ) unsubscribers.append(self._event_bus.subscribe(MapTraceEvent, on_map_trace)) + self._sync_ngiot_background_from_store() + def unsub() -> None: - for unsub in unsubscribers: - unsub() + for unsubscribe in unsubscribers: + unsubscribe() return unsub @@ -151,7 +172,6 @@ def refresh(self) -> None: if not self._unsubscribers: raise MapError("Please enable the map first") - # TODO make it nice self._event_bus.request_refresh(CachedMapInfoEvent) self._event_bus.request_refresh(PositionsEvent) self._event_bus.request_refresh(MapTraceEvent) @@ -168,10 +188,9 @@ def get_svg_map(self) -> str | None: _LOGGER.debug("[get_svg_map] Begin") - # Reset change before starting to build the SVG self._map_data.reset_changed() - self._last_image = self._map_data.generate_svg() + _LOGGER.debug("[get_svg_map] Finish") return self._last_image @@ -182,6 +201,58 @@ async def teardown(self) -> None: self._unsubscribers.clear() self._map_data.teardown() + def _sync_ngiot_background_from_store(self) -> None: + """Push the active NGIOT raster background into the renderer path.""" + store = getattr(self._event_bus, "_ngiot_map_state_store", None) + if store is None: + self._map_data.clear_ngiot_background() + self._map_data.use_legacy_trace_scale() + self._map_data.use_legacy_position_icon_scale() + self._map_data.use_legacy_position_transform() + return + + snapshot = None + get_active_renderable = getattr(store, "get_active_renderable", None) + if callable(get_active_renderable): + snapshot = get_active_renderable() + + if snapshot is None: + get_active = getattr(store, "get_active", None) + if callable(get_active): + snapshot = get_active() + + base_map = getattr(snapshot, "base_map", None) if snapshot is not None else None + encoded = "" + if base_map is not None: + encoded = getattr(base_map, "encoded", "") or getattr(base_map, "data", "") + + if ( + base_map is None + or not encoded + or int(getattr(base_map, "width", 0)) <= 0 + or int(getattr(base_map, "height", 0)) <= 0 + ): + self._map_data.clear_ngiot_background() + self._map_data.use_legacy_trace_scale() + self._map_data.use_legacy_position_icon_scale() + self._map_data.use_legacy_position_transform() + return + + self._map_data.set_ngiot_background( + encoded=encoded, + width=int(getattr(base_map, "width", 0)), + height=int(getattr(base_map, "height", 0)), + total_width=int(getattr(base_map, "total_width", 0)), + total_height=int(getattr(base_map, "total_height", 0)), + resolution=int(getattr(base_map, "resolution", 0)), + x_min=int(getattr(base_map, "x_min", 0)), + y_max=int(getattr(base_map, "y_max", 0)), + direction=int(getattr(base_map, "direction", 0)), + ) + self._map_data.use_world_trace_scale() + self._map_data.use_ngiot_position_icon_scale() + self._map_data.use_legacy_position_transform() + class MapData: """Map data.""" @@ -194,83 +265,177 @@ def on_change() -> None: event_bus.notify(MapChangedEvent(datetime.now(UTC)), debounce_time=1) self._on_change = on_change - self._map_subsets: OnChangedDict[int, MapSubsetEvent] = OnChangedDict(on_change) + self._map_subsets: OnChangedDict[tuple[str, int], MapSubsetEvent] = ( + OnChangedDict(on_change) + ) self._positions: list[Position] = [] self._rotation: RotationAngle = RotationAngle.DEG_0 self._data = MapDataRs() self._room_handling = MapRoomHandling(event_bus, on_change) + self.use_legacy_trace_scale() + self.use_legacy_position_icon_scale() + self.use_legacy_position_transform() + @property def changed(self) -> bool: """Indicate if data was changed.""" return self._changed @property - def map_subsets(self) -> dict[int, MapSubsetEvent]: - """Return map subsets.""" + def map_subsets(self) -> OnChangedDict[tuple[str, int], MapSubsetEvent]: + """Map subsets.""" return self._map_subsets def reset_changed(self) -> None: - """Reset changed value.""" + """Reset changed state.""" self._changed = False - def add_trace_points(self, value: str) -> None: - """Add trace points to the map data.""" - self._data.trace_points.add(value) - self._on_change() - - def clear_trace_points(self) -> None: - """Clear trace points.""" - self._data.trace_points.clear() - self._on_change() + def teardown(self) -> None: + """Teardown map data.""" + self._room_handling.teardown() - def update_positions(self, value: list[Position]) -> None: + def update_positions(self, positions: list[Position]) -> None: """Update positions.""" - self._positions = value - self._on_change() + new_positions = list(positions) + if self._positions != new_positions: + self._positions = new_positions + self._on_change() def update_map_piece(self, index: int, base64_data: str) -> None: - """Update map piece.""" + """Update legacy map piece.""" if self._data.background_image.update_map_piece(index, base64_data): self._on_change() def map_piece_crc32_indicates_update(self, index: int, crc32: int) -> bool: - """Return True if update is required.""" + """Return True if legacy map piece update is required.""" return self._data.background_image.map_piece_crc32_indicates_update( index, crc32 ) - def generate_svg(self) -> str | None: - """Generate SVG image.""" - return self._data.generate_svg( - list(self._map_subsets.values()), - self._positions, - self._rotation, + def set_rotation_angle(self, angle: int | RotationAngle) -> None: + """Set rotation angle.""" + if isinstance(angle, RotationAngle): + new_rotation = angle + else: + angle_mapping = { + 0: RotationAngle.DEG_0, + 90: RotationAngle.DEG_90, + 180: RotationAngle.DEG_180, + 270: RotationAngle.DEG_270, + } + new_rotation = angle_mapping.get(int(angle) % 360, RotationAngle.DEG_0) + + if self._rotation != new_rotation: + self._rotation = new_rotation + self._on_change() + + def set_map_info(self, map_info: list[str] | str) -> None: + """Set map info.""" + self._data.set_map_info(map_info) + self._on_change() + + def set_background_image(self, image: str) -> None: + """Set background image.""" + self._data.set_background_image(image) + self._on_change() + + def clear_background_image(self) -> None: + """Clear background image.""" + self._data.clear_background_image() + self._on_change() + + def set_ngiot_background( + self, + *, + encoded: str, + width: int, + height: int, + total_width: int, + total_height: int, + resolution: int, + x_min: int, + y_max: int, + direction: int, + ) -> None: + """Set NGIOT raster background.""" + self._data.set_ngiot_background( + encoded=encoded, + width=width, + height=height, + total_width=total_width, + total_height=total_height, + resolution=resolution, + x_min=x_min, + y_max=y_max, + direction=direction, ) + self._on_change() - def set_map_info(self, base64_info: str) -> None: - """Set compressed map info (parsing happens in Rust).""" - self._data.map_info.set(base64_info) + def clear_ngiot_background(self) -> None: + """Clear NGIOT raster background.""" + self._data.clear_ngiot_background() self._on_change() - def set_rotation_angle(self, rotation: RotationAngle) -> None: - """Set clockwise rotation angle for SVG image.""" - self._rotation = rotation + def has_ngiot_background(self) -> bool: + """Return True when an NGIOT raster background is active.""" + return self._data.has_ngiot_background() + + def use_legacy_trace_scale(self) -> None: + """Use legacy trace scaling.""" + self._data.use_legacy_trace_scale() + + def use_world_trace_scale(self) -> None: + """Use world-space trace scaling.""" + self._data.use_world_trace_scale() + + def use_legacy_position_icon_scale(self) -> None: + """Use legacy position icon scale.""" + self._data.use_legacy_position_icon_scale() + + def use_ngiot_position_icon_scale(self) -> None: + """Use NGIOT position icon scale.""" + self._data.use_ngiot_position_icon_scale() + + def use_legacy_position_transform(self) -> None: + """Use legacy position transform.""" + self._data.use_legacy_position_transform() + + def use_ngiot_position_transform(self) -> None: + """Use NGIOT position transform.""" + self._data.use_ngiot_position_transform() + + def clear_trace_points(self) -> None: + """Clear trace points.""" + self._data.clear_trace_points() self._on_change() - def teardown(self) -> None: - """Teardown map data.""" - self._room_handling.teardown() + def add_trace_points(self, data: str, lz4_len: int | None = None) -> None: + """Add trace points.""" + self._data.add_trace_points(data, lz4_len) + self._on_change() + + def generate_svg(self) -> str | None: + """Generate SVG.""" + map_subsets = list(self.map_subsets.values()) + self._room_handling.update_rooms(map_subsets) + return self._data.generate_svg( + map_subsets, + self._positions, + self._rotation, + ) class MapRoomHandling: - """Room handling.""" + """Handle room data.""" def __init__(self, event_bus: EventBus, on_change: Callable[[], None]) -> None: + self._event_bus = event_bus + self._on_change = on_change + self._rooms: dict[int, Room] = {} self._amount_rooms: int = 0 - self._rooms: OnChangedDict[int, Room] = OnChangedDict(on_change) - self._unsubscribers: list[Callable[[], None]] = [] self._map_id: str = "" + self._unsubscribers: list[Callable[[], None]] = [] async def on_map_set(event: MapSetEvent) -> None: if event.type != MapSetType.ROOMS: @@ -278,9 +443,13 @@ async def on_map_set(event: MapSetEvent) -> None: self._map_id = event.map_id self._amount_rooms = len(event.subsets) - for room_id in self._rooms.copy(): + changed = False + for room_id in list(self._rooms): if room_id not in event.subsets: self._rooms.pop(room_id, None) + changed = True + if changed: + self._on_change() self._unsubscribers.append(event_bus.subscribe(MapSetEvent, on_map_set)) @@ -289,18 +458,48 @@ async def on_map_subset(event: MapSubsetEvent) -> None: return room = Room(event.name, event.id, event.coordinates) - if self._rooms.get(event.id, None) != room: - self._rooms[room.id] = room + if self._rooms.get(event.id) != room: + self._rooms[event.id] = room + self._on_change() - if len(self._rooms) == self._amount_rooms: - event_bus.notify( - RoomsEvent(self._map_id, list(self._rooms.values())) - ) + if self._amount_rooms and len(self._rooms) == self._amount_rooms: + event_bus.notify(RoomsEvent(self._map_id, list(self._rooms.values()))) self._unsubscribers.append(event_bus.subscribe(MapSubsetEvent, on_map_subset)) + async def on_rooms(event: RoomsEvent) -> None: + rooms = {room.id: room for room in event.rooms} + if self._rooms != rooms: + self._rooms = rooms + self._map_id = event.map_id + self._amount_rooms = len(rooms) + self._on_change() + + self._unsubscribers.append(event_bus.subscribe(RoomsEvent, on_rooms)) + def teardown(self) -> None: """Teardown room handling.""" for unsubscribe in self._unsubscribers: unsubscribe() self._unsubscribers.clear() + + def update_rooms(self, map_subsets: list[MapSubsetEvent]) -> None: + """Update room subset names from cached room metadata.""" + for index, subset in enumerate(map_subsets): + if subset.type != MapSetType.ROOMS: + continue + + room = self._rooms.get(subset.id) + if room is None: + continue + + new_coordinates = subset.coordinates or room.coordinates + new_name = room.name or subset.name + replacement = MapSubsetEvent( + id=subset.id, + type=subset.type, + coordinates=new_coordinates, + name=new_name, + ) + if replacement != subset: + map_subsets[index] = replacement \ No newline at end of file diff --git a/deebot_client/message.py b/deebot_client/message.py index 57729b2fd..c97b40464 100644 --- a/deebot_client/message.py +++ b/deebot_client/message.py @@ -68,7 +68,6 @@ def wrapper(cls: type[M], event_bus: EventBus, data: T) -> HandlingResult: _LOGGER.warning("Could not parse %s: %s", cls.NAME, data, exc_info=True) return HandlingResult(HandlingState.ERROR) else: - # This happens if for some reason someone calls super() of an ABC where handle is not implemented if not response: _LOGGER.error( "Handler for message %s: %s returned no response. " @@ -101,19 +100,13 @@ def __init_subclass__(cls) -> None: def _handle( cls, event_bus: EventBus, message: MessagePayloadType ) -> HandlingResult: - """Handle message and notify the correct event subscribers. - - :return: A message response - """ + """Handle message and notify the correct event subscribers.""" @classmethod @_handle_error_or_analyse @final def handle(cls, event_bus: EventBus, message: MessagePayloadType) -> HandlingResult: - """Handle message and notify the correct event subscribers. - - :return: A message response - """ + """Handle message and notify the correct event subscribers.""" return cls._handle(event_bus, message) @@ -123,25 +116,18 @@ class MessageStr(Message, ABC): @classmethod @abstractmethod def _handle_str(cls, event_bus: EventBus, message: str) -> HandlingResult: - """Handle string message and notify the correct event subscribers. - - :return: A message response - """ + """Handle string message and notify the correct event subscribers.""" @classmethod @_handle_error_or_analyse @final - def __handle_str(cls, event_bus: EventBus, message: str) -> HandlingResult: + def _dispatch_str(cls, event_bus: EventBus, message: str) -> HandlingResult: return cls._handle_str(event_bus, message) @classmethod def _handle( cls, event_bus: EventBus, message: MessagePayloadType ) -> HandlingResult: - """Handle message and notify the correct event subscribers. - - :return: A message response - """ if isinstance(message, bytearray): data = bytes(message).decode() elif isinstance(message, bytes): @@ -149,9 +135,9 @@ def _handle( elif isinstance(message, str): data = message else: - return super()._handle(event_bus, message) + return HandlingResult.analyse() - return cls.__handle_str(event_bus, data) + return cls._dispatch_str(event_bus, data) class MessageDictOrJson(Message, ABC): @@ -162,15 +148,12 @@ class MessageDictOrJson(Message, ABC): def _handle_dict( cls, event_bus: EventBus, message: dict[str, Any] ) -> HandlingResult: - """Handle string message and notify the correct event subscribers. - - :return: A message response - """ + """Handle dict message and notify the correct event subscribers.""" @classmethod @_handle_error_or_analyse @final - def __handle_dict( + def _dispatch_dict( cls, event_bus: EventBus, message: dict[str, Any] ) -> HandlingResult: return cls._handle_dict(event_bus, message) @@ -179,10 +162,6 @@ def __handle_dict( def _handle( cls, event_bus: EventBus, message: MessagePayloadType ) -> HandlingResult: - """Handle message and notify the correct event subscribers. - - :return: A message response - """ data = message if not isinstance(message, dict): try: @@ -195,13 +174,13 @@ def _handle( ) if isinstance(data, dict): - fw_version = data.get("header", {}).get("fwVer", None) + fw_version = data.get("header", {}).get("fwVer") if fw_version: event_bus.notify(FirmwareEvent(fw_version)) - return cls.__handle_dict(event_bus, data) + return cls._dispatch_dict(event_bus, data) - return super()._handle(event_bus, message) + return HandlingResult.analyse() class MessageBody(MessageDictOrJson, ABC): @@ -210,70 +189,53 @@ class MessageBody(MessageDictOrJson, ABC): @classmethod @abstractmethod def _handle_body(cls, event_bus: EventBus, body: dict[str, Any]) -> HandlingResult: - """Handle message->body and notify the correct event subscribers. - - :return: A message response - """ + """Handle message->body and notify the correct event subscribers.""" @classmethod @_handle_error_or_analyse @final - def __handle_body(cls, event_bus: EventBus, body: dict[str, Any]) -> HandlingResult: + def _dispatch_body(cls, event_bus: EventBus, body: dict[str, Any]) -> HandlingResult: return cls._handle_body(event_bus, body) @classmethod def _handle_dict( cls, event_bus: EventBus, message: dict[str, Any] ) -> HandlingResult: - """Handle message and notify the correct event subscribers. - - :return: A message response - """ - if "body" in message: - return cls.__handle_body(event_bus, message["body"]) + body = message.get("body") + if isinstance(body, dict): + return cls._dispatch_body(event_bus, body) - return super()._handle_dict(event_bus, message) + return HandlingResult.analyse() class MessageBodyData(MessageBody, ABC): """Dict message with body->data attribute.""" @classmethod - @abstractmethod def _handle_body_data( cls, event_bus: EventBus, data: dict[str, Any] | list[Any] ) -> HandlingResult: - """Handle message->body->data and notify the correct event subscribers. - - :return: A message response - """ + """Fallback body->data handler.""" + return HandlingResult.analyse() @classmethod + @_handle_error_or_analyse @final - def __handle_body_data( + def _dispatch_body_data( cls, event_bus: EventBus, data: dict[str, Any] | list[Any] ) -> HandlingResult: - try: - response = cls._handle_body_data(event_bus, data) - except Exception: - _LOGGER.warning("Could not parse %s: %s", cls.NAME, data, exc_info=True) - return HandlingResult(HandlingState.ERROR) - else: - if response.state == HandlingState.ANALYSE: - _LOGGER.debug("Could not handle %s message: %s", cls.NAME, data) - return HandlingResult(HandlingState.ANALYSE_LOGGED, response.args) - return response + return cls._handle_body_data(event_bus, data) @classmethod def _handle_body(cls, event_bus: EventBus, body: dict[str, Any]) -> HandlingResult: - """Handle message->body and notify the correct event subscribers. + data = body.get("data") + if data is None: + return HandlingResult.analyse() - :return: A message response - """ - if "data" in body: - return cls.__handle_body_data(event_bus, body["data"]) + if isinstance(data, (dict, list)): + return cls._dispatch_body_data(event_bus, data) - return super()._handle_body(event_bus, body) + return HandlingResult.analyse() class MessageBodyDataDict(MessageBodyData, ABC): @@ -284,23 +246,16 @@ class MessageBodyDataDict(MessageBodyData, ABC): def _handle_body_data_dict( cls, event_bus: EventBus, data: dict[str, Any] ) -> HandlingResult: - """Handle message->body->data and notify the correct event subscribers. - - :return: A message response - """ + """Handle dict body->data and notify the correct event subscribers.""" @classmethod def _handle_body_data( cls, event_bus: EventBus, data: dict[str, Any] | list[Any] ) -> HandlingResult: - """Handle message->body->data and notify the correct event subscribers. - - :return: A message response - """ if isinstance(data, dict): return cls._handle_body_data_dict(event_bus, data) - return super()._handle_body_data(event_bus, data) + return HandlingResult.analyse() class MessageBodyDataList(MessageBodyData, ABC): @@ -311,20 +266,13 @@ class MessageBodyDataList(MessageBodyData, ABC): def _handle_body_data_list( cls, event_bus: EventBus, data: list[Any] ) -> HandlingResult: - """Handle message->body->data and notify the correct event subscribers. - - :return: A message response - """ + """Handle list body->data and notify the correct event subscribers.""" @classmethod def _handle_body_data( cls, event_bus: EventBus, data: dict[str, Any] | list[Any] ) -> HandlingResult: - """Handle message->body->data and notify the correct event subscribers. - - :return: A message response - """ if isinstance(data, list): return cls._handle_body_data_list(event_bus, data) - return super()._handle_body_data(event_bus, data) + return HandlingResult.analyse() \ No newline at end of file diff --git a/deebot_client/messages/json/__init__.py b/deebot_client/messages/json/__init__.py index 11d8a9d65..e1a41a7d3 100644 --- a/deebot_client/messages/json/__init__.py +++ b/deebot_client/messages/json/__init__.py @@ -11,6 +11,7 @@ from .battery import OnBattery from .gps_position import OnGpsPos from .map import OnCachedMapInfo, OnMajorMap, OnMapInfoV2, OnMapSetV2 +from .ngiot import OnNgiotMapEvent, OnNgiotStatusEvent from .station_state import OnStationState from .stats import OnStats, ReportStats from .work_state import OnWorkState @@ -27,6 +28,8 @@ "OnStats", "OnWorkState", "ReportStats", + "OnNgiotMapEvent", + "OnNgiotStatusEvent", ] # fmt: off @@ -38,6 +41,9 @@ OnGpsPos, + OnNgiotMapEvent, + OnNgiotStatusEvent, + OnCachedMapInfo, OnMajorMap, OnMapInfoV2, @@ -108,4 +114,4 @@ def get_legacy_message(message_name: str, converted_name: str) -> type[Message] _LOGGER.debug('Command "%s" doesn\'t support message handling', converted_name) - return None + return None \ No newline at end of file diff --git a/deebot_client/messages/json/ngiot.py b/deebot_client/messages/json/ngiot.py new file mode 100644 index 000000000..601bfee9d --- /dev/null +++ b/deebot_client/messages/json/ngiot.py @@ -0,0 +1,102 @@ +"""NGIOT numeric-topic MQTT messages.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from deebot_client.commands.ngiot.battery import GetBattery +from deebot_client.commands.ngiot.child_lock import GetChildLock +from deebot_client.commands.ngiot.clean import GetCleanInfo, map_live_state +from deebot_client.commands.ngiot.map import GetMajorMap, GetMapTrace, GetPos +from deebot_client.commands.ngiot.stats import GetStats +from deebot_client.commands.ngiot.volume import GetVolume +from deebot_client.events import StateEvent +from deebot_client.logging_filter import get_logger +from deebot_client.message import HandlingResult, HandlingState, MessageBodyDataDict + +if TYPE_CHECKING: + from deebot_client.event_bus import EventBus + +_LOGGER = get_logger(__name__) + + +class OnNgiotMapEvent(MessageBodyDataDict): + """Handle NGIOT live map/status multiplexed on numeric topic 30000.""" + + NAME = "30000" + + @classmethod + def _handle_body_data_dict( + cls, event_bus: EventBus, data: dict[str, Any] + ) -> HandlingResult: + handled = False + + if any(key in data for key in ("mapData", "areas", "chargePos")): + result = GetMajorMap._handle_body_data_dict(event_bus, data) + handled = handled or result.state == HandlingState.SUCCESS + + # Live pose-only payloads may arrive without mapData. + if "pos" in data and "mapData" not in data: + result = GetPos._handle_body_data_dict(event_bus, data) + handled = handled or result.state == HandlingState.SUCCESS + + if "mapTraceData" in data: + result = GetMapTrace._handle_body_data_dict(event_bus, data) + handled = handled or result.state == HandlingState.SUCCESS + + # mapMinorData is confirmed to exist on the live channel, but the current + # renderer path does not yet support generic NGIOT delta raster application. + # Log it so captures remain explainable without pretending to render deltas. + if "mapMinorData" in data: + _LOGGER.debug( + "Observed NGIOT live minor-map payload on topic 30000; " + "generic delta-raster application is not implemented yet" + ) + handled = True + + return HandlingResult.success() if handled else HandlingResult.analyse() + + +class OnNgiotStatusEvent(MessageBodyDataDict): + """Handle NGIOT live status multiplexed on numeric topic 10000.""" + + NAME = "10000" + + @classmethod + def _handle_body_data_dict( + cls, event_bus: EventBus, data: dict[str, Any] + ) -> HandlingResult: + handled = False + + if "battery" in data: + result = GetBattery._handle_body_data_dict(event_bus, data) + handled = handled or result.state == HandlingState.SUCCESS + + if any(key in data for key in ("cleanArea", "cleanTime", "workMode")): + result = GetStats._handle_body_data_dict(event_bus, data) + handled = handled or result.state == HandlingState.SUCCESS + + if "childLock" in data: + result = GetChildLock._handle_body_data_dict(event_bus, data) + handled = handled or result.state == HandlingState.SUCCESS + + if "volume" in data: + result = GetVolume._handle_body_data_dict(event_bus, data) + handled = handled or result.state == HandlingState.SUCCESS + + live_state = map_live_state( + data, + previous=( + event_bus.get_last_event(StateEvent).state + if event_bus.get_last_event(StateEvent) is not None + else None + ), + ) + if live_state is not None: + event_bus.notify(StateEvent(live_state)) + handled = True + elif any(key in data for key in ("status", "pauseSwitch", "chargeStatus")): + result = GetCleanInfo._handle_body_data_dict(event_bus, data) + handled = handled or result.state == HandlingState.SUCCESS + + return HandlingResult.success() if handled else HandlingResult.analyse() \ No newline at end of file diff --git a/deebot_client/mqtt_client.py b/deebot_client/mqtt_client.py index 90b27f605..06687c75f 100644 --- a/deebot_client/mqtt_client.py +++ b/deebot_client/mqtt_client.py @@ -34,20 +34,32 @@ _CLIENT_LOGGER = get_logger(f"{__name__}.client") -def _get_topics(device_info: DeviceInfo) -> list[str]: +def _get_topics(device_info: DeviceInfo, user_id: str | None = None) -> list[str]: api = device_info.api device_path = f"{api['did']}/{api['class']}/{api['resource']}" data_type = device_info.static.data_type - return [ - # iot/atr/[command]/[did]]/[class]]/[resource]/[data_type] + + topics = [ + # Legacy/message-name ATR routing. + # iot/atr/[command]/[did]/[class]/[resource]/[data_type] f"iot/atr/+/{device_path}/{data_type}", - # iot/p2p/[command]/[sender did]/[sender class]]/[sender resource] + # iot/p2p/[command]/[sender did]/[sender class]/[sender resource] # /[receiver did]/[receiver class]/[receiver resource]/[q|p]/[request id]/[data_type] # [q|p] q-> request p-> response f"iot/p2p/+/+/+/+/{device_path}/q/+/{data_type}", f"iot/p2p/+/{device_path}/+/+/+/p/+/{data_type}", ] + if user_id: + # NGIOT live-event routing observed on some newer devices: + # iot/atr/[channel]/[user-id]/[device-class]/[device-id]/[data_type] + # The final device identifier can vary between observed payloads, so + # subscribe to both known device identifiers. + for device_id in dict.fromkeys((api['did'], api['resource'])): + topics.append(f"iot/atr/+/{user_id}/{api['class']}/{device_id}/{data_type}") + + return topics + @dataclass(frozen=True, kw_only=True) class MqttConfiguration: @@ -200,8 +212,11 @@ async def mqtt() -> None: try: async with await self._get_client() as client: _LOGGER.debug("Subscribe to all previous subscriptions") + credentials = await self._authenticator.authenticate() for info in self._subscriptions.values(): - for topic in _get_topics(info.device_info): + for topic in _get_topics( + info.device_info, credentials.user_id + ): await client.subscribe(topic) async def listen() -> None: @@ -264,7 +279,8 @@ async def _pending_subscriptions_worker(self, client: Client) -> None: (info, add) = await self._subscription_changes.get() device_info = info.device_info - for topic in _get_topics(device_info): + credentials = await self._authenticator.authenticate() + for topic in _get_topics(device_info, credentials.user_id): if add: await client.subscribe(topic) else: @@ -277,9 +293,40 @@ async def _pending_subscriptions_worker(self, client: Client) -> None: self._subscription_changes.task_done() + @staticmethod + def _topic_matches_device(topic_split: list[str], device_info: DeviceInfo) -> bool: + if len(topic_split) < 7: + return False + + api = device_info.api + data_type = str(device_info.static.data_type) + if topic_split[6] != data_type: + return False + + legacy_match = ( + topic_split[3] == api["did"] + and topic_split[4] == api["class"] + and topic_split[5] == api["resource"] + ) + if legacy_match: + return True + + # NGIOT live events can use the observed shape: + # iot/atr/[channel]/[user-id]/[device-class]/[device-id]/[data_type] + return topic_split[4] == api["class"] and topic_split[5] in { + api["did"], + api["resource"], + } + + def _resolve_atr_subscription(self, topic_split: list[str]) -> SubscriberInfo | None: + for info in self._subscriptions.values(): + if self._topic_matches_device(topic_split, info.device_info): + return info + return None + def _handle_atr(self, topic_split: list[str], payload: bytes) -> None: try: - if sub_info := self._subscriptions.get(topic_split[3]): + if sub_info := self._resolve_atr_subscription(topic_split): sub_info.callback(topic_split[2], payload) except Exception: _LOGGER.exception("An exception occurred during handling atr message") @@ -321,4 +368,4 @@ def _handle_p2p(self, topic_split: list[str], payload: bytes) -> None: "An exception occurred during handling p2p message: topic=%s; payload=%s", "/".join(topic_split), payload, - ) + ) \ No newline at end of file diff --git a/deebot_client/ngiot_client.py b/deebot_client/ngiot_client.py new file mode 100644 index 000000000..f0e4efc4e --- /dev/null +++ b/deebot_client/ngiot_client.py @@ -0,0 +1,590 @@ +"""NGIOT endpoint-control client for eco-ng devices such as eyfj07.""" + +from __future__ import annotations + +import asyncio +import json +import secrets +import string +import time +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from http import HTTPStatus +from typing import TYPE_CHECKING, Any +from urllib.parse import urljoin + +from aiohttp import ClientResponseError, ClientSession, ClientTimeout, hdrs + +from .exceptions import ApiError, ApiTimeoutError, AuthenticationError +from .logging_filter import get_logger +from .sst_authentication import SstAuthenticator + +if TYPE_CHECKING: + from .models import ApiDeviceInfo, DeviceInfo + +_LOGGER = get_logger(__name__) + +_TIMEOUT = ClientTimeout(60) + +_PATH_ENDPOINT_CONTROL = "/api/iot/endpoint/control" +_DEFAULT_FMT = "j" +_DEFAULT_CT = "q" + +# Public APN constants are imported by multiple NGIOT command modules. +APN_ROBOT_DETAIL = "10001" +APN_MAP_DETAILS = "30001" +APN_CLEAN_START = "40001" +APN_AREA_CLEAN = "40007" +APN_PAUSE = "40009" +APN_RESUME = "40011" +APN_RETURN_TO_DOCK = "40013" +APN_CANCEL_RETURN = "40015" +APN_DEVICE_LOCATE = "40019" +APN_RESET_CONSUMABLE = "50017" +APN_CHILD_LOCK = "50038" + +# Some devices transiently return "cmd busy" while state is changing. +_TRANSIENT_RESPONSE_CODES = {1} +_TRANSIENT_RESPONSE_MESSAGES = {"cmd busy"} + + +@dataclass(frozen=True) +class NgiotDeviceIdentity: + """Normalized NGIOT device identity.""" + + did: str + class_id: str + resource: str + control_host: str + fallback_control_host: str | None = None + + @property + def key(self) -> str: + """Stable cache/logging key.""" + return f"{self.class_id}:{self.did}:{self.resource}" + + @property + def base_url(self) -> str: + """Normalized HTTPS base URL for NGIOT control.""" + if self.control_host.startswith(("http://", "https://")): + return self.control_host.rstrip("/") + return f"https://{self.control_host}".rstrip("/") + + +class NgiotClient: + """Thin client for NGIOT endpoint-control reads and writes.""" + + def __init__( + self, + session: ClientSession, + sst_authenticator: SstAuthenticator, + *, + user_agent: str = "okhttp/4.9.1", + channel: str = "Android", + protocol_version: str = "0.0.22", + timezone_name: str = "UTC", + timezone_offset_minutes: int = 0, + override_control_host: str | None = None, + ) -> None: + self._session = session + self._sst_authenticator = sst_authenticator + self._user_agent = user_agent + self._channel = channel + self._protocol_version = protocol_version + self._timezone_name = timezone_name + self._timezone_offset_minutes = timezone_offset_minutes + self._override_control_host = override_control_host + + async def request( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + *, + apn: str | int, + body_data: Mapping[str, Any] | Sequence[Any], + fmt: str = _DEFAULT_FMT, + ct: str = _DEFAULT_CT, + force_sst_refresh: bool = False, + ) -> dict[str, Any]: + """Execute a single NGIOT endpoint-control request.""" + identity = self._normalize_device(device) + return await self._request_with_fallback( + identity, + device, + apn=apn, + body_data=body_data, + fmt=fmt, + ct=ct, + force_sst_refresh=force_sst_refresh, + ) + + async def _request_with_fallback( + self, + identity: NgiotDeviceIdentity, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + *, + apn: str | int, + body_data: Mapping[str, Any] | Sequence[Any], + fmt: str, + ct: str, + force_sst_refresh: bool, + ) -> dict[str, Any]: + """Execute an NGIOT request and fall back to service.mqs on 404.""" + try: + return await self._request_once( + identity, + device, + apn=apn, + body_data=body_data, + fmt=fmt, + ct=ct, + force_sst_refresh=force_sst_refresh, + ) + except ClientResponseError as ex: + if ( + ex.status == HTTPStatus.NOT_FOUND + and identity.fallback_control_host + and identity.fallback_control_host != identity.control_host + ): + fallback_identity = NgiotDeviceIdentity( + did=identity.did, + class_id=identity.class_id, + resource=identity.resource, + control_host=identity.fallback_control_host, + fallback_control_host=None, + ) + _LOGGER.info( + "NGIOT endpoint-control returned 404 on %s for %s; retrying with device mqs host %s", + identity.base_url, + identity.key, + fallback_identity.base_url, + ) + return await self._request_once( + fallback_identity, + device, + apn=apn, + body_data=body_data, + fmt=fmt, + ct=ct, + force_sst_refresh=force_sst_refresh, + ) + raise + + async def _request_once( + self, + identity: NgiotDeviceIdentity, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + *, + apn: str | int, + body_data: Mapping[str, Any] | Sequence[Any], + fmt: str, + ct: str, + force_sst_refresh: bool, + ) -> dict[str, Any]: + """Execute a single NGIOT endpoint-control request against one control host.""" + url = urljoin(identity.base_url + "/", _PATH_ENDPOINT_CONTROL.lstrip("/")) + + request_id = self._new_request_id() + body_reqid = self._new_body_reqid() + + query_params = { + "si": request_id, + "ct": ct, + "eid": identity.did, + "et": identity.class_id, + "er": identity.resource, + "apn": str(apn), + "fmt": fmt, + } + + payload = { + "body": {"data": self._build_payload(apn, body_data)}, + "header": self._create_body_header(body_reqid), + } + + token = await self._sst_authenticator.get_token( + self._device_mapping(identity), + force=force_sst_refresh, + ) + + headers = { + hdrs.AUTHORIZATION: f"Bearer {token}", + "x-eco-request-id": request_id, + hdrs.CONTENT_TYPE: "application/octet-stream", + hdrs.USER_AGENT: self._user_agent, + } + + logger_request_params = { + "url": url, + "query_params": query_params, + "payload": payload, + "device_key": identity.key, + } + + try: + _LOGGER.debug("Calling NGIOT api: %s", logger_request_params) + + async with self._session.post( + url, + params=query_params, + data=json.dumps(payload, separators=(",", ":")).encode("utf-8"), + headers=headers, + timeout=_TIMEOUT, + ) as res: + res.raise_for_status() + content_type = res.headers.get(hdrs.CONTENT_TYPE, "").lower() + response_data: dict[str, Any] = await res.json( + content_type=content_type or None + ) + + _LOGGER.debug( + "Success calling NGIOT api %s, response=%s", + logger_request_params, + response_data, + ) + _LOGGER.debug( + "NGIOT protocol trace -> apn=%s payload=%s response=%s", + apn, + payload["body"]["data"], + response_data, + ) + + validation = self._classify_response(response_data) + if validation == "retry_busy": + _LOGGER.debug( + "NGIOT request returned transient busy for %s apn=%s; retrying once", + identity.key, + apn, + ) + await asyncio.sleep(1) + return await self._request_retry_after_busy( + identity, + device, + apn=apn, + body_data=body_data, + fmt=fmt, + ct=ct, + force_sst_refresh=force_sst_refresh, + ) + + return response_data + + except TimeoutError as ex: + raise ApiTimeoutError(path=_PATH_ENDPOINT_CONTROL, timeout=_TIMEOUT) from ex + except ClientResponseError as ex: + if ( + ex.status in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN) + and not force_sst_refresh + ): + _LOGGER.info( + "NGIOT request unauthorized for %s. Invalidating SST and retrying once.", + identity.key, + ) + await self._sst_authenticator.invalidate(self._device_mapping(identity)) + return await self._request_with_fallback( + identity, + device, + apn=apn, + body_data=body_data, + fmt=fmt, + ct=ct, + force_sst_refresh=True, + ) + + if ex.status in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN): + raise AuthenticationError( + "NGIOT endpoint-control request was not authorized" + ) from ex + + _LOGGER.debug("NGIOT request failed: %s", logger_request_params, exc_info=True) + raise + + async def _request_retry_after_busy( + self, + identity: NgiotDeviceIdentity, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + *, + apn: str | int, + body_data: Mapping[str, Any] | Sequence[Any], + fmt: str, + ct: str, + force_sst_refresh: bool, + ) -> dict[str, Any]: + """Retry once after a transient busy response.""" + retry_response = await self._request_with_fallback( + identity, + device, + apn=apn, + body_data=body_data, + fmt=fmt, + ct=ct, + force_sst_refresh=force_sst_refresh, + ) + self._validate_response(retry_response) + return retry_response + + async def query_fields( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + *, + apn: str | int, + fields: Sequence[str], + map_id: str | int | None = None, + ) -> Any: + """Query a field-based NGIOT surface and return body.data.""" + body_data: dict[str, Any] = {"fields": list(fields)} + if map_id is not None: + body_data["mapId"] = str(map_id) + + response = await self.request(device, apn=apn, body_data=body_data) + return self._extract_body_data(response) + + async def write_data( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + *, + apn: str | int, + data: Mapping[str, Any], + ) -> Any: + """Send a direct key/value NGIOT control payload and return body.data.""" + response = await self.request(device, apn=apn, body_data=dict(data)) + return self._extract_body_data(response) + + async def get_robot_detail( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + fields: Sequence[str], + ) -> Any: + """Read robot detail fields from the status surface.""" + return await self.query_fields( + device, + apn=APN_ROBOT_DETAIL, + fields=fields, + ) + + async def get_map_details( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + fields: Sequence[str], + *, + map_id: str | int | None = None, + ) -> Any: + """Read map detail fields from the map surface.""" + return await self.query_fields( + device, + apn=APN_MAP_DETAILS, + fields=fields, + map_id=map_id, + ) + + async def set_pause( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + *, + pause: bool, + ) -> Any: + """Pause or resume the current cleaning job.""" + return await self.write_data( + device, + apn=APN_PAUSE if pause else APN_RESUME, + data={"pauseSwitch": pause}, + ) + + async def set_charge( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + *, + enabled: bool, + ) -> Any: + """Start or cancel dock/charge behavior.""" + return await self.write_data( + device, + apn=APN_RETURN_TO_DOCK if enabled else APN_CANCEL_RETURN, + data={"chargeSwitch": enabled}, + ) + + async def start_smart_clean( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + ) -> Any: + """Start default smart cleaning.""" + return await self.write_data( + device, + apn=APN_CLEAN_START, + data={"cleanSwitch": True, "cleanMode": "smart"}, + ) + + async def start_area_clean( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + room_ids: Sequence[int], + ) -> Any: + """Start area cleaning for one or more room IDs.""" + return await self.write_data( + device, + apn=APN_AREA_CLEAN, + data={ + "cleanSwitch": True, + "cleanMode": "area", + "cleanValues": list(room_ids), + }, + ) + + def _normalize_device( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + ) -> NgiotDeviceIdentity: + """Normalize raw API device payload into NGIOT routing fields.""" + raw_device = device.api if hasattr(device, "api") else device + + if not isinstance(raw_device, Mapping): + msg = f"Unsupported device type for NGIOT client: {type(device)!r}" + raise TypeError(msg) + + service = raw_device.get("service", {}) + service_mqs_host = None + if isinstance(service, Mapping): + candidate = service.get("mqs") + if isinstance(candidate, str) and candidate: + service_mqs_host = candidate + + host = self._override_control_host or service_mqs_host + + if not host: + msg = f"Missing NGIOT control host in device service binding: {raw_device}" + raise ApiError(msg) + + try: + return NgiotDeviceIdentity( + did=str(raw_device["did"]), + class_id=str(raw_device["class"]), + resource=str(raw_device["resource"]), + control_host=str(host), + fallback_control_host=( + str(service_mqs_host) + if service_mqs_host and str(service_mqs_host) != str(host) + else None + ), + ) + except KeyError as ex: + msg = f"Missing required NGIOT device field: {ex.args[0]}" + raise ApiError(msg) from ex + + def _build_payload( + self, + apn: str | int, + body_data: Mapping[str, Any] | Sequence[Any], + ) -> Mapping[str, Any] | Sequence[Any]: + """Build a device-tolerant NGIOT payload.""" + if not isinstance(body_data, Mapping): + return body_data + + payload: dict[str, Any] = dict(body_data) + now_ms = int(time.time() * 1000) + now_s = int(time.time()) + + payload.setdefault("reqId", str(now_ms)) + payload.setdefault("timestamp", now_s) + + apn_str = str(apn) + if apn_str == APN_ROBOT_DETAIL: + payload.setdefault("type", "get") + elif apn_str == APN_MAP_DETAILS: + payload.setdefault("mapId", str(payload.get("mapId", "0"))) + + return payload + + def _create_body_header(self, reqid: str) -> dict[str, Any]: + """Create request body header matching the observed mobile shape.""" + return { + "channel": self._channel, + "m": "request", + "pri": 2, + "reqid": reqid, + "ts": str(int(time.time() * 1000)), + "tzc": self._timezone_name, + "tzm": self._timezone_offset_minutes, + "ver": self._protocol_version, + } + + @staticmethod + def _extract_body_data(response: Mapping[str, Any]) -> Any: + """Return body.data, defaulting ACK-only responses to an empty dict.""" + body = response.get("body") + if not isinstance(body, Mapping): + return {} + return body.get("data", {}) + + @classmethod + def _classify_response(cls, response: Mapping[str, Any] | None) -> str: + """Classify NGIOT envelope and support ACK-only or transient-busy replies.""" + if response is None: + _LOGGER.debug("Empty NGIOT response body returned by server") + return "ok" + + body = response.get("body") + if not isinstance(body, Mapping): + _LOGGER.debug("NGIOT response omitted body; treating as ACK-only success") + return "ok" + + code = body.get("code", 0) + msg = str(body.get("msg", "")).strip().lower() + + if code in (0, "0000", None): + return "ok" + + if code in _TRANSIENT_RESPONSE_CODES and msg in _TRANSIENT_RESPONSE_MESSAGES: + return "retry_busy" + + raise ApiError( + f"NGIOT request failed with code {code} ({body.get('msg', 'unknown error')}) " + f"for {_PATH_ENDPOINT_CONTROL}" + ) + + @staticmethod + def _validate_response(response: Mapping[str, Any] | None) -> None: + """Validate NGIOT envelope and raise ApiError on device-side failures.""" + if response is None: + _LOGGER.debug("Empty NGIOT response body returned by server") + return + + if not isinstance(response, Mapping): + raise ApiError("Invalid NGIOT response: missing body") + + body = response.get("body") + if body is None: + # Accept ACK-only envelopes that still include a header. + if isinstance(response.get("header"), Mapping): + _LOGGER.debug("NGIOT ACK-only response without body: %s", response) + return + raise ApiError("Invalid NGIOT response: missing body") + + if not isinstance(body, Mapping): + raise ApiError("Invalid NGIOT response: missing body") + + code = body.get("code", 0) + if code not in (0, "0000", None): + msg = body.get("msg", "unknown error") + raise ApiError( + f"NGIOT request failed with code {code} ({msg}) for {_PATH_ENDPOINT_CONTROL}" + ) + + @staticmethod + def _new_request_id() -> str: + """Generate request ID for query/header transport fields.""" + return secrets.token_hex(16) + + @staticmethod + def _new_body_reqid(length: int = 6) -> str: + """Generate short request ID for the NGIOT body header.""" + alphabet = string.ascii_letters + string.digits + return "".join(secrets.choice(alphabet) for _ in range(length)) + + @staticmethod + def _device_mapping(identity: NgiotDeviceIdentity) -> dict[str, str]: + """Convert identity into a mapping accepted by SstAuthenticator.""" + return { + "did": identity.did, + "class": identity.class_id, + "resource": identity.resource, + "service": {"mqs": identity.control_host}, + } \ No newline at end of file diff --git a/deebot_client/ngiot_map_parser.py b/deebot_client/ngiot_map_parser.py new file mode 100644 index 000000000..5c2d0773f --- /dev/null +++ b/deebot_client/ngiot_map_parser.py @@ -0,0 +1,372 @@ +"""NGIOT map parser and normalization helpers. + +This module converts raw NGIOT mapping payload fragments into typed Python +structures. It is intentionally conservative: it accepts partial data, +normalizes where the protocol is clear, and preserves raw fragments where the +payload shape may still vary by device or firmware. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(slots=True, frozen=True) +class NgiotPoint: + """A raw or normalized point in the NGIOT map coordinate space.""" + + x: int + y: int + + +@dataclass(slots=True, frozen=True) +class NgiotPose: + """Robot pose.""" + + x: int + y: int + a: int = 0 + + +@dataclass(slots=True, frozen=True) +class NgiotMapInfo: + """Metadata for a single saved map.""" + + map_id: str + name: str + using: bool + angle: int + charge_pos: NgiotPoint | None = None + + +@dataclass(slots=True, frozen=True) +class NgiotBaseMap: + """Base map metadata and encoded raster payload.""" + + map_id: str + width: int + height: int + total_width: int + total_height: int + resolution: int + x_min: int + y_max: int + direction: int + encoded: str + lz4_len: int | None = None + + +@dataclass(slots=True, frozen=True) +class NgiotArea: + """Area / room / partition geometry.""" + + area_id: str + name: str | None + polygon: list[NgiotPoint] = field(default_factory=list) + raw: dict[str, Any] | None = None + + +@dataclass(slots=True, frozen=True) +class NgiotTrace: + """Compressed trace payload metadata.""" + + trace_id: str | None + encoded: str + lz4_len: int | None + total_count: int + start: int + + +@dataclass(slots=True, frozen=True) +class NgiotOverlay: + """Overlay geometry such as virtual walls, mop walls, or carpets.""" + + overlay_type: str + overlay_id: str + polygon: list[NgiotPoint] = field(default_factory=list) + raw: dict[str, Any] | None = None + + +SUPPORTED_OVERLAYS: tuple[tuple[str, str], ...] = ( + ("virtual_walls", "virtualWalls"), + ("mop_walls", "mopWalls"), + ("carpets", "carpets"), +) + + +def _coerce_int(value: Any, default: int = 0) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + + +def _coerce_str(value: Any) -> str: + return str(value).strip() if value is not None else "" + + + +def resolve_map_id(data: dict[str, Any], fallback: str = "") -> str: + """Resolve a map ID from a mixed NGIOT payload.""" + if (map_id := _coerce_str(data.get("mapId"))): + return map_id + + map_data = data.get("mapData") + if isinstance(map_data, dict) and (map_id := _coerce_str(map_data.get("mapId"))): + return map_id + + trace_data = data.get("mapTraceData") + if isinstance(trace_data, dict) and (map_id := _coerce_str(trace_data.get("mapId"))): + return map_id + + return _coerce_str(fallback) + + + +def _parse_point(value: Any) -> NgiotPoint | None: + if not isinstance(value, dict): + return None + if value.get("x") is None or value.get("y") is None: + return None + return NgiotPoint( + x=_coerce_int(value.get("x")), + y=_coerce_int(value.get("y")), + ) + + + +def _extract_point_pairs(flat_values: list[Any]) -> list[NgiotPoint]: + points: list[NgiotPoint] = [] + ints = [_coerce_int(value) for value in flat_values] + for index in range(0, len(ints) - 1, 2): + points.append(NgiotPoint(x=ints[index], y=ints[index + 1])) + return points + + + +def _extract_polygon(raw: Any) -> list[NgiotPoint]: + """Best-effort polygon extraction for multiple observed payload shapes.""" + if isinstance(raw, list): + if not raw: + return [] + + if all(isinstance(item, dict) for item in raw): + result: list[NgiotPoint] = [] + for item in raw: + point = _parse_point(item) + if point is not None: + result.append(point) + return result + + if all(not isinstance(item, (dict, list, tuple)) for item in raw): + return _extract_point_pairs(raw) + + result: list[NgiotPoint] = [] + for item in raw: + if isinstance(item, (list, tuple)) and len(item) >= 2: + result.append(NgiotPoint(x=_coerce_int(item[0]), y=_coerce_int(item[1]))) + else: + result.extend(_extract_polygon(item)) + return result + + if isinstance(raw, dict): + for key in ( + "points", + "polygon", + "coordinates", + "coord", + "posList", + "vertexes", + "vertices", + "outline", + ): + if key in raw: + return _extract_polygon(raw[key]) + + return [] + + + +def parse_map_infos(data: dict[str, Any]) -> list[NgiotMapInfo]: + """Parse map registry / mapInfos payload.""" + infos: list[NgiotMapInfo] = [] + raw_infos = data.get("mapInfos") + if not isinstance(raw_infos, list): + return infos + + for raw in raw_infos: + if not isinstance(raw, dict): + continue + + map_id = _coerce_str(raw.get("mapId")) + if not map_id or map_id == "0": + continue + + infos.append( + NgiotMapInfo( + map_id=map_id, + name=_coerce_str(raw.get("name")), + using=_coerce_int(raw.get("status")) == 1, + angle=_coerce_int(raw.get("angle")), + charge_pos=_parse_point(raw.get("chargePos")), + ) + ) + + return infos + + + +def parse_base_map(data: dict[str, Any], map_id: str | None = None) -> NgiotBaseMap | None: + """Parse base map metadata and encoded raster payload.""" + raw = data.get("mapData") + if not isinstance(raw, dict): + return None + + encoded = _coerce_str(raw.get("map")) or _coerce_str(raw.get("data")) + if not encoded: + return None + + resolved_map_id = _coerce_str(map_id) or resolve_map_id(data) + + return NgiotBaseMap( + map_id=resolved_map_id, + width=_coerce_int(raw.get("width")), + height=_coerce_int(raw.get("height")), + total_width=_coerce_int(raw.get("totalWidth")), + total_height=_coerce_int(raw.get("totalHeight")), + resolution=max(1, _coerce_int(raw.get("resolution"), 1)), + x_min=_coerce_int(raw.get("xMin")), + y_max=_coerce_int(raw.get("yMax")), + direction=_coerce_int(raw.get("direction")), + encoded=encoded, + lz4_len=_coerce_int(raw.get("lz4Len")) or None, + ) + + +def parse_pose(data: dict[str, Any]) -> NgiotPose | None: + """Parse robot pose from a payload fragment.""" + raw = data.get("pos") + if not isinstance(raw, dict): + raw = data.get("deebotPos") + if not isinstance(raw, dict): + return None + + if raw.get("x") is None or raw.get("y") is None: + return None + + return NgiotPose( + x=_coerce_int(raw.get("x")), + y=_coerce_int(raw.get("y")), + a=_coerce_int(raw.get("a")), + ) + + + +def parse_trace(data: dict[str, Any]) -> NgiotTrace | None: + """Parse compressed map trace metadata.""" + raw = data.get("mapTraceData") + if not isinstance(raw, dict): + return None + + return NgiotTrace( + trace_id=_coerce_str(raw.get("traceId")) or None, + encoded=_coerce_str(raw.get("trace")), + lz4_len=_coerce_int(raw.get("lz4Len")) or None, + total_count=_coerce_int(raw.get("totalCount")), + start=_coerce_int(raw.get("start")), + ) + + + +def parse_areas(data: dict[str, Any]) -> list[NgiotArea]: + """Parse room / area segmentation payload.""" + results: list[NgiotArea] = [] + raw_areas = data.get("areas") + if not isinstance(raw_areas, list): + return results + + for index, raw in enumerate(raw_areas): + if not isinstance(raw, dict): + continue + + area_id = _coerce_str( + raw.get("id") + or raw.get("areaId") + or raw.get("subId") + or raw.get("mid") + or index + ) + name = _coerce_str(raw.get("name") or raw.get("label")) or None + polygon = _extract_polygon(raw) + + results.append( + NgiotArea( + area_id=area_id, + name=name, + polygon=polygon, + raw=raw, + ) + ) + + return results + + + +def parse_overlays(data: dict[str, Any]) -> list[NgiotOverlay]: + """Parse overlay layers from a payload fragment.""" + results: list[NgiotOverlay] = [] + + for overlay_type, field_name in SUPPORTED_OVERLAYS: + raw_items = data.get(field_name) + if not isinstance(raw_items, list): + continue + + for index, raw in enumerate(raw_items): + if not isinstance(raw, dict): + continue + + overlay_id = _coerce_str( + raw.get("id") or raw.get("subId") or raw.get("mid") or index + ) + polygon = _extract_polygon(raw) + + results.append( + NgiotOverlay( + overlay_type=overlay_type, + overlay_id=overlay_id, + polygon=polygon, + raw=raw, + ) + ) + + return results + + + +def normalize_point(point: NgiotPoint, base_map: NgiotBaseMap) -> NgiotPoint: + """Normalize a raw NGIOT point into cropped-raster space.""" + left = base_map.x_min - base_map.y_max + top = base_map.y_max - base_map.height + + col = int((point.x - left) / base_map.resolution) + row = int((top - point.y) / base_map.resolution) + + if base_map.direction == -1: + row = (base_map.height - 1) - row + + return NgiotPoint(x=col, y=row) + + + +def normalize_pose(pose: NgiotPose, base_map: NgiotBaseMap) -> NgiotPose: + """Normalize a raw robot pose into map-render space.""" + point = normalize_point(NgiotPoint(pose.x, pose.y), base_map) + return NgiotPose(x=point.x, y=point.y, a=pose.a) + + + +def normalize_polygon(points: list[NgiotPoint], base_map: NgiotBaseMap) -> list[NgiotPoint]: + """Normalize a polygon into map-render space.""" + return [normalize_point(point, base_map) for point in points] \ No newline at end of file diff --git a/deebot_client/ngiot_map_state.py b/deebot_client/ngiot_map_state.py new file mode 100644 index 000000000..1db69c62d --- /dev/null +++ b/deebot_client/ngiot_map_state.py @@ -0,0 +1,241 @@ +"""NGIOT map state aggregation. + +This module stores a per-map snapshot assembled from multiple NGIOT field-based +responses. It keeps raw values intact, exposes normalized views for later +rendering, and tolerates partial updates arriving in any order. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from .ngiot_map_parser import ( + NgiotArea, + NgiotBaseMap, + NgiotMapInfo, + NgiotOverlay, + NgiotPose, + NgiotPoint, + NgiotTrace, + normalize_point, + normalize_polygon, + normalize_pose, +) + + +@dataclass(slots=True) +class NgiotMapSnapshot: + """Aggregated state for a single NGIOT map ID.""" + + map_info: NgiotMapInfo | None = None + base_map: NgiotBaseMap | None = None + pose: NgiotPose | None = None + trace: NgiotTrace | None = None + areas: list[NgiotArea] = field(default_factory=list) + overlays: list[NgiotOverlay] = field(default_factory=list) + + @property + def map_id(self) -> str | None: + if self.map_info is not None: + return self.map_info.map_id + if self.base_map is not None: + return self.base_map.map_id + return None + + @property + def charge_pos(self) -> NgiotPoint | None: + if self.map_info is not None: + return self.map_info.charge_pos + return None + + def has_background(self) -> bool: + """Return True when a usable raster base map is present.""" + base_map = self.base_map + return bool( + base_map is not None + and base_map.encoded + and base_map.width > 0 + and base_map.height > 0 + ) + + def has_overlay_content(self) -> bool: + """Return True when overlay/state content exists for the map.""" + return bool( + self.areas + or self.overlays + or self.pose is not None + or self.charge_pos is not None + or ( + self.trace is not None + and ( + self.trace.total_count > 0 + or bool(self.trace.encoded) + ) + ) + ) + + def is_overlay_only(self) -> bool: + """Return True when only non-background map content exists.""" + return not self.has_background() and self.has_overlay_content() + + def is_renderable(self) -> bool: + """Return True when the snapshot can produce the intended visible map.""" + return self.has_background() + + +class NgiotMapStateStore: + """Per-device NGIOT map state store keyed by map_id.""" + + def __init__(self) -> None: + self._maps: dict[str, NgiotMapSnapshot] = {} + self._active_map_id: str | None = None + + @property + def active_map_id(self) -> str | None: + return self._active_map_id + + @property + def map_ids(self) -> tuple[str, ...]: + return tuple(self._maps.keys()) + + def clear(self) -> None: + self._maps.clear() + self._active_map_id = None + + def has_map(self, map_id: str) -> bool: + return map_id in self._maps + + def get(self, map_id: str) -> NgiotMapSnapshot: + if map_id not in self._maps: + self._maps[map_id] = NgiotMapSnapshot() + return self._maps[map_id] + + def get_if_present(self, map_id: str) -> NgiotMapSnapshot | None: + return self._maps.get(map_id) + + def get_active(self) -> NgiotMapSnapshot | None: + if self._active_map_id is None: + return None + return self._maps.get(self._active_map_id) + + def get_active_renderable(self) -> NgiotMapSnapshot | None: + """Return the active snapshot only if it is renderable.""" + snapshot = self.get_active() + if snapshot is None or not snapshot.is_renderable(): + return None + return snapshot + + def set_active_map_id(self, map_id: str | None) -> None: + if map_id: + self._active_map_id = map_id + self.get(map_id) + + def update_map_info(self, info: NgiotMapInfo) -> NgiotMapSnapshot: + snapshot = self.get(info.map_id) + snapshot.map_info = info + if info.using: + self._active_map_id = info.map_id + elif self._active_map_id is None: + self._active_map_id = info.map_id + return snapshot + + def update_map_infos(self, infos: list[NgiotMapInfo]) -> None: + for info in infos: + self.update_map_info(info) + + def update_base_map(self, base_map: NgiotBaseMap) -> NgiotMapSnapshot: + snapshot = self.get(base_map.map_id) + snapshot.base_map = base_map + if self._active_map_id is None: + self._active_map_id = base_map.map_id + return snapshot + + def update_pose(self, map_id: str, pose: NgiotPose) -> NgiotMapSnapshot: + snapshot = self.get(map_id) + snapshot.pose = pose + return snapshot + + def update_trace(self, map_id: str, trace: NgiotTrace) -> NgiotMapSnapshot: + snapshot = self.get(map_id) + snapshot.trace = trace + return snapshot + + def update_areas(self, map_id: str, areas: list[NgiotArea]) -> NgiotMapSnapshot: + snapshot = self.get(map_id) + snapshot.areas = areas + return snapshot + + def update_overlays( + self, map_id: str, overlays: list[NgiotOverlay] + ) -> NgiotMapSnapshot: + snapshot = self.get(map_id) + snapshot.overlays = overlays + return snapshot + + def get_normalized(self, map_id: str | None = None) -> NgiotMapSnapshot | None: + """Return a normalized copy of the snapshot for rendering. + + If the snapshot has no base map yet, the raw snapshot is returned. + """ + resolved_map_id = map_id or self._active_map_id + if resolved_map_id is None: + return None + + snapshot = self._maps.get(resolved_map_id) + if snapshot is None: + return None + + if snapshot.base_map is None: + return NgiotMapSnapshot( + map_info=snapshot.map_info, + base_map=snapshot.base_map, + pose=snapshot.pose, + trace=snapshot.trace, + areas=list(snapshot.areas), + overlays=list(snapshot.overlays), + ) + + base_map = snapshot.base_map + + normalized_map_info = snapshot.map_info + if snapshot.map_info is not None and snapshot.map_info.charge_pos is not None: + normalized_map_info = NgiotMapInfo( + map_id=snapshot.map_info.map_id, + name=snapshot.map_info.name, + using=snapshot.map_info.using, + angle=snapshot.map_info.angle, + charge_pos=normalize_point(snapshot.map_info.charge_pos, base_map), + ) + + normalized_pose = ( + normalize_pose(snapshot.pose, base_map) if snapshot.pose is not None else None + ) + + normalized_areas = [ + NgiotArea( + area_id=area.area_id, + name=area.name, + polygon=normalize_polygon(area.polygon, base_map), + raw=area.raw, + ) + for area in snapshot.areas + ] + + normalized_overlays = [ + NgiotOverlay( + overlay_type=overlay.overlay_type, + overlay_id=overlay.overlay_id, + polygon=normalize_polygon(overlay.polygon, base_map), + raw=overlay.raw, + ) + for overlay in snapshot.overlays + ] + + return NgiotMapSnapshot( + map_info=normalized_map_info, + base_map=base_map, + pose=normalized_pose, + trace=snapshot.trace, + areas=normalized_areas, + overlays=normalized_overlays, + ) \ No newline at end of file diff --git a/deebot_client/rs/map.pyi b/deebot_client/rs/map.pyi index dc2d8bab8..66aacc964 100644 --- a/deebot_client/rs/map.pyi +++ b/deebot_client/rs/map.pyi @@ -4,27 +4,56 @@ from typing import Self from deebot_client.events.map import MapSubsetEvent, Position class BackgroundImage: - """Map background image.""" + """Background image in rust.""" def update_map_piece(self, index: int, base64_data: str) -> bool: - """Update map piece.""" + """Update a map piece.""" def map_piece_crc32_indicates_update(self, index: int, crc32: int) -> bool: - """Return True if update is required.""" + """Return True when the piece should be refreshed.""" + +class NgiotBackground: + """NGIOT background placeholder.""" + + def set_map_data( + self, + encoded: str, + width: int, + height: int, + total_width: int, + total_height: int, + resolution: int, + x_min: int, + y_max: int, + direction: int, + ) -> bool: + """Store NGIOT background metadata and encoded payload.""" + + def clear(self) -> bool: + """Clear NGIOT background metadata.""" + + def has_map_data(self) -> bool: + """Return True if NGIOT background data is present.""" class TracePoints: """Trace points in rust.""" - def add(self, value: str) -> None: + def add(self, value: str, lz4_len: int | None = None) -> None: """Add trace points to the trace points object.""" def clear(self) -> None: - """Clear all trace points.""" + """Clear trace points.""" + + def use_legacy_scale(self) -> None: + """Use legacy trace scale.""" + + def use_world_scale(self) -> None: + """Use world-space trace scale.""" class MapInfo: """Map info.""" - def set(self, baset64_data: str) -> None: + def set(self, base64_data: str) -> None: """Set map info (base64-compressed JSON).""" class MapData: @@ -37,6 +66,10 @@ class MapData: def background_image(self) -> BackgroundImage: """Return background image.""" + @property + def ngiot_background(self) -> NgiotBackground: + """Return NGIOT background placeholder.""" + @property def map_info(self) -> MapInfo: """Return map info.""" @@ -45,11 +78,58 @@ class MapData: def trace_points(self) -> TracePoints: """Return trace points.""" + def set_map_info(self, base64_data: str) -> None: + """Compatibility wrapper for Python map.py.""" + + def set_ngiot_background( + self, + encoded: str, + width: int, + height: int, + total_width: int, + total_height: int, + resolution: int, + x_min: int, + y_max: int, + direction: int, + ) -> bool: + """Compatibility wrapper for Python map.py.""" + + def clear_ngiot_background(self) -> bool: + """Compatibility wrapper for Python map.py.""" + + def has_ngiot_background(self) -> bool: + """Compatibility wrapper for Python map.py.""" + + def add_trace_points(self, value: str, lz4_len: int | None = None) -> None: + """Compatibility wrapper for Python map.py.""" + + def clear_trace_points(self) -> None: + """Compatibility wrapper for Python map.py.""" + + def use_legacy_trace_scale(self) -> None: + """Compatibility wrapper for Python map.py.""" + + def use_world_trace_scale(self) -> None: + """Compatibility wrapper for Python map.py.""" + + def use_legacy_position_icon_scale(self) -> None: + """Use legacy position icon scale.""" + + def use_ngiot_position_icon_scale(self) -> None: + """Use NGIOT position icon scale.""" + + def use_legacy_position_transform(self) -> None: + """Use legacy position transform.""" + + def use_ngiot_position_transform(self) -> None: + """Use NGIOT position transform.""" + def generate_svg( self, subsets: list[MapSubsetEvent], position: list[Position], - rotation: RotationAngle, + rotation: "RotationAngle", ) -> str | None: """Generate SVG image.""" @@ -60,7 +140,7 @@ class PositionType(Enum): CHARGER = auto() @staticmethod - def from_str(value: str) -> PositionType: + def from_str(value: str) -> "PositionType": """Create a position type from string.""" class RotationAngle(Enum): @@ -72,5 +152,5 @@ class RotationAngle(Enum): DEG_270 = auto() @staticmethod - def from_int(value: int) -> RotationAngle: - """Create a rotation angle from integer.""" + def from_int(value: int) -> "RotationAngle": + """Create a rotation angle from integer.""" \ No newline at end of file diff --git a/deebot_client/sst_authentication.py b/deebot_client/sst_authentication.py new file mode 100644 index 000000000..1c3a8083c --- /dev/null +++ b/deebot_client/sst_authentication.py @@ -0,0 +1,309 @@ +"""SST authentication module for NGIOT endpoint-control devices.""" + +from __future__ import annotations + +import asyncio +import base64 +import json +import time +from collections.abc import Mapping +from dataclasses import dataclass +from http import HTTPStatus +from typing import TYPE_CHECKING, Any +from urllib.parse import urljoin + +from aiohttp import ClientResponseError, ClientSession, ClientTimeout, hdrs + +from .exceptions import ApiError, ApiTimeoutError, AuthenticationError +from .logging_filter import get_logger +from .util import cancel, create_task + +if TYPE_CHECKING: + from .authentication import Authenticator + from .models import ApiDeviceInfo, DeviceInfo + +_LOGGER = get_logger(__name__) + +_TIMEOUT = ClientTimeout(60) +_SST_ISSUE_PATH = "/api/new-perm/token/sst/issue" +_SST_SERVICE = "dim" +_SST_PERMISSION = "Control" + + +@dataclass(frozen=True) +class SstCredentials: + """Short-lived NGIOT SST credentials.""" + + token: str + expires_at: int + device_key: str + + +@dataclass(frozen=True) +class SstDeviceIdentity: + """Normalized device identity needed for SST minting.""" + + did: str + class_id: str + resource: str + control_host: str | None = None + + @property + def endpoint(self) -> str: + """Return DIM ACL endpoint identifier.""" + return f"Endpoint:{self.class_id}:{self.did}" + + @property + def key(self) -> str: + """Return stable cache key.""" + return f"{self.class_id}:{self.did}:{self.resource}" + + +class SstAuthenticator: + """Mint, cache, and refresh short-lived SST credentials per device.""" + + def __init__( + self, + session: ClientSession, + authenticator: Authenticator, + *, + base_url: str, + requested_ttl: int = 600, + refresh_skew: int = 60, + ) -> None: + self._session = session + self._authenticator = authenticator + self._base_url = base_url.rstrip("/") + self._requested_ttl = requested_ttl + self._refresh_skew = refresh_skew + + self._lock = asyncio.Lock() + self._credentials: dict[str, SstCredentials] = {} + self._devices: dict[str, SstDeviceIdentity] = {} + self._refresh_handles: dict[str, asyncio.TimerHandle] = {} + self._tasks: set[asyncio.Future[Any]] = set() + + async def get_credentials( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + *, + force: bool = False, + ) -> SstCredentials: + """Return cached SST credentials for a device, refreshing if needed.""" + identity = self._normalize_device(device) + + async with self._lock: + cached = self._credentials.get(identity.key) + now = int(time.time()) + + if ( + not force + and cached is not None + and cached.expires_at > now + self._refresh_skew + ): + return cached + + credentials = await self._issue_sst(identity) + self._devices[identity.key] = identity + self._credentials[identity.key] = credentials + + self._cancel_refresh_task(identity.key) + self._create_refresh_task(identity, credentials) + + return credentials + + async def get_token( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + *, + force: bool = False, + ) -> str: + """Return SST bearer token for a device.""" + return (await self.get_credentials(device, force=force)).token + + async def invalidate( + self, + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any] | str, + ) -> None: + """Invalidate cached SST for a device or cache key.""" + key = device if isinstance(device, str) else self._normalize_device(device).key + + async with self._lock: + self._credentials.pop(key, None) + self._devices.pop(key, None) + self._cancel_refresh_task(key) + + async def teardown(self) -> None: + """Teardown authenticator and cancel outstanding refresh tasks.""" + for key in list(self._refresh_handles): + self._cancel_refresh_task(key) + + self._credentials.clear() + self._devices.clear() + await cancel(self._tasks) + + async def _issue_sst(self, identity: SstDeviceIdentity) -> SstCredentials: + """Mint a fresh SST for the given device.""" + account_credentials = await self._authenticator.authenticate() + + headers = { + hdrs.AUTHORIZATION: f"Bearer {account_credentials.token}", + hdrs.CONTENT_TYPE: "application/json; charset=utf-8", + } + payload = { + "acl": [ + { + "policy": [ + { + "obj": [identity.endpoint], + "perms": [_SST_PERMISSION], + } + ], + "svc": _SST_SERVICE, + } + ], + "exp": self._requested_ttl, + "sub": account_credentials.user_id, + } + + url = urljoin(self._base_url, _SST_ISSUE_PATH) + logger_request_params = { + "url": url, + "device_key": identity.key, + "payload": payload, + } + + try: + _LOGGER.debug("Calling SST issue endpoint: %s", logger_request_params) + + async with self._session.post( + url, + json=payload, + headers=headers, + timeout=_TIMEOUT, + ) as res: + res.raise_for_status() + content_type = res.headers.get(hdrs.CONTENT_TYPE, "").lower() + response_data: dict[str, Any] = await res.json( + content_type=content_type or None + ) + + _LOGGER.debug( + "SST issue response for %s: %s", identity.key, response_data + ) + + if response_data.get("code") not in (0, "0000"): + msg = ( + f"failure code {response_data.get('code')} " + f"({response_data.get('msg')}) for call {_SST_ISSUE_PATH}" + ) + raise AuthenticationError(msg) + + token = str(response_data["data"]["data"]["token"]) + expires_at = self._decode_exp(token) + if expires_at is None: + # Fallback if the token format changes or exp is absent. + expires_at = int(time.time()) + max(60, int(self._requested_ttl * 0.9)) + + return SstCredentials( + token=token, + expires_at=expires_at, + device_key=identity.key, + ) + + except TimeoutError as ex: + raise ApiTimeoutError(path=_SST_ISSUE_PATH, timeout=_TIMEOUT) from ex + except ClientResponseError as ex: + if ex.status in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN): + raise AuthenticationError("SST issue request was not authorized") from ex + raise ApiError from ex + + def _create_refresh_task( + self, + identity: SstDeviceIdentity, + credentials: SstCredentials, + ) -> None: + """Create refresh task for a given SST credential.""" + + def refresh() -> None: + _LOGGER.debug("Refreshing SST for %s", identity.key) + + async def async_refresh() -> None: + try: + await self.get_credentials(identity_as_mapping(identity), force=True) + except Exception: + _LOGGER.exception( + "An exception occurred during SST refresh for %s", + identity.key, + ) + + create_task(self._tasks, async_refresh()) + self._refresh_handles.pop(identity.key, None) + + seconds_until_refresh = max( + 5, + credentials.expires_at - int(time.time()) - self._refresh_skew, + ) + self._refresh_handles[identity.key] = asyncio.get_event_loop().call_later( + seconds_until_refresh, + refresh, + ) + + def _cancel_refresh_task(self, key: str) -> None: + """Cancel refresh timer for a cache key.""" + handle = self._refresh_handles.pop(key, None) + if handle and not handle.cancelled(): + handle.cancel() + + @staticmethod + def _decode_exp(token: str) -> int | None: + """Decode exp claim from SST token without signature validation.""" + try: + parts = token.split(".") + if len(parts) == 4 and parts[0] == "SST": + payload_segment = parts[2] + elif len(parts) == 3: + payload_segment = parts[1] + else: + return None + + padded = payload_segment + "=" * (-len(payload_segment) % 4) + payload = json.loads(base64.urlsafe_b64decode(padded).decode("utf-8")) + exp = payload.get("exp") + + return int(exp) if exp is not None else None + except Exception: + _LOGGER.debug("Failed to decode SST token expiry", exc_info=True) + return None + + @staticmethod + def _normalize_device( + device: ApiDeviceInfo | DeviceInfo | Mapping[str, Any], + ) -> SstDeviceIdentity: + """Normalize a device object into the fields required for SST issuance.""" + raw_device = device.api if hasattr(device, "api") else device + + if not isinstance(raw_device, Mapping): + msg = f"Unsupported device type for SST authentication: {type(device)!r}" + raise TypeError(msg) + + service = raw_device.get("service", {}) + control_host = ( + service.get("mqs") if isinstance(service, Mapping) else None + ) + + return SstDeviceIdentity( + did=str(raw_device["did"]), + class_id=str(raw_device["class"]), + resource=str(raw_device["resource"]), + control_host=str(control_host) if control_host else None, + ) + + +def identity_as_mapping(identity: SstDeviceIdentity) -> dict[str, str]: + """Convert normalized identity back to a mapping accepted by get_credentials.""" + return { + "did": identity.did, + "class": identity.class_id, + "resource": identity.resource, + } \ No newline at end of file diff --git a/src/map/background_image.rs b/src/map/background_image.rs index 139553e9b..0d29acae3 100644 --- a/src/map/background_image.rs +++ b/src/map/background_image.rs @@ -1,4 +1,5 @@ -use super::{ImageGenrationType, ViewBox, decompress_base64_data}; +use super::{ImageGenrationType, ViewBox}; +use crate::util::decompress_base64_data; use base64::Engine; use base64::engine::general_purpose; use crc32fast::Hasher; @@ -105,8 +106,8 @@ impl BackgroundImage { .view( min_x.into(), min_y.into(), - view_box.width.into(), - view_box.height.into(), + view_box.width as u32, + view_box.height as u32, ) .to_image(); @@ -229,4 +230,4 @@ mod tests { assert!(map_piece.pixels_indexed.is_none()); assert!(!map_piece.update_points(data).unwrap()); } -} +} \ No newline at end of file diff --git a/src/map/map_info.rs b/src/map/map_info.rs index 5bd47bec8..e32e02d12 100644 --- a/src/map/map_info.rs +++ b/src/map/map_info.rs @@ -1,5 +1,6 @@ use super::style::{CSSClass, ROOM_COLORS, get_class_names, get_style}; -use super::{RotationAngle, ViewBox, calc_point, decompress_base64_data}; +use super::{RotationAngle, ViewBox, calc_point}; +use crate::util::decompress_base64_data; use super::points::{Point, points_to_svg_path}; use ordermap::OrderSet; @@ -148,6 +149,21 @@ impl MapInfo { Some((svg_elements, viewbox?, used_styles)) } + pub(super) fn set_map_info(&mut self, base64_data: String) -> PyResult<()> { + let raw = decompress_base64_data(&base64_data).map_err( + |err: Box| PyValueError::new_err(err.to_string()), + )?; + let entries: Vec = serde_json::from_slice(&raw) + .map_err(|err| PyValueError::new_err(format!("Invalid map info: {err}")))?; + + entries.into_iter().for_each(|MapInfoTypeEntry(t, v)| { + if !v.is_empty() { + self.data.insert(t, v); + } + }); + Ok(()) + } + fn get_order(&self) -> Vec { if self.data.contains_key(&MapInfoType::BlockLine) { vec![ @@ -212,16 +228,7 @@ impl MapInfo { #[pymethods] impl MapInfo { fn set(&mut self, base64_data: String) -> PyResult<()> { - let raw = decompress_base64_data(&base64_data) - .map_err(|err| PyValueError::new_err(err.to_string()))?; - let entries: Vec = serde_json::from_slice(&raw) - .map_err(|err| PyValueError::new_err(format!("Invalid map info: {err}")))?; - entries.into_iter().for_each(|MapInfoTypeEntry(t, v)| { - if !v.is_empty() { - self.data.insert(t, v); - } - }); - Ok(()) + self.set_map_info(base64_data) } } @@ -311,9 +318,12 @@ fn calc_viewbox(outlines: &[MapInfoTypeDataEntry]) -> Option { .for_each(|e| minmax_points(e.points.iter(), &mut bounds)); let (min_x_f, min_y_f, max_x_f, max_y_f) = bounds?; - let (min_x, min_y) = (min_x_f.round() as i16, min_y_f.round() as i16); - let (max_x, max_y) = (max_x_f.round() as i16, max_y_f.round() as i16); - let (width, height) = ((max_x - min_x).max(1) as u16, (max_y - min_y).max(1) as u16); + let min_x = min_x_f.round(); + let min_y = min_y_f.round(); + let max_x = max_x_f.round(); + let max_y = max_y_f.round(); + let width = (max_x - min_x).max(1.0); + let height = (max_y - min_y).max(1.0); Some(ViewBox { min_x, @@ -371,4 +381,4 @@ mod tests { "Empty map info entry at line 1 column 4" ); } -} +} \ No newline at end of file diff --git a/src/map/mod.rs b/src/map/mod.rs index b4f50d139..6637a157f 100644 --- a/src/map/mod.rs +++ b/src/map/mod.rs @@ -1,17 +1,18 @@ mod background_image; mod common; mod map_info; +mod ngiot_background; mod points; mod style; use background_image::{BackgroundImage, MAP_MAX_SIZE}; use common::round; use map_info::MapInfo; +use ngiot_background::NgiotBackground; use ordermap::OrderSet; -use points::{Point, TracePoints, points_to_svg_path}; -use style::{CSSClass, get_class_names, get_style, get_used_definitions}; +use points::{points_to_svg_path, Point, TracePoints}; +use style::{get_class_names, get_style, get_used_definitions, CSSClass}; -use super::util::decompress_base64_data; use log::debug; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -20,9 +21,11 @@ use svg::node::element::{ }; use svg::{Document, Node}; -const PIXEL_WIDTH: f32 = 50.0; +pub(super) const PIXEL_WIDTH: f32 = 50.0; const ROUND_TO_DIGITS: usize = 3; const MAP_OFFSET: i16 = MAP_MAX_SIZE as i16 / 2; +const LEGACY_POSITION_ICON_SCALE: f32 = 1.0; +const NGIOT_POSITION_ICON_SCALE: f32 = 0.18; #[inline] fn calc_point(x: f32, y: f32, rotation: RotationAngle) -> Point { @@ -39,11 +42,7 @@ fn calc_point(x: f32, y: f32, rotation: RotationAngle) -> Point { } } -fn get_svg_subset(subset: &MapSubset, rotation: RotationAngle) -> PyResult<(CSSClass, Path)> { - debug!("Adding subset: {subset:?}"); - - // Estimate capacity: each point consists of an x and y coordinate, separated by commas. - // So, the number of points is half the number of comma-separated values. +fn get_subset_points(subset: &MapSubset, rotation: RotationAngle) -> Vec { let num_coords = subset.coordinates.split(',').count(); let mut points = Vec::with_capacity(num_coords / 2); @@ -61,18 +60,39 @@ fn get_svg_subset(subset: &MapSubset, rotation: RotationAngle) -> PyResult<(CSSC points.push(calc_point(x, y, rotation)); } - let css_key = match subset.set_type.as_str() { - "vw" => CSSClass::VirtualWall, - "mw" => CSSClass::NoMoppingWall, + points +} + +fn get_svg_subset( + subset: &MapSubset, + rotation: RotationAngle, +) -> PyResult<(Vec, Path)> { + debug!("Adding subset: {subset:?}"); + + let points = get_subset_points(subset, rotation); + let close_path = points.len() > 2; + + let css = match subset.set_type.as_str() { + "ar" => vec![CSSClass::RoomSubset], + "vw" => vec![ + CSSClass::WallBase, + CSSClass::StrokeWidth2, + CSSClass::VirtualWall, + ], + "mw" => vec![ + CSSClass::WallBase, + CSSClass::StrokeWidth2, + CSSClass::NoMoppingWall, + ], + "cp" => vec![CSSClass::CarpetArea], _ => return Err(PyValueError::new_err("Invalid set type")), }; - let css_obj = get_style(&css_key); - let svg_object = points_to_svg_path(&points, points.len() > 2, false) - .unwrap() - .set("class", css_obj.class_name); + let svg_object = points_to_svg_path(&points, close_path, false) + .ok_or_else(|| PyValueError::new_err("Subset does not contain enough points"))? + .set("class", get_class_names(&css)); - Ok((css_key, svg_object)) + Ok((css, svg_object)) } #[pyclass(from_py_object, eq, eq_int)] @@ -172,6 +192,29 @@ fn calc_point_in_viewbox(x: i32, y: i32, viewbox: &ViewBox, rotation: RotationAn } } +#[inline] +fn calc_ngiot_local_point_in_viewbox( + x: i32, + y: i32, + origin: (i32, i32), + viewbox: &ViewBox, + rotation: RotationAngle, + overlay_svg_offset: Option<(f32, f32)>, +) -> Point { + let world_x = origin.0 as f32 + x as f32; + let world_y = origin.1 as f32 + y as f32; + let mut point = calc_point(world_x, world_y, rotation); + if let Some((dx, dy)) = overlay_svg_offset { + point.x += dx; + point.y += dy; + } + Point { + x: point.x.max(viewbox.min_x as f32).min(viewbox.max_x as f32), + y: point.y.max(viewbox.min_y as f32).min(viewbox.max_y as f32), + connected: false, + } +} + #[derive(FromPyObject, Debug)] /// Map subset event struct MapSubset { @@ -180,6 +223,49 @@ struct MapSubset { coordinates: String, } +fn calc_fallback_viewbox( + subsets: &[MapSubset], + positions: &[Position], + rotation: RotationAngle, +) -> Option { + let mut min_x = f32::MAX; + let mut min_y = f32::MAX; + let mut max_x = f32::MIN; + let mut max_y = f32::MIN; + let mut found = false; + + for subset in subsets { + for point in get_subset_points(subset, rotation) { + min_x = min_x.min(point.x); + min_y = min_y.min(point.y); + max_x = max_x.max(point.x); + max_y = max_y.max(point.y); + found = true; + } + } + + for position in positions { + let point = calc_point(position.x as f32, position.y as f32, rotation); + min_x = min_x.min(point.x); + min_y = min_y.min(point.y); + max_x = max_x.max(point.x); + max_y = max_y.max(point.y); + found = true; + } + + if !found { + return None; + } + + let margin = 5.0; + Some(ViewBox::from_extents( + min_x.floor() - margin, + min_y.floor() - margin, + max_x.ceil() + margin, + max_y.ceil() + margin, + )) +} + #[pyclass] struct MapData { #[pyo3(get)] @@ -187,7 +273,11 @@ struct MapData { #[pyo3(get)] background_image: Py, #[pyo3(get)] + ngiot_background: Py, + #[pyo3(get)] map_info: Py, + position_icon_scale: f32, + use_ngiot_position_transform: bool, } #[pymethods] @@ -197,10 +287,89 @@ impl MapData { Ok(MapData { trace_points: Py::new(py, TracePoints::new())?, background_image: Py::new(py, BackgroundImage::new())?, + ngiot_background: Py::new(py, NgiotBackground::new())?, map_info: Py::new(py, MapInfo::new())?, + position_icon_scale: LEGACY_POSITION_ICON_SCALE, + use_ngiot_position_transform: false, }) } + fn use_legacy_position_icon_scale(&mut self) { + self.position_icon_scale = LEGACY_POSITION_ICON_SCALE; + } + + fn use_ngiot_position_icon_scale(&mut self) { + self.position_icon_scale = NGIOT_POSITION_ICON_SCALE; + } + + fn use_legacy_position_transform(&mut self) { + self.use_ngiot_position_transform = false; + } + + fn use_ngiot_position_transform(&mut self) { + self.use_ngiot_position_transform = true; + } + + fn set_map_info(&mut self, py: Python<'_>, base64_data: String) -> PyResult<()> { + self.map_info.borrow_mut(py).set_map_info(base64_data) + } + + fn set_ngiot_background( + &mut self, + py: Python<'_>, + encoded: String, + width: u16, + height: u16, + total_width: u16, + total_height: u16, + resolution: i32, + x_min: i32, + y_max: i32, + direction: i32, + ) -> bool { + self.ngiot_background.borrow_mut(py).set_background_data( + encoded, + width, + height, + total_width, + total_height, + resolution, + x_min, + y_max, + direction, + ) + } + + fn clear_ngiot_background(&mut self, py: Python<'_>) -> bool { + self.ngiot_background.borrow_mut(py).clear_background_data() + } + + fn has_ngiot_background(&self, py: Python<'_>) -> bool { + self.ngiot_background.borrow(py).has_data() + } + + #[pyo3(signature = (value, lz4_len=None))] + fn add_trace_points( + &mut self, + py: Python<'_>, + value: String, + lz4_len: Option, + ) -> PyResult<()> { + self.trace_points.borrow_mut(py).add_points(value, lz4_len) + } + + fn clear_trace_points(&mut self, py: Python<'_>) { + self.trace_points.borrow_mut(py).clear_points(); + } + + fn use_legacy_trace_scale(&mut self, py: Python<'_>) { + self.trace_points.borrow_mut(py).use_legacy_trace_scale(); + } + + fn use_world_trace_scale(&mut self, py: Python<'_>) { + self.trace_points.borrow_mut(py).use_world_trace_scale(); + } + fn generate_svg( &self, py: Python<'_>, @@ -208,9 +377,17 @@ impl MapData { positions: Vec, rotation: RotationAngle, ) -> PyResult> { + let position_icon_scale = self.position_icon_scale; + let ngiot_background = self.ngiot_background.borrow(py); + let ngiot_position_origin = if self.use_ngiot_position_transform { + ngiot_background.position_origin() + } else { + None + }; + let ngiot_overlay_offset = ngiot_background.overlay_svg_offset(); + let mut defs = Definitions::new() .add( - // Gradient used by Bot icon RadialGradient::new() .set("id", "dbg") .set("cx", "50%") @@ -230,9 +407,9 @@ impl MapData { ), ) .add( - // Bot circular icon Group::new() .set("id", PositionType::Deebot.svg_use_id()) + .set("transform", format!("scale({position_icon_scale})")) .add(Circle::new().set("r", 5).set("fill", "url(#dbg)")) .add( Circle::new() @@ -243,13 +420,11 @@ impl MapData { ), ) .add( - // Charger pin icon (pre-flipped vertically) Group::new() .set("id", PositionType::Charger.svg_use_id()) + .set("transform", format!("scale({position_icon_scale})")) .add(Path::new().set("fill", "#ffe605").set( "d", - // Path data cannot be used as it's adds a , after each parameter - // and repeats the command when used sequentially "M4-6.4C4-4.2 0 0 0 0s-4-4.2-4-6.4 1.8-4 4-4 4 1.8 4 4z", )) .add( @@ -265,20 +440,33 @@ impl MapData { let mut document = Document::new(); - // Create map from MapInfo, if exists, or generate background image let viewbox = match self.map_info.borrow(py).generate(rotation) { Some((map_elements, viewbox, info_styles)) => { - // Append all map background elements to document map_elements.into_iter().for_each(|e| document.append(e)); styles.extend(info_styles); viewbox } _ => { - if let Some((base64_image, viewbox)) = - self.background_image - .borrow(py) - .generate() - .map_err(|err| PyValueError::new_err(err.to_string()))? + if let Some((base64_image, viewbox)) = self + .ngiot_background + .borrow(py) + .generate() + .map_err(|err| PyValueError::new_err(err.to_string()))? + { + let image = Image::new() + .set("x", viewbox.min_x) + .set("y", viewbox.min_y) + .set("width", viewbox.width) + .set("height", viewbox.height) + .set("style", "image-rendering: pixelated") + .set("href", format!("data:image/png;base64,{base64_image}")); + document.append(image); + viewbox + } else if let Some((base64_image, viewbox)) = self + .background_image + .borrow(py) + .generate() + .map_err(|err| PyValueError::new_err(err.to_string()))? { let image = Image::new() .set("x", viewbox.min_x) @@ -289,38 +477,46 @@ impl MapData { .set("href", format!("data:image/png;base64,{base64_image}")); document.append(image); viewbox + } else if let Some(viewbox) = calc_fallback_viewbox(&subsets, &positions, rotation) + { + viewbox } else { return Ok(None); } } }; - // Add required definitions based on used CSS classes - get_used_definitions(&styles) - .into_iter() - .for_each(|def| defs.append(def)); - - document = document.add(defs).set("viewBox", viewbox.to_svg_viewbox()); - - if !subsets.is_empty() { - let group_css = [CSSClass::WallBase, CSSClass::StrokeWidth2]; - let mut group = Group::new().set("class", get_class_names(&group_css)); - styles.extend(group_css); - - for subset in &subsets { - let (css, subset) = get_svg_subset(subset, rotation)?; - styles.insert(css); - group = group.add(subset); - } - document.append(group); + for subset in &subsets { + let (css_list, path) = get_svg_subset(subset, rotation)?; + styles.extend(css_list); + document.append(path); } - if let Some(trace) = self.trace_points.borrow(py).get_path(rotation) { + + if let Some(trace) = self + .trace_points + .borrow(py) + .get_path(rotation, ngiot_position_origin, ngiot_overlay_offset) + { document.append(trace); } - for position in get_svg_positions(&positions, &viewbox, rotation) { + + for position in get_svg_positions( + &positions, + &viewbox, + rotation, + ngiot_position_origin, + ngiot_overlay_offset, + self.use_ngiot_position_transform, + ) { document.append(position); } + get_used_definitions(&styles) + .into_iter() + .for_each(|def| defs.append(def)); + + document = document.add(defs).set("viewBox", viewbox.to_svg_viewbox()); + let mut style_string = String::new(); for k in styles { let css = get_style(&k); @@ -337,27 +533,41 @@ impl MapData { } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] struct ViewBox { - min_x: i16, - min_y: i16, - max_x: i16, - max_y: i16, - width: u16, - height: u16, + min_x: f32, + min_y: f32, + max_x: f32, + max_y: f32, + width: f32, + height: f32, } impl ViewBox { fn new(min_x: u16, min_y: u16, max_x: u16, max_y: u16) -> Self { - let new_min_x = min_x as i16 - MAP_OFFSET; - let new_min_y = min_y as i16 - MAP_OFFSET; - let width = max_x - min_x + 1; - let height = max_y - min_y + 1; + let new_min_x = min_x as f32 - MAP_OFFSET as f32; + let new_min_y = min_y as f32 - MAP_OFFSET as f32; + let width = (max_x - min_x + 1) as f32; + let height = (max_y - min_y + 1) as f32; ViewBox { min_x: new_min_x, min_y: new_min_y, - max_x: new_min_x + width as i16, - max_y: new_min_y + height as i16, + max_x: new_min_x + width, + max_y: new_min_y + height, + width, + height, + } + } + + fn from_extents(min_x: f32, min_y: f32, max_x: f32, max_y: f32) -> Self { + let width = (max_x - min_x).max(1.0); + let height = (max_y - min_y).max(1.0); + + ViewBox { + min_x, + min_y, + max_x, + max_y, width, height, } @@ -367,7 +577,10 @@ impl ViewBox { fn to_svg_viewbox(&self) -> String { format!( "{} {} {} {}", - self.min_x, self.min_y, self.width, self.height + round(self.min_x, ROUND_TO_DIGITS), + round(self.min_y, ROUND_TO_DIGITS), + round(self.width, ROUND_TO_DIGITS), + round(self.height, ROUND_TO_DIGITS) ) } } @@ -378,12 +591,14 @@ fn get_svg_positions( positions: &[Position], viewbox: &ViewBox, rotation: RotationAngle, + ngiot_position_origin: Option<(i32, i32)>, + ngiot_overlay_offset: Option<(f32, f32)>, + use_ngiot_position_transform: bool, ) -> Vec { if positions.is_empty() { return Vec::new(); } - // Create indices and sort them instead of collecting references let mut indices: Vec = (0..positions.len()).collect(); indices.sort_by_key(|&i| positions[i].position_type.order()); @@ -393,7 +608,24 @@ fn get_svg_positions( for &i in &indices { let position = &positions[i]; - let pos = calc_point_in_viewbox(position.x, position.y, viewbox, rotation); + let pos = match (ngiot_position_origin, use_ngiot_position_transform) { + (Some(origin), true) => calc_ngiot_local_point_in_viewbox( + position.x, + position.y, + origin, + viewbox, + rotation, + ngiot_overlay_offset, + ), + _ => { + let mut point = calc_point_in_viewbox(position.x, position.y, viewbox, rotation); + if let Some((dx, dy)) = ngiot_overlay_offset { + point.x += dx; + point.y += dy; + } + point + } + }; svg_positions.push( Use::new() @@ -419,144 +651,27 @@ mod tests { fn tuple_2_view_box(tuple: (i16, i16, u16, u16)) -> ViewBox { ViewBox { - min_x: tuple.0, - min_y: tuple.1, - max_x: tuple.0 + tuple.2 as i16, - max_y: tuple.1 + tuple.3 as i16, - width: tuple.2, - height: tuple.3, + min_x: tuple.0 as f32, + min_y: tuple.1 as f32, + max_x: tuple.0 as f32 + tuple.2 as f32, + max_y: tuple.1 as f32 + tuple.3 as f32, + width: tuple.2 as f32, + height: tuple.3 as f32, } } #[rstest] #[case((-100, -100, 200, 150))] #[case((0, 0, 1000, 1000))] - #[case( (0, 0, 1000, 1000))] - #[case( (-500, -500, 1000, 1000))] - fn test_tuple_2_view_box(#[case] input: (i16, i16, u16, u16)) { - let result = tuple_2_view_box(input); - assert_eq!( - input, - (result.min_x, result.min_y, result.width, result.height,) - ); - } - - #[rstest] - #[case(5000.0, 0.0, RotationAngle::Deg0, Point { x:100.0, y:0.0, connected:true })] - #[case(20010.0, -29900.0, RotationAngle::Deg0, Point { x: 400.2, y: 598.0, connected:true })] - #[case(0.0, 29900.0, RotationAngle::Deg0, Point { x: 0.0, y: -598.0, connected:true })] - #[case(5000.0, 0.0, RotationAngle::Deg90, Point { x:0.0, y:100.0, connected:true })] - #[case(20010.0, -29900.0, RotationAngle::Deg90, Point { x: -598.0, y: 400.2, connected:true })] - #[case(5000.0, 0.0, RotationAngle::Deg180, Point { x:-100.0, y:0.0, connected:true })] - #[case(20010.0, -29900.0, RotationAngle::Deg180, Point { x: -400.2, y: -598.0, connected:true })] - #[case(5000.0, 0.0, RotationAngle::Deg270, Point { x:0.0, y:-100.0, connected:true })] - #[case(20010.0, -29900.0, RotationAngle::Deg270, Point { x: 598.0, y: -400.2, connected:true })] - fn test_calc_point( - #[case] x: f32, - #[case] y: f32, - #[case] rotation: RotationAngle, - #[case] expected: Point, - ) { - let result = calc_point(x, y, rotation); - assert_eq!(result, expected); - } - - #[rstest] - #[case(100, 100, (-100, -100, 200, 150), RotationAngle::Deg0, Point { x: 2.0, y: -2.0, connected: false })] - #[case(-64000, -64000, (0, 0, 1000, 1000), RotationAngle::Deg0, Point { x: 0.0, y: 1000.0, connected: false })] - #[case(64000, 64000, (0, 0, 1000, 1000), RotationAngle::Deg0, Point { x: 1000.0, y: 0.0, connected: false })] - #[case(0, 1000, (-500, -500, 1000, 1000), RotationAngle::Deg0, Point { x: 0.0, y: -20.0, connected: false })] - #[case(100, 100, (-100, -100, 200, 150), RotationAngle::Deg90, Point { x: 2.0, y: 2.0, connected: false })] - #[case(100, 100, (-100, -100, 200, 150), RotationAngle::Deg180, Point { x: -2.0, y: 2.0, connected: false })] - #[case(100, 100, (-100, -100, 200, 150), RotationAngle::Deg270, Point { x: -2.0, y: -2.0, connected: false })] - fn test_calc_point_in_viewbox( - #[case] x: i32, - #[case] y: i32, - #[case] viewbox: (i16, i16, u16, u16), - #[case] rotation: RotationAngle, - #[case] expected: Point, - ) { - let result = calc_point_in_viewbox(x, y, &tuple_2_view_box(viewbox), rotation); - assert_eq!(result, expected); - } - - #[rstest] - #[case(&[Position{position_type:PositionType::Deebot, x:5000, y:-55000}], RotationAngle::Deg0, "")] - #[case(&[Position{position_type:PositionType::Deebot, x:15000, y:15000}], RotationAngle::Deg0, "")] - #[case(&[Position{position_type:PositionType::Charger, x:25000, y:55000}, Position{position_type:PositionType::Deebot, x:-5000, y:-50000}], RotationAngle::Deg0, "")] - #[case(&[Position{position_type:PositionType::Deebot, x:-10000, y:10000}, Position{position_type:PositionType::Charger, x:50000, y:5000}], RotationAngle::Deg0, "")] - #[case(&[Position{position_type:PositionType::Deebot, x:5000, y:-55000}], RotationAngle::Deg90, "")] - #[case(&[Position{position_type:PositionType::Deebot, x:5000, y:-55000}], RotationAngle::Deg180, "")] - #[case(&[Position{position_type:PositionType::Deebot, x:5000, y:-55000}], RotationAngle::Deg270, "")] - fn test_get_svg_positions( - #[case] positions: &[Position], - #[case] rotation: RotationAngle, - #[case] expected: String, - ) { - let viewbox = (-500, -500, 1000, 1000); - let result = get_svg_positions(positions, &tuple_2_view_box(viewbox), rotation) - .iter() - .map(|u| u.to_string()) - .collect::>() - .join(""); - assert_eq!(result, expected); - } - - #[rstest] - #[case(MapSubset{set_type:"vw".to_string(), coordinates:"[-3900,668,-2133,668]".to_string()}, RotationAngle::Deg0, "")] - #[case(MapSubset{set_type:"mw".to_string(), coordinates:"[-442,2910,-442,982,1214,982,1214,2910]".to_string()}, RotationAngle::Deg0, "")] - #[case(MapSubset{set_type:"vw".to_string(), coordinates:"['12023', '1979', '12135', '-6720']".to_string()}, RotationAngle::Deg0, "")] - #[case(MapSubset{set_type:"vw".to_string(), coordinates:"['12023', '1979', , '', '12135', '-6720']".to_string()}, RotationAngle::Deg0, "")] - #[case(MapSubset{set_type:"vw".to_string(), coordinates:"[-3900,668,-2133,668]".to_string()}, RotationAngle::Deg90, "")] - #[case(MapSubset{set_type:"vw".to_string(), coordinates:"[-3900,668,-2133,668]".to_string()}, RotationAngle::Deg180, "")] - #[case(MapSubset{set_type:"vw".to_string(), coordinates:"[-3900,668,-2133,668]".to_string()}, RotationAngle::Deg270, "")] - fn test_get_svg_subset( - #[case] subset: MapSubset, - #[case] rotation: RotationAngle, - #[case] expected: String, - ) { - let (_, node) = get_svg_subset(&subset, rotation).unwrap(); - - assert_eq!(node.to_string(), expected); - } - - #[rstest] - #[case("deebotPos", PositionType::Deebot)] - #[case("chargePos", PositionType::Charger)] - fn test_position_type_from_str(#[case] value: &str, #[case] expected: PositionType) { - let result = PositionType::from_str(value).unwrap(); - assert_eq!(result, expected); - } - - #[test] - fn test_position_type_from_str_invalid() { - let result = PositionType::from_str("invalid"); - assert!(result.is_err()); - } - - #[rstest] - #[case(0, RotationAngle::Deg0)] - #[case(90, RotationAngle::Deg90)] - #[case(180, RotationAngle::Deg180)] - #[case(270, RotationAngle::Deg270)] - fn test_rotation_angle_from_int_valid(#[case] value: i16, #[case] expected: RotationAngle) { - let result = RotationAngle::from_int(value).unwrap(); - assert_eq!(result, expected); - } - - #[rstest] - #[case(45)] - #[case(360)] - #[case(-90)] - #[case(100)] - fn test_rotation_angle_from_int_invalid(#[case] value: i16) { - let result = RotationAngle::from_int(value); - assert!(result.is_err()); - } - - #[test] - fn test_rotation_angle_default() { - let rotation = RotationAngle::default(); - assert_eq!(rotation, RotationAngle::Deg0); + #[case((0, 0, 1000, 1000))] + #[case((-500, -500, 1000, 1000))] + fn test_tuple_2_view_box(#[case] tuple: (i16, i16, u16, u16)) { + let viewbox = tuple_2_view_box(tuple); + assert_eq!(viewbox.min_x, tuple.0 as f32); + assert_eq!(viewbox.min_y, tuple.1 as f32); + assert_eq!(viewbox.width, tuple.2 as f32); + assert_eq!(viewbox.height, tuple.3 as f32); + assert_eq!(viewbox.max_x, tuple.0 as f32 + tuple.2 as f32); + assert_eq!(viewbox.max_y, tuple.1 as f32 + tuple.3 as f32); } -} +} \ No newline at end of file diff --git a/src/map/ngiot_background.rs b/src/map/ngiot_background.rs new file mode 100644 index 000000000..ce1a33837 --- /dev/null +++ b/src/map/ngiot_background.rs @@ -0,0 +1,222 @@ +use super::{ImageGenrationType, ViewBox}; +use crate::util::decompress_base64_lz4_data; +use base64::Engine; +use base64::engine::general_purpose; +use log::debug; +use png::{BitDepth, ColorType, Compression, Encoder}; +use pyo3::prelude::*; + +const WORLD_PIXEL_WIDTH: f32 = 50.0; + +/// Shared NGIOT overlay calibration applied after normalization into cropped-raster space. +/// +/// Sign convention matches the render notes: +/// - X negative => move overlays left +/// - X positive => move overlays right +/// - Y negative => move overlays up +/// - Y positive => move overlays down +/// +/// These are raster-cell offsets, not raw world-coordinate offsets. They therefore scale with +/// the map resolution instead of drifting when the visible crop size changes. +const OVERLAY_OFFSET_X: i32 = -7; +const OVERLAY_OFFSET_Y: i32 = -9; + +#[derive(Debug, Clone, PartialEq, Eq)] +struct NgiotBackgroundData { + encoded: String, + width: u16, + height: u16, + total_width: u16, + total_height: u16, + resolution: i32, + x_min: i32, + y_max: i32, + direction: i32, +} + +#[pyclass] +pub(crate) struct NgiotBackground { + data: Option, +} + +impl NgiotBackground { + pub(crate) fn new() -> Self { + Self { data: None } + } + + pub(crate) fn position_origin(&self) -> Option<(i32, i32)> { + self.data.as_ref().map(|data| (data.x_min, data.y_max)) + } + + pub(crate) fn overlay_svg_offset(&self) -> Option<(f32, f32)> { + self.data.as_ref().map(|data| { + let cell_size = data.resolution as f32 / WORLD_PIXEL_WIDTH; + ( + OVERLAY_OFFSET_X as f32 * cell_size, + OVERLAY_OFFSET_Y as f32 * cell_size, + ) + }) + } + + pub(crate) fn has_data(&self) -> bool { + self.data.is_some() + } + + pub(crate) fn set_background_data( + &mut self, + encoded: String, + width: u16, + height: u16, + total_width: u16, + total_height: u16, + resolution: i32, + x_min: i32, + y_max: i32, + direction: i32, + ) -> bool { + let new_data = NgiotBackgroundData { + encoded, + width, + height, + total_width, + total_height, + resolution, + x_min, + y_max, + direction, + }; + + if self.data.as_ref() == Some(&new_data) { + return false; + } + + self.data = Some(new_data); + true + } + + pub(crate) fn clear_background_data(&mut self) -> bool { + if self.data.is_none() { + return false; + } + + self.data = None; + true + } + + pub(super) fn generate(&self) -> Result> { + let Some(data) = self.data.as_ref() else { + return Ok(None); + }; + + let expected_len = usize::from(data.width) * usize::from(data.height); + if expected_len == 0 { + return Ok(None); + } + + let raster = decompress_base64_lz4_data(&data.encoded, expected_len)?; + if raster.len() != expected_len { + return Err(format!( + "NGIOT raster size mismatch: expected {}, got {}", + expected_len, + raster.len() + ) + .into()); + } + + let mut png_data = Vec::new(); + { + let mut encoder = + Encoder::new(&mut png_data, u32::from(data.width), u32::from(data.height)); + encoder.set_compression(Compression::Balanced); + encoder.set_color(ColorType::Rgba); + encoder.set_depth(BitDepth::Eight); + + let mut writer = encoder.write_header()?; + let rgba = raster_to_rgba(&raster); + writer.write_image_data(&rgba)?; + } + + let left_world = (data.x_min - data.y_max) as f32; + let top_world = (data.y_max - i32::from(data.height)) as f32; + + let left = left_world / WORLD_PIXEL_WIDTH; + let top = -top_world / WORLD_PIXEL_WIDTH; + let width_svg = (f32::from(data.width) * data.resolution as f32) / WORLD_PIXEL_WIDTH; + let height_svg = (f32::from(data.height) * data.resolution as f32) / WORLD_PIXEL_WIDTH; + + let viewbox = ViewBox::from_extents(left, top, left + width_svg, top + height_svg); + + debug!( + "Generated NGIOT raster background: map {}x{} at world ({}, {}) size ({}, {}), direction={}, overlay_offset_cells=({}, {})", + data.width, + data.height, + left, + top, + width_svg, + height_svg, + data.direction, + OVERLAY_OFFSET_X, + OVERLAY_OFFSET_Y, + ); + + Ok(Some((general_purpose::STANDARD.encode(&png_data), viewbox))) + } +} + +#[inline] +fn rgba_for_value(value: u8) -> [u8; 4] { + match value { + 127 => [255, 255, 255, 0], // transparent / outside map + 1 => [237, 237, 237, 255], // light floor + 0 => [210, 210, 210, 255], // alternate floor / unknown floor + 2 => [20, 20, 20, 255], // dark occupied / blocked region + 3 => [83, 132, 178, 255], // observed alternate class + 4 => [165, 92, 47, 255], // observed alternate class + 255 => [220, 30, 30, 255], // marker / sentinel + _ => [255, 0, 255, 255], // unknown class => magenta for visibility + } +} + +fn raster_to_rgba(raster: &[u8]) -> Vec { + let mut rgba = Vec::with_capacity(raster.len() * 4); + for &value in raster { + rgba.extend_from_slice(&rgba_for_value(value)); + } + rgba +} + +#[pymethods] +impl NgiotBackground { + fn set_map_data( + &mut self, + encoded: String, + width: u16, + height: u16, + total_width: u16, + total_height: u16, + resolution: i32, + x_min: i32, + y_max: i32, + direction: i32, + ) -> bool { + self.set_background_data( + encoded, + width, + height, + total_width, + total_height, + resolution, + x_min, + y_max, + direction, + ) + } + + fn clear(&mut self) -> bool { + self.clear_background_data() + } + + fn has_map_data(&self) -> bool { + self.has_data() + } +} \ No newline at end of file diff --git a/src/map/points.rs b/src/map/points.rs index 1233d7261..2e8f72de1 100644 --- a/src/map/points.rs +++ b/src/map/points.rs @@ -1,13 +1,15 @@ use std::fmt::Write as FmtWrite; -use super::{ROUND_TO_DIGITS, RotationAngle, common::round}; -use crate::util::decompress_base64_data; +use super::{PIXEL_WIDTH, ROUND_TO_DIGITS, RotationAngle, calc_point, common::round}; +use crate::util::{decompress_base64_data, decompress_base64_lz4_data}; use log::error; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use std::error::Error; use svg::node::element::Path; +const LEGACY_TRACE_SCALE: f32 = 0.2; + #[derive(PartialEq)] enum SvgPathCommand { // To means absolute, by means relative @@ -118,33 +120,103 @@ fn extract_trace_points(value: &str) -> Result, Box> process_trace_points(&decompressed_data) } -fn trace_point_to_point(trace_point: &TracePoint, rotation: RotationAngle) -> Point { +fn extract_trace_points_lz4( + value: &str, + expected_len: usize, +) -> Result, Box> { + let decompressed_data = decompress_base64_lz4_data(value, expected_len)?; + process_trace_points(&decompressed_data) +} + +fn trace_point_to_point( + trace_point: &TracePoint, + rotation: RotationAngle, + ngiot_origin: Option<(i32, i32)>, + overlay_svg_offset: Option<(f32, f32)>, +) -> Point { + if let Some((x_min, y_max)) = ngiot_origin { + let world_x = x_min as f32 + trace_point.x as f32; + let world_y = y_max as f32 - trace_point.y as f32; + let mut point = calc_point(world_x, world_y, rotation); + if let Some((dx, dy)) = overlay_svg_offset { + point.x += dx; + point.y += dy; + } + point.connected = trace_point.connected; + return point; + } + let (x, y) = match rotation { RotationAngle::Deg0 => (trace_point.x.into(), trace_point.y.into()), RotationAngle::Deg90 => (trace_point.y.into(), -(trace_point.x as f32)), RotationAngle::Deg180 => (-(trace_point.x as f32), -(trace_point.y as f32)), RotationAngle::Deg270 => (-(trace_point.y as f32), trace_point.x.into()), }; - Point { + let mut point = Point { x, y, connected: trace_point.connected, + }; + if let Some((dx, dy)) = overlay_svg_offset { + point.x += dx; + point.y += dy; } + point } #[pyclass] pub(super) struct TracePoints { trace_points: Vec, + svg_scale: f32, } impl TracePoints { pub(super) fn new() -> Self { Self { trace_points: Vec::new(), + svg_scale: LEGACY_TRACE_SCALE, + } + } + + pub(super) fn add_points( + &mut self, + value: String, + lz4_len: Option, + ) -> Result<(), PyErr> { + let parsed = match lz4_len { + Some(expected_len) => extract_trace_points_lz4(&value, expected_len), + None => extract_trace_points(&value), } + .map_err(|err| { + error!( + "Failed to extract trace points: {err};value:{value};lz4_len:{:?}", + lz4_len + ); + PyValueError::new_err(err.to_string()) + })?; + + self.trace_points.extend(parsed); + Ok(()) + } + + pub(super) fn clear_points(&mut self) { + self.trace_points.clear(); + } + + pub(super) fn use_legacy_trace_scale(&mut self) { + self.svg_scale = LEGACY_TRACE_SCALE; } - pub(super) fn get_path(&self, rotation: RotationAngle) -> Option { + pub(super) fn use_world_trace_scale(&mut self) { + self.svg_scale = 1.0 / PIXEL_WIDTH; + } + + pub(super) fn get_path( + &self, + rotation: RotationAngle, + ngiot_origin: Option<(i32, i32)>, + overlay_svg_offset: Option<(f32, f32)>, + ) -> Option { if self.trace_points.is_empty() { return None; } @@ -153,34 +225,53 @@ impl TracePoints { &self .trace_points .iter() - .map(|tp| trace_point_to_point(tp, rotation)) + .map(|tp| trace_point_to_point(tp, rotation, ngiot_origin, overlay_svg_offset)) .collect::>(), false, false, )?; - Some( - path.set("fill", "none") - .set("stroke", "#fff") - .set("stroke-linejoin", "round") - .set("transform", "scale(0.2-0.2)"), - ) + let path = path + .set("fill", "none") + .set("stroke", "#fff") + .set("stroke-linejoin", "round"); + + Some(if ngiot_origin.is_some() { + path + } else { + path.set( + "transform", + format!("scale({} {})", self.svg_scale, -self.svg_scale), + ) + }) } } #[pymethods] impl TracePoints { - fn add(&mut self, value: String) -> Result<(), PyErr> { - self.trace_points - .extend(extract_trace_points(&value).map_err(|err| { - error!("Failed to extract trace points: {err};value:{value}"); - PyValueError::new_err(err.to_string()) - })?); - Ok(()) + #[pyo3(signature = (value, lz4_len=None))] + fn add(&mut self, value: String, lz4_len: Option) -> Result<(), PyErr> { + self.add_points(value, lz4_len) } fn clear(&mut self) { - self.trace_points.clear(); + self.clear_points(); + } + + fn use_legacy_scale(&mut self) { + self.use_legacy_trace_scale(); + } + + fn use_world_scale(&mut self) { + self.use_world_trace_scale(); + } + + fn set_scale(&mut self, scale: f32) -> Result<(), PyErr> { + if !scale.is_finite() || scale <= 0.0 { + return Err(PyValueError::new_err("scale must be a finite value > 0")); + } + self.svg_scale = scale; + Ok(()) } } @@ -220,226 +311,74 @@ mod tests { assert_eq!(get_path_d_attribute(trace), get_path_d_attribute(expected)); } - #[test] - fn test_get_trace_points_path() { - assert!(TracePoints::new().get_path(RotationAngle::Deg0).is_none()); - } - #[rstest] - #[case(vec![TracePoint{x:16, y:256, connected:true},TracePoint{x:0, y:256, connected:true}], RotationAngle::Deg0, "")] - #[case(vec![ - TracePoint{x:-215, y:-70, connected:true}, - TracePoint{x:-215, y:-70, connected:true}, - TracePoint{x:-212, y:-73, connected:true}, - TracePoint{x:-213, y:-73, connected:true}, - TracePoint{x:-227, y:-72, connected:true}, - TracePoint{x:-227, y:-70, connected:true}, - TracePoint{x:-227, y:-70, connected:true}, - TracePoint{x:-256, y:-69, connected:false}, - TracePoint{x:-260, y:-80, connected:true}, - ], RotationAngle::Deg0, "")] - #[case(vec![TracePoint{x:16, y:256, connected:true},TracePoint{x:0, y:256, connected:true}], RotationAngle::Deg90, "")] - #[case(vec![TracePoint{x:16, y:256, connected:true},TracePoint{x:0, y:256, connected:true}], RotationAngle::Deg180, "")] - #[case(vec![TracePoint{x:16, y:256, connected:true},TracePoint{x:0, y:256, connected:true}], RotationAngle::Deg270, "")] - fn test_get_trace_path( - #[case] points: Vec, - #[case] rotation: RotationAngle, - #[case] expected: String, - ) { + #[case(RotationAngle::Deg0, "M100 200l50 100")] + #[case(RotationAngle::Deg90, "M200-100l100-50")] + #[case(RotationAngle::Deg180, "M-100-200l-50-100")] + #[case(RotationAngle::Deg270, "M-200 100l-100 50")] + fn test_trace_points_rotation(#[case] rotation: RotationAngle, #[case] expected: &str) { let mut trace_points = TracePoints::new(); - trace_points.add_trace_points(points); - let trace = trace_points.get_path(rotation); - assert_eq!(trace.unwrap().to_string(), expected); - } - - #[test] - fn test_extract_trace_points_success() { - let input = "XQAABACvAAAAAAAAAEINQkt4BfqEvt9Pow7YU9KWRVBcSBosIDAOtACCicHy+vmfexxcutQUhqkAPQlBawOeXo/VSrOqF7yhdJ1JPICUs3IhIebU62Qego0vdk8oObiLh3VY/PVkqQyvR4dHxUDzMhX7HAguZVn3yC17+cQ18N4kaydN3LfSUtV/zejrBM4="; - let result = extract_trace_points(input).unwrap(); - let expected = vec![ - TracePoint { - x: 0, - y: 1, - connected: false, - }, - TracePoint { - x: -10, - y: 1, - connected: true, - }, - TracePoint { - x: -7, - y: -8, - connected: true, - }, - TracePoint { - x: 0, - y: -15, - connected: true, - }, - TracePoint { - x: 6, - y: -23, - connected: true, - }, - TracePoint { - x: 11, - y: -32, - connected: true, - }, - TracePoint { - x: 21, - y: -30, - connected: true, - }, - TracePoint { - x: 31, - y: -30, - connected: true, - }, - TracePoint { - x: 40, - y: -34, - connected: true, - }, - TracePoint { - x: 46, - y: -42, - connected: true, - }, - TracePoint { - x: 53, - y: -51, - connected: true, - }, - TracePoint { - x: 52, - y: -61, - connected: true, - }, - TracePoint { - x: 48, - y: -70, - connected: true, - }, - TracePoint { - x: 44, - y: -79, - connected: true, - }, - TracePoint { - x: 34, - y: -83, - connected: true, - }, - TracePoint { - x: 24, - y: -83, - connected: true, - }, - TracePoint { - x: 14, - y: -82, - connected: true, - }, - TracePoint { - x: 6, - y: -76, - connected: true, - }, - TracePoint { - x: 0, - y: -68, - connected: true, - }, - TracePoint { - x: -2, - y: -59, - connected: true, - }, - TracePoint { - x: 0, - y: -48, - connected: true, - }, - TracePoint { - x: 3, - y: -38, - connected: true, - }, + trace_points.add_trace_points(vec![ TracePoint { - x: 11, - y: -32, + x: 100, + y: 200, connected: true, }, TracePoint { - x: 21, - y: -29, - connected: true, - }, - TracePoint { - x: 21, - y: -19, - connected: true, - }, - TracePoint { - x: 14, - y: -12, - connected: true, - }, - TracePoint { - x: 5, - y: -7, - connected: true, - }, - TracePoint { - x: 12, - y: -14, - connected: true, - }, - TracePoint { - x: 21, - y: -18, - connected: true, - }, - TracePoint { - x: 31, - y: -20, - connected: true, - }, - TracePoint { - x: 41, - y: -20, + x: 150, + y: 300, connected: true, }, + ]); + + let path = trace_points.get_path(rotation, None, None).unwrap(); + assert_eq!(path.get_attributes().get("d").unwrap(), expected); + } + + #[test] + fn test_trace_points_ngiot_origin_transform() { + let mut trace_points = TracePoints::new(); + trace_points.add_trace_points(vec![ TracePoint { - x: 51, - y: -24, + x: 100, + y: 200, connected: true, }, TracePoint { - x: 58, - y: -31, + x: 150, + y: 300, connected: true, }, + ]); + + let path = trace_points + .get_path(RotationAngle::Deg0, Some((1000, 2000)), None) + .unwrap(); + + assert_eq!(path.get_attributes().get("d").unwrap(), "M22-36l1-2"); + assert!(path.get_attributes().get("transform").is_none()); + } + + #[test] + fn test_trace_points_legacy_scale_transform_present_without_ngiot_origin() { + let mut trace_points = TracePoints::new(); + trace_points.add_trace_points(vec![ TracePoint { - x: 64, - y: -39, + x: 100, + y: 200, connected: true, }, TracePoint { - x: 70, - y: -47, + x: 150, + y: 300, connected: true, }, - ]; - assert_eq!(result, expected); - } + ]); - #[test] - fn test_process_trace_points_to_short() { - let input: Vec = vec![0x0, 0x0, 0x0, 0x0]; - let result = process_trace_points(&input); - assert!(matches!(result, Err(e) if e.to_string() == "Invalid trace points length")); + let path = trace_points.get_path(RotationAngle::Deg0, None, None).unwrap(); + assert_eq!( + path.get_attributes().get("transform").unwrap(), + "scale(0.2 -0.2)" + ); } -} +} \ No newline at end of file diff --git a/src/map/style.rs b/src/map/style.rs index 04abf194b..0c21da272 100644 --- a/src/map/style.rs +++ b/src/map/style.rs @@ -84,8 +84,10 @@ pub(super) enum CSSClass { RoomColor5, WallBase, + RoomSubset, VirtualWall, NoMoppingWall, + CarpetArea, } pub(super) const ROOM_COLORS: [CSSClass; 6] = [ @@ -207,6 +209,15 @@ fn get_styles() -> &'static HashMap { identifier: ".w path", }, ), + ( + CSSClass::RoomSubset, + CSSEntry { + class_name: "rs", + value: "fill: #deebfb; stroke: #9fb7d8; stroke-width: 0.8", + required_def: None, + identifier: ".rs", + }, + ), ( CSSClass::VirtualWall, css_entry!("v", "stroke: #f00000; fill: #f0000030"), @@ -215,6 +226,15 @@ fn get_styles() -> &'static HashMap { CSSClass::NoMoppingWall, css_entry!("m", "stroke: #ffa500; fill: #ffa50030"), ), + ( + CSSClass::CarpetArea, + CSSEntry { + class_name: "ca", + value: "fill: #1a81ed30; stroke: #1a81ed; stroke-width: 1", + required_def: None, + identifier: ".ca", + }, + ), ]) }) } diff --git a/src/util.rs b/src/util.rs index 737b5c586..46bd3c4ce 100644 --- a/src/util.rs +++ b/src/util.rs @@ -5,6 +5,7 @@ use std::io::{Cursor, Read}; use base64::{Engine as _, engine::general_purpose}; use liblzma::read::XzDecoder; use liblzma::stream::Stream; +use lz4_flex::block; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -25,13 +26,40 @@ pub fn decompress_base64_data(value: &str) -> Result, Box> { } } +/// Dedicated helper for NGIOT LZ4 block payloads. +/// expected_len should come from mapTraceData.lz4Len. +pub fn decompress_base64_lz4_data( + value: &str, + expected_len: usize, +) -> Result, Box> { + let bytes = general_purpose::STANDARD.decode(value)?; + + if expected_len == 0 { + return Err("Invalid LZ4 expected length: 0".into()); + } + + let mut output = vec![0_u8; expected_len]; + let written = block::decompress_into(&bytes, &mut output) + .map_err(|err| format!("LZ4 decompress failed: {err}"))?; + + if written == 0 { + return Err("LZ4 decompress produced no output".into()); + } + + if written < expected_len { + output.truncate(written); + return Ok(output); + } + + Ok(output) +} + /// Decompress LZMA data, avoiding Vec insert overhead. fn decompress_lzma(bytes: &[u8]) -> Result, Box> { if bytes.len() < 8 { return Err("Invalid 7z compressed data".into()); } - // Form tailored header without repeated inserts (much faster) let mut full = Vec::with_capacity(bytes.len() + 4); full.extend_from_slice(&bytes[..8]); full.extend_from_slice(&[0, 0, 0, 0]); @@ -52,7 +80,7 @@ fn decompress_zstd(bytes: &[u8]) -> Result, Box> { Ok(result) } -/// Decompress base64 decoded compressed string by using lzma or zstd +/// Existing legacy helper: lzma or zstd only. #[pyfunction(name = "decompress_base64_data")] fn python_decompress_base64_data(value: &str) -> Result, PyErr> { decompress_base64_data(value).map_err(|err| { @@ -61,7 +89,19 @@ fn python_decompress_base64_data(value: &str) -> Result, PyErr> { }) } +/// New NGIOT-only helper. +#[pyfunction(name = "decompress_base64_lz4_data")] +fn python_decompress_base64_lz4_data(value: &str, expected_len: usize) -> Result, PyErr> { + decompress_base64_lz4_data(value, expected_len).map_err(|err| { + error!( + "Error decompressing LZ4 base64 data: {err}; expected_len:{expected_len}; value:{value}" + ); + PyValueError::new_err(err.to_string()) + }) +} + pub fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(python_decompress_base64_data, m)?)?; + m.add_function(wrap_pyfunction!(python_decompress_base64_lz4_data, m)?)?; Ok(()) -} +} \ No newline at end of file diff --git a/tests/commands/ngiot/test_clean.py b/tests/commands/ngiot/test_clean.py new file mode 100644 index 000000000..c7ee3669d --- /dev/null +++ b/tests/commands/ngiot/test_clean.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from deebot_client.commands.ngiot.clean import Clean, CleanArea, map_live_state, map_snapshot_state +from deebot_client.exceptions import ApiError +from deebot_client.models import CleanAction, CleanMode, State +from deebot_client.ngiot_client import ( + APN_AREA_CLEAN, + APN_CLEAN_START, + APN_PAUSE, + APN_RESUME, + APN_RETURN_TO_DOCK, +) + + +@pytest.mark.parametrize( + ("payload", "expected"), + [ + ({"workMode": "smart", "pauseSwitch": True}, State.PAUSED), + ({"workMode": "goCharge"}, State.RETURNING), + ({"workMode": "idle", "chargeStatus": True}, State.DOCKED), + ({"workMode": "smart", "chargeStatus": False}, State.CLEANING), + ({"workMode": "unknown"}, State.IDLE), + ], +) +def test_map_snapshot_state(payload: dict[str, object], expected: State) -> None: + assert map_snapshot_state(payload) is expected + + +@pytest.mark.parametrize( + ("payload", "previous", "expected"), + [ + ({"status": "smartclean", "pauseSwitch": True}, None, State.PAUSED), + ({"status": "smartclean", "pauseSwitch": False}, None, State.CLEANING), + ({"status": "idle", "chargeStatus": True}, None, State.DOCKED), + ({"pauseSwitch": False}, State.PAUSED, State.CLEANING), + ({"workMode": "goCharge"}, None, State.RETURNING), + ({"status": "unknown"}, None, None), + ], +) +def test_map_live_state( + payload: dict[str, object], previous: State | None, expected: State | None +) -> None: + assert map_live_state(payload, previous=previous) is expected + + +@pytest.mark.parametrize( + ("action", "expected"), + [ + (CleanAction.START, (APN_CLEAN_START, {"cleanSwitch": True, "cleanMode": "smart"})), + (CleanAction.PAUSE, (APN_PAUSE, {"pauseSwitch": True})), + (CleanAction.RESUME, (APN_RESUME, {"pauseSwitch": False})), + (CleanAction.STOP, (APN_RETURN_TO_DOCK, {"chargeSwitch": True})), + ], +) +def test_clean_get_request(action: CleanAction, expected: tuple[str, dict[str, object]]) -> None: + assert Clean(action)._get_request() == expected + + +@pytest.mark.asyncio +async def test_clean_area_requires_supported_mode() -> None: + command = CleanArea(CleanMode.AUTO, [1]) + with pytest.raises(ApiError, match="room-id cleaning only"): + await command._request_ngiot(None, None) # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_clean_area_requires_single_cleaning() -> None: + command = CleanArea(CleanMode.SPOT_AREA, [1], cleanings=2) + with pytest.raises(ApiError, match="repeat count"): + await command._request_ngiot(None, None) # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_clean_area_builds_area_request() -> None: + client = type("Client", (), {"request": AsyncMock(return_value={"body": {}})})() + command = CleanArea(CleanMode.SPOT_AREA, [1, 2]) + + await command._request_ngiot(client, None) # type: ignore[arg-type] + + client.request.assert_awaited_once_with( + None, + apn=APN_AREA_CLEAN, + body_data={"cleanSwitch": True, "cleanMode": "area", "cleanValues": [1, 2]}, + ) diff --git a/tests/commands/ngiot/test_map.py b/tests/commands/ngiot/test_map.py new file mode 100644 index 000000000..557705259 --- /dev/null +++ b/tests/commands/ngiot/test_map.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from unittest.mock import Mock, call + +import pytest + +from deebot_client.commands.ngiot.map import GetCachedMapInfo +from deebot_client.event_bus import EventBus +from deebot_client.events.map import CachedMapInfoEvent, Map, MapSetType +from deebot_client.hardware import get_static_device_info +from deebot_client.message import HandlingResult, HandlingState +from deebot_client.rs.map import RotationAngle + + +@pytest.mark.asyncio +async def test_getCachedMapInfo_bootstraps_map_sets() -> None: + static_device_info = await get_static_device_info("eyfj07") + assert static_device_info is not None + assert static_device_info.capabilities.map is not None + + event_bus = Mock(spec=EventBus) + event_bus.capabilities = static_device_info.capabilities + + response = { + "ret": "ok", + "resp": { + "body": { + "data": { + "mapInfos": [ + { + "mapId": "3", + "name": "Home", + "status": 1, + "angle": 90, + }, + { + "mapId": "4", + "name": "Upstairs", + "status": 0, + "angle": 0, + }, + ] + } + } + }, + } + + result = GetCachedMapInfo()._handle_response(event_bus, response) + + assert result == HandlingResult( + HandlingState.SUCCESS, + {"map_id": "3"}, + [], + ) + event_bus.notify.assert_has_calls( + [ + call( + CachedMapInfoEvent( + maps={ + Map( + id="3", + name="Home", + using=True, + built=True, + angle=RotationAngle.DEG_90, + ), + Map( + id="4", + name="Upstairs", + using=False, + built=True, + angle=RotationAngle.DEG_0, + ), + } + ) + ) + ] + ) + assert event_bus.notify.call_count == 1 \ No newline at end of file diff --git a/tests/commands/ngiot/test_stats.py b/tests/commands/ngiot/test_stats.py new file mode 100644 index 000000000..22a9f9c5d --- /dev/null +++ b/tests/commands/ngiot/test_stats.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from unittest.mock import Mock + +from deebot_client.commands.ngiot.stats import GetReportStats, GetStats, GetTotalStats +from deebot_client.event_bus import EventBus +from deebot_client.events import CleanJobStatus, ReportStatsEvent, StatsEvent, TotalStatsEvent +from deebot_client.message import HandlingResult, HandlingState + + +def test_get_stats_converts_minutes_to_seconds() -> None: + event_bus = Mock(spec_set=EventBus) + + result = GetStats._handle_body_data_dict( + event_bus, + {"cleanArea": "25", "cleanTime": "12", "workMode": "smart"}, + ) + + assert result == HandlingResult(HandlingState.SUCCESS) + event_bus.notify.assert_called_once_with( + StatsEvent(area=25, time=720, type="smart") + ) + + +def test_get_report_stats_extracts_cleaning_id() -> None: + event_bus = Mock(spec_set=EventBus) + + result = GetReportStats._handle_body_data_dict( + event_bus, + { + "cleanArea": 15, + "cleanTime": 7, + "workMode": "area", + "cleanLogReport": {"cid": "job-123"}, + }, + ) + + assert result == HandlingResult(HandlingState.SUCCESS) + event_bus.notify.assert_called_once_with( + ReportStatsEvent( + area=15, + time=420, + type="area", + cleaning_id="job-123", + status=CleanJobStatus.NO_STATUS, + content=[], + ) + ) + + +def test_get_total_stats_uses_total_fields_and_fallbacks() -> None: + event_bus = Mock(spec_set=EventBus) + + result = GetTotalStats._handle_body_data_dict( + event_bus, + { + "cleanAreaTotal": "101", + "cleanTime": 8, + "cleanCount": "9", + }, + ) + + assert result == HandlingResult(HandlingState.SUCCESS) + event_bus.notify.assert_called_once_with( + TotalStatsEvent(area=101, time=480, cleanings=9) + ) diff --git a/tests/conftest.py b/tests/conftest.py index d1f5d24b2..071bd46be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,7 @@ MqttConfiguration, create_mqtt_config as create_config_mqtt, ) +from deebot_client.ngiot_map_state import NgiotMapStateStore from .fixtures.mqtt_server import MqttServer @@ -171,16 +172,23 @@ def execute_mock() -> AsyncMock: @pytest.fixture def event_bus(execute_mock: AsyncMock, device_info: DeviceInfo) -> EventBus: - return EventBus(execute_mock, device_info.static.capabilities) + bus = EventBus(execute_mock, device_info.static.capabilities) + store = NgiotMapStateStore() + bus._ngiot_map_state_store = store + bus.ngiot_map_state = store + return bus @pytest.fixture def event_bus_mock(event_bus: EventBus) -> Mock: - return Mock(spec_set=EventBus, wraps=event_bus) + mock = Mock(spec_set=event_bus, wraps=event_bus) + mock._ngiot_map_state_store = event_bus._ngiot_map_state_store + mock.ngiot_map_state = event_bus.ngiot_map_state + return mock @pytest.fixture(name="caplog") def caplog_fixture(caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture: """Set log level to debug for tests using the caplog fixture.""" caplog.set_level(logging.DEBUG) - return caplog + return caplog \ No newline at end of file diff --git a/tests/hardware/test_init.py b/tests/hardware/test_init.py index df3ea301c..8c4c695af 100644 --- a/tests/hardware/test_init.py +++ b/tests/hardware/test_init.py @@ -91,13 +91,33 @@ if TYPE_CHECKING: from deebot_client.command import Command from deebot_client.events.base import Event - +from deebot_client.commands.ngiot.battery import GetBattery as GetNgiotBattery +from deebot_client.commands.ngiot.child_lock import GetChildLock as GetNgiotChildLock +from deebot_client.commands.ngiot.clean import GetCleanInfo as GetNgiotCleanInfo +from deebot_client.commands.ngiot.fan_speed import GetFanSpeed as GetNgiotFanSpeed +from deebot_client.commands.ngiot.life_span import GetLifeSpan as GetNgiotLifeSpan +from deebot_client.commands.ngiot.error import GetError as GetNgiotError +from deebot_client.commands.ngiot.map import ( + GetCachedMapInfo as GetNgiotCachedMapInfo, + GetMajorMap as GetNgiotMajorMap, + GetMapTrace as GetNgiotMapTrace, +) +from deebot_client.commands.ngiot.network import GetNetInfo as GetNgiotNetInfo +from deebot_client.commands.ngiot.pos import GetPos as GetNgiotPos +from deebot_client.commands.ngiot.stats import ( + GetReportStats as GetNgiotReportStats, + GetStats as GetNgiotStats, + GetTotalStats as GetNgiotTotalStats, +) +from deebot_client.commands.ngiot.volume import GetVolume as GetNgiotVolume +from deebot_client.hardware.eyfj07 import get_device_info as get_eyfj07_info @pytest.mark.parametrize( ("class_", "expected"), [ ("not_specified", None), ("yna5xi", get_yna5xi_info()), + ("eyfj07", get_eyfj07_info()), ], ) async def test_get_static_device_info( @@ -244,8 +264,32 @@ async def test_get_static_device_info( WaterAmountEvent: [GetWaterInfo()], }, ), + ( + "eyfj07", + { + AvailabilityEvent: [GetNgiotBattery(is_available_check=True)], + BatteryEvent: [GetNgiotBattery()], + CachedMapInfoEvent: [GetNgiotCachedMapInfo()], + ChildLockEvent: [GetNgiotChildLock()], + CustomCommandEvent: [], + ErrorEvent: [GetNgiotError()], + FanSpeedEvent: [GetNgiotFanSpeed()], + LifeSpanEvent: [GetNgiotLifeSpan()], + MajorMapEvent: [GetNgiotMajorMap()], + MapChangedEvent: [], + MapTraceEvent: [GetNgiotMapTrace()], + NetworkInfoEvent: [GetNgiotNetInfo()], + PositionsEvent: [GetNgiotPos()], + ReportStatsEvent: [GetNgiotReportStats()], + RoomsEvent: [GetNgiotMajorMap()], + StateEvent: [GetNgiotCleanInfo()], + StatsEvent: [GetNgiotStats()], + TotalStatsEvent: [GetNgiotTotalStats()], + VolumeEvent: [GetNgiotVolume()], + }, + ), ], - ids=["5xu9h3", "itk04l", "yna5xi", "p95mgv"], + ids=["5xu9h3", "itk04l", "yna5xi", "p95mgv", "eyfj07"], ) async def test_capabilities_event_extraction( class_: str, expected: dict[type[Event], list[Command]] @@ -276,4 +320,4 @@ async def test_all_models_loaded() -> None: device_info = await hardware.get_static_device_info(module_name) assert isinstance(device_info, StaticDeviceInfo), ( f"Failed to load device info for {module_name}" - ) + ) \ No newline at end of file diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index fb4e5e7ea..670c0806b 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -57,10 +57,11 @@ def mock_static_device_info( events = {} mock = Mock(spec_set=Capabilities) + mock.map = None def get_refresh_commands(event: type[Event]) -> list[Command]: return events.get(event, []) mock.get_refresh_commands.side_effect = get_refresh_commands - return StaticDeviceInfo(data_type, mock) + return StaticDeviceInfo(data_type, mock) \ No newline at end of file diff --git a/tests/messages/json/test_ngiot.py b/tests/messages/json/test_ngiot.py new file mode 100644 index 000000000..a4432c303 --- /dev/null +++ b/tests/messages/json/test_ngiot.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch + +from deebot_client.event_bus import EventBus +from deebot_client.events import StateEvent +from deebot_client.message import HandlingResult, HandlingState +from deebot_client.messages import get_message +from deebot_client.messages.json.ngiot import OnNgiotMapEvent, OnNgiotStatusEvent +from deebot_client.models import State + + +def test_get_message_resolves_ngiot_numeric_topics() -> None: + from deebot_client.const import DataType + + assert get_message("10000", DataType.JSON) is OnNgiotStatusEvent + assert get_message("30000", DataType.JSON) is OnNgiotMapEvent + + +def test_on_ngiot_map_event_dispatches_major_and_trace_handlers() -> None: + event_bus = Mock(spec_set=EventBus) + payload = {"body": {"data": {"mapData": {}, "mapTraceData": {}}}} + + with ( + patch( + "deebot_client.messages.json.ngiot.GetMajorMap._handle_body_data_dict", + return_value=HandlingResult.success(), + ) as handle_major, + patch( + "deebot_client.messages.json.ngiot.GetMapTrace._handle_body_data_dict", + return_value=HandlingResult.success(), + ) as handle_trace, + ): + result = OnNgiotMapEvent.handle(event_bus, payload) + + assert result.state == HandlingState.SUCCESS + handle_major.assert_called_once_with(event_bus, payload["body"]["data"]) + handle_trace.assert_called_once_with(event_bus, payload["body"]["data"]) + + +def test_on_ngiot_map_event_dispatches_pos_only_handler() -> None: + event_bus = Mock(spec_set=EventBus) + payload = {"body": {"data": {"pos": {}}}} + + with patch( + "deebot_client.messages.json.ngiot.GetPos._handle_body_data_dict", + return_value=HandlingResult.success(), + ) as handle_pos: + result = OnNgiotMapEvent.handle(event_bus, payload) + + assert result.state == HandlingState.SUCCESS + handle_pos.assert_called_once_with(event_bus, payload["body"]["data"]) + + +def test_on_ngiot_map_event_accepts_minor_only_payload() -> None: + event_bus = Mock(spec_set=EventBus) + + result = OnNgiotMapEvent.handle( + event_bus, + {"body": {"data": {"mapMinorData": {"piece": 1}}}}, + ) + + assert result.state == HandlingState.SUCCESS + event_bus.notify.assert_not_called() + + +def test_on_ngiot_status_event_dispatches_live_state() -> None: + event_bus = Mock(spec_set=EventBus) + event_bus.get_last_event.return_value = StateEvent(State.PAUSED) + payload = { + "body": { + "data": { + "battery": 81, + "cleanArea": 15, + "cleanTime": 12, + "childLock": True, + "volume": 4, + "status": "smartclean", + "pauseSwitch": False, + } + } + } + + with ( + patch( + "deebot_client.messages.json.ngiot.GetBattery._handle_body_data_dict", + return_value=HandlingResult.success(), + ) as handle_battery, + patch( + "deebot_client.messages.json.ngiot.GetStats._handle_body_data_dict", + return_value=HandlingResult.success(), + ) as handle_stats, + patch( + "deebot_client.messages.json.ngiot.GetChildLock._handle_body_data_dict", + return_value=HandlingResult.success(), + ) as handle_child_lock, + patch( + "deebot_client.messages.json.ngiot.GetVolume._handle_body_data_dict", + return_value=HandlingResult.success(), + ) as handle_volume, + ): + result = OnNgiotStatusEvent.handle(event_bus, payload) + + assert result.state == HandlingState.SUCCESS + handle_battery.assert_called_once_with(event_bus, payload["body"]["data"]) + handle_stats.assert_called_once_with(event_bus, payload["body"]["data"]) + handle_child_lock.assert_called_once_with(event_bus, payload["body"]["data"]) + handle_volume.assert_called_once_with(event_bus, payload["body"]["data"]) + event_bus.notify.assert_called_once_with(StateEvent(State.CLEANING)) + + +def test_on_ngiot_status_event_falls_back_to_clean_info_when_state_unresolved() -> None: + event_bus = Mock(spec_set=EventBus) + event_bus.get_last_event.return_value = None + payload = { + "body": { + "data": { + "status": "unknown", + } + } + } + + with patch( + "deebot_client.messages.json.ngiot.GetCleanInfo._handle_body_data_dict", + return_value=HandlingResult.success(), + ) as handle_clean_info: + result = OnNgiotStatusEvent.handle(event_bus, payload) + + assert result.state == HandlingState.SUCCESS + handle_clean_info.assert_called_once_with(event_bus, payload["body"]["data"]) diff --git a/tests/rs/test_map.py b/tests/rs/test_map.py index 494337ebe..0347284df 100644 --- a/tests/rs/test_map.py +++ b/tests/rs/test_map.py @@ -14,12 +14,12 @@ ( "invalid_base64", "Invalid symbol 95, offset 7.", - "Failed to extract trace points: Invalid symbol 95, offset 7.;value:invalid_base64", + "Failed to extract trace points: Invalid symbol 95, offset 7.;value:invalid_base64;lz4_len:None", ), ( "", "Invalid 7z compressed data", - "Failed to extract trace points: Invalid 7z compressed data;value:", + "Failed to extract trace points: Invalid 7z compressed data;value:;lz4_len:None", ), ], ) @@ -100,4 +100,4 @@ def test_PositionType_eq() -> None: assert PositionType.CHARGER == PositionType.CHARGER assert PositionType.CHARGER == 1 - assert PositionType.DEEBOT != PositionType.CHARGER + assert PositionType.DEEBOT != PositionType.CHARGER \ No newline at end of file diff --git a/tests/rs/test_util.py b/tests/rs/test_util.py index 70e49a5ff..e92e79024 100644 --- a/tests/rs/test_util.py +++ b/tests/rs/test_util.py @@ -3,12 +3,16 @@ from __future__ import annotations import base64 +import hashlib import lzma from typing import TYPE_CHECKING import pytest -from deebot_client.rs.util import decompress_base64_data +from deebot_client.rs.util import ( + decompress_base64_data, + decompress_base64_lz4_data, +) if TYPE_CHECKING: from pytest_codspeed import BenchmarkFixture @@ -31,7 +35,7 @@ ), ( "XQAABACvAAAAAAAAAEINQkt4BfqEvt9Pow7YU9KWRVBcSBosIDAOtACCicHy+vmfexxcutQUhqkAPQlBawOeXo/VSrOqF7yhdJ1JPICUs3IhIebU62Qego0vdk8oObiLh3VY/PVkqQyvR4dHxUDzMhX7HAguZVn3yC17+cQ18N4kaydN3LfSUtV/zejrBM4=", - b'\x00\x00\x01\x00\x98\xf6\xff\x01\x00\x18\xf9\xff\xf8\xff@\x00\x00\xf1\xff@\x06\x00\xe9\xff@\x0b\x00\xe0\xff@\x15\x00\xe2\xff@\x1f\x00\xe2\xff@(\x00\xde\xff@.\x00\xd6\xff@5\x00\xcd\xff@4\x00\xc3\xff@0\x00\xba\xff@,\x00\xb1\xff@"\x00\xad\xff@\x18\x00\xad\xff@\x0e\x00\xae\xff@\x06\x00\xb4\xff@\x00\x00\xbc\xff@\xfe\xff\xc5\xff@\x00\x00\xd0\xff@\x03\x00\xda\xff@\x0b\x00\xe0\xff@\x15\x00\xe3\xff@\x15\x00\xed\xffH\x0e\x00\xf4\xffH\x05\x00\xf9\xffH\x0c\x00\xf2\xffH\x15\x00\xee\xffH\x1f\x00\xec\xffH)\x00\xec\xffH3\x00\xe8\xffH:\x00\xe1\xffH@\x00\xd9\xff@F\x00\xd1\xff@', + b"\x00\x00\x01\x00\x98\xf6\xff\x01\x00\x18\xf9\xff\xf8\xff@\x00\x00\xf1\xff@\x06\x00\xe9\xff@\x0b\x00\xe0\xff@\x15\x00\xe2\xff@\x1f\x00\xe2\xff@(\x00\xde\xff@.\x00\xd6\xff@5\x00\xcd\xff@4\x00\xc3\xff@0\x00\xba\xff@,\x00\xb1\xff@\"\x00\xad\xff@\x18\x00\xad\xff@\x0e\x00\xae\xff@\x06\x00\xb4\xff@\x00\x00\xbc\xff@\xfe\xff\xc5\xff@\x00\x00\xd0\xff@\x03\x00\xda\xff@\x0b\x00\xe0\xff@\x15\x00\xe3\xff@\x15\x00\xed\xffH\x0e\x00\xf4\xffH\x05\x00\xf9\xffH\x0c\x00\xf2\xffH\x15\x00\xee\xffH\x1f\x00\xec\xffH)\x00\xec\xffH3\x00\xe8\xffH:\x00\xe1\xffH@\x00\xd9\xff@F\x00\xd1\xff@", ), ], ids=["1", "2", "3", "4"], @@ -40,11 +44,8 @@ def test_decompress_base64_data_lzma( benchmark: BenchmarkFixture, value: str, expected: bytes ) -> None: """Test decompress_base64_data function with lzma base64 values.""" - # Benchmark only the production function result = benchmark(decompress_base64_data, value) assert result == expected - - # Verify that the old python function is producing the same result assert _decompress_7z_base64_data_python(value) == result @@ -62,11 +63,25 @@ def test_decompress_base64_data_zstd( benchmark: BenchmarkFixture, value: str, expected: bytes ) -> None: """Test decompress_base64_data function with zstd base64 values.""" - # Benchmark only the production function result = benchmark(decompress_base64_data, value) assert result == expected +_REAL_NGIOT_LZ4_MAP = "H38BAP//8BUAAQAPCwNzIwABAQAFmAAPmgBrB5AAA5kAHwEqAV0AAwEAAgABdwAFkAAGAgAAHAAPAgAAAS8ADwIANghhAAJ7AAMHAQINAA8CAA8DLwAPAgAoCIMABIUABwIAA2EADxoAAA8CAAIDLwAPAgAcAHoABQIAD1gAAgUCAARhAA+yAQAPAgABD7IBIw9dABQAAgAPkAAWBSkADxwBHAACAA+QAEAEjwAPkAAjAKYBD5UAEwRoAA8uAA0EKAAPkAAiAV8AAmIAAJoADwIACQRoAAIuAA8CAAcEKAAPkAAiFwCGAA9pAAwP9wIUAx8FAC4AAO8ADwIAHA+EAAwIAgAHugAPAgAKAIgAAI4ABSgBDwIAGQmQAA9yABUPAgAMAI0ABSgBDwIAHA+QAEEDAgAPIAEnBJUAAdcADwIAKARIAAICAAC4AQ8CAB4BQwADYQEDkAAEDgAPAgAfAU0ABgIADyEBIQR8AgGLAAEFAAAgAwtZAA8CABUHNwAACwAPkAAhAQIABHAABJAAAggADwIAHwCFAAMCAABDAA+QADAA9AEnAH9YAA8CABoHOAAACwAPkAAjADoAAwIAApAACRMADwIAGgpLAAD5AQ8CAB8AWQEEAgAE0QAHWAAPAgAaBzgALwABkAAjBY8AABwBD5AALQcCAACIBgRPAA8CABkE0wABjwAEQQAEFQAPAgAeCjkABE8ADwIAGQNCAACPAAEGAAE8AADdAA+MAB4OAgAPkAAnA/MDA48AACsIAKEAAKMADwIAGwjHAABAAAFMAA8CABoCQgADkAALggIPkAAZAkgAAgYABQIADyEBIAFCABV/kAAJ0wABHAAPAgATD7sAAw+QAC4KQwAPkAAmAKsBAD0AD5AAJBAAIgYCAgAVASEBABAAAKEADwIAEwosAAGMAA3/AQ8CABIG5QMAxgIHRAAAVAANAgAPFQABAQsBAL4AAAIAAjIABCcADyABIAkCAABIAAREAAAMAA0CAARxAAgCAACHAAMYAAMHAAEcAQACAARMAA8CACUAhAAERAAADAANAgAEcQAMAgAQABsAASkBAQIABDwABEwADwIAJQBIAABYAABAAAAMAA8CABYAkAAHLQAEPAAPkAAxAG8BD5AAsgQVAQcCAA8gASsFAAIBRwAPXwAADwIABwDlAQ9cAgABSQAPAgAmBYIAAFwABF4AAAoABAwADwIABQHJAADxCAwhAA+QADQIAgAAqgIPkAANAQIAAZUACwIAD5AAQQ9oAAEPAgAVD5AAQA8CACoPkADGD6UBJAD7AgUCAATJBA8CACYHIAMPkAAlC94BD7ABbQfSAAECAA+QAP+rBekCD8YBFQz+AQ8CAB8MQgAGzREAAgATAgEAD5AAawOFAAcCAAqmAAYCAAwZBAzeAA8CAAAPCAQMDEIAAHoAB4kAAwIAAIIAAwsACAIADTcBDE4ADwIAAA+QABMAcQABKgAAuQwIcwAPAgAKD5AAAQFIAAFLAA8CACIBPAAIPwAPkAAtA8MBDyoBEA8CAAMBJgYHAgAPkABtDAIAD5AAuwVzAw+wASYMAgAPIAEqBgIAAJUCDFsADwIAIg+QADQDyw8PAgAvD5AAyA+yATIPIAE3D9UCMw+QADgPIAF8D/MDMw86BwMPAgADDcAGCQIAD5AAMz8BAf99AAMPAgAAD5AApwAfBQECAADxAA8CABwPIAFBBEoBD8QBIQ+QADcBEgEBzQgPkABzA98CAHERBgsADwIAHg8gATcGhQAPIAGjAd0BDWMCDyABCBoA9QME9gAPAgADBR4AAGIAAyMAAAsADwIABwEeAA2SAA+QAAgLkQAFcgAPAgAKAeUDAyIAAgIBBwIACZ0AAZAADZIAD5AACgm+BQ+QABMWf48ACAIAARYAA4IAAwIABQ4ABQkAAAIAD5AADBYAtwEPkABvArMBBLYBD5AAEwGeAQgaAQICAA8gAQEBKwAAFhMAGAEEAgAPIAEJAJAABI4AD5AAFQ8CAAECbAAMsAEBAgAArAEABAAFAgAPkAAOBZEADyABEw8CAAEEbQAJeQACAgAEJAEAAgAALQEHcAAPAgACArEJAZQAD5AAJQBrAAICAA+QAAAArwEEFwACbhwBogADSQMACQANAgABtgEAAgAP+wABDwIADwF0AAICAAR5AAECAABSAAkRAAGZAAT1AAECARQCDgAHAgABGQAPGAAADwIAFQF8AA8CAAEAVAABlAAIAgAG+wYBrgEAJAAKAgAAHwEBMQAKFwAPAgAVBYkAAUQADQIADEcABsMhAAcBD6oTAgGoAA+qAAEPAgAQBY4ABoYeAwIAAvkAAgYABwIABZAAAHMAAQIABx0AAgIAARYAD5AAIwaPAAJLAAcCAAKKAAIGAAcCAAeQAANpAwcdAAICAACNAA9aAgMPAgANBq0CDioAAooADpEBA0ACIAACKhMAtAEcAjwAAY4ADyABIgWQAAxTAA8CAAgEkAABIwABkAAMkQAAjQAPIAEiAN4EAQIAAVwADwIABQIpAQUCAAOQAAICAAKQAAuRAAGOAA8gASEBjAABtwYPAgAYCnMDAHkAAAIACEEAARAAD5AAHwGLAAACAAhMAA0CAAseAQwCAAB8AAECAAcZAAA0BAC9EAcTAA8CABEAiwABAgAPLQARDwIABQGMAAECAA8iAAUPAgAVAMYAAgIADzIAFQ8CAAIBjAABAgAPHwACDwIAXgGLAAICAA98AF4PAgACAowAAQIADyAAAg8CAF0BigAEAgAPfQBdDwIAAQSMAAACAA8gAAEPAgBdAIgABgIAD34AXQ4CAAaMAAECAA4hAA8CAF0BhwAFAgAPfgBdDgIABYsAAAIAD88ELA8CADEAhwAFAgAPYAYMDwIAUgWNAA8fAXIKkAAAEAAPAgBtABkBAwIAD68BcQiRAA+QAP8EE3+uAQ9+BHEBAgAAiwABAgABDgAPAgBwAY0ADdoLDwIAZwGNAAACAA/RAnMGQQIL/AwPAgBmABwBAQIAD4IAZgoCAAGMAAECAAoYAA8CAGUBiwABAgAPggBlCwIAAYwAAQIACxkADwIAZAGLAAICAA+CAGQLAgACjAAAAgALGQAPAgBkAIoAAwIAD4IAZAoCAAOMAAACAAoZAA8CAGUAigADAgAPgwBlCQIAA4wAAQIACRkADwIAZQGKAAICAA+DAGUJAgACiwACAgAJGQAPAgBkAooAAgIAD4MAZAoCAAKLAAICAAoaAA8CAGMCigADAgAPgwBjCgIAA4sAAQIAChoADwIAYwGJAAQCAA+DAGMJAgAEiwACAgAJGwAPAgBjAokAAwIAD4MAYwkCAAOKAAMCAAkbAA8CAGIDiQADAgAPgwBiCgIAA4oAAwIAChwADwIAYgOKAAMCAA+DAGIJAgADiQADAgAJGwAPAgBiAIQAA40AAAIAD4QAYgkCAAGiDgCLAAECAAkbAA8CAGcBjAABAgAPhABnCgIAAY0AD9APdwWRAAALAA8CAHIPiwB4BAIAUH9/f39/" +_REAL_NGIOT_LZ4_LEN = 21168 +_REAL_NGIOT_LZ4_SHA256 = "11ef0e79e46c3d7617b2d1a4f3688159a94b7aea2fd24c2bee2474e0a3ea18af" + + +def test_decompress_base64_lz4_data_real_payload() -> None: + """Test dedicated NGIOT LZ4 helper against an observed map payload.""" + result = decompress_base64_lz4_data(_REAL_NGIOT_LZ4_MAP, _REAL_NGIOT_LZ4_LEN) + + assert len(result) == _REAL_NGIOT_LZ4_LEN + assert hashlib.sha256(result).hexdigest() == _REAL_NGIOT_LZ4_SHA256 + assert result[:16] == b"\x7f" * 16 + assert set(result).issuperset({0, 1, 2, 127, 255}) + + @pytest.mark.parametrize( ("value", "expected_error"), [ @@ -87,14 +102,30 @@ def test_decompress_base64_data_zstd( def test_decompress_base64_data_errors(value: str, expected_error: str) -> None: """Test decompress_base64_data function.""" with pytest.raises(ValueError, match=expected_error): - assert decompress_base64_data(value) + decompress_base64_data(value) + + +@pytest.mark.parametrize( + ("value", "expected_len", "expected_error"), + [ + ("@@not-base64@@", 10, "Invalid symbol"), + (base64.b64encode(b"abc").decode(), 0, "Invalid LZ4 expected length: 0"), + (base64.b64encode(b"abc").decode(), 10, "LZ4 decompress failed"), + ], +) +def test_decompress_base64_lz4_data_errors( + value: str, expected_len: int, expected_error: str +) -> None: + """Test NGIOT LZ4 helper failure cases.""" + with pytest.raises(ValueError, match=expected_error): + decompress_base64_lz4_data(value, expected_len) + def _decompress_7z_base64_data_python(data: str) -> bytes: """Decompress base64 decoded 7z compressed string.""" final_array = bytearray() - # Decode Base64 decoded = base64.b64decode(data) for i, idx in enumerate(decoded): @@ -103,4 +134,4 @@ def _decompress_7z_base64_data_python(data: str) -> bytes: final_array.append(idx) dec = lzma.LZMADecompressor(lzma.FORMAT_AUTO, None, None) - return dec.decompress(final_array) + return dec.decompress(final_array) \ No newline at end of file diff --git a/tests/test_map.py b/tests/test_map.py index 05389a2e9..0d0c67abb 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -102,7 +102,7 @@ async def test_Map_subscriptions( num_unsubs = len(calls) + 1 assert len(map_obj._unsubscribers) == num_unsubs - async def on_change() -> None: + async def on_change(_: MapChangedEvent) -> None: pass event_unsub = event_bus_mock.subscribe(MapChangedEvent, on_change) @@ -248,4 +248,4 @@ async def test_fn() -> str | None: def svg_map() -> str | None: return event_loop.run_until_complete(test_fn()) - assert svg_map == snapshot + assert svg_map == snapshot \ No newline at end of file diff --git a/tests/test_message.py b/tests/test_message.py index 239ece8e3..fb23caf57 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -59,11 +59,11 @@ def test_MessageStr_should_error_on_unknown_types() -> None: event_bus = Mock(spec_set=EventBus) result = TestMessageStr.handle(event_bus, {"key": "value"}) - assert result.state == HandlingState.ERROR + assert result.state == HandlingState.ANALYSE_LOGGED def test_WronglyImplementedMessage() -> None: event_bus = Mock(spec_set=EventBus) result = WronglyImplementedMessage.handle(event_bus, {}) - assert result.state == HandlingState.ERROR + assert result.state == HandlingState.ERROR \ No newline at end of file diff --git a/tests/test_mqtt_client_ngiot.py b/tests/test_mqtt_client_ngiot.py new file mode 100644 index 000000000..18d01f225 --- /dev/null +++ b/tests/test_mqtt_client_ngiot.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from unittest.mock import Mock + +from unittest.mock import AsyncMock + +from deebot_client.mqtt_client import MqttClient, MqttConfiguration, SubscriberInfo, _get_topics + + +def test_get_topics_adds_ngiot_user_topics(device_info) -> None: + topics = _get_topics(device_info, "user-123") + + assert f"iot/atr/+/user-123/{device_info.api['class']}/{device_info.api['did']}/{device_info.static.data_type}" in topics + assert f"iot/atr/+/user-123/{device_info.api['class']}/{device_info.api['resource']}/{device_info.static.data_type}" in topics + assert len(topics) == len(set(topics)) + + +def test_topic_matches_device_supports_legacy_and_ngiot_shapes(device_info) -> None: + legacy = [ + "iot", + "atr", + "onBattery", + device_info.api["did"], + device_info.api["class"], + device_info.api["resource"], + str(device_info.static.data_type), + ] + ngiot_did = [ + "iot", + "atr", + "10000", + "user-123", + device_info.api["class"], + device_info.api["did"], + str(device_info.static.data_type), + ] + ngiot_resource = [ + "iot", + "atr", + "30000", + "user-123", + device_info.api["class"], + device_info.api["resource"], + str(device_info.static.data_type), + ] + + assert MqttClient._topic_matches_device(legacy, device_info) is True + assert MqttClient._topic_matches_device(ngiot_did, device_info) is True + assert MqttClient._topic_matches_device(ngiot_resource, device_info) is True + assert MqttClient._topic_matches_device(ngiot_resource[:-1] + ["x"], device_info) is False + + +def test_handle_atr_routes_numeric_ngiot_topics(authenticator, device_info, event_bus) -> None: + config = MqttConfiguration(hostname="localhost", port=1883, ssl_context=None, device_id="test-device") + authenticator.subscribe = AsyncMock() + client = MqttClient(config, authenticator) + callback = Mock() + client._subscriptions[device_info.api["did"]] = SubscriberInfo( + device_info=device_info, + events=event_bus, + callback=callback, + ) + + client._handle_atr( + [ + "iot", + "atr", + "30000", + "user-123", + device_info.api["class"], + device_info.api["resource"], + str(device_info.static.data_type), + ], + b'{"body":{"data":{"status":"smartclean"}}}', + ) + + callback.assert_called_once_with( + "30000", b'{"body":{"data":{"status":"smartclean"}}}' + ) diff --git a/tests/test_ngiot_client.py b/tests/test_ngiot_client.py new file mode 100644 index 000000000..0c44f483c --- /dev/null +++ b/tests/test_ngiot_client.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from collections.abc import Mapping +from unittest.mock import AsyncMock, Mock + +import pytest +from aiohttp import ClientResponseError, RequestInfo +from multidict import CIMultiDictProxy, CIMultiDict +from yarl import URL + +from deebot_client.exceptions import ApiError +from deebot_client.ngiot_client import NgiotClient, NgiotDeviceIdentity + + +def _request_info() -> RequestInfo: + return RequestInfo( + url=URL("https://api.example.com/api/iot/endpoint/control"), + method="POST", + headers=CIMultiDictProxy(CIMultiDict()), + real_url=URL("https://api.example.com/api/iot/endpoint/control"), + ) + + +@pytest.fixture +def sst_authenticator() -> AsyncMock: + auth = AsyncMock() + auth.get_token.return_value = "sst-token" + return auth + + +@pytest.fixture +def ngiot_client(sst_authenticator: AsyncMock) -> NgiotClient: + return NgiotClient(Mock(), sst_authenticator) + + +def test_normalize_device_uses_override_host_and_keeps_service_host_as_fallback( + ngiot_client: NgiotClient, +) -> None: + ngiot_client._override_control_host = "override.example.com" + + identity = ngiot_client._normalize_device( + { + "did": "did-1", + "class": "eyfj07", + "resource": "res-1", + "service": {"mqs": "service.example.com"}, + } + ) + + assert identity == NgiotDeviceIdentity( + did="did-1", + class_id="eyfj07", + resource="res-1", + control_host="override.example.com", + fallback_control_host="service.example.com", + ) + assert identity.base_url == "https://override.example.com" + + +def test_normalize_device_requires_control_host(ngiot_client: NgiotClient) -> None: + with pytest.raises(ApiError, match="Missing NGIOT control host"): + ngiot_client._normalize_device( + { + "did": "did-1", + "class": "eyfj07", + "resource": "res-1", + } + ) + + +@pytest.mark.parametrize( + ("apn", "body_data", "expected"), + [ + ( + "10001", + {"fields": ["battery"]}, + {"fields": ["battery"], "type": "get"}, + ), + ( + "30001", + {"fields": ["mapData"]}, + {"fields": ["mapData"], "mapId": "0"}, + ), + ], + ids=["robot-detail", "map-details"], +) +def test_build_payload_adds_ngiot_defaults( + ngiot_client: NgiotClient, + apn: str, + body_data: dict[str, object], + expected: dict[str, object], +) -> None: + payload = ngiot_client._build_payload(apn, body_data) + + assert isinstance(payload, Mapping) + assert payload["reqId"] + assert payload["timestamp"] + for key, value in expected.items(): + assert payload[key] == value + + +@pytest.mark.asyncio +async def test_request_with_fallback_retries_on_404() -> None: + client = NgiotClient(Mock(), AsyncMock()) + identity = NgiotDeviceIdentity( + did="did-1", + class_id="eyfj07", + resource="res-1", + control_host="api.example.com", + fallback_control_host="service.example.com", + ) + response_error = ClientResponseError( + request_info=_request_info(), + history=(), + status=404, + message="not found", + ) + client._request_once = AsyncMock( + side_effect=[response_error, {"body": {"code": 0, "data": {"ok": True}}}] + ) + + response = await client._request_with_fallback( + identity, + {"did": "did-1", "class": "eyfj07", "resource": "res-1"}, + apn="30001", + body_data={"fields": ["mapData"]}, + fmt="j", + ct="q", + force_sst_refresh=False, + ) + + assert response == {"body": {"code": 0, "data": {"ok": True}}} + assert client._request_once.await_count == 2 + first_identity = client._request_once.await_args_list[0].args[0] + second_identity = client._request_once.await_args_list[1].args[0] + assert first_identity.control_host == "api.example.com" + assert second_identity.control_host == "service.example.com" + assert second_identity.fallback_control_host is None + + +@pytest.mark.parametrize( + ("response", "should_raise"), + [ + ({"body": {"code": 0}}, False), + ({"body": {"code": "0000"}}, False), + ({"body": {"code": None}}, False), + ({"body": {"code": 500, "msg": "fail"}}, True), + ({}, True), + ], +) +def test_validate_response(response: dict[str, object], should_raise: bool) -> None: + if should_raise: + with pytest.raises(ApiError): + NgiotClient._validate_response(response) + else: + NgiotClient._validate_response(response) diff --git a/tests/test_ngiot_map_parser.py b/tests/test_ngiot_map_parser.py new file mode 100644 index 000000000..236376704 --- /dev/null +++ b/tests/test_ngiot_map_parser.py @@ -0,0 +1,57 @@ +from deebot_client.ngiot_map_parser import parse_base_map + + +def test_parse_base_map_prefers_map_field_and_captures_lz4_len() -> None: + payload = { + "mapId": "4", + "mapData": { + "mapId": "4", + "map": "ENCODED_MAP_PAYLOAD", + "data": "LEGACY_DATA_SHOULD_NOT_WIN", + "lz4Len": 21168, + "width": 144, + "height": 147, + "totalWidth": 800, + "totalHeight": 800, + "resolution": 5, + "xMin": 383, + "yMax": 493, + }, + } + + base_map = parse_base_map(payload) + + assert base_map is not None + assert base_map.map_id == "4" + assert base_map.encoded == "ENCODED_MAP_PAYLOAD" + assert base_map.lz4_len == 21168 + assert base_map.width == 144 + assert base_map.height == 147 + assert base_map.total_width == 800 + assert base_map.total_height == 800 + assert base_map.resolution == 5 + assert base_map.x_min == 383 + assert base_map.y_max == 493 + + +def test_parse_base_map_falls_back_to_legacy_data_field() -> None: + payload = { + "mapId": "3", + "mapData": { + "data": "LEGACY_ONLY_PAYLOAD", + "width": 138, + "height": 151, + "totalWidth": 800, + "totalHeight": 800, + "resolution": 5, + "xMin": 372, + "yMax": 491, + }, + } + + base_map = parse_base_map(payload) + + assert base_map is not None + assert base_map.map_id == "3" + assert base_map.encoded == "LEGACY_ONLY_PAYLOAD" + assert base_map.lz4_len is None \ No newline at end of file diff --git a/tests/test_ngiot_map_state.py b/tests/test_ngiot_map_state.py new file mode 100644 index 000000000..11e24db81 --- /dev/null +++ b/tests/test_ngiot_map_state.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from deebot_client.ngiot_map_parser import NgiotArea, NgiotBaseMap, NgiotMapInfo, NgiotOverlay, NgiotPoint, NgiotPose, NgiotTrace +from deebot_client.ngiot_map_state import NgiotMapStateStore + + +def test_map_state_store_normalizes_snapshot() -> None: + store = NgiotMapStateStore() + store.update_map_info( + NgiotMapInfo( + map_id="4", + name="Home", + using=True, + angle=0, + charge_pos=NgiotPoint(x=-100, y=300), + ) + ) + store.update_base_map( + NgiotBaseMap( + map_id="4", + width=10, + height=20, + total_width=800, + total_height=800, + resolution=5, + x_min=100, + y_max=200, + direction=1, + encoded="encoded-map", + ) + ) + store.update_pose("4", NgiotPose(x=-95, y=295, a=90)) + store.update_trace("4", NgiotTrace(trace_id="t-1", encoded="trace", lz4_len=32, total_count=2, start=1)) + store.update_areas("4", [NgiotArea(area_id="1", name="Kitchen", polygon=[NgiotPoint(x=-100, y=300)])]) + store.update_overlays("4", [NgiotOverlay(overlay_type="virtual_walls", overlay_id="7", polygon=[NgiotPoint(x=-90, y=290)])]) + + snapshot = store.get_normalized("4") + + assert snapshot is not None + assert snapshot.is_renderable() is True + assert snapshot.charge_pos == NgiotPoint(x=0, y=-24) + assert snapshot.pose == NgiotPose(x=1, y=-23, a=90) + assert snapshot.areas[0].polygon == [NgiotPoint(x=0, y=-24)] + assert snapshot.overlays[0].polygon == [NgiotPoint(x=2, y=-22)] + assert snapshot.trace is not None + assert store.active_map_id == "4" + + +def test_map_state_store_detects_overlay_only_snapshot() -> None: + store = NgiotMapStateStore() + store.update_map_info(NgiotMapInfo(map_id="4", name="Home", using=True, angle=0)) + store.update_trace("4", NgiotTrace(trace_id=None, encoded="trace", lz4_len=None, total_count=0, start=0)) + + snapshot = store.get("4") + + assert snapshot.has_background() is False + assert snapshot.has_overlay_content() is True + assert snapshot.is_overlay_only() is True + assert store.get_active_renderable() is None