Skip to content

Commit 7d54fe4

Browse files
committed
Finetune handling of disable_overloaded_equal
1 parent 128954c commit 7d54fe4

1 file changed

Lines changed: 66 additions & 65 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 66 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2269,76 +2269,77 @@ def get_chunk(self, nchunk):
22692269
return out.schunk.get_chunk(nchunk)
22702270

22712271
def update_expr(self, new_op): # noqa: C901
2272-
original_disable = blosc2._disable_overloaded_equal
2272+
prev_flag = getattr(blosc2, "_disable_overloaded_equal", False)
22732273
# We use a lot of the original NDArray.__eq__ as 'is', so deactivate the overloaded one
22742274
blosc2._disable_overloaded_equal = True
22752275
# One of the two operands are LazyExpr instances
2276-
value1, op, value2 = new_op
2277-
# The new expression and operands
2278-
expression = None
2279-
new_operands = {}
2280-
# where() handling requires evaluating the expression prior to merge.
2281-
# This is different from reductions, where the expression is evaluated
2282-
# and returned a NumPy array (for usability convenience).
2283-
# We do things like this to enable the fusion of operations like
2284-
# `a.where(0, 1).sum()`.
2285-
# Another possibility would have been to always evaluate where() and produce
2286-
# an NDArray, but that would have been less efficient for the case above.
2287-
if hasattr(value1, "_where_args"):
2288-
value1 = value1.compute()
2289-
if hasattr(value2, "_where_args"):
2290-
value2 = value2.compute()
2291-
2292-
if not isinstance(value1, LazyExpr) and not isinstance(value2, LazyExpr):
2293-
# We converted some of the operands to NDArray (where() handling above)
2294-
new_operands = {"o0": value1, "o1": value2}
2295-
expression = f"(o0 {op} o1)"
2296-
blosc2._disable_overloaded_equal = original_disable
2297-
return self._new_expr(expression, new_operands, guess=False, out=None, where=None)
2298-
elif isinstance(value1, LazyExpr) and isinstance(value2, LazyExpr):
2299-
# Expression fusion
2300-
# Fuse operands in expressions and detect duplicates
2301-
new_operands, dup_op = fuse_operands(value1.operands, value2.operands)
2302-
# Take expression 2 and rebase the operands while removing duplicates
2303-
new_expr = fuse_expressions(value2.expression, len(value1.operands), dup_op)
2304-
expression = f"({self.expression} {op} {new_expr})"
2305-
elif isinstance(value1, LazyExpr):
2306-
if op == "~":
2307-
expression = f"({op}{self.expression})"
2308-
elif np.isscalar(value2):
2309-
expression = f"({self.expression} {op} {value2})"
2310-
elif hasattr(value2, "shape") and value2.shape == ():
2311-
expression = f"({self.expression} {op} {value2[()]})"
2312-
else:
2313-
operand_to_key = {id(v): k for k, v in value1.operands.items()}
2314-
try:
2315-
op_name = operand_to_key[id(value2)]
2316-
except KeyError:
2317-
op_name = f"o{len(self.operands)}"
2318-
new_operands = {op_name: value2}
2319-
expression = f"({self.expression} {op} {op_name})"
2320-
self.operands = value1.operands
2321-
else:
2322-
if np.isscalar(value1):
2323-
expression = f"({value1} {op} {self.expression})"
2324-
elif hasattr(value1, "shape") and value1.shape == ():
2325-
expression = f"({value1[()]} {op} {self.expression})"
2276+
try:
2277+
value1, op, value2 = new_op
2278+
# The new expression and operands
2279+
expression = None
2280+
new_operands = {}
2281+
# where() handling requires evaluating the expression prior to merge.
2282+
# This is different from reductions, where the expression is evaluated
2283+
# and returned a NumPy array (for usability convenience).
2284+
# We do things like this to enable the fusion of operations like
2285+
# `a.where(0, 1).sum()`.
2286+
# Another possibility would have been to always evaluate where() and produce
2287+
# an NDArray, but that would have been less efficient for the case above.
2288+
if hasattr(value1, "_where_args"):
2289+
value1 = value1.compute()
2290+
if hasattr(value2, "_where_args"):
2291+
value2 = value2.compute()
2292+
2293+
if not isinstance(value1, LazyExpr) and not isinstance(value2, LazyExpr):
2294+
# We converted some of the operands to NDArray (where() handling above)
2295+
new_operands = {"o0": value1, "o1": value2}
2296+
expression = f"(o0 {op} o1)"
2297+
return self._new_expr(expression, new_operands, guess=False, out=None, where=None)
2298+
elif isinstance(value1, LazyExpr) and isinstance(value2, LazyExpr):
2299+
# Expression fusion
2300+
# Fuse operands in expressions and detect duplicates
2301+
new_operands, dup_op = fuse_operands(value1.operands, value2.operands)
2302+
# Take expression 2 and rebase the operands while removing duplicates
2303+
new_expr = fuse_expressions(value2.expression, len(value1.operands), dup_op)
2304+
expression = f"({self.expression} {op} {new_expr})"
2305+
elif isinstance(value1, LazyExpr):
2306+
if op == "~":
2307+
expression = f"({op}{self.expression})"
2308+
elif np.isscalar(value2):
2309+
expression = f"({self.expression} {op} {value2})"
2310+
elif hasattr(value2, "shape") and value2.shape == ():
2311+
expression = f"({self.expression} {op} {value2[()]})"
2312+
else:
2313+
operand_to_key = {id(v): k for k, v in value1.operands.items()}
2314+
try:
2315+
op_name = operand_to_key[id(value2)]
2316+
except KeyError:
2317+
op_name = f"o{len(self.operands)}"
2318+
new_operands = {op_name: value2}
2319+
expression = f"({self.expression} {op} {op_name})"
2320+
self.operands = value1.operands
23262321
else:
2327-
operand_to_key = {id(v): k for k, v in value2.operands.items()}
2328-
try:
2329-
op_name = operand_to_key[id(value1)]
2330-
except KeyError:
2331-
op_name = f"o{len(value2.operands)}"
2332-
new_operands = {op_name: value1}
2333-
if op == "[]": # syntactic sugar for slicing
2334-
expression = f"({op_name}[{self.expression}])"
2322+
if np.isscalar(value1):
2323+
expression = f"({value1} {op} {self.expression})"
2324+
elif hasattr(value1, "shape") and value1.shape == ():
2325+
expression = f"({value1[()]} {op} {self.expression})"
23352326
else:
2336-
expression = f"({op_name} {op} {self.expression})"
2337-
self.operands = value2.operands
2338-
blosc2._disable_overloaded_equal = original_disable
2339-
# Return a new expression
2340-
operands = self.operands | new_operands
2341-
return self._new_expr(expression, operands, guess=False, out=None, where=None)
2327+
operand_to_key = {id(v): k for k, v in value2.operands.items()}
2328+
try:
2329+
op_name = operand_to_key[id(value1)]
2330+
except KeyError:
2331+
op_name = f"o{len(value2.operands)}"
2332+
new_operands = {op_name: value1}
2333+
if op == "[]": # syntactic sugar for slicing
2334+
expression = f"({op_name}[{self.expression}])"
2335+
else:
2336+
expression = f"({op_name} {op} {self.expression})"
2337+
self.operands = value2.operands
2338+
# Return a new expression
2339+
operands = self.operands | new_operands
2340+
return self._new_expr(expression, operands, guess=False, out=None, where=None)
2341+
finally:
2342+
blosc2._disable_overloaded_equal = prev_flag
23422343

23432344
@property
23442345
def dtype(self):

0 commit comments

Comments
 (0)