Skip to content

Commit 310bdf7

Browse files
committed
Add option to return color scale information.
1 parent 32e613f commit 310bdf7

2 files changed

Lines changed: 22 additions & 2 deletions

File tree

src/brainplotlib/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import numpy as np
3-
from matplotlib import cm
3+
from matplotlib import cm, colors
44

55

66
DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data')
@@ -100,7 +100,7 @@ def prepare_data(*values):
100100
return new_values
101101

102102

103-
def brain_plot(*values, vmax, vmin, cmap=None, medial_wall_color=[0.8, 0.8, 0.8, 1.0], background_color=[1.0, 1.0, 1.0, 0.0]):
103+
def brain_plot(*values, vmax, vmin, cmap=None, medial_wall_color=[0.8, 0.8, 0.8, 1.0], background_color=[1.0, 1.0, 1.0, 0.0], return_scale=False):
104104
values = prepare_data(*values)
105105
nan_mask = np.isnan(values)
106106
r = (vmax - values) / (vmax - vmin)
@@ -110,4 +110,8 @@ def brain_plot(*values, vmax, vmin, cmap=None, medial_wall_color=[0.8, 0.8, 0.8,
110110
c[nan_mask] = medial_wall_color
111111
c = np.concatenate([c, [_[:c.shape[1]] for _ in [medial_wall_color, background_color]]], axis=0)
112112
img = c[PLOT_MAPPING]
113+
if return_scale:
114+
norm = colors.Normalize(vmax=vmax, vmin=vmin, clip=True)
115+
scale = cm.ScalarMappable(norm=norm, cmap=cmap)
116+
return img, scale
113117
return img

tests/test_plotting.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,19 @@ def test_jet_cmap(self, tmp_path):
122122
if importlib.util.find_spec('cv2'):
123123
import cv2
124124
cv2.imwrite(os.path.join(tmp_path, 'test_jet_cmap.png'), np.round(img * 255).astype(np.uint8)[:, :, [2, 1, 0, 3]])
125+
126+
127+
class TestScale:
128+
def test_color_scale(self, tmp_path):
129+
rng = np.random.default_rng()
130+
values = rng.random((588, )), rng.random((587, ))
131+
img, scale = brain_plot(*values, vmax=1, vmin=0, cmap='viridis', return_scale=True)
132+
from matplotlib import cm
133+
assert isinstance(scale, cm.ScalarMappable)
134+
assert img.shape in [(1560, 1728, 4), (1560, 1728, 3)]
135+
assert img.dtype == np.float64
136+
assert np.all(img <= 1)
137+
assert np.all(img >= 0)
138+
if importlib.util.find_spec('cv2'):
139+
import cv2
140+
cv2.imwrite(os.path.join(tmp_path, 'test_colorscale.png'), np.round(img * 255).astype(np.uint8)[:, :, [2, 1, 0, 3]])

0 commit comments

Comments
 (0)