22from copy import deepcopy
33import math
44import logging
5+ import sys
56from collections import OrderedDict
67from glob import glob
7- from typing import Union , List , Dict
8+ from typing import Union , List
89from time import sleep , time
9- from numpy .core .fromnumeric import std
1010
1111import pandas as pd
1212import numpy as np
1616from mne .channels import make_standard_montage
1717from mne .filter import create_filter
1818from matplotlib import pyplot as plt
19+ from scipy import stats
1920from scipy .signal import lfilter , lfilter_zi
2021
2122from eegnb import _get_recording_dir
2223from eegnb .devices .eeg import EEG
2324from eegnb .devices .utils import EEG_INDICES , SAMPLE_FREQS
2425
25-
2626
2727# this should probably not be done here
2828sns .set_context ("talk" )
3232logger = logging .getLogger (__name__ )
3333
3434
35+ def _bootstrap (data , n_boot : int , ci : float ):
36+ """From: https://stackoverflow.com/a/47582329/965332"""
37+ boot_dist = []
38+ for i in range (int (n_boot )):
39+ resampler = np .random .randint (0 , data .shape [0 ], data .shape [0 ])
40+ sample = data .take (resampler , axis = 0 )
41+ boot_dist .append (np .mean (sample , axis = 0 ))
42+ b = np .array (boot_dist )
43+ s1 = np .apply_along_axis (stats .scoreatpercentile , 0 , b , 50 - ci / 2 )
44+ s2 = np .apply_along_axis (stats .scoreatpercentile , 0 , b , 50 + ci / 2 )
45+ return (s1 , s2 )
46+
47+
48+ def _tsplotboot (ax , data , time : list , n_boot : int , ci : float , color ):
49+ """From: https://stackoverflow.com/a/47582329/965332"""
50+ # Time forms the xaxis of the plot
51+ if time is None :
52+ x = np .arange (data .shape [1 ])
53+ else :
54+ x = np .asarray (time )
55+ est = np .mean (data , axis = 0 )
56+ cis = _bootstrap (data , n_boot , ci )
57+ ax .fill_between (x , cis [0 ], cis [1 ], alpha = 0.2 , color = color )
58+ ax .plot (x , est , color = color )
59+ ax .margins (x = 0 )
60+
61+
3562def load_csv_as_raw (
3663 fnames : List [str ],
3764 sfreq : float ,
@@ -152,7 +179,9 @@ def load_data(
152179 site = "*"
153180
154181 data_path = (
155- _get_recording_dir (device_name , experiment , subject_str , session_str , site , data_dir )
182+ _get_recording_dir (
183+ device_name , experiment , subject_str , session_str , site , data_dir
184+ )
156185 / "*.csv"
157186 )
158187 fnames = glob (str (data_path ))
@@ -193,7 +222,8 @@ def plot_conditions(
193222 ylim = (- 6 , 6 ),
194223 diff_waveform = (1 , 2 ),
195224 channel_count = 4 ,
196- channel_order = None ):
225+ channel_order = None ,
226+ ):
197227 """Plot ERP conditions.
198228 Args:
199229 epochs (mne.epochs): EEG epochs
@@ -219,10 +249,9 @@ def plot_conditions(
219249 """
220250
221251 if channel_order :
222- channel_order = np .array (channel_order )
252+ channel_order = np .array (channel_order )
223253 else :
224- channel_order = np .array (range (channel_count ))
225-
254+ channel_order = np .array (range (channel_count ))
226255
227256 if isinstance (conditions , dict ):
228257 conditions = OrderedDict (conditions )
@@ -232,7 +261,7 @@ def plot_conditions(
232261
233262 X = epochs .get_data () * 1e6
234263
235- X = X [:,channel_order ]
264+ X = X [:, channel_order ]
236265
237266 times = epochs .times
238267 y = pd .Series (epochs .events [:, - 1 ])
@@ -249,13 +278,15 @@ def plot_conditions(
249278
250279 for ch in range (channel_count ):
251280 for cond , color in zip (conditions .values (), palette ):
252- sns .tsplot (
253- X [y .isin (cond ), ch ],
281+ y_cond = y .isin (cond )
282+ X_cond = X [y_cond , ch ]
283+ _tsplotboot (
284+ ax = axes [ch ],
285+ data = X_cond ,
254286 time = times ,
255287 color = color ,
256288 n_boot = n_boot ,
257289 ci = ci ,
258- ax = axes [ch ],
259290 )
260291
261292 if diff_waveform :
0 commit comments