Skip to content

Commit 3af60ad

Browse files
Vectorize Scale.val_in_range
1 parent a94275a commit 3af60ad

2 files changed

Lines changed: 59 additions & 21 deletions

File tree

lib/matplotlib/scale.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
""" # noqa: E501
3131

3232
import inspect
33-
import math
3433
import textwrap
3534
from functools import wraps
3635

@@ -119,17 +118,20 @@ def val_in_range(self, val):
119118
"""
120119
Return whether the value(s) are within the valid range for this scale.
121120
122-
This method is a generic implementation. Subclasses may implement more
123-
efficient solutions for their domain.
124-
"""
125-
try:
126-
if not math.isfinite(val):
127-
return False
121+
Accepts a scalar or array-like ``val``. For a scalar, returns a
122+
Python ``bool``. For an array, returns a bool ndarray of the same
123+
shape. This is a generic implementation, and subclasses may implement
124+
more efficient solutions for their domain.
125+
"""
126+
arr = np.asarray(val)
127+
with np.errstate(invalid='ignore'):
128+
try:
129+
vmin, vmax = self.limit_range_for_scale(arr, arr, minpos=1e-300)
130+
except (TypeError, ValueError):
131+
result = np.zeros(arr.shape, dtype=bool)
128132
else:
129-
vmin, vmax = self.limit_range_for_scale(val, val, minpos=1e-300)
130-
return vmin == val and vmax == val
131-
except (TypeError, ValueError):
132-
return False
133+
result = np.isfinite(arr) & (vmin == arr) & (vmax == arr)
134+
return bool(result) if arr.ndim == 0 else result
133135

134136

135137
def _make_axis_parameter_optional(init_func):
@@ -219,11 +221,13 @@ def get_transform(self):
219221

220222
def val_in_range(self, val):
221223
"""
222-
Return whether the value is within the valid range for this scale.
224+
Return whether the value(s) are within the valid range for this scale.
223225
224226
This is True for all values, except +-inf and NaN.
225227
"""
226-
return math.isfinite(val)
228+
arr = np.asarray(val)
229+
result = np.isfinite(arr)
230+
return bool(result) if arr.ndim == 0 else result
227231

228232

229233
class FuncTransform(Transform):
@@ -431,11 +435,14 @@ def limit_range_for_scale(self, vmin, vmax, minpos):
431435

432436
def val_in_range(self, val):
433437
"""
434-
Return whether the value is within the valid range for this scale.
438+
Return whether the value(s) are within the valid range for this scale.
435439
436440
This is True for value(s) > 0 except +inf and NaN.
437441
"""
438-
return math.isfinite(val) and val > 0
442+
arr = np.asarray(val)
443+
with np.errstate(invalid='ignore'):
444+
result = np.isfinite(arr) & (arr > 0)
445+
return bool(result) if arr.ndim == 0 else result
439446

440447

441448
class FuncScaleLog(LogScale):
@@ -625,11 +632,13 @@ def get_transform(self):
625632

626633
def val_in_range(self, val):
627634
"""
628-
Return whether the value is within the valid range for this scale.
635+
Return whether the value(s) are within the valid range for this scale.
629636
630637
This is True for all values, except +-inf and NaN.
631638
"""
632-
return math.isfinite(val)
639+
arr = np.asarray(val)
640+
result = np.isfinite(arr)
641+
return bool(result) if arr.ndim == 0 else result
633642

634643

635644
class AsinhTransform(Transform):
@@ -759,11 +768,13 @@ def set_default_locators_and_formatters(self, axis):
759768

760769
def val_in_range(self, val):
761770
"""
762-
Return whether the value is within the valid range for this scale.
771+
Return whether the value(s) are within the valid range for this scale.
763772
764773
This is True for all values, except +-inf and NaN.
765774
"""
766-
return math.isfinite(val)
775+
arr = np.asarray(val)
776+
result = np.isfinite(arr)
777+
return bool(result) if arr.ndim == 0 else result
767778

768779

769780
class LogitTransform(Transform):
@@ -880,11 +891,14 @@ def limit_range_for_scale(self, vmin, vmax, minpos):
880891

881892
def val_in_range(self, val):
882893
"""
883-
Return whether the value is within the valid range for this scale.
894+
Return whether the value(s) are within the valid range for this scale.
884895
885896
This is True for value(s) which are between 0 and 1 (excluded).
886897
"""
887-
return 0 < val < 1
898+
arr = np.asarray(val)
899+
with np.errstate(invalid='ignore'):
900+
result = (0 < arr) & (arr < 1)
901+
return bool(result) if arr.ndim == 0 else result
888902

889903

890904
_scale_mapping = {

lib/matplotlib/tests/test_scale.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,27 @@ def test_val_in_range_base_fallback():
477477
assert s.val_in_range(np.nan) is False
478478
assert s.val_in_range(np.inf) is False
479479
assert s.val_in_range(-np.inf) is False
480+
481+
482+
def test_val_in_range_array():
483+
# Vectorized: scalar in -> scalar bool, array in -> bool array.
484+
arr = np.array([1.0, -1.0, 0.0, np.nan, np.inf, 5.0])
485+
cases = {
486+
'linear': [True, True, True, False, False, True],
487+
'log': [True, False, False, False, False, True],
488+
'symlog': [True, True, True, False, False, True],
489+
'asinh': [True, True, True, False, False, True],
490+
}
491+
for name, expected in cases.items():
492+
s = mscale._scale_mapping[name](axis=None)
493+
np.testing.assert_array_equal(s.val_in_range(arr), expected)
494+
495+
s = mscale._scale_mapping['logit'](axis=None)
496+
np.testing.assert_array_equal(
497+
s.val_in_range(np.array([0.1, 0.5, 0.0, 1.0, -0.1, 1.1])),
498+
[True, True, False, False, False, False])
499+
500+
# 2D shape is preserved.
501+
out = mscale._scale_mapping['log'](axis=None).val_in_range(
502+
np.array([[1.0, -1.0], [0.5, np.nan]]))
503+
np.testing.assert_array_equal(out, [[True, False], [True, False]])

0 commit comments

Comments
 (0)