From 1f2c6804c7f1f284de67a44d81287cca6b3ba77c Mon Sep 17 00:00:00 2001 From: Karishma Kumar Date: Wed, 11 Mar 2026 16:00:29 +0530 Subject: [PATCH 1/7] [feature] implement store annotated data framework Provides a thread-safe, non-blocking ``DataCollector.record()`` call that any Odemis module can invoke to capture a labelled data sample. CLI interface is implemented to download the data from the cloud storage. --- debian/control | 2 + scripts/odemis-dc-fetch.py | 26 + src/odemis/gui/cont/menu.py | 23 +- src/odemis/gui/main.py | 28 +- src/odemis/gui/win/consent.py | 59 ++ src/odemis/util/datacollector.py | 724 +++++++++++++++++++++ src/odemis/util/dc_fetch.py | 299 +++++++++ src/odemis/util/test/datacollector_test.py | 623 ++++++++++++++++++ src/odemis/util/test/dc_fetch_test.py | 195 ++++++ 9 files changed, 1977 insertions(+), 2 deletions(-) create mode 100755 scripts/odemis-dc-fetch.py create mode 100644 src/odemis/gui/win/consent.py create mode 100644 src/odemis/util/datacollector.py create mode 100644 src/odemis/util/dc_fetch.py create mode 100644 src/odemis/util/test/datacollector_test.py create mode 100644 src/odemis/util/test/dc_fetch_test.py diff --git a/debian/control b/debian/control index 164bb886f9..ff920b5ab1 100644 --- a/debian/control +++ b/debian/control @@ -75,6 +75,8 @@ Depends: ${shlibs:Depends}, python3-libusb1, # Needed for acq.fastem python3-shapely, +# Needed for communicating with AWS S3 + python3-boto3, Suggests: imagej Description: Open Delmic Microscope Software Odemis is the acquisition software for the Delmic microscopes. In particular, diff --git a/scripts/odemis-dc-fetch.py b/scripts/odemis-dc-fetch.py new file mode 100755 index 0000000000..cfb2fc293c --- /dev/null +++ b/scripts/odemis-dc-fetch.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""Fetch DataCollector ZIP samples from S3. + +Examples: + ./scripts/odemis-dc-fetch.py + ./scripts/odemis-dc-fetch.py --output ./downloads + ./scripts/odemis-dc-fetch.py --event z_stack_acquired + ./scripts/odemis-dc-fetch.py --since 2026-03-01 + ./scripts/odemis-dc-fetch.py --host meteor-5099 + ./scripts/odemis-dc-fetch.py --host meteor-5099,atlas-001,secom-22 + ./scripts/odemis-dc-fetch.py --bucket delmic-odemis-collect-test --region eu-west-1 + ./scripts/odemis-dc-fetch.py --since 2026-03-01T12:30:00 --event z_stack_acquired --output ./dc_samples +""" + +import logging +import sys + +from odemis.util.dc_fetch import main + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + rc = main(sys.argv[1:]) + logging.shutdown() + sys.exit(rc) diff --git a/src/odemis/gui/cont/menu.py b/src/odemis/gui/cont/menu.py index 8d8da61956..18bb85af75 100644 --- a/src/odemis/gui/cont/menu.py +++ b/src/odemis/gui/cont/menu.py @@ -28,6 +28,7 @@ from odemis.gui.model import CHAMBER_VACUUM, CHAMBER_UNKNOWN from odemis.gui.model.dye import DyeDatabase from odemis.gui.util import call_in_wx_main +from odemis.util.datacollector import DataCollector from odemis.util import driver import os import subprocess @@ -42,7 +43,7 @@ class MenuController(object): tab controller. """ - def __init__(self, main_data, main_frame): + def __init__(self, main_data, main_frame, data_collector: DataCollector): """ Binds the menu actions. main_data (MainGUIData): the representation of the microscope GUI @@ -50,6 +51,7 @@ def __init__(self, main_data, main_frame): """ self._main_data = main_data self._main_frame = main_frame + self._data_collector = data_collector # /File # /File/Open... @@ -153,6 +155,9 @@ def __init__(self, main_data, main_frame): # /Help/About main_frame.Bind(wx.EVT_MENU, self._on_about, id=main_frame.menu_item_about.GetId()) + self._consent_menu_item = self._append_data_sharing_menu_item(main_frame) + if self._consent_menu_item is not None: + main_frame.Bind(wx.EVT_MENU, self._on_toggle_data_sharing, id=self._consent_menu_item.GetId()) # add a toggle for correlation tab in viewer mode if main_data.is_viewer: @@ -163,6 +168,22 @@ def __init__(self, main_data, main_frame): menu.Remove(main_frame.menu_item_show_correlation) main_frame.menu_item_show_correlation.Destroy() + def _append_data_sharing_menu_item(self, main_frame): + """Append and initialize Help menu checkbox for data sharing consent.""" + help_menu = main_frame.menu_item_about.GetMenu() + if help_menu is None: + return None + help_menu.AppendSeparator() + item = help_menu.AppendCheckItem(wx.ID_ANY, "Share data with Delmic") + item.Check(self._data_collector.get_consent() is True) + return item + + def _on_toggle_data_sharing(self, evt): + """Toggle data sharing consent from Help menu check item.""" + enabled = self._consent_menu_item.IsChecked() + self._data_collector.set_consent(enabled) + evt.Skip() + def _on_update(self, evt): import odemis.gui.util.updater as updater u = updater.WindowsUpdater() diff --git a/src/odemis/gui/main.py b/src/odemis/gui/main.py index cb5bf09418..2e2e2daca9 100755 --- a/src/odemis/gui/main.py +++ b/src/odemis/gui/main.py @@ -30,8 +30,10 @@ from odemis.gui.cont import acquisition from odemis.gui.cont.menu import MenuController from odemis.gui.cont.temperature import TemperatureController +from odemis.gui.win.consent import ConsentDialog from odemis.gui.util import call_in_wx_main from odemis.gui.xmlh import odemis_get_resources +from odemis.util.datacollector import DataCollector import sys import threading import traceback @@ -73,6 +75,7 @@ def __init__(self, standalone=False, file_name=None): self._snapshot_controller = None self._temperature_controller = None self._menu_controller = None + self._data_collector = DataCollector() self.plugins = [] # List of instances of plugin.Plugins # User input devices @@ -344,7 +347,7 @@ def toggle_log_panel(_): self.main_data.level.subscribe(self.on_level_va, init=True) log.create_gui_logger(self.main_frame.txt_log, self.main_data.debug, self.main_data.level) - self._menu_controller = MenuController(self.main_data, self.main_frame) + self._menu_controller = MenuController(self.main_data, self.main_frame, self._data_collector) # Menu events self.main_frame.Bind(wx.EVT_MENU, self.on_close_window, id=self.main_frame.menu_item_quit.GetId()) @@ -375,6 +378,8 @@ def toggle_log_panel(_): # Due to a bug in wxPython, sometimes the .Maximize() at the beginning of the function # has no effect. So we call it after the Show() to be sure it works. wx.CallAfter(self.main_frame.Maximize) + wx.CallAfter(self._maybe_show_consent_dialog) + except Exception: self.excepthook(*sys.exc_info()) # Re-raise the exception, so the program will exit. If this is not @@ -416,6 +421,27 @@ def on_level_va(self, log_level): if hasattr(tab.panel, 'btn_log'): tab.panel.btn_log.set_face_colour(colour) + @call_in_wx_main + def _maybe_show_consent_dialog(self) -> None: + """Show consent dialog when consent is undecided and prompt is due.""" + try: + if not self._data_collector.should_prompt_for_consent(): + return + + remind_days = self._data_collector.get_consent_remind_days() + dlg = ConsentDialog(parent=self.main_frame, remind_days=remind_days) + result = dlg.ShowModal() + dlg.Destroy() + + if result == ConsentDialog.RESULT_OPT_IN: + self._data_collector.set_consent(True) + elif result == ConsentDialog.RESULT_OPT_OUT: + self._data_collector.set_consent(False) + else: + self._data_collector.postpone_consent() + except Exception: + logging.exception("Failed to run data-collection consent prompt.") + def on_close_window(self, evt=None): """ This method cleans up and closes the Odemis GUI. """ logging.info("Exiting Odemis") diff --git a/src/odemis/gui/win/consent.py b/src/odemis/gui/win/consent.py new file mode 100644 index 0000000000..0266089fec --- /dev/null +++ b/src/odemis/gui/win/consent.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +""" +Consent dialog for Odemis data collection. +""" + +import wx + + +class ConsentDialog(wx.Dialog): + """Dialog asking user consent for data sharing.""" + + RESULT_OPT_IN = int(wx.NewIdRef()) + RESULT_OPT_OUT = int(wx.NewIdRef()) + RESULT_REMIND_LATER = int(wx.NewIdRef()) + + def __init__(self, parent: wx.Window, remind_days: int) -> None: + title = "Share data with Delmic" + super().__init__(parent, wx.ID_ANY, title=title, size=(560, -1)) + + sizer = wx.BoxSizer(wx.VERTICAL) + message = ( + "Help improve Odemis by sharing anonymized diagnostic and measurement data.\n\n" + "You can change this choice later from Help > Share data with Delmic." + ) + label = wx.StaticText(self, wx.ID_ANY, message) + label.Wrap(520) + sizer.Add(label, 0, wx.ALL | wx.EXPAND, 12) + + button_sizer = wx.BoxSizer(wx.HORIZONTAL) + btn_opt_out = wx.Button(self, wx.ID_ANY, "Opt out") + btn_remind_later = wx.Button(self, wx.ID_ANY, f"Remind me in {remind_days} days") + btn_opt_in = wx.Button(self, wx.ID_ANY, "Opt in") + btn_opt_out.SetDefault() + + btn_opt_in.Bind(wx.EVT_BUTTON, self._on_opt_in) + btn_opt_out.Bind(wx.EVT_BUTTON, self._on_opt_out) + btn_remind_later.Bind(wx.EVT_BUTTON, self._on_remind_later) + self.Bind(wx.EVT_CLOSE, self._on_close) + + button_sizer.Add(btn_opt_out, 0, wx.RIGHT, 8) + button_sizer.Add(btn_remind_later, 0, wx.RIGHT, 8) + button_sizer.Add(btn_opt_in, 0) + sizer.Add(button_sizer, 0, wx.ALL | wx.ALIGN_RIGHT, 12) + + self.SetSizer(sizer) + sizer.Fit(self) + self.CentreOnParent() + + def _on_opt_in(self, _evt: wx.CommandEvent) -> None: + self.EndModal(self.RESULT_OPT_IN) + + def _on_opt_out(self, _evt: wx.CommandEvent) -> None: + self.EndModal(self.RESULT_OPT_OUT) + + def _on_remind_later(self, _evt: wx.CommandEvent) -> None: + self.EndModal(self.RESULT_REMIND_LATER) + + def _on_close(self, _evt: wx.CloseEvent) -> None: + self.EndModal(self.RESULT_REMIND_LATER) diff --git a/src/odemis/util/datacollector.py b/src/odemis/util/datacollector.py new file mode 100644 index 0000000000..a3cec680b8 --- /dev/null +++ b/src/odemis/util/datacollector.py @@ -0,0 +1,724 @@ +# -*- coding: utf-8 -*- +""" +Created on 11 March 2026 + +@author: Karishma Kumar + +Copyright © 2026 Karishma Kumar, Delmic + +This file is part of Odemis. + +Odemis is free software: you can redistribute it and/or modify it under the +terms of the GNU General Public License version 2 as published by the Free +Software Foundation. + +Odemis is distributed in the hope that it will be useful, but WITHOUT ANY +WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. See the GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along with +Odemis. If not, see http://www.gnu.org/licenses/. + +Odemis Annotated Data Collection Framework. + +Provides a thread-safe, non-blocking ``DataCollector.record()`` call that any +Odemis module can invoke to capture a labelled data sample. Serialisation +happens asynchronously in a background daemon thread; the caller returns +immediately. +""" + +import configparser +import json +import logging +import os +import queue +import shutil +import socket +import tempfile +import threading +import time +import uuid +import zipfile +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Optional + +try: + import boto3 +except ImportError: + logging.error("boto3 is required for S3 upload functionality; install with 'sudo apt install python3-boto3'") + raise +import numpy + +import odemis +from odemis import model +from odemis.dataio import hdf5 +from odemis.dataio import tiff + + +# S3 bucket name — shared production bucket, created once by the dev team. +S3_BUCKET = "delmic-odemis-collect" + +# S3 bucket used for automated tests (not the production bucket). +S3_TEST_BUCKET = "delmic-odemis-collect-test" + +# S3 endpoint URL — None means let boto3 resolve the regional endpoint automatically. +# Set explicitly only for custom S3-compatible storage. +S3_ENDPOINT_URL = None +S3_REGION = "eu-west-1" + +# Path to the S3 credentials key file (JSON with access_key / secret_key). +_CREDENTIALS_PATH = "/usr/share/odemis/datacollector.key" + +# Default paths +_CONF_DIR = os.path.join(os.path.expanduser("~"), ".config", "odemis") +_DEFAULT_QUEUE_DIR = Path("/var/log/odemis/dc_queue") + +_VALID_IMAGE_FORMATS = ("TIFF", "HDF5") +_INITIAL_RETRY_DELAY_SECONDS = 30.0 +_MAX_RETRY_DELAY_SECONDS = 3600.0 +# Default delay before re-prompting for consent after postpone. +_DEFAULT_CONSENT_REMIND_DELTA = timedelta(days=30) +REMINDER_DATE_KEY = "reminder_date" + + +def _search_credentials() -> dict: + """ + Load S3 credentials from the standard key-file location. + The key file is a JSON file containing ``access_key`` and ``secret_key``. + :returns: Dict with ``access_key`` and ``secret_key``. + :raises LookupError: If the key file is not found at the expected location. + """ + if not os.path.isfile(_CREDENTIALS_PATH): + raise LookupError( + f"S3 credentials key file not found at {_CREDENTIALS_PATH}" + ) + with open(_CREDENTIALS_PATH, "r") as fh: + data = json.load(fh) + return { + "access_key": data["access_key"], + "secret_key": data["secret_key"], + } + + +class DataCollectorConfig: + """Persistent configuration for the data-collection framework. + + Backed by a ``configparser`` INI file at + ``~/.config/odemis/datacollector.config``. + + Sections + -------- + ``[general]`` + ``consent`` — ``true`` / ``false``/``none`` (not yet decided). + ``reminder_date`` — Date (``YYYY-MM-DD``) after which to re-prompt. + Commented out when not applicable. + + The file is written in a human-readable format with inline comments so it + can be inspected and manually edited by a support engineer. Example:: + + [general] + # Data sharing consent: true, false, or commented-out (not yet decided). + consent = none + # + # Date after which the consent dialog will be shown again (YYYY-MM-DD). + # reminder_date = 2026-05-07 + """ + + file_name: str = "datacollector.config" + + def __init__(self) -> None: + self.file_path = Path(_CONF_DIR) / self.file_name + self._cp = configparser.ConfigParser(interpolation=None) + self._lock = threading.Lock() + self._read() + + def _read(self) -> None: + """Read the config file if it exists; otherwise leave defaults.""" + if self.file_path.exists(): + self._cp.read(str(self.file_path)) + else: + logging.info("No datacollector config found; using defaults.") + + def _write(self) -> None: + """Write the config file with human-readable comments. + + ``consent`` is always present as ``none`` / ``true`` / ``false``. + ``reminder_date`` is commented out until explicitly set, then written + in ``YYYY-MM-DD`` format. + """ + self.file_path.parent.mkdir(parents=True, exist_ok=True) + + consent_val = self.consent + remind_val = self.remind_date + + if consent_val is True: + consent_line = "consent = true" + elif consent_val is False: + consent_line = "consent = false" + else: + consent_line = "consent = none" + + if remind_val is not None: + remind_line = f"reminder_date = {remind_val.strftime('%Y-%m-%d')}" + else: + remind_line = "# reminder_date = " + + content = ( + "[general]\n" + "# Data sharing consent (none / true / false).\n" + f"{consent_line}\n" + "#\n" + "# Date after which the consent dialog will be shown again (YYYY-MM-DD).\n" + f"{remind_line}\n" + ) + + with open(str(self.file_path), "w") as fh: + fh.write(content) + os.chmod(str(self.file_path), 0o600) + + def _ensure_section(self, section: str) -> None: + if not self._cp.has_section(section): + self._cp.add_section(section) + + @property + def consent(self) -> Optional[bool]: + """Return the consent state, or ``None`` if not yet set.""" + try: + return self._cp.getboolean("general", "consent") + except (configparser.NoSectionError, configparser.NoOptionError): + return None + + @consent.setter + def consent(self, value: bool) -> None: + with self._lock: + self._ensure_section("general") + self._cp.set("general", "consent", "true" if value else "false") + self._cp.remove_option("general", REMINDER_DATE_KEY) + self._write() + + def clear_consent(self) -> None: + """Unset consent so it becomes undecided again.""" + with self._lock: + self._ensure_section("general") + self._cp.remove_option("general", "consent") + self._write() + + @property + def remind_date(self) -> Optional[datetime]: + """Return next reminder date as a UTC-aware datetime, or ``None`` when unset.""" + try: + value = self._cp.get("general", REMINDER_DATE_KEY) + except (configparser.NoSectionError, configparser.NoOptionError): + return None + value = value.strip() + if not value: + return None + # Accept simple YYYY-MM-DD as well as full ISO 8601. + for fmt in ("%Y-%m-%d", None): + try: + if fmt: + parsed = datetime.strptime(value, fmt) + else: + parsed = datetime.fromisoformat(value) + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + except ValueError: + continue + return None + + @remind_date.setter + def remind_date(self, value: Optional[datetime]) -> None: + """Persist next reminder UTC timestamp, or unset when ``None``.""" + with self._lock: + self._ensure_section("general") + if value is None: + self._cp.remove_option("general", REMINDER_DATE_KEY) + else: + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + value_utc = value.astimezone(timezone.utc) + self._cp.set("general", REMINDER_DATE_KEY, value_utc.isoformat()) + self._write() + + def postpone_consent(self, remind_date: Optional[datetime] = None) -> None: + """Postpone consent prompt and clear current consent choice. + + :param remind_date: UTC datetime after which the consent prompt should + be shown again. Defaults to ``_DEFAULT_CONSENT_REMIND_DELTA`` + from now when not specified. + """ + if self.consent is False: + return + if remind_date is None: + remind_date = datetime.now(timezone.utc) + _DEFAULT_CONSENT_REMIND_DELTA + elif remind_date.tzinfo is None: + remind_date = remind_date.replace(tzinfo=timezone.utc) + with self._lock: + self._ensure_section("general") + self._cp.remove_option("general", "consent") + self._cp.set("general", REMINDER_DATE_KEY, remind_date.astimezone(timezone.utc).isoformat()) + self._write() + + def should_prompt_for_consent(self) -> bool: + """Return whether the consent dialog should be shown now.""" + if self.consent is not None: + return False + remind_after = self.remind_date + if remind_after is None: + return True + return datetime.now(timezone.utc) >= remind_after + + def get_upload_backend(self) -> "S3UploadBackend": + """Return the configured upload backend instance.""" + credentials = _search_credentials() + return S3UploadBackend( + access_key=credentials["access_key"], + secret_key=credentials["secret_key"], + endpoint_url=S3_ENDPOINT_URL, + region=S3_REGION, + bucket=S3_BUCKET, + ) + + +@dataclass +class _WorkItem: + """A single data-collection event to be serialised and uploaded.""" + + event_name: str + schema_version: str + payload: dict + image_format: str = "TIFF" + submitted_at: float = field(default_factory=time.monotonic) + + +def _serialize(item: _WorkItem, queue_dir: Path) -> Path: + """Serialise *item* into a ZIP archive and place it in *queue_dir*. + + :param item: The work item to serialise. + :param queue_dir: Directory where the finished ZIP is placed. + :returns: Path to the created ZIP file inside *queue_dir*. + :raises OSError: On disk errors (caller must handle). + """ + queue_dir.mkdir(parents=True, exist_ok=True) + + sample_uuid = str(uuid.uuid4()) + uuid8 = sample_uuid.split("-")[0] + timestamp_utc = datetime.now(timezone.utc) + timestamp_str = timestamp_utc.strftime("%Y%m%dT%H%M%S") + # Limit event_name length so the filename stays within filesystem limits. + safe_event = item.event_name[:64] if item.event_name else "event" + zip_name = f"{safe_event}-{timestamp_str}-{uuid8}.zip" + + tmp_dir = Path(tempfile.mkdtemp(prefix="dc_")) + try: + payload_meta: dict = {} + extra_files: list = [] # list of (arcname, abs_path) + + for key, value in item.payload.items(): + if value is None or isinstance(value, (str, int, float, bool)): + payload_meta[key] = value + + elif isinstance(value, numpy.ndarray): + if item.image_format.upper() == "HDF5": + arc_name = f"{key}.h5" + abs_path = tmp_dir / arc_name + try: + da = value if isinstance(value, model.DataArray) else model.DataArray(value) + hdf5.export(str(abs_path), da) + except Exception: + logging.exception("Failed to export DataArray to HDF5 at %s", abs_path) + abs_path = None + else: + arc_name = f"{key}.ome.tiff" + abs_path = tmp_dir / arc_name + try: + da = value if isinstance(value, model.DataArray) else model.DataArray(value) + tiff.export(str(abs_path), da) + except Exception: + logging.exception("Failed to export DataArray to TIFF at %s", abs_path) + abs_path = None + + if abs_path is not None and abs_path.exists(): + extra_files.append((arc_name, abs_path)) + payload_meta[key] = arc_name + else: + payload_meta[key] = None + payload_meta["export_error"] = True + + elif isinstance(value, (dict, list)): + arc_name = f"extra_{key}.json" + abs_path = tmp_dir / arc_name + abs_path.write_text(json.dumps(value, default=str), encoding="utf-8") + extra_files.append((arc_name, abs_path)) + payload_meta[key] = arc_name + + else: + # Fallback: store string representation, guarding against + # __repr__/__str__ implementations that raise. + try: + payload_meta[key] = str(value) + except Exception: + # TODO add logging + payload_meta[key] = "" + + metadata = { + "sample_uuid": sample_uuid, + "timestamp_utc": timestamp_utc.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", + "system_id": socket.gethostname(), + "odemis_version": odemis.__version__, + "event_name": item.event_name, + "schema_version": item.schema_version, + "payload": payload_meta, + } + + meta_path = tmp_dir / "metadata.json" + meta_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8") + + # Build ZIP in temp dir, then rename atomically into queue_dir. + tmp_zip = queue_dir / f"{uuid8}.tmp" + final_zip = queue_dir / zip_name + with zipfile.ZipFile(str(tmp_zip), "w", compression=zipfile.ZIP_DEFLATED) as zf: + zf.write(str(meta_path), "metadata.json") + for arc_name, abs_path in extra_files: + zf.write(str(abs_path), arc_name) + + os.replace(str(tmp_zip), str(final_zip)) + return final_zip + + finally: + shutil.rmtree(str(tmp_dir), ignore_errors=True) + + +def _enforce_queue_limit(queue_dir: Path) -> None: + """Delete the oldest ZIP files if the queue exceeds 10% of partition space. + + :param queue_dir: The staging directory to inspect. + """ + if not queue_dir.exists(): + return + + zips = sorted(queue_dir.glob("*.zip"), key=lambda p: p.stat().st_mtime) + if not zips: + return + + try: + usage = shutil.disk_usage(str(queue_dir)) + except OSError: + logging.warning("Cannot read disk usage for %s", queue_dir) + return + + limit = usage.total * 0.10 # 10 % of partition + total_size = sum(p.stat().st_size for p in zips) + + while total_size > limit and zips: + oldest = zips.pop(0) + try: + size = oldest.stat().st_size + oldest.unlink() + total_size -= size + logging.warning("Queue limit exceeded: removed oldest sample %s", oldest.name) + except OSError: + logging.warning("Could not remove queue file %s", oldest) + + +class S3UploadBackend: + """S3 upload backend implemented with boto3.""" + + def __init__( + self, + access_key: str, + secret_key: str, + endpoint_url: Optional[str] = S3_ENDPOINT_URL, + region: str = S3_REGION, + bucket: str = S3_BUCKET, + ) -> None: + self._access_key = access_key + self._secret_key = secret_key + self._endpoint_url = endpoint_url + self._region = region + self._bucket = bucket + self._client = None + + def _get_client(self): + """Return a cached boto3 S3 client.""" + if self._client is None: + self._client = boto3.client( + "s3", + endpoint_url=self._endpoint_url, + region_name = self._region, + aws_access_key_id=self._access_key, + aws_secret_access_key=self._secret_key, + ) + return self._client + + def upload(self, local_path: Path, remote_key: str) -> None: + """Upload *local_path* to *remote_key* in the configured bucket.""" + client = self._get_client() + client.upload_file(str(local_path), self._bucket, remote_key) + + +def _upload(zip_path: Path, backend: S3UploadBackend) -> None: + """Upload *zip_path* with *backend* using the standard remote key.""" + remote_key = f"{socket.gethostname()}/{zip_path.name}" + backend.upload(zip_path, remote_key) + + +class _BackgroundWorker: + """Daemon thread that consumes :class:`_WorkItem` objects from a queue. + + For each item it calls :func:`_enforce_queue_limit`, :func:`_serialize`, + and :func:`_upload` in sequence. After a successful upload the local ZIP + is deleted. Exceptions are caught and logged so that the worker never + crashes the host application. + + The thread is started lazily and restarted automatically if it dies. + """ + + def __init__(self, config: DataCollectorConfig, queue_dir: Path = _DEFAULT_QUEUE_DIR) -> None: + self._config = config + self._queue_dir = queue_dir + self._queue: queue.Queue = queue.Queue() + self._thread: Optional[threading.Thread] = None + self._lock = threading.Lock() + self._upload_backend: Optional[S3UploadBackend] = None + self._next_retry_at: float = 0.0 + self._retry_delay: float = _INITIAL_RETRY_DELAY_SECONDS + + def enqueue(self, item: _WorkItem) -> None: + """Add *item* to the processing queue and ensure the thread is alive. + + :param item: The work item to enqueue. + """ + self._ensure_thread() + self._queue.put_nowait(item) + + def _ensure_thread(self) -> None: + with self._lock: + if self._thread is None or not self._thread.is_alive(): + self._thread = threading.Thread( + target=self._run, + name="DataCollectorWorker", + daemon=True, + ) + self._thread.start() + logging.debug("DataCollector background thread started.") + + def _get_upload_backend(self) -> S3UploadBackend: + """Return a cached upload backend.""" + if self._upload_backend is None: + self._upload_backend = self._config.get_upload_backend() + return self._upload_backend + + def _schedule_retry(self) -> None: + """Schedule the next retry using exponential backoff.""" + delay = self._retry_delay + self._next_retry_at = time.monotonic() + delay + self._retry_delay = min(self._retry_delay * 2.0, _MAX_RETRY_DELAY_SECONDS) + logging.warning("DataCollector upload failed; retrying in %.0f s", delay) + + def _reset_retry(self) -> None: + """Reset retry state after a successful upload.""" + self._next_retry_at = 0.0 + self._retry_delay = _INITIAL_RETRY_DELAY_SECONDS + + def _pending_zip_paths(self, queue_dir: Path) -> list[Path]: + """Return pending ZIP files ordered oldest-first.""" + if not queue_dir.exists(): + return [] + return sorted(queue_dir.glob("*.zip"), key=lambda p: p.stat().st_mtime) + + def _process_pending_zips(self, queue_dir: Path) -> bool: + """Upload pending ZIP files from *queue_dir*. + + Returns ``True`` when pending work existed (including backoff wait), + otherwise ``False``. + """ + pending = self._pending_zip_paths(queue_dir) + if not pending: + return False + + now = time.monotonic() + if now < self._next_retry_at: + time.sleep(min(1.0, self._next_retry_at - now)) + return True + + try: + backend = self._get_upload_backend() + except Exception: + logging.exception("DataCollector failed to initialize upload backend") + self._schedule_retry() + return True + for zip_path in pending: + try: + _upload(zip_path, backend) + zip_path.unlink(missing_ok=True) + self._reset_retry() + except Exception: + logging.exception("DataCollector upload failed for %s", zip_path.name) + self._schedule_retry() + return True + return True + + def _process_work_item(self, item: _WorkItem) -> None: + """Serialize one work item and trigger upload processing.""" + _enforce_queue_limit(self._queue_dir) + _serialize(item, self._queue_dir) + self._process_pending_zips(self._queue_dir) + + def _run(self) -> None: + """Main loop: process items until the thread is stopped.""" + while True: + try: + if self._process_pending_zips(self._queue_dir): + continue + except Exception: + logging.exception("DataCollector error while processing pending uploads") + self._schedule_retry() + continue + try: + item = self._queue.get(timeout=1.0) + except queue.Empty: + continue + + try: + self._process_work_item(item) + except Exception: + logging.exception( + "DataCollector error processing event '%s'", item.event_name + ) + + +class DataCollector: + """Thread-safe recorder for annotated data samples.""" + + def __init__(self) -> None: + self._config: Optional[DataCollectorConfig] = None + self._worker: Optional[_BackgroundWorker] = None + self._init_ok: bool = False + self._init_lock = threading.Lock() + + def _lazy_init(self) -> None: + """Initialise configuration and worker on first use.""" + if self._init_ok: + return + with self._init_lock: + if self._init_ok: + return + try: + self._config = DataCollectorConfig() + self._worker = _BackgroundWorker(self._config) + self._init_ok = True + logging.debug("DataCollector initialised.") + except Exception: + logging.exception( + "DataCollector failed to initialise; all record() calls will be no-ops." + ) + + def get_consent(self) -> Optional[bool]: + """Return current consent state from configuration.""" + self._lazy_init() + if not self._init_ok: + return None + return self._config.consent # type: ignore[union-attr] + + def should_prompt_for_consent(self) -> bool: + """Return whether a consent dialog should be shown to the user.""" + self._lazy_init() + if not self._init_ok: + return False + return self._config.should_prompt_for_consent() # type: ignore[union-attr] + + def set_consent(self, value: bool) -> None: + """Persist explicit user consent choice.""" + if not isinstance(value, bool): + raise ValueError("value must be a bool") + self._lazy_init() + if not self._init_ok: + return + self._config.consent = value # type: ignore[union-attr] + + def postpone_consent(self, remind_date: Optional[datetime] = None) -> None: + """Postpone consent prompt and clear current consent state. + + :param remind_date: UTC datetime after which the consent prompt should + be shown again. Defaults to ``_DEFAULT_CONSENT_REMIND_DELTA`` from + now when not specified. + """ + self._lazy_init() + if not self._init_ok: + return + self._config.postpone_consent(remind_date=remind_date) # type: ignore[union-attr] + + def get_consent_remind_days(self) -> int: + """Return the number of days used for the consent remind-later interval. + + :returns: Integer number of days in the default remind delta. + """ + return int(_DEFAULT_CONSENT_REMIND_DELTA.days) + + def record( + self, + event_name: str, + schema_version: str, + payload: dict, + image_format: str = "TIFF", + ) -> None: + """Capture an annotated data sample at a software event. + + Returns immediately (non-blocking). Serialisation and upload happen + asynchronously in a background thread. If consent has not been + granted, this is a no-op. This function never raises (beyond the + input validation below); all errors are logged and suppressed. + + :param event_name: Human-readable event identifier, e.g. + ``"z_stack_acquired"``. Must be a non-empty string. + :param schema_version: Payload schema version string, e.g. ``"1.0"``. + Must be a non-empty string. + :param payload: Dict of arbitrary values. Must be a dict. Supported + value types: + - Python primitives (str, int, float, bool, None) — inlined in + ``metadata.json`` + - :class:`odemis.model.DataArray` / :class:`numpy.ndarray` — + exported as TIFF or HDF5 side-car + - dict / list — written as ``extra_.json`` side-car + :param image_format: Format for DataArray export. ``"TIFF"`` + (default) or ``"HDF5"``. + :raises ValueError: If any input parameter is invalid. + """ + if not isinstance(event_name, str) or not event_name: + raise ValueError("event_name must be a non-empty string") + if not isinstance(schema_version, str) or not schema_version: + raise ValueError("schema_version must be a non-empty string") + if not isinstance(payload, dict): + raise ValueError("payload must be a dict") + if image_format.upper() not in _VALID_IMAGE_FORMATS: + raise ValueError( + f"image_format must be one of {_VALID_IMAGE_FORMATS}, got {image_format!r}" + ) + + try: + self._lazy_init() + if not self._init_ok: + return + + consent = self._config.consent # type: ignore[union-attr] + if not consent: + logging.debug( + "DataCollector: consent=%s, skipping event '%s'.", consent, event_name + ) + return + + item = _WorkItem( + event_name=event_name, + schema_version=schema_version, + payload=payload, + image_format=image_format, + ) + self._worker.enqueue(item) # type: ignore[union-attr] + except Exception: + logging.exception( + "Unexpected error in DataCollector.record(); event '%s' dropped.", event_name + ) diff --git a/src/odemis/util/dc_fetch.py b/src/odemis/util/dc_fetch.py new file mode 100644 index 0000000000..02bbfb9430 --- /dev/null +++ b/src/odemis/util/dc_fetch.py @@ -0,0 +1,299 @@ +# -*- coding: utf-8 -*- +""" +Created on 11 March 2026 + +@author: Karishma Kumar + +Copyright © 2026 Karishma Kumar, Delmic + +This file is part of Odemis. + +Odemis is free software: you can redistribute it and/or modify it under the +terms of the GNU General Public License version 2 as published by the Free +Software Foundation. + +Odemis is distributed in the hope that it will be useful, but WITHOUT ANY +WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. See the GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along with +Odemis. If not, see http://www.gnu.org/licenses/. + +Retrieval helpers for downloading DataCollector ZIP samples from S3. +""" + + +import argparse +import logging +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional + +from odemis.util.datacollector import DataCollectorConfig, S3UploadBackend + + +def parse_since_utc(value: str) -> datetime: + """Parse a date/datetime string to UTC-aware datetime. + + Accepts ISO-8601 date (`YYYY-MM-DD`) and datetime (`YYYY-MM-DDTHH:MM:SS` with optional timezone). + """ + text = value.strip() + if len(text) == 10: + parsed = datetime.strptime(text, "%Y-%m-%d") + return parsed.replace(tzinfo=timezone.utc) + parsed = datetime.fromisoformat(text) + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + + +def parse_key_timestamp_utc(key: str) -> Optional[datetime]: + """Parse `--.zip` timestamp from key basename.""" + name = Path(key).name + if not name.endswith(".zip"): + return None + stem = name[:-4] + parts = stem.rsplit("-", 2) + if len(parts) != 3: + return None + ts = parts[1] + try: + parsed = datetime.strptime(ts, "%Y%m%dT%H%M%S") + except ValueError: + return None + return parsed.replace(tzinfo=timezone.utc) + + +def parse_key_event_name(key: str) -> Optional[str]: + """Parse event name from `--.zip` key basename.""" + name = Path(key).name + if not name.endswith(".zip"): + return None + stem = name[:-4] + parts = stem.rsplit("-", 2) + if len(parts) != 3: + return None + return parts[0] or None + + +def build_argument_parser() -> argparse.ArgumentParser: + """Build CLI argument parser for `odemis-dc-fetch`.""" + examples = ( + "Examples:\n" + " odemis-dc-fetch\n" + " odemis-dc-fetch --output ./downloads\n" + " odemis-dc-fetch --event z_stack_acquired\n" + " odemis-dc-fetch --since 2026-03-01\n" + " odemis-dc-fetch --host meteor-5099\n" + " odemis-dc-fetch --host meteor-5099,atlas-001,secom-22\n" + " odemis-dc-fetch --bucket delmic-odemis-collect-test --region eu-west-1\n" + " odemis-dc-fetch --since 2026-03-01T12:30:00 --event z_stack_acquired --output ./dc_samples" + ) + parser = argparse.ArgumentParser( + description="Fetch data-collection ZIP samples from S3.", + epilog=examples, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--event", + dest="event", + help="Only fetch samples matching event name.", + ) + parser.add_argument( + "--since", + dest="since", + help="Only fetch samples since UTC date/datetime (e.g. 2026-03-01 or 2026-03-01T12:30:00).", + ) + parser.add_argument( + "--output", + dest="output", + default="./dc_samples", + help="Output directory for downloaded ZIPs (default: ./dc_samples).", + ) + parser.add_argument( + "--host", + dest="host", + help="Optional host/system-id filter; use comma-separated IDs for multiple hosts. " + "By default, fetch across all hosts.", + ) + parser.add_argument( + "--bucket", + dest="bucket", + help="Optional S3 bucket override (default from datacollector backend config).", + ) + parser.add_argument( + "--endpoint-url", + dest="endpoint_url", + help="Optional S3 endpoint URL override (default from datacollector backend config).", + ) + parser.add_argument( + "--region", + dest="region", + help="Optional AWS region name for S3 client creation.", + ) + return parser + + +def iter_s3_objects(s3_client: Any, bucket: str, prefix: str) -> Iterator[Dict[str, Any]]: + """Iterate S3 objects under `prefix` using `list_objects_v2` pagination.""" + token: Optional[str] = None + while True: + kwargs: Dict[str, Any] = {"Bucket": bucket, "Prefix": prefix} + if token: + kwargs["ContinuationToken"] = token + response = s3_client.list_objects_v2(**kwargs) + for item in response.get("Contents", []): + yield item + if not response.get("IsTruncated"): + break + token = response.get("NextContinuationToken") + + +def should_download_key(key: str, event_filter: Optional[str], since_utc: Optional[datetime]) -> bool: + """Return whether an S3 key should be downloaded by filters.""" + if not key.endswith(".zip"): + return False + if event_filter: + event_name = parse_key_event_name(key) + if event_name != event_filter: + return False + if since_utc: + key_ts = parse_key_timestamp_utc(key) + if key_ts is None: + return False + if key_ts < since_utc: + return False + return True + + +def create_s3_client_from_config(config: DataCollectorConfig) -> tuple[Any, str]: + """Create S3 client and return `(client, bucket)`.""" + backend = config.get_upload_backend() + if not isinstance(backend, S3UploadBackend): + raise RuntimeError("Only S3 backend is supported for retrieval.") + # Accessing protected members intentionally to reuse existing backend setup. + client = backend._get_client() # pylint: disable=protected-access + bucket = backend._bucket # pylint: disable=protected-access + return client, bucket + + +def parse_host_filters(value: Optional[str]) -> List[str]: + """Parse comma-separated host filters into normalized host IDs.""" + if not value: + return [] + hosts = [part.strip().strip("/") for part in value.split(",")] + return [host for host in hosts if host] + + +def build_s3_client_from_config( + config: DataCollectorConfig, + bucket_override: Optional[str] = None, + endpoint_override: Optional[str] = None, + region_override: Optional[str] = None, +) -> tuple[Any, str]: + """Build an S3 client using datacollector credentials with optional endpoint/bucket overrides.""" + backend = config.get_upload_backend() + if not isinstance(backend, S3UploadBackend): + raise RuntimeError("Only S3 backend is supported for retrieval.") + endpoint_url = endpoint_override or backend._endpoint_url # pylint: disable=protected-access + bucket = bucket_override or backend._bucket # pylint: disable=protected-access + import boto3 + client_kwargs: Dict[str, Any] = { + "endpoint_url": endpoint_url, + "aws_access_key_id": backend._access_key, # pylint: disable=protected-access + "aws_secret_access_key": backend._secret_key, # pylint: disable=protected-access + } + if region_override: + client_kwargs["region_name"] = region_override + else: + client_kwargs["region_name"] = backend._region # pylint: disable=protected-access + client = boto3.client("s3", **client_kwargs) + return client, bucket + + +def fetch_samples( + event_filter: Optional[str], + since_utc: Optional[datetime], + output_dir: Path, + host_filter: Optional[str] = None, + bucket_override: Optional[str] = None, + endpoint_override: Optional[str] = None, + region_override: Optional[str] = None, +) -> Dict[str, int]: + """Fetch matching samples from S3 into output directory.""" + cfg = DataCollectorConfig() + s3_client, bucket = build_s3_client_from_config( + cfg, + bucket_override=bucket_override, + endpoint_override=endpoint_override, + region_override=region_override, + ) + host_filters = parse_host_filters(host_filter) + prefixes = [f"{host}/" for host in host_filters] if host_filters else [""] + + output_dir.mkdir(parents=True, exist_ok=True) + + listed = 0 + matched = 0 + downloaded = 0 + skipped_existing = 0 + failed = 0 + + for prefix in prefixes: + for item in iter_s3_objects(s3_client, bucket=bucket, prefix=prefix): + listed += 1 + key = item.get("Key") + if not key or not should_download_key(key, event_filter, since_utc): + continue + matched += 1 + destination = output_dir / Path(key).name + if destination.exists(): + skipped_existing += 1 + continue + try: + s3_client.download_file(bucket, key, str(destination)) + downloaded += 1 + except Exception: + failed += 1 + logging.exception("Failed to download key %s", key) + + return { + "listed": listed, + "matched": matched, + "downloaded": downloaded, + "skipped_existing": skipped_existing, + "failed": failed, + } + + +def main(argv: Optional[List[str]] = None) -> int: + """CLI entrypoint for retrieval flow.""" + parser = build_argument_parser() + args = parser.parse_args(argv) + + try: + since_utc = parse_since_utc(args.since) if args.since else None + except ValueError: + logging.error("Invalid --since value: %s", args.since) + return 2 + + output_dir = Path(args.output) + try: + result = fetch_samples( + event_filter=args.event, + since_utc=since_utc, + output_dir=output_dir, + host_filter=args.host, + bucket_override=args.bucket, + endpoint_override=args.endpoint_url, + region_override=args.region, + ) + except Exception: + logging.exception("Failed to fetch samples from S3") + return 1 + + print( + "listed={listed} matched={matched} downloaded={downloaded} " + "skipped_existing={skipped_existing} failed={failed}".format(**result) + ) + return 0 if result["failed"] == 0 else 1 diff --git a/src/odemis/util/test/datacollector_test.py b/src/odemis/util/test/datacollector_test.py new file mode 100644 index 0000000000..65e7118f28 --- /dev/null +++ b/src/odemis/util/test/datacollector_test.py @@ -0,0 +1,623 @@ +# -*- coding: utf-8 -*- +""" +Created on 11 March 2026 + +Copyright © 2026 Delmic + +This file is part of Odemis. + +Odemis is free software: you can redistribute it and/or modify it under the +terms of the GNU General Public License version 2 as published by the Free +Software Foundation. + +Odemis is distributed in the hope that it will be useful, but WITHOUT ANY +WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A +PARTICULAR PURPOSE. See the GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along with +Odemis. If not, see http://www.gnu.org/licenses/. +""" + +import configparser +import json +import logging +import os +import shutil +import socket +import stat +import tempfile +import time +import unittest +import uuid +import zipfile +from datetime import datetime, timedelta, timezone +from pathlib import Path +from unittest.mock import patch + +import numpy + +from odemis.util.datacollector import ( + DataCollector, + DataCollectorConfig, + S3UploadBackend, + S3_REGION, + S3_TEST_BUCKET, + _CREDENTIALS_PATH, + _BackgroundWorker, + _WorkItem, + _enforce_queue_limit, + _serialize, +) + +logging.basicConfig(level=logging.DEBUG) + + +class TestDataCollectorConfig(unittest.TestCase): + """Tests for DataCollectorConfig read/write behaviour.""" + + def setUp(self) -> None: + self._tmp_conf_dir = tempfile.mkdtemp(prefix="dc_conf_") + + def tearDown(self) -> None: + shutil.rmtree(self._tmp_conf_dir, ignore_errors=True) + + def _make_config(self) -> DataCollectorConfig: + """Return a DataCollectorConfig pointed at the temp directory.""" + cfg = DataCollectorConfig.__new__(DataCollectorConfig) + cfg.file_path = Path(self._tmp_conf_dir) / "datacollector.config" + import threading + cfg._cp = configparser.ConfigParser(interpolation=None) + cfg._lock = threading.Lock() + cfg._read() + return cfg + + def test_consent_initially_none(self) -> None: + """Consent should be None when no config file exists.""" + cfg = self._make_config() + self.assertIsNone(cfg.consent) + + def test_consent_round_trip(self) -> None: + """Setting consent to True/False persists to disk and re-reads correctly.""" + cfg = self._make_config() + cfg.consent = True + cfg2 = self._make_config() + self.assertTrue(cfg2.consent) + cfg.consent = False + cfg3 = self._make_config() + self.assertFalse(cfg3.consent) + + def test_config_file_permissions(self) -> None: + """Config file should be written with mode 0o600 (security requirement).""" + cfg = self._make_config() + cfg.consent = True + mode = stat.S_IMODE(os.stat(str(cfg.file_path)).st_mode) + self.assertEqual(mode, 0o600) + + def test_postpone_consent_sets_due_and_clears_consent(self) -> None: + """Postponing should clear consent and schedule next reminder.""" + cfg = self._make_config() + cfg.consent = True + cfg.postpone_consent() + self.assertIsNone(cfg.consent) + remind_after = cfg.remind_date + self.assertIsNotNone(remind_after) + self.assertGreater(remind_after, datetime.now(timezone.utc)) + + def test_postpone_does_not_override_explicit_opt_out(self) -> None: + """Postpone should not schedule reminders when consent is explicitly False.""" + cfg = self._make_config() + cfg.consent = False + cfg.postpone_consent() + self.assertFalse(cfg.consent) + self.assertIsNone(cfg.remind_date) + self.assertFalse(cfg.should_prompt_for_consent()) + + def test_should_prompt_for_consent_logic(self) -> None: + """Prompt logic follows consent and remind-after semantics.""" + cfg = self._make_config() + self.assertTrue(cfg.should_prompt_for_consent()) + + cfg.consent = True + self.assertFalse(cfg.should_prompt_for_consent()) + + cfg.clear_consent() + cfg.remind_date = datetime.now(timezone.utc) + timedelta(days=1) + self.assertFalse(cfg.should_prompt_for_consent()) + + cfg.remind_date = datetime.now(timezone.utc) - timedelta(seconds=1) + self.assertTrue(cfg.should_prompt_for_consent()) + + +class TestSerialize(unittest.TestCase): + """Tests for _serialize() — ZIP structure and metadata correctness.""" + + def setUp(self) -> None: + self._tmp_queue = Path(tempfile.mkdtemp(prefix="dc_queue_")) + + def tearDown(self) -> None: + shutil.rmtree(str(self._tmp_queue), ignore_errors=True) + + def _make_item(self, payload: dict, image_format: str = "TIFF") -> _WorkItem: + return _WorkItem( + event_name="test_event", + schema_version="1.0", + payload=payload, + image_format=image_format, + ) + + def test_zip_created(self) -> None: + """A ZIP file is created in queue_dir after serialisation.""" + item = self._make_item({"score": 0.9}) + zip_path = _serialize(item, self._tmp_queue) + self.assertTrue(zip_path.exists(), "ZIP file not created") + self.assertTrue(zip_path.suffix == ".zip") + + def test_zip_filename_format(self) -> None: + """ZIP filename follows --.zip convention.""" + item = self._make_item({"x": 1}) + zip_path = _serialize(item, self._tmp_queue) + name = zip_path.name + parts = name[:-4].split("-") # strip .zip + self.assertEqual(parts[0], "test_event") + self.assertEqual(len(parts[2]), 8, "UUID8 part should be 8 hex characters") + + def test_metadata_json_envelope_fields(self) -> None: + """metadata.json must contain all standard envelope fields.""" + item = self._make_item({"score": 0.5}) + zip_path = _serialize(item, self._tmp_queue) + with zipfile.ZipFile(str(zip_path)) as zf: + meta = json.loads(zf.read("metadata.json")) + required = {"sample_uuid", "timestamp_utc", "system_id", "odemis_version", + "event_name", "schema_version", "payload"} + self.assertEqual(required, required & meta.keys()) + self.assertEqual(meta["event_name"], "test_event") + self.assertEqual(meta["schema_version"], "1.0") + + def test_primitive_payload_inlined(self) -> None: + """Primitive payload values are inlined in metadata.json.""" + item = self._make_item({"score": 0.87, "n": 12, "name": "foo", "flag": True}) + zip_path = _serialize(item, self._tmp_queue) + with zipfile.ZipFile(str(zip_path)) as zf: + meta = json.loads(zf.read("metadata.json")) + self.assertAlmostEqual(meta["payload"]["score"], 0.87) + self.assertEqual(meta["payload"]["n"], 12) + self.assertEqual(meta["payload"]["name"], "foo") + self.assertTrue(meta["payload"]["flag"]) + + def test_numpy_array_exported_as_tiff(self) -> None: + """numpy.ndarray values are exported as .ome.tiff formats.""" + arr = numpy.zeros((64, 64), dtype=numpy.uint16) + item = self._make_item({"image": arr}) + zip_path = _serialize(item, self._tmp_queue) + with zipfile.ZipFile(str(zip_path)) as zf: + names = zf.namelist() + meta = json.loads(zf.read("metadata.json")) + output_formats = meta["payload"]["image"] + self.assertIn(output_formats, names, "TIFF sidecar not in ZIP") + self.assertTrue(output_formats.endswith(".ome.tiff")) + + def test_dict_payload_written_as_extra_json(self) -> None: + """dict payload values are written as extra_*.json.""" + item = self._make_item({"params": {"a": 1, "b": 2}}) + zip_path = _serialize(item, self._tmp_queue) + with zipfile.ZipFile(str(zip_path)) as zf: + names = zf.namelist() + meta = json.loads(zf.read("metadata.json")) + self.assertIn("extra_params.json", names) + self.assertEqual(meta["payload"]["params"], "extra_params.json") + + def test_list_payload_written_as_extra_json(self) -> None: + """list payload values are written as extra_*.json sidecars.""" + item = self._make_item({"items": [1, 2, 3]}) + zip_path = _serialize(item, self._tmp_queue) + with zipfile.ZipFile(str(zip_path)) as zf: + names = zf.namelist() + self.assertIn("extra_items.json", names) + + def test_atomic_write(self) -> None: + """No .tmp files remain after successful serialisation.""" + item = self._make_item({"x": 1}) + _serialize(item, self._tmp_queue) + tmps = list(self._tmp_queue.glob("*.tmp")) + self.assertEqual(tmps, [], "Leftover .tmp files found") + + def test_hdf5_image_format(self) -> None: + """When image_format=HDF5, DataArray is exported as an .h5.""" + arr = numpy.zeros((32, 32), dtype=numpy.float32) + item = self._make_item({"data": arr}, image_format="HDF5") + zip_path = _serialize(item, self._tmp_queue) + with zipfile.ZipFile(str(zip_path)) as zf: + names = zf.namelist() + meta = json.loads(zf.read("metadata.json")) + output_formats = meta["payload"]["data"] + self.assertIn(output_formats, names) + self.assertTrue(output_formats.endswith(".h5")) + + +class TestEnforceQueueLimit(unittest.TestCase): + """Tests for _enforce_queue_limit().""" + + def setUp(self) -> None: + self._tmp_queue = Path(tempfile.mkdtemp(prefix="dc_qlimit_")) + + def tearDown(self) -> None: + shutil.rmtree(str(self._tmp_queue), ignore_errors=True) + + def _write_zip(self, name: str, size_bytes: int, mtime: float) -> Path: + """Create a dummy ZIP file of the given size and modification time.""" + p = self._tmp_queue / name + p.write_bytes(b"\x00" * size_bytes) + os.utime(str(p), (mtime, mtime)) + return p + + def test_no_deletion_when_under_limit(self) -> None: + """Files are NOT deleted when total queue size is within the 10% limit.""" + self._write_zip("small.zip", 1024, time.time()) + _enforce_queue_limit(self._tmp_queue) + self.assertTrue((self._tmp_queue / "small.zip").exists()) + + def test_oldest_deleted_when_over_limit(self) -> None: + """Oldest ZIPs are deleted when the queue exceeds 10% of partition space.""" + import collections + FakeDiskUsage = collections.namedtuple("usage", ["total", "used", "free"]) + fake_usage = FakeDiskUsage(total=300, used=0, free=300) + + now = time.time() + # 3 files × 20 bytes = 60 bytes > 10% of 300 (= 30 bytes). + old = self._write_zip("old.zip", 20, now - 100) + self._write_zip("newer.zip", 20, now - 50) + self._write_zip("newest.zip", 20, now) + + with patch("odemis.util.datacollector.shutil.disk_usage", return_value=fake_usage): + _enforce_queue_limit(self._tmp_queue) + + self.assertFalse(old.exists(), "Oldest ZIP should have been deleted") + + def test_empty_dir_no_error(self) -> None: + """_enforce_queue_limit on an empty directory must not raise.""" + try: + _enforce_queue_limit(self._tmp_queue) + except Exception as exc: + self.fail(f"_enforce_queue_limit raised unexpectedly: {exc}") + + def test_nonexistent_dir_no_error(self) -> None: + """_enforce_queue_limit on a non-existent directory must not raise.""" + try: + _enforce_queue_limit(Path("/nonexistent/path/dc_queue")) + except Exception as exc: + self.fail(f"_enforce_queue_limit raised unexpectedly: {exc}") + + +class TestUploadAndRetry(unittest.TestCase): + """Tests upload backend and retry behavior.""" + + def setUp(self) -> None: + self._tmp_dir = Path(tempfile.mkdtemp(prefix="dc_upload_")) + self._queue_dir = self._tmp_dir / "queue" + self._queue_dir.mkdir(parents=True, exist_ok=True) + + cfg = DataCollectorConfig.__new__(DataCollectorConfig) + cfg.file_path = self._tmp_dir / "datacollector.config" + import threading + cfg._cp = configparser.ConfigParser(interpolation=None) + cfg._lock = threading.Lock() + cfg._read() + cfg.consent = True + self._cfg = cfg + self._worker = _BackgroundWorker(cfg, queue_dir=self._queue_dir) + + def tearDown(self) -> None: + shutil.rmtree(str(self._tmp_dir), ignore_errors=True) + + def _write_zip(self, name: str, mtime: float) -> Path: + """Create a pending ZIP in queue_dir with a stable mtime.""" + p = self._queue_dir / name + p.write_bytes(b"zip") + os.utime(str(p), (mtime, mtime)) + return p + + def test_upload_called_after_serialization(self) -> None: + """Worker should upload a serialized ZIP right after it is created.""" + order = [] + item = _WorkItem(event_name="upload_call_test", schema_version="1.0", payload={"x": 1}) + target_zip = self._queue_dir / "serialized.zip" + + class _Backend: + def upload(self, local_path: Path, remote_key: str) -> None: + order.append(("upload", local_path.name, remote_key)) + + def _fake_serialize(work_item: _WorkItem, queue_dir: Path) -> Path: + del work_item, queue_dir + target_zip.write_bytes(b"zip") + order.append(("serialize", target_zip.name)) + return target_zip + + with patch("odemis.util.datacollector._serialize", side_effect=_fake_serialize), \ + patch.object(self._worker, "_get_upload_backend", return_value=_Backend()): + self._worker._process_work_item(item) + + self.assertEqual(order[0], ("serialize", "serialized.zip")) + self.assertEqual(order[1][0], "upload") + self.assertEqual(order[1][1], "serialized.zip") + self.assertFalse(target_zip.exists(), "ZIP should be deleted after successful upload") + + def test_retry_on_failure_then_success(self) -> None: + """Failed upload is retried and eventually clears pending ZIPs.""" + now = time.time() + older = self._write_zip("older.zip", now - 60) + newer = self._write_zip("newer.zip", now - 30) + uploaded_names = [] + failures = {"count": 0} + + class _Backend: + def upload(self, local_path: Path, remote_key: str) -> None: + uploaded_names.append(local_path.name) + if failures["count"] == 0: + failures["count"] += 1 + raise ConnectionError("temporary network issue") + + with patch.object(self._worker, "_get_upload_backend", return_value=_Backend()): + had_pending = self._worker._process_pending_zips(self._queue_dir) + self.assertTrue(had_pending) + self.assertTrue(older.exists(), "Failed ZIP should remain for retry") + self.assertGreater(self._worker._next_retry_at, 0.0) + + with patch("odemis.util.datacollector.time.monotonic", return_value=self._worker._next_retry_at + 1.0): + had_pending = self._worker._process_pending_zips(self._queue_dir) + self.assertTrue(had_pending) + + self.assertFalse(older.exists()) + self.assertFalse(newer.exists()) + self.assertEqual(uploaded_names[:3], ["older.zip", "older.zip", "newer.zip"]) + + def test_pending_flush_oldest_first(self) -> None: + """Recovery flush should process queued ZIP files oldest-first.""" + now = time.time() + self._write_zip("oldest.zip", now - 120) + self._write_zip("middle.zip", now - 60) + self._write_zip("newest.zip", now - 10) + uploaded = [] + + class _Backend: + def upload(self, local_path: Path, remote_key: str) -> None: + uploaded.append(local_path.name) + + with patch.object(self._worker, "_get_upload_backend", return_value=_Backend()): + had_pending = self._worker._process_pending_zips(self._queue_dir) + self.assertTrue(had_pending) + self.assertEqual(uploaded, ["oldest.zip", "middle.zip", "newest.zip"]) + self.assertEqual(list(self._queue_dir.glob("*.zip")), []) + + +class TestRealS3Integration(unittest.TestCase): + """Real S3 integration tests for Phase 2 upload workflow. + + Tests use the credentials from ``_CREDENTIALS_PATH`` and upload to + ``S3_TEST_BUCKET`` (not the production bucket). The test class is + skipped automatically when the key file is absent. + """ + + @classmethod + def setUpClass(cls) -> None: + try: + import boto3 + except ImportError as exc: + raise unittest.SkipTest(f"boto3 is required for real S3 integration tests: {exc}") + + import json as _json + if not os.path.isfile(_CREDENTIALS_PATH): + raise unittest.SkipTest( + f"S3 key file not found at {_CREDENTIALS_PATH}; skipping real S3 integration tests." + ) + with open(_CREDENTIALS_PATH, "r") as fh: + creds = _json.load(fh) + + cls._access_key = creds["access_key"] + cls._secret_key = creds["secret_key"] + cls._bucket = S3_TEST_BUCKET + cls._region = S3_REGION + + cls._s3_client = boto3.client( + "s3", + aws_access_key_id=cls._access_key, + aws_secret_access_key=cls._secret_key, + region_name=cls._region, + ) + + def setUp(self) -> None: + self._tmp_dir = Path(tempfile.mkdtemp(prefix="dc_reals3_")) + self._queue_dir = self._tmp_dir / "queue" + self._queue_dir.mkdir(parents=True, exist_ok=True) + self._created_keys: list[str] = [] + + def tearDown(self) -> None: + for key in self._created_keys: + try: + self._s3_client.delete_object(Bucket=self._bucket, Key=key) + except Exception as exc: + logging.warning("Could not delete test object %s: %s", key, exc) + shutil.rmtree(str(self._tmp_dir), ignore_errors=True) + + def _new_remote_key(self, suffix: str = ".zip") -> str: + """Create a unique remote key for test uploads.""" + return f"odemis-integration-tests/{socket.gethostname()}/{uuid.uuid4().hex}{suffix}" + + def test_s3_upload_backend_uploads_file(self) -> None: + """S3UploadBackend.upload should place an object in the configured bucket.""" + local_path = self._tmp_dir / "sample.zip" + local_path.write_bytes(b"integration-test-payload") + remote_key = self._new_remote_key() + + backend = S3UploadBackend( + access_key=self._access_key, + secret_key=self._secret_key, + region=self._region, + bucket=self._bucket, + ) + backend.upload(local_path, remote_key) + self._created_keys.append(remote_key) + + response = self._s3_client.head_object(Bucket=self._bucket, Key=remote_key) + self.assertGreater(response["ContentLength"], 0) + + def test_worker_uploads_pending_and_deletes_local(self) -> None: + """Background worker should upload pending ZIPs and delete them locally.""" + old_zip = self._queue_dir / "old.zip" + new_zip = self._queue_dir / "new.zip" + old_zip.write_bytes(b"old") + new_zip.write_bytes(b"new") + now = time.time() + os.utime(str(old_zip), (now - 60, now - 60)) + os.utime(str(new_zip), (now - 30, now - 30)) + + cfg = DataCollectorConfig.__new__(DataCollectorConfig) + cfg.file_path = self._tmp_dir / "datacollector.config" + import threading + cfg._cp = configparser.ConfigParser(interpolation=None) + cfg._lock = threading.Lock() + cfg._read() + cfg.consent = True + worker = _BackgroundWorker(cfg, queue_dir=self._queue_dir) + + uploaded_keys: list[str] = [] + backend = S3UploadBackend( + access_key=self._access_key, + secret_key=self._secret_key, + region=self._region, + bucket=self._bucket, + ) + + def _capture_upload(local_path: Path, backend_obj: S3UploadBackend) -> None: + remote_key = f"{socket.gethostname()}/{local_path.name}" + backend_obj.upload(local_path, remote_key) + uploaded_keys.append(remote_key) + + with patch.object(worker, "_get_upload_backend", return_value=backend), \ + patch("odemis.util.datacollector._upload", side_effect=_capture_upload): + had_pending = worker._process_pending_zips(self._queue_dir) + + self.assertTrue(had_pending) + self.assertFalse(old_zip.exists(), "Old pending ZIP should be removed locally") + self.assertFalse(new_zip.exists(), "New pending ZIP should be removed locally") + self.assertEqual(len(uploaded_keys), 2) + + self._created_keys.extend(uploaded_keys) + for key in uploaded_keys: + self.assertIsNotNone(self._s3_client.head_object(Bucket=self._bucket, Key=key)) + + + +class DataCollectorTest(unittest.TestCase): + """Integration-level tests for DataCollector.record().""" + + def setUp(self) -> None: + self._tmp_dir = Path(tempfile.mkdtemp(prefix="dc_test_")) + self._queue_dir = self._tmp_dir / "queue" + + cfg = DataCollectorConfig.__new__(DataCollectorConfig) + cfg.file_path = self._tmp_dir / "datacollector.config" + import threading + cfg._cp = configparser.ConfigParser(interpolation=None) + cfg._lock = threading.Lock() + cfg._read() + cfg.consent = True + self._cfg = cfg + + worker = _BackgroundWorker(cfg, queue_dir=self._queue_dir) + + self._collector = DataCollector() + self._collector._config = cfg + self._collector._worker = worker + self._collector._init_ok = True + + def tearDown(self) -> None: + shutil.rmtree(str(self._tmp_dir), ignore_errors=True) + + def test_record_returns_fast(self) -> None: + """record() must return to the caller in under 10 ms.""" + arr = numpy.zeros((256, 256), dtype=numpy.uint16) + t0 = time.monotonic() + self._collector.record("perf_test", "1.0", {"image": arr}) + elapsed_ms = (time.monotonic() - t0) * 1000 + self.assertLess(elapsed_ms, 10.0, f"record() took {elapsed_ms:.1f} ms (limit: 10 ms)") + + def test_serialize_creates_zip_with_metadata(self) -> None: + """_serialize() produces a ZIP with valid metadata.json.""" + item = _WorkItem(event_name="zip_test", schema_version="1.0", payload={"score": 0.5}) + self._queue_dir.mkdir(parents=True, exist_ok=True) + zip_path = _serialize(item, self._queue_dir) + self.assertTrue(zip_path.exists()) + with zipfile.ZipFile(str(zip_path)) as zf: + meta = json.loads(zf.read("metadata.json")) + self.assertEqual(meta["event_name"], "zip_test") + self.assertIn("sample_uuid", meta) + self.assertIn("timestamp_utc", meta) + self.assertIn("odemis_version", meta) + + def test_noop_when_consent_false(self) -> None: + """record() does not enqueue work when consent is False.""" + self._cfg.consent = False + q = self._collector._worker._queue + size_before = q.qsize() + self._collector.record("no_consent_test", "1.0", {"x": 1}) + self.assertEqual(q.qsize(), size_before, "No item should be enqueued when consent is False") + + def test_noop_when_consent_none(self) -> None: + """record() does not enqueue work when consent has not been set.""" + self._cfg._cp.remove_option("general", "consent") + q = self._collector._worker._queue + size_before = q.qsize() + self._collector.record("no_consent_none_test", "1.0", {"x": 1}) + self.assertEqual(q.qsize(), size_before, "No item should be enqueued when consent is None") + + def test_no_exception_on_bad_payload_value(self) -> None: + """record() must not raise for unserializable payload values.""" + class _Unserializable: + def __repr__(self): + raise RuntimeError("boom") + + try: + self._collector.record("bad_payload", "1.0", {"bad": _Unserializable()}) + except Exception as exc: + self.fail(f"record() raised unexpectedly: {exc}") + + def test_raises_for_empty_event_name(self) -> None: + """record() raises ValueError for an empty event_name.""" + with self.assertRaises(ValueError): + self._collector.record("", "1.0", {}) + + def test_raises_for_non_string_event_name(self) -> None: + """record() raises ValueError when event_name is not a string.""" + with self.assertRaises(ValueError): + self._collector.record(123, "1.0", {}) # type: ignore[arg-type] + + def test_raises_for_empty_schema_version(self) -> None: + """record() raises ValueError for an empty schema_version.""" + with self.assertRaises(ValueError): + self._collector.record("event", "", {}) + + def test_raises_for_non_dict_payload(self) -> None: + """record() raises ValueError when payload is not a dict.""" + with self.assertRaises(ValueError): + self._collector.record("event", "1.0", [1, 2, 3]) # type: ignore[arg-type] + + def test_raises_for_invalid_image_format(self) -> None: + """record() raises ValueError for an unknown image_format.""" + with self.assertRaises(ValueError): + self._collector.record("event", "1.0", {}, image_format="PNG") + + def test_validation_raises_even_when_consent_false(self) -> None: + """Input validation fires before the consent gate.""" + self._cfg.consent = False + with self.assertRaises(ValueError): + self._collector.record("", "1.0", {}) + + def test_valid_hdf5_format_accepted(self) -> None: + """record() accepts 'HDF5' as image_format without raising.""" + try: + self._collector.record("event", "1.0", {}, image_format="HDF5") + except ValueError as exc: + self.fail(f"record() raised ValueError for valid HDF5 format: {exc}") + +if __name__ == "__main__": + unittest.main() diff --git a/src/odemis/util/test/dc_fetch_test.py b/src/odemis/util/test/dc_fetch_test.py new file mode 100644 index 0000000000..a67fe14ae3 --- /dev/null +++ b/src/odemis/util/test/dc_fetch_test.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +""" +Created on 11 March 2026 + +@author: Karishma Kumar + +Copyright © 2026 Karishma Kumar, Delmic + +This file is part of Odemis. + +Odemis is free software: you can redistribute it and/or modify it under the +terms of the GNU General Public License version 2 as published by the Free +Software Foundation. + +Odemis is distributed in the hope that it will be useful, but WITHOUT ANY +WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. See the GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along with +Odemis. If not, see http://www.gnu.org/licenses/. +""" + +import tempfile +import unittest +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import Mock, patch + +from odemis.util import dc_fetch + + +class DCFetchTest(unittest.TestCase): + """Unit tests for S3 retrieval helpers.""" + + def test_parse_since_utc_date(self) -> None: + """Date input should parse as UTC midnight.""" + parsed = dc_fetch.parse_since_utc("2026-03-22") + self.assertEqual(parsed, datetime(2026, 3, 22, 0, 0, 0, tzinfo=timezone.utc)) + + def test_parse_key_timestamp(self) -> None: + """Timestamp should be parsed from key basename.""" + parsed = dc_fetch.parse_key_timestamp_utc("host/z_stack_acquired-20260322T104530-a1b2c3d4.zip") + self.assertEqual(parsed, datetime(2026, 3, 22, 10, 45, 30, tzinfo=timezone.utc)) + + def test_should_download_key_filters(self) -> None: + """Event and since filters should both be enforced.""" + key = "host/z_stack_acquired-20260322T104530-a1b2c3d4.zip" + since_before = datetime(2026, 3, 22, 10, 0, 0, tzinfo=timezone.utc) + since_after = datetime(2026, 3, 22, 11, 0, 0, tzinfo=timezone.utc) + self.assertTrue(dc_fetch.should_download_key(key, "z_stack_acquired", since_before)) + self.assertFalse(dc_fetch.should_download_key(key, "other_event", since_before)) + self.assertFalse(dc_fetch.should_download_key(key, "z_stack_acquired", since_after)) + + def test_iter_s3_objects_paginates(self) -> None: + """S3 iterator should follow continuation tokens.""" + client = Mock() + client.list_objects_v2.side_effect = [ + { + "Contents": [{"Key": "host/a.zip"}], + "IsTruncated": True, + "NextContinuationToken": "token-1", + }, + { + "Contents": [{"Key": "host/b.zip"}], + "IsTruncated": False, + }, + ] + keys = [item["Key"] for item in dc_fetch.iter_s3_objects(client, "bucket", "host/")] + self.assertEqual(keys, ["host/a.zip", "host/b.zip"]) + self.assertEqual(client.list_objects_v2.call_count, 2) + + def test_parse_host_filters_comma_list(self) -> None: + """Host parser should accept comma-separated values and normalize them.""" + hosts = dc_fetch.parse_host_filters("meteor-5099, atlas-001 ,/secom-22/") + self.assertEqual(hosts, ["meteor-5099", "atlas-001", "secom-22"]) + + def test_fetch_samples_downloads_matching_keys(self) -> None: + """Fetch flow should download matching keys and report counters.""" + with tempfile.TemporaryDirectory(prefix="dc_fetch_") as tmp_dir: + output_dir = Path(tmp_dir) + client = Mock() + client.list_objects_v2.return_value = { + "Contents": [ + {"Key": "host/evt-20260322T100000-aaaa1111.zip"}, + {"Key": "host/other-20260322T100000-bbbb2222.zip"}, + ], + "IsTruncated": False, + } + + def _download_file(_bucket: str, _key: str, filename: str) -> None: + Path(filename).write_bytes(b"zip") + + client.download_file.side_effect = _download_file + + with patch("odemis.util.dc_fetch.build_s3_client_from_config", return_value=(client, "bucket")): + result = dc_fetch.fetch_samples( + event_filter="evt", + since_utc=datetime(2026, 3, 22, 9, 0, 0, tzinfo=timezone.utc), + output_dir=output_dir, + ) + + self.assertEqual(result["listed"], 2) + self.assertEqual(result["matched"], 1) + self.assertEqual(result["downloaded"], 1) + self.assertEqual(result["failed"], 0) + self.assertTrue((output_dir / "evt-20260322T100000-aaaa1111.zip").exists()) + + def test_fetch_samples_applies_host_filter_prefix(self) -> None: + """Host filter should become the S3 list prefix.""" + with tempfile.TemporaryDirectory(prefix="dc_fetch_") as tmp_dir: + output_dir = Path(tmp_dir) + client = Mock() + client.list_objects_v2.return_value = {"Contents": [], "IsTruncated": False} + + with patch("odemis.util.dc_fetch.build_s3_client_from_config", return_value=(client, "bucket")): + dc_fetch.fetch_samples( + event_filter=None, + since_utc=None, + output_dir=output_dir, + host_filter="meteor-5099", + ) + + call_kwargs = client.list_objects_v2.call_args.kwargs + self.assertEqual(call_kwargs["Bucket"], "bucket") + self.assertEqual(call_kwargs["Prefix"], "meteor-5099/") + + def test_fetch_samples_applies_multiple_host_prefixes(self) -> None: + """Comma-separated hosts should trigger one listing call per host prefix.""" + with tempfile.TemporaryDirectory(prefix="dc_fetch_") as tmp_dir: + output_dir = Path(tmp_dir) + client = Mock() + client.list_objects_v2.return_value = {"Contents": [], "IsTruncated": False} + + with patch("odemis.util.dc_fetch.build_s3_client_from_config", return_value=(client, "bucket")): + dc_fetch.fetch_samples( + event_filter=None, + since_utc=None, + output_dir=output_dir, + host_filter="meteor-5099,atlas-001", + ) + + self.assertEqual(client.list_objects_v2.call_count, 2) + first_prefix = client.list_objects_v2.call_args_list[0].kwargs["Prefix"] + second_prefix = client.list_objects_v2.call_args_list[1].kwargs["Prefix"] + self.assertEqual(first_prefix, "meteor-5099/") + self.assertEqual(second_prefix, "atlas-001/") + + def test_fetch_samples_passes_bucket_endpoint_region_overrides(self) -> None: + """Overrides should be forwarded to the S3 client builder.""" + with tempfile.TemporaryDirectory(prefix="dc_fetch_") as tmp_dir: + output_dir = Path(tmp_dir) + client = Mock() + client.list_objects_v2.return_value = {"Contents": [], "IsTruncated": False} + + with patch("odemis.util.dc_fetch.build_s3_client_from_config", return_value=(client, "bucket")) as builder: + dc_fetch.fetch_samples( + event_filter=None, + since_utc=None, + output_dir=output_dir, + host_filter=None, + bucket_override="other-bucket", + endpoint_override="https://s3.eu-west-1.amazonaws.com", + region_override="eu-west-1", + ) + + kwargs = builder.call_args.kwargs + self.assertEqual(kwargs["bucket_override"], "other-bucket") + self.assertEqual(kwargs["endpoint_override"], "https://s3.eu-west-1.amazonaws.com") + self.assertEqual(kwargs["region_override"], "eu-west-1") + + def test_build_s3_client_uses_backend_region_as_default(self) -> None: + """build_s3_client_from_config should use backend._region when no override is given.""" + from odemis.util.datacollector import S3UploadBackend, S3_REGION + + backend = S3UploadBackend( + access_key="key", + secret_key="secret", + region=S3_REGION, + bucket="test-bucket", + ) + mock_config = Mock() + mock_config.get_upload_backend.return_value = backend + + with patch("boto3.client") as mock_boto3_client: + mock_boto3_client.return_value = Mock() + dc_fetch.build_s3_client_from_config(mock_config) + + call_kwargs = mock_boto3_client.call_args.kwargs + self.assertEqual(call_kwargs.get("region_name"), S3_REGION) + # endpoint_url must be None so boto3 resolves the regional endpoint automatically + self.assertIsNone(call_kwargs.get("endpoint_url")) + + +if __name__ == "__main__": + unittest.main() From 895588f5480440195eb3cfc87257ae54642b8c6c Mon Sep 17 00:00:00 2001 From: Karishma Kumar Date: Mon, 20 Apr 2026 14:15:35 +0200 Subject: [PATCH 2/7] store annoted data on Meteor --- src/odemis/acq/feature.py | 172 ++++++++++++++++++++++- src/odemis/acq/test/feature_test.py | 185 ++++++++++++++++++++++++- src/odemis/acq/test/test-features.json | 2 +- src/odemis/gui/cont/features.py | 110 +++++++++++++++ 4 files changed, 463 insertions(+), 6 deletions(-) diff --git a/src/odemis/acq/feature.py b/src/odemis/acq/feature.py index 22c58fd0d0..2e2b83328c 100644 --- a/src/odemis/acq/feature.py +++ b/src/odemis/acq/feature.py @@ -25,6 +25,7 @@ import json import logging import os +import random import threading import time from concurrent import futures @@ -51,10 +52,11 @@ MicroscopePostureManager, ) from odemis.acq.stitching._tiledacq import SAFE_REL_RANGE_DEFAULT -from odemis.acq.stream import Stream, StaticFluoStream +from odemis.acq.stream import Stream, StaticFluoStream, StaticSEMStream, StaticFIBStream from odemis.dataio import find_fittest_converter, get_available_formats from odemis.util import dataio, executeAsyncTask from odemis.util.comp import generate_zlevels +from odemis.util.datacollector import DataCollector from odemis.util.dataio import data_to_static_streams, open_acquisition, splitext from odemis.util.driver import estimate_stage_movement_time from odemis.util.filename import create_filename @@ -70,6 +72,9 @@ REFERENCE_IMAGE_FILENAME = "Reference-Alignment-FIB.ome.tiff" +# Probability that a newly created feature is marked as eligible for data collection. +FEATURE_COLLECT_PROBABILITY = 0.2 + USER_MILLING_TASKS_PATH = os.path.expanduser("~/.config/odemis/milling_tasks.yaml") @@ -173,7 +178,9 @@ def __init__(self, name: str, stage_position: Dict[str, float], fm_focus_position: Dict[str, float], streams: Optional[List[Stream]] = None, - milling_tasks: Optional[Dict[str, MillingTaskSettings]] = None, correlation_data=None): + milling_tasks: Optional[Dict[str, MillingTaskSettings]] = None, + correlation_data=None, + collect: Optional[bool] = None): """ :param name: (string) the feature name :param stage_position: (dict) the stage position of the feature (stage-bare) @@ -181,6 +188,8 @@ def __init__(self, name: str, :param streams: (List of StaticStream) list of acquired streams on this feature :param correlation_data: (Dict[str,FIBFMCorrelationData]) Dictionary mapping the feature status to FIBFMCorrelationData, where feature status like Active, Rough Milled or polished is the key. + :param collect: (bool or None) Whether this feature is eligible for data collection. + If None (default), the flag is initialised randomly with FEATURE_COLLECT_PROBABILITY. """ self.name = model.StringVA(name) # FIXME: The 'position' parameter should eventually contain the SampleStage coordinates and not stage bare from the stage_position! @@ -211,6 +220,13 @@ def __init__(self, name: str, correlation_data = {} self.correlation_data = correlation_data + # Whether this feature is eligible for data collection. + # Randomly assigned on creation with FEATURE_COLLECT_PROBABILITY; persisted in features.json. + if collect is None: + self.collect: bool = random.random() < FEATURE_COLLECT_PROBABILITY + else: + self.collect = collect + # attributes for automated milling self.path: str = None # TODO:support path creation here, rather than on milling data save self.reference_image: model.DataArray = None @@ -318,6 +334,7 @@ def get_features_dict(features: List[CryoFeature]) -> Dict[str, str]: for feature in features: feature_item = {'name': feature.name.value, 'status': feature.status.value, + 'collect': feature.collect, 'stage_position': feature.stage_position.value, 'fm_focus_position': feature.fm_focus_position.value, 'posture_positions': feature.posture_positions, @@ -351,7 +368,8 @@ def object_hook(self, obj): milling_task_json = obj.get('milling_tasks', {}) feature = CryoFeature(name=obj['name'], stage_position=stage_position, - fm_focus_position=fm_focus_position + fm_focus_position=fm_focus_position, + collect=obj.get('collect', None) ) feature.correlation_data = FIBFMCorrelationData.from_dict(correlation_data) if correlation_data else None feature.status.value = obj['status'] @@ -471,6 +489,154 @@ def _create_fibsem_filename(filename: str, acq_type: str) -> str: return create_filename(path, ptn, ext, count="001") + +def _is_zstack_stream(stream: "Stream") -> bool: + """Return True if the stream contains z-stack (3-D ZYX) data.""" + return hasattr(stream, "zIndex") + + +def _stream_overlaps_position(stream: "Stream", x: float, y: float) -> bool: + """Return True if the stage position (x, y) falls within the stream's field of view. + + :param stream: A static stream with a getBoundingBox() method. + :param x: Stage x position in metres. + :param y: Stage y position in metres. + :returns: True when the position is inside the bounding box, False otherwise. + """ + try: + bbox = stream.getBoundingBox() # (left, top, right, bottom) in metres + except Exception: + return False + left, top, right, bottom = bbox + return left <= x <= right and bottom <= y <= top + + +def collect_feature_data( + feature: "CryoFeature", + overview_streams: Optional[List["Stream"]] = None, + project_dir: Optional[str] = None, +) -> None: + """Collect anonymized data for a feature and submit it to the data collector. + + Skips immediately if feature.collect is False or if data collection consent + has not been granted. Never raises — all errors are logged and suppressed. + + The payload contains: + - First acquired z-stack per FM channel (or first FM image if no z-stack). + - FM and SEM overview images that spatially overlap the feature's position. + - Feature status, stage position, and FM focus position. + + Privacy rules enforced: + - Feature name is never included. + - Image payload keys are generic (channel_0, overview_fm_0, etc.). + - Original filenames are not included in the payload. + + After collection feature.collect is set to False to prevent re-collection. + + :param feature: The feature to collect data for. + :param overview_streams: Optional list of overview static streams. + Used to find FM / SEM overviews that overlap the feature position. + :param project_dir: Optional project directory path. When provided and + feature.streams is empty, streams are loaded from disk first. + """ + if not feature.collect: + return + + try: + _dc = DataCollector() + if not _dc.get_consent(): + return + except Exception: + logging.exception("collect_feature_data: failed to access DataCollector; skipping.") + return + + try: + + # Load feature streams from disk when not yet in memory. + if not feature.streams.value and project_dir: + try: + load_feature_streams_from_disk(feature, project_dir) + except Exception: + logging.exception( + "collect_feature_data: failed to load streams; skipping.") + return + + feature_streams = list(feature.streams.value) + + # Collect first z-stack per FM channel; fall back to first FM image per channel. + fm_zstacks: List = [] + fm_images: List = [] + for s in feature_streams: + if isinstance(s, StaticFluoStream): + if _is_zstack_stream(s): + fm_zstacks.append(s) + else: + fm_images.append(s) + + # Per channel: prefer z-stack, then plain FM image. + # Channels are delineated by MD_OUT_WL; use index as fallback. + selected_fm: List = [] + seen_channels: set = set() + for s in fm_zstacks + fm_images: + try: + channel_key = s.raw[0].metadata.get(model.MD_OUT_WL) + except (IndexError, AttributeError): + channel_key = None + if channel_key not in seen_channels: + seen_channels.add(channel_key) + selected_fm.append(s) + + # Collect spatially overlapping overview streams. + stage_pos = feature.stage_position.value + feat_x = stage_pos.get("x", 0.0) + feat_y = stage_pos.get("y", 0.0) + + overview_fm: List = [] + overview_sem: List = [] + for s in (overview_streams or []): + if not _stream_overlaps_position(s, feat_x, feat_y): + continue + if isinstance(s, StaticFluoStream): + overview_fm.append(s) + elif isinstance(s, (StaticSEMStream, StaticFIBStream)): + overview_sem.append(s) + + # Build privacy-preserving payload — generic keys, no names or filenames. + payload: dict = { + "status": feature.status.value, + "stage_position": dict(stage_pos), + "fm_focus_position": dict(feature.fm_focus_position.value), + } + + def _get_raw(stream: "Stream") -> Optional["model.DataArray"]: + try: + return stream.raw[0] if stream.raw else None + except Exception: + return None + + for idx, s in enumerate(selected_fm): + da = _get_raw(s) + if da is not None: + payload[f"channel_{idx}"] = da + + for idx, s in enumerate(overview_fm): + da = _get_raw(s) + if da is not None: + payload[f"overview_fm_{idx}"] = da + + for idx, s in enumerate(overview_sem): + da = _get_raw(s) + if da is not None: + payload[f"overview_sem_{idx}"] = da + + _dc.record("feature_collected", "1.0", payload) + + feature.collect = False + logging.debug("collect_feature_data: submitted feature data for collection.") + + except Exception: + logging.exception("collect_feature_data: unexpected error; feature data not collected.") + # To handle the timeout error when the stage is not able to move to the desired position # It logs the message and raises the MoveError exception class MoveError(Exception): diff --git a/src/odemis/acq/test/feature_test.py b/src/odemis/acq/test/feature_test.py index 8443f85090..7e3c9646ef 100644 --- a/src/odemis/acq/test/feature_test.py +++ b/src/odemis/acq/test/feature_test.py @@ -22,14 +22,21 @@ import logging import os import random +import shutil +import tempfile +import time import unittest +from unittest.mock import patch, MagicMock import numpy from odemis import model from odemis.acq.feature import ( CryoFeature, + FEATURE_COLLECT_PROBABILITY, FeaturesDecoder, + _is_zstack_stream, + collect_feature_data, get_features_dict, read_features, save_features, @@ -61,8 +68,8 @@ def tearDown(self): os.rmdir(self.path) def test_feature_encoder(self): - feature1 = CryoFeature("Feature-1", stage_position={"x": 0, "y": 0, "z": 0}, fm_focus_position={"z": 0}) - feature2 = CryoFeature("Feature-2", stage_position={"x": 1e-3, "y": 1e-3, "z": 1e-3}, fm_focus_position={"z": 2e-3}) + feature1 = CryoFeature("Feature-1", stage_position={"x": 0, "y": 0, "z": 0}, fm_focus_position={"z": 0}, collect=False) + feature2 = CryoFeature("Feature-2", stage_position={"x": 1e-3, "y": 1e-3, "z": 1e-3}, fm_focus_position={"z": 2e-3}, collect=False) feature1.milling_tasks = {} feature2.milling_tasks = {} features = [feature1, feature2] @@ -122,5 +129,179 @@ def test_feature_milling_tasks(self): filename = os.path.join(feature.path, f"{feature.name.value}-{REFERENCE_IMAGE_FILENAME}") self.assertTrue(os.path.exists(filename)) + +class TestCollectFlag(unittest.TestCase): + """Tests for the CryoFeature.collect flag and its persistence.""" + + def test_collect_flag_is_bool(self): + """CryoFeature.collect must be a bool when not explicitly provided.""" + f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}) + self.assertIsInstance(f.collect, bool) + + def test_collect_flag_explicit_true(self): + """Passing collect=True must set the attribute to True.""" + f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}, collect=True) + self.assertTrue(f.collect) + + def test_collect_flag_explicit_false(self): + """Passing collect=False must set the attribute to False.""" + f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}, collect=False) + self.assertFalse(f.collect) + + def test_collect_flag_random_probability(self): + """With enough samples, roughly FEATURE_COLLECT_PROBABILITY fraction should be True.""" + n = 500 + trues = sum( + CryoFeature(f"F{i}", {"x": 0, "y": 0, "z": 0}, {"z": 0}).collect + for i in range(n) + ) + ratio = trues / n + # Allow ±10 percentage points tolerance. + self.assertAlmostEqual(ratio, FEATURE_COLLECT_PROBABILITY, delta=0.10) + + def test_collect_flag_persisted_in_dict(self): + """get_features_dict must include 'collect' in each feature entry.""" + f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}, collect=True) + d = get_features_dict([f]) + self.assertIn("collect", d["feature_list"][0]) + self.assertTrue(d["feature_list"][0]["collect"]) + + def test_collect_flag_round_trip_json(self): + """collect flag must survive JSON serialise / deserialise round-trip.""" + for value in (True, False): + f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}, collect=value) + j = json.dumps(get_features_dict([f])) + loaded = json.loads(j, cls=FeaturesDecoder) + self.assertEqual(loaded[0].collect, value) + + def test_collect_flag_missing_in_json_defaults_to_random(self): + """When collect key is absent in loaded JSON, the flag is randomly assigned.""" + f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}, collect=True) + d = get_features_dict([f]) + del d["feature_list"][0]["collect"] + j = json.dumps(d) + loaded = json.loads(j, cls=FeaturesDecoder) + self.assertIsInstance(loaded[0].collect, bool) + + +class TestCollectFeatureData(unittest.TestCase): + """Tests for collect_feature_data().""" + + def _make_feature(self, collect: bool = True, pos=None) -> CryoFeature: + if pos is None: + pos = {"x": 0.0, "y": 0.0, "z": 0.0} + return CryoFeature("TestFeature", pos, {"z": 0.0}, collect=collect) + + def test_skips_when_collect_false(self): + """collect_feature_data must not call record() when feature.collect is False.""" + f = self._make_feature(collect=False) + with patch("odemis.acq.feature.DataCollector") as MockDC: + collect_feature_data(f) + MockDC.return_value.get_consent.assert_not_called() + + def test_skips_when_no_consent(self): + """collect_feature_data must not call record() when consent is not granted.""" + f = self._make_feature(collect=True) + with patch("odemis.acq.feature.DataCollector") as MockDC: + MockDC.return_value.get_consent.return_value = False + collect_feature_data(f) + MockDC.return_value.record.assert_not_called() + + def test_calls_record_when_consent_given(self): + """collect_feature_data must call record() once when consent is True.""" + f = self._make_feature(collect=True) + with patch("odemis.acq.feature.DataCollector") as MockDC: + mock_instance = MockDC.return_value + mock_instance.get_consent.return_value = True + collect_feature_data(f) + mock_instance.record.assert_called_once() + + def test_sets_collect_false_after_collection(self): + """feature.collect must be False after collect_feature_data is called.""" + f = self._make_feature(collect=True) + with patch("odemis.acq.feature.DataCollector") as MockDC: + MockDC.return_value.get_consent.return_value = True + collect_feature_data(f) + self.assertFalse(f.collect) + + def test_collect_false_not_changed_when_skipped(self): + """feature.collect remains True when collection is skipped due to no consent.""" + f = self._make_feature(collect=True) + with patch("odemis.acq.feature.DataCollector") as MockDC: + MockDC.return_value.get_consent.return_value = False + collect_feature_data(f) + self.assertTrue(f.collect) + + def test_payload_contains_status_and_positions(self): + """Payload must contain status, stage_position and fm_focus_position.""" + f = self._make_feature(collect=True) + f.status.value = "Active" + captured = {} + + def fake_record(event_name, schema_version, payload, **kwargs): + captured.update(payload) + + with patch("odemis.acq.feature.DataCollector") as MockDC: + MockDC.return_value.get_consent.return_value = True + MockDC.return_value.record.side_effect = fake_record + collect_feature_data(f) + + self.assertIn("status", captured) + self.assertIn("stage_position", captured) + self.assertIn("fm_focus_position", captured) + + def test_payload_has_no_feature_name(self): + """Payload must not contain the feature name string as a key or value.""" + f = self._make_feature(collect=True) + f.name.value = "my_secret_feature_name" + captured = {} + + def fake_record(event_name, schema_version, payload, **kwargs): + captured.update(payload) + + with patch("odemis.acq.feature.DataCollector") as MockDC: + MockDC.return_value.get_consent.return_value = True + MockDC.return_value.record.side_effect = fake_record + collect_feature_data(f) + + self.assertNotIn("my_secret_feature_name", captured) + self.assertNotIn("my_secret_feature_name", str(captured.keys())) + + def test_payload_channel_keys_are_generic(self): + """Image payload keys must be generic (channel_N) not derived from feature name.""" + from odemis.acq.stream import StaticFluoStream + arr = numpy.zeros((64, 64), dtype=numpy.uint16) + da = model.DataArray(arr, metadata={ + model.MD_POS: (0.0, 0.0), + model.MD_PIXEL_SIZE: (1e-6, 1e-6), + }) + stream = StaticFluoStream("test_stream", da) + f = self._make_feature(collect=True) + f.streams.value.append(stream) + captured = {} + + def fake_record(event_name, schema_version, payload, **kwargs): + captured.update(payload) + + with patch("odemis.acq.feature.DataCollector") as MockDC: + MockDC.return_value.get_consent.return_value = True + MockDC.return_value.record.side_effect = fake_record + collect_feature_data(f) + + image_keys = [k for k in captured if k.startswith("channel_")] + self.assertTrue(len(image_keys) >= 1, "Expected at least one channel_N key in payload") + for k in image_keys: + self.assertRegex(k, r"^channel_\d+$") + + def test_never_raises(self): + """collect_feature_data must never raise an exception.""" + f = self._make_feature(collect=True) + with patch("odemis.acq.feature.DataCollector", side_effect=RuntimeError("boom")): + try: + collect_feature_data(f) + except Exception as exc: + self.fail(f"collect_feature_data raised unexpectedly: {exc}") + + if __name__ == "__main__": unittest.main() diff --git a/src/odemis/acq/test/test-features.json b/src/odemis/acq/test/test-features.json index 19bcaad5d8..91290c1810 100644 --- a/src/odemis/acq/test/test-features.json +++ b/src/odemis/acq/test/test-features.json @@ -1 +1 @@ -{"feature_list": [{"name": "Feature-1", "status": "Active", "stage_position": {"x": 0, "y": 0, "z": 0}, "fm_focus_position": {"z": 0}, "posture_positions": {}, "milling_tasks": {}, "correlation_data": {}, "superz_stream_name": null, "superz_focused": null}, {"name": "Feature-2", "status": "Active", "stage_position": {"x": 0.001, "y": 0.001, "z": 0.001}, "fm_focus_position": {"z": 0.002}, "posture_positions": {}, "milling_tasks": {}, "correlation_data": {}, "superz_stream_name": null, "superz_focused": null}]} \ No newline at end of file +{"feature_list": [{"name": "Feature-1", "status": "Active", "collect": false, "stage_position": {"x": 0, "y": 0, "z": 0}, "fm_focus_position": {"z": 0}, "posture_positions": {}, "milling_tasks": {}, "correlation_data": {}, "superz_stream_name": null, "superz_focused": null}, {"name": "Feature-2", "status": "Active", "collect": false, "stage_position": {"x": 0.001, "y": 0.001, "z": 0.001}, "fm_focus_position": {"z": 0.002}, "posture_positions": {}, "milling_tasks": {}, "correlation_data": {}, "superz_stream_name": null, "superz_focused": null}]} \ No newline at end of file diff --git a/src/odemis/gui/cont/features.py b/src/odemis/gui/cont/features.py index 001dc89ed6..1575e2f43e 100644 --- a/src/odemis/gui/cont/features.py +++ b/src/odemis/gui/cont/features.py @@ -23,7 +23,9 @@ import copy import itertools import logging +import math import os +import threading import wx from typing import Dict, List @@ -36,6 +38,7 @@ FEATURE_READY_TO_MILL, FEATURE_ROUGH_MILLED, CryoFeature, + collect_feature_data, get_feature_position_at_posture, save_features, FIBFMCorrelationData, @@ -58,6 +61,10 @@ SUPPORTED_POSTURES = [SEM_IMAGING, FM_IMAGING, MILLING, FIB_IMAGING] +# Maximum distance (in metres) within which another feature is considered "nearby" +# for the feature-deletion data-collection trigger. +_NEARBY_FEATURE_DISTANCE_M = 100e-6 + class CryoFeatureController(object): """ controller to handle the cryo feature panel elements It requires features list VA & currentFeature VA on the tab data to function properly @@ -113,6 +120,9 @@ def __init__(self, tab_data, panel, tab, mode: guimod.AcquiMode): self._panel.btn_feature_save_position.Show(LICENCE_MILLING_ENABLED) self.pm.current_posture.subscribe(self._on_posture_change) + # Track previous posture so we can detect FM → SEM/FIB transitions. + self._prev_posture = self.pm.getCurrentPostureLabel() if self.pm else None + def _on_btn_create_move_feature(self, _): # As this button is identical to clicking the feature tool, # directly change the tool to feature tool @@ -130,6 +140,7 @@ def _on_btn_delete_feature(self, _): style=wx.YES_NO | wx.ICON_QUESTION | wx.CENTER) ans = box.ShowModal() if ans == wx.ID_YES: + self._maybe_collect_on_delete(current_feature) self._tab_data_model.main.features.value.remove(current_feature) self._tab_data_model.main.currentFeature.value = None if self.acqui_mode is guimod.AcquiMode.FIBSEM: @@ -251,6 +262,11 @@ def _on_posture_change(self, posture: int): return self._enable_feature_ctrls(True) + prev = self._prev_posture + self._prev_posture = posture + if prev == FM_IMAGING and posture in (SEM_IMAGING, FIB_IMAGING): + self._collect_eligible_features_in_thread() + def _enable_feature_ctrls(self, enable: bool): """ Enables/disables the feature controls @@ -457,3 +473,97 @@ def _on_ctrl_feature_z_change(self): zpos = self._panel.ctrl_feature_z.GetValue() return {"z": zpos} + + # ── Data-collection helpers ────────────────────────────────────────────── + + def _get_overview_streams(self) -> list: + """Return overview streams from the tab data model, or an empty list.""" + try: + return list(self._tab_data_model.overviewStreams.value) + except AttributeError: + return [] + + def _get_project_dir(self) -> str: + """Return the current project directory, or an empty string.""" + try: + return self._tab.conf.pj_last_path or "" + except AttributeError: + return "" + + def _collect_feature_in_thread(self, feature: CryoFeature) -> None: + """Launch collect_feature_data for a single feature in a background thread. + + :param feature: The feature to collect data for. + """ + overview_streams = self._get_overview_streams() + project_dir = self._get_project_dir() + + def _run(): + collect_feature_data(feature, overview_streams=overview_streams, project_dir=project_dir) + + t = threading.Thread(target=_run, name="FeatureDataCollection", daemon=True) + t.start() + + def _collect_eligible_features_in_thread(self) -> None: + """Launch collect_feature_data for all features with collect=True in a background thread.""" + features = list(self._tab_data_model.main.features.value) + overview_streams = self._get_overview_streams() + project_dir = self._get_project_dir() + + def _run(): + for feature in features: + if feature.collect: + collect_feature_data( + feature, + overview_streams=overview_streams, + project_dir=project_dir, + ) + + t = threading.Thread(target=_run, name="FeatureDataCollectionBulk", daemon=True) + t.start() + + def _has_zstack_stream(self, feature: CryoFeature) -> bool: + """Return True if the feature has at least one z-stack stream. + + :param feature: The feature to check. + :returns: True when a z-stack stream is present, False otherwise. + """ + from odemis.acq.feature import _is_zstack_stream + return any(_is_zstack_stream(s) for s in feature.streams.value) + + def _has_nearby_feature(self, feature: CryoFeature, distance_m: float = _NEARBY_FEATURE_DISTANCE_M) -> bool: + """Return True if any other feature is within distance_m of the given feature. + + :param feature: The feature to check proximity for. + :param distance_m: Maximum distance in metres to be considered nearby. + :returns: True when another feature is within the threshold distance. + """ + pos = feature.stage_position.value + fx, fy = pos.get("x", 0.0), pos.get("y", 0.0) + for other in self._tab_data_model.main.features.value: + if other is feature: + continue + other_pos = other.stage_position.value + ox, oy = other_pos.get("x", 0.0), other_pos.get("y", 0.0) + dist = math.sqrt((fx - ox) ** 2 + (fy - oy) ** 2) + if dist <= distance_m: + return True + return False + + def _maybe_collect_on_delete(self, feature: CryoFeature) -> None: + """Trigger data collection before a feature is deleted if eligible. + + Collection is triggered when all three conditions are met: + - feature.collect is True + - The feature has at least one z-stack stream + - No other feature is within 100 µm + + :param feature: The feature about to be deleted. + """ + if not feature.collect: + return + if not self._has_zstack_stream(feature): + return + if self._has_nearby_feature(feature): + return + self._collect_feature_in_thread(feature) From c94530d26dfd37ecd914d6bc3da28c2a6b7b276a Mon Sep 17 00:00:00 2001 From: Karishma Kumar Date: Mon, 20 Apr 2026 14:24:51 +0200 Subject: [PATCH 3/7] make images required to upload --- src/odemis/acq/feature.py | 7 +++ src/odemis/acq/test/feature_test.py | 71 +++++++++++++++++++++-------- 2 files changed, 58 insertions(+), 20 deletions(-) diff --git a/src/odemis/acq/feature.py b/src/odemis/acq/feature.py index 2e2b83328c..16d133a9fc 100644 --- a/src/odemis/acq/feature.py +++ b/src/odemis/acq/feature.py @@ -629,6 +629,13 @@ def _get_raw(stream: "Stream") -> Optional["model.DataArray"]: if da is not None: payload[f"overview_sem_{idx}"] = da + image_keys = [k for k in payload if k.startswith(("channel_", "overview_fm_", "overview_sem_"))] + if not image_keys: + logging.debug( + "collect_feature_data: no images found for feature; skipping upload." + ) + return + _dc.record("feature_collected", "1.0", payload) feature.collect = False diff --git a/src/odemis/acq/test/feature_test.py b/src/odemis/acq/test/feature_test.py index 7e3c9646ef..0a23c3590b 100644 --- a/src/odemis/acq/test/feature_test.py +++ b/src/odemis/acq/test/feature_test.py @@ -192,6 +192,22 @@ def _make_feature(self, collect: bool = True, pos=None) -> CryoFeature: pos = {"x": 0.0, "y": 0.0, "z": 0.0} return CryoFeature("TestFeature", pos, {"z": 0.0}, collect=collect) + def _make_fluo_stream(self): + """Return a minimal StaticFluoStream with a 2-D DataArray.""" + from odemis.acq.stream import StaticFluoStream + arr = numpy.zeros((64, 64), dtype=numpy.uint16) + da = model.DataArray(arr, metadata={ + model.MD_POS: (0.0, 0.0), + model.MD_PIXEL_SIZE: (1e-6, 1e-6), + }) + return StaticFluoStream("ch0", da) + + def _make_feature_with_stream(self, collect: bool = True) -> CryoFeature: + """Return a feature with one FM stream attached.""" + f = self._make_feature(collect=collect) + f.streams.value.append(self._make_fluo_stream()) + return f + def test_skips_when_collect_false(self): """collect_feature_data must not call record() when feature.collect is False.""" f = self._make_feature(collect=False) @@ -201,15 +217,31 @@ def test_skips_when_collect_false(self): def test_skips_when_no_consent(self): """collect_feature_data must not call record() when consent is not granted.""" - f = self._make_feature(collect=True) + f = self._make_feature_with_stream(collect=True) with patch("odemis.acq.feature.DataCollector") as MockDC: MockDC.return_value.get_consent.return_value = False collect_feature_data(f) MockDC.return_value.record.assert_not_called() + def test_no_record_without_images(self): + """record() must NOT be called when the feature has no image streams.""" + f = self._make_feature(collect=True) # no streams attached + with patch("odemis.acq.feature.DataCollector") as MockDC: + MockDC.return_value.get_consent.return_value = True + collect_feature_data(f) + MockDC.return_value.record.assert_not_called() + + def test_collect_flag_unchanged_without_images(self): + """feature.collect must stay True when skipped due to no images.""" + f = self._make_feature(collect=True) # no streams + with patch("odemis.acq.feature.DataCollector") as MockDC: + MockDC.return_value.get_consent.return_value = True + collect_feature_data(f) + self.assertTrue(f.collect) + def test_calls_record_when_consent_given(self): - """collect_feature_data must call record() once when consent is True.""" - f = self._make_feature(collect=True) + """collect_feature_data must call record() once when consent is True and images exist.""" + f = self._make_feature_with_stream(collect=True) with patch("odemis.acq.feature.DataCollector") as MockDC: mock_instance = MockDC.return_value mock_instance.get_consent.return_value = True @@ -217,8 +249,8 @@ def test_calls_record_when_consent_given(self): mock_instance.record.assert_called_once() def test_sets_collect_false_after_collection(self): - """feature.collect must be False after collect_feature_data is called.""" - f = self._make_feature(collect=True) + """feature.collect must be False after successful collection with images.""" + f = self._make_feature_with_stream(collect=True) with patch("odemis.acq.feature.DataCollector") as MockDC: MockDC.return_value.get_consent.return_value = True collect_feature_data(f) @@ -226,15 +258,15 @@ def test_sets_collect_false_after_collection(self): def test_collect_false_not_changed_when_skipped(self): """feature.collect remains True when collection is skipped due to no consent.""" - f = self._make_feature(collect=True) + f = self._make_feature_with_stream(collect=True) with patch("odemis.acq.feature.DataCollector") as MockDC: MockDC.return_value.get_consent.return_value = False collect_feature_data(f) self.assertTrue(f.collect) - def test_payload_contains_status_and_positions(self): - """Payload must contain status, stage_position and fm_focus_position.""" - f = self._make_feature(collect=True) + def test_payload_contains_status_positions_and_image(self): + """Payload must contain status, stage_position, fm_focus_position, and at least one image.""" + f = self._make_feature_with_stream(collect=True) f.status.value = "Active" captured = {} @@ -249,10 +281,12 @@ def fake_record(event_name, schema_version, payload, **kwargs): self.assertIn("status", captured) self.assertIn("stage_position", captured) self.assertIn("fm_focus_position", captured) + image_keys = [k for k in captured if k.startswith(("channel_", "overview_fm_", "overview_sem_"))] + self.assertTrue(len(image_keys) >= 1, "Payload must contain at least one image") def test_payload_has_no_feature_name(self): """Payload must not contain the feature name string as a key or value.""" - f = self._make_feature(collect=True) + f = self._make_feature_with_stream(collect=True) f.name.value = "my_secret_feature_name" captured = {} @@ -268,16 +302,13 @@ def fake_record(event_name, schema_version, payload, **kwargs): self.assertNotIn("my_secret_feature_name", str(captured.keys())) def test_payload_channel_keys_are_generic(self): - """Image payload keys must be generic (channel_N) not derived from feature name.""" - from odemis.acq.stream import StaticFluoStream - arr = numpy.zeros((64, 64), dtype=numpy.uint16) - da = model.DataArray(arr, metadata={ - model.MD_POS: (0.0, 0.0), - model.MD_PIXEL_SIZE: (1e-6, 1e-6), - }) - stream = StaticFluoStream("test_stream", da) - f = self._make_feature(collect=True) - f.streams.value.append(stream) + """Image payload keys must be generic (channel_N), not derived from feature or stream name. + + A StaticFluoStream named 'test_stream' is attached to the feature. + After collection the payload key for the image must be 'channel_0', + not 'test_stream' or the feature name — ensuring data privacy. + """ + f = self._make_feature_with_stream(collect=True) captured = {} def fake_record(event_name, schema_version, payload, **kwargs): From 547d2c962db0be60a51982a994ce7e7e8db30a27 Mon Sep 17 00:00:00 2001 From: Karishma Kumar Date: Tue, 21 Apr 2026 11:27:35 +0200 Subject: [PATCH 4/7] trigger on feature change --- src/odemis/acq/feature.py | 2 +- src/odemis/acq/test/feature_test.py | 103 ++++++++++++++++++++++++++-- src/odemis/gui/cont/features.py | 25 +++++++ 3 files changed, 125 insertions(+), 5 deletions(-) diff --git a/src/odemis/acq/feature.py b/src/odemis/acq/feature.py index 16d133a9fc..4d24d5dd50 100644 --- a/src/odemis/acq/feature.py +++ b/src/odemis/acq/feature.py @@ -508,7 +508,7 @@ def _stream_overlaps_position(stream: "Stream", x: float, y: float) -> bool: except Exception: return False left, top, right, bottom = bbox - return left <= x <= right and bottom <= y <= top + return left <= x <= right and top <= y <= bottom def collect_feature_data( diff --git a/src/odemis/acq/test/feature_test.py b/src/odemis/acq/test/feature_test.py index 0a23c3590b..7f179152e0 100644 --- a/src/odemis/acq/test/feature_test.py +++ b/src/odemis/acq/test/feature_test.py @@ -34,16 +34,17 @@ from odemis.acq.feature import ( CryoFeature, FEATURE_COLLECT_PROBABILITY, + FEATURE_READY_TO_MILL, FeaturesDecoder, + MILLING, + REFERENCE_IMAGE_FILENAME, _is_zstack_stream, + _stream_overlaps_position, collect_feature_data, get_features_dict, + load_milling_tasks, read_features, save_features, - load_milling_tasks, - FEATURE_READY_TO_MILL, - MILLING, - REFERENCE_IMAGE_FILENAME, ) from odemis.acq.milling import DEFAULT_MILLING_TASKS_PATH @@ -333,6 +334,100 @@ def test_never_raises(self): except Exception as exc: self.fail(f"collect_feature_data raised unexpectedly: {exc}") + def test_collects_on_status_change(self): + """Subscribing to feature.status and calling collect_feature_data on change must call record(). + + This simulates the controller's _on_status_for_collection subscriber: when + the feature status VA changes, collect_feature_data is invoked and record() + is called exactly once (consent granted, images present, collect=True). + """ + f = self._make_feature_with_stream(collect=True) + record_calls = [] + + def fake_record(event_name, schema_version, payload, **kwargs): + record_calls.append((event_name, schema_version)) + + def _on_status_changed(_status): + if f.collect: + collect_feature_data(f) + + f.status.subscribe(_on_status_changed, init=False) + try: + with patch("odemis.acq.feature.DataCollector") as MockDC: + MockDC.return_value.get_consent.return_value = True + MockDC.return_value.record.side_effect = fake_record + f.status.value = FEATURE_READY_TO_MILL + finally: + f.status.unsubscribe(_on_status_changed) + + self.assertEqual(len(record_calls), 1) + self.assertEqual(record_calls[0][0], "feature_collected") + + def test_no_collect_on_status_change_when_disabled(self): + """record() must not be called on status change when feature.collect is False. + + Simulates the controller's _on_status_for_collection guard on feature.collect. + """ + f = self._make_feature_with_stream(collect=False) + + def _on_status_changed(_status): + if f.collect: + collect_feature_data(f) + + f.status.subscribe(_on_status_changed, init=False) + try: + with patch("odemis.acq.feature.DataCollector") as MockDC: + MockDC.return_value.get_consent.return_value = True + f.status.value = FEATURE_READY_TO_MILL + MockDC.return_value.record.assert_not_called() + finally: + f.status.unsubscribe(_on_status_changed) + + +class TestStreamHelpers(unittest.TestCase): + """Tests for _is_zstack_stream and _stream_overlaps_position.""" + + def _make_static_fluo_stream(self, shape=(64, 64), pos=(0.0, 0.0), pixel_size=(1e-6, 1e-6)): + """Return a minimal StaticFluoStream.""" + from odemis.acq.stream import StaticFluoStream + arr = numpy.zeros(shape, dtype=numpy.uint16) + da = model.DataArray(arr, metadata={ + model.MD_POS: pos, + model.MD_PIXEL_SIZE: pixel_size, + }) + return StaticFluoStream("test_stream", da) + + def _make_zstack_stream(self, pos=(0.0, 0.0), pixel_size=(1e-6, 1e-6)): + """Return a minimal StaticFluoStream that looks like a z-stack (has zIndex).""" + s = self._make_static_fluo_stream(pos=pos, pixel_size=pixel_size) + s.zIndex = model.IntContinuous(0, (0, 3)) + return s + + def test_overlaps_centre(self): + """Position at the stream centre must overlap.""" + # 64 x 64 pixels at 1 µm/pixel centred at (0, 0) → bbox ±32 µm. + s = self._make_static_fluo_stream() + self.assertTrue(_stream_overlaps_position(s, 0.0, 0.0)) + + def test_overlaps_edge(self): + """Position exactly on the bounding-box edge must still overlap.""" + s = self._make_static_fluo_stream(pos=(0.0, 0.0), pixel_size=(2e-6, 2e-6)) + # half-width = 64/2 * 2e-6 = 64e-6 m → right edge at +64e-6 + self.assertTrue(_stream_overlaps_position(s, 64e-6, 0.0)) + + def test_no_overlap_outside(self): + """Position clearly outside the bounding box must not overlap.""" + s = self._make_static_fluo_stream() + # bbox is ±32 µm; 100 µm is well outside. + self.assertFalse(_stream_overlaps_position(s, 100e-6, 0.0)) + + def test_no_overlap_bad_stream(self): + """_stream_overlaps_position returns False when getBoundingBox() raises.""" + from unittest.mock import MagicMock + bad_stream = MagicMock() + bad_stream.getBoundingBox.side_effect = AttributeError("no bbox") + self.assertFalse(_stream_overlaps_position(bad_stream, 0.0, 0.0)) + if __name__ == "__main__": unittest.main() diff --git a/src/odemis/gui/cont/features.py b/src/odemis/gui/cont/features.py index 1575e2f43e..f46e4f8b04 100644 --- a/src/odemis/gui/cont/features.py +++ b/src/odemis/gui/cont/features.py @@ -94,6 +94,9 @@ def __init__(self, tab_data, panel, tab, mode: guimod.AcquiMode): self._feature_status_va_connector = None self._feature_z_va_connector = None + # Feature whose status VA we are subscribed to for data-collection triggering. + self._status_collect_feature: Optional[CryoFeature] = None + self._tab_data_model.main.features.subscribe(self._on_features_changes, init=True) self._tab_data_model.main.currentFeature.subscribe(self._on_current_feature_changes, init=True) @@ -349,6 +352,11 @@ def _on_current_feature_changes(self, feature): if self._feature_z_va_connector: self._feature_z_va_connector.disconnect() + # Unsubscribe status-change data-collection trigger from the previous feature. + if self._status_collect_feature is not None: + self._status_collect_feature.status.unsubscribe(self._on_status_for_collection) + self._status_collect_feature = None + self._update_feature_cmb_list() if feature is None: @@ -415,6 +423,10 @@ def _on_current_feature_changes(self, feature): ctrl_2_va=self._on_ctrl_feature_z_change, va_2_ctrl=self._on_feature_focus_pos) + # Subscribe to status changes to trigger data collection (init=False: skip current value). + feature.status.subscribe(self._on_status_for_collection, init=False) + self._status_collect_feature = feature + def _on_feature_focus_pos(self, fm_focus_position: dict): # Set the feature Z ctrl with the focus position self._panel.ctrl_feature_z.SetValue(fm_focus_position["z"]) @@ -490,6 +502,19 @@ def _get_project_dir(self) -> str: except AttributeError: return "" + def _on_status_for_collection(self, _status: str) -> None: + """Trigger data collection when the current feature's status changes. + + Called by the status VA subscriber (init=False) so it fires only on + actual changes, never on initial subscription. + + :param _status: The new feature status value (unused; feature is read + from the stored reference to avoid a race with currentFeature). + """ + feature = self._status_collect_feature + if feature is not None and feature.collect: + self._collect_feature_in_thread(feature) + def _collect_feature_in_thread(self, feature: CryoFeature) -> None: """Launch collect_feature_data for a single feature in a background thread. From 8c7f40927e106544a8be98c9b0f003640f607410 Mon Sep 17 00:00:00 2001 From: Karishma Kumar Date: Wed, 22 Apr 2026 11:00:21 +0200 Subject: [PATCH 5/7] use env variable to use test bucket --- src/odemis/util/datacollector.py | 21 +++++++++++-- src/odemis/util/test/datacollector_test.py | 34 ++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/odemis/util/datacollector.py b/src/odemis/util/datacollector.py index a3cec680b8..5cfa551e9c 100644 --- a/src/odemis/util/datacollector.py +++ b/src/odemis/util/datacollector.py @@ -61,8 +61,12 @@ S3_BUCKET = "delmic-odemis-collect" # S3 bucket used for automated tests (not the production bucket). +# Selected automatically when the environment variable TEST_DATACOLLECTION=1 is set. S3_TEST_BUCKET = "delmic-odemis-collect-test" +# Environment variable name that switches the framework to the test bucket. +_TEST_DATACOLLECTION_ENV = "TEST_DATACOLLECTION" + # S3 endpoint URL — None means let boto3 resolve the regional endpoint automatically. # Set explicitly only for custom S3-compatible storage. S3_ENDPOINT_URL = None @@ -272,14 +276,27 @@ def should_prompt_for_consent(self) -> bool: return datetime.now(timezone.utc) >= remind_after def get_upload_backend(self) -> "S3UploadBackend": - """Return the configured upload backend instance.""" + """Return the configured upload backend instance. + + When the environment variable ``TEST_DATACOLLECTION=1`` is set the test + bucket (``S3_TEST_BUCKET``) is used instead of the production bucket so + that developer machines never contaminate real data. + """ credentials = _search_credentials() + if os.environ.get(_TEST_DATACOLLECTION_ENV) == "1": + bucket = S3_TEST_BUCKET + logging.info( + "DataCollector: %s=1 — using test bucket '%s'", + _TEST_DATACOLLECTION_ENV, bucket, + ) + else: + bucket = S3_BUCKET return S3UploadBackend( access_key=credentials["access_key"], secret_key=credentials["secret_key"], endpoint_url=S3_ENDPOINT_URL, region=S3_REGION, - bucket=S3_BUCKET, + bucket=bucket, ) diff --git a/src/odemis/util/test/datacollector_test.py b/src/odemis/util/test/datacollector_test.py index 65e7118f28..171a9e119c 100644 --- a/src/odemis/util/test/datacollector_test.py +++ b/src/odemis/util/test/datacollector_test.py @@ -40,9 +40,11 @@ DataCollector, DataCollectorConfig, S3UploadBackend, + S3_BUCKET, S3_REGION, S3_TEST_BUCKET, _CREDENTIALS_PATH, + _TEST_DATACOLLECTION_ENV, _BackgroundWorker, _WorkItem, _enforce_queue_limit, @@ -127,6 +129,38 @@ def test_should_prompt_for_consent_logic(self) -> None: cfg.remind_date = datetime.now(timezone.utc) - timedelta(seconds=1) self.assertTrue(cfg.should_prompt_for_consent()) + def test_get_upload_backend_uses_production_bucket_by_default(self) -> None: + """get_upload_backend uses the production bucket when TEST_DATACOLLECTION is unset.""" + cfg = self._make_config() + fake_creds = {"access_key": "AKID", "secret_key": "SECRET"} + with patch("odemis.util.datacollector._search_credentials", return_value=fake_creds), \ + patch.dict(os.environ, {}, clear=False) as env: + env.pop(_TEST_DATACOLLECTION_ENV, None) + backend = cfg.get_upload_backend() + self.assertIsInstance(backend, S3UploadBackend) + self.assertEqual(backend._bucket, S3_BUCKET) + + def test_get_upload_backend_uses_test_bucket_when_env_set(self) -> None: + """get_upload_backend selects S3_TEST_BUCKET when TEST_DATACOLLECTION=1.""" + cfg = self._make_config() + fake_creds = {"access_key": "AKID", "secret_key": "SECRET"} + with patch("odemis.util.datacollector._search_credentials", return_value=fake_creds), \ + patch.dict(os.environ, {_TEST_DATACOLLECTION_ENV: "1"}): + backend = cfg.get_upload_backend() + self.assertIsInstance(backend, S3UploadBackend) + self.assertEqual(backend._bucket, S3_TEST_BUCKET) + + def test_get_upload_backend_ignores_non_one_env_value(self) -> None: + """get_upload_backend falls back to production when TEST_DATACOLLECTION != '1'.""" + cfg = self._make_config() + fake_creds = {"access_key": "AKID", "secret_key": "SECRET"} + for value in ("0", "true", "yes", ""): + with self.subTest(value=value), \ + patch("odemis.util.datacollector._search_credentials", return_value=fake_creds), \ + patch.dict(os.environ, {_TEST_DATACOLLECTION_ENV: value}): + backend = cfg.get_upload_backend() + self.assertEqual(backend._bucket, S3_BUCKET, f"Expected production bucket for env={value!r}") + class TestSerialize(unittest.TestCase): """Tests for _serialize() — ZIP structure and metadata correctness.""" From 279f1b461a599e177df45580dac62c51077f183b Mon Sep 17 00:00:00 2001 From: Karishma Kumar Date: Wed, 22 Apr 2026 11:59:24 +0200 Subject: [PATCH 6/7] change the collection per project per session --- src/odemis/acq/feature.py | 17 ++-- src/odemis/acq/test/feature_test.py | 93 ++++++++++++++++---- src/odemis/gui/cont/tabs/cryo_chamber_tab.py | 18 +++- src/odemis/gui/model/tab_gui_data.py | 3 +- 4 files changed, 103 insertions(+), 28 deletions(-) diff --git a/src/odemis/acq/feature.py b/src/odemis/acq/feature.py index 4d24d5dd50..90d592d578 100644 --- a/src/odemis/acq/feature.py +++ b/src/odemis/acq/feature.py @@ -25,7 +25,6 @@ import json import logging import os -import random import threading import time from concurrent import futures @@ -180,7 +179,7 @@ def __init__(self, name: str, streams: Optional[List[Stream]] = None, milling_tasks: Optional[Dict[str, MillingTaskSettings]] = None, correlation_data=None, - collect: Optional[bool] = None): + collect: bool = False): """ :param name: (string) the feature name :param stage_position: (dict) the stage position of the feature (stage-bare) @@ -188,8 +187,9 @@ def __init__(self, name: str, :param streams: (List of StaticStream) list of acquired streams on this feature :param correlation_data: (Dict[str,FIBFMCorrelationData]) Dictionary mapping the feature status to FIBFMCorrelationData, where feature status like Active, Rough Milled or polished is the key. - :param collect: (bool or None) Whether this feature is eligible for data collection. - If None (default), the flag is initialised randomly with FEATURE_COLLECT_PROBABILITY. + :param collect: (bool) Whether this feature is eligible for data collection. + Defaults to False. The GUI sets this based on the per-project sampling + decision made when a project is opened or created. """ self.name = model.StringVA(name) # FIXME: The 'position' parameter should eventually contain the SampleStage coordinates and not stage bare from the stage_position! @@ -221,11 +221,8 @@ def __init__(self, name: str, self.correlation_data = correlation_data # Whether this feature is eligible for data collection. - # Randomly assigned on creation with FEATURE_COLLECT_PROBABILITY; persisted in features.json. - if collect is None: - self.collect: bool = random.random() < FEATURE_COLLECT_PROBABILITY - else: - self.collect = collect + # Set by the GUI from the per-project sampling decision; persisted in features.json. + self.collect: bool = collect # attributes for automated milling self.path: str = None # TODO:support path creation here, rather than on milling data save @@ -369,7 +366,7 @@ def object_hook(self, obj): feature = CryoFeature(name=obj['name'], stage_position=stage_position, fm_focus_position=fm_focus_position, - collect=obj.get('collect', None) + collect=obj.get('collect', False) ) feature.correlation_data = FIBFMCorrelationData.from_dict(correlation_data) if correlation_data else None feature.status.value = obj['status'] diff --git a/src/odemis/acq/test/feature_test.py b/src/odemis/acq/test/feature_test.py index 7f179152e0..6254f650eb 100644 --- a/src/odemis/acq/test/feature_test.py +++ b/src/odemis/acq/test/feature_test.py @@ -135,9 +135,9 @@ class TestCollectFlag(unittest.TestCase): """Tests for the CryoFeature.collect flag and its persistence.""" def test_collect_flag_is_bool(self): - """CryoFeature.collect must be a bool when not explicitly provided.""" + """CryoFeature.collect must be False by default.""" f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}) - self.assertIsInstance(f.collect, bool) + self.assertFalse(f.collect) def test_collect_flag_explicit_true(self): """Passing collect=True must set the attribute to True.""" @@ -149,17 +149,6 @@ def test_collect_flag_explicit_false(self): f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}, collect=False) self.assertFalse(f.collect) - def test_collect_flag_random_probability(self): - """With enough samples, roughly FEATURE_COLLECT_PROBABILITY fraction should be True.""" - n = 500 - trues = sum( - CryoFeature(f"F{i}", {"x": 0, "y": 0, "z": 0}, {"z": 0}).collect - for i in range(n) - ) - ratio = trues / n - # Allow ±10 percentage points tolerance. - self.assertAlmostEqual(ratio, FEATURE_COLLECT_PROBABILITY, delta=0.10) - def test_collect_flag_persisted_in_dict(self): """get_features_dict must include 'collect' in each feature entry.""" f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}, collect=True) @@ -175,14 +164,86 @@ def test_collect_flag_round_trip_json(self): loaded = json.loads(j, cls=FeaturesDecoder) self.assertEqual(loaded[0].collect, value) - def test_collect_flag_missing_in_json_defaults_to_random(self): - """When collect key is absent in loaded JSON, the flag is randomly assigned.""" + def test_collect_flag_missing_in_json_defaults_to_false(self): + """When collect key is absent in loaded JSON, the flag defaults to False.""" f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}, collect=True) d = get_features_dict([f]) del d["feature_list"][0]["collect"] j = json.dumps(d) loaded = json.loads(j, cls=FeaturesDecoder) - self.assertIsInstance(loaded[0].collect, bool) + self.assertFalse(loaded[0].collect) + + +class TestPerProjectSampling(unittest.TestCase): + """Tests for the per-project data-collection sampling decision. + + The GUI model stores features_collectable which is set once per project + (randomly with FEATURE_COLLECT_PROBABILITY) and passed explicitly to + CryoFeature on creation, so all features in a project share the same + collect value. + """ + + def _make_main(self, features_collectable: bool) -> object: + """Return a lightweight mock of the main GUI data model.""" + main = MagicMock() + main.features_collectable = features_collectable + return main + + def test_features_collectable_true_propagates_to_feature(self): + """When features_collectable=True all new features get collect=True.""" + main = self._make_main(True) + collect = getattr(main, "features_collectable", False) + f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}, collect=collect) + self.assertTrue(f.collect) + + def test_features_collectable_false_propagates_to_feature(self): + """When features_collectable=False all new features get collect=False.""" + main = self._make_main(False) + collect = getattr(main, "features_collectable", False) + f = CryoFeature("F", {"x": 0, "y": 0, "z": 0}, {"z": 0}, collect=collect) + self.assertFalse(f.collect) + + def test_features_collectable_absent_defaults_false(self): + """getattr fallback returns False when features_collectable is not set.""" + main = MagicMock(spec=[]) # no attributes at all + collect = getattr(main, "features_collectable", False) + self.assertFalse(collect) + + def test_per_project_decision_uniform_within_project(self): + """All features created in one project session share the same collect flag.""" + main = self._make_main(True) + collect = getattr(main, "features_collectable", False) + features = [ + CryoFeature(f"F{i}", {"x": i * 1e-6, "y": 0, "z": 0}, {"z": 0}, collect=collect) + for i in range(20) + ] + self.assertTrue(all(f.collect for f in features)) + + def test_loaded_features_are_reset_to_not_collectable(self): + """Features loaded from disk must be immediately marked collect=False. + + This mirrors the behaviour of _load_project_data in cryo_chamber_tab: + loaded features were either already collected or not selected, so they + must never be re-collected in the new session. + """ + loaded_features = [ + CryoFeature(f"F{i}", {"x": 0, "y": 0, "z": 0}, {"z": 0}, collect=True) + for i in range(5) + ] + # Simulate the reset performed by _load_project_data. + for f in loaded_features: + f.collect = False + self.assertTrue(all(not f.collect for f in loaded_features)) + + def test_project_sampling_probability(self): + """Over many projects, roughly FEATURE_COLLECT_PROBABILITY fraction are collectable.""" + n_projects = 500 + collectable_count = sum( + 1 for _ in range(n_projects) + if random.random() < FEATURE_COLLECT_PROBABILITY + ) + ratio = collectable_count / n_projects + self.assertAlmostEqual(ratio, FEATURE_COLLECT_PROBABILITY, delta=0.10) class TestCollectFeatureData(unittest.TestCase): diff --git a/src/odemis/gui/cont/tabs/cryo_chamber_tab.py b/src/odemis/gui/cont/tabs/cryo_chamber_tab.py index 5f9ea25db8..fec6119c64 100644 --- a/src/odemis/gui/cont/tabs/cryo_chamber_tab.py +++ b/src/odemis/gui/cont/tabs/cryo_chamber_tab.py @@ -382,6 +382,17 @@ def _change_project_conf(self, new_dir): self.conf.pj_ptn, self.conf.pj_count = guess_pattern(new_dir) self.txt_projectpath.Value = os.path.basename(self.conf.pj_last_path) self.tab_data_model.main.project_path.value = new_dir + # Decide once per project whether features created during this session are + # eligible for data collection. Stored as a dynamic attribute — not part + # of the formal model — and read by add_new_feature via getattr. + self.tab_data_model.main.features_collectable = ( + random.random() < FEATURE_COLLECT_PROBABILITY + ) + logging.debug( + "Project '%s': features_collectable=%s", + os.path.basename(new_dir), + self.tab_data_model.main.features_collectable, + ) logging.debug("Generated project folder name pattern '%s'", self.conf.pj_ptn) def _create_new_dir(self): @@ -516,7 +527,12 @@ def _load_project_data(self, evt: wx.Event) -> bool: if len(streams) > 0: correlation_tab.correlation_controller.add_streams(streams) - # load features + # load features — immediately mark all as not collectable. + # Loaded features were either already collected in a previous session or + # were never selected; the new per-project sampling decision (set in + # _change_project_conf above) applies only to features created after this. + for f in proj_data["features"]: + f.collect = False self.tab_data_model.main.features.value = proj_data["features"] # log project data diff --git a/src/odemis/gui/model/tab_gui_data.py b/src/odemis/gui/model/tab_gui_data.py index f8c2e20058..dbeb6bdebe 100644 --- a/src/odemis/gui/model/tab_gui_data.py +++ b/src/odemis/gui/model/tab_gui_data.py @@ -343,7 +343,8 @@ def add_new_feature(self, stage_position: Dict[str, float], else: md = self.main.focus.getMetadata() fm_focus_position = md[model.MD_FAV_POS_ACTIVE] - feature = CryoFeature(f_name, stage_position, fm_focus_position) + features_collectable = getattr(self.main, "features_collectable", False) + feature = CryoFeature(f_name, stage_position, fm_focus_position, collect=features_collectable) for p in pm.postures: # calculate the position at all postures get_feature_position_at_posture(pm, feature, p) From 510e3abc1fc5516c3b5bcf6058d4a4cd48d04247 Mon Sep 17 00:00:00 2001 From: Karishma Kumar Date: Wed, 22 Apr 2026 17:10:03 +0200 Subject: [PATCH 7/7] fix import issues in cryo_chamber_tab --- src/odemis/gui/cont/tabs/cryo_chamber_tab.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/odemis/gui/cont/tabs/cryo_chamber_tab.py b/src/odemis/gui/cont/tabs/cryo_chamber_tab.py index fec6119c64..917eeb1d09 100644 --- a/src/odemis/gui/cont/tabs/cryo_chamber_tab.py +++ b/src/odemis/gui/cont/tabs/cryo_chamber_tab.py @@ -29,13 +29,14 @@ import math import os.path from concurrent.futures import CancelledError +import random import wx import odemis.gui.cont.views as viewcont import odemis.gui.model as guimod from odemis import model -from odemis.acq.feature import load_project_data +from odemis.acq.feature import FEATURE_COLLECT_PROBABILITY, load_project_data from odemis.acq.move import ( ALIGNMENT, COATING,