Skip to content

Commit 11b90f7

Browse files
authored
Merge pull request #5445 from FBumann/main
feat: Add facet_row support to px.imshow
2 parents db6aa61 + f1380c3 commit 11b90f7

File tree

3 files changed

+167
-27
lines changed

3 files changed

+167
-27
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).
44

55
## Unreleased
66

7+
### Added
8+
- Add `facet_row` support to `px.imshow` for creating subplots along an additional dimension [[#5445](https://github.com/plotly/plotly.py/pull/5445)]
9+
710
### Fixed
811
- Update `numpy.percentile` syntax to stop using deprecated alias [[5483](https://github.com/plotly/plotly.py/pull/5483)], with thanks to @Mr-Neutr0n for the contribution!
912
- `numpy` with a version less than 1.22 is no longer supported.

plotly/express/_imshow.py

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def imshow(
6363
y=None,
6464
animation_frame=None,
6565
facet_col=None,
66+
facet_row=None,
6667
facet_col_wrap=None,
6768
facet_col_spacing=None,
6869
facet_row_spacing=None,
@@ -128,10 +129,15 @@ def imshow(
128129
axis number along which the image array is sliced to create a facetted plot.
129130
If `img` is an xarray, `facet_col` can be the name of one the dimensions.
130131
132+
facet_row: int or str, optional (default None)
133+
axis number along which the image array is sliced to create a vertically
134+
facetted plot. If `img` is an xarray, `facet_row` can be the name of one
135+
the dimensions.
136+
131137
facet_col_wrap: int
132138
Maximum number of facet columns. Wraps the column variable at this width,
133139
so that the column facets span multiple rows.
134-
Ignored if `facet_col` is None.
140+
Ignored if `facet_col` is None or if `facet_row` is set.
135141
136142
facet_col_spacing: float between 0 and 1
137143
Spacing between facet columns, in paper units. Default is 0.02.
@@ -235,30 +241,46 @@ def imshow(
235241
args = locals()
236242
apply_default_cascade(args, constructor=None)
237243
labels = labels.copy()
238-
nslices_facet = 1
244+
nslices_facet_col = 1
245+
nslices_facet_row = 1
246+
facet_col_slices = None
247+
facet_row_slices = None
239248
if facet_col is not None:
240249
if isinstance(facet_col, str):
241250
facet_col = img.dims.index(facet_col)
242-
nslices_facet = img.shape[facet_col]
243-
facet_slices = range(nslices_facet)
244-
ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet
251+
nslices_facet_col = img.shape[facet_col]
252+
facet_col_slices = range(nslices_facet_col)
253+
if facet_row is not None:
254+
if isinstance(facet_row, str):
255+
facet_row = img.dims.index(facet_row)
256+
nslices_facet_row = img.shape[facet_row]
257+
facet_row_slices = range(nslices_facet_row)
258+
# ignore facet_col_wrap when facet_row is set
259+
if facet_row is not None:
260+
facet_col_wrap = None
261+
262+
if facet_col_wrap is None:
263+
ncols = nslices_facet_col
264+
nrows = nslices_facet_row
265+
else:
266+
ncols = min(int(facet_col_wrap), nslices_facet_col)
245267
nrows = (
246-
nslices_facet // ncols + 1
247-
if nslices_facet % ncols
248-
else nslices_facet // ncols
268+
nslices_facet_col // ncols + 1
269+
if nslices_facet_col % ncols
270+
else nslices_facet_col // ncols
249271
)
250-
else:
251-
nrows = 1
252-
ncols = 1
253272
if animation_frame is not None:
254273
if isinstance(animation_frame, str):
255274
animation_frame = img.dims.index(animation_frame)
256275
nslices_animation = img.shape[animation_frame]
257276
animation_slices = range(nslices_animation)
258-
slice_dimensions = (facet_col is not None) + (
259-
animation_frame is not None
260-
) # 0, 1, or 2
261-
facet_label = None
277+
slice_dimensions = (
278+
(facet_col is not None)
279+
+ (facet_row is not None)
280+
+ (animation_frame is not None)
281+
) # 0, 1, 2, or 3
282+
facet_col_label = None
283+
facet_row_label = None
262284
animation_label = None
263285
img_is_xarray = False
264286
# ----- Define x and y, set labels if img is an xarray -------------------
@@ -267,9 +289,13 @@ def imshow(
267289
img_is_xarray = True
268290
pop_indexes = []
269291
if facet_col is not None:
270-
facet_slices = img.coords[img.dims[facet_col]].values
292+
facet_col_slices = img.coords[img.dims[facet_col]].values
271293
pop_indexes.append(facet_col)
272-
facet_label = img.dims[facet_col]
294+
facet_col_label = img.dims[facet_col]
295+
if facet_row is not None:
296+
facet_row_slices = img.coords[img.dims[facet_row]].values
297+
pop_indexes.append(facet_row)
298+
facet_row_label = img.dims[facet_row]
273299
if animation_frame is not None:
274300
animation_slices = img.coords[img.dims[animation_frame]].values
275301
pop_indexes.append(animation_frame)
@@ -295,7 +321,9 @@ def imshow(
295321
if labels.get("animation_frame", None) is None:
296322
labels["animation_frame"] = animation_label
297323
if labels.get("facet_col", None) is None:
298-
labels["facet_col"] = facet_label
324+
labels["facet_col"] = facet_col_label
325+
if labels.get("facet_row", None) is None:
326+
labels["facet_row"] = facet_row_label
299327
if labels.get("color", None) is None:
300328
labels["color"] = xarray.plot.utils.label_from_attrs(img)
301329
labels["color"] = labels["color"].replace("\n", "<br>")
@@ -331,12 +359,20 @@ def imshow(
331359

332360
# --------------- Starting from here img is always a numpy array --------
333361
img = np.asanyarray(img)
334-
# Reshape array so that animation dimension comes first, then facets, then images
362+
# Reshape array so that animation dimension comes first, then facet_row, then facet_col, then images
363+
# We move axes to front in reverse order so each axis ends up at position 0 in the final order
335364
if facet_col is not None:
336365
img = np.moveaxis(img, facet_col, 0)
337366
if animation_frame is not None and animation_frame < facet_col:
338367
animation_frame += 1
368+
if facet_row is not None and facet_row < facet_col:
369+
facet_row += 1
339370
facet_col = True
371+
if facet_row is not None:
372+
img = np.moveaxis(img, facet_row, 0)
373+
if animation_frame is not None and animation_frame < facet_row:
374+
animation_frame += 1
375+
facet_row = True
340376
if animation_frame is not None:
341377
img = np.moveaxis(img, animation_frame, 0)
342378
animation_frame = True
@@ -348,8 +384,10 @@ def imshow(
348384
iterables = ()
349385
if animation_frame is not None:
350386
iterables += (range(nslices_animation),)
387+
if facet_row is not None:
388+
iterables += (range(nslices_facet_row),)
351389
if facet_col is not None:
352-
iterables += (range(nslices_facet),)
390+
iterables += (range(nslices_facet_col),)
353391

354392
# Default behaviour of binary_string: True for RGB images, False for 2D
355393
if binary_string is None:
@@ -535,19 +573,25 @@ def imshow(
535573
raise ValueError(
536574
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
537575
"An image of shape %s was provided. "
538-
"Alternatively, 3- or 4-D single or multichannel datasets can be "
539-
"visualized using the `facet_col` or/and `animation_frame` arguments."
576+
"Alternatively, 3-, 4-, or 5-D single or multichannel datasets can be "
577+
"visualized using the `facet_col`, `facet_row`, and/or `animation_frame` arguments."
540578
% str(img.shape)
541579
)
542580

543581
# Now build figure
544582
col_labels = []
583+
row_labels = []
545584
if facet_col is not None:
546585
slice_label = (
547586
"facet_col" if labels.get("facet_col") is None else labels["facet_col"]
548587
)
549-
col_labels = [f"{slice_label}={i}" for i in facet_slices]
550-
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
588+
col_labels = [f"{slice_label}={i}" for i in facet_col_slices]
589+
if facet_row is not None:
590+
slice_label = (
591+
"facet_row" if labels.get("facet_row") is None else labels["facet_row"]
592+
)
593+
row_labels = [f"{slice_label}={i}" for i in facet_row_slices]
594+
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, row_labels)
551595
for attr_name in ["height", "width"]:
552596
if args[attr_name]:
553597
layout[attr_name] = args[attr_name]
@@ -556,15 +600,22 @@ def imshow(
556600
elif args["template"].layout.margin.t is None:
557601
layout["margin"] = {"t": 60}
558602

603+
nslices_facets = nslices_facet_row * nslices_facet_col
559604
frame_list = []
560605
for index, trace in enumerate(traces):
561-
if (facet_col and index < nrows * ncols) or index == 0:
562-
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
606+
if ((facet_col or facet_row) and index < nrows * ncols) or index == 0:
607+
# Calculate row and col position
608+
# index is ordered by (facet_row, facet_col) from itertools.product
609+
# When facet_col_wrap is used (and facet_row is None), traces are laid out
610+
# across wrapped columns, so we use ncols for the calculation
611+
row_idx = index // ncols
612+
col_idx = index % ncols
613+
fig.add_trace(trace, row=nrows - row_idx, col=col_idx + 1)
563614
if animation_frame is not None:
564615
for i, index in zip(range(nslices_animation), animation_slices):
565616
frame_list.append(
566617
dict(
567-
data=traces[nslices_facet * i : nslices_facet * (i + 1)],
618+
data=traces[nslices_facets * i : nslices_facets * (i + 1)],
568619
layout=layout,
569620
name=str(index),
570621
)

tests/test_optional/test_px/test_imshow.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,3 +450,89 @@ def test_animation_and_facet(binary_string):
450450
nslices = img.shape[0]
451451
assert len(fig.frames) == nslices
452452
assert len(fig.data) == img.shape[1]
453+
454+
455+
@pytest.mark.parametrize("facet_row", [0, 1, 2, -1])
456+
@pytest.mark.parametrize("binary_string", [False, True])
457+
def test_facet_row(facet_row, binary_string):
458+
img = np.random.randint(255, size=(10, 9, 8))
459+
fig = px.imshow(
460+
img,
461+
facet_row=facet_row,
462+
binary_string=binary_string,
463+
)
464+
nslices = img.shape[facet_row]
465+
nrows = nslices
466+
ncols = 1
467+
nmax = ncols * nrows
468+
assert "yaxis%d" % nmax in fig.layout
469+
assert "yaxis%d" % (nmax + 1) not in fig.layout
470+
assert len(fig.data) == nslices
471+
472+
473+
@pytest.mark.parametrize("binary_string", [False, True])
474+
def test_facet_row_and_col(binary_string):
475+
img = np.random.randint(255, size=(4, 3, 9, 8))
476+
fig = px.imshow(
477+
img,
478+
facet_row=0,
479+
facet_col=1,
480+
binary_string=binary_string,
481+
)
482+
nrows = img.shape[0]
483+
ncols = img.shape[1]
484+
nmax = ncols * nrows
485+
assert "yaxis%d" % nmax in fig.layout
486+
assert "yaxis%d" % (nmax + 1) not in fig.layout
487+
assert len(fig.data) == nrows * ncols
488+
489+
490+
@pytest.mark.parametrize("binary_string", [False, True])
491+
def test_animation_facet_row_and_col(binary_string):
492+
img = np.random.randint(255, size=(5, 4, 3, 9, 8)).astype(np.uint8)
493+
fig = px.imshow(
494+
img,
495+
animation_frame=0,
496+
facet_row=1,
497+
facet_col=2,
498+
binary_string=binary_string,
499+
)
500+
nslices_animation = img.shape[0]
501+
nrows = img.shape[1]
502+
ncols = img.shape[2]
503+
assert len(fig.frames) == nslices_animation
504+
assert len(fig.data) == nrows * ncols
505+
506+
507+
def test_imshow_xarray_facet_row():
508+
img = np.random.random((3, 4, 5))
509+
da = xr.DataArray(
510+
img, dims=["row_dim", "dim_1", "dim_2"], coords={"row_dim": ["A", "B", "C"]}
511+
)
512+
fig = px.imshow(da, facet_row="row_dim")
513+
# Dimensions are used for axis labels and coordinates
514+
assert fig.layout.xaxis.title.text == "dim_2"
515+
assert fig.layout.yaxis.title.text == "dim_1"
516+
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))
517+
assert len(fig.data) == 3
518+
# Check row labels are present
519+
annotations = [a.text for a in fig.layout.annotations]
520+
assert any("row_dim=A" in a for a in annotations)
521+
522+
523+
def test_imshow_xarray_facet_row_and_col():
524+
img = np.random.random((3, 4, 5, 6))
525+
da = xr.DataArray(
526+
img,
527+
dims=["row_dim", "col_dim", "dim_y", "dim_x"],
528+
coords={"row_dim": ["R1", "R2", "R3"], "col_dim": ["C1", "C2", "C3", "C4"]},
529+
)
530+
fig = px.imshow(da, facet_row="row_dim", facet_col="col_dim")
531+
# Dimensions are used for axis labels and coordinates
532+
assert fig.layout.xaxis.title.text == "dim_x"
533+
assert fig.layout.yaxis.title.text == "dim_y"
534+
assert len(fig.data) == 3 * 4
535+
# Check labels are present
536+
annotations = [a.text for a in fig.layout.annotations]
537+
assert any("row_dim=R1" in a for a in annotations)
538+
assert any("col_dim=C1" in a for a in annotations)

0 commit comments

Comments
 (0)