@@ -7,20 +7,40 @@ namespace py = pybind11;
77using namespace correction ;
88
99namespace {
10+ template <typename T> // Correction or CompoundCorrection
11+ void check_length (const T& c, py::args args) {
12+ // Error message should be the same as in Correction::evaluate
13+ if ( py::len (args) != c.inputs ().size () ) {
14+ throw std::invalid_argument (" Incorrect number of inputs (got " + std::to_string (py::len (args))
15+ + " , expected " + std::to_string (c.inputs ().size ()) + " )" );
16+ }
17+ }
18+
19+ template <typename T> // Correction or CompoundCorrection
20+ std::vector<Variable::Type> validate_pyargs (const T& c, py::args args) {
21+ // Ensure length before checking integer overflow
22+ check_length (c, args);
23+
24+ // Check for integer overflow (py::cast converts to the double alternative)
25+ for (size_t i=0 ; i < c.inputs ().size (); ++i) {
26+ if ( c.inputs ()[i].type () == Variable::VarType::integer && py::isinstance<py::int_>(args[i]) ) {
27+ py::cast<int64_t >(args[i]); // throws if out of range
28+ }
29+ }
30+
31+ return py::cast<std::vector<Variable::Type>>(args);
32+ }
1033
11- template <typename T>
34+ template <typename T> // Correction or CompoundCorrection
1235 py::array_t <double > evalv (T& c, py::args args) {
1336 std::vector<Variable::Type> inputs;
1437 inputs.reserve (py::len (args));
1538 std::vector<std::pair<size_t , py::buffer_info>> vargs;
16- if ( py::len (args) != c.inputs ().size () ) {
17- throw std::invalid_argument (" Incorrect number of inputs (got " + std::to_string (py::len (args))
18- + " , expected " + std::to_string (c.inputs ().size ()) + " )" );
19- }
39+ check_length (c, args);
2040 for (size_t i=0 ; i < py::len (args); ++i) {
2141 if ( py::isinstance<py::array>(args[i]) ) {
2242 if ( c.inputs ()[i].type () == Variable::VarType::integer ) {
23- vargs.emplace_back (i, py::cast<py::array_t <int , py::array::c_style | py::array::forcecast>>(args[i]).request ());
43+ vargs.emplace_back (i, py::cast<py::array_t <int64_t , py::array::c_style | py::array::forcecast>>(args[i]).request ());
2444 inputs.emplace_back (0 );
2545 }
2646 else if ( c.inputs ()[i].type () == Variable::VarType::real ) {
@@ -52,8 +72,8 @@ namespace {
5272 py::gil_scoped_release release;
5373 for (long i=0 ; i < outbuffer.shape [0 ]; ++i) {
5474 for (const auto & varg : vargs) {
55- if ( std::holds_alternative<int >(inputs[varg.first ]) ) {
56- inputs[varg.first ] = static_cast <int *>(varg.second .ptr )[i];
75+ if ( std::holds_alternative<int64_t >(inputs[varg.first ]) ) {
76+ inputs[varg.first ] = static_cast <int64_t *>(varg.second .ptr )[i];
5777 }
5878 else if ( std::holds_alternative<double >(inputs[varg.first ]) ) {
5979 inputs[varg.first ] = static_cast <double *>(varg.second .ptr )[i];
@@ -65,7 +85,6 @@ namespace {
6585 return output;
6686 }
6787}
68-
6988PYBIND11_MODULE (_core, m) {
7089 m.doc () = " python binding for corrections evaluator" ;
7190
@@ -82,7 +101,7 @@ PYBIND11_MODULE(_core, m) {
82101 .def_property_readonly (" inputs" , &Correction::inputs)
83102 .def_property_readonly (" output" , &Correction::output)
84103 .def (" evaluate" , [](Correction& c, py::args args) {
85- return c.evaluate (py::cast<std::vector<Variable::Type>>( args));
104+ return c.evaluate (validate_pyargs (c, args));
86105 })
87106 .def (" evalv" , evalv<Correction>);
88107
@@ -92,7 +111,7 @@ PYBIND11_MODULE(_core, m) {
92111 .def_property_readonly (" inputs" , &CompoundCorrection::inputs)
93112 .def_property_readonly (" output" , &CompoundCorrection::output)
94113 .def (" evaluate" , [](CompoundCorrection& c, py::args args) {
95- return c.evaluate (py::cast<std::vector<Variable::Type>>( args));
114+ return c.evaluate (validate_pyargs (c, args));
96115 })
97116 .def (" evalv" , evalv<CompoundCorrection>);
98117
0 commit comments