diff --git a/clientlib/python/src/udmi/core/blob/fetcher.py b/clientlib/python/src/udmi/core/blob/fetcher.py index 8c36871a84..b0b492c8ae 100644 --- a/clientlib/python/src/udmi/core/blob/fetcher.py +++ b/clientlib/python/src/udmi/core/blob/fetcher.py @@ -8,7 +8,9 @@ import abc import base64 import logging +import os import shutil +import time from urllib.parse import urlparse import requests @@ -82,39 +84,132 @@ def fetch(self, url: str) -> bytes: class HttpFetcher(AbstractBlobFetcher): """ Fetcher implementation for handling standard HTTP/HTTPS URLs. + Supports HTTP Range requests, exponential backoff for retryable errors (e.g. 503), + and immediate aborts for fatal errors (e.g. 403, 404). """ - def __init__(self, timeout_sec: int = 30): + def __init__(self, timeout_sec: int = 30, max_retries: int = 5, backoff_sec: float = 1.0): self.timeout = timeout_sec + self.max_retries = max_retries + self.backoff_sec = backoff_sec def fetch(self, url: str) -> bytes: + """ + Fetches the raw bytes from the given URL entirely into memory. + Uses a resumable streaming approach under the hood if partially complete. + """ + import tempfile + tmp_fd, tmp_path = tempfile.mkstemp() + os.close(tmp_fd) try: - LOGGER.info("Fetching blob via HTTP: %s", url) - headers = {'User-Agent': 'udmi-python-device/1.0'} - response = requests.get(url, timeout=(10, self.timeout), - headers=headers) - response.raise_for_status() - return response.content - except requests.RequestException as e: - raise BlobFetchError(f"HTTP fetch failed: {e}") from e + self.download_to_file(url, tmp_path) + with open(tmp_path, 'rb') as f: + return f.read() + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) def download_to_file(self, url: str, dest_path: str) -> None: """ - Streams content to a temporary file and atomically renames it to dest. + Streams content to a temporary file, supporting resumes via Range headers. + Uses exponential backoff for retryable errors (e.g. 503, connection drops). + Raises BlobFetchError immediately for fatal errors (e.g. 403, 404). """ - try: - LOGGER.info("Streaming blob to file: %s", url) - headers = {'User-Agent': 'udmi-python-device/1.0'} - - with requests.get(url, stream=True, timeout=(10, self.timeout), - headers=headers) as r: - r.raise_for_status() - - with atomic_file_context(dest_path) as tmp_file: - shutil.copyfileobj(r.raw, tmp_file) - - except Exception as e: - raise BlobFetchError(f"HTTP stream failed: {e}") from e + headers = {'User-Agent': 'udmi-python-device/1.0'} + + retries = 0 + backoff_sec = self.backoff_sec + + with atomic_file_context(dest_path) as f: + tmp_path = f.name + while retries <= self.max_retries: + try: + # Check if we have partially downloaded the file + downloaded_bytes = 0 + if os.path.exists(tmp_path): + downloaded_bytes = os.path.getsize(tmp_path) + + if downloaded_bytes > 0: + LOGGER.info("Resuming download from byte %d for %s", downloaded_bytes, url) + headers['Range'] = f'bytes={downloaded_bytes}-' + else: + LOGGER.info("Starting new download for %s", url) + headers.pop('Range', None) + + with requests.get(url, stream=True, timeout=(10, self.timeout), headers=headers) as r: + if r.status_code in (401, 403, 404): + # Fatal Auth/Net error + LOGGER.error("Fatal HTTP Error %d for %s. Aborting.", r.status_code, url) + raise BlobFetchError(f"HTTP fetch failed: {r.status_code}") + + if r.status_code == 416: + # Range Not Satisfiable - already fully downloaded or invalid range + LOGGER.info("Range not satisfiable (already downloaded or invalid) for %s", url) + # Check content length to verify + head_r = requests.head(url, timeout=(10, self.timeout), headers={'User-Agent': 'udmi-python-device/1.0'}) + total_size = int(head_r.headers.get('content-length', 0)) + if total_size > 0 and downloaded_bytes >= total_size: + LOGGER.info("File already fully downloaded.") + break + else: + # Start over + downloaded_bytes = 0 + # Clear file content for retry + f.seek(0) + f.truncate() + continue + + r.raise_for_status() + + # Check if the server respected the Range request. If not (returns 200 OK instead of 206 Partial Content), + # we must start over from the beginning, otherwise we will append the entire file again, causing corruption. + if r.status_code == 200 and downloaded_bytes > 0: + LOGGER.warning("Server ignored Range request and returned 200 OK. Downloading from scratch.") + downloaded_bytes = 0 + + # we are inside atomic_file_context which opens the temp file for us, we can just write to it + if downloaded_bytes == 0: + f.seek(0) + f.truncate() + else: + f.seek(downloaded_bytes) + + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + # Download completed successfully + break + + except requests.exceptions.HTTPError as e: + status = e.response.status_code + if status in (401, 403, 404): + raise BlobFetchError(f"HTTP fetch failed: {status}") from e + elif status == 503: + LOGGER.warning("HTTP 503 Service Unavailable for %s. Retrying...", url) + else: + LOGGER.warning("HTTP Error %d for %s. Retrying...", status, url) + + self._handle_retry(retries, backoff_sec, e) + retries += 1 + backoff_sec *= 2 + + except requests.exceptions.RequestException as e: + LOGGER.warning("Network error during fetch of %s: %s. Retrying...", url, e) + self._handle_retry(retries, backoff_sec, e) + retries += 1 + backoff_sec *= 2 + + if retries > self.max_retries: + raise BlobFetchError(f"HTTP fetch failed: Max retries exceeded") + + + def _handle_retry(self, retries: int, backoff_sec: float, exc: Exception): + if retries >= self.max_retries: + LOGGER.error("Max retries (%d) reached. Aborting.", self.max_retries) + raise BlobFetchError("HTTP fetch failed: Max retries exceeded") from exc + LOGGER.info("Backing off for %f seconds...", backoff_sec) + time.sleep(backoff_sec) class FileFetcher(AbstractBlobFetcher): diff --git a/clientlib/python/tests/core/blob/test_fetcher.py b/clientlib/python/tests/core/blob/test_fetcher.py index ce114b9d54..3956a78776 100644 --- a/clientlib/python/tests/core/blob/test_fetcher.py +++ b/clientlib/python/tests/core/blob/test_fetcher.py @@ -56,16 +56,16 @@ def test_data_fetcher_raises_decode_error(data_fetcher): @pytest.fixture def http_fetcher(): - return HttpFetcher() + return HttpFetcher(timeout_sec=1, max_retries=1) @patch("requests.get") def test_http_fetch_success(mock_get, http_fetcher): """Verifies successful HTTP GET.""" mock_response = MagicMock() - mock_response.content = b"http_data" + mock_response.iter_content.return_value = [b"http_data"] mock_response.status_code = 200 - mock_get.return_value = mock_response + mock_get.return_value.__enter__.return_value = mock_response result = http_fetcher.fetch("http://example.com/blob") @@ -77,15 +77,17 @@ def test_http_fetch_success(mock_get, http_fetcher): def test_http_fetch_http_error(mock_get, http_fetcher): """Verifies 404/500 errors raise BlobFetchError.""" mock_response = MagicMock() - mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") - mock_get.return_value = mock_response + mock_response.status_code = 404 + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("404 Not Found") + mock_get.return_value.__enter__.return_value = mock_response with pytest.raises(BlobFetchError, match="HTTP fetch failed"): http_fetcher.fetch("http://example.com/missing") @patch("requests.get") -def test_http_fetch_connection_error(mock_get, http_fetcher): +@patch("time.sleep") +def test_http_fetch_connection_error(mock_sleep, mock_get, http_fetcher): """Verifies connection issues raise BlobFetchError.""" mock_get.side_effect = requests.ConnectionError("Name resolution failure") @@ -93,57 +95,19 @@ def test_http_fetch_connection_error(mock_get, http_fetcher): http_fetcher.fetch("http://bad-host.com") -@patch("os.fsync") -@patch("shutil.copyfileobj") -@patch("tempfile.NamedTemporaryFile") @patch("requests.get") @patch("os.replace") -@patch("os.chmod") -def test_http_download_to_file_streaming( - mock_chmod, mock_replace, mock_get, mock_tempfile, mock_copy, mock_fsync, http_fetcher -): +def test_http_download_to_file_streaming(mock_replace, mock_get, http_fetcher): """ Verifies that download_to_file streams data to a temp file and renames it. """ mock_response = MagicMock() - mock_response.raw = MagicMock() + mock_response.iter_content.return_value = [b"chunk1", b"chunk2"] + mock_response.status_code = 200 mock_get.return_value.__enter__.return_value = mock_response - mock_tmp = MagicMock() - mock_tmp.name = "/tmp/random_tmp_file" - mock_tmp.fileno.return_value = 123 - mock_tempfile.return_value.__enter__.return_value = mock_tmp - - http_fetcher.download_to_file("http://site.com/large.bin", "/var/lib/final.bin") + http_fetcher.download_to_file("http://site.com/large.bin", "/tmp/final.bin") mock_get.assert_called_with("http://site.com/large.bin", stream=True, timeout=ANY, headers=ANY) + mock_replace.assert_called_once() - mock_copy.assert_called_with(mock_response.raw, mock_tmp) - - mock_tmp.flush.assert_called_once() - mock_fsync.assert_called_once_with(123) - - mock_replace.assert_called_with("/tmp/random_tmp_file", "/var/lib/final.bin") - - -# --- FileFetcher Tests --- - -@pytest.fixture -def file_fetcher(): - return FileFetcher() - - -def test_file_fetch_reads_content(file_fetcher): - """Verifies reading a local file.""" - with patch("builtins.open", mock_open(read_data=b"local_content")) as mock_file: - result = file_fetcher.fetch("file:///etc/config.json") - - assert result == b"local_content" - mock_file.assert_called_with("/etc/config.json", "rb") - - -def test_file_fetch_missing_file(file_fetcher): - """Verifies FileNotFoundError is wrapped in BlobFetchError.""" - with patch("builtins.open", side_effect=FileNotFoundError("No entry")): - with pytest.raises(BlobFetchError, match="File fetch failed"): - file_fetcher.fetch("file:///missing.txt") \ No newline at end of file diff --git a/spotter/README.md b/spotter/README.md new file mode 100644 index 0000000000..55500283ee --- /dev/null +++ b/spotter/README.md @@ -0,0 +1,52 @@ +# Spotter - UDMI Reference Client + +Spotter is an on-premise, Python-based reference client for UDMI. It is designed to act as an actual UDMI compliant device and a virtualized test target for the Sequencer CI framework. + +Spotter implements an extensible architecture and handles Over-The-Air (OTA) updates using the UDMI `blobset` protocol, utilizing a Git-based update strategy. + +## Key Features + +1. **Extensible Architecture**: Based on the standard UDMI Python library (`clientlib/python/src/udmi/core`), allowing for easy extension of managers and handlers. +2. **Robust OTA Updates**: + - Supports out-of-band downloading via standard HTTP(S). + - Handles resumable downloads via HTTP `Range` requests and implements exponential backoff for transient failures (e.g., HTTP 503). + - Immediately aborts on fatal authorization/network failures (e.g., HTTP 403, 404). +3. **Git-Based OTA Updates**: + - The payload specifies the target Git commit hash. + - Spotter fetches the remote repository and extracts a manifest (`spotter_manifest.json`) directly from the target commit using `git show`. + - Validates hardware make/model and software dependencies against the downloaded manifest *before* checking out the code. + - If validation passes, Spotter switches to the target commit and triggers a simulated restart. + +## Usage + +You can run Spotter using basic MQTT credentials or JWT Authentication. + +### Basic Auth +```bash +python -m spotter.spotter.main \ + --client_id projects/my-project/locations/us-central1/registries/reg/devices/AHU-1 \ + --hostname mqtt.googleapis.com \ + --port 8883 \ + --username my_user \ + --password my_password +``` + +### JWT Auth +```bash +python -m spotter.spotter.main \ + --client_id projects/my-project/locations/us-central1/registries/reg/devices/AHU-1 \ + --hostname mqtt.googleapis.com \ + --port 8883 \ + --jwt_audience my-project \ + --key_file /path/to/rsa_private.pem +``` + +## Running Tests + +To run the unit tests, ensure you have the `udmi` Python client library installed or set in your PYTHONPATH: + +```bash +export PYTHONPATH="../clientlib/python/src:../gencode/python:." +cd ../clientlib/python +poetry run pytest ../../spotter/tests/ +``` diff --git a/spotter/__init__.py b/spotter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/spotter/spec.md b/spotter/spec.md new file mode 100644 index 0000000000..78fa9ea348 --- /dev/null +++ b/spotter/spec.md @@ -0,0 +1,73 @@ +## Technical Specification: Spotter (Python UDMI Reference Client) + +**Spotter** is an extensible, Python-based reference client designed as a fully UDMI-compliant IoT device intended for on-premise deployment. While it serves as a robust test target for the Sequencer CI framework to validate OTA (Over-The-Air) update orchestration, its architecture is modular and production-ready, making it capable of handling real-world deployments and complex update scenarios natively. + +--- + +### 1. Core Objective + +To implement an atomic, configuration-driven execution handler that processes modular component updates (targeting `system.software.`) using the UDMI `blobset` protocol, while supporting extensible device capabilities (such as JWT authentication, telemetry logging, and dynamic dependency validation). + +--- + +### 2. Functional Requirements + +#### A. Atomic State Machine Implementation + +Spotter inherently transitions through the standardized UDMI update phases: + +* **Idle / Steady State**: Reports currently running module versions in `system.software`. +* **Apply Phase**: Upon receiving a valid `blobset` config, Spotter acknowledges by updating its state to `phase: apply` and initiating an out-of-band download. +* **Final Phase**: Reports `phase: final` upon successful application or fatal failure. + +#### B. Payload Delivery & Validation + +Spotter leverages the robust underlying UDMI Python library for reliable payload delivery: + +* **Out-of-Band Download**: Retrieves payloads via HTTP(S) using URLs provided in the configuration. +* **Resumable Downloads**: Supports standard HTTP(S) `Range` requests to handle constrained network drops, falling back gracefully if the server ignores the Range header. +* **Cryptographic Verification**: Securely calculates the local SHA256 hash of the downloaded payload and verifies it against the mandatory 64-character hash in the cloud configuration before execution. + +#### C. Git-Based Update Strategy (On-Premise Ready) + +Spotter utilizes a real-world, Git-based update strategy for self-updating on-premise: + +* The cloud payload provides a target Git commit hash. +* Before applying the update, Spotter fetches the remote repository and extracts a manifest (`spotter_manifest.json`) directly from the target commit using `git show`. +* It cross-references hardware requirements and dependencies from the target manifest against its current local state. +* Updates are applied natively by executing a `git checkout ` and safely restarting the service via OS-level signals (e.g., `sys.exit(0)` for `systemd` recovery). + +--- + +### 3. Error Taxonomy & Handling + +Spotter strictly categorizes errors at both the network and application layers to prevent "bricking" or infinite retry loops. + +| Error Type | Scenarios | Required Action | +| :--- | :--- | :--- | +| **Retryable** | Transient network drops, HTTP 503 | Handled natively by the UDMI fetcher via local exponential backoff and retry. | +| **Fatal (Auth/Net)** | Expired Signed URL, HTTP 401/403/404 | Abort installation immediately and report level 500 `ERROR`. | +| **Fatal (Integrity)** | SHA256 Hash Mismatch | UDMI library securely discards the file, aborts, and reports a level 500 `ERROR`. | +| **Fatal (Logic)** | Hardware mismatch, missing manifest, or dependency conflict | Reject payload before `git checkout`, abort execution, and report level 500 `ERROR`. | + +--- + +### 4. Telemetry & Observability + +Spotter provides robust closed-loop visibility by publishing system milestones to the `events/system` MQTT pipeline: + +* **Standardized Logs**: Directly logs `blobset.download.start`, `blobset.hash.verify`, and `blobset.apply.success` during the update lifecycle. +* **Decoupled Reporting**: Automatically attaches the `UDMIMqttLogHandler` to the root device logger. If an HTTP download or `git` operation fails, the resulting OS-level error logs are seamlessly routed through the primary MQTT telemetry channel as `SystemEvent` metrics, ensuring the cloud orchestrator is notified independently of the standard `state` update. + +--- + +### 5. Compliance Checklist for Sequencer CI + +Spotter guarantees compliance with the Sequencer CI framework by passing these six automated scenarios: + +1. **Happy Path**: Successful download, hash match, dependency validation, and Git version update. +2. **Hash Mismatch**: Detection of corrupted SHA256 and secure file deletion. +3. **Invalid URL**: Handling of 403/404 errors without attempting installation or application logic. +4. **Hardware Mismatch**: Rejection of incorrect bundles (e.g., wrong controller type mapped against the fetched Git manifest). +5. **Corrupted Payload**: Trapping OS-level execution exceptions for malformed binaries or missing manifest files within the target Git commit. +6. **Dependency Mismatch**: Validating that new modules described in the remote target manifest are strictly compatible with existing local dependencies. diff --git a/spotter/spotter/__init__.py b/spotter/spotter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/spotter/spotter/core/__init__.py b/spotter/spotter/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/spotter/spotter/core/constants.py b/spotter/spotter/core/constants.py new file mode 100644 index 0000000000..6740ac05c4 --- /dev/null +++ b/spotter/spotter/core/constants.py @@ -0,0 +1,5 @@ +# Hardware definitions for Spotter +SPOTTER_MAKE = "PyUDMI" +SPOTTER_MODEL = "Spotter-v1" +SPOTTER_FIRMWARE_VER = "1.0.0" +SPOTTER_DEPENDENCIES = {"libA": "v1.0", "libB": "v2.0"} diff --git a/spotter/spotter/core/device.py b/spotter/spotter/core/device.py new file mode 100644 index 0000000000..041ec1652b --- /dev/null +++ b/spotter/spotter/core/device.py @@ -0,0 +1,55 @@ +import logging +from typing import Optional + +from udmi.core.factory import create_device +from udmi.core.managers import SystemManager +from udmi.schema import EndpointConfiguration, SystemState, StateSystemHardware + +from spotter.ota.handler import SpotterOTAHandler +from spotter.core.constants import SPOTTER_MAKE, SPOTTER_MODEL, SPOTTER_FIRMWARE_VER, SPOTTER_DEPENDENCIES + +LOGGER = logging.getLogger(__name__) + +class SpotterDevice: + """ + Spotter Device Implementation. + Spotter is a UDMI reference client running on-prem that can handle + OTA updates and extensible capabilities over the Sequencer CI framework. + """ + def __init__(self, endpoint: EndpointConfiguration, persist_path: str = "/tmp/spotter_persist.json", key_file: Optional[str] = None): + self.endpoint = endpoint + self.persist_path = persist_path + + # Define static device identity + self.static_info = SystemState( + hardware=StateSystemHardware(make=SPOTTER_MAKE, model=SPOTTER_MODEL), + serial_no="SPOTTER-001", + software={"firmware": SPOTTER_FIRMWARE_VER} + ) + + self.device = create_device( + endpoint, + system_state=self.static_info, + persistence_path=self.persist_path, + key_file=key_file + ) + self.sys_manager = self.device.get_manager(SystemManager) + + # Initialize the OTA Handler + self.ota_handler = SpotterOTAHandler( + hardware_make=SPOTTER_MAKE, + hardware_model=SPOTTER_MODEL, + current_dependencies=SPOTTER_DEPENDENCIES + ) + + # Register the OTA blob handler for 'firmware' + self.sys_manager.register_blob_handler( + "firmware", + process=self.ota_handler.process, + post_process=self.ota_handler.post_process, + expects_file=False + ) + + def run(self): + LOGGER.info("Spotter running. Waiting for OTA...") + self.device.run() diff --git a/spotter/spotter/main.py b/spotter/spotter/main.py new file mode 100644 index 0000000000..0cbff4e096 --- /dev/null +++ b/spotter/spotter/main.py @@ -0,0 +1,56 @@ +import argparse +import logging +from udmi.schema import EndpointConfiguration, AuthProvider, Basic +from udmi.core.logging.mqtt_handler import UDMIMqttLogHandler +from spotter.core.device import SpotterDevice + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') +LOGGER = logging.getLogger("SpotterMain") + +def main(): + parser = argparse.ArgumentParser(description="Spotter - UDMI Python Reference Client") + parser.add_argument("--client_id", required=True, help="MQTT Client ID") + parser.add_argument("--hostname", required=True, help="MQTT Broker Hostname") + parser.add_argument("--port", type=int, default=8883, help="MQTT Broker Port") + parser.add_argument("--topic_prefix", default="", help="MQTT Topic Prefix") + + # Basic Auth + parser.add_argument("--username", help="MQTT Username") + parser.add_argument("--password", help="MQTT Password") + + # JWT Auth + parser.add_argument("--jwt_audience", help="JWT Audience") + parser.add_argument("--key_file", help="Path to RSA/ES private key file for JWT Auth") + + args = parser.parse_args() + + auth_provider = None + if args.username and args.password: + auth_provider = AuthProvider(basic=Basic(username=args.username, password=args.password)) + elif args.jwt_audience: + auth_provider = AuthProvider(jwt={"audience": args.jwt_audience}) + + endpoint = EndpointConfiguration( + client_id=args.client_id, + hostname=args.hostname, + port=args.port, + topic_prefix=args.topic_prefix, + auth_provider=auth_provider + ) + + spotter = SpotterDevice(endpoint, key_file=args.key_file) + + # Attach UDMI MQTT Log Handler to root logger to route all logs (including exceptions) + # to the cloud via 'events/system' telemetry stream. + mqtt_log_handler = UDMIMqttLogHandler(spotter.sys_manager) + logging.getLogger().addHandler(mqtt_log_handler) + + try: + spotter.run() + except SystemExit: + LOGGER.info("Spotter shutdown successfully (Restart Triggered).") + except KeyboardInterrupt: + LOGGER.info("Stopped by user.") + +if __name__ == "__main__": + main() diff --git a/spotter/spotter/ota/__init__.py b/spotter/spotter/ota/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/spotter/spotter/ota/handler.py b/spotter/spotter/ota/handler.py new file mode 100644 index 0000000000..f8a77d0a5e --- /dev/null +++ b/spotter/spotter/ota/handler.py @@ -0,0 +1,95 @@ +import json +import logging +import os +import subprocess +import sys +from typing import Any, Dict + +LOGGER = logging.getLogger(__name__) + +class SpotterOTAHandler: + """ + Handles Over-The-Air (OTA) updates for the Spotter device. + Uses a Git-based update strategy where the payload specifies the target commit hash. + It reads a manifest file from the target commit to verify hardware requirements and dependencies. + """ + + def __init__(self, hardware_make: str, hardware_model: str, current_dependencies: Dict[str, str]): + self.hardware_make = hardware_make + self.hardware_model = hardware_model + self.current_dependencies = current_dependencies + self.staging_file = "/tmp/spotter_update.json" + self.manifest_file_path = "spotter_manifest.json" + + def process(self, blob_key: str, data: bytes) -> str: + """ + STAGE 1: PROCESS + Validates the downloaded payload. + Expected Payload format: Raw bytes representing the commit hash (e.g., b'abcd123'). + """ + commit_hash = data.decode("utf-8").strip() + LOGGER.info(f"Processing OTA update for blob '{blob_key}'. Target commit: {commit_hash}") + + if not commit_hash: + raise ValueError("Payload did not contain a valid commit hash.") + + try: + # 1. Fetch the latest commits to ensure we have the target hash locally + LOGGER.info("Fetching from git remote to find target commit...") + subprocess.run(["git", "fetch", "origin"], check=True, capture_output=True) + + # 2. Extract the manifest file directly from the target commit using git show + LOGGER.info(f"Extracting {self.manifest_file_path} from commit {commit_hash}...") + result = subprocess.run( + ["git", "show", f"{commit_hash}:{self.manifest_file_path}"], + check=True, capture_output=True, text=True + ) + + manifest = json.loads(result.stdout) + + except subprocess.CalledProcessError as e: + LOGGER.error(f"Git operation failed. Cannot find commit or manifest: {e.stderr}") + raise ValueError(f"Git target error: Unable to verify commit {commit_hash} or read manifest.") from e + except json.JSONDecodeError as e: + LOGGER.error(f"Corrupted payload (manifest is not valid JSON): {e}") + raise RuntimeError(f"Corrupted payload: {e}") from e + + # 3. Validate hardware mismatch + if manifest.get("hardware_make") != self.hardware_make or manifest.get("hardware_model") != self.hardware_model: + LOGGER.error("Hardware mismatch detected.") + raise ValueError(f"Hardware mismatch: Expected {self.hardware_make} {self.hardware_model}") + + # 4. Validate dependency mismatch + target_deps = manifest.get("dependencies", {}) + for dep, req_ver in target_deps.items(): + curr_ver = self.current_dependencies.get(dep) + if curr_ver != req_ver: + LOGGER.error(f"Dependency mismatch detected for {dep}. Required {req_ver}, got {curr_ver}.") + raise ValueError(f"Dependency mismatch: Incompatible with local dependencies ({dep}).") + + LOGGER.info("Payload and target manifest validation passed. Ready for apply.") + return commit_hash + + def post_process(self, blob_key: str, commit_hash: str): + """ + STAGE 2: POST-PROCESS + Applies the update by switching the Git commit hash and restarting the service. + """ + LOGGER.info("STAGE 2: POST-PROCESS (State has been flushed!)") + + # 1. Standard log for telemetry + LOGGER.info("blobset.apply.success") + + # 2. Execute Git commands + LOGGER.info(f"Switching local Git commit hash to {commit_hash}...") + try: + subprocess.run(["git", "checkout", commit_hash], check=True, capture_output=True) + LOGGER.info(f"Successfully checked out {commit_hash}") + except subprocess.CalledProcessError as e: + LOGGER.error(f"Git update failed: {e.stderr.decode()}") + # Revert or handle failure + return + + # 3. Simulate Restart or actually restart the service via OS + LOGGER.warning("INITIATING SYSTEM RESTART...") + sys.exit(0) diff --git a/spotter/tests/__init__.py b/spotter/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/spotter/tests/test_ota_handler.py b/spotter/tests/test_ota_handler.py new file mode 100644 index 0000000000..05cbbe9ad4 --- /dev/null +++ b/spotter/tests/test_ota_handler.py @@ -0,0 +1,99 @@ +import json +import pytest +import subprocess +from unittest.mock import patch, MagicMock + +from spotter.spotter.ota.handler import SpotterOTAHandler + +@pytest.fixture +def ota_handler(): + handler = SpotterOTAHandler( + hardware_make="PyUDMI", + hardware_model="Spotter-v1", + current_dependencies={"libA": "v1.0"} + ) + handler.manifest_file_path = "spotter_manifest.json" + return handler + +def _mock_git_show(commit_hash, manifest_content): + def side_effect(cmd, **kwargs): + if "fetch" in cmd: + return MagicMock() + if "show" in cmd and commit_hash in cmd[2]: + result = MagicMock() + result.stdout = json.dumps(manifest_content) + return result + raise subprocess.CalledProcessError(1, cmd, stderr=b"Git error") + return side_effect + +@patch("subprocess.run") +def test_1_happy_path(mock_run, ota_handler): + manifest = { + "hardware_make": "PyUDMI", + "hardware_model": "Spotter-v1", + "dependencies": {"libA": "v1.0"} + } + mock_run.side_effect = _mock_git_show("abcd123", manifest) + + # Process + commit = ota_handler.process("firmware", b"abcd123") + assert commit == "abcd123" + + # Post Process + mock_run.reset_mock() + mock_run.side_effect = None # Remove side effect for post process + with patch("sys.exit") as mock_exit: + ota_handler.post_process("firmware", commit) + mock_run.assert_called_once() + assert "checkout" in mock_run.call_args[0][0] + assert "abcd123" in mock_run.call_args[0][0] + mock_exit.assert_called_once_with(0) + +@patch("subprocess.run") +def test_2_hardware_mismatch(mock_run, ota_handler): + manifest = { + "hardware_make": "WRONG", + "hardware_model": "Spotter-v1", + "dependencies": {"libA": "v1.0"} + } + mock_run.side_effect = _mock_git_show("abcd123", manifest) + + with pytest.raises(ValueError, match="Hardware mismatch"): + ota_handler.process("firmware", b"abcd123") + +@patch("subprocess.run") +def test_3_dependency_mismatch(mock_run, ota_handler): + manifest = { + "hardware_make": "PyUDMI", + "hardware_model": "Spotter-v1", + "dependencies": {"libA": "v2.0"} # requires v2.0 but we have v1.0 + } + mock_run.side_effect = _mock_git_show("abcd123", manifest) + + with pytest.raises(ValueError, match="Dependency mismatch"): + ota_handler.process("firmware", b"abcd123") + +@patch("subprocess.run") +def test_4_corrupted_payload(mock_run, ota_handler): + # Mock git show returning non-JSON + def side_effect(cmd, **kwargs): + if "show" in cmd: + res = MagicMock() + res.stdout = "NOT_JSON" + return res + return MagicMock() + + mock_run.side_effect = side_effect + + with pytest.raises(RuntimeError, match="Corrupted payload"): + ota_handler.process("firmware", b"abcd123") + +@patch("subprocess.run") +def test_5_git_failure_on_post_process(mock_run, ota_handler): + # Post Process with git failure + mock_run.side_effect = subprocess.CalledProcessError(1, ["git"], stderr=b"fatal: cannot parse object") + + with patch("sys.exit") as mock_exit: + # Should catch and handle gracefully without exiting + ota_handler.post_process("firmware", "abcd123") + mock_exit.assert_not_called()