11"""High-level correctionlib objects"""
22
3+ from __future__ import annotations
4+
35import json
46from collections .abc import Iterator , Mapping
57from numbers import Integral
6- from typing import TYPE_CHECKING , Any , Callable , Union
8+ from typing import TYPE_CHECKING , Any , Callable
79
810import numpy
911from packaging import version
@@ -101,7 +103,7 @@ def _call_as_numpy(
101103
102104def _wrap_awkward (
103105 func : Callable [..., Any ],
104- * args : Union [ " awkward.Array" , " numpy.ndarray[Any, Any]" , str , int , float ] ,
106+ * args : awkward .Array | numpy .ndarray [Any , Any ] | str | int | float ,
105107) -> Any :
106108 from functools import partial
107109
@@ -136,14 +138,14 @@ def _wrap_awkward(
136138
137139def _call_dask_correction (
138140 correction : Any ,
139- * args : Union [ " numpy.ndarray[Any, Any]" , str , int , float ] ,
141+ * args : numpy .ndarray [Any , Any ] | str | int | float ,
140142):
141143 return _wrap_awkward (correction ._base .evalv , * args )
142144
143145
144146def _wrap_dask_awkward (
145147 correction : Any ,
146- * args : Union [ " numpy.ndarray[Any, Any]" , str , int , float ] ,
148+ * args : numpy .ndarray [Any , Any ] | str | int | float ,
147149) -> Any :
148150 import dask .delayed
149151 import dask_awkward
@@ -177,14 +179,58 @@ def _wrap_dask_awkward(
177179 )
178180
179181
182+ def _isinstance (arg : Any , clsprefix : str ) -> bool :
183+ """Return True if arg is an instance of a class with the given prefix
184+
185+ Avoids importing modules
186+ """
187+ return str (type (arg )).startswith (f"<class '{ clsprefix } ." )
188+
189+
190+ def _evaluate (
191+ corr : Correction | CompoundCorrection ,
192+ * args : awkward .Array | numpy .ndarray [Any , Any ] | str | int | float ,
193+ ) -> float | awkward .Array | numpy .ndarray [Any , numpy .dtype [numpy .float64 ]]:
194+ # TODO: create a ufunc with numpy.vectorize in constructor?
195+ if any (_isinstance (arg , "dask.array" ) for arg in args ):
196+ raise TypeError (
197+ "Correctionlib does not yet handle dask.array collections. "
198+ "If you require this functionality (i.e. you cannot or do "
199+ "not want to use dask_awkward/awkward arrays) please open an "
200+ "issue at https://github.com/cms-nanoAOD/correctionlib/issues."
201+ )
202+ if any (_isinstance (arg , "dask_awkward" ) for arg in args ):
203+ return _wrap_dask_awkward (corr , * args ) # type: ignore
204+ if any (_isinstance (arg , "awkward" ) for arg in args ):
205+ return _wrap_awkward (corr ._base .evalv , * args ) # type: ignore
206+ if all (isinstance (arg , (str , int , float )) for arg in args ):
207+ return corr ._base .evaluate (* args ) # type: ignore
208+
209+ # everything else: convert to numpy and broadcast
210+ vargs = [
211+ numpy .asarray (arg ) for arg in args if not isinstance (arg , (str , int , float ))
212+ ]
213+ assert vargs , "should have caught all-scalar case above"
214+ bargs = numpy .broadcast_arrays (* vargs )
215+ oshape = bargs [0 ].shape
216+ fargs = (arg .flatten () for arg in bargs )
217+ out = corr ._base .evalv (
218+ * (
219+ next (fargs ) if not isinstance (arg , (str , int , float )) else arg
220+ for arg in args
221+ )
222+ )
223+ return out .reshape (oshape )
224+
225+
180226class Correction :
181227 """High-level correction evaluator object
182228
183229 This class is typically instantiated by accessing a named correction from
184230 a CorrectionSet object, rather than directly by construction.
185231 """
186232
187- def __init__ (self , base : correctionlib ._core .Correction , context : " CorrectionSet" ):
233+ def __init__ (self , base : correctionlib ._core .Correction , context : CorrectionSet ):
188234 self ._base = base
189235 self ._name = base .name
190236 self ._context = context
@@ -218,43 +264,9 @@ def output(self) -> correctionlib._core.Variable:
218264 return self ._base .output
219265
220266 def evaluate (
221- self , * args : Union ["numpy.ndarray[Any, Any]" , str , int , float ]
222- ) -> Union [float , "numpy.ndarray[Any, numpy.dtype[numpy.float64]]" ]:
223- # TODO: create a ufunc with numpy.vectorize in constructor?
224- if any (str (type (arg )).startswith ("<class 'dask.array." ) for arg in args ):
225- raise TypeError (
226- "Correctionlib does not yet handle dask.array collections. "
227- "If you require this functionality (i.e. you cannot or do "
228- "not want to use dask_awkward/awkward arrays) please open an "
229- "issue at https://github.com/cms-nanoAOD/correctionlib/issues."
230- )
231- try :
232- vargs = [
233- numpy .asarray (arg )
234- for arg in args
235- if not isinstance (arg , (str , int , float ))
236- ]
237- except NotImplementedError :
238- if any (str (type (arg )).startswith ("<class 'dask_awkward." ) for arg in args ):
239- return _wrap_dask_awkward (self , * args ) # type: ignore
240- raise
241- except (ValueError , TypeError ):
242- if any (str (type (arg )).startswith ("<class 'awkward." ) for arg in args ):
243- return _wrap_awkward (self ._base .evalv , * args ) # type: ignore
244- raise
245-
246- if vargs :
247- bargs = numpy .broadcast_arrays (* vargs )
248- oshape = bargs [0 ].shape
249- fargs = (arg .flatten () for arg in bargs )
250- out = self ._base .evalv (
251- * (
252- next (fargs ) if not isinstance (arg , (str , int , float )) else arg
253- for arg in args
254- )
255- )
256- return out .reshape (oshape )
257- return self ._base .evaluate (* args ) # type: ignore
267+ self , * args : awkward .Array | numpy .ndarray [Any , Any ] | str | int | float
268+ ) -> float | awkward .Array | numpy .ndarray [Any , numpy .dtype [numpy .float64 ]]:
269+ return _evaluate (self , * args )
258270
259271
260272class CompoundCorrection :
@@ -265,7 +277,7 @@ class CompoundCorrection:
265277 """
266278
267279 def __init__ (
268- self , base : correctionlib ._core .CompoundCorrection , context : " CorrectionSet"
280+ self , base : correctionlib ._core .CompoundCorrection , context : CorrectionSet
269281 ):
270282 self ._base = base
271283 self ._name = base .name
@@ -296,50 +308,16 @@ def output(self) -> correctionlib._core.Variable:
296308 return self ._base .output
297309
298310 def evaluate (
299- self , * args : Union ["numpy.ndarray[Any, Any]" , str , int , float ]
300- ) -> Union [float , "numpy.ndarray[Any, numpy.dtype[numpy.float64]]" ]:
301- # TODO: create a ufunc with numpy.vectorize in constructor?
302- if any (str (type (arg )).startswith ("<class 'dask.array." ) for arg in args ):
303- raise TypeError (
304- "Correctionlib does not yet handle dask.array collections. "
305- "if you require this functionality (i.e. you cannot or do "
306- "not want to use dask_awkward/awkward arrays) please open an "
307- "issue at https://github.com/cms-nanoAOD/correctionlib/issues."
308- )
309- try :
310- vargs = [
311- numpy .asarray (arg )
312- for arg in args
313- if not isinstance (arg , (str , int , float ))
314- ]
315- except NotImplementedError :
316- if any (str (type (arg )).startswith ("<class 'dask_awkward." ) for arg in args ):
317- return _wrap_dask_awkward (self , * args ) # type: ignore
318- raise
319- except (ValueError , TypeError ):
320- if any (str (type (arg )).startswith ("<class 'awkward." ) for arg in args ):
321- return _wrap_awkward (self ._base .evalv , * args ) # type: ignore
322- raise
323-
324- if vargs :
325- bargs = numpy .broadcast_arrays (* vargs )
326- oshape = bargs [0 ].shape
327- fargs = (arg .flatten () for arg in bargs )
328- out = self ._base .evalv (
329- * (
330- next (fargs ) if not isinstance (arg , (str , int , float )) else arg
331- for arg in args
332- )
333- )
334- return out .reshape (oshape )
335- return self ._base .evaluate (* args ) # type: ignore
311+ self , * args : awkward .Array | numpy .ndarray [Any , Any ] | str | int | float
312+ ) -> float | awkward .Array | numpy .ndarray [Any , numpy .dtype [numpy .float64 ]]:
313+ return _evaluate (self , * args )
336314
337315
338316class _CompoundMap (Mapping [str , CompoundCorrection ]):
339317 def __init__ (
340318 self ,
341319 base : Mapping [str , correctionlib ._core .CompoundCorrection ],
342- context : " CorrectionSet" ,
320+ context : CorrectionSet ,
343321 ):
344322 self ._base = base
345323 self ._context = context
@@ -372,11 +350,11 @@ def __init__(self, data: Any):
372350 self ._base = correctionlib ._core .CorrectionSet .from_string (self ._data )
373351
374352 @classmethod
375- def from_file (cls , filename : str ) -> " CorrectionSet" :
353+ def from_file (cls , filename : str ) -> CorrectionSet :
376354 return cls (open_auto (filename ))
377355
378356 @classmethod
379- def from_string (cls , data : str ) -> " CorrectionSet" :
357+ def from_string (cls , data : str ) -> CorrectionSet :
380358 return cls (data )
381359
382360 def __getstate__ (self ) -> dict [str , Any ]:
0 commit comments