Skip to content

Commit 5f4e0a1

Browse files
author
The TensorFlow Datasets Authors
committed
Fix PIL image conversion for uint16 multi-channel arrays.
PiperOrigin-RevId: 911926026
1 parent b4c81e2 commit 5f4e0a1

2 files changed

Lines changed: 34 additions & 3 deletions

File tree

tensorflow_datasets/core/utils/image_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,25 @@ def create_thumbnail(
177177
if use_colormap: # Apply the colormap first as it modify the shape/dtype
178178
ex = apply_colormap(ex)
179179

180-
_, _, c = ex.shape
180+
c = ex.shape[-1] if ex.ndim == 3 else 1
181181
postprocess = _postprocess_noop
182182
if c == 1:
183183
ex = ex.squeeze(axis=-1)
184-
mode = 'L'
184+
if ex.dtype == np.uint16:
185+
mode = 'I;16'
186+
postprocess = _postprocess_convert_rgb
187+
else:
188+
mode = 'L'
185189
elif ex.dtype == np.uint16:
186-
mode = 'I;16'
190+
# PIL.Image.fromarray doesn't support uint16 for >1 channels.
191+
# https://github.com/python-pillow/Pillow/blob/11.0.0/src/PIL/Image.py#L3225
192+
# Scale to 8-bit for visualization.
193+
ex = (ex / 257).astype(np.uint8)
194+
mode = None
187195
postprocess = _postprocess_convert_rgb
188196
else:
189197
mode = None
198+
postprocess = _postprocess_convert_rgb
190199
img = PIL_Image.fromarray(ex, mode=mode)
191200
img = postprocess(img)
192201
if default_dimensions:

tensorflow_datasets/core/utils/image_utils_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919

2020
import numpy as np
21+
import pytest
2122
import tensorflow as tf
2223
from tensorflow_datasets import testing
2324
from tensorflow_datasets.core.utils import image_utils
@@ -94,5 +95,26 @@ def test_apply_colormap(self):
9495
)
9596

9697

98+
@pytest.mark.parametrize(
99+
'dtype,fill_value,expected_mode,channels,expected_shape',
100+
[
101+
(np.uint8, 255, 'L', 1, (32, 32)),
102+
(np.uint8, 255, 'RGB', 3, (32, 32, 3)),
103+
(np.uint16, 65535, 'RGB', 1, (32, 32, 3)),
104+
(np.uint16, 65535, 'RGB', 3, (32, 32, 3)),
105+
],
106+
)
107+
def test_create_thumbnail_conversion(
108+
dtype, fill_value, expected_mode, channels, expected_shape
109+
):
110+
img = np.full((32, 32, channels), fill_value=fill_value, dtype=dtype)
111+
thumbnail = image_utils.create_thumbnail(
112+
img, use_colormap=False, default_dimensions=True
113+
)
114+
assert thumbnail.mode == expected_mode
115+
expected_array = np.full(expected_shape, fill_value=255, dtype=np.uint8)
116+
np.testing.assert_array_equal(np.asarray(thumbnail), expected_array)
117+
118+
97119
if __name__ == '__main__':
98120
testing.test_main()

0 commit comments

Comments
 (0)