forked from datajoint/element-array-ephys
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathephys_no_curation.py
More file actions
1075 lines (882 loc) · 41.4 KB
/
Copy pathephys_no_curation.py
File metadata and controls
1075 lines (882 loc) · 41.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import importlib
import inspect
import pathlib
from datetime import timedelta, datetime, timezone
import datajoint as dj
import numpy as np
import pandas as pd
from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
from scipy import signal
import intanrhdreader
from . import ephys_report, probe
from .readers import kilosort, openephys, spikeglx
logger = dj.logger
schema = dj.schema()
_linking_module = None
def activate(
ephys_schema_name: str,
probe_schema_name: str = None,
*,
create_schema: bool = True,
create_tables: bool = True,
linking_module: str = None,
):
"""Activates the `ephys` and `probe` schemas.
Args:
ephys_schema_name (str): A string containing the name of the ephys schema.
probe_schema_name (str): A string containing the name of the probe schema.
create_schema (bool): If True, schema will be created in the database.
create_tables (bool): If True, tables related to the schema will be created in the database.
linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema.
Dependencies:
Upstream tables:
culture.Experiment: A parent table to EphysSession.
Functions:
get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s).
get_organoid_directory(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings.
get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory.
"""
if isinstance(linking_module, str):
linking_module = importlib.import_module(linking_module)
assert inspect.ismodule(
linking_module
), "The argument 'dependency' must be a module's name or a module"
global _linking_module
_linking_module = linking_module
# activate
probe.activate(
probe_schema_name, create_schema=create_schema, create_tables=create_tables
)
schema.activate(
ephys_schema_name,
create_schema=create_schema,
create_tables=create_tables,
add_objects=_linking_module.__dict__,
)
ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name)
# -------------- Functions required by the elements-ephys ---------------
def get_ephys_root_data_dir() -> list:
"""Fetches absolute data path to ephys data directories.
The absolute path here is used as a reference for all downstream relative paths used in DataJoint.
Returns:
A list of the absolute path(s) to ephys data directories.
"""
root_directories = _linking_module.get_ephys_root_data_dir()
if isinstance(root_directories, (str, pathlib.Path)):
root_directories = [root_directories]
if hasattr(_linking_module, "get_processed_root_data_dir"):
root_directories.append(_linking_module.get_processed_root_data_dir())
return root_directories
def get_organoid_directory(session_key: dict) -> str:
"""Retrieve the session directory with Neuropixels for the given session.
Args:
session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database.
Returns:
A string for the path to the session directory.
"""
return _linking_module.get_organoid_directory(session_key)
def get_processed_root_data_dir() -> str:
"""Retrieve the root directory for all processed data.
Returns:
A string for the full path to the root directory for processed data.
"""
if hasattr(_linking_module, "get_processed_root_data_dir"):
return _linking_module.get_processed_root_data_dir()
else:
return get_ephys_root_data_dir()[0]
# ----------------------------- Table declarations ----------------------
@schema
class AcquisitionSoftware(dj.Lookup):
"""Name of software used for recording electrophysiological data.
Attributes:
acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys
"""
definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys
acq_software: varchar(24)
"""
contents = zip(["SpikeGLX", "Open Ephys", "Intan"])
@schema
class Port(dj.Lookup):
definition = """ # Port ID of the Intan acquisition system
port_id : char(2)
"""
contents = zip(["A", "B", "C", "D"])
@schema
class EphysRawFile(dj.Manual):
definition = """ # Catalog of all raw ephys files
file_path : varchar(512) # path to the file relative to the root directory
---
-> AcquisitionSoftware
file_time : datetime # date and time of the file acquisition
parent_folder : varchar(128) # parent folder containing the file
filename_prefix : varchar(64) # filename prefix, if any, excluding the datetime information
"""
@schema
class EphysSession(dj.Manual):
definition = """ # User defined ephys session for downstream analysis.
-> culture.Experiment
insertion_number : tinyint unsigned
start_time : datetime
end_time : datetime
---
session_type : enum("lfp", "spike_sorting", "both", "test")
"""
@schema
class EphysSessionProbe(dj.Manual):
"""User defined probe for each ephys session.
Attributes:
EphysSession (foreign key): EphysSession primary key.
probe.Probe (foreign key): probe.Probe primary key.
probe.ElectrodeConfig (foreign key): probe.ElectrodeConfig primary key.
"""
definition = """
-> EphysSession
---
-> probe.Probe
-> Port # port ID where the probe was connected to.
used_electrodes=null : longblob # list of electrode IDs used in this session (if null, all electrodes are used)
"""
@schema
class EphysSessionInfo(dj.Imported):
definition = """ # Store header information from the first session file.
-> EphysSession
---
session_info: longblob # Session header info from intan .rhd file. Get this from the first session file.
"""
def make(self, key):
query = (
EphysRawFile
& f"file_time BETWEEN '{key['start_time']}' AND '{key['end_time']}'"
)
if not query:
raise FileNotFoundError(
f"No EphysRawFile found BETWEEN '{key['start_time']}' AND '{key['end_time']}'"
)
first_file = query.fetch("file_path", order_by="file_time", limit=1)[0]
first_file = find_full_path(get_ephys_root_data_dir(), first_file)
# Read file header
with open(first_file, "rb") as f:
try:
header = intanrhdreader.read_header(f)
except OSError:
raise OSError(f"Error occurred when reading file {first_file}")
else:
del header["spike_triggers"], header["aux_input_channels"]
logger.info(f"Populating ephys.EphysSessionInfo for <{key}>")
self.insert(
[
{
**key,
"session_info": header,
}
]
)
@schema
class LFP(dj.Imported):
definition = """ # Store pre-processed LFP traces per electrode. Only the LFPs collected from a pre-defined recording session.
-> EphysSession
---
lfp_sampling_rate : float # Down-sampled sampling rate (Hz).
execution_duration : float # execution duration in hours
"""
class Trace(dj.Part):
definition = """
-> master
-> probe.ElectrodeConfig.Electrode
---
lfp : blob@datajoint-blob # uV
"""
@property
def key_source(self):
return (
EphysSession
& EphysSessionInfo
& EphysSessionProbe
& 'session_type IN ("lfp", "both")'
)
TARGET_SAMPLING_RATE = 2500 # Hz
POWERLINE_NOISE_FREQ = 60 # Hz
MAX_DURATION_MINUTES = 30 # Minutes
def make_fetch(self, key):
execution_time = datetime.now(timezone.utc)
# Check if the trace duration is within the expected range
duration = (key["end_time"] - key["start_time"]).total_seconds() / 60 # minutes
assert (
duration <= self.MAX_DURATION_MINUTES
), f"LFP session duration {duration} min > max session duration {self.MAX_DURATION_MINUTES} min"
# Fetch the raw data files for the given ephys session
query = (
EphysRawFile
& f"file_time BETWEEN '{key['start_time']}' AND '{key['end_time']}'"
)
if not query:
logger.info(f"No raw data file found. Skipping LFP for <{key}>")
return
logger.info(f"Populating ephys.LFP for <{key}>")
# Fetch the probe information for the given ephys session
probe_info = (EphysSessionProbe & key).fetch1()
probe_type = (probe.Probe & {"probe": probe_info["probe"]}).fetch1("probe_type")
electrode_query = probe.ElectrodeConfig.Electrode & (
probe.ElectrodeConfig & {"probe_type": probe_type}
)
# Fetch the electrode configuration for the given probe
# Filter for used electrodes. If probe_info["used_electrodes"] is None, it means all electrodes were used.
if probe_info["used_electrodes"]:
electrode_query &= f"electrode IN {tuple(probe_info['used_electrodes'])}"
lfp_indices = np.array(electrode_query.fetch("channel_idx"), dtype=int)
electrode_df = electrode_query.fetch(format="frame").reset_index()
file_paths = query.fetch("file_path", order_by="file_time")
return file_paths, lfp_indices, probe_info, electrode_df, execution_time
def make_compute(self, key, file_paths, lfp_indices, probe_info, electrode_df, execution_time):
"""Compute broadband LFP signals for each electrode.
Args:
key (dict): EphysSession primary key.
Raises:
ValueError: If the trace duration is not within the expected range.
OSError: If there is an error when loading the file.
Logic:
- Fetch the probe information for the given ephys session.
- Fetch the electrode configuration for the given probe.
- Fetch the raw data files for the given ephys session.
- Check for missing files or short trace durations in min
- Design notch filter to remove powerline noise that contaminates the LFP
- Downsample the signal with `decimate` and apply an anti-aliasing FIR filter
"""
header = {}
lfp_concat = []
# Iterate over the raw data files for the given ephys session to load the data
for file_relpath in file_paths:
file = find_full_path(get_ephys_root_data_dir(), file_relpath)
try:
data = intanrhdreader.load_file(file)
except OSError:
raise OSError(f"OS error occurred when loading file {file.name}")
if not header:
header = data.pop("header")
lfp_sampling_rate = header["sample_rate"]
powerline_noise_freq = (
header["notch_filter_frequency"] or self.POWERLINE_NOISE_FREQ
) # in Hz
# Calculate downsampling factor
true_ratio = lfp_sampling_rate / self.TARGET_SAMPLING_RATE
downsample_factor = int(np.round(true_ratio))
# Check if the ratio is within 1% of an integer (1% tolerance)
if not np.isclose(true_ratio, downsample_factor, rtol=0.01, atol=1e-8):
raise ValueError(
f"Downsampling factor {true_ratio} is too far from an integer. Check LFP sampling rates."
)
# Get LFP indices (row index of the LFP matrix to be used)
port_indices = np.array(
[
ind
for ind, ch in enumerate(data["amplifier_channels"])
if ch["port_prefix"] == probe_info["port_id"]
]
)
lfp_indices = np.sort(port_indices[lfp_indices])
# Get LFP channels
channels = np.array(
[
ch["native_channel_name"]
for ch in data["amplifier_channels"]
if ch["port_prefix"]
]
)[lfp_indices]
# Get channel to electrode mapping
channel_to_electrode_map = dict(
zip(electrode_df["channel_idx"], electrode_df["electrode"])
)
channel_to_electrode_map = {
f'{probe_info["port_id"]}-{int(channel):03d}': electrode
for channel, electrode in channel_to_electrode_map.items()
}
lfps = data.pop("amplifier_data")[lfp_indices]
lfp_concat.append(lfps)
full_lfp = np.hstack(lfp_concat)
# Check if the trace duration is within the expected range
trace_duration = full_lfp.shape[1] / lfp_sampling_rate / 60 # in min
if abs(trace_duration - duration) > 0.5:
raise ValueError(
f"Trace duration mismatch: expected {duration}, got {trace_duration} min"
)
# Design notch filter
notch_b, notch_a = signal.iirnotch(
w0=powerline_noise_freq, Q=30, fs=lfp_sampling_rate
)
all_lfps = []
for ch_idx, raw_lfp in zip(channels, full_lfp):
# Apply notch filter
lfp = signal.filtfilt(notch_b, notch_a, raw_lfp)
# Downsample the signal with `decimate`
lfp = signal.decimate(lfp, downsample_factor, ftype="fir", zero_phase=True)
all_lfps.append(lfp)
return all_lfps, channels, electrode_df, channel_to_electrode_map, execution_time
def make_insert(self, key, all_lfps, channels, electrode_df, channel_to_electrode_map, execution_time):
self.insert1(
{
**key,
"lfp_sampling_rate": self.TARGET_SAMPLING_RATE,
"execution_duration": (
datetime.now(timezone.utc) - execution_time
).total_seconds()
/ 3600,
}
)
for ch_idx, lfp in zip(channels, all_lfps):
self.Trace.insert1(
{
**key,
"electrode_config_hash": electrode_df["electrode_config_hash"][0],
"probe_type": electrode_df["probe_type"][0],
"electrode": channel_to_electrode_map[ch_idx],
"lfp": lfp,
}
)
# ------------ Clustering --------------
@schema
class ClusteringMethod(dj.Lookup):
"""Kilosort clustering method.
Attributes:
clustering_method (foreign key, varchar(20) ): Kilosort clustering method.
clustering_methods_desc (varchar(1000) ): Additional description of the clustering method.
"""
definition = """
# Method for clustering
clustering_method: varchar(20)
---
clustering_method_desc: varchar(1000)
"""
contents = [
("kilosort2", "kilosort2 clustering method"),
("kilosort2.5", "kilosort2.5 clustering method"),
("kilosort3", "kilosort3 clustering method"),
]
@schema
class ClusteringParamSet(dj.Lookup):
"""Parameters to be used in clustering procedure for spike sorting.
Attributes:
paramset_idx (foreign key): Unique ID for the clustering parameter set.
ClusteringMethod (dict): ClusteringMethod primary key.
paramset_desc (varchar(128) ): Description of the clustering parameter set.
param_set_hash (uuid): UUID hash for the parameter set.
params (longblob): Set of clustering parameters.
"""
definition = """
# Parameter set to be used in a clustering procedure
paramset_idx: smallint
---
-> ClusteringMethod
paramset_desc: varchar(128)
param_set_hash: uuid
unique index (param_set_hash)
params: longblob # dictionary of all applicable parameters
"""
@classmethod
def insert_new_params(
cls,
clustering_method: str,
paramset_desc: str,
params: dict,
paramset_idx: int = None,
):
"""Inserts new parameters into the ClusteringParamSet table.
Args:
clustering_method (str): name of the clustering method.
paramset_desc (str): description of the parameter set
params (dict): clustering parameters
paramset_idx (int, optional): Unique parameter set ID. Defaults to None.
"""
if paramset_idx is None:
paramset_idx = (
dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0
) + 1
param_dict = {
"clustering_method": clustering_method,
"paramset_idx": paramset_idx,
"paramset_desc": paramset_desc,
"params": params,
"param_set_hash": dict_to_uuid(
{**params, "clustering_method": clustering_method}
),
}
param_query = cls & {"param_set_hash": param_dict["param_set_hash"]}
if param_query: # If the specified param-set already exists
existing_paramset_idx = param_query.fetch1("paramset_idx")
if (
existing_paramset_idx == paramset_idx
): # If the existing set has the same paramset_idx: job done
return
else: # If not same name: human error, trying to add the same paramset with different name
raise dj.DataJointError(
f"The specified param-set already exists"
f" - with paramset_idx: {existing_paramset_idx}"
)
else:
if {"paramset_idx": paramset_idx} in cls.proj():
raise dj.DataJointError(
f"The specified paramset_idx {paramset_idx} already exists,"
f" please pick a different one."
)
cls.insert1(param_dict)
@schema
class ClusterQualityLabel(dj.Lookup):
"""Quality label for each spike sorted cluster.
Attributes:
cluster_quality_label (foreign key, varchar(100) ): Cluster quality type.
cluster_quality_description (varchar(4000) ): Description of the cluster quality type.
"""
definition = """
# Quality
cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc.
---
cluster_quality_description: varchar(4000)
"""
contents = [
("good", "single unit"),
("ok", "probably a single unit, but could be contaminated"),
("mua", "multi-unit activity"),
("noise", "bad unit"),
("n.a.", "not available"),
]
@schema
class ClusteringTask(dj.Manual):
"""A clustering task to spike sort electrophysiology datasets.
Attributes:
EphysSession (foreign key): EphysSession primary key.
ClusteringParamSet (foreign key): ClusteringParamSet primary key.
clustering_outdir_dir (varchar (255) ): Relative path to output clustering results.
"""
definition = """
# Manual table for defining a clustering task ready to be run
-> EphysSession
-> ClusteringParamSet
---
clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory
"""
@property
def key_source(self):
return EphysSession & 'session_type IN ("spike_sorting", "both")'
@classmethod
def infer_output_dir(cls, key, relative=False, mkdir=False) -> pathlib.Path:
"""Infer output directory if it is not provided.
Args:
key (dict): ClusteringTask primary key.
Returns:
Expected clustering_output_dir based on the following convention:
processed_dir / subject_dir / {clustering_method}_{paramset_idx}
e.g.: sub4/sess1/kilosort2_0
"""
processed_dir = pathlib.Path(get_processed_root_data_dir())
exp_dir = find_full_path(get_ephys_root_data_dir(), get_organoid_directory(key))
session_time = "_".join(
[
key["start_time"].strftime("%Y%m%d%H%M"),
key["end_time"].strftime("%Y%m%d%H%M"),
]
)
session_dir = exp_dir / session_time / key["organoid_id"]
root_dir = find_root_directory(get_ephys_root_data_dir(), exp_dir)
method = (
(ClusteringParamSet * ClusteringMethod & key)
.fetch1("clustering_method")
.replace(".", "-")
)
output_dir = (
processed_dir
/ session_dir.relative_to(root_dir)
/ f'{method}_{key["paramset_idx"]}'
)
if mkdir:
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"{output_dir} created!")
return output_dir.relative_to(processed_dir) if relative else output_dir
@classmethod
def auto_generate_entries(cls, ephys_recording_key: dict, paramset_idx: int = 0):
"""Autogenerate entries based on a particular ephys recording.
Args:
ephys_recording_key (dict): EphysSession primary key.
paramset_idx (int, optional): Parameter index to use for clustering task. Defaults to 0.
"""
key = {**ephys_recording_key, "paramset_idx": paramset_idx}
processed_dir = get_processed_root_data_dir()
output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True)
cls.insert1(
{
**key,
"clustering_output_dir": output_dir.relative_to(
processed_dir
).as_posix(),
}
)
@schema
class Clustering(dj.Imported):
"""A processing table to handle each clustering task.
Attributes:
ClusteringTask (foreign key): ClusteringTask primary key.
clustering_time (datetime): Time when clustering results are generated.
package_version (varchar(16) ): Package version used for a clustering analysis.
"""
definition = """
# Clustering Procedure
-> ClusteringTask
---
clustering_time: datetime # time of generation of this set of clustering results
package_version='': varchar(16)
"""
def make(self, key):
"""This will be implemented via `ephys_sorter` schema with `si_spike_sorting` tables."""
pass
@schema
class CuratedClustering(dj.Imported):
"""Clustering results after curation.
Attributes:
Clustering (foreign key): Clustering primary key.
"""
definition = """
# Clustering results of the spike sorting step.
-> Clustering
"""
class Unit(dj.Part):
"""Single unit properties after clustering and curation.
Attributes:
CuratedClustering (foreign key): CuratedClustering primary key.
unit (int): Unique integer identifying a single unit.
probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
ClusteringQualityLabel (foreign key): CLusteringQualityLabel primary key.
spike_count (int): Number of spikes in this recording for this unit.
spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording.
spike_sites (longblob): Array of electrode associated with each spike.
spike_depths (longblob): Array of depths associated with each spike, relative to each spike.
"""
definition = """
# Properties of a given unit from a round of clustering (and curation)
-> master
unit: int
---
-> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit
-> ClusterQualityLabel
spike_count: int # how many spikes in this recording for this unit
spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
spike_sites : longblob # array of electrode associated with each spike
spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
"""
def make(self, key):
"""Automated population of Unit information."""
clustering_method, output_dir = (
ClusteringTask * ClusteringParamSet & key
).fetch1("clustering_method", "clustering_output_dir")
output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
# Get electrode & channel info
probe_info = (probe.Probe * EphysSessionProbe & key).fetch1()
electrode_config_key = (probe.ElectrodeConfig & probe_info).fetch1("KEY")
electrode_query = (
probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
& electrode_config_key
)
channel2electrode_map: dict[int, dict] = {
chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
} # e.g., {0: {'organoid_id': 'O09',
# Get sorter method and create output directory.
sorter_name = clustering_method.replace(".", "_")
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs
import spikeinterface as si
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
si_sorting = sorting_analyzer.sorting
# Find representative channel for each unit
unit_peak_channel: dict[int, np.ndarray] = (
si.ChannelSparsity.from_best_channels(
sorting_analyzer, 1, peak_sign="both"
).unit_id_to_channel_indices
)
unit_peak_channel: dict[int, int] = {
u: chn[0] for u, chn in unit_peak_channel.items()
}
spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
# {unit: spike_count}
# update channel2electrode_map to match with probe's channel index
channel2electrode_map = {
idx: channel2electrode_map[int(chn_idx)]
for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
}
# Get unit id to quality label mapping
cluster_quality_label_map = {
int(unit_id): (
si_sorting.get_unit_property(unit_id, "KSLabel")
if "KSLabel" in si_sorting.get_property_keys()
else "n.a."
)
for unit_id in si_sorting.unit_ids
}
spike_locations = sorting_analyzer.get_extension("spike_locations")
extremum_channel_inds = si.template_tools.get_template_extremum_channel(
sorting_analyzer, outputs="index"
)
spikes_df = pd.DataFrame(
sorting_analyzer.sorting.to_spike_vector(
extremum_channel_inds=extremum_channel_inds
)
)
units = []
for unit_idx, unit_id in enumerate(si_sorting.unit_ids):
unit_id = int(unit_id)
unit_spikes_df = spikes_df[spikes_df.unit_index == unit_idx]
spike_sites = np.array(
[
channel2electrode_map[chn_idx]["electrode"]
for chn_idx in unit_spikes_df.channel_index
]
)
unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index]
_, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates
spike_times = si_sorting.get_unit_spike_train(
unit_id, return_times=True
)
assert len(spike_times) == len(spike_sites) == len(spike_depths)
units.append(
{
**key,
**channel2electrode_map[unit_peak_channel[unit_id]],
"unit": unit_id,
"cluster_quality_label": cluster_quality_label_map[unit_id],
"spike_times": spike_times,
"spike_count": spike_count_dict[unit_id],
"spike_sites": spike_sites,
"spike_depths": spike_depths,
}
)
else: # read from kilosort outputs
raise NotImplementedError
self.insert1(key)
self.Unit.insert(units, ignore_extra_fields=True)
@schema
class WaveformSet(dj.Imported):
"""A set of spike waveforms for units out of a given CuratedClustering.
Attributes:
CuratedClustering (foreign key): CuratedClustering primary key.
"""
definition = """
# A set of spike waveforms for units out of a given CuratedClustering
-> CuratedClustering
"""
class PeakWaveform(dj.Part):
"""Mean waveform across spikes for a given unit.
Attributes:
WaveformSet (foreign key): WaveformSet primary key.
CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode.
"""
definition = """
# Mean waveform across spikes for a given unit at its representative electrode
-> master
-> CuratedClustering.Unit
---
peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode
"""
class Waveform(dj.Part):
"""Spike waveforms for a given unit.
Attributes:
WaveformSet (foreign key): WaveformSet primary key.
CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
waveform_mean (longblob): mean waveform across spikes of the unit in microvolts.
waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit.
"""
definition = """
# Spike waveforms and their mean across spikes for the given unit
-> master
-> CuratedClustering.Unit
-> probe.ElectrodeConfig.Electrode
---
waveform_mean: longblob # (uV) mean waveform across spikes of the given unit
waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit
"""
def make(self, key):
"""Populates waveform tables."""
clustering_method, output_dir = (
ClusteringTask * ClusteringParamSet & key
).fetch1("clustering_method", "clustering_output_dir")
output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
sorter_name = clustering_method.replace(".", "_")
# Get electrode & channel info
probe_info = (probe.Probe * EphysSessionProbe & key).fetch1()
electrode_config_key = (probe.ElectrodeConfig & probe_info).fetch1("KEY")
electrode_query = (
probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
& electrode_config_key
)
channel2electrode_map: dict[int, dict] = {
chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
} # e.g., {0: {'organoid_id': 'O09',
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
import spikeinterface as si
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
# Find representative channel for each unit
unit_peak_channel: dict[int, np.ndarray] = (
si.ChannelSparsity.from_best_channels(
sorting_analyzer, 1, peak_sign="both"
).unit_id_to_channel_indices
) # {unit: peak_channel_index}
unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()}
# update channel2electrode_map to match with probe's channel index
channel2electrode_map = {
idx: channel2electrode_map[int(chn_idx)]
for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
}
templates = sorting_analyzer.get_extension("templates")
def yield_unit_waveforms():
for unit in (CuratedClustering.Unit & key).fetch(
"KEY", order_by="unit"
):
# Get mean waveform for this unit from all channels - (sample x channel)
unit_waveforms = templates.get_unit_template(
unit_id=unit["unit"], operator="average"
)
unit_peak_waveform = {
**unit,
"peak_electrode_waveform": unit_waveforms[
:, unit_peak_channel[unit["unit"]]
],
}
unit_electrode_waveforms = [
{
**unit,
**channel2electrode_map[chn_idx],
"waveform_mean": unit_waveforms[:, chn_idx],
}
for chn_idx in channel2electrode_map
]
yield unit_peak_waveform, unit_electrode_waveforms
else: # read from kilosort outputs (ecephys pipeline)
raise NotImplementedError
# insert waveform on a per-unit basis to mitigate potential memory issue
self.insert1(key)
for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms():
if unit_peak_waveform:
self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True)
if unit_electrode_waveforms:
self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True)
@schema
class QualityMetrics(dj.Imported):
"""Clustering and waveform quality metrics.
Attributes:
CuratedClustering (foreign key): CuratedClustering primary key.
"""
definition = """
# Clusters and waveforms metrics
-> CuratedClustering
"""
class Cluster(dj.Part):
"""Cluster metrics for a unit.
Attributes:
QualityMetrics (foreign key): QualityMetrics primary key.
CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
firing_rate (float): Firing rate of the unit.
snr (float): Signal-to-noise ratio for a unit.
presence_ratio (float): Fraction of time where spikes are present.
isi_violation (float): rate of ISI violation as a fraction of overall rate.
number_violation (int): Total ISI violations.
amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram.
isolation_distance (float): Distance to nearest cluster.
l_ratio (float): Amount of empty space between a cluster and other spikes in dataset.
d_prime (float): Classification accuracy based on LDA.
nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster.
nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster.
silhouette_core (float): Maximum change in spike depth throughout recording.
cumulative_drift (float): Cumulative change in spike depth throughout recording.
contamination_rate (float): Frequency of spikes in the refractory period.
"""
definition = """
# Cluster metrics for a particular unit
-> master
-> CuratedClustering.Unit
---
firing_rate=null: float # (Hz) firing rate for a unit
snr=null: float # signal-to-noise ratio for a unit
presence_ratio=null: float # fraction of time in which spikes are present
isi_violation=null: float # rate of ISI violation as a fraction of overall rate
number_violation=null: int # total number of ISI violations
amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram
isolation_distance=null: float # distance to nearest cluster in Mahalanobis space
l_ratio=null: float #
d_prime=null: float # Classification accuracy based on LDA
nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster
nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster
silhouette_score=null: float # Standard metric for cluster overlap
max_drift=null: float # Maximum change in spike depth throughout recording
cumulative_drift=null: float # Cumulative change in spike depth throughout recording
contamination_rate=null: float #
"""
class Waveform(dj.Part):
"""Waveform metrics for a particular unit.
Attributes:
QualityMetrics (foreign key): QualityMetrics primary key.