Skip to content

Commit 79937a0

Browse files
committed
retain quality labels
1 parent 6081908 commit 79937a0

6 files changed

Lines changed: 99 additions & 114 deletions

File tree

spikeinterface_gui/basescatterview.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get_unit_data(self, unit_id, segment_index=0):
5656
return spike_times, spike_data, np.array([1]), np.array([ymin, ymax]), ymin, ymax, inds
5757

5858
# avoid clear outliers in the plot and histogram by using percentiles
59-
ymin, ymax = np.percentile(spike_data, [self.settings['display_low_percentiles'], self.settings['display_high_percentiles']])
59+
ymin, ymax = np.percentile(spike_data[~np.isnan(spike_data)], [self.settings['display_low_percentiles'], self.settings['display_high_percentiles']])
6060
min_bin_size = np.min(np.diff(np.unique(spike_data)))
6161
bins = np.linspace(ymin, ymax, self.settings['num_bins'])
6262
# if bins are too small, adjust the number of bins to ensure a minimum bin size and avoid jumps in the histogram
@@ -329,8 +329,8 @@ def _qt_refresh(self, set_scatter_range=False):
329329
# set x range to time range of the current segment for scatter, and max count for histogram
330330
# set y range to min and max of visible spike amplitudes
331331
if len(ymins) > 0 and (set_scatter_range or not self._first_refresh_done):
332-
ymin = np.min(ymins)
333-
ymax = np.max(ymaxs)
332+
ymin = np.nanmin(ymins)
333+
ymax = np.nanmax(ymaxs)
334334
t_start, t_stop = self.controller.get_t_start_t_stop()
335335
self.viewBox.setXRange(t_start, t_stop, padding = 0.0)
336336
self.viewBox.setYRange(ymin, ymax, padding = 0.0)

spikeinterface_gui/controller.py

Lines changed: 92 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def __init__(
8383
self.signal_handler = SignalHandler(self, parent=parent)
8484

8585
self.with_traces = with_traces
86-
self.main_settings = _default_main_settings.copy()
8786
self.save_on_compute = save_on_compute
8887
self.verbose = verbose
8988

@@ -95,6 +94,8 @@ def __init__(
9594

9695
self.set_analyzer_info(analyzer)
9796
self.units_table = make_units_table_from_analyzer(self.analyzer, extra_properties=extra_unit_properties)
97+
98+
self.set_curation_info(curation, curation_data, label_definitions, curation_callback, curation_callback_kwargs)
9899

99100
# parse events
100101
self.events = None
@@ -114,81 +115,7 @@ def __init__(
114115
# spikeinterface handle colors in matplotlib style tuple values in range (0,1)
115116
self.refresh_colors()
116117

117-
self.curation = curation
118-
self.curation_callback = curation_callback
119-
self.curation_callback_kwargs = curation_callback_kwargs
120-
121-
self._potential_merges = None
122-
self.curation = curation
123-
# TODO: Reload the dictionary if it already exists
124-
if self.curation:
125-
# rules:
126-
# * if user sends curation_data, then it is used
127-
# * otherwise, if curation_data already exists in folder it is used
128-
# * otherwise create an empty one
129-
130-
if curation_data is not None:
131-
# validate the curation data
132-
format_version = curation_data.get("format_version", None)
133-
# assume version 2 if not present
134-
if format_version is None:
135-
raise ValueError("Curation data format version is missing and is required in the curation data.")
136-
try:
137-
validate_curation_dict(curation_data)
138-
except Exception as e:
139-
raise ValueError(f"Invalid curation data.\nError: {e}")
140-
141-
if curation_data.get("merges") is None:
142-
curation_data["merges"] = []
143-
else:
144-
# here we reset the merges for better formatting (str)
145-
existing_merges = curation_data["merges"]
146-
new_merges = []
147-
for m in existing_merges:
148-
if "unit_ids" not in m:
149-
continue
150-
if len(m["unit_ids"]) < 2:
151-
continue
152-
new_merges = add_merge(new_merges, m["unit_ids"])
153-
curation_data["merges"] = new_merges
154-
if curation_data.get("splits") is None:
155-
curation_data["splits"] = []
156-
if curation_data.get("removed") is None:
157-
curation_data["removed"] = []
158-
159-
elif self.analyzer.format == "binary_folder":
160-
json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json"
161-
if json_file.exists():
162-
with open(json_file, "r") as f:
163-
curation_data = json.load(f)
164-
165-
elif self.analyzer.format == "zarr":
166-
import zarr
167-
zarr_root = zarr.open(self.analyzer.folder, mode='r')
168-
if "spikeinterface_gui" in zarr_root.keys() and "curation_data" in zarr_root["spikeinterface_gui"].attrs.keys():
169-
curation_data = zarr_root["spikeinterface_gui"].attrs["curation_data"]
170-
171-
if curation_data is None:
172-
curation_data = deepcopy(empty_curation_data)
173-
curation_data["label_definitions"] = default_label_definitions.copy()
174-
175-
if curation_data.get("discard_spikes") is None:
176-
curation_data["discard_spikes"] = []
177-
178-
self.curation_data = curation_data
179-
180-
if "label_definitions" not in self.curation_data:
181-
if label_definitions is not None:
182-
self.curation_data["label_definitions"] = label_definitions
183-
184-
self.has_default_quality_labels = False
185-
if "quality" in self.curation_data["label_definitions"]:
186-
curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"]
187-
default_quality_labels = default_label_definitions["quality"]["label_options"]
188-
if set(curation_dict_quality_labels) == set(default_quality_labels):
189-
if self.verbose:
190-
print('Curation quality labels are the default ones')
191-
self.has_default_quality_labels = True
118+
192119

193120
def check_is_view_possible(self, view_name):
194121
from .viewlist import get_all_possible_views
@@ -473,30 +400,83 @@ def set_analyzer_info(self, analyzer):
473400
self.update_time_info()
474401

475402

476-
def check_is_view_possible(self, view_name):
477-
from .viewlist import get_all_possible_views
478-
possible_class_views = get_all_possible_views()
479-
view_class = possible_class_views[view_name]
480-
if view_class._depend_on is not None:
481-
depencies_ok = all(self.has_extension(k) for k in view_class._depend_on)
482-
if not depencies_ok:
483-
if self.verbose:
484-
print(view_name, 'does not have all dependencies', view_class._depend_on)
485-
return False
486-
return True
403+
def set_curation_info(self, curation, curation_data, label_definitions, curation_callback, curation_callback_kwargs):
487404

488-
def declare_a_view(self, new_view):
489-
assert new_view not in self.views, 'view already declared {}'.format(self)
490-
self.views.append(new_view)
491-
self.signal_handler.connect_view(new_view)
492-
493-
@property
494-
def channel_ids(self):
495-
return self.analyzer.channel_ids
405+
self.curation = curation
406+
self.curation_callback = curation_callback
407+
self.curation_callback_kwargs = curation_callback_kwargs
496408

497-
@property
498-
def unit_ids(self):
499-
return self.analyzer.unit_ids
409+
self._potential_merges = None
410+
self.curation = curation
411+
# TODO: Reload the dictionary if it already exists
412+
if self.curation:
413+
# rules:
414+
# * if user sends curation_data, then it is used
415+
# * otherwise, if curation_data already exists in folder it is used
416+
# * otherwise create an empty one
417+
418+
if curation_data is not None:
419+
# validate the curation data
420+
format_version = curation_data.get("format_version", None)
421+
# assume version 2 if not present
422+
if format_version is None:
423+
raise ValueError("Curation data format version is missing and is required in the curation data.")
424+
try:
425+
validate_curation_dict(curation_data)
426+
except Exception as e:
427+
raise ValueError(f"Invalid curation data.\nError: {e}")
428+
429+
if curation_data.get("merges") is None:
430+
curation_data["merges"] = []
431+
else:
432+
# here we reset the merges for better formatting (str)
433+
existing_merges = curation_data["merges"]
434+
new_merges = []
435+
for m in existing_merges:
436+
if "unit_ids" not in m:
437+
continue
438+
if len(m["unit_ids"]) < 2:
439+
continue
440+
new_merges = add_merge(new_merges, m["unit_ids"])
441+
curation_data["merges"] = new_merges
442+
if curation_data.get("splits") is None:
443+
curation_data["splits"] = []
444+
if curation_data.get("removed") is None:
445+
curation_data["removed"] = []
446+
447+
elif self.analyzer.format == "binary_folder":
448+
json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json"
449+
if json_file.exists():
450+
with open(json_file, "r") as f:
451+
curation_data = json.load(f)
452+
453+
elif self.analyzer.format == "zarr":
454+
import zarr
455+
zarr_root = zarr.open(self.analyzer.folder, mode='r')
456+
if "spikeinterface_gui" in zarr_root.keys() and "curation_data" in zarr_root["spikeinterface_gui"].attrs.keys():
457+
curation_data = zarr_root["spikeinterface_gui"].attrs["curation_data"]
458+
459+
if curation_data is None:
460+
curation_data = deepcopy(empty_curation_data)
461+
curation_data["label_definitions"] = default_label_definitions.copy()
462+
463+
if curation_data.get("discard_spikes") is None:
464+
curation_data["discard_spikes"] = []
465+
466+
self.curation_data = curation_data
467+
468+
if "label_definitions" not in self.curation_data:
469+
if label_definitions is not None:
470+
self.curation_data["label_definitions"] = label_definitions
471+
472+
self.has_default_quality_labels = False
473+
if "quality" in self.curation_data["label_definitions"]:
474+
curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"]
475+
default_quality_labels = default_label_definitions["quality"]["label_options"]
476+
if set(curation_dict_quality_labels) == set(default_quality_labels):
477+
if self.verbose:
478+
print('Curation quality labels are the default ones')
479+
self.has_default_quality_labels = True
500480

501481
def get_time(self):
502482
"""
@@ -953,7 +933,6 @@ def construct_final_curation(self, with_explicit_new_unit_ids=False):
953933
d["format_version"] = "2"
954934
d["unit_ids"] = self.unit_ids.tolist()
955935
d.update(self.curation_data.copy())
956-
957936
if with_explicit_new_unit_ids:
958937
split_new_id_strategy = self.main_settings.get('split_new_id_strategy')
959938
merge_new_id_strategy = self.main_settings.get('merge_new_id_strategy')
@@ -972,7 +951,7 @@ def apply_curation(self):
972951
curated_analyzer = apply_curation(self.analyzer, curation)
973952

974953
self.applied_curations.append(curation)
975-
self.remove_curation()
954+
self.remove_curation(curated_analyzer)
976955

977956
self.set_analyzer_info(curated_analyzer)
978957

@@ -984,10 +963,21 @@ def apply_curation(self):
984963
for view in self.views:
985964
view.reinitialize()
986965

987-
def remove_curation(self):
988-
label_definitioins = self.curation_data.get("label_definitions", None)
966+
def remove_curation(self, curated_analyzer):
967+
"""Removes curation from the controller, retaining quality labels."""
968+
989969
curation_data = deepcopy(empty_curation_data)
970+
# retain label definitions and 'quality' label
971+
label_definitioins = self.curation_data.get("label_definitions", None)
990972
curation_data["label_definitions"] = label_definitioins
973+
974+
if (quality_labels := curated_analyzer.get_sorting_property('quality')) is not None:
975+
manual_labels = []
976+
for unit_id, quality_label in zip(curated_analyzer.unit_ids, quality_labels):
977+
manual_labels.append({'unit_id': unit_id, 'labels': {'quality': [quality_label]}})
978+
979+
curation_data['manual_labels'] = manual_labels
980+
991981
self.curation_data = curation_data
992982

993983
def set_curation_data(self, curation_data):

spikeinterface_gui/curationview.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .view_base import ViewBase
55

66
from spikeinterface.core.core_tools import check_json
7-
7+
from spikeinterface.curation.curation_model import SequentialCuration
88

99
class CurationView(ViewBase):
1010
id = "curation"
@@ -301,9 +301,6 @@ def _qt_export_json(self):
301301
f.write(curation_model.model_dump_json(indent=4))
302302
self.controller.current_curation_saved = True
303303
else:
304-
# Keep this here until `SeqentialCuration` in release of spikeinterface
305-
from spikeinterface.curation.curation_model import SequentialCuration
306-
307304
current_curation_model = self.controller.construct_final_curation()
308305
applied_curations = self.controller.applied_curations
309306
current_and_applied_curations = applied_curations + [current_curation_model]

spikeinterface_gui/unitlistview.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _qt_set_up_visible_columns(self):
130130

131131
def _qt_reinitialize(self):
132132

133-
self._qt_set_up_visible_columns()
133+
#self._qt_set_up_visible_columns()
134134
self._qt_full_table_refresh()
135135
self._qt_refresh()
136136

@@ -227,7 +227,6 @@ def _qt_full_table_refresh(self):
227227

228228
self.table.clear()
229229

230-
231230
internal_column_names = ['unit_id', 'visible', 'channel_id']
232231

233232
# internal labels

spikeinterface_gui/utils_global.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def add_new_unit_ids_to_curation_dict(curation_dict, sorting, split_new_id_strat
6868
"""
6969

7070
from spikeinterface.core.sorting_tools import generate_unit_ids_for_split, generate_unit_ids_for_merge_group
71-
from spikeinterface.curation.curation_model import CurationModel
71+
from spikeinterface.curation.curation_model import Curation
7272

73-
curation_model = CurationModel(**curation_dict)
73+
curation_model = Curation(**curation_dict)
7474
old_unit_ids = copy(curation_model.unit_ids)
7575

7676
if len(curation_model.splits) > 0:

spikeinterface_gui/view_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(self, controller=None, parent=None, backend="qt"):
4040
create_settings(self)
4141
self.notifier = SignalNotifier(view=self)
4242
self.busy = pn.indicators.LoadingSpinner(value=True, size=20, name='busy...')
43-
self.layout = None
4443
make_layout()
4544
if self._settings is not None:
4645
listen_setting_changes(self)

0 commit comments

Comments
 (0)