Skip to content

Commit 771d357

Browse files
authored
fix: re-wrap awkward arrays in highlevel evaluator (#306)
* fix: re-wrap awkward arrays in highlevel evaluator Closes #296 * Refactor as proposed
1 parent 4cb68a3 commit 771d357

3 files changed

Lines changed: 101 additions & 97 deletions

File tree

src/correctionlib/convert.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Tools to convert other formats to correctionlib"""
22

3+
from __future__ import annotations
4+
35
from collections.abc import Iterable, Sequence
46
from numbers import Real
5-
from typing import TYPE_CHECKING, Any, Optional, Union, cast
7+
from typing import TYPE_CHECKING, Any, cast
68

79
import numpy
810

@@ -30,7 +32,7 @@
3032

3133
def from_uproot_THx(
3234
path: str,
33-
axis_names: Optional[list[str]] = None,
35+
axis_names: list[str] | None = None,
3436
flow: Literal["clamp", "error"] = "error",
3537
) -> Correction:
3638
"""Convert a ROOT histogram
@@ -51,17 +53,17 @@ def from_uproot_THx(
5153

5254

5355
def from_histogram(
54-
hist: "PlottableHistogram",
55-
axis_names: Optional[list[str]] = None,
56-
flow: Optional[Union[Content, Literal["clamp", "error"]]] = "error",
56+
hist: PlottableHistogram,
57+
axis_names: list[str] | None = None,
58+
flow: Content | Literal["clamp", "error"] | None = "error",
5759
) -> Correction:
5860
"""Read any object with PlottableHistogram interface protocol
5961
6062
Interface as defined in
6163
https://github.com/scikit-hep/uhi/blob/v0.1.1/src/uhi/typing/plottable.py
6264
"""
6365

64-
def read_axis(axis: "PlottableAxis", pos: int) -> Variable:
66+
def read_axis(axis: PlottableAxis, pos: int) -> Variable:
6567
axtype = "real"
6668
if len(axis) == 0:
6769
raise ValueError(f"Zero-length axis {axis}, what to do?")
@@ -83,7 +85,7 @@ def read_axis(axis: "PlottableAxis", pos: int) -> Variable:
8385
variables = [read_axis(ax, i) for i, ax in enumerate(hist.axes)]
8486
# Here we could try to optimize the ordering
8587

86-
def edges(axis: "PlottableAxis") -> list[float]:
88+
def edges(axis: PlottableAxis) -> list[float]:
8789
out = []
8890
for i, b in enumerate(axis):
8991
if isinstance(b, (str, int)):
@@ -96,16 +98,16 @@ def edges(axis: "PlottableAxis") -> list[float]:
9698
out.append(b[1])
9799
return out
98100

99-
def flatten_to(values: "ndarray[Any, Any]", depth: int) -> Iterable[Any]:
101+
def flatten_to(values: ndarray[Any, Any], depth: int) -> Iterable[Any]:
100102
for value in values:
101103
if depth > 0:
102104
yield from flatten_to(value, depth - 1)
103105
else:
104106
yield value
105107

106108
def build_data(
107-
values: "ndarray[Any, Any]",
108-
axes: Sequence["PlottableAxis"],
109+
values: ndarray[Any, Any],
110+
axes: Sequence[PlottableAxis],
109111
variables: list[Variable],
110112
) -> Content:
111113
vartype = variables[0].type
@@ -179,9 +181,9 @@ def build_data(
179181

180182

181183
def ndpolyfit(
182-
points: list["ndarray[Any, Any]"],
183-
values: "ndarray[Any, Any]",
184-
weights: "ndarray[Any, Any]",
184+
points: list[ndarray[Any, Any]],
185+
values: ndarray[Any, Any],
186+
weights: ndarray[Any, Any],
185187
varnames: list[str],
186188
degree: tuple[int],
187189
) -> tuple[Correction, Any]:
@@ -216,7 +218,7 @@ def ndpolyfit(
216218
raise NotImplementedError(
217219
"correctionlib Formula not available for more than 4 variables"
218220
)
219-
_degree: "ndarray[Any, Any]" = numpy.array(degree, dtype=int)
221+
_degree: ndarray[Any, Any] = numpy.array(degree, dtype=int)
220222
npoints = len(values)
221223
powergrid = numpy.ones(shape=(npoints, *(_degree + 1)))
222224
for i, (x, deg) in enumerate(zip(points, _degree)):

src/correctionlib/highlevel.py

Lines changed: 61 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""High-level correctionlib objects"""
22

3+
from __future__ import annotations
4+
35
import json
46
from collections.abc import Iterator, Mapping
57
from numbers import Integral
6-
from typing import TYPE_CHECKING, Any, Callable, Union
8+
from typing import TYPE_CHECKING, Any, Callable
79

810
import numpy
911
from packaging import version
@@ -101,7 +103,7 @@ def _call_as_numpy(
101103

102104
def _wrap_awkward(
103105
func: Callable[..., Any],
104-
*args: Union["awkward.Array", "numpy.ndarray[Any, Any]", str, int, float],
106+
*args: awkward.Array | numpy.ndarray[Any, Any] | str | int | float,
105107
) -> Any:
106108
from functools import partial
107109

@@ -136,14 +138,14 @@ def _wrap_awkward(
136138

137139
def _call_dask_correction(
138140
correction: Any,
139-
*args: Union["numpy.ndarray[Any, Any]", str, int, float],
141+
*args: numpy.ndarray[Any, Any] | str | int | float,
140142
):
141143
return _wrap_awkward(correction._base.evalv, *args)
142144

143145

144146
def _wrap_dask_awkward(
145147
correction: Any,
146-
*args: Union["numpy.ndarray[Any, Any]", str, int, float],
148+
*args: numpy.ndarray[Any, Any] | str | int | float,
147149
) -> Any:
148150
import dask.delayed
149151
import dask_awkward
@@ -177,14 +179,58 @@ def _wrap_dask_awkward(
177179
)
178180

179181

182+
def _isinstance(arg: Any, clsprefix: str) -> bool:
183+
"""Return True if arg is an instance of a class with the given prefix
184+
185+
Avoids importing modules
186+
"""
187+
return str(type(arg)).startswith(f"<class '{clsprefix}.")
188+
189+
190+
def _evaluate(
191+
corr: Correction | CompoundCorrection,
192+
*args: awkward.Array | numpy.ndarray[Any, Any] | str | int | float,
193+
) -> float | awkward.Array | numpy.ndarray[Any, numpy.dtype[numpy.float64]]:
194+
# TODO: create a ufunc with numpy.vectorize in constructor?
195+
if any(_isinstance(arg, "dask.array") for arg in args):
196+
raise TypeError(
197+
"Correctionlib does not yet handle dask.array collections. "
198+
"If you require this functionality (i.e. you cannot or do "
199+
"not want to use dask_awkward/awkward arrays) please open an "
200+
"issue at https://github.com/cms-nanoAOD/correctionlib/issues."
201+
)
202+
if any(_isinstance(arg, "dask_awkward") for arg in args):
203+
return _wrap_dask_awkward(corr, *args) # type: ignore
204+
if any(_isinstance(arg, "awkward") for arg in args):
205+
return _wrap_awkward(corr._base.evalv, *args) # type: ignore
206+
if all(isinstance(arg, (str, int, float)) for arg in args):
207+
return corr._base.evaluate(*args) # type: ignore
208+
209+
# everything else: convert to numpy and broadcast
210+
vargs = [
211+
numpy.asarray(arg) for arg in args if not isinstance(arg, (str, int, float))
212+
]
213+
assert vargs, "should have caught all-scalar case above"
214+
bargs = numpy.broadcast_arrays(*vargs)
215+
oshape = bargs[0].shape
216+
fargs = (arg.flatten() for arg in bargs)
217+
out = corr._base.evalv(
218+
*(
219+
next(fargs) if not isinstance(arg, (str, int, float)) else arg
220+
for arg in args
221+
)
222+
)
223+
return out.reshape(oshape)
224+
225+
180226
class Correction:
181227
"""High-level correction evaluator object
182228
183229
This class is typically instantiated by accessing a named correction from
184230
a CorrectionSet object, rather than directly by construction.
185231
"""
186232

187-
def __init__(self, base: correctionlib._core.Correction, context: "CorrectionSet"):
233+
def __init__(self, base: correctionlib._core.Correction, context: CorrectionSet):
188234
self._base = base
189235
self._name = base.name
190236
self._context = context
@@ -218,43 +264,9 @@ def output(self) -> correctionlib._core.Variable:
218264
return self._base.output
219265

220266
def evaluate(
221-
self, *args: Union["numpy.ndarray[Any, Any]", str, int, float]
222-
) -> Union[float, "numpy.ndarray[Any, numpy.dtype[numpy.float64]]"]:
223-
# TODO: create a ufunc with numpy.vectorize in constructor?
224-
if any(str(type(arg)).startswith("<class 'dask.array.") for arg in args):
225-
raise TypeError(
226-
"Correctionlib does not yet handle dask.array collections. "
227-
"If you require this functionality (i.e. you cannot or do "
228-
"not want to use dask_awkward/awkward arrays) please open an "
229-
"issue at https://github.com/cms-nanoAOD/correctionlib/issues."
230-
)
231-
try:
232-
vargs = [
233-
numpy.asarray(arg)
234-
for arg in args
235-
if not isinstance(arg, (str, int, float))
236-
]
237-
except NotImplementedError:
238-
if any(str(type(arg)).startswith("<class 'dask_awkward.") for arg in args):
239-
return _wrap_dask_awkward(self, *args) # type: ignore
240-
raise
241-
except (ValueError, TypeError):
242-
if any(str(type(arg)).startswith("<class 'awkward.") for arg in args):
243-
return _wrap_awkward(self._base.evalv, *args) # type: ignore
244-
raise
245-
246-
if vargs:
247-
bargs = numpy.broadcast_arrays(*vargs)
248-
oshape = bargs[0].shape
249-
fargs = (arg.flatten() for arg in bargs)
250-
out = self._base.evalv(
251-
*(
252-
next(fargs) if not isinstance(arg, (str, int, float)) else arg
253-
for arg in args
254-
)
255-
)
256-
return out.reshape(oshape)
257-
return self._base.evaluate(*args) # type: ignore
267+
self, *args: awkward.Array | numpy.ndarray[Any, Any] | str | int | float
268+
) -> float | awkward.Array | numpy.ndarray[Any, numpy.dtype[numpy.float64]]:
269+
return _evaluate(self, *args)
258270

259271

260272
class CompoundCorrection:
@@ -265,7 +277,7 @@ class CompoundCorrection:
265277
"""
266278

267279
def __init__(
268-
self, base: correctionlib._core.CompoundCorrection, context: "CorrectionSet"
280+
self, base: correctionlib._core.CompoundCorrection, context: CorrectionSet
269281
):
270282
self._base = base
271283
self._name = base.name
@@ -296,50 +308,16 @@ def output(self) -> correctionlib._core.Variable:
296308
return self._base.output
297309

298310
def evaluate(
299-
self, *args: Union["numpy.ndarray[Any, Any]", str, int, float]
300-
) -> Union[float, "numpy.ndarray[Any, numpy.dtype[numpy.float64]]"]:
301-
# TODO: create a ufunc with numpy.vectorize in constructor?
302-
if any(str(type(arg)).startswith("<class 'dask.array.") for arg in args):
303-
raise TypeError(
304-
"Correctionlib does not yet handle dask.array collections. "
305-
"if you require this functionality (i.e. you cannot or do "
306-
"not want to use dask_awkward/awkward arrays) please open an "
307-
"issue at https://github.com/cms-nanoAOD/correctionlib/issues."
308-
)
309-
try:
310-
vargs = [
311-
numpy.asarray(arg)
312-
for arg in args
313-
if not isinstance(arg, (str, int, float))
314-
]
315-
except NotImplementedError:
316-
if any(str(type(arg)).startswith("<class 'dask_awkward.") for arg in args):
317-
return _wrap_dask_awkward(self, *args) # type: ignore
318-
raise
319-
except (ValueError, TypeError):
320-
if any(str(type(arg)).startswith("<class 'awkward.") for arg in args):
321-
return _wrap_awkward(self._base.evalv, *args) # type: ignore
322-
raise
323-
324-
if vargs:
325-
bargs = numpy.broadcast_arrays(*vargs)
326-
oshape = bargs[0].shape
327-
fargs = (arg.flatten() for arg in bargs)
328-
out = self._base.evalv(
329-
*(
330-
next(fargs) if not isinstance(arg, (str, int, float)) else arg
331-
for arg in args
332-
)
333-
)
334-
return out.reshape(oshape)
335-
return self._base.evaluate(*args) # type: ignore
311+
self, *args: awkward.Array | numpy.ndarray[Any, Any] | str | int | float
312+
) -> float | awkward.Array | numpy.ndarray[Any, numpy.dtype[numpy.float64]]:
313+
return _evaluate(self, *args)
336314

337315

338316
class _CompoundMap(Mapping[str, CompoundCorrection]):
339317
def __init__(
340318
self,
341319
base: Mapping[str, correctionlib._core.CompoundCorrection],
342-
context: "CorrectionSet",
320+
context: CorrectionSet,
343321
):
344322
self._base = base
345323
self._context = context
@@ -372,11 +350,11 @@ def __init__(self, data: Any):
372350
self._base = correctionlib._core.CorrectionSet.from_string(self._data)
373351

374352
@classmethod
375-
def from_file(cls, filename: str) -> "CorrectionSet":
353+
def from_file(cls, filename: str) -> CorrectionSet:
376354
return cls(open_auto(filename))
377355

378356
@classmethod
379-
def from_string(cls, data: str) -> "CorrectionSet":
357+
def from_string(cls, data: str) -> CorrectionSet:
380358
return cls(data)
381359

382360
def __getstate__(self) -> dict[str, Any]:

tests/test_issue296.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import awkward as ak
2+
import numpy as np
3+
4+
import correctionlib.schemav2 as cs
5+
6+
7+
def test_issue_296():
8+
corr = cs.Correction(
9+
name="test",
10+
version=1,
11+
inputs=[
12+
cs.Variable(name="x", type="int"),
13+
],
14+
output=cs.Variable(name="out", type="real"),
15+
data=cs.HashPRNG(nodetype="hashprng", inputs=["x"], distribution="normal"),
16+
).to_evaluator()
17+
18+
x = ak.Array([1, 2, 3, 4, 5])
19+
result = corr.evaluate(x)
20+
assert isinstance(result, ak.Array)
21+
22+
x = np.array([1, 2, 3, 4, 5])
23+
result = corr.evaluate(x)
24+
assert isinstance(result, np.ndarray)

0 commit comments

Comments
 (0)