@@ -38,7 +38,13 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
3838 nwb, an nwb like object that contains attributes: df_events, session_id
3939 and optionally contains attributes fip_df, df_licks
4040
41- ax is a pyplot figure axis. If None, a new figure is created
41+ ax is a list of pyplot figure axis. The list must be the correct length of
42+ 1 + len(metrics) + len(fip). If provided, fig must also be provided.
43+ If None, a new figure is created.
44+
45+
46+ fig is a pyplot figure container. If provided, ax must also be provided.
47+ If None, a new figure is created.
4248
4349 metrics, list of metrics to plot. Each metric must be a column of
4450 nwb.df_trials
@@ -112,13 +118,12 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
112118 fig .subplots_adjust (hspace = 0 )
113119 if num_plots == 1 :
114120 ax = [ax ]
115- ax = np .flip (ax )
116121
117122 xmin = df_events .iloc [0 ]["timestamps" ]
118123 x_first = xmin
119124 x_last = df_events .iloc [- 1 ]["timestamps" ]
120125 xmax = xmin + 20
121- ax [0 ].set_xlim (xmin , xmax )
126+ ax [- 1 ].set_xlim (xmin , xmax )
122127
123128 params = {
124129 "left_lick_bottom" : 0 ,
@@ -159,7 +164,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
159164 if ("bouts" not in plot_list ) or (df_licks is None ):
160165 left_licks = df_events .query ('event == "left_lick_time"' )
161166 left_times = left_licks .timestamps .values
162- ax [0 ].vlines (
167+ ax [- 1 ].vlines (
163168 left_times ,
164169 params ["left_lick_bottom" ],
165170 params ["left_lick_top" ],
@@ -170,7 +175,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
170175
171176 right_licks = df_events .query ('event == "right_lick_time"' )
172177 right_times = right_licks .timestamps .values
173- ax [0 ].vlines (
178+ ax [- 1 ].vlines (
174179 right_times ,
175180 params ["right_lick_bottom" ],
176181 params ["right_lick_top" ],
@@ -188,15 +193,15 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
188193 bout_right_licks = df_licks .query (
189194 '(bout_number == @b)&(event=="right_lick_time")'
190195 ).timestamps .values
191- ax [0 ].vlines (
196+ ax [- 1 ].vlines (
192197 bout_left_licks ,
193198 params ["left_lick_bottom" ],
194199 params ["left_lick_top" ],
195200 alpha = 1 ,
196201 linewidth = 2 ,
197202 color = cmap (np .mod (b , 20 )),
198203 )
199- ax [0 ].vlines (
204+ ax [- 1 ].vlines (
200205 bout_right_licks ,
201206 params ["right_lick_bottom" ],
202207 params ["right_lick_top" ],
@@ -210,7 +215,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
210215 left_rewarded_licks = df_licks .query (
211216 '(event == "left_lick_time")&(rewarded)'
212217 ).timestamps .values
213- ax [0 ].plot (
218+ ax [- 1 ].plot (
214219 left_rewarded_licks ,
215220 [params ["left_lick_top" ]] * len (left_rewarded_licks ),
216221 "ro" ,
@@ -221,7 +226,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
221226 right_rewarded_licks = df_licks .query (
222227 '(event == "right_lick_time")&(rewarded)'
223228 ).timestamps .values
224- ax [0 ].plot (
229+ ax [- 1 ].plot (
225230 right_rewarded_licks , [params ["right_lick_bottom" ]] * len (right_rewarded_licks ), "ro"
226231 )
227232
@@ -230,7 +235,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
230235 left_cue_licks = df_licks .query (
231236 '(event == "left_lick_time")&(cue_response)'
232237 ).timestamps .values
233- ax [0 ].plot (
238+ ax [- 1 ].plot (
234239 left_cue_licks ,
235240 [
236241 params ["left_lick_bottom" ]
@@ -243,7 +248,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
243248 right_cue_licks = df_licks .query (
244249 '(event == "right_lick_time")&(cue_response)'
245250 ).timestamps .values
246- ax [0 ].plot (
251+ ax [- 1 ].plot (
247252 right_cue_licks ,
248253 [
249254 params ["right_lick_bottom" ]
@@ -257,10 +262,10 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
257262 # Plot baiting
258263 bait_right = df_trials .query ("bait_right" )["goCue_start_time_in_session" ].values
259264 bait_left = df_trials .query ("bait_left" )["goCue_start_time_in_session" ].values
260- ax [0 ].plot (
265+ ax [- 1 ].plot (
261266 bait_right , [params ["right_lick_top" ] - 0.05 ] * len (bait_right ), "ms" , label = "baited"
262267 )
263- ax [0 ].plot (bait_left , [params ["left_lick_bottom" ] + 0.05 ] * len (bait_left ), "ms" )
268+ ax [- 1 ].plot (bait_left , [params ["left_lick_bottom" ] + 0.05 ] * len (bait_left ), "ms" )
264269
265270 if "lick artifacts" in plot_list :
266271 artifacts_right = df_licks .query ('likely_artifact and (event=="right_lick_time")' )[
@@ -269,14 +274,14 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
269274 artifacts_left = df_licks .query ('likely_artifact and (event=="left_lick_time")' )[
270275 "timestamps"
271276 ].values
272- ax [0 ].plot (
277+ ax [- 1 ].plot (
273278 artifacts_right ,
274279 [params ["right_lick_top" ]] * len (artifacts_right ),
275280 "d" ,
276281 color = "darkorange" ,
277282 label = "lick artifact" ,
278283 )
279- ax [0 ].plot (
284+ ax [- 1 ].plot (
280285 artifacts_left ,
281286 [params ["left_lick_bottom" ]] * len (artifacts_left ),
282287 "d" ,
@@ -285,7 +290,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
285290
286291 left_reward_deliverys = df_events .query ('event == "left_reward_delivery_time"' )
287292 left_times = left_reward_deliverys .timestamps .values
288- ax [0 ].vlines (
293+ ax [- 1 ].vlines (
289294 left_times ,
290295 params ["left_reward_bottom" ],
291296 params ["left_reward_top" ],
@@ -296,7 +301,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
296301
297302 right_reward_deliverys = df_events .query ('event == "right_reward_delivery_time"' )
298303 right_times = right_reward_deliverys .timestamps .values
299- ax [0 ].vlines (
304+ ax [- 1 ].vlines (
300305 right_times ,
301306 params ["right_reward_bottom" ],
302307 params ["right_reward_top" ],
@@ -307,7 +312,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
307312
308313 if "manual rewards" in plot_list :
309314 manual_left_times = left_reward_deliverys .query ('data == "manual"' ).timestamps .values
310- ax [0 ].vlines (
315+ ax [- 1 ].vlines (
311316 manual_left_times ,
312317 params ["left_reward_bottom" ],
313318 params ["left_reward_top" ],
@@ -317,7 +322,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
317322 label = "manual reward" ,
318323 )
319324 manual_right_times = right_reward_deliverys .query ('data == "manual"' ).timestamps .values
320- ax [0 ].vlines (
325+ ax [- 1 ].vlines (
321326 manual_right_times ,
322327 params ["right_reward_bottom" ],
323328 params ["right_reward_top" ],
@@ -327,7 +332,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
327332 )
328333 if "auto rewards" in plot_list :
329334 auto_left_times = left_reward_deliverys .query ('data == "auto"' ).timestamps .values
330- ax [0 ].vlines (
335+ ax [- 1 ].vlines (
331336 auto_left_times ,
332337 params ["left_reward_bottom" ],
333338 params ["left_reward_top" ],
@@ -337,7 +342,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
337342 label = "auto reward" ,
338343 )
339344 auto_right_times = right_reward_deliverys .query ('data == "auto"' ).timestamps .values
340- ax [0 ].vlines (
345+ ax [- 1 ].vlines (
341346 auto_right_times ,
342347 params ["right_reward_bottom" ],
343348 params ["right_reward_top" ],
@@ -349,7 +354,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
349354 go_cues = df_events .query ('event == "goCue_start_time"' )
350355 go_cue_times = go_cues .timestamps .values
351356 if "go cue" in plot_list :
352- ax [0 ].vlines (
357+ ax [- 1 ].vlines (
353358 go_cue_times ,
354359 params ["left_lick_bottom" ],
355360 params ["left_reward_top" ],
@@ -358,7 +363,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
358363 color = "b" ,
359364 label = "go cue" ,
360365 )
361- ax [0 ].vlines (
366+ ax [- 1 ].vlines (
362367 go_cue_times ,
363368 params ["right_reward_bottom" ],
364369 params ["right_lick_top" ],
@@ -368,41 +373,41 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
368373 )
369374
370375 # plot metrics
371- ax [0 ].axhline (params ["right_lick_top" ], color = "k" , linewidth = 0.5 , alpha = 0.25 )
376+ ax [- 1 ].axhline (params ["right_lick_top" ], color = "k" , linewidth = 0.5 , alpha = 0.25 )
372377 go_cue_times_doubled = np .repeat (go_cue_times , 2 )[1 :]
373378
374379 pR = params ["probs_bottom" ] + df_trials ["reward_probabilityR" ] / 4
375380 pR = np .repeat (pR , 2 )[:- 1 ]
376- ax [0 ].fill_between (go_cue_times_doubled , params ["probs_bottom" ], pR , color = "r" , alpha = 0.4 )
381+ ax [- 1 ].fill_between (go_cue_times_doubled , params ["probs_bottom" ], pR , color = "r" , alpha = 0.4 )
377382
378383 pL = params ["probs_bottom" ] - df_trials ["reward_probabilityL" ] / 4
379384 pL = np .repeat (pL , 2 )[:- 1 ]
380385
381- ax [0 ].fill_between (go_cue_times_doubled , pL , params ["probs_bottom" ], color = "b" , alpha = 0.4 )
386+ ax [- 1 ].fill_between (go_cue_times_doubled , pL , params ["probs_bottom" ], color = "b" , alpha = 0.4 )
382387
383388 # plot metrics if they are available
384389 for index , metric in enumerate (metrics ):
385- plot_metric (df_trials , go_cue_times , metric , ax [index + 1 ])
390+ plot_metric (df_trials , go_cue_times , metric , ax [len ( fip ) + index ])
386391
387392 # plot fip if they are available:
388393 for index , f in enumerate (fip ):
389- plot_fip (fip_df , f , ax [index + 1 + len ( metrics ) ])
394+ plot_fip (fip_df , f , ax [index ])
390395
391396 # Clean up plot
392397 if len (plot_list ) > 0 :
393- ax [0 ].legend (framealpha = 1 , loc = "lower left" , reverse = True )
394- ax [0 ].set_xlabel ("time (s)" , fontsize = STYLE ["axis_fontsize" ])
395- ax [0 ].set_ylim (0 , 1.5 )
396- ax [0 ].set_yticks (yticks )
397- ax [0 ].set_yticklabels (ylabels , fontsize = STYLE ["axis_ticks_fontsize" ])
398- for tick , color in zip (ax [0 ].get_yticklabels (), ycolors ):
398+ ax [- 1 ].legend (framealpha = 1 , loc = "lower left" , reverse = True )
399+ ax [- 1 ].set_xlabel ("time (s)" , fontsize = STYLE ["axis_fontsize" ])
400+ ax [- 1 ].set_ylim (0 , 1.5 )
401+ ax [- 1 ].set_yticks (yticks )
402+ ax [- 1 ].set_yticklabels (ylabels , fontsize = STYLE ["axis_ticks_fontsize" ])
403+ for tick , color in zip (ax [- 1 ].get_yticklabels (), ycolors ):
399404 tick .set_color (color )
400405
401406 for my_ax in ax :
402407 my_ax .spines ["top" ].set_visible (False )
403408 my_ax .spines ["right" ].set_visible (False )
404409
405- ax [- 1 ].set_title (nwb .session_id )
410+ ax [0 ].set_title (nwb .session_id )
406411
407412 if num_plots == 1 :
408413 plt .tight_layout ()
@@ -411,7 +416,7 @@ def on_key_press(event):
411416 """
412417 Define interaction resonsivity
413418 """
414- x = ax [0 ].get_xlim ()
419+ x = ax [- 1 ].get_xlim ()
415420 xmin = x [0 ]
416421 xmax = x [1 ]
417422 xStep = (xmax - xmin ) / 4
@@ -430,7 +435,7 @@ def on_key_press(event):
430435 elif event .key == "h" :
431436 xmin = x_first
432437 xmax = x_last
433- ax [0 ].set_xlim (xmin , xmax )
438+ ax [- 1 ].set_xlim (xmin , xmax )
434439 plt .draw ()
435440
436441 kpid = fig .canvas .mpl_connect ("key_press_event" , on_key_press ) # noqa: F841
0 commit comments