diff --git a/element_miniscope/miniscope.py b/element_miniscope/miniscope.py index 98a611c..97f16ce 100644 --- a/element_miniscope/miniscope.py +++ b/element_miniscope/miniscope.py @@ -1,11 +1,13 @@ import csv import copy import cv2 +import gc import importlib import inspect import json import os import pathlib +import time from datetime import datetime, timezone from typing import Union @@ -75,6 +77,17 @@ def activate( add_objects=_linking_module.__dict__, ) +# ─── Configuration ──────────────────────────────────────────────────────── + +# Docker image for minian processing. +# Override with env var MINIAN_DOCKER_IMAGE if using a custom registry. +MINIAN_DOCKER_IMAGE = os.getenv( + "MINIAN_DOCKER_IMAGE", "datajoint/minian-py38:latest" +) + +# Timeout for the docker run command (seconds). Default: 24 hours. +# Large datasets (70k+ frames) can take many hours. +MINIAN_DOCKER_TIMEOUT = int(os.getenv("MINIAN_DOCKER_TIMEOUT", "86400")) # Functions required by the element-miniscope ----------------------------------------- @@ -382,7 +395,6 @@ def make(self, key): nframes = total_frames nchannels = 1 # Assumes a single channel - elif acq_software == "Inscopix": inscopix_metadata = next(recording_path.glob("session.json")) @@ -472,7 +484,10 @@ class ProcessingMethod(dj.Lookup): processing_method_desc: varchar(1000) """ - contents = [("caiman", "caiman analysis suite")] + contents = [ + ("caiman", "caiman analysis suite"), + ("minian", "minian analysis suite"), + ] @schema @@ -717,6 +732,7 @@ def make_fetch(self, key): avi_files = (RecordingInfo.File & key).fetch("file_path") processing_params = (ProcessingParamSet & key).fetch1("params") sampling_rate = (RecordingInfo & key).fetch1("fps") + px_height, px_width = (RecordingInfo & key).fetch1("px_height", "px_width") return ( task_mode, @@ -725,6 +741,8 @@ def make_fetch(self, key): avi_files, processing_params, sampling_rate, + px_height, + px_width, ) def make_compute( @@ -736,13 +754,15 @@ def make_compute( avi_files, processing_params, sampling_rate, + px_height, + px_width, ): """ Execute the miniscope analysis defined by the ProcessingTask. - task_mode: 'load', confirm that the results are already computed. - task_mode: 'trigger' runs the analysis. """ - if method != "caiman": + if method not in ["caiman", "minian"]: raise NotImplementedError(f"Method {method} is not supported") params = copy.deepcopy(processing_params) @@ -766,184 +786,240 @@ def make_compute( if method == "caiman": loaded_caiman = loaded_result key = {**key, "processing_time": loaded_caiman.creation_time} + elif method == "minian": + loaded_minian = loaded_result + key = {**key, "processing_time": loaded_minian.creation_time} else: raise NotImplementedError( f"Loading of {method} data is not yet supported" ) + file_entries = [ + { + **key, + "file_name": f.relative_to( + get_processed_root_data_dir() + ).as_posix(), + "file": f.as_posix(), + } + for f in output_dir.rglob("*") + if f.is_file() + ] elif task_mode == "trigger": - import multiprocessing - import caiman as cm - from caiman.motion_correction import MotionCorrect - from caiman.source_extraction.cnmf.cnmf import CNMF - from caiman.source_extraction.cnmf.params import CNMFParams - from element_interface.run_caiman import _save_mc - - extra_params = params.pop("extra_dj_params", {}) - avi_files = [ find_full_path(get_miniscope_root_data_dir(), avi_file).as_posix() for avi_file in avi_files ] - params["fnames"] = avi_files - params["fr"] = sampling_rate - params["is3D"] = False - if "indices" in params: - params["motion"] = { - "indices": ( - slice(*params.get("indices")[0]), - slice(*params.get("indices")[1]), - ) - } - else: - params["motion"] = {"indices": (slice(None), slice(None))} + if method == "caiman": + import multiprocessing + import caiman as cm + from caiman.motion_correction import MotionCorrect + from caiman.source_extraction.cnmf.cnmf import CNMF + from caiman.source_extraction.cnmf.params import CNMFParams + from element_interface.run_caiman import _save_mc + + extra_params = params.pop("extra_dj_params", {}) + + params["fnames"] = avi_files + params["fr"] = sampling_rate + params["is3D"] = False + if "indices" in params: + params["motion"] = { + "indices": ( + slice(*params.get("indices")[0]), + slice(*params.get("indices")[1]), + ) + } + else: + params["motion"] = {"indices": (slice(None), slice(None))} - @memoized_result( - uniqueness_dict=params, - output_directory=output_dir, - ) - def _run_processing(): - mc_indices = params["motion"].get("indices") - caiman_temp = os.environ.get("CAIMAN_TEMP") - os.environ["CAIMAN_TEMP"] = str(output_dir) - n_processes = np.floor(multiprocessing.cpu_count() * 0.6) - n_processes = int(os.getenv("CAIMAN_MC_N_PROCESSES", n_processes)) - _, dview, n_processes = cm.cluster.setup_cluster( - backend="multiprocessing", - n_processes=n_processes, - maxtasksperchild=1, + @memoized_result( + uniqueness_dict=params, + output_directory=output_dir, ) - try: - opts = CNMFParams(params_dict=params) - cnm = CNMF(n_processes, params=opts, dview=dview) - fnames = cnm.params.get("data", "fnames") - mc = MotionCorrect(fnames, dview=cnm.dview, **cnm.params.motion) - mc_base_attrs = list(mc.__dict__) - logger.info("Starting motion correction (CaImAn)...") - mc.motion_correct(save_movie=mc_indices is None) - mc_results = { - k: v for k, v in mc.__dict__.items() if k not in mc_base_attrs - } - if cnm.params.get("motion", "pw_rigid"): - mc_results["b0"] = np.ceil( - np.max(np.abs(mc.shifts_rig)) - ).astype(int) - cnm.estimates.shifts = mc.shifts_rig - if cnm.params.get("motion", "is3D"): - cnm.estimates.shifts = [ - mc.x_shifts_els, - mc.y_shifts_els, - mc.z_shifts_els, - ] - else: - cnm.estimates.shifts = [mc.x_shifts_els, mc.y_shifts_els] - else: - mc_results["b0"] = np.ceil( - np.max(np.abs(mc.shifts_rig)) - ).astype(int) - cnm.estimates.shifts = mc.shifts_rig - - base_name = pathlib.Path(fnames[0]).stem - fname_mc = ( - mc.fname_tot_els - if cnm.params.motion["pw_rigid"] - else mc.fname_tot_rig - ) - if all(fname_mc): - logger.info("Generating C-order memmap file...") - border_to_0 = 0 if mc.border_nan == "copy" else mc.border_to_0 - fname_new = cm.mmapping.save_memmap( - fname_mc, - base_name=base_name + "_mc", - order="C", - var_name_hdf5=cnm.params.get("data", "var_name_hdf5"), - border_to_0=border_to_0, - ) - else: - logger.info( - "Applying shifts, then generating C-order memmap file..." - ) - fname_new = mc.apply_shifts_movie( - fnames, - save_memmap=True, - save_base_name=base_name + "_mc", - order="C", - ) - mc.mmap_file = [fname_new] - Yr, dims, T = cm.mmapping.load_memmap(fname_new) - images = np.reshape(Yr.T, [T] + list(dims), order="F") - cnm.mmap_file = fname_new - # terminate the previous cluster and setup a new one with fewer - # processes for CNMF because it is memory intensive - dview.terminate() - n_processes = np.floor(multiprocessing.cpu_count() * 0.2) - n_processes = int(os.getenv("CAIMAN_CNMF_N_PROCESSES", n_processes)) + def _run_processing(): + mc_indices = params["motion"].get("indices") + caiman_temp = os.environ.get("CAIMAN_TEMP") + os.environ["CAIMAN_TEMP"] = str(output_dir) + n_processes = np.floor(multiprocessing.cpu_count() * 0.6) + n_processes = int(os.getenv("CAIMAN_MC_N_PROCESSES", n_processes)) _, dview, n_processes = cm.cluster.setup_cluster( backend="multiprocessing", n_processes=n_processes, maxtasksperchild=1, ) - cnm.dview = dview - logger.info(f"Starting CNMF analysis with {n_processes} processes...") + try: + opts = CNMFParams(params_dict=params) + cnm = CNMF(n_processes, params=opts, dview=dview) + fnames = cnm.params.get("data", "fnames") + mc = MotionCorrect(fnames, dview=cnm.dview, **cnm.params.motion) + mc_base_attrs = list(mc.__dict__) + logger.info("Starting motion correction (CaImAn)...") + mc.motion_correct(save_movie=mc_indices is None) + mc_results = { + k: v + for k, v in mc.__dict__.items() + if k not in mc_base_attrs + } + if cnm.params.get("motion", "pw_rigid"): + mc_results["b0"] = np.ceil( + np.max(np.abs(mc.shifts_rig)) + ).astype(int) + cnm.estimates.shifts = mc.shifts_rig + if cnm.params.get("motion", "is3D"): + cnm.estimates.shifts = [ + mc.x_shifts_els, + mc.y_shifts_els, + mc.z_shifts_els, + ] + else: + cnm.estimates.shifts = [ + mc.x_shifts_els, + mc.y_shifts_els, + ] + else: + mc_results["b0"] = np.ceil( + np.max(np.abs(mc.shifts_rig)) + ).astype(int) + cnm.estimates.shifts = mc.shifts_rig + + base_name = pathlib.Path(fnames[0]).stem + fname_mc = ( + mc.fname_tot_els + if cnm.params.motion["pw_rigid"] + else mc.fname_tot_rig + ) + if all(fname_mc): + logger.info("Generating C-order memmap file...") + border_to_0 = ( + 0 if mc.border_nan == "copy" else mc.border_to_0 + ) + fname_new = cm.mmapping.save_memmap( + fname_mc, + base_name=base_name + "_mc", + order="C", + var_name_hdf5=cnm.params.get("data", "var_name_hdf5"), + border_to_0=border_to_0, + ) + else: + logger.info( + "Applying shifts, then generating C-order memmap file..." + ) + fname_new = mc.apply_shifts_movie( + fnames, + save_memmap=True, + save_base_name=base_name + "_mc", + order="C", + ) + mc.mmap_file = [fname_new] + Yr, dims, T = cm.mmapping.load_memmap(fname_new) + images = np.reshape(Yr.T, [T] + list(dims), order="F") + cnm.mmap_file = fname_new + # terminate the previous cluster and setup a new one with fewer + # processes for CNMF because it is memory intensive + dview.terminate() + n_processes = np.floor(multiprocessing.cpu_count() * 0.2) + n_processes = int( + os.getenv("CAIMAN_CNMF_N_PROCESSES", n_processes) + ) + _, dview, n_processes = cm.cluster.setup_cluster( + backend="multiprocessing", + n_processes=n_processes, + maxtasksperchild=1, + ) + cnm.dview = dview + logger.info( + f"Starting CNMF analysis with {n_processes} processes..." + ) - cnm.fit(images, indices=(slice(None), slice(None))) - cnm.estimates.evaluate_components( - images, cnm.params, dview=cnm.dview - ) - cnm.estimates.detrend_df_f(quantileMin=8, frames_window=250) - logger.info("Computing summary images...") - correlation_image, _ = cm.summary_images.correlation_pnr( - images[:: max(T // 1000, 1)], - gSig=cnm.params.init["gSig"][0], - swap_dim=False, - ) - correlation_image[np.isnan(correlation_image)] = 0 - cnm.estimates.Cn = correlation_image - fname_hdf5 = cnm.mmap_file[:-4] + "hdf5" - cnm.save(fname_hdf5) - cnmf_output_file = pathlib.Path(fname_hdf5) - summary_images = { - "average_image": np.mean(images[:: max(T // 1000, 1)], axis=0), - "max_image": np.max(images[:: max(T // 1000, 1)], axis=0), - "correlation_image": correlation_image, - } - _save_mc( - mc, - cnmf_output_file.as_posix(), - params["is3D"], - summary_images=summary_images, - ) - except Exception as e: - dview.terminate() - raise e - else: - cm.stop_server(dview=dview) - logger.info("CNMF analysis complete. Resulted saved.") - caiman_temp = os.environ.get("CAIMAN_TEMP") - if caiman_temp is not None: - os.environ["CAIMAN_TEMP"] = caiman_temp + cnm.fit(images, indices=(slice(None), slice(None))) + cnm.estimates.evaluate_components( + images, cnm.params, dview=cnm.dview + ) + cnm.estimates.detrend_df_f(quantileMin=8, frames_window=250) + logger.info("Computing summary images...") + correlation_image, _ = cm.summary_images.correlation_pnr( + images[:: max(T // 1000, 1)], + gSig=cnm.params.init["gSig"][0], + swap_dim=False, + ) + correlation_image[np.isnan(correlation_image)] = 0 + cnm.estimates.Cn = correlation_image + fname_hdf5 = cnm.mmap_file[:-4] + "hdf5" + cnm.save(fname_hdf5) + cnmf_output_file = pathlib.Path(fname_hdf5) + summary_images = { + "average_image": np.mean( + images[:: max(T // 1000, 1)], axis=0 + ), + "max_image": np.max(images[:: max(T // 1000, 1)], axis=0), + "correlation_image": correlation_image, + } + _save_mc( + mc, + cnmf_output_file.as_posix(), + params["is3D"], + summary_images=summary_images, + ) + except Exception as e: + dview.terminate() + raise e else: - del os.environ["CAIMAN_TEMP"] + cm.stop_server(dview=dview) + logger.info("CNMF analysis complete. Resulted saved.") + caiman_temp = os.environ.get("CAIMAN_TEMP") + if caiman_temp is not None: + os.environ["CAIMAN_TEMP"] = caiman_temp + else: + del os.environ["CAIMAN_TEMP"] - _run_processing() - _, imaging_dataset = get_loader_result( - key, ProcessingTask, full_output_dir=output_dir - ) - caiman_dataset = imaging_dataset - key["processing_time"] = caiman_dataset.creation_time - key["package_version"] = cm.__version__ - file_entries = [ - { - **key, - "file_name": f.relative_to( - get_processed_root_data_dir() - ).as_posix(), - "file": f.as_posix(), - } - for f in output_dir.rglob("*") - if f.is_file() - ] - else: - raise ValueError(f"Unknown task mode: {task_mode}") + _run_processing() + _, imaging_dataset = get_loader_result( + key, ProcessingTask, full_output_dir=output_dir + ) + caiman_dataset = imaging_dataset + key["processing_time"] = caiman_dataset.creation_time + key["package_version"] = cm.__version__ + file_entries = [ + { + **key, + "file_name": f.relative_to( + get_processed_root_data_dir() + ).as_posix(), + "file": f.as_posix(), + } + for f in output_dir.rglob("*") + if f.is_file() + ] + elif method == "minian": + logger.info("Running minian via Docker container...") + + # Run the minian pipeline in a Docker container + status = _run_minian_in_container( + avi_files=avi_files, + output_dir=output_dir, + params=params, + ) + + # Set processing time from container output + key["processing_time"] = status.get( + "timestamp", time.strftime("%Y-%m-%dT%H:%M:%S") + ) + + # Get minian version (from status.json or default) + key["package_version"] = status.get("minian_version", "") + + # Collect output files for insertion into DataJoint + file_entries = [ + { + **key, + "file_name": f.name, + "file": f.as_posix(), + } + for f in output_dir.rglob("*") + if f.is_file() + ] + else: + raise ValueError(f"Unknown task mode: {task_mode}") return (file_entries, output_dir) def make_insert(self, key, file_entries, output_dir): @@ -956,9 +1032,14 @@ def make_insert(self, key, file_entries, output_dir): ).as_posix(), } ) - self.insert1(dict(**key, processing_time=datetime.now(timezone.utc))) - # for file in file_entries: - # self.File.insert1(file, ignore_extra_fields=True) + self.insert1( + dict( + **key, + processing_time=key.get( + "processing_time", datetime.now(timezone.utc) + ), + ) + ) # Motion Correction -------------------------------------------------------------------- @@ -1131,6 +1212,38 @@ def make(self, key): } self.Summary.insert1(summary_images) + elif method == "minian": + minian_dataset = loaded_result + + self.insert1( + {**key, "motion_correct_channel": minian_dataset.alignment_channel} + ) + + # Minian uses rigid motion correction + rigid_correction = minian_dataset.extract_rigid_mc() + if rigid_correction is not None: + rigid_correction.update(**key) + self.RigidMotionCorrection.insert1(rigid_correction) + + # -- summary images -- + ref_image = minian_dataset.ref_image + mean_image = minian_dataset.mean_image + max_proj_image = minian_dataset.max_proj_image + correlation_image = minian_dataset.correlation_map + + summary_images = { + **key, + "ref_image": ( + ref_image if ref_image is not None else np.zeros((1, 1, 1)) + ), + "average_image": ( + mean_image if mean_image is not None else np.zeros((1, 1, 1)) + ), + "correlation_image": correlation_image, + "max_proj_image": max_proj_image, + } + self.Summary.insert1(summary_images) + else: raise NotImplementedError("Unknown/unimplemented method: {}".format(method)) @@ -1238,6 +1351,57 @@ def make(self, key): cells, ignore_extra_fields=True, allow_direct_insert=True ) + elif method == "minian": + minian_dataset = loaded_result + + # infer "segmentation_channel" - from params if available, else from minian loader + params = (ProcessingParamSet * ProcessingTask & key).fetch1("params") + segmentation_channel = params.get( + "segmentation_channel", minian_dataset.segmentation_channel + ) + + masks, cells = [], [] + for mask in minian_dataset.masks: + masks.append( + { + **key, + "segmentation_channel": segmentation_channel, + "mask": mask["mask_id"], + "mask_npix": mask["mask_npix"], + "mask_center_x": mask["mask_center_x"], + "mask_center_y": mask["mask_center_y"], + "mask_center_z": mask["mask_center_z"], + "mask_xpix": mask["mask_xpix"], + "mask_ypix": mask["mask_ypix"], + "mask_zpix": mask["mask_zpix"], + "mask_weights": mask["mask_weights"], + } + ) + if mask["accepted"]: + cells.append( + { + **key, + "mask_classification_method": "minian_default_classifier", + "mask": mask["mask_id"], + "mask_type": "soma", + } + ) + + self.insert1(key) + self.Mask.insert(masks, ignore_extra_fields=True) + + if cells: + MaskClassification.insert1( + { + **key, + "mask_classification_method": "minian_default_classifier", + }, + allow_direct_insert=True, + ) + MaskClassification.MaskType.insert( + cells, ignore_extra_fields=True, allow_direct_insert=True + ) + else: raise NotImplementedError(f"Unknown/unimplemented method: {method}") @@ -1255,7 +1419,7 @@ class MaskClassificationMethod(dj.Lookup): mask_classification_method: varchar(48) """ - contents = zip(["caiman_default_classifier"]) + contents = zip(["caiman_default_classifier", "minian_default_classifier"]) @schema @@ -1360,6 +1524,29 @@ def make(self, key): self.insert1(key) self.Trace.insert(fluo_traces) + elif method == "minian": + minian_dataset = loaded_result + + # infer "segmentation_channel" - from params if available, else from minian loader + params = (ProcessingParamSet * ProcessingTask & key).fetch1("params") + segmentation_channel = params.get( + "segmentation_channel", minian_dataset.segmentation_channel + ) + + fluo_traces = [] + for mask in minian_dataset.masks: + fluo_traces.append( + { + **key, + "mask": mask["mask_id"], + "fluorescence_channel": segmentation_channel, + "fluorescence": mask["inferred_trace"], + } + ) + + self.insert1(key) + self.Trace.insert(fluo_traces) + else: raise NotImplementedError("Unknown/unimplemented method: {}".format(method)) @@ -1369,14 +1556,14 @@ class ActivityExtractionMethod(dj.Lookup): """Lookup table for activity extraction methods. Attributes: - extraction_method (foreign key, varchar(32) ): Extraction method from CaImAn. + extraction_method (foreign key, varchar(32) ): Extraction method from CaImAn or Minian. """ definition = """ extraction_method: varchar(32) """ - contents = zip(["caiman_deconvolution", "caiman_dff"]) + contents = zip(["caiman_deconvolution", "caiman_dff", "minian_deconvolution"]) @schema @@ -1421,7 +1608,15 @@ def key_source(self): & 'extraction_method LIKE "caiman%"' ) - return caiman_key_source.proj() + minian_key_source = ( + Fluorescence + * ActivityExtractionMethod + * ProcessingParamSet.proj("processing_method") + & 'processing_method = "minian"' + & 'extraction_method LIKE "minian%"' + ) + + return caiman_key_source.proj() + minian_key_source.proj() def make(self, key): """Populates table with activity trace data.""" @@ -1456,6 +1651,28 @@ def make(self, key): for mask in caiman_dataset.masks ) + elif method == "minian": + minian_dataset = loaded_result + + if key["extraction_method"] == "minian_deconvolution": + # infer "segmentation_channel" - from params if available, else from minian loader + params = (ProcessingParamSet * ProcessingTask & key).fetch1("params") + segmentation_channel = params.get( + "segmentation_channel", minian_dataset.segmentation_channel + ) + + self.insert1(key) + self.Trace.insert( + dict( + key, + mask=mask["mask_id"], + fluorescence_channel=segmentation_channel, + activity_trace=mask["spikes"], + ) + for mask in minian_dataset.masks + if "spikes" in mask + ) + else: raise NotImplementedError("Unknown/unimplemented method: {}".format(method)) @@ -1526,6 +1743,191 @@ def make(self, key): # Helper Functions --------------------------------------------------------------------- +class MinianLoader: + """Loader class for Minian analysis results. + + Provides a consistent interface for accessing Minian outputs similar to CaImAn loader. + """ + + def __init__(self, output_dir): + """Initialize the MinianLoader. + + Args: + output_dir: Path to the directory containing Minian zarr outputs. + """ + import xarray as xr + + self.output_dir = pathlib.Path(output_dir) + + # Load each zarr variable into a combined Dataset + # (replicates what minian's open_minian does without requiring + # the minian package to be installed) + ds = xr.Dataset() + for zarr_path in sorted(self.output_dir.glob("*.zarr")): + try: + var_ds = xr.open_zarr(str(zarr_path)) + for name, da in var_ds.data_vars.items(): + ds[name] = da + except Exception: + logger.debug(f"Skipping non-zarr path: {zarr_path}") + self._minian_ds = ds + + # Load core arrays + self._A = self._minian_ds.get( + "A" + ) # Spatial footprints (unit_id, height, width) + self._C = self._minian_ds.get("C") # Temporal traces (unit_id, frame) + self._S = self._minian_ds.get( + "S" + ) # Deconvolved activity/spikes (unit_id, frame) + self._b = self._minian_ds.get("b") # Background spatial (height, width) + self._f = self._minian_ds.get("f") # Background temporal (frame) + self._b0 = self._minian_ds.get("b0") # Baseline (unit_id, frame) + self._c0 = self._minian_ds.get("c0") # Initial calcium (unit_id, frame) + + # Try to load motion correction data + self._motion = self._minian_ds.get("motion") + + # Summary images + self._max_proj = self._minian_ds.get("max_proj") + self._mean_proj = self._minian_ds.get("mean_proj") + self._ref_image = self._minian_ds.get("ref_image") + + @property + def minian_dataset(self): + """Return the raw Minian xarray Dataset.""" + return self._minian_ds + + @property + def creation_time(self): + """Get the creation time of the Minian output.""" + # Use the modification time of the output directory + return datetime.fromtimestamp(self.output_dir.stat().st_mtime, tz=timezone.utc) + + @property + def alignment_channel(self): + """Channel used for motion correction (default 0 for miniscope).""" + return 0 + + @property + def segmentation_channel(self): + """Channel used for segmentation (default 0 for miniscope).""" + return 0 + + @property + def is_pw_rigid(self): + """Minian uses rigid motion correction by default.""" + return False + + @property + def motion_shifts(self): + """Return motion correction shifts as dict with 'x' and 'y' keys.""" + if self._motion is not None: + motion_data = self._motion.compute() + return { + "x": motion_data.sel(shift_dim="width").values, + "y": motion_data.sel(shift_dim="height").values, + } + return None + + def extract_rigid_mc(self): + """Extract rigid motion correction data in format compatible with MotionCorrection table.""" + shifts = self.motion_shifts + if shifts is None: + return None + + return { + "x_shifts": shifts["x"], + "y_shifts": shifts["y"], + "x_std": np.std(shifts["x"]), + "y_std": np.std(shifts["y"]), + } + + @property + def ref_image(self): + """Return reference image used for motion correction.""" + if self._ref_image is not None: + return self._ref_image.compute().values[np.newaxis, :, :] + return None + + @property + def mean_image(self): + """Return mean image (average across frames).""" + if self._mean_proj is not None: + return self._mean_proj.compute().values[np.newaxis, :, :] + return None + + @property + def max_proj_image(self): + """Return maximum projection image.""" + if self._max_proj is not None: + return self._max_proj.compute().values[np.newaxis, :, :] + return None + + @property + def correlation_map(self): + """Return correlation image (computed during initialization).""" + # Minian doesn't store correlation image by default + # Return None or compute if needed + return None + + @property + def masks(self): + """Extract mask information in format compatible with Segmentation table. + + Yields dict for each unit with mask properties. + """ + if self._A is None: + return [] + + A_data = self._A.compute() + C_data = self._C.compute() if self._C is not None else None + S_data = self._S.compute() if self._S is not None else None + + masks = [] + for unit_idx, unit_id in enumerate(A_data.coords["unit_id"].values): + footprint = A_data.sel(unit_id=unit_id).values + + # Find non-zero pixels + mask_indices = np.where(footprint > 0) + if len(mask_indices[0]) == 0: + continue + + y_pix = mask_indices[0] + x_pix = mask_indices[1] + weights = footprint[y_pix, x_pix] + + mask_dict = { + "mask_id": int(unit_id), + "mask_npix": len(x_pix), + "mask_center_x": int(np.mean(x_pix)), + "mask_center_y": int(np.mean(y_pix)), + "mask_center_z": None, + "mask_xpix": x_pix, + "mask_ypix": y_pix, + "mask_zpix": None, + "mask_weights": weights, + "accepted": True, # All units accepted (no manual curation) + } + + # Add trace data if available + if C_data is not None: + mask_dict["inferred_trace"] = C_data.sel(unit_id=unit_id).values + if S_data is not None: + mask_dict["spikes"] = S_data.sel(unit_id=unit_id).values + + masks.append(mask_dict) + + return masks + + @property + def num_units(self): + """Return number of detected units.""" + if self._A is not None: + return len(self._A.coords["unit_id"]) + return 0 + + def get_loader_result(key, table, full_output_dir=None) -> tuple: """Retrieve the loaded processed imaging results from the loader (e.g. caiman, etc.) @@ -1548,7 +1950,214 @@ def get_loader_result(key, table, full_output_dir=None) -> tuple: from element_interface import caiman_loader loaded_output = caiman_loader.CaImAn(output_dir) + elif method == "minian": + loaded_output = MinianLoader(output_dir) else: raise NotImplementedError("Unknown/unimplemented method: {}".format(method)) return method, loaded_output + + +def _run_minian_in_container( + avi_files: list, + output_dir: pathlib.Path, + params: dict, + docker_image: str = MINIAN_DOCKER_IMAGE, + timeout: int = MINIAN_DOCKER_TIMEOUT, +): + import json + import subprocess + + # ── Resolve input directory ────────────────────────────────────────── + input_dir = pathlib.Path(avi_files[0]).parent.resolve() + + for avi in avi_files: + if pathlib.Path(avi).parent.resolve() != input_dir: + raise ValueError( + f"All AVI files must be in the same directory. " + f"Found files in {input_dir} and {pathlib.Path(avi).parent}" + ) + + # ── Prepare output directory ───────────────────────────────────────── + output_dir = pathlib.Path(output_dir).resolve() + os.makedirs(output_dir, exist_ok=True) + + # ── Prepare params (under output_dir so it's on a shared volume) ───── + params_dir = output_dir / ".params" + os.makedirs(params_dir, exist_ok=True) + params_file = params_dir / "params.json" + + def _serialize_params(obj): + if isinstance(obj, tuple): + return list(obj) + if isinstance(obj, dict): + return {k: _serialize_params(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_serialize_params(v) for v in obj] + return obj + + serializable_params = _serialize_params(params) + + config = { + "input_dir": "/data/input", + "output_dir": "/data/output", + "params": serializable_params, + } + + with open(params_file, "w") as f: + json.dump(config, f, indent=2) + + logger.info(f"Params written to {params_file}") + + # ── Translate container paths → host paths ─────────────────────────── + # The host Docker daemon resolves bind-mount paths on the HOST + # filesystem, not inside this worker container. We use env vars + # set by docker-compose to map container paths to host paths. + host_s3_root = os.getenv("HOST_S3_ROOT") + host_outbox = os.getenv("HOST_OUTBOX") + + def _to_host_path(container_path, container_prefix, host_prefix): + s = str(container_path) + if host_prefix and s.startswith(container_prefix): + return host_prefix + s[len(container_prefix):] + return s # fallback: use as-is (works if running directly on host) + + host_input_dir = _to_host_path(input_dir, "/home/jovyan/s3", host_s3_root) + host_output_dir = _to_host_path(output_dir, "/home/jovyan/efs/outbox", host_outbox) + host_params_dir = _to_host_path(params_dir, "/home/jovyan/efs/outbox", host_outbox) + + # ── Build docker run command ───────────────────────────────────────── + docker_cmd = [ + "docker", "run", "--rm", + "-v", f"{host_input_dir}:/data/input:ro", + "-v", f"{host_output_dir}:/data/output", + "-v", f"{host_params_dir}:/data/params:ro", + ] + + # Pass through resource configuration and Dask env vars + _passthrough_env_vars = [ + "MINIAN_NWORKERS", + "MINIAN_MEMORY_LIMIT", + # Dask distributed config (set via docker-compose with defaults) + "DASK_DISTRIBUTED__SCHEDULER__WORKER_TTL", + "DASK_DISTRIBUTED__COMM__TIMEOUTS__TCP", + "DASK_DISTRIBUTED__COMM__TIMEOUTS__CONNECT", + "DASK_DISTRIBUTED__SCHEDULER__WORK_STEALING", + "DASK_DISTRIBUTED__WORKER__MEMORY__TARGET", + "DASK_DISTRIBUTED__WORKER__MEMORY__SPILL", + "DASK_DISTRIBUTED__WORKER__MEMORY__PAUSE", + "DASK_DISTRIBUTED__WORKER__MEMORY__TERMINATE", + ] + for env_var in _passthrough_env_vars: + val = os.getenv(env_var) + if val: + docker_cmd.extend(["-e", f"{env_var}={val}"]) + + # Memory allocator tuning (reduces heap fragmentation under heavy load) + docker_cmd.extend(["-e", "MALLOC_TRIM_THRESHOLD_=131072"]) + docker_cmd.extend(["-e", "MALLOC_MMAP_THRESHOLD_=131072"]) + + # Memory limit for the container itself (if set) + container_memory = os.getenv("MINIAN_CONTAINER_MEM_LIMIT") + if container_memory: + docker_cmd.extend(["--memory", container_memory]) + + # CPU limit (if set) + container_cpus = os.getenv("MINIAN_CONTAINER_CPUS") + if container_cpus: + docker_cmd.extend(["--cpus", container_cpus]) + + # SHM size — Dask workers need shared memory for inter-process comms + docker_cmd.extend(["--shm-size", os.getenv("MINIAN_SHM_SIZE", "8g")]) + + # The image + docker_cmd.append(docker_image) + + logger.info(f"Docker command: {' '.join(docker_cmd)}") + + # ── Pull image if not present ──────────────────────────────────────── + try: + subprocess.run( + ["docker", "image", "inspect", docker_image], + capture_output=True, + check=True, + ) + logger.info(f"Docker image {docker_image} found locally") + except subprocess.CalledProcessError: + logger.info(f"Pulling Docker image {docker_image}...") + subprocess.run( + ["docker", "pull", docker_image], + check=True, + timeout=600, # 10 min pull timeout + ) + + # ── Run the container ──────────────────────────────────────────────── + logger.info("Starting minian container...") + start_time = time.time() + + process = subprocess.Popen( + docker_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, # Line-buffered + ) + + # Stream container logs to the pipeline logger in real time + try: + for line in process.stdout: + line = line.rstrip() + if line: + logger.info(f"[minian-container] {line}") + + process.wait(timeout=timeout) + except subprocess.TimeoutExpired: + logger.error( + f"Minian container timed out after {timeout}s. Killing..." + ) + process.kill() + process.wait() + raise RuntimeError( + f"Minian container exceeded timeout of {timeout} seconds" + ) + + elapsed = time.time() - start_time + logger.info( + f"Container exited with code {process.returncode} " + f"after {elapsed:.1f}s ({elapsed/3600:.1f}h)" + ) + + # ── Check results ──────────────────────────────────────────────────── + status_file = output_dir / "status.json" + + if process.returncode != 0: + error_msg = f"Minian container exited with code {process.returncode}" + if status_file.exists(): + with open(status_file) as f: + status = json.load(f) + error_msg += f": {status.get('message', 'unknown error')}" + raise RuntimeError(error_msg) + + if not status_file.exists(): + raise RuntimeError( + "Minian container exited successfully but no status.json found" + ) + + with open(status_file) as f: + status = json.load(f) + + if status.get("status") != "success": + raise RuntimeError( + f"Minian container reported failure: {status.get('message', 'unknown')}" + ) + + logger.info( + f"Minian processing complete: {status.get('n_units', '?')} units detected" + ) + + # ── Cleanup temp params dir ────────────────────────────────────────── + import shutil + + shutil.rmtree(params_dir, ignore_errors=True) + + return status diff --git a/notebooks/test_minian_pipeline.py b/notebooks/test_minian_pipeline.py new file mode 100644 index 0000000..28d4e6f --- /dev/null +++ b/notebooks/test_minian_pipeline.py @@ -0,0 +1,396 @@ +""" +Test script for Minian integration in element-miniscope. + +This script sets up the database, populates required tables, and runs +the Minian processing pipeline. + +Prerequisites: +- DataJoint database connection configured (dj.config) +- Minian package installed +- Test miniscope video files available + +Usage: + python test_minian_pipeline.py --data-dir /path/to/miniscope/videos +""" + +import os +import sys +import argparse +import datajoint as dj +import numpy as np +from pathlib import Path +from datetime import datetime + +# ============================================================================ +# Configuration - Modify these paths for your environment +# ============================================================================ + +# Root directory containing raw miniscope data +# This should be the parent directory containing session folders with .avi files +DEFAULT_DATA_DIR = "./miniscope_test/" + +# Root directory for processed outputs +DEFAULT_PROCESSED_DIR = "./processed/" + +# Database schema prefix (schemas will be named: {prefix}_miniscope, etc.) +SCHEMA_PREFIX = "test" + + +# ============================================================================ +# Linking Module - Required functions for element-miniscope +# ============================================================================ + +def get_miniscope_root_data_dir(): + """Return the root directory for raw miniscope data.""" + return [os.environ.get("MINISCOPE_ROOT_DATA_DIR", DEFAULT_DATA_DIR)] + + +def get_processed_root_data_dir(): + """Return the root directory for processed data.""" + processed_dir = os.environ.get("MINISCOPE_PROCESSED_DIR", DEFAULT_PROCESSED_DIR) + os.makedirs(processed_dir, exist_ok=True) + return processed_dir + + +def get_session_directory(session_key): + """Return the session directory for a given session key. + + For testing, we assume the directory structure is: + {root_data_dir}/{subject}/{session_date}/ + """ + from element_miniscope import miniscope + + # For a simple test setup, use the session key directly + subject = session_key.get("subject", "test_subject") + session_datetime = session_key.get("session_datetime", datetime.now()) + + if isinstance(session_datetime, datetime): + session_date = session_datetime.strftime("%Y-%m-%d") + else: + session_date = str(session_datetime).split()[0] + + return f"{subject}/{session_date}" + + +# Create a module-like object for the linking module +class LinkingModule: + get_miniscope_root_data_dir = staticmethod(get_miniscope_root_data_dir) + get_processed_root_data_dir = staticmethod(get_processed_root_data_dir) + get_session_directory = staticmethod(get_session_directory) + + +# ============================================================================ +# Schema Setup +# ============================================================================ + +def setup_schemas(): + """Activate the miniscope schemas.""" + from element_miniscope import miniscope + from element_miniscope import miniscope_report + + # Create minimal upstream tables for testing + schema = dj.schema(f"{SCHEMA_PREFIX}_lab") + + @schema + class Subject(dj.Manual): + definition = """ + subject: varchar(32) + """ + + @schema + class Session(dj.Manual): + definition = """ + -> Subject + session_datetime: datetime + """ + + @schema + class Device(dj.Lookup): + definition = """ + device: varchar(32) + """ + contents = [("Miniscope_V4",)] + + @schema + class AnatomicalLocation(dj.Lookup): + definition = """ + location: varchar(32) + """ + contents = [("CA1",), ("mPFC",)] + + # Add required attributes to linking module + LinkingModule.Subject = Subject + LinkingModule.Session = Session + LinkingModule.Device = Device + LinkingModule.AnatomicalLocation = AnatomicalLocation + + # Activate miniscope schema + miniscope.activate( + f"{SCHEMA_PREFIX}_miniscope", + linking_module=LinkingModule, + create_schema=True, + create_tables=True, + ) + + # Activate report schema + miniscope_report.activate( + f"{SCHEMA_PREFIX}_miniscope_report", + create_schema=True, + create_tables=True, + ) + + print(f"Schemas activated: {SCHEMA_PREFIX}_lab, {SCHEMA_PREFIX}_miniscope, {SCHEMA_PREFIX}_miniscope_report") + + return Subject, Session, Device, miniscope + + +def populate_test_data(Subject, Session, miniscope, data_dir): + """Populate the database with test data entries.""" + + # Insert test subject + subject_key = {"subject": "test_mouse"} + Subject.insert1(subject_key, skip_duplicates=True) + print(f"Inserted subject: {subject_key}") + + # Insert test session + session_key = { + **subject_key, + "session_datetime": datetime.now().replace(microsecond=0), + } + Session.insert1(session_key, skip_duplicates=True) + print(f"Inserted session: {session_key}") + + # Insert recording + recording_key = { + **session_key, + "recording_id": 0, + "acq_software": "Miniscope-DAQ-V4", + } + miniscope.Recording.insert1(recording_key, skip_duplicates=True) + print(f"Inserted recording: {recording_key}") + + # Populate RecordingInfo (this reads metadata from the video files) + miniscope.RecordingInfo.populate(display_progress=True) + print("Populated RecordingInfo") + + return session_key, recording_key + + +def create_minian_paramset(miniscope): + """Create a ProcessingParamSet for Minian analysis.""" + + minian_params = { + # Video loading + "load_videos": { + "pattern": r".*\.avi$", + "dtype": "uint8", + "downsample": {"frame": 1, "height": 1, "width": 1}, + "downsample_strategy": "subset", + }, + # Preprocessing + "denoise": {"method": "median", "ksize": 7}, + "background_removal": {"method": "tophat", "wnd": 10}, + # Motion correction + "estimate_motion": {"dim": "frame"}, + # Seeds initialization + "seeds_init": { + "wnd_size": 1000, + "method": "rolling", + "stp_size": 500, + "max_wnd": 15, + "diff_thres": 6.5, + }, + "pnr_refine": {"noise_freq": 0.05, "thres": 1}, + "ks_refine": {"sig": 0.05}, + "seeds_merge": {"thres_dist": 10, "thres_corr": 0.8, "noise_freq": 0.06}, + "initialize": {"thres_corr": 0.8, "wnd": 10, "noise_freq": 0.06}, + "init_merge": {"thres_corr": 0.8}, + # CNMF + "get_noise": {"noise_range": (0.06, 0.5)}, + "first_spatial": { + "dl_wnd": 5, + "sparse_penal": 0.001, + "size_thres": (20, None), + }, + "first_temporal": { + "noise_freq": 0.06, + "sparse_penal": 0.001, + "p": 1, + "add_lag": 10, + "jac_thres": 0.1, + }, + "first_merge": {"thres_corr": 0.6}, + "second_spatial": { + "dl_wnd": 10, + "sparse_penal": 0.001, + "size_thres": (20, None), + }, + "second_temporal": { + "noise_freq": 0.06, + "sparse_penal": 0.01, + "p": 1, + "add_lag": 10, + "jac_thres": 0.2, + "zero_thres": 1e-10, + }, + } + + paramset_idx = 0 + + miniscope.ProcessingParamSet.insert_new_params( + processing_method="minian", + paramset_idx=paramset_idx, + paramset_desc="Default Minian parameters for miniscope analysis", + params=minian_params, + ) + + print(f"Created Minian ProcessingParamSet with idx={paramset_idx}") + return paramset_idx + + +def create_processing_task(miniscope, recording_key, paramset_idx): + """Create a ProcessingTask entry.""" + + task_key = { + **recording_key, + "paramset_idx": paramset_idx, + } + + # Infer output directory + output_dir = miniscope.ProcessingTask.infer_output_dir( + task_key, relative=True, mkdir=True + ) + + miniscope.ProcessingTask.insert1( + { + **task_key, + "processing_output_dir": str(output_dir), + "task_mode": "trigger", + }, + skip_duplicates=True, + ) + + print(f"Created ProcessingTask: {task_key}") + print(f"Output directory: {output_dir}") + + return task_key + + +def run_processing(miniscope): + """Run the processing pipeline.""" + print("\n" + "=" * 60) + print("Starting Minian Processing...") + print("=" * 60 + "\n") + + miniscope.Processing.populate(display_progress=True) + + print("\nProcessing complete!") + print(f"Entries in Processing table: {len(miniscope.Processing())}") + + +def populate_downstream_tables(miniscope, miniscope_report): + """Populate downstream analysis tables.""" + print("\n" + "=" * 60) + print("Populating downstream tables...") + print("=" * 60 + "\n") + + print("Populating MotionCorrection...") + miniscope.MotionCorrection.populate(display_progress=True) + + print("Populating Segmentation...") + miniscope.Segmentation.populate(display_progress=True) + + print("Populating Fluorescence...") + miniscope.Fluorescence.populate(display_progress=True) + + print("Populating Activity...") + miniscope.Activity.populate(display_progress=True) + + print("Populating visualizations...") + miniscope_report.MinianProcessingVisualization.populate(display_progress=True) + + print("\nDownstream tables populated!") + + +def verify_results(miniscope): + """Print summary of results.""" + print("\n" + "=" * 60) + print("Results Summary") + print("=" * 60 + "\n") + + print(f"Processing entries: {len(miniscope.Processing())}") + print(f"MotionCorrection entries: {len(miniscope.MotionCorrection())}") + print(f"Segmentation entries: {len(miniscope.Segmentation())}") + print(f"Segmentation.Mask entries: {len(miniscope.Segmentation.Mask())}") + print(f"Fluorescence entries: {len(miniscope.Fluorescence())}") + print(f"Fluorescence.Trace entries: {len(miniscope.Fluorescence.Trace())}") + print(f"Activity entries: {len(miniscope.Activity())}") + + # Show some mask statistics if available + if len(miniscope.Segmentation.Mask()) > 0: + masks = miniscope.Segmentation.Mask.fetch() + print(f"\nDetected {len(masks)} ROIs/cells") + + +def main(): + parser = argparse.ArgumentParser(description="Test Minian integration") + parser.add_argument( + "--data-dir", + type=str, + default=DEFAULT_DATA_DIR, + help="Path to directory containing miniscope video files", + ) + parser.add_argument( + "--processed-dir", + type=str, + default=DEFAULT_PROCESSED_DIR, + help="Path to directory for processed outputs", + ) + parser.add_argument( + "--skip-processing", + action="store_true", + help="Skip processing and just verify existing results", + ) + args = parser.parse_args() + + # Set environment variables + os.environ["MINISCOPE_ROOT_DATA_DIR"] = os.path.abspath(args.data_dir) + os.environ["MINISCOPE_PROCESSED_DIR"] = os.path.abspath(args.processed_dir) + + print(f"Data directory: {os.environ['MINISCOPE_ROOT_DATA_DIR']}") + print(f"Processed directory: {os.environ['MINISCOPE_PROCESSED_DIR']}") + + # Import after setting paths + from element_miniscope import miniscope_report + + # Setup schemas + Subject, Session, Device, miniscope = setup_schemas() + + if not args.skip_processing: + # Populate test data + session_key, recording_key = populate_test_data( + Subject, Session, miniscope, args.data_dir + ) + + # Create parameter set + paramset_idx = create_minian_paramset(miniscope) + + # Create processing task + task_key = create_processing_task(miniscope, recording_key, paramset_idx) + + # Run processing + run_processing(miniscope) + + # Populate downstream tables + populate_downstream_tables(miniscope, miniscope_report) + + # Verify results + verify_results(miniscope) + + print("\n" + "=" * 60) + print("Test complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index ddfc8bf..a036b3b 100644 --- a/setup.py +++ b/setup.py @@ -41,5 +41,12 @@ "element-session @ git+https://github.com/datajoint/element-session.git", ], "tests": ["pytest", "pytest-cov", "shutils"], + "minian": [ + "dask==2022.5.0", + "xarray==2022.3.0", + "pandas==1.5.3", + "distributed==2022.5.0", + "minian @ git+https://github.com/kushalbakshi/minian.git", + ] }, ) \ No newline at end of file