Skip to content

Commit 1398794

Browse files
committed
fix: fixed plot_conditions
1 parent 1182473 commit 1398794

1 file changed

Lines changed: 17 additions & 13 deletions

File tree

eegnb/analysis/utils.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,17 @@ def _bootstrap(data, n_boot: int, ci: float):
4545
return (s1, s2)
4646

4747

48-
def _tsplotboot(ax, data, n_boot: int, ci: float, **kw):
48+
def _tsplotboot(ax, data, time: list, n_boot: int, ci: float, color):
4949
"""From: https://stackoverflow.com/a/47582329/965332"""
50-
x = np.arange(data.shape[1])
50+
# Time forms the xaxis of the plot
51+
if time is None:
52+
x = np.arange(data.shape[1])
53+
else:
54+
x = np.asarray(time)
5155
est = np.mean(data, axis=0)
5256
cis = _bootstrap(data, n_boot, ci)
53-
ax.fill_between(x, cis[0], cis[1], alpha=0.2, **kw)
54-
ax.plot(x, est, **kw)
57+
ax.fill_between(x, cis[0], cis[1], alpha=0.2, color=color)
58+
ax.plot(x, est, color=color)
5559
ax.margins(x=0)
5660

5761

@@ -175,7 +179,9 @@ def load_data(
175179
site = "*"
176180

177181
data_path = (
178-
_get_recording_dir(device_name, experiment, subject_str, session_str, site, data_dir)
182+
_get_recording_dir(
183+
device_name, experiment, subject_str, session_str, site, data_dir
184+
)
179185
/ "*.csv"
180186
)
181187
fnames = glob(str(data_path))
@@ -216,7 +222,8 @@ def plot_conditions(
216222
ylim=(-6, 6),
217223
diff_waveform=(1, 2),
218224
channel_count=4,
219-
channel_order=None):
225+
channel_order=None,
226+
):
220227
"""Plot ERP conditions.
221228
Args:
222229
epochs (mne.epochs): EEG epochs
@@ -242,10 +249,9 @@ def plot_conditions(
242249
"""
243250

244251
if channel_order:
245-
channel_order = np.array(channel_order)
252+
channel_order = np.array(channel_order)
246253
else:
247-
channel_order = np.array(range(channel_count))
248-
254+
channel_order = np.array(range(channel_count))
249255

250256
if isinstance(conditions, dict):
251257
conditions = OrderedDict(conditions)
@@ -255,7 +261,7 @@ def plot_conditions(
255261

256262
X = epochs.get_data() * 1e6
257263

258-
X = X[:,channel_order]
264+
X = X[:, channel_order]
259265

260266
times = epochs.times
261267
y = pd.Series(epochs.events[:, -1])
@@ -270,16 +276,14 @@ def plot_conditions(
270276
plot_axes.append(axes[axis_x, axis_y])
271277
axes = plot_axes
272278

273-
print("\n\n\n\n\n\n")
274-
275279
for ch in range(channel_count):
276280
for cond, color in zip(conditions.values(), palette):
277281
y_cond = y.isin(cond)
278-
# make X[y_cond, ch] one-dimensional
279282
X_cond = X[y_cond, ch]
280283
_tsplotboot(
281284
ax=axes[ch],
282285
data=X_cond,
286+
time=times,
283287
color=color,
284288
n_boot=n_boot,
285289
ci=ci,

0 commit comments

Comments
 (0)