@@ -63,6 +63,7 @@ def imshow(
6363 y = None ,
6464 animation_frame = None ,
6565 facet_col = None ,
66+ facet_row = None ,
6667 facet_col_wrap = None ,
6768 facet_col_spacing = None ,
6869 facet_row_spacing = None ,
@@ -128,10 +129,15 @@ def imshow(
128129 axis number along which the image array is sliced to create a facetted plot.
129130 If `img` is an xarray, `facet_col` can be the name of one the dimensions.
130131
132+ facet_row: int or str, optional (default None)
133+ axis number along which the image array is sliced to create a vertically
134+ facetted plot. If `img` is an xarray, `facet_row` can be the name of one
135+ the dimensions.
136+
131137 facet_col_wrap: int
132138 Maximum number of facet columns. Wraps the column variable at this width,
133139 so that the column facets span multiple rows.
134- Ignored if `facet_col` is None.
140+ Ignored if `facet_col` is None or if `facet_row` is set .
135141
136142 facet_col_spacing: float between 0 and 1
137143 Spacing between facet columns, in paper units. Default is 0.02.
@@ -235,30 +241,46 @@ def imshow(
235241 args = locals ()
236242 apply_default_cascade (args , constructor = None )
237243 labels = labels .copy ()
238- nslices_facet = 1
244+ nslices_facet_col = 1
245+ nslices_facet_row = 1
246+ facet_col_slices = None
247+ facet_row_slices = None
239248 if facet_col is not None :
240249 if isinstance (facet_col , str ):
241250 facet_col = img .dims .index (facet_col )
242- nslices_facet = img .shape [facet_col ]
243- facet_slices = range (nslices_facet )
244- ncols = int (facet_col_wrap ) if facet_col_wrap is not None else nslices_facet
251+ nslices_facet_col = img .shape [facet_col ]
252+ facet_col_slices = range (nslices_facet_col )
253+ if facet_row is not None :
254+ if isinstance (facet_row , str ):
255+ facet_row = img .dims .index (facet_row )
256+ nslices_facet_row = img .shape [facet_row ]
257+ facet_row_slices = range (nslices_facet_row )
258+ # ignore facet_col_wrap when facet_row is set
259+ if facet_row is not None :
260+ facet_col_wrap = None
261+
262+ if facet_col_wrap is None :
263+ ncols = nslices_facet_col
264+ nrows = nslices_facet_row
265+ else :
266+ ncols = min (int (facet_col_wrap ), nslices_facet_col )
245267 nrows = (
246- nslices_facet // ncols + 1
247- if nslices_facet % ncols
248- else nslices_facet // ncols
268+ nslices_facet_col // ncols + 1
269+ if nslices_facet_col % ncols
270+ else nslices_facet_col // ncols
249271 )
250- else :
251- nrows = 1
252- ncols = 1
253272 if animation_frame is not None :
254273 if isinstance (animation_frame , str ):
255274 animation_frame = img .dims .index (animation_frame )
256275 nslices_animation = img .shape [animation_frame ]
257276 animation_slices = range (nslices_animation )
258- slice_dimensions = (facet_col is not None ) + (
259- animation_frame is not None
260- ) # 0, 1, or 2
261- facet_label = None
277+ slice_dimensions = (
278+ (facet_col is not None )
279+ + (facet_row is not None )
280+ + (animation_frame is not None )
281+ ) # 0, 1, 2, or 3
282+ facet_col_label = None
283+ facet_row_label = None
262284 animation_label = None
263285 img_is_xarray = False
264286 # ----- Define x and y, set labels if img is an xarray -------------------
@@ -267,9 +289,13 @@ def imshow(
267289 img_is_xarray = True
268290 pop_indexes = []
269291 if facet_col is not None :
270- facet_slices = img .coords [img .dims [facet_col ]].values
292+ facet_col_slices = img .coords [img .dims [facet_col ]].values
271293 pop_indexes .append (facet_col )
272- facet_label = img .dims [facet_col ]
294+ facet_col_label = img .dims [facet_col ]
295+ if facet_row is not None :
296+ facet_row_slices = img .coords [img .dims [facet_row ]].values
297+ pop_indexes .append (facet_row )
298+ facet_row_label = img .dims [facet_row ]
273299 if animation_frame is not None :
274300 animation_slices = img .coords [img .dims [animation_frame ]].values
275301 pop_indexes .append (animation_frame )
@@ -295,7 +321,9 @@ def imshow(
295321 if labels .get ("animation_frame" , None ) is None :
296322 labels ["animation_frame" ] = animation_label
297323 if labels .get ("facet_col" , None ) is None :
298- labels ["facet_col" ] = facet_label
324+ labels ["facet_col" ] = facet_col_label
325+ if labels .get ("facet_row" , None ) is None :
326+ labels ["facet_row" ] = facet_row_label
299327 if labels .get ("color" , None ) is None :
300328 labels ["color" ] = xarray .plot .utils .label_from_attrs (img )
301329 labels ["color" ] = labels ["color" ].replace ("\n " , "<br>" )
@@ -331,12 +359,20 @@ def imshow(
331359
332360 # --------------- Starting from here img is always a numpy array --------
333361 img = np .asanyarray (img )
334- # Reshape array so that animation dimension comes first, then facets, then images
362+ # Reshape array so that animation dimension comes first, then facet_row, then facet_col, then images
363+ # We move axes to front in reverse order so each axis ends up at position 0 in the final order
335364 if facet_col is not None :
336365 img = np .moveaxis (img , facet_col , 0 )
337366 if animation_frame is not None and animation_frame < facet_col :
338367 animation_frame += 1
368+ if facet_row is not None and facet_row < facet_col :
369+ facet_row += 1
339370 facet_col = True
371+ if facet_row is not None :
372+ img = np .moveaxis (img , facet_row , 0 )
373+ if animation_frame is not None and animation_frame < facet_row :
374+ animation_frame += 1
375+ facet_row = True
340376 if animation_frame is not None :
341377 img = np .moveaxis (img , animation_frame , 0 )
342378 animation_frame = True
@@ -348,8 +384,10 @@ def imshow(
348384 iterables = ()
349385 if animation_frame is not None :
350386 iterables += (range (nslices_animation ),)
387+ if facet_row is not None :
388+ iterables += (range (nslices_facet_row ),)
351389 if facet_col is not None :
352- iterables += (range (nslices_facet ),)
390+ iterables += (range (nslices_facet_col ),)
353391
354392 # Default behaviour of binary_string: True for RGB images, False for 2D
355393 if binary_string is None :
@@ -535,19 +573,25 @@ def imshow(
535573 raise ValueError (
536574 "px.imshow only accepts 2D single-channel, RGB or RGBA images. "
537575 "An image of shape %s was provided. "
538- "Alternatively, 3- or 4 -D single or multichannel datasets can be "
539- "visualized using the `facet_col` or/ and `animation_frame` arguments."
576+ "Alternatively, 3-, 4-, or 5 -D single or multichannel datasets can be "
577+ "visualized using the `facet_col`, `facet_row`, and/or `animation_frame` arguments."
540578 % str (img .shape )
541579 )
542580
543581 # Now build figure
544582 col_labels = []
583+ row_labels = []
545584 if facet_col is not None :
546585 slice_label = (
547586 "facet_col" if labels .get ("facet_col" ) is None else labels ["facet_col" ]
548587 )
549- col_labels = [f"{ slice_label } ={ i } " for i in facet_slices ]
550- fig = init_figure (args , "xy" , [], nrows , ncols , col_labels , [])
588+ col_labels = [f"{ slice_label } ={ i } " for i in facet_col_slices ]
589+ if facet_row is not None :
590+ slice_label = (
591+ "facet_row" if labels .get ("facet_row" ) is None else labels ["facet_row" ]
592+ )
593+ row_labels = [f"{ slice_label } ={ i } " for i in facet_row_slices ]
594+ fig = init_figure (args , "xy" , [], nrows , ncols , col_labels , row_labels )
551595 for attr_name in ["height" , "width" ]:
552596 if args [attr_name ]:
553597 layout [attr_name ] = args [attr_name ]
@@ -556,15 +600,22 @@ def imshow(
556600 elif args ["template" ].layout .margin .t is None :
557601 layout ["margin" ] = {"t" : 60 }
558602
603+ nslices_facets = nslices_facet_row * nslices_facet_col
559604 frame_list = []
560605 for index , trace in enumerate (traces ):
561- if (facet_col and index < nrows * ncols ) or index == 0 :
562- fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
606+ if ((facet_col or facet_row ) and index < nrows * ncols ) or index == 0 :
607+ # Calculate row and col position
608+ # index is ordered by (facet_row, facet_col) from itertools.product
609+ # When facet_col_wrap is used (and facet_row is None), traces are laid out
610+ # across wrapped columns, so we use ncols for the calculation
611+ row_idx = index // ncols
612+ col_idx = index % ncols
613+ fig .add_trace (trace , row = nrows - row_idx , col = col_idx + 1 )
563614 if animation_frame is not None :
564615 for i , index in zip (range (nslices_animation ), animation_slices ):
565616 frame_list .append (
566617 dict (
567- data = traces [nslices_facet * i : nslices_facet * (i + 1 )],
618+ data = traces [nslices_facets * i : nslices_facets * (i + 1 )],
568619 layout = layout ,
569620 name = str (index ),
570621 )
0 commit comments