Skip to content

Commit a071b78

Browse files
authored
Migrate Variable integer type to 64-bit (#302)
* Migrate Variable integer type to 64-bit Also fix #298 by checking for integer overflow when python ints are passed, since they have arbitrary width. Also bump pyupgrade version * Consistently use std::runtime_error Probably we should be using std::invalid_argument however * Consistently use std::runtime_error Probably should be std::invalid_argument Also bump pre-commit as needed * Move to ValueError consistently for mismatched inputs Why not TypeError? Motivated by the fact that unpacking *args into the wrong number is a ValueError. * WINDOWS
1 parent f28d75e commit a071b78

11 files changed

Lines changed: 107 additions & 52 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/psf/black
3-
rev: 22.3.0
3+
rev: 25.9.0
44
hooks:
55
- id: black
66

@@ -24,7 +24,7 @@ repos:
2424
- id: isort
2525

2626
- repo: https://github.com/asottile/pyupgrade
27-
rev: v2.31.0
27+
rev: v3.21.0
2828
hooks:
2929
- id: pyupgrade
3030
args: ["--py39-plus"]

include/correction.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class JSONObject; // internal wrapper around rapidjson
1717
class Variable {
1818
public:
1919
enum class VarType {string, integer, real};
20-
typedef std::variant<int, double, std::string> Type;
20+
typedef std::variant<int64_t, double, std::string> Type;
2121

2222
Variable(const JSONObject& json);
2323
std::string name() const { return name_; };
@@ -225,7 +225,7 @@ class Category {
225225
double evaluate(const std::vector<Variable::Type>& values) const;
226226

227227
private:
228-
typedef std::map<int, Content> IntMap;
228+
typedef std::map<int64_t, Content> IntMap;
229229
typedef std::map<std::string, Content> StrMap;
230230
std::variant<IntMap, StrMap> map_;
231231
std::unique_ptr<const Content> default_;

src/correction.cc

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ namespace {
154154
{
155155
double value = std::visit([](auto&& arg) -> double {
156156
using T = std::decay_t<decltype(arg)>;
157-
if constexpr (std::is_same_v<T, int>) return static_cast<double>(arg);
157+
if constexpr (std::is_same_v<T, int64_t>) return static_cast<double>(arg);
158158
else if constexpr (std::is_same_v<T, double>) return arg;
159159
else throw std::logic_error("I should not have ever seen a string");
160160
}, value_variant);
@@ -286,7 +286,7 @@ void Variable::validate(const Type& t) const {
286286
throw std::runtime_error("Input " + name() + " has wrong type: got string expected " + typeStr());
287287
}
288288
}
289-
else if ( std::holds_alternative<int>(t) ) {
289+
else if ( std::holds_alternative<int64_t>(t) ) {
290290
if ( type_ != VarType::integer ) {
291291
throw std::runtime_error("Input " + name() + " has wrong type: got int expected " + typeStr());
292292
}
@@ -398,8 +398,8 @@ double Transform::evaluate(const std::vector<Variable::Type>& values) const {
398398
if ( std::holds_alternative<double>(v) ) {
399399
v = vnew;
400400
}
401-
else if ( std::holds_alternative<int>(v) ) {
402-
v = (int) std::round(vnew);
401+
else if ( std::holds_alternative<int64_t>(v) ) {
402+
v = (int64_t) std::round(vnew);
403403
}
404404
else {
405405
throw std::logic_error("I should not have ever seen a string");
@@ -432,7 +432,7 @@ double HashPRNG::evaluate(const std::vector<Variable::Type>& values) const {
432432
size_t nbytes = sizeof(uint64_t)*variablesIdx_.size();
433433
uint64_t* seedData = (uint64_t*) alloca(nbytes);
434434
for(size_t i=0; i<variablesIdx_.size(); ++i) {
435-
if ( auto v = std::get_if<int>(&values[variablesIdx_[i]]) ) {
435+
if ( auto v = std::get_if<int64_t>(&values[variablesIdx_[i]]) ) {
436436
seedData[i] = static_cast<uint64_t>(*v);
437437
}
438438
else if ( auto v = std::get_if<double>(&values[variablesIdx_[i]]) ) {
@@ -652,7 +652,7 @@ double Category::evaluate(const std::vector<Variable::Type>& values) const {
652652
}
653653
}
654654
}
655-
else if ( auto pval = std::get_if<int>(&values[variableIdx_]) ) {
655+
else if ( auto pval = std::get_if<int64_t>(&values[variableIdx_]) ) {
656656
try {
657657
child = &std::get<IntMap>(map_).at(*pval);
658658
} catch (std::out_of_range& ex) {
@@ -696,11 +696,9 @@ double Correction::evaluate(const std::vector<Variable::Type>& values) const {
696696
if ( ! initialized_ ) {
697697
throw std::logic_error("Not initialized");
698698
}
699-
if ( values.size() > inputs_.size() ) {
700-
throw std::runtime_error("Too many inputs");
701-
}
702-
else if ( values.size() < inputs_.size() ) {
703-
throw std::runtime_error("Insufficient inputs");
699+
if ( values.size() != inputs_.size() ) {
700+
throw std::invalid_argument("Incorrect number of inputs (got " + std::to_string(values.size())
701+
+ ", expected " + std::to_string(inputs_.size()) + ")");
704702
}
705703
for (size_t i=0; i < inputs_.size(); ++i) {
706704
inputs_[i].validate(values[i]);
@@ -767,11 +765,9 @@ size_t CompoundCorrection::input_index(const std::string_view name) const {
767765
}
768766

769767
double CompoundCorrection::evaluate(const std::vector<Variable::Type>& values) const {
770-
if ( values.size() > inputs_.size() ) {
771-
throw std::runtime_error("Too many inputs");
772-
}
773-
else if ( values.size() < inputs_.size() ) {
774-
throw std::runtime_error("Insufficient inputs");
768+
if ( values.size() != inputs_.size() ) {
769+
throw std::invalid_argument("Incorrect number of inputs (got " + std::to_string(values.size())
770+
+ ", expected " + std::to_string(inputs_.size()) + ")");
775771
}
776772
for (size_t i=0; i < inputs_.size(); ++i) {
777773
inputs_[i].validate(values[i]);

src/correctionlib/JSONEncoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Adapted from:
99
https://stackoverflow.com/questions/16264515/json-dumps-custom-formatting
1010
"""
11+
1112
import gzip
1213
import json
1314
import math

src/correctionlib/_core/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class FormulaAst:
6969
class BinaryOp:
7070
name: str
7171
value: int
72+
7273
@property
7374
def nodetype(self) -> NodeType: ...
7475
@property

src/correctionlib/convert.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
"""Tools to convert other formats to correctionlib
1+
"""Tools to convert other formats to correctionlib"""
22

3-
"""
3+
from collections.abc import Iterable, Sequence
44
from numbers import Real
5-
from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, Union, cast
5+
from typing import TYPE_CHECKING, Any, Optional, Union, cast
66

77
import numpy
88

@@ -117,9 +117,11 @@ def build_data(
117117
"content": [
118118
{
119119
"key": axes[0][i],
120-
"value": value
121-
if isinstance(value, Real)
122-
else build_data(value, axes[1:], variables[1:]),
120+
"value": (
121+
value
122+
if isinstance(value, Real)
123+
else build_data(value, axes[1:], variables[1:])
124+
),
123125
}
124126
for i, value in enumerate(values)
125127
],
@@ -138,9 +140,11 @@ def build_data(
138140
"edges": [edges(ax) for ax in axes[:i]],
139141
"inputs": [var.name for var in variables[:i]],
140142
"content": [
141-
value
142-
if isinstance(value, Real)
143-
else build_data(value, axes[i:], variables[i:])
143+
(
144+
value
145+
if isinstance(value, Real)
146+
else build_data(value, axes[i:], variables[i:])
147+
)
144148
for value in flatten_to(values, i - 1)
145149
],
146150
"flow": flow,
@@ -152,9 +156,11 @@ def build_data(
152156
"input": variables[0].name,
153157
"edges": edges(axes[0]),
154158
"content": [
155-
value
156-
if isinstance(value, Real)
157-
else build_data(value, axes[1:], variables[1:])
159+
(
160+
value
161+
if isinstance(value, Real)
162+
else build_data(value, axes[1:], variables[1:])
163+
)
158164
for value in values
159165
],
160166
"flow": flow,

src/correctionlib/highlevel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
"""High-level correctionlib objects
1+
"""High-level correctionlib objects"""
22

3-
"""
43
import json
4+
from collections.abc import Iterator, Mapping
55
from numbers import Integral
6-
from typing import Any, Callable, Iterator, Mapping, Union
6+
from typing import Any, Callable, Union
77

88
import numpy
99
from packaging import version

src/python.cc

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,40 @@ namespace py = pybind11;
77
using namespace correction;
88

99
namespace {
10+
template<typename T> // Correction or CompoundCorrection
11+
void check_length(const T& c, py::args args) {
12+
// Error message should be the same as in Correction::evaluate
13+
if ( py::len(args) != c.inputs().size() ) {
14+
throw std::invalid_argument("Incorrect number of inputs (got " + std::to_string(py::len(args))
15+
+ ", expected " + std::to_string(c.inputs().size()) + ")");
16+
}
17+
}
18+
19+
template<typename T> // Correction or CompoundCorrection
20+
std::vector<Variable::Type> validate_pyargs(const T& c, py::args args) {
21+
// Ensure length before checking integer overflow
22+
check_length(c, args);
23+
24+
// Check for integer overflow (py::cast converts to the double alternative)
25+
for (size_t i=0; i < c.inputs().size(); ++i) {
26+
if ( c.inputs()[i].type() == Variable::VarType::integer && py::isinstance<py::int_>(args[i]) ) {
27+
py::cast<int64_t>(args[i]); // throws if out of range
28+
}
29+
}
30+
31+
return py::cast<std::vector<Variable::Type>>(args);
32+
}
1033

11-
template<typename T>
34+
template<typename T> // Correction or CompoundCorrection
1235
py::array_t<double> evalv(T& c, py::args args) {
1336
std::vector<Variable::Type> inputs;
1437
inputs.reserve(py::len(args));
1538
std::vector<std::pair<size_t, py::buffer_info>> vargs;
16-
if ( py::len(args) != c.inputs().size() ) {
17-
throw std::invalid_argument("Incorrect number of inputs (got " + std::to_string(py::len(args))
18-
+ ", expected " + std::to_string(c.inputs().size()) + ")");
19-
}
39+
check_length(c, args);
2040
for (size_t i=0; i < py::len(args); ++i) {
2141
if ( py::isinstance<py::array>(args[i]) ) {
2242
if ( c.inputs()[i].type() == Variable::VarType::integer ) {
23-
vargs.emplace_back(i, py::cast<py::array_t<int, py::array::c_style | py::array::forcecast>>(args[i]).request());
43+
vargs.emplace_back(i, py::cast<py::array_t<int64_t, py::array::c_style | py::array::forcecast>>(args[i]).request());
2444
inputs.emplace_back(0);
2545
}
2646
else if ( c.inputs()[i].type() == Variable::VarType::real ) {
@@ -52,8 +72,8 @@ namespace {
5272
py::gil_scoped_release release;
5373
for (long i=0; i < outbuffer.shape[0]; ++i) {
5474
for (const auto& varg : vargs) {
55-
if ( std::holds_alternative<int>(inputs[varg.first]) ) {
56-
inputs[varg.first] = static_cast<int*>(varg.second.ptr)[i];
75+
if ( std::holds_alternative<int64_t>(inputs[varg.first]) ) {
76+
inputs[varg.first] = static_cast<int64_t*>(varg.second.ptr)[i];
5777
}
5878
else if ( std::holds_alternative<double>(inputs[varg.first]) ) {
5979
inputs[varg.first] = static_cast<double*>(varg.second.ptr)[i];
@@ -65,7 +85,6 @@ namespace {
6585
return output;
6686
}
6787
}
68-
6988
PYBIND11_MODULE(_core, m) {
7089
m.doc() = "python binding for corrections evaluator";
7190

@@ -82,7 +101,7 @@ PYBIND11_MODULE(_core, m) {
82101
.def_property_readonly("inputs", &Correction::inputs)
83102
.def_property_readonly("output", &Correction::output)
84103
.def("evaluate", [](Correction& c, py::args args) {
85-
return c.evaluate(py::cast<std::vector<Variable::Type>>(args));
104+
return c.evaluate(validate_pyargs(c, args));
86105
})
87106
.def("evalv", evalv<Correction>);
88107

@@ -92,7 +111,7 @@ PYBIND11_MODULE(_core, m) {
92111
.def_property_readonly("inputs", &CompoundCorrection::inputs)
93112
.def_property_readonly("output", &CompoundCorrection::output)
94113
.def("evaluate", [](CompoundCorrection& c, py::args args) {
95-
return c.evaluate(py::cast<std::vector<Variable::Type>>(args));
114+
return c.evaluate(validate_pyargs(c, args));
96115
})
97116
.def("evalv", evalv<CompoundCorrection>);
98117

tests/test_core.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_evaluator():
4949
assert sf.version == 2
5050
assert sf.description == ""
5151

52-
with pytest.raises(RuntimeError):
52+
with pytest.raises(ValueError):
5353
sf.evaluate(0, 1.2, 35.0, 0.01)
5454

5555
assert sf.evaluate() == 1.234
@@ -110,21 +110,21 @@ def test_evaluator():
110110
assert sf.version == 2
111111
assert sf.description == ""
112112

113-
with pytest.raises(RuntimeError):
113+
with pytest.raises(ValueError):
114114
# too many inputs
115115
sf.evaluate(0, 1.2, 35.0, 0.01)
116116

117-
with pytest.raises(RuntimeError):
117+
with pytest.raises(ValueError):
118118
# not enough inputs
119119
sf.evaluate(1.2)
120120

121121
with pytest.raises(RuntimeError):
122122
# wrong type
123-
sf.evaluate(5)
123+
sf.evaluate(5, "blah")
124124

125125
with pytest.raises(RuntimeError):
126126
# wrong type
127-
sf.evaluate("asdf")
127+
sf.evaluate("asdf", "blah")
128128

129129
assert sf.evaluate(12.0, "blah") == 1.1
130130
# Do we need pytest.approx? Maybe not

tests/test_highlevel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_highlevel(cset):
3535
assert sf.version == 2
3636
assert sf.description == ""
3737

38-
with pytest.raises(RuntimeError):
38+
with pytest.raises(ValueError):
3939
sf.evaluate(0, 1.2, 35.0, 0.01)
4040

4141
assert sf.evaluate(1.0, 1.0) == 1.234

0 commit comments

Comments
 (0)