Skip to content

Commit e88307b

Browse files
committed
Add generic float/complex to numpy_scalar<> tests
1 parent 65dc777 commit e88307b

2 files changed

Lines changed: 95 additions & 60 deletions

File tree

tests/test_numpy_scalars.cpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
BSD-style license that can be found in the LICENSE file.
88
*/
99

10+
#include <complex>
1011
#include <cstdint>
1112
#include <string>
1213
#include <utility>
@@ -16,29 +17,44 @@
1617

1718
namespace py = pybind11;
1819

19-
template<typename T>
20-
void register_test(py::module& m, const char *name) {
20+
template<typename T, typename F>
21+
void register_test(py::module& m, const char *name, F&& func) {
2122
m.def("test_numpy_scalars", [=](py::numpy_scalar<T> v) {
22-
return std::make_tuple(name, py::make_scalar(static_cast<T>(v.value + 1)));
23+
return std::make_tuple(name, py::make_scalar(static_cast<T>(func(v.value))));
2324
}, py::arg("x"));
2425
m.def((std::string("test_") + name).c_str(), [=](py::numpy_scalar<T> v) {
25-
return std::make_tuple(name, py::make_scalar(static_cast<T>(v.value + 1)));
26+
return std::make_tuple(name, py::make_scalar(static_cast<T>(func(v.value))));
2627
}, py::arg("x"));
2728
}
2829

30+
template<typename T>
31+
struct add {
32+
T x;
33+
add(T x) : x(x) {}
34+
T operator()(T y) const { return static_cast<T>(x + y); }
35+
};
36+
2937
TEST_SUBMODULE(numpy_scalars, m) {
3038
try { py::module::import("numpy"); }
3139
catch (...) { return; }
3240

33-
register_test<bool>(m, "bool");
34-
register_test<int8_t>(m, "int8");
35-
register_test<int16_t>(m, "int16");
36-
register_test<int32_t>(m, "int32");
37-
register_test<int64_t>(m, "int64");
38-
register_test<uint8_t>(m, "uint8");
39-
register_test<uint16_t>(m, "uint16");
40-
register_test<uint32_t>(m, "uint32");
41-
register_test<uint64_t>(m, "uint64");
42-
register_test<float>(m, "float32");
43-
register_test<double>(m, "float64");
41+
using cfloat = std::complex<float>;
42+
using cdouble = std::complex<double>;
43+
using clongdouble = std::complex<long double>;
44+
45+
register_test<bool>(m, "bool", [](bool x) { return !x; });
46+
register_test<int8_t>(m, "int8", add<int8_t>(-8));
47+
register_test<int16_t>(m, "int16", add<int16_t>(-16));
48+
register_test<int32_t>(m, "int32", add<int32_t>(-32));
49+
register_test<int64_t>(m, "int64", add<int64_t>(-64));
50+
register_test<uint8_t>(m, "uint8", add<uint8_t>(8));
51+
register_test<uint16_t>(m, "uint16", add<uint16_t>(16));
52+
register_test<uint32_t>(m, "uint32", add<uint32_t>(32));
53+
register_test<uint64_t>(m, "uint64", add<uint64_t>(64));
54+
register_test<float>(m, "float32", add<float>(0.125f));
55+
register_test<double>(m, "float64", add<double>(0.25f));
56+
register_test<long double>(m, "longdouble", add<long double>(0.5L));
57+
register_test<cfloat>(m, "complex64", add<cfloat>({0, -0.125f}));
58+
register_test<cdouble>(m, "complex128", add<cdouble>({0, -0.25}));
59+
register_test<clongdouble>(m, "longcomplex", add<clongdouble>({0, -0.5L}));
4460
}

tests/test_numpy_scalars.py

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,66 +4,85 @@
44

55
pytestmark = pytest.requires_numpy
66

7+
SCALAR_TYPES = {}
8+
79
with pytest.suppress(ImportError):
810
import numpy as np
911

12+
SCALAR_TYPES = dict([
13+
(np.bool_, False),
14+
(np.int8, -7),
15+
(np.int16, -15),
16+
(np.int32, -31),
17+
(np.int64, -63),
18+
(np.uint8, 9),
19+
(np.uint16, 17),
20+
(np.uint32, 33),
21+
(np.uint64, 65),
22+
(np.single, 1.125),
23+
(np.double, 1.25),
24+
(np.longdouble, 1.5),
25+
(np.csingle, 1 - 0.125j),
26+
(np.cdouble, 1 - 0.25j),
27+
(np.clongdouble, 1 - 0.5j),
28+
])
1029

11-
@pytest.fixture(scope='module')
12-
def scalar_types():
13-
return [
14-
np.bool_,
15-
np.int8,
16-
np.int16,
17-
np.int32,
18-
np.int64,
19-
np.uint8,
20-
np.uint16,
21-
np.uint32,
22-
np.uint64,
23-
np.float32,
24-
np.float64,
25-
]
30+
ALL_TYPES = [int, bool, float, bytes, str, type(None)] + list(SCALAR_TYPES)
2631

2732

28-
@pytest.fixture(scope='module')
29-
def other_types():
30-
return [int, bool, float, bytes, str, np.complex64, type(None)]
33+
def type_name(tp):
34+
try:
35+
if tp is np.longdouble:
36+
return 'longdouble'
37+
elif issubclass(tp, np.floating):
38+
return 'float' + str(8 * tp().itemsize)
39+
elif tp is np.clongdouble:
40+
return 'longcomplex'
41+
elif issubclass(tp, np.complexfloating):
42+
return 'complex' + str(8 * tp().itemsize)
43+
return tp.__name__.rstrip('_')
44+
except BaseException:
45+
# no numpy
46+
return str(tp)
3147

3248

33-
def signature(tp):
34-
name = tp.__name__.rstrip('_')
35-
return 'test_numpy_scalars(x: {tp}) -> Tuple[str, {tp}]'.format(tp=name)
49+
@pytest.fixture(scope='module', params=list(SCALAR_TYPES), ids=type_name)
50+
def scalar_type(request):
51+
return request.param
3652

3753

38-
def test_numpy_scalars_single(scalar_types, other_types):
39-
s_tp = 'str' if sys.version_info[0] >= 3 else 'unicode'
40-
for scalar_type in scalar_types:
41-
name = scalar_type.__name__.rstrip('_')
42-
func = getattr(m, 'test_' + name)
43-
sig = 'test_{tp}(x: {tp}) -> Tuple[{s_tp}, {tp}]\n'.format(tp=name, s_tp=s_tp)
44-
assert func.__doc__ == sig
45-
for tp in (scalar_types + other_types):
46-
value = None if isinstance(None, tp) else tp()
47-
if tp is scalar_type:
48-
result = func(value)
49-
assert result[0] == name
50-
assert isinstance(result[1], tp)
51-
assert result[1] == tp(1)
52-
else:
53-
with pytest.raises(TypeError):
54-
func(value)
54+
def expected_signature(tp):
55+
s = 'str' if sys.version_info[0] >= 3 else 'unicode'
56+
t = type_name(tp)
57+
return 'test_{t}(x: {t}) -> Tuple[{s}, {t}]\n'.format(s=s, t=t)
5558

5659

57-
def test_numpy_scalars_overload(scalar_types, other_types):
58-
func = m.test_numpy_scalars
59-
for tp in (scalar_types + other_types):
60-
value = None if isinstance(None, tp) else tp()
61-
if tp in scalar_types:
62-
name = tp.__name__.rstrip('_')
60+
def test_numpy_scalars_single(scalar_type):
61+
expected = SCALAR_TYPES[scalar_type]
62+
name = type_name(scalar_type)
63+
func = getattr(m, 'test_' + name)
64+
assert func.__doc__ == expected_signature(scalar_type)
65+
for tp in ALL_TYPES:
66+
value = None if isinstance(None, tp) else tp(1)
67+
if tp is scalar_type:
6368
result = func(value)
6469
assert result[0] == name
6570
assert isinstance(result[1], tp)
66-
assert result[1] == tp(1)
71+
assert result[1] == tp(expected)
72+
else:
73+
with pytest.raises(TypeError):
74+
func(value)
75+
76+
77+
def test_numpy_scalars_overload():
78+
func = m.test_numpy_scalars
79+
for tp in ALL_TYPES:
80+
value = None if isinstance(None, tp) else tp(1)
81+
if tp in SCALAR_TYPES:
82+
result = func(value)
83+
assert result[0] == type_name(tp)
84+
assert isinstance(result[1], tp)
85+
assert result[1] == tp(SCALAR_TYPES[tp])
6786
else:
6887
with pytest.raises(TypeError):
6988
func(value)

0 commit comments

Comments
 (0)