@@ -88,9 +88,21 @@ def __init__(
8888 self .verbose = verbose
8989
9090 self .original_analyzer = None
91+
92+ self .main_settings = _default_main_settings .copy ()
93+ if user_main_settings is not None :
94+ self .main_settings .update (user_main_settings )
95+
9196 self .set_analyzer_info (analyzer )
9297 self .units_table = make_units_table_from_analyzer (self .analyzer , extra_properties = extra_unit_properties )
9398
99+ # parse events
100+ self .events = None
101+ if events is not None :
102+ self .events = parse_events (events , self , verbose = verbose )
103+ if len (self .events ) == 0 :
104+ self .events = None
105+
94106 if displayed_unit_properties is None :
95107 displayed_unit_properties = list (_default_displayed_unit_properties )
96108 if extra_unit_properties is not None :
@@ -102,6 +114,10 @@ def __init__(
102114 # spikeinterface handle colors in matplotlib style tuple values in range (0,1)
103115 self .refresh_colors ()
104116
117+ self .curation = curation
118+ self .curation_callback = curation_callback
119+ self .curation_callback_kwargs = curation_callback_kwargs
120+
105121 self ._potential_merges = None
106122 self .curation = curation
107123 # TODO: Reload the dictionary if it already exists
@@ -207,10 +223,6 @@ def set_analyzer_info(self, analyzer):
207223 self .return_in_uV = self .analyzer .return_in_uV
208224 t0 = time .perf_counter ()
209225
210- self .main_settings = _default_main_settings .copy ()
211- if user_main_settings is not None :
212- self .main_settings .update (user_main_settings )
213-
214226 self .num_channels = self .analyzer .get_num_channels ()
215227 # this now private and should be access using function
216228 self ._visible_unit_ids = [self .unit_ids [0 ]]
@@ -282,20 +294,20 @@ def set_analyzer_info(self, analyzer):
282294 else :
283295 self .spike_amplitudes = None
284296
285- if "amplitude_scalings" in skip_extensions :
297+ if "amplitude_scalings" in self . skip_extensions :
286298 if self .verbose :
287299 print ('\t Skipping amplitude_scalings' )
288300 self .amplitude_scalings = None
289301 else :
290- if verbose :
302+ if self . verbose :
291303 print ('\t Loading amplitude_scalings' )
292304 sa_ext = analyzer .get_extension ('amplitude_scalings' )
293305 if sa_ext is not None :
294306 self .amplitude_scalings = sa_ext .get_data ()
295307 else :
296308 self .amplitude_scalings = None
297309
298- if "spike_locations" in skip_extensions :
310+ if "spike_locations" in self . skip_extensions :
299311 if self .verbose :
300312 print ('\t Skipping spike_locations' )
301313 self .spike_depths = None
@@ -388,13 +400,6 @@ def set_analyzer_info(self, analyzer):
388400 self .num_segments = self .analyzer .get_num_segments ()
389401 self .sampling_frequency = self .analyzer .sampling_frequency
390402
391- # parse events
392- self .events = None
393- if events is not None :
394- self .events = parse_events (events , self , verbose = verbose )
395- if len (self .events ) == 0 :
396- self .events = None
397-
398403 t1 = time .perf_counter ()
399404 if self .verbose :
400405 print ('Loading extensions took' , t1 - t0 )
@@ -464,74 +469,9 @@ def set_analyzer_info(self, analyzer):
464469
465470 self ._traces_cached = {}
466471
467- self .units_table = make_units_table_from_analyzer (analyzer , extra_properties = extra_unit_properties )
468-
469- if displayed_unit_properties is None :
470- displayed_unit_properties = list (_default_displayed_unit_properties )
471- if extra_unit_properties is not None :
472- displayed_unit_properties += list (extra_unit_properties .keys ())
473- displayed_unit_properties = [v for v in displayed_unit_properties if v in self .units_table .columns ]
474- self .displayed_unit_properties = displayed_unit_properties
475-
476472 # set default time info
477473 self .update_time_info ()
478474
479- self .curation = curation
480- self .curation_callback = curation_callback
481- self .curation_callback_kwargs = curation_callback_kwargs
482-
483- if self .curation :
484- # rules:
485- # * if user sends curation_data, then it is used
486- # * otherwise, if curation_data already exists in folder it is used
487- # * otherwise create an empty one
488-
489- if curation_data is not None :
490- # validate the curation data
491- curation_data = deepcopy (curation_data )
492- format_version = curation_data .get ("format_version" , None )
493- # assume version 2 if not present
494- if format_version is None :
495- raise ValueError ("Curation data format version is missing and is required in the curation data." )
496- try :
497- validate_curation_dict (curation_data )
498- except Exception as e :
499- raise ValueError (f"Invalid curation data.\n Error: { e } " )
500-
501- elif self .analyzer .format == "binary_folder" :
502- json_file = self .analyzer .folder / "spikeinterface_gui" / "curation_data.json"
503- if json_file .exists ():
504- with open (json_file , "r" ) as f :
505- curation_data = json .load (f )
506-
507- elif self .analyzer .format == "zarr" :
508- import zarr
509- zarr_root = zarr .open (self .analyzer .folder , mode = 'r' )
510- if "spikeinterface_gui" in zarr_root .keys () and "curation_data" in zarr_root ["spikeinterface_gui" ].attrs .keys ():
511- curation_data = zarr_root ["spikeinterface_gui" ].attrs ["curation_data" ]
512-
513- if curation_data is None :
514- curation_data = deepcopy (empty_curation_data )
515- curation_data ["unit_ids" ] = self .unit_ids .tolist ()
516-
517- if "label_definitions" not in curation_data :
518- if label_definitions is not None :
519- curation_data ["label_definitions" ] = label_definitions
520- else :
521- curation_data ["label_definitions" ] = default_label_definitions .copy ()
522-
523- # This will enable the default shortcuts if has default quality labels
524- self .has_default_quality_labels = False
525- if "quality" in curation_data ["label_definitions" ]:
526- curation_dict_quality_labels = curation_data ["label_definitions" ]["quality" ]["label_options" ]
527- default_quality_labels = default_label_definitions ["quality" ]["label_options" ]
528- if set (curation_dict_quality_labels ) == set (default_quality_labels ):
529- if self .verbose :
530- print ('Curation quality labels are the default ones' )
531- self .has_default_quality_labels = True
532-
533- curation_data = Curation (** curation_data ).model_dump ()
534- self .curation_data = curation_data
535475
536476 def check_is_view_possible (self , view_name ):
537477 from .viewlist import get_all_possible_views
@@ -1014,9 +954,42 @@ def construct_final_curation(self, with_explicit_new_unit_ids=False):
1014954 d ["unit_ids" ] = self .unit_ids .tolist ()
1015955 d .update (self .curation_data .copy ())
1016956
957+ if with_explicit_new_unit_ids :
958+ split_new_id_strategy = self .main_settings .get ('split_new_id_strategy' )
959+ merge_new_id_strategy = self .main_settings .get ('merge_new_id_strategy' )
960+ d = add_new_unit_ids_to_curation_dict (d , self .analyzer .sorting , split_new_id_strategy = split_new_id_strategy , merge_new_id_strategy = merge_new_id_strategy )
961+
1017962 model = Curation (** d )
1018963 return model
1019964
965+ def apply_curation (self ):
966+
967+ if self .original_analyzer is None :
968+ self .original_analyzer = deepcopy (self .analyzer )
969+ self .original_analyzer .extensions = {}
970+
971+ curation = self .construct_final_curation (with_explicit_new_unit_ids = True )
972+ curated_analyzer = apply_curation (self .analyzer , curation )
973+
974+ self .applied_curations .append (curation )
975+ self .remove_curation ()
976+
977+ self .set_analyzer_info (curated_analyzer )
978+
979+ # for now, don't show externally provided properties after curation
980+ self .displayed_unit_properties = [displayed_property for displayed_property in self .displayed_unit_properties if displayed_property not in self .extra_unit_properties_names ]
981+ self .units_table = make_units_table_from_analyzer (self .analyzer )
982+ self .refresh_colors (existing_colors = self .colors )
983+
984+ for view in self .views :
985+ view .reinitialize ()
986+
987+ def remove_curation (self ):
988+ label_definitioins = self .curation_data .get ("label_definitions" , None )
989+ curation_data = deepcopy (empty_curation_data )
990+ curation_data ["label_definitions" ] = label_definitioins
991+ self .curation_data = curation_data
992+
1020993 def set_curation_data (self , curation_data ):
1021994 print ("Setting curation data" )
1022995 new_curation_data = empty_curation_data .copy ()
0 commit comments