Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions tensorflow_datasets/core/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,25 @@ def create_thumbnail(
if use_colormap: # Apply the colormap first as it modify the shape/dtype
ex = apply_colormap(ex)

_, _, c = ex.shape
c = ex.shape[-1] if ex.ndim == 3 else 1
postprocess = _postprocess_noop
if c == 1:
ex = ex.squeeze(axis=-1)
mode = 'L'
if ex.dtype == np.uint16:
mode = 'I;16'
postprocess = _postprocess_convert_rgb
else:
mode = 'L'
elif ex.dtype == np.uint16:
mode = 'I;16'
# PIL.Image.fromarray doesn't support uint16 for >1 channels.
# https://github.com/python-pillow/Pillow/blob/11.0.0/src/PIL/Image.py#L3225
# Scale to 8-bit for visualization.
ex = (ex / 257).astype(np.uint8)
mode = None
postprocess = _postprocess_convert_rgb
else:
mode = None
postprocess = _postprocess_convert_rgb
img = PIL_Image.fromarray(ex, mode=mode)
img = postprocess(img)
if default_dimensions:
Expand Down
22 changes: 22 additions & 0 deletions tensorflow_datasets/core/utils/image_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os

import numpy as np
import pytest
import tensorflow as tf
from tensorflow_datasets import testing
from tensorflow_datasets.core.utils import image_utils
Expand Down Expand Up @@ -94,5 +95,26 @@ def test_apply_colormap(self):
)


@pytest.mark.parametrize(
'dtype,fill_value,expected_mode,channels,expected_shape',
[
(np.uint8, 255, 'L', 1, (32, 32)),
(np.uint8, 255, 'RGB', 3, (32, 32, 3)),
(np.uint16, 65535, 'RGB', 1, (32, 32, 3)),
(np.uint16, 65535, 'RGB', 3, (32, 32, 3)),
],
)
def test_create_thumbnail_conversion(
dtype, fill_value, expected_mode, channels, expected_shape
):
img = np.full((32, 32, channels), fill_value=fill_value, dtype=dtype)
thumbnail = image_utils.create_thumbnail(
img, use_colormap=False, default_dimensions=True
)
assert thumbnail.mode == expected_mode
expected_array = np.full(expected_shape, fill_value=255, dtype=np.uint8)
np.testing.assert_array_equal(np.asarray(thumbnail), expected_array)


if __name__ == '__main__':
testing.test_main()
Loading