Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 163 additions & 6 deletions deebot_client/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from __future__ import annotations

import asyncio
from collections.abc import Mapping
from dataclasses import dataclass
from http import HTTPStatus
import time
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin
from urllib.parse import urljoin, urlparse

from aiohttp import ClientResponseError, ClientSession, ClientTimeout, hdrs

Expand All @@ -20,12 +21,16 @@
)
from .logging_filter import get_logger
from .models import Credentials
from .ngiot_client import NgiotClient, NgiotClientConfiguration
from .sst_authentication import SstAuthenticator
from .util import cancel, create_task, md5
from .util.continents import get_continent_url_postfix
from .util.countries import get_ecovacs_country

if TYPE_CHECKING:
from collections.abc import Callable, Coroutine, Mapping
from collections.abc import Callable, Coroutine

from .models import ApiDeviceInfo


_LOGGER = get_logger(__name__)
Expand All @@ -44,6 +49,7 @@
"deviceType": "1",
}
MAX_RETRIES = 3
_NGIOT_BASE_URL_TEMPLATE = "https://api-base.dc-{region}.ww.ecouser.net"


@dataclass(frozen=True, kw_only=True)
Expand All @@ -58,6 +64,21 @@ class RestConfiguration:
auth_code_url: str


@dataclass(frozen=True, kw_only=True)
class NgiotConfiguration:
"""Optional overrides and defaults for NGIOT-backed devices."""

base_url: str | None = None
region: str | None = None
user_agent: str = "okhttp/4.9.1"
channel: str = "Android"
protocol_version: str = "0.0.22"
timezone_name: str = "UTC"
timezone_offset_minutes: int = 0
requested_ttl: int = 600
refresh_skew: int = 60


def create_rest_config(
session: ClientSession,
*,
Expand Down Expand Up @@ -148,12 +169,11 @@ async def __do_auth_response(
url, params=params, timeout=_TIMEOUT
) as res:
res.raise_for_status()

# ecovacs returns a json but content_type header is set to text
# ecovacs returns a json but content_type header is set to text
content_type = res.headers.get(hdrs.CONTENT_TYPE, "").lower()
json = await res.json(content_type=content_type)
_LOGGER.debug("got %s", json)
# TODO better error handling
# TODO better error handling
if json["code"] == "0000":
data: dict[str, Any] = json["data"]
return data
Expand Down Expand Up @@ -348,6 +368,7 @@ def __init__(
account_id: str,
password_hash: str,
) -> None:
self._config = config
self._auth_client = _AuthClient(
config,
account_id,
Expand All @@ -361,6 +382,37 @@ def __init__(
self._credentials: Credentials | None = None
self._refresh_handle: asyncio.TimerHandle | None = None
self._tasks: set[asyncio.Future[Any]] = set()
self._ngiot_config = NgiotConfiguration()
self._ngiot_base_url: str | None = None
self.sst_authenticator: SstAuthenticator | None = None
self.ngiot_client: NgiotClient | None = None

def configure_ngiot(self, config: NgiotConfiguration | None = None) -> None:
"""Store NGIOT defaults and optional region/base-url overrides."""
self._ngiot_config = self._normalize_ngiot_configuration(
config or NgiotConfiguration()
)

def attach_ngiot(self, config: NgiotConfiguration | None = None) -> None:
"""Attach NGIOT helpers using an explicit base URL or region."""
if config is not None:
self.configure_ngiot(config)

resolved_base_url = self._resolve_configured_ngiot_base_url()
if resolved_base_url is None:
msg = "attach_ngiot() requires a configured NGIOT base_url or region."
raise ApiError(msg)

if self.ngiot_client is not None:
if self._ngiot_base_url == resolved_base_url:
return
msg = (
"NGIOT transport already attached with a different base URL. "
"Call teardown() before attaching a different NGIOT base URL."
)
raise ApiError(msg)

self._create_ngiot_stack(resolved_base_url)

async def authenticate(self, *, force: bool = False) -> Credentials:
"""Authenticate on ecovacs servers."""
Expand Down Expand Up @@ -411,6 +463,11 @@ async def post_authenticated(
async def teardown(self) -> None:
"""Teardown authenticator."""
self._cancel_refresh_task()
if self.sst_authenticator is not None:
await self.sst_authenticator.teardown()
self.sst_authenticator = None
self.ngiot_client = None
self._ngiot_base_url = None
await cancel(self._tasks)

def _cancel_refresh_task(self) -> None:
Expand All @@ -432,5 +489,105 @@ async def async_refresh() -> None:
self._refresh_handle = None

validity = (credentials.expires_at - time.time()) * 0.99
self._refresh_handle = asyncio.get_running_loop().call_later(validity, refresh)

def _create_ngiot_stack(self, base_url: str) -> None:
normalized_base_url = self._normalize_base_url(base_url)
self.sst_authenticator = SstAuthenticator(
self._config.session,
self,
base_url=normalized_base_url,
requested_ttl=self._ngiot_config.requested_ttl,
refresh_skew=self._ngiot_config.refresh_skew,
)
self.ngiot_client = NgiotClient(
self._config.session,
self.sst_authenticator,
config=NgiotClientConfiguration(
user_agent=self._ngiot_config.user_agent,
channel=self._ngiot_config.channel,
protocol_version=self._ngiot_config.protocol_version,
timezone_name=self._ngiot_config.timezone_name,
timezone_offset_minutes=self._ngiot_config.timezone_offset_minutes,
),
)
self._ngiot_base_url = normalized_base_url

@classmethod
def _normalize_ngiot_configuration(
cls, config: NgiotConfiguration
) -> NgiotConfiguration:
return NgiotConfiguration(
base_url=(
cls._normalize_base_url(config.base_url)
if config.base_url is not None
else None
),
region=(
cls._normalize_region(config.region)
if config.region is not None
else None
),
user_agent=config.user_agent,
channel=config.channel,
protocol_version=config.protocol_version,
timezone_name=config.timezone_name,
timezone_offset_minutes=config.timezone_offset_minutes,
requested_ttl=config.requested_ttl,
refresh_skew=config.refresh_skew,
)

self._refresh_handle = asyncio.get_event_loop().call_later(validity, refresh)
def _resolve_configured_ngiot_base_url(self) -> str | None:
if self._ngiot_config.base_url is not None:
return self._ngiot_config.base_url
if self._ngiot_config.region is not None:
return self._format_ngiot_base_url(self._ngiot_config.region)
return None

def _resolve_ngiot_base_url(self, device_info: ApiDeviceInfo) -> str:
configured_base_url = self._resolve_configured_ngiot_base_url()
if configured_base_url is not None:
return configured_base_url

service = device_info.get("service")
if isinstance(service, Mapping):
mqs_host = service.get("mqs")
if isinstance(mqs_host, str) and mqs_host:
return self._derive_ngiot_base_url_from_mqs(mqs_host)

msg = (
f'Could not resolve NGIOT base URL for device class "{device_info["class"]}". '
"Configure an explicit region or base_url before device bootstrap."
)
raise ApiError(msg)

@staticmethod
def _normalize_base_url(base_url: str) -> str:
parsed = urlparse(base_url)
host = parsed.netloc if parsed.scheme and parsed.netloc else base_url
return f"https://{host.strip().rstrip('/')}"

@staticmethod
def _normalize_region(region: str) -> str:
return region.strip().lower().removeprefix("dc-")

@classmethod
def _format_ngiot_base_url(cls, region: str) -> str:
return _NGIOT_BASE_URL_TEMPLATE.format(region=cls._normalize_region(region))

@classmethod
def _derive_ngiot_base_url_from_mqs(cls, mqs_host: str) -> str:
parsed = urlparse(mqs_host)
host = parsed.netloc or parsed.path
host = host.strip().rstrip("/")
if not host:
msg = f'Could not derive NGIOT base URL from mqs host "{mqs_host}"'
raise ApiError(msg)
if host.startswith("api-base."):
return cls._normalize_base_url(host)
if host.startswith("api-ngiot."):
return cls._normalize_base_url("api-base." + host.split(".", 1)[1])
if "." in host:
return cls._normalize_base_url("api-base." + host.split(".", 1)[1])
msg = f'Could not derive NGIOT base URL from mqs host "{mqs_host}"'
raise ApiError(msg)
Loading