diff --git a/README.md b/README.md index d8451c0..ab94490 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,8 @@ project_names = unique_project_names() | `source_data` | Mapping from derived asset names to their source raw asset names | `s3://allen-data-views/data-asset-cache/zs_source_data.pqt` | metadata | False | `name`, `source_data`, `pipeline_name`, `processing_time` | | `quality_control` | Quality control table with one row per QC metric | `s3://allen-data-views/data-asset-cache/zs_qc/` | asset | True (by `subject_id`) | `name`, `stage`, `modality`, `value`, `status`, `asset_name` | | `assets_smartspim` | SmartSPIM assets with processing status and neuroglancer links | `s3://allen-data-views/data-asset-cache/zs_assets_smartspim.pqt` | metadata | False | `subject_id`, `genotype`, `institution`, `acquisition_start_time`, `processing_end_time`, `stitched_link`, `processed`, `name`, `channel_1`, `segmentation_link_1`, `quantification_link_1`, `channel_2`, `segmentation_link_2`, `quantification_link_2`, `channel_3`, `segmentation_link_3`, `quantification_link_3` | +| `procedures` | Subject procedures summary, one row per procedure per surgery | `s3://allen-data-views/data-asset-cache/zs_procedures.pqt` | asset | False | `procedure_key`, `subject_id`, `surgery_start_date`, `procedure_type` | +| `brain_injections` | Detailed Injection and BrainInjection data, one row per injection | `s3://allen-data-views/data-asset-cache/zs_brain_injections.pqt` | asset | False | `procedure_key`, `subject_id`, `surgery_start_date`, `procedure_type`, `targeted_structure_name`, `targeted_structure_acronym`, `relative_position`, `coordinate_system_name`, ``, `injection_materials`, `injection_profile`, `injection_volume`, `injection_volume_unit`, `protocol_id` | The `raw_to_derived` function is not a table stored in S3, instead it is used by passing an asset_name (or list of asset names) and a modality. The function returns the latest derived asset matching the requested pattern. diff --git a/pyproject.toml b/pyproject.toml index a284539..c3d0ea4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ 'pyarrow', 'boto3', 'pandas>=2.2.0', - 'aind-data-access-api[docdb]', + 'aind-data-access-api[docdb]>=1.10.0,<2', ] [dependency-groups] diff --git a/scripts/hide_procedures_acorn.py b/scripts/hide_procedures_acorn.py new file mode 100644 index 0000000..9ebb9a4 --- /dev/null +++ b/scripts/hide_procedures_acorn.py @@ -0,0 +1,14 @@ +"""Run the procedures and brain_injections hide_acorn for all subjects in one pass.""" + +from zombie_squirrel.acorns import ACORN_REGISTRY, NAMES + + +def main(): + """Hide procedures and brain_injections acorns for all subjects.""" + print("Fetching procedures data for all subjects...") + ACORN_REGISTRY[NAMES["procedures"]](force_update=True) + print("Procedures cache update complete.") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_procedures_integration.py b/scripts/test_procedures_integration.py new file mode 100644 index 0000000..9a3d326 --- /dev/null +++ b/scripts/test_procedures_integration.py @@ -0,0 +1,118 @@ +"""Integration tests for procedures and brain_injections acorns against S3.""" + +import unittest + +import boto3 +import pandas as pd + +from zombie_squirrel.acorns import NAMES + +BUCKET = "allen-data-views" +PROCEDURES_KEY = f"data-asset-cache/zs_{NAMES['procedures']}.pqt" +INJECTIONS_KEY = f"data-asset-cache/zs_{NAMES['injections']}.pqt" +TEST_SUBJECT = "813992" + + +def _s3_key_exists(key: str) -> bool: + s3 = boto3.client("s3") + try: + s3.head_object(Bucket=BUCKET, Key=key) + return True + except s3.exceptions.ClientError: + return False + + +class TestProceduresS3(unittest.TestCase): + """Integration tests for the procedures acorn on S3.""" + + def test_file_exists(self): + self.assertTrue(_s3_key_exists(PROCEDURES_KEY), f"No procedures file found at s3://{BUCKET}/{PROCEDURES_KEY}") + + def test_has_expected_columns(self): + import os + + os.environ["FOREST_TYPE"] = "s3" + from zombie_squirrel.acorns import ACORN_REGISTRY + + df = ACORN_REGISTRY[NAMES["procedures"]](force_update=False) + self.assertIsInstance(df, pd.DataFrame) + self.assertFalse(df.empty) + for col in ("procedure_key", "subject_id", "surgery_start_date", "procedure_type"): + self.assertIn(col, df.columns, f"Missing column: {col}") + + def test_procedure_keys_are_unique(self): + import os + + os.environ["FOREST_TYPE"] = "s3" + from zombie_squirrel.acorns import ACORN_REGISTRY + + df = ACORN_REGISTRY[NAMES["procedures"]](force_update=False) + self.assertEqual(df["procedure_key"].nunique(), len(df), "procedure_key values are not unique") + + def test_contains_test_subject(self): + import os + + os.environ["FOREST_TYPE"] = "s3" + from zombie_squirrel.acorns import ACORN_REGISTRY + + df = ACORN_REGISTRY[NAMES["procedures"]](force_update=False) + self.assertIn(TEST_SUBJECT, df["subject_id"].values, f"Subject {TEST_SUBJECT} not found in procedures table") + + +class TestBrainInjectionsS3(unittest.TestCase): + """Integration tests for the brain_injections acorn on S3.""" + + def test_file_exists(self): + self.assertTrue( + _s3_key_exists(INJECTIONS_KEY), f"No brain_injections file found at s3://{BUCKET}/{INJECTIONS_KEY}" + ) + + def test_has_expected_columns(self): + import os + + os.environ["FOREST_TYPE"] = "s3" + from zombie_squirrel.acorns import ACORN_REGISTRY + + df = ACORN_REGISTRY[NAMES["injections"]](force_update=False) + self.assertIsInstance(df, pd.DataFrame) + self.assertFalse(df.empty) + for col in ( + "procedure_key", + "subject_id", + "surgery_start_date", + "procedure_type", + "targeted_structure_acronym", + "injection_profile", + "injection_volume", + "injection_volume_unit", + ): + self.assertIn(col, df.columns, f"Missing column: {col}") + + def test_procedure_keys_join_to_procedures_table(self): + """Every procedure_key in brain_injections must appear in procedures.""" + import os + + os.environ["FOREST_TYPE"] = "s3" + from zombie_squirrel.acorns import ACORN_REGISTRY + + proc_df = ACORN_REGISTRY[NAMES["procedures"]](force_update=False) + inj_df = ACORN_REGISTRY[NAMES["injections"]](force_update=False) + + orphans = set(inj_df["procedure_key"]) - set(proc_df["procedure_key"]) + self.assertEqual(orphans, set(), f"brain_injections has procedure_keys not in procedures: {orphans}") + + def test_contains_brain_injections(self): + import os + + os.environ["FOREST_TYPE"] = "s3" + from zombie_squirrel.acorns import ACORN_REGISTRY + + df = ACORN_REGISTRY[NAMES["injections"]](force_update=False) + self.assertTrue( + (df["procedure_type"] == "Brain injection").any(), + "No 'Brain injection' rows found in brain_injections table", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/zombie_squirrel/__init__.py b/src/zombie_squirrel/__init__.py index b906a75..a8d34f9 100644 --- a/src/zombie_squirrel/__init__.py +++ b/src/zombie_squirrel/__init__.py @@ -11,6 +11,7 @@ from zombie_squirrel.acorn_helpers.asset_basics import asset_basics # noqa: F401 from zombie_squirrel.acorn_helpers.assets_smartspim import assets_smartspim # noqa: F401 from zombie_squirrel.acorn_helpers.custom import custom # noqa: F401 +from zombie_squirrel.acorn_helpers.procedures import brain_injections, procedures # noqa: F401 from zombie_squirrel.acorn_helpers.qc import qc, qc_columns # noqa: F401 from zombie_squirrel.acorn_helpers.raw_to_derived import raw_to_derived # noqa: F401 from zombie_squirrel.acorn_helpers.source_data import source_data # noqa: F401 diff --git a/src/zombie_squirrel/acorn_helpers/__init__.py b/src/zombie_squirrel/acorn_helpers/__init__.py index 48d23de..c0618e4 100644 --- a/src/zombie_squirrel/acorn_helpers/__init__.py +++ b/src/zombie_squirrel/acorn_helpers/__init__.py @@ -3,6 +3,7 @@ from zombie_squirrel.acorn_helpers import ( # noqa: F401 asset_basics, custom, + procedures, qc, raw_to_derived, source_data, diff --git a/src/zombie_squirrel/acorn_helpers/assets_smartspim.py b/src/zombie_squirrel/acorn_helpers/assets_smartspim.py index e549660..56f8dbb 100644 --- a/src/zombie_squirrel/acorn_helpers/assets_smartspim.py +++ b/src/zombie_squirrel/acorn_helpers/assets_smartspim.py @@ -119,7 +119,9 @@ def _build_rows(raw_to_stitched: dict[str, str | None], metadata: dict[str, dict channel = channels[i - 1] if i <= len(channels) else None row[f"channel_{i}"] = channel row[f"segmentation_link_{i}"] = _segmentation_link(location, channel) if (processed and channel) else None - row[f"quantification_link_{i}"] = _quantification_link(location, channel) if (processed and channel) else None + row[f"quantification_link_{i}"] = ( + _quantification_link(location, channel) if (processed and channel) else None + ) rows.append(row) return rows @@ -157,9 +159,7 @@ def assets_smartspim(force_update: bool = False) -> pd.DataFrame: ) basics = asset_basics() - raw_spim = basics[ - (basics["data_level"] == "raw") & (basics["modalities"].str.contains("SPIM", na=False)) - ] + raw_spim = basics[(basics["data_level"] == "raw") & (basics["modalities"].str.contains("SPIM", na=False))] raw_spim_names = list(raw_spim["name"].dropna()) sd = source_data() diff --git a/src/zombie_squirrel/acorn_helpers/procedures.py b/src/zombie_squirrel/acorn_helpers/procedures.py new file mode 100644 index 0000000..9f12389 --- /dev/null +++ b/src/zombie_squirrel/acorn_helpers/procedures.py @@ -0,0 +1,324 @@ +"""Procedures acorn - subject procedures and brain injections tables.""" + +import logging + +import pandas as pd +from aind_data_access_api.document_db import MetadataDbClient + +import zombie_squirrel.acorns as acorns +from zombie_squirrel.squirrel import Column +from zombie_squirrel.utils import ( + SquirrelMessage, + setup_logging, +) + + +def _to_float(value) -> float | None: + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _serialize_materials(materials: list) -> str: + """Serialize a list of injection materials to a semicolon-separated string of names.""" + if not materials: + return "" + return "; ".join(m.get("name") or "" for m in materials) + + +def _axis_names_from_coord_sys(coord_sys: dict) -> list[str]: + """Return ordered axis names from a coordinate system dict.""" + axes = coord_sys.get("axes") or [] + return [a.get("name", f"axis_{i}") for i, a in enumerate(axes)] + + +def _coord_systems_from_procedures(proc_block: dict, surgery: dict) -> dict[str, list[str]]: + """Build a name->axis_names mapping from top-level and surgery coordinate systems.""" + result = {} + for cs in (proc_block.get("coordinate_system"), surgery.get("coordinate_system")): + if cs and isinstance(cs, dict): + name = cs.get("name", "") + if name: + result[name] = _axis_names_from_coord_sys(cs) + return result + + +def _extract_translation_by_axes(coordinates: list, axis_names: list[str]) -> dict: + """Map the first Translation and Rotation transform values to axis-named columns.""" + result = {name: None for name in axis_names} + result.update({f"{name}_rotation": None for name in axis_names}) + for site in coordinates: + if not isinstance(site, list): + continue + for transform in site: + obj_type = transform.get("object_type") + if obj_type == "Translation": + vals = transform.get("translation") or [] + result.update({axis_names[i]: vals[i] if i < len(vals) else None for i in range(len(axis_names))}) + elif obj_type == "Rotation": + vals = transform.get("rotation") or [] + result.update( + {f"{axis_names[i]}_rotation": vals[i] if i < len(vals) else None for i in range(len(axis_names))} + ) + return result + + +def _extract_surgery_fields(surgery: dict) -> dict: + """Extract flat surgery-level metadata fields (excluding coordinate_system and measured_coordinates).""" + anaesthesia = surgery.get("anaesthesia") or {} + experimenters = surgery.get("experimenters") or [] + return { + "surgery_protocol_id": surgery.get("protocol_id"), + "experimenters": "; ".join(str(e) for e in experimenters), + "ethics_review_id": surgery.get("ethics_review_id"), + "animal_weight_prior": _to_float(surgery.get("animal_weight_prior")), + "animal_weight_post": _to_float(surgery.get("animal_weight_post")), + "weight_unit": surgery.get("weight_unit"), + "anaesthetic_type": anaesthesia.get("anaesthetic_type"), + "anaesthesia_duration": anaesthesia.get("duration"), + "anaesthesia_duration_unit": anaesthesia.get("duration_unit"), + "anaesthesia_level": anaesthesia.get("level"), + "workstation_id": surgery.get("workstation_id"), + "surgery_notes": surgery.get("notes"), + } + + +def _extract_first_dynamics(dynamics: list) -> dict: + """Extract profile, volume, and volume_unit from the first dynamics entry.""" + if not dynamics: + return {"injection_profile": None, "injection_volume": None, "injection_volume_unit": None} + d = dynamics[0] + return { + "injection_profile": d.get("profile"), + "injection_volume": d.get("volume"), + "injection_volume_unit": d.get("volume_unit"), + } + + +@acorns.register_acorn(acorns.NAMES["procedures"]) +def procedures(force_update: bool = False) -> pd.DataFrame: + """Fetch subject procedures summary with one row per procedure per surgery across all subjects. + + Returns a DataFrame with columns: procedure_key, subject_id, surgery_start_date, + and procedure_type. + + Args: + force_update: If True, bypass cache and fetch fresh data from database. + + Returns: + DataFrame with one row per procedure per surgery. + + """ + df = acorns.TREE.scurry(acorns.NAMES["procedures"]) + + if df.empty and not force_update: + raise ValueError("Cache is empty. Use force_update=True to fetch data from database.") + + if df.empty or force_update: + proc_df, _ = _fetch_all_procedures() + df = proc_df + + return df + + +@acorns.register_acorn(acorns.NAMES["injections"]) +def brain_injections(force_update: bool = False) -> pd.DataFrame: + """Fetch detailed Injection and BrainInjection data across all subjects. + + Returns a DataFrame with one row per injection procedure, including + coordinates, materials, dynamics, and targeted structure. + + Args: + force_update: If True, bypass cache and fetch fresh data from database. + + Returns: + DataFrame with detailed injection procedure data. + + """ + df = acorns.TREE.scurry(acorns.NAMES["injections"]) + + if df.empty and not force_update: + raise ValueError("Cache is empty. Use force_update=True to fetch data from database.") + + if df.empty or force_update: + _, inj_df = _fetch_all_procedures() + df = inj_df + + return df + + +def _fetch_all_procedures() -> tuple[pd.DataFrame, pd.DataFrame]: + """Fetch all procedures records from the database and cache both tables.""" + setup_logging() + + logging.info( + SquirrelMessage( + tree=acorns.TREE.__class__.__name__, + acorn=acorns.NAMES["procedures"], + message="Updating cache", + ).to_json() + ) + + client = MetadataDbClient( + host=acorns.API_GATEWAY_HOST, + version="v2", + ) + + all_records = client.retrieve_docdb_records( + filter_query={}, + projection={"_id": 1}, + limit=0, + ) + all_ids = {r["_id"] for r in all_records} + + # Batch retrieve 50 at a time + records = [] + batch_size = 50 + for i, batch_start in enumerate(range(0, len(all_ids), batch_size)): + batch_ids = list(all_ids)[batch_start : batch_start + batch_size] + batch_records = client.retrieve_docdb_records( + filter_query={"_id": {"$in": batch_ids}}, + projection={"procedures": 1, "subject": 1}, + limit=0, + ) + records.extend(batch_records) + logging.info( + SquirrelMessage( + tree=acorns.TREE.__class__.__name__, + acorn=acorns.NAMES["procedures"], + message=f"Fetched batch {i + 1}/{(len(all_ids) + batch_size - 1) // batch_size} ({len(records)}/{len(all_ids)} records)", + ).to_json() + ) + + proc_rows = [] + inj_rows = [] + seen_subject_ids = set() + total = len(records) + + for i, record in enumerate(records): + proc_block = record.get("procedures", {}) or {} + subject_block = record.get("subject", {}) or {} + sid = subject_block.get("subject_id", "") + + if not sid or sid in seen_subject_ids: + continue + seen_subject_ids.add(sid) + logging.info(f"[{i + 1}/{total}] Processing subject {sid}") + + subject_procedures = proc_block.get("subject_procedures", []) or [] + for surgery_idx, surgery in enumerate(subject_procedures): + if surgery.get("object_type") != "Surgery": + continue + surgery_start_date = surgery.get("start_date", "") + coord_sys_map = _coord_systems_from_procedures(proc_block, surgery) + surgery_fields = _extract_surgery_fields(surgery) + inner_procedures = surgery.get("procedures", []) or [] + for proc_idx, proc in enumerate(inner_procedures): + proc_type = proc.get("object_type", "") + procedure_key = f"{sid}_{surgery_idx}_{proc_idx}" + proc_rows.append( + { + "procedure_key": procedure_key, + "subject_id": sid, + "surgery_start_date": surgery_start_date, + "procedure_type": proc_type, + } + ) + if proc_type in ("Brain injection", "Injection"): + inj_rows.append( + _extract_injection_row( + procedure_key, sid, surgery_start_date, proc, coord_sys_map, surgery_fields + ) + ) + + proc_df = pd.DataFrame(proc_rows) if proc_rows else pd.DataFrame() + inj_df = pd.DataFrame(inj_rows) if inj_rows else pd.DataFrame() + + acorns.TREE.hide(acorns.NAMES["procedures"], proc_df) + acorns.TREE.hide(acorns.NAMES["injections"], inj_df) + + return proc_df, inj_df + + +def _extract_injection_row( + procedure_key: str, + subject_id: str, + surgery_start_date: str, + proc: dict, + coord_sys_map: dict[str, list[str]], + surgery_fields: dict, +) -> dict: + """Extract a flat row dict from an Injection or BrainInjection procedure dict.""" + targeted = proc.get("targeted_structure") or {} + targeted_name = targeted.get("name", "") if isinstance(targeted, dict) else "" + targeted_acronym = targeted.get("acronym", "") if isinstance(targeted, dict) else "" + + relative_position = proc.get("relative_position") or [] + if isinstance(relative_position, list): + relative_position = "; ".join(str(p) for p in relative_position) + + cs_name = proc.get("coordinate_system_name", "") + axis_names = coord_sys_map.get(cs_name, []) + coord_cols = _extract_translation_by_axes(proc.get("coordinates") or [], axis_names) + + row = { + "procedure_key": procedure_key, + "subject_id": subject_id, + "surgery_start_date": surgery_start_date, + "procedure_type": proc.get("object_type", ""), + "targeted_structure_name": targeted_name, + "targeted_structure_acronym": targeted_acronym, + "relative_position": relative_position, + "coordinate_system_name": cs_name, + "injection_materials": _serialize_materials(proc.get("injection_materials") or []), + "protocol_id": proc.get("protocol_id", ""), + } + row.update(surgery_fields) + row.update(coord_cols) + row.update(_extract_first_dynamics(proc.get("dynamics") or [])) + return row + + +def procedures_columns() -> list[Column]: + """Return procedures acorn column definitions.""" + return [ + Column(name="procedure_key", description="Unique key for this procedure, joins to brain_injections table"), + Column(name="subject_id", description="Subject ID"), + Column(name="surgery_start_date", description="Start date of the surgery"), + Column(name="procedure_type", description="Type of procedure (e.g. Brain injection, Headframe)"), + ] + + +def brain_injections_columns() -> list[Column]: + """Return brain injections acorn column definitions.""" + return [ + Column(name="procedure_key", description="Unique key for this procedure, joins to procedures table"), + Column(name="subject_id", description="Subject ID"), + Column(name="surgery_start_date", description="Start date of the surgery"), + Column(name="procedure_type", description="Injection type (Brain injection or Injection)"), + Column(name="targeted_structure_name", description="Full name of targeted brain structure"), + Column(name="targeted_structure_acronym", description="Acronym of targeted brain structure"), + Column(name="relative_position", description="Relative anatomical position (e.g. Left; Right)"), + Column(name="coordinate_system_name", description="Name of the coordinate system used"), + Column(name="", description="One column per axis in the coordinate system (e.g. AP, ML, SI, Depth)"), + Column(name="injection_materials", description="Semicolon-separated injection material names"), + Column(name="injection_profile", description="Injection profile (e.g. Bolus, Continuous)"), + Column(name="injection_volume", description="Injection volume"), + Column(name="injection_volume_unit", description="Injection volume unit (e.g. nanoliter)"), + Column(name="protocol_id", description="Protocol ID (DOI)"), + Column(name="surgery_protocol_id", description="Surgery protocol ID"), + Column(name="experimenters", description="Semicolon-separated list of experimenters"), + Column(name="ethics_review_id", description="Ethics review ID"), + Column(name="animal_weight_prior", description="Animal weight before surgery"), + Column(name="animal_weight_post", description="Animal weight after surgery"), + Column(name="weight_unit", description="Unit for animal weight measurements"), + Column(name="anaesthetic_type", description="Type of anaesthetic used"), + Column(name="anaesthesia_duration", description="Duration of anaesthesia"), + Column(name="anaesthesia_duration_unit", description="Unit for anaesthesia duration"), + Column(name="anaesthesia_level", description="Level of anaesthesia"), + Column(name="workstation_id", description="Workstation ID used for surgery"), + Column(name="surgery_notes", description="Free-text notes about the surgery"), + ] diff --git a/src/zombie_squirrel/acorns.py b/src/zombie_squirrel/acorns.py index 7b4b819..838b2f7 100644 --- a/src/zombie_squirrel/acorns.py +++ b/src/zombie_squirrel/acorns.py @@ -40,6 +40,8 @@ "r2d": "raw_to_derived", "qc": "quality_control", "smartspim": "assets_smartspim", + "procedures": "procedures", + "injections": "brain_injections", } ACORN_REGISTRY: dict[str, Callable[[], Any]] = {} diff --git a/tests/acorn_helpers/test_asset_basics.py b/tests/acorn_helpers/test_asset_basics.py index a381127..e77bce4 100644 --- a/tests/acorn_helpers/test_asset_basics.py +++ b/tests/acorn_helpers/test_asset_basics.py @@ -307,6 +307,29 @@ def test_acquisition_type_stored(self, mock_tree, mock_client_class): self.assertEqual(result.iloc[0]["acquisition_type"], "multiplane-2photon") + @patch("zombie_squirrel.acorn_helpers.asset_basics.MetadataDbClient") + @patch("zombie_squirrel.acorn_helpers.asset_basics.acorns.TREE") + def test_age_none_when_unparseable_date(self, mock_tree, mock_client_class): + """Test age is None when date parsing raises an exception.""" + mock_tree.scurry.return_value = pd.DataFrame() + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + mock_client_instance.retrieve_docdb_records.return_value = [ + { + "_id": "id1", + "_last_modified": "2023-01-01", + "data_description": {}, + "acquisition": { + "acquisition_start_time": "not-a-date", + "subject_details": {"date_of_birth": "also-not-a-date"}, + }, + } + ] + + result = asset_basics(force_update=True) + + self.assertIsNone(result.iloc[0]["age"]) + if __name__ == "__main__": unittest.main() diff --git a/tests/acorn_helpers/test_assets_smartspim.py b/tests/acorn_helpers/test_assets_smartspim.py index 43510b8..b4750c4 100644 --- a/tests/acorn_helpers/test_assets_smartspim.py +++ b/tests/acorn_helpers/test_assets_smartspim.py @@ -66,8 +66,12 @@ def test_returns_channel_names(self, mock_boto_client): mock_boto_client.return_value = mock_s3 mock_s3.list_objects_v2.return_value = { "CommonPrefixes": [ - {"Prefix": "SmartSPIM_123_2026-01-01_00-00-00_stitched_2026-01-02_00-00-00/image_cell_segmentation/Ex_488_Em_525/"}, - {"Prefix": "SmartSPIM_123_2026-01-01_00-00-00_stitched_2026-01-02_00-00-00/image_cell_segmentation/Ex_561_Em_600/"}, + { + "Prefix": "SmartSPIM_123_2026-01-01_00-00-00_stitched_2026-01-02_00-00-00/image_cell_segmentation/Ex_488_Em_525/" + }, + { + "Prefix": "SmartSPIM_123_2026-01-01_00-00-00_stitched_2026-01-02_00-00-00/image_cell_segmentation/Ex_561_Em_600/" + }, ] } @@ -209,10 +213,15 @@ def test_uses_last_data_process_end_time(self, mock_list_channels): mock_list_channels.return_value = [] raw_name = "SmartSPIM_123_2026-01-01_00-00-00" stitched_name = "SmartSPIM_123_2026-01-01_00-00-00_stitched_2026-01-02_00-00-00" - record = {**EXAMPLE_RECORD, "processing": {"data_processes": [ - {"end_date_time": "2026-01-01T10:00:00"}, - {"end_date_time": "2026-01-02T12:00:00"}, - ]}} + record = { + **EXAMPLE_RECORD, + "processing": { + "data_processes": [ + {"end_date_time": "2026-01-01T10:00:00"}, + {"end_date_time": "2026-01-02T12:00:00"}, + ] + }, + } rows = _build_rows({raw_name: stitched_name}, {stitched_name: record}) @@ -276,7 +285,9 @@ def test_force_update_builds_and_caches( } ) mock_fetch_meta.return_value = {} - mock_build_rows.return_value = [{"name": "SmartSPIM_raw_2026-01-01_00-00-00_stitched_2026-01-02_00-00-00", "processed": True}] + mock_build_rows.return_value = [ + {"name": "SmartSPIM_raw_2026-01-01_00-00-00_stitched_2026-01-02_00-00-00", "processed": True} + ] result = assets_smartspim(force_update=True) @@ -301,8 +312,11 @@ def test_unprocessed_assets_included_with_processed_false(self, mock_tree, mock_ ) with patch("zombie_squirrel.acorn_helpers.assets_smartspim._fetch_asset_metadata", return_value={}): - with patch("zombie_squirrel.acorn_helpers.assets_smartspim._build_rows", return_value=[{"name": "SmartSPIM_raw_2026-01-01_00-00-00", "processed": False}]) as mock_build: - result = assets_smartspim(force_update=True) + with patch( + "zombie_squirrel.acorn_helpers.assets_smartspim._build_rows", + return_value=[{"name": "SmartSPIM_raw_2026-01-01_00-00-00", "processed": False}], + ) as mock_build: + assets_smartspim(force_update=True) raw_to_stitched_arg = mock_build.call_args[0][0] self.assertIn("SmartSPIM_raw_2026-01-01_00-00-00", raw_to_stitched_arg) diff --git a/tests/acorn_helpers/test_procedures.py b/tests/acorn_helpers/test_procedures.py new file mode 100644 index 0000000..c1e0724 --- /dev/null +++ b/tests/acorn_helpers/test_procedures.py @@ -0,0 +1,420 @@ +"""Unit tests for procedures acorn.""" + +import unittest +from unittest.mock import MagicMock, patch + +import pandas as pd + +import zombie_squirrel.acorns as acorns +from zombie_squirrel.acorn_helpers.procedures import ( + _axis_names_from_coord_sys, + _coord_systems_from_procedures, + _extract_first_dynamics, + _extract_injection_row, + _extract_translation_by_axes, + _serialize_materials, + _to_float, + brain_injections, + brain_injections_columns, + procedures, + procedures_columns, +) +from zombie_squirrel.forest import MemoryTree + +SAMPLE_RECORD = { + "_id": "abc123", + "subject": {"subject_id": "813992"}, + "procedures": { + "object_type": "Procedures", + "subject_id": "813992", + "coordinate_system": { + "object_type": "Coordinate system", + "name": "BREGMA_ARID", + "origin": "Bregma", + "axes": [ + {"object_type": "Axis", "name": "AP", "direction": "Posterior_to_anterior"}, + {"object_type": "Axis", "name": "ML", "direction": "Left_to_right"}, + {"object_type": "Axis", "name": "SI", "direction": "Superior_to_inferior"}, + {"object_type": "Axis", "name": "Depth", "direction": "Up_to_down"}, + ], + "axis_unit": "millimeter", + }, + "subject_procedures": [ + { + "object_type": "Surgery", + "start_date": "2025-09-24", + "procedures": [ + { + "object_type": "Headframe", + "headframe_type": "AI Straight bar", + }, + { + "object_type": "Brain injection", + "injection_materials": [ + { + "object_type": "Viral material", + "name": "AAV-test", + "titer": 1e13, + "titer_unit": "gc/mL", + } + ], + "targeted_structure": { + "atlas": "CCFv3", + "name": "Nucleus accumbens", + "acronym": "ACB", + "id": "56", + }, + "relative_position": ["Left"], + "dynamics": [ + { + "object_type": "Injection dynamics", + "profile": "Bolus", + "volume": 300, + "volume_unit": "nanoliter", + "duration": None, + "duration_unit": "minute", + } + ], + "protocol_id": "dx.doi.org/10.17504/protocols.io.test", + "coordinate_system_name": "BREGMA_ARID", + "coordinates": [ + [ + { + "object_type": "Translation", + "translation": [1.3, -1.8, 0, 4.4], + } + ] + ], + }, + ], + } + ], + }, +} + + +class TestAxisHelpers(unittest.TestCase): + """Tests for coordinate axis helper functions.""" + + def test_axis_names_from_coord_sys(self): + cs = {"axes": [{"name": "AP"}, {"name": "ML"}, {"name": "SI"}, {"name": "Depth"}]} + self.assertEqual(_axis_names_from_coord_sys(cs), ["AP", "ML", "SI", "Depth"]) + + def test_axis_names_empty(self): + self.assertEqual(_axis_names_from_coord_sys({}), []) + + def test_coord_systems_from_procedures_top_level(self): + proc_block = { + "coordinate_system": { + "name": "BREGMA_ARID", + "axes": [{"name": "AP"}, {"name": "ML"}, {"name": "SI"}, {"name": "Depth"}], + } + } + result = _coord_systems_from_procedures(proc_block, {}) + self.assertIn("BREGMA_ARID", result) + self.assertEqual(result["BREGMA_ARID"], ["AP", "ML", "SI", "Depth"]) + + def test_coord_systems_surgery_overrides(self): + proc_block = {} + surgery = { + "coordinate_system": { + "name": "CUSTOM", + "axes": [{"name": "X"}, {"name": "Y"}], + } + } + result = _coord_systems_from_procedures(proc_block, surgery) + self.assertIn("CUSTOM", result) + + def test_extract_translation_by_axes(self): + coords = [[{"object_type": "Translation", "translation": [1.3, -1.8, 0.0, 4.4]}]] + result = _extract_translation_by_axes(coords, ["AP", "ML", "SI", "Depth"]) + self.assertEqual( + result, + { + "AP": 1.3, + "ML": -1.8, + "SI": 0.0, + "Depth": 4.4, + "AP_rotation": None, + "ML_rotation": None, + "SI_rotation": None, + "Depth_rotation": None, + }, + ) + + def test_extract_translation_no_translation(self): + coords = [[{"object_type": "Rotation", "rotation": [0, 0, 1, 45]}]] + result = _extract_translation_by_axes(coords, ["AP", "ML"]) + self.assertEqual(result, {"AP": None, "ML": None, "AP_rotation": 0, "ML_rotation": 0}) + + def test_extract_translation_empty_coords(self): + result = _extract_translation_by_axes([], ["AP", "ML"]) + self.assertEqual(result, {"AP": None, "ML": None, "AP_rotation": None, "ML_rotation": None}) + + def test_extract_first_dynamics(self): + d = [{"profile": "Bolus", "volume": 300, "volume_unit": "nanoliter", "duration": None}] + result = _extract_first_dynamics(d) + self.assertEqual(result["injection_profile"], "Bolus") + self.assertEqual(result["injection_volume"], 300) + self.assertEqual(result["injection_volume_unit"], "nanoliter") + + def test_extract_first_dynamics_empty(self): + result = _extract_first_dynamics([]) + self.assertIsNone(result["injection_profile"]) + self.assertIsNone(result["injection_volume"]) + self.assertIsNone(result["injection_volume_unit"]) + + +class TestSerializeMaterials(unittest.TestCase): + """Tests for _serialize_materials helper.""" + + def test_empty(self): + self.assertEqual(_serialize_materials([]), "") + + def test_single_material(self): + m = [{"name": "AAV-test", "titer": 1e13}] + self.assertEqual(_serialize_materials(m), "AAV-test") + + def test_multiple_materials(self): + m = [{"name": "AAV-GCaMP"}, {"name": "AAV-ChRmine"}] + self.assertEqual(_serialize_materials(m), "AAV-GCaMP; AAV-ChRmine") + + +class TestExtractInjectionRow(unittest.TestCase): + """Tests for _extract_injection_row helper.""" + + def setUp(self): + self.coord_sys_map = {"BREGMA_ARID": ["AP", "ML", "SI", "Depth"]} + + def test_brain_injection_row(self): + proc = SAMPLE_RECORD["procedures"]["subject_procedures"][0]["procedures"][1] + row = _extract_injection_row("813992_0_1", "813992", "2025-09-24", proc, self.coord_sys_map, {}) + self.assertEqual(row["procedure_key"], "813992_0_1") + self.assertEqual(row["subject_id"], "813992") + self.assertEqual(row["surgery_start_date"], "2025-09-24") + self.assertEqual(row["procedure_type"], "Brain injection") + self.assertEqual(row["targeted_structure_acronym"], "ACB") + self.assertEqual(row["targeted_structure_name"], "Nucleus accumbens") + self.assertEqual(row["relative_position"], "Left") + self.assertEqual(row["coordinate_system_name"], "BREGMA_ARID") + self.assertEqual(row["AP"], 1.3) + self.assertEqual(row["ML"], -1.8) + self.assertEqual(row["SI"], 0) + self.assertEqual(row["Depth"], 4.4) + self.assertIn("AAV-test", row["injection_materials"]) + self.assertEqual(row["injection_profile"], "Bolus") + self.assertEqual(row["injection_volume"], 300) + self.assertEqual(row["injection_volume_unit"], "nanoliter") + + def test_missing_targeted_structure(self): + proc = {"object_type": "Brain injection", "targeted_structure": None, "dynamics": [], "injection_materials": []} + row = _extract_injection_row("sub1_0_0", "sub1", "2025-01-01", proc, {}, {}) + self.assertEqual(row["targeted_structure_name"], "") + self.assertEqual(row["targeted_structure_acronym"], "") + + def test_unknown_coord_sys_gives_no_axis_columns(self): + proc = { + "object_type": "Brain injection", + "coordinate_system_name": "UNKNOWN", + "coordinates": [[{"object_type": "Translation", "translation": [1, 2, 3, 4]}]], + "dynamics": [], + "injection_materials": [], + } + row = _extract_injection_row("sub1_0_0", "sub1", "2025-01-01", proc, {}, {}) + self.assertNotIn("AP", row) + + +class TestProceduresAcorn(unittest.TestCase): + """Tests for procedures() acorn function.""" + + def setUp(self): + acorns.TREE = MemoryTree() + + @patch("zombie_squirrel.acorn_helpers.procedures.MetadataDbClient") + def test_procedures_force_update(self, mock_client_class): + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + mock_client_instance.retrieve_docdb_records.return_value = [SAMPLE_RECORD] + + df = procedures(force_update=True) + + self.assertFalse(df.empty) + self.assertIn("procedure_key", df.columns) + self.assertIn("procedure_type", df.columns) + self.assertIn("Headframe", df["procedure_type"].values) + self.assertIn("Brain injection", df["procedure_type"].values) + self.assertEqual(df.iloc[0]["surgery_start_date"], "2025-09-24") + + @patch("zombie_squirrel.acorn_helpers.procedures.MetadataDbClient") + def test_procedures_cache_hit(self, mock_client_class): + cached_df = pd.DataFrame({"subject_id": ["813992"], "procedure_type": ["Headframe"]}) + acorns.TREE.hide("procedures", cached_df) + + df = procedures(force_update=False) + + self.assertEqual(len(df), 1) + mock_client_class.assert_not_called() + + def test_procedures_empty_cache_raises(self): + with self.assertRaises(ValueError): + procedures(force_update=False) + + @patch("zombie_squirrel.acorn_helpers.procedures.MetadataDbClient") + def test_procedures_no_records_found(self, mock_client_class): + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + mock_client_instance.retrieve_docdb_records.return_value = [] + + df = procedures(force_update=True) + + self.assertTrue(df.empty) + + @patch("zombie_squirrel.acorn_helpers.procedures.MetadataDbClient") + def test_procedures_deduplicates_by_subject_id(self, mock_client_class): + """Multiple records for the same subject should only yield one pass through surgery data.""" + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + duplicate = dict(SAMPLE_RECORD) + duplicate["_id"] = "other-id" + mock_client_instance.retrieve_docdb_records.return_value = [SAMPLE_RECORD, duplicate] + + df = procedures(force_update=True) + + self.assertEqual(len(df), 2) + + +class TestBrainInjectionsAcorn(unittest.TestCase): + """Tests for brain_injections() acorn function.""" + + def setUp(self): + acorns.TREE = MemoryTree() + + @patch("zombie_squirrel.acorn_helpers.procedures.MetadataDbClient") + def test_brain_injections_force_update(self, mock_client_class): + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + mock_client_instance.retrieve_docdb_records.return_value = [SAMPLE_RECORD] + + df = brain_injections(force_update=True) + + self.assertFalse(df.empty) + self.assertIn("targeted_structure_acronym", df.columns) + self.assertIn("procedure_key", df.columns) + self.assertEqual(df.iloc[0]["targeted_structure_acronym"], "ACB") + self.assertEqual(df.iloc[0]["AP"], 1.3) + self.assertEqual(df.iloc[0]["injection_profile"], "Bolus") + self.assertEqual(df.iloc[0]["injection_volume"], 300) + + @patch("zombie_squirrel.acorn_helpers.procedures.MetadataDbClient") + def test_brain_injections_cache_hit(self, mock_client_class): + cached_df = pd.DataFrame({"subject_id": ["813992"], "targeted_structure_acronym": ["ACB"]}) + acorns.TREE.hide("brain_injections", cached_df) + + df = brain_injections(force_update=False) + + self.assertEqual(len(df), 1) + mock_client_class.assert_not_called() + + def test_brain_injections_empty_cache_raises(self): + with self.assertRaises(ValueError): + brain_injections(force_update=False) + + @patch("zombie_squirrel.acorn_helpers.procedures.MetadataDbClient") + def test_no_injections_returns_empty(self, mock_client_class): + """Record with only non-injection procedures yields empty injections table.""" + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + record_no_injections = { + "_id": "xyz", + "subject": {"subject_id": "000001"}, + "procedures": { + "subject_procedures": [ + { + "object_type": "Surgery", + "start_date": "2025-01-01", + "procedures": [{"object_type": "Headframe"}], + } + ] + }, + } + mock_client_instance.retrieve_docdb_records.return_value = [record_no_injections] + + df = brain_injections(force_update=True) + + self.assertTrue(df.empty) + + +class TestToFloat(unittest.TestCase): + """Tests for _to_float helper.""" + + def test_valid_float(self): + self.assertEqual(_to_float(1.5), 1.5) + + def test_invalid_value(self): + self.assertIsNone(_to_float("not-a-number")) + + def test_none(self): + self.assertIsNone(_to_float(None)) + + +class TestExtractTranslationNonListSite(unittest.TestCase): + """Tests for _extract_translation_by_axes with non-list site.""" + + def test_non_list_site_is_skipped(self): + coords = [ + {"object_type": "Translation", "translation": [1, 2]}, + [{"object_type": "Translation", "translation": [3, 4]}], + ] + result = _extract_translation_by_axes(coords, ["AP", "ML"]) + self.assertEqual(result["AP"], 3) + self.assertEqual(result["ML"], 4) + + +class TestNonSurgeryProcedures(unittest.TestCase): + """Tests that non-Surgery entries in subject_procedures are skipped.""" + + def setUp(self): + acorns.TREE = MemoryTree() + + @patch("zombie_squirrel.acorn_helpers.procedures.MetadataDbClient") + def test_non_surgery_object_skipped(self, mock_client_class): + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + record = { + "_id": "abc", + "subject": {"subject_id": "sub1"}, + "procedures": { + "subject_procedures": [ + {"object_type": "NotASurgery"}, + { + "object_type": "Surgery", + "start_date": "2025-01-01", + "procedures": [{"object_type": "Headframe"}], + }, + ] + }, + } + mock_client_instance.retrieve_docdb_records.return_value = [record] + + df = procedures(force_update=True) + + self.assertEqual(len(df), 1) + self.assertEqual(df.iloc[0]["procedure_type"], "Headframe") + + +class TestColumnFunctions(unittest.TestCase): + """Tests for procedures_columns and brain_injections_columns.""" + + def test_procedures_columns_returns_list(self): + cols = procedures_columns() + self.assertIsInstance(cols, list) + self.assertTrue(any(c.name == "procedure_key" for c in cols)) + + def test_brain_injections_columns_returns_list(self): + cols = brain_injections_columns() + self.assertIsInstance(cols, list) + self.assertTrue(any(c.name == "procedure_key" for c in cols)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_forest_coverage.py b/tests/test_forest_coverage.py index d3ede4b..e0f2696 100644 --- a/tests/test_forest_coverage.py +++ b/tests/test_forest_coverage.py @@ -52,5 +52,14 @@ def test_get_location_not_partitioned(self, mock_boto3): self.assertNotIn("zs_my_table/", result) +class TestMemoryTreeFetchMissingKey(unittest.TestCase): + """Tests for MemoryTree.fetch when key is not present.""" + + def test_fetch_missing_key_returns_empty_json(self): + tree = MemoryTree() + result = tree.fetch("nonexistent.json") + self.assertEqual(result, "{}") + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index a9af344..597abaf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,10 @@ import unittest -from zombie_squirrel.utils import get_s3_cache_path, prefix_table_name +import zombie_squirrel.acorns as acorns +from zombie_squirrel.forest import MemoryTree +from zombie_squirrel.squirrel import Squirrel +from zombie_squirrel.utils import get_s3_cache_path, get_squirrel_info, prefix_table_name class TestPrefixTableName(unittest.TestCase): @@ -51,5 +54,20 @@ def test_get_s3_cache_path_various_names(self): self.assertEqual(result, "data-asset-cache/zs_my_data.pqt") +class TestGetSquirrelInfo(unittest.TestCase): + """Tests for get_squirrel_info function.""" + + def test_get_squirrel_info(self): + tree = MemoryTree() + squirrel = Squirrel(acorns=[]) + tree.plant("squirrel.json", squirrel.model_dump_json()) + acorns.TREE = tree + + result = get_squirrel_info() + + self.assertIsInstance(result, Squirrel) + self.assertEqual(result.acorns, []) + + if __name__ == "__main__": unittest.main() diff --git a/uv.lock b/uv.lock index e1c0ff0..15c84a5 100644 --- a/uv.lock +++ b/uv.lock @@ -34,7 +34,7 @@ wheels = [ [[package]] name = "aind-data-access-api" -version = "1.9.2" +version = "1.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "boto3" }, @@ -42,9 +42,9 @@ dependencies = [ { name = "pydantic-settings" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ea/71/b4255d81da567ac3e4db9ea594c8fbb6494cb3fd9f80fb52e28a5216d701/aind_data_access_api-1.9.2.tar.gz", hash = "sha256:8a2afe4268079f950e35cbd5c1880abe1608b11b8e081bee9308b1507d141c6b", size = 71746, upload-time = "2025-11-18T00:02:14.32Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/7c/210906ebf642a19cd5eca57f84a32e7b6d24ff97b7b8b7f315a819b2fad9/aind_data_access_api-1.10.0.tar.gz", hash = "sha256:78dfdcdf8052af95b716e897099b9ed2f912fb3c03741ed08bb65a42163df8d4", size = 72394, upload-time = "2026-04-21T23:30:02.986Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/3e/3a312f564cfbb2fee2f6a3d24e3325260b96e8bea4bf0abbd15d4b5a1d46/aind_data_access_api-1.9.2-py3-none-any.whl", hash = "sha256:588afebdf1647be8071131ead1f9c2d2761d29468a128caa181746bad204d2a6", size = 18646, upload-time = "2025-11-18T00:02:13.284Z" }, + { url = "https://files.pythonhosted.org/packages/43/d6/8fd44feb29df9eaa2d65418d41b7e3303cfc65e66eeccc6159882d5ac673/aind_data_access_api-1.10.0-py3-none-any.whl", hash = "sha256:a463917e13674cfd4e88518beaf1b81a6d077aa80a1fc54ee7d682f462e4ac24", size = 18962, upload-time = "2026-04-21T23:30:01.64Z" }, ] [package.optional-dependencies] @@ -1900,7 +1900,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "aind-data-access-api", extras = ["docdb"] }, + { name = "aind-data-access-api", extras = ["docdb"], specifier = ">=1.10.0,<2" }, { name = "boto3" }, { name = "duckdb" }, { name = "pandas", specifier = ">=2.2.0" },