Skip to content

Commit 8ecb31c

Browse files
Merge pull request #67 from AllenNeuralDynamics/session_scroll_improve
Session scroll improve
2 parents 37a752b + 9e9f8b7 commit 8ecb31c

1 file changed

Lines changed: 42 additions & 37 deletions

File tree

src/aind_dynamic_foraging_basic_analysis/plot/plot_session_scroller.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
3838
nwb, an nwb like object that contains attributes: df_events, session_id
3939
and optionally contains attributes fip_df, df_licks
4040
41-
ax is a pyplot figure axis. If None, a new figure is created
41+
ax is a list of pyplot figure axis. The list must be the correct length of
42+
1 + len(metrics) + len(fip). If provided, fig must also be provided.
43+
If None, a new figure is created.
44+
45+
46+
fig is a pyplot figure container. If provided, ax must also be provided.
47+
If None, a new figure is created.
4248
4349
metrics, list of metrics to plot. Each metric must be a column of
4450
nwb.df_trials
@@ -112,13 +118,12 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
112118
fig.subplots_adjust(hspace=0)
113119
if num_plots == 1:
114120
ax = [ax]
115-
ax = np.flip(ax)
116121

117122
xmin = df_events.iloc[0]["timestamps"]
118123
x_first = xmin
119124
x_last = df_events.iloc[-1]["timestamps"]
120125
xmax = xmin + 20
121-
ax[0].set_xlim(xmin, xmax)
126+
ax[-1].set_xlim(xmin, xmax)
122127

123128
params = {
124129
"left_lick_bottom": 0,
@@ -159,7 +164,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
159164
if ("bouts" not in plot_list) or (df_licks is None):
160165
left_licks = df_events.query('event == "left_lick_time"')
161166
left_times = left_licks.timestamps.values
162-
ax[0].vlines(
167+
ax[-1].vlines(
163168
left_times,
164169
params["left_lick_bottom"],
165170
params["left_lick_top"],
@@ -170,7 +175,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
170175

171176
right_licks = df_events.query('event == "right_lick_time"')
172177
right_times = right_licks.timestamps.values
173-
ax[0].vlines(
178+
ax[-1].vlines(
174179
right_times,
175180
params["right_lick_bottom"],
176181
params["right_lick_top"],
@@ -188,15 +193,15 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
188193
bout_right_licks = df_licks.query(
189194
'(bout_number == @b)&(event=="right_lick_time")'
190195
).timestamps.values
191-
ax[0].vlines(
196+
ax[-1].vlines(
192197
bout_left_licks,
193198
params["left_lick_bottom"],
194199
params["left_lick_top"],
195200
alpha=1,
196201
linewidth=2,
197202
color=cmap(np.mod(b, 20)),
198203
)
199-
ax[0].vlines(
204+
ax[-1].vlines(
200205
bout_right_licks,
201206
params["right_lick_bottom"],
202207
params["right_lick_top"],
@@ -210,7 +215,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
210215
left_rewarded_licks = df_licks.query(
211216
'(event == "left_lick_time")&(rewarded)'
212217
).timestamps.values
213-
ax[0].plot(
218+
ax[-1].plot(
214219
left_rewarded_licks,
215220
[params["left_lick_top"]] * len(left_rewarded_licks),
216221
"ro",
@@ -221,7 +226,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
221226
right_rewarded_licks = df_licks.query(
222227
'(event == "right_lick_time")&(rewarded)'
223228
).timestamps.values
224-
ax[0].plot(
229+
ax[-1].plot(
225230
right_rewarded_licks, [params["right_lick_bottom"]] * len(right_rewarded_licks), "ro"
226231
)
227232

@@ -230,7 +235,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
230235
left_cue_licks = df_licks.query(
231236
'(event == "left_lick_time")&(cue_response)'
232237
).timestamps.values
233-
ax[0].plot(
238+
ax[-1].plot(
234239
left_cue_licks,
235240
[
236241
params["left_lick_bottom"]
@@ -243,7 +248,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
243248
right_cue_licks = df_licks.query(
244249
'(event == "right_lick_time")&(cue_response)'
245250
).timestamps.values
246-
ax[0].plot(
251+
ax[-1].plot(
247252
right_cue_licks,
248253
[
249254
params["right_lick_bottom"]
@@ -257,10 +262,10 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
257262
# Plot baiting
258263
bait_right = df_trials.query("bait_right")["goCue_start_time_in_session"].values
259264
bait_left = df_trials.query("bait_left")["goCue_start_time_in_session"].values
260-
ax[0].plot(
265+
ax[-1].plot(
261266
bait_right, [params["right_lick_top"] - 0.05] * len(bait_right), "ms", label="baited"
262267
)
263-
ax[0].plot(bait_left, [params["left_lick_bottom"] + 0.05] * len(bait_left), "ms")
268+
ax[-1].plot(bait_left, [params["left_lick_bottom"] + 0.05] * len(bait_left), "ms")
264269

265270
if "lick artifacts" in plot_list:
266271
artifacts_right = df_licks.query('likely_artifact and (event=="right_lick_time")')[
@@ -269,14 +274,14 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
269274
artifacts_left = df_licks.query('likely_artifact and (event=="left_lick_time")')[
270275
"timestamps"
271276
].values
272-
ax[0].plot(
277+
ax[-1].plot(
273278
artifacts_right,
274279
[params["right_lick_top"]] * len(artifacts_right),
275280
"d",
276281
color="darkorange",
277282
label="lick artifact",
278283
)
279-
ax[0].plot(
284+
ax[-1].plot(
280285
artifacts_left,
281286
[params["left_lick_bottom"]] * len(artifacts_left),
282287
"d",
@@ -285,7 +290,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
285290

286291
left_reward_deliverys = df_events.query('event == "left_reward_delivery_time"')
287292
left_times = left_reward_deliverys.timestamps.values
288-
ax[0].vlines(
293+
ax[-1].vlines(
289294
left_times,
290295
params["left_reward_bottom"],
291296
params["left_reward_top"],
@@ -296,7 +301,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
296301

297302
right_reward_deliverys = df_events.query('event == "right_reward_delivery_time"')
298303
right_times = right_reward_deliverys.timestamps.values
299-
ax[0].vlines(
304+
ax[-1].vlines(
300305
right_times,
301306
params["right_reward_bottom"],
302307
params["right_reward_top"],
@@ -307,7 +312,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
307312

308313
if "manual rewards" in plot_list:
309314
manual_left_times = left_reward_deliverys.query('data == "manual"').timestamps.values
310-
ax[0].vlines(
315+
ax[-1].vlines(
311316
manual_left_times,
312317
params["left_reward_bottom"],
313318
params["left_reward_top"],
@@ -317,7 +322,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
317322
label="manual reward",
318323
)
319324
manual_right_times = right_reward_deliverys.query('data == "manual"').timestamps.values
320-
ax[0].vlines(
325+
ax[-1].vlines(
321326
manual_right_times,
322327
params["right_reward_bottom"],
323328
params["right_reward_top"],
@@ -327,7 +332,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
327332
)
328333
if "auto rewards" in plot_list:
329334
auto_left_times = left_reward_deliverys.query('data == "auto"').timestamps.values
330-
ax[0].vlines(
335+
ax[-1].vlines(
331336
auto_left_times,
332337
params["left_reward_bottom"],
333338
params["left_reward_top"],
@@ -337,7 +342,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
337342
label="auto reward",
338343
)
339344
auto_right_times = right_reward_deliverys.query('data == "auto"').timestamps.values
340-
ax[0].vlines(
345+
ax[-1].vlines(
341346
auto_right_times,
342347
params["right_reward_bottom"],
343348
params["right_reward_top"],
@@ -349,7 +354,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
349354
go_cues = df_events.query('event == "goCue_start_time"')
350355
go_cue_times = go_cues.timestamps.values
351356
if "go cue" in plot_list:
352-
ax[0].vlines(
357+
ax[-1].vlines(
353358
go_cue_times,
354359
params["left_lick_bottom"],
355360
params["left_reward_top"],
@@ -358,7 +363,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
358363
color="b",
359364
label="go cue",
360365
)
361-
ax[0].vlines(
366+
ax[-1].vlines(
362367
go_cue_times,
363368
params["right_reward_bottom"],
364369
params["right_lick_top"],
@@ -368,41 +373,41 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
368373
)
369374

370375
# plot metrics
371-
ax[0].axhline(params["right_lick_top"], color="k", linewidth=0.5, alpha=0.25)
376+
ax[-1].axhline(params["right_lick_top"], color="k", linewidth=0.5, alpha=0.25)
372377
go_cue_times_doubled = np.repeat(go_cue_times, 2)[1:]
373378

374379
pR = params["probs_bottom"] + df_trials["reward_probabilityR"] / 4
375380
pR = np.repeat(pR, 2)[:-1]
376-
ax[0].fill_between(go_cue_times_doubled, params["probs_bottom"], pR, color="r", alpha=0.4)
381+
ax[-1].fill_between(go_cue_times_doubled, params["probs_bottom"], pR, color="r", alpha=0.4)
377382

378383
pL = params["probs_bottom"] - df_trials["reward_probabilityL"] / 4
379384
pL = np.repeat(pL, 2)[:-1]
380385

381-
ax[0].fill_between(go_cue_times_doubled, pL, params["probs_bottom"], color="b", alpha=0.4)
386+
ax[-1].fill_between(go_cue_times_doubled, pL, params["probs_bottom"], color="b", alpha=0.4)
382387

383388
# plot metrics if they are available
384389
for index, metric in enumerate(metrics):
385-
plot_metric(df_trials, go_cue_times, metric, ax[index + 1])
390+
plot_metric(df_trials, go_cue_times, metric, ax[len(fip) + index])
386391

387392
# plot fip if they are available:
388393
for index, f in enumerate(fip):
389-
plot_fip(fip_df, f, ax[index + 1 + len(metrics)])
394+
plot_fip(fip_df, f, ax[index])
390395

391396
# Clean up plot
392397
if len(plot_list) > 0:
393-
ax[0].legend(framealpha=1, loc="lower left", reverse=True)
394-
ax[0].set_xlabel("time (s)", fontsize=STYLE["axis_fontsize"])
395-
ax[0].set_ylim(0, 1.5)
396-
ax[0].set_yticks(yticks)
397-
ax[0].set_yticklabels(ylabels, fontsize=STYLE["axis_ticks_fontsize"])
398-
for tick, color in zip(ax[0].get_yticklabels(), ycolors):
398+
ax[-1].legend(framealpha=1, loc="lower left", reverse=True)
399+
ax[-1].set_xlabel("time (s)", fontsize=STYLE["axis_fontsize"])
400+
ax[-1].set_ylim(0, 1.5)
401+
ax[-1].set_yticks(yticks)
402+
ax[-1].set_yticklabels(ylabels, fontsize=STYLE["axis_ticks_fontsize"])
403+
for tick, color in zip(ax[-1].get_yticklabels(), ycolors):
399404
tick.set_color(color)
400405

401406
for my_ax in ax:
402407
my_ax.spines["top"].set_visible(False)
403408
my_ax.spines["right"].set_visible(False)
404409

405-
ax[-1].set_title(nwb.session_id)
410+
ax[0].set_title(nwb.session_id)
406411

407412
if num_plots == 1:
408413
plt.tight_layout()
@@ -411,7 +416,7 @@ def on_key_press(event):
411416
"""
412417
Define interaction resonsivity
413418
"""
414-
x = ax[0].get_xlim()
419+
x = ax[-1].get_xlim()
415420
xmin = x[0]
416421
xmax = x[1]
417422
xStep = (xmax - xmin) / 4
@@ -430,7 +435,7 @@ def on_key_press(event):
430435
elif event.key == "h":
431436
xmin = x_first
432437
xmax = x_last
433-
ax[0].set_xlim(xmin, xmax)
438+
ax[-1].set_xlim(xmin, xmax)
434439
plt.draw()
435440

436441
kpid = fig.canvas.mpl_connect("key_press_event", on_key_press) # noqa: F841

0 commit comments

Comments
 (0)