Skip to content

Commit 6081908

Browse files
committed
get to usable form after 0.104 release
1 parent cdc784d commit 6081908

6 files changed

Lines changed: 63 additions & 85 deletions

File tree

spikeinterface_gui/controller.py

Lines changed: 52 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,21 @@ def __init__(
8888
self.verbose = verbose
8989

9090
self.original_analyzer = None
91+
92+
self.main_settings = _default_main_settings.copy()
93+
if user_main_settings is not None:
94+
self.main_settings.update(user_main_settings)
95+
9196
self.set_analyzer_info(analyzer)
9297
self.units_table = make_units_table_from_analyzer(self.analyzer, extra_properties=extra_unit_properties)
9398

99+
# parse events
100+
self.events = None
101+
if events is not None:
102+
self.events = parse_events(events, self, verbose=verbose)
103+
if len(self.events) == 0:
104+
self.events = None
105+
94106
if displayed_unit_properties is None:
95107
displayed_unit_properties = list(_default_displayed_unit_properties)
96108
if extra_unit_properties is not None:
@@ -102,6 +114,10 @@ def __init__(
102114
# spikeinterface handle colors in matplotlib style tuple values in range (0,1)
103115
self.refresh_colors()
104116

117+
self.curation = curation
118+
self.curation_callback = curation_callback
119+
self.curation_callback_kwargs = curation_callback_kwargs
120+
105121
self._potential_merges = None
106122
self.curation = curation
107123
# TODO: Reload the dictionary if it already exists
@@ -207,10 +223,6 @@ def set_analyzer_info(self, analyzer):
207223
self.return_in_uV = self.analyzer.return_in_uV
208224
t0 = time.perf_counter()
209225

210-
self.main_settings = _default_main_settings.copy()
211-
if user_main_settings is not None:
212-
self.main_settings.update(user_main_settings)
213-
214226
self.num_channels = self.analyzer.get_num_channels()
215227
# this now private and should be access using function
216228
self._visible_unit_ids = [self.unit_ids[0]]
@@ -282,20 +294,20 @@ def set_analyzer_info(self, analyzer):
282294
else:
283295
self.spike_amplitudes = None
284296

285-
if "amplitude_scalings" in skip_extensions:
297+
if "amplitude_scalings" in self.skip_extensions:
286298
if self.verbose:
287299
print('\tSkipping amplitude_scalings')
288300
self.amplitude_scalings = None
289301
else:
290-
if verbose:
302+
if self.verbose:
291303
print('\tLoading amplitude_scalings')
292304
sa_ext = analyzer.get_extension('amplitude_scalings')
293305
if sa_ext is not None:
294306
self.amplitude_scalings = sa_ext.get_data()
295307
else:
296308
self.amplitude_scalings = None
297309

298-
if "spike_locations" in skip_extensions:
310+
if "spike_locations" in self.skip_extensions:
299311
if self.verbose:
300312
print('\tSkipping spike_locations')
301313
self.spike_depths = None
@@ -388,13 +400,6 @@ def set_analyzer_info(self, analyzer):
388400
self.num_segments = self.analyzer.get_num_segments()
389401
self.sampling_frequency = self.analyzer.sampling_frequency
390402

391-
# parse events
392-
self.events = None
393-
if events is not None:
394-
self.events = parse_events(events, self, verbose=verbose)
395-
if len(self.events) == 0:
396-
self.events = None
397-
398403
t1 = time.perf_counter()
399404
if self.verbose:
400405
print('Loading extensions took', t1 - t0)
@@ -464,74 +469,9 @@ def set_analyzer_info(self, analyzer):
464469

465470
self._traces_cached = {}
466471

467-
self.units_table = make_units_table_from_analyzer(analyzer, extra_properties=extra_unit_properties)
468-
469-
if displayed_unit_properties is None:
470-
displayed_unit_properties = list(_default_displayed_unit_properties)
471-
if extra_unit_properties is not None:
472-
displayed_unit_properties += list(extra_unit_properties.keys())
473-
displayed_unit_properties = [v for v in displayed_unit_properties if v in self.units_table.columns]
474-
self.displayed_unit_properties = displayed_unit_properties
475-
476472
# set default time info
477473
self.update_time_info()
478474

479-
self.curation = curation
480-
self.curation_callback = curation_callback
481-
self.curation_callback_kwargs = curation_callback_kwargs
482-
483-
if self.curation:
484-
# rules:
485-
# * if user sends curation_data, then it is used
486-
# * otherwise, if curation_data already exists in folder it is used
487-
# * otherwise create an empty one
488-
489-
if curation_data is not None:
490-
# validate the curation data
491-
curation_data = deepcopy(curation_data)
492-
format_version = curation_data.get("format_version", None)
493-
# assume version 2 if not present
494-
if format_version is None:
495-
raise ValueError("Curation data format version is missing and is required in the curation data.")
496-
try:
497-
validate_curation_dict(curation_data)
498-
except Exception as e:
499-
raise ValueError(f"Invalid curation data.\nError: {e}")
500-
501-
elif self.analyzer.format == "binary_folder":
502-
json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json"
503-
if json_file.exists():
504-
with open(json_file, "r") as f:
505-
curation_data = json.load(f)
506-
507-
elif self.analyzer.format == "zarr":
508-
import zarr
509-
zarr_root = zarr.open(self.analyzer.folder, mode='r')
510-
if "spikeinterface_gui" in zarr_root.keys() and "curation_data" in zarr_root["spikeinterface_gui"].attrs.keys():
511-
curation_data = zarr_root["spikeinterface_gui"].attrs["curation_data"]
512-
513-
if curation_data is None:
514-
curation_data = deepcopy(empty_curation_data)
515-
curation_data["unit_ids"] = self.unit_ids.tolist()
516-
517-
if "label_definitions" not in curation_data:
518-
if label_definitions is not None:
519-
curation_data["label_definitions"] = label_definitions
520-
else:
521-
curation_data["label_definitions"] = default_label_definitions.copy()
522-
523-
# This will enable the default shortcuts if has default quality labels
524-
self.has_default_quality_labels = False
525-
if "quality" in curation_data["label_definitions"]:
526-
curation_dict_quality_labels = curation_data["label_definitions"]["quality"]["label_options"]
527-
default_quality_labels = default_label_definitions["quality"]["label_options"]
528-
if set(curation_dict_quality_labels) == set(default_quality_labels):
529-
if self.verbose:
530-
print('Curation quality labels are the default ones')
531-
self.has_default_quality_labels = True
532-
533-
curation_data = Curation(**curation_data).model_dump()
534-
self.curation_data = curation_data
535475

536476
def check_is_view_possible(self, view_name):
537477
from .viewlist import get_all_possible_views
@@ -1014,9 +954,42 @@ def construct_final_curation(self, with_explicit_new_unit_ids=False):
1014954
d["unit_ids"] = self.unit_ids.tolist()
1015955
d.update(self.curation_data.copy())
1016956

957+
if with_explicit_new_unit_ids:
958+
split_new_id_strategy = self.main_settings.get('split_new_id_strategy')
959+
merge_new_id_strategy = self.main_settings.get('merge_new_id_strategy')
960+
d = add_new_unit_ids_to_curation_dict(d, self.analyzer.sorting, split_new_id_strategy=split_new_id_strategy, merge_new_id_strategy=merge_new_id_strategy)
961+
1017962
model = Curation(**d)
1018963
return model
1019964

965+
def apply_curation(self):
966+
967+
if self.original_analyzer is None:
968+
self.original_analyzer = deepcopy(self.analyzer)
969+
self.original_analyzer.extensions = {}
970+
971+
curation = self.construct_final_curation(with_explicit_new_unit_ids=True)
972+
curated_analyzer = apply_curation(self.analyzer, curation)
973+
974+
self.applied_curations.append(curation)
975+
self.remove_curation()
976+
977+
self.set_analyzer_info(curated_analyzer)
978+
979+
# for now, don't show externally provided properties after curation
980+
self.displayed_unit_properties = [displayed_property for displayed_property in self.displayed_unit_properties if displayed_property not in self.extra_unit_properties_names]
981+
self.units_table = make_units_table_from_analyzer(self.analyzer)
982+
self.refresh_colors(existing_colors=self.colors)
983+
984+
for view in self.views:
985+
view.reinitialize()
986+
987+
def remove_curation(self):
988+
label_definitioins = self.curation_data.get("label_definitions", None)
989+
curation_data = deepcopy(empty_curation_data)
990+
curation_data["label_definitions"] = label_definitioins
991+
self.curation_data = curation_data
992+
1020993
def set_curation_data(self, curation_data):
1021994
print("Setting curation data")
1022995
new_curation_data = empty_curation_data.copy()

spikeinterface_gui/curationview.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _qt_make_layout(self):
7373
elif self.controller.curation_can_be_saved():
7474
but = QT.QPushButton("Save in analyzer")
7575
tb.addWidget(but)
76-
but.clicked.connect(self.save_curation_in_analyzer)
76+
but.clicked.connect(self.controller.save_curation_in_analyzer)
7777

7878
but_apply = QT.QPushButton("Apply curation")
7979
tb.addWidget(but_apply)

spikeinterface_gui/isiview.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def _on_settings_changed(self):
2525
self.isi_histograms, self.isi_bins = None, None
2626
self.refresh()
2727

28+
def _reinitialize(self):
29+
self.isi_histograms, self.isi_bins = self.controller.get_isi_histograms()
30+
self._refresh()
31+
2832
## QT ##
2933

3034
def _qt_make_layout(self):

spikeinterface_gui/mainsettingsview.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
{'name': 'color_mode', 'type': 'list', 'value' : 'color_by_unit',
1010
'limits': ['color_by_unit', 'color_only_visible', 'color_by_visibility']},
1111
{'name': 'use_times', 'type': 'bool', 'value': False},
12-
{'name': 'merge_new_id_strategy', 'type': 'list', 'limits' : ['take_first', 'append', 'join']},
13-
{'name': 'split_new_id_strategy', 'type': 'list', 'limits' : ['append', 'split']},
12+
{'name': 'merge_new_id_strategy', 'type': 'list', 'limits' : ['take_first', 'append', 'join'], 'value': 'take_first'},
13+
{'name': 'split_new_id_strategy', 'type': 'list', 'limits' : ['append', 'split'], 'value': 'append'},
1414
]
1515

1616

spikeinterface_gui/maintemplateview.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,19 @@ def _qt_refresh(self):
9292

9393
if peak_data is not None:
9494
# trough
95-
peak_inds = peak_data[['trough_index']].values
95+
peak_inds = peak_data[['trough_index']].values.astype(int)
9696
scatter = pg.ScatterPlotItem(x = times[peak_inds], y = template_high[peak_inds],
9797
size=10, pxMode = True, color="white", symbol="t")
9898
plot.addItem(scatter)
9999

100100
names = ('peak_before', 'peak_after')
101-
peak_inds = peak_data[[f'{k}_index' for k in names]].values
101+
peak_inds = peak_data[[f'{k}_index' for k in names]].values.astype(int)
102102
scatter = pg.ScatterPlotItem(x = times[peak_inds], y = template_high[peak_inds],
103103
size=10, pxMode = True, color="white", symbol="t1")
104104
plot.addItem(scatter)
105105

106106
all_names = ('trough', 'peak_before', 'peak_after')
107-
peak_inds = peak_data[[f'{k}_index' for k in all_names]].values
107+
peak_inds = peak_data[[f'{k}_index' for k in all_names]].values.astype(int)
108108
# Vertical dotted lines from peak to zero
109109
for ind in peak_inds:
110110
x = [times[ind], times[ind]]

spikeinterface_gui/mergeview.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def accept_group_merge(self, group_ids):
161161
self.refresh()
162162

163163
def _reinitialize(self):
164+
self.proposed_merge_unit_groups_all = []
164165
self.proposed_merge_unit_groups = []
165166
self.merge_info = {}
166167
self._refresh()

0 commit comments

Comments
 (0)